diff --git a/.asf.yaml b/.asf.yaml index d71e7def36ad..99fd6fac22c7 100644 --- a/.asf.yaml +++ b/.asf.yaml @@ -41,6 +41,7 @@ github: - sql enabled_merge_buttons: squash: true + squash_commit_message: PR_TITLE_AND_DESC merge: false rebase: false features: diff --git a/.github/workflows/audit.yml b/.github/workflows/audit.yml index 5d5e9e270a65..3685bb2f9a78 100644 --- a/.github/workflows/audit.yml +++ b/.github/workflows/audit.yml @@ -42,8 +42,13 @@ jobs: steps: - uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0 - name: Install cargo-audit - uses: taiki-e/install-action@5b5de1b4da26ad411330c0454bdd72929bfcbeb2 # v2.62.29 + uses: taiki-e/install-action@c5b1b6f479c32f356cc6f4ba672a47f63853b13b # v2.62.38 with: tool: cargo-audit - name: Run audit check - run: cargo audit + # RUSTSEC-2025-0111: tokio-tar is by testcontainers for orchestration + # of testing, so does not impact DataFusion's security + # See https://github.com/apache/datafusion/issues/18288 + # NOTE: can remove this once testcontainers releases a version that includes + # https://github.com/testcontainers/testcontainers-rs/pull/852 + run: cargo audit --ignore RUSTSEC-2025-0111 diff --git a/.github/workflows/extended.yml b/.github/workflows/extended.yml index 9343997e0568..23bd66a0cf35 100644 --- a/.github/workflows/extended.yml +++ b/.github/workflows/extended.yml @@ -36,6 +36,14 @@ on: # it is not expected to have many changes in these branches, # so running extended tests is not a burden - 'branch-*' + # Also run for changes to some critical areas that are most likely + # to trigger errors in extended tests + pull_request: + branches: [ '**' ] + paths: + - 'datafusion/physical*/**/*.rs' + - 'datafusion/expr*/**/*.rs' + - 'datafusion/optimizer/**/*.rs' workflow_dispatch: inputs: pr_number: diff --git a/.github/workflows/labeler/labeler-config.yml b/.github/workflows/labeler/labeler-config.yml index e40813072521..38d88059dab7 100644 --- a/.github/workflows/labeler/labeler-config.yml +++ b/.github/workflows/labeler/labeler-config.yml @@ -58,7 +58,7 @@ execution: datasource: - changed-files: - - any-glob-to-any-file: ['datafusion/datasource/**/*', 'datafusion/datasource-avro/**/*', 'datafusion/datasource-csv/**/*', 'datafusion/datasource-json/**/*', 'datafusion/datasource-parquet/**/*'] + - any-glob-to-any-file: ['datafusion/datasource/**/*', 'datafusion/datasource-avro/**/*', 'datafusion/datasource-arrow/**/*', 'datafusion/datasource-csv/**/*', 'datafusion/datasource-json/**/*', 'datafusion/datasource-parquet/**/*'] functions: - changed-files: diff --git a/.github/workflows/rust.yml b/.github/workflows/rust.yml index ecdbf031b45b..4b61a04bfb14 100644 --- a/.github/workflows/rust.yml +++ b/.github/workflows/rust.yml @@ -353,6 +353,19 @@ jobs: with: save-if: ${{ github.ref_name == 'main' }} shared-key: "amd-ci-linux-test-example" + - name: Remove unnecessary preinstalled software + run: | + echo "Disk space before cleanup:" + df -h + apt-get clean + rm -rf /__t/CodeQL + rm -rf /__t/PyPy + rm -rf /__t/Java_Temurin-Hotspot_jdk + rm -rf /__t/Python + rm -rf /__t/go + rm -rf /__t/Ruby + echo "Disk space after cleanup:" + df -h - name: Run examples run: | # test datafusion-sql examples @@ -412,7 +425,7 @@ jobs: sudo apt-get update -qq sudo apt-get install -y -qq clang - name: Setup wasm-pack - uses: taiki-e/install-action@5b5de1b4da26ad411330c0454bdd72929bfcbeb2 # v2.62.29 + uses: taiki-e/install-action@c5b1b6f479c32f356cc6f4ba672a47f63853b13b # v2.62.38 with: tool: wasm-pack - name: Run tests with headless mode @@ -739,7 +752,7 @@ jobs: - name: Setup Rust toolchain uses: ./.github/actions/setup-builder - name: Install cargo-msrv - uses: taiki-e/install-action@5b5de1b4da26ad411330c0454bdd72929bfcbeb2 # v2.62.29 + uses: taiki-e/install-action@c5b1b6f479c32f356cc6f4ba672a47f63853b13b # v2.62.38 with: tool: cargo-msrv diff --git a/Cargo.lock b/Cargo.lock index 00bd64f21eb1..aaa75ecf3247 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -50,15 +50,6 @@ dependencies = [ "core_extensions", ] -[[package]] -name = "addr2line" -version = "0.24.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "dfbe277e56a376000877090da837660b4427aad530e3028d44e0bffe4f89a1c1" -dependencies = [ - "gimli", -] - [[package]] name = "adler2" version = "2.0.1" @@ -84,7 +75,7 @@ checksum = "5a15f179cd60c4584b8a8c596927aadc462e27f2ca70c04e0071964a73ba7a75" dependencies = [ "cfg-if", "const-random", - "getrandom 0.3.3", + "getrandom 0.3.4", "once_cell", "version_check", "zerocopy", @@ -199,7 +190,7 @@ checksum = "3a033b4ced7c585199fb78ef50fca7fe2f444369ec48080c5fd072efa1a03cc7" dependencies = [ "bigdecimal", "bon", - "bzip2 0.6.0", + "bzip2 0.6.1", "crc32fast", "digest", "log", @@ -234,9 +225,9 @@ checksum = "7c02d123df017efcdfbd739ef81735b36c5ba83ec3c59c80a9d7ecc718f92e50" [[package]] name = "arrow" -version = "56.2.0" +version = "57.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6e833808ff2d94ed40d9379848a950d995043c7fb3e81a30b383f4c6033821cc" +checksum = "4df8bb5b0bd64c0b9bc61317fcc480bad0f00e56d3bc32c69a4c8dada4786bae" dependencies = [ "arrow-arith", "arrow-array", @@ -258,23 +249,23 @@ dependencies = [ [[package]] name = "arrow-arith" -version = "56.2.0" +version = "57.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ad08897b81588f60ba983e3ca39bda2b179bdd84dced378e7df81a5313802ef8" +checksum = "a1a640186d3bd30a24cb42264c2dafb30e236a6f50d510e56d40b708c9582491" dependencies = [ "arrow-array", "arrow-buffer", "arrow-data", "arrow-schema", "chrono", - "num", + "num-traits", ] [[package]] name = "arrow-array" -version = "56.2.0" +version = "57.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8548ca7c070d8db9ce7aa43f37393e4bfcf3f2d3681df278490772fd1673d08d" +checksum = "219fe420e6800979744c8393b687afb0252b3f8a89b91027d27887b72aa36d31" dependencies = [ "ahash 0.8.12", "arrow-buffer", @@ -284,25 +275,28 @@ dependencies = [ "chrono-tz", "half", "hashbrown 0.16.0", - "num", + "num-complex", + "num-integer", + "num-traits", ] [[package]] name = "arrow-buffer" -version = "56.2.0" +version = "57.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e003216336f70446457e280807a73899dd822feaf02087d31febca1363e2fccc" +checksum = "76885a2697a7edf6b59577f568b456afc94ce0e2edc15b784ce3685b6c3c5c27" dependencies = [ "bytes", "half", - "num", + "num-bigint", + "num-traits", ] [[package]] name = "arrow-cast" -version = "56.2.0" +version = "57.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "919418a0681298d3a77d1a315f625916cb5678ad0d74b9c60108eb15fd083023" +checksum = "9c9ebb4c987e6b3b236fb4a14b20b34835abfdd80acead3ccf1f9bf399e1f168" dependencies = [ "arrow-array", "arrow-buffer", @@ -315,15 +309,15 @@ dependencies = [ "comfy-table", "half", "lexical-core", - "num", + "num-traits", "ryu", ] [[package]] name = "arrow-csv" -version = "56.2.0" +version = "57.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bfa9bf02705b5cf762b6f764c65f04ae9082c7cfc4e96e0c33548ee3f67012eb" +checksum = "92386159c8d4bce96f8bd396b0642a0d544d471bdc2ef34d631aec80db40a09c" dependencies = [ "arrow-array", "arrow-cast", @@ -336,21 +330,22 @@ dependencies = [ [[package]] name = "arrow-data" -version = "56.2.0" +version = "57.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a5c64fff1d142f833d78897a772f2e5b55b36cb3e6320376f0961ab0db7bd6d0" +checksum = "727681b95de313b600eddc2a37e736dcb21980a40f640314dcf360e2f36bc89b" dependencies = [ "arrow-buffer", "arrow-schema", "half", - "num", + "num-integer", + "num-traits", ] [[package]] name = "arrow-flight" -version = "56.2.0" +version = "57.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8c8b0ba0784d56bc6266b79f5de7a24b47024e7b3a0045d2ad4df3d9b686099f" +checksum = "f70bb56412a007b0cfc116d15f24dda6adeed9611a213852a004cda20085a3b9" dependencies = [ "arrow-arith", "arrow-array", @@ -368,16 +363,17 @@ dependencies = [ "futures", "once_cell", "paste", - "prost 0.13.5", - "prost-types 0.13.5", + "prost", + "prost-types", "tonic", + "tonic-prost", ] [[package]] name = "arrow-ipc" -version = "56.2.0" +version = "57.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1d3594dcddccc7f20fd069bc8e9828ce37220372680ff638c5e00dea427d88f5" +checksum = "da9ba92e3de170295c98a84e5af22e2b037f0c7b32449445e6c493b5fca27f27" dependencies = [ "arrow-array", "arrow-buffer", @@ -391,9 +387,9 @@ dependencies = [ [[package]] name = "arrow-json" -version = "56.2.0" +version = "57.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "88cf36502b64a127dc659e3b305f1d993a544eab0d48cce704424e62074dc04b" +checksum = "b969b4a421ae83828591c6bf5450bd52e6d489584142845ad6a861f42fe35df8" dependencies = [ "arrow-array", "arrow-buffer", @@ -402,20 +398,22 @@ dependencies = [ "arrow-schema", "chrono", "half", - "indexmap 2.11.4", + "indexmap 2.12.0", + "itoa", "lexical-core", "memchr", - "num", - "serde", + "num-traits", + "ryu", + "serde_core", "serde_json", "simdutf8", ] [[package]] name = "arrow-ord" -version = "56.2.0" +version = "57.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3c8f82583eb4f8d84d4ee55fd1cb306720cddead7596edce95b50ee418edf66f" +checksum = "141c05298b21d03e88062317a1f1a73f5ba7b6eb041b350015b1cd6aabc0519b" dependencies = [ "arrow-array", "arrow-buffer", @@ -426,9 +424,9 @@ dependencies = [ [[package]] name = "arrow-pyarrow" -version = "56.2.0" +version = "57.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7d924b32e96f8bb74d94cd82bd97b313c432fcb0ea331689ef9e7c6b8be4b258" +checksum = "cfcfb2be2e9096236f449c11f425cddde18c4cc540f516d90f066f10a29ed515" dependencies = [ "arrow-array", "arrow-data", @@ -438,9 +436,9 @@ dependencies = [ [[package]] name = "arrow-row" -version = "56.2.0" +version = "57.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9d07ba24522229d9085031df6b94605e0f4b26e099fb7cdeec37abd941a73753" +checksum = "c5f3c06a6abad6164508ed283c7a02151515cef3de4b4ff2cebbcaeb85533db2" dependencies = [ "arrow-array", "arrow-buffer", @@ -451,34 +449,35 @@ dependencies = [ [[package]] name = "arrow-schema" -version = "56.2.0" +version = "57.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b3aa9e59c611ebc291c28582077ef25c97f1975383f1479b12f3b9ffee2ffabe" +checksum = "9cfa7a03d1eee2a4d061476e1840ad5c9867a544ca6c4c59256496af5d0a8be5" dependencies = [ "bitflags 2.9.4", "serde", + "serde_core", "serde_json", ] [[package]] name = "arrow-select" -version = "56.2.0" +version = "57.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8c41dbbd1e97bfcaee4fcb30e29105fb2c75e4d82ae4de70b792a5d3f66b2e7a" +checksum = "bafa595babaad59f2455f4957d0f26448fb472722c186739f4fac0823a1bdb47" dependencies = [ "ahash 0.8.12", "arrow-array", "arrow-buffer", "arrow-data", "arrow-schema", - "num", + "num-traits", ] [[package]] name = "arrow-string" -version = "56.2.0" +version = "57.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "53f5183c150fbc619eede22b861ea7c0eebed8eaac0333eaa7f6da5205fd504d" +checksum = "32f46457dbbb99f2650ff3ac23e46a929e0ab81db809b02aa5511c258348bef2" dependencies = [ "arrow-array", "arrow-buffer", @@ -486,7 +485,7 @@ dependencies = [ "arrow-schema", "arrow-select", "memchr", - "num", + "num-traits", "regex", "regex-syntax", ] @@ -537,7 +536,7 @@ checksum = "3b43422f69d8ff38f95f1b2bb76517c91589a924d1559a0e935d7c8ce0274c11" dependencies = [ "proc-macro2", "quote", - "syn 2.0.106", + "syn 2.0.108", ] [[package]] @@ -548,7 +547,7 @@ checksum = "9035ad2d096bed7955a320ee7e2230574d28fd3c3a0f186cbea1ff3c7eed5dbb" dependencies = [ "proc-macro2", "quote", - "syn 2.0.106", + "syn 2.0.108", ] [[package]] @@ -965,21 +964,6 @@ dependencies = [ "tower-service", ] -[[package]] -name = "backtrace" -version = "0.3.75" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6806a6321ec58106fea15becdad98371e28d92ccbc7c8f1b3b6dd724fe8f1002" -dependencies = [ - "addr2line", - "cfg-if", - "libc", - "miniz_oxide", - "object", - "rustc-demangle", - "windows-targets 0.52.6", -] - [[package]] name = "base64" version = "0.21.7" @@ -1033,7 +1017,7 @@ dependencies = [ "regex", "rustc-hash", "shlex", - "syn 2.0.106", + "syn 2.0.108", ] [[package]] @@ -1163,7 +1147,7 @@ dependencies = [ "proc-macro2", "quote", "rustversion", - "syn 2.0.106", + "syn 2.0.108", ] [[package]] @@ -1186,7 +1170,7 @@ dependencies = [ "proc-macro-crate", "proc-macro2", "quote", - "syn 2.0.106", + "syn 2.0.108", ] [[package]] @@ -1281,9 +1265,9 @@ dependencies = [ [[package]] name = "bzip2" -version = "0.6.0" +version = "0.6.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bea8dcd42434048e4f7a304411d9273a411f647446c1234a65ce0554923f4cff" +checksum = "f3a53fac24f34a81bc9954b5d6cfce0c21e18ec6959f44f56e8e90e4bb7c346c" dependencies = [ "libbz2-rs-sys", ] @@ -1412,9 +1396,9 @@ dependencies = [ [[package]] name = "clap" -version = "4.5.48" +version = "4.5.50" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e2134bb3ea021b78629caa971416385309e0131b351b25e01dc16fb54e1b5fae" +checksum = "0c2cfd7bf8a6017ddaa4e32ffe7403d547790db06bd171c1c53926faab501623" dependencies = [ "clap_builder", "clap_derive", @@ -1422,9 +1406,9 @@ dependencies = [ [[package]] name = "clap_builder" -version = "4.5.48" +version = "4.5.50" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c2ba64afa3c0a6df7fa517765e31314e983f51dda798ffba27b988194fb65dc9" +checksum = "0a4c05b9e80c5ccd3a7ef080ad7b6ba7d6fc00a985b8b157197075677c82c7a0" dependencies = [ "anstream", "anstyle", @@ -1434,14 +1418,14 @@ dependencies = [ [[package]] name = "clap_derive" -version = "4.5.47" +version = "4.5.49" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bbfd7eae0b0f1a6e63d4b13c9c478de77c2eb546fba158ad50b4203dc24b9f9c" +checksum = "2a0b5487afeab2deb2ff4e03a807ad1a03ac532ff5a2cee5d86884440c7f7671" dependencies = [ "heck 0.5.0", "proc-macro2", "quote", - "syn 2.0.106", + "syn 2.0.108", ] [[package]] @@ -1613,7 +1597,7 @@ dependencies = [ "anes", "cast", "ciborium", - "clap 4.5.48", + "clap 4.5.50", "criterion-plot", "futures", "is-terminal", @@ -1756,7 +1740,7 @@ dependencies = [ "proc-macro2", "quote", "strsim", - "syn 2.0.106", + "syn 2.0.108", ] [[package]] @@ -1767,7 +1751,7 @@ checksum = "d38308df82d1080de0afee5d069fa14b0326a88c14f15c5ccda35b4a6c414c81" dependencies = [ "darling_core", "quote", - "syn 2.0.106", + "syn 2.0.108", ] [[package]] @@ -1786,14 +1770,13 @@ dependencies = [ [[package]] name = "datafusion" -version = "50.2.0" +version = "50.3.0" dependencies = [ "arrow", - "arrow-ipc", "arrow-schema", "async-trait", "bytes", - "bzip2 0.6.0", + "bzip2 0.6.1", "chrono", "criterion", "ctor", @@ -1803,6 +1786,7 @@ dependencies = [ "datafusion-common", "datafusion-common-runtime", "datafusion-datasource", + "datafusion-datasource-arrow", "datafusion-datasource-avro", "datafusion-datasource-csv", "datafusion-datasource-json", @@ -1858,7 +1842,7 @@ dependencies = [ [[package]] name = "datafusion-benchmarks" -version = "50.2.0" +version = "50.3.0" dependencies = [ "arrow", "datafusion", @@ -1883,7 +1867,7 @@ dependencies = [ [[package]] name = "datafusion-catalog" -version = "50.2.0" +version = "50.3.0" dependencies = [ "arrow", "async-trait", @@ -1906,19 +1890,22 @@ dependencies = [ [[package]] name = "datafusion-catalog-listing" -version = "50.2.0" +version = "50.3.0" dependencies = [ "arrow", "async-trait", "datafusion-catalog", "datafusion-common", "datafusion-datasource", + "datafusion-datasource-parquet", "datafusion-execution", "datafusion-expr", "datafusion-physical-expr", + "datafusion-physical-expr-adapter", "datafusion-physical-expr-common", "datafusion-physical-plan", "futures", + "itertools 0.14.0", "log", "object_store", "tokio", @@ -1926,16 +1913,17 @@ dependencies = [ [[package]] name = "datafusion-cli" -version = "50.2.0" +version = "50.3.0" dependencies = [ "arrow", "async-trait", "aws-config", "aws-credential-types", "chrono", - "clap 4.5.48", + "clap 4.5.50", "ctor", "datafusion", + "datafusion-common", "dirs", "env_logger", "futures", @@ -1957,7 +1945,7 @@ dependencies = [ [[package]] name = "datafusion-common" -version = "50.2.0" +version = "50.3.0" dependencies = [ "ahash 0.8.12", "apache-avro", @@ -1967,7 +1955,7 @@ dependencies = [ "half", "hashbrown 0.14.5", "hex", - "indexmap 2.11.4", + "indexmap 2.12.0", "insta", "libc", "log", @@ -1984,7 +1972,7 @@ dependencies = [ [[package]] name = "datafusion-common-runtime" -version = "50.2.0" +version = "50.3.0" dependencies = [ "futures", "log", @@ -1993,13 +1981,13 @@ dependencies = [ [[package]] name = "datafusion-datasource" -version = "50.2.0" +version = "50.3.0" dependencies = [ "arrow", "async-compression", "async-trait", "bytes", - "bzip2 0.6.0", + "bzip2 0.6.1", "chrono", "criterion", "datafusion-common", @@ -2026,9 +2014,32 @@ dependencies = [ "zstd", ] +[[package]] +name = "datafusion-datasource-arrow" +version = "50.3.0" +dependencies = [ + "arrow", + "arrow-ipc", + "async-trait", + "bytes", + "chrono", + "datafusion-common", + "datafusion-common-runtime", + "datafusion-datasource", + "datafusion-execution", + "datafusion-expr", + "datafusion-physical-expr-common", + "datafusion-physical-plan", + "datafusion-session", + "futures", + "itertools 0.14.0", + "object_store", + "tokio", +] + [[package]] name = "datafusion-datasource-avro" -version = "50.2.0" +version = "50.3.0" dependencies = [ "apache-avro", "arrow", @@ -2047,7 +2058,7 @@ dependencies = [ [[package]] name = "datafusion-datasource-csv" -version = "50.2.0" +version = "50.3.0" dependencies = [ "arrow", "async-trait", @@ -2068,7 +2079,7 @@ dependencies = [ [[package]] name = "datafusion-datasource-json" -version = "50.2.0" +version = "50.3.0" dependencies = [ "arrow", "async-trait", @@ -2088,7 +2099,7 @@ dependencies = [ [[package]] name = "datafusion-datasource-parquet" -version = "50.2.0" +version = "50.3.0" dependencies = [ "arrow", "async-trait", @@ -2117,11 +2128,11 @@ dependencies = [ [[package]] name = "datafusion-doc" -version = "50.2.0" +version = "50.3.0" [[package]] name = "datafusion-examples" -version = "50.2.0" +version = "50.3.0" dependencies = [ "arrow", "arrow-flight", @@ -2140,7 +2151,7 @@ dependencies = [ "mimalloc", "nix", "object_store", - "prost 0.13.5", + "prost", "rand 0.9.2", "serde_json", "tempfile", @@ -2155,7 +2166,7 @@ dependencies = [ [[package]] name = "datafusion-execution" -version = "50.2.0" +version = "50.3.0" dependencies = [ "arrow", "async-trait", @@ -2176,7 +2187,7 @@ dependencies = [ [[package]] name = "datafusion-expr" -version = "50.2.0" +version = "50.3.0" dependencies = [ "arrow", "async-trait", @@ -2189,7 +2200,7 @@ dependencies = [ "datafusion-functions-window-common", "datafusion-physical-expr-common", "env_logger", - "indexmap 2.11.4", + "indexmap 2.12.0", "insta", "itertools 0.14.0", "paste", @@ -2200,18 +2211,18 @@ dependencies = [ [[package]] name = "datafusion-expr-common" -version = "50.2.0" +version = "50.3.0" dependencies = [ "arrow", "datafusion-common", - "indexmap 2.11.4", + "indexmap 2.12.0", "itertools 0.14.0", "paste", ] [[package]] name = "datafusion-ffi" -version = "50.2.0" +version = "50.3.0" dependencies = [ "abi_stable", "arrow", @@ -2226,14 +2237,14 @@ dependencies = [ "doc-comment", "futures", "log", - "prost 0.13.5", + "prost", "semver", "tokio", ] [[package]] name = "datafusion-functions" -version = "50.2.0" +version = "50.3.0" dependencies = [ "arrow", "arrow-buffer", @@ -2264,7 +2275,7 @@ dependencies = [ [[package]] name = "datafusion-functions-aggregate" -version = "50.2.0" +version = "50.3.0" dependencies = [ "ahash 0.8.12", "arrow", @@ -2285,7 +2296,7 @@ dependencies = [ [[package]] name = "datafusion-functions-aggregate-common" -version = "50.2.0" +version = "50.3.0" dependencies = [ "ahash 0.8.12", "arrow", @@ -2298,7 +2309,7 @@ dependencies = [ [[package]] name = "datafusion-functions-nested" -version = "50.2.0" +version = "50.3.0" dependencies = [ "arrow", "arrow-ord", @@ -2321,7 +2332,7 @@ dependencies = [ [[package]] name = "datafusion-functions-table" -version = "50.2.0" +version = "50.3.0" dependencies = [ "arrow", "async-trait", @@ -2335,7 +2346,7 @@ dependencies = [ [[package]] name = "datafusion-functions-window" -version = "50.2.0" +version = "50.3.0" dependencies = [ "arrow", "datafusion-common", @@ -2351,7 +2362,7 @@ dependencies = [ [[package]] name = "datafusion-functions-window-common" -version = "50.2.0" +version = "50.3.0" dependencies = [ "datafusion-common", "datafusion-physical-expr-common", @@ -2359,16 +2370,16 @@ dependencies = [ [[package]] name = "datafusion-macros" -version = "50.2.0" +version = "50.3.0" dependencies = [ "datafusion-doc", "quote", - "syn 2.0.106", + "syn 2.0.108", ] [[package]] name = "datafusion-optimizer" -version = "50.2.0" +version = "50.3.0" dependencies = [ "arrow", "async-trait", @@ -2384,7 +2395,7 @@ dependencies = [ "datafusion-physical-expr", "datafusion-sql", "env_logger", - "indexmap 2.11.4", + "indexmap 2.12.0", "insta", "itertools 0.14.0", "log", @@ -2395,7 +2406,7 @@ dependencies = [ [[package]] name = "datafusion-physical-expr" -version = "50.2.0" +version = "50.3.0" dependencies = [ "ahash 0.8.12", "arrow", @@ -2408,7 +2419,7 @@ dependencies = [ "datafusion-physical-expr-common", "half", "hashbrown 0.14.5", - "indexmap 2.11.4", + "indexmap 2.12.0", "insta", "itertools 0.14.0", "parking_lot", @@ -2420,7 +2431,7 @@ dependencies = [ [[package]] name = "datafusion-physical-expr-adapter" -version = "50.2.0" +version = "50.3.0" dependencies = [ "arrow", "datafusion-common", @@ -2433,7 +2444,7 @@ dependencies = [ [[package]] name = "datafusion-physical-expr-common" -version = "50.2.0" +version = "50.3.0" dependencies = [ "ahash 0.8.12", "arrow", @@ -2445,13 +2456,14 @@ dependencies = [ [[package]] name = "datafusion-physical-optimizer" -version = "50.2.0" +version = "50.3.0" dependencies = [ "arrow", "datafusion-common", "datafusion-execution", "datafusion-expr", "datafusion-expr-common", + "datafusion-functions", "datafusion-physical-expr", "datafusion-physical-expr-common", "datafusion-physical-plan", @@ -2464,7 +2476,7 @@ dependencies = [ [[package]] name = "datafusion-physical-plan" -version = "50.2.0" +version = "50.3.0" dependencies = [ "ahash 0.8.12", "arrow", @@ -2486,7 +2498,7 @@ dependencies = [ "futures", "half", "hashbrown 0.14.5", - "indexmap 2.11.4", + "indexmap 2.12.0", "insta", "itertools 0.14.0", "log", @@ -2500,22 +2512,35 @@ dependencies = [ [[package]] name = "datafusion-proto" -version = "50.2.0" +version = "50.3.0" dependencies = [ "arrow", "chrono", "datafusion", + "datafusion-catalog", + "datafusion-catalog-listing", "datafusion-common", + "datafusion-datasource", + "datafusion-datasource-arrow", + "datafusion-datasource-avro", + "datafusion-datasource-csv", + "datafusion-datasource-json", + "datafusion-datasource-parquet", + "datafusion-execution", "datafusion-expr", "datafusion-functions", "datafusion-functions-aggregate", + "datafusion-functions-table", "datafusion-functions-window-common", + "datafusion-physical-expr", + "datafusion-physical-expr-common", + "datafusion-physical-plan", "datafusion-proto-common", "doc-comment", "object_store", "pbjson", "pretty_assertions", - "prost 0.13.5", + "prost", "serde", "serde_json", "tokio", @@ -2523,19 +2548,19 @@ dependencies = [ [[package]] name = "datafusion-proto-common" -version = "50.2.0" +version = "50.3.0" dependencies = [ "arrow", "datafusion-common", "doc-comment", "pbjson", - "prost 0.13.5", + "prost", "serde", ] [[package]] name = "datafusion-pruning" -version = "50.2.0" +version = "50.3.0" dependencies = [ "arrow", "datafusion-common", @@ -2553,7 +2578,7 @@ dependencies = [ [[package]] name = "datafusion-session" -version = "50.2.0" +version = "50.3.0" dependencies = [ "async-trait", "datafusion-common", @@ -2565,7 +2590,7 @@ dependencies = [ [[package]] name = "datafusion-spark" -version = "50.2.0" +version = "50.3.0" dependencies = [ "arrow", "bigdecimal", @@ -2585,7 +2610,7 @@ dependencies = [ [[package]] name = "datafusion-sql" -version = "50.2.0" +version = "50.3.0" dependencies = [ "arrow", "bigdecimal", @@ -2598,7 +2623,7 @@ dependencies = [ "datafusion-functions-nested", "datafusion-functions-window", "env_logger", - "indexmap 2.11.4", + "indexmap 2.12.0", "insta", "itertools 0.14.0", "log", @@ -2611,14 +2636,14 @@ dependencies = [ [[package]] name = "datafusion-sqllogictest" -version = "50.2.0" +version = "50.3.0" dependencies = [ "arrow", "async-trait", "bigdecimal", "bytes", "chrono", - "clap 4.5.48", + "clap 4.5.50", "datafusion", "datafusion-spark", "datafusion-substrait", @@ -2645,18 +2670,19 @@ dependencies = [ [[package]] name = "datafusion-substrait" -version = "50.2.0" +version = "50.3.0" dependencies = [ "async-recursion", "async-trait", "chrono", "datafusion", "datafusion-functions-aggregate", + "half", "insta", "itertools 0.14.0", "object_store", "pbjson-types", - "prost 0.13.5", + "prost", "serde_json", "substrait", "tokio", @@ -2666,7 +2692,7 @@ dependencies = [ [[package]] name = "datafusion-wasmtest" -version = "50.2.0" +version = "50.3.0" dependencies = [ "chrono", "console_error_panic_hook", @@ -2677,7 +2703,7 @@ dependencies = [ "datafusion-optimizer", "datafusion-physical-plan", "datafusion-sql", - "getrandom 0.3.3", + "getrandom 0.3.4", "object_store", "tokio", "url", @@ -2741,7 +2767,7 @@ checksum = "97369cbbc041bc366949bc74d34658d6cda5621039731c6310521892a3a20ae0" dependencies = [ "proc-macro2", "quote", - "syn 2.0.106", + "syn 2.0.108", ] [[package]] @@ -2797,7 +2823,7 @@ dependencies = [ "enum-ordinalize", "proc-macro2", "quote", - "syn 2.0.106", + "syn 2.0.108", ] [[package]] @@ -2835,7 +2861,7 @@ checksum = "0d28318a75d4aead5c4db25382e8ef717932d0346600cacae6357eb5941bc5ff" dependencies = [ "proc-macro2", "quote", - "syn 2.0.106", + "syn 2.0.108", ] [[package]] @@ -3096,7 +3122,7 @@ checksum = "162ee34ebcb7c64a8abebc059ce0fee27c2262618d7b60ed8faf72fef13c3650" dependencies = [ "proc-macro2", "quote", - "syn 2.0.106", + "syn 2.0.108", ] [[package]] @@ -3139,16 +3165,16 @@ dependencies = [ name = "gen" version = "0.1.0" dependencies = [ - "pbjson-build 0.8.0", - "prost-build 0.14.1", + "pbjson-build", + "prost-build", ] [[package]] name = "gen-common" version = "0.1.0" dependencies = [ - "pbjson-build 0.8.0", - "prost-build 0.14.1", + "pbjson-build", + "prost-build", ] [[package]] @@ -3179,30 +3205,24 @@ dependencies = [ "cfg-if", "js-sys", "libc", - "wasi 0.11.1+wasi-snapshot-preview1", + "wasi", "wasm-bindgen", ] [[package]] name = "getrandom" -version = "0.3.3" +version = "0.3.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "26145e563e54f2cadc477553f1ec5ee650b00862f0a58bcd12cbdc5f0ea2d2f4" +checksum = "899def5c37c4fd7b2664648c28120ecec138e4d395b459e5ca34f9cce2dd77fd" dependencies = [ "cfg-if", "js-sys", "libc", "r-efi", - "wasi 0.14.7+wasi-0.2.4", + "wasip2", "wasm-bindgen", ] -[[package]] -name = "gimli" -version = "0.31.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "07e28edb80900c19c28f1072f2e8aeca7fa06b23cd4169cefe1af5aa3260783f" - [[package]] name = "glob" version = "0.3.3" @@ -3234,7 +3254,7 @@ dependencies = [ "futures-core", "futures-sink", "http 1.3.1", - "indexmap 2.11.4", + "indexmap 2.12.0", "slab", "tokio", "tokio-util", @@ -3243,9 +3263,9 @@ dependencies = [ [[package]] name = "half" -version = "2.7.0" +version = "2.7.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e54c115d4f30f52c67202f079c5f9d8b49db4691f460fdb0b4c2e838261b2ba5" +checksum = "6ea2d84b969582b4b1864a92dc5d27cd2b77b622a8d79306834f1be5ba20d84b" dependencies = [ "cfg-if", "crunchy", @@ -3494,7 +3514,7 @@ dependencies = [ "libc", "percent-encoding", "pin-project-lite", - "socket2 0.6.0", + "socket2", "tokio", "tower-service", "tracing", @@ -3665,9 +3685,9 @@ dependencies = [ [[package]] name = "indexmap" -version = "2.11.4" +version = "2.12.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4b0f83760fb341a774ed326568e19f5a863af4a952def8c39f9ab92fd95b88e5" +checksum = "6717a8d2a5a929a1a2eb43a12812498ed141a0bcfb7e8f7844fbdbe4303bba9f" dependencies = [ "equivalent", "hashbrown 0.16.0", @@ -3726,17 +3746,6 @@ version = "3.0.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8bb03732005da905c88227371639bf1ad885cc712789c011c31c5fb3ab3ccf02" -[[package]] -name = "io-uring" -version = "0.7.10" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "046fa2d4d00aea763528b4950358d0ead425372445dc8ff86312b3c69ff7727b" -dependencies = [ - "bitflags 2.9.4", - "cfg-if", - "libc", -] - [[package]] name = "ipnet" version = "2.11.0" @@ -3824,7 +3833,7 @@ checksum = "03343451ff899767262ec32146f6d559dd759fdadf42ff0e227c7c48f72594b4" dependencies = [ "proc-macro2", "quote", - "syn 2.0.106", + "syn 2.0.108", ] [[package]] @@ -3833,7 +3842,7 @@ version = "0.1.34" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9afb3de4395d6b3e67a780b6de64b51c978ecf11cb9a462c66be7d4ca9039d33" dependencies = [ - "getrandom 0.3.3", + "getrandom 0.3.4", "libc", ] @@ -3978,7 +3987,7 @@ checksum = "5297962ef19edda4ce33aaa484386e0a5b3d7f2f4e037cbeee00503ef6b29d33" dependencies = [ "anstream", "anstyle", - "clap 4.5.48", + "clap 4.5.50", "escape8259", ] @@ -4124,7 +4133,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "78bed444cc8a2160f01cbcf811ef18cac863ad68ae8ca62092e8db51d51c761c" dependencies = [ "libc", - "wasi 0.11.1+wasi-snapshot-preview1", + "wasi", "windows-sys 0.59.0", ] @@ -4183,20 +4192,6 @@ dependencies = [ "windows-sys 0.52.0", ] -[[package]] -name = "num" -version = "0.4.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "35bd024e8b2ff75562e5f34e7f4905839deb4b22955ef5e73d2fea1b9813cb23" -dependencies = [ - "num-bigint", - "num-complex", - "num-integer", - "num-iter", - "num-rational", - "num-traits", -] - [[package]] name = "num-bigint" version = "0.4.6" @@ -4232,28 +4227,6 @@ dependencies = [ "num-traits", ] -[[package]] -name = "num-iter" -version = "0.1.45" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1429034a0490724d0075ebb2bc9e875d6503c3cf69e235a8941aa757d83ef5bf" -dependencies = [ - "autocfg", - "num-integer", - "num-traits", -] - -[[package]] -name = "num-rational" -version = "0.4.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f83d14da390562dca69fc84082e73e548e1ad308d24accdedd2720017cb37824" -dependencies = [ - "num-bigint", - "num-integer", - "num-traits", -] - [[package]] name = "num-traits" version = "0.2.19" @@ -4283,15 +4256,6 @@ dependencies = [ "objc2-core-foundation", ] -[[package]] -name = "object" -version = "0.36.7" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "62948e14d923ea95ea2c7c86c71013138b66525b86bdc08d2dcc262bdb497b87" -dependencies = [ - "memchr", -] - [[package]] name = "object_store" version = "0.12.4" @@ -4405,9 +4369,9 @@ dependencies = [ [[package]] name = "parquet" -version = "56.2.0" +version = "57.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f0dbd48ad52d7dccf8ea1b90a3ddbfaea4f69878dd7683e51c507d4bc52b5b27" +checksum = "7a0f31027ef1af7549f7cec603a9a21dce706d3f8d7c2060a68f43c1773be95a" dependencies = [ "ahash 0.8.12", "arrow-array", @@ -4426,8 +4390,9 @@ dependencies = [ "half", "hashbrown 0.16.0", "lz4_flex", - "num", "num-bigint", + "num-integer", + "num-traits", "object_store", "paste", "ring", @@ -4462,7 +4427,7 @@ dependencies = [ "regex", "regex-syntax", "structmeta", - "syn 2.0.106", + "syn 2.0.108", ] [[package]] @@ -4473,26 +4438,14 @@ checksum = "57c0d7b74b563b49d38dae00a0c37d4d6de9b432382b2892f0574ddcae73fd0a" [[package]] name = "pbjson" -version = "0.7.0" +version = "0.8.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c7e6349fa080353f4a597daffd05cb81572a9c031a6d4fff7e504947496fcc68" +checksum = "898bac3fa00d0ba57a4e8289837e965baa2dee8c3749f3b11d45a64b4223d9c3" dependencies = [ - "base64 0.21.7", + "base64 0.22.1", "serde", ] -[[package]] -name = "pbjson-build" -version = "0.7.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6eea3058763d6e656105d1403cb04e0a41b7bbac6362d413e7c33be0c32279c9" -dependencies = [ - "heck 0.5.0", - "itertools 0.13.0", - "prost 0.13.5", - "prost-types 0.13.5", -] - [[package]] name = "pbjson-build" version = "0.8.0" @@ -4501,22 +4454,22 @@ checksum = "af22d08a625a2213a78dbb0ffa253318c5c79ce3133d32d296655a7bdfb02095" dependencies = [ "heck 0.5.0", "itertools 0.14.0", - "prost 0.14.1", - "prost-types 0.14.1", + "prost", + "prost-types", ] [[package]] name = "pbjson-types" -version = "0.7.0" +version = "0.8.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e54e5e7bfb1652f95bc361d76f3c780d8e526b134b85417e774166ee941f0887" +checksum = "8e748e28374f10a330ee3bb9f29b828c0ac79831a32bab65015ad9b661ead526" dependencies = [ "bytes", "chrono", "pbjson", - "pbjson-build 0.7.0", - "prost 0.13.5", - "prost-build 0.13.5", + "pbjson-build", + "prost", + "prost-build", "serde", ] @@ -4533,7 +4486,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3672b37090dbd86368a4145bc067582552b29c27377cad4e0a306c97f9bd7772" dependencies = [ "fixedbitset", - "indexmap 2.11.4", + "indexmap 2.12.0", ] [[package]] @@ -4544,7 +4497,7 @@ checksum = "8701b58ea97060d5e5b155d383a69952a60943f0e6dfe30b04c287beb0b27455" dependencies = [ "fixedbitset", "hashbrown 0.15.5", - "indexmap 2.11.4", + "indexmap 2.12.0", "serde", ] @@ -4602,7 +4555,7 @@ checksum = "6e918e4ff8c4549eb882f14b3a4bc8c8bc93de829416eacf579f1207a8fbf861" dependencies = [ "proc-macro2", "quote", - "syn 2.0.106", + "syn 2.0.108", ] [[package]] @@ -4675,7 +4628,7 @@ dependencies = [ "heck 0.5.0", "proc-macro2", "quote", - "syn 2.0.106", + "syn 2.0.108", ] [[package]] @@ -4750,7 +4703,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "479ca8adacdd7ce8f1fb39ce9ecccbfe93a3f1344b3d0d97f20bc0196208f62b" dependencies = [ "proc-macro2", - "syn 2.0.106", + "syn 2.0.108", ] [[package]] @@ -4795,16 +4748,6 @@ dependencies = [ "unicode-ident", ] -[[package]] -name = "prost" -version = "0.13.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2796faa41db3ec313a31f7624d9286acf277b52de526150b7e69f3debf891ee5" -dependencies = [ - "bytes", - "prost-derive 0.13.5", -] - [[package]] name = "prost" version = "0.14.1" @@ -4812,27 +4755,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7231bd9b3d3d33c86b58adbac74b5ec0ad9f496b19d22801d773636feaa95f3d" dependencies = [ "bytes", - "prost-derive 0.14.1", -] - -[[package]] -name = "prost-build" -version = "0.13.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "be769465445e8c1474e9c5dac2018218498557af32d9ed057325ec9a41ae81bf" -dependencies = [ - "heck 0.5.0", - "itertools 0.14.0", - "log", - "multimap", - "once_cell", - "petgraph 0.7.1", - "prettyplease", - "prost 0.13.5", - "prost-types 0.13.5", - "regex", - "syn 2.0.106", - "tempfile", + "prost-derive", ] [[package]] @@ -4848,26 +4771,13 @@ dependencies = [ "once_cell", "petgraph 0.7.1", "prettyplease", - "prost 0.14.1", - "prost-types 0.14.1", + "prost", + "prost-types", "regex", - "syn 2.0.106", + "syn 2.0.108", "tempfile", ] -[[package]] -name = "prost-derive" -version = "0.13.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8a56d757972c98b346a9b766e3f02746cde6dd1cd1d1d563472929fdd74bec4d" -dependencies = [ - "anyhow", - "itertools 0.14.0", - "proc-macro2", - "quote", - "syn 2.0.106", -] - [[package]] name = "prost-derive" version = "0.14.1" @@ -4878,16 +4788,7 @@ dependencies = [ "itertools 0.14.0", "proc-macro2", "quote", - "syn 2.0.106", -] - -[[package]] -name = "prost-types" -version = "0.13.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "52c2c1bf36ddb1a1c396b3601a3cec27c2462e45f07c386894ec3ccf5332bd16" -dependencies = [ - "prost 0.13.5", + "syn 2.0.108", ] [[package]] @@ -4896,7 +4797,7 @@ version = "0.14.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b9b4db3d6da204ed77bb26ba83b6122a73aeb2e87e25fbf7ad2e84c4ccbf8f72" dependencies = [ - "prost 0.14.1", + "prost", ] [[package]] @@ -4939,9 +4840,9 @@ dependencies = [ [[package]] name = "pyo3" -version = "0.25.1" +version = "0.26.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8970a78afe0628a3e3430376fc5fd76b6b45c4d43360ffd6cdd40bdde72b682a" +checksum = "7ba0117f4212101ee6544044dae45abe1083d30ce7b29c4b5cbdfa2354e07383" dependencies = [ "indoc", "libc", @@ -4956,19 +4857,18 @@ dependencies = [ [[package]] name = "pyo3-build-config" -version = "0.25.1" +version = "0.26.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "458eb0c55e7ece017adeba38f2248ff3ac615e53660d7c71a238d7d2a01c7598" +checksum = "4fc6ddaf24947d12a9aa31ac65431fb1b851b8f4365426e182901eabfb87df5f" dependencies = [ - "once_cell", "target-lexicon", ] [[package]] name = "pyo3-ffi" -version = "0.25.1" +version = "0.26.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7114fe5457c61b276ab77c5055f206295b812608083644a5c5b2640c3102565c" +checksum = "025474d3928738efb38ac36d4744a74a400c901c7596199e20e45d98eb194105" dependencies = [ "libc", "pyo3-build-config", @@ -4976,27 +4876,27 @@ dependencies = [ [[package]] name = "pyo3-macros" -version = "0.25.1" +version = "0.26.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a8725c0a622b374d6cb051d11a0983786448f7785336139c3c94f5aa6bef7e50" +checksum = "2e64eb489f22fe1c95911b77c44cc41e7c19f3082fc81cce90f657cdc42ffded" dependencies = [ "proc-macro2", "pyo3-macros-backend", "quote", - "syn 2.0.106", + "syn 2.0.108", ] [[package]] name = "pyo3-macros-backend" -version = "0.25.1" +version = "0.26.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4109984c22491085343c05b0dbc54ddc405c3cf7b4374fc533f5c3313a572ccc" +checksum = "100246c0ecf400b475341b8455a9213344569af29a3c841d29270e53102e0fcf" dependencies = [ "heck 0.5.0", "proc-macro2", "pyo3-build-config", "quote", - "syn 2.0.106", + "syn 2.0.108", ] [[package]] @@ -5028,7 +4928,7 @@ dependencies = [ "quinn-udp", "rustc-hash", "rustls", - "socket2 0.6.0", + "socket2", "thiserror", "tokio", "tracing", @@ -5042,7 +4942,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f1906b49b0c3bc04b5fe5d86a77925ae6524a19b816ae38ce1e426255f1d8a31" dependencies = [ "bytes", - "getrandom 0.3.3", + "getrandom 0.3.4", "lru-slab", "rand 0.9.2", "ring", @@ -5065,7 +4965,7 @@ dependencies = [ "cfg_aliases", "libc", "once_cell", - "socket2 0.6.0", + "socket2", "tracing", "windows-sys 0.60.2", ] @@ -5157,7 +5057,7 @@ version = "0.9.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "99d9a13982dcf210057a8a78572b2217b667c3beacbf3a0d8b454f6f82837d38" dependencies = [ - "getrandom 0.3.3", + "getrandom 0.3.4", ] [[package]] @@ -5207,7 +5107,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "76009fbe0614077fc1a2ce255e3a1881a2e3a3527097d5dc6d8212c585e7e38b" dependencies = [ "quote", - "syn 2.0.106", + "syn 2.0.108", ] [[package]] @@ -5256,14 +5156,14 @@ checksum = "1165225c21bff1f3bbce98f5a1f889949bc902d3575308cc7b0de30b4f6d27c7" dependencies = [ "proc-macro2", "quote", - "syn 2.0.106", + "syn 2.0.108", ] [[package]] name = "regex" -version = "1.11.3" +version = "1.12.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8b5288124840bee7b386bc413c487869b360b2b4ec421ea56425128692f2a82c" +checksum = "843bc0191f75f3e22651ae5f1e72939ab2f72a4bc30fa80a066bd66edefc24d4" dependencies = [ "aho-corasick", "memchr", @@ -5273,9 +5173,9 @@ dependencies = [ [[package]] name = "regex-automata" -version = "0.4.11" +version = "0.4.13" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "833eb9ce86d40ef33cb1306d8accf7bc8ec2bfea4355cbdebb3df68b40925cad" +checksum = "5276caf25ac86c8d810222b3dbb938e512c55c6831a10f3e6ed1c93b84041f1c" dependencies = [ "aho-corasick", "memchr", @@ -5439,7 +5339,7 @@ dependencies = [ "regex", "relative-path", "rustc_version", - "syn 2.0.106", + "syn 2.0.108", "unicode-ident", ] @@ -5451,7 +5351,7 @@ checksum = "b3a8fb4672e840a587a66fc577a5491375df51ddb88f2a2c2a792598c326fe14" dependencies = [ "quote", "rand 0.8.5", - "syn 2.0.106", + "syn 2.0.108", ] [[package]] @@ -5471,12 +5371,6 @@ dependencies = [ "serde_json", ] -[[package]] -name = "rustc-demangle" -version = "0.1.26" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "56f7d92ca342cea22a06f2121d944b4fd82af56988c270852495420f961d4ace" - [[package]] name = "rustc-hash" version = "2.1.1" @@ -5660,7 +5554,7 @@ dependencies = [ "proc-macro2", "quote", "serde_derive_internals", - "syn 2.0.106", + "syn 2.0.108", ] [[package]] @@ -5751,7 +5645,7 @@ checksum = "d540f220d3187173da220f885ab66608367b6574e925011a9353e4badda91d79" dependencies = [ "proc-macro2", "quote", - "syn 2.0.106", + "syn 2.0.108", ] [[package]] @@ -5762,7 +5656,7 @@ checksum = "18d26a20a969b9e3fdf2fc2d9f21eda6c40e2de84c9408bb5d3b05d499aae711" dependencies = [ "proc-macro2", "quote", - "syn 2.0.106", + "syn 2.0.108", ] [[package]] @@ -5786,7 +5680,7 @@ checksum = "175ee3e80ae9982737ca543e96133087cbd9a485eecc3bc4de9c1a37b47ea59c" dependencies = [ "proc-macro2", "quote", - "syn 2.0.106", + "syn 2.0.108", ] [[package]] @@ -5798,7 +5692,7 @@ dependencies = [ "proc-macro2", "quote", "serde", - "syn 2.0.106", + "syn 2.0.108", ] [[package]] @@ -5823,7 +5717,7 @@ dependencies = [ "chrono", "hex", "indexmap 1.9.3", - "indexmap 2.11.4", + "indexmap 2.12.0", "schemars 0.9.0", "schemars 1.0.4", "serde", @@ -5842,7 +5736,7 @@ dependencies = [ "darling", "proc-macro2", "quote", - "syn 2.0.106", + "syn 2.0.108", ] [[package]] @@ -5851,7 +5745,7 @@ version = "0.9.34+deprecated" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6a8b1a1a2ebf674015cc02edccce75287f1a0130d394307b36743c2f5d504b47" dependencies = [ - "indexmap 2.11.4", + "indexmap 2.12.0", "itoa", "ryu", "serde", @@ -5964,16 +5858,6 @@ dependencies = [ "cmake", ] -[[package]] -name = "socket2" -version = "0.5.10" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e22376abed350d73dd1cd119b57ffccad95b4e585a7cda43e286245ce23c0678" -dependencies = [ - "libc", - "windows-sys 0.52.0", -] - [[package]] name = "socket2" version = "0.6.0" @@ -6028,7 +5912,7 @@ checksum = "da5fc6819faabb412da764b99d3b713bb55083c11e7e0c00144d386cd6a1939c" dependencies = [ "proc-macro2", "quote", - "syn 2.0.106", + "syn 2.0.108", ] [[package]] @@ -6076,7 +5960,7 @@ dependencies = [ "proc-macro2", "quote", "structmeta-derive", - "syn 2.0.106", + "syn 2.0.108", ] [[package]] @@ -6087,7 +5971,7 @@ checksum = "152a0b65a590ff6c3da95cabe2353ee04e6167c896b28e3b14478c2636c922fc" dependencies = [ "proc-macro2", "quote", - "syn 2.0.106", + "syn 2.0.108", ] [[package]] @@ -6136,7 +6020,7 @@ dependencies = [ "proc-macro2", "quote", "rustversion", - "syn 2.0.106", + "syn 2.0.108", ] [[package]] @@ -6148,7 +6032,7 @@ dependencies = [ "heck 0.5.0", "proc-macro2", "quote", - "syn 2.0.106", + "syn 2.0.108", ] [[package]] @@ -6163,18 +6047,18 @@ dependencies = [ [[package]] name = "substrait" -version = "0.58.0" +version = "0.59.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "de6d24c270c6c672a86c183c3a8439ba46c1936f93cf7296aa692de3b0ff0228" +checksum = "540683f325ab9ab1a2008bc24588f3e76f63b6a3f52bc47e121122376a063639" dependencies = [ "heck 0.5.0", "pbjson", - "pbjson-build 0.7.0", + "pbjson-build", "pbjson-types", "prettyplease", - "prost 0.13.5", - "prost-build 0.13.5", - "prost-types 0.13.5", + "prost", + "prost-build", + "prost-types", "protobuf-src", "regress", "schemars 0.8.22", @@ -6182,7 +6066,7 @@ dependencies = [ "serde", "serde_json", "serde_yaml", - "syn 2.0.106", + "syn 2.0.108", "typify", "walkdir", ] @@ -6206,9 +6090,9 @@ dependencies = [ [[package]] name = "syn" -version = "2.0.106" +version = "2.0.108" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ede7c438028d4436d71104916910f5bb611972c5cfd7f89b8300a8186e6fada6" +checksum = "da58917d35242480a05c2897064da0a80589a2a0476c9a3f2fdc83b53502e917" dependencies = [ "proc-macro2", "quote", @@ -6232,7 +6116,7 @@ checksum = "728a70f3dbaf5bab7f0c4b1ac8d7ae5ea60a4b5549c8a5914361c99147a709d2" dependencies = [ "proc-macro2", "quote", - "syn 2.0.106", + "syn 2.0.108", ] [[package]] @@ -6268,7 +6152,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2d31c77bdf42a745371d260a26ca7163f1e0924b64afa0b688e61b5a9fa02f16" dependencies = [ "fastrand", - "getrandom 0.3.3", + "getrandom 0.3.4", "once_cell", "rustix", "windows-sys 0.61.0", @@ -6349,7 +6233,7 @@ checksum = "3ff15c8ecd7de3849db632e14d18d2571fa09dfc5ed93479bc4485c7a517c913" dependencies = [ "proc-macro2", "quote", - "syn 2.0.106", + "syn 2.0.108", ] [[package]] @@ -6449,33 +6333,30 @@ checksum = "1f3ccbac311fea05f86f61904b462b55fb3df8837a366dfc601a0161d0532f20" [[package]] name = "tokio" -version = "1.47.1" +version = "1.48.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "89e49afdadebb872d3145a5638b59eb0691ea23e46ca484037cfab3b76b95038" +checksum = "ff360e02eab121e0bc37a2d3b4d4dc622e6eda3a8e5253d5435ecf5bd4c68408" dependencies = [ - "backtrace", "bytes", - "io-uring", "libc", "mio", "parking_lot", "pin-project-lite", "signal-hook-registry", - "slab", - "socket2 0.6.0", + "socket2", "tokio-macros", - "windows-sys 0.59.0", + "windows-sys 0.61.0", ] [[package]] name = "tokio-macros" -version = "2.5.0" +version = "2.6.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6e06d43f1345a3bcd39f6a56dbb7dcab2ba47e68e8ac134855e7e2bdbaf8cab8" +checksum = "af407857209536a95c8e56f8231ef2c2e2aff839b22e07a1ffcbc617e9db9fa5" dependencies = [ "proc-macro2", "quote", - "syn 2.0.106", + "syn 2.0.108", ] [[package]] @@ -6498,7 +6379,7 @@ dependencies = [ "postgres-protocol", "postgres-types", "rand 0.9.2", - "socket2 0.6.0", + "socket2", "tokio", "tokio-util", "whoami", @@ -6568,7 +6449,7 @@ version = "0.23.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f3effe7c0e86fdff4f69cdd2ccc1b96f933e24811c5441d44904e8683e27184b" dependencies = [ - "indexmap 2.11.4", + "indexmap 2.12.0", "toml_datetime", "toml_parser", "winnow", @@ -6585,9 +6466,9 @@ dependencies = [ [[package]] name = "tonic" -version = "0.13.1" +version = "0.14.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7e581ba15a835f4d9ea06c55ab1bd4dce26fc53752c69a04aac00703bfb49ba9" +checksum = "eb7613188ce9f7df5bfe185db26c5814347d110db17920415cf2fbcad85e7203" dependencies = [ "async-trait", "axum", @@ -6602,8 +6483,8 @@ dependencies = [ "hyper-util", "percent-encoding", "pin-project", - "prost 0.13.5", - "socket2 0.5.10", + "socket2", + "sync_wrapper", "tokio", "tokio-stream", "tower", @@ -6612,6 +6493,17 @@ dependencies = [ "tracing", ] +[[package]] +name = "tonic-prost" +version = "0.14.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "66bd50ad6ce1252d87ef024b3d64fe4c3cf54a86fb9ef4c631fdd0ded7aeaa67" +dependencies = [ + "bytes", + "prost", + "tonic", +] + [[package]] name = "tower" version = "0.5.2" @@ -6620,7 +6512,7 @@ checksum = "d039ad9159c98b70ecfd540b2573b97f7f52c3e8d9f8ad57a24b916a536975f9" dependencies = [ "futures-core", "futures-util", - "indexmap 2.11.4", + "indexmap 2.12.0", "pin-project-lite", "slab", "sync_wrapper", @@ -6680,7 +6572,7 @@ checksum = "81383ab64e72a7a8b8e13130c49e3dab29def6d0c7d76a03087b3cf71c5c6903" dependencies = [ "proc-macro2", "quote", - "syn 2.0.106", + "syn 2.0.108", ] [[package]] @@ -6788,7 +6680,7 @@ dependencies = [ "semver", "serde", "serde_json", - "syn 2.0.106", + "syn 2.0.108", "thiserror", "unicode-ident", ] @@ -6806,7 +6698,7 @@ dependencies = [ "serde", "serde_json", "serde_tokenstream", - "syn 2.0.106", + "syn 2.0.108", "typify-impl", ] @@ -6915,7 +6807,7 @@ version = "1.18.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2f87b8aa10b915a06587d0dec516c282ff295b475d94abf425d62b57710070a2" dependencies = [ - "getrandom 0.3.3", + "getrandom 0.3.4", "js-sys", "serde", "wasm-bindgen", @@ -6964,15 +6856,6 @@ version = "0.11.1+wasi-snapshot-preview1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ccf3ec651a847eb01de73ccad15eb7d99f80485de043efb2f370cd654f4ea44b" -[[package]] -name = "wasi" -version = "0.14.7+wasi-0.2.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "883478de20367e224c0090af9cf5f9fa85bed63a95c1abf3afc5c083ebc06e8c" -dependencies = [ - "wasip2", -] - [[package]] name = "wasip2" version = "1.0.1+wasi-0.2.4" @@ -7011,7 +6894,7 @@ dependencies = [ "log", "proc-macro2", "quote", - "syn 2.0.106", + "syn 2.0.108", "wasm-bindgen-shared", ] @@ -7046,7 +6929,7 @@ checksum = "9f07d2f20d4da7b26400c9f4a0511e6e0345b040694e8a75bd41d578fa4421d7" dependencies = [ "proc-macro2", "quote", - "syn 2.0.106", + "syn 2.0.108", "wasm-bindgen-backend", "wasm-bindgen-shared", ] @@ -7081,7 +6964,7 @@ checksum = "b673bca3298fe582aeef8352330ecbad91849f85090805582400850f8270a2e8" dependencies = [ "proc-macro2", "quote", - "syn 2.0.106", + "syn 2.0.108", ] [[package]] @@ -7226,7 +7109,7 @@ checksum = "a47fddd13af08290e67f4acabf4b459f647552718f683a7b415d290ac744a836" dependencies = [ "proc-macro2", "quote", - "syn 2.0.106", + "syn 2.0.108", ] [[package]] @@ -7237,7 +7120,7 @@ checksum = "bd9211b69f8dcdfa817bfd14bf1c97c9188afa36f4750130fcdf3f400eca9fa8" dependencies = [ "proc-macro2", "quote", - "syn 2.0.106", + "syn 2.0.108", ] [[package]] @@ -7553,7 +7436,7 @@ checksum = "38da3c9736e16c5d3c8c597a9aaa5d1fa565d0532ae05e27c24aa62fb32c0ab6" dependencies = [ "proc-macro2", "quote", - "syn 2.0.106", + "syn 2.0.108", "synstructure", ] @@ -7574,7 +7457,7 @@ checksum = "88d2b8d9c68ad2b9e4340d7832716a4d21a22a1154777ad56ea55c51a9cf3831" dependencies = [ "proc-macro2", "quote", - "syn 2.0.106", + "syn 2.0.108", ] [[package]] @@ -7594,7 +7477,7 @@ checksum = "d71e5d6e06ab090c67b5e44993ec16b72dcbaabc526db883a360057678b48502" dependencies = [ "proc-macro2", "quote", - "syn 2.0.106", + "syn 2.0.108", "synstructure", ] @@ -7634,7 +7517,7 @@ checksum = "5b96237efa0c878c64bd89c436f661be4e46b2f3eff1ebb976f7ef2321d2f58f" dependencies = [ "proc-macro2", "quote", - "syn 2.0.106", + "syn 2.0.108", ] [[package]] diff --git a/Cargo.toml b/Cargo.toml index dd0b20de528a..1cfb23bb183d 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -22,6 +22,7 @@ members = [ "datafusion/catalog", "datafusion/catalog-listing", "datafusion/datasource", + "datafusion/datasource-arrow", "datafusion/datasource-avro", "datafusion/datasource-csv", "datafusion/datasource-json", @@ -78,7 +79,7 @@ repository = "https://github.com/apache/datafusion" # Define Minimum Supported Rust Version (MSRV) rust-version = "1.87.0" # Define DataFusion version -version = "50.2.0" +version = "50.3.0" [workspace.dependencies] # We turn off default-features for some dependencies here so the workspaces which inherit them can @@ -90,19 +91,19 @@ ahash = { version = "0.8", default-features = false, features = [ "runtime-rng", ] } apache-avro = { version = "0.20", default-features = false } -arrow = { version = "56.2.0", features = [ +arrow = { version = "57.0.0", features = [ "prettyprint", "chrono-tz", ] } -arrow-buffer = { version = "56.2.0", default-features = false } -arrow-flight = { version = "56.2.0", features = [ +arrow-buffer = { version = "57.0.0", default-features = false } +arrow-flight = { version = "57.0.0", features = [ "flight-sql-experimental", ] } -arrow-ipc = { version = "56.2.0", default-features = false, features = [ +arrow-ipc = { version = "57.0.0", default-features = false, features = [ "lz4", ] } -arrow-ord = { version = "56.2.0", default-features = false } -arrow-schema = { version = "56.2.0", default-features = false } +arrow-ord = { version = "57.0.0", default-features = false } +arrow-schema = { version = "57.0.0", default-features = false } async-trait = "0.1.89" bigdecimal = "0.4.8" bytes = "1.10" @@ -110,73 +111,75 @@ chrono = { version = "0.4.42", default-features = false } criterion = "0.5.1" ctor = "0.4.3" dashmap = "6.0.1" -datafusion = { path = "datafusion/core", version = "50.2.0", default-features = false } -datafusion-catalog = { path = "datafusion/catalog", version = "50.2.0" } -datafusion-catalog-listing = { path = "datafusion/catalog-listing", version = "50.2.0" } -datafusion-common = { path = "datafusion/common", version = "50.2.0", default-features = false } -datafusion-common-runtime = { path = "datafusion/common-runtime", version = "50.2.0" } -datafusion-datasource = { path = "datafusion/datasource", version = "50.2.0", default-features = false } -datafusion-datasource-avro = { path = "datafusion/datasource-avro", version = "50.2.0", default-features = false } -datafusion-datasource-csv = { path = "datafusion/datasource-csv", version = "50.2.0", default-features = false } -datafusion-datasource-json = { path = "datafusion/datasource-json", version = "50.2.0", default-features = false } -datafusion-datasource-parquet = { path = "datafusion/datasource-parquet", version = "50.2.0", default-features = false } -datafusion-doc = { path = "datafusion/doc", version = "50.2.0" } -datafusion-execution = { path = "datafusion/execution", version = "50.2.0", default-features = false } -datafusion-expr = { path = "datafusion/expr", version = "50.2.0", default-features = false } -datafusion-expr-common = { path = "datafusion/expr-common", version = "50.2.0" } -datafusion-ffi = { path = "datafusion/ffi", version = "50.2.0" } -datafusion-functions = { path = "datafusion/functions", version = "50.2.0" } -datafusion-functions-aggregate = { path = "datafusion/functions-aggregate", version = "50.2.0" } -datafusion-functions-aggregate-common = { path = "datafusion/functions-aggregate-common", version = "50.2.0" } -datafusion-functions-nested = { path = "datafusion/functions-nested", version = "50.2.0", default-features = false } -datafusion-functions-table = { path = "datafusion/functions-table", version = "50.2.0" } -datafusion-functions-window = { path = "datafusion/functions-window", version = "50.2.0" } -datafusion-functions-window-common = { path = "datafusion/functions-window-common", version = "50.2.0" } -datafusion-macros = { path = "datafusion/macros", version = "50.2.0" } -datafusion-optimizer = { path = "datafusion/optimizer", version = "50.2.0", default-features = false } -datafusion-physical-expr = { path = "datafusion/physical-expr", version = "50.2.0", default-features = false } -datafusion-physical-expr-adapter = { path = "datafusion/physical-expr-adapter", version = "50.2.0", default-features = false } -datafusion-physical-expr-common = { path = "datafusion/physical-expr-common", version = "50.2.0", default-features = false } -datafusion-physical-optimizer = { path = "datafusion/physical-optimizer", version = "50.2.0" } -datafusion-physical-plan = { path = "datafusion/physical-plan", version = "50.2.0" } -datafusion-proto = { path = "datafusion/proto", version = "50.2.0" } -datafusion-proto-common = { path = "datafusion/proto-common", version = "50.2.0" } -datafusion-pruning = { path = "datafusion/pruning", version = "50.2.0" } -datafusion-session = { path = "datafusion/session", version = "50.2.0" } -datafusion-spark = { path = "datafusion/spark", version = "50.2.0" } -datafusion-sql = { path = "datafusion/sql", version = "50.2.0" } -datafusion-substrait = { path = "datafusion/substrait", version = "50.2.0" } +datafusion = { path = "datafusion/core", version = "50.3.0", default-features = false } +datafusion-catalog = { path = "datafusion/catalog", version = "50.3.0" } +datafusion-catalog-listing = { path = "datafusion/catalog-listing", version = "50.3.0" } +datafusion-common = { path = "datafusion/common", version = "50.3.0", default-features = false } +datafusion-common-runtime = { path = "datafusion/common-runtime", version = "50.3.0" } +datafusion-datasource = { path = "datafusion/datasource", version = "50.3.0", default-features = false } +datafusion-datasource-arrow = { path = "datafusion/datasource-arrow", version = "50.3.0", default-features = false } +datafusion-datasource-avro = { path = "datafusion/datasource-avro", version = "50.3.0", default-features = false } +datafusion-datasource-csv = { path = "datafusion/datasource-csv", version = "50.3.0", default-features = false } +datafusion-datasource-json = { path = "datafusion/datasource-json", version = "50.3.0", default-features = false } +datafusion-datasource-parquet = { path = "datafusion/datasource-parquet", version = "50.3.0", default-features = false } +datafusion-doc = { path = "datafusion/doc", version = "50.3.0" } +datafusion-execution = { path = "datafusion/execution", version = "50.3.0", default-features = false } +datafusion-expr = { path = "datafusion/expr", version = "50.3.0", default-features = false } +datafusion-expr-common = { path = "datafusion/expr-common", version = "50.3.0" } +datafusion-ffi = { path = "datafusion/ffi", version = "50.3.0" } +datafusion-functions = { path = "datafusion/functions", version = "50.3.0" } +datafusion-functions-aggregate = { path = "datafusion/functions-aggregate", version = "50.3.0" } +datafusion-functions-aggregate-common = { path = "datafusion/functions-aggregate-common", version = "50.3.0" } +datafusion-functions-nested = { path = "datafusion/functions-nested", version = "50.3.0", default-features = false } +datafusion-functions-table = { path = "datafusion/functions-table", version = "50.3.0" } +datafusion-functions-window = { path = "datafusion/functions-window", version = "50.3.0" } +datafusion-functions-window-common = { path = "datafusion/functions-window-common", version = "50.3.0" } +datafusion-macros = { path = "datafusion/macros", version = "50.3.0" } +datafusion-optimizer = { path = "datafusion/optimizer", version = "50.3.0", default-features = false } +datafusion-physical-expr = { path = "datafusion/physical-expr", version = "50.3.0", default-features = false } +datafusion-physical-expr-adapter = { path = "datafusion/physical-expr-adapter", version = "50.3.0", default-features = false } +datafusion-physical-expr-common = { path = "datafusion/physical-expr-common", version = "50.3.0", default-features = false } +datafusion-physical-optimizer = { path = "datafusion/physical-optimizer", version = "50.3.0" } +datafusion-physical-plan = { path = "datafusion/physical-plan", version = "50.3.0" } +datafusion-proto = { path = "datafusion/proto", version = "50.3.0" } +datafusion-proto-common = { path = "datafusion/proto-common", version = "50.3.0" } +datafusion-pruning = { path = "datafusion/pruning", version = "50.3.0" } +datafusion-session = { path = "datafusion/session", version = "50.3.0" } +datafusion-spark = { path = "datafusion/spark", version = "50.3.0" } +datafusion-sql = { path = "datafusion/sql", version = "50.3.0" } +datafusion-substrait = { path = "datafusion/substrait", version = "50.3.0" } + doc-comment = "0.3" env_logger = "0.11" futures = "0.3" half = { version = "2.7.0", default-features = false } hashbrown = { version = "0.14.5", features = ["raw"] } hex = { version = "0.4.3" } -indexmap = "2.11.4" +indexmap = "2.12.0" +insta = { version = "1.43.2", features = ["glob", "filters"] } itertools = "0.14" log = "^0.4" object_store = { version = "0.12.4", default-features = false } parking_lot = "0.12" -parquet = { version = "56.2.0", default-features = false, features = [ +parquet = { version = "57.0.0", default-features = false, features = [ "arrow", "async", "object_store", ] } -pbjson = { version = "0.7.0" } -pbjson-types = "0.7" +pbjson = { version = "0.8.0" } +pbjson-types = "0.8" # Should match arrow-flight's version of prost. -insta = { version = "1.43.2", features = ["glob", "filters"] } -prost = "0.13.1" +prost = "0.14.1" rand = "0.9" recursive = "0.1.1" -regex = "1.11" +regex = "1.12" rstest = "0.25.0" serde_json = "1" sqlparser = { version = "0.59.0", default-features = false, features = ["std", "visitor"] } tempfile = "3" testcontainers = { version = "0.24", features = ["default"] } testcontainers-modules = { version = "0.12" } -tokio = { version = "1.47", features = ["macros", "rt", "sync"] } +tokio = { version = "1.48", features = ["macros", "rt", "sync"] } url = "2.5.7" [workspace.lints.clippy] diff --git a/README.md b/README.md index 4c4b955176b2..5191496eaafe 100644 --- a/README.md +++ b/README.md @@ -64,7 +64,7 @@ See [use cases] for examples. The following related subprojects target end users DataFusion. "Out of the box," -DataFusion offers [SQL] and [`Dataframe`] APIs, excellent [performance], +DataFusion offers [SQL](https://datafusion.apache.org/user-guide/sql/index.html) and [Dataframe](https://datafusion.apache.org/user-guide/dataframe.html) APIs, excellent [performance], built-in support for CSV, Parquet, JSON, and Avro, extensive customization, and a great community. diff --git a/benchmarks/src/imdb/run.rs b/benchmarks/src/imdb/run.rs index 3d58d5f54d4b..11bd424ba686 100644 --- a/benchmarks/src/imdb/run.rs +++ b/benchmarks/src/imdb/run.rs @@ -534,7 +534,7 @@ mod tests { let plan = ctx.sql(&query).await?; let plan = plan.into_optimized_plan()?; let bytes = logical_plan_to_bytes(&plan)?; - let plan2 = logical_plan_from_bytes(&bytes, &ctx)?; + let plan2 = logical_plan_from_bytes(&bytes, &ctx.task_ctx())?; let plan_formatted = format!("{}", plan.display_indent()); let plan2_formatted = format!("{}", plan2.display_indent()); assert_eq!(plan_formatted, plan2_formatted); diff --git a/benchmarks/src/nlj.rs b/benchmarks/src/nlj.rs index e412c0ade8a8..7d1e14f69439 100644 --- a/benchmarks/src/nlj.rs +++ b/benchmarks/src/nlj.rs @@ -146,6 +146,45 @@ const NLJ_QUERIES: &[&str] = &[ FULL JOIN range(30000) AS t2 ON (t1.value > t2.value); "#, + // Q13: LEFT SEMI 30K x 30K | HIGH 99.9% + r#" + SELECT t1.* + FROM range(30000) AS t1 + LEFT SEMI JOIN range(30000) AS t2 + ON t1.value < t2.value; + "#, + // Q14: LEFT ANTI 30K x 30K | LOW 0.003% + r#" + SELECT t1.* + FROM range(30000) AS t1 + LEFT ANTI JOIN range(30000) AS t2 + ON t1.value < t2.value; + "#, + // Q15: RIGHT SEMI 30K x 30K | HIGH 99.9% + r#" + SELECT t1.* + FROM range(30000) AS t2 + RIGHT SEMI JOIN range(30000) AS t1 + ON t2.value < t1.value; + "#, + // Q16: RIGHT ANTI 30K x 30K | LOW 0.003% + r#" + SELECT t1.* + FROM range(30000) AS t2 + RIGHT ANTI JOIN range(30000) AS t1 + ON t2.value < t1.value; + "#, + // Q17: LEFT MARK | HIGH 99.9% + r#" + SELECT * + FROM range(30000) AS t2(k2) + WHERE k2 > 0 + OR EXISTS ( + SELECT 1 + FROM range(30000) AS t1(k1) + WHERE t2.k2 > t1.k1 + ); + "#, ]; impl RunOpt { diff --git a/benchmarks/src/tpch/run.rs b/benchmarks/src/tpch/run.rs index b93bdf254a27..cc59b7803036 100644 --- a/benchmarks/src/tpch/run.rs +++ b/benchmarks/src/tpch/run.rs @@ -92,6 +92,15 @@ pub struct RunOpt { #[structopt(short = "j", long = "prefer_hash_join", default_value = "true")] prefer_hash_join: BoolDefaultTrue, + /// If true then Piecewise Merge Join can be used, if false then it will opt for Nested Loop Join + /// True by default. + #[structopt( + short = "j", + long = "enable_piecewise_merge_join", + default_value = "false" + )] + enable_piecewise_merge_join: BoolDefaultTrue, + /// Mark the first column of each table as sorted in ascending order. /// The tables should have been created with the `--sort` option for this to have any effect. #[structopt(short = "t", long = "sorted")] @@ -112,6 +121,8 @@ impl RunOpt { .config()? .with_collect_statistics(!self.disable_statistics); config.options_mut().optimizer.prefer_hash_join = self.prefer_hash_join; + config.options_mut().optimizer.enable_piecewise_merge_join = + self.enable_piecewise_merge_join; let rt_builder = self.common.runtime_env_builder()?; let ctx = SessionContext::new_with_config_rt(config, rt_builder.build_arc()?); // register tables @@ -379,6 +390,7 @@ mod tests { output_path: None, disable_statistics: false, prefer_hash_join: true, + enable_piecewise_merge_join: false, sorted: false, }; opt.register_tables(&ctx).await?; @@ -387,7 +399,7 @@ mod tests { let plan = ctx.sql(&query).await?; let plan = plan.into_optimized_plan()?; let bytes = logical_plan_to_bytes(&plan)?; - let plan2 = logical_plan_from_bytes(&bytes, &ctx)?; + let plan2 = logical_plan_from_bytes(&bytes, &ctx.task_ctx())?; let plan_formatted = format!("{}", plan.display_indent()); let plan2_formatted = format!("{}", plan2.display_indent()); assert_eq!(plan_formatted, plan2_formatted); @@ -416,6 +428,7 @@ mod tests { output_path: None, disable_statistics: false, prefer_hash_join: true, + enable_piecewise_merge_join: false, sorted: false, }; opt.register_tables(&ctx).await?; diff --git a/datafusion-cli/Cargo.toml b/datafusion-cli/Cargo.toml index d186cd711945..f3069b492352 100644 --- a/datafusion-cli/Cargo.toml +++ b/datafusion-cli/Cargo.toml @@ -40,7 +40,7 @@ async-trait = { workspace = true } aws-config = "1.8.7" aws-credential-types = "1.2.7" chrono = { workspace = true } -clap = { version = "4.5.47", features = ["derive", "cargo"] } +clap = { version = "4.5.50", features = ["cargo", "derive"] } datafusion = { workspace = true, features = [ "avro", "compression", @@ -55,6 +55,7 @@ datafusion = { workspace = true, features = [ "sql", "unicode_expressions", ] } +datafusion-common = { workspace = true } dirs = "6.0.0" env_logger = { workspace = true } futures = { workspace = true } @@ -65,7 +66,7 @@ parking_lot = { workspace = true } parquet = { workspace = true, default-features = false } regex = { workspace = true } rustyline = "17.0" -tokio = { workspace = true, features = ["macros", "rt", "rt-multi-thread", "sync", "parking_lot", "signal"] } +tokio = { workspace = true, features = ["macros", "parking_lot", "rt", "rt-multi-thread", "signal", "sync"] } url = { workspace = true } [dev-dependencies] diff --git a/datafusion-cli/src/command.rs b/datafusion-cli/src/command.rs index 48fb37e8a888..3fbfe5680cfc 100644 --- a/datafusion-cli/src/command.rs +++ b/datafusion-cli/src/command.rs @@ -128,7 +128,7 @@ impl Command { let profile_mode = mode .parse() .map_err(|_| - exec_datafusion_err!("Failed to parse input: {mode}. Valid options are disabled, enabled") + exec_datafusion_err!("Failed to parse input: {mode}. Valid options are disabled, summary, trace") )?; print_options .instrumented_registry @@ -165,7 +165,7 @@ impl Command { ("\\pset [NAME [VALUE]]", "set table output option\n(format)") } Self::ObjectStoreProfileMode(_) => ( - "\\object_store_profiling (disabled|enabled)", + "\\object_store_profiling (disabled|summary|trace)", "print or set object store profile mode", ), } @@ -312,13 +312,22 @@ mod tests { InstrumentedObjectStoreMode::default() ); - cmd = "object_store_profiling enabled" + cmd = "object_store_profiling summary" .parse() .expect("expected parse to succeed"); assert!(cmd.execute(&ctx, &mut print_options).await.is_ok()); assert_eq!( print_options.instrumented_registry.instrument_mode(), - InstrumentedObjectStoreMode::Enabled + InstrumentedObjectStoreMode::Summary + ); + + cmd = "object_store_profiling trace" + .parse() + .expect("expected parse to succeed"); + assert!(cmd.execute(&ctx, &mut print_options).await.is_ok()); + assert_eq!( + print_options.instrumented_registry.instrument_mode(), + InstrumentedObjectStoreMode::Trace ); cmd = "object_store_profiling does_not_exist" diff --git a/datafusion-cli/src/functions.rs b/datafusion-cli/src/functions.rs index 3ec446c51583..d23b12469e38 100644 --- a/datafusion-cli/src/functions.rs +++ b/datafusion-cli/src/functions.rs @@ -419,7 +419,9 @@ impl TableFunctionImpl for ParquetMetadataFunc { stats_max_value_arr.push(None); }; compression_arr.push(format!("{:?}", column.compression())); - encodings_arr.push(format!("{:?}", column.encodings())); + // need to collect into Vec to format + let encodings: Vec<_> = column.encodings().collect(); + encodings_arr.push(format!("{:?}", encodings)); index_page_offset_arr.push(column.index_page_offset()); dictionary_page_offset_arr.push(column.dictionary_page_offset()); data_page_offset_arr.push(column.data_page_offset()); diff --git a/datafusion-cli/src/helper.rs b/datafusion-cli/src/helper.rs index 64c34c473736..219637b3460e 100644 --- a/datafusion-cli/src/helper.rs +++ b/datafusion-cli/src/helper.rs @@ -24,6 +24,7 @@ use crate::highlighter::{NoSyntaxHighlighter, SyntaxHighlighter}; use datafusion::sql::parser::{DFParser, Statement}; use datafusion::sql::sqlparser::dialect::dialect_from_str; +use datafusion_common::config::Dialect; use rustyline::completion::{Completer, FilenameCompleter, Pair}; use rustyline::error::ReadlineError; @@ -34,12 +35,12 @@ use rustyline::{Context, Helper, Result}; pub struct CliHelper { completer: FilenameCompleter, - dialect: String, + dialect: Dialect, highlighter: Box, } impl CliHelper { - pub fn new(dialect: &str, color: bool) -> Self { + pub fn new(dialect: &Dialect, color: bool) -> Self { let highlighter: Box = if !color { Box::new(NoSyntaxHighlighter {}) } else { @@ -47,20 +48,20 @@ impl CliHelper { }; Self { completer: FilenameCompleter::new(), - dialect: dialect.into(), + dialect: *dialect, highlighter, } } - pub fn set_dialect(&mut self, dialect: &str) { - if dialect != self.dialect { - self.dialect = dialect.to_string(); + pub fn set_dialect(&mut self, dialect: &Dialect) { + if *dialect != self.dialect { + self.dialect = *dialect; } } fn validate_input(&self, input: &str) -> Result { if let Some(sql) = input.strip_suffix(';') { - let dialect = match dialect_from_str(&self.dialect) { + let dialect = match dialect_from_str(self.dialect) { Some(dialect) => dialect, None => { return Ok(ValidationResult::Invalid(Some(format!( @@ -97,7 +98,7 @@ impl CliHelper { impl Default for CliHelper { fn default() -> Self { - Self::new("generic", false) + Self::new(&Dialect::Generic, false) } } @@ -289,7 +290,7 @@ mod tests { ); // valid in postgresql dialect - validator.set_dialect("postgresql"); + validator.set_dialect(&Dialect::PostgreSQL); let result = readline_direct(Cursor::new(r"select 1 # 2;".as_bytes()), &validator)?; assert!(matches!(result, ValidationResult::Valid(None))); diff --git a/datafusion-cli/src/highlighter.rs b/datafusion-cli/src/highlighter.rs index 7a886b94740b..f4e57a2e3593 100644 --- a/datafusion-cli/src/highlighter.rs +++ b/datafusion-cli/src/highlighter.rs @@ -27,6 +27,7 @@ use datafusion::sql::sqlparser::{ keywords::Keyword, tokenizer::{Token, Tokenizer}, }; +use datafusion_common::config; use rustyline::highlight::{CmdKind, Highlighter}; /// The syntax highlighter. @@ -36,7 +37,7 @@ pub struct SyntaxHighlighter { } impl SyntaxHighlighter { - pub fn new(dialect: &str) -> Self { + pub fn new(dialect: &config::Dialect) -> Self { let dialect = dialect_from_str(dialect).unwrap_or(Box::new(GenericDialect {})); Self { dialect } } @@ -93,13 +94,14 @@ impl Color { #[cfg(test)] mod tests { + use super::config::Dialect; use super::SyntaxHighlighter; use rustyline::highlight::Highlighter; #[test] fn highlighter_valid() { let s = "SElect col_a from tab_1;"; - let highlighter = SyntaxHighlighter::new("generic"); + let highlighter = SyntaxHighlighter::new(&Dialect::Generic); let out = highlighter.highlight(s, s.len()); assert_eq!( "\u{1b}[91mSElect\u{1b}[0m col_a \u{1b}[91mfrom\u{1b}[0m tab_1;", @@ -110,7 +112,7 @@ mod tests { #[test] fn highlighter_valid_with_new_line() { let s = "SElect col_a from tab_1\n WHERE col_b = 'なにか';"; - let highlighter = SyntaxHighlighter::new("generic"); + let highlighter = SyntaxHighlighter::new(&Dialect::Generic); let out = highlighter.highlight(s, s.len()); assert_eq!( "\u{1b}[91mSElect\u{1b}[0m col_a \u{1b}[91mfrom\u{1b}[0m tab_1\n \u{1b}[91mWHERE\u{1b}[0m col_b = \u{1b}[92m'なにか'\u{1b}[0m;", @@ -121,7 +123,7 @@ mod tests { #[test] fn highlighter_invalid() { let s = "SElect col_a from tab_1 WHERE col_b = ';"; - let highlighter = SyntaxHighlighter::new("generic"); + let highlighter = SyntaxHighlighter::new(&Dialect::Generic); let out = highlighter.highlight(s, s.len()); assert_eq!("SElect col_a from tab_1 WHERE col_b = ';", out); } diff --git a/datafusion-cli/src/main.rs b/datafusion-cli/src/main.rs index 3dbe839d3c9b..09fa8ef15af8 100644 --- a/datafusion-cli/src/main.rs +++ b/datafusion-cli/src/main.rs @@ -151,7 +151,7 @@ struct Args { #[clap( long, - help = "Specify the default object_store_profiling mode, defaults to 'disabled'.\n[possible values: disabled, enabled]", + help = "Specify the default object_store_profiling mode, defaults to 'disabled'.\n[possible values: disabled, summary, trace]", default_value_t = InstrumentedObjectStoreMode::Disabled )] object_store_profiling: InstrumentedObjectStoreMode, @@ -497,7 +497,7 @@ mod tests { +-------------------------------------------------------------+--------------+--------------------+-----------------------+-----------------+-----------+-------------+------------+----------------+-------+-----------+-----------+------------------+----------------------+-----------------+-----------------+-------------+------------------------------+-------------------+------------------------+------------------+-----------------------+-------------------------+ | filename | row_group_id | row_group_num_rows | row_group_num_columns | row_group_bytes | column_id | file_offset | num_values | path_in_schema | type | stats_min | stats_max | stats_null_count | stats_distinct_count | stats_min_value | stats_max_value | compression | encodings | index_page_offset | dictionary_page_offset | data_page_offset | total_compressed_size | total_uncompressed_size | +-------------------------------------------------------------+--------------+--------------------+-----------------------+-----------------+-----------+-------------+------------+----------------+-------+-----------+-----------+------------------+----------------------+-----------------+-----------------+-------------+------------------------------+-------------------+------------------------+------------------+-----------------------+-------------------------+ - | ../datafusion/core/tests/data/fixed_size_list_array.parquet | 0 | 2 | 1 | 123 | 0 | 125 | 4 | "f0.list.item" | INT64 | 1 | 4 | 0 | | 1 | 4 | SNAPPY | [RLE_DICTIONARY, PLAIN, RLE] | | 4 | 46 | 121 | 123 | + | ../datafusion/core/tests/data/fixed_size_list_array.parquet | 0 | 2 | 1 | 123 | 0 | 125 | 4 | "f0.list.item" | INT64 | 1 | 4 | 0 | | 1 | 4 | SNAPPY | [PLAIN, RLE, RLE_DICTIONARY] | | 4 | 46 | 121 | 123 | +-------------------------------------------------------------+--------------+--------------------+-----------------------+-----------------+-----------+-------------+------------+----------------+-------+-----------+-----------+------------------+----------------------+-----------------+-----------------+-------------+------------------------------+-------------------+------------------------+------------------+-----------------------+-------------------------+ "#); @@ -510,7 +510,7 @@ mod tests { +-------------------------------------------------------------+--------------+--------------------+-----------------------+-----------------+-----------+-------------+------------+----------------+-------+-----------+-----------+------------------+----------------------+-----------------+-----------------+-------------+------------------------------+-------------------+------------------------+------------------+-----------------------+-------------------------+ | filename | row_group_id | row_group_num_rows | row_group_num_columns | row_group_bytes | column_id | file_offset | num_values | path_in_schema | type | stats_min | stats_max | stats_null_count | stats_distinct_count | stats_min_value | stats_max_value | compression | encodings | index_page_offset | dictionary_page_offset | data_page_offset | total_compressed_size | total_uncompressed_size | +-------------------------------------------------------------+--------------+--------------------+-----------------------+-----------------+-----------+-------------+------------+----------------+-------+-----------+-----------+------------------+----------------------+-----------------+-----------------+-------------+------------------------------+-------------------+------------------------+------------------+-----------------------+-------------------------+ - | ../datafusion/core/tests/data/fixed_size_list_array.parquet | 0 | 2 | 1 | 123 | 0 | 125 | 4 | "f0.list.item" | INT64 | 1 | 4 | 0 | | 1 | 4 | SNAPPY | [RLE_DICTIONARY, PLAIN, RLE] | | 4 | 46 | 121 | 123 | + | ../datafusion/core/tests/data/fixed_size_list_array.parquet | 0 | 2 | 1 | 123 | 0 | 125 | 4 | "f0.list.item" | INT64 | 1 | 4 | 0 | | 1 | 4 | SNAPPY | [PLAIN, RLE, RLE_DICTIONARY] | | 4 | 46 | 121 | 123 | +-------------------------------------------------------------+--------------+--------------------+-----------------------+-----------------+-----------+-------------+------------+----------------+-------+-----------+-----------+------------------+----------------------+-----------------+-----------------+-------------+------------------------------+-------------------+------------------------+------------------+-----------------------+-------------------------+ "#); @@ -532,7 +532,7 @@ mod tests { +-----------------------------------------------------------------+--------------+--------------------+-----------------------+-----------------+-----------+-------------+------------+----------------+------------+-----------+-----------+------------------+----------------------+-----------------+-----------------+--------------------+--------------------------+-------------------+------------------------+------------------+-----------------------+-------------------------+ | filename | row_group_id | row_group_num_rows | row_group_num_columns | row_group_bytes | column_id | file_offset | num_values | path_in_schema | type | stats_min | stats_max | stats_null_count | stats_distinct_count | stats_min_value | stats_max_value | compression | encodings | index_page_offset | dictionary_page_offset | data_page_offset | total_compressed_size | total_uncompressed_size | +-----------------------------------------------------------------+--------------+--------------------+-----------------------+-----------------+-----------+-------------+------------+----------------+------------+-----------+-----------+------------------+----------------------+-----------------+-----------------+--------------------+--------------------------+-------------------+------------------------+------------------+-----------------------+-------------------------+ - | ../parquet-testing/data/data_index_bloom_encoding_stats.parquet | 0 | 14 | 1 | 163 | 0 | 4 | 14 | "String" | BYTE_ARRAY | Hello | today | 0 | | Hello | today | GZIP(GzipLevel(6)) | [BIT_PACKED, RLE, PLAIN] | | | 4 | 152 | 163 | + | ../parquet-testing/data/data_index_bloom_encoding_stats.parquet | 0 | 14 | 1 | 163 | 0 | 4 | 14 | "String" | BYTE_ARRAY | Hello | today | 0 | | Hello | today | GZIP(GzipLevel(6)) | [PLAIN, RLE, BIT_PACKED] | | | 4 | 152 | 163 | +-----------------------------------------------------------------+--------------+--------------------+-----------------------+-----------------+-----------+-------------+------------+----------------+------------+-----------+-----------+------------------+----------------------+-----------------+-----------------+--------------------+--------------------------+-------------------+------------------------+------------------+-----------------------+-------------------------+ "#); @@ -592,9 +592,9 @@ mod tests { +-----------------------------------+-----------------+---------------------+------+------------------+ | filename | file_size_bytes | metadata_size_bytes | hits | extra | +-----------------------------------+-----------------+---------------------+------+------------------+ - | alltypes_plain.parquet | 1851 | 10181 | 2 | page_index=false | - | alltypes_tiny_pages.parquet | 454233 | 881418 | 2 | page_index=true | - | lz4_raw_compressed_larger.parquet | 380836 | 2939 | 2 | page_index=false | + | alltypes_plain.parquet | 1851 | 6957 | 2 | page_index=false | + | alltypes_tiny_pages.parquet | 454233 | 267014 | 2 | page_index=true | + | lz4_raw_compressed_larger.parquet | 380836 | 996 | 2 | page_index=false | +-----------------------------------+-----------------+---------------------+------+------------------+ "); @@ -623,9 +623,9 @@ mod tests { +-----------------------------------+-----------------+---------------------+------+------------------+ | filename | file_size_bytes | metadata_size_bytes | hits | extra | +-----------------------------------+-----------------+---------------------+------+------------------+ - | alltypes_plain.parquet | 1851 | 10181 | 5 | page_index=false | - | alltypes_tiny_pages.parquet | 454233 | 881418 | 2 | page_index=true | - | lz4_raw_compressed_larger.parquet | 380836 | 2939 | 3 | page_index=false | + | alltypes_plain.parquet | 1851 | 6957 | 5 | page_index=false | + | alltypes_tiny_pages.parquet | 454233 | 267014 | 2 | page_index=true | + | lz4_raw_compressed_larger.parquet | 380836 | 996 | 3 | page_index=false | +-----------------------------------+-----------------+---------------------+------+------------------+ "); diff --git a/datafusion-cli/src/object_storage/instrumented.rs b/datafusion-cli/src/object_storage/instrumented.rs index 9252e0688c35..c4b63b417fe4 100644 --- a/datafusion-cli/src/object_storage/instrumented.rs +++ b/datafusion-cli/src/object_storage/instrumented.rs @@ -26,6 +26,8 @@ use std::{ time::Duration, }; +use arrow::array::{ArrayRef, RecordBatch, StringArray}; +use arrow::util::pretty::pretty_format_batches; use async_trait::async_trait; use chrono::Utc; use datafusion::{ @@ -48,13 +50,15 @@ pub enum InstrumentedObjectStoreMode { /// Disable collection of profiling data #[default] Disabled, - /// Enable collection of profiling data - Enabled, + /// Enable collection of profiling data and output a summary + Summary, + /// Enable collection of profiling data and output a summary and all details + Trace, } impl fmt::Display for InstrumentedObjectStoreMode { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(f, "{:?}", self) + write!(f, "{self:?}") } } @@ -64,7 +68,8 @@ impl FromStr for InstrumentedObjectStoreMode { fn from_str(s: &str) -> std::result::Result { match s.to_lowercase().as_str() { "disabled" => Ok(Self::Disabled), - "enabled" => Ok(Self::Enabled), + "summary" => Ok(Self::Summary), + "trace" => Ok(Self::Trace), _ => Err(DataFusionError::Execution(format!("Unrecognized mode {s}"))), } } @@ -73,7 +78,8 @@ impl FromStr for InstrumentedObjectStoreMode { impl From for InstrumentedObjectStoreMode { fn from(value: u8) -> Self { match value { - 1 => InstrumentedObjectStoreMode::Enabled, + 1 => InstrumentedObjectStoreMode::Summary, + 2 => InstrumentedObjectStoreMode::Trace, _ => InstrumentedObjectStoreMode::Disabled, } } @@ -110,6 +116,59 @@ impl InstrumentedObjectStore { req.drain(..).collect() } + fn enabled(&self) -> bool { + self.instrument_mode.load(Ordering::Relaxed) + != InstrumentedObjectStoreMode::Disabled as u8 + } + + async fn instrumented_put_opts( + &self, + location: &Path, + payload: PutPayload, + opts: PutOptions, + ) -> Result { + let timestamp = Utc::now(); + let start = Instant::now(); + let size = payload.content_length(); + let ret = self.inner.put_opts(location, payload, opts).await?; + let elapsed = start.elapsed(); + + self.requests.lock().push(RequestDetails { + op: Operation::Put, + path: location.clone(), + timestamp, + duration: Some(elapsed), + size: Some(size), + range: None, + extra_display: None, + }); + + Ok(ret) + } + + async fn instrumented_put_multipart( + &self, + location: &Path, + opts: PutMultipartOptions, + ) -> Result> { + let timestamp = Utc::now(); + let start = Instant::now(); + let ret = self.inner.put_multipart_opts(location, opts).await?; + let elapsed = start.elapsed(); + + self.requests.lock().push(RequestDetails { + op: Operation::Put, + path: location.clone(), + timestamp, + duration: Some(elapsed), + size: None, + range: None, + extra_display: None, + }); + + Ok(ret) + } + async fn instrumented_get_opts( &self, location: &Path, @@ -134,6 +193,128 @@ impl InstrumentedObjectStore { Ok(ret) } + + async fn instrumented_delete(&self, location: &Path) -> Result<()> { + let timestamp = Utc::now(); + let start = Instant::now(); + self.inner.delete(location).await?; + let elapsed = start.elapsed(); + + self.requests.lock().push(RequestDetails { + op: Operation::Delete, + path: location.clone(), + timestamp, + duration: Some(elapsed), + size: None, + range: None, + extra_display: None, + }); + + Ok(()) + } + + fn instrumented_list( + &self, + prefix: Option<&Path>, + ) -> BoxStream<'static, Result> { + let timestamp = Utc::now(); + let ret = self.inner.list(prefix); + + self.requests.lock().push(RequestDetails { + op: Operation::List, + path: prefix.cloned().unwrap_or_else(|| Path::from("")), + timestamp, + duration: None, // list returns a stream, so the duration isn't meaningful + size: None, + range: None, + extra_display: None, + }); + + ret + } + + async fn instrumented_list_with_delimiter( + &self, + prefix: Option<&Path>, + ) -> Result { + let timestamp = Utc::now(); + let start = Instant::now(); + let ret = self.inner.list_with_delimiter(prefix).await?; + let elapsed = start.elapsed(); + + self.requests.lock().push(RequestDetails { + op: Operation::List, + path: prefix.cloned().unwrap_or_else(|| Path::from("")), + timestamp, + duration: Some(elapsed), + size: None, + range: None, + extra_display: None, + }); + + Ok(ret) + } + + async fn instrumented_copy(&self, from: &Path, to: &Path) -> Result<()> { + let timestamp = Utc::now(); + let start = Instant::now(); + self.inner.copy(from, to).await?; + let elapsed = start.elapsed(); + + self.requests.lock().push(RequestDetails { + op: Operation::Copy, + path: from.clone(), + timestamp, + duration: Some(elapsed), + size: None, + range: None, + extra_display: Some(format!("copy_to: {to}")), + }); + + Ok(()) + } + + async fn instrumented_copy_if_not_exists( + &self, + from: &Path, + to: &Path, + ) -> Result<()> { + let timestamp = Utc::now(); + let start = Instant::now(); + self.inner.copy_if_not_exists(from, to).await?; + let elapsed = start.elapsed(); + + self.requests.lock().push(RequestDetails { + op: Operation::Copy, + path: from.clone(), + timestamp, + duration: Some(elapsed), + size: None, + range: None, + extra_display: Some(format!("copy_to: {to}")), + }); + + Ok(()) + } + + async fn instrumented_head(&self, location: &Path) -> Result { + let timestamp = Utc::now(); + let start = Instant::now(); + let ret = self.inner.head(location).await?; + let elapsed = start.elapsed(); + + self.requests.lock().push(RequestDetails { + op: Operation::Head, + path: location.clone(), + timestamp, + duration: Some(elapsed), + size: None, + range: None, + extra_display: None, + }); + + Ok(ret) + } } impl fmt::Display for InstrumentedObjectStore { @@ -156,6 +337,10 @@ impl ObjectStore for InstrumentedObjectStore { payload: PutPayload, opts: PutOptions, ) -> Result { + if self.enabled() { + return self.instrumented_put_opts(location, payload, opts).await; + } + self.inner.put_opts(location, payload, opts).await } @@ -164,13 +349,15 @@ impl ObjectStore for InstrumentedObjectStore { location: &Path, opts: PutMultipartOptions, ) -> Result> { + if self.enabled() { + return self.instrumented_put_multipart(location, opts).await; + } + self.inner.put_multipart_opts(location, opts).await } async fn get_opts(&self, location: &Path, options: GetOptions) -> Result { - if self.instrument_mode.load(Ordering::Relaxed) - != InstrumentedObjectStoreMode::Disabled as u8 - { + if self.enabled() { return self.instrumented_get_opts(location, options).await; } @@ -178,39 +365,69 @@ impl ObjectStore for InstrumentedObjectStore { } async fn delete(&self, location: &Path) -> Result<()> { + if self.enabled() { + return self.instrumented_delete(location).await; + } + self.inner.delete(location).await } fn list(&self, prefix: Option<&Path>) -> BoxStream<'static, Result> { + if self.enabled() { + return self.instrumented_list(prefix); + } + self.inner.list(prefix) } async fn list_with_delimiter(&self, prefix: Option<&Path>) -> Result { + if self.enabled() { + return self.instrumented_list_with_delimiter(prefix).await; + } + self.inner.list_with_delimiter(prefix).await } async fn copy(&self, from: &Path, to: &Path) -> Result<()> { + if self.enabled() { + return self.instrumented_copy(from, to).await; + } + self.inner.copy(from, to).await } async fn copy_if_not_exists(&self, from: &Path, to: &Path) -> Result<()> { + if self.enabled() { + return self.instrumented_copy_if_not_exists(from, to).await; + } + self.inner.copy_if_not_exists(from, to).await } async fn head(&self, location: &Path) -> Result { + if self.enabled() { + return self.instrumented_head(location).await; + } + self.inner.head(location).await } } /// Object store operation types tracked by [`InstrumentedObjectStore`] -#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash)] +#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash, PartialOrd, Ord)] pub enum Operation { - _Copy, - _Delete, + Copy, + Delete, Get, - _Head, - _List, - _Put, + Head, + List, + Put, +} + +impl fmt::Display for Operation { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{self:?}") + } } /// Holds profiling details about individual requests made through an [`InstrumentedObjectStore`] @@ -252,35 +469,172 @@ impl fmt::Display for RequestDetails { } } -/// Summary statistics for an [`InstrumentedObjectStore`]'s [`RequestDetails`] +/// Summary statistics for all requests recorded in an [`InstrumentedObjectStore`] #[derive(Default)] -pub struct RequestSummary { - count: usize, - duration_stats: Option>, - size_stats: Option>, +pub struct RequestSummaries { + summaries: Vec, } -impl RequestSummary { - /// Generates a set of [RequestSummaries](RequestSummary) from the input [`RequestDetails`] - /// grouped by the input's [`Operation`] - pub fn summarize_by_operation( - requests: &[RequestDetails], - ) -> HashMap { - let mut summaries: HashMap = HashMap::new(); +/// Display the summary as a table +impl fmt::Display for RequestSummaries { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + // Don't expect an error, but avoid panicking if it happens + match pretty_format_batches(&[self.to_batch()]) { + Err(e) => { + write!(f, "Error formatting summary: {e}") + } + Ok(displayable) => { + write!(f, "{displayable}") + } + } + } +} + +impl RequestSummaries { + /// Summarizes input [`RequestDetails`] + pub fn new(requests: &[RequestDetails]) -> Self { + let mut summaries: HashMap = HashMap::new(); for rd in requests { match summaries.get_mut(&rd.op) { Some(rs) => rs.push(rd), None => { - let mut rs = RequestSummary::default(); + let mut rs = RequestSummary::new(rd.op); rs.push(rd); summaries.insert(rd.op, rs); } } } + // Convert to a Vec with consistent ordering + let mut summaries: Vec = summaries.into_values().collect(); + summaries.sort_by_key(|s| s.operation); + Self { summaries } + } + + /// Convert the summaries into a `RecordBatch` for display + /// + /// Results in a table like: + /// ```text + /// +-----------+----------+-----------+-----------+-----------+-----------+-----------+ + /// | Operation | Metric | min | max | avg | sum | count | + /// +-----------+----------+-----------+-----------+-----------+-----------+-----------+ + /// | Get | duration | 5.000000s | 5.000000s | 5.000000s | | 1 | + /// | Get | size | 100 B | 100 B | 100 B | 100 B | 1 | + /// +-----------+----------+-----------+-----------+-----------+-----------+-----------+ + /// ``` + pub fn to_batch(&self) -> RecordBatch { + let operations: StringArray = self + .iter() + .flat_map(|s| std::iter::repeat_n(Some(s.operation.to_string()), 2)) + .collect(); + let metrics: StringArray = self + .iter() + .flat_map(|_s| [Some("duration"), Some("size")]) + .collect(); + let mins: StringArray = self + .stats_iter() + .flat_map(|(duration_stats, size_stats)| { + let dur_min = + duration_stats.map(|d| format!("{:.6}s", d.min.as_secs_f32())); + let size_min = size_stats.map(|s| format!("{} B", s.min)); + [dur_min, size_min] + }) + .collect(); + let maxs: StringArray = self + .stats_iter() + .flat_map(|(duration_stats, size_stats)| { + let dur_max = + duration_stats.map(|d| format!("{:.6}s", d.max.as_secs_f32())); + let size_max = size_stats.map(|s| format!("{} B", s.max)); + [dur_max, size_max] + }) + .collect(); + let avgs: StringArray = self + .iter() + .flat_map(|s| { + let count = s.count as f32; + let duration_stats = s.duration_stats.as_ref(); + let size_stats = s.size_stats.as_ref(); + let dur_avg = duration_stats.map(|d| { + let avg = d.sum.as_secs_f32() / count; + format!("{avg:.6}s") + }); + let size_avg = size_stats.map(|s| { + let avg = s.sum as f32 / count; + format!("{avg} B") + }); + [dur_avg, size_avg] + }) + .collect(); + let sums: StringArray = self + .stats_iter() + .flat_map(|(duration_stats, size_stats)| { + // Omit a sum stat for duration in the initial + // implementation because it can be a bit misleading (at least + // at first glance). For example, particularly large queries the + // sum of the durations was often larger than the total time of + // the query itself, can be confusing without additional + // explanation (e.g. that the sum is of individual requests, + // which may be concurrent). + let dur_sum = + duration_stats.map(|d| format!("{:.6}s", d.sum.as_secs_f32())); + let size_sum = size_stats.map(|s| format!("{} B", s.sum)); + [dur_sum, size_sum] + }) + .collect(); + let counts: StringArray = self + .iter() + .flat_map(|s| { + let count = s.count.to_string(); + [Some(count.clone()), Some(count)] + }) + .collect(); + + RecordBatch::try_from_iter(vec![ + ("Operation", Arc::new(operations) as ArrayRef), + ("Metric", Arc::new(metrics) as ArrayRef), + ("min", Arc::new(mins) as ArrayRef), + ("max", Arc::new(maxs) as ArrayRef), + ("avg", Arc::new(avgs) as ArrayRef), + ("sum", Arc::new(sums) as ArrayRef), + ("count", Arc::new(counts) as ArrayRef), + ]) + .expect("Created the batch correctly") + } - summaries + /// Return an iterator over the summaries + fn iter(&self) -> impl Iterator { + self.summaries.iter() + } + + /// Return an iterator over (duration_stats, size_stats) tuples + /// for each summary + fn stats_iter( + &self, + ) -> impl Iterator>, Option<&Stats>)> { + self.summaries + .iter() + .map(|s| (s.duration_stats.as_ref(), s.size_stats.as_ref())) } +} +/// Summary statistics for a particular type of [`Operation`] (e.g. `GET` or `PUT`) +/// in an [`InstrumentedObjectStore`]'s [`RequestDetails`] +pub struct RequestSummary { + operation: Operation, + count: usize, + duration_stats: Option>, + size_stats: Option>, +} + +impl RequestSummary { + fn new(operation: Operation) -> Self { + Self { + operation, + count: 0, + duration_stats: None, + size_stats: None, + } + } fn push(&mut self, request: &RequestDetails) { self.count += 1; if let Some(dur) = request.duration { @@ -292,29 +646,6 @@ impl RequestSummary { } } -impl fmt::Display for RequestSummary { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - writeln!(f, "count: {}", self.count)?; - - if let Some(dur_stats) = &self.duration_stats { - writeln!(f, "duration min: {:.6}s", dur_stats.min.as_secs_f32())?; - writeln!(f, "duration max: {:.6}s", dur_stats.max.as_secs_f32())?; - let avg = dur_stats.sum.as_secs_f32() / (self.count as f32); - writeln!(f, "duration avg: {:.6}s", avg)?; - } - - if let Some(size_stats) = &self.size_stats { - writeln!(f, "size min: {} B", size_stats.min)?; - writeln!(f, "size max: {} B", size_stats.max)?; - let avg = size_stats.sum / self.count; - writeln!(f, "size avg: {} B", avg)?; - writeln!(f, "size sum: {} B", size_stats.sum)?; - } - - Ok(()) - } -} - struct Stats> { min: T, max: T, @@ -413,6 +744,13 @@ impl ObjectStoreRegistry for InstrumentedObjectStoreRegistry { self.inner.register_store(url, instrumented) } + fn deregister_store( + &self, + url: &Url, + ) -> datafusion::common::Result> { + self.inner.deregister_store(url) + } + fn get_store(&self, url: &Url) -> datafusion::common::Result> { self.inner.get_store(url) } @@ -420,7 +758,10 @@ impl ObjectStoreRegistry for InstrumentedObjectStoreRegistry { #[cfg(test)] mod tests { + use object_store::WriteMultipart; + use super::*; + use insta::assert_snapshot; #[test] fn instrumented_mode() { @@ -434,16 +775,21 @@ mod tests { InstrumentedObjectStoreMode::Disabled )); assert!(matches!( - "EnABlEd".parse().unwrap(), - InstrumentedObjectStoreMode::Enabled + "SUmMaRy".parse().unwrap(), + InstrumentedObjectStoreMode::Summary + )); + assert!(matches!( + "TRaCe".parse().unwrap(), + InstrumentedObjectStoreMode::Trace )); assert!("does_not_exist" .parse::() .is_err()); assert!(matches!(0.into(), InstrumentedObjectStoreMode::Disabled)); - assert!(matches!(1.into(), InstrumentedObjectStoreMode::Enabled)); - assert!(matches!(2.into(), InstrumentedObjectStoreMode::Disabled)); + assert!(matches!(1.into(), InstrumentedObjectStoreMode::Summary)); + assert!(matches!(2.into(), InstrumentedObjectStoreMode::Trace)); + assert!(matches!(3.into(), InstrumentedObjectStoreMode::Disabled)); } #[test] @@ -455,8 +801,8 @@ mod tests { InstrumentedObjectStoreMode::default() ); - reg = reg.with_profile_mode(InstrumentedObjectStoreMode::Enabled); - assert_eq!(reg.instrument_mode(), InstrumentedObjectStoreMode::Enabled); + reg = reg.with_profile_mode(InstrumentedObjectStoreMode::Trace); + assert_eq!(reg.instrument_mode(), InstrumentedObjectStoreMode::Trace); let store = object_store::memory::InMemory::new(); let url = "mem://test".parse().unwrap(); @@ -468,8 +814,9 @@ mod tests { assert_eq!(reg.stores().len(), 1); } - #[tokio::test] - async fn instrumented_store() { + // Returns an `InstrumentedObjectStore` with some data loaded for testing and the path to + // access the data + async fn setup_test_store() -> (InstrumentedObjectStore, Path) { let store = Arc::new(object_store::memory::InMemory::new()); let mode = AtomicU8::new(InstrumentedObjectStoreMode::default() as u8); let instrumented = InstrumentedObjectStore::new(store, mode); @@ -479,12 +826,19 @@ mod tests { let payload = PutPayload::from_static(b"test_data"); instrumented.put(&path, payload).await.unwrap(); + (instrumented, path) + } + + #[tokio::test] + async fn instrumented_store_get() { + let (instrumented, path) = setup_test_store().await; + // By default no requests should be instrumented/stored assert!(instrumented.requests.lock().is_empty()); let _ = instrumented.get(&path).await.unwrap(); assert!(instrumented.requests.lock().is_empty()); - instrumented.set_instrument_mode(InstrumentedObjectStoreMode::Enabled); + instrumented.set_instrument_mode(InstrumentedObjectStoreMode::Trace); assert!(instrumented.requests.lock().is_empty()); let _ = instrumented.get(&path).await.unwrap(); assert_eq!(instrumented.requests.lock().len(), 1); @@ -502,6 +856,244 @@ mod tests { assert!(request.extra_display.is_none()); } + #[tokio::test] + async fn instrumented_store_delete() { + let (instrumented, path) = setup_test_store().await; + + // By default no requests should be instrumented/stored + assert!(instrumented.requests.lock().is_empty()); + instrumented.delete(&path).await.unwrap(); + assert!(instrumented.requests.lock().is_empty()); + + // We need a new store so we have data to delete again + let (instrumented, path) = setup_test_store().await; + instrumented.set_instrument_mode(InstrumentedObjectStoreMode::Trace); + assert!(instrumented.requests.lock().is_empty()); + instrumented.delete(&path).await.unwrap(); + assert_eq!(instrumented.requests.lock().len(), 1); + + let mut requests = instrumented.take_requests(); + assert_eq!(requests.len(), 1); + assert!(instrumented.requests.lock().is_empty()); + + let request = requests.pop().unwrap(); + assert_eq!(request.op, Operation::Delete); + assert_eq!(request.path, path); + assert!(request.duration.is_some()); + assert!(request.size.is_none()); + assert!(request.range.is_none()); + assert!(request.extra_display.is_none()); + } + + #[tokio::test] + async fn instrumented_store_list() { + let (instrumented, path) = setup_test_store().await; + + // By default no requests should be instrumented/stored + assert!(instrumented.requests.lock().is_empty()); + let _ = instrumented.list(Some(&path)); + assert!(instrumented.requests.lock().is_empty()); + + instrumented.set_instrument_mode(InstrumentedObjectStoreMode::Trace); + assert!(instrumented.requests.lock().is_empty()); + let _ = instrumented.list(Some(&path)); + assert_eq!(instrumented.requests.lock().len(), 1); + + let request = instrumented.take_requests().pop().unwrap(); + assert_eq!(request.op, Operation::List); + assert_eq!(request.path, path); + assert!(request.duration.is_none()); + assert!(request.size.is_none()); + assert!(request.range.is_none()); + assert!(request.extra_display.is_none()); + } + + #[tokio::test] + async fn instrumented_store_list_with_delimiter() { + let (instrumented, path) = setup_test_store().await; + + // By default no requests should be instrumented/stored + assert!(instrumented.requests.lock().is_empty()); + let _ = instrumented.list_with_delimiter(Some(&path)).await.unwrap(); + assert!(instrumented.requests.lock().is_empty()); + + instrumented.set_instrument_mode(InstrumentedObjectStoreMode::Trace); + assert!(instrumented.requests.lock().is_empty()); + let _ = instrumented.list_with_delimiter(Some(&path)).await.unwrap(); + assert_eq!(instrumented.requests.lock().len(), 1); + + let request = instrumented.take_requests().pop().unwrap(); + assert_eq!(request.op, Operation::List); + assert_eq!(request.path, path); + assert!(request.duration.is_some()); + assert!(request.size.is_none()); + assert!(request.range.is_none()); + assert!(request.extra_display.is_none()); + } + + #[tokio::test] + async fn instrumented_store_put_opts() { + // The `setup_test_store()` method comes with data already `put` into it, so we'll setup + // manually for this test + let store = Arc::new(object_store::memory::InMemory::new()); + let mode = AtomicU8::new(InstrumentedObjectStoreMode::default() as u8); + let instrumented = InstrumentedObjectStore::new(store, mode); + + let path = Path::from("test/data"); + let payload = PutPayload::from_static(b"test_data"); + let size = payload.content_length(); + + // By default no requests should be instrumented/stored + assert!(instrumented.requests.lock().is_empty()); + instrumented.put(&path, payload.clone()).await.unwrap(); + assert!(instrumented.requests.lock().is_empty()); + + instrumented.set_instrument_mode(InstrumentedObjectStoreMode::Trace); + assert!(instrumented.requests.lock().is_empty()); + instrumented.put(&path, payload).await.unwrap(); + assert_eq!(instrumented.requests.lock().len(), 1); + + let request = instrumented.take_requests().pop().unwrap(); + assert_eq!(request.op, Operation::Put); + assert_eq!(request.path, path); + assert!(request.duration.is_some()); + assert_eq!(request.size.unwrap(), size); + assert!(request.range.is_none()); + assert!(request.extra_display.is_none()); + } + + #[tokio::test] + async fn instrumented_store_put_multipart() { + // The `setup_test_store()` method comes with data already `put` into it, so we'll setup + // manually for this test + let store = Arc::new(object_store::memory::InMemory::new()); + let mode = AtomicU8::new(InstrumentedObjectStoreMode::default() as u8); + let instrumented = InstrumentedObjectStore::new(store, mode); + + let path = Path::from("test/data"); + + // By default no requests should be instrumented/stored + assert!(instrumented.requests.lock().is_empty()); + let mp = instrumented.put_multipart(&path).await.unwrap(); + let mut write = WriteMultipart::new(mp); + write.write(b"test_data"); + write.finish().await.unwrap(); + assert!(instrumented.requests.lock().is_empty()); + + instrumented.set_instrument_mode(InstrumentedObjectStoreMode::Trace); + assert!(instrumented.requests.lock().is_empty()); + let mp = instrumented.put_multipart(&path).await.unwrap(); + let mut write = WriteMultipart::new(mp); + write.write(b"test_data"); + write.finish().await.unwrap(); + assert_eq!(instrumented.requests.lock().len(), 1); + + let request = instrumented.take_requests().pop().unwrap(); + assert_eq!(request.op, Operation::Put); + assert_eq!(request.path, path); + assert!(request.duration.is_some()); + assert!(request.size.is_none()); + assert!(request.range.is_none()); + assert!(request.extra_display.is_none()); + } + + #[tokio::test] + async fn instrumented_store_copy() { + let (instrumented, path) = setup_test_store().await; + let copy_to = Path::from("test/copied"); + + // By default no requests should be instrumented/stored + assert!(instrumented.requests.lock().is_empty()); + instrumented.copy(&path, ©_to).await.unwrap(); + assert!(instrumented.requests.lock().is_empty()); + + instrumented.set_instrument_mode(InstrumentedObjectStoreMode::Trace); + assert!(instrumented.requests.lock().is_empty()); + instrumented.copy(&path, ©_to).await.unwrap(); + assert_eq!(instrumented.requests.lock().len(), 1); + + let mut requests = instrumented.take_requests(); + assert_eq!(requests.len(), 1); + assert!(instrumented.requests.lock().is_empty()); + + let request = requests.pop().unwrap(); + assert_eq!(request.op, Operation::Copy); + assert_eq!(request.path, path); + assert!(request.duration.is_some()); + assert!(request.size.is_none()); + assert!(request.range.is_none()); + assert_eq!( + request.extra_display.unwrap(), + format!("copy_to: {copy_to}") + ); + } + + #[tokio::test] + async fn instrumented_store_copy_if_not_exists() { + let (instrumented, path) = setup_test_store().await; + let mut copy_to = Path::from("test/copied"); + + // By default no requests should be instrumented/stored + assert!(instrumented.requests.lock().is_empty()); + instrumented + .copy_if_not_exists(&path, ©_to) + .await + .unwrap(); + assert!(instrumented.requests.lock().is_empty()); + + // Use a new destination since the previous one already exists + copy_to = Path::from("test/copied_again"); + instrumented.set_instrument_mode(InstrumentedObjectStoreMode::Trace); + assert!(instrumented.requests.lock().is_empty()); + instrumented + .copy_if_not_exists(&path, ©_to) + .await + .unwrap(); + assert_eq!(instrumented.requests.lock().len(), 1); + + let mut requests = instrumented.take_requests(); + assert_eq!(requests.len(), 1); + assert!(instrumented.requests.lock().is_empty()); + + let request = requests.pop().unwrap(); + assert_eq!(request.op, Operation::Copy); + assert_eq!(request.path, path); + assert!(request.duration.is_some()); + assert!(request.size.is_none()); + assert!(request.range.is_none()); + assert_eq!( + request.extra_display.unwrap(), + format!("copy_to: {copy_to}") + ); + } + + #[tokio::test] + async fn instrumented_store_head() { + let (instrumented, path) = setup_test_store().await; + + // By default no requests should be instrumented/stored + assert!(instrumented.requests.lock().is_empty()); + let _ = instrumented.head(&path).await.unwrap(); + assert!(instrumented.requests.lock().is_empty()); + + instrumented.set_instrument_mode(InstrumentedObjectStoreMode::Trace); + assert!(instrumented.requests.lock().is_empty()); + let _ = instrumented.head(&path).await.unwrap(); + assert_eq!(instrumented.requests.lock().len(), 1); + + let mut requests = instrumented.take_requests(); + assert_eq!(requests.len(), 1); + assert!(instrumented.requests.lock().is_empty()); + + let request = requests.pop().unwrap(); + assert_eq!(request.op, Operation::Head); + assert_eq!(request.path, path); + assert!(request.duration.is_some()); + assert!(request.size.is_none()); + assert!(request.range.is_none()); + assert!(request.extra_display.is_none()); + } + #[test] fn request_details() { let rd = RequestDetails { @@ -524,8 +1116,12 @@ mod tests { fn request_summary() { // Test empty request list let mut requests = Vec::new(); - let summaries = RequestSummary::summarize_by_operation(&requests); - assert!(summaries.is_empty()); + assert_snapshot!(RequestSummaries::new(&requests), @r" + +-----------+--------+-----+-----+-----+-----+-------+ + | Operation | Metric | min | max | avg | sum | count | + +-----------+--------+-----+-----+-----+-----+-------+ + +-----------+--------+-----+-----+-----+-----+-------+ + "); requests.push(RequestDetails { op: Operation::Get, @@ -537,26 +1133,14 @@ mod tests { extra_display: None, }); - let summaries = RequestSummary::summarize_by_operation(&requests); - assert_eq!(summaries.len(), 1); - - let summary = summaries.get(&Operation::Get).unwrap(); - assert_eq!(summary.count, 1); - assert_eq!( - summary.duration_stats.as_ref().unwrap().min, - Duration::from_secs(5) - ); - assert_eq!( - summary.duration_stats.as_ref().unwrap().max, - Duration::from_secs(5) - ); - assert_eq!( - summary.duration_stats.as_ref().unwrap().sum, - Duration::from_secs(5) - ); - assert_eq!(summary.size_stats.as_ref().unwrap().min, 100); - assert_eq!(summary.size_stats.as_ref().unwrap().max, 100); - assert_eq!(summary.size_stats.as_ref().unwrap().sum, 100); + assert_snapshot!(RequestSummaries::new(&requests), @r" + +-----------+----------+-----------+-----------+-----------+-----------+-------+ + | Operation | Metric | min | max | avg | sum | count | + +-----------+----------+-----------+-----------+-----------+-----------+-------+ + | Get | duration | 5.000000s | 5.000000s | 5.000000s | 5.000000s | 1 | + | Get | size | 100 B | 100 B | 100 B | 100 B | 1 | + +-----------+----------+-----------+-----------+-----------+-----------+-------+ + "); // Add more Get requests to test aggregation requests.push(RequestDetails { @@ -577,31 +1161,18 @@ mod tests { range: None, extra_display: None, }); - - let summaries = RequestSummary::summarize_by_operation(&requests); - assert_eq!(summaries.len(), 1); - - let summary = summaries.get(&Operation::Get).unwrap(); - assert_eq!(summary.count, 3); - assert_eq!( - summary.duration_stats.as_ref().unwrap().min, - Duration::from_secs(2) - ); - assert_eq!( - summary.duration_stats.as_ref().unwrap().max, - Duration::from_secs(8) - ); - assert_eq!( - summary.duration_stats.as_ref().unwrap().sum, - Duration::from_secs(15) - ); - assert_eq!(summary.size_stats.as_ref().unwrap().min, 50); - assert_eq!(summary.size_stats.as_ref().unwrap().max, 150); - assert_eq!(summary.size_stats.as_ref().unwrap().sum, 300); + assert_snapshot!(RequestSummaries::new(&requests), @r" + +-----------+----------+-----------+-----------+-----------+------------+-------+ + | Operation | Metric | min | max | avg | sum | count | + +-----------+----------+-----------+-----------+-----------+------------+-------+ + | Get | duration | 2.000000s | 8.000000s | 5.000000s | 15.000000s | 3 | + | Get | size | 50 B | 150 B | 100 B | 300 B | 3 | + +-----------+----------+-----------+-----------+-----------+------------+-------+ + "); // Add Put requests to test grouping requests.push(RequestDetails { - op: Operation::_Put, + op: Operation::Put, path: Path::from("test4"), timestamp: chrono::DateTime::from_timestamp(3, 0).unwrap(), duration: Some(Duration::from_millis(200)), @@ -610,20 +1181,20 @@ mod tests { extra_display: None, }); - let summaries = RequestSummary::summarize_by_operation(&requests); - assert_eq!(summaries.len(), 2); - - let get_summary = summaries.get(&Operation::Get).unwrap(); - assert_eq!(get_summary.count, 3); - - let put_summary = summaries.get(&Operation::_Put).unwrap(); - assert_eq!(put_summary.count, 1); - assert_eq!( - put_summary.duration_stats.as_ref().unwrap().min, - Duration::from_millis(200) - ); - assert_eq!(put_summary.size_stats.as_ref().unwrap().sum, 75); + assert_snapshot!(RequestSummaries::new(&requests), @r" + +-----------+----------+-----------+-----------+-----------+------------+-------+ + | Operation | Metric | min | max | avg | sum | count | + +-----------+----------+-----------+-----------+-----------+------------+-------+ + | Get | duration | 2.000000s | 8.000000s | 5.000000s | 15.000000s | 3 | + | Get | size | 50 B | 150 B | 100 B | 300 B | 3 | + | Put | duration | 0.200000s | 0.200000s | 0.200000s | 0.200000s | 1 | + | Put | size | 75 B | 75 B | 75 B | 75 B | 1 | + +-----------+----------+-----------+-----------+-----------+------------+-------+ + "); + } + #[test] + fn request_summary_only_duration() { // Test request with only duration (no size) let only_duration = vec![RequestDetails { op: Operation::Get, @@ -634,12 +1205,18 @@ mod tests { range: None, extra_display: None, }]; - let summaries = RequestSummary::summarize_by_operation(&only_duration); - let summary = summaries.get(&Operation::Get).unwrap(); - assert_eq!(summary.count, 1); - assert!(summary.duration_stats.is_some()); - assert!(summary.size_stats.is_none()); + assert_snapshot!(RequestSummaries::new(&only_duration), @r" + +-----------+----------+-----------+-----------+-----------+-----------+-------+ + | Operation | Metric | min | max | avg | sum | count | + +-----------+----------+-----------+-----------+-----------+-----------+-------+ + | Get | duration | 3.000000s | 3.000000s | 3.000000s | 3.000000s | 1 | + | Get | size | | | | | 1 | + +-----------+----------+-----------+-----------+-----------+-----------+-------+ + "); + } + #[test] + fn request_summary_only_size() { // Test request with only size (no duration) let only_size = vec![RequestDetails { op: Operation::Get, @@ -650,13 +1227,18 @@ mod tests { range: None, extra_display: None, }]; - let summaries = RequestSummary::summarize_by_operation(&only_size); - let summary = summaries.get(&Operation::Get).unwrap(); - assert_eq!(summary.count, 1); - assert!(summary.duration_stats.is_none()); - assert!(summary.size_stats.is_some()); - assert_eq!(summary.size_stats.as_ref().unwrap().sum, 200); + assert_snapshot!(RequestSummaries::new(&only_size), @r" + +-----------+----------+-------+-------+-------+-------+-------+ + | Operation | Metric | min | max | avg | sum | count | + +-----------+----------+-------+-------+-------+-------+-------+ + | Get | duration | | | | | 1 | + | Get | size | 200 B | 200 B | 200 B | 200 B | 1 | + +-----------+----------+-------+-------+-------+-------+-------+ + "); + } + #[test] + fn request_summary_neither_duration_or_size() { // Test request with neither duration nor size let no_stats = vec![RequestDetails { op: Operation::Get, @@ -667,10 +1249,13 @@ mod tests { range: None, extra_display: None, }]; - let summaries = RequestSummary::summarize_by_operation(&no_stats); - let summary = summaries.get(&Operation::Get).unwrap(); - assert_eq!(summary.count, 1); - assert!(summary.duration_stats.is_none()); - assert!(summary.size_stats.is_none()); + assert_snapshot!(RequestSummaries::new(&no_stats), @r" + +-----------+----------+-----+-----+-----+-----+-------+ + | Operation | Metric | min | max | avg | sum | count | + +-----------+----------+-----+-----+-----+-----+-------+ + | Get | duration | | | | | 1 | + | Get | size | | | | | 1 | + +-----------+----------+-----+-----+-----+-----+-------+ + "); } } diff --git a/datafusion-cli/src/print_options.rs b/datafusion-cli/src/print_options.rs index f54de189b4ef..93d1d450fd82 100644 --- a/datafusion-cli/src/print_options.rs +++ b/datafusion-cli/src/print_options.rs @@ -22,7 +22,7 @@ use std::str::FromStr; use std::sync::Arc; use crate::object_storage::instrumented::{ - InstrumentedObjectStoreMode, InstrumentedObjectStoreRegistry, RequestSummary, + InstrumentedObjectStoreMode, InstrumentedObjectStoreRegistry, RequestSummaries, }; use crate::print_format::PrintFormat; @@ -188,27 +188,25 @@ impl PrintOptions { if !self.quiet { writeln!(writer, "{formatted_exec_details}")?; - if self.instrumented_registry.instrument_mode() - != InstrumentedObjectStoreMode::Disabled - { + let instrument_mode = self.instrumented_registry.instrument_mode(); + if instrument_mode != InstrumentedObjectStoreMode::Disabled { writeln!(writer, "{OBJECT_STORE_PROFILING_HEADER}")?; for store in self.instrumented_registry.stores() { let requests = store.take_requests(); if !requests.is_empty() { writeln!(writer, "{store}")?; - for req in requests.iter() { - writeln!(writer, "{req}")?; + if instrument_mode == InstrumentedObjectStoreMode::Trace { + for req in requests.iter() { + writeln!(writer, "{req}")?; + } + // Add an extra blank line to help visually organize the output + writeln!(writer)?; } - // Add an extra blank line to help visually organize the output - writeln!(writer)?; writeln!(writer, "Summaries:")?; - let summaries = RequestSummary::summarize_by_operation(&requests); - for (op, summary) in summaries { - writeln!(writer, "{op:?}")?; - writeln!(writer, "{summary}")?; - } + let summaries = RequestSummaries::new(&requests); + writeln!(writer, "{summaries}")?; } } } @@ -252,7 +250,7 @@ mod tests { print_output.clear(); print_options .instrumented_registry - .set_instrument_mode(InstrumentedObjectStoreMode::Enabled); + .set_instrument_mode(InstrumentedObjectStoreMode::Trace); print_options.write_output(&mut print_output, exec_out.clone())?; let out_str: String = print_output .clone() diff --git a/datafusion-cli/tests/cli_integration.rs b/datafusion-cli/tests/cli_integration.rs index a67924fef253..c1395aa4f562 100644 --- a/datafusion-cli/tests/cli_integration.rs +++ b/datafusion-cli/tests/cli_integration.rs @@ -402,7 +402,6 @@ async fn test_object_store_profiling() { let container = setup_minio_container().await; let mut settings = make_settings(); - settings.set_snapshot_suffix("s3_url_fallback"); // as the object store profiling contains timestamps and durations, we must // filter them out to have stable snapshots @@ -412,18 +411,17 @@ async fn test_object_store_profiling() { // Output: // operation=Get duration=[DURATION] size=1006 path=cars.csv settings.add_filter( - r"\d{4}-\d{2}-\d{2}T\d{2}:\d{2}:\d{2}(?:\.\d+)?[+-]\d{2}:\d{2} operation=(Get|Put|Delete|List|Head) duration=\d+\.\d{6}s size=(\d+) path=(.*)", - " operation=$1 duration=[DURATION] size=$2 path=$3", + r"\d{4}-\d{2}-\d{2}T\d{2}:\d{2}:\d{2}(?:\.\d+)?[+-]\d{2}:\d{2} operation=(Get|Put|Delete|List|Head) duration=\d+\.\d{6}s (size=\d+\s+)?path=(.*)", + " operation=$1 duration=[DURATION] ${2}path=$3", ); - // We also need to filter out the durations reported in the summary output - // + // We also need to filter out the summary statistics (anything with an 's' at the end) // Example line(s) to filter: - // - // duration min: 0.000729s - // duration max: 0.000729s - // duration avg: 0.000729s - settings.add_filter(r"duration (min|max|avg): \d+\.\d{6}s", "[SUMMARY_DURATION]"); + // | Get | duration | 5.000000s | 5.000000s | 5.000000s | | 1 | + settings.add_filter( + r"\| (Get|Put|Delete|List|Head)( +)\| duration \| .*? \| .*? \| .*? \| .*? \| (.*?) \|", + "| $1$2 | duration | ...NORMALIZED...| $3 |", + ); let _bound = settings.bind_to_scope(); @@ -434,8 +432,11 @@ LOCATION 's3://data/cars.csv'; -- Initial query should not show any profiling as the object store is not instrumented yet SELECT * from CARS LIMIT 1; -\object_store_profiling enabled --- Query again to see the profiling output +\object_store_profiling trace +-- Query again to see the full profiling output +SELECT * from CARS LIMIT 1; +\object_store_profiling summary +-- Query again to see the summarized profiling output SELECT * from CARS LIMIT 1; \object_store_profiling disabled -- Final query should not show any profiling as we disabled it again diff --git a/datafusion-cli/tests/snapshots/object_store_profiling.snap b/datafusion-cli/tests/snapshots/object_store_profiling.snap new file mode 100644 index 000000000000..029b07c324f5 --- /dev/null +++ b/datafusion-cli/tests/snapshots/object_store_profiling.snap @@ -0,0 +1,83 @@ +--- +source: datafusion-cli/tests/cli_integration.rs +info: + program: datafusion-cli + args: [] + env: + AWS_ACCESS_KEY_ID: TEST-DataFusionLogin + AWS_ALLOW_HTTP: "true" + AWS_ENDPOINT: "http://localhost:55057" + AWS_SECRET_ACCESS_KEY: TEST-DataFusionPassword + stdin: "\n CREATE EXTERNAL TABLE CARS\nSTORED AS CSV\nLOCATION 's3://data/cars.csv';\n\n-- Initial query should not show any profiling as the object store is not instrumented yet\nSELECT * from CARS LIMIT 1;\n\\object_store_profiling trace\n-- Query again to see the full profiling output\nSELECT * from CARS LIMIT 1;\n\\object_store_profiling summary\n-- Query again to see the summarized profiling output\nSELECT * from CARS LIMIT 1;\n\\object_store_profiling disabled\n-- Final query should not show any profiling as we disabled it again\nSELECT * from CARS LIMIT 1;\n" +snapshot_kind: text +--- +success: true +exit_code: 0 +----- stdout ----- +[CLI_VERSION] +0 row(s) fetched. +[ELAPSED] + ++-----+-------+---------------------+ +| car | speed | time | ++-----+-------+---------------------+ +| red | 20.0 | 1996-04-12T12:05:03 | ++-----+-------+---------------------+ +1 row(s) fetched. +[ELAPSED] + +ObjectStore Profile mode set to Trace ++-----+-------+---------------------+ +| car | speed | time | ++-----+-------+---------------------+ +| red | 20.0 | 1996-04-12T12:05:03 | ++-----+-------+---------------------+ +1 row(s) fetched. +[ELAPSED] + +Object Store Profiling +Instrumented Object Store: instrument_mode: Trace, inner: AmazonS3(data) + operation=Head duration=[DURATION] path=cars.csv + operation=Get duration=[DURATION] size=1006 path=cars.csv + +Summaries: ++-----------+----------+-----------+-----------+-----------+-----------+-------+ +| Operation | Metric | min | max | avg | sum | count | ++-----------+----------+-----------+-----------+-----------+-----------+-------+ +| Get | duration | ...NORMALIZED...| 1 | +| Get | size | 1006 B | 1006 B | 1006 B | 1006 B | 1 | +| Head | duration | ...NORMALIZED...| 1 | +| Head | size | | | | | 1 | ++-----------+----------+-----------+-----------+-----------+-----------+-------+ +ObjectStore Profile mode set to Summary ++-----+-------+---------------------+ +| car | speed | time | ++-----+-------+---------------------+ +| red | 20.0 | 1996-04-12T12:05:03 | ++-----+-------+---------------------+ +1 row(s) fetched. +[ELAPSED] + +Object Store Profiling +Instrumented Object Store: instrument_mode: Summary, inner: AmazonS3(data) +Summaries: ++-----------+----------+-----------+-----------+-----------+-----------+-------+ +| Operation | Metric | min | max | avg | sum | count | ++-----------+----------+-----------+-----------+-----------+-----------+-------+ +| Get | duration | ...NORMALIZED...| 1 | +| Get | size | 1006 B | 1006 B | 1006 B | 1006 B | 1 | +| Head | duration | ...NORMALIZED...| 1 | +| Head | size | | | | | 1 | ++-----------+----------+-----------+-----------+-----------+-----------+-------+ +ObjectStore Profile mode set to Disabled ++-----+-------+---------------------+ +| car | speed | time | ++-----+-------+---------------------+ +| red | 20.0 | 1996-04-12T12:05:03 | ++-----+-------+---------------------+ +1 row(s) fetched. +[ELAPSED] + +\q + +----- stderr ----- diff --git a/datafusion-cli/tests/snapshots/object_store_profiling@s3_url_fallback.snap b/datafusion-cli/tests/snapshots/object_store_profiling@s3_url_fallback.snap deleted file mode 100644 index 50c6cc8eab99..000000000000 --- a/datafusion-cli/tests/snapshots/object_store_profiling@s3_url_fallback.snap +++ /dev/null @@ -1,64 +0,0 @@ ---- -source: datafusion-cli/tests/cli_integration.rs -info: - program: datafusion-cli - args: [] - env: - AWS_ACCESS_KEY_ID: TEST-DataFusionLogin - AWS_ALLOW_HTTP: "true" - AWS_ENDPOINT: "http://localhost:55031" - AWS_SECRET_ACCESS_KEY: TEST-DataFusionPassword - stdin: "\n CREATE EXTERNAL TABLE CARS\nSTORED AS CSV\nLOCATION 's3://data/cars.csv';\n\n-- Initial query should not show any profiling as the object store is not instrumented yet\nSELECT * from CARS LIMIT 1;\n\\object_store_profiling enabled\n-- Query again to see the profiling output\nSELECT * from CARS LIMIT 1;\n\\object_store_profiling disabled\n-- Final query should not show any profiling as we disabled it again\nSELECT * from CARS LIMIT 1;\n" -snapshot_kind: text ---- -success: true -exit_code: 0 ------ stdout ----- -[CLI_VERSION] -0 row(s) fetched. -[ELAPSED] - -+-----+-------+---------------------+ -| car | speed | time | -+-----+-------+---------------------+ -| red | 20.0 | 1996-04-12T12:05:03 | -+-----+-------+---------------------+ -1 row(s) fetched. -[ELAPSED] - -ObjectStore Profile mode set to Enabled -+-----+-------+---------------------+ -| car | speed | time | -+-----+-------+---------------------+ -| red | 20.0 | 1996-04-12T12:05:03 | -+-----+-------+---------------------+ -1 row(s) fetched. -[ELAPSED] - -Object Store Profiling -Instrumented Object Store: instrument_mode: Enabled, inner: AmazonS3(data) - operation=Get duration=[DURATION] size=1006 path=cars.csv - -Summaries: -Get -count: 1 -[SUMMARY_DURATION] -[SUMMARY_DURATION] -[SUMMARY_DURATION] -size min: 1006 B -size max: 1006 B -size avg: 1006 B -size sum: 1006 B - -ObjectStore Profile mode set to Disabled -+-----+-------+---------------------+ -| car | speed | time | -+-----+-------+---------------------+ -| red | 20.0 | 1996-04-12T12:05:03 | -+-----+-------+---------------------+ -1 row(s) fetched. -[ELAPSED] - -\q - ------ stderr ----- diff --git a/datafusion-examples/Cargo.toml b/datafusion-examples/Cargo.toml index 68bb5376a1ac..bb0525e57753 100644 --- a/datafusion-examples/Cargo.toml +++ b/datafusion-examples/Cargo.toml @@ -81,7 +81,7 @@ serde_json = { workspace = true } tempfile = { workspace = true } test-utils = { path = "../test-utils" } tokio = { workspace = true, features = ["rt-multi-thread", "parking_lot"] } -tonic = "0.13.1" +tonic = "0.14" tracing = { version = "0.1" } tracing-subscriber = { version = "0.3" } url = { workspace = true } diff --git a/datafusion-examples/examples/custom_file_casts.rs b/datafusion-examples/examples/custom_file_casts.rs index 65ca09682064..4d97ecd91dc6 100644 --- a/datafusion-examples/examples/custom_file_casts.rs +++ b/datafusion-examples/examples/custom_file_casts.rs @@ -25,7 +25,7 @@ use datafusion::common::not_impl_err; use datafusion::common::tree_node::{Transformed, TransformedResult, TreeNode}; use datafusion::common::{Result, ScalarValue}; use datafusion::datasource::listing::{ - ListingTable, ListingTableConfig, ListingTableUrl, + ListingTable, ListingTableConfig, ListingTableConfigExt, ListingTableUrl, }; use datafusion::execution::context::SessionContext; use datafusion::execution::object_store::ObjectStoreUrl; diff --git a/datafusion-examples/examples/flight/flight_client.rs b/datafusion-examples/examples/flight/flight_client.rs index e3237284b430..ff4b5903ad88 100644 --- a/datafusion-examples/examples/flight/flight_client.rs +++ b/datafusion-examples/examples/flight/flight_client.rs @@ -17,6 +17,7 @@ use std::collections::HashMap; use std::sync::Arc; +use tonic::transport::Endpoint; use datafusion::arrow::datatypes::Schema; @@ -34,7 +35,9 @@ async fn main() -> Result<(), Box> { let testdata = datafusion::test_util::parquet_test_data(); // Create Flight client - let mut client = FlightServiceClient::connect("http://localhost:50051").await?; + let endpoint = Endpoint::new("http://localhost:50051")?; + let channel = endpoint.connect().await?; + let mut client = FlightServiceClient::new(channel); // Call get_schema to get the schema of a Parquet file let request = tonic::Request::new(FlightDescriptor { diff --git a/datafusion-examples/examples/flight/flight_server.rs b/datafusion-examples/examples/flight/flight_server.rs index 58bfb7a341c1..22265e415fbd 100644 --- a/datafusion-examples/examples/flight/flight_server.rs +++ b/datafusion-examples/examples/flight/flight_server.rs @@ -15,7 +15,7 @@ // specific language governing permissions and limitations // under the License. -use arrow::ipc::writer::{DictionaryTracker, IpcDataGenerator}; +use arrow::ipc::writer::{CompressionContext, DictionaryTracker, IpcDataGenerator}; use std::sync::Arc; use arrow_flight::{PollInfo, SchemaAsIpc}; @@ -106,6 +106,7 @@ impl FlightService for FlightServiceImpl { // add an initial FlightData message that sends schema let options = arrow::ipc::writer::IpcWriteOptions::default(); + let mut compression_context = CompressionContext::default(); let schema_flight_data = SchemaAsIpc::new(&schema, &options); let mut flights = vec![FlightData::from(schema_flight_data)]; @@ -115,7 +116,7 @@ impl FlightService for FlightServiceImpl { for batch in &results { let (flight_dictionaries, flight_batch) = encoder - .encoded_batch(batch, &mut tracker, &options) + .encode(batch, &mut tracker, &options, &mut compression_context) .map_err(|e: ArrowError| Status::internal(e.to_string()))?; flights.extend(flight_dictionaries.into_iter().map(Into::into)); diff --git a/datafusion-examples/examples/json_shredding.rs b/datafusion-examples/examples/json_shredding.rs index c7d0146a001f..a2e83bc9510a 100644 --- a/datafusion-examples/examples/json_shredding.rs +++ b/datafusion-examples/examples/json_shredding.rs @@ -27,7 +27,7 @@ use datafusion::common::tree_node::{ }; use datafusion::common::{assert_contains, exec_datafusion_err, Result}; use datafusion::datasource::listing::{ - ListingTable, ListingTableConfig, ListingTableUrl, + ListingTable, ListingTableConfig, ListingTableConfigExt, ListingTableUrl, }; use datafusion::execution::context::SessionContext; use datafusion::execution::object_store::ObjectStoreUrl; diff --git a/datafusion-examples/examples/parquet_encrypted.rs b/datafusion-examples/examples/parquet_encrypted.rs index e9e239b7a1c3..690d9f2a5f14 100644 --- a/datafusion-examples/examples/parquet_encrypted.rs +++ b/datafusion-examples/examples/parquet_encrypted.rs @@ -16,12 +16,13 @@ // under the License. use datafusion::common::DataFusionError; -use datafusion::config::TableParquetOptions; +use datafusion::config::{ConfigFileEncryptionProperties, TableParquetOptions}; use datafusion::dataframe::{DataFrame, DataFrameWriteOptions}; use datafusion::logical_expr::{col, lit}; use datafusion::parquet::encryption::decrypt::FileDecryptionProperties; use datafusion::parquet::encryption::encrypt::FileEncryptionProperties; use datafusion::prelude::{ParquetReadOptions, SessionContext}; +use std::sync::Arc; use tempfile::TempDir; #[tokio::main] @@ -55,7 +56,7 @@ async fn main() -> datafusion::common::Result<()> { // Write encrypted parquet let mut options = TableParquetOptions::default(); - options.crypto.file_encryption = Some((&encrypt).into()); + options.crypto.file_encryption = Some(ConfigFileEncryptionProperties::from(&encrypt)); parquet_df .write_parquet( tempfile_str.as_str(), @@ -100,7 +101,8 @@ async fn query_dataframe(df: &DataFrame) -> Result<(), DataFusionError> { // Setup encryption and decryption properties fn setup_encryption( parquet_df: &DataFrame, -) -> Result<(FileEncryptionProperties, FileDecryptionProperties), DataFusionError> { +) -> Result<(Arc, Arc), DataFusionError> +{ let schema = parquet_df.schema(); let footer_key = b"0123456789012345".to_vec(); // 128bit/16 let column_key = b"1234567890123450".to_vec(); // 128bit/16 diff --git a/datafusion-examples/examples/parquet_encrypted_with_kms.rs b/datafusion-examples/examples/parquet_encrypted_with_kms.rs index 19b0e8d0b199..45bfd183773a 100644 --- a/datafusion-examples/examples/parquet_encrypted_with_kms.rs +++ b/datafusion-examples/examples/parquet_encrypted_with_kms.rs @@ -226,7 +226,7 @@ impl EncryptionFactory for TestEncryptionFactory { options: &EncryptionFactoryOptions, schema: &SchemaRef, _file_path: &Path, - ) -> Result> { + ) -> Result>> { let config: EncryptionConfig = options.to_extension_options()?; // Generate a random encryption key for this file. @@ -268,7 +268,7 @@ impl EncryptionFactory for TestEncryptionFactory { &self, _options: &EncryptionFactoryOptions, _file_path: &Path, - ) -> Result> { + ) -> Result>> { let decryption_properties = FileDecryptionProperties::with_key_retriever(Arc::new(TestKeyRetriever {})) .build()?; diff --git a/datafusion-examples/examples/remote_catalog.rs b/datafusion-examples/examples/remote_catalog.rs index 70c0963545e0..74575554ec0a 100644 --- a/datafusion-examples/examples/remote_catalog.rs +++ b/datafusion-examples/examples/remote_catalog.rs @@ -75,8 +75,8 @@ async fn main() -> Result<()> { let state = ctx.state(); // First, parse the SQL (but don't plan it / resolve any table references) - let dialect = state.config().options().sql_parser.dialect.as_str(); - let statement = state.sql_to_statement(sql, dialect)?; + let dialect = state.config().options().sql_parser.dialect; + let statement = state.sql_to_statement(sql, &dialect)?; // Find all `TableReferences` in the parsed queries. These correspond to the // tables referred to by the query (in this case diff --git a/datafusion-testing b/datafusion-testing index 905df5f65cc9..eccb0e4a4263 160000 --- a/datafusion-testing +++ b/datafusion-testing @@ -1 +1 @@ -Subproject commit 905df5f65cc9d0851719c21f5a4dd5cd77621f19 +Subproject commit eccb0e4a426344ef3faf534cd60e02e9c3afd3ac diff --git a/datafusion/catalog-listing/Cargo.toml b/datafusion/catalog-listing/Cargo.toml index 69f952ae9840..4eaeed675a20 100644 --- a/datafusion/catalog-listing/Cargo.toml +++ b/datafusion/catalog-listing/Cargo.toml @@ -39,14 +39,17 @@ datafusion-datasource = { workspace = true } datafusion-execution = { workspace = true } datafusion-expr = { workspace = true } datafusion-physical-expr = { workspace = true } +datafusion-physical-expr-adapter = { workspace = true } datafusion-physical-expr-common = { workspace = true } datafusion-physical-plan = { workspace = true } futures = { workspace = true } +itertools = { workspace = true } log = { workspace = true } object_store = { workspace = true } tokio = { workspace = true } [dev-dependencies] +datafusion-datasource-parquet = { workspace = true } [lints] workspace = true @@ -54,3 +57,6 @@ workspace = true [lib] name = "datafusion_catalog_listing" path = "src/mod.rs" + +[package.metadata.cargo-machete] +ignored = ["datafusion-datasource-parquet"] diff --git a/datafusion/catalog-listing/src/config.rs b/datafusion/catalog-listing/src/config.rs new file mode 100644 index 000000000000..90f44de4fdbc --- /dev/null +++ b/datafusion/catalog-listing/src/config.rs @@ -0,0 +1,360 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::options::ListingOptions; +use arrow::datatypes::{DataType, Schema, SchemaRef}; +use datafusion_catalog::Session; +use datafusion_common::{config_err, internal_err}; +use datafusion_datasource::file_compression_type::FileCompressionType; +use datafusion_datasource::schema_adapter::SchemaAdapterFactory; +use datafusion_datasource::ListingTableUrl; +use datafusion_physical_expr_adapter::PhysicalExprAdapterFactory; +use std::str::FromStr; +use std::sync::Arc; + +/// Indicates the source of the schema for a [`crate::ListingTable`] +// PartialEq required for assert_eq! in tests +#[derive(Debug, Clone, Copy, PartialEq, Default)] +pub enum SchemaSource { + /// Schema is not yet set (initial state) + #[default] + Unset, + /// Schema was inferred from first table_path + Inferred, + /// Schema was specified explicitly via with_schema + Specified, +} + +/// Configuration for creating a [`crate::ListingTable`] +/// +/// # Schema Evolution Support +/// +/// This configuration supports schema evolution through the optional +/// [`SchemaAdapterFactory`]. You might want to override the default factory when you need: +/// +/// - **Type coercion requirements**: When you need custom logic for converting between +/// different Arrow data types (e.g., Int32 ↔ Int64, Utf8 ↔ LargeUtf8) +/// - **Column mapping**: You need to map columns with a legacy name to a new name +/// - **Custom handling of missing columns**: By default they are filled in with nulls, but you may e.g. want to fill them in with `0` or `""`. +/// +/// If not specified, a [`datafusion_datasource::schema_adapter::DefaultSchemaAdapterFactory`] +/// will be used, which handles basic schema compatibility cases. +/// +#[derive(Debug, Clone, Default)] +pub struct ListingTableConfig { + /// Paths on the `ObjectStore` for creating [`crate::ListingTable`]. + /// They should share the same schema and object store. + pub table_paths: Vec, + /// Optional `SchemaRef` for the to be created [`crate::ListingTable`]. + /// + /// See details on [`ListingTableConfig::with_schema`] + pub file_schema: Option, + /// Optional [`ListingOptions`] for the to be created [`crate::ListingTable`]. + /// + /// See details on [`ListingTableConfig::with_listing_options`] + pub options: Option, + /// Tracks the source of the schema information + pub(crate) schema_source: SchemaSource, + /// Optional [`SchemaAdapterFactory`] for creating schema adapters + pub(crate) schema_adapter_factory: Option>, + /// Optional [`PhysicalExprAdapterFactory`] for creating physical expression adapters + pub(crate) expr_adapter_factory: Option>, +} + +impl ListingTableConfig { + /// Creates new [`ListingTableConfig`] for reading the specified URL + pub fn new(table_path: ListingTableUrl) -> Self { + Self { + table_paths: vec![table_path], + ..Default::default() + } + } + + /// Creates new [`ListingTableConfig`] with multiple table paths. + /// + /// See `ListingTableConfigExt::infer_options` for details on what happens with multiple paths + pub fn new_with_multi_paths(table_paths: Vec) -> Self { + Self { + table_paths, + ..Default::default() + } + } + + /// Returns the source of the schema for this configuration + pub fn schema_source(&self) -> SchemaSource { + self.schema_source + } + /// Set the `schema` for the overall [`crate::ListingTable`] + /// + /// [`crate::ListingTable`] will automatically coerce, when possible, the schema + /// for individual files to match this schema. + /// + /// If a schema is not provided, it is inferred using + /// [`Self::infer_schema`]. + /// + /// If the schema is provided, it must contain only the fields in the file + /// without the table partitioning columns. + /// + /// # Example: Specifying Table Schema + /// ```rust + /// # use std::sync::Arc; + /// # use datafusion_catalog_listing::{ListingTableConfig, ListingOptions}; + /// # use datafusion_datasource::ListingTableUrl; + /// # use datafusion_datasource_parquet::file_format::ParquetFormat; + /// # use arrow::datatypes::{Schema, Field, DataType}; + /// # let table_paths = ListingTableUrl::parse("file:///path/to/data").unwrap(); + /// # let listing_options = ListingOptions::new(Arc::new(ParquetFormat::default())); + /// let schema = Arc::new(Schema::new(vec![ + /// Field::new("id", DataType::Int64, false), + /// Field::new("name", DataType::Utf8, true), + /// ])); + /// + /// let config = ListingTableConfig::new(table_paths) + /// .with_listing_options(listing_options) // Set options first + /// .with_schema(schema); // Then set schema + /// ``` + pub fn with_schema(self, schema: SchemaRef) -> Self { + // Note: We preserve existing options state, but downstream code may expect + // options to be set. Consider calling with_listing_options() or infer_options() + // before operations that require options to be present. + debug_assert!( + self.options.is_some() || cfg!(test), + "ListingTableConfig::with_schema called without options set. \ + Consider calling with_listing_options() or infer_options() first to avoid panics in downstream code." + ); + + Self { + file_schema: Some(schema), + schema_source: SchemaSource::Specified, + ..self + } + } + + /// Add `listing_options` to [`ListingTableConfig`] + /// + /// If not provided, format and other options are inferred via + /// `ListingTableConfigExt::infer_options`. + /// + /// # Example: Configuring Parquet Files with Custom Options + /// ```rust + /// # use std::sync::Arc; + /// # use datafusion_catalog_listing::{ListingTableConfig, ListingOptions}; + /// # use datafusion_datasource::ListingTableUrl; + /// # use datafusion_datasource_parquet::file_format::ParquetFormat; + /// # let table_paths = ListingTableUrl::parse("file:///path/to/data").unwrap(); + /// let options = ListingOptions::new(Arc::new(ParquetFormat::default())) + /// .with_file_extension(".parquet") + /// .with_collect_stat(true); + /// + /// let config = ListingTableConfig::new(table_paths) + /// .with_listing_options(options); // Configure file format and options + /// ``` + pub fn with_listing_options(self, listing_options: ListingOptions) -> Self { + // Note: This method properly sets options, but be aware that downstream + // methods like infer_schema() and try_new() require both schema and options + // to be set to function correctly. + debug_assert!( + !self.table_paths.is_empty() || cfg!(test), + "ListingTableConfig::with_listing_options called without table_paths set. \ + Consider calling new() or new_with_multi_paths() first to establish table paths." + ); + + Self { + options: Some(listing_options), + ..self + } + } + + /// Returns a tuple of `(file_extension, optional compression_extension)` + /// + /// For example a path ending with blah.test.csv.gz returns `("csv", Some("gz"))` + /// For example a path ending with blah.test.csv returns `("csv", None)` + pub fn infer_file_extension_and_compression_type( + path: &str, + ) -> datafusion_common::Result<(String, Option)> { + let mut exts = path.rsplit('.'); + + let split = exts.next().unwrap_or(""); + + let file_compression_type = FileCompressionType::from_str(split) + .unwrap_or(FileCompressionType::UNCOMPRESSED); + + if file_compression_type.is_compressed() { + let split2 = exts.next().unwrap_or(""); + Ok((split2.to_string(), Some(split.to_string()))) + } else { + Ok((split.to_string(), None)) + } + } + + /// Infer the [`SchemaRef`] based on `table_path`s. + /// + /// This method infers the table schema using the first `table_path`. + /// See [`ListingOptions::infer_schema`] for more details + /// + /// # Errors + /// * if `self.options` is not set. See [`Self::with_listing_options`] + pub async fn infer_schema( + self, + state: &dyn Session, + ) -> datafusion_common::Result { + match self.options { + Some(options) => { + let ListingTableConfig { + table_paths, + file_schema, + options: _, + schema_source, + schema_adapter_factory, + expr_adapter_factory: physical_expr_adapter_factory, + } = self; + + let (schema, new_schema_source) = match file_schema { + Some(schema) => (schema, schema_source), // Keep existing source if schema exists + None => { + if let Some(url) = table_paths.first() { + ( + options.infer_schema(state, url).await?, + SchemaSource::Inferred, + ) + } else { + (Arc::new(Schema::empty()), SchemaSource::Inferred) + } + } + }; + + Ok(Self { + table_paths, + file_schema: Some(schema), + options: Some(options), + schema_source: new_schema_source, + schema_adapter_factory, + expr_adapter_factory: physical_expr_adapter_factory, + }) + } + None => internal_err!("No `ListingOptions` set for inferring schema"), + } + } + + /// Infer the partition columns from `table_paths`. + /// + /// # Errors + /// * if `self.options` is not set. See [`Self::with_listing_options`] + pub async fn infer_partitions_from_path( + self, + state: &dyn Session, + ) -> datafusion_common::Result { + match self.options { + Some(options) => { + let Some(url) = self.table_paths.first() else { + return config_err!("No table path found"); + }; + let partitions = options + .infer_partitions(state, url) + .await? + .into_iter() + .map(|col_name| { + ( + col_name, + DataType::Dictionary( + Box::new(DataType::UInt16), + Box::new(DataType::Utf8), + ), + ) + }) + .collect::>(); + let options = options.with_table_partition_cols(partitions); + Ok(Self { + table_paths: self.table_paths, + file_schema: self.file_schema, + options: Some(options), + schema_source: self.schema_source, + schema_adapter_factory: self.schema_adapter_factory, + expr_adapter_factory: self.expr_adapter_factory, + }) + } + None => config_err!("No `ListingOptions` set for inferring schema"), + } + } + + /// Set the [`SchemaAdapterFactory`] for the [`crate::ListingTable`] + /// + /// The schema adapter factory is used to create schema adapters that can + /// handle schema evolution and type conversions when reading files with + /// different schemas than the table schema. + /// + /// If not provided, a default schema adapter factory will be used. + /// + /// # Example: Custom Schema Adapter for Type Coercion + /// ```rust + /// # use std::sync::Arc; + /// # use datafusion_catalog_listing::{ListingTableConfig, ListingOptions}; + /// # use datafusion_datasource::schema_adapter::{SchemaAdapterFactory, SchemaAdapter}; + /// # use datafusion_datasource::ListingTableUrl; + /// # use datafusion_datasource_parquet::file_format::ParquetFormat; + /// # use arrow::datatypes::{SchemaRef, Schema, Field, DataType}; + /// # + /// # #[derive(Debug)] + /// # struct MySchemaAdapterFactory; + /// # impl SchemaAdapterFactory for MySchemaAdapterFactory { + /// # fn create(&self, _projected_table_schema: SchemaRef, _file_schema: SchemaRef) -> Box { + /// # unimplemented!() + /// # } + /// # } + /// # let table_paths = ListingTableUrl::parse("file:///path/to/data").unwrap(); + /// # let listing_options = ListingOptions::new(Arc::new(ParquetFormat::default())); + /// # let table_schema = Arc::new(Schema::new(vec![Field::new("id", DataType::Int64, false)])); + /// let config = ListingTableConfig::new(table_paths) + /// .with_listing_options(listing_options) + /// .with_schema(table_schema) + /// .with_schema_adapter_factory(Arc::new(MySchemaAdapterFactory)); + /// ``` + pub fn with_schema_adapter_factory( + self, + schema_adapter_factory: Arc, + ) -> Self { + Self { + schema_adapter_factory: Some(schema_adapter_factory), + ..self + } + } + + /// Get the [`SchemaAdapterFactory`] for this configuration + pub fn schema_adapter_factory(&self) -> Option<&Arc> { + self.schema_adapter_factory.as_ref() + } + + /// Set the [`PhysicalExprAdapterFactory`] for the [`crate::ListingTable`] + /// + /// The expression adapter factory is used to create physical expression adapters that can + /// handle schema evolution and type conversions when evaluating expressions + /// with different schemas than the table schema. + /// + /// If not provided, a default physical expression adapter factory will be used unless a custom + /// `SchemaAdapterFactory` is set, in which case only the `SchemaAdapterFactory` will be used. + /// + /// See for details on this transition. + pub fn with_expr_adapter_factory( + self, + expr_adapter_factory: Arc, + ) -> Self { + Self { + expr_adapter_factory: Some(expr_adapter_factory), + ..self + } + } +} diff --git a/datafusion/catalog-listing/src/helpers.rs b/datafusion/catalog-listing/src/helpers.rs index 00e9c71df348..82cc36867939 100644 --- a/datafusion/catalog-listing/src/helpers.rs +++ b/datafusion/catalog-listing/src/helpers.rs @@ -156,6 +156,7 @@ pub fn split_files( chunks } +#[derive(Debug)] pub struct Partition { /// The path to the partition, including the table prefix path: Path, @@ -245,7 +246,16 @@ async fn prune_partitions( partition_cols: &[(String, DataType)], ) -> Result> { if filters.is_empty() { - return Ok(partitions); + // prune partitions which don't contain the partition columns + return Ok(partitions + .into_iter() + .filter(|p| { + let cols = partition_cols.iter().map(|x| x.0.as_str()); + !parse_partitions_for_path(table_path, &p.path, cols) + .unwrap_or_default() + .is_empty() + }) + .collect()); } let mut builders: Vec<_> = (0..partition_cols.len()) @@ -432,6 +442,7 @@ pub async fn pruned_partition_list<'a>( } let partition_prefix = evaluate_partition_prefix(partition_cols, filters); + let partitions = list_partitions(store, table_path, partition_cols.len(), partition_prefix) .await?; @@ -502,12 +513,12 @@ where let subpath = table_path.strip_prefix(file_path)?; let mut part_values = vec![]; - for (part, pn) in subpath.zip(table_partition_cols) { + for (part, expected_partition) in subpath.zip(table_partition_cols) { match part.split_once('=') { - Some((name, val)) if name == pn => part_values.push(val), + Some((name, val)) if name == expected_partition => part_values.push(val), _ => { debug!( - "Ignoring file: file_path='{file_path}', table_path='{table_path}', part='{part}', partition_col='{pn}'", + "Ignoring file: file_path='{file_path}', table_path='{table_path}', part='{part}', partition_col='{expected_partition}'", ); return None; } @@ -594,6 +605,8 @@ mod tests { ("tablepath/mypartition=val1/notparquetfile", 100), ("tablepath/mypartition=val1/ignoresemptyfile.parquet", 0), ("tablepath/file.parquet", 100), + ("tablepath/notapartition/file.parquet", 100), + ("tablepath/notmypartition=val1/file.parquet", 100), ]); let filter = Expr::eq(col("mypartition"), lit("val1")); let pruned = pruned_partition_list( @@ -619,6 +632,8 @@ mod tests { ("tablepath/mypartition=val2/file.parquet", 100), ("tablepath/mypartition=val1/ignoresemptyfile.parquet", 0), ("tablepath/mypartition=val1/other=val3/file.parquet", 100), + ("tablepath/notapartition/file.parquet", 100), + ("tablepath/notmypartition=val1/file.parquet", 100), ]); let filter = Expr::eq(col("mypartition"), lit("val1")); let pruned = pruned_partition_list( diff --git a/datafusion/catalog-listing/src/mod.rs b/datafusion/catalog-listing/src/mod.rs index 1322577b207a..90d04b46b806 100644 --- a/datafusion/catalog-listing/src/mod.rs +++ b/datafusion/catalog-listing/src/mod.rs @@ -24,4 +24,11 @@ // https://github.com/apache/datafusion/issues/11143 #![cfg_attr(not(test), deny(clippy::clone_on_ref_ptr))] +mod config; pub mod helpers; +mod options; +mod table; + +pub use config::{ListingTableConfig, SchemaSource}; +pub use options::ListingOptions; +pub use table::ListingTable; diff --git a/datafusion/catalog-listing/src/options.rs b/datafusion/catalog-listing/src/options.rs new file mode 100644 index 000000000000..3cbf3573e951 --- /dev/null +++ b/datafusion/catalog-listing/src/options.rs @@ -0,0 +1,411 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::datatypes::{DataType, SchemaRef}; +use datafusion_catalog::Session; +use datafusion_common::plan_err; +use datafusion_datasource::file_format::FileFormat; +use datafusion_datasource::ListingTableUrl; +use datafusion_execution::config::SessionConfig; +use datafusion_expr::SortExpr; +use futures::StreamExt; +use futures::{future, TryStreamExt}; +use itertools::Itertools; +use std::sync::Arc; + +/// Options for creating a [`crate::ListingTable`] +#[derive(Clone, Debug)] +pub struct ListingOptions { + /// A suffix on which files should be filtered (leave empty to + /// keep all files on the path) + pub file_extension: String, + /// The file format + pub format: Arc, + /// The expected partition column names in the folder structure. + /// See [Self::with_table_partition_cols] for details + pub table_partition_cols: Vec<(String, DataType)>, + /// Set true to try to guess statistics from the files. + /// This can add a lot of overhead as it will usually require files + /// to be opened and at least partially parsed. + pub collect_stat: bool, + /// Group files to avoid that the number of partitions exceeds + /// this limit + pub target_partitions: usize, + /// Optional pre-known sort order(s). Must be `SortExpr`s. + /// + /// DataFusion may take advantage of this ordering to omit sorts + /// or use more efficient algorithms. Currently sortedness must be + /// provided if it is known by some external mechanism, but may in + /// the future be automatically determined, for example using + /// parquet metadata. + /// + /// See + /// + /// NOTE: This attribute stores all equivalent orderings (the outer `Vec`) + /// where each ordering consists of an individual lexicographic + /// ordering (encapsulated by a `Vec`). If there aren't + /// multiple equivalent orderings, the outer `Vec` will have a + /// single element. + pub file_sort_order: Vec>, +} + +impl ListingOptions { + /// Creates an options instance with the given format + /// Default values: + /// - use default file extension filter + /// - no input partition to discover + /// - one target partition + /// - do not collect statistics + pub fn new(format: Arc) -> Self { + Self { + file_extension: format.get_ext(), + format, + table_partition_cols: vec![], + collect_stat: false, + target_partitions: 1, + file_sort_order: vec![], + } + } + + /// Set options from [`SessionConfig`] and returns self. + /// + /// Currently this sets `target_partitions` and `collect_stat` + /// but if more options are added in the future that need to be coordinated + /// they will be synchronized through this method. + pub fn with_session_config_options(mut self, config: &SessionConfig) -> Self { + self = self.with_target_partitions(config.target_partitions()); + self = self.with_collect_stat(config.collect_statistics()); + self + } + + /// Set file extension on [`ListingOptions`] and returns self. + /// + /// # Example + /// ``` + /// # use std::sync::Arc; + /// # use datafusion_catalog_listing::ListingOptions; + /// # use datafusion_datasource_parquet::file_format::ParquetFormat; + /// + /// let listing_options = ListingOptions::new(Arc::new( + /// ParquetFormat::default() + /// )) + /// .with_file_extension(".parquet"); + /// + /// assert_eq!(listing_options.file_extension, ".parquet"); + /// ``` + pub fn with_file_extension(mut self, file_extension: impl Into) -> Self { + self.file_extension = file_extension.into(); + self + } + + /// Optionally set file extension on [`ListingOptions`] and returns self. + /// + /// If `file_extension` is `None`, the file extension will not be changed + /// + /// # Example + /// ``` + /// # use std::sync::Arc; + /// # use datafusion_catalog_listing::ListingOptions; + /// # use datafusion_datasource_parquet::file_format::ParquetFormat; + /// + /// let extension = Some(".parquet"); + /// let listing_options = ListingOptions::new(Arc::new( + /// ParquetFormat::default() + /// )) + /// .with_file_extension_opt(extension); + /// + /// assert_eq!(listing_options.file_extension, ".parquet"); + /// ``` + pub fn with_file_extension_opt(mut self, file_extension: Option) -> Self + where + S: Into, + { + if let Some(file_extension) = file_extension { + self.file_extension = file_extension.into(); + } + self + } + + /// Set `table partition columns` on [`ListingOptions`] and returns self. + /// + /// "partition columns," used to support [Hive Partitioning], are + /// columns added to the data that is read, based on the folder + /// structure where the data resides. + /// + /// For example, give the following files in your filesystem: + /// + /// ```text + /// /mnt/nyctaxi/year=2022/month=01/tripdata.parquet + /// /mnt/nyctaxi/year=2021/month=12/tripdata.parquet + /// /mnt/nyctaxi/year=2021/month=11/tripdata.parquet + /// ``` + /// + /// A [`crate::ListingTable`] created at `/mnt/nyctaxi/` with partition + /// columns "year" and "month" will include new `year` and `month` + /// columns while reading the files. The `year` column would have + /// value `2022` and the `month` column would have value `01` for + /// the rows read from + /// `/mnt/nyctaxi/year=2022/month=01/tripdata.parquet` + /// + ///# Notes + /// + /// - If only one level (e.g. `year` in the example above) is + /// specified, the other levels are ignored but the files are + /// still read. + /// + /// - Files that don't follow this partitioning scheme will be + /// ignored. + /// + /// - Since the columns have the same value for all rows read from + /// each individual file (such as dates), they are typically + /// dictionary encoded for efficiency. You may use + /// [`wrap_partition_type_in_dict`] to request a + /// dictionary-encoded type. + /// + /// - The partition columns are solely extracted from the file path. Especially they are NOT part of the parquet files itself. + /// + /// # Example + /// + /// ``` + /// # use std::sync::Arc; + /// # use arrow::datatypes::DataType; + /// # use datafusion_expr::col; + /// # use datafusion_catalog_listing::ListingOptions; + /// # use datafusion_datasource_parquet::file_format::ParquetFormat; + /// + /// // listing options for files with paths such as `/mnt/data/col_a=x/col_b=y/data.parquet` + /// // `col_a` and `col_b` will be included in the data read from those files + /// let listing_options = ListingOptions::new(Arc::new( + /// ParquetFormat::default() + /// )) + /// .with_table_partition_cols(vec![("col_a".to_string(), DataType::Utf8), + /// ("col_b".to_string(), DataType::Utf8)]); + /// + /// assert_eq!(listing_options.table_partition_cols, vec![("col_a".to_string(), DataType::Utf8), + /// ("col_b".to_string(), DataType::Utf8)]); + /// ``` + /// + /// [Hive Partitioning]: https://docs.cloudera.com/HDPDocuments/HDP2/HDP-2.1.3/bk_system-admin-guide/content/hive_partitioned_tables.html + /// [`wrap_partition_type_in_dict`]: datafusion_datasource::file_scan_config::wrap_partition_type_in_dict + pub fn with_table_partition_cols( + mut self, + table_partition_cols: Vec<(String, DataType)>, + ) -> Self { + self.table_partition_cols = table_partition_cols; + self + } + + /// Set stat collection on [`ListingOptions`] and returns self. + /// + /// ``` + /// # use std::sync::Arc; + /// # use datafusion_catalog_listing::ListingOptions; + /// # use datafusion_datasource_parquet::file_format::ParquetFormat; + /// + /// let listing_options = ListingOptions::new(Arc::new( + /// ParquetFormat::default() + /// )) + /// .with_collect_stat(true); + /// + /// assert_eq!(listing_options.collect_stat, true); + /// ``` + pub fn with_collect_stat(mut self, collect_stat: bool) -> Self { + self.collect_stat = collect_stat; + self + } + + /// Set number of target partitions on [`ListingOptions`] and returns self. + /// + /// ``` + /// # use std::sync::Arc; + /// # use datafusion_catalog_listing::ListingOptions; + /// # use datafusion_datasource_parquet::file_format::ParquetFormat; + /// + /// let listing_options = ListingOptions::new(Arc::new( + /// ParquetFormat::default() + /// )) + /// .with_target_partitions(8); + /// + /// assert_eq!(listing_options.target_partitions, 8); + /// ``` + pub fn with_target_partitions(mut self, target_partitions: usize) -> Self { + self.target_partitions = target_partitions; + self + } + + /// Set file sort order on [`ListingOptions`] and returns self. + /// + /// ``` + /// # use std::sync::Arc; + /// # use datafusion_expr::col; + /// # use datafusion_catalog_listing::ListingOptions; + /// # use datafusion_datasource_parquet::file_format::ParquetFormat; + /// + /// // Tell datafusion that the files are sorted by column "a" + /// let file_sort_order = vec![vec![ + /// col("a").sort(true, true) + /// ]]; + /// + /// let listing_options = ListingOptions::new(Arc::new( + /// ParquetFormat::default() + /// )) + /// .with_file_sort_order(file_sort_order.clone()); + /// + /// assert_eq!(listing_options.file_sort_order, file_sort_order); + /// ``` + pub fn with_file_sort_order(mut self, file_sort_order: Vec>) -> Self { + self.file_sort_order = file_sort_order; + self + } + + /// Infer the schema of the files at the given path on the provided object store. + /// + /// If the table_path contains one or more files (i.e. it is a directory / + /// prefix of files) their schema is merged by calling [`FileFormat::infer_schema`] + /// + /// Note: The inferred schema does not include any partitioning columns. + /// + /// This method is called as part of creating a [`crate::ListingTable`]. + pub async fn infer_schema<'a>( + &'a self, + state: &dyn Session, + table_path: &'a ListingTableUrl, + ) -> datafusion_common::Result { + let store = state.runtime_env().object_store(table_path)?; + + let files: Vec<_> = table_path + .list_all_files(state, store.as_ref(), &self.file_extension) + .await? + // Empty files cannot affect schema but may throw when trying to read for it + .try_filter(|object_meta| future::ready(object_meta.size > 0)) + .try_collect() + .await?; + + let schema = self.format.infer_schema(state, &store, &files).await?; + + Ok(schema) + } + + /// Infers the partition columns stored in `LOCATION` and compares + /// them with the columns provided in `PARTITIONED BY` to help prevent + /// accidental corrupts of partitioned tables. + /// + /// Allows specifying partial partitions. + pub async fn validate_partitions( + &self, + state: &dyn Session, + table_path: &ListingTableUrl, + ) -> datafusion_common::Result<()> { + if self.table_partition_cols.is_empty() { + return Ok(()); + } + + if !table_path.is_collection() { + return plan_err!( + "Can't create a partitioned table backed by a single file, \ + perhaps the URL is missing a trailing slash?" + ); + } + + let inferred = self.infer_partitions(state, table_path).await?; + + // no partitioned files found on disk + if inferred.is_empty() { + return Ok(()); + } + + let table_partition_names = self + .table_partition_cols + .iter() + .map(|(col_name, _)| col_name.clone()) + .collect_vec(); + + if inferred.len() < table_partition_names.len() { + return plan_err!( + "Inferred partitions to be {:?}, but got {:?}", + inferred, + table_partition_names + ); + } + + // match prefix to allow creating tables with partial partitions + for (idx, col) in table_partition_names.iter().enumerate() { + if &inferred[idx] != col { + return plan_err!( + "Inferred partitions to be {:?}, but got {:?}", + inferred, + table_partition_names + ); + } + } + + Ok(()) + } + + /// Infer the partitioning at the given path on the provided object store. + /// For performance reasons, it doesn't read all the files on disk + /// and therefore may fail to detect invalid partitioning. + pub async fn infer_partitions( + &self, + state: &dyn Session, + table_path: &ListingTableUrl, + ) -> datafusion_common::Result> { + let store = state.runtime_env().object_store(table_path)?; + + // only use 10 files for inference + // This can fail to detect inconsistent partition keys + // A DFS traversal approach of the store can help here + let files: Vec<_> = table_path + .list_all_files(state, store.as_ref(), &self.file_extension) + .await? + .take(10) + .try_collect() + .await?; + + let stripped_path_parts = files.iter().map(|file| { + table_path + .strip_prefix(&file.location) + .unwrap() + .collect_vec() + }); + + let partition_keys = stripped_path_parts + .map(|path_parts| { + path_parts + .into_iter() + .rev() + .skip(1) // get parents only; skip the file itself + .rev() + // Partitions are expected to follow the format "column_name=value", so we + // should ignore any path part that cannot be parsed into the expected format + .filter(|s| s.contains('=')) + .map(|s| s.split('=').take(1).collect()) + .collect_vec() + }) + .collect_vec(); + + match partition_keys.into_iter().all_equal_value() { + Ok(v) => Ok(v), + Err(None) => Ok(vec![]), + Err(Some(diff)) => { + let mut sorted_diff = [diff.0, diff.1]; + sorted_diff.sort(); + plan_err!("Found mixed partition values on disk {:?}", sorted_diff) + } + } + } +} diff --git a/datafusion/catalog-listing/src/table.rs b/datafusion/catalog-listing/src/table.rs new file mode 100644 index 000000000000..e9ac1bf097a2 --- /dev/null +++ b/datafusion/catalog-listing/src/table.rs @@ -0,0 +1,788 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::config::SchemaSource; +use crate::helpers::{expr_applicable_for_cols, pruned_partition_list}; +use crate::{ListingOptions, ListingTableConfig}; +use arrow::datatypes::{Field, Schema, SchemaBuilder, SchemaRef}; +use async_trait::async_trait; +use datafusion_catalog::{ScanArgs, ScanResult, Session, TableProvider}; +use datafusion_common::stats::Precision; +use datafusion_common::{ + internal_datafusion_err, plan_err, project_schema, Constraints, DataFusionError, + SchemaExt, Statistics, +}; +use datafusion_datasource::file::FileSource; +use datafusion_datasource::file_groups::FileGroup; +use datafusion_datasource::file_scan_config::{FileScanConfig, FileScanConfigBuilder}; +use datafusion_datasource::file_sink_config::FileSinkConfig; +use datafusion_datasource::schema_adapter::{ + DefaultSchemaAdapterFactory, SchemaAdapter, SchemaAdapterFactory, +}; +use datafusion_datasource::{ + compute_all_files_statistics, ListingTableUrl, PartitionedFile, +}; +use datafusion_execution::cache::cache_manager::FileStatisticsCache; +use datafusion_execution::cache::cache_unit::DefaultFileStatisticsCache; +use datafusion_expr::dml::InsertOp; +use datafusion_expr::execution_props::ExecutionProps; +use datafusion_expr::{Expr, TableProviderFilterPushDown, TableType}; +use datafusion_physical_expr::create_lex_ordering; +use datafusion_physical_expr_adapter::PhysicalExprAdapterFactory; +use datafusion_physical_expr_common::sort_expr::LexOrdering; +use datafusion_physical_plan::empty::EmptyExec; +use datafusion_physical_plan::ExecutionPlan; +use futures::{future, stream, Stream, StreamExt, TryStreamExt}; +use object_store::ObjectStore; +use std::any::Any; +use std::collections::HashMap; +use std::sync::Arc; + +/// Built in [`TableProvider`] that reads data from one or more files as a single table. +/// +/// The files are read using an [`ObjectStore`] instance, for example from +/// local files or objects from AWS S3. +/// +/// # Features: +/// * Reading multiple files as a single table +/// * Hive style partitioning (e.g., directories named `date=2024-06-01`) +/// * Merges schemas from files with compatible but not identical schemas (see [`ListingTableConfig::file_schema`]) +/// * `limit`, `filter` and `projection` pushdown for formats that support it (e.g., +/// Parquet) +/// * Statistics collection and pruning based on file metadata +/// * Pre-existing sort order (see [`ListingOptions::file_sort_order`]) +/// * Metadata caching to speed up repeated queries (see [`FileMetadataCache`]) +/// * Statistics caching (see [`FileStatisticsCache`]) +/// +/// [`FileMetadataCache`]: datafusion_execution::cache::cache_manager::FileMetadataCache +/// +/// # Reading Directories and Hive Style Partitioning +/// +/// For example, given the `table1` directory (or object store prefix) +/// +/// ```text +/// table1 +/// ├── file1.parquet +/// └── file2.parquet +/// ``` +/// +/// A `ListingTable` would read the files `file1.parquet` and `file2.parquet` as +/// a single table, merging the schemas if the files have compatible but not +/// identical schemas. +/// +/// Given the `table2` directory (or object store prefix) +/// +/// ```text +/// table2 +/// ├── date=2024-06-01 +/// │ ├── file3.parquet +/// │ └── file4.parquet +/// └── date=2024-06-02 +/// └── file5.parquet +/// ``` +/// +/// A `ListingTable` would read the files `file3.parquet`, `file4.parquet`, and +/// `file5.parquet` as a single table, again merging schemas if necessary. +/// +/// Given the hive style partitioning structure (e.g,. directories named +/// `date=2024-06-01` and `date=2026-06-02`), `ListingTable` also adds a `date` +/// column when reading the table: +/// * The files in `table2/date=2024-06-01` will have the value `2024-06-01` +/// * The files in `table2/date=2024-06-02` will have the value `2024-06-02`. +/// +/// If the query has a predicate like `WHERE date = '2024-06-01'` +/// only the corresponding directory will be read. +/// +/// # See Also +/// +/// 1. [`ListingTableConfig`]: Configuration options +/// 1. [`DataSourceExec`]: `ExecutionPlan` used by `ListingTable` +/// +/// [`DataSourceExec`]: datafusion_datasource::source::DataSourceExec +/// +/// # Caching Metadata +/// +/// Some formats, such as Parquet, use the `FileMetadataCache` to cache file +/// metadata that is needed to execute but expensive to read, such as row +/// groups and statistics. The cache is scoped to the `SessionContext` and can +/// be configured via the [runtime config options]. +/// +/// [runtime config options]: https://datafusion.apache.org/user-guide/configs.html#runtime-configuration-settings +/// +/// # Example: Read a directory of parquet files using a [`ListingTable`] +/// +/// ```no_run +/// # use datafusion_common::Result; +/// # use std::sync::Arc; +/// # use datafusion_catalog::TableProvider; +/// # use datafusion_catalog_listing::{ListingOptions, ListingTable, ListingTableConfig}; +/// # use datafusion_datasource::ListingTableUrl; +/// # use datafusion_datasource_parquet::file_format::ParquetFormat;/// # +/// # use datafusion_catalog::Session; +/// async fn get_listing_table(session: &dyn Session) -> Result> { +/// let table_path = "/path/to/parquet"; +/// +/// // Parse the path +/// let table_path = ListingTableUrl::parse(table_path)?; +/// +/// // Create default parquet options +/// let file_format = ParquetFormat::new(); +/// let listing_options = ListingOptions::new(Arc::new(file_format)) +/// .with_file_extension(".parquet"); +/// +/// // Resolve the schema +/// let resolved_schema = listing_options +/// .infer_schema(session, &table_path) +/// .await?; +/// +/// let config = ListingTableConfig::new(table_path) +/// .with_listing_options(listing_options) +/// .with_schema(resolved_schema); +/// +/// // Create a new TableProvider +/// let provider = Arc::new(ListingTable::try_new(config)?); +/// +/// # Ok(provider) +/// # } +/// ``` +#[derive(Debug, Clone)] +pub struct ListingTable { + table_paths: Vec, + /// `file_schema` contains only the columns physically stored in the data files themselves. + /// - Represents the actual fields found in files like Parquet, CSV, etc. + /// - Used when reading the raw data from files + file_schema: SchemaRef, + /// `table_schema` combines `file_schema` + partition columns + /// - Partition columns are derived from directory paths (not stored in files) + /// - These are columns like "year=2022/month=01" in paths like `/data/year=2022/month=01/file.parquet` + table_schema: SchemaRef, + /// Indicates how the schema was derived (inferred or explicitly specified) + schema_source: SchemaSource, + /// Options used to configure the listing table such as the file format + /// and partitioning information + options: ListingOptions, + /// The SQL definition for this table, if any + definition: Option, + /// Cache for collected file statistics + collected_statistics: FileStatisticsCache, + /// Constraints applied to this table + constraints: Constraints, + /// Column default expressions for columns that are not physically present in the data files + column_defaults: HashMap, + /// Optional [`SchemaAdapterFactory`] for creating schema adapters + schema_adapter_factory: Option>, + /// Optional [`PhysicalExprAdapterFactory`] for creating physical expression adapters + expr_adapter_factory: Option>, +} + +impl ListingTable { + /// Create new [`ListingTable`] + /// + /// See documentation and example on [`ListingTable`] and [`ListingTableConfig`] + pub fn try_new(config: ListingTableConfig) -> datafusion_common::Result { + // Extract schema_source before moving other parts of the config + let schema_source = config.schema_source(); + + let file_schema = config + .file_schema + .ok_or_else(|| internal_datafusion_err!("No schema provided."))?; + + let options = config + .options + .ok_or_else(|| internal_datafusion_err!("No ListingOptions provided"))?; + + // Add the partition columns to the file schema + let mut builder = SchemaBuilder::from(file_schema.as_ref().to_owned()); + for (part_col_name, part_col_type) in &options.table_partition_cols { + builder.push(Field::new(part_col_name, part_col_type.clone(), false)); + } + + let table_schema = Arc::new( + builder + .finish() + .with_metadata(file_schema.metadata().clone()), + ); + + let table = Self { + table_paths: config.table_paths, + file_schema, + table_schema, + schema_source, + options, + definition: None, + collected_statistics: Arc::new(DefaultFileStatisticsCache::default()), + constraints: Constraints::default(), + column_defaults: HashMap::new(), + schema_adapter_factory: config.schema_adapter_factory, + expr_adapter_factory: config.expr_adapter_factory, + }; + + Ok(table) + } + + /// Assign constraints + pub fn with_constraints(mut self, constraints: Constraints) -> Self { + self.constraints = constraints; + self + } + + /// Assign column defaults + pub fn with_column_defaults( + mut self, + column_defaults: HashMap, + ) -> Self { + self.column_defaults = column_defaults; + self + } + + /// Set the [`FileStatisticsCache`] used to cache parquet file statistics. + /// + /// Setting a statistics cache on the `SessionContext` can avoid refetching statistics + /// multiple times in the same session. + /// + /// If `None`, creates a new [`DefaultFileStatisticsCache`] scoped to this query. + pub fn with_cache(mut self, cache: Option) -> Self { + self.collected_statistics = + cache.unwrap_or_else(|| Arc::new(DefaultFileStatisticsCache::default())); + self + } + + /// Specify the SQL definition for this table, if any + pub fn with_definition(mut self, definition: Option) -> Self { + self.definition = definition; + self + } + + /// Get paths ref + pub fn table_paths(&self) -> &Vec { + &self.table_paths + } + + /// Get options ref + pub fn options(&self) -> &ListingOptions { + &self.options + } + + /// Get the schema source + pub fn schema_source(&self) -> SchemaSource { + self.schema_source + } + + /// Set the [`SchemaAdapterFactory`] for this [`ListingTable`] + /// + /// The schema adapter factory is used to create schema adapters that can + /// handle schema evolution and type conversions when reading files with + /// different schemas than the table schema. + /// + /// # Example: Adding Schema Evolution Support + /// ```rust + /// # use std::sync::Arc; + /// # use datafusion_catalog_listing::{ListingTable, ListingTableConfig, ListingOptions}; + /// # use datafusion_datasource::ListingTableUrl; + /// # use datafusion_datasource::schema_adapter::{DefaultSchemaAdapterFactory, SchemaAdapter}; + /// # use datafusion_datasource_parquet::file_format::ParquetFormat; + /// # use arrow::datatypes::{SchemaRef, Schema, Field, DataType}; + /// # let table_path = ListingTableUrl::parse("file:///path/to/data").unwrap(); + /// # let options = ListingOptions::new(Arc::new(ParquetFormat::default())); + /// # let schema = Arc::new(Schema::new(vec![Field::new("id", DataType::Int64, false)])); + /// # let config = ListingTableConfig::new(table_path).with_listing_options(options).with_schema(schema); + /// # let table = ListingTable::try_new(config).unwrap(); + /// let table_with_evolution = table + /// .with_schema_adapter_factory(Arc::new(DefaultSchemaAdapterFactory)); + /// ``` + /// See [`ListingTableConfig::with_schema_adapter_factory`] for an example of custom SchemaAdapterFactory. + pub fn with_schema_adapter_factory( + self, + schema_adapter_factory: Arc, + ) -> Self { + Self { + schema_adapter_factory: Some(schema_adapter_factory), + ..self + } + } + + /// Get the [`SchemaAdapterFactory`] for this table + pub fn schema_adapter_factory(&self) -> Option<&Arc> { + self.schema_adapter_factory.as_ref() + } + + /// Creates a schema adapter for mapping between file and table schemas + /// + /// Uses the configured schema adapter factory if available, otherwise falls back + /// to the default implementation. + fn create_schema_adapter(&self) -> Box { + let table_schema = self.schema(); + match &self.schema_adapter_factory { + Some(factory) => { + factory.create_with_projected_schema(Arc::clone(&table_schema)) + } + None => DefaultSchemaAdapterFactory::from_schema(Arc::clone(&table_schema)), + } + } + + /// Creates a file source and applies schema adapter factory if available + fn create_file_source_with_schema_adapter( + &self, + ) -> datafusion_common::Result> { + let mut source = self.options.format.file_source(); + // Apply schema adapter to source if available + // + // The source will use this SchemaAdapter to adapt data batches as they flow up the plan. + // Note: ListingTable also creates a SchemaAdapter in `scan()` but that is only used to adapt collected statistics. + if let Some(factory) = &self.schema_adapter_factory { + source = source.with_schema_adapter_factory(Arc::clone(factory))?; + } + Ok(source) + } + + /// If file_sort_order is specified, creates the appropriate physical expressions + pub fn try_create_output_ordering( + &self, + execution_props: &ExecutionProps, + ) -> datafusion_common::Result> { + create_lex_ordering( + &self.table_schema, + &self.options.file_sort_order, + execution_props, + ) + } +} + +// Expressions can be used for partition pruning if they can be evaluated using +// only the partition columns and there are partition columns. +fn can_be_evaluated_for_partition_pruning( + partition_column_names: &[&str], + expr: &Expr, +) -> bool { + !partition_column_names.is_empty() + && expr_applicable_for_cols(partition_column_names, expr) +} + +#[async_trait] +impl TableProvider for ListingTable { + fn as_any(&self) -> &dyn Any { + self + } + + fn schema(&self) -> SchemaRef { + Arc::clone(&self.table_schema) + } + + fn constraints(&self) -> Option<&Constraints> { + Some(&self.constraints) + } + + fn table_type(&self) -> TableType { + TableType::Base + } + + async fn scan( + &self, + state: &dyn Session, + projection: Option<&Vec>, + filters: &[Expr], + limit: Option, + ) -> datafusion_common::Result> { + let options = ScanArgs::default() + .with_projection(projection.map(|p| p.as_slice())) + .with_filters(Some(filters)) + .with_limit(limit); + Ok(self.scan_with_args(state, options).await?.into_inner()) + } + + async fn scan_with_args<'a>( + &self, + state: &dyn Session, + args: ScanArgs<'a>, + ) -> datafusion_common::Result { + let projection = args.projection().map(|p| p.to_vec()); + let filters = args.filters().map(|f| f.to_vec()).unwrap_or_default(); + let limit = args.limit(); + + // extract types of partition columns + let table_partition_cols = self + .options + .table_partition_cols + .iter() + .map(|col| Ok(self.table_schema.field_with_name(&col.0)?.clone())) + .collect::>>()?; + + let table_partition_col_names = table_partition_cols + .iter() + .map(|field| field.name().as_str()) + .collect::>(); + + // If the filters can be resolved using only partition cols, there is no need to + // pushdown it to TableScan, otherwise, `unhandled` pruning predicates will be generated + let (partition_filters, filters): (Vec<_>, Vec<_>) = + filters.iter().cloned().partition(|filter| { + can_be_evaluated_for_partition_pruning(&table_partition_col_names, filter) + }); + + // We should not limit the number of partitioned files to scan if there are filters and limit + // at the same time. This is because the limit should be applied after the filters are applied. + let statistic_file_limit = if filters.is_empty() { limit } else { None }; + + let (mut partitioned_file_lists, statistics) = self + .list_files_for_scan(state, &partition_filters, statistic_file_limit) + .await?; + + // if no files need to be read, return an `EmptyExec` + if partitioned_file_lists.is_empty() { + let projected_schema = project_schema(&self.schema(), projection.as_ref())?; + return Ok(ScanResult::new(Arc::new(EmptyExec::new(projected_schema)))); + } + + let output_ordering = self.try_create_output_ordering(state.execution_props())?; + match state + .config_options() + .execution + .split_file_groups_by_statistics + .then(|| { + output_ordering.first().map(|output_ordering| { + FileScanConfig::split_groups_by_statistics_with_target_partitions( + &self.table_schema, + &partitioned_file_lists, + output_ordering, + self.options.target_partitions, + ) + }) + }) + .flatten() + { + Some(Err(e)) => log::debug!("failed to split file groups by statistics: {e}"), + Some(Ok(new_groups)) => { + if new_groups.len() <= self.options.target_partitions { + partitioned_file_lists = new_groups; + } else { + log::debug!("attempted to split file groups by statistics, but there were more file groups than target_partitions; falling back to unordered") + } + } + None => {} // no ordering required + }; + + let Some(object_store_url) = + self.table_paths.first().map(ListingTableUrl::object_store) + else { + return Ok(ScanResult::new(Arc::new(EmptyExec::new(Arc::new( + Schema::empty(), + ))))); + }; + + let file_source = self.create_file_source_with_schema_adapter()?; + + // create the execution plan + let plan = self + .options + .format + .create_physical_plan( + state, + FileScanConfigBuilder::new( + object_store_url, + Arc::clone(&self.file_schema), + file_source, + ) + .with_file_groups(partitioned_file_lists) + .with_constraints(self.constraints.clone()) + .with_statistics(statistics) + .with_projection(projection) + .with_limit(limit) + .with_output_ordering(output_ordering) + .with_table_partition_cols(table_partition_cols) + .with_expr_adapter(self.expr_adapter_factory.clone()) + .build(), + ) + .await?; + + Ok(ScanResult::new(plan)) + } + + fn supports_filters_pushdown( + &self, + filters: &[&Expr], + ) -> datafusion_common::Result> { + let partition_column_names = self + .options + .table_partition_cols + .iter() + .map(|col| col.0.as_str()) + .collect::>(); + filters + .iter() + .map(|filter| { + if can_be_evaluated_for_partition_pruning(&partition_column_names, filter) + { + // if filter can be handled by partition pruning, it is exact + return Ok(TableProviderFilterPushDown::Exact); + } + + Ok(TableProviderFilterPushDown::Inexact) + }) + .collect() + } + + fn get_table_definition(&self) -> Option<&str> { + self.definition.as_deref() + } + + async fn insert_into( + &self, + state: &dyn Session, + input: Arc, + insert_op: InsertOp, + ) -> datafusion_common::Result> { + // Check that the schema of the plan matches the schema of this table. + self.schema() + .logically_equivalent_names_and_types(&input.schema())?; + + let table_path = &self.table_paths()[0]; + if !table_path.is_collection() { + return plan_err!( + "Inserting into a ListingTable backed by a single file is not supported, URL is possibly missing a trailing `/`. \ + To append to an existing file use StreamTable, e.g. by using CREATE UNBOUNDED EXTERNAL TABLE" + ); + } + + // Get the object store for the table path. + let store = state.runtime_env().object_store(table_path)?; + + let file_list_stream = pruned_partition_list( + state, + store.as_ref(), + table_path, + &[], + &self.options.file_extension, + &self.options.table_partition_cols, + ) + .await?; + + let file_group = file_list_stream.try_collect::>().await?.into(); + let keep_partition_by_columns = + state.config_options().execution.keep_partition_by_columns; + + // Sink related option, apart from format + let config = FileSinkConfig { + original_url: String::default(), + object_store_url: self.table_paths()[0].object_store(), + table_paths: self.table_paths().clone(), + file_group, + output_schema: self.schema(), + table_partition_cols: self.options.table_partition_cols.clone(), + insert_op, + keep_partition_by_columns, + file_extension: self.options().format.get_ext(), + }; + + let orderings = self.try_create_output_ordering(state.execution_props())?; + // It is sufficient to pass only one of the equivalent orderings: + let order_requirements = orderings.into_iter().next().map(Into::into); + + self.options() + .format + .create_writer_physical_plan(input, state, config, order_requirements) + .await + } + + fn get_column_default(&self, column: &str) -> Option<&Expr> { + self.column_defaults.get(column) + } +} + +impl ListingTable { + /// Get the list of files for a scan as well as the file level statistics. + /// The list is grouped to let the execution plan know how the files should + /// be distributed to different threads / executors. + pub async fn list_files_for_scan<'a>( + &'a self, + ctx: &'a dyn Session, + filters: &'a [Expr], + limit: Option, + ) -> datafusion_common::Result<(Vec, Statistics)> { + let store = if let Some(url) = self.table_paths.first() { + ctx.runtime_env().object_store(url)? + } else { + return Ok((vec![], Statistics::new_unknown(&self.file_schema))); + }; + // list files (with partitions) + let file_list = future::try_join_all(self.table_paths.iter().map(|table_path| { + pruned_partition_list( + ctx, + store.as_ref(), + table_path, + filters, + &self.options.file_extension, + &self.options.table_partition_cols, + ) + })) + .await?; + let meta_fetch_concurrency = + ctx.config_options().execution.meta_fetch_concurrency; + let file_list = stream::iter(file_list).flatten_unordered(meta_fetch_concurrency); + // collect the statistics if required by the config + let files = file_list + .map(|part_file| async { + let part_file = part_file?; + let statistics = if self.options.collect_stat { + self.do_collect_statistics(ctx, &store, &part_file).await? + } else { + Arc::new(Statistics::new_unknown(&self.file_schema)) + }; + Ok(part_file.with_statistics(statistics)) + }) + .boxed() + .buffer_unordered(ctx.config_options().execution.meta_fetch_concurrency); + + let (file_group, inexact_stats) = + get_files_with_limit(files, limit, self.options.collect_stat).await?; + + let file_groups = file_group.split_files(self.options.target_partitions); + let (mut file_groups, mut stats) = compute_all_files_statistics( + file_groups, + self.schema(), + self.options.collect_stat, + inexact_stats, + )?; + + let schema_adapter = self.create_schema_adapter(); + let (schema_mapper, _) = schema_adapter.map_schema(self.file_schema.as_ref())?; + + stats.column_statistics = + schema_mapper.map_column_statistics(&stats.column_statistics)?; + file_groups.iter_mut().try_for_each(|file_group| { + if let Some(stat) = file_group.statistics_mut() { + stat.column_statistics = + schema_mapper.map_column_statistics(&stat.column_statistics)?; + } + Ok::<_, DataFusionError>(()) + })?; + Ok((file_groups, stats)) + } + + /// Collects statistics for a given partitioned file. + /// + /// This method first checks if the statistics for the given file are already cached. + /// If they are, it returns the cached statistics. + /// If they are not, it infers the statistics from the file and stores them in the cache. + async fn do_collect_statistics( + &self, + ctx: &dyn Session, + store: &Arc, + part_file: &PartitionedFile, + ) -> datafusion_common::Result> { + match self + .collected_statistics + .get_with_extra(&part_file.object_meta.location, &part_file.object_meta) + { + Some(statistics) => Ok(statistics), + None => { + let statistics = self + .options + .format + .infer_stats( + ctx, + store, + Arc::clone(&self.file_schema), + &part_file.object_meta, + ) + .await?; + let statistics = Arc::new(statistics); + self.collected_statistics.put_with_extra( + &part_file.object_meta.location, + Arc::clone(&statistics), + &part_file.object_meta, + ); + Ok(statistics) + } + } + } +} + +/// Processes a stream of partitioned files and returns a `FileGroup` containing the files. +/// +/// This function collects files from the provided stream until either: +/// 1. The stream is exhausted +/// 2. The accumulated number of rows exceeds the provided `limit` (if specified) +/// +/// # Arguments +/// * `files` - A stream of `Result` items to process +/// * `limit` - An optional row count limit. If provided, the function will stop collecting files +/// once the accumulated number of rows exceeds this limit +/// * `collect_stats` - Whether to collect and accumulate statistics from the files +/// +/// # Returns +/// A `Result` containing a `FileGroup` with the collected files +/// and a boolean indicating whether the statistics are inexact. +/// +/// # Note +/// The function will continue processing files if statistics are not available or if the +/// limit is not provided. If `collect_stats` is false, statistics won't be accumulated +/// but files will still be collected. +async fn get_files_with_limit( + files: impl Stream>, + limit: Option, + collect_stats: bool, +) -> datafusion_common::Result<(FileGroup, bool)> { + let mut file_group = FileGroup::default(); + // Fusing the stream allows us to call next safely even once it is finished. + let mut all_files = Box::pin(files.fuse()); + enum ProcessingState { + ReadingFiles, + ReachedLimit, + } + + let mut state = ProcessingState::ReadingFiles; + let mut num_rows = Precision::Absent; + + while let Some(file_result) = all_files.next().await { + // Early exit if we've already reached our limit + if matches!(state, ProcessingState::ReachedLimit) { + break; + } + + let file = file_result?; + + // Update file statistics regardless of state + if collect_stats { + if let Some(file_stats) = &file.statistics { + num_rows = if file_group.is_empty() { + // For the first file, just take its row count + file_stats.num_rows + } else { + // For subsequent files, accumulate the counts + num_rows.add(&file_stats.num_rows) + }; + } + } + + // Always add the file to our group + file_group.push(file); + + // Check if we've hit the limit (if one was specified) + if let Some(limit) = limit { + if let Precision::Exact(row_count) = num_rows { + if row_count > limit { + state = ProcessingState::ReachedLimit; + } + } + } + } + // If we still have files in the stream, it means that the limit kicked + // in, and the statistic could have been different had we processed the + // files in a different order. + let inexact_stats = all_files.next().await.is_some(); + Ok((file_group, inexact_stats)) +} diff --git a/datafusion/common/Cargo.toml b/datafusion/common/Cargo.toml index f5e51cb236d4..abeb4e66a269 100644 --- a/datafusion/common/Cargo.toml +++ b/datafusion/common/Cargo.toml @@ -71,7 +71,7 @@ log = { workspace = true } object_store = { workspace = true, optional = true } parquet = { workspace = true, optional = true, default-features = true } paste = "1.0.15" -pyo3 = { version = "0.25", optional = true } +pyo3 = { version = "0.26", optional = true } recursive = { workspace = true, optional = true } sqlparser = { workspace = true, optional = true } tokio = { workspace = true } diff --git a/datafusion/common/src/config.rs b/datafusion/common/src/config.rs index 39d730eaafb4..1713377f8d4d 100644 --- a/datafusion/common/src/config.rs +++ b/datafusion/common/src/config.rs @@ -22,18 +22,19 @@ use arrow_ipc::CompressionType; #[cfg(feature = "parquet_encryption")] use crate::encryption::{FileDecryptionProperties, FileEncryptionProperties}; use crate::error::_config_err; -use crate::format::ExplainFormat; +use crate::format::{ExplainAnalyzeLevel, ExplainFormat}; use crate::parsers::CompressionTypeVariant; use crate::utils::get_available_parallelism; use crate::{DataFusionError, Result}; +#[cfg(feature = "parquet_encryption")] +use hex; use std::any::Any; use std::collections::{BTreeMap, HashMap}; use std::error::Error; use std::fmt::{self, Display}; use std::str::FromStr; - #[cfg(feature = "parquet_encryption")] -use hex; +use std::sync::Arc; /// A macro that wraps a configuration struct and automatically derives /// [`Default`] and [`ConfigField`] for it, allowing it to be used @@ -258,7 +259,7 @@ config_namespace! { /// Configure the SQL dialect used by DataFusion's parser; supported values include: Generic, /// MySQL, PostgreSQL, Hive, SQLite, Snowflake, Redshift, MsSQL, ClickHouse, BigQuery, Ansi, DuckDB and Databricks. - pub dialect: String, default = "generic".to_string() + pub dialect: Dialect, default = Dialect::Generic // no need to lowercase because `sqlparser::dialect_from_str`] is case-insensitive /// If true, permit lengths for `VARCHAR` such as `VARCHAR(20)`, but @@ -292,6 +293,94 @@ config_namespace! { } } +/// This is the SQL dialect used by DataFusion's parser. +/// This mirrors [sqlparser::dialect::Dialect](https://docs.rs/sqlparser/latest/sqlparser/dialect/trait.Dialect.html) +/// trait in order to offer an easier API and avoid adding the `sqlparser` dependency +#[derive(Debug, Default, Clone, Copy, PartialEq, Eq)] +pub enum Dialect { + #[default] + Generic, + MySQL, + PostgreSQL, + Hive, + SQLite, + Snowflake, + Redshift, + MsSQL, + ClickHouse, + BigQuery, + Ansi, + DuckDB, + Databricks, +} + +impl AsRef for Dialect { + fn as_ref(&self) -> &str { + match self { + Self::Generic => "generic", + Self::MySQL => "mysql", + Self::PostgreSQL => "postgresql", + Self::Hive => "hive", + Self::SQLite => "sqlite", + Self::Snowflake => "snowflake", + Self::Redshift => "redshift", + Self::MsSQL => "mssql", + Self::ClickHouse => "clickhouse", + Self::BigQuery => "bigquery", + Self::Ansi => "ansi", + Self::DuckDB => "duckdb", + Self::Databricks => "databricks", + } + } +} + +impl FromStr for Dialect { + type Err = DataFusionError; + + fn from_str(s: &str) -> Result { + let value = match s.to_ascii_lowercase().as_str() { + "generic" => Self::Generic, + "mysql" => Self::MySQL, + "postgresql" | "postgres" => Self::PostgreSQL, + "hive" => Self::Hive, + "sqlite" => Self::SQLite, + "snowflake" => Self::Snowflake, + "redshift" => Self::Redshift, + "mssql" => Self::MsSQL, + "clickhouse" => Self::ClickHouse, + "bigquery" => Self::BigQuery, + "ansi" => Self::Ansi, + "duckdb" => Self::DuckDB, + "databricks" => Self::Databricks, + other => { + let error_message = format!( + "Invalid Dialect: {other}. Expected one of: Generic, MySQL, PostgreSQL, Hive, SQLite, Snowflake, Redshift, MsSQL, ClickHouse, BigQuery, Ansi, DuckDB, Databricks" + ); + return Err(DataFusionError::Configuration(error_message)); + } + }; + Ok(value) + } +} + +impl ConfigField for Dialect { + fn visit(&self, v: &mut V, key: &str, description: &'static str) { + v.some(key, self, description) + } + + fn set(&mut self, _: &str, value: &str) -> Result<()> { + *self = Self::from_str(value)?; + Ok(()) + } +} + +impl Display for Dialect { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + let str = self.as_ref(); + write!(f, "{str}") + } +} + #[derive(Debug, Default, Clone, Copy, PartialEq, Eq)] pub enum SpillCompression { Zstd, @@ -745,11 +834,21 @@ config_namespace! { /// past window functions, if possible pub enable_window_limits: bool, default = true - /// When set to true attempts to push down dynamic filters generated by operators into the file scan phase. + /// When set to true, the optimizer will attempt to push down TopK dynamic filters + /// into the file scan phase. + pub enable_topk_dynamic_filter_pushdown: bool, default = true + + /// When set to true, the optimizer will attempt to push down Join dynamic filters + /// into the file scan phase. + pub enable_join_dynamic_filter_pushdown: bool, default = true + + /// When set to true attempts to push down dynamic filters generated by operators (topk & join) into the file scan phase. /// For example, for a query such as `SELECT * FROM t ORDER BY timestamp DESC LIMIT 10`, the optimizer /// will attempt to push down the current top 10 timestamps that the TopK operator references into the file scans. /// This means that if we already have 10 timestamps in the year 2025 /// any files that only have timestamps in the year 2024 can be skipped / pruned at various stages in the scan. + /// The config will suppress `enable_join_dynamic_filter_pushdown` & `enable_topk_dynamic_filter_pushdown` + /// So if you disable `enable_topk_dynamic_filter_pushdown`, then enable `enable_dynamic_filter_pushdown`, the `enable_topk_dynamic_filter_pushdown` will be overridden. pub enable_dynamic_filter_pushdown: bool, default = true /// When set to true, the optimizer will insert filters before a join between @@ -840,6 +939,11 @@ config_namespace! { /// HashJoin can work more efficiently than SortMergeJoin but consumes more memory pub prefer_hash_join: bool, default = true + /// When set to true, piecewise merge join is enabled. PiecewiseMergeJoin is currently + /// experimental. Physical planner will opt for PiecewiseMergeJoin when there is only + /// one range filter. + pub enable_piecewise_merge_join: bool, default = false + /// The maximum estimated size in bytes for one input side of a HashJoin /// will be collected into a single partition pub hash_join_single_partition_threshold: usize, default = 1024 * 1024 @@ -893,6 +997,11 @@ config_namespace! { /// (format=tree only) Maximum total width of the rendered tree. /// When set to 0, the tree will have no width limit. pub tree_maximum_render_width: usize, default = 240 + + /// Verbosity level for "EXPLAIN ANALYZE". Default is "dev" + /// "summary" shows common metrics for high-level insights. + /// "dev" provides deep operator-level introspection for developers. + pub analyze_level: ExplainAnalyzeLevel, default = ExplainAnalyzeLevel::Dev } } @@ -1039,6 +1148,20 @@ impl ConfigOptions { }; if prefix == "datafusion" { + if key == "optimizer.enable_dynamic_filter_pushdown" { + let bool_value = value.parse::().map_err(|e| { + DataFusionError::Configuration(format!( + "Failed to parse '{value}' as bool: {e}", + )) + })?; + + { + self.optimizer.enable_dynamic_filter_pushdown = bool_value; + self.optimizer.enable_topk_dynamic_filter_pushdown = bool_value; + self.optimizer.enable_join_dynamic_filter_pushdown = bool_value; + } + return Ok(()); + } return ConfigField::set(self, key, value); } @@ -2287,13 +2410,13 @@ impl From for FileEncryptionProperties { hex::decode(&val.aad_prefix_as_hex).expect("Invalid AAD prefix"); fep = fep.with_aad_prefix(aad_prefix); } - fep.build().unwrap() + Arc::unwrap_or_clone(fep.build().unwrap()) } } #[cfg(feature = "parquet_encryption")] -impl From<&FileEncryptionProperties> for ConfigFileEncryptionProperties { - fn from(f: &FileEncryptionProperties) -> Self { +impl From<&Arc> for ConfigFileEncryptionProperties { + fn from(f: &Arc) -> Self { let (column_names_vec, column_keys_vec, column_metas_vec) = f.column_keys(); let mut column_encryption_properties: HashMap< @@ -2435,13 +2558,13 @@ impl From for FileDecryptionProperties { fep = fep.with_aad_prefix(aad_prefix); } - fep.build().unwrap() + Arc::unwrap_or_clone(fep.build().unwrap()) } } #[cfg(feature = "parquet_encryption")] -impl From<&FileDecryptionProperties> for ConfigFileDecryptionProperties { - fn from(f: &FileDecryptionProperties) -> Self { +impl From<&Arc> for ConfigFileDecryptionProperties { + fn from(f: &Arc) -> Self { let (column_names_vec, column_keys_vec) = f.column_keys(); let mut column_decryption_properties: HashMap< String, @@ -2712,6 +2835,7 @@ mod tests { }; use std::any::Any; use std::collections::HashMap; + use std::sync::Arc; #[derive(Default, Debug, Clone)] pub struct TestExtensionConfig { @@ -2868,16 +2992,15 @@ mod tests { .unwrap(); // Test round-trip - let config_encrypt: ConfigFileEncryptionProperties = - (&file_encryption_properties).into(); - let encryption_properties_built: FileEncryptionProperties = - config_encrypt.clone().into(); + let config_encrypt = + ConfigFileEncryptionProperties::from(&file_encryption_properties); + let encryption_properties_built = + Arc::new(FileEncryptionProperties::from(config_encrypt.clone())); assert_eq!(file_encryption_properties, encryption_properties_built); - let config_decrypt: ConfigFileDecryptionProperties = - (&decryption_properties).into(); - let decryption_properties_built: FileDecryptionProperties = - config_decrypt.clone().into(); + let config_decrypt = ConfigFileDecryptionProperties::from(&decryption_properties); + let decryption_properties_built = + Arc::new(FileDecryptionProperties::from(config_decrypt.clone())); assert_eq!(decryption_properties, decryption_properties_built); /////////////////////////////////////////////////////////////////////////////////// diff --git a/datafusion/common/src/datatype.rs b/datafusion/common/src/datatype.rs new file mode 100644 index 000000000000..544ec0c2468c --- /dev/null +++ b/datafusion/common/src/datatype.rs @@ -0,0 +1,175 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! [`DataTypeExt`] and [`FieldExt`] extension trait for working with DataTypes to Fields + +use crate::arrow::datatypes::{DataType, Field, FieldRef}; +use std::sync::Arc; + +/// DataFusion extension methods for Arrow [`DataType`] +pub trait DataTypeExt { + /// Convert the type to field with nullable type and "" name + /// + /// This is used to track the places where we convert a [`DataType`] + /// into a nameless field to interact with an API that is + /// capable of representing an extension type and/or nullability. + /// + /// For example, it will convert a `DataType::Int32` into + /// `Field::new("", DataType::Int32, true)`. + /// + /// ``` + /// # use datafusion_common::datatype::DataTypeExt; + /// # use arrow::datatypes::DataType; + /// let dt = DataType::Utf8; + /// let field = dt.into_nullable_field(); + /// // result is a nullable Utf8 field with "" name + /// assert_eq!(field.name(), ""); + /// assert_eq!(field.data_type(), &DataType::Utf8); + /// assert!(field.is_nullable()); + /// ``` + fn into_nullable_field(self) -> Field; + + /// Convert the type to [`FieldRef`] with nullable type and "" name + /// + /// Concise wrapper around [`DataTypeExt::into_nullable_field`] that + /// constructs a [`FieldRef`]. + fn into_nullable_field_ref(self) -> FieldRef; +} + +impl DataTypeExt for DataType { + fn into_nullable_field(self) -> Field { + Field::new("", self, true) + } + + fn into_nullable_field_ref(self) -> FieldRef { + Arc::new(Field::new("", self, true)) + } +} + +/// DataFusion extension methods for Arrow [`Field`] and [`FieldRef`] +pub trait FieldExt { + /// Returns a new Field representing a List of this Field's DataType. + /// + /// For example if input represents an `Int32`, the return value will + /// represent a `List`. + /// + /// Example: + /// ``` + /// # use std::sync::Arc; + /// # use arrow::datatypes::{DataType, Field}; + /// # use datafusion_common::datatype::FieldExt; + /// // Int32 field + /// let int_field = Field::new("my_int", DataType::Int32, true); + /// // convert to a List field + /// let list_field = int_field.into_list(); + /// // List + /// // Note that the item field name has been renamed to "item" + /// assert_eq!(list_field.data_type(), &DataType::List(Arc::new( + /// Field::new("item", DataType::Int32, true) + /// ))); + /// + fn into_list(self) -> Self; + + /// Return a new Field representing this Field as the item type of a + /// [`DataType::FixedSizeList`] + /// + /// For example if input represents an `Int32`, the return value will + /// represent a `FixedSizeList`. + /// + /// Example: + /// ``` + /// # use std::sync::Arc; + /// # use arrow::datatypes::{DataType, Field}; + /// # use datafusion_common::datatype::FieldExt; + /// // Int32 field + /// let int_field = Field::new("my_int", DataType::Int32, true); + /// // convert to a FixedSizeList field of size 3 + /// let fixed_size_list_field = int_field.into_fixed_size_list(3); + /// // FixedSizeList + /// // Note that the item field name has been renamed to "item" + /// assert_eq!( + /// fixed_size_list_field.data_type(), + /// &DataType::FixedSizeList(Arc::new( + /// Field::new("item", DataType::Int32, true)), + /// 3 + /// )); + /// + fn into_fixed_size_list(self, list_size: i32) -> Self; + + /// Update the field to have the default list field name ("item") + /// + /// Lists are allowed to have an arbitrarily named field; however, a name + /// other than 'item' will cause it to fail an == check against a more + /// idiomatically created list in arrow-rs which causes issues. + /// + /// For example, if input represents an `Int32` field named "my_int", + /// the return value will represent an `Int32` field named "item". + /// + /// Example: + /// ``` + /// # use arrow::datatypes::Field; + /// # use datafusion_common::datatype::FieldExt; + /// let my_field = Field::new("my_int", arrow::datatypes::DataType::Int32, true); + /// let item_field = my_field.into_list_item(); + /// assert_eq!(item_field.name(), Field::LIST_FIELD_DEFAULT_NAME); + /// assert_eq!(item_field.name(), "item"); + /// ``` + fn into_list_item(self) -> Self; +} + +impl FieldExt for Field { + fn into_list(self) -> Self { + DataType::List(Arc::new(self.into_list_item())).into_nullable_field() + } + + fn into_fixed_size_list(self, list_size: i32) -> Self { + DataType::FixedSizeList(self.into_list_item().into(), list_size) + .into_nullable_field() + } + + fn into_list_item(self) -> Self { + if self.name() != Field::LIST_FIELD_DEFAULT_NAME { + self.with_name(Field::LIST_FIELD_DEFAULT_NAME) + } else { + self + } + } +} + +impl FieldExt for Arc { + fn into_list(self) -> Self { + DataType::List(self.into_list_item()) + .into_nullable_field() + .into() + } + + fn into_fixed_size_list(self, list_size: i32) -> Self { + DataType::FixedSizeList(self.into_list_item(), list_size) + .into_nullable_field() + .into() + } + + fn into_list_item(self) -> Self { + if self.name() != Field::LIST_FIELD_DEFAULT_NAME { + Arc::unwrap_or_clone(self) + .with_name(Field::LIST_FIELD_DEFAULT_NAME) + .into() + } else { + self + } + } +} diff --git a/datafusion/common/src/dfschema.rs b/datafusion/common/src/dfschema.rs index 6866b4011f9e..34a36f543657 100644 --- a/datafusion/common/src/dfschema.rs +++ b/datafusion/common/src/dfschema.rs @@ -1417,7 +1417,7 @@ mod tests { fn from_qualified_schema_into_arrow_schema() -> Result<()> { let schema = DFSchema::try_from_qualified_schema("t1", &test_schema_1())?; let arrow_schema = schema.as_arrow(); - insta::assert_snapshot!(arrow_schema, @r#"Field { name: "c0", data_type: Boolean, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, Field { name: "c1", data_type: Boolean, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }"#); + insta::assert_snapshot!(arrow_schema.to_string(), @r#"Field { "c0": nullable Boolean }, Field { "c1": nullable Boolean }"#); Ok(()) } diff --git a/datafusion/common/src/encryption.rs b/datafusion/common/src/encryption.rs index b764ad77cff1..2a8cfdbc8996 100644 --- a/datafusion/common/src/encryption.rs +++ b/datafusion/common/src/encryption.rs @@ -24,38 +24,10 @@ pub use parquet::encryption::decrypt::FileDecryptionProperties; pub use parquet::encryption::encrypt::FileEncryptionProperties; #[cfg(not(feature = "parquet_encryption"))] -#[derive(Default, Debug)] +#[derive(Default, Clone, Debug)] pub struct FileDecryptionProperties; #[cfg(not(feature = "parquet_encryption"))] -#[derive(Default, Debug)] +#[derive(Default, Clone, Debug)] pub struct FileEncryptionProperties; pub use crate::config::{ConfigFileDecryptionProperties, ConfigFileEncryptionProperties}; - -#[cfg(feature = "parquet_encryption")] -pub fn map_encryption_to_config_encryption( - encryption: Option<&FileEncryptionProperties>, -) -> Option { - encryption.map(|fe| fe.into()) -} - -#[cfg(not(feature = "parquet_encryption"))] -pub fn map_encryption_to_config_encryption( - _encryption: Option<&FileEncryptionProperties>, -) -> Option { - None -} - -#[cfg(feature = "parquet_encryption")] -pub fn map_config_decryption_to_decryption( - decryption: &ConfigFileDecryptionProperties, -) -> FileDecryptionProperties { - decryption.clone().into() -} - -#[cfg(not(feature = "parquet_encryption"))] -pub fn map_config_decryption_to_decryption( - _decryption: &ConfigFileDecryptionProperties, -) -> FileDecryptionProperties { - FileDecryptionProperties {} -} diff --git a/datafusion/common/src/file_options/parquet_writer.rs b/datafusion/common/src/file_options/parquet_writer.rs index 3977f2b489e1..564929c61bab 100644 --- a/datafusion/common/src/file_options/parquet_writer.rs +++ b/datafusion/common/src/file_options/parquet_writer.rs @@ -402,15 +402,14 @@ pub(crate) fn parse_statistics_string(str_setting: &str) -> Result Result { + match level.to_lowercase().as_str() { + "summary" => Ok(ExplainAnalyzeLevel::Summary), + "dev" => Ok(ExplainAnalyzeLevel::Dev), + other => Err(DataFusionError::Configuration(format!( + "Invalid explain analyze level. Expected 'summary' or 'dev'. Got '{other}'" + ))), + } + } +} + +impl Display for ExplainAnalyzeLevel { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + let s = match self { + ExplainAnalyzeLevel::Summary => "summary", + ExplainAnalyzeLevel::Dev => "dev", + }; + write!(f, "{s}") + } +} + +impl ConfigField for ExplainAnalyzeLevel { + fn visit(&self, v: &mut V, key: &str, description: &'static str) { + v.some(key, self, description) + } + + fn set(&mut self, _: &str, value: &str) -> Result<()> { + *self = ExplainAnalyzeLevel::from_str(value)?; + Ok(()) + } +} diff --git a/datafusion/common/src/lib.rs b/datafusion/common/src/lib.rs index 24ec9b7be323..76c7b46e3273 100644 --- a/datafusion/common/src/lib.rs +++ b/datafusion/common/src/lib.rs @@ -39,6 +39,7 @@ pub mod alias; pub mod cast; pub mod config; pub mod cse; +pub mod datatype; pub mod diagnostic; pub mod display; pub mod encryption; @@ -47,6 +48,7 @@ pub mod file_options; pub mod format; pub mod hash_utils; pub mod instant; +pub mod metadata; pub mod nested_struct; mod null_equality; pub mod parsers; @@ -108,6 +110,12 @@ pub use error::{ // The HashMap and HashSet implementations that should be used as the uniform defaults pub type HashMap = hashbrown::HashMap; pub type HashSet = hashbrown::HashSet; +pub mod hash_map { + pub use hashbrown::hash_map::Entry; +} +pub mod hash_set { + pub use hashbrown::hash_set::Entry; +} /// Downcast an Arrow Array to a concrete type, return an `DataFusionError::Internal` if the cast is /// not possible. In normal usage of DataFusion the downcast should always succeed. diff --git a/datafusion/common/src/metadata.rs b/datafusion/common/src/metadata.rs new file mode 100644 index 000000000000..39065808efb9 --- /dev/null +++ b/datafusion/common/src/metadata.rs @@ -0,0 +1,371 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::{collections::BTreeMap, sync::Arc}; + +use arrow::datatypes::{DataType, Field}; +use hashbrown::HashMap; + +use crate::{error::_plan_err, DataFusionError, ScalarValue}; + +/// A [`ScalarValue`] with optional [`FieldMetadata`] +#[derive(Debug, Clone)] +pub struct ScalarAndMetadata { + pub value: ScalarValue, + pub metadata: Option, +} + +impl ScalarAndMetadata { + /// Create a new Literal from a scalar value with optional [`FieldMetadata`] + pub fn new(value: ScalarValue, metadata: Option) -> Self { + Self { value, metadata } + } + + /// Access the underlying [ScalarValue] storage + pub fn value(&self) -> &ScalarValue { + &self.value + } + + /// Access the [FieldMetadata] attached to this value, if any + pub fn metadata(&self) -> Option<&FieldMetadata> { + self.metadata.as_ref() + } + + /// Consume self and return components + pub fn into_inner(self) -> (ScalarValue, Option) { + (self.value, self.metadata) + } + + /// Cast this values's storage type + /// + /// This operation assumes that if the underlying [ScalarValue] can be casted + /// to a given type that any extension type represented by the metadata is also + /// valid. + pub fn cast_storage_to( + &self, + target_type: &DataType, + ) -> Result { + let new_value = self.value().cast_to(target_type)?; + Ok(Self::new(new_value, self.metadata.clone())) + } +} + +/// create a new ScalarAndMetadata from a ScalarValue without +/// any metadata +impl From for ScalarAndMetadata { + fn from(value: ScalarValue) -> Self { + Self::new(value, None) + } +} + +/// Assert equality of data types where one or both sides may have field metadata +/// +/// This currently compares absent metadata (e.g., one side was a DataType) and +/// empty metadata (e.g., one side was a field where the field had no metadata) +/// as equal and uses byte-for-byte comparison for the keys and values of the +/// fields, even though this is potentially too strict for some cases (e.g., +/// extension types where extension metadata is represented by JSON, or cases +/// where field metadata is orthogonal to the interpretation of the data type). +/// +/// Returns a planning error with suitably formatted type representations if +/// actual and expected do not compare to equal. +pub fn check_metadata_with_storage_equal( + actual: ( + &DataType, + Option<&std::collections::HashMap>, + ), + expected: ( + &DataType, + Option<&std::collections::HashMap>, + ), + what: &str, + context: &str, +) -> Result<(), DataFusionError> { + if actual.0 != expected.0 { + return _plan_err!( + "Expected {what} of type {}, got {}{context}", + format_type_and_metadata(expected.0, expected.1), + format_type_and_metadata(actual.0, actual.1) + ); + } + + let metadata_equal = match (actual.1, expected.1) { + (None, None) => true, + (None, Some(expected_metadata)) => expected_metadata.is_empty(), + (Some(actual_metadata), None) => actual_metadata.is_empty(), + (Some(actual_metadata), Some(expected_metadata)) => { + actual_metadata == expected_metadata + } + }; + + if !metadata_equal { + return _plan_err!( + "Expected {what} of type {}, got {}{context}", + format_type_and_metadata(expected.0, expected.1), + format_type_and_metadata(actual.0, actual.1) + ); + } + + Ok(()) +} + +/// Given a data type represented by storage and optional metadata, generate +/// a user-facing string +/// +/// This function exists to reduce the number of Field debug strings that are +/// used to communicate type information in error messages and plan explain +/// renderings. +pub fn format_type_and_metadata( + data_type: &DataType, + metadata: Option<&std::collections::HashMap>, +) -> String { + match metadata { + Some(metadata) if !metadata.is_empty() => { + format!("{data_type}<{metadata:?}>") + } + _ => data_type.to_string(), + } +} + +/// Literal metadata +/// +/// Stores metadata associated with a literal expressions +/// and is designed to be fast to `clone`. +/// +/// This structure is used to store metadata associated with a literal expression, and it +/// corresponds to the `metadata` field on [`Field`]. +/// +/// # Example: Create [`FieldMetadata`] from a [`Field`] +/// ``` +/// # use std::collections::HashMap; +/// # use datafusion_common::metadata::FieldMetadata; +/// # use arrow::datatypes::{Field, DataType}; +/// # let field = Field::new("c1", DataType::Int32, true) +/// # .with_metadata(HashMap::from([("foo".to_string(), "bar".to_string())])); +/// // Create a new `FieldMetadata` instance from a `Field` +/// let metadata = FieldMetadata::new_from_field(&field); +/// // There is also a `From` impl: +/// let metadata = FieldMetadata::from(&field); +/// ``` +/// +/// # Example: Update a [`Field`] with [`FieldMetadata`] +/// ``` +/// # use datafusion_common::metadata::FieldMetadata; +/// # use arrow::datatypes::{Field, DataType}; +/// # let field = Field::new("c1", DataType::Int32, true); +/// # let metadata = FieldMetadata::new_from_field(&field); +/// // Add any metadata from `FieldMetadata` to `Field` +/// let updated_field = metadata.add_to_field(field); +/// ``` +/// +#[derive(Clone, PartialEq, Eq, PartialOrd, Hash, Debug)] +pub struct FieldMetadata { + /// The inner metadata of a literal expression, which is a map of string + /// keys to string values. + /// + /// Note this is not a `HashMap` because `HashMap` does not provide + /// implementations for traits like `Debug` and `Hash`. + inner: Arc>, +} + +impl Default for FieldMetadata { + fn default() -> Self { + Self::new_empty() + } +} + +impl FieldMetadata { + /// Create a new empty metadata instance. + pub fn new_empty() -> Self { + Self { + inner: Arc::new(BTreeMap::new()), + } + } + + /// Merges two optional `FieldMetadata` instances, overwriting any existing + /// keys in `m` with keys from `n` if present. + /// + /// This function is commonly used in alias operations, particularly for literals + /// with metadata. When creating an alias expression, the metadata from the original + /// expression (such as a literal) is combined with any metadata specified on the alias. + /// + /// # Arguments + /// + /// * `m` - The first metadata (typically from the original expression like a literal) + /// * `n` - The second metadata (typically from the alias definition) + /// + /// # Merge Strategy + /// + /// - If both metadata instances exist, they are merged with `n` taking precedence + /// - Keys from `n` will overwrite keys from `m` if they have the same name + /// - If only one metadata instance exists, it is returned unchanged + /// - If neither exists, `None` is returned + /// + /// # Example usage + /// ```rust + /// use datafusion_common::metadata::FieldMetadata; + /// use std::collections::BTreeMap; + /// + /// // Create metadata for a literal expression + /// let literal_metadata = Some(FieldMetadata::from(BTreeMap::from([ + /// ("source".to_string(), "constant".to_string()), + /// ("type".to_string(), "int".to_string()), + /// ]))); + /// + /// // Create metadata for an alias + /// let alias_metadata = Some(FieldMetadata::from(BTreeMap::from([ + /// ("description".to_string(), "answer".to_string()), + /// ("source".to_string(), "user".to_string()), // This will override literal's "source" + /// ]))); + /// + /// // Merge the metadata + /// let merged = FieldMetadata::merge_options( + /// literal_metadata.as_ref(), + /// alias_metadata.as_ref(), + /// ); + /// + /// // Result contains: {"source": "user", "type": "int", "description": "answer"} + /// assert!(merged.is_some()); + /// ``` + pub fn merge_options( + m: Option<&FieldMetadata>, + n: Option<&FieldMetadata>, + ) -> Option { + match (m, n) { + (Some(m), Some(n)) => { + let mut merged = m.clone(); + merged.extend(n.clone()); + Some(merged) + } + (Some(m), None) => Some(m.clone()), + (None, Some(n)) => Some(n.clone()), + (None, None) => None, + } + } + + /// Create a new metadata instance from a `Field`'s metadata. + pub fn new_from_field(field: &Field) -> Self { + let inner = field + .metadata() + .iter() + .map(|(k, v)| (k.to_string(), v.to_string())) + .collect(); + Self { + inner: Arc::new(inner), + } + } + + /// Create a new metadata instance from a map of string keys to string values. + pub fn new(inner: BTreeMap) -> Self { + Self { + inner: Arc::new(inner), + } + } + + /// Get the inner metadata as a reference to a `BTreeMap`. + pub fn inner(&self) -> &BTreeMap { + &self.inner + } + + /// Return the inner metadata + pub fn into_inner(self) -> Arc> { + self.inner + } + + /// Adds metadata from `other` into `self`, overwriting any existing keys. + pub fn extend(&mut self, other: Self) { + if other.is_empty() { + return; + } + let other = Arc::unwrap_or_clone(other.into_inner()); + Arc::make_mut(&mut self.inner).extend(other); + } + + /// Returns true if the metadata is empty. + pub fn is_empty(&self) -> bool { + self.inner.is_empty() + } + + /// Returns the number of key-value pairs in the metadata. + pub fn len(&self) -> usize { + self.inner.len() + } + + /// Convert this `FieldMetadata` into a `HashMap` + pub fn to_hashmap(&self) -> std::collections::HashMap { + self.inner + .iter() + .map(|(k, v)| (k.to_string(), v.to_string())) + .collect() + } + + /// Updates the metadata on the Field with this metadata, if it is not empty. + pub fn add_to_field(&self, field: Field) -> Field { + if self.inner.is_empty() { + return field; + } + + field.with_metadata(self.to_hashmap()) + } +} + +impl From<&Field> for FieldMetadata { + fn from(field: &Field) -> Self { + Self::new_from_field(field) + } +} + +impl From> for FieldMetadata { + fn from(inner: BTreeMap) -> Self { + Self::new(inner) + } +} + +impl From> for FieldMetadata { + fn from(map: std::collections::HashMap) -> Self { + Self::new(map.into_iter().collect()) + } +} + +/// From reference +impl From<&std::collections::HashMap> for FieldMetadata { + fn from(map: &std::collections::HashMap) -> Self { + let inner = map + .iter() + .map(|(k, v)| (k.to_string(), v.to_string())) + .collect(); + Self::new(inner) + } +} + +/// From hashbrown map +impl From> for FieldMetadata { + fn from(map: HashMap) -> Self { + let inner = map.into_iter().collect(); + Self::new(inner) + } +} + +impl From<&HashMap> for FieldMetadata { + fn from(map: &HashMap) -> Self { + let inner = map + .into_iter() + .map(|(k, v)| (k.to_string(), v.to_string())) + .collect(); + Self::new(inner) + } +} diff --git a/datafusion/common/src/param_value.rs b/datafusion/common/src/param_value.rs index 7582cff56f87..ebf68e4dd210 100644 --- a/datafusion/common/src/param_value.rs +++ b/datafusion/common/src/param_value.rs @@ -16,22 +16,37 @@ // under the License. use crate::error::{_plan_datafusion_err, _plan_err}; +use crate::metadata::{check_metadata_with_storage_equal, ScalarAndMetadata}; use crate::{Result, ScalarValue}; -use arrow::datatypes::DataType; +use arrow::datatypes::{DataType, Field, FieldRef}; use std::collections::HashMap; /// The parameter value corresponding to the placeholder #[derive(Debug, Clone)] pub enum ParamValues { /// For positional query parameters, like `SELECT * FROM test WHERE a > $1 AND b = $2` - List(Vec), + List(Vec), /// For named query parameters, like `SELECT * FROM test WHERE a > $foo AND b = $goo` - Map(HashMap), + Map(HashMap), } impl ParamValues { - /// Verify parameter list length and type + /// Verify parameter list length and DataType + /// + /// Use [`ParamValues::verify_fields`] to ensure field metadata is considered when + /// computing type equality. + #[deprecated(since = "51.0.0", note = "Use verify_fields instead")] pub fn verify(&self, expect: &[DataType]) -> Result<()> { + // make dummy Fields + let expect = expect + .iter() + .map(|dt| Field::new("", dt.clone(), true).into()) + .collect::>(); + self.verify_fields(&expect) + } + + /// Verify parameter list length and type + pub fn verify_fields(&self, expect: &[FieldRef]) -> Result<()> { match self { ParamValues::List(list) => { // Verify if the number of params matches the number of values @@ -45,15 +60,16 @@ impl ParamValues { // Verify if the types of the params matches the types of the values let iter = expect.iter().zip(list.iter()); - for (i, (param_type, value)) in iter.enumerate() { - if *param_type != value.data_type() { - return _plan_err!( - "Expected parameter of type {}, got {:?} at index {}", - param_type, - value.data_type(), - i - ); - } + for (i, (param_type, lit)) in iter.enumerate() { + check_metadata_with_storage_equal( + ( + &lit.value.data_type(), + lit.metadata.as_ref().map(|m| m.to_hashmap()).as_ref(), + ), + (param_type.data_type(), Some(param_type.metadata())), + "parameter", + &format!(" at index {i}"), + )?; } Ok(()) } @@ -65,7 +81,7 @@ impl ParamValues { } } - pub fn get_placeholders_with_values(&self, id: &str) -> Result { + pub fn get_placeholders_with_values(&self, id: &str) -> Result { match self { ParamValues::List(list) => { if id.is_empty() { @@ -99,7 +115,7 @@ impl ParamValues { impl From> for ParamValues { fn from(value: Vec) -> Self { - Self::List(value) + Self::List(value.into_iter().map(ScalarAndMetadata::from).collect()) } } @@ -108,8 +124,10 @@ where K: Into, { fn from(value: Vec<(K, ScalarValue)>) -> Self { - let value: HashMap = - value.into_iter().map(|(k, v)| (k.into(), v)).collect(); + let value: HashMap = value + .into_iter() + .map(|(k, v)| (k.into(), ScalarAndMetadata::from(v))) + .collect(); Self::Map(value) } } @@ -119,8 +137,10 @@ where K: Into, { fn from(value: HashMap) -> Self { - let value: HashMap = - value.into_iter().map(|(k, v)| (k.into(), v)).collect(); + let value: HashMap = value + .into_iter() + .map(|(k, v)| (k.into(), ScalarAndMetadata::from(v))) + .collect(); Self::Map(value) } } diff --git a/datafusion/common/src/pyarrow.rs b/datafusion/common/src/pyarrow.rs index ff413e08ab07..3b7d80b3da78 100644 --- a/datafusion/common/src/pyarrow.rs +++ b/datafusion/common/src/pyarrow.rs @@ -22,7 +22,7 @@ use arrow::pyarrow::{FromPyArrow, ToPyArrow}; use pyo3::exceptions::PyException; use pyo3::prelude::PyErr; use pyo3::types::{PyAnyMethods, PyList}; -use pyo3::{Bound, FromPyObject, IntoPyObject, PyAny, PyObject, PyResult, Python}; +use pyo3::{Bound, FromPyObject, IntoPyObject, PyAny, PyResult, Python}; use crate::{DataFusionError, ScalarValue}; @@ -52,11 +52,11 @@ impl FromPyArrow for ScalarValue { } impl ToPyArrow for ScalarValue { - fn to_pyarrow(&self, py: Python) -> PyResult { + fn to_pyarrow<'py>(&self, py: Python<'py>) -> PyResult> { let array = self.to_array()?; // convert to pyarrow array using C data interface let pyarray = array.to_data().to_pyarrow(py)?; - let pyscalar = pyarray.call_method1(py, "__getitem__", (0,))?; + let pyscalar = pyarray.call_method1("__getitem__", (0,))?; Ok(pyscalar) } @@ -79,23 +79,22 @@ impl<'source> IntoPyObject<'source> for ScalarValue { let array = self.to_array()?; // convert to pyarrow array using C data interface let pyarray = array.to_data().to_pyarrow(py)?; - let pyarray_bound = pyarray.bind(py); - pyarray_bound.call_method1("__getitem__", (0,)) + pyarray.call_method1("__getitem__", (0,)) } } #[cfg(test)] mod tests { use pyo3::ffi::c_str; - use pyo3::prepare_freethreaded_python; use pyo3::py_run; use pyo3::types::PyDict; + use pyo3::Python; use super::*; fn init_python() { - prepare_freethreaded_python(); - Python::with_gil(|py| { + Python::initialize(); + Python::attach(|py| { if py.run(c_str!("import pyarrow"), None, None).is_err() { let locals = PyDict::new(py); py.run( @@ -135,12 +134,11 @@ mod tests { ScalarValue::Date32(Some(1234)), ]; - Python::with_gil(|py| { + Python::attach(|py| { for scalar in example_scalars.iter() { - let result = ScalarValue::from_pyarrow_bound( - scalar.to_pyarrow(py).unwrap().bind(py), - ) - .unwrap(); + let result = + ScalarValue::from_pyarrow_bound(&scalar.to_pyarrow(py).unwrap()) + .unwrap(); assert_eq!(scalar, &result); } }); @@ -150,7 +148,7 @@ mod tests { fn test_py_scalar() -> PyResult<()> { init_python(); - Python::with_gil(|py| -> PyResult<()> { + Python::attach(|py| -> PyResult<()> { let scalar_float = ScalarValue::Float64(Some(12.34)); let py_float = scalar_float .into_pyobject(py)? diff --git a/datafusion/common/src/scalar/mod.rs b/datafusion/common/src/scalar/mod.rs index 60ff1f4b2ed4..a70a027a8fac 100644 --- a/datafusion/common/src/scalar/mod.rs +++ b/datafusion/common/src/scalar/mod.rs @@ -70,7 +70,7 @@ use arrow::array::{ TimestampNanosecondArray, TimestampSecondArray, UInt16Array, UInt32Array, UInt64Array, UInt8Array, UnionArray, }; -use arrow::buffer::ScalarBuffer; +use arrow::buffer::{BooleanBuffer, ScalarBuffer}; use arrow::compute::kernels::cast::{cast_with_options, CastOptions}; use arrow::compute::kernels::numeric::{ add, add_wrapping, div, mul, mul_wrapping, rem, sub, sub_wrapping, @@ -2888,9 +2888,17 @@ impl ScalarValue { ScalarValue::Decimal256(e, precision, scale) => Arc::new( ScalarValue::build_decimal256_array(*e, *precision, *scale, size)?, ), - ScalarValue::Boolean(e) => { - Arc::new(BooleanArray::from(vec![*e; size])) as ArrayRef - } + ScalarValue::Boolean(e) => match e { + None => new_null_array(&DataType::Boolean, size), + Some(true) => { + Arc::new(BooleanArray::new(BooleanBuffer::new_set(size), None)) + as ArrayRef + } + Some(false) => { + Arc::new(BooleanArray::new(BooleanBuffer::new_unset(size), None)) + as ArrayRef + } + }, ScalarValue::Float64(e) => { build_array_from_option!(Float64, Float64Array, e, size) } @@ -2973,15 +2981,13 @@ impl ScalarValue { Some(value) => Arc::new( repeat_n(Some(value.as_slice()), size).collect::(), ), - None => Arc::new(repeat_n(None::<&str>, size).collect::()), + None => new_null_array(&DataType::Binary, size), }, ScalarValue::BinaryView(e) => match e { Some(value) => Arc::new( repeat_n(Some(value.as_slice()), size).collect::(), ), - None => { - Arc::new(repeat_n(None::<&str>, size).collect::()) - } + None => new_null_array(&DataType::BinaryView, size), }, ScalarValue::FixedSizeBinary(s, e) => match e { Some(value) => Arc::new( @@ -2991,21 +2997,13 @@ impl ScalarValue { ) .unwrap(), ), - None => Arc::new( - FixedSizeBinaryArray::try_from_sparse_iter_with_size( - repeat_n(None::<&[u8]>, size), - *s, - ) - .unwrap(), - ), + None => Arc::new(FixedSizeBinaryArray::new_null(*s, size)), }, ScalarValue::LargeBinary(e) => match e { Some(value) => Arc::new( repeat_n(Some(value.as_slice()), size).collect::(), ), - None => { - Arc::new(repeat_n(None::<&str>, size).collect::()) - } + None => new_null_array(&DataType::LargeBinary, size), }, ScalarValue::List(arr) => { if size == 1 { diff --git a/datafusion/common/src/table_reference.rs b/datafusion/common/src/table_reference.rs index 7cf8e7af1a79..574465856760 100644 --- a/datafusion/common/src/table_reference.rs +++ b/datafusion/common/src/table_reference.rs @@ -269,24 +269,41 @@ impl TableReference { } /// Forms a [`TableReference`] by parsing `s` as a multipart SQL - /// identifier. See docs on [`TableReference`] for more details. + /// identifier, normalizing `s` to lowercase. + /// See docs on [`TableReference`] for more details. pub fn parse_str(s: &str) -> Self { - let mut parts = parse_identifiers_normalized(s, false); + Self::parse_str_normalized(s, false) + } + + /// Forms a [`TableReference`] by parsing `s` as a multipart SQL + /// identifier, normalizing `s` to lowercase if `ignore_case` is `false`. + /// See docs on [`TableReference`] for more details. + pub fn parse_str_normalized(s: &str, ignore_case: bool) -> Self { + let table_parts = parse_identifiers_normalized(s, ignore_case); + Self::from_vec(table_parts).unwrap_or_else(|| Self::Bare { table: s.into() }) + } + + /// Consume a vector of identifier parts to compose a [`TableReference`]. The input vector + /// should contain 1 <= N <= 3 elements in the following sequence: + /// ```no_rust + /// [, , table] + /// ``` + fn from_vec(mut parts: Vec) -> Option { match parts.len() { - 1 => Self::Bare { - table: parts.remove(0).into(), - }, - 2 => Self::Partial { - schema: parts.remove(0).into(), - table: parts.remove(0).into(), - }, - 3 => Self::Full { - catalog: parts.remove(0).into(), - schema: parts.remove(0).into(), - table: parts.remove(0).into(), - }, - _ => Self::Bare { table: s.into() }, + 1 => Some(Self::Bare { + table: parts.pop()?.into(), + }), + 2 => Some(Self::Partial { + table: parts.pop()?.into(), + schema: parts.pop()?.into(), + }), + 3 => Some(Self::Full { + table: parts.pop()?.into(), + schema: parts.pop()?.into(), + catalog: parts.pop()?.into(), + }), + _ => None, } } diff --git a/datafusion/common/src/utils/mod.rs b/datafusion/common/src/utils/mod.rs index c72e3b3a8df7..045c02a5a2aa 100644 --- a/datafusion/common/src/utils/mod.rs +++ b/datafusion/common/src/utils/mod.rs @@ -285,6 +285,9 @@ pub(crate) fn parse_identifiers(s: &str) -> Result> { Ok(idents) } +/// Parse a string into a vector of identifiers. +/// +/// Note: If ignore_case is false, the string will be normalized to lowercase. #[cfg(feature = "sql")] pub(crate) fn parse_identifiers_normalized(s: &str, ignore_case: bool) -> Vec { parse_identifiers(s) diff --git a/datafusion/core/Cargo.toml b/datafusion/core/Cargo.toml index d3bc4546588d..22c9f43a902e 100644 --- a/datafusion/core/Cargo.toml +++ b/datafusion/core/Cargo.toml @@ -47,7 +47,7 @@ compression = [ "bzip2", "flate2", "zstd", - "arrow-ipc/zstd", + "datafusion-datasource-arrow/compression", "datafusion-datasource/compression", ] crypto_expressions = ["datafusion-functions/crypto_expressions"] @@ -109,17 +109,17 @@ extended_tests = [] [dependencies] arrow = { workspace = true } -arrow-ipc = { workspace = true } arrow-schema = { workspace = true, features = ["canonical_extension_types"] } async-trait = { workspace = true } bytes = { workspace = true } -bzip2 = { version = "0.6.0", optional = true } +bzip2 = { version = "0.6.1", optional = true } chrono = { workspace = true } datafusion-catalog = { workspace = true } datafusion-catalog-listing = { workspace = true } datafusion-common = { workspace = true, features = ["object_store"] } datafusion-common-runtime = { workspace = true } datafusion-datasource = { workspace = true } +datafusion-datasource-arrow = { workspace = true } datafusion-datasource-avro = { workspace = true, optional = true } datafusion-datasource-csv = { workspace = true } datafusion-datasource-json = { workspace = true } diff --git a/datafusion/core/benches/parquet_query_sql.rs b/datafusion/core/benches/parquet_query_sql.rs index 14dcdf15f173..e2b381048013 100644 --- a/datafusion/core/benches/parquet_query_sql.rs +++ b/datafusion/core/benches/parquet_query_sql.rs @@ -166,11 +166,12 @@ fn generate_file() -> NamedTempFile { } let metadata = writer.close().unwrap(); + let file_metadata = metadata.file_metadata(); assert_eq!( - metadata.num_rows as usize, + file_metadata.num_rows() as usize, WRITE_RECORD_BATCH_SIZE * NUM_BATCHES ); - assert_eq!(metadata.row_groups.len(), EXPECTED_ROW_GROUPS); + assert_eq!(metadata.row_groups().len(), EXPECTED_ROW_GROUPS); println!( "Generated parquet file in {} seconds", diff --git a/datafusion/core/benches/sql_planner.rs b/datafusion/core/benches/sql_planner.rs index 3be8668b2b8c..83563099cad6 100644 --- a/datafusion/core/benches/sql_planner.rs +++ b/datafusion/core/benches/sql_planner.rs @@ -30,7 +30,7 @@ use criterion::Bencher; use datafusion::datasource::MemTable; use datafusion::execution::context::SessionContext; use datafusion::prelude::DataFrame; -use datafusion_common::ScalarValue; +use datafusion_common::{config::Dialect, ScalarValue}; use datafusion_expr::Expr::Literal; use datafusion_expr::{cast, col, lit, not, try_cast, when}; use datafusion_functions::expr_fn::{ @@ -288,7 +288,10 @@ fn benchmark_with_param_values_many_columns( } // SELECT max(attr0), ..., max(attrN) FROM t1. let query = format!("SELECT {aggregates} FROM t1"); - let statement = ctx.state().sql_to_statement(&query, "Generic").unwrap(); + let statement = ctx + .state() + .sql_to_statement(&query, &Dialect::Generic) + .unwrap(); let plan = rt.block_on(async { ctx.state().statement_to_plan(statement).await.unwrap() }); b.iter(|| { diff --git a/datafusion/core/src/dataframe/parquet.rs b/datafusion/core/src/dataframe/parquet.rs index d46a902ca513..930b4fad1d9b 100644 --- a/datafusion/core/src/dataframe/parquet.rs +++ b/datafusion/core/src/dataframe/parquet.rs @@ -116,6 +116,8 @@ mod tests { use datafusion_execution::config::SessionConfig; use datafusion_expr::{col, lit}; + #[cfg(feature = "parquet_encryption")] + use datafusion_common::config::ConfigFileEncryptionProperties; use object_store::local::LocalFileSystem; use parquet::file::reader::FileReader; use tempfile::TempDir; @@ -280,7 +282,8 @@ mod tests { // Write encrypted parquet using write_parquet let mut options = TableParquetOptions::default(); - options.crypto.file_encryption = Some((&encrypt).into()); + options.crypto.file_encryption = + Some(ConfigFileEncryptionProperties::from(&encrypt)); options.global.allow_single_file_parallelism = allow_single_file_parallelism; df.write_parquet( diff --git a/datafusion/core/src/datasource/dynamic_file.rs b/datafusion/core/src/datasource/dynamic_file.rs index b30d53e58691..256a11ba693b 100644 --- a/datafusion/core/src/datasource/dynamic_file.rs +++ b/datafusion/core/src/datasource/dynamic_file.rs @@ -20,6 +20,7 @@ use std::sync::Arc; +use crate::datasource::listing::ListingTableConfigExt; use crate::datasource::listing::{ListingTable, ListingTableConfig, ListingTableUrl}; use crate::datasource::TableProvider; use crate::error::Result; diff --git a/datafusion/core/src/datasource/file_format/arrow.rs b/datafusion/core/src/datasource/file_format/arrow.rs index 25bc166d657a..8701f96eb3b8 100644 --- a/datafusion/core/src/datasource/file_format/arrow.rs +++ b/datafusion/core/src/datasource/file_format/arrow.rs @@ -15,510 +15,5 @@ // specific language governing permissions and limitations // under the License. -//! [`ArrowFormat`]: Apache Arrow [`FileFormat`] abstractions -//! -//! Works with files following the [Arrow IPC format](https://arrow.apache.org/docs/format/Columnar.html#ipc-file-format) - -use std::any::Any; -use std::borrow::Cow; -use std::collections::HashMap; -use std::fmt::{self, Debug}; -use std::sync::Arc; - -use super::file_compression_type::FileCompressionType; -use super::write::demux::DemuxedStreamReceiver; -use super::write::SharedBuffer; -use super::FileFormatFactory; -use crate::datasource::file_format::write::get_writer_schema; -use crate::datasource::file_format::FileFormat; -use crate::datasource::physical_plan::{ArrowSource, FileSink, FileSinkConfig}; -use crate::error::Result; -use crate::physical_plan::{DisplayAs, DisplayFormatType, ExecutionPlan}; - -use arrow::datatypes::{Schema, SchemaRef}; -use arrow::error::ArrowError; -use arrow::ipc::convert::fb_to_schema; -use arrow::ipc::reader::FileReader; -use arrow::ipc::writer::IpcWriteOptions; -use arrow::ipc::{root_as_message, CompressionType}; -use datafusion_catalog::Session; -use datafusion_common::parsers::CompressionTypeVariant; -use datafusion_common::{ - internal_datafusion_err, not_impl_err, DataFusionError, GetExt, Statistics, - DEFAULT_ARROW_EXTENSION, -}; -use datafusion_common_runtime::{JoinSet, SpawnedTask}; -use datafusion_datasource::display::FileGroupDisplay; -use datafusion_datasource::file::FileSource; -use datafusion_datasource::file_scan_config::{FileScanConfig, FileScanConfigBuilder}; -use datafusion_datasource::sink::{DataSink, DataSinkExec}; -use datafusion_datasource::write::ObjectWriterBuilder; -use datafusion_execution::{SendableRecordBatchStream, TaskContext}; -use datafusion_expr::dml::InsertOp; -use datafusion_physical_expr_common::sort_expr::LexRequirement; - -use async_trait::async_trait; -use bytes::Bytes; -use datafusion_datasource::source::DataSourceExec; -use futures::stream::BoxStream; -use futures::StreamExt; -use object_store::{GetResultPayload, ObjectMeta, ObjectStore}; -use tokio::io::AsyncWriteExt; - -/// Initial writing buffer size. Note this is just a size hint for efficiency. It -/// will grow beyond the set value if needed. -const INITIAL_BUFFER_BYTES: usize = 1048576; - -/// If the buffered Arrow data exceeds this size, it is flushed to object store -const BUFFER_FLUSH_BYTES: usize = 1024000; - -#[derive(Default, Debug)] -/// Factory struct used to create [ArrowFormat] -pub struct ArrowFormatFactory; - -impl ArrowFormatFactory { - /// Creates an instance of [ArrowFormatFactory] - pub fn new() -> Self { - Self {} - } -} - -impl FileFormatFactory for ArrowFormatFactory { - fn create( - &self, - _state: &dyn Session, - _format_options: &HashMap, - ) -> Result> { - Ok(Arc::new(ArrowFormat)) - } - - fn default(&self) -> Arc { - Arc::new(ArrowFormat) - } - - fn as_any(&self) -> &dyn Any { - self - } -} - -impl GetExt for ArrowFormatFactory { - fn get_ext(&self) -> String { - // Removes the dot, i.e. ".parquet" -> "parquet" - DEFAULT_ARROW_EXTENSION[1..].to_string() - } -} - -/// Arrow `FileFormat` implementation. -#[derive(Default, Debug)] -pub struct ArrowFormat; - -#[async_trait] -impl FileFormat for ArrowFormat { - fn as_any(&self) -> &dyn Any { - self - } - - fn get_ext(&self) -> String { - ArrowFormatFactory::new().get_ext() - } - - fn get_ext_with_compression( - &self, - file_compression_type: &FileCompressionType, - ) -> Result { - let ext = self.get_ext(); - match file_compression_type.get_variant() { - CompressionTypeVariant::UNCOMPRESSED => Ok(ext), - _ => Err(internal_datafusion_err!( - "Arrow FileFormat does not support compression." - )), - } - } - - fn compression_type(&self) -> Option { - None - } - - async fn infer_schema( - &self, - _state: &dyn Session, - store: &Arc, - objects: &[ObjectMeta], - ) -> Result { - let mut schemas = vec![]; - for object in objects { - let r = store.as_ref().get(&object.location).await?; - let schema = match r.payload { - #[cfg(not(target_arch = "wasm32"))] - GetResultPayload::File(mut file, _) => { - let reader = FileReader::try_new(&mut file, None)?; - reader.schema() - } - GetResultPayload::Stream(stream) => { - infer_schema_from_file_stream(stream).await? - } - }; - schemas.push(schema.as_ref().clone()); - } - let merged_schema = Schema::try_merge(schemas)?; - Ok(Arc::new(merged_schema)) - } - - async fn infer_stats( - &self, - _state: &dyn Session, - _store: &Arc, - table_schema: SchemaRef, - _object: &ObjectMeta, - ) -> Result { - Ok(Statistics::new_unknown(&table_schema)) - } - - async fn create_physical_plan( - &self, - _state: &dyn Session, - conf: FileScanConfig, - ) -> Result> { - let source = Arc::new(ArrowSource::default()); - let config = FileScanConfigBuilder::from(conf) - .with_source(source) - .build(); - - Ok(DataSourceExec::from_data_source(config)) - } - - async fn create_writer_physical_plan( - &self, - input: Arc, - _state: &dyn Session, - conf: FileSinkConfig, - order_requirements: Option, - ) -> Result> { - if conf.insert_op != InsertOp::Append { - return not_impl_err!("Overwrites are not implemented yet for Arrow format"); - } - - let sink = Arc::new(ArrowFileSink::new(conf)); - - Ok(Arc::new(DataSinkExec::new(input, sink, order_requirements)) as _) - } - - fn file_source(&self) -> Arc { - Arc::new(ArrowSource::default()) - } -} - -/// Implements [`FileSink`] for writing to arrow_ipc files -struct ArrowFileSink { - config: FileSinkConfig, -} - -impl ArrowFileSink { - fn new(config: FileSinkConfig) -> Self { - Self { config } - } -} - -#[async_trait] -impl FileSink for ArrowFileSink { - fn config(&self) -> &FileSinkConfig { - &self.config - } - - async fn spawn_writer_tasks_and_join( - &self, - context: &Arc, - demux_task: SpawnedTask>, - mut file_stream_rx: DemuxedStreamReceiver, - object_store: Arc, - ) -> Result { - let mut file_write_tasks: JoinSet> = - JoinSet::new(); - - let ipc_options = - IpcWriteOptions::try_new(64, false, arrow_ipc::MetadataVersion::V5)? - .try_with_compression(Some(CompressionType::LZ4_FRAME))?; - while let Some((path, mut rx)) = file_stream_rx.recv().await { - let shared_buffer = SharedBuffer::new(INITIAL_BUFFER_BYTES); - let mut arrow_writer = arrow_ipc::writer::FileWriter::try_new_with_options( - shared_buffer.clone(), - &get_writer_schema(&self.config), - ipc_options.clone(), - )?; - let mut object_store_writer = ObjectWriterBuilder::new( - FileCompressionType::UNCOMPRESSED, - &path, - Arc::clone(&object_store), - ) - .with_buffer_size(Some( - context - .session_config() - .options() - .execution - .objectstore_writer_buffer_size, - )) - .build()?; - file_write_tasks.spawn(async move { - let mut row_count = 0; - while let Some(batch) = rx.recv().await { - row_count += batch.num_rows(); - arrow_writer.write(&batch)?; - let mut buff_to_flush = shared_buffer.buffer.try_lock().unwrap(); - if buff_to_flush.len() > BUFFER_FLUSH_BYTES { - object_store_writer - .write_all(buff_to_flush.as_slice()) - .await?; - buff_to_flush.clear(); - } - } - arrow_writer.finish()?; - let final_buff = shared_buffer.buffer.try_lock().unwrap(); - - object_store_writer.write_all(final_buff.as_slice()).await?; - object_store_writer.shutdown().await?; - Ok(row_count) - }); - } - - let mut row_count = 0; - while let Some(result) = file_write_tasks.join_next().await { - match result { - Ok(r) => { - row_count += r?; - } - Err(e) => { - if e.is_panic() { - std::panic::resume_unwind(e.into_panic()); - } else { - unreachable!(); - } - } - } - } - - demux_task - .join_unwind() - .await - .map_err(|e| DataFusionError::ExecutionJoin(Box::new(e)))??; - Ok(row_count as u64) - } -} - -impl Debug for ArrowFileSink { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - f.debug_struct("ArrowFileSink").finish() - } -} - -impl DisplayAs for ArrowFileSink { - fn fmt_as(&self, t: DisplayFormatType, f: &mut fmt::Formatter<'_>) -> fmt::Result { - match t { - DisplayFormatType::Default | DisplayFormatType::Verbose => { - write!(f, "ArrowFileSink(file_groups=",)?; - FileGroupDisplay(&self.config.file_group).fmt_as(t, f)?; - write!(f, ")") - } - DisplayFormatType::TreeRender => { - writeln!(f, "format: arrow")?; - write!(f, "file={}", &self.config.original_url) - } - } - } -} - -#[async_trait] -impl DataSink for ArrowFileSink { - fn as_any(&self) -> &dyn Any { - self - } - - fn schema(&self) -> &SchemaRef { - self.config.output_schema() - } - - async fn write_all( - &self, - data: SendableRecordBatchStream, - context: &Arc, - ) -> Result { - FileSink::write_all(self, data, context).await - } -} - -const ARROW_MAGIC: [u8; 6] = [b'A', b'R', b'R', b'O', b'W', b'1']; -const CONTINUATION_MARKER: [u8; 4] = [0xff; 4]; - -/// Custom implementation of inferring schema. Should eventually be moved upstream to arrow-rs. -/// See -async fn infer_schema_from_file_stream( - mut stream: BoxStream<'static, object_store::Result>, -) -> Result { - // Expected format: - // - 6 bytes - // - 2 bytes - // - 4 bytes, not present below v0.15.0 - // - 4 bytes - // - // - - // So in first read we need at least all known sized sections, - // which is 6 + 2 + 4 + 4 = 16 bytes. - let bytes = collect_at_least_n_bytes(&mut stream, 16, None).await?; - - // Files should start with these magic bytes - if bytes[0..6] != ARROW_MAGIC { - return Err(ArrowError::ParseError( - "Arrow file does not contain correct header".to_string(), - ))?; - } - - // Since continuation marker bytes added in later versions - let (meta_len, rest_of_bytes_start_index) = if bytes[8..12] == CONTINUATION_MARKER { - (&bytes[12..16], 16) - } else { - (&bytes[8..12], 12) - }; - - let meta_len = [meta_len[0], meta_len[1], meta_len[2], meta_len[3]]; - let meta_len = i32::from_le_bytes(meta_len); - - // Read bytes for Schema message - let block_data = if bytes[rest_of_bytes_start_index..].len() < meta_len as usize { - // Need to read more bytes to decode Message - let mut block_data = Vec::with_capacity(meta_len as usize); - // In case we had some spare bytes in our initial read chunk - block_data.extend_from_slice(&bytes[rest_of_bytes_start_index..]); - let size_to_read = meta_len as usize - block_data.len(); - let block_data = - collect_at_least_n_bytes(&mut stream, size_to_read, Some(block_data)).await?; - Cow::Owned(block_data) - } else { - // Already have the bytes we need - let end_index = meta_len as usize + rest_of_bytes_start_index; - let block_data = &bytes[rest_of_bytes_start_index..end_index]; - Cow::Borrowed(block_data) - }; - - // Decode Schema message - let message = root_as_message(&block_data).map_err(|err| { - ArrowError::ParseError(format!("Unable to read IPC message as metadata: {err:?}")) - })?; - let ipc_schema = message.header_as_schema().ok_or_else(|| { - ArrowError::IpcError("Unable to read IPC message as schema".to_string()) - })?; - let schema = fb_to_schema(ipc_schema); - - Ok(Arc::new(schema)) -} - -async fn collect_at_least_n_bytes( - stream: &mut BoxStream<'static, object_store::Result>, - n: usize, - extend_from: Option>, -) -> Result> { - let mut buf = extend_from.unwrap_or_else(|| Vec::with_capacity(n)); - // If extending existing buffer then ensure we read n additional bytes - let n = n + buf.len(); - while let Some(bytes) = stream.next().await.transpose()? { - buf.extend_from_slice(&bytes); - if buf.len() >= n { - break; - } - } - if buf.len() < n { - return Err(ArrowError::ParseError( - "Unexpected end of byte stream for Arrow IPC file".to_string(), - ))?; - } - Ok(buf) -} - -#[cfg(test)] -mod tests { - use super::*; - use crate::execution::context::SessionContext; - - use chrono::DateTime; - use object_store::{chunked::ChunkedStore, memory::InMemory, path::Path}; - - #[tokio::test] - async fn test_infer_schema_stream() -> Result<()> { - let mut bytes = std::fs::read("tests/data/example.arrow")?; - bytes.truncate(bytes.len() - 20); // mangle end to show we don't need to read whole file - let location = Path::parse("example.arrow")?; - let in_memory_store: Arc = Arc::new(InMemory::new()); - in_memory_store.put(&location, bytes.into()).await?; - - let session_ctx = SessionContext::new(); - let state = session_ctx.state(); - let object_meta = ObjectMeta { - location, - last_modified: DateTime::default(), - size: u64::MAX, - e_tag: None, - version: None, - }; - - let arrow_format = ArrowFormat {}; - let expected = vec!["f0: Int64", "f1: Utf8", "f2: Boolean"]; - - // Test chunk sizes where too small so we keep having to read more bytes - // And when large enough that first read contains all we need - for chunk_size in [7, 3000] { - let store = Arc::new(ChunkedStore::new(in_memory_store.clone(), chunk_size)); - let inferred_schema = arrow_format - .infer_schema( - &state, - &(store.clone() as Arc), - std::slice::from_ref(&object_meta), - ) - .await?; - let actual_fields = inferred_schema - .fields() - .iter() - .map(|f| format!("{}: {:?}", f.name(), f.data_type())) - .collect::>(); - assert_eq!(expected, actual_fields); - } - - Ok(()) - } - - #[tokio::test] - async fn test_infer_schema_short_stream() -> Result<()> { - let mut bytes = std::fs::read("tests/data/example.arrow")?; - bytes.truncate(20); // should cause error that file shorter than expected - let location = Path::parse("example.arrow")?; - let in_memory_store: Arc = Arc::new(InMemory::new()); - in_memory_store.put(&location, bytes.into()).await?; - - let session_ctx = SessionContext::new(); - let state = session_ctx.state(); - let object_meta = ObjectMeta { - location, - last_modified: DateTime::default(), - size: u64::MAX, - e_tag: None, - version: None, - }; - - let arrow_format = ArrowFormat {}; - - let store = Arc::new(ChunkedStore::new(in_memory_store.clone(), 7)); - let err = arrow_format - .infer_schema( - &state, - &(store.clone() as Arc), - std::slice::from_ref(&object_meta), - ) - .await; - - assert!(err.is_err()); - assert_eq!( - "Arrow error: Parser error: Unexpected end of byte stream for Arrow IPC file", - err.unwrap_err().to_string().lines().next().unwrap() - ); - - Ok(()) - } -} +//! Re-exports the [`datafusion_datasource_arrow::file_format`] module, and contains tests for it. +pub use datafusion_datasource_arrow::file_format::*; diff --git a/datafusion/core/src/datasource/file_format/parquet.rs b/datafusion/core/src/datasource/file_format/parquet.rs index 088c4408fff5..1781ea569d90 100644 --- a/datafusion/core/src/datasource/file_format/parquet.rs +++ b/datafusion/core/src/datasource/file_format/parquet.rs @@ -154,7 +154,6 @@ mod tests { use futures::stream::BoxStream; use futures::StreamExt; use insta::assert_snapshot; - use log::error; use object_store::local::LocalFileSystem; use object_store::ObjectMeta; use object_store::{ @@ -163,9 +162,10 @@ mod tests { }; use parquet::arrow::arrow_reader::ArrowReaderOptions; use parquet::arrow::ParquetRecordBatchStreamBuilder; - use parquet::file::metadata::{KeyValue, ParquetColumnIndex, ParquetOffsetIndex}; - use parquet::file::page_index::index::Index; - use parquet::format::FileMetaData; + use parquet::file::metadata::{ + KeyValue, ParquetColumnIndex, ParquetMetaData, ParquetOffsetIndex, + }; + use parquet::file::page_index::column_index::ColumnIndexMetaData; use tokio::fs::File; enum ForceViews { @@ -1144,18 +1144,14 @@ mod tests { // 325 pages in int_col assert_eq!(int_col_offset.len(), 325); - match int_col_index { - Index::INT32(index) => { - assert_eq!(index.indexes.len(), 325); - for min_max in index.clone().indexes { - assert!(min_max.min.is_some()); - assert!(min_max.max.is_some()); - assert!(min_max.null_count.is_some()); - } - } - _ => { - error!("fail to read page index.") - } + let ColumnIndexMetaData::INT32(index) = int_col_index else { + panic!("fail to read page index.") + }; + assert_eq!(index.min_values().len(), 325); + assert_eq!(index.max_values().len(), 325); + // all values are non null + for idx in 0..325 { + assert_eq!(index.null_count(idx), Some(0)); } } @@ -1556,7 +1552,7 @@ mod tests { Ok(parquet_sink) } - fn get_written(parquet_sink: Arc) -> Result<(Path, FileMetaData)> { + fn get_written(parquet_sink: Arc) -> Result<(Path, ParquetMetaData)> { let mut written = parquet_sink.written(); let written = written.drain(); assert_eq!( @@ -1566,28 +1562,33 @@ mod tests { written.len() ); - let (path, file_metadata) = written.take(1).next().unwrap(); - Ok((path, file_metadata)) + let (path, parquet_meta_data) = written.take(1).next().unwrap(); + Ok((path, parquet_meta_data)) } - fn assert_file_metadata(file_metadata: FileMetaData, expected_kv: &Vec) { - let FileMetaData { - num_rows, - schema, - key_value_metadata, - .. - } = file_metadata; - assert_eq!(num_rows, 2, "file metadata to have 2 rows"); + fn assert_file_metadata( + parquet_meta_data: ParquetMetaData, + expected_kv: &Vec, + ) { + let file_metadata = parquet_meta_data.file_metadata(); + let schema_descr = file_metadata.schema_descr(); + assert_eq!(file_metadata.num_rows(), 2, "file metadata to have 2 rows"); assert!( - schema.iter().any(|col_schema| col_schema.name == "a"), + schema_descr + .columns() + .iter() + .any(|col_schema| col_schema.name() == "a"), "output file metadata should contain col a" ); assert!( - schema.iter().any(|col_schema| col_schema.name == "b"), + schema_descr + .columns() + .iter() + .any(|col_schema| col_schema.name() == "b"), "output file metadata should contain col b" ); - let mut key_value_metadata = key_value_metadata.unwrap(); + let mut key_value_metadata = file_metadata.key_value_metadata().unwrap().clone(); key_value_metadata.sort_by(|a, b| a.key.cmp(&b.key)); assert_eq!(&key_value_metadata, expected_kv); } @@ -1644,13 +1645,11 @@ mod tests { // check the file metadata includes partitions let mut expected_partitions = std::collections::HashSet::from(["a=foo", "a=bar"]); - for ( - path, - FileMetaData { - num_rows, schema, .. - }, - ) in written.take(2) - { + for (path, parquet_metadata) in written.take(2) { + let file_metadata = parquet_metadata.file_metadata(); + let schema = file_metadata.schema_descr(); + let num_rows = file_metadata.num_rows(); + let path_parts = path.parts().collect::>(); assert_eq!(path_parts.len(), 2, "should have path prefix"); @@ -1663,11 +1662,17 @@ mod tests { assert_eq!(num_rows, 1, "file metadata to have 1 row"); assert!( - !schema.iter().any(|col_schema| col_schema.name == "a"), + !schema + .columns() + .iter() + .any(|col_schema| col_schema.name() == "a"), "output file metadata will not contain partitioned col a" ); assert!( - schema.iter().any(|col_schema| col_schema.name == "b"), + schema + .columns() + .iter() + .any(|col_schema| col_schema.name() == "b"), "output file metadata should contain col b" ); } diff --git a/datafusion/core/src/datasource/listing/mod.rs b/datafusion/core/src/datasource/listing/mod.rs index a58db55bccb6..c206566a6594 100644 --- a/datafusion/core/src/datasource/listing/mod.rs +++ b/datafusion/core/src/datasource/listing/mod.rs @@ -20,7 +20,8 @@ mod table; pub use datafusion_catalog_listing::helpers; +pub use datafusion_catalog_listing::{ListingOptions, ListingTable, ListingTableConfig}; pub use datafusion_datasource::{ FileRange, ListingTableUrl, PartitionedFile, PartitionedFileStream, }; -pub use table::{ListingOptions, ListingTable, ListingTableConfig}; +pub use table::ListingTableConfigExt; diff --git a/datafusion/core/src/datasource/listing/table.rs b/datafusion/core/src/datasource/listing/table.rs index 3ce58938d77e..3333b7067620 100644 --- a/datafusion/core/src/datasource/listing/table.rs +++ b/datafusion/core/src/datasource/listing/table.rs @@ -15,226 +15,42 @@ // specific language governing permissions and limitations // under the License. -//! The table implementation. - -use super::{ - helpers::{expr_applicable_for_cols, pruned_partition_list}, - ListingTableUrl, PartitionedFile, -}; -use crate::{ - datasource::file_format::{file_compression_type::FileCompressionType, FileFormat}, - datasource::physical_plan::FileSinkConfig, - execution::context::SessionState, -}; -use arrow::datatypes::{DataType, Field, SchemaBuilder, SchemaRef}; -use arrow_schema::Schema; +use crate::execution::SessionState; use async_trait::async_trait; -use datafusion_catalog::{ScanArgs, ScanResult, Session, TableProvider}; -use datafusion_common::{ - config_datafusion_err, config_err, internal_datafusion_err, internal_err, plan_err, - project_schema, stats::Precision, Constraints, DataFusionError, Result, SchemaExt, -}; -use datafusion_datasource::{ - compute_all_files_statistics, - file::FileSource, - file_groups::FileGroup, - file_scan_config::{FileScanConfig, FileScanConfigBuilder}, - schema_adapter::{DefaultSchemaAdapterFactory, SchemaAdapter, SchemaAdapterFactory}, -}; -use datafusion_execution::{ - cache::{cache_manager::FileStatisticsCache, cache_unit::DefaultFileStatisticsCache}, - config::SessionConfig, -}; -use datafusion_expr::execution_props::ExecutionProps; -use datafusion_expr::{ - dml::InsertOp, Expr, SortExpr, TableProviderFilterPushDown, TableType, -}; -use datafusion_physical_expr::create_lex_ordering; -use datafusion_physical_expr_adapter::PhysicalExprAdapterFactory; -use datafusion_physical_expr_common::sort_expr::LexOrdering; -use datafusion_physical_plan::{empty::EmptyExec, ExecutionPlan, Statistics}; -use futures::{future, stream, Stream, StreamExt, TryStreamExt}; -use itertools::Itertools; -use object_store::ObjectStore; -use std::{any::Any, collections::HashMap, str::FromStr, sync::Arc}; - -/// Indicates the source of the schema for a [`ListingTable`] -// PartialEq required for assert_eq! in tests -#[derive(Debug, Clone, Copy, PartialEq, Default)] -pub enum SchemaSource { - /// Schema is not yet set (initial state) - #[default] - Unset, - /// Schema was inferred from first table_path - Inferred, - /// Schema was specified explicitly via with_schema - Specified, -} +use datafusion_catalog_listing::{ListingOptions, ListingTableConfig}; +use datafusion_common::{config_datafusion_err, internal_datafusion_err}; +use datafusion_session::Session; +use futures::StreamExt; +use std::collections::HashMap; -/// Configuration for creating a [`ListingTable`] -/// -/// # Schema Evolution Support -/// -/// This configuration supports schema evolution through the optional -/// [`SchemaAdapterFactory`]. You might want to override the default factory when you need: -/// -/// - **Type coercion requirements**: When you need custom logic for converting between -/// different Arrow data types (e.g., Int32 ↔ Int64, Utf8 ↔ LargeUtf8) -/// - **Column mapping**: You need to map columns with a legacy name to a new name -/// - **Custom handling of missing columns**: By default they are filled in with nulls, but you may e.g. want to fill them in with `0` or `""`. +/// Extension trait for [`ListingTableConfig`] that supports inferring schemas /// -/// If not specified, a [`DefaultSchemaAdapterFactory`] will be used, which handles -/// basic schema compatibility cases. -/// -#[derive(Debug, Clone, Default)] -pub struct ListingTableConfig { - /// Paths on the `ObjectStore` for creating `ListingTable`. - /// They should share the same schema and object store. - pub table_paths: Vec, - /// Optional `SchemaRef` for the to be created `ListingTable`. - /// - /// See details on [`ListingTableConfig::with_schema`] - pub file_schema: Option, - /// Optional [`ListingOptions`] for the to be created [`ListingTable`]. - /// - /// See details on [`ListingTableConfig::with_listing_options`] - pub options: Option, - /// Tracks the source of the schema information - schema_source: SchemaSource, - /// Optional [`SchemaAdapterFactory`] for creating schema adapters - schema_adapter_factory: Option>, - /// Optional [`PhysicalExprAdapterFactory`] for creating physical expression adapters - expr_adapter_factory: Option>, -} - -impl ListingTableConfig { - /// Creates new [`ListingTableConfig`] for reading the specified URL - pub fn new(table_path: ListingTableUrl) -> Self { - Self { - table_paths: vec![table_path], - ..Default::default() - } - } - - /// Creates new [`ListingTableConfig`] with multiple table paths. - /// - /// See [`Self::infer_options`] for details on what happens with multiple paths - pub fn new_with_multi_paths(table_paths: Vec) -> Self { - Self { - table_paths, - ..Default::default() - } - } - - /// Returns the source of the schema for this configuration - pub fn schema_source(&self) -> SchemaSource { - self.schema_source - } - /// Set the `schema` for the overall [`ListingTable`] - /// - /// [`ListingTable`] will automatically coerce, when possible, the schema - /// for individual files to match this schema. - /// - /// If a schema is not provided, it is inferred using - /// [`Self::infer_schema`]. - /// - /// If the schema is provided, it must contain only the fields in the file - /// without the table partitioning columns. - /// - /// # Example: Specifying Table Schema - /// ```rust - /// # use std::sync::Arc; - /// # use datafusion::datasource::listing::{ListingTableConfig, ListingOptions, ListingTableUrl}; - /// # use datafusion::datasource::file_format::parquet::ParquetFormat; - /// # use arrow::datatypes::{Schema, Field, DataType}; - /// # let table_paths = ListingTableUrl::parse("file:///path/to/data").unwrap(); - /// # let listing_options = ListingOptions::new(Arc::new(ParquetFormat::default())); - /// let schema = Arc::new(Schema::new(vec![ - /// Field::new("id", DataType::Int64, false), - /// Field::new("name", DataType::Utf8, true), - /// ])); - /// - /// let config = ListingTableConfig::new(table_paths) - /// .with_listing_options(listing_options) // Set options first - /// .with_schema(schema); // Then set schema - /// ``` - pub fn with_schema(self, schema: SchemaRef) -> Self { - // Note: We preserve existing options state, but downstream code may expect - // options to be set. Consider calling with_listing_options() or infer_options() - // before operations that require options to be present. - debug_assert!( - self.options.is_some() || cfg!(test), - "ListingTableConfig::with_schema called without options set. \ - Consider calling with_listing_options() or infer_options() first to avoid panics in downstream code." - ); - - Self { - file_schema: Some(schema), - schema_source: SchemaSource::Specified, - ..self - } - } - - /// Add `listing_options` to [`ListingTableConfig`] - /// - /// If not provided, format and other options are inferred via - /// [`Self::infer_options`]. - /// - /// # Example: Configuring Parquet Files with Custom Options - /// ```rust - /// # use std::sync::Arc; - /// # use datafusion::datasource::listing::{ListingTableConfig, ListingOptions, ListingTableUrl}; - /// # use datafusion::datasource::file_format::parquet::ParquetFormat; - /// # let table_paths = ListingTableUrl::parse("file:///path/to/data").unwrap(); - /// let options = ListingOptions::new(Arc::new(ParquetFormat::default())) - /// .with_file_extension(".parquet") - /// .with_collect_stat(true); - /// - /// let config = ListingTableConfig::new(table_paths) - /// .with_listing_options(options); // Configure file format and options - /// ``` - pub fn with_listing_options(self, listing_options: ListingOptions) -> Self { - // Note: This method properly sets options, but be aware that downstream - // methods like infer_schema() and try_new() require both schema and options - // to be set to function correctly. - debug_assert!( - !self.table_paths.is_empty() || cfg!(test), - "ListingTableConfig::with_listing_options called without table_paths set. \ - Consider calling new() or new_with_multi_paths() first to establish table paths." - ); - - Self { - options: Some(listing_options), - ..self - } - } - - /// Returns a tuple of `(file_extension, optional compression_extension)` - /// - /// For example a path ending with blah.test.csv.gz returns `("csv", Some("gz"))` - /// For example a path ending with blah.test.csv returns `("csv", None)` - fn infer_file_extension_and_compression_type( - path: &str, - ) -> Result<(String, Option)> { - let mut exts = path.rsplit('.'); - - let split = exts.next().unwrap_or(""); - - let file_compression_type = FileCompressionType::from_str(split) - .unwrap_or(FileCompressionType::UNCOMPRESSED); - - if file_compression_type.is_compressed() { - let split2 = exts.next().unwrap_or(""); - Ok((split2.to_string(), Some(split.to_string()))) - } else { - Ok((split.to_string(), None)) - } - } - +/// This trait exists because the following inference methods only +/// work for [`SessionState`] implementations of [`Session`]. +/// See [`ListingTableConfig`] for the remaining inference methods. +#[async_trait] +pub trait ListingTableConfigExt { /// Infer `ListingOptions` based on `table_path` and file suffix. /// /// The format is inferred based on the first `table_path`. - pub async fn infer_options(self, state: &dyn Session) -> Result { + async fn infer_options( + self, + state: &dyn Session, + ) -> datafusion_common::Result; + + /// Convenience method to call both [`Self::infer_options`] and [`ListingTableConfig::infer_schema`] + async fn infer( + self, + state: &dyn Session, + ) -> datafusion_common::Result; +} + +#[async_trait] +impl ListingTableConfigExt for ListingTableConfig { + async fn infer_options( + self, + state: &dyn Session, + ) -> datafusion_common::Result { let store = if let Some(url) = self.table_paths.first() { state.runtime_env().object_store(url)? } else { @@ -281,1299 +97,19 @@ impl ListingTableConfig { .with_target_partitions(state.config().target_partitions()) .with_collect_stat(state.config().collect_statistics()); - Ok(Self { - table_paths: self.table_paths, - file_schema: self.file_schema, - options: Some(listing_options), - schema_source: self.schema_source, - schema_adapter_factory: self.schema_adapter_factory, - expr_adapter_factory: self.expr_adapter_factory, - }) + Ok(self.with_listing_options(listing_options)) } - /// Infer the [`SchemaRef`] based on `table_path`s. - /// - /// This method infers the table schema using the first `table_path`. - /// See [`ListingOptions::infer_schema`] for more details - /// - /// # Errors - /// * if `self.options` is not set. See [`Self::with_listing_options`] - pub async fn infer_schema(self, state: &dyn Session) -> Result { - match self.options { - Some(options) => { - let ListingTableConfig { - table_paths, - file_schema, - options: _, - schema_source, - schema_adapter_factory, - expr_adapter_factory: physical_expr_adapter_factory, - } = self; - - let (schema, new_schema_source) = match file_schema { - Some(schema) => (schema, schema_source), // Keep existing source if schema exists - None => { - if let Some(url) = table_paths.first() { - ( - options.infer_schema(state, url).await?, - SchemaSource::Inferred, - ) - } else { - (Arc::new(Schema::empty()), SchemaSource::Inferred) - } - } - }; - - Ok(Self { - table_paths, - file_schema: Some(schema), - options: Some(options), - schema_source: new_schema_source, - schema_adapter_factory, - expr_adapter_factory: physical_expr_adapter_factory, - }) - } - None => internal_err!("No `ListingOptions` set for inferring schema"), - } - } - - /// Convenience method to call both [`Self::infer_options`] and [`Self::infer_schema`] - pub async fn infer(self, state: &dyn Session) -> Result { + async fn infer(self, state: &dyn Session) -> datafusion_common::Result { self.infer_options(state).await?.infer_schema(state).await } - - /// Infer the partition columns from `table_paths`. - /// - /// # Errors - /// * if `self.options` is not set. See [`Self::with_listing_options`] - pub async fn infer_partitions_from_path(self, state: &dyn Session) -> Result { - match self.options { - Some(options) => { - let Some(url) = self.table_paths.first() else { - return config_err!("No table path found"); - }; - let partitions = options - .infer_partitions(state, url) - .await? - .into_iter() - .map(|col_name| { - ( - col_name, - DataType::Dictionary( - Box::new(DataType::UInt16), - Box::new(DataType::Utf8), - ), - ) - }) - .collect::>(); - let options = options.with_table_partition_cols(partitions); - Ok(Self { - table_paths: self.table_paths, - file_schema: self.file_schema, - options: Some(options), - schema_source: self.schema_source, - schema_adapter_factory: self.schema_adapter_factory, - expr_adapter_factory: self.expr_adapter_factory, - }) - } - None => config_err!("No `ListingOptions` set for inferring schema"), - } - } - - /// Set the [`SchemaAdapterFactory`] for the [`ListingTable`] - /// - /// The schema adapter factory is used to create schema adapters that can - /// handle schema evolution and type conversions when reading files with - /// different schemas than the table schema. - /// - /// If not provided, a default schema adapter factory will be used. - /// - /// # Example: Custom Schema Adapter for Type Coercion - /// ```rust - /// # use std::sync::Arc; - /// # use datafusion::datasource::listing::{ListingTableConfig, ListingOptions, ListingTableUrl}; - /// # use datafusion::datasource::schema_adapter::{SchemaAdapterFactory, SchemaAdapter}; - /// # use datafusion::datasource::file_format::parquet::ParquetFormat; - /// # use arrow::datatypes::{SchemaRef, Schema, Field, DataType}; - /// # - /// # #[derive(Debug)] - /// # struct MySchemaAdapterFactory; - /// # impl SchemaAdapterFactory for MySchemaAdapterFactory { - /// # fn create(&self, _projected_table_schema: SchemaRef, _file_schema: SchemaRef) -> Box { - /// # unimplemented!() - /// # } - /// # } - /// # let table_paths = ListingTableUrl::parse("file:///path/to/data").unwrap(); - /// # let listing_options = ListingOptions::new(Arc::new(ParquetFormat::default())); - /// # let table_schema = Arc::new(Schema::new(vec![Field::new("id", DataType::Int64, false)])); - /// let config = ListingTableConfig::new(table_paths) - /// .with_listing_options(listing_options) - /// .with_schema(table_schema) - /// .with_schema_adapter_factory(Arc::new(MySchemaAdapterFactory)); - /// ``` - pub fn with_schema_adapter_factory( - self, - schema_adapter_factory: Arc, - ) -> Self { - Self { - schema_adapter_factory: Some(schema_adapter_factory), - ..self - } - } - - /// Get the [`SchemaAdapterFactory`] for this configuration - pub fn schema_adapter_factory(&self) -> Option<&Arc> { - self.schema_adapter_factory.as_ref() - } - - /// Set the [`PhysicalExprAdapterFactory`] for the [`ListingTable`] - /// - /// The expression adapter factory is used to create physical expression adapters that can - /// handle schema evolution and type conversions when evaluating expressions - /// with different schemas than the table schema. - /// - /// If not provided, a default physical expression adapter factory will be used unless a custom - /// `SchemaAdapterFactory` is set, in which case only the `SchemaAdapterFactory` will be used. - /// - /// See for details on this transition. - pub fn with_expr_adapter_factory( - self, - expr_adapter_factory: Arc, - ) -> Self { - Self { - expr_adapter_factory: Some(expr_adapter_factory), - ..self - } - } -} - -/// Options for creating a [`ListingTable`] -#[derive(Clone, Debug)] -pub struct ListingOptions { - /// A suffix on which files should be filtered (leave empty to - /// keep all files on the path) - pub file_extension: String, - /// The file format - pub format: Arc, - /// The expected partition column names in the folder structure. - /// See [Self::with_table_partition_cols] for details - pub table_partition_cols: Vec<(String, DataType)>, - /// Set true to try to guess statistics from the files. - /// This can add a lot of overhead as it will usually require files - /// to be opened and at least partially parsed. - pub collect_stat: bool, - /// Group files to avoid that the number of partitions exceeds - /// this limit - pub target_partitions: usize, - /// Optional pre-known sort order(s). Must be `SortExpr`s. - /// - /// DataFusion may take advantage of this ordering to omit sorts - /// or use more efficient algorithms. Currently sortedness must be - /// provided if it is known by some external mechanism, but may in - /// the future be automatically determined, for example using - /// parquet metadata. - /// - /// See - /// - /// NOTE: This attribute stores all equivalent orderings (the outer `Vec`) - /// where each ordering consists of an individual lexicographic - /// ordering (encapsulated by a `Vec`). If there aren't - /// multiple equivalent orderings, the outer `Vec` will have a - /// single element. - pub file_sort_order: Vec>, -} - -impl ListingOptions { - /// Creates an options instance with the given format - /// Default values: - /// - use default file extension filter - /// - no input partition to discover - /// - one target partition - /// - do not collect statistics - pub fn new(format: Arc) -> Self { - Self { - file_extension: format.get_ext(), - format, - table_partition_cols: vec![], - collect_stat: false, - target_partitions: 1, - file_sort_order: vec![], - } - } - - /// Set options from [`SessionConfig`] and returns self. - /// - /// Currently this sets `target_partitions` and `collect_stat` - /// but if more options are added in the future that need to be coordinated - /// they will be synchronized through this method. - pub fn with_session_config_options(mut self, config: &SessionConfig) -> Self { - self = self.with_target_partitions(config.target_partitions()); - self = self.with_collect_stat(config.collect_statistics()); - self - } - - /// Set file extension on [`ListingOptions`] and returns self. - /// - /// # Example - /// ``` - /// # use std::sync::Arc; - /// # use datafusion::prelude::SessionContext; - /// # use datafusion::datasource::{listing::ListingOptions, file_format::parquet::ParquetFormat}; - /// - /// let listing_options = ListingOptions::new(Arc::new( - /// ParquetFormat::default() - /// )) - /// .with_file_extension(".parquet"); - /// - /// assert_eq!(listing_options.file_extension, ".parquet"); - /// ``` - pub fn with_file_extension(mut self, file_extension: impl Into) -> Self { - self.file_extension = file_extension.into(); - self - } - - /// Optionally set file extension on [`ListingOptions`] and returns self. - /// - /// If `file_extension` is `None`, the file extension will not be changed - /// - /// # Example - /// ``` - /// # use std::sync::Arc; - /// # use datafusion::prelude::SessionContext; - /// # use datafusion::datasource::{listing::ListingOptions, file_format::parquet::ParquetFormat}; - /// let extension = Some(".parquet"); - /// let listing_options = ListingOptions::new(Arc::new( - /// ParquetFormat::default() - /// )) - /// .with_file_extension_opt(extension); - /// - /// assert_eq!(listing_options.file_extension, ".parquet"); - /// ``` - pub fn with_file_extension_opt(mut self, file_extension: Option) -> Self - where - S: Into, - { - if let Some(file_extension) = file_extension { - self.file_extension = file_extension.into(); - } - self - } - - /// Set `table partition columns` on [`ListingOptions`] and returns self. - /// - /// "partition columns," used to support [Hive Partitioning], are - /// columns added to the data that is read, based on the folder - /// structure where the data resides. - /// - /// For example, give the following files in your filesystem: - /// - /// ```text - /// /mnt/nyctaxi/year=2022/month=01/tripdata.parquet - /// /mnt/nyctaxi/year=2021/month=12/tripdata.parquet - /// /mnt/nyctaxi/year=2021/month=11/tripdata.parquet - /// ``` - /// - /// A [`ListingTable`] created at `/mnt/nyctaxi/` with partition - /// columns "year" and "month" will include new `year` and `month` - /// columns while reading the files. The `year` column would have - /// value `2022` and the `month` column would have value `01` for - /// the rows read from - /// `/mnt/nyctaxi/year=2022/month=01/tripdata.parquet` - /// - ///# Notes - /// - /// - If only one level (e.g. `year` in the example above) is - /// specified, the other levels are ignored but the files are - /// still read. - /// - /// - Files that don't follow this partitioning scheme will be - /// ignored. - /// - /// - Since the columns have the same value for all rows read from - /// each individual file (such as dates), they are typically - /// dictionary encoded for efficiency. You may use - /// [`wrap_partition_type_in_dict`] to request a - /// dictionary-encoded type. - /// - /// - The partition columns are solely extracted from the file path. Especially they are NOT part of the parquet files itself. - /// - /// # Example - /// - /// ``` - /// # use std::sync::Arc; - /// # use arrow::datatypes::DataType; - /// # use datafusion::prelude::col; - /// # use datafusion::datasource::{listing::ListingOptions, file_format::parquet::ParquetFormat}; - /// - /// // listing options for files with paths such as `/mnt/data/col_a=x/col_b=y/data.parquet` - /// // `col_a` and `col_b` will be included in the data read from those files - /// let listing_options = ListingOptions::new(Arc::new( - /// ParquetFormat::default() - /// )) - /// .with_table_partition_cols(vec![("col_a".to_string(), DataType::Utf8), - /// ("col_b".to_string(), DataType::Utf8)]); - /// - /// assert_eq!(listing_options.table_partition_cols, vec![("col_a".to_string(), DataType::Utf8), - /// ("col_b".to_string(), DataType::Utf8)]); - /// ``` - /// - /// [Hive Partitioning]: https://docs.cloudera.com/HDPDocuments/HDP2/HDP-2.1.3/bk_system-admin-guide/content/hive_partitioned_tables.html - /// [`wrap_partition_type_in_dict`]: crate::datasource::physical_plan::wrap_partition_type_in_dict - pub fn with_table_partition_cols( - mut self, - table_partition_cols: Vec<(String, DataType)>, - ) -> Self { - self.table_partition_cols = table_partition_cols; - self - } - - /// Set stat collection on [`ListingOptions`] and returns self. - /// - /// ``` - /// # use std::sync::Arc; - /// # use datafusion::datasource::{listing::ListingOptions, file_format::parquet::ParquetFormat}; - /// - /// let listing_options = ListingOptions::new(Arc::new( - /// ParquetFormat::default() - /// )) - /// .with_collect_stat(true); - /// - /// assert_eq!(listing_options.collect_stat, true); - /// ``` - pub fn with_collect_stat(mut self, collect_stat: bool) -> Self { - self.collect_stat = collect_stat; - self - } - - /// Set number of target partitions on [`ListingOptions`] and returns self. - /// - /// ``` - /// # use std::sync::Arc; - /// # use datafusion::datasource::{listing::ListingOptions, file_format::parquet::ParquetFormat}; - /// - /// let listing_options = ListingOptions::new(Arc::new( - /// ParquetFormat::default() - /// )) - /// .with_target_partitions(8); - /// - /// assert_eq!(listing_options.target_partitions, 8); - /// ``` - pub fn with_target_partitions(mut self, target_partitions: usize) -> Self { - self.target_partitions = target_partitions; - self - } - - /// Set file sort order on [`ListingOptions`] and returns self. - /// - /// ``` - /// # use std::sync::Arc; - /// # use datafusion::prelude::col; - /// # use datafusion::datasource::{listing::ListingOptions, file_format::parquet::ParquetFormat}; - /// - /// // Tell datafusion that the files are sorted by column "a" - /// let file_sort_order = vec![vec![ - /// col("a").sort(true, true) - /// ]]; - /// - /// let listing_options = ListingOptions::new(Arc::new( - /// ParquetFormat::default() - /// )) - /// .with_file_sort_order(file_sort_order.clone()); - /// - /// assert_eq!(listing_options.file_sort_order, file_sort_order); - /// ``` - pub fn with_file_sort_order(mut self, file_sort_order: Vec>) -> Self { - self.file_sort_order = file_sort_order; - self - } - - /// Infer the schema of the files at the given path on the provided object store. - /// - /// If the table_path contains one or more files (i.e. it is a directory / - /// prefix of files) their schema is merged by calling [`FileFormat::infer_schema`] - /// - /// Note: The inferred schema does not include any partitioning columns. - /// - /// This method is called as part of creating a [`ListingTable`]. - pub async fn infer_schema<'a>( - &'a self, - state: &dyn Session, - table_path: &'a ListingTableUrl, - ) -> Result { - let store = state.runtime_env().object_store(table_path)?; - - let files: Vec<_> = table_path - .list_all_files(state, store.as_ref(), &self.file_extension) - .await? - // Empty files cannot affect schema but may throw when trying to read for it - .try_filter(|object_meta| future::ready(object_meta.size > 0)) - .try_collect() - .await?; - - let schema = self.format.infer_schema(state, &store, &files).await?; - - Ok(schema) - } - - /// Infers the partition columns stored in `LOCATION` and compares - /// them with the columns provided in `PARTITIONED BY` to help prevent - /// accidental corrupts of partitioned tables. - /// - /// Allows specifying partial partitions. - pub async fn validate_partitions( - &self, - state: &dyn Session, - table_path: &ListingTableUrl, - ) -> Result<()> { - if self.table_partition_cols.is_empty() { - return Ok(()); - } - - if !table_path.is_collection() { - return plan_err!( - "Can't create a partitioned table backed by a single file, \ - perhaps the URL is missing a trailing slash?" - ); - } - - let inferred = self.infer_partitions(state, table_path).await?; - - // no partitioned files found on disk - if inferred.is_empty() { - return Ok(()); - } - - let table_partition_names = self - .table_partition_cols - .iter() - .map(|(col_name, _)| col_name.clone()) - .collect_vec(); - - if inferred.len() < table_partition_names.len() { - return plan_err!( - "Inferred partitions to be {:?}, but got {:?}", - inferred, - table_partition_names - ); - } - - // match prefix to allow creating tables with partial partitions - for (idx, col) in table_partition_names.iter().enumerate() { - if &inferred[idx] != col { - return plan_err!( - "Inferred partitions to be {:?}, but got {:?}", - inferred, - table_partition_names - ); - } - } - - Ok(()) - } - - /// Infer the partitioning at the given path on the provided object store. - /// For performance reasons, it doesn't read all the files on disk - /// and therefore may fail to detect invalid partitioning. - pub(crate) async fn infer_partitions( - &self, - state: &dyn Session, - table_path: &ListingTableUrl, - ) -> Result> { - let store = state.runtime_env().object_store(table_path)?; - - // only use 10 files for inference - // This can fail to detect inconsistent partition keys - // A DFS traversal approach of the store can help here - let files: Vec<_> = table_path - .list_all_files(state, store.as_ref(), &self.file_extension) - .await? - .take(10) - .try_collect() - .await?; - - let stripped_path_parts = files.iter().map(|file| { - table_path - .strip_prefix(&file.location) - .unwrap() - .collect_vec() - }); - - let partition_keys = stripped_path_parts - .map(|path_parts| { - path_parts - .into_iter() - .rev() - .skip(1) // get parents only; skip the file itself - .rev() - // Partitions are expected to follow the format "column_name=value", so we - // should ignore any path part that cannot be parsed into the expected format - .filter(|s| s.contains('=')) - .map(|s| s.split('=').take(1).collect()) - .collect_vec() - }) - .collect_vec(); - - match partition_keys.into_iter().all_equal_value() { - Ok(v) => Ok(v), - Err(None) => Ok(vec![]), - Err(Some(diff)) => { - let mut sorted_diff = [diff.0, diff.1]; - sorted_diff.sort(); - plan_err!("Found mixed partition values on disk {:?}", sorted_diff) - } - } - } -} - -/// Built in [`TableProvider`] that reads data from one or more files as a single table. -/// -/// The files are read using an [`ObjectStore`] instance, for example from -/// local files or objects from AWS S3. -/// -/// # Features: -/// * Reading multiple files as a single table -/// * Hive style partitioning (e.g., directories named `date=2024-06-01`) -/// * Merges schemas from files with compatible but not identical schemas (see [`ListingTableConfig::file_schema`]) -/// * `limit`, `filter` and `projection` pushdown for formats that support it (e.g., -/// Parquet) -/// * Statistics collection and pruning based on file metadata -/// * Pre-existing sort order (see [`ListingOptions::file_sort_order`]) -/// * Metadata caching to speed up repeated queries (see [`FileMetadataCache`]) -/// * Statistics caching (see [`FileStatisticsCache`]) -/// -/// [`FileMetadataCache`]: datafusion_execution::cache::cache_manager::FileMetadataCache -/// -/// # Reading Directories and Hive Style Partitioning -/// -/// For example, given the `table1` directory (or object store prefix) -/// -/// ```text -/// table1 -/// ├── file1.parquet -/// └── file2.parquet -/// ``` -/// -/// A `ListingTable` would read the files `file1.parquet` and `file2.parquet` as -/// a single table, merging the schemas if the files have compatible but not -/// identical schemas. -/// -/// Given the `table2` directory (or object store prefix) -/// -/// ```text -/// table2 -/// ├── date=2024-06-01 -/// │ ├── file3.parquet -/// │ └── file4.parquet -/// └── date=2024-06-02 -/// └── file5.parquet -/// ``` -/// -/// A `ListingTable` would read the files `file3.parquet`, `file4.parquet`, and -/// `file5.parquet` as a single table, again merging schemas if necessary. -/// -/// Given the hive style partitioning structure (e.g,. directories named -/// `date=2024-06-01` and `date=2026-06-02`), `ListingTable` also adds a `date` -/// column when reading the table: -/// * The files in `table2/date=2024-06-01` will have the value `2024-06-01` -/// * The files in `table2/date=2024-06-02` will have the value `2024-06-02`. -/// -/// If the query has a predicate like `WHERE date = '2024-06-01'` -/// only the corresponding directory will be read. -/// -/// # See Also -/// -/// 1. [`ListingTableConfig`]: Configuration options -/// 1. [`DataSourceExec`]: `ExecutionPlan` used by `ListingTable` -/// -/// [`DataSourceExec`]: crate::datasource::source::DataSourceExec -/// -/// # Caching Metadata -/// -/// Some formats, such as Parquet, use the `FileMetadataCache` to cache file -/// metadata that is needed to execute but expensive to read, such as row -/// groups and statistics. The cache is scoped to the [`SessionContext`] and can -/// be configured via the [runtime config options]. -/// -/// [`SessionContext`]: crate::prelude::SessionContext -/// [runtime config options]: https://datafusion.apache.org/user-guide/configs.html#runtime-configuration-settings -/// -/// # Example: Read a directory of parquet files using a [`ListingTable`] -/// -/// ```no_run -/// # use datafusion::prelude::SessionContext; -/// # use datafusion::error::Result; -/// # use std::sync::Arc; -/// # use datafusion::datasource::{ -/// # listing::{ -/// # ListingOptions, ListingTable, ListingTableConfig, ListingTableUrl, -/// # }, -/// # file_format::parquet::ParquetFormat, -/// # }; -/// # #[tokio::main] -/// # async fn main() -> Result<()> { -/// let ctx = SessionContext::new(); -/// let session_state = ctx.state(); -/// let table_path = "/path/to/parquet"; -/// -/// // Parse the path -/// let table_path = ListingTableUrl::parse(table_path)?; -/// -/// // Create default parquet options -/// let file_format = ParquetFormat::new(); -/// let listing_options = ListingOptions::new(Arc::new(file_format)) -/// .with_file_extension(".parquet"); -/// -/// // Resolve the schema -/// let resolved_schema = listing_options -/// .infer_schema(&session_state, &table_path) -/// .await?; -/// -/// let config = ListingTableConfig::new(table_path) -/// .with_listing_options(listing_options) -/// .with_schema(resolved_schema); -/// -/// // Create a new TableProvider -/// let provider = Arc::new(ListingTable::try_new(config)?); -/// -/// // This provider can now be read as a dataframe: -/// let df = ctx.read_table(provider.clone()); -/// -/// // or registered as a named table: -/// ctx.register_table("my_table", provider); -/// -/// # Ok(()) -/// # } -/// ``` -#[derive(Debug, Clone)] -pub struct ListingTable { - table_paths: Vec, - /// `file_schema` contains only the columns physically stored in the data files themselves. - /// - Represents the actual fields found in files like Parquet, CSV, etc. - /// - Used when reading the raw data from files - file_schema: SchemaRef, - /// `table_schema` combines `file_schema` + partition columns - /// - Partition columns are derived from directory paths (not stored in files) - /// - These are columns like "year=2022/month=01" in paths like `/data/year=2022/month=01/file.parquet` - table_schema: SchemaRef, - /// Indicates how the schema was derived (inferred or explicitly specified) - schema_source: SchemaSource, - /// Options used to configure the listing table such as the file format - /// and partitioning information - options: ListingOptions, - /// The SQL definition for this table, if any - definition: Option, - /// Cache for collected file statistics - collected_statistics: FileStatisticsCache, - /// Constraints applied to this table - constraints: Constraints, - /// Column default expressions for columns that are not physically present in the data files - column_defaults: HashMap, - /// Optional [`SchemaAdapterFactory`] for creating schema adapters - schema_adapter_factory: Option>, - /// Optional [`PhysicalExprAdapterFactory`] for creating physical expression adapters - expr_adapter_factory: Option>, -} - -impl ListingTable { - /// Create new [`ListingTable`] - /// - /// See documentation and example on [`ListingTable`] and [`ListingTableConfig`] - pub fn try_new(config: ListingTableConfig) -> Result { - // Extract schema_source before moving other parts of the config - let schema_source = config.schema_source(); - - let file_schema = config - .file_schema - .ok_or_else(|| internal_datafusion_err!("No schema provided."))?; - - let options = config - .options - .ok_or_else(|| internal_datafusion_err!("No ListingOptions provided"))?; - - // Add the partition columns to the file schema - let mut builder = SchemaBuilder::from(file_schema.as_ref().to_owned()); - for (part_col_name, part_col_type) in &options.table_partition_cols { - builder.push(Field::new(part_col_name, part_col_type.clone(), false)); - } - - let table_schema = Arc::new( - builder - .finish() - .with_metadata(file_schema.metadata().clone()), - ); - - let table = Self { - table_paths: config.table_paths, - file_schema, - table_schema, - schema_source, - options, - definition: None, - collected_statistics: Arc::new(DefaultFileStatisticsCache::default()), - constraints: Constraints::default(), - column_defaults: HashMap::new(), - schema_adapter_factory: config.schema_adapter_factory, - expr_adapter_factory: config.expr_adapter_factory, - }; - - Ok(table) - } - - /// Assign constraints - pub fn with_constraints(mut self, constraints: Constraints) -> Self { - self.constraints = constraints; - self - } - - /// Assign column defaults - pub fn with_column_defaults( - mut self, - column_defaults: HashMap, - ) -> Self { - self.column_defaults = column_defaults; - self - } - - /// Set the [`FileStatisticsCache`] used to cache parquet file statistics. - /// - /// Setting a statistics cache on the `SessionContext` can avoid refetching statistics - /// multiple times in the same session. - /// - /// If `None`, creates a new [`DefaultFileStatisticsCache`] scoped to this query. - pub fn with_cache(mut self, cache: Option) -> Self { - self.collected_statistics = - cache.unwrap_or_else(|| Arc::new(DefaultFileStatisticsCache::default())); - self - } - - /// Specify the SQL definition for this table, if any - pub fn with_definition(mut self, definition: Option) -> Self { - self.definition = definition; - self - } - - /// Get paths ref - pub fn table_paths(&self) -> &Vec { - &self.table_paths - } - - /// Get options ref - pub fn options(&self) -> &ListingOptions { - &self.options - } - - /// Get the schema source - pub fn schema_source(&self) -> SchemaSource { - self.schema_source - } - - /// Set the [`SchemaAdapterFactory`] for this [`ListingTable`] - /// - /// The schema adapter factory is used to create schema adapters that can - /// handle schema evolution and type conversions when reading files with - /// different schemas than the table schema. - /// - /// # Example: Adding Schema Evolution Support - /// ```rust - /// # use std::sync::Arc; - /// # use datafusion::datasource::listing::{ListingTable, ListingTableConfig, ListingOptions, ListingTableUrl}; - /// # use datafusion::datasource::schema_adapter::{DefaultSchemaAdapterFactory, SchemaAdapter}; - /// # use datafusion::datasource::file_format::parquet::ParquetFormat; - /// # use arrow::datatypes::{SchemaRef, Schema, Field, DataType}; - /// # let table_path = ListingTableUrl::parse("file:///path/to/data").unwrap(); - /// # let options = ListingOptions::new(Arc::new(ParquetFormat::default())); - /// # let schema = Arc::new(Schema::new(vec![Field::new("id", DataType::Int64, false)])); - /// # let config = ListingTableConfig::new(table_path).with_listing_options(options).with_schema(schema); - /// # let table = ListingTable::try_new(config).unwrap(); - /// let table_with_evolution = table - /// .with_schema_adapter_factory(Arc::new(DefaultSchemaAdapterFactory)); - /// ``` - /// See [`ListingTableConfig::with_schema_adapter_factory`] for an example of custom SchemaAdapterFactory. - pub fn with_schema_adapter_factory( - self, - schema_adapter_factory: Arc, - ) -> Self { - Self { - schema_adapter_factory: Some(schema_adapter_factory), - ..self - } - } - - /// Get the [`SchemaAdapterFactory`] for this table - pub fn schema_adapter_factory(&self) -> Option<&Arc> { - self.schema_adapter_factory.as_ref() - } - - /// Creates a schema adapter for mapping between file and table schemas - /// - /// Uses the configured schema adapter factory if available, otherwise falls back - /// to the default implementation. - fn create_schema_adapter(&self) -> Box { - let table_schema = self.schema(); - match &self.schema_adapter_factory { - Some(factory) => { - factory.create_with_projected_schema(Arc::clone(&table_schema)) - } - None => DefaultSchemaAdapterFactory::from_schema(Arc::clone(&table_schema)), - } - } - - /// Creates a file source and applies schema adapter factory if available - fn create_file_source_with_schema_adapter(&self) -> Result> { - let mut source = self.options.format.file_source(); - // Apply schema adapter to source if available - // - // The source will use this SchemaAdapter to adapt data batches as they flow up the plan. - // Note: ListingTable also creates a SchemaAdapter in `scan()` but that is only used to adapt collected statistics. - if let Some(factory) = &self.schema_adapter_factory { - source = source.with_schema_adapter_factory(Arc::clone(factory))?; - } - Ok(source) - } - - /// If file_sort_order is specified, creates the appropriate physical expressions - fn try_create_output_ordering( - &self, - execution_props: &ExecutionProps, - ) -> Result> { - create_lex_ordering( - &self.table_schema, - &self.options.file_sort_order, - execution_props, - ) - } -} - -// Expressions can be used for partition pruning if they can be evaluated using -// only the partition columns and there are partition columns. -fn can_be_evaluated_for_partition_pruning( - partition_column_names: &[&str], - expr: &Expr, -) -> bool { - !partition_column_names.is_empty() - && expr_applicable_for_cols(partition_column_names, expr) -} - -#[async_trait] -impl TableProvider for ListingTable { - fn as_any(&self) -> &dyn Any { - self - } - - fn schema(&self) -> SchemaRef { - Arc::clone(&self.table_schema) - } - - fn constraints(&self) -> Option<&Constraints> { - Some(&self.constraints) - } - - fn table_type(&self) -> TableType { - TableType::Base - } - - async fn scan( - &self, - state: &dyn Session, - projection: Option<&Vec>, - filters: &[Expr], - limit: Option, - ) -> Result> { - let options = ScanArgs::default() - .with_projection(projection.map(|p| p.as_slice())) - .with_filters(Some(filters)) - .with_limit(limit); - Ok(self.scan_with_args(state, options).await?.into_inner()) - } - - async fn scan_with_args<'a>( - &self, - state: &dyn Session, - args: ScanArgs<'a>, - ) -> Result { - let projection = args.projection().map(|p| p.to_vec()); - let filters = args.filters().map(|f| f.to_vec()).unwrap_or_default(); - let limit = args.limit(); - - // extract types of partition columns - let table_partition_cols = self - .options - .table_partition_cols - .iter() - .map(|col| Ok(self.table_schema.field_with_name(&col.0)?.clone())) - .collect::>>()?; - - let table_partition_col_names = table_partition_cols - .iter() - .map(|field| field.name().as_str()) - .collect::>(); - - // If the filters can be resolved using only partition cols, there is no need to - // pushdown it to TableScan, otherwise, `unhandled` pruning predicates will be generated - let (partition_filters, filters): (Vec<_>, Vec<_>) = - filters.iter().cloned().partition(|filter| { - can_be_evaluated_for_partition_pruning(&table_partition_col_names, filter) - }); - - // We should not limit the number of partitioned files to scan if there are filters and limit - // at the same time. This is because the limit should be applied after the filters are applied. - let statistic_file_limit = if filters.is_empty() { limit } else { None }; - - let (mut partitioned_file_lists, statistics) = self - .list_files_for_scan(state, &partition_filters, statistic_file_limit) - .await?; - - // if no files need to be read, return an `EmptyExec` - if partitioned_file_lists.is_empty() { - let projected_schema = project_schema(&self.schema(), projection.as_ref())?; - return Ok(ScanResult::new(Arc::new(EmptyExec::new(projected_schema)))); - } - - let output_ordering = self.try_create_output_ordering(state.execution_props())?; - match state - .config_options() - .execution - .split_file_groups_by_statistics - .then(|| { - output_ordering.first().map(|output_ordering| { - FileScanConfig::split_groups_by_statistics_with_target_partitions( - &self.table_schema, - &partitioned_file_lists, - output_ordering, - self.options.target_partitions, - ) - }) - }) - .flatten() - { - Some(Err(e)) => log::debug!("failed to split file groups by statistics: {e}"), - Some(Ok(new_groups)) => { - if new_groups.len() <= self.options.target_partitions { - partitioned_file_lists = new_groups; - } else { - log::debug!("attempted to split file groups by statistics, but there were more file groups than target_partitions; falling back to unordered") - } - } - None => {} // no ordering required - }; - - let Some(object_store_url) = - self.table_paths.first().map(ListingTableUrl::object_store) - else { - return Ok(ScanResult::new(Arc::new(EmptyExec::new(Arc::new( - Schema::empty(), - ))))); - }; - - let file_source = self.create_file_source_with_schema_adapter()?; - - // create the execution plan - let plan = self - .options - .format - .create_physical_plan( - state, - FileScanConfigBuilder::new( - object_store_url, - Arc::clone(&self.file_schema), - file_source, - ) - .with_file_groups(partitioned_file_lists) - .with_constraints(self.constraints.clone()) - .with_statistics(statistics) - .with_projection(projection) - .with_limit(limit) - .with_output_ordering(output_ordering) - .with_table_partition_cols(table_partition_cols) - .with_expr_adapter(self.expr_adapter_factory.clone()) - .build(), - ) - .await?; - - Ok(ScanResult::new(plan)) - } - - fn supports_filters_pushdown( - &self, - filters: &[&Expr], - ) -> Result> { - let partition_column_names = self - .options - .table_partition_cols - .iter() - .map(|col| col.0.as_str()) - .collect::>(); - filters - .iter() - .map(|filter| { - if can_be_evaluated_for_partition_pruning(&partition_column_names, filter) - { - // if filter can be handled by partition pruning, it is exact - return Ok(TableProviderFilterPushDown::Exact); - } - - Ok(TableProviderFilterPushDown::Inexact) - }) - .collect() - } - - fn get_table_definition(&self) -> Option<&str> { - self.definition.as_deref() - } - - async fn insert_into( - &self, - state: &dyn Session, - input: Arc, - insert_op: InsertOp, - ) -> Result> { - // Check that the schema of the plan matches the schema of this table. - self.schema() - .logically_equivalent_names_and_types(&input.schema())?; - - let table_path = &self.table_paths()[0]; - if !table_path.is_collection() { - return plan_err!( - "Inserting into a ListingTable backed by a single file is not supported, URL is possibly missing a trailing `/`. \ - To append to an existing file use StreamTable, e.g. by using CREATE UNBOUNDED EXTERNAL TABLE" - ); - } - - // Get the object store for the table path. - let store = state.runtime_env().object_store(table_path)?; - - let file_list_stream = pruned_partition_list( - state, - store.as_ref(), - table_path, - &[], - &self.options.file_extension, - &self.options.table_partition_cols, - ) - .await?; - - let file_group = file_list_stream.try_collect::>().await?.into(); - let keep_partition_by_columns = - state.config_options().execution.keep_partition_by_columns; - - // Sink related option, apart from format - let config = FileSinkConfig { - original_url: String::default(), - object_store_url: self.table_paths()[0].object_store(), - table_paths: self.table_paths().clone(), - file_group, - output_schema: self.schema(), - table_partition_cols: self.options.table_partition_cols.clone(), - insert_op, - keep_partition_by_columns, - file_extension: self.options().format.get_ext(), - }; - - let orderings = self.try_create_output_ordering(state.execution_props())?; - // It is sufficient to pass only one of the equivalent orderings: - let order_requirements = orderings.into_iter().next().map(Into::into); - - self.options() - .format - .create_writer_physical_plan(input, state, config, order_requirements) - .await - } - - fn get_column_default(&self, column: &str) -> Option<&Expr> { - self.column_defaults.get(column) - } -} - -impl ListingTable { - /// Get the list of files for a scan as well as the file level statistics. - /// The list is grouped to let the execution plan know how the files should - /// be distributed to different threads / executors. - async fn list_files_for_scan<'a>( - &'a self, - ctx: &'a dyn Session, - filters: &'a [Expr], - limit: Option, - ) -> Result<(Vec, Statistics)> { - let store = if let Some(url) = self.table_paths.first() { - ctx.runtime_env().object_store(url)? - } else { - return Ok((vec![], Statistics::new_unknown(&self.file_schema))); - }; - // list files (with partitions) - let file_list = future::try_join_all(self.table_paths.iter().map(|table_path| { - pruned_partition_list( - ctx, - store.as_ref(), - table_path, - filters, - &self.options.file_extension, - &self.options.table_partition_cols, - ) - })) - .await?; - let meta_fetch_concurrency = - ctx.config_options().execution.meta_fetch_concurrency; - let file_list = stream::iter(file_list).flatten_unordered(meta_fetch_concurrency); - // collect the statistics if required by the config - let files = file_list - .map(|part_file| async { - let part_file = part_file?; - let statistics = if self.options.collect_stat { - self.do_collect_statistics(ctx, &store, &part_file).await? - } else { - Arc::new(Statistics::new_unknown(&self.file_schema)) - }; - Ok(part_file.with_statistics(statistics)) - }) - .boxed() - .buffer_unordered(ctx.config_options().execution.meta_fetch_concurrency); - - let (file_group, inexact_stats) = - get_files_with_limit(files, limit, self.options.collect_stat).await?; - - let file_groups = file_group.split_files(self.options.target_partitions); - let (mut file_groups, mut stats) = compute_all_files_statistics( - file_groups, - self.schema(), - self.options.collect_stat, - inexact_stats, - )?; - - let schema_adapter = self.create_schema_adapter(); - let (schema_mapper, _) = schema_adapter.map_schema(self.file_schema.as_ref())?; - - stats.column_statistics = - schema_mapper.map_column_statistics(&stats.column_statistics)?; - file_groups.iter_mut().try_for_each(|file_group| { - if let Some(stat) = file_group.statistics_mut() { - stat.column_statistics = - schema_mapper.map_column_statistics(&stat.column_statistics)?; - } - Ok::<_, DataFusionError>(()) - })?; - Ok((file_groups, stats)) - } - - /// Collects statistics for a given partitioned file. - /// - /// This method first checks if the statistics for the given file are already cached. - /// If they are, it returns the cached statistics. - /// If they are not, it infers the statistics from the file and stores them in the cache. - async fn do_collect_statistics( - &self, - ctx: &dyn Session, - store: &Arc, - part_file: &PartitionedFile, - ) -> Result> { - match self - .collected_statistics - .get_with_extra(&part_file.object_meta.location, &part_file.object_meta) - { - Some(statistics) => Ok(statistics), - None => { - let statistics = self - .options - .format - .infer_stats( - ctx, - store, - Arc::clone(&self.file_schema), - &part_file.object_meta, - ) - .await?; - let statistics = Arc::new(statistics); - self.collected_statistics.put_with_extra( - &part_file.object_meta.location, - Arc::clone(&statistics), - &part_file.object_meta, - ); - Ok(statistics) - } - } - } -} - -/// Processes a stream of partitioned files and returns a `FileGroup` containing the files. -/// -/// This function collects files from the provided stream until either: -/// 1. The stream is exhausted -/// 2. The accumulated number of rows exceeds the provided `limit` (if specified) -/// -/// # Arguments -/// * `files` - A stream of `Result` items to process -/// * `limit` - An optional row count limit. If provided, the function will stop collecting files -/// once the accumulated number of rows exceeds this limit -/// * `collect_stats` - Whether to collect and accumulate statistics from the files -/// -/// # Returns -/// A `Result` containing a `FileGroup` with the collected files -/// and a boolean indicating whether the statistics are inexact. -/// -/// # Note -/// The function will continue processing files if statistics are not available or if the -/// limit is not provided. If `collect_stats` is false, statistics won't be accumulated -/// but files will still be collected. -async fn get_files_with_limit( - files: impl Stream>, - limit: Option, - collect_stats: bool, -) -> Result<(FileGroup, bool)> { - let mut file_group = FileGroup::default(); - // Fusing the stream allows us to call next safely even once it is finished. - let mut all_files = Box::pin(files.fuse()); - enum ProcessingState { - ReadingFiles, - ReachedLimit, - } - - let mut state = ProcessingState::ReadingFiles; - let mut num_rows = Precision::Absent; - - while let Some(file_result) = all_files.next().await { - // Early exit if we've already reached our limit - if matches!(state, ProcessingState::ReachedLimit) { - break; - } - - let file = file_result?; - - // Update file statistics regardless of state - if collect_stats { - if let Some(file_stats) = &file.statistics { - num_rows = if file_group.is_empty() { - // For the first file, just take its row count - file_stats.num_rows - } else { - // For subsequent files, accumulate the counts - num_rows.add(&file_stats.num_rows) - }; - } - } - - // Always add the file to our group - file_group.push(file); - - // Check if we've hit the limit (if one was specified) - if let Some(limit) = limit { - if let Precision::Exact(row_count) = num_rows { - if row_count > limit { - state = ProcessingState::ReachedLimit; - } - } - } - } - // If we still have files in the stream, it means that the limit kicked - // in, and the statistic could have been different had we processed the - // files in a different order. - let inexact_stats = all_files.next().await.is_some(); - Ok((file_group, inexact_stats)) } #[cfg(test)] mod tests { - use super::*; #[cfg(feature = "parquet")] use crate::datasource::file_format::parquet::ParquetFormat; + use crate::datasource::listing::table::ListingTableConfigExt; use crate::prelude::*; use crate::{ datasource::{ @@ -1587,21 +123,34 @@ mod tests { }, }; use arrow::{compute::SortOptions, record_batch::RecordBatch}; + use arrow_schema::{DataType, Field, Schema, SchemaRef}; + use datafusion_catalog::TableProvider; + use datafusion_catalog_listing::{ + ListingOptions, ListingTable, ListingTableConfig, SchemaSource, + }; use datafusion_common::{ - assert_contains, + assert_contains, plan_err, stats::Precision, test_util::{batches_to_string, datafusion_test_data}, - ColumnStatistics, ScalarValue, + ColumnStatistics, DataFusionError, Result, ScalarValue, }; + use datafusion_datasource::file_compression_type::FileCompressionType; + use datafusion_datasource::file_format::FileFormat; use datafusion_datasource::schema_adapter::{ SchemaAdapter, SchemaAdapterFactory, SchemaMapper, }; + use datafusion_datasource::ListingTableUrl; + use datafusion_expr::dml::InsertOp; use datafusion_expr::{BinaryExpr, LogicalPlanBuilder, Operator}; use datafusion_physical_expr::expressions::binary; use datafusion_physical_expr::PhysicalSortExpr; + use datafusion_physical_expr_common::sort_expr::LexOrdering; + use datafusion_physical_plan::empty::EmptyExec; use datafusion_physical_plan::{collect, ExecutionPlanProperties}; use rstest::rstest; + use std::collections::HashMap; use std::io::Write; + use std::sync::Arc; use tempfile::TempDir; use url::Url; @@ -1638,10 +187,13 @@ mod tests { let ctx = SessionContext::new(); let testdata = datafusion_test_data(); let filename = format!("{testdata}/aggregate_simple.csv"); - let table_path = ListingTableUrl::parse(filename).unwrap(); + let table_path = ListingTableUrl::parse(filename)?; // Test default schema source - let config = ListingTableConfig::new(table_path.clone()); + let format = CsvFormat::default(); + let options = ListingOptions::new(Arc::new(format)); + let config = + ListingTableConfig::new(table_path.clone()).with_listing_options(options); assert_eq!(config.schema_source(), SchemaSource::Unset); // Test schema source after setting a schema explicitly @@ -1650,18 +202,13 @@ mod tests { assert_eq!(config_with_schema.schema_source(), SchemaSource::Specified); // Test schema source after inferring schema - let format = CsvFormat::default(); - let options = ListingOptions::new(Arc::new(format)); - let config_with_options = config.with_listing_options(options.clone()); - assert_eq!(config_with_options.schema_source(), SchemaSource::Unset); + assert_eq!(config.schema_source(), SchemaSource::Unset); - let config_with_inferred = config_with_options.infer_schema(&ctx.state()).await?; + let config_with_inferred = config.infer_schema(&ctx.state()).await?; assert_eq!(config_with_inferred.schema_source(), SchemaSource::Inferred); // Test schema preservation through operations - let config_with_schema_and_options = config_with_schema - .clone() - .with_listing_options(options.clone()); + let config_with_schema_and_options = config_with_schema.clone(); assert_eq!( config_with_schema_and_options.schema_source(), SchemaSource::Specified @@ -1836,7 +383,7 @@ mod tests { .with_table_partition_cols(vec![(String::from("p1"), DataType::Utf8)]) .with_target_partitions(4); - let table_path = ListingTableUrl::parse("test:///table/").unwrap(); + let table_path = ListingTableUrl::parse("test:///table/")?; let file_schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Boolean, false)])); let config = ListingTableConfig::new(table_path) @@ -1872,7 +419,7 @@ mod tests { ) -> Result> { let testdata = crate::test_util::parquet_test_data(); let filename = format!("{testdata}/{name}"); - let table_path = ListingTableUrl::parse(filename).unwrap(); + let table_path = ListingTableUrl::parse(filename)?; let config = ListingTableConfig::new(table_path) .infer(&ctx.state()) @@ -1899,7 +446,7 @@ mod tests { let schema = Schema::new(vec![Field::new("a", DataType::Boolean, false)]); - let table_path = ListingTableUrl::parse(table_prefix).unwrap(); + let table_path = ListingTableUrl::parse(table_prefix)?; let config = ListingTableConfig::new(table_path) .with_listing_options(opt) .with_schema(Arc::new(schema)); @@ -2458,7 +1005,7 @@ mod tests { async fn test_infer_options_compressed_csv() -> Result<()> { let testdata = crate::test_util::arrow_test_data(); let filename = format!("{testdata}/csv/aggregate_test_100.csv.gz"); - let table_path = ListingTableUrl::parse(filename).unwrap(); + let table_path = ListingTableUrl::parse(filename)?; let ctx = SessionContext::new(); @@ -2479,12 +1026,15 @@ mod tests { let testdata = datafusion_test_data(); let filename = format!("{testdata}/aggregate_simple.csv"); - let table_path = ListingTableUrl::parse(filename).unwrap(); + let table_path = ListingTableUrl::parse(filename)?; let provided_schema = create_test_schema(); - let config = - ListingTableConfig::new(table_path).with_schema(Arc::clone(&provided_schema)); + let format = CsvFormat::default(); + let options = ListingOptions::new(Arc::new(format)); + let config = ListingTableConfig::new(table_path) + .with_listing_options(options) + .with_schema(Arc::clone(&provided_schema)); let config = config.infer(&ctx.state()).await?; @@ -2549,8 +1099,8 @@ mod tests { table_path1.clone(), table_path2.clone(), ]) - .with_schema(schema_3cols) - .with_listing_options(options.clone()); + .with_listing_options(options.clone()) + .with_schema(schema_3cols); let config2 = config2.infer_schema(&ctx.state()).await?; assert_eq!(config2.schema_source(), SchemaSource::Specified); @@ -2573,8 +1123,8 @@ mod tests { table_path1.clone(), table_path2.clone(), ]) - .with_schema(schema_4cols) - .with_listing_options(options.clone()); + .with_listing_options(options.clone()) + .with_schema(schema_4cols); let config3 = config3.infer_schema(&ctx.state()).await?; assert_eq!(config3.schema_source(), SchemaSource::Specified); @@ -2732,6 +1282,52 @@ mod tests { Ok(()) } + #[tokio::test] + async fn test_listing_table_prunes_extra_files_in_hive() -> Result<()> { + let files = [ + "bucket/test/pid=1/file1", + "bucket/test/pid=1/file2", + "bucket/test/pid=2/file3", + "bucket/test/pid=2/file4", + "bucket/test/other/file5", + ]; + + let ctx = SessionContext::new(); + register_test_store(&ctx, &files.iter().map(|f| (*f, 10)).collect::>()); + + let opt = ListingOptions::new(Arc::new(JsonFormat::default())) + .with_file_extension_opt(Some("")) + .with_table_partition_cols(vec![("pid".to_string(), DataType::Int32)]); + + let table_path = ListingTableUrl::parse("test:///bucket/test/").unwrap(); + let schema = Schema::new(vec![Field::new("a", DataType::Boolean, false)]); + let config = ListingTableConfig::new(table_path) + .with_listing_options(opt) + .with_schema(Arc::new(schema)); + + let table = ListingTable::try_new(config)?; + + let (file_list, _) = table.list_files_for_scan(&ctx.state(), &[], None).await?; + assert_eq!(file_list.len(), 1); + + let files = file_list[0].clone(); + + assert_eq!( + files + .iter() + .map(|f| f.path().to_string()) + .collect::>(), + vec![ + "bucket/test/pid=1/file1", + "bucket/test/pid=1/file2", + "bucket/test/pid=2/file3", + "bucket/test/pid=2/file4", + ] + ); + + Ok(()) + } + #[cfg(feature = "parquet")] #[tokio::test] async fn test_table_stats_behaviors() -> Result<()> { @@ -2739,7 +1335,7 @@ mod tests { let testdata = crate::test_util::parquet_test_data(); let filename = format!("{}/{}", testdata, "alltypes_plain.parquet"); - let table_path = ListingTableUrl::parse(filename).unwrap(); + let table_path = ListingTableUrl::parse(filename)?; let ctx = SessionContext::new(); let state = ctx.state(); @@ -2750,6 +1346,7 @@ mod tests { let config_default = ListingTableConfig::new(table_path.clone()) .with_listing_options(opt_default) .with_schema(schema_default); + let table_default = ListingTable::try_new(config_default)?; let exec_default = table_default.scan(&state, None, &[], None).await?; @@ -2885,7 +1482,7 @@ mod tests { let format = JsonFormat::default(); let opt = ListingOptions::new(Arc::new(format)).with_collect_stat(false); let schema = Schema::new(vec![Field::new("a", DataType::Boolean, false)]); - let table_path = ListingTableUrl::parse("test:///table/").unwrap(); + let table_path = ListingTableUrl::parse("test:///table/")?; let config = ListingTableConfig::new(table_path) .with_listing_options(opt) @@ -3099,7 +1696,7 @@ mod tests { let format = JsonFormat::default(); let opt = ListingOptions::new(Arc::new(format)).with_collect_stat(collect_stat); let schema = Schema::new(vec![Field::new("a", DataType::Boolean, false)]); - let table_path = ListingTableUrl::parse("test:///table/").unwrap(); + let table_path = ListingTableUrl::parse("test:///table/")?; let config = ListingTableConfig::new(table_path) .with_listing_options(opt) diff --git a/datafusion/core/src/datasource/physical_plan/arrow.rs b/datafusion/core/src/datasource/physical_plan/arrow.rs new file mode 100644 index 000000000000..392eaa8c4be4 --- /dev/null +++ b/datafusion/core/src/datasource/physical_plan/arrow.rs @@ -0,0 +1,23 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Reexports the [`datafusion_datasource_arrow::source`] module, containing [Arrow] based [`FileSource`]. +//! +//! [Arrow]: https://arrow.apache.org/docs/python/ipc.html +//! [`FileSource`]: datafusion_datasource::file::FileSource + +pub use datafusion_datasource_arrow::source::*; diff --git a/datafusion/core/src/datasource/physical_plan/csv.rs b/datafusion/core/src/datasource/physical_plan/csv.rs index e33761a0abb3..b2ef51a76f89 100644 --- a/datafusion/core/src/datasource/physical_plan/csv.rs +++ b/datafusion/core/src/datasource/physical_plan/csv.rs @@ -121,7 +121,7 @@ mod tests { .with_projection(Some(vec![0, 2, 4])) .build(); - assert_eq!(13, config.file_schema.fields().len()); + assert_eq!(13, config.file_schema().fields().len()); let csv = DataSourceExec::from_data_source(config); assert_eq!(3, csv.schema().fields().len()); @@ -185,7 +185,7 @@ mod tests { .with_file_compression_type(file_compression_type.to_owned()) .with_projection(Some(vec![4, 0, 2])) .build(); - assert_eq!(13, config.file_schema.fields().len()); + assert_eq!(13, config.file_schema().fields().len()); let csv = DataSourceExec::from_data_source(config); assert_eq!(3, csv.schema().fields().len()); @@ -250,7 +250,7 @@ mod tests { .with_file_compression_type(file_compression_type.to_owned()) .with_limit(Some(5)) .build(); - assert_eq!(13, config.file_schema.fields().len()); + assert_eq!(13, config.file_schema().fields().len()); let csv = DataSourceExec::from_data_source(config); assert_eq!(13, csv.schema().fields().len()); @@ -313,7 +313,7 @@ mod tests { .with_file_compression_type(file_compression_type.to_owned()) .with_limit(Some(5)) .build(); - assert_eq!(14, config.file_schema.fields().len()); + assert_eq!(14, config.file_schema().fields().len()); let csv = DataSourceExec::from_data_source(config); assert_eq!(14, csv.schema().fields().len()); @@ -349,7 +349,7 @@ mod tests { let filename = "aggregate_test_100.csv"; let tmp_dir = TempDir::new()?; - let file_groups = partitioned_file_groups( + let mut file_groups = partitioned_file_groups( path.as_str(), filename, 1, @@ -357,30 +357,29 @@ mod tests { file_compression_type.to_owned(), tmp_dir.path(), )?; + // Add partition columns / values + file_groups[0][0].partition_values = vec![ScalarValue::from("2021-10-26")]; + + let num_file_schema_fields = file_schema.fields().len(); let source = Arc::new(CsvSource::new(true, b',', b'"')); - let mut config = FileScanConfigBuilder::from(partitioned_csv_config( + let config = FileScanConfigBuilder::from(partitioned_csv_config( file_schema, file_groups, source, )) .with_newlines_in_values(false) .with_file_compression_type(file_compression_type.to_owned()) - .build(); - - // Add partition columns - config.table_partition_cols = - vec![Arc::new(Field::new("date", DataType::Utf8, false))]; - config.file_groups[0][0].partition_values = vec![ScalarValue::from("2021-10-26")]; - + .with_table_partition_cols(vec![Field::new("date", DataType::Utf8, false)]) // We should be able to project on the partition column // Which is supposed to be after the file fields - config.projection = Some(vec![0, config.file_schema.fields().len()]); + .with_projection(Some(vec![0, num_file_schema_fields])) + .build(); // we don't have `/date=xx/` in the path but that is ok because // partitions are resolved during scan anyway - assert_eq!(13, config.file_schema.fields().len()); + assert_eq!(13, config.file_schema().fields().len()); let csv = DataSourceExec::from_data_source(config); assert_eq!(2, csv.schema().fields().len()); diff --git a/datafusion/core/src/datasource/physical_plan/mod.rs b/datafusion/core/src/datasource/physical_plan/mod.rs index 3a9dedaa028f..1ac292e260fd 100644 --- a/datafusion/core/src/datasource/physical_plan/mod.rs +++ b/datafusion/core/src/datasource/physical_plan/mod.rs @@ -17,7 +17,7 @@ //! Execution plans that read file formats -mod arrow_file; +pub mod arrow; pub mod csv; pub mod json; @@ -35,10 +35,9 @@ pub use datafusion_datasource_parquet::source::ParquetSource; #[cfg(feature = "parquet")] pub use datafusion_datasource_parquet::{ParquetFileMetrics, ParquetFileReaderFactory}; -pub use arrow_file::ArrowSource; - pub use json::{JsonOpener, JsonSource}; +pub use arrow::{ArrowOpener, ArrowSource}; pub use csv::{CsvOpener, CsvSource}; pub use datafusion_datasource::file::FileSource; pub use datafusion_datasource::file_groups::FileGroup; diff --git a/datafusion/core/src/datasource/physical_plan/parquet.rs b/datafusion/core/src/datasource/physical_plan/parquet.rs index d0774e57174e..10a475c1cc9a 100644 --- a/datafusion/core/src/datasource/physical_plan/parquet.rs +++ b/datafusion/core/src/datasource/physical_plan/parquet.rs @@ -64,7 +64,9 @@ mod tests { use datafusion_physical_expr::planner::logical2physical; use datafusion_physical_plan::analyze::AnalyzeExec; use datafusion_physical_plan::collect; - use datafusion_physical_plan::metrics::{ExecutionPlanMetricsSet, MetricsSet}; + use datafusion_physical_plan::metrics::{ + ExecutionPlanMetricsSet, MetricType, MetricsSet, + }; use datafusion_physical_plan::{ExecutionPlan, ExecutionPlanProperties}; use chrono::{TimeZone, Utc}; @@ -238,6 +240,7 @@ mod tests { let analyze_exec = Arc::new(AnalyzeExec::new( false, false, + vec![MetricType::SUMMARY, MetricType::DEV], // use a new ParquetSource to avoid sharing execution metrics self.build_parquet_exec( Arc::clone(table_schema), diff --git a/datafusion/core/src/execution/context/mod.rs b/datafusion/core/src/execution/context/mod.rs index a8148b80495e..448ee5264afd 100644 --- a/datafusion/core/src/execution/context/mod.rs +++ b/datafusion/core/src/execution/context/mod.rs @@ -64,12 +64,13 @@ use datafusion_catalog::{ DynamicFileCatalog, TableFunction, TableFunctionImpl, UrlTableFactory, }; use datafusion_common::config::ConfigOptions; +use datafusion_common::metadata::ScalarAndMetadata; use datafusion_common::{ config::{ConfigExtension, TableOptions}, exec_datafusion_err, exec_err, internal_datafusion_err, not_impl_err, plan_datafusion_err, plan_err, tree_node::{TreeNodeRecursion, TreeNodeVisitor}, - DFSchema, DataFusionError, ParamValues, ScalarValue, SchemaReference, TableReference, + DFSchema, DataFusionError, ParamValues, SchemaReference, TableReference, }; pub use datafusion_execution::config::SessionConfig; use datafusion_execution::registry::SerializerRegistry; @@ -505,6 +506,13 @@ impl SessionContext { self.runtime_env().register_object_store(url, object_store) } + /// Deregisters an [`ObjectStore`] associated with the specific URL prefix. + /// + /// See [`RuntimeEnv::deregister_object_store`] for more details. + pub fn deregister_object_store(&self, url: &Url) -> Result> { + self.runtime_env().deregister_object_store(url) + } + /// Registers the [`RecordBatch`] as the specified table name pub fn register_batch( &self, @@ -708,15 +716,15 @@ impl SessionContext { LogicalPlan::Statement(Statement::Prepare(Prepare { name, input, - data_types, + fields, })) => { // The number of parameters must match the specified data types length. - if !data_types.is_empty() { + if !fields.is_empty() { let param_names = input.get_parameter_names()?; - if param_names.len() != data_types.len() { + if param_names.len() != fields.len() { return plan_err!( "Prepare specifies {} data types but query has {} parameters", - data_types.len(), + fields.len(), param_names.len() ); } @@ -726,7 +734,7 @@ impl SessionContext { // not currently feasible. This is because `now()` would be optimized to a // constant value, causing each EXECUTE to yield the same result, which is // incorrect behavior. - self.state.write().store_prepared(name, data_types, input)?; + self.state.write().store_prepared(name, fields, input)?; self.return_empty_dataframe() } LogicalPlan::Statement(Statement::Execute(execute)) => { @@ -1072,6 +1080,26 @@ impl SessionContext { } else { let mut state = self.state.write(); state.config_mut().options_mut().set(&variable, &value)?; + + // Re-initialize any UDFs that depend on configuration + // This allows both built-in and custom functions to respond to configuration changes + let config_options = state.config().options(); + + // Collect updated UDFs in a separate vector + let udfs_to_update: Vec<_> = state + .scalar_functions() + .values() + .filter_map(|udf| { + udf.inner() + .with_updated_config(config_options) + .map(Arc::new) + }) + .collect(); + + for udf in udfs_to_update { + state.register_udf(udf)?; + } + drop(state); } @@ -1238,28 +1266,30 @@ impl SessionContext { })?; // Only allow literals as parameters for now. - let mut params: Vec = parameters + let mut params: Vec = parameters .into_iter() .map(|e| match e { - Expr::Literal(scalar, _) => Ok(scalar), + Expr::Literal(scalar, metadata) => { + Ok(ScalarAndMetadata::new(scalar, metadata)) + } _ => not_impl_err!("Unsupported parameter type: {}", e), }) .collect::>()?; // If the prepared statement provides data types, cast the params to those types. - if !prepared.data_types.is_empty() { - if params.len() != prepared.data_types.len() { + if !prepared.fields.is_empty() { + if params.len() != prepared.fields.len() { return exec_err!( "Prepared statement '{}' expects {} parameters, but {} provided", name, - prepared.data_types.len(), + prepared.fields.len(), params.len() ); } params = params .into_iter() - .zip(prepared.data_types.iter()) - .map(|(e, dt)| e.cast_to(dt)) + .zip(prepared.fields.iter()) + .map(|(e, dt)| -> Result<_> { e.cast_storage_to(dt.data_type()) }) .collect::>()?; } diff --git a/datafusion/core/src/execution/session_state.rs b/datafusion/core/src/execution/session_state.rs index b04004dd495c..561e0c363a37 100644 --- a/datafusion/core/src/execution/session_state.rs +++ b/datafusion/core/src/execution/session_state.rs @@ -30,15 +30,14 @@ use crate::datasource::provider_as_source; use crate::execution::context::{EmptySerializerRegistry, FunctionFactory, QueryPlanner}; use crate::execution::SessionStateDefaults; use crate::physical_planner::{DefaultPhysicalPlanner, PhysicalPlanner}; +use arrow_schema::{DataType, FieldRef}; use datafusion_catalog::information_schema::{ InformationSchemaProvider, INFORMATION_SCHEMA, }; - -use arrow::datatypes::DataType; use datafusion_catalog::MemoryCatalogProviderList; use datafusion_catalog::{TableFunction, TableFunctionImpl}; use datafusion_common::alias::AliasGenerator; -use datafusion_common::config::{ConfigExtension, ConfigOptions, TableOptions}; +use datafusion_common::config::{ConfigExtension, ConfigOptions, Dialect, TableOptions}; use datafusion_common::display::{PlanType, StringifiedPlan, ToStringifiedPlan}; use datafusion_common::tree_node::TreeNode; use datafusion_common::{ @@ -116,11 +115,11 @@ use uuid::Uuid; /// # #[tokio::main] /// # async fn main() -> Result<()> { /// let state = SessionStateBuilder::new() -/// .with_config(SessionConfig::new()) +/// .with_config(SessionConfig::new()) /// .with_runtime_env(Arc::new(RuntimeEnv::default())) /// .with_default_features() /// .build(); -/// Ok(()) +/// Ok(()) /// # } /// ``` /// @@ -374,7 +373,7 @@ impl SessionState { pub fn sql_to_statement( &self, sql: &str, - dialect: &str, + dialect: &Dialect, ) -> datafusion_common::Result { let dialect = dialect_from_str(dialect).ok_or_else(|| { plan_datafusion_err!( @@ -411,7 +410,7 @@ impl SessionState { pub fn sql_to_expr( &self, sql: &str, - dialect: &str, + dialect: &Dialect, ) -> datafusion_common::Result { self.sql_to_expr_with_alias(sql, dialect).map(|x| x.expr) } @@ -423,7 +422,7 @@ impl SessionState { pub fn sql_to_expr_with_alias( &self, sql: &str, - dialect: &str, + dialect: &Dialect, ) -> datafusion_common::Result { let dialect = dialect_from_str(dialect).ok_or_else(|| { plan_datafusion_err!( @@ -527,8 +526,8 @@ impl SessionState { &self, sql: &str, ) -> datafusion_common::Result { - let dialect = self.config.options().sql_parser.dialect.as_str(); - let statement = self.sql_to_statement(sql, dialect)?; + let dialect = self.config.options().sql_parser.dialect; + let statement = self.sql_to_statement(sql, &dialect)?; let plan = self.statement_to_plan(statement).await?; Ok(plan) } @@ -542,9 +541,9 @@ impl SessionState { sql: &str, df_schema: &DFSchema, ) -> datafusion_common::Result { - let dialect = self.config.options().sql_parser.dialect.as_str(); + let dialect = self.config.options().sql_parser.dialect; - let sql_expr = self.sql_to_expr_with_alias(sql, dialect)?; + let sql_expr = self.sql_to_expr_with_alias(sql, &dialect)?; let provider = SessionContextProvider { state: self, @@ -873,12 +872,12 @@ impl SessionState { pub(crate) fn store_prepared( &mut self, name: String, - data_types: Vec, + fields: Vec, plan: Arc, ) -> datafusion_common::Result<()> { match self.prepared_plans.entry(name) { Entry::Vacant(e) => { - e.insert(Arc::new(PreparedPlan { data_types, plan })); + e.insert(Arc::new(PreparedPlan { fields, plan })); Ok(()) } Entry::Occupied(e) => { @@ -1323,7 +1322,7 @@ impl SessionStateBuilder { /// let url = Url::try_from("file://").unwrap(); /// let object_store = object_store::local::LocalFileSystem::new(); /// let state = SessionStateBuilder::new() - /// .with_config(SessionConfig::new()) + /// .with_config(SessionConfig::new()) /// .with_object_store(&url, Arc::new(object_store)) /// .with_default_features() /// .build(); @@ -1419,12 +1418,31 @@ impl SessionStateBuilder { } if let Some(scalar_functions) = scalar_functions { - scalar_functions.into_iter().for_each(|udf| { - let existing_udf = state.register_udf(udf); - if let Ok(Some(existing_udf)) = existing_udf { - debug!("Overwrote an existing UDF: {}", existing_udf.name()); + for udf in scalar_functions { + let config_options = state.config().options(); + match udf.inner().with_updated_config(config_options) { + Some(new_udf) => { + if let Err(err) = state.register_udf(Arc::new(new_udf)) { + debug!( + "Failed to re-register updated UDF '{}': {}", + udf.name(), + err + ); + } + } + None => match state.register_udf(Arc::clone(&udf)) { + Ok(Some(existing)) => { + debug!("Overwrote existing UDF '{}'", existing.name()); + } + Ok(None) => { + debug!("Registered UDF '{}'", udf.name()); + } + Err(err) => { + debug!("Failed to register UDF '{}': {}", udf.name(), err); + } + }, } - }); + } } if let Some(aggregate_functions) = aggregate_functions { @@ -2012,7 +2030,7 @@ impl SimplifyInfo for SessionSimplifyProvider<'_> { #[derive(Debug)] pub(crate) struct PreparedPlan { /// Data types of the parameters - pub(crate) data_types: Vec, + pub(crate) fields: Vec, /// The prepared logical plan pub(crate) plan: Arc, } @@ -2034,6 +2052,7 @@ mod tests { use arrow::array::{ArrayRef, Int32Array, RecordBatch, StringArray}; use arrow::datatypes::{DataType, Field, Schema}; use datafusion_catalog::MemoryCatalogProviderList; + use datafusion_common::config::Dialect; use datafusion_common::DFSchema; use datafusion_common::Result; use datafusion_execution::config::SessionConfig; @@ -2059,8 +2078,8 @@ mod tests { let sql = "[1,2,3]"; let schema = Schema::new(vec![Field::new("a", DataType::Int32, true)]); let df_schema = DFSchema::try_from(schema)?; - let dialect = state.config.options().sql_parser.dialect.as_str(); - let sql_expr = state.sql_to_expr(sql, dialect)?; + let dialect = state.config.options().sql_parser.dialect; + let sql_expr = state.sql_to_expr(sql, &dialect)?; let query = SqlToRel::new_with_options(&provider, state.get_parser_options()); query.sql_to_expr(sql_expr, &df_schema, &mut PlannerContext::new()) @@ -2218,7 +2237,8 @@ mod tests { } let state = &context_provider.state; - let statement = state.sql_to_statement("select count(*) from t", "mysql")?; + let statement = + state.sql_to_statement("select count(*) from t", &Dialect::MySQL)?; let plan = SqlToRel::new(&context_provider).statement_to_plan(statement)?; state.create_physical_plan(&plan).await } diff --git a/datafusion/core/src/lib.rs b/datafusion/core/src/lib.rs index e7ace544a11c..78db28eaacc7 100644 --- a/datafusion/core/src/lib.rs +++ b/datafusion/core/src/lib.rs @@ -443,7 +443,30 @@ //! other operators read a single [`RecordBatch`] from their input to produce a //! single [`RecordBatch`] as output. //! -//! For example, given this SQL query: +//! For example, given this SQL: +//! +//! ```sql +//! SELECT name FROM 'data.parquet' WHERE id > 10 +//! ``` +//! +//! An simplified DataFusion execution plan is shown below. It first reads +//! data from the Parquet file, then applies the filter, then the projection, +//! and finally produces output. Each step processes one [`RecordBatch`] at a +//! time. Multiple batches are processed concurrently on different CPU cores +//! for plans with multiple partitions. +//! +//! ```text +//! ┌─────────────┐ ┌──────────────┐ ┌────────────────┐ ┌──────────────────┐ ┌──────────┐ +//! │ Parquet │───▶│ DataSource │───▶│ FilterExec │───▶│ ProjectionExec │───▶│ Results │ +//! │ File │ │ │ │ │ │ │ │ │ +//! └─────────────┘ └──────────────┘ └────────────────┘ └──────────────────┘ └──────────┘ +//! (reads data) (id > 10) (keeps "name" col) +//! RecordBatch ───▶ RecordBatch ────▶ RecordBatch ────▶ RecordBatch +//! ``` +//! +//! DataFusion uses the classic "pull" based control flow (explained more in the +//! next section) to implement streaming execution. As an example, +//! consider the following SQL query: //! //! ```sql //! SELECT date_trunc('month', time) FROM data WHERE id IN (10,20,30); @@ -897,6 +920,12 @@ doc_comment::doctest!("../../../README.md", readme_example_test); // For example, if `user_guide_expressions(line 123)` fails, // go to `docs/source/user-guide/expressions.md` to find the relevant problem. // +#[cfg(doctest)] +doc_comment::doctest!( + "../../../docs/source/user-guide/arrow-introduction.md", + user_guide_arrow_introduction +); + #[cfg(doctest)] doc_comment::doctest!( "../../../docs/source/user-guide/concepts-readings-events.md", diff --git a/datafusion/core/src/physical_planner.rs b/datafusion/core/src/physical_planner.rs index c28e56790e66..c280b50a9f07 100644 --- a/datafusion/core/src/physical_planner.rs +++ b/datafusion/core/src/physical_planner.rs @@ -62,6 +62,7 @@ use arrow::compute::SortOptions; use arrow::datatypes::Schema; use datafusion_catalog::ScanArgs; use datafusion_common::display::ToStringifiedPlan; +use datafusion_common::format::ExplainAnalyzeLevel; use datafusion_common::tree_node::{TreeNode, TreeNodeRecursion, TreeNodeVisitor}; use datafusion_common::TableReference; use datafusion_common::{ @@ -77,10 +78,11 @@ use datafusion_expr::expr::{ }; use datafusion_expr::expr_rewriter::unnormalize_cols; use datafusion_expr::logical_plan::builder::wrap_projection_for_join_if_necessary; +use datafusion_expr::utils::split_conjunction; use datafusion_expr::{ - Analyze, DescribeTable, DmlStatement, Explain, ExplainFormat, Extension, FetchType, - Filter, JoinType, RecursiveQuery, SkipType, StringifiedPlan, WindowFrame, - WindowFrameBound, WriteOp, + Analyze, BinaryExpr, DescribeTable, DmlStatement, Explain, ExplainFormat, Extension, + FetchType, Filter, JoinType, Operator, RecursiveQuery, SkipType, StringifiedPlan, + WindowFrame, WindowFrameBound, WriteOp, }; use datafusion_physical_expr::aggregate::{AggregateExprBuilder, AggregateFunctionExpr}; use datafusion_physical_expr::expressions::Literal; @@ -90,6 +92,8 @@ use datafusion_physical_expr::{ use datafusion_physical_optimizer::PhysicalOptimizerRule; use datafusion_physical_plan::empty::EmptyExec; use datafusion_physical_plan::execution_plan::InvariantLevel; +use datafusion_physical_plan::joins::PiecewiseMergeJoinExec; +use datafusion_physical_plan::metrics::MetricType; use datafusion_physical_plan::placeholder_row::PlaceholderRowExec; use datafusion_physical_plan::recursive_query::RecursiveQueryExec; use datafusion_physical_plan::unnest::ListUnnest; @@ -985,7 +989,7 @@ impl DefaultPhysicalPlanner { struct_type_columns.clone(), schema, options.clone(), - )) + )?) } // 2 Children @@ -1131,8 +1135,42 @@ impl DefaultPhysicalPlanner { }) .collect::>()?; + // TODO: `num_range_filters` can be used later on for ASOF joins (`num_range_filters > 1`) + let mut num_range_filters = 0; + let mut range_filters: Vec = Vec::new(); + let mut total_filters = 0; + let join_filter = match filter { Some(expr) => { + let split_expr = split_conjunction(expr); + for expr in split_expr.iter() { + match *expr { + Expr::BinaryExpr(BinaryExpr { + left: _, + right: _, + op, + }) => { + if matches!( + op, + Operator::Lt + | Operator::LtEq + | Operator::Gt + | Operator::GtEq + ) { + range_filters.push((**expr).clone()); + num_range_filters += 1; + } + total_filters += 1; + } + // TODO: Want to deal with `Expr::Between` for IEJoins, it counts as two range predicates + // which is why it is not dealt with in PWMJ + // Expr::Between(_) => {}, + _ => { + total_filters += 1; + } + } + } + // Extract columns from filter expression and saved in a HashSet let cols = expr.column_refs(); @@ -1188,6 +1226,7 @@ impl DefaultPhysicalPlanner { )?; let filter_schema = Schema::new_with_metadata(filter_fields, metadata); + let filter_expr = create_physical_expr( expr, &filter_df_schema, @@ -1210,10 +1249,125 @@ impl DefaultPhysicalPlanner { let prefer_hash_join = session_state.config_options().optimizer.prefer_hash_join; + // TODO: Allow PWMJ to deal with residual equijoin conditions let join: Arc = if join_on.is_empty() { if join_filter.is_none() && matches!(join_type, JoinType::Inner) { // cross join if there is no join conditions and no join filter set Arc::new(CrossJoinExec::new(physical_left, physical_right)) + } else if num_range_filters == 1 + && total_filters == 1 + && !matches!( + join_type, + JoinType::LeftSemi + | JoinType::RightSemi + | JoinType::LeftAnti + | JoinType::RightAnti + | JoinType::LeftMark + | JoinType::RightMark + ) + && session_state + .config_options() + .optimizer + .enable_piecewise_merge_join + { + let Expr::BinaryExpr(be) = &range_filters[0] else { + return plan_err!( + "Unsupported expression for PWMJ: Expected `Expr::BinaryExpr`" + ); + }; + + let mut op = be.op; + if !matches!( + op, + Operator::Lt | Operator::LtEq | Operator::Gt | Operator::GtEq + ) { + return plan_err!( + "Unsupported operator for PWMJ: {:?}. Expected one of <, <=, >, >=", + op + ); + } + + fn reverse_ineq(op: Operator) -> Operator { + match op { + Operator::Lt => Operator::Gt, + Operator::LtEq => Operator::GtEq, + Operator::Gt => Operator::Lt, + Operator::GtEq => Operator::LtEq, + _ => op, + } + } + + #[derive(Clone, Copy, Debug, PartialEq, Eq)] + enum Side { + Left, + Right, + Both, + } + + let side_of = |e: &Expr| -> Result { + let cols = e.column_refs(); + let any_left = cols + .iter() + .any(|c| left_df_schema.index_of_column(c).is_ok()); + let any_right = cols + .iter() + .any(|c| right_df_schema.index_of_column(c).is_ok()); + + Ok(match (any_left, any_right) { + (true, false) => Side::Left, + (false, true) => Side::Right, + (true, true) => Side::Both, + _ => unreachable!(), + }) + }; + + let mut lhs_logical = &be.left; + let mut rhs_logical = &be.right; + + let left_side = side_of(lhs_logical)?; + let right_side = side_of(rhs_logical)?; + if matches!(left_side, Side::Both) + || matches!(right_side, Side::Both) + { + return Ok(Arc::new(NestedLoopJoinExec::try_new( + physical_left, + physical_right, + join_filter, + join_type, + None, + )?)); + } + + if left_side == Side::Right && right_side == Side::Left { + std::mem::swap(&mut lhs_logical, &mut rhs_logical); + op = reverse_ineq(op); + } else if !(left_side == Side::Left && right_side == Side::Right) + { + return plan_err!( + "Unsupported operator for PWMJ: {:?}. Expected one of <, <=, >, >=", + op + ); + } + + let on_left = create_physical_expr( + lhs_logical, + left_df_schema, + session_state.execution_props(), + )?; + let on_right = create_physical_expr( + rhs_logical, + right_df_schema, + session_state.execution_props(), + )?; + + Arc::new(PiecewiseMergeJoinExec::try_new( + physical_left, + physical_right, + (on_left, on_right), + op, + *join_type, + session_state.config().target_partitions(), + )?) } else { // there is no equal join condition, use the nested loop join Arc::new(NestedLoopJoinExec::try_new( @@ -2073,9 +2227,15 @@ impl DefaultPhysicalPlanner { let input = self.create_physical_plan(&a.input, session_state).await?; let schema = Arc::clone(a.schema.inner()); let show_statistics = session_state.config_options().explain.show_statistics; + let analyze_level = session_state.config_options().explain.analyze_level; + let metric_types = match analyze_level { + ExplainAnalyzeLevel::Summary => vec![MetricType::SUMMARY], + ExplainAnalyzeLevel::Dev => vec![MetricType::SUMMARY, MetricType::DEV], + }; Ok(Arc::new(AnalyzeExec::new( a.verbose, show_statistics, + metric_types, input, schema, ))) @@ -2484,7 +2644,7 @@ mod tests { // verify that the plan correctly casts u8 to i64 // the cast from u8 to i64 for literal will be simplified, and get lit(int64(5)) // the cast here is implicit so has CastOptions with safe=true - let expected = r#"BinaryExpr { left: Column { name: "c7", index: 2 }, op: Lt, right: Literal { value: Int64(5), field: Field { name: "lit", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} } }, fail_on_overflow: false }"#; + let expected = r#"BinaryExpr { left: Column { name: "c7", index: 2 }, op: Lt, right: Literal { value: Int64(5), field: Field { name: "lit", data_type: Int64 } }, fail_on_overflow: false"#; assert_contains!(format!("{exec_plan:?}"), expected); Ok(()) @@ -2544,9 +2704,6 @@ mod tests { name: "lit", data_type: Utf8, nullable: true, - dict_id: 0, - dict_is_ordered: false, - metadata: {}, }, }, "c1", @@ -2558,9 +2715,6 @@ mod tests { name: "lit", data_type: Int64, nullable: true, - dict_id: 0, - dict_is_ordered: false, - metadata: {}, }, }, "c2", @@ -2572,9 +2726,6 @@ mod tests { name: "lit", data_type: Int64, nullable: true, - dict_id: 0, - dict_is_ordered: false, - metadata: {}, }, }, "c3", @@ -2683,9 +2834,6 @@ mod tests { name: "lit", data_type: Utf8, nullable: true, - dict_id: 0, - dict_is_ordered: false, - metadata: {}, }, }, "c1", @@ -2697,9 +2845,6 @@ mod tests { name: "lit", data_type: Int64, nullable: true, - dict_id: 0, - dict_is_ordered: false, - metadata: {}, }, }, "c2", @@ -2711,9 +2856,6 @@ mod tests { name: "lit", data_type: Int64, nullable: true, - dict_id: 0, - dict_is_ordered: false, - metadata: {}, }, }, "c3", @@ -2887,7 +3029,7 @@ mod tests { .expect_err("planning error") .strip_backtrace(); - insta::assert_snapshot!(e, @r#"Error during planning: Extension planner for NoOp created an ExecutionPlan with mismatched schema. LogicalPlan schema: DFSchema { inner: Schema { fields: [Field { name: "a", data_type: Int32, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }], metadata: {} }, field_qualifiers: [None], functional_dependencies: FunctionalDependencies { deps: [] } }, ExecutionPlan schema: Schema { fields: [Field { name: "b", data_type: Int32, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }], metadata: {} }"#); + insta::assert_snapshot!(e, @r#"Error during planning: Extension planner for NoOp created an ExecutionPlan with mismatched schema. LogicalPlan schema: DFSchema { inner: Schema { fields: [Field { name: "a", data_type: Int32 }], metadata: {} }, field_qualifiers: [None], functional_dependencies: FunctionalDependencies { deps: [] } }, ExecutionPlan schema: Schema { fields: [Field { name: "b", data_type: Int32 }], metadata: {} }"#); } #[tokio::test] @@ -2903,7 +3045,7 @@ mod tests { let execution_plan = plan(&logical_plan).await?; // verify that the plan correctly adds cast from Int64(1) to Utf8, and the const will be evaluated. - let expected = "expr: [ProjectionExpr { expr: BinaryExpr { left: BinaryExpr { left: Column { name: \"c1\", index: 0 }, op: Eq, right: Literal { value: Utf8(\"a\"), field: Field { name: \"lit\", data_type: Utf8, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} } }, fail_on_overflow: false }, op: Or, right: BinaryExpr { left: Column { name: \"c1\", index: 0 }, op: Eq, right: Literal { value: Utf8(\"1\"), field: Field { name: \"lit\", data_type: Utf8, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} } }, fail_on_overflow: false }, fail_on_overflow: false }"; + let expected = r#"expr: BinaryExpr { left: BinaryExpr { left: Column { name: "c1", index: 0 }, op: Eq, right: Literal { value: Utf8("a"), field: Field { name: "lit", data_type: Utf8 } }, fail_on_overflow: false }"#; assert_contains!(format!("{execution_plan:?}"), expected); @@ -2925,7 +3067,7 @@ mod tests { assert_contains!( &e, - r#"Error during planning: Can not find compatible types to compare Boolean with [Struct(foo Boolean), Utf8]"# + r#"Error during planning: Can not find compatible types to compare Boolean with [Struct("foo": Boolean), Utf8]"# ); Ok(()) diff --git a/datafusion/core/tests/catalog/memory.rs b/datafusion/core/tests/catalog/memory.rs index ea9e71fc3746..06ed141b2e8b 100644 --- a/datafusion/core/tests/catalog/memory.rs +++ b/datafusion/core/tests/catalog/memory.rs @@ -19,7 +19,7 @@ use arrow::datatypes::Schema; use datafusion::catalog::CatalogProvider; use datafusion::datasource::empty::EmptyTable; use datafusion::datasource::listing::{ - ListingTable, ListingTableConfig, ListingTableUrl, + ListingTable, ListingTableConfig, ListingTableConfigExt, ListingTableUrl, }; use datafusion::prelude::SessionContext; use datafusion_catalog::memory::*; diff --git a/datafusion/core/tests/core_integration.rs b/datafusion/core/tests/core_integration.rs index e37a368f0771..edcf039e4e70 100644 --- a/datafusion/core/tests/core_integration.rs +++ b/datafusion/core/tests/core_integration.rs @@ -21,6 +21,9 @@ mod sql; /// Run all tests that are found in the `dataframe` directory mod dataframe; +/// Run all tests that are found in the `datasource` directory +mod datasource; + /// Run all tests that are found in the `macro_hygiene` directory mod macro_hygiene; diff --git a/datafusion/core/tests/dataframe/dataframe_functions.rs b/datafusion/core/tests/dataframe/dataframe_functions.rs index b664fccdfa80..265862ff9af8 100644 --- a/datafusion/core/tests/dataframe/dataframe_functions.rs +++ b/datafusion/core/tests/dataframe/dataframe_functions.rs @@ -274,6 +274,33 @@ async fn test_nvl2() -> Result<()> { Ok(()) } + +#[tokio::test] +async fn test_nvl2_short_circuit() -> Result<()> { + let expr = nvl2( + col("a"), + arrow_cast(lit("1"), lit("Int32")), + arrow_cast(col("a"), lit("Int32")), + ); + + let batches = get_batches(expr).await?; + + assert_snapshot!( + batches_to_string(&batches), + @r#" + +-----------------------------------------------------------------------------------+ + | nvl2(test.a,arrow_cast(Utf8("1"),Utf8("Int32")),arrow_cast(test.a,Utf8("Int32"))) | + +-----------------------------------------------------------------------------------+ + | 1 | + | 1 | + | 1 | + | 1 | + +-----------------------------------------------------------------------------------+ + "# + ); + + Ok(()) +} #[tokio::test] async fn test_fn_arrow_typeof() -> Result<()> { let expr = arrow_typeof(col("l")); @@ -282,16 +309,16 @@ async fn test_fn_arrow_typeof() -> Result<()> { assert_snapshot!( batches_to_string(&batches), - @r#" - +------------------------------------------------------------------------------------------------------------------+ - | arrow_typeof(test.l) | - +------------------------------------------------------------------------------------------------------------------+ - | List(Field { name: "item", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }) | - | List(Field { name: "item", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }) | - | List(Field { name: "item", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }) | - | List(Field { name: "item", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }) | - +------------------------------------------------------------------------------------------------------------------+ - "#); + @r" + +----------------------+ + | arrow_typeof(test.l) | + +----------------------+ + | List(nullable Int32) | + | List(nullable Int32) | + | List(nullable Int32) | + | List(nullable Int32) | + +----------------------+ + "); Ok(()) } diff --git a/datafusion/core/tests/dataframe/mod.rs b/datafusion/core/tests/dataframe/mod.rs index aa538f6dee81..17d1695478a5 100644 --- a/datafusion/core/tests/dataframe/mod.rs +++ b/datafusion/core/tests/dataframe/mod.rs @@ -33,6 +33,7 @@ use arrow::error::ArrowError; use arrow::util::pretty::pretty_format_batches; use arrow_schema::{SortOptions, TimeUnit}; use datafusion::{assert_batches_eq, dataframe}; +use datafusion_common::metadata::FieldMetadata; use datafusion_functions_aggregate::count::{count_all, count_all_window}; use datafusion_functions_aggregate::expr_fn::{ array_agg, avg, avg_distinct, count, count_distinct, max, median, min, sum, @@ -65,15 +66,13 @@ use datafusion_catalog::TableProvider; use datafusion_common::test_util::{batches_to_sort_string, batches_to_string}; use datafusion_common::{ assert_contains, internal_datafusion_err, Constraint, Constraints, DFSchema, - DataFusionError, ParamValues, ScalarValue, TableReference, UnnestOptions, + DataFusionError, ScalarValue, TableReference, UnnestOptions, }; use datafusion_common_runtime::SpawnedTask; use datafusion_datasource::file_format::format_as_file_type; use datafusion_execution::config::SessionConfig; use datafusion_execution::runtime_env::RuntimeEnv; -use datafusion_expr::expr::{ - FieldMetadata, GroupingSet, NullTreatment, Sort, WindowFunction, -}; +use datafusion_expr::expr::{GroupingSet, NullTreatment, Sort, WindowFunction}; use datafusion_expr::var_provider::{VarProvider, VarType}; use datafusion_expr::{ cast, col, create_udf, exists, in_subquery, lit, out_ref_col, placeholder, @@ -2465,7 +2464,7 @@ async fn filtered_aggr_with_param_values() -> Result<()> { let df = ctx .sql("select count (c2) filter (where c3 > $1) from table1") .await? - .with_param_values(ParamValues::List(vec![ScalarValue::from(10u64)])); + .with_param_values(vec![ScalarValue::from(10u64)]); let df_results = df?.collect().await?; assert_snapshot!( @@ -2945,18 +2944,18 @@ async fn test_count_wildcard_on_window() -> Result<()> { assert_snapshot!( pretty_format_batches(&sql_results).unwrap(), @r#" - +---------------+----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ - | plan_type | plan | - +---------------+----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ - | logical_plan | Projection: count(Int64(1)) ORDER BY [t1.a DESC NULLS FIRST] RANGE BETWEEN 6 PRECEDING AND 2 FOLLOWING AS count(*) ORDER BY [t1.a DESC NULLS FIRST] RANGE BETWEEN 6 PRECEDING AND 2 FOLLOWING | - | | WindowAggr: windowExpr=[[count(Int64(1)) ORDER BY [t1.a DESC NULLS FIRST] RANGE BETWEEN 6 PRECEDING AND 2 FOLLOWING]] | - | | TableScan: t1 projection=[a] | - | physical_plan | ProjectionExec: expr=[count(Int64(1)) ORDER BY [t1.a DESC NULLS FIRST] RANGE BETWEEN 6 PRECEDING AND 2 FOLLOWING@1 as count(*) ORDER BY [t1.a DESC NULLS FIRST] RANGE BETWEEN 6 PRECEDING AND 2 FOLLOWING] | - | | BoundedWindowAggExec: wdw=[count(Int64(1)) ORDER BY [t1.a DESC NULLS FIRST] RANGE BETWEEN 6 PRECEDING AND 2 FOLLOWING: Field { name: "count(Int64(1)) ORDER BY [t1.a DESC NULLS FIRST] RANGE BETWEEN 6 PRECEDING AND 2 FOLLOWING", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }, frame: RANGE BETWEEN 6 PRECEDING AND 2 FOLLOWING], mode=[Sorted] | - | | SortExec: expr=[a@0 DESC], preserve_partitioning=[false] | - | | DataSourceExec: partitions=1, partition_sizes=[1] | - | | | - +---------------+----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ + +---------------+----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ + | plan_type | plan | + +---------------+----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ + | logical_plan | Projection: count(Int64(1)) ORDER BY [t1.a DESC NULLS FIRST] RANGE BETWEEN 6 PRECEDING AND 2 FOLLOWING AS count(*) ORDER BY [t1.a DESC NULLS FIRST] RANGE BETWEEN 6 PRECEDING AND 2 FOLLOWING | + | | WindowAggr: windowExpr=[[count(Int64(1)) ORDER BY [t1.a DESC NULLS FIRST] RANGE BETWEEN 6 PRECEDING AND 2 FOLLOWING]] | + | | TableScan: t1 projection=[a] | + | physical_plan | ProjectionExec: expr=[count(Int64(1)) ORDER BY [t1.a DESC NULLS FIRST] RANGE BETWEEN 6 PRECEDING AND 2 FOLLOWING@1 as count(*) ORDER BY [t1.a DESC NULLS FIRST] RANGE BETWEEN 6 PRECEDING AND 2 FOLLOWING] | + | | BoundedWindowAggExec: wdw=[count(Int64(1)) ORDER BY [t1.a DESC NULLS FIRST] RANGE BETWEEN 6 PRECEDING AND 2 FOLLOWING: Field { "count(Int64(1)) ORDER BY [t1.a DESC NULLS FIRST] RANGE BETWEEN 6 PRECEDING AND 2 FOLLOWING": Int64 }, frame: RANGE BETWEEN 6 PRECEDING AND 2 FOLLOWING], mode=[Sorted] | + | | SortExec: expr=[a@0 DESC], preserve_partitioning=[false] | + | | DataSourceExec: partitions=1, partition_sizes=[1] | + | | | + +---------------+----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ "# ); @@ -2979,18 +2978,18 @@ async fn test_count_wildcard_on_window() -> Result<()> { assert_snapshot!( pretty_format_batches(&df_results).unwrap(), @r#" - +---------------+----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ - | plan_type | plan | - +---------------+----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ - | logical_plan | Projection: count(Int64(1)) ORDER BY [t1.a DESC NULLS FIRST] RANGE BETWEEN 6 PRECEDING AND 2 FOLLOWING | - | | WindowAggr: windowExpr=[[count(Int64(1)) ORDER BY [t1.a DESC NULLS FIRST] RANGE BETWEEN 6 PRECEDING AND 2 FOLLOWING]] | - | | TableScan: t1 projection=[a] | - | physical_plan | ProjectionExec: expr=[count(Int64(1)) ORDER BY [t1.a DESC NULLS FIRST] RANGE BETWEEN 6 PRECEDING AND 2 FOLLOWING@1 as count(Int64(1)) ORDER BY [t1.a DESC NULLS FIRST] RANGE BETWEEN 6 PRECEDING AND 2 FOLLOWING] | - | | BoundedWindowAggExec: wdw=[count(Int64(1)) ORDER BY [t1.a DESC NULLS FIRST] RANGE BETWEEN 6 PRECEDING AND 2 FOLLOWING: Field { name: "count(Int64(1)) ORDER BY [t1.a DESC NULLS FIRST] RANGE BETWEEN 6 PRECEDING AND 2 FOLLOWING", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }, frame: RANGE BETWEEN 6 PRECEDING AND 2 FOLLOWING], mode=[Sorted] | - | | SortExec: expr=[a@0 DESC], preserve_partitioning=[false] | - | | DataSourceExec: partitions=1, partition_sizes=[1] | - | | | - +---------------+----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ + +---------------+----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ + | plan_type | plan | + +---------------+----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ + | logical_plan | Projection: count(Int64(1)) ORDER BY [t1.a DESC NULLS FIRST] RANGE BETWEEN 6 PRECEDING AND 2 FOLLOWING | + | | WindowAggr: windowExpr=[[count(Int64(1)) ORDER BY [t1.a DESC NULLS FIRST] RANGE BETWEEN 6 PRECEDING AND 2 FOLLOWING]] | + | | TableScan: t1 projection=[a] | + | physical_plan | ProjectionExec: expr=[count(Int64(1)) ORDER BY [t1.a DESC NULLS FIRST] RANGE BETWEEN 6 PRECEDING AND 2 FOLLOWING@1 as count(Int64(1)) ORDER BY [t1.a DESC NULLS FIRST] RANGE BETWEEN 6 PRECEDING AND 2 FOLLOWING] | + | | BoundedWindowAggExec: wdw=[count(Int64(1)) ORDER BY [t1.a DESC NULLS FIRST] RANGE BETWEEN 6 PRECEDING AND 2 FOLLOWING: Field { "count(Int64(1)) ORDER BY [t1.a DESC NULLS FIRST] RANGE BETWEEN 6 PRECEDING AND 2 FOLLOWING": Int64 }, frame: RANGE BETWEEN 6 PRECEDING AND 2 FOLLOWING], mode=[Sorted] | + | | SortExec: expr=[a@0 DESC], preserve_partitioning=[false] | + | | DataSourceExec: partitions=1, partition_sizes=[1] | + | | | + +---------------+----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ "# ); @@ -4436,12 +4435,12 @@ async fn unnest_with_redundant_columns() -> Result<()> { let actual = formatted.trim(); assert_snapshot!( actual, - @r###" + @r" Projection: shapes.shape_id [shape_id:UInt32] Unnest: lists[shape_id2|depth=1] structs[] [shape_id:UInt32, shape_id2:UInt32;N] - Aggregate: groupBy=[[shapes.shape_id]], aggr=[[array_agg(shapes.shape_id) AS shape_id2]] [shape_id:UInt32, shape_id2:List(Field { name: "item", data_type: UInt32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} });N] + Aggregate: groupBy=[[shapes.shape_id]], aggr=[[array_agg(shapes.shape_id) AS shape_id2]] [shape_id:UInt32, shape_id2:List(Field { data_type: UInt32, nullable: true });N] TableScan: shapes projection=[shape_id] [shape_id:UInt32] - "### + " ); let results = df.collect().await?; diff --git a/datafusion/core/tests/csv_schema_fix_test.rs b/datafusion/core/tests/datasource/csv.rs similarity index 100% rename from datafusion/core/tests/csv_schema_fix_test.rs rename to datafusion/core/tests/datasource/csv.rs diff --git a/datafusion/core/tests/datasource/mod.rs b/datafusion/core/tests/datasource/mod.rs new file mode 100644 index 000000000000..3785aa076618 --- /dev/null +++ b/datafusion/core/tests/datasource/mod.rs @@ -0,0 +1,24 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Tests for various DataSources +//! +//! Note tests for the Parquet format are in `parquet_integration` binary + +// Include tests in csv module +mod csv; +mod object_store_access; diff --git a/datafusion/core/tests/datasource/object_store_access.rs b/datafusion/core/tests/datasource/object_store_access.rs new file mode 100644 index 000000000000..6b9585f408a1 --- /dev/null +++ b/datafusion/core/tests/datasource/object_store_access.rs @@ -0,0 +1,616 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Tests for object store access patterns with [`ListingTable`]\ +//! +//! These tests setup a `ListingTable` backed by an in-memory object store +//! that counts the number of requests made against it and then do +//! various operations (table creation, queries with and without predicates) +//! to verify the expected object store access patterns. +//! +//! [`ListingTable`]: datafusion::datasource::listing::ListingTable + +use arrow::array::{ArrayRef, Int32Array, RecordBatch}; +use async_trait::async_trait; +use bytes::Bytes; +use datafusion::prelude::{CsvReadOptions, SessionContext}; +use futures::stream::BoxStream; +use insta::assert_snapshot; +use object_store::memory::InMemory; +use object_store::path::Path; +use object_store::{ + GetOptions, GetRange, GetResult, ListResult, MultipartUpload, ObjectMeta, + ObjectStore, PutMultipartOptions, PutOptions, PutPayload, PutResult, +}; +use parking_lot::Mutex; +use std::fmt; +use std::fmt::{Display, Formatter}; +use std::ops::Range; +use std::sync::Arc; +use url::Url; + +#[tokio::test] +async fn create_single_csv_file() { + assert_snapshot!( + single_file_csv_test().await.requests(), + @r" + RequestCountingObjectStore() + Total Requests: 2 + - HEAD path=csv_table.csv + - GET path=csv_table.csv + " + ); +} + +#[tokio::test] +async fn query_single_csv_file() { + assert_snapshot!( + single_file_csv_test().await.query("select * from csv_table").await, + @r" + ------- Query Output (2 rows) ------- + +---------+-------+-------+ + | c1 | c2 | c3 | + +---------+-------+-------+ + | 0.00001 | 5e-12 | true | + | 0.00002 | 4e-12 | false | + +---------+-------+-------+ + ------- Object Store Request Summary ------- + RequestCountingObjectStore() + Total Requests: 2 + - HEAD path=csv_table.csv + - GET (opts) path=csv_table.csv + " + ); +} + +#[tokio::test] +async fn create_multi_file_csv_file() { + assert_snapshot!( + multi_file_csv_test().await.requests(), + @r" + RequestCountingObjectStore() + Total Requests: 4 + - LIST prefix=data + - GET path=data/file_0.csv + - GET path=data/file_1.csv + - GET path=data/file_2.csv + " + ); +} + +#[tokio::test] +async fn query_multi_csv_file() { + assert_snapshot!( + multi_file_csv_test().await.query("select * from csv_table").await, + @r" + ------- Query Output (6 rows) ------- + +---------+-------+-------+ + | c1 | c2 | c3 | + +---------+-------+-------+ + | 0.0 | 0.0 | true | + | 0.00003 | 5e-12 | false | + | 0.00001 | 1e-12 | true | + | 0.00003 | 5e-12 | false | + | 0.00002 | 2e-12 | true | + | 0.00003 | 5e-12 | false | + +---------+-------+-------+ + ------- Object Store Request Summary ------- + RequestCountingObjectStore() + Total Requests: 4 + - LIST prefix=data + - GET (opts) path=data/file_0.csv + - GET (opts) path=data/file_1.csv + - GET (opts) path=data/file_2.csv + " + ); +} + +#[tokio::test] +async fn create_single_parquet_file() { + assert_snapshot!( + single_file_parquet_test().await.requests(), + @r" + RequestCountingObjectStore() + Total Requests: 4 + - HEAD path=parquet_table.parquet + - GET (range) range=2986-2994 path=parquet_table.parquet + - GET (range) range=2264-2986 path=parquet_table.parquet + - GET (range) range=2124-2264 path=parquet_table.parquet + " + ); +} + +#[tokio::test] +async fn query_single_parquet_file() { + assert_snapshot!( + single_file_parquet_test().await.query("select count(distinct a), count(b) from parquet_table").await, + @r" + ------- Query Output (1 rows) ------- + +---------------------------------+------------------------+ + | count(DISTINCT parquet_table.a) | count(parquet_table.b) | + +---------------------------------+------------------------+ + | 200 | 200 | + +---------------------------------+------------------------+ + ------- Object Store Request Summary ------- + RequestCountingObjectStore() + Total Requests: 3 + - HEAD path=parquet_table.parquet + - GET (ranges) path=parquet_table.parquet ranges=4-534,534-1064 + - GET (ranges) path=parquet_table.parquet ranges=1064-1594,1594-2124 + " + ); +} + +#[tokio::test] +async fn query_single_parquet_file_with_single_predicate() { + // Note that evaluating predicates requires additional object store requests + // (to evaluate predicates) + assert_snapshot!( + single_file_parquet_test().await.query("select min(a), max(b) from parquet_table WHERE a > 150").await, + @r" + ------- Query Output (1 rows) ------- + +----------------------+----------------------+ + | min(parquet_table.a) | max(parquet_table.b) | + +----------------------+----------------------+ + | 151 | 1199 | + +----------------------+----------------------+ + ------- Object Store Request Summary ------- + RequestCountingObjectStore() + Total Requests: 2 + - HEAD path=parquet_table.parquet + - GET (ranges) path=parquet_table.parquet ranges=1064-1481,1481-1594,1594-2011,2011-2124 + " + ); +} + +#[tokio::test] +async fn query_single_parquet_file_multi_row_groups_multiple_predicates() { + // Note that evaluating predicates requires additional object store requests + // (to evaluate predicates) + assert_snapshot!( + single_file_parquet_test().await.query("select min(a), max(b) from parquet_table WHERE a > 50 AND b < 1150").await, + @r" + ------- Query Output (1 rows) ------- + +----------------------+----------------------+ + | min(parquet_table.a) | max(parquet_table.b) | + +----------------------+----------------------+ + | 51 | 1149 | + +----------------------+----------------------+ + ------- Object Store Request Summary ------- + RequestCountingObjectStore() + Total Requests: 3 + - HEAD path=parquet_table.parquet + - GET (ranges) path=parquet_table.parquet ranges=4-421,421-534,534-951,951-1064 + - GET (ranges) path=parquet_table.parquet ranges=1064-1481,1481-1594,1594-2011,2011-2124 + " + ); +} + +/// Create a test with a single CSV file with three columns and two rows +async fn single_file_csv_test() -> Test { + // upload CSV data to object store + let csv_data = r#"c1,c2,c3 +0.00001,5e-12,true +0.00002,4e-12,false +"#; + + Test::new() + .with_bytes("/csv_table.csv", csv_data) + .await + .register_csv("csv_table", "/csv_table.csv") + .await +} + +/// Create a test with three CSV files in a directory +async fn multi_file_csv_test() -> Test { + let mut test = Test::new(); + // upload CSV data to object store + for i in 0..3 { + let csv_data1 = format!( + r#"c1,c2,c3 +0.0000{i},{i}e-12,true +0.00003,5e-12,false +"# + ); + test = test + .with_bytes(&format!("/data/file_{i}.csv"), csv_data1) + .await; + } + // register table + test.register_csv("csv_table", "/data/").await +} + +/// Create a test with a single parquet file that has two +/// columns and two row groups +/// +/// Column "a": Int32 with values 0-100] in row group 1 +/// and [101-200] in row group 2 +/// +/// Column "b": Int32 with values 1000-1100] in row group 1 +/// and [1101-1200] in row group 2 +async fn single_file_parquet_test() -> Test { + // Create parquet bytes + let a: ArrayRef = Arc::new(Int32Array::from_iter_values(0..200)); + let b: ArrayRef = Arc::new(Int32Array::from_iter_values(1000..1200)); + let batch = RecordBatch::try_from_iter([("a", a), ("b", b)]).unwrap(); + + let mut buffer = vec![]; + let props = parquet::file::properties::WriterProperties::builder() + .set_max_row_group_size(100) + .build(); + let mut writer = + parquet::arrow::ArrowWriter::try_new(&mut buffer, batch.schema(), Some(props)) + .unwrap(); + writer.write(&batch).unwrap(); + writer.close().unwrap(); + + Test::new() + .with_bytes("/parquet_table.parquet", buffer) + .await + .register_parquet("parquet_table", "/parquet_table.parquet") + .await +} + +/// Runs tests with a request counting object store +struct Test { + object_store: Arc, + session_context: SessionContext, +} + +impl Test { + fn new() -> Self { + let object_store = Arc::new(RequestCountingObjectStore::new()); + let session_context = SessionContext::new(); + session_context + .runtime_env() + .register_object_store(&Url::parse("mem://").unwrap(), object_store.clone()); + Self { + object_store, + session_context, + } + } + + /// Returns a string representation of all recorded requests thus far + fn requests(&self) -> String { + format!("{}", self.object_store) + } + + /// Store the specified bytes at the given path + async fn with_bytes(self, path: &str, bytes: impl Into) -> Self { + let path = Path::from(path); + self.object_store + .inner + .put(&path, PutPayload::from(bytes.into())) + .await + .unwrap(); + self + } + + /// Register a CSV file at the given path relative to the [`datafusion_test_data`] directory + async fn register_csv(self, table_name: &str, path: &str) -> Self { + let mut options = CsvReadOptions::new(); + options.has_header = true; + let url = format!("mem://{path}"); + self.session_context + .register_csv(table_name, url, options) + .await + .unwrap(); + self + } + + /// Register a CSV file at the given path relative to the [`datafusion_test_data`] directory + async fn register_parquet(self, table_name: &str, path: &str) -> Self { + let path = format!("mem://{path}"); + self.session_context + .register_parquet(table_name, path, Default::default()) + .await + .unwrap(); + self + } + + /// Runs the specified query and returns a string representation of the results + /// suitable for comparison with insta snapshots + /// + /// Clears all recorded requests before running the query + async fn query(&self, sql: &str) -> String { + self.object_store.clear_requests(); + let results = self + .session_context + .sql(sql) + .await + .unwrap() + .collect() + .await + .unwrap(); + + let num_rows = results.iter().map(|batch| batch.num_rows()).sum::(); + let formatted_result = + arrow::util::pretty::pretty_format_batches(&results).unwrap(); + + let object_store = &self.object_store; + + format!( + r#"------- Query Output ({num_rows} rows) ------- +{formatted_result} +------- Object Store Request Summary ------- +{object_store} +"# + ) + } +} + +/// Details of individual requests made through the [`RequestCountingObjectStore`] +#[derive(Clone, Debug)] +enum RequestDetails { + Get { path: Path }, + GetOpts { path: Path, get_options: GetOptions }, + GetRanges { path: Path, ranges: Vec> }, + GetRange { path: Path, range: Range }, + Head { path: Path }, + List { prefix: Option }, + ListWithDelimiter { prefix: Option }, + ListWithOffset { prefix: Option, offset: Path }, +} + +fn display_range(range: &Range) -> impl Display + '_ { + struct Wrapper<'a>(&'a Range); + impl Display for Wrapper<'_> { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + write!(f, "{}-{}", self.0.start, self.0.end) + } + } + Wrapper(range) +} +impl Display for RequestDetails { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + match self { + RequestDetails::Get { path } => { + write!(f, "GET path={path}") + } + RequestDetails::GetOpts { path, get_options } => { + write!(f, "GET (opts) path={path}")?; + if let Some(range) = &get_options.range { + match range { + GetRange::Bounded(range) => { + let range = display_range(range); + write!(f, " range={range}")?; + } + GetRange::Offset(offset) => { + write!(f, " range=offset:{offset}")?; + } + GetRange::Suffix(suffix) => { + write!(f, " range=suffix:{suffix}")?; + } + } + } + if let Some(version) = &get_options.version { + write!(f, " version={version}")?; + } + if get_options.head { + write!(f, " head=true")?; + } + Ok(()) + } + RequestDetails::GetRanges { path, ranges } => { + write!(f, "GET (ranges) path={path}")?; + if !ranges.is_empty() { + write!(f, " ranges=")?; + for (i, range) in ranges.iter().enumerate() { + if i > 0 { + write!(f, ",")?; + } + write!(f, "{}", display_range(range))?; + } + } + Ok(()) + } + RequestDetails::GetRange { path, range } => { + let range = display_range(range); + write!(f, "GET (range) range={range} path={path}") + } + RequestDetails::Head { path } => { + write!(f, "HEAD path={path}") + } + RequestDetails::List { prefix } => { + write!(f, "LIST")?; + if let Some(prefix) = prefix { + write!(f, " prefix={prefix}")?; + } + Ok(()) + } + RequestDetails::ListWithDelimiter { prefix } => { + write!(f, "LIST (with delimiter)")?; + if let Some(prefix) = prefix { + write!(f, " prefix={prefix}")?; + } + Ok(()) + } + RequestDetails::ListWithOffset { prefix, offset } => { + write!(f, "LIST (with offset) offset={offset}")?; + if let Some(prefix) = prefix { + write!(f, " prefix={prefix}")?; + } + Ok(()) + } + } + } +} + +#[derive(Debug)] +struct RequestCountingObjectStore { + /// Inner (memory) store + inner: Arc, + requests: Mutex>, +} + +impl Display for RequestCountingObjectStore { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + write!(f, "RequestCountingObjectStore()")?; + let requests = self.recorded_requests(); + write!(f, "\nTotal Requests: {}", requests.len())?; + for request in requests { + write!(f, "\n- {request}")?; + } + Ok(()) + } +} + +impl RequestCountingObjectStore { + pub fn new() -> Self { + let inner = Arc::new(InMemory::new()); + Self { + inner, + requests: Mutex::new(vec![]), + } + } + + pub fn clear_requests(&self) { + self.requests.lock().clear(); + } + + /// Return a copy of the recorded requests normalized + /// by removing the path prefix + pub fn recorded_requests(&self) -> Vec { + self.requests.lock().to_vec() + } +} + +#[async_trait] +impl ObjectStore for RequestCountingObjectStore { + async fn put_opts( + &self, + _location: &Path, + _payload: PutPayload, + _opts: PutOptions, + ) -> object_store::Result { + Err(object_store::Error::NotImplemented) + } + + async fn put_multipart_opts( + &self, + _location: &Path, + _opts: PutMultipartOptions, + ) -> object_store::Result> { + Err(object_store::Error::NotImplemented) + } + + async fn get(&self, location: &Path) -> object_store::Result { + let result = self.inner.get(location).await?; + self.requests.lock().push(RequestDetails::Get { + path: location.to_owned(), + }); + Ok(result) + } + + async fn get_opts( + &self, + location: &Path, + options: GetOptions, + ) -> object_store::Result { + let result = self.inner.get_opts(location, options.clone()).await?; + self.requests.lock().push(RequestDetails::GetOpts { + path: location.to_owned(), + get_options: options, + }); + Ok(result) + } + + async fn get_range( + &self, + location: &Path, + range: Range, + ) -> object_store::Result { + let result = self.inner.get_range(location, range.clone()).await?; + self.requests.lock().push(RequestDetails::GetRange { + path: location.to_owned(), + range: range.clone(), + }); + Ok(result) + } + + async fn get_ranges( + &self, + location: &Path, + ranges: &[Range], + ) -> object_store::Result> { + let result = self.inner.get_ranges(location, ranges).await?; + self.requests.lock().push(RequestDetails::GetRanges { + path: location.to_owned(), + ranges: ranges.to_vec(), + }); + Ok(result) + } + + async fn head(&self, location: &Path) -> object_store::Result { + let result = self.inner.head(location).await?; + self.requests.lock().push(RequestDetails::Head { + path: location.to_owned(), + }); + Ok(result) + } + + async fn delete(&self, _location: &Path) -> object_store::Result<()> { + Err(object_store::Error::NotImplemented) + } + + fn list( + &self, + prefix: Option<&Path>, + ) -> BoxStream<'static, object_store::Result> { + self.requests.lock().push(RequestDetails::List { + prefix: prefix.map(|p| p.to_owned()), + }); + + self.inner.list(prefix) + } + + fn list_with_offset( + &self, + prefix: Option<&Path>, + offset: &Path, + ) -> BoxStream<'static, object_store::Result> { + self.requests.lock().push(RequestDetails::ListWithOffset { + prefix: prefix.map(|p| p.to_owned()), + offset: offset.to_owned(), + }); + self.inner.list_with_offset(prefix, offset) + } + + async fn list_with_delimiter( + &self, + prefix: Option<&Path>, + ) -> object_store::Result { + self.requests + .lock() + .push(RequestDetails::ListWithDelimiter { + prefix: prefix.map(|p| p.to_owned()), + }); + self.inner.list_with_delimiter(prefix).await + } + + async fn copy(&self, _from: &Path, _to: &Path) -> object_store::Result<()> { + Err(object_store::Error::NotImplemented) + } + + async fn copy_if_not_exists( + &self, + _from: &Path, + _to: &Path, + ) -> object_store::Result<()> { + Err(object_store::Error::NotImplemented) + } +} diff --git a/datafusion/core/tests/expr_api/mod.rs b/datafusion/core/tests/expr_api/mod.rs index 4aee274de908..84e644480a4f 100644 --- a/datafusion/core/tests/expr_api/mod.rs +++ b/datafusion/core/tests/expr_api/mod.rs @@ -320,6 +320,26 @@ async fn test_create_physical_expr() { create_simplified_expr_test(lit(1i32) + lit(2i32), "3"); } +#[test] +fn test_create_physical_expr_nvl2() { + let batch = &TEST_BATCH; + let df_schema = DFSchema::try_from(batch.schema()).unwrap(); + let ctx = SessionContext::new(); + + let expect_err = |expr| { + let physical_expr = ctx.create_physical_expr(expr, &df_schema).unwrap(); + let err = physical_expr.evaluate(batch).unwrap_err(); + assert!( + err.to_string() + .contains("nvl2 should have been simplified to case"), + "unexpected error: {err:?}" + ); + }; + + expect_err(nvl2(col("i"), lit(1i64), lit(0i64))); + expect_err(nvl2(lit(1i64), col("i"), lit(0i64))); +} + #[tokio::test] async fn test_create_physical_expr_coercion() { // create_physical_expr does apply type coercion and unwrapping in cast diff --git a/datafusion/core/tests/expr_api/simplification.rs b/datafusion/core/tests/expr_api/simplification.rs index 89651726a69a..572a7e2b335c 100644 --- a/datafusion/core/tests/expr_api/simplification.rs +++ b/datafusion/core/tests/expr_api/simplification.rs @@ -514,8 +514,7 @@ fn multiple_now() -> Result<()> { // expect the same timestamp appears in both exprs let actual = get_optimized_plan_formatted(plan, &time); let expected = format!( - "Projection: TimestampNanosecond({}, Some(\"+00:00\")) AS now(), TimestampNanosecond({}, Some(\"+00:00\")) AS t2\ - \n TableScan: test", + "Projection: TimestampNanosecond({}, Some(\"+00:00\")) AS now(), TimestampNanosecond({}, Some(\"+00:00\")) AS t2\n TableScan: test", time.timestamp_nanos_opt().unwrap(), time.timestamp_nanos_opt().unwrap() ); diff --git a/datafusion/core/tests/memory_limit/mod.rs b/datafusion/core/tests/memory_limit/mod.rs index 89bc48b1e634..5d8a1d24181c 100644 --- a/datafusion/core/tests/memory_limit/mod.rs +++ b/datafusion/core/tests/memory_limit/mod.rs @@ -23,6 +23,7 @@ use std::sync::{Arc, LazyLock}; #[cfg(feature = "extended_tests")] mod memory_limit_validation; +mod repartition_mem_limit; use arrow::array::{ArrayRef, DictionaryArray, Int32Array, RecordBatch, StringViewArray}; use arrow::compute::SortOptions; use arrow::datatypes::{Int32Type, SchemaRef}; diff --git a/datafusion/core/tests/memory_limit/repartition_mem_limit.rs b/datafusion/core/tests/memory_limit/repartition_mem_limit.rs new file mode 100644 index 000000000000..a7af2f01d1cc --- /dev/null +++ b/datafusion/core/tests/memory_limit/repartition_mem_limit.rs @@ -0,0 +1,116 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::sync::Arc; + +use arrow::array::{ArrayRef, Int32Array, RecordBatch}; +use datafusion::{ + assert_batches_sorted_eq, + prelude::{SessionConfig, SessionContext}, +}; +use datafusion_catalog::MemTable; +use datafusion_common::tree_node::{Transformed, TreeNode}; +use datafusion_execution::runtime_env::RuntimeEnvBuilder; +use datafusion_physical_plan::{repartition::RepartitionExec, ExecutionPlanProperties}; +use futures::TryStreamExt; +use itertools::Itertools; + +/// End to end test for spilling in RepartitionExec. +/// The idea is to make a real world query with a relatively low memory limit and +/// then drive one partition at a time, simulating dissimilar execution speed in partitions. +/// Just as some examples of real world scenarios where this can happen consider +/// lopsided groups in a group by especially if one partitions spills and others don't, +/// or in distributed systems if one upstream node is slower than others. +#[tokio::test] +async fn test_repartition_memory_limit() { + let runtime = RuntimeEnvBuilder::new() + .with_memory_limit(1024 * 1024, 1.0) + .build() + .unwrap(); + let config = SessionConfig::new() + .with_batch_size(32) + .with_target_partitions(2); + let ctx = SessionContext::new_with_config_rt(config, Arc::new(runtime)); + let batches = vec![RecordBatch::try_from_iter(vec![( + "c1", + Arc::new(Int32Array::from_iter_values((0..10).cycle().take(100_000))) as ArrayRef, + )]) + .unwrap()]; + let table = Arc::new(MemTable::try_new(batches[0].schema(), vec![batches]).unwrap()); + ctx.register_table("t", table).unwrap(); + let plan = ctx + .state() + .create_logical_plan("SELECT c1, count(*) as c FROM t GROUP BY c1;") + .await + .unwrap(); + let plan = ctx.state().create_physical_plan(&plan).await.unwrap(); + assert_eq!(plan.output_partitioning().partition_count(), 2); + // Execute partition 0, this should cause items going into the rest of the partitions to queue up and because + // of the low memory limit should spill to disk. + let batches0 = Arc::clone(&plan) + .execute(0, ctx.task_ctx()) + .unwrap() + .try_collect::>() + .await + .unwrap(); + + let mut metrics = None; + Arc::clone(&plan) + .transform_down(|node| { + if node.as_any().is::() { + metrics = node.metrics(); + } + Ok(Transformed::no(node)) + }) + .unwrap(); + + let metrics = metrics.unwrap(); + assert!(metrics.spilled_bytes().unwrap() > 0); + assert!(metrics.spilled_rows().unwrap() > 0); + assert!(metrics.spill_count().unwrap() > 0); + + // Execute the other partition + let batches1 = Arc::clone(&plan) + .execute(1, ctx.task_ctx()) + .unwrap() + .try_collect::>() + .await + .unwrap(); + + let all_batches = batches0 + .into_iter() + .chain(batches1.into_iter()) + .collect_vec(); + #[rustfmt::skip] + let expected = &[ + "+----+-------+", + "| c1 | c |", + "+----+-------+", + "| 0 | 10000 |", + "| 1 | 10000 |", + "| 2 | 10000 |", + "| 3 | 10000 |", + "| 4 | 10000 |", + "| 5 | 10000 |", + "| 6 | 10000 |", + "| 7 | 10000 |", + "| 8 | 10000 |", + "| 9 | 10000 |", + "+----+-------+", + ]; + assert_batches_sorted_eq!(expected, &all_batches); +} diff --git a/datafusion/core/tests/optimizer/mod.rs b/datafusion/core/tests/optimizer/mod.rs index 9899a0158fb8..aec32d05624c 100644 --- a/datafusion/core/tests/optimizer/mod.rs +++ b/datafusion/core/tests/optimizer/mod.rs @@ -144,8 +144,9 @@ fn test_sql(sql: &str) -> Result { let statement = &ast[0]; // create a logical query plan + let config = ConfigOptions::default(); let context_provider = MyContextProvider::default() - .with_udf(datetime::now()) + .with_udf(datetime::now(&config)) .with_udf(datafusion_functions::core::arrow_cast()) .with_udf(datafusion_functions::string::concat()) .with_udf(datafusion_functions::string::concat_ws()); diff --git a/datafusion/core/tests/parquet/encryption.rs b/datafusion/core/tests/parquet/encryption.rs index 819d8bf3a283..09b93f06ce85 100644 --- a/datafusion/core/tests/parquet/encryption.rs +++ b/datafusion/core/tests/parquet/encryption.rs @@ -314,7 +314,7 @@ async fn verify_file_encrypted( for col in row_group.columns() { assert!(matches!( col.crypto_metadata(), - Some(ColumnCryptoMetaData::EncryptionWithFooterKey) + Some(ColumnCryptoMetaData::ENCRYPTION_WITH_FOOTER_KEY) )); } } @@ -336,7 +336,7 @@ impl EncryptionFactory for MockEncryptionFactory { config: &EncryptionFactoryOptions, _schema: &SchemaRef, file_path: &object_store::path::Path, - ) -> datafusion_common::Result> { + ) -> datafusion_common::Result>> { assert_eq!( config.options.get("test_key"), Some(&"test value".to_string()) @@ -353,7 +353,7 @@ impl EncryptionFactory for MockEncryptionFactory { &self, config: &EncryptionFactoryOptions, file_path: &object_store::path::Path, - ) -> datafusion_common::Result> { + ) -> datafusion_common::Result>> { assert_eq!( config.options.get("test_key"), Some(&"test value".to_string()) diff --git a/datafusion/core/tests/parquet/filter_pushdown.rs b/datafusion/core/tests/parquet/filter_pushdown.rs index b769fec7d372..226497fe5824 100644 --- a/datafusion/core/tests/parquet/filter_pushdown.rs +++ b/datafusion/core/tests/parquet/filter_pushdown.rs @@ -631,8 +631,8 @@ async fn predicate_cache_pushdown_default() -> datafusion_common::Result<()> { #[tokio::test] async fn predicate_cache_pushdown_disable() -> datafusion_common::Result<()> { - // Can disable the cache even with filter pushdown by setting the size to 0. In this case we - // expect the inner records are reported but no records are read from the cache + // Can disable the cache even with filter pushdown by setting the size to 0. + // This results in no records read from the cache and no metrics reported let mut config = SessionConfig::new(); config.options_mut().execution.parquet.pushdown_filters = true; config @@ -641,13 +641,10 @@ async fn predicate_cache_pushdown_disable() -> datafusion_common::Result<()> { .parquet .max_predicate_cache_size = Some(0); let ctx = SessionContext::new_with_config(config); + // Since the cache is disabled, there is no reporting or use of the cache PredicateCacheTest { - // file has 8 rows, which need to be read twice, one for filter, one for - // final output - expected_inner_records: 16, - // Expect this to 0 records read as the cache is disabled. However, it is - // non zero due to https://github.com/apache/arrow-rs/issues/8307 - expected_records: 3, + expected_inner_records: 0, + expected_records: 0, } .run(&ctx) .await diff --git a/datafusion/core/tests/parquet/schema_adapter.rs b/datafusion/core/tests/parquet/schema_adapter.rs index 4ae2fa9b4c39..40fc6176e212 100644 --- a/datafusion/core/tests/parquet/schema_adapter.rs +++ b/datafusion/core/tests/parquet/schema_adapter.rs @@ -23,7 +23,9 @@ use arrow_schema::{DataType, Field, FieldRef, Schema, SchemaRef}; use bytes::{BufMut, BytesMut}; use datafusion::assert_batches_eq; use datafusion::common::Result; -use datafusion::datasource::listing::{ListingTable, ListingTableConfig}; +use datafusion::datasource::listing::{ + ListingTable, ListingTableConfig, ListingTableConfigExt, +}; use datafusion::prelude::{SessionConfig, SessionContext}; use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; use datafusion_common::DataFusionError; diff --git a/datafusion/core/tests/parquet_config.rs b/datafusion/core/tests/parquet_integration.rs similarity index 100% rename from datafusion/core/tests/parquet_config.rs rename to datafusion/core/tests/parquet_integration.rs diff --git a/datafusion/core/tests/physical_optimizer/enforce_sorting.rs b/datafusion/core/tests/physical_optimizer/enforce_sorting.rs index a2c604a84e76..620259821871 100644 --- a/datafusion/core/tests/physical_optimizer/enforce_sorting.rs +++ b/datafusion/core/tests/physical_optimizer/enforce_sorting.rs @@ -33,16 +33,12 @@ use arrow::compute::SortOptions; use arrow::datatypes::{DataType, SchemaRef}; use datafusion_common::config::ConfigOptions; use datafusion_common::tree_node::{TreeNode, TransformedResult}; -use datafusion_common::{Result, ScalarValue, TableReference}; +use datafusion_common::{Result, TableReference}; use datafusion_datasource::file_scan_config::FileScanConfigBuilder; use datafusion_datasource::source::DataSourceExec; use datafusion_expr_common::operator::Operator; -use datafusion_expr::{JoinType, SortExpr, WindowFrame, WindowFrameBound, WindowFrameUnits, WindowFunctionDefinition}; +use datafusion_expr::{JoinType, SortExpr}; use datafusion_execution::object_store::ObjectStoreUrl; -use datafusion_functions_aggregate::average::avg_udaf; -use datafusion_functions_aggregate::count::count_udaf; -use datafusion_functions_aggregate::min_max::{max_udaf, min_udaf}; -use datafusion_physical_expr_common::physical_expr::PhysicalExpr; use datafusion_physical_expr_common::sort_expr::{ LexOrdering, PhysicalSortExpr, PhysicalSortRequirement, OrderingRequirements }; @@ -52,8 +48,7 @@ use datafusion_physical_plan::limit::{GlobalLimitExec, LocalLimitExec}; use datafusion_physical_plan::repartition::RepartitionExec; use datafusion_physical_plan::sorts::sort_preserving_merge::SortPreservingMergeExec; use datafusion_physical_plan::sorts::sort::SortExec; -use datafusion_physical_plan::windows::{create_window_expr, BoundedWindowAggExec, WindowAggExec}; -use datafusion_physical_plan::{displayable, get_plan_string, ExecutionPlan, InputOrderMode}; +use datafusion_physical_plan::{displayable, get_plan_string, ExecutionPlan}; use datafusion::datasource::physical_plan::CsvSource; use datafusion::datasource::listing::PartitionedFile; use datafusion_physical_optimizer::enforce_sorting::{EnforceSorting, PlanWithCorrespondingCoalescePartitions, PlanWithCorrespondingSort, parallelize_sorts, ensure_sorting}; @@ -93,13 +88,13 @@ fn csv_exec_sorted( /// Runs the sort enforcement optimizer and asserts the plan /// against the original and expected plans -struct EnforceSortingTest { +pub(crate) struct EnforceSortingTest { plan: Arc, repartition_sorts: bool, } impl EnforceSortingTest { - fn new(plan: Arc) -> Self { + pub(crate) fn new(plan: Arc) -> Self { Self { plan, repartition_sorts: false, @@ -107,14 +102,14 @@ impl EnforceSortingTest { } /// Set whether to repartition sorts - fn with_repartition_sorts(mut self, repartition_sorts: bool) -> Self { + pub(crate) fn with_repartition_sorts(mut self, repartition_sorts: bool) -> Self { self.repartition_sorts = repartition_sorts; self } /// Runs the enforce sorting test and returns a string with the input and /// optimized plan as strings for snapshot comparison using insta - fn run(&self) -> String { + pub(crate) fn run(&self) -> String { let mut config = ConfigOptions::new(); config.optimizer.repartition_sorts = self.repartition_sorts; @@ -672,12 +667,12 @@ async fn test_soft_hard_requirements_remove_soft_requirement() -> Result<()> { let test = EnforceSortingTest::new(physical_plan).with_repartition_sorts(true); assert_snapshot!(test.run(), @r#" Input Plan: - BoundedWindowAggExec: wdw=[count: Field { name: "count", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }, frame: RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW], mode=[Sorted] + BoundedWindowAggExec: wdw=[count: Field { "count": Int64 }, frame: RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW], mode=[Sorted] SortExec: expr=[nullable_col@0 DESC NULLS LAST], preserve_partitioning=[false] DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], file_type=parquet Optimized Plan: - BoundedWindowAggExec: wdw=[count: Field { name: "count", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }, frame: RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW], mode=[Sorted] + BoundedWindowAggExec: wdw=[count: Field { "count": Int64 }, frame: RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW], mode=[Sorted] SortExec: expr=[nullable_col@0 ASC NULLS LAST], preserve_partitioning=[false] DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], file_type=parquet "#); @@ -721,13 +716,13 @@ async fn test_soft_hard_requirements_remove_soft_requirement_without_pushdowns( assert_snapshot!(test.run(), @r#" Input Plan: ProjectionExec: expr=[nullable_col@0 + non_nullable_col@1 as count] - BoundedWindowAggExec: wdw=[count: Field { name: "count", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }, frame: RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW], mode=[Sorted] + BoundedWindowAggExec: wdw=[count: Field { "count": Int64 }, frame: RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW], mode=[Sorted] SortExec: expr=[nullable_col@0 DESC NULLS LAST], preserve_partitioning=[false] DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], file_type=parquet Optimized Plan: ProjectionExec: expr=[nullable_col@0 + non_nullable_col@1 as count] - BoundedWindowAggExec: wdw=[count: Field { name: "count", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }, frame: RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW], mode=[Sorted] + BoundedWindowAggExec: wdw=[count: Field { "count": Int64 }, frame: RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW], mode=[Sorted] SortExec: expr=[nullable_col@0 ASC NULLS LAST], preserve_partitioning=[false] DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], file_type=parquet "#); @@ -768,13 +763,13 @@ async fn test_soft_hard_requirements_remove_soft_requirement_without_pushdowns( let test = EnforceSortingTest::new(physical_plan).with_repartition_sorts(true); assert_snapshot!(test.run(), @r#" Input Plan: - BoundedWindowAggExec: wdw=[count: Field { name: "count", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }, frame: RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW], mode=[Sorted] + BoundedWindowAggExec: wdw=[count: Field { "count": Int64 }, frame: RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW], mode=[Sorted] ProjectionExec: expr=[nullable_col@0 + non_nullable_col@1 as nullable_col] SortExec: expr=[nullable_col@0 DESC NULLS LAST], preserve_partitioning=[false] DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], file_type=parquet Optimized Plan: - BoundedWindowAggExec: wdw=[count: Field { name: "count", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }, frame: RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW], mode=[Sorted] + BoundedWindowAggExec: wdw=[count: Field { "count": Int64 }, frame: RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW], mode=[Sorted] SortExec: expr=[nullable_col@0 ASC NULLS LAST], preserve_partitioning=[false] ProjectionExec: expr=[nullable_col@0 + non_nullable_col@1 as nullable_col] SortExec: expr=[nullable_col@0 DESC NULLS LAST], preserve_partitioning=[false] @@ -829,15 +824,15 @@ async fn test_soft_hard_requirements_multiple_soft_requirements() -> Result<()> let test = EnforceSortingTest::new(physical_plan).with_repartition_sorts(true); assert_snapshot!(test.run(), @r#" Input Plan: - BoundedWindowAggExec: wdw=[count: Field { name: "count", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }, frame: RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW], mode=[Sorted] - BoundedWindowAggExec: wdw=[count: Field { name: "count", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }, frame: RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW], mode=[Sorted] + BoundedWindowAggExec: wdw=[count: Field { "count": Int64 }, frame: RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW], mode=[Sorted] + BoundedWindowAggExec: wdw=[count: Field { "count": Int64 }, frame: RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW], mode=[Sorted] ProjectionExec: expr=[nullable_col@0 + non_nullable_col@1 as nullable_col] SortExec: expr=[nullable_col@0 DESC NULLS LAST], preserve_partitioning=[false] DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], file_type=parquet Optimized Plan: - BoundedWindowAggExec: wdw=[count: Field { name: "count", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }, frame: RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW], mode=[Sorted] - BoundedWindowAggExec: wdw=[count: Field { name: "count", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }, frame: RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW], mode=[Sorted] + BoundedWindowAggExec: wdw=[count: Field { "count": Int64 }, frame: RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW], mode=[Sorted] + BoundedWindowAggExec: wdw=[count: Field { "count": Int64 }, frame: RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW], mode=[Sorted] SortExec: expr=[nullable_col@0 ASC NULLS LAST], preserve_partitioning=[false] ProjectionExec: expr=[nullable_col@0 + non_nullable_col@1 as nullable_col] SortExec: expr=[nullable_col@0 DESC NULLS LAST], preserve_partitioning=[false] @@ -894,17 +889,17 @@ async fn test_soft_hard_requirements_multiple_soft_requirements() -> Result<()> let test = EnforceSortingTest::new(physical_plan).with_repartition_sorts(true); assert_snapshot!(test.run(), @r#" Input Plan: - BoundedWindowAggExec: wdw=[count: Field { name: "count", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }, frame: RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW], mode=[Sorted] + BoundedWindowAggExec: wdw=[count: Field { "count": Int64 }, frame: RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW], mode=[Sorted] SortExec: expr=[nullable_col@0 DESC NULLS LAST], preserve_partitioning=[false] SortExec: expr=[nullable_col@0 DESC NULLS LAST], preserve_partitioning=[false] - BoundedWindowAggExec: wdw=[count: Field { name: "count", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }, frame: RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW], mode=[Sorted] + BoundedWindowAggExec: wdw=[count: Field { "count": Int64 }, frame: RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW], mode=[Sorted] ProjectionExec: expr=[nullable_col@0 + non_nullable_col@1 as nullable_col] SortExec: expr=[nullable_col@0 DESC NULLS LAST], preserve_partitioning=[false] DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], file_type=parquet Optimized Plan: - BoundedWindowAggExec: wdw=[count: Field { name: "count", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }, frame: RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW], mode=[Sorted] - BoundedWindowAggExec: wdw=[count: Field { name: "count", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }, frame: RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW], mode=[Sorted] + BoundedWindowAggExec: wdw=[count: Field { "count": Int64 }, frame: RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW], mode=[Sorted] + BoundedWindowAggExec: wdw=[count: Field { "count": Int64 }, frame: RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW], mode=[Sorted] SortExec: expr=[nullable_col@0 ASC NULLS LAST], preserve_partitioning=[false] ProjectionExec: expr=[nullable_col@0 + non_nullable_col@1 as nullable_col] SortExec: expr=[nullable_col@0 DESC NULLS LAST], preserve_partitioning=[false] @@ -966,14 +961,14 @@ async fn test_soft_hard_requirements_multiple_sorts() -> Result<()> { Input Plan: SortExec: expr=[nullable_col@0 DESC NULLS LAST], preserve_partitioning=[false] SortExec: expr=[nullable_col@0 DESC NULLS LAST], preserve_partitioning=[false] - BoundedWindowAggExec: wdw=[count: Field { name: "count", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }, frame: RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW], mode=[Sorted] + BoundedWindowAggExec: wdw=[count: Field { "count": Int64 }, frame: RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW], mode=[Sorted] ProjectionExec: expr=[nullable_col@0 + non_nullable_col@1 as nullable_col] SortExec: expr=[nullable_col@0 DESC NULLS LAST], preserve_partitioning=[false] DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], file_type=parquet Optimized Plan: SortExec: expr=[nullable_col@0 DESC NULLS LAST], preserve_partitioning=[false] - BoundedWindowAggExec: wdw=[count: Field { name: "count", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }, frame: RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW], mode=[Sorted] + BoundedWindowAggExec: wdw=[count: Field { "count": Int64 }, frame: RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW], mode=[Sorted] SortExec: expr=[nullable_col@0 ASC NULLS LAST], preserve_partitioning=[false] ProjectionExec: expr=[nullable_col@0 + non_nullable_col@1 as nullable_col] SortExec: expr=[nullable_col@0 DESC NULLS LAST], preserve_partitioning=[false] @@ -1028,16 +1023,16 @@ async fn test_soft_hard_requirements_with_multiple_soft_requirements_and_output_ assert_snapshot!(test.run(), @r#" Input Plan: OutputRequirementExec: order_by=[(non_nullable_col@1, asc)], dist_by=SinglePartition - BoundedWindowAggExec: wdw=[count: Field { name: "count", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }, frame: RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW], mode=[Sorted] - BoundedWindowAggExec: wdw=[count: Field { name: "count", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }, frame: RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW], mode=[Sorted] + BoundedWindowAggExec: wdw=[count: Field { "count": Int64 }, frame: RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW], mode=[Sorted] + BoundedWindowAggExec: wdw=[count: Field { "count": Int64 }, frame: RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW], mode=[Sorted] SortExec: expr=[nullable_col@0 DESC NULLS LAST], preserve_partitioning=[false] DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], file_type=parquet Optimized Plan: OutputRequirementExec: order_by=[(non_nullable_col@1, asc)], dist_by=SinglePartition - BoundedWindowAggExec: wdw=[count: Field { name: "count", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }, frame: RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW], mode=[Sorted] + BoundedWindowAggExec: wdw=[count: Field { "count": Int64 }, frame: RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW], mode=[Sorted] SortExec: expr=[non_nullable_col@1 ASC NULLS LAST], preserve_partitioning=[false] - BoundedWindowAggExec: wdw=[count: Field { name: "count", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }, frame: RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW], mode=[Sorted] + BoundedWindowAggExec: wdw=[count: Field { "count": Int64 }, frame: RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW], mode=[Sorted] SortExec: expr=[nullable_col@0 ASC NULLS LAST], preserve_partitioning=[false] DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], file_type=parquet "#); @@ -1086,7 +1081,7 @@ async fn test_window_multi_path_sort() -> Result<()> { let test = EnforceSortingTest::new(physical_plan).with_repartition_sorts(true); assert_snapshot!(test.run(), @r#" Input Plan: - BoundedWindowAggExec: wdw=[count: Field { name: "count", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }, frame: RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW], mode=[Sorted] + BoundedWindowAggExec: wdw=[count: Field { "count": Int64 }, frame: RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW], mode=[Sorted] SortPreservingMergeExec: [nullable_col@0 DESC NULLS LAST] UnionExec SortExec: expr=[nullable_col@0 DESC NULLS LAST], preserve_partitioning=[false] @@ -1095,7 +1090,7 @@ async fn test_window_multi_path_sort() -> Result<()> { DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC], file_type=parquet Optimized Plan: - WindowAggExec: wdw=[count: Ok(Field { name: "count", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: CurrentRow, end_bound: Following(UInt64(NULL)), is_causal: false }] + WindowAggExec: wdw=[count: Ok(Field { name: "count", data_type: Int64 }), frame: WindowFrame { units: Range, start_bound: CurrentRow, end_bound: Following(UInt64(NULL)), is_causal: false }] SortPreservingMergeExec: [nullable_col@0 ASC] UnionExec DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC, non_nullable_col@1 ASC], file_type=parquet @@ -1127,7 +1122,7 @@ async fn test_window_multi_path_sort2() -> Result<()> { let test = EnforceSortingTest::new(physical_plan).with_repartition_sorts(true); assert_snapshot!(test.run(), @r#" Input Plan: - BoundedWindowAggExec: wdw=[count: Field { name: "count", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }, frame: RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW], mode=[Sorted] + BoundedWindowAggExec: wdw=[count: Field { "count": Int64 }, frame: RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW], mode=[Sorted] SortPreservingMergeExec: [nullable_col@0 ASC, non_nullable_col@1 ASC] UnionExec SortExec: expr=[nullable_col@0 ASC, non_nullable_col@1 ASC], preserve_partitioning=[false] @@ -1136,7 +1131,7 @@ async fn test_window_multi_path_sort2() -> Result<()> { DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC], file_type=parquet Optimized Plan: - BoundedWindowAggExec: wdw=[count: Field { name: "count", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }, frame: RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW], mode=[Sorted] + BoundedWindowAggExec: wdw=[count: Field { "count": Int64 }, frame: RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW], mode=[Sorted] SortPreservingMergeExec: [nullable_col@0 ASC] UnionExec DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC], file_type=parquet @@ -1683,7 +1678,7 @@ async fn test_window_multi_layer_requirement() -> Result<()> { EnforceSortingTest::new(physical_plan.clone()).with_repartition_sorts(true); assert_snapshot!(test.run(), @r#" Input Plan: - BoundedWindowAggExec: wdw=[count: Field { name: "count", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }, frame: RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW], mode=[Sorted] + BoundedWindowAggExec: wdw=[count: Field { "count": Int64 }, frame: RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW], mode=[Sorted] SortPreservingMergeExec: [a@0 ASC, b@1 ASC] RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=10, preserve_order=true, sort_exprs=a@0 ASC, b@1 ASC RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1 @@ -1691,7 +1686,7 @@ async fn test_window_multi_layer_requirement() -> Result<()> { DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=csv, has_header=false Optimized Plan: - BoundedWindowAggExec: wdw=[count: Field { name: "count", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }, frame: RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW], mode=[Sorted] + BoundedWindowAggExec: wdw=[count: Field { "count": Int64 }, frame: RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW], mode=[Sorted] SortPreservingMergeExec: [a@0 ASC, b@1 ASC] SortExec: expr=[a@0 ASC, b@1 ASC], preserve_partitioning=[true] RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=10 @@ -1788,18 +1783,18 @@ async fn test_remove_unnecessary_sort_window_multilayer() -> Result<()> { EnforceSortingTest::new(physical_plan.clone()).with_repartition_sorts(true); assert_snapshot!(test.run(), @r#" Input Plan: - BoundedWindowAggExec: wdw=[count: Field { name: "count", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }, frame: RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW], mode=[Sorted] + BoundedWindowAggExec: wdw=[count: Field { "count": Int64 }, frame: RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW], mode=[Sorted] FilterExec: NOT non_nullable_col@1 SortExec: expr=[non_nullable_col@1 ASC NULLS LAST], preserve_partitioning=[false] - BoundedWindowAggExec: wdw=[count: Field { name: "count", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }, frame: RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW], mode=[Sorted] + BoundedWindowAggExec: wdw=[count: Field { "count": Int64 }, frame: RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW], mode=[Sorted] CoalesceBatchesExec: target_batch_size=128 SortExec: expr=[non_nullable_col@1 DESC], preserve_partitioning=[false] DataSourceExec: partitions=1, partition_sizes=[0] Optimized Plan: - WindowAggExec: wdw=[count: Ok(Field { name: "count", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: CurrentRow, end_bound: Following(UInt64(NULL)), is_causal: false }] + WindowAggExec: wdw=[count: Ok(Field { name: "count", data_type: Int64 }), frame: WindowFrame { units: Range, start_bound: CurrentRow, end_bound: Following(UInt64(NULL)), is_causal: false }] FilterExec: NOT non_nullable_col@1 - BoundedWindowAggExec: wdw=[count: Field { name: "count", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }, frame: RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW], mode=[Sorted] + BoundedWindowAggExec: wdw=[count: Field { "count": Int64 }, frame: RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW], mode=[Sorted] CoalesceBatchesExec: target_batch_size=128 SortExec: expr=[non_nullable_col@1 DESC], preserve_partitioning=[false] DataSourceExec: partitions=1, partition_sizes=[0] @@ -2243,17 +2238,17 @@ async fn test_multiple_sort_window_exec() -> Result<()> { EnforceSortingTest::new(physical_plan.clone()).with_repartition_sorts(true); assert_snapshot!(test.run(), @r#" Input Plan: - BoundedWindowAggExec: wdw=[count: Field { name: "count", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }, frame: RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW], mode=[Sorted] - BoundedWindowAggExec: wdw=[count: Field { name: "count", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }, frame: RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW], mode=[Sorted] - BoundedWindowAggExec: wdw=[count: Field { name: "count", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }, frame: RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW], mode=[Sorted] + BoundedWindowAggExec: wdw=[count: Field { "count": Int64 }, frame: RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW], mode=[Sorted] + BoundedWindowAggExec: wdw=[count: Field { "count": Int64 }, frame: RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW], mode=[Sorted] + BoundedWindowAggExec: wdw=[count: Field { "count": Int64 }, frame: RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW], mode=[Sorted] SortExec: expr=[nullable_col@0 ASC], preserve_partitioning=[false] DataSourceExec: partitions=1, partition_sizes=[0] Optimized Plan: - BoundedWindowAggExec: wdw=[count: Field { name: "count", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }, frame: RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW], mode=[Sorted] - BoundedWindowAggExec: wdw=[count: Field { name: "count", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }, frame: RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW], mode=[Sorted] + BoundedWindowAggExec: wdw=[count: Field { "count": Int64 }, frame: RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW], mode=[Sorted] + BoundedWindowAggExec: wdw=[count: Field { "count": Int64 }, frame: RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW], mode=[Sorted] SortExec: expr=[nullable_col@0 ASC, non_nullable_col@1 ASC], preserve_partitioning=[false] - BoundedWindowAggExec: wdw=[count: Field { name: "count", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }, frame: RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW], mode=[Sorted] + BoundedWindowAggExec: wdw=[count: Field { "count": Int64 }, frame: RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW], mode=[Sorted] SortExec: expr=[nullable_col@0 ASC], preserve_partitioning=[false] DataSourceExec: partitions=1, partition_sizes=[0] "#); @@ -2278,7 +2273,7 @@ async fn test_commutativity() -> Result<()> { assert_snapshot!(displayable(orig_plan.as_ref()).indent(true), @r#" SortExec: expr=[nullable_col@0 ASC], preserve_partitioning=[false] RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1 - BoundedWindowAggExec: wdw=[count: Field { name: "count", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }, frame: RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW], mode=[Sorted] + BoundedWindowAggExec: wdw=[count: Field { "count": Int64 }, frame: RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW], mode=[Sorted] DataSourceExec: partitions=1, partition_sizes=[0] "#); @@ -2487,1203 +2482,7 @@ async fn test_not_replaced_with_partial_sort_for_unbounded_input() -> Result<()> "); Ok(()) } -// aal here -#[tokio::test] -async fn test_window_partial_constant_and_set_monotonicity() -> Result<()> { - let input_schema = create_test_schema()?; - let ordering = [sort_expr_options( - "nullable_col", - &input_schema, - SortOptions { - descending: false, - nulls_first: false, - }, - )] - .into(); - let source = parquet_exec_with_sort(input_schema.clone(), vec![ordering]) as _; - - // Macro for testing window function optimization with snapshots - macro_rules! test_window_case { - ( - partition_by: $partition_by:expr, - window_frame: $window_frame:expr, - func: ($func_def:expr, $func_name:expr, $func_args:expr), - required_sort: [$($col:expr, $asc:expr, $nulls_first:expr),*], - @ $expected:literal - ) => {{ - let partition_by_exprs = if $partition_by { - vec![col("nullable_col", &input_schema)?] - } else { - vec![] - }; - - let window_expr = create_window_expr( - &$func_def, - $func_name, - &$func_args, - &partition_by_exprs, - &[], - $window_frame, - Arc::clone(&input_schema), - false, - false, - None, - )?; - - let window_exec = if window_expr.uses_bounded_memory() { - Arc::new(BoundedWindowAggExec::try_new( - vec![window_expr], - Arc::clone(&source), - InputOrderMode::Sorted, - $partition_by, - )?) as Arc - } else { - Arc::new(WindowAggExec::try_new( - vec![window_expr], - Arc::clone(&source), - $partition_by, - )?) as Arc - }; - - let output_schema = window_exec.schema(); - let sort_expr = vec![ - $( - sort_expr_options( - $col, - &output_schema, - SortOptions { - descending: !$asc, - nulls_first: $nulls_first, - }, - ) - ),* - ]; - let ordering = LexOrdering::new(sort_expr).unwrap(); - let physical_plan = sort_exec(ordering, window_exec); - - let test = EnforceSortingTest::new(physical_plan).with_repartition_sorts(true); - - assert_snapshot!(test.run(), @ $expected); - - Result::<(), datafusion_common::DataFusionError>::Ok(()) - }}; - } - - // Function definition - Alias of the resulting column - Arguments of the function - #[derive(Clone)] - struct WindowFuncParam(WindowFunctionDefinition, String, Vec>); - let function_arg_ordered = vec![col("nullable_col", &input_schema)?]; - let function_arg_unordered = vec![col("non_nullable_col", &input_schema)?]; - let fn_count_on_ordered = WindowFuncParam( - WindowFunctionDefinition::AggregateUDF(count_udaf()), - "count".to_string(), - function_arg_ordered.clone(), - ); - let fn_max_on_ordered = WindowFuncParam( - WindowFunctionDefinition::AggregateUDF(max_udaf()), - "max".to_string(), - function_arg_ordered.clone(), - ); - let fn_min_on_ordered = WindowFuncParam( - WindowFunctionDefinition::AggregateUDF(min_udaf()), - "min".to_string(), - function_arg_ordered.clone(), - ); - let fn_avg_on_ordered = WindowFuncParam( - WindowFunctionDefinition::AggregateUDF(avg_udaf()), - "avg".to_string(), - function_arg_ordered, - ); - let fn_count_on_unordered = WindowFuncParam( - WindowFunctionDefinition::AggregateUDF(count_udaf()), - "count".to_string(), - function_arg_unordered.clone(), - ); - let fn_max_on_unordered = WindowFuncParam( - WindowFunctionDefinition::AggregateUDF(max_udaf()), - "max".to_string(), - function_arg_unordered.clone(), - ); - let fn_min_on_unordered = WindowFuncParam( - WindowFunctionDefinition::AggregateUDF(min_udaf()), - "min".to_string(), - function_arg_unordered.clone(), - ); - let fn_avg_on_unordered = WindowFuncParam( - WindowFunctionDefinition::AggregateUDF(avg_udaf()), - "avg".to_string(), - function_arg_unordered, - ); - - // ============================================REGION STARTS============================================ - // WindowAggExec + Plain(unbounded preceding, unbounded following) + no partition_by + on ordered column - // Case 0: - test_window_case!( - partition_by: false, - window_frame: Arc::new(WindowFrame::new(None)), - func: (fn_count_on_ordered.0.clone(), fn_count_on_ordered.1.clone(), fn_count_on_ordered.2.clone()), - required_sort: ["nullable_col", true, false, "count", true, false], - @ r#" - Input Plan: - SortExec: expr=[nullable_col@0 ASC NULLS LAST, count@2 ASC NULLS LAST], preserve_partitioning=[false] - WindowAggExec: wdw=[count: Ok(Field { name: "count", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: Following(UInt64(NULL)), is_causal: false }] - DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet - - Optimized Plan: - WindowAggExec: wdw=[count: Ok(Field { name: "count", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: Following(UInt64(NULL)), is_causal: false }] - DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet - "# - )?; - - // Case 1: - test_window_case!( - partition_by: false, - window_frame: Arc::new(WindowFrame::new(None)), - func: (fn_max_on_ordered.0.clone(), fn_max_on_ordered.1.clone(), fn_max_on_ordered.2.clone()), - required_sort: ["nullable_col", true, false, "max", false, false], - @ r#" - Input Plan: - SortExec: expr=[nullable_col@0 ASC NULLS LAST, max@2 DESC NULLS LAST], preserve_partitioning=[false] - WindowAggExec: wdw=[max: Ok(Field { name: "max", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: Following(UInt64(NULL)), is_causal: false }] - DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet - - Optimized Plan: - WindowAggExec: wdw=[max: Ok(Field { name: "max", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: Following(UInt64(NULL)), is_causal: false }] - DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet - "# - )?; - - // Case 2: - test_window_case!( - partition_by: false, - window_frame: Arc::new(WindowFrame::new(None)), - func: (fn_min_on_ordered.0.clone(), fn_min_on_ordered.1.clone(), fn_min_on_ordered.2.clone()), - required_sort: ["min", false, false, "nullable_col", true, false], - @ r#" - Input Plan: - SortExec: expr=[min@2 DESC NULLS LAST, nullable_col@0 ASC NULLS LAST], preserve_partitioning=[false] - WindowAggExec: wdw=[min: Ok(Field { name: "min", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: Following(UInt64(NULL)), is_causal: false }] - DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet - - Optimized Plan: - WindowAggExec: wdw=[min: Ok(Field { name: "min", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: Following(UInt64(NULL)), is_causal: false }] - DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet - "# - )?; - - // Case 3: - test_window_case!( - partition_by: false, - window_frame: Arc::new(WindowFrame::new(None)), - func: (fn_avg_on_ordered.0.clone(), fn_avg_on_ordered.1.clone(), fn_avg_on_ordered.2.clone()), - required_sort: ["avg", true, false, "nullable_col", true, false], - @ r#" -Input Plan: -SortExec: expr=[avg@2 ASC NULLS LAST, nullable_col@0 ASC NULLS LAST], preserve_partitioning=[false] - WindowAggExec: wdw=[avg: Ok(Field { name: "avg", data_type: Float64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: Following(UInt64(NULL)), is_causal: false }] - DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet - -Optimized Plan: -WindowAggExec: wdw=[avg: Ok(Field { name: "avg", data_type: Float64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: Following(UInt64(NULL)), is_causal: false }] - DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet -"# - )?; - // =============================================REGION ENDS============================================= - // = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = - // ============================================REGION STARTS============================================ - // WindowAggExec + Plain(unbounded preceding, unbounded following) + no partition_by + on unordered column - // Case 4: - test_window_case!( - partition_by: false, - window_frame: Arc::new(WindowFrame::new(None)), - func: (fn_count_on_unordered.0.clone(), fn_count_on_unordered.1.clone(), fn_count_on_unordered.2.clone()), - required_sort: ["non_nullable_col", true, false, "count", true, false], - @ r#" -Input Plan: -SortExec: expr=[non_nullable_col@1 ASC NULLS LAST, count@2 ASC NULLS LAST], preserve_partitioning=[false] - WindowAggExec: wdw=[count: Ok(Field { name: "count", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: Following(UInt64(NULL)), is_causal: false }] - DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet - -Optimized Plan: -SortExec: expr=[non_nullable_col@1 ASC NULLS LAST], preserve_partitioning=[false] - WindowAggExec: wdw=[count: Ok(Field { name: "count", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: Following(UInt64(NULL)), is_causal: false }] - DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet -"# - )?; - - // Case 5: - test_window_case!( - partition_by: false, - window_frame: Arc::new(WindowFrame::new(None)), - func: (fn_max_on_unordered.0.clone(), fn_max_on_unordered.1.clone(), fn_max_on_unordered.2.clone()), - required_sort: ["non_nullable_col", false, false, "max", false, false], - @ r#" -Input Plan: -SortExec: expr=[non_nullable_col@1 DESC NULLS LAST, max@2 DESC NULLS LAST], preserve_partitioning=[false] - WindowAggExec: wdw=[max: Ok(Field { name: "max", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: Following(UInt64(NULL)), is_causal: false }] - DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet - -Optimized Plan: -SortExec: expr=[non_nullable_col@1 DESC NULLS LAST], preserve_partitioning=[false] - WindowAggExec: wdw=[max: Ok(Field { name: "max", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: Following(UInt64(NULL)), is_causal: false }] - DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet -"# - )?; - - // Case 6: - test_window_case!( - partition_by: false, - window_frame: Arc::new(WindowFrame::new(None)), - func: (fn_min_on_unordered.0.clone(), fn_min_on_unordered.1.clone(), fn_min_on_unordered.2.clone()), - required_sort: ["min", true, false, "non_nullable_col", true, false], - @ r#" -Input Plan: -SortExec: expr=[min@2 ASC NULLS LAST, non_nullable_col@1 ASC NULLS LAST], preserve_partitioning=[false] - WindowAggExec: wdw=[min: Ok(Field { name: "min", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: Following(UInt64(NULL)), is_causal: false }] - DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet - -Optimized Plan: -SortExec: expr=[non_nullable_col@1 ASC NULLS LAST], preserve_partitioning=[false] - WindowAggExec: wdw=[min: Ok(Field { name: "min", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: Following(UInt64(NULL)), is_causal: false }] - DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet -"# - )?; - - // Case 7: - test_window_case!( - partition_by: false, - window_frame: Arc::new(WindowFrame::new(None)), - func: (fn_avg_on_unordered.0.clone(), fn_avg_on_unordered.1.clone(), fn_avg_on_unordered.2.clone()), - required_sort: ["avg", false, false, "nullable_col", false, false], - @ r#" -Input Plan: -SortExec: expr=[avg@2 DESC NULLS LAST, nullable_col@0 DESC NULLS LAST], preserve_partitioning=[false] - WindowAggExec: wdw=[avg: Ok(Field { name: "avg", data_type: Float64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: Following(UInt64(NULL)), is_causal: false }] - DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet - -Optimized Plan: -SortExec: expr=[nullable_col@0 DESC NULLS LAST], preserve_partitioning=[false] - WindowAggExec: wdw=[avg: Ok(Field { name: "avg", data_type: Float64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: Following(UInt64(NULL)), is_causal: false }] - DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet -"# - )?; - // =============================================REGION ENDS============================================= - // = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = - // ============================================REGION STARTS============================================ - // WindowAggExec + Plain(unbounded preceding, unbounded following) + partition_by + on ordered column - // Case 8: - test_window_case!( - partition_by: true, - window_frame: Arc::new(WindowFrame::new(None)), - func: (fn_count_on_ordered.0.clone(), fn_count_on_ordered.1.clone(), fn_count_on_ordered.2.clone()), - required_sort: ["nullable_col", true, false, "count", true, false], - @ r#" -Input Plan: -SortExec: expr=[nullable_col@0 ASC NULLS LAST, count@2 ASC NULLS LAST], preserve_partitioning=[false] - WindowAggExec: wdw=[count: Ok(Field { name: "count", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: Following(UInt64(NULL)), is_causal: false }] - DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet - -Optimized Plan: -WindowAggExec: wdw=[count: Ok(Field { name: "count", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: Following(UInt64(NULL)), is_causal: false }] - DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet -"# - )?; - - // Case 9: - test_window_case!( - partition_by: true, - window_frame: Arc::new(WindowFrame::new(None)), - func: (fn_max_on_ordered.0.clone(), fn_max_on_ordered.1.clone(), fn_max_on_ordered.2.clone()), - required_sort: ["nullable_col", true, false, "max", false, false], - @ r#" -Input Plan: -SortExec: expr=[nullable_col@0 ASC NULLS LAST, max@2 DESC NULLS LAST], preserve_partitioning=[false] - WindowAggExec: wdw=[max: Ok(Field { name: "max", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: Following(UInt64(NULL)), is_causal: false }] - DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet - -Optimized Plan: -WindowAggExec: wdw=[max: Ok(Field { name: "max", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: Following(UInt64(NULL)), is_causal: false }] - DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet -"# - )?; - - // Case 10: - test_window_case!( - partition_by: true, - window_frame: Arc::new(WindowFrame::new(None)), - func: (fn_min_on_ordered.0.clone(), fn_min_on_ordered.1.clone(), fn_min_on_ordered.2.clone()), - required_sort: ["min", false, false, "nullable_col", true, false], - @ r#" - Input / Optimized Plan: - SortExec: expr=[min@2 DESC NULLS LAST, nullable_col@0 ASC NULLS LAST], preserve_partitioning=[false] - WindowAggExec: wdw=[min: Ok(Field { name: "min", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: Following(UInt64(NULL)), is_causal: false }] - DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet - "# - )?; - - // Case 11: - test_window_case!( - partition_by: true, - window_frame: Arc::new(WindowFrame::new(None)), - func: (fn_avg_on_ordered.0.clone(), fn_avg_on_ordered.1.clone(), fn_avg_on_ordered.2.clone()), - required_sort: ["avg", true, false, "nullable_col", true, false], - @ r#" - Input / Optimized Plan: - SortExec: expr=[avg@2 ASC NULLS LAST, nullable_col@0 ASC NULLS LAST], preserve_partitioning=[false] - WindowAggExec: wdw=[avg: Ok(Field { name: "avg", data_type: Float64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: Following(UInt64(NULL)), is_causal: false }] - DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet - "# - )?; - // =============================================REGION ENDS============================================= - // = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = - // ============================================REGION STARTS============================================ - // WindowAggExec + Plain(unbounded preceding, unbounded following) + partition_by + on unordered column - // Case 12: - test_window_case!( - partition_by: true, - window_frame: Arc::new(WindowFrame::new(None)), - func: (fn_count_on_unordered.0.clone(), fn_count_on_unordered.1.clone(), fn_count_on_unordered.2.clone()), - required_sort: ["non_nullable_col", true, false, "count", true, false], - @ r#" - Input / Optimized Plan: - SortExec: expr=[non_nullable_col@1 ASC NULLS LAST, count@2 ASC NULLS LAST], preserve_partitioning=[false] - WindowAggExec: wdw=[count: Ok(Field { name: "count", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: Following(UInt64(NULL)), is_causal: false }] - DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet - "# - )?; - - // Case 13: - test_window_case!( - partition_by: true, - window_frame: Arc::new(WindowFrame::new(None)), - func: (fn_max_on_unordered.0.clone(), fn_max_on_unordered.1.clone(), fn_max_on_unordered.2.clone()), - required_sort: ["non_nullable_col", true, false, "max", false, false], - @ r#" - Input / Optimized Plan: - SortExec: expr=[non_nullable_col@1 ASC NULLS LAST, max@2 DESC NULLS LAST], preserve_partitioning=[false] - WindowAggExec: wdw=[max: Ok(Field { name: "max", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: Following(UInt64(NULL)), is_causal: false }] - DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet - "# - )?; - - // Case 14: - test_window_case!( - partition_by: true, - window_frame: Arc::new(WindowFrame::new(None)), - func: (fn_min_on_unordered.0.clone(), fn_min_on_unordered.1.clone(), fn_min_on_unordered.2.clone()), - required_sort: ["min", false, false, "non_nullable_col", true, false], - @ r#" - Input / Optimized Plan: - SortExec: expr=[min@2 DESC NULLS LAST, non_nullable_col@1 ASC NULLS LAST], preserve_partitioning=[false] - WindowAggExec: wdw=[min: Ok(Field { name: "min", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: Following(UInt64(NULL)), is_causal: false }] - DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet - "# - )?; - - // Case 15: - test_window_case!( - partition_by: true, - window_frame: Arc::new(WindowFrame::new(None)), - func: (fn_avg_on_unordered.0.clone(), fn_avg_on_unordered.1.clone(), fn_avg_on_unordered.2.clone()), - required_sort: ["avg", true, false, "nullable_col", true, false], - @ r#" - Input / Optimized Plan: - SortExec: expr=[avg@2 ASC NULLS LAST, nullable_col@0 ASC NULLS LAST], preserve_partitioning=[false] - WindowAggExec: wdw=[avg: Ok(Field { name: "avg", data_type: Float64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: Following(UInt64(NULL)), is_causal: false }] - DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet - "# - )?; - - // =============================================REGION ENDS============================================= - // = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = - // ============================================REGION STARTS============================================ - // WindowAggExec + Sliding(current row, unbounded following) + no partition_by + on ordered column - // Case 16: - test_window_case!( - partition_by: false, - window_frame: Arc::new(WindowFrame::new(Some(true)).reverse()), - func: (fn_count_on_ordered.0.clone(), fn_count_on_ordered.1.clone(), fn_count_on_ordered.2.clone()), - required_sort: ["nullable_col", true, false, "count", false, false], - @ r#" -Input Plan: -SortExec: expr=[nullable_col@0 ASC NULLS LAST, count@2 DESC NULLS LAST], preserve_partitioning=[false] - WindowAggExec: wdw=[count: Ok(Field { name: "count", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: CurrentRow, end_bound: Following(UInt64(NULL)), is_causal: false }] - DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet - -Optimized Plan: -WindowAggExec: wdw=[count: Ok(Field { name: "count", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: CurrentRow, end_bound: Following(UInt64(NULL)), is_causal: false }] - DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet -"# - )?; - - // Case 17: - test_window_case!( - partition_by: false, - window_frame: Arc::new(WindowFrame::new(Some(true)).reverse()), - func: (fn_max_on_ordered.0.clone(), fn_max_on_ordered.1.clone(), fn_max_on_ordered.2.clone()), - required_sort: ["max", false, true, "nullable_col", true, false], - @ r#" -Input Plan: -SortExec: expr=[max@2 DESC, nullable_col@0 ASC NULLS LAST], preserve_partitioning=[false] - WindowAggExec: wdw=[max: Ok(Field { name: "max", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: CurrentRow, end_bound: Following(UInt64(NULL)), is_causal: false }] - DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet - -Optimized Plan: -WindowAggExec: wdw=[max: Ok(Field { name: "max", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: CurrentRow, end_bound: Following(UInt64(NULL)), is_causal: false }] - DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet -"# - )?; - - // Case 18: - test_window_case!( - partition_by: false, - window_frame: Arc::new(WindowFrame::new(Some(true)).reverse()), - func: (fn_min_on_ordered.0.clone(), fn_min_on_ordered.1.clone(), fn_min_on_ordered.2.clone()), - required_sort: ["min", true, true, "nullable_col", true, false], - @ r#" -Input Plan: -SortExec: expr=[min@2 ASC, nullable_col@0 ASC NULLS LAST], preserve_partitioning=[false] - WindowAggExec: wdw=[min: Ok(Field { name: "min", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: CurrentRow, end_bound: Following(UInt64(NULL)), is_causal: false }] - DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet - -Optimized Plan: -WindowAggExec: wdw=[min: Ok(Field { name: "min", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: CurrentRow, end_bound: Following(UInt64(NULL)), is_causal: false }] - DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet -"# - )?; - - // Case 19: - test_window_case!( - partition_by: false, - window_frame: Arc::new(WindowFrame::new(Some(true)).reverse()), - func: (fn_avg_on_ordered.0.clone(), fn_avg_on_ordered.1.clone(), fn_avg_on_ordered.2.clone()), - required_sort: ["avg", false, false, "nullable_col", true, false], - @ r#" - Input / Optimized Plan: - SortExec: expr=[avg@2 DESC NULLS LAST, nullable_col@0 ASC NULLS LAST], preserve_partitioning=[false] - WindowAggExec: wdw=[avg: Ok(Field { name: "avg", data_type: Float64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: CurrentRow, end_bound: Following(UInt64(NULL)), is_causal: false }] - DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet - "# - )?; - // =============================================REGION ENDS============================================= - // = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = - // ============================================REGION STARTS============================================ - // WindowAggExec + Sliding(current row, unbounded following) + no partition_by + on unordered column - // Case 20: - test_window_case!( - partition_by: false, - window_frame: Arc::new(WindowFrame::new(Some(true)).reverse()), - func: (fn_count_on_unordered.0.clone(), fn_count_on_unordered.1.clone(), fn_count_on_unordered.2.clone()), - required_sort: ["nullable_col", true, false, "count", true, false], - @ r#" - Input / Optimized Plan: - SortExec: expr=[nullable_col@0 ASC NULLS LAST, count@2 ASC NULLS LAST], preserve_partitioning=[false] - WindowAggExec: wdw=[count: Ok(Field { name: "count", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: CurrentRow, end_bound: Following(UInt64(NULL)), is_causal: false }] - DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet - "# - )?; - - // Case 21: - test_window_case!( - partition_by: false, - window_frame: Arc::new(WindowFrame::new(Some(true)).reverse()), - func: (fn_max_on_unordered.0.clone(), fn_max_on_unordered.1.clone(), fn_max_on_unordered.2.clone()), - required_sort: ["nullable_col", true, false, "max", false, true], - @ r#" -Input Plan: -SortExec: expr=[nullable_col@0 ASC NULLS LAST, max@2 DESC], preserve_partitioning=[false] - WindowAggExec: wdw=[max: Ok(Field { name: "max", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: CurrentRow, end_bound: Following(UInt64(NULL)), is_causal: false }] - DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet - -Optimized Plan: -WindowAggExec: wdw=[max: Ok(Field { name: "max", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: CurrentRow, end_bound: Following(UInt64(NULL)), is_causal: false }] - DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet -"# - )?; - - // Case 22: - test_window_case!( - partition_by: false, - window_frame: Arc::new(WindowFrame::new(Some(true)).reverse()), - func: (fn_min_on_unordered.0.clone(), fn_min_on_unordered.1.clone(), fn_min_on_unordered.2.clone()), - required_sort: ["min", true, false, "nullable_col", true, false], - @ r#" - Input / Optimized Plan: - SortExec: expr=[min@2 ASC NULLS LAST, nullable_col@0 ASC NULLS LAST], preserve_partitioning=[false] - WindowAggExec: wdw=[min: Ok(Field { name: "min", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: CurrentRow, end_bound: Following(UInt64(NULL)), is_causal: false }] - DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet - "# - )?; - - // Case 23: - test_window_case!( - partition_by: false, - window_frame: Arc::new(WindowFrame::new(Some(true)).reverse()), - func: (fn_avg_on_unordered.0.clone(), fn_avg_on_unordered.1.clone(), fn_avg_on_unordered.2.clone()), - required_sort: ["avg", false, false, "nullable_col", true, false], - @ r#" - Input / Optimized Plan: - SortExec: expr=[avg@2 DESC NULLS LAST, nullable_col@0 ASC NULLS LAST], preserve_partitioning=[false] - WindowAggExec: wdw=[avg: Ok(Field { name: "avg", data_type: Float64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: CurrentRow, end_bound: Following(UInt64(NULL)), is_causal: false }] - DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet - "# - )?; - // =============================================REGION ENDS============================================= - // = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = - // ============================================REGION STARTS============================================ - // WindowAggExec + Sliding(current row, unbounded following) + partition_by + on ordered column - // Case 24: - test_window_case!( - partition_by: true, - window_frame: Arc::new(WindowFrame::new(Some(true)).reverse()), - func: (fn_count_on_ordered.0.clone(), fn_count_on_ordered.1.clone(), fn_count_on_ordered.2.clone()), - required_sort: ["nullable_col", true, false, "count", false, false], - @ r#" -Input Plan: -SortExec: expr=[nullable_col@0 ASC NULLS LAST, count@2 DESC NULLS LAST], preserve_partitioning=[false] - WindowAggExec: wdw=[count: Ok(Field { name: "count", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: CurrentRow, end_bound: Following(UInt64(NULL)), is_causal: false }] - DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet - -Optimized Plan: -WindowAggExec: wdw=[count: Ok(Field { name: "count", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: CurrentRow, end_bound: Following(UInt64(NULL)), is_causal: false }] - DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet -"# - )?; - - // Case 25: - test_window_case!( - partition_by: true, - window_frame: Arc::new(WindowFrame::new(Some(true)).reverse()), - func: (fn_max_on_ordered.0.clone(), fn_max_on_ordered.1.clone(), fn_max_on_ordered.2.clone()), - required_sort: ["nullable_col", true, false, "max", true, false], - @ r#" - Input / Optimized Plan: - SortExec: expr=[nullable_col@0 ASC NULLS LAST, max@2 ASC NULLS LAST], preserve_partitioning=[false] - WindowAggExec: wdw=[max: Ok(Field { name: "max", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: CurrentRow, end_bound: Following(UInt64(NULL)), is_causal: false }] - DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet - "# - )?; - - // Case 26: - test_window_case!( - partition_by: true, - window_frame: Arc::new(WindowFrame::new(Some(true)).reverse()), - func: (fn_min_on_ordered.0.clone(), fn_min_on_ordered.1.clone(), fn_min_on_ordered.2.clone()), - required_sort: ["min", false, false], - @ r#" - Input / Optimized Plan: - SortExec: expr=[min@2 DESC NULLS LAST], preserve_partitioning=[false] - WindowAggExec: wdw=[min: Ok(Field { name: "min", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: CurrentRow, end_bound: Following(UInt64(NULL)), is_causal: false }] - DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet - "# - )?; - - // Case 27: - test_window_case!( - partition_by: true, - window_frame: Arc::new(WindowFrame::new(Some(true)).reverse()), - func: (fn_avg_on_ordered.0.clone(), fn_avg_on_ordered.1.clone(), fn_avg_on_ordered.2.clone()), - required_sort: ["avg", false, false], - @ r#" - Input / Optimized Plan: - SortExec: expr=[avg@2 DESC NULLS LAST], preserve_partitioning=[false] - WindowAggExec: wdw=[avg: Ok(Field { name: "avg", data_type: Float64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: CurrentRow, end_bound: Following(UInt64(NULL)), is_causal: false }] - DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet - "# - )?; - // =============================================REGION ENDS============================================= - // = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = - // ============================================REGION STARTS============================================ - // WindowAggExec + Sliding(current row, unbounded following) + partition_by + on unordered column - // Case 28: - test_window_case!( - partition_by: true, - window_frame: Arc::new(WindowFrame::new(Some(true)).reverse()), - func: (fn_count_on_unordered.0.clone(), fn_count_on_unordered.1.clone(), fn_count_on_unordered.2.clone()), - required_sort: ["count", false, false, "nullable_col", true, false], - @ r#" - Input / Optimized Plan: - SortExec: expr=[count@2 DESC NULLS LAST, nullable_col@0 ASC NULLS LAST], preserve_partitioning=[false] - WindowAggExec: wdw=[count: Ok(Field { name: "count", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: CurrentRow, end_bound: Following(UInt64(NULL)), is_causal: false }] - DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet - "# - )?; - - // Case 29: - test_window_case!( - partition_by: true, - window_frame: Arc::new(WindowFrame::new(Some(true)).reverse()), - func: (fn_max_on_unordered.0.clone(), fn_max_on_unordered.1.clone(), fn_max_on_unordered.2.clone()), - required_sort: ["nullable_col", true, false, "max", false, true], - @ r#" -Input Plan: -SortExec: expr=[nullable_col@0 ASC NULLS LAST, max@2 DESC], preserve_partitioning=[false] - WindowAggExec: wdw=[max: Ok(Field { name: "max", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: CurrentRow, end_bound: Following(UInt64(NULL)), is_causal: false }] - DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet - -Optimized Plan: -WindowAggExec: wdw=[max: Ok(Field { name: "max", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: CurrentRow, end_bound: Following(UInt64(NULL)), is_causal: false }] - DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet -"# - )?; - - // Case 30: - test_window_case!( - partition_by: true, - window_frame: Arc::new(WindowFrame::new(Some(true)).reverse()), - func: (fn_min_on_unordered.0.clone(), fn_min_on_unordered.1.clone(), fn_min_on_unordered.2.clone()), - required_sort: ["min", false, false], - @ r#" - Input / Optimized Plan: - SortExec: expr=[min@2 DESC NULLS LAST], preserve_partitioning=[false] - WindowAggExec: wdw=[min: Ok(Field { name: "min", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: CurrentRow, end_bound: Following(UInt64(NULL)), is_causal: false }] - DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet - "# - )?; - - // Case 31: - test_window_case!( - partition_by: true, - window_frame: Arc::new(WindowFrame::new(Some(true)).reverse()), - func: (fn_avg_on_unordered.0.clone(), fn_avg_on_unordered.1.clone(), fn_avg_on_unordered.2.clone()), - required_sort: ["nullable_col", true, false, "avg", true, false], - @ r#" - Input / Optimized Plan: - SortExec: expr=[nullable_col@0 ASC NULLS LAST, avg@2 ASC NULLS LAST], preserve_partitioning=[false] - WindowAggExec: wdw=[avg: Ok(Field { name: "avg", data_type: Float64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: CurrentRow, end_bound: Following(UInt64(NULL)), is_causal: false }] - DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet - "# - )?; - - // =============================================REGION ENDS============================================= - // = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = - // ============================================REGION STARTS============================================ - // BoundedWindowAggExec + Plain(unbounded preceding, unbounded following) + no partition_by + on ordered column - // Case 32: - test_window_case!( - partition_by: false, - window_frame: Arc::new(WindowFrame::new(Some(true))), - func: (fn_count_on_ordered.0.clone(), fn_count_on_ordered.1.clone(), fn_count_on_ordered.2.clone()), - required_sort: ["nullable_col", true, false, "count", true, false], - @ r#" -Input Plan: -SortExec: expr=[nullable_col@0 ASC NULLS LAST, count@2 ASC NULLS LAST], preserve_partitioning=[false] - BoundedWindowAggExec: wdw=[count: Field { name: "count", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }, frame: ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW], mode=[Sorted] - DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet - -Optimized Plan: -BoundedWindowAggExec: wdw=[count: Field { name: "count", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }, frame: ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW], mode=[Sorted] - DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet -"# - )?; - - // Case 33: - test_window_case!( - partition_by: false, - window_frame: Arc::new(WindowFrame::new(Some(true))), - func: (fn_max_on_ordered.0.clone(), fn_max_on_ordered.1.clone(), fn_max_on_ordered.2.clone()), - required_sort: ["max", false, false, "nullable_col", true, false], - @ r#" - Input / Optimized Plan: - SortExec: expr=[max@2 DESC NULLS LAST, nullable_col@0 ASC NULLS LAST], preserve_partitioning=[false] - BoundedWindowAggExec: wdw=[max: Field { name: "max", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, frame: ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW], mode=[Sorted] - DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet - "# - )?; - - // Case 34: - test_window_case!( - partition_by: false, - window_frame: Arc::new(WindowFrame::new(Some(true))), - func: (fn_min_on_ordered.0.clone(), fn_min_on_ordered.1.clone(), fn_min_on_ordered.2.clone()), - required_sort: ["min", false, false, "nullable_col", true, false], - @ r#" -Input Plan: -SortExec: expr=[min@2 DESC NULLS LAST, nullable_col@0 ASC NULLS LAST], preserve_partitioning=[false] - BoundedWindowAggExec: wdw=[min: Field { name: "min", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, frame: ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW], mode=[Sorted] - DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet - -Optimized Plan: -BoundedWindowAggExec: wdw=[min: Field { name: "min", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, frame: ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW], mode=[Sorted] - DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet -"# - )?; - - // Case 35: - test_window_case!( - partition_by: false, - window_frame: Arc::new(WindowFrame::new(Some(true))), - func: (fn_avg_on_ordered.0.clone(), fn_avg_on_ordered.1.clone(), fn_avg_on_ordered.2.clone()), - required_sort: ["nullable_col", true, false, "avg", true, false], - @ r#" - Input / Optimized Plan: - SortExec: expr=[nullable_col@0 ASC NULLS LAST, avg@2 ASC NULLS LAST], preserve_partitioning=[false] - BoundedWindowAggExec: wdw=[avg: Field { name: "avg", data_type: Float64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, frame: ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW], mode=[Sorted] - DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet - "# - )?; - // =============================================REGION ENDS============================================= - // = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = - // ============================================REGION STARTS============================================ - // BoundedWindowAggExec + Plain(unbounded preceding, unbounded following) + no partition_by + on unordered column - // Case 36: - test_window_case!( - partition_by: false, - window_frame: Arc::new(WindowFrame::new(Some(true))), - func: (fn_count_on_unordered.0.clone(), fn_count_on_unordered.1.clone(), fn_count_on_unordered.2.clone()), - required_sort: ["nullable_col", true, false, "count", true, true], - @ r#" -Input Plan: -SortExec: expr=[nullable_col@0 ASC NULLS LAST, count@2 ASC], preserve_partitioning=[false] - BoundedWindowAggExec: wdw=[count: Field { name: "count", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }, frame: ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW], mode=[Sorted] - DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet - -Optimized Plan: -BoundedWindowAggExec: wdw=[count: Field { name: "count", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }, frame: ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW], mode=[Sorted] - DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet -"# - )?; - - // Case 37: - test_window_case!( - partition_by: false, - window_frame: Arc::new(WindowFrame::new(Some(true))), - func: (fn_max_on_unordered.0.clone(), fn_max_on_unordered.1.clone(), fn_max_on_unordered.2.clone()), - required_sort: ["max", true, false, "nullable_col", true, false], - @ r#" -Input Plan: -SortExec: expr=[max@2 ASC NULLS LAST, nullable_col@0 ASC NULLS LAST], preserve_partitioning=[false] - BoundedWindowAggExec: wdw=[max: Field { name: "max", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, frame: ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW], mode=[Sorted] - DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet - -Optimized Plan: -BoundedWindowAggExec: wdw=[max: Field { name: "max", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, frame: ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW], mode=[Sorted] - DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet -"# - )?; - - // Case 38: - test_window_case!( - partition_by: false, - window_frame: Arc::new(WindowFrame::new(Some(true))), - func: (fn_min_on_unordered.0.clone(), fn_min_on_unordered.1.clone(), fn_min_on_unordered.2.clone()), - required_sort: ["min", false, true, "nullable_col", true, false], - @ r#" - Input / Optimized Plan: - SortExec: expr=[min@2 DESC, nullable_col@0 ASC NULLS LAST], preserve_partitioning=[false] - BoundedWindowAggExec: wdw=[min: Field { name: "min", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, frame: ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW], mode=[Sorted] - DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet - "# - )?; - - // Case 39: - test_window_case!( - partition_by: false, - window_frame: Arc::new(WindowFrame::new(Some(true))), - func: (fn_avg_on_unordered.0.clone(), fn_avg_on_unordered.1.clone(), fn_avg_on_unordered.2.clone()), - required_sort: ["avg", true, false], - @ r#" - Input / Optimized Plan: - SortExec: expr=[avg@2 ASC NULLS LAST], preserve_partitioning=[false] - BoundedWindowAggExec: wdw=[avg: Field { name: "avg", data_type: Float64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, frame: ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW], mode=[Sorted] - DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet - "# - )?; - // =============================================REGION ENDS============================================= - // = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = - // ============================================REGION STARTS============================================ - // BoundedWindowAggExec + Plain(unbounded preceding, unbounded following) + partition_by + on ordered column - // Case 40: - test_window_case!( - partition_by: true, - window_frame: Arc::new(WindowFrame::new(Some(true))), - func: (fn_count_on_ordered.0.clone(), fn_count_on_ordered.1.clone(), fn_count_on_ordered.2.clone()), - required_sort: ["nullable_col", true, false, "count", true, false], - @ r#" -Input Plan: -SortExec: expr=[nullable_col@0 ASC NULLS LAST, count@2 ASC NULLS LAST], preserve_partitioning=[false] - BoundedWindowAggExec: wdw=[count: Field { name: "count", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }, frame: ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW], mode=[Sorted] - DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet - -Optimized Plan: -BoundedWindowAggExec: wdw=[count: Field { name: "count", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }, frame: ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW], mode=[Sorted] - DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet -"# - )?; - - // Case 41: - test_window_case!( - partition_by: true, - window_frame: Arc::new(WindowFrame::new(Some(true))), - func: (fn_max_on_ordered.0.clone(), fn_max_on_ordered.1.clone(), fn_max_on_ordered.2.clone()), - required_sort: ["max", true, false, "nullable_col", true, false], - @ r#" - Input / Optimized Plan: - SortExec: expr=[max@2 ASC NULLS LAST, nullable_col@0 ASC NULLS LAST], preserve_partitioning=[false] - BoundedWindowAggExec: wdw=[max: Field { name: "max", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, frame: ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW], mode=[Sorted] - DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet - "# - )?; - - // Case 42: - test_window_case!( - partition_by: true, - window_frame: Arc::new(WindowFrame::new(Some(true))), - func: (fn_min_on_ordered.0.clone(), fn_min_on_ordered.1.clone(), fn_min_on_ordered.2.clone()), - required_sort: ["min", false, false, "nullable_col", true, false], - @ r#" - Input / Optimized Plan: - SortExec: expr=[min@2 DESC NULLS LAST, nullable_col@0 ASC NULLS LAST], preserve_partitioning=[false] - BoundedWindowAggExec: wdw=[min: Field { name: "min", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, frame: ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW], mode=[Sorted] - DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet - "# - )?; - - // Case 43: - test_window_case!( - partition_by: true, - window_frame: Arc::new(WindowFrame::new(Some(true))), - func: (fn_avg_on_ordered.0.clone(), fn_avg_on_ordered.1.clone(), fn_avg_on_ordered.2.clone()), - required_sort: ["nullable_col", true, false, "avg", true, false], - @ r#" - Input / Optimized Plan: - SortExec: expr=[nullable_col@0 ASC NULLS LAST, avg@2 ASC NULLS LAST], preserve_partitioning=[false] - BoundedWindowAggExec: wdw=[avg: Field { name: "avg", data_type: Float64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, frame: ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW], mode=[Sorted] - DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet - "# - )?; - - // =============================================REGION ENDS============================================= - // = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = - // ============================================REGION STARTS============================================ - // BoundedWindowAggExec + Plain(unbounded preceding, unbounded following) + partition_by + on unordered column - // Case 44: - test_window_case!( - partition_by: true, - window_frame: Arc::new(WindowFrame::new(Some(true))), - func: (fn_count_on_unordered.0.clone(), fn_count_on_unordered.1.clone(), fn_count_on_unordered.2.clone()), - required_sort: ["count", true, true], - @ r#" - Input / Optimized Plan: - SortExec: expr=[count@2 ASC], preserve_partitioning=[false] - BoundedWindowAggExec: wdw=[count: Field { name: "count", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }, frame: ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW], mode=[Sorted] - DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet - "# - )?; - - // Case 45: - test_window_case!( - partition_by: true, - window_frame: Arc::new(WindowFrame::new(Some(true))), - func: (fn_max_on_unordered.0.clone(), fn_max_on_unordered.1.clone(), fn_max_on_unordered.2.clone()), - required_sort: ["nullable_col", true, false, "max", false, false], - @ r#" - Input / Optimized Plan: - SortExec: expr=[nullable_col@0 ASC NULLS LAST, max@2 DESC NULLS LAST], preserve_partitioning=[false] - BoundedWindowAggExec: wdw=[max: Field { name: "max", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, frame: ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW], mode=[Sorted] - DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet - "# - )?; - - // Case 46: - test_window_case!( - partition_by: true, - window_frame: Arc::new(WindowFrame::new(Some(true))), - func: (fn_min_on_unordered.0.clone(), fn_min_on_unordered.1.clone(), fn_min_on_unordered.2.clone()), - required_sort: ["nullable_col", true, false, "min", false, false], - @ r#" -Input Plan: -SortExec: expr=[nullable_col@0 ASC NULLS LAST, min@2 DESC NULLS LAST], preserve_partitioning=[false] - BoundedWindowAggExec: wdw=[min: Field { name: "min", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, frame: ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW], mode=[Sorted] - DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet - -Optimized Plan: -BoundedWindowAggExec: wdw=[min: Field { name: "min", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, frame: ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW], mode=[Sorted] - DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet -"# - )?; - - // Case 47: - test_window_case!( - partition_by: true, - window_frame: Arc::new(WindowFrame::new(Some(true))), - func: (fn_avg_on_unordered.0.clone(), fn_avg_on_unordered.1.clone(), fn_avg_on_unordered.2.clone()), - required_sort: ["nullable_col", true, false], - @ r#" -Input Plan: -SortExec: expr=[nullable_col@0 ASC NULLS LAST], preserve_partitioning=[false] - BoundedWindowAggExec: wdw=[avg: Field { name: "avg", data_type: Float64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, frame: ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW], mode=[Sorted] - DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet - -Optimized Plan: -BoundedWindowAggExec: wdw=[avg: Field { name: "avg", data_type: Float64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, frame: ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW], mode=[Sorted] - DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet -"# - )?; - - // =============================================REGION ENDS============================================= - // = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = - // ============================================REGION STARTS============================================ - // BoundedWindowAggExec + Sliding(bounded preceding, bounded following) + no partition_by + on ordered column - // Case 48: - test_window_case!( - partition_by: false, - window_frame: Arc::new(WindowFrame::new_bounds(WindowFrameUnits::Rows, WindowFrameBound::Preceding(ScalarValue::new_one(&DataType::UInt32)?), WindowFrameBound::CurrentRow)), - func: (fn_count_on_ordered.0.clone(), fn_count_on_ordered.1.clone(), fn_count_on_ordered.2.clone()), - required_sort: ["count", true, false, "nullable_col", true, false], - @ r#" -Input Plan: -SortExec: expr=[count@2 ASC NULLS LAST, nullable_col@0 ASC NULLS LAST], preserve_partitioning=[false] - BoundedWindowAggExec: wdw=[count: Field { name: "count", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }, frame: ROWS BETWEEN 1 PRECEDING AND CURRENT ROW], mode=[Sorted] - DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet - -Optimized Plan: -BoundedWindowAggExec: wdw=[count: Field { name: "count", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }, frame: ROWS BETWEEN 1 PRECEDING AND CURRENT ROW], mode=[Sorted] - DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet -"# - )?; - - // Case 49: - test_window_case!( - partition_by: false, - window_frame: Arc::new(WindowFrame::new_bounds(WindowFrameUnits::Rows, WindowFrameBound::Preceding(ScalarValue::new_one(&DataType::UInt32)?), WindowFrameBound::Following(ScalarValue::new_one(&DataType::UInt32)?))), - func: (fn_max_on_ordered.0.clone(), fn_max_on_ordered.1.clone(), fn_max_on_ordered.2.clone()), - required_sort: ["max", true, false], - @ r#" - Input / Optimized Plan: - SortExec: expr=[max@2 ASC NULLS LAST], preserve_partitioning=[false] - BoundedWindowAggExec: wdw=[max: Field { name: "max", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, frame: ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING], mode=[Sorted] - DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet - "# - )?; - - // Case 50: - test_window_case!( - partition_by: false, - window_frame: Arc::new(WindowFrame::new_bounds(WindowFrameUnits::Rows, WindowFrameBound::Preceding(ScalarValue::new_one(&DataType::UInt32)?), WindowFrameBound::CurrentRow)), - func: (fn_min_on_ordered.0.clone(), fn_min_on_ordered.1.clone(), fn_min_on_ordered.2.clone()), - required_sort: ["nullable_col", true, false, "min", false, false], - @ r#" -Input Plan: -SortExec: expr=[nullable_col@0 ASC NULLS LAST, min@2 DESC NULLS LAST], preserve_partitioning=[false] - BoundedWindowAggExec: wdw=[min: Field { name: "min", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, frame: ROWS BETWEEN 1 PRECEDING AND CURRENT ROW], mode=[Sorted] - DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet - -Optimized Plan: -BoundedWindowAggExec: wdw=[min: Field { name: "min", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, frame: ROWS BETWEEN 1 PRECEDING AND CURRENT ROW], mode=[Sorted] - DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet -"# - )?; - - // Case 51: - test_window_case!( - partition_by: false, - window_frame: Arc::new(WindowFrame::new_bounds(WindowFrameUnits::Rows, WindowFrameBound::Preceding(ScalarValue::new_one(&DataType::UInt32)?), WindowFrameBound::CurrentRow)), - func: (fn_avg_on_ordered.0.clone(), fn_avg_on_ordered.1.clone(), fn_avg_on_ordered.2.clone()), - required_sort: ["avg", true, false], - @ r#" - Input / Optimized Plan: - SortExec: expr=[avg@2 ASC NULLS LAST], preserve_partitioning=[false] - BoundedWindowAggExec: wdw=[avg: Field { name: "avg", data_type: Float64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, frame: ROWS BETWEEN 1 PRECEDING AND CURRENT ROW], mode=[Sorted] - DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet - "# - )?; - - // =============================================REGION ENDS============================================= - // = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = - // ============================================REGION STARTS============================================ - // BoundedWindowAggExec + Sliding(bounded preceding, bounded following) + no partition_by + on unordered column - // Case 52: - test_window_case!( - partition_by: false, - window_frame: Arc::new(WindowFrame::new_bounds(WindowFrameUnits::Rows, WindowFrameBound::Preceding(ScalarValue::new_one(&DataType::UInt32)?), WindowFrameBound::Following(ScalarValue::new_one(&DataType::UInt32)?))), - func: (fn_count_on_unordered.0.clone(), fn_count_on_unordered.1.clone(), fn_count_on_unordered.2.clone()), - required_sort: ["count", true, false, "nullable_col", true, false], - @ r#" - Input / Optimized Plan: - SortExec: expr=[count@2 ASC NULLS LAST, nullable_col@0 ASC NULLS LAST], preserve_partitioning=[false] - BoundedWindowAggExec: wdw=[count: Field { name: "count", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }, frame: ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING], mode=[Sorted] - DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet - "# - )?; - - // Case 53: - test_window_case!( - partition_by: false, - window_frame: Arc::new(WindowFrame::new_bounds(WindowFrameUnits::Rows, WindowFrameBound::Preceding(ScalarValue::new_one(&DataType::UInt32)?), WindowFrameBound::CurrentRow)), - func: (fn_max_on_unordered.0.clone(), fn_max_on_unordered.1.clone(), fn_max_on_unordered.2.clone()), - required_sort: ["nullable_col", true, false, "max", true, false], - @ r#" - Input / Optimized Plan: - SortExec: expr=[nullable_col@0 ASC NULLS LAST, max@2 ASC NULLS LAST], preserve_partitioning=[false] - BoundedWindowAggExec: wdw=[max: Field { name: "max", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, frame: ROWS BETWEEN 1 PRECEDING AND CURRENT ROW], mode=[Sorted] - DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet - "# - )?; - - // Case 54: - test_window_case!( - partition_by: false, - window_frame: Arc::new(WindowFrame::new_bounds(WindowFrameUnits::Rows, WindowFrameBound::Preceding(ScalarValue::new_one(&DataType::UInt32)?), WindowFrameBound::CurrentRow)), - func: (fn_min_on_unordered.0.clone(), fn_min_on_unordered.1.clone(), fn_min_on_unordered.2.clone()), - required_sort: ["min", true, false], - @ r#" - Input / Optimized Plan: - SortExec: expr=[min@2 ASC NULLS LAST], preserve_partitioning=[false] - BoundedWindowAggExec: wdw=[min: Field { name: "min", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, frame: ROWS BETWEEN 1 PRECEDING AND CURRENT ROW], mode=[Sorted] - DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet - "# - )?; - - // Case 55: - test_window_case!( - partition_by: false, - window_frame: Arc::new(WindowFrame::new_bounds(WindowFrameUnits::Rows, WindowFrameBound::Preceding(ScalarValue::new_one(&DataType::UInt32)?), WindowFrameBound::Following(ScalarValue::new_one(&DataType::UInt32)?))), - func: (fn_avg_on_unordered.0.clone(), fn_avg_on_unordered.1.clone(), fn_avg_on_unordered.2.clone()), - required_sort: ["nullable_col", true, false], - @ r#" -Input Plan: -SortExec: expr=[nullable_col@0 ASC NULLS LAST], preserve_partitioning=[false] - BoundedWindowAggExec: wdw=[avg: Field { name: "avg", data_type: Float64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, frame: ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING], mode=[Sorted] - DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet - -Optimized Plan: -BoundedWindowAggExec: wdw=[avg: Field { name: "avg", data_type: Float64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, frame: ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING], mode=[Sorted] - DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet -"# - )?; - - // =============================================REGION ENDS============================================= - // = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = - // ============================================REGION STARTS============================================ - // BoundedWindowAggExec + Sliding(bounded preceding, bounded following) + partition_by + on ordered column - // Case 56: - test_window_case!( - partition_by: true, - window_frame: Arc::new(WindowFrame::new_bounds(WindowFrameUnits::Rows, WindowFrameBound::Preceding(ScalarValue::new_one(&DataType::UInt32)?), WindowFrameBound::CurrentRow)), - func: (fn_count_on_ordered.0.clone(), fn_count_on_ordered.1.clone(), fn_count_on_ordered.2.clone()), - required_sort: ["count", true, false, "nullable_col", true, false], - @ r#" -Input Plan: -SortExec: expr=[count@2 ASC NULLS LAST, nullable_col@0 ASC NULLS LAST], preserve_partitioning=[false] - BoundedWindowAggExec: wdw=[count: Field { name: "count", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }, frame: ROWS BETWEEN 1 PRECEDING AND CURRENT ROW], mode=[Sorted] - DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet - -Optimized Plan: -BoundedWindowAggExec: wdw=[count: Field { name: "count", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }, frame: ROWS BETWEEN 1 PRECEDING AND CURRENT ROW], mode=[Sorted] - DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet -"# - )?; - - // Case 57: - test_window_case!( - partition_by: true, - window_frame: Arc::new(WindowFrame::new_bounds(WindowFrameUnits::Rows, WindowFrameBound::Preceding(ScalarValue::new_one(&DataType::UInt32)?), WindowFrameBound::Following(ScalarValue::new_one(&DataType::UInt32)?))), - func: (fn_max_on_ordered.0.clone(), fn_max_on_ordered.1.clone(), fn_max_on_ordered.2.clone()), - required_sort: ["nullable_col", true, false, "max", true, false], - @ r#" - Input / Optimized Plan: - SortExec: expr=[nullable_col@0 ASC NULLS LAST, max@2 ASC NULLS LAST], preserve_partitioning=[false] - BoundedWindowAggExec: wdw=[max: Field { name: "max", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, frame: ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING], mode=[Sorted] - DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet - "# - )?; - - // Case 58: - test_window_case!( - partition_by: true, - window_frame: Arc::new(WindowFrame::new_bounds(WindowFrameUnits::Rows, WindowFrameBound::Preceding(ScalarValue::new_one(&DataType::UInt32)?), WindowFrameBound::CurrentRow)), - func: (fn_min_on_ordered.0.clone(), fn_min_on_ordered.1.clone(), fn_min_on_ordered.2.clone()), - required_sort: ["min", false, false, "nullable_col", true, false], - @ r#" - Input / Optimized Plan: - SortExec: expr=[min@2 DESC NULLS LAST, nullable_col@0 ASC NULLS LAST], preserve_partitioning=[false] - BoundedWindowAggExec: wdw=[min: Field { name: "min", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, frame: ROWS BETWEEN 1 PRECEDING AND CURRENT ROW], mode=[Sorted] - DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet - "# - )?; - - // Case 59: - test_window_case!( - partition_by: true, - window_frame: Arc::new(WindowFrame::new_bounds(WindowFrameUnits::Rows, WindowFrameBound::Preceding(ScalarValue::new_one(&DataType::UInt32)?), WindowFrameBound::CurrentRow)), - func: (fn_avg_on_ordered.0.clone(), fn_avg_on_ordered.1.clone(), fn_avg_on_ordered.2.clone()), - required_sort: ["avg", true, false], - @ r#" - Input / Optimized Plan: - SortExec: expr=[avg@2 ASC NULLS LAST], preserve_partitioning=[false] - BoundedWindowAggExec: wdw=[avg: Field { name: "avg", data_type: Float64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, frame: ROWS BETWEEN 1 PRECEDING AND CURRENT ROW], mode=[Sorted] - DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet - "# - )?; - // =============================================REGION ENDS============================================= - // = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = - // ============================================REGION STARTS============================================ - // BoundedWindowAggExec + Sliding(bounded preceding, bounded following) + partition_by + on unordered column - // Case 60: - test_window_case!( - partition_by: true, - window_frame: Arc::new(WindowFrame::new_bounds(WindowFrameUnits::Rows, WindowFrameBound::Preceding(ScalarValue::new_one(&DataType::UInt32)?), WindowFrameBound::CurrentRow)), - func: (fn_count_on_unordered.0.clone(), fn_count_on_unordered.1.clone(), fn_count_on_unordered.2.clone()), - required_sort: ["nullable_col", true, false, "count", true, false], - @ r#" - Input / Optimized Plan: - SortExec: expr=[nullable_col@0 ASC NULLS LAST, count@2 ASC NULLS LAST], preserve_partitioning=[false] - BoundedWindowAggExec: wdw=[count: Field { name: "count", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }, frame: ROWS BETWEEN 1 PRECEDING AND CURRENT ROW], mode=[Sorted] - DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet - "# - )?; - - // Case 61: - test_window_case!( - partition_by: true, - window_frame: Arc::new(WindowFrame::new_bounds(WindowFrameUnits::Rows, WindowFrameBound::Preceding(ScalarValue::new_one(&DataType::UInt32)?), WindowFrameBound::CurrentRow)), - func: (fn_max_on_unordered.0.clone(), fn_max_on_unordered.1.clone(), fn_max_on_unordered.2.clone()), - required_sort: ["nullable_col", true, false, "max", true, true], - @ r#" - Input / Optimized Plan: - SortExec: expr=[nullable_col@0 ASC NULLS LAST, max@2 ASC], preserve_partitioning=[false] - BoundedWindowAggExec: wdw=[max: Field { name: "max", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, frame: ROWS BETWEEN 1 PRECEDING AND CURRENT ROW], mode=[Sorted] - DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet - "# - )?; - - // Case 62: - test_window_case!( - partition_by: true, - window_frame: Arc::new(WindowFrame::new_bounds(WindowFrameUnits::Rows, WindowFrameBound::Preceding(ScalarValue::new_one(&DataType::UInt32)?), WindowFrameBound::CurrentRow)), - func: (fn_min_on_unordered.0.clone(), fn_min_on_unordered.1.clone(), fn_min_on_unordered.2.clone()), - required_sort: ["nullable_col", true, false, "min", false, false], - @ r#" - Input / Optimized Plan: - SortExec: expr=[nullable_col@0 ASC NULLS LAST, min@2 DESC NULLS LAST], preserve_partitioning=[false] - BoundedWindowAggExec: wdw=[min: Field { name: "min", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, frame: ROWS BETWEEN 1 PRECEDING AND CURRENT ROW], mode=[Sorted] - DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet - "# - )?; - - // Case 63: - test_window_case!( - partition_by: true, - window_frame: Arc::new(WindowFrame::new_bounds(WindowFrameUnits::Rows, WindowFrameBound::Preceding(ScalarValue::new_one(&DataType::UInt32)?), WindowFrameBound::CurrentRow)), - func: (fn_avg_on_unordered.0.clone(), fn_avg_on_unordered.1.clone(), fn_avg_on_unordered.2.clone()), - required_sort: ["nullable_col", true, false], - @ r#" -Input Plan: -SortExec: expr=[nullable_col@0 ASC NULLS LAST], preserve_partitioning=[false] - BoundedWindowAggExec: wdw=[avg: Field { name: "avg", data_type: Float64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, frame: ROWS BETWEEN 1 PRECEDING AND CURRENT ROW], mode=[Sorted] - DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet - -Optimized Plan: -BoundedWindowAggExec: wdw=[avg: Field { name: "avg", data_type: Float64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, frame: ROWS BETWEEN 1 PRECEDING AND CURRENT ROW], mode=[Sorted] - DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet -"# - )?; - // =============================================REGION ENDS============================================= - - Ok(()) -} #[test] fn test_removes_unused_orthogonal_sort() -> Result<()> { let schema = create_test_schema3()?; diff --git a/datafusion/core/tests/physical_optimizer/enforce_sorting_monotonicity.rs b/datafusion/core/tests/physical_optimizer/enforce_sorting_monotonicity.rs new file mode 100644 index 000000000000..ef233e222912 --- /dev/null +++ b/datafusion/core/tests/physical_optimizer/enforce_sorting_monotonicity.rs @@ -0,0 +1,1715 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::physical_optimizer::test_utils::{ + create_test_schema, parquet_exec_with_sort, sort_exec, sort_expr_options, +}; +use arrow::datatypes::DataType; +use arrow_schema::SortOptions; +use datafusion::common::ScalarValue; +use datafusion::logical_expr::WindowFrameBound; +use datafusion::logical_expr::WindowFrameUnits; +use datafusion_expr::{WindowFrame, WindowFunctionDefinition}; +use datafusion_functions_aggregate::average::avg_udaf; +use datafusion_functions_aggregate::count::count_udaf; +use datafusion_functions_aggregate::min_max::{max_udaf, min_udaf}; +use datafusion_physical_expr::expressions::col; +use datafusion_physical_expr_common::physical_expr::PhysicalExpr; +use datafusion_physical_expr_common::sort_expr::LexOrdering; +use datafusion_physical_plan::windows::{ + create_window_expr, BoundedWindowAggExec, WindowAggExec, +}; +use datafusion_physical_plan::{ExecutionPlan, InputOrderMode}; +use insta::assert_snapshot; +use std::sync::{Arc, LazyLock}; + +// Function definition - Alias of the resulting column - Arguments of the function +#[derive(Clone)] +struct WindowFuncParam( + WindowFunctionDefinition, + &'static str, + Vec>, +); + +fn function_arg_ordered() -> Vec> { + let input_schema = create_test_schema().unwrap(); + vec![col("nullable_col", &input_schema).unwrap()] +} +fn function_arg_unordered() -> Vec> { + let input_schema = create_test_schema().unwrap(); + vec![col("non_nullable_col", &input_schema).unwrap()] +} + +fn fn_count_on_ordered() -> WindowFuncParam { + WindowFuncParam( + WindowFunctionDefinition::AggregateUDF(count_udaf()), + "count", + function_arg_ordered(), + ) +} + +fn fn_max_on_ordered() -> WindowFuncParam { + WindowFuncParam( + WindowFunctionDefinition::AggregateUDF(max_udaf()), + "max", + function_arg_ordered(), + ) +} + +fn fn_min_on_ordered() -> WindowFuncParam { + WindowFuncParam( + WindowFunctionDefinition::AggregateUDF(min_udaf()), + "min", + function_arg_ordered(), + ) +} + +fn fn_avg_on_ordered() -> WindowFuncParam { + WindowFuncParam( + WindowFunctionDefinition::AggregateUDF(avg_udaf()), + "avg", + function_arg_ordered(), + ) +} + +fn fn_count_on_unordered() -> WindowFuncParam { + WindowFuncParam( + WindowFunctionDefinition::AggregateUDF(count_udaf()), + "count", + function_arg_unordered(), + ) +} + +fn fn_max_on_unordered() -> WindowFuncParam { + WindowFuncParam( + WindowFunctionDefinition::AggregateUDF(max_udaf()), + "max", + function_arg_unordered(), + ) +} +fn fn_min_on_unordered() -> WindowFuncParam { + WindowFuncParam( + WindowFunctionDefinition::AggregateUDF(min_udaf()), + "min", + function_arg_unordered(), + ) +} + +fn fn_avg_on_unordered() -> WindowFuncParam { + WindowFuncParam( + WindowFunctionDefinition::AggregateUDF(avg_udaf()), + "avg", + function_arg_unordered(), + ) +} + +struct TestWindowCase { + partition_by: bool, + window_frame: Arc, + func: WindowFuncParam, + required_sort: Vec<(&'static str, bool, bool)>, // (column name, ascending, nulls_first) +} +impl TestWindowCase { + fn source() -> Arc { + static SOURCE: LazyLock> = LazyLock::new(|| { + let input_schema = create_test_schema().unwrap(); + let ordering = [sort_expr_options( + "nullable_col", + &input_schema, + SortOptions { + descending: false, + nulls_first: false, + }, + )] + .into(); + parquet_exec_with_sort(input_schema.clone(), vec![ordering]) + }); + Arc::clone(&SOURCE) + } + + // runs the window test case and returns the string representation of the plan + fn run(self) -> String { + let input_schema = create_test_schema().unwrap(); + let source = Self::source(); + + let Self { + partition_by, + window_frame, + func: WindowFuncParam(func_def, func_name, func_args), + required_sort, + } = self; + let partition_by_exprs = if partition_by { + vec![col("nullable_col", &input_schema).unwrap()] + } else { + vec![] + }; + + let window_expr = create_window_expr( + &func_def, + func_name.to_string(), + &func_args, + &partition_by_exprs, + &[], + window_frame, + Arc::clone(&input_schema), + false, + false, + None, + ) + .unwrap(); + + let window_exec = if window_expr.uses_bounded_memory() { + Arc::new( + BoundedWindowAggExec::try_new( + vec![window_expr], + Arc::clone(&source), + InputOrderMode::Sorted, + partition_by, + ) + .unwrap(), + ) as Arc + } else { + Arc::new( + WindowAggExec::try_new( + vec![window_expr], + Arc::clone(&source), + partition_by, + ) + .unwrap(), + ) as Arc + }; + + let output_schema = window_exec.schema(); + let sort_expr = required_sort.into_iter().map(|(col, asc, nulls_first)| { + sort_expr_options( + col, + &output_schema, + SortOptions { + descending: !asc, + nulls_first, + }, + ) + }); + let ordering = LexOrdering::new(sort_expr).unwrap(); + let physical_plan = sort_exec(ordering, window_exec); + + crate::physical_optimizer::enforce_sorting::EnforceSortingTest::new(physical_plan) + .with_repartition_sorts(true) + .run() + } +} +#[test] +fn test_window_partial_constant_and_set_monotonicity_0() { + // ============================================REGION STARTS============================================ + // WindowAggExec + Plain(unbounded preceding, unbounded following) + no partition_by + on ordered column + // Case 0: + assert_snapshot!(TestWindowCase { + partition_by: false, + window_frame: Arc::new(WindowFrame::new(None)), + func: fn_count_on_ordered(), + required_sort: vec![ + ("nullable_col", true, false), + ("count", true, false), + ], + }.run(), + @ r#" + Input Plan: + SortExec: expr=[nullable_col@0 ASC NULLS LAST, count@2 ASC NULLS LAST], preserve_partitioning=[false] + WindowAggExec: wdw=[count: Ok(Field { name: "count", data_type: Int64 }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: Following(UInt64(NULL)), is_causal: false }] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet + + Optimized Plan: + WindowAggExec: wdw=[count: Ok(Field { name: "count", data_type: Int64 }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: Following(UInt64(NULL)), is_causal: false }] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet + "# + ); +} + +#[test] +fn test_window_partial_constant_and_set_monotonicity_1() { + assert_snapshot!(TestWindowCase { + partition_by: false, + window_frame: Arc::new(WindowFrame::new(None)), + func: fn_max_on_ordered(), + required_sort: vec![ + ("nullable_col", true, false), + ("max", false, false), + ], + }.run(), + @ r#" + Input Plan: + SortExec: expr=[nullable_col@0 ASC NULLS LAST, max@2 DESC NULLS LAST], preserve_partitioning=[false] + WindowAggExec: wdw=[max: Ok(Field { name: "max", data_type: Int32, nullable: true }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: Following(UInt64(NULL)), is_causal: false }] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet + + Optimized Plan: + WindowAggExec: wdw=[max: Ok(Field { name: "max", data_type: Int32, nullable: true }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: Following(UInt64(NULL)), is_causal: false }] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet + "# + ); +} + +#[test] +fn test_window_partial_constant_and_set_monotonicity_2() { + assert_snapshot!(TestWindowCase { + partition_by: false, + window_frame: Arc::new(WindowFrame::new(None)), + func: fn_min_on_ordered(), + required_sort: vec![ + ("min", false, false), + ("nullable_col", true, false), + ], + }.run(), + @ r#" + Input Plan: + SortExec: expr=[min@2 DESC NULLS LAST, nullable_col@0 ASC NULLS LAST], preserve_partitioning=[false] + WindowAggExec: wdw=[min: Ok(Field { name: "min", data_type: Int32, nullable: true }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: Following(UInt64(NULL)), is_causal: false }] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet + + Optimized Plan: + WindowAggExec: wdw=[min: Ok(Field { name: "min", data_type: Int32, nullable: true }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: Following(UInt64(NULL)), is_causal: false }] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet + "# + ); +} + +#[test] +fn test_window_partial_constant_and_set_monotonicity_3() { + assert_snapshot!(TestWindowCase { + partition_by: false, + window_frame: Arc::new(WindowFrame::new(None)), + func: fn_avg_on_ordered(), + required_sort: vec![ + ("avg", true, false), + ("nullable_col", true, false), + ], + }.run(), + @ r#" + Input Plan: + SortExec: expr=[avg@2 ASC NULLS LAST, nullable_col@0 ASC NULLS LAST], preserve_partitioning=[false] + WindowAggExec: wdw=[avg: Ok(Field { name: "avg", data_type: Float64, nullable: true }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: Following(UInt64(NULL)), is_causal: false }] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet + + Optimized Plan: + WindowAggExec: wdw=[avg: Ok(Field { name: "avg", data_type: Float64, nullable: true }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: Following(UInt64(NULL)), is_causal: false }] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet + "# + ); +} + +#[test] +fn test_window_partial_constant_and_set_monotonicity_4() { + assert_snapshot!(TestWindowCase { + partition_by: false, + window_frame: Arc::new(WindowFrame::new(None)), + func: fn_count_on_unordered(), + required_sort: vec![ + ("non_nullable_col", true, false), + ("count", true, false), + ], + }.run(), + @ r#" + Input Plan: + SortExec: expr=[non_nullable_col@1 ASC NULLS LAST, count@2 ASC NULLS LAST], preserve_partitioning=[false] + WindowAggExec: wdw=[count: Ok(Field { name: "count", data_type: Int64 }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: Following(UInt64(NULL)), is_causal: false }] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet + + Optimized Plan: + SortExec: expr=[non_nullable_col@1 ASC NULLS LAST], preserve_partitioning=[false] + WindowAggExec: wdw=[count: Ok(Field { name: "count", data_type: Int64 }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: Following(UInt64(NULL)), is_causal: false }] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet + "# + ); +} + +#[test] +fn test_window_partial_constant_and_set_monotonicity_5() { + assert_snapshot!(TestWindowCase { + partition_by: false, + window_frame: Arc::new(WindowFrame::new(None)), + func: fn_max_on_unordered(), + required_sort: vec![ + ("non_nullable_col", false, false), + ("max", false, false), + ], + }.run(), + @ r#" + Input Plan: + SortExec: expr=[non_nullable_col@1 DESC NULLS LAST, max@2 DESC NULLS LAST], preserve_partitioning=[false] + WindowAggExec: wdw=[max: Ok(Field { name: "max", data_type: Int32, nullable: true }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: Following(UInt64(NULL)), is_causal: false }] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet + + Optimized Plan: + SortExec: expr=[non_nullable_col@1 DESC NULLS LAST], preserve_partitioning=[false] + WindowAggExec: wdw=[max: Ok(Field { name: "max", data_type: Int32, nullable: true }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: Following(UInt64(NULL)), is_causal: false }] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet + "# + ); +} + +#[test] +fn test_window_partial_constant_and_set_monotonicity_6() { + assert_snapshot!(TestWindowCase { + partition_by: false, + window_frame: Arc::new(WindowFrame::new(None)), + func: fn_min_on_unordered(), + required_sort: vec![ + ("min", true, false), + ("non_nullable_col", true, false), + ], + }.run(), + @ r#" + Input Plan: + SortExec: expr=[min@2 ASC NULLS LAST, non_nullable_col@1 ASC NULLS LAST], preserve_partitioning=[false] + WindowAggExec: wdw=[min: Ok(Field { name: "min", data_type: Int32, nullable: true }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: Following(UInt64(NULL)), is_causal: false }] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet + + Optimized Plan: + SortExec: expr=[non_nullable_col@1 ASC NULLS LAST], preserve_partitioning=[false] + WindowAggExec: wdw=[min: Ok(Field { name: "min", data_type: Int32, nullable: true }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: Following(UInt64(NULL)), is_causal: false }] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet + "# + ); +} + +#[test] +fn test_window_partial_constant_and_set_monotonicity_7() { + assert_snapshot!(TestWindowCase { + partition_by: false, + window_frame: Arc::new(WindowFrame::new(None)), + func: fn_avg_on_unordered(), + required_sort: vec![ + ("avg", false, false), + ("nullable_col", false, false), + ], + }.run(), + @ r#" + Input Plan: + SortExec: expr=[avg@2 DESC NULLS LAST, nullable_col@0 DESC NULLS LAST], preserve_partitioning=[false] + WindowAggExec: wdw=[avg: Ok(Field { name: "avg", data_type: Float64, nullable: true }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: Following(UInt64(NULL)), is_causal: false }] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet + + Optimized Plan: + SortExec: expr=[nullable_col@0 DESC NULLS LAST], preserve_partitioning=[false] + WindowAggExec: wdw=[avg: Ok(Field { name: "avg", data_type: Float64, nullable: true }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: Following(UInt64(NULL)), is_causal: false }] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet + "# + ); +} + +// =============================================REGION ENDS============================================= +// = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = +// ============================================REGION STARTS============================================ + +#[test] +fn test_window_partial_constant_and_set_monotonicity_8() { + assert_snapshot!(TestWindowCase { + partition_by: true, + window_frame: Arc::new(WindowFrame::new(None)), + func: fn_count_on_ordered(), + required_sort: vec![ + ("nullable_col", true, false), + ("count", true, false), + ], + }.run(), + @ r#" + Input Plan: + SortExec: expr=[nullable_col@0 ASC NULLS LAST, count@2 ASC NULLS LAST], preserve_partitioning=[false] + WindowAggExec: wdw=[count: Ok(Field { name: "count", data_type: Int64 }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: Following(UInt64(NULL)), is_causal: false }] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet + + Optimized Plan: + WindowAggExec: wdw=[count: Ok(Field { name: "count", data_type: Int64 }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: Following(UInt64(NULL)), is_causal: false }] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet + "# + ); +} + +#[test] +fn test_window_partial_constant_and_set_monotonicity_9() { + assert_snapshot!(TestWindowCase { + partition_by: true, + window_frame: Arc::new(WindowFrame::new(None)), + func: fn_max_on_ordered(), + required_sort: vec![ + ("nullable_col", true, false), + ("max", false, false), + ], + }.run(), + @ r#" + Input Plan: + SortExec: expr=[nullable_col@0 ASC NULLS LAST, max@2 DESC NULLS LAST], preserve_partitioning=[false] + WindowAggExec: wdw=[max: Ok(Field { name: "max", data_type: Int32, nullable: true }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: Following(UInt64(NULL)), is_causal: false }] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet + + Optimized Plan: + WindowAggExec: wdw=[max: Ok(Field { name: "max", data_type: Int32, nullable: true }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: Following(UInt64(NULL)), is_causal: false }] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet + "# + ); +} + +#[test] +fn test_window_partial_constant_and_set_monotonicity_10() { + assert_snapshot!(TestWindowCase { + partition_by: true, + window_frame: Arc::new(WindowFrame::new(None)), + func: fn_min_on_ordered(), + required_sort: vec![ + ("min", false, false), + ("nullable_col", true, false), + ], + }.run(), + @ r#" + Input / Optimized Plan: + SortExec: expr=[min@2 DESC NULLS LAST, nullable_col@0 ASC NULLS LAST], preserve_partitioning=[false] + WindowAggExec: wdw=[min: Ok(Field { name: "min", data_type: Int32, nullable: true }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: Following(UInt64(NULL)), is_causal: false }] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet + "# + ); +} + +#[test] +fn test_window_partial_constant_and_set_monotonicity_11() { + assert_snapshot!(TestWindowCase { + partition_by: true, + window_frame: Arc::new(WindowFrame::new(None)), + func: fn_avg_on_ordered(), + required_sort: vec![ + ("avg", true, false), + ("nullable_col", true, false), + ], + }.run(), + @ r#" + Input / Optimized Plan: + SortExec: expr=[avg@2 ASC NULLS LAST, nullable_col@0 ASC NULLS LAST], preserve_partitioning=[false] + WindowAggExec: wdw=[avg: Ok(Field { name: "avg", data_type: Float64, nullable: true }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: Following(UInt64(NULL)), is_causal: false }] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet + "# + ); +} + +// =============================================REGION ENDS============================================= +// = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = +// ============================================REGION STARTS============================================ +// WindowAggExec + Plain(unbounded preceding, unbounded following) + partition_by + on unordered column +// Case 12: +#[test] +fn test_window_partial_constant_and_set_monotonicity_12() { + assert_snapshot!(TestWindowCase { + partition_by: true, + window_frame: Arc::new(WindowFrame::new(None)), + func: fn_count_on_unordered(), + required_sort: vec![ + ("non_nullable_col", true, false), + ("count", true, false), + ], + }.run(), + @ r#" + Input / Optimized Plan: + SortExec: expr=[non_nullable_col@1 ASC NULLS LAST, count@2 ASC NULLS LAST], preserve_partitioning=[false] + WindowAggExec: wdw=[count: Ok(Field { name: "count", data_type: Int64 }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: Following(UInt64(NULL)), is_causal: false }] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet + "# + ); +} + +// Case 13: +#[test] +fn test_window_partial_constant_and_set_monotonicity_13() { + assert_snapshot!(TestWindowCase { + partition_by: true, + window_frame: Arc::new(WindowFrame::new(None)), + func: fn_max_on_unordered(), + required_sort: vec![ + ("non_nullable_col", true, false), + ("max", false, false), + ], + }.run(), + @ r#" + Input / Optimized Plan: + SortExec: expr=[non_nullable_col@1 ASC NULLS LAST, max@2 DESC NULLS LAST], preserve_partitioning=[false] + WindowAggExec: wdw=[max: Ok(Field { name: "max", data_type: Int32, nullable: true }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: Following(UInt64(NULL)), is_causal: false }] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet + "# + ); +} + +// Case 14: +#[test] +fn test_window_partial_constant_and_set_monotonicity_14() { + assert_snapshot!(TestWindowCase { + partition_by: true, + window_frame: Arc::new(WindowFrame::new(None)), + func: fn_min_on_unordered(), + required_sort: vec![ + ("min", false, false), + ("non_nullable_col", true, false), + ], + }.run(), + @ r#" + Input / Optimized Plan: + SortExec: expr=[min@2 DESC NULLS LAST, non_nullable_col@1 ASC NULLS LAST], preserve_partitioning=[false] + WindowAggExec: wdw=[min: Ok(Field { name: "min", data_type: Int32, nullable: true }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: Following(UInt64(NULL)), is_causal: false }] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet + "# + ); +} + +// Case 15: +#[test] +fn test_window_partial_constant_and_set_monotonicity_15() { + assert_snapshot!(TestWindowCase { + partition_by: true, + window_frame: Arc::new(WindowFrame::new(None)), + func: fn_avg_on_unordered(), + required_sort: vec![ + ("avg", true, false), + ("nullable_col", true, false), + ], + }.run(), + @ r#" + Input / Optimized Plan: + SortExec: expr=[avg@2 ASC NULLS LAST, nullable_col@0 ASC NULLS LAST], preserve_partitioning=[false] + WindowAggExec: wdw=[avg: Ok(Field { name: "avg", data_type: Float64, nullable: true }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: Following(UInt64(NULL)), is_causal: false }] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet + "# + ); +} + +// =============================================REGION ENDS============================================= +// = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = +// ============================================REGION STARTS============================================ +// WindowAggExec + Sliding(current row, unbounded following) + no partition_by + on ordered column +// Case 16: +#[test] +fn test_window_partial_constant_and_set_monotonicity_16() { + assert_snapshot!(TestWindowCase { + partition_by: false, + window_frame: Arc::new(WindowFrame::new(Some(true)).reverse()), + func: fn_count_on_ordered(), + required_sort: vec![ + ("nullable_col", true, false), + ("count", false, false), + ], + }.run(), + @ r#" + Input Plan: + SortExec: expr=[nullable_col@0 ASC NULLS LAST, count@2 DESC NULLS LAST], preserve_partitioning=[false] + WindowAggExec: wdw=[count: Ok(Field { name: "count", data_type: Int64 }), frame: WindowFrame { units: Rows, start_bound: CurrentRow, end_bound: Following(UInt64(NULL)), is_causal: false }] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet + + Optimized Plan: + WindowAggExec: wdw=[count: Ok(Field { name: "count", data_type: Int64 }), frame: WindowFrame { units: Rows, start_bound: CurrentRow, end_bound: Following(UInt64(NULL)), is_causal: false }] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet + "# + ); +} + +// Case 17: +#[test] +fn test_window_partial_constant_and_set_monotonicity_17() { + assert_snapshot!(TestWindowCase { + partition_by: false, + window_frame: Arc::new(WindowFrame::new(Some(true)).reverse()), + func: fn_max_on_ordered(), + required_sort: vec![ + ("max", false, true), + ("nullable_col", true, false), + ], + }.run(), + @ r#" + Input Plan: + SortExec: expr=[max@2 DESC, nullable_col@0 ASC NULLS LAST], preserve_partitioning=[false] + WindowAggExec: wdw=[max: Ok(Field { name: "max", data_type: Int32, nullable: true }), frame: WindowFrame { units: Rows, start_bound: CurrentRow, end_bound: Following(UInt64(NULL)), is_causal: false }] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet + + Optimized Plan: + WindowAggExec: wdw=[max: Ok(Field { name: "max", data_type: Int32, nullable: true }), frame: WindowFrame { units: Rows, start_bound: CurrentRow, end_bound: Following(UInt64(NULL)), is_causal: false }] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet + "# + ); +} + +// Case 18: +#[test] +fn test_window_partial_constant_and_set_monotonicity_18() { + assert_snapshot!(TestWindowCase { + partition_by: false, + window_frame: Arc::new(WindowFrame::new(Some(true)).reverse()), + func: fn_min_on_ordered(), + required_sort: vec![ + ("min", true, true), + ("nullable_col", true, false), + ], + }.run(), + @ r#" + Input Plan: + SortExec: expr=[min@2 ASC, nullable_col@0 ASC NULLS LAST], preserve_partitioning=[false] + WindowAggExec: wdw=[min: Ok(Field { name: "min", data_type: Int32, nullable: true }), frame: WindowFrame { units: Rows, start_bound: CurrentRow, end_bound: Following(UInt64(NULL)), is_causal: false }] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet + + Optimized Plan: + WindowAggExec: wdw=[min: Ok(Field { name: "min", data_type: Int32, nullable: true }), frame: WindowFrame { units: Rows, start_bound: CurrentRow, end_bound: Following(UInt64(NULL)), is_causal: false }] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet + "# + ); +} + +// Case 19: +#[test] +fn test_window_partial_constant_and_set_monotonicity_19() { + assert_snapshot!(TestWindowCase { + partition_by: false, + window_frame: Arc::new(WindowFrame::new(Some(true)).reverse()), + func: fn_avg_on_ordered(), + required_sort: vec![ + ("avg", false, false), + ("nullable_col", true, false), + ], + }.run(), + @ r#" + Input / Optimized Plan: + SortExec: expr=[avg@2 DESC NULLS LAST, nullable_col@0 ASC NULLS LAST], preserve_partitioning=[false] + WindowAggExec: wdw=[avg: Ok(Field { name: "avg", data_type: Float64, nullable: true }), frame: WindowFrame { units: Rows, start_bound: CurrentRow, end_bound: Following(UInt64(NULL)), is_causal: false }] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet + "# + ); +} + +// =============================================REGION ENDS============================================= +// = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = +// ============================================REGION STARTS============================================ +// WindowAggExec + Sliding(current row, unbounded following) + no partition_by + on unordered column +// Case 20: +#[test] +fn test_window_partial_constant_and_set_monotonicity_20() { + assert_snapshot!(TestWindowCase { + partition_by: false, + window_frame: Arc::new(WindowFrame::new(Some(true)).reverse()), + func: fn_count_on_unordered(), + required_sort: vec![ + ("nullable_col", true, false), + ("count", true, false), + ], + }.run(), + @ r#" + Input / Optimized Plan: + SortExec: expr=[nullable_col@0 ASC NULLS LAST, count@2 ASC NULLS LAST], preserve_partitioning=[false] + WindowAggExec: wdw=[count: Ok(Field { name: "count", data_type: Int64 }), frame: WindowFrame { units: Rows, start_bound: CurrentRow, end_bound: Following(UInt64(NULL)), is_causal: false }] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet + "# + ); +} + +// Case 21: +#[test] +fn test_window_partial_constant_and_set_monotonicity_21() { + assert_snapshot!(TestWindowCase { + partition_by: false, + window_frame: Arc::new(WindowFrame::new(Some(true)).reverse()), + func: fn_max_on_unordered(), + required_sort: vec![ + ("nullable_col", true, false), + ("max", false, true), + ], + }.run(), + @ r#" + Input Plan: + SortExec: expr=[nullable_col@0 ASC NULLS LAST, max@2 DESC], preserve_partitioning=[false] + WindowAggExec: wdw=[max: Ok(Field { name: "max", data_type: Int32, nullable: true }), frame: WindowFrame { units: Rows, start_bound: CurrentRow, end_bound: Following(UInt64(NULL)), is_causal: false }] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet + + Optimized Plan: + WindowAggExec: wdw=[max: Ok(Field { name: "max", data_type: Int32, nullable: true }), frame: WindowFrame { units: Rows, start_bound: CurrentRow, end_bound: Following(UInt64(NULL)), is_causal: false }] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet + "# + ); +} + +// Case 22: +#[test] +fn test_window_partial_constant_and_set_monotonicity_22() { + assert_snapshot!(TestWindowCase { + partition_by: false, + window_frame: Arc::new(WindowFrame::new(Some(true)).reverse()), + func: fn_min_on_unordered(), + required_sort: vec![ + ("min", true, false), + ("nullable_col", true, false), + ], + }.run(), + @ r#" + Input / Optimized Plan: + SortExec: expr=[min@2 ASC NULLS LAST, nullable_col@0 ASC NULLS LAST], preserve_partitioning=[false] + WindowAggExec: wdw=[min: Ok(Field { name: "min", data_type: Int32, nullable: true }), frame: WindowFrame { units: Rows, start_bound: CurrentRow, end_bound: Following(UInt64(NULL)), is_causal: false }] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet + "# + ); +} + +// Case 23: +#[test] +fn test_window_partial_constant_and_set_monotonicity_23() { + assert_snapshot!(TestWindowCase { + partition_by: false, + window_frame: Arc::new(WindowFrame::new(Some(true)).reverse()), + func: fn_avg_on_unordered(), + required_sort: vec![ + ("avg", false, false), + ("nullable_col", true, false), + ], + }.run(), + @ r#" + Input / Optimized Plan: + SortExec: expr=[avg@2 DESC NULLS LAST, nullable_col@0 ASC NULLS LAST], preserve_partitioning=[false] + WindowAggExec: wdw=[avg: Ok(Field { name: "avg", data_type: Float64, nullable: true }), frame: WindowFrame { units: Rows, start_bound: CurrentRow, end_bound: Following(UInt64(NULL)), is_causal: false }] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet + "# + ); +} + +// =============================================REGION ENDS============================================= +// = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = +// ============================================REGION STARTS============================================ +// WindowAggExec + Sliding(current row, unbounded following) + partition_by + on ordered column +// Case 24: +#[test] +fn test_window_partial_constant_and_set_monotonicity_24() { + assert_snapshot!(TestWindowCase { + partition_by: true, + window_frame: Arc::new(WindowFrame::new(Some(true)).reverse()), + func: fn_count_on_ordered(), + required_sort: vec![ + ("nullable_col", true, false), + ("count", false, false), + ], + }.run(), + @ r#" + Input Plan: + SortExec: expr=[nullable_col@0 ASC NULLS LAST, count@2 DESC NULLS LAST], preserve_partitioning=[false] + WindowAggExec: wdw=[count: Ok(Field { name: "count", data_type: Int64 }), frame: WindowFrame { units: Rows, start_bound: CurrentRow, end_bound: Following(UInt64(NULL)), is_causal: false }] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet + + Optimized Plan: + WindowAggExec: wdw=[count: Ok(Field { name: "count", data_type: Int64 }), frame: WindowFrame { units: Rows, start_bound: CurrentRow, end_bound: Following(UInt64(NULL)), is_causal: false }] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet + "# + ); +} + +// Case 25: +#[test] +fn test_window_partial_constant_and_set_monotonicity_25() { + assert_snapshot!(TestWindowCase { + partition_by: true, + window_frame: Arc::new(WindowFrame::new(Some(true)).reverse()), + func: fn_max_on_ordered(), + required_sort: vec![ + ("nullable_col", true, false), + ("max", true, false), + ], + }.run(), + @ r#" + Input / Optimized Plan: + SortExec: expr=[nullable_col@0 ASC NULLS LAST, max@2 ASC NULLS LAST], preserve_partitioning=[false] + WindowAggExec: wdw=[max: Ok(Field { name: "max", data_type: Int32, nullable: true }), frame: WindowFrame { units: Rows, start_bound: CurrentRow, end_bound: Following(UInt64(NULL)), is_causal: false }] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet + "# + ); +} + +// Case 26: +#[test] +fn test_window_partial_constant_and_set_monotonicity_26() { + assert_snapshot!(TestWindowCase { + partition_by: true, + window_frame: Arc::new(WindowFrame::new(Some(true)).reverse()), + func: fn_min_on_ordered(), + required_sort: vec![ + ("min", false, false), + ], + }.run(), + @ r#" + Input / Optimized Plan: + SortExec: expr=[min@2 DESC NULLS LAST], preserve_partitioning=[false] + WindowAggExec: wdw=[min: Ok(Field { name: "min", data_type: Int32, nullable: true }), frame: WindowFrame { units: Rows, start_bound: CurrentRow, end_bound: Following(UInt64(NULL)), is_causal: false }] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet + "#); +} + +// Case 27: +#[test] +fn test_window_partial_constant_and_set_monotonicity_27() { + assert_snapshot!( + TestWindowCase { + partition_by: true, + window_frame: Arc::new(WindowFrame::new(Some(true)).reverse()), + func: fn_avg_on_ordered(), + required_sort: vec![ + ("avg", false, false), + ], + }.run(), + @ r#" + Input / Optimized Plan: + SortExec: expr=[avg@2 DESC NULLS LAST], preserve_partitioning=[false] + WindowAggExec: wdw=[avg: Ok(Field { name: "avg", data_type: Float64, nullable: true }), frame: WindowFrame { units: Rows, start_bound: CurrentRow, end_bound: Following(UInt64(NULL)), is_causal: false }] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet + "#); +} + +// =============================================REGION ENDS============================================= +// = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = +// ============================================REGION STARTS============================================ +// WindowAggExec + Sliding(current row, unbounded following) + partition_by + on unordered column + +// Case 28: +#[test] +fn test_window_partial_constant_and_set_monotonicity_28() { + assert_snapshot!( + TestWindowCase { + partition_by: true, + window_frame: Arc::new(WindowFrame::new(Some(true)).reverse()), + func: fn_count_on_unordered(), + required_sort: vec![ + ("count", false, false), + ("nullable_col", true, false), + ], + }.run(), + @ r#" + Input / Optimized Plan: + SortExec: expr=[count@2 DESC NULLS LAST, nullable_col@0 ASC NULLS LAST], preserve_partitioning=[false] + WindowAggExec: wdw=[count: Ok(Field { name: "count", data_type: Int64 }), frame: WindowFrame { units: Rows, start_bound: CurrentRow, end_bound: Following(UInt64(NULL)), is_causal: false }] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet + "# + ); +} + +// Case 29: +#[test] +fn test_window_partial_constant_and_set_monotonicity_29() { + assert_snapshot!(TestWindowCase { + partition_by: true, + window_frame: Arc::new(WindowFrame::new(Some(true)).reverse()), + func: fn_max_on_unordered(), + required_sort: vec![ + ("nullable_col", true, false), + ("max", false, true), + ], + }.run(), + @ r#" + Input Plan: + SortExec: expr=[nullable_col@0 ASC NULLS LAST, max@2 DESC], preserve_partitioning=[false] + WindowAggExec: wdw=[max: Ok(Field { name: "max", data_type: Int32, nullable: true }), frame: WindowFrame { units: Rows, start_bound: CurrentRow, end_bound: Following(UInt64(NULL)), is_causal: false }] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet + + Optimized Plan: + WindowAggExec: wdw=[max: Ok(Field { name: "max", data_type: Int32, nullable: true }), frame: WindowFrame { units: Rows, start_bound: CurrentRow, end_bound: Following(UInt64(NULL)), is_causal: false }] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet + "#) +} + +// Case 30: +#[test] +fn test_window_partial_constant_and_set_monotonicity_30() { + assert_snapshot!(TestWindowCase { + partition_by: true, + window_frame: Arc::new(WindowFrame::new(Some(true)).reverse()), + func: fn_min_on_unordered(), + required_sort: vec![ + ("min", false, false), + ], + }.run(), + @ r#" + Input / Optimized Plan: + SortExec: expr=[min@2 DESC NULLS LAST], preserve_partitioning=[false] + WindowAggExec: wdw=[min: Ok(Field { name: "min", data_type: Int32, nullable: true }), frame: WindowFrame { units: Rows, start_bound: CurrentRow, end_bound: Following(UInt64(NULL)), is_causal: false }] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet + "#); +} + +// Case 31: +#[test] +fn test_window_partial_constant_and_set_monotonicity_31() { + assert_snapshot!(TestWindowCase { + partition_by: true, + window_frame: Arc::new(WindowFrame::new(Some(true)).reverse()), + func: fn_avg_on_unordered(), + required_sort: vec![ + ("nullable_col", true, false), + ("avg", true, false), + ], + }.run(), + @ r#" + Input / Optimized Plan: + SortExec: expr=[nullable_col@0 ASC NULLS LAST, avg@2 ASC NULLS LAST], preserve_partitioning=[false] + WindowAggExec: wdw=[avg: Ok(Field { name: "avg", data_type: Float64, nullable: true }), frame: WindowFrame { units: Rows, start_bound: CurrentRow, end_bound: Following(UInt64(NULL)), is_causal: false }] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet + "# + ); +} + +// =============================================REGION ENDS============================================= +// = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = +// ============================================REGION STARTS============================================ +// BoundedWindowAggExec + Plain(unbounded preceding, unbounded following) + no partition_by + on ordered column + +// Case 32: +#[test] +fn test_window_partial_constant_and_set_monotonicity_32() { + assert_snapshot!(TestWindowCase { + partition_by: false, + window_frame: Arc::new(WindowFrame::new(Some(true))), + func: fn_count_on_ordered(), + required_sort: vec![ + ("nullable_col", true, false), + ("count", true, false), + ], + }.run(), + @ r#" + Input Plan: + SortExec: expr=[nullable_col@0 ASC NULLS LAST, count@2 ASC NULLS LAST], preserve_partitioning=[false] + BoundedWindowAggExec: wdw=[count: Field { "count": Int64 }, frame: ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW], mode=[Sorted] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet + + Optimized Plan: + BoundedWindowAggExec: wdw=[count: Field { "count": Int64 }, frame: ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW], mode=[Sorted] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet + "# + ); +} + +// Case 33: +#[test] +fn test_window_partial_constant_and_set_monotonicity_33() { + assert_snapshot!(TestWindowCase { + partition_by: false, + window_frame: Arc::new(WindowFrame::new(Some(true))), + func: fn_max_on_ordered(), + required_sort: vec![ + ("max", false, false), + ("nullable_col", true, false), + ], + }.run(), + @ r#" + Input / Optimized Plan: + SortExec: expr=[max@2 DESC NULLS LAST, nullable_col@0 ASC NULLS LAST], preserve_partitioning=[false] + BoundedWindowAggExec: wdw=[max: Field { "max": nullable Int32 }, frame: ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW], mode=[Sorted] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet + "# + ); +} + +// Case 34: +#[test] +fn test_window_partial_constant_and_set_monotonicity_34() { + assert_snapshot!(TestWindowCase { + partition_by: false, + window_frame: Arc::new(WindowFrame::new(Some(true))), + func: fn_min_on_ordered(), + required_sort: vec![ + ("min", false, false), + ("nullable_col", true, false), + ], + }.run(), + @ r#" + Input Plan: + SortExec: expr=[min@2 DESC NULLS LAST, nullable_col@0 ASC NULLS LAST], preserve_partitioning=[false] + BoundedWindowAggExec: wdw=[min: Field { "min": nullable Int32 }, frame: ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW], mode=[Sorted] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet + + Optimized Plan: + BoundedWindowAggExec: wdw=[min: Field { "min": nullable Int32 }, frame: ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW], mode=[Sorted] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet + "# + ); +} +// Case 35: +#[test] +fn test_window_partial_constant_and_set_monotonicity_35() { + assert_snapshot!(TestWindowCase { + partition_by: false, + window_frame: Arc::new(WindowFrame::new(Some(true))), + func: fn_avg_on_ordered(), + required_sort: vec![ + ("nullable_col", true, false), + ("avg", true, false), + ], + }.run(), + @ r#" + Input / Optimized Plan: + SortExec: expr=[nullable_col@0 ASC NULLS LAST, avg@2 ASC NULLS LAST], preserve_partitioning=[false] + BoundedWindowAggExec: wdw=[avg: Field { "avg": nullable Float64 }, frame: ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW], mode=[Sorted] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet + "# + ); +} + +// =============================================REGION ENDS============================================= +// = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = +// ============================================REGION STARTS============================================ +// BoundedWindowAggExec + Plain(unbounded preceding, unbounded following) + no partition_by + on unordered column + +// Case 36: +#[test] +fn test_window_partial_constant_and_set_monotonicity_36() { + assert_snapshot!(TestWindowCase { + partition_by: false, + window_frame: Arc::new(WindowFrame::new(Some(true))), + func: fn_count_on_unordered(), + required_sort: vec![ + ("nullable_col", true, false), + ("count", true, true), + ], + }.run(), + @ r#" + Input Plan: + SortExec: expr=[nullable_col@0 ASC NULLS LAST, count@2 ASC], preserve_partitioning=[false] + BoundedWindowAggExec: wdw=[count: Field { "count": Int64 }, frame: ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW], mode=[Sorted] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet + + Optimized Plan: + BoundedWindowAggExec: wdw=[count: Field { "count": Int64 }, frame: ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW], mode=[Sorted] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet + "# + ); +} + +// Case 37: +#[test] +fn test_window_partial_constant_and_set_monotonicity_37() { + assert_snapshot!(TestWindowCase { + partition_by: false, + window_frame: Arc::new(WindowFrame::new(Some(true))), + func: fn_max_on_unordered(), + required_sort: vec![ + ("max", true, false), + ("nullable_col", true, false), + ], + }.run(), + @ r#" + Input Plan: + SortExec: expr=[max@2 ASC NULLS LAST, nullable_col@0 ASC NULLS LAST], preserve_partitioning=[false] + BoundedWindowAggExec: wdw=[max: Field { "max": nullable Int32 }, frame: ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW], mode=[Sorted] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet + + Optimized Plan: + BoundedWindowAggExec: wdw=[max: Field { "max": nullable Int32 }, frame: ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW], mode=[Sorted] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet + "# + ); +} + +// Case 38: +#[test] +fn test_window_partial_constant_and_set_monotonicity_38() { + assert_snapshot!(TestWindowCase { + partition_by: false, + window_frame: Arc::new(WindowFrame::new(Some(true))), + func: fn_min_on_unordered(), + required_sort: vec![ + ("min", false, true), + ("nullable_col", true, false), + ], + }.run(), + @ r#" + Input / Optimized Plan: + SortExec: expr=[min@2 DESC, nullable_col@0 ASC NULLS LAST], preserve_partitioning=[false] + BoundedWindowAggExec: wdw=[min: Field { "min": nullable Int32 }, frame: ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW], mode=[Sorted] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet + "# + ); +} + +// Case 39: +#[test] +fn test_window_partial_constant_and_set_monotonicity_39() { + assert_snapshot!(TestWindowCase { + partition_by: false, + window_frame: Arc::new(WindowFrame::new(Some(true))), + func: fn_avg_on_unordered(), + required_sort: vec![ + ("avg", true, false), + ], + }.run(), + @ r#" + Input / Optimized Plan: + SortExec: expr=[avg@2 ASC NULLS LAST], preserve_partitioning=[false] + BoundedWindowAggExec: wdw=[avg: Field { "avg": nullable Float64 }, frame: ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW], mode=[Sorted] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet + "# + ); +} + +// =============================================REGION ENDS============================================= +// = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = +// ============================================REGION STARTS============================================ +// BoundedWindowAggExec + Plain(unbounded preceding, unbounded following) + partition_by + on ordered column + +// Case 40: +#[test] +fn test_window_partial_constant_and_set_monotonicity_40() { + assert_snapshot!(TestWindowCase { + partition_by: true, + window_frame: Arc::new(WindowFrame::new(Some(true))), + func: fn_count_on_ordered(), + required_sort: vec![ + ("nullable_col", true, false), + ("count", true, false), + ], + }.run(), + @ r#" + Input Plan: + SortExec: expr=[nullable_col@0 ASC NULLS LAST, count@2 ASC NULLS LAST], preserve_partitioning=[false] + BoundedWindowAggExec: wdw=[count: Field { "count": Int64 }, frame: ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW], mode=[Sorted] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet + + Optimized Plan: + BoundedWindowAggExec: wdw=[count: Field { "count": Int64 }, frame: ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW], mode=[Sorted] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet + "# + ); +} + +// Case 41: +#[test] +fn test_window_partial_constant_and_set_monotonicity_41() { + assert_snapshot!(TestWindowCase { + partition_by: true, + window_frame: Arc::new(WindowFrame::new(Some(true))), + func: fn_max_on_ordered(), + required_sort: vec![ + ("max", true, false), + ("nullable_col", true, false), + ], + }.run(), + @ r#" + Input / Optimized Plan: + SortExec: expr=[max@2 ASC NULLS LAST, nullable_col@0 ASC NULLS LAST], preserve_partitioning=[false] + BoundedWindowAggExec: wdw=[max: Field { "max": nullable Int32 }, frame: ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW], mode=[Sorted] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet + "# + ); +} + +// Case 42: +#[test] +fn test_window_partial_constant_and_set_monotonicity_42() { + assert_snapshot!(TestWindowCase { + partition_by: true, + window_frame: Arc::new(WindowFrame::new(Some(true))), + func: fn_min_on_ordered(), + required_sort: vec![ + ("min", false, false), + ("nullable_col", true, false), + ], + }.run(), + @ r#" + Input / Optimized Plan: + SortExec: expr=[min@2 DESC NULLS LAST, nullable_col@0 ASC NULLS LAST], preserve_partitioning=[false] + BoundedWindowAggExec: wdw=[min: Field { "min": nullable Int32 }, frame: ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW], mode=[Sorted] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet + "# + ); +} + +// Case 43: +#[test] +fn test_window_partial_constant_and_set_monotonicity_43() { + assert_snapshot!(TestWindowCase { + partition_by: true, + window_frame: Arc::new(WindowFrame::new(Some(true))), + func: fn_avg_on_ordered(), + required_sort: vec![ + ("nullable_col", true, false), + ("avg", true, false), + ], + }.run(), + @ r#" + Input / Optimized Plan: + SortExec: expr=[nullable_col@0 ASC NULLS LAST, avg@2 ASC NULLS LAST], preserve_partitioning=[false] + BoundedWindowAggExec: wdw=[avg: Field { "avg": nullable Float64 }, frame: ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW], mode=[Sorted] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet + "# + ); +} + +// =============================================REGION ENDS============================================= +// = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = +// ============================================REGION STARTS============================================ +// BoundedWindowAggExec + Plain(unbounded preceding, unbounded following) + partition_by + on unordered column + +// Case 44: +#[test] +fn test_window_partial_constant_and_set_monotonicity_44() { + assert_snapshot!(TestWindowCase { + partition_by: true, + window_frame: Arc::new(WindowFrame::new(Some(true))), + func: fn_count_on_unordered(), + required_sort: vec![ + ("count", true, true), + ], + }.run(), + @ r#" + Input / Optimized Plan: + SortExec: expr=[count@2 ASC], preserve_partitioning=[false] + BoundedWindowAggExec: wdw=[count: Field { "count": Int64 }, frame: ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW], mode=[Sorted] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet + "# + ); +} + +// Case 45: +#[test] +fn test_window_partial_constant_and_set_monotonicity_45() { + assert_snapshot!(TestWindowCase { + partition_by: true, + window_frame: Arc::new(WindowFrame::new(Some(true))), + func: fn_max_on_unordered(), + required_sort: vec![ + ("nullable_col", true, false), + ("max", false, false), + ], + }.run(), + @ r#" + Input / Optimized Plan: + SortExec: expr=[nullable_col@0 ASC NULLS LAST, max@2 DESC NULLS LAST], preserve_partitioning=[false] + BoundedWindowAggExec: wdw=[max: Field { "max": nullable Int32 }, frame: ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW], mode=[Sorted] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet + "# + ); +} + +// Case 46: +#[test] +fn test_window_partial_constant_and_set_monotonicity_46() { + assert_snapshot!(TestWindowCase { + partition_by: true, + window_frame: Arc::new(WindowFrame::new(Some(true))), + func: fn_min_on_unordered(), + required_sort: vec![ + ("nullable_col", true, false), + ("min", false, false), + ], + }.run(), + @ r#" + Input Plan: + SortExec: expr=[nullable_col@0 ASC NULLS LAST, min@2 DESC NULLS LAST], preserve_partitioning=[false] + BoundedWindowAggExec: wdw=[min: Field { "min": nullable Int32 }, frame: ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW], mode=[Sorted] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet + + Optimized Plan: + BoundedWindowAggExec: wdw=[min: Field { "min": nullable Int32 }, frame: ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW], mode=[Sorted] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet + "# + ); +} + +// Case 47: +#[test] +fn test_window_partial_constant_and_set_monotonicity_47() { + assert_snapshot!(TestWindowCase { + partition_by: true, + window_frame: Arc::new(WindowFrame::new(Some(true))), + func: fn_avg_on_unordered(), + required_sort: vec![ + ("nullable_col", true, false), + ], + }.run(), + @ r#" + Input Plan: + SortExec: expr=[nullable_col@0 ASC NULLS LAST], preserve_partitioning=[false] + BoundedWindowAggExec: wdw=[avg: Field { "avg": nullable Float64 }, frame: ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW], mode=[Sorted] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet + + Optimized Plan: + BoundedWindowAggExec: wdw=[avg: Field { "avg": nullable Float64 }, frame: ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW], mode=[Sorted] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet + "# + ); +} + +// =============================================REGION ENDS============================================= +// = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = +// ============================================REGION STARTS============================================ +// BoundedWindowAggExec + Sliding(bounded preceding, bounded following) + no partition_by + on ordered column + +// Case 48: +#[test] +fn test_window_partial_constant_and_set_monotonicity_48() { + assert_snapshot!(TestWindowCase { + partition_by: false, + window_frame: Arc::new(WindowFrame::new_bounds(WindowFrameUnits::Rows, WindowFrameBound::Preceding(ScalarValue::new_one(&DataType::UInt32).unwrap()), WindowFrameBound::CurrentRow)), + func: fn_count_on_ordered(), + required_sort: vec![ + ("count", true, false), + ("nullable_col", true, false), + ], + }.run(), + @ r#" + Input Plan: + SortExec: expr=[count@2 ASC NULLS LAST, nullable_col@0 ASC NULLS LAST], preserve_partitioning=[false] + BoundedWindowAggExec: wdw=[count: Field { "count": Int64 }, frame: ROWS BETWEEN 1 PRECEDING AND CURRENT ROW], mode=[Sorted] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet + + Optimized Plan: + BoundedWindowAggExec: wdw=[count: Field { "count": Int64 }, frame: ROWS BETWEEN 1 PRECEDING AND CURRENT ROW], mode=[Sorted] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet + "# + ); +} + +// Case 49: +#[test] +fn test_window_partial_constant_and_set_monotonicity_49() { + assert_snapshot!(TestWindowCase { + partition_by: false, + window_frame: Arc::new(WindowFrame::new_bounds(WindowFrameUnits::Rows, WindowFrameBound::Preceding(ScalarValue::new_one(&DataType::UInt32).unwrap()), WindowFrameBound::Following(ScalarValue::new_one(&DataType::UInt32).unwrap()))), + func: fn_max_on_ordered(), + required_sort: vec![ + ("max", true, false), + ], + }.run(), + @ r#" + Input / Optimized Plan: + SortExec: expr=[max@2 ASC NULLS LAST], preserve_partitioning=[false] + BoundedWindowAggExec: wdw=[max: Field { "max": nullable Int32 }, frame: ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING], mode=[Sorted] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet + "# + ); +} + +// Case 50: +#[test] +fn test_window_partial_constant_and_set_monotonicity_50() { + assert_snapshot!(TestWindowCase { + partition_by: false, + window_frame: Arc::new(WindowFrame::new_bounds(WindowFrameUnits::Rows, WindowFrameBound::Preceding(ScalarValue::new_one(&DataType::UInt32).unwrap()), WindowFrameBound::CurrentRow)), + func: fn_min_on_ordered(), + required_sort: vec![ + ("nullable_col", true, false), + ("min", false, false), + ], + }.run(), + @ r#" + Input Plan: + SortExec: expr=[nullable_col@0 ASC NULLS LAST, min@2 DESC NULLS LAST], preserve_partitioning=[false] + BoundedWindowAggExec: wdw=[min: Field { "min": nullable Int32 }, frame: ROWS BETWEEN 1 PRECEDING AND CURRENT ROW], mode=[Sorted] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet + + Optimized Plan: + BoundedWindowAggExec: wdw=[min: Field { "min": nullable Int32 }, frame: ROWS BETWEEN 1 PRECEDING AND CURRENT ROW], mode=[Sorted] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet + "# + ); +} + +// Case 51: +#[test] +fn test_window_partial_constant_and_set_monotonicity_51() { + assert_snapshot!(TestWindowCase { + partition_by: false, + window_frame: Arc::new(WindowFrame::new_bounds(WindowFrameUnits::Rows, WindowFrameBound::Preceding(ScalarValue::new_one(&DataType::UInt32).unwrap()), WindowFrameBound::CurrentRow)), + func: fn_avg_on_ordered(), + required_sort: vec![ + ("avg", true, false), + ], + }.run(), + @ r#" + Input / Optimized Plan: + SortExec: expr=[avg@2 ASC NULLS LAST], preserve_partitioning=[false] + BoundedWindowAggExec: wdw=[avg: Field { "avg": nullable Float64 }, frame: ROWS BETWEEN 1 PRECEDING AND CURRENT ROW], mode=[Sorted] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet + "# + ); +} + +// =============================================REGION ENDS============================================= +// = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = +// ============================================REGION STARTS============================================ +// BoundedWindowAggExec + Sliding(bounded preceding, bounded following) + no partition_by + on unordered column + +// Case 52: +#[test] +fn test_window_partial_constant_and_set_monotonicity_52() { + assert_snapshot!(TestWindowCase { + partition_by: false, + window_frame: Arc::new(WindowFrame::new_bounds(WindowFrameUnits::Rows, WindowFrameBound::Preceding(ScalarValue::new_one(&DataType::UInt32).unwrap()), WindowFrameBound::Following(ScalarValue::new_one(&DataType::UInt32).unwrap()))), + func: fn_count_on_unordered(), + required_sort: vec![ + ("count", true, false), + ("nullable_col", true, false), + ], + }.run(), + @ r#" + Input / Optimized Plan: + SortExec: expr=[count@2 ASC NULLS LAST, nullable_col@0 ASC NULLS LAST], preserve_partitioning=[false] + BoundedWindowAggExec: wdw=[count: Field { "count": Int64 }, frame: ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING], mode=[Sorted] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet + "# + ); +} + +// Case 53: +#[test] +fn test_window_partial_constant_and_set_monotonicity_53() { + assert_snapshot!(TestWindowCase { + partition_by: false, + window_frame: Arc::new(WindowFrame::new_bounds(WindowFrameUnits::Rows, WindowFrameBound::Preceding(ScalarValue::new_one(&DataType::UInt32).unwrap()), WindowFrameBound::CurrentRow)), + func: fn_max_on_unordered(), + required_sort: vec![ + ("nullable_col", true, false), + ("max", true, false), + ], + }.run(), + @ r#" + Input / Optimized Plan: + SortExec: expr=[nullable_col@0 ASC NULLS LAST, max@2 ASC NULLS LAST], preserve_partitioning=[false] + BoundedWindowAggExec: wdw=[max: Field { "max": nullable Int32 }, frame: ROWS BETWEEN 1 PRECEDING AND CURRENT ROW], mode=[Sorted] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet + "# + ); +} + +// Case 54: +#[test] +fn test_window_partial_constant_and_set_monotonicity_54() { + assert_snapshot!(TestWindowCase { + partition_by: false, + window_frame: Arc::new(WindowFrame::new_bounds(WindowFrameUnits::Rows, WindowFrameBound::Preceding(ScalarValue::new_one(&DataType::UInt32).unwrap()), WindowFrameBound::CurrentRow)), + func: fn_min_on_unordered(), + required_sort: vec![ + ("min", true, false), + ], + }.run(), + @ r#" + Input / Optimized Plan: + SortExec: expr=[min@2 ASC NULLS LAST], preserve_partitioning=[false] + BoundedWindowAggExec: wdw=[min: Field { "min": nullable Int32 }, frame: ROWS BETWEEN 1 PRECEDING AND CURRENT ROW], mode=[Sorted] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet + "# + ); +} + +// Case 55: +#[test] +fn test_window_partial_constant_and_set_monotonicity_55() { + assert_snapshot!(TestWindowCase { + partition_by: false, + window_frame: Arc::new(WindowFrame::new_bounds(WindowFrameUnits::Rows, WindowFrameBound::Preceding(ScalarValue::new_one(&DataType::UInt32).unwrap()), WindowFrameBound::Following(ScalarValue::new_one(&DataType::UInt32).unwrap()))), + func: fn_avg_on_unordered(), + required_sort: vec![ + ("nullable_col", true, false), + ], + }.run(), + @ r#" + Input Plan: + SortExec: expr=[nullable_col@0 ASC NULLS LAST], preserve_partitioning=[false] + BoundedWindowAggExec: wdw=[avg: Field { "avg": nullable Float64 }, frame: ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING], mode=[Sorted] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet + + Optimized Plan: + BoundedWindowAggExec: wdw=[avg: Field { "avg": nullable Float64 }, frame: ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING], mode=[Sorted] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet + "# + ); +} + +// =============================================REGION ENDS============================================= +// = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = +// ============================================REGION STARTS============================================ +// BoundedWindowAggExec + Sliding(bounded preceding, bounded following) + partition_by + on ordered column + +// Case 56: +#[test] +fn test_window_partial_constant_and_set_monotonicity_56() { + assert_snapshot!(TestWindowCase { + partition_by: true, + window_frame: Arc::new(WindowFrame::new_bounds(WindowFrameUnits::Rows, WindowFrameBound::Preceding(ScalarValue::new_one(&DataType::UInt32).unwrap()), WindowFrameBound::CurrentRow)), + func: fn_count_on_ordered(), + required_sort: vec![ + ("count", true, false), + ("nullable_col", true, false), + ], + }.run(), + @ r#" + Input Plan: + SortExec: expr=[count@2 ASC NULLS LAST, nullable_col@0 ASC NULLS LAST], preserve_partitioning=[false] + BoundedWindowAggExec: wdw=[count: Field { "count": Int64 }, frame: ROWS BETWEEN 1 PRECEDING AND CURRENT ROW], mode=[Sorted] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet + + Optimized Plan: + BoundedWindowAggExec: wdw=[count: Field { "count": Int64 }, frame: ROWS BETWEEN 1 PRECEDING AND CURRENT ROW], mode=[Sorted] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet + "# + ); +} + +// Case 57: +#[test] +fn test_window_partial_constant_and_set_monotonicity_57() { + assert_snapshot!(TestWindowCase { + partition_by: true, + window_frame: Arc::new(WindowFrame::new_bounds(WindowFrameUnits::Rows, WindowFrameBound::Preceding(ScalarValue::new_one(&DataType::UInt32).unwrap()), WindowFrameBound::Following(ScalarValue::new_one(&DataType::UInt32).unwrap()))), + func: fn_max_on_ordered(), + required_sort: vec![ + ("nullable_col", true, false), + ("max", true, false), + ], + }.run(), + @ r#" + Input / Optimized Plan: + SortExec: expr=[nullable_col@0 ASC NULLS LAST, max@2 ASC NULLS LAST], preserve_partitioning=[false] + BoundedWindowAggExec: wdw=[max: Field { "max": nullable Int32 }, frame: ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING], mode=[Sorted] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet + "# + ); +} + +// Case 58: +#[test] +fn test_window_partial_constant_and_set_monotonicity_58() { + assert_snapshot!(TestWindowCase { + partition_by: true, + window_frame: Arc::new(WindowFrame::new_bounds(WindowFrameUnits::Rows, WindowFrameBound::Preceding(ScalarValue::new_one(&DataType::UInt32).unwrap()), WindowFrameBound::CurrentRow)), + func: fn_min_on_ordered(), + required_sort: vec![ + ("min", false, false), + ("nullable_col", true, false), + ], + }.run(), + @ r#" + Input / Optimized Plan: + SortExec: expr=[min@2 DESC NULLS LAST, nullable_col@0 ASC NULLS LAST], preserve_partitioning=[false] + BoundedWindowAggExec: wdw=[min: Field { "min": nullable Int32 }, frame: ROWS BETWEEN 1 PRECEDING AND CURRENT ROW], mode=[Sorted] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet + "# + ); +} + +// Case 59: +#[test] +fn test_window_partial_constant_and_set_monotonicity_59() { + assert_snapshot!(TestWindowCase { + partition_by: true, + window_frame: Arc::new(WindowFrame::new_bounds(WindowFrameUnits::Rows, WindowFrameBound::Preceding(ScalarValue::new_one(&DataType::UInt32).unwrap()), WindowFrameBound::CurrentRow)), + func: fn_avg_on_ordered(), + required_sort: vec![ + ("avg", true, false), + ], + }.run(), + @ r#" + Input / Optimized Plan: + SortExec: expr=[avg@2 ASC NULLS LAST], preserve_partitioning=[false] + BoundedWindowAggExec: wdw=[avg: Field { "avg": nullable Float64 }, frame: ROWS BETWEEN 1 PRECEDING AND CURRENT ROW], mode=[Sorted] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet + "# + ); +} + +// =============================================REGION ENDS============================================= +// = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = +// ============================================REGION STARTS============================================ +// BoundedWindowAggExec + Sliding(bounded preceding, bounded following) + partition_by + on unordered column + +// Case 60: +#[test] +fn test_window_partial_constant_and_set_monotonicity_60() { + assert_snapshot!(TestWindowCase { + partition_by: true, + window_frame: Arc::new(WindowFrame::new_bounds(WindowFrameUnits::Rows, WindowFrameBound::Preceding(ScalarValue::new_one(&DataType::UInt32).unwrap()), WindowFrameBound::CurrentRow)), + func: fn_count_on_unordered(), + required_sort: vec![ + ("nullable_col", true, false), + ("count", true, false), + ], + }.run(), + @ r#" + Input / Optimized Plan: + SortExec: expr=[nullable_col@0 ASC NULLS LAST, count@2 ASC NULLS LAST], preserve_partitioning=[false] + BoundedWindowAggExec: wdw=[count: Field { "count": Int64 }, frame: ROWS BETWEEN 1 PRECEDING AND CURRENT ROW], mode=[Sorted] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet + "# + ); +} + +// Case 61: +#[test] +fn test_window_partial_constant_and_set_monotonicity_61() { + assert_snapshot!(TestWindowCase { + partition_by: true, + window_frame: Arc::new(WindowFrame::new_bounds(WindowFrameUnits::Rows, WindowFrameBound::Preceding(ScalarValue::new_one(&DataType::UInt32).unwrap()), WindowFrameBound::CurrentRow)), + func: fn_max_on_unordered(), + required_sort: vec![ + ("nullable_col", true, false), + ("max", true, true), + ], + }.run(), + @ r#" + Input / Optimized Plan: + SortExec: expr=[nullable_col@0 ASC NULLS LAST, max@2 ASC], preserve_partitioning=[false] + BoundedWindowAggExec: wdw=[max: Field { "max": nullable Int32 }, frame: ROWS BETWEEN 1 PRECEDING AND CURRENT ROW], mode=[Sorted] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet + "# + ); +} + +// Case 62: +#[test] +fn test_window_partial_constant_and_set_monotonicity_62() { + assert_snapshot!(TestWindowCase { + partition_by: true, + window_frame: Arc::new(WindowFrame::new_bounds(WindowFrameUnits::Rows, WindowFrameBound::Preceding(ScalarValue::new_one(&DataType::UInt32).unwrap()), WindowFrameBound::CurrentRow)), + func: fn_min_on_unordered(), + required_sort: vec![ + ("nullable_col", true, false), + ("min", false, false), + ], + }.run(), + @ r#" + Input / Optimized Plan: + SortExec: expr=[nullable_col@0 ASC NULLS LAST, min@2 DESC NULLS LAST], preserve_partitioning=[false] + BoundedWindowAggExec: wdw=[min: Field { "min": nullable Int32 }, frame: ROWS BETWEEN 1 PRECEDING AND CURRENT ROW], mode=[Sorted] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet + "# + ); +} + +// Case 63: +#[test] +fn test_window_partial_constant_and_set_monotonicity_63() { + assert_snapshot!(TestWindowCase { + partition_by: true, + window_frame: Arc::new(WindowFrame::new_bounds(WindowFrameUnits::Rows, WindowFrameBound::Preceding(ScalarValue::new_one(&DataType::UInt32).unwrap()), WindowFrameBound::CurrentRow)), + func: fn_avg_on_unordered(), + required_sort: vec![ + ("nullable_col", true, false), + ], + }.run(), + @ r#" + Input Plan: + SortExec: expr=[nullable_col@0 ASC NULLS LAST], preserve_partitioning=[false] + BoundedWindowAggExec: wdw=[avg: Field { "avg": nullable Float64 }, frame: ROWS BETWEEN 1 PRECEDING AND CURRENT ROW], mode=[Sorted] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet + + Optimized Plan: + BoundedWindowAggExec: wdw=[avg: Field { "avg": nullable Float64 }, frame: ROWS BETWEEN 1 PRECEDING AND CURRENT ROW], mode=[Sorted] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet + "# + ); +} +// =============================================REGION ENDS============================================= diff --git a/datafusion/core/tests/physical_optimizer/mod.rs b/datafusion/core/tests/physical_optimizer/mod.rs index 777c26e80e90..936c02eb2a02 100644 --- a/datafusion/core/tests/physical_optimizer/mod.rs +++ b/datafusion/core/tests/physical_optimizer/mod.rs @@ -21,6 +21,7 @@ mod aggregate_statistics; mod combine_partial_final_agg; mod enforce_distribution; mod enforce_sorting; +mod enforce_sorting_monotonicity; mod filter_pushdown; mod join_selection; mod limit_pushdown; diff --git a/datafusion/core/tests/physical_optimizer/partition_statistics.rs b/datafusion/core/tests/physical_optimizer/partition_statistics.rs index 62ab5cbc422b..49dc5b845605 100644 --- a/datafusion/core/tests/physical_optimizer/partition_statistics.rs +++ b/datafusion/core/tests/physical_optimizer/partition_statistics.rs @@ -17,6 +17,7 @@ #[cfg(test)] mod test { + use insta::assert_snapshot; use std::sync::Arc; use arrow::array::{Int32Array, RecordBatch}; @@ -606,21 +607,21 @@ mod test { .build() .map(Arc::new)?]; - let aggregate_exec_partial = Arc::new(AggregateExec::try_new( - AggregateMode::Partial, - group_by.clone(), - aggr_expr.clone(), - vec![None], - Arc::clone(&scan), - scan_schema.clone(), - )?) as _; - - let mut plan_string = get_plan_string(&aggregate_exec_partial); - let _ = plan_string.swap_remove(1); - let expected_plan = vec![ - "AggregateExec: mode=Partial, gby=[id@0 as id, 1 + id@0 as expr], aggr=[COUNT(c)]", - ]; - assert_eq!(plan_string, expected_plan); + let aggregate_exec_partial: Arc = + Arc::new(AggregateExec::try_new( + AggregateMode::Partial, + group_by.clone(), + aggr_expr.clone(), + vec![None], + Arc::clone(&scan), + scan_schema.clone(), + )?) as _; + + let plan_string = get_plan_string(&aggregate_exec_partial).swap_remove(0); + assert_snapshot!( + plan_string, + @"AggregateExec: mode=Partial, gby=[id@0 as id, 1 + id@0 as expr], aggr=[COUNT(c)]" + ); let p0_statistics = aggregate_exec_partial.partition_statistics(Some(0))?; @@ -710,7 +711,10 @@ mod test { )?) as _; let agg_plan = get_plan_string(&agg_partial).remove(0); - assert_eq!("AggregateExec: mode=Partial, gby=[id@0 as id, 1 + id@0 as expr], aggr=[COUNT(c)]",agg_plan); + assert_snapshot!( + agg_plan, + @"AggregateExec: mode=Partial, gby=[id@0 as id, 1 + id@0 as expr], aggr=[COUNT(c)]" + ); let empty_stat = Statistics { num_rows: Precision::Exact(0), diff --git a/datafusion/core/tests/physical_optimizer/sanity_checker.rs b/datafusion/core/tests/physical_optimizer/sanity_checker.rs index ce6eb13c86c4..9867ed173341 100644 --- a/datafusion/core/tests/physical_optimizer/sanity_checker.rs +++ b/datafusion/core/tests/physical_optimizer/sanity_checker.rs @@ -421,7 +421,7 @@ async fn test_bounded_window_agg_sort_requirement() -> Result<()> { assert_snapshot!( actual, @r#" - BoundedWindowAggExec: wdw=[count: Field { name: "count", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }, frame: RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW], mode=[Sorted] + BoundedWindowAggExec: wdw=[count: Field { "count": Int64 }, frame: RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW], mode=[Sorted] SortExec: expr=[c9@0 ASC NULLS LAST], preserve_partitioning=[false] DataSourceExec: partitions=1, partition_sizes=[0] "# @@ -449,7 +449,7 @@ async fn test_bounded_window_agg_no_sort_requirement() -> Result<()> { assert_snapshot!( actual, @r#" - BoundedWindowAggExec: wdw=[count: Field { name: "count", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }, frame: RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW], mode=[Sorted] + BoundedWindowAggExec: wdw=[count: Field { "count": Int64 }, frame: RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW], mode=[Sorted] DataSourceExec: partitions=1, partition_sizes=[0] "# ); diff --git a/datafusion/core/tests/sql/explain_analyze.rs b/datafusion/core/tests/sql/explain_analyze.rs index e082cabaadaf..43f79ead0257 100644 --- a/datafusion/core/tests/sql/explain_analyze.rs +++ b/datafusion/core/tests/sql/explain_analyze.rs @@ -22,6 +22,7 @@ use rstest::rstest; use datafusion::config::ConfigOptions; use datafusion::physical_plan::display::DisplayableExecutionPlan; use datafusion::physical_plan::metrics::Timestamp; +use datafusion_common::format::ExplainAnalyzeLevel; use object_store::path::Path; #[tokio::test] @@ -62,36 +63,59 @@ async fn explain_analyze_baseline_metrics() { "AggregateExec: mode=Partial, gby=[]", "metrics=[output_rows=3, elapsed_compute=" ); + assert_metrics!( + &formatted, + "AggregateExec: mode=Partial, gby=[]", + "output_bytes=" + ); assert_metrics!( &formatted, "AggregateExec: mode=FinalPartitioned, gby=[c1@0 as c1]", "metrics=[output_rows=5, elapsed_compute=" ); + assert_metrics!( + &formatted, + "AggregateExec: mode=FinalPartitioned, gby=[c1@0 as c1]", + "output_bytes=" + ); assert_metrics!( &formatted, "FilterExec: c13@1 != C2GT5KVyOPZpgKVl110TyZO0NcJ434", "metrics=[output_rows=99, elapsed_compute=" ); + assert_metrics!( + &formatted, + "FilterExec: c13@1 != C2GT5KVyOPZpgKVl110TyZO0NcJ434", + "output_bytes=" + ); assert_metrics!( &formatted, "ProjectionExec: expr=[]", "metrics=[output_rows=5, elapsed_compute=" ); + assert_metrics!(&formatted, "ProjectionExec: expr=[]", "output_bytes="); assert_metrics!( &formatted, "CoalesceBatchesExec: target_batch_size=4096", "metrics=[output_rows=5, elapsed_compute" ); + assert_metrics!( + &formatted, + "CoalesceBatchesExec: target_batch_size=4096", + "output_bytes=" + ); assert_metrics!( &formatted, "UnionExec", "metrics=[output_rows=3, elapsed_compute=" ); + assert_metrics!(&formatted, "UnionExec", "output_bytes="); assert_metrics!( &formatted, "WindowAggExec", "metrics=[output_rows=1, elapsed_compute=" ); + assert_metrics!(&formatted, "WindowAggExec", "output_bytes="); fn expected_to_have_metrics(plan: &dyn ExecutionPlan) -> bool { use datafusion::physical_plan; @@ -158,6 +182,81 @@ async fn explain_analyze_baseline_metrics() { fn nanos_from_timestamp(ts: &Timestamp) -> i64 { ts.value().unwrap().timestamp_nanos_opt().unwrap() } + +// Test different detail level for config `datafusion.explain.analyze_level` + +async fn collect_plan_with_context( + sql_str: &str, + ctx: &SessionContext, + level: ExplainAnalyzeLevel, +) -> String { + { + let state = ctx.state_ref(); + let mut state = state.write(); + state.config_mut().options_mut().explain.analyze_level = level; + } + let dataframe = ctx.sql(sql_str).await.unwrap(); + let batches = dataframe.collect().await.unwrap(); + arrow::util::pretty::pretty_format_batches(&batches) + .unwrap() + .to_string() +} + +async fn collect_plan(sql_str: &str, level: ExplainAnalyzeLevel) -> String { + let ctx = SessionContext::new(); + collect_plan_with_context(sql_str, &ctx, level).await +} + +#[tokio::test] +async fn explain_analyze_level() { + let sql = "EXPLAIN ANALYZE \ + SELECT * \ + FROM generate_series(10) as t1(v1) \ + ORDER BY v1 DESC"; + + for (level, needle, should_contain) in [ + (ExplainAnalyzeLevel::Summary, "spill_count", false), + (ExplainAnalyzeLevel::Summary, "output_rows", true), + (ExplainAnalyzeLevel::Dev, "spill_count", true), + (ExplainAnalyzeLevel::Dev, "output_rows", true), + ] { + let plan = collect_plan(sql, level).await; + assert_eq!( + plan.contains(needle), + should_contain, + "plan for level {level:?} unexpected content: {plan}" + ); + } +} + +#[tokio::test] +async fn explain_analyze_level_datasource_parquet() { + let table_name = "tpch_lineitem_small"; + let parquet_path = "tests/data/tpch_lineitem_small.parquet"; + let sql = format!("EXPLAIN ANALYZE SELECT * FROM {table_name}"); + + // Register test parquet file into context + let ctx = SessionContext::new(); + ctx.register_parquet(table_name, parquet_path, ParquetReadOptions::default()) + .await + .expect("register parquet table for explain analyze test"); + + for (level, needle, should_contain) in [ + (ExplainAnalyzeLevel::Summary, "metadata_load_time", true), + (ExplainAnalyzeLevel::Summary, "page_index_eval_time", false), + (ExplainAnalyzeLevel::Dev, "metadata_load_time", true), + (ExplainAnalyzeLevel::Dev, "page_index_eval_time", true), + ] { + let plan = collect_plan_with_context(&sql, &ctx, level).await; + + assert_eq!( + plan.contains(needle), + should_contain, + "plan for level {level:?} unexpected content: {plan}" + ); + } +} + #[tokio::test] async fn csv_explain_plans() { // This test verify the look of each plan in its full cycle plan creation diff --git a/datafusion/core/tests/sql/select.rs b/datafusion/core/tests/sql/select.rs index 98c3e3ccee8a..8a0f62062738 100644 --- a/datafusion/core/tests/sql/select.rs +++ b/datafusion/core/tests/sql/select.rs @@ -15,8 +15,11 @@ // specific language governing permissions and limitations // under the License. +use std::collections::HashMap; + use super::*; -use datafusion_common::ScalarValue; +use datafusion::assert_batches_eq; +use datafusion_common::{metadata::ScalarAndMetadata, ParamValues, ScalarValue}; use insta::assert_snapshot; #[tokio::test] @@ -219,11 +222,11 @@ async fn test_parameter_invalid_types() -> Result<()> { .collect() .await; assert_snapshot!(results.unwrap_err().strip_backtrace(), - @r#" - type_coercion - caused by - Error during planning: Cannot infer common argument type for comparison operation List(Field { name: "item", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }) = Int32 - "#); + @r" + type_coercion + caused by + Error during planning: Cannot infer common argument type for comparison operation List(nullable Int32) = Int32 + "); Ok(()) } @@ -317,6 +320,53 @@ async fn test_named_parameter_not_bound() -> Result<()> { Ok(()) } +#[tokio::test] +async fn test_query_parameters_with_metadata() -> Result<()> { + let ctx = SessionContext::new(); + + let df = ctx.sql("SELECT $1, $2").await.unwrap(); + + let metadata1 = HashMap::from([("some_key".to_string(), "some_value".to_string())]); + let metadata2 = + HashMap::from([("some_other_key".to_string(), "some_other_value".to_string())]); + + let df_with_params_replaced = df + .with_param_values(ParamValues::List(vec![ + ScalarAndMetadata::new( + ScalarValue::UInt32(Some(1)), + Some(metadata1.clone().into()), + ), + ScalarAndMetadata::new( + ScalarValue::Utf8(Some("two".to_string())), + Some(metadata2.clone().into()), + ), + ])) + .unwrap(); + + // df_with_params_replaced.schema() is not correct here + // https://github.com/apache/datafusion/issues/18102 + let batches = df_with_params_replaced.clone().collect().await.unwrap(); + let schema = batches[0].schema(); + + assert_eq!(schema.field(0).data_type(), &DataType::UInt32); + assert_eq!(schema.field(0).metadata(), &metadata1); + assert_eq!(schema.field(1).data_type(), &DataType::Utf8); + assert_eq!(schema.field(1).metadata(), &metadata2); + + assert_batches_eq!( + [ + "+----+-----+", + "| $1 | $2 |", + "+----+-----+", + "| 1 | two |", + "+----+-----+", + ], + &batches + ); + + Ok(()) +} + #[tokio::test] async fn test_version_function() { let expected_version = format!( diff --git a/datafusion/core/tests/user_defined/insert_operation.rs b/datafusion/core/tests/user_defined/insert_operation.rs index c8a4279a4211..e0a3e98604ae 100644 --- a/datafusion/core/tests/user_defined/insert_operation.rs +++ b/datafusion/core/tests/user_defined/insert_operation.rs @@ -15,7 +15,7 @@ // specific language governing permissions and limitations // under the License. -use std::{any::Any, sync::Arc}; +use std::{any::Any, str::FromStr, sync::Arc}; use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; use async_trait::async_trait; @@ -24,6 +24,7 @@ use datafusion::{ prelude::{SessionConfig, SessionContext}, }; use datafusion_catalog::{Session, TableProvider}; +use datafusion_common::config::Dialect; use datafusion_expr::{dml::InsertOp, Expr, TableType}; use datafusion_physical_expr::{EquivalenceProperties, Partitioning}; use datafusion_physical_plan::execution_plan::SchedulingType; @@ -63,7 +64,7 @@ async fn assert_insert_op(ctx: &SessionContext, sql: &str, insert_op: InsertOp) fn session_ctx_with_dialect(dialect: impl Into) -> SessionContext { let mut config = SessionConfig::new(); let options = config.options_mut(); - options.sql_parser.dialect = dialect.into(); + options.sql_parser.dialect = Dialect::from_str(&dialect.into()).unwrap(); SessionContext::new_with_config(config) } diff --git a/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs b/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs index f1af66de9b59..fb1371da6ceb 100644 --- a/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs +++ b/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs @@ -34,13 +34,13 @@ use datafusion::execution::context::{FunctionFactory, RegisterFunction, SessionS use datafusion::prelude::*; use datafusion::{execution::registry::FunctionRegistry, test_util}; use datafusion_common::cast::{as_float64_array, as_int32_array}; +use datafusion_common::metadata::FieldMetadata; use datafusion_common::tree_node::{Transformed, TreeNode}; use datafusion_common::utils::take_function_args; use datafusion_common::{ assert_batches_eq, assert_batches_sorted_eq, assert_contains, exec_datafusion_err, exec_err, not_impl_err, plan_err, DFSchema, DataFusionError, Result, ScalarValue, }; -use datafusion_expr::expr::FieldMetadata; use datafusion_expr::simplify::{ExprSimplifyResult, SimplifyInfo}; use datafusion_expr::{ lit_with_metadata, Accumulator, ColumnarValue, CreateFunction, CreateFunctionBody, diff --git a/datafusion/datasource-arrow/Cargo.toml b/datafusion/datasource-arrow/Cargo.toml new file mode 100644 index 000000000000..b3d1e3f2accc --- /dev/null +++ b/datafusion/datasource-arrow/Cargo.toml @@ -0,0 +1,64 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +[package] +name = "datafusion-datasource-arrow" +description = "datafusion-datasource-arrow" +readme = "README.md" +authors.workspace = true +edition.workspace = true +homepage.workspace = true +license.workspace = true +repository.workspace = true +rust-version.workspace = true +version.workspace = true + +[package.metadata.docs.rs] +all-features = true + +[dependencies] +arrow = { workspace = true } +arrow-ipc = { workspace = true } +async-trait = { workspace = true } +bytes = { workspace = true } +datafusion-common = { workspace = true, features = ["object_store"] } +datafusion-common-runtime = { workspace = true } +datafusion-datasource = { workspace = true } +datafusion-execution = { workspace = true } +datafusion-expr = { workspace = true } +datafusion-physical-expr-common = { workspace = true } +datafusion-physical-plan = { workspace = true } +datafusion-session = { workspace = true } +futures = { workspace = true } +itertools = { workspace = true } +object_store = { workspace = true } +tokio = { workspace = true } + +[dev-dependencies] +chrono = { workspace = true } + +[lints] +workspace = true + +[lib] +name = "datafusion_datasource_arrow" +path = "src/mod.rs" + +[features] +compression = [ + "arrow-ipc/zstd", +] diff --git a/datafusion/datasource-arrow/LICENSE.txt b/datafusion/datasource-arrow/LICENSE.txt new file mode 100644 index 000000000000..d74c6b599d2a --- /dev/null +++ b/datafusion/datasource-arrow/LICENSE.txt @@ -0,0 +1,212 @@ + + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + + +This project includes code from Apache Aurora. + +* dev/release/{release,changelog,release-candidate} are based on the scripts from + Apache Aurora + +Copyright: 2016 The Apache Software Foundation. +Home page: https://aurora.apache.org/ +License: http://www.apache.org/licenses/LICENSE-2.0 diff --git a/datafusion/datasource-arrow/NOTICE.txt b/datafusion/datasource-arrow/NOTICE.txt new file mode 100644 index 000000000000..7f3c80d606c0 --- /dev/null +++ b/datafusion/datasource-arrow/NOTICE.txt @@ -0,0 +1,5 @@ +Apache DataFusion +Copyright 2019-2025 The Apache Software Foundation + +This product includes software developed at +The Apache Software Foundation (http://www.apache.org/). diff --git a/datafusion/datasource-arrow/README.md b/datafusion/datasource-arrow/README.md new file mode 100644 index 000000000000..9901b52105dd --- /dev/null +++ b/datafusion/datasource-arrow/README.md @@ -0,0 +1,34 @@ + + +# Apache DataFusion Arrow DataSource + +[Apache DataFusion] is an extensible query execution framework, written in Rust, that uses [Apache Arrow] as its in-memory format. + +This crate is a submodule of DataFusion that defines a Arrow based file source. +It works with files following the [Arrow IPC format]. + +Most projects should use the [`datafusion`] crate directly, which re-exports +this module. If you are already using the [`datafusion`] crate, there is no +reason to use this crate directly in your project as well. + +[apache arrow]: https://arrow.apache.org/ +[apache datafusion]: https://datafusion.apache.org/ +[`datafusion`]: https://crates.io/crates/datafusion +[arrow ipc format]: https://arrow.apache.org/docs/format/Columnar.html#ipc-file-format diff --git a/datafusion/datasource-arrow/src/file_format.rs b/datafusion/datasource-arrow/src/file_format.rs new file mode 100644 index 000000000000..3b8564080421 --- /dev/null +++ b/datafusion/datasource-arrow/src/file_format.rs @@ -0,0 +1,603 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! [`ArrowFormat`]: Apache Arrow [`FileFormat`] abstractions +//! +//! Works with files following the [Arrow IPC format](https://arrow.apache.org/docs/format/Columnar.html#ipc-file-format) + +use std::any::Any; +use std::borrow::Cow; +use std::collections::HashMap; +use std::fmt::{self, Debug}; +use std::sync::Arc; + +use arrow::datatypes::{Schema, SchemaRef}; +use arrow::error::ArrowError; +use arrow::ipc::convert::fb_to_schema; +use arrow::ipc::reader::FileReader; +use arrow::ipc::writer::IpcWriteOptions; +use arrow::ipc::{root_as_message, CompressionType}; +use datafusion_common::error::Result; +use datafusion_common::parsers::CompressionTypeVariant; +use datafusion_common::{ + internal_datafusion_err, not_impl_err, DataFusionError, GetExt, Statistics, + DEFAULT_ARROW_EXTENSION, +}; +use datafusion_common_runtime::{JoinSet, SpawnedTask}; +use datafusion_datasource::display::FileGroupDisplay; +use datafusion_datasource::file::FileSource; +use datafusion_datasource::file_scan_config::{FileScanConfig, FileScanConfigBuilder}; +use datafusion_datasource::sink::{DataSink, DataSinkExec}; +use datafusion_datasource::write::{ + get_writer_schema, ObjectWriterBuilder, SharedBuffer, +}; +use datafusion_execution::{SendableRecordBatchStream, TaskContext}; +use datafusion_expr::dml::InsertOp; +use datafusion_physical_expr_common::sort_expr::LexRequirement; + +use crate::source::ArrowSource; +use async_trait::async_trait; +use bytes::Bytes; +use datafusion_datasource::file_compression_type::FileCompressionType; +use datafusion_datasource::file_format::{FileFormat, FileFormatFactory}; +use datafusion_datasource::file_sink_config::{FileSink, FileSinkConfig}; +use datafusion_datasource::source::DataSourceExec; +use datafusion_datasource::write::demux::DemuxedStreamReceiver; +use datafusion_physical_plan::{DisplayAs, DisplayFormatType, ExecutionPlan}; +use datafusion_session::Session; +use futures::stream::BoxStream; +use futures::StreamExt; +use object_store::{GetResultPayload, ObjectMeta, ObjectStore}; +use tokio::io::AsyncWriteExt; + +/// Initial writing buffer size. Note this is just a size hint for efficiency. It +/// will grow beyond the set value if needed. +const INITIAL_BUFFER_BYTES: usize = 1048576; + +/// If the buffered Arrow data exceeds this size, it is flushed to object store +const BUFFER_FLUSH_BYTES: usize = 1024000; + +#[derive(Default, Debug)] +/// Factory struct used to create [ArrowFormat] +pub struct ArrowFormatFactory; + +impl ArrowFormatFactory { + /// Creates an instance of [ArrowFormatFactory] + pub fn new() -> Self { + Self {} + } +} + +impl FileFormatFactory for ArrowFormatFactory { + fn create( + &self, + _state: &dyn Session, + _format_options: &HashMap, + ) -> Result> { + Ok(Arc::new(ArrowFormat)) + } + + fn default(&self) -> Arc { + Arc::new(ArrowFormat) + } + + fn as_any(&self) -> &dyn Any { + self + } +} + +impl GetExt for ArrowFormatFactory { + fn get_ext(&self) -> String { + // Removes the dot, i.e. ".parquet" -> "parquet" + DEFAULT_ARROW_EXTENSION[1..].to_string() + } +} + +/// Arrow `FileFormat` implementation. +#[derive(Default, Debug)] +pub struct ArrowFormat; + +#[async_trait] +impl FileFormat for ArrowFormat { + fn as_any(&self) -> &dyn Any { + self + } + + fn get_ext(&self) -> String { + ArrowFormatFactory::new().get_ext() + } + + fn get_ext_with_compression( + &self, + file_compression_type: &FileCompressionType, + ) -> Result { + let ext = self.get_ext(); + match file_compression_type.get_variant() { + CompressionTypeVariant::UNCOMPRESSED => Ok(ext), + _ => Err(internal_datafusion_err!( + "Arrow FileFormat does not support compression." + )), + } + } + + fn compression_type(&self) -> Option { + None + } + + async fn infer_schema( + &self, + _state: &dyn Session, + store: &Arc, + objects: &[ObjectMeta], + ) -> Result { + let mut schemas = vec![]; + for object in objects { + let r = store.as_ref().get(&object.location).await?; + let schema = match r.payload { + #[cfg(not(target_arch = "wasm32"))] + GetResultPayload::File(mut file, _) => { + let reader = FileReader::try_new(&mut file, None)?; + reader.schema() + } + GetResultPayload::Stream(stream) => { + infer_schema_from_file_stream(stream).await? + } + }; + schemas.push(schema.as_ref().clone()); + } + let merged_schema = Schema::try_merge(schemas)?; + Ok(Arc::new(merged_schema)) + } + + async fn infer_stats( + &self, + _state: &dyn Session, + _store: &Arc, + table_schema: SchemaRef, + _object: &ObjectMeta, + ) -> Result { + Ok(Statistics::new_unknown(&table_schema)) + } + + async fn create_physical_plan( + &self, + _state: &dyn Session, + conf: FileScanConfig, + ) -> Result> { + let source = Arc::new(ArrowSource::default()); + let config = FileScanConfigBuilder::from(conf) + .with_source(source) + .build(); + + Ok(DataSourceExec::from_data_source(config)) + } + + async fn create_writer_physical_plan( + &self, + input: Arc, + _state: &dyn Session, + conf: FileSinkConfig, + order_requirements: Option, + ) -> Result> { + if conf.insert_op != InsertOp::Append { + return not_impl_err!("Overwrites are not implemented yet for Arrow format"); + } + + let sink = Arc::new(ArrowFileSink::new(conf)); + + Ok(Arc::new(DataSinkExec::new(input, sink, order_requirements)) as _) + } + + fn file_source(&self) -> Arc { + Arc::new(ArrowSource::default()) + } +} + +/// Implements [`FileSink`] for writing to arrow_ipc files +struct ArrowFileSink { + config: FileSinkConfig, +} + +impl ArrowFileSink { + fn new(config: FileSinkConfig) -> Self { + Self { config } + } +} + +#[async_trait] +impl FileSink for ArrowFileSink { + fn config(&self) -> &FileSinkConfig { + &self.config + } + + async fn spawn_writer_tasks_and_join( + &self, + context: &Arc, + demux_task: SpawnedTask>, + mut file_stream_rx: DemuxedStreamReceiver, + object_store: Arc, + ) -> Result { + let mut file_write_tasks: JoinSet> = + JoinSet::new(); + + let ipc_options = + IpcWriteOptions::try_new(64, false, arrow_ipc::MetadataVersion::V5)? + .try_with_compression(Some(CompressionType::LZ4_FRAME))?; + while let Some((path, mut rx)) = file_stream_rx.recv().await { + let shared_buffer = SharedBuffer::new(INITIAL_BUFFER_BYTES); + let mut arrow_writer = arrow_ipc::writer::FileWriter::try_new_with_options( + shared_buffer.clone(), + &get_writer_schema(&self.config), + ipc_options.clone(), + )?; + let mut object_store_writer = ObjectWriterBuilder::new( + FileCompressionType::UNCOMPRESSED, + &path, + Arc::clone(&object_store), + ) + .with_buffer_size(Some( + context + .session_config() + .options() + .execution + .objectstore_writer_buffer_size, + )) + .build()?; + file_write_tasks.spawn(async move { + let mut row_count = 0; + while let Some(batch) = rx.recv().await { + row_count += batch.num_rows(); + arrow_writer.write(&batch)?; + let mut buff_to_flush = shared_buffer.buffer.try_lock().unwrap(); + if buff_to_flush.len() > BUFFER_FLUSH_BYTES { + object_store_writer + .write_all(buff_to_flush.as_slice()) + .await?; + buff_to_flush.clear(); + } + } + arrow_writer.finish()?; + let final_buff = shared_buffer.buffer.try_lock().unwrap(); + + object_store_writer.write_all(final_buff.as_slice()).await?; + object_store_writer.shutdown().await?; + Ok(row_count) + }); + } + + let mut row_count = 0; + while let Some(result) = file_write_tasks.join_next().await { + match result { + Ok(r) => { + row_count += r?; + } + Err(e) => { + if e.is_panic() { + std::panic::resume_unwind(e.into_panic()); + } else { + unreachable!(); + } + } + } + } + + demux_task + .join_unwind() + .await + .map_err(|e| DataFusionError::ExecutionJoin(Box::new(e)))??; + Ok(row_count as u64) + } +} + +impl Debug for ArrowFileSink { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("ArrowFileSink").finish() + } +} + +impl DisplayAs for ArrowFileSink { + fn fmt_as(&self, t: DisplayFormatType, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match t { + DisplayFormatType::Default | DisplayFormatType::Verbose => { + write!(f, "ArrowFileSink(file_groups=",)?; + FileGroupDisplay(&self.config.file_group).fmt_as(t, f)?; + write!(f, ")") + } + DisplayFormatType::TreeRender => { + writeln!(f, "format: arrow")?; + write!(f, "file={}", &self.config.original_url) + } + } + } +} + +#[async_trait] +impl DataSink for ArrowFileSink { + fn as_any(&self) -> &dyn Any { + self + } + + fn schema(&self) -> &SchemaRef { + self.config.output_schema() + } + + async fn write_all( + &self, + data: SendableRecordBatchStream, + context: &Arc, + ) -> Result { + FileSink::write_all(self, data, context).await + } +} + +const ARROW_MAGIC: [u8; 6] = [b'A', b'R', b'R', b'O', b'W', b'1']; +const CONTINUATION_MARKER: [u8; 4] = [0xff; 4]; + +/// Custom implementation of inferring schema. Should eventually be moved upstream to arrow-rs. +/// See +async fn infer_schema_from_file_stream( + mut stream: BoxStream<'static, object_store::Result>, +) -> Result { + // Expected format: + // - 6 bytes + // - 2 bytes + // - 4 bytes, not present below v0.15.0 + // - 4 bytes + // + // + + // So in first read we need at least all known sized sections, + // which is 6 + 2 + 4 + 4 = 16 bytes. + let bytes = collect_at_least_n_bytes(&mut stream, 16, None).await?; + + // Files should start with these magic bytes + if bytes[0..6] != ARROW_MAGIC { + return Err(ArrowError::ParseError( + "Arrow file does not contain correct header".to_string(), + ))?; + } + + // Since continuation marker bytes added in later versions + let (meta_len, rest_of_bytes_start_index) = if bytes[8..12] == CONTINUATION_MARKER { + (&bytes[12..16], 16) + } else { + (&bytes[8..12], 12) + }; + + let meta_len = [meta_len[0], meta_len[1], meta_len[2], meta_len[3]]; + let meta_len = i32::from_le_bytes(meta_len); + + // Read bytes for Schema message + let block_data = if bytes[rest_of_bytes_start_index..].len() < meta_len as usize { + // Need to read more bytes to decode Message + let mut block_data = Vec::with_capacity(meta_len as usize); + // In case we had some spare bytes in our initial read chunk + block_data.extend_from_slice(&bytes[rest_of_bytes_start_index..]); + let size_to_read = meta_len as usize - block_data.len(); + let block_data = + collect_at_least_n_bytes(&mut stream, size_to_read, Some(block_data)).await?; + Cow::Owned(block_data) + } else { + // Already have the bytes we need + let end_index = meta_len as usize + rest_of_bytes_start_index; + let block_data = &bytes[rest_of_bytes_start_index..end_index]; + Cow::Borrowed(block_data) + }; + + // Decode Schema message + let message = root_as_message(&block_data).map_err(|err| { + ArrowError::ParseError(format!("Unable to read IPC message as metadata: {err:?}")) + })?; + let ipc_schema = message.header_as_schema().ok_or_else(|| { + ArrowError::IpcError("Unable to read IPC message as schema".to_string()) + })?; + let schema = fb_to_schema(ipc_schema); + + Ok(Arc::new(schema)) +} + +async fn collect_at_least_n_bytes( + stream: &mut BoxStream<'static, object_store::Result>, + n: usize, + extend_from: Option>, +) -> Result> { + let mut buf = extend_from.unwrap_or_else(|| Vec::with_capacity(n)); + // If extending existing buffer then ensure we read n additional bytes + let n = n + buf.len(); + while let Some(bytes) = stream.next().await.transpose()? { + buf.extend_from_slice(&bytes); + if buf.len() >= n { + break; + } + } + if buf.len() < n { + return Err(ArrowError::ParseError( + "Unexpected end of byte stream for Arrow IPC file".to_string(), + ))?; + } + Ok(buf) +} + +#[cfg(test)] +mod tests { + use super::*; + + use chrono::DateTime; + use datafusion_common::config::TableOptions; + use datafusion_common::DFSchema; + use datafusion_execution::config::SessionConfig; + use datafusion_execution::runtime_env::RuntimeEnv; + use datafusion_expr::execution_props::ExecutionProps; + use datafusion_expr::{AggregateUDF, Expr, LogicalPlan, ScalarUDF, WindowUDF}; + use datafusion_physical_expr_common::physical_expr::PhysicalExpr; + use object_store::{chunked::ChunkedStore, memory::InMemory, path::Path}; + + struct MockSession { + config: SessionConfig, + runtime_env: Arc, + } + + impl MockSession { + fn new() -> Self { + Self { + config: SessionConfig::new(), + runtime_env: Arc::new(RuntimeEnv::default()), + } + } + } + + #[async_trait::async_trait] + impl Session for MockSession { + fn session_id(&self) -> &str { + unimplemented!() + } + + fn config(&self) -> &SessionConfig { + &self.config + } + + async fn create_physical_plan( + &self, + _logical_plan: &LogicalPlan, + ) -> Result> { + unimplemented!() + } + + fn create_physical_expr( + &self, + _expr: Expr, + _df_schema: &DFSchema, + ) -> Result> { + unimplemented!() + } + + fn scalar_functions(&self) -> &HashMap> { + unimplemented!() + } + + fn aggregate_functions(&self) -> &HashMap> { + unimplemented!() + } + + fn window_functions(&self) -> &HashMap> { + unimplemented!() + } + + fn runtime_env(&self) -> &Arc { + &self.runtime_env + } + + fn execution_props(&self) -> &ExecutionProps { + unimplemented!() + } + + fn as_any(&self) -> &dyn Any { + unimplemented!() + } + + fn table_options(&self) -> &TableOptions { + unimplemented!() + } + + fn table_options_mut(&mut self) -> &mut TableOptions { + unimplemented!() + } + + fn task_ctx(&self) -> Arc { + unimplemented!() + } + } + + #[tokio::test] + async fn test_infer_schema_stream() -> Result<()> { + let mut bytes = std::fs::read("tests/data/example.arrow")?; + bytes.truncate(bytes.len() - 20); // mangle end to show we don't need to read whole file + let location = Path::parse("example.arrow")?; + let in_memory_store: Arc = Arc::new(InMemory::new()); + in_memory_store.put(&location, bytes.into()).await?; + + let state = MockSession::new(); + let object_meta = ObjectMeta { + location, + last_modified: DateTime::default(), + size: u64::MAX, + e_tag: None, + version: None, + }; + + let arrow_format = ArrowFormat {}; + let expected = vec!["f0: Int64", "f1: Utf8", "f2: Boolean"]; + + // Test chunk sizes where too small so we keep having to read more bytes + // And when large enough that first read contains all we need + for chunk_size in [7, 3000] { + let store = Arc::new(ChunkedStore::new(in_memory_store.clone(), chunk_size)); + let inferred_schema = arrow_format + .infer_schema( + &state, + &(store.clone() as Arc), + std::slice::from_ref(&object_meta), + ) + .await?; + let actual_fields = inferred_schema + .fields() + .iter() + .map(|f| format!("{}: {:?}", f.name(), f.data_type())) + .collect::>(); + assert_eq!(expected, actual_fields); + } + + Ok(()) + } + + #[tokio::test] + async fn test_infer_schema_short_stream() -> Result<()> { + let mut bytes = std::fs::read("tests/data/example.arrow")?; + bytes.truncate(20); // should cause error that file shorter than expected + let location = Path::parse("example.arrow")?; + let in_memory_store: Arc = Arc::new(InMemory::new()); + in_memory_store.put(&location, bytes.into()).await?; + + let state = MockSession::new(); + let object_meta = ObjectMeta { + location, + last_modified: DateTime::default(), + size: u64::MAX, + e_tag: None, + version: None, + }; + + let arrow_format = ArrowFormat {}; + + let store = Arc::new(ChunkedStore::new(in_memory_store.clone(), 7)); + let err = arrow_format + .infer_schema( + &state, + &(store.clone() as Arc), + std::slice::from_ref(&object_meta), + ) + .await; + + assert!(err.is_err()); + assert_eq!( + "Arrow error: Parser error: Unexpected end of byte stream for Arrow IPC file", + err.unwrap_err().to_string().lines().next().unwrap() + ); + + Ok(()) + } +} diff --git a/datafusion/datasource-arrow/src/mod.rs b/datafusion/datasource-arrow/src/mod.rs new file mode 100644 index 000000000000..18bb8792c3ff --- /dev/null +++ b/datafusion/datasource-arrow/src/mod.rs @@ -0,0 +1,25 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +// Make sure fast / cheap clones on Arc are explicit: +// https://github.com/apache/datafusion/issues/11143 +#![cfg_attr(not(test), deny(clippy::clone_on_ref_ptr))] + +pub mod file_format; +pub mod source; + +pub use file_format::*; diff --git a/datafusion/core/src/datasource/physical_plan/arrow_file.rs b/datafusion/datasource-arrow/src/source.rs similarity index 98% rename from datafusion/core/src/datasource/physical_plan/arrow_file.rs rename to datafusion/datasource-arrow/src/source.rs index b37dc499d403..f43f11880182 100644 --- a/datafusion/core/src/datasource/physical_plan/arrow_file.rs +++ b/datafusion/datasource-arrow/src/source.rs @@ -18,20 +18,21 @@ use std::any::Any; use std::sync::Arc; -use crate::datasource::physical_plan::{FileOpenFuture, FileOpener}; -use crate::error::Result; use datafusion_datasource::as_file_source; use datafusion_datasource::schema_adapter::SchemaAdapterFactory; use arrow::buffer::Buffer; use arrow::datatypes::SchemaRef; use arrow_ipc::reader::FileDecoder; +use datafusion_common::error::Result; use datafusion_common::{exec_datafusion_err, Statistics}; use datafusion_datasource::file::FileSource; use datafusion_datasource::file_scan_config::FileScanConfig; use datafusion_datasource::PartitionedFile; use datafusion_physical_plan::metrics::ExecutionPlanMetricsSet; +use datafusion_datasource::file_stream::FileOpenFuture; +use datafusion_datasource::file_stream::FileOpener; use futures::StreamExt; use itertools::Itertools; use object_store::{GetOptions, GetRange, GetResultPayload, ObjectStore}; diff --git a/datafusion/core/tests/data/example.arrow b/datafusion/datasource-arrow/tests/data/example.arrow similarity index 100% rename from datafusion/core/tests/data/example.arrow rename to datafusion/datasource-arrow/tests/data/example.arrow diff --git a/datafusion/datasource-parquet/src/file_format.rs b/datafusion/datasource-parquet/src/file_format.rs index 963c1d77950c..f27bda387fda 100644 --- a/datafusion/datasource-parquet/src/file_format.rs +++ b/datafusion/datasource-parquet/src/file_format.rs @@ -38,8 +38,6 @@ use datafusion_datasource::write::demux::DemuxedStreamReceiver; use arrow::datatypes::{DataType, Field, FieldRef}; use datafusion_common::config::{ConfigField, ConfigFileType, TableParquetOptions}; -#[cfg(feature = "parquet_encryption")] -use datafusion_common::encryption::map_config_decryption_to_decryption; use datafusion_common::encryption::FileDecryptionProperties; use datafusion_common::parsers::CompressionTypeVariant; use datafusion_common::{ @@ -59,11 +57,13 @@ use datafusion_physical_expr_common::sort_expr::LexRequirement; use datafusion_physical_plan::{DisplayAs, DisplayFormatType, ExecutionPlan}; use datafusion_session::Session; +use crate::metadata::DFParquetMetadata; use crate::reader::CachedParquetFileReaderFactory; use crate::source::{parse_coerce_int96_string, ParquetSource}; use async_trait::async_trait; use bytes::Bytes; use datafusion_datasource::source::DataSourceExec; +use datafusion_execution::cache::cache_manager::FileMetadataCache; use datafusion_execution::runtime_env::RuntimeEnv; use futures::future::BoxFuture; use futures::{FutureExt, StreamExt, TryStreamExt}; @@ -77,14 +77,12 @@ use parquet::arrow::arrow_writer::{ use parquet::arrow::async_reader::MetadataFetch; use parquet::arrow::{ArrowWriter, AsyncArrowWriter}; use parquet::basic::Type; - -use crate::metadata::DFParquetMetadata; -use datafusion_execution::cache::cache_manager::FileMetadataCache; +#[cfg(feature = "parquet_encryption")] +use parquet::encryption::encrypt::FileEncryptionProperties; use parquet::errors::ParquetError; use parquet::file::metadata::ParquetMetaData; use parquet::file::properties::{WriterProperties, WriterPropertiesBuilder}; use parquet::file::writer::SerializedFileWriter; -use parquet::format::FileMetaData; use parquet::schema::types::SchemaDescriptor; use tokio::io::{AsyncWrite, AsyncWriteExt}; use tokio::sync::mpsc::{self, Receiver, Sender}; @@ -306,25 +304,23 @@ async fn get_file_decryption_properties( state: &dyn Session, options: &TableParquetOptions, file_path: &Path, -) -> Result> { - let file_decryption_properties: Option = - match &options.crypto.file_decryption { - Some(cfd) => Some(map_config_decryption_to_decryption(cfd)), - None => match &options.crypto.factory_id { - Some(factory_id) => { - let factory = - state.runtime_env().parquet_encryption_factory(factory_id)?; - factory - .get_file_decryption_properties( - &options.crypto.factory_options, - file_path, - ) - .await? - } - None => None, - }, - }; - Ok(file_decryption_properties) +) -> Result>> { + Ok(match &options.crypto.file_decryption { + Some(cfd) => Some(Arc::new(FileDecryptionProperties::from(cfd.clone()))), + None => match &options.crypto.factory_id { + Some(factory_id) => { + let factory = + state.runtime_env().parquet_encryption_factory(factory_id)?; + factory + .get_file_decryption_properties( + &options.crypto.factory_options, + file_path, + ) + .await? + } + None => None, + }, + }) } #[cfg(not(feature = "parquet_encryption"))] @@ -332,7 +328,7 @@ async fn get_file_decryption_properties( _state: &dyn Session, _options: &TableParquetOptions, _file_path: &Path, -) -> Result> { +) -> Result>> { Ok(None) } @@ -385,7 +381,7 @@ impl FileFormat for ParquetFormat { .await?; let result = DFParquetMetadata::new(store.as_ref(), object) .with_metadata_size_hint(self.metadata_size_hint()) - .with_decryption_properties(file_decryption_properties.as_ref()) + .with_decryption_properties(file_decryption_properties) .with_file_metadata_cache(Some(Arc::clone(&file_metadata_cache))) .with_coerce_int96(coerce_int96) .fetch_schema_with_location() @@ -446,7 +442,7 @@ impl FileFormat for ParquetFormat { state.runtime_env().cache_manager.get_file_metadata_cache(); DFParquetMetadata::new(store, object) .with_metadata_size_hint(self.metadata_size_hint()) - .with_decryption_properties(file_decryption_properties.as_ref()) + .with_decryption_properties(file_decryption_properties) .with_file_metadata_cache(Some(file_metadata_cache)) .fetch_statistics(&table_schema) .await @@ -1027,9 +1023,10 @@ pub async fn fetch_parquet_metadata( store: &dyn ObjectStore, object_meta: &ObjectMeta, size_hint: Option, - #[allow(unused)] decryption_properties: Option<&FileDecryptionProperties>, + decryption_properties: Option<&FileDecryptionProperties>, file_metadata_cache: Option>, ) -> Result> { + let decryption_properties = decryption_properties.cloned().map(Arc::new); DFParquetMetadata::new(store, object_meta) .with_metadata_size_hint(size_hint) .with_decryption_properties(decryption_properties) @@ -1053,6 +1050,7 @@ pub async fn fetch_statistics( decryption_properties: Option<&FileDecryptionProperties>, file_metadata_cache: Option>, ) -> Result { + let decryption_properties = decryption_properties.cloned().map(Arc::new); DFParquetMetadata::new(store, file) .with_metadata_size_hint(metadata_size_hint) .with_decryption_properties(decryption_properties) @@ -1080,7 +1078,7 @@ pub struct ParquetSink { parquet_options: TableParquetOptions, /// File metadata from successfully produced parquet files. The Mutex is only used /// to allow inserting to HashMap from behind borrowed reference in DataSink::write_all. - written: Arc>>, + written: Arc>>, } impl Debug for ParquetSink { @@ -1117,7 +1115,7 @@ impl ParquetSink { /// Retrieve the file metadata for the written files, keyed to the path /// which may be partitioned (in the case of hive style partitioning). - pub fn written(&self) -> HashMap { + pub fn written(&self) -> HashMap { self.written.lock().clone() } @@ -1141,7 +1139,7 @@ impl ParquetSink { builder = set_writer_encryption_properties( builder, runtime, - &parquet_opts, + parquet_opts, schema, path, ) @@ -1189,14 +1187,15 @@ impl ParquetSink { async fn set_writer_encryption_properties( builder: WriterPropertiesBuilder, runtime: &Arc, - parquet_opts: &TableParquetOptions, + parquet_opts: TableParquetOptions, schema: &Arc, path: &Path, ) -> Result { - if let Some(file_encryption_properties) = &parquet_opts.crypto.file_encryption { + if let Some(file_encryption_properties) = parquet_opts.crypto.file_encryption { // Encryption properties have been specified directly - return Ok(builder - .with_file_encryption_properties(file_encryption_properties.clone().into())); + return Ok(builder.with_file_encryption_properties(Arc::new( + FileEncryptionProperties::from(file_encryption_properties), + ))); } else if let Some(encryption_factory_id) = &parquet_opts.crypto.factory_id.as_ref() { // Encryption properties will be generated by an encryption factory let encryption_factory = @@ -1221,7 +1220,7 @@ async fn set_writer_encryption_properties( async fn set_writer_encryption_properties( builder: WriterPropertiesBuilder, _runtime: &Arc, - _parquet_opts: &TableParquetOptions, + _parquet_opts: TableParquetOptions, _schema: &Arc, _path: &Path, ) -> Result { @@ -1244,7 +1243,7 @@ impl FileSink for ParquetSink { let parquet_opts = &self.parquet_options; let mut file_write_tasks: JoinSet< - std::result::Result<(Path, FileMetaData), DataFusionError>, + std::result::Result<(Path, ParquetMetaData), DataFusionError>, > = JoinSet::new(); let runtime = context.runtime_env(); @@ -1275,11 +1274,11 @@ impl FileSink for ParquetSink { writer.write(&batch).await?; reservation.try_resize(writer.memory_size())?; } - let file_metadata = writer + let parquet_meta_data = writer .close() .await .map_err(|e| DataFusionError::ParquetError(Box::new(e)))?; - Ok((path, file_metadata)) + Ok((path, parquet_meta_data)) }); } else { let writer = ObjectWriterBuilder::new( @@ -1303,7 +1302,7 @@ impl FileSink for ParquetSink { let parallel_options_clone = parallel_options.clone(); let pool = Arc::clone(context.memory_pool()); file_write_tasks.spawn(async move { - let file_metadata = output_single_parquet_file_parallelized( + let parquet_meta_data = output_single_parquet_file_parallelized( writer, rx, schema, @@ -1313,7 +1312,7 @@ impl FileSink for ParquetSink { pool, ) .await?; - Ok((path, file_metadata)) + Ok((path, parquet_meta_data)) }); } } @@ -1322,11 +1321,11 @@ impl FileSink for ParquetSink { while let Some(result) = file_write_tasks.join_next().await { match result { Ok(r) => { - let (path, file_metadata) = r?; - row_count += file_metadata.num_rows; + let (path, parquet_meta_data) = r?; + row_count += parquet_meta_data.file_metadata().num_rows(); let mut written_files = self.written.lock(); written_files - .try_insert(path.clone(), file_metadata) + .try_insert(path.clone(), parquet_meta_data) .map_err(|e| internal_datafusion_err!("duplicate entry detected for partitioned file {path}: {e}"))?; drop(written_files); } @@ -1589,7 +1588,7 @@ async fn concatenate_parallel_row_groups( mut serialize_rx: Receiver>, mut object_store_writer: Box, pool: Arc, -) -> Result { +) -> Result { let mut file_reservation = MemoryConsumer::new("ParquetSink(SerializedFileWriter)").register(&pool); @@ -1617,14 +1616,14 @@ async fn concatenate_parallel_row_groups( rg_out.close()?; } - let file_metadata = parquet_writer.close()?; + let parquet_meta_data = parquet_writer.close()?; let final_buff = merged_buff.buffer.try_lock().unwrap(); object_store_writer.write_all(final_buff.as_slice()).await?; object_store_writer.shutdown().await?; file_reservation.free(); - Ok(file_metadata) + Ok(parquet_meta_data) } /// Parallelizes the serialization of a single parquet file, by first serializing N @@ -1639,7 +1638,7 @@ async fn output_single_parquet_file_parallelized( skip_arrow_metadata: bool, parallel_options: ParallelParquetWriterOptions, pool: Arc, -) -> Result { +) -> Result { let max_rowgroups = parallel_options.max_parallel_row_groups; // Buffer size of this channel limits maximum number of RowGroups being worked on in parallel let (serialize_tx, serialize_rx) = @@ -1666,7 +1665,7 @@ async fn output_single_parquet_file_parallelized( parallel_options, Arc::clone(&pool), ); - let file_metadata = concatenate_parallel_row_groups( + let parquet_meta_data = concatenate_parallel_row_groups( writer, merged_buff, serialize_rx, @@ -1679,7 +1678,7 @@ async fn output_single_parquet_file_parallelized( .join_unwind() .await .map_err(|e| DataFusionError::ExecutionJoin(Box::new(e)))??; - Ok(file_metadata) + Ok(parquet_meta_data) } #[cfg(test)] diff --git a/datafusion/datasource-parquet/src/metadata.rs b/datafusion/datasource-parquet/src/metadata.rs index 4de68793ce02..6505a447d7ce 100644 --- a/datafusion/datasource-parquet/src/metadata.rs +++ b/datafusion/datasource-parquet/src/metadata.rs @@ -58,7 +58,7 @@ pub struct DFParquetMetadata<'a> { store: &'a dyn ObjectStore, object_meta: &'a ObjectMeta, metadata_size_hint: Option, - decryption_properties: Option<&'a FileDecryptionProperties>, + decryption_properties: Option>, file_metadata_cache: Option>, /// timeunit to coerce INT96 timestamps to pub coerce_int96: Option, @@ -85,7 +85,7 @@ impl<'a> DFParquetMetadata<'a> { /// set decryption properties pub fn with_decryption_properties( mut self, - decryption_properties: Option<&'a FileDecryptionProperties>, + decryption_properties: Option>, ) -> Self { self.decryption_properties = decryption_properties; self @@ -145,7 +145,8 @@ impl<'a> DFParquetMetadata<'a> { #[cfg(feature = "parquet_encryption")] if let Some(decryption_properties) = decryption_properties { - reader = reader.with_decryption_properties(Some(decryption_properties)); + reader = reader + .with_decryption_properties(Some(Arc::clone(decryption_properties))); } if cache_metadata && file_metadata_cache.is_some() { @@ -299,7 +300,6 @@ impl<'a> DFParquetMetadata<'a> { summarize_min_max_null_counts( &mut accumulators, idx, - num_rows, &stats_converter, row_groups_metadata, ) @@ -417,7 +417,6 @@ struct StatisticsAccumulators<'a> { fn summarize_min_max_null_counts( accumulators: &mut StatisticsAccumulators, arrow_schema_index: usize, - num_rows: usize, stats_converter: &StatisticsConverter, row_groups_metadata: &[RowGroupMetaData], ) -> Result<()> { @@ -449,11 +448,14 @@ fn summarize_min_max_null_counts( ); } - accumulators.null_counts_array[arrow_schema_index] = - Precision::Exact(match sum(&null_counts) { - Some(null_count) => null_count as usize, - None => num_rows, - }); + accumulators.null_counts_array[arrow_schema_index] = match sum(&null_counts) { + Some(null_count) => Precision::Exact(null_count as usize), + None => match null_counts.len() { + // If sum() returned None we either have no rows or all values are null + 0 => Precision::Exact(0), + _ => Precision::Absent, + }, + }; Ok(()) } diff --git a/datafusion/datasource-parquet/src/metrics.rs b/datafusion/datasource-parquet/src/metrics.rs index d75a979d4cad..5f17fbb4b9ee 100644 --- a/datafusion/datasource-parquet/src/metrics.rs +++ b/datafusion/datasource-parquet/src/metrics.rs @@ -16,7 +16,7 @@ // under the License. use datafusion_physical_plan::metrics::{ - Count, ExecutionPlanMetricsSet, MetricBuilder, Time, + Count, ExecutionPlanMetricsSet, MetricBuilder, MetricType, Time, }; /// Stores metrics about the parquet execution for a particular parquet file. @@ -88,30 +88,59 @@ impl ParquetFileMetrics { filename: &str, metrics: &ExecutionPlanMetricsSet, ) -> Self { - let predicate_evaluation_errors = MetricBuilder::new(metrics) - .with_new_label("filename", filename.to_string()) - .counter("predicate_evaluation_errors", partition); - + // ----------------------- + // 'summary' level metrics + // ----------------------- let row_groups_matched_bloom_filter = MetricBuilder::new(metrics) .with_new_label("filename", filename.to_string()) + .with_type(MetricType::SUMMARY) .counter("row_groups_matched_bloom_filter", partition); let row_groups_pruned_bloom_filter = MetricBuilder::new(metrics) .with_new_label("filename", filename.to_string()) + .with_type(MetricType::SUMMARY) .counter("row_groups_pruned_bloom_filter", partition); let row_groups_matched_statistics = MetricBuilder::new(metrics) .with_new_label("filename", filename.to_string()) + .with_type(MetricType::SUMMARY) .counter("row_groups_matched_statistics", partition); let row_groups_pruned_statistics = MetricBuilder::new(metrics) .with_new_label("filename", filename.to_string()) + .with_type(MetricType::SUMMARY) .counter("row_groups_pruned_statistics", partition); + let page_index_rows_pruned = MetricBuilder::new(metrics) + .with_new_label("filename", filename.to_string()) + .with_type(MetricType::SUMMARY) + .counter("page_index_rows_pruned", partition); + let page_index_rows_matched = MetricBuilder::new(metrics) + .with_new_label("filename", filename.to_string()) + .with_type(MetricType::SUMMARY) + .counter("page_index_rows_matched", partition); + let bytes_scanned = MetricBuilder::new(metrics) .with_new_label("filename", filename.to_string()) + .with_type(MetricType::SUMMARY) .counter("bytes_scanned", partition); + let metadata_load_time = MetricBuilder::new(metrics) + .with_new_label("filename", filename.to_string()) + .with_type(MetricType::SUMMARY) + .subset_time("metadata_load_time", partition); + + let files_ranges_pruned_statistics = MetricBuilder::new(metrics) + .with_type(MetricType::SUMMARY) + .counter("files_ranges_pruned_statistics", partition); + + // ----------------------- + // 'dev' level metrics + // ----------------------- + let predicate_evaluation_errors = MetricBuilder::new(metrics) + .with_new_label("filename", filename.to_string()) + .counter("predicate_evaluation_errors", partition); + let pushdown_rows_pruned = MetricBuilder::new(metrics) .with_new_label("filename", filename.to_string()) .counter("pushdown_rows_pruned", partition); @@ -129,24 +158,10 @@ impl ParquetFileMetrics { .with_new_label("filename", filename.to_string()) .subset_time("bloom_filter_eval_time", partition); - let page_index_rows_pruned = MetricBuilder::new(metrics) - .with_new_label("filename", filename.to_string()) - .counter("page_index_rows_pruned", partition); - let page_index_rows_matched = MetricBuilder::new(metrics) - .with_new_label("filename", filename.to_string()) - .counter("page_index_rows_matched", partition); - let page_index_eval_time = MetricBuilder::new(metrics) .with_new_label("filename", filename.to_string()) .subset_time("page_index_eval_time", partition); - let metadata_load_time = MetricBuilder::new(metrics) - .with_new_label("filename", filename.to_string()) - .subset_time("metadata_load_time", partition); - - let files_ranges_pruned_statistics = MetricBuilder::new(metrics) - .counter("files_ranges_pruned_statistics", partition); - let predicate_cache_inner_records = MetricBuilder::new(metrics) .with_new_label("filename", filename.to_string()) .counter("predicate_cache_inner_records", partition); diff --git a/datafusion/datasource-parquet/src/opener.rs b/datafusion/datasource-parquet/src/opener.rs index 167fc3c5147e..af7a537ca6f4 100644 --- a/datafusion/datasource-parquet/src/opener.rs +++ b/datafusion/datasource-parquet/src/opener.rs @@ -208,7 +208,7 @@ impl FileOpener for ParquetOpener { let mut options = ArrowReaderOptions::new().with_page_index(false); #[cfg(feature = "parquet_encryption")] if let Some(fd_val) = file_decryption_properties { - options = options.with_file_decryption_properties((*fd_val).clone()); + options = options.with_file_decryption_properties(Arc::clone(&fd_val)); } let mut metadata_timer = file_metrics.metadata_load_time.timer(); @@ -581,8 +581,7 @@ impl EncryptionContext { None => match &self.encryption_factory { Some((encryption_factory, encryption_config)) => Ok(encryption_factory .get_file_decryption_properties(encryption_config, file_location) - .await? - .map(Arc::new)), + .await?), None => Ok(None), }, } diff --git a/datafusion/datasource-parquet/src/page_filter.rs b/datafusion/datasource-parquet/src/page_filter.rs index 5f3e05747d40..65d1affb44a9 100644 --- a/datafusion/datasource-parquet/src/page_filter.rs +++ b/datafusion/datasource-parquet/src/page_filter.rs @@ -36,7 +36,7 @@ use datafusion_pruning::PruningPredicate; use log::{debug, trace}; use parquet::arrow::arrow_reader::statistics::StatisticsConverter; use parquet::file::metadata::{ParquetColumnIndex, ParquetOffsetIndex}; -use parquet::format::PageLocation; +use parquet::file::page_index::offset_index::PageLocation; use parquet::schema::types::SchemaDescriptor; use parquet::{ arrow::arrow_reader::{RowSelection, RowSelector}, diff --git a/datafusion/datasource-parquet/src/reader.rs b/datafusion/datasource-parquet/src/reader.rs index 687a7f15fccc..88a3cea5623b 100644 --- a/datafusion/datasource-parquet/src/reader.rs +++ b/datafusion/datasource-parquet/src/reader.rs @@ -262,8 +262,9 @@ impl AsyncFileReader for CachedParquetFileReader { async move { #[cfg(feature = "parquet_encryption")] - let file_decryption_properties = - options.and_then(|o| o.file_decryption_properties()); + let file_decryption_properties = options + .and_then(|o| o.file_decryption_properties()) + .map(Arc::clone); #[cfg(not(feature = "parquet_encryption"))] let file_decryption_properties = None; diff --git a/datafusion/datasource-parquet/src/source.rs b/datafusion/datasource-parquet/src/source.rs index dd10363079f9..186d922fc373 100644 --- a/datafusion/datasource-parquet/src/source.rs +++ b/datafusion/datasource-parquet/src/source.rs @@ -52,12 +52,12 @@ use datafusion_physical_plan::metrics::Count; use datafusion_physical_plan::metrics::ExecutionPlanMetricsSet; use datafusion_physical_plan::DisplayFormatType; -#[cfg(feature = "parquet_encryption")] -use datafusion_common::encryption::map_config_decryption_to_decryption; #[cfg(feature = "parquet_encryption")] use datafusion_execution::parquet_encryption::EncryptionFactory; use itertools::Itertools; use object_store::ObjectStore; +#[cfg(feature = "parquet_encryption")] +use parquet::encryption::decrypt::FileDecryptionProperties; /// Execution plan for reading one or more Parquet files. /// @@ -497,7 +497,7 @@ impl FileSource for ParquetSource { ) -> Arc { let projection = base_config .file_column_projection_indices() - .unwrap_or_else(|| (0..base_config.file_schema.fields().len()).collect()); + .unwrap_or_else(|| (0..base_config.file_schema().fields().len()).collect()); let (expr_adapter_factory, schema_adapter_factory) = match ( base_config.expr_adapter_factory.as_ref(), @@ -547,8 +547,8 @@ impl FileSource for ParquetSource { .table_parquet_options() .crypto .file_decryption - .as_ref() - .map(map_config_decryption_to_decryption) + .clone() + .map(FileDecryptionProperties::from) .map(Arc::new); let coerce_int96 = self @@ -566,8 +566,8 @@ impl FileSource for ParquetSource { .expect("Batch size must set before creating ParquetOpener"), limit: base_config.limit, predicate: self.predicate.clone(), - logical_file_schema: Arc::clone(&base_config.file_schema), - partition_fields: base_config.table_partition_cols.clone(), + logical_file_schema: Arc::clone(base_config.file_schema()), + partition_fields: base_config.table_partition_cols().clone(), metadata_size_hint: self.metadata_size_hint, metrics: self.metrics().clone(), parquet_file_reader_factory, diff --git a/datafusion/datasource/Cargo.toml b/datafusion/datasource/Cargo.toml index afd0256ba972..8e0738448a75 100644 --- a/datafusion/datasource/Cargo.toml +++ b/datafusion/datasource/Cargo.toml @@ -45,7 +45,7 @@ async-compression = { version = "0.4.19", features = [ ], optional = true } async-trait = { workspace = true } bytes = { workspace = true } -bzip2 = { version = "0.6.0", optional = true } +bzip2 = { version = "0.6.1", optional = true } chrono = { workspace = true } datafusion-common = { workspace = true, features = ["object_store"] } datafusion-common-runtime = { workspace = true } diff --git a/datafusion/datasource/src/file_scan_config.rs b/datafusion/datasource/src/file_scan_config.rs index e67e1f827372..695252803bae 100644 --- a/datafusion/datasource/src/file_scan_config.rs +++ b/datafusion/datasource/src/file_scan_config.rs @@ -24,7 +24,7 @@ use crate::schema_adapter::SchemaAdapterFactory; use crate::{ display::FileGroupsDisplay, file::FileSource, file_compression_type::FileCompressionType, file_stream::FileStream, - source::DataSource, statistics::MinMaxStatistics, PartitionedFile, + source::DataSource, statistics::MinMaxStatistics, PartitionedFile, TableSchema, }; use arrow::datatypes::FieldRef; use arrow::{ @@ -153,15 +153,16 @@ pub struct FileScanConfig { /// [`RuntimeEnv::register_object_store`]: datafusion_execution::runtime_env::RuntimeEnv::register_object_store /// [`RuntimeEnv::object_store`]: datafusion_execution::runtime_env::RuntimeEnv::object_store pub object_store_url: ObjectStoreUrl, - /// Schema before `projection` is applied. It contains the all columns that may - /// appear in the files. It does not include table partition columns - /// that may be added. - /// Note that this is **not** the schema of the physical files. - /// This is the schema that the physical file schema will be - /// mapped onto, and the schema that the [`DataSourceExec`] will return. + /// Schema information including the file schema, table partition columns, + /// and the combined table schema. + /// + /// The table schema (file schema + partition columns) is the schema exposed + /// upstream of [`FileScanConfig`] (e.g. in [`DataSourceExec`]). + /// + /// See [`TableSchema`] for more information. /// /// [`DataSourceExec`]: crate::source::DataSourceExec - pub file_schema: SchemaRef, + pub table_schema: TableSchema, /// List of files to be processed, grouped into partitions /// /// Each file must have a schema of `file_schema` or a subset. If @@ -180,8 +181,6 @@ pub struct FileScanConfig { /// The maximum number of records to read from this plan. If `None`, /// all records after filtering are returned. pub limit: Option, - /// The partitioning columns - pub table_partition_cols: Vec, /// All equivalent lexicographical orderings that describe the schema. pub output_ordering: Vec, /// File compression type @@ -250,23 +249,19 @@ pub struct FileScanConfig { #[derive(Clone)] pub struct FileScanConfigBuilder { object_store_url: ObjectStoreUrl, - /// Table schema before any projections or partition columns are applied. + /// Schema information including the file schema, table partition columns, + /// and the combined table schema. /// - /// This schema is used to read the files, but is **not** necessarily the - /// schema of the physical files. Rather this is the schema that the + /// This schema is used to read the files, but the file schema is **not** necessarily + /// the schema of the physical files. Rather this is the schema that the /// physical file schema will be mapped onto, and the schema that the /// [`DataSourceExec`] will return. /// - /// This is usually the same as the table schema as specified by the `TableProvider` minus any partition columns. - /// - /// This probably would be better named `table_schema` - /// /// [`DataSourceExec`]: crate::source::DataSourceExec - file_schema: SchemaRef, + table_schema: TableSchema, file_source: Arc, limit: Option, projection: Option>, - table_partition_cols: Vec, constraints: Option, file_groups: Vec, statistics: Option, @@ -291,7 +286,7 @@ impl FileScanConfigBuilder { ) -> Self { Self { object_store_url, - file_schema, + table_schema: TableSchema::from_file_schema(file_schema), file_source, file_groups: vec![], statistics: None, @@ -300,7 +295,6 @@ impl FileScanConfigBuilder { new_lines_in_values: None, limit: None, projection: None, - table_partition_cols: vec![], constraints: None, batch_size: None, expr_adapter_factory: None, @@ -332,10 +326,13 @@ impl FileScanConfigBuilder { /// Set the partitioning columns pub fn with_table_partition_cols(mut self, table_partition_cols: Vec) -> Self { - self.table_partition_cols = table_partition_cols + let table_partition_cols: Vec = table_partition_cols .into_iter() .map(|f| Arc::new(f) as FieldRef) .collect(); + self.table_schema = self + .table_schema + .with_table_partition_cols(table_partition_cols); self } @@ -433,11 +430,10 @@ impl FileScanConfigBuilder { pub fn build(self) -> FileScanConfig { let Self { object_store_url, - file_schema, + table_schema, file_source, limit, projection, - table_partition_cols, constraints, file_groups, statistics, @@ -449,23 +445,22 @@ impl FileScanConfigBuilder { } = self; let constraints = constraints.unwrap_or_default(); - let statistics = - statistics.unwrap_or_else(|| Statistics::new_unknown(&file_schema)); + let statistics = statistics + .unwrap_or_else(|| Statistics::new_unknown(table_schema.file_schema())); let file_source = file_source .with_statistics(statistics.clone()) - .with_schema(Arc::clone(&file_schema)); + .with_schema(Arc::clone(table_schema.file_schema())); let file_compression_type = file_compression_type.unwrap_or(FileCompressionType::UNCOMPRESSED); let new_lines_in_values = new_lines_in_values.unwrap_or(false); FileScanConfig { object_store_url, - file_schema, + table_schema, file_source, limit, projection, - table_partition_cols, constraints, file_groups, output_ordering, @@ -481,7 +476,7 @@ impl From for FileScanConfigBuilder { fn from(config: FileScanConfig) -> Self { Self { object_store_url: config.object_store_url, - file_schema: config.file_schema, + table_schema: config.table_schema, file_source: Arc::::clone(&config.file_source), file_groups: config.file_groups, statistics: config.file_source.statistics().ok(), @@ -490,7 +485,6 @@ impl From for FileScanConfigBuilder { new_lines_in_values: Some(config.new_lines_in_values), limit: config.limit, projection: config.projection, - table_partition_cols: config.table_partition_cols, constraints: Some(config.constraints), batch_size: config.batch_size, expr_adapter_factory: config.expr_adapter_factory, @@ -604,8 +598,39 @@ impl DataSource for FileScanConfig { SchedulingType::Cooperative } - fn statistics(&self) -> Result { - Ok(self.projected_stats()) + fn partition_statistics(&self, partition: Option) -> Result { + if let Some(partition) = partition { + // Get statistics for a specific partition + if let Some(file_group) = self.file_groups.get(partition) { + if let Some(stat) = file_group.file_statistics(None) { + // Project the statistics based on the projection + let table_cols_stats = self + .projection_indices() + .into_iter() + .map(|idx| { + if idx < self.file_schema().fields().len() { + stat.column_statistics[idx].clone() + } else { + // TODO provide accurate stat for partition column + // See https://github.com/apache/datafusion/issues/1186 + ColumnStatistics::new_unknown() + } + }) + .collect(); + + return Ok(Statistics { + num_rows: stat.num_rows, + total_byte_size: stat.total_byte_size, + column_statistics: table_cols_stats, + }); + } + } + // If no statistics available for this partition, return unknown + Ok(Statistics::new_unknown(&self.projected_schema())) + } else { + // Return aggregate statistics across all partitions + Ok(self.projected_stats()) + } } fn with_fetch(&self, limit: Option) -> Option> { @@ -635,7 +660,7 @@ impl DataSource for FileScanConfig { .expr .as_any() .downcast_ref::() - .map(|expr| expr.index() >= self.file_schema.fields().len()) + .map(|expr| expr.index() >= self.file_schema().fields().len()) .unwrap_or(false) }); @@ -650,7 +675,7 @@ impl DataSource for FileScanConfig { &file_scan .projection .clone() - .unwrap_or_else(|| (0..self.file_schema.fields().len()).collect()), + .unwrap_or_else(|| (0..self.file_schema().fields().len()).collect()), ); Arc::new( @@ -691,11 +716,21 @@ impl DataSource for FileScanConfig { } impl FileScanConfig { + /// Get the file schema (schema of the files without partition columns) + pub fn file_schema(&self) -> &SchemaRef { + self.table_schema.file_schema() + } + + /// Get the table partition columns + pub fn table_partition_cols(&self) -> &Vec { + self.table_schema.table_partition_cols() + } + fn projection_indices(&self) -> Vec { match &self.projection { Some(proj) => proj.clone(), - None => (0..self.file_schema.fields().len() - + self.table_partition_cols.len()) + None => (0..self.file_schema().fields().len() + + self.table_partition_cols().len()) .collect(), } } @@ -707,7 +742,7 @@ impl FileScanConfig { .projection_indices() .into_iter() .map(|idx| { - if idx < self.file_schema.fields().len() { + if idx < self.file_schema().fields().len() { statistics.column_statistics[idx].clone() } else { // TODO provide accurate stat for partition column (#1186) @@ -729,12 +764,12 @@ impl FileScanConfig { .projection_indices() .into_iter() .map(|idx| { - if idx < self.file_schema.fields().len() { - self.file_schema.field(idx).clone() + if idx < self.file_schema().fields().len() { + self.file_schema().field(idx).clone() } else { - let partition_idx = idx - self.file_schema.fields().len(); + let partition_idx = idx - self.file_schema().fields().len(); Arc::unwrap_or_clone(Arc::clone( - &self.table_partition_cols[partition_idx], + &self.table_partition_cols()[partition_idx], )) } }) @@ -742,7 +777,7 @@ impl FileScanConfig { Arc::new(Schema::new_with_metadata( table_fields, - self.file_schema.metadata().clone(), + self.file_schema().metadata().clone(), )) } @@ -790,9 +825,9 @@ impl FileScanConfig { /// Project the schema, constraints, and the statistics on the given column indices pub fn project(&self) -> (SchemaRef, Constraints, Statistics, Vec) { - if self.projection.is_none() && self.table_partition_cols.is_empty() { + if self.projection.is_none() && self.table_partition_cols().is_empty() { return ( - Arc::clone(&self.file_schema), + Arc::clone(self.file_schema()), self.constraints.clone(), self.file_source.statistics().unwrap().clone(), self.output_ordering.clone(), @@ -811,8 +846,8 @@ impl FileScanConfig { pub fn projected_file_column_names(&self) -> Option> { self.projection.as_ref().map(|p| { p.iter() - .filter(|col_idx| **col_idx < self.file_schema.fields().len()) - .map(|col_idx| self.file_schema.field(*col_idx).name()) + .filter(|col_idx| **col_idx < self.file_schema().fields().len()) + .map(|col_idx| self.file_schema().field(*col_idx).name()) .cloned() .collect() }) @@ -823,17 +858,17 @@ impl FileScanConfig { let fields = self.file_column_projection_indices().map(|indices| { indices .iter() - .map(|col_idx| self.file_schema.field(*col_idx)) + .map(|col_idx| self.file_schema().field(*col_idx)) .cloned() .collect::>() }); fields.map_or_else( - || Arc::clone(&self.file_schema), + || Arc::clone(self.file_schema()), |f| { Arc::new(Schema::new_with_metadata( f, - self.file_schema.metadata.clone(), + self.file_schema().metadata.clone(), )) }, ) @@ -842,7 +877,7 @@ impl FileScanConfig { pub fn file_column_projection_indices(&self) -> Option> { self.projection.as_ref().map(|p| { p.iter() - .filter(|col_idx| **col_idx < self.file_schema.fields().len()) + .filter(|col_idx| **col_idx < self.file_schema().fields().len()) .copied() .collect() }) @@ -1599,7 +1634,7 @@ mod tests { ); let source_statistics = conf.file_source.statistics().unwrap(); - let conf_stats = conf.statistics().unwrap(); + let conf_stats = conf.partition_statistics(None).unwrap(); // projection should be reflected in the file source statistics assert_eq!(conf_stats.num_rows, Precision::Inexact(3)); @@ -2182,11 +2217,11 @@ mod tests { // Verify the built config has all the expected values assert_eq!(config.object_store_url, object_store_url); - assert_eq!(config.file_schema, file_schema); + assert_eq!(*config.file_schema(), file_schema); assert_eq!(config.limit, Some(1000)); assert_eq!(config.projection, Some(vec![0, 1])); - assert_eq!(config.table_partition_cols.len(), 1); - assert_eq!(config.table_partition_cols[0].name(), "date"); + assert_eq!(config.table_partition_cols().len(), 1); + assert_eq!(config.table_partition_cols()[0].name(), "date"); assert_eq!(config.file_groups.len(), 1); assert_eq!(config.file_groups[0].len(), 1); assert_eq!( @@ -2265,10 +2300,10 @@ mod tests { // Verify default values assert_eq!(config.object_store_url, object_store_url); - assert_eq!(config.file_schema, file_schema); + assert_eq!(*config.file_schema(), file_schema); assert_eq!(config.limit, None); assert_eq!(config.projection, None); - assert!(config.table_partition_cols.is_empty()); + assert!(config.table_partition_cols().is_empty()); assert!(config.file_groups.is_empty()); assert_eq!( config.file_compression_type, @@ -2339,10 +2374,10 @@ mod tests { // Verify properties match let partition_cols = partition_cols.into_iter().map(Arc::new).collect::>(); assert_eq!(new_config.object_store_url, object_store_url); - assert_eq!(new_config.file_schema, schema); + assert_eq!(*new_config.file_schema(), schema); assert_eq!(new_config.projection, Some(vec![0, 2])); assert_eq!(new_config.limit, Some(10)); - assert_eq!(new_config.table_partition_cols, partition_cols); + assert_eq!(*new_config.table_partition_cols(), partition_cols); assert_eq!(new_config.file_groups.len(), 1); assert_eq!(new_config.file_groups[0].len(), 1); assert_eq!( @@ -2506,4 +2541,91 @@ mod tests { Ok(()) } + + #[test] + fn test_partition_statistics_projection() { + // This test verifies that partition_statistics applies projection correctly. + // The old implementation had a bug where it returned file group statistics + // without applying the projection, returning all column statistics instead + // of just the projected ones. + + use crate::source::DataSourceExec; + use datafusion_physical_plan::ExecutionPlan; + + // Create a schema with 4 columns + let schema = Arc::new(Schema::new(vec![ + Field::new("col0", DataType::Int32, false), + Field::new("col1", DataType::Int32, false), + Field::new("col2", DataType::Int32, false), + Field::new("col3", DataType::Int32, false), + ])); + + // Create statistics for all 4 columns + let file_group_stats = Statistics { + num_rows: Precision::Exact(100), + total_byte_size: Precision::Exact(1024), + column_statistics: vec![ + ColumnStatistics { + null_count: Precision::Exact(0), + ..ColumnStatistics::new_unknown() + }, + ColumnStatistics { + null_count: Precision::Exact(5), + ..ColumnStatistics::new_unknown() + }, + ColumnStatistics { + null_count: Precision::Exact(10), + ..ColumnStatistics::new_unknown() + }, + ColumnStatistics { + null_count: Precision::Exact(15), + ..ColumnStatistics::new_unknown() + }, + ], + }; + + // Create a file group with statistics + let file_group = FileGroup::new(vec![PartitionedFile::new("test.parquet", 1024)]) + .with_statistics(Arc::new(file_group_stats)); + + // Create a FileScanConfig with projection: only keep columns 0 and 2 + let config = FileScanConfigBuilder::new( + ObjectStoreUrl::parse("test:///").unwrap(), + Arc::clone(&schema), + Arc::new(MockSource::default()), + ) + .with_projection(Some(vec![0, 2])) // Only project columns 0 and 2 + .with_file_groups(vec![file_group]) + .build(); + + // Create a DataSourceExec from the config + let exec = DataSourceExec::from_data_source(config); + + // Get statistics for partition 0 + let partition_stats = exec.partition_statistics(Some(0)).unwrap(); + + // Verify that only 2 columns are in the statistics (the projected ones) + assert_eq!( + partition_stats.column_statistics.len(), + 2, + "Expected 2 column statistics (projected), but got {}", + partition_stats.column_statistics.len() + ); + + // Verify the column statistics are for columns 0 and 2 + assert_eq!( + partition_stats.column_statistics[0].null_count, + Precision::Exact(0), + "First projected column should be col0 with 0 nulls" + ); + assert_eq!( + partition_stats.column_statistics[1].null_count, + Precision::Exact(10), + "Second projected column should be col2 with 10 nulls" + ); + + // Verify row count and byte size are preserved + assert_eq!(partition_stats.num_rows, Precision::Exact(100)); + assert_eq!(partition_stats.total_byte_size, Precision::Exact(1024)); + } } diff --git a/datafusion/datasource/src/file_stream.rs b/datafusion/datasource/src/file_stream.rs index e0b6c25a1916..9fee5691beea 100644 --- a/datafusion/datasource/src/file_stream.rs +++ b/datafusion/datasource/src/file_stream.rs @@ -80,7 +80,7 @@ impl FileStream { let pc_projector = PartitionColumnProjector::new( Arc::clone(&projected_schema), &config - .table_partition_cols + .table_partition_cols() .iter() .map(|x| x.name().clone()) .collect::>(), diff --git a/datafusion/datasource/src/memory.rs b/datafusion/datasource/src/memory.rs index eb55aa9b0b0d..7d5c8c4834ea 100644 --- a/datafusion/datasource/src/memory.rs +++ b/datafusion/datasource/src/memory.rs @@ -21,6 +21,7 @@ use std::collections::BinaryHeap; use std::fmt; use std::fmt::Debug; use std::ops::Deref; +use std::slice::from_ref; use std::sync::Arc; use crate::sink::DataSink; @@ -192,12 +193,27 @@ impl DataSource for MemorySourceConfig { SchedulingType::Cooperative } - fn statistics(&self) -> Result { - Ok(common::compute_record_batch_statistics( - &self.partitions, - &self.schema, - self.projection.clone(), - )) + fn partition_statistics(&self, partition: Option) -> Result { + if let Some(partition) = partition { + // Compute statistics for a specific partition + if let Some(batches) = self.partitions.get(partition) { + Ok(common::compute_record_batch_statistics( + from_ref(batches), + &self.schema, + self.projection.clone(), + )) + } else { + // Invalid partition index + Ok(Statistics::new_unknown(&self.projected_schema)) + } + } else { + // Compute statistics across all partitions + Ok(common::compute_record_batch_statistics( + &self.partitions, + &self.schema, + self.projection.clone(), + )) + } } fn with_fetch(&self, limit: Option) -> Option> { diff --git a/datafusion/datasource/src/mod.rs b/datafusion/datasource/src/mod.rs index 1f47c0983ea1..80b44ad5949a 100644 --- a/datafusion/datasource/src/mod.rs +++ b/datafusion/datasource/src/mod.rs @@ -41,6 +41,7 @@ pub mod schema_adapter; pub mod sink; pub mod source; mod statistics; +pub mod table_schema; #[cfg(test)] pub mod test_util; @@ -57,6 +58,7 @@ use datafusion_common::{ScalarValue, Statistics}; use futures::{Stream, StreamExt}; use object_store::{path::Path, ObjectMeta}; use object_store::{GetOptions, GetRange, ObjectStore}; +pub use table_schema::TableSchema; // Remove when add_row_stats is remove #[allow(deprecated)] pub use statistics::add_row_stats; diff --git a/datafusion/datasource/src/source.rs b/datafusion/datasource/src/source.rs index 20d9a1d6e53f..11a8a3867b80 100644 --- a/datafusion/datasource/src/source.rs +++ b/datafusion/datasource/src/source.rs @@ -151,7 +151,21 @@ pub trait DataSource: Send + Sync + Debug { fn scheduling_type(&self) -> SchedulingType { SchedulingType::NonCooperative } - fn statistics(&self) -> Result; + + /// Returns statistics for a specific partition, or aggregate statistics + /// across all partitions if `partition` is `None`. + fn partition_statistics(&self, partition: Option) -> Result; + + /// Returns aggregate statistics across all partitions. + /// + /// # Deprecated + /// Use [`Self::partition_statistics`] instead, which provides more fine-grained + /// control over statistics retrieval (per-partition or aggregate). + #[deprecated(since = "51.0.0", note = "Use partition_statistics instead")] + fn statistics(&self) -> Result { + self.partition_statistics(None) + } + /// Return a copy of this DataSource with a new fetch limit fn with_fetch(&self, _limit: Option) -> Option>; fn fetch(&self) -> Option; @@ -285,21 +299,7 @@ impl ExecutionPlan for DataSourceExec { } fn partition_statistics(&self, partition: Option) -> Result { - if let Some(partition) = partition { - let mut statistics = Statistics::new_unknown(&self.schema()); - if let Some(file_config) = - self.data_source.as_any().downcast_ref::() - { - if let Some(file_group) = file_config.file_groups.get(partition) { - if let Some(stat) = file_group.file_statistics(None) { - statistics = stat.clone(); - } - } - } - Ok(statistics) - } else { - Ok(self.data_source.statistics()?) - } + self.data_source.partition_statistics(partition) } fn with_fetch(&self, limit: Option) -> Option> { diff --git a/datafusion/datasource/src/table_schema.rs b/datafusion/datasource/src/table_schema.rs new file mode 100644 index 000000000000..8e95585ce873 --- /dev/null +++ b/datafusion/datasource/src/table_schema.rs @@ -0,0 +1,160 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Helper struct to manage table schemas with partition columns + +use arrow::datatypes::{FieldRef, SchemaBuilder, SchemaRef}; +use std::sync::Arc; + +/// Helper to hold table schema information for partitioned data sources. +/// +/// When reading partitioned data (such as Hive-style partitioning), a table's schema +/// consists of two parts: +/// 1. **File schema**: The schema of the actual data files on disk +/// 2. **Partition columns**: Columns that are encoded in the directory structure, +/// not stored in the files themselves +/// +/// # Example: Partitioned Table +/// +/// Consider a table with the following directory structure: +/// ```text +/// /data/date=2025-10-10/region=us-west/data.parquet +/// /data/date=2025-10-11/region=us-east/data.parquet +/// ``` +/// +/// In this case: +/// - **File schema**: The schema of `data.parquet` files (e.g., `[user_id, amount]`) +/// - **Partition columns**: `[date, region]` extracted from the directory path +/// - **Table schema**: The full schema combining both (e.g., `[user_id, amount, date, region]`) +/// +/// # When to Use +/// +/// Use `TableSchema` when: +/// - Reading partitioned data sources (Parquet, CSV, etc. with Hive-style partitioning) +/// - You need to efficiently access different schema representations without reconstructing them +/// - You want to avoid repeatedly concatenating file and partition schemas +/// +/// For non-partitioned data or when working with a single schema representation, +/// working directly with Arrow's `Schema` or `SchemaRef` is simpler. +/// +/// # Performance +/// +/// This struct pre-computes and caches the full table schema, allowing cheap references +/// to any representation without repeated allocations or reconstructions. +#[derive(Debug, Clone)] +pub struct TableSchema { + /// The schema of the data files themselves, without partition columns. + /// + /// For example, if your Parquet files contain `[user_id, amount]`, + /// this field holds that schema. + file_schema: SchemaRef, + + /// Columns that are derived from the directory structure (partitioning scheme). + /// + /// For Hive-style partitioning like `/date=2025-10-10/region=us-west/`, + /// this contains the `date` and `region` fields. + /// + /// These columns are NOT present in the data files but are appended to each + /// row during query execution based on the file's location. + table_partition_cols: Vec, + + /// The complete table schema: file_schema columns followed by partition columns. + /// + /// This is pre-computed during construction by concatenating `file_schema` + /// and `table_partition_cols`, so it can be returned as a cheap reference. + table_schema: SchemaRef, +} + +impl TableSchema { + /// Create a new TableSchema from a file schema and partition columns. + /// + /// The table schema is automatically computed by appending the partition columns + /// to the file schema. + /// + /// # Arguments + /// + /// * `file_schema` - Schema of the data files (without partition columns) + /// * `table_partition_cols` - Partition columns to append to each row + /// + /// # Example + /// + /// ``` + /// # use std::sync::Arc; + /// # use arrow::datatypes::{Schema, Field, DataType}; + /// # use datafusion_datasource::TableSchema; + /// let file_schema = Arc::new(Schema::new(vec![ + /// Field::new("user_id", DataType::Int64, false), + /// Field::new("amount", DataType::Float64, false), + /// ])); + /// + /// let partition_cols = vec![ + /// Arc::new(Field::new("date", DataType::Utf8, false)), + /// Arc::new(Field::new("region", DataType::Utf8, false)), + /// ]; + /// + /// let table_schema = TableSchema::new(file_schema, partition_cols); + /// + /// // Table schema will have 4 columns: user_id, amount, date, region + /// assert_eq!(table_schema.table_schema().fields().len(), 4); + /// ``` + pub fn new(file_schema: SchemaRef, table_partition_cols: Vec) -> Self { + let mut builder = SchemaBuilder::from(file_schema.as_ref()); + builder.extend(table_partition_cols.iter().cloned()); + Self { + file_schema, + table_partition_cols, + table_schema: Arc::new(builder.finish()), + } + } + + /// Create a new TableSchema from a file schema with no partition columns. + pub fn from_file_schema(file_schema: SchemaRef) -> Self { + Self::new(file_schema, vec![]) + } + + /// Set the table partition columns and rebuild the table schema. + pub fn with_table_partition_cols( + mut self, + table_partition_cols: Vec, + ) -> TableSchema { + self.table_partition_cols = table_partition_cols; + self + } + + /// Get the file schema (without partition columns). + /// + /// This is the schema of the actual data files on disk. + pub fn file_schema(&self) -> &SchemaRef { + &self.file_schema + } + + /// Get the table partition columns. + /// + /// These are the columns derived from the directory structure that + /// will be appended to each row during query execution. + pub fn table_partition_cols(&self) -> &Vec { + &self.table_partition_cols + } + + /// Get the full table schema (file schema + partition columns). + /// + /// This is the complete schema that will be seen by queries, combining + /// both the columns from the files and the partition columns. + pub fn table_schema(&self) -> &SchemaRef { + &self.table_schema + } +} diff --git a/datafusion/datasource/src/url.rs b/datafusion/datasource/src/url.rs index c87b307c5fb8..0f31eb7caf41 100644 --- a/datafusion/datasource/src/url.rs +++ b/datafusion/datasource/src/url.rs @@ -252,7 +252,10 @@ impl ListingTableUrl { .boxed(), // If the head command fails, it is likely that object doesn't exist. // Retry as though it were a prefix (aka a collection) - Err(_) => list_with_cache(ctx, store, &self.prefix).await?, + Err(object_store::Error::NotFound { .. }) => { + list_with_cache(ctx, store, &self.prefix).await? + } + Err(e) => return Err(e.into()), } }; @@ -405,6 +408,8 @@ fn split_glob_expression(path: &str) -> Option<(&str, &str)> { #[cfg(test)] mod tests { use super::*; + use async_trait::async_trait; + use bytes::Bytes; use datafusion_common::config::TableOptions; use datafusion_common::DFSchema; use datafusion_execution::config::SessionConfig; @@ -414,9 +419,13 @@ mod tests { use datafusion_expr::{AggregateUDF, Expr, LogicalPlan, ScalarUDF, WindowUDF}; use datafusion_physical_expr_common::physical_expr::PhysicalExpr; use datafusion_physical_plan::ExecutionPlan; - use object_store::PutPayload; + use object_store::{ + GetOptions, GetResult, ListResult, MultipartUpload, PutMultipartOptions, + PutPayload, + }; use std::any::Any; use std::collections::HashMap; + use std::ops::Range; use tempfile::tempdir; #[test] @@ -632,48 +641,68 @@ mod tests { } #[tokio::test] - async fn test_list_files() { - let store = object_store::memory::InMemory::new(); + async fn test_list_files() -> Result<()> { + let store = MockObjectStore { + in_mem: object_store::memory::InMemory::new(), + forbidden_paths: vec!["forbidden/e.parquet".into()], + }; + // Create some files: create_file(&store, "a.parquet").await; create_file(&store, "/t/b.parquet").await; create_file(&store, "/t/c.csv").await; create_file(&store, "/t/d.csv").await; + // This file returns a permission error. + create_file(&store, "/forbidden/e.parquet").await; + assert_eq!( - list_all_files("/", &store, "parquet").await, + list_all_files("/", &store, "parquet").await?, vec!["a.parquet"], ); // test with and without trailing slash assert_eq!( - list_all_files("/t/", &store, "parquet").await, + list_all_files("/t/", &store, "parquet").await?, vec!["t/b.parquet"], ); assert_eq!( - list_all_files("/t", &store, "parquet").await, + list_all_files("/t", &store, "parquet").await?, vec!["t/b.parquet"], ); // test with and without trailing slash assert_eq!( - list_all_files("/t", &store, "csv").await, + list_all_files("/t", &store, "csv").await?, vec!["t/c.csv", "t/d.csv"], ); assert_eq!( - list_all_files("/t/", &store, "csv").await, + list_all_files("/t/", &store, "csv").await?, vec!["t/c.csv", "t/d.csv"], ); // Test a non existing prefix assert_eq!( - list_all_files("/NonExisting", &store, "csv").await, + list_all_files("/NonExisting", &store, "csv").await?, vec![] as Vec ); assert_eq!( - list_all_files("/NonExisting/", &store, "csv").await, + list_all_files("/NonExisting/", &store, "csv").await?, vec![] as Vec ); + + // Including forbidden.parquet generates an error. + let Err(DataFusionError::ObjectStore(err)) = + list_all_files("/forbidden/e.parquet", &store, "parquet").await + else { + panic!("Expected ObjectStore error"); + }; + + let object_store::Error::PermissionDenied { .. } = &*err else { + panic!("Expected PermissionDenied error"); + }; + + Ok(()) } /// Creates a file with "hello world" content at the specified path @@ -691,10 +720,8 @@ mod tests { url: &str, store: &dyn ObjectStore, file_extension: &str, - ) -> Vec { - try_list_all_files(url, store, file_extension) - .await - .unwrap() + ) -> Result> { + try_list_all_files(url, store, file_extension).await } /// Runs "list_all_files" and returns their paths @@ -716,6 +743,95 @@ mod tests { Ok(files) } + #[derive(Debug)] + struct MockObjectStore { + in_mem: object_store::memory::InMemory, + forbidden_paths: Vec, + } + + impl std::fmt::Display for MockObjectStore { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + self.in_mem.fmt(f) + } + } + + #[async_trait] + impl ObjectStore for MockObjectStore { + async fn put_opts( + &self, + location: &Path, + payload: PutPayload, + opts: object_store::PutOptions, + ) -> object_store::Result { + self.in_mem.put_opts(location, payload, opts).await + } + + async fn put_multipart_opts( + &self, + location: &Path, + opts: PutMultipartOptions, + ) -> object_store::Result> { + self.in_mem.put_multipart_opts(location, opts).await + } + + async fn get_opts( + &self, + location: &Path, + options: GetOptions, + ) -> object_store::Result { + self.in_mem.get_opts(location, options).await + } + + async fn get_ranges( + &self, + location: &Path, + ranges: &[Range], + ) -> object_store::Result> { + self.in_mem.get_ranges(location, ranges).await + } + + async fn head(&self, location: &Path) -> object_store::Result { + if self.forbidden_paths.contains(location) { + Err(object_store::Error::PermissionDenied { + path: location.to_string(), + source: "forbidden".into(), + }) + } else { + self.in_mem.head(location).await + } + } + + async fn delete(&self, location: &Path) -> object_store::Result<()> { + self.in_mem.delete(location).await + } + + fn list( + &self, + prefix: Option<&Path>, + ) -> BoxStream<'static, object_store::Result> { + self.in_mem.list(prefix) + } + + async fn list_with_delimiter( + &self, + prefix: Option<&Path>, + ) -> object_store::Result { + self.in_mem.list_with_delimiter(prefix).await + } + + async fn copy(&self, from: &Path, to: &Path) -> object_store::Result<()> { + self.in_mem.copy(from, to).await + } + + async fn copy_if_not_exists( + &self, + from: &Path, + to: &Path, + ) -> object_store::Result<()> { + self.in_mem.copy_if_not_exists(from, to).await + } + } + struct MockSession { config: SessionConfig, runtime_env: Arc, diff --git a/datafusion/datasource/src/write/demux.rs b/datafusion/datasource/src/write/demux.rs index e80099823054..52cb17c10453 100644 --- a/datafusion/datasource/src/write/demux.rs +++ b/datafusion/datasource/src/write/demux.rs @@ -40,9 +40,9 @@ use datafusion_common::cast::{ }; use datafusion_common::{exec_datafusion_err, internal_datafusion_err, not_impl_err}; use datafusion_common_runtime::SpawnedTask; -use datafusion_execution::TaskContext; use chrono::NaiveDate; +use datafusion_execution::TaskContext; use futures::StreamExt; use object_store::path::Path; use rand::distr::SampleString; @@ -68,6 +68,11 @@ pub type DemuxedStreamReceiver = UnboundedReceiver<(Path, RecordBatchReceiver)>; /// be written with the extension from the path. Otherwise the default extension /// will be used and the output will be split into multiple files. /// +/// Output file guarantees: +/// - Partitioned files: Files are created only for non-empty partitions. +/// - Single-file output: 1 file is always written, even when the stream is empty. +/// - Multi-file output: Depending on the number of record batches, 0 or more files are written. +/// /// Examples of `base_output_path` /// * `tmp/dataset/` -> is a folder since it ends in `/` /// * `tmp/dataset` -> is still a folder since it does not end in `/` but has no valid file extension @@ -171,6 +176,21 @@ async fn row_count_demuxer( max_rows_per_file }; + if single_file_output { + // ensure we have one file open, even when the input stream is empty + open_file_streams.push(create_new_file_stream( + &base_output_path, + &write_id, + part_idx, + &file_extension, + single_file_output, + max_buffered_batches, + &mut tx, + )?); + row_counts.push(0); + part_idx += 1; + } + while let Some(rb) = input.next().await.transpose()? { // ensure we have at least minimum_parallel_files open if open_file_streams.len() < minimum_parallel_files { diff --git a/datafusion/execution/src/object_store.rs b/datafusion/execution/src/object_store.rs index ef83128ac681..aedee7d44460 100644 --- a/datafusion/execution/src/object_store.rs +++ b/datafusion/execution/src/object_store.rs @@ -20,7 +20,9 @@ //! and query data inside these systems. use dashmap::DashMap; -use datafusion_common::{exec_err, internal_datafusion_err, DataFusionError, Result}; +use datafusion_common::{ + exec_err, internal_datafusion_err, not_impl_err, DataFusionError, Result, +}; #[cfg(not(target_arch = "wasm32"))] use object_store::local::LocalFileSystem; use object_store::ObjectStore; @@ -154,6 +156,13 @@ pub trait ObjectStoreRegistry: Send + Sync + std::fmt::Debug + 'static { store: Arc, ) -> Option>; + /// Deregister the store previously registered with the same key. Returns the + /// deregistered store if it existed. + #[allow(unused_variables)] + fn deregister_store(&self, url: &Url) -> Result> { + not_impl_err!("ObjectStoreRegistry::deregister_store is not implemented for this ObjectStoreRegistry") + } + /// Get a suitable store for the provided URL. For example: /// /// - URL with scheme `file:///` or no scheme will return the default LocalFS store @@ -230,6 +239,17 @@ impl ObjectStoreRegistry for DefaultObjectStoreRegistry { self.object_stores.insert(s, store) } + fn deregister_store(&self, url: &Url) -> Result> { + let s = get_url_key(url); + let (_, object_store) = self.object_stores + .remove(&s) + .ok_or_else(|| { + internal_datafusion_err!("Failed to deregister object store. No suitable object store found for {url}. See `RuntimeEnv::register_object_store`") + })?; + + Ok(object_store) + } + fn get_store(&self, url: &Url) -> Result> { let s = get_url_key(url); self.object_stores diff --git a/datafusion/execution/src/parquet_encryption.rs b/datafusion/execution/src/parquet_encryption.rs index 73881e11ca72..027421e08f54 100644 --- a/datafusion/execution/src/parquet_encryption.rs +++ b/datafusion/execution/src/parquet_encryption.rs @@ -41,14 +41,14 @@ pub trait EncryptionFactory: Send + Sync + std::fmt::Debug + 'static { config: &EncryptionFactoryOptions, schema: &SchemaRef, file_path: &Path, - ) -> Result>; + ) -> Result>>; /// Generate file decryption properties to use when reading a Parquet file. async fn get_file_decryption_properties( &self, config: &EncryptionFactoryOptions, file_path: &Path, - ) -> Result>; + ) -> Result>>; } /// Stores [`EncryptionFactory`] implementations that can be retrieved by a unique string identifier diff --git a/datafusion/execution/src/runtime_env.rs b/datafusion/execution/src/runtime_env.rs index db045a8b7e8a..b0d0a966b7a2 100644 --- a/datafusion/execution/src/runtime_env.rs +++ b/datafusion/execution/src/runtime_env.rs @@ -114,8 +114,6 @@ impl RuntimeEnv { /// ``` /// /// # Example: Register remote URL object store like [Github](https://github.com) - /// - /// /// ``` /// # use std::sync::Arc; /// # use url::Url; @@ -141,6 +139,12 @@ impl RuntimeEnv { self.object_store_registry.register_store(url, object_store) } + /// Deregisters a custom `ObjectStore` previously registered for a specific url. + /// See [`ObjectStoreRegistry::deregister_store`] for more details. + pub fn deregister_object_store(&self, url: &Url) -> Result> { + self.object_store_registry.deregister_store(url) + } + /// Retrieves a `ObjectStore` instance for a url by consulting the /// registry. See [`ObjectStoreRegistry::get_store`] for more /// details. diff --git a/datafusion/expr-common/src/type_coercion/aggregates.rs b/datafusion/expr-common/src/type_coercion/aggregates.rs index e77a072a84f3..55a8843394b5 100644 --- a/datafusion/expr-common/src/type_coercion/aggregates.rs +++ b/datafusion/expr-common/src/type_coercion/aggregates.rs @@ -16,31 +16,12 @@ // under the License. use crate::signature::TypeSignature; -use arrow::datatypes::{ - DataType, FieldRef, TimeUnit, DECIMAL128_MAX_PRECISION, DECIMAL128_MAX_SCALE, - DECIMAL256_MAX_PRECISION, DECIMAL256_MAX_SCALE, DECIMAL32_MAX_PRECISION, - DECIMAL32_MAX_SCALE, DECIMAL64_MAX_PRECISION, DECIMAL64_MAX_SCALE, -}; +use arrow::datatypes::{DataType, FieldRef}; use datafusion_common::{internal_err, plan_err, Result}; -pub static STRINGS: &[DataType] = - &[DataType::Utf8, DataType::LargeUtf8, DataType::Utf8View]; - -pub static SIGNED_INTEGERS: &[DataType] = &[ - DataType::Int8, - DataType::Int16, - DataType::Int32, - DataType::Int64, -]; - -pub static UNSIGNED_INTEGERS: &[DataType] = &[ - DataType::UInt8, - DataType::UInt16, - DataType::UInt32, - DataType::UInt64, -]; - +// TODO: remove usage of these (INTEGERS and NUMERICS) in favour of signatures +// see https://github.com/apache/datafusion/issues/18092 pub static INTEGERS: &[DataType] = &[ DataType::Int8, DataType::Int16, @@ -65,24 +46,6 @@ pub static NUMERICS: &[DataType] = &[ DataType::Float64, ]; -pub static TIMESTAMPS: &[DataType] = &[ - DataType::Timestamp(TimeUnit::Second, None), - DataType::Timestamp(TimeUnit::Millisecond, None), - DataType::Timestamp(TimeUnit::Microsecond, None), - DataType::Timestamp(TimeUnit::Nanosecond, None), -]; - -pub static DATES: &[DataType] = &[DataType::Date32, DataType::Date64]; - -pub static BINARYS: &[DataType] = &[DataType::Binary, DataType::LargeBinary]; - -pub static TIMES: &[DataType] = &[ - DataType::Time32(TimeUnit::Second), - DataType::Time32(TimeUnit::Millisecond), - DataType::Time64(TimeUnit::Microsecond), - DataType::Time64(TimeUnit::Nanosecond), -]; - /// Validate the length of `input_fields` matches the `signature` for `agg_fun`. /// /// This method DOES NOT validate the argument fields - only that (at least one, @@ -144,260 +107,3 @@ pub fn check_arg_count( } Ok(()) } - -/// Function return type of a sum -pub fn sum_return_type(arg_type: &DataType) -> Result { - match arg_type { - DataType::Int64 => Ok(DataType::Int64), - DataType::UInt64 => Ok(DataType::UInt64), - DataType::Float64 => Ok(DataType::Float64), - DataType::Decimal32(precision, scale) => { - // in the spark, the result type is DECIMAL(min(38,precision+10), s) - // ref: https://github.com/apache/spark/blob/fcf636d9eb8d645c24be3db2d599aba2d7e2955a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala#L66 - let new_precision = DECIMAL32_MAX_PRECISION.min(*precision + 10); - Ok(DataType::Decimal32(new_precision, *scale)) - } - DataType::Decimal64(precision, scale) => { - // in the spark, the result type is DECIMAL(min(38,precision+10), s) - // ref: https://github.com/apache/spark/blob/fcf636d9eb8d645c24be3db2d599aba2d7e2955a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala#L66 - let new_precision = DECIMAL64_MAX_PRECISION.min(*precision + 10); - Ok(DataType::Decimal64(new_precision, *scale)) - } - DataType::Decimal128(precision, scale) => { - // In the spark, the result type is DECIMAL(min(38,precision+10), s) - // Ref: https://github.com/apache/spark/blob/fcf636d9eb8d645c24be3db2d599aba2d7e2955a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala#L66 - let new_precision = DECIMAL128_MAX_PRECISION.min(*precision + 10); - Ok(DataType::Decimal128(new_precision, *scale)) - } - DataType::Decimal256(precision, scale) => { - // In the spark, the result type is DECIMAL(min(38,precision+10), s) - // Ref: https://github.com/apache/spark/blob/fcf636d9eb8d645c24be3db2d599aba2d7e2955a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala#L66 - let new_precision = DECIMAL256_MAX_PRECISION.min(*precision + 10); - Ok(DataType::Decimal256(new_precision, *scale)) - } - other => plan_err!("SUM does not support type \"{other:?}\""), - } -} - -/// Function return type of variance -pub fn variance_return_type(arg_type: &DataType) -> Result { - if NUMERICS.contains(arg_type) { - Ok(DataType::Float64) - } else { - plan_err!("VAR does not support {arg_type}") - } -} - -/// Function return type of covariance -pub fn covariance_return_type(arg_type: &DataType) -> Result { - if NUMERICS.contains(arg_type) { - Ok(DataType::Float64) - } else { - plan_err!("COVAR does not support {arg_type}") - } -} - -/// Function return type of correlation -pub fn correlation_return_type(arg_type: &DataType) -> Result { - if NUMERICS.contains(arg_type) { - Ok(DataType::Float64) - } else { - plan_err!("CORR does not support {arg_type}") - } -} - -/// Function return type of an average -pub fn avg_return_type(func_name: &str, arg_type: &DataType) -> Result { - match arg_type { - DataType::Decimal32(precision, scale) => { - // In the spark, the result type is DECIMAL(min(38,precision+4), min(38,scale+4)). - // Ref: https://github.com/apache/spark/blob/fcf636d9eb8d645c24be3db2d599aba2d7e2955a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala#L66 - let new_precision = DECIMAL32_MAX_PRECISION.min(*precision + 4); - let new_scale = DECIMAL32_MAX_SCALE.min(*scale + 4); - Ok(DataType::Decimal32(new_precision, new_scale)) - } - DataType::Decimal64(precision, scale) => { - // In the spark, the result type is DECIMAL(min(38,precision+4), min(38,scale+4)). - // Ref: https://github.com/apache/spark/blob/fcf636d9eb8d645c24be3db2d599aba2d7e2955a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala#L66 - let new_precision = DECIMAL64_MAX_PRECISION.min(*precision + 4); - let new_scale = DECIMAL64_MAX_SCALE.min(*scale + 4); - Ok(DataType::Decimal64(new_precision, new_scale)) - } - DataType::Decimal128(precision, scale) => { - // In the spark, the result type is DECIMAL(min(38,precision+4), min(38,scale+4)). - // Ref: https://github.com/apache/spark/blob/fcf636d9eb8d645c24be3db2d599aba2d7e2955a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala#L66 - let new_precision = DECIMAL128_MAX_PRECISION.min(*precision + 4); - let new_scale = DECIMAL128_MAX_SCALE.min(*scale + 4); - Ok(DataType::Decimal128(new_precision, new_scale)) - } - DataType::Decimal256(precision, scale) => { - // In the spark, the result type is DECIMAL(min(38,precision+4), min(38,scale+4)). - // Ref: https://github.com/apache/spark/blob/fcf636d9eb8d645c24be3db2d599aba2d7e2955a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala#L66 - let new_precision = DECIMAL256_MAX_PRECISION.min(*precision + 4); - let new_scale = DECIMAL256_MAX_SCALE.min(*scale + 4); - Ok(DataType::Decimal256(new_precision, new_scale)) - } - DataType::Duration(time_unit) => Ok(DataType::Duration(*time_unit)), - arg_type if NUMERICS.contains(arg_type) => Ok(DataType::Float64), - DataType::Dictionary(_, dict_value_type) => { - avg_return_type(func_name, dict_value_type.as_ref()) - } - other => plan_err!("{func_name} does not support {other:?}"), - } -} - -/// Internal sum type of an average -pub fn avg_sum_type(arg_type: &DataType) -> Result { - match arg_type { - DataType::Decimal32(precision, scale) => { - // In the spark, the sum type of avg is DECIMAL(min(38,precision+10), s) - let new_precision = DECIMAL32_MAX_PRECISION.min(*precision + 10); - Ok(DataType::Decimal32(new_precision, *scale)) - } - DataType::Decimal64(precision, scale) => { - // In the spark, the sum type of avg is DECIMAL(min(38,precision+10), s) - let new_precision = DECIMAL64_MAX_PRECISION.min(*precision + 10); - Ok(DataType::Decimal64(new_precision, *scale)) - } - DataType::Decimal128(precision, scale) => { - // In the spark, the sum type of avg is DECIMAL(min(38,precision+10), s) - let new_precision = DECIMAL128_MAX_PRECISION.min(*precision + 10); - Ok(DataType::Decimal128(new_precision, *scale)) - } - DataType::Decimal256(precision, scale) => { - // In Spark the sum type of avg is DECIMAL(min(38,precision+10), s) - let new_precision = DECIMAL256_MAX_PRECISION.min(*precision + 10); - Ok(DataType::Decimal256(new_precision, *scale)) - } - DataType::Duration(time_unit) => Ok(DataType::Duration(*time_unit)), - arg_type if NUMERICS.contains(arg_type) => Ok(DataType::Float64), - DataType::Dictionary(_, dict_value_type) => { - avg_sum_type(dict_value_type.as_ref()) - } - other => plan_err!("AVG does not support {other:?}"), - } -} - -pub fn is_sum_support_arg_type(arg_type: &DataType) -> bool { - match arg_type { - DataType::Dictionary(_, dict_value_type) => { - is_sum_support_arg_type(dict_value_type.as_ref()) - } - _ => matches!( - arg_type, - arg_type if NUMERICS.contains(arg_type) - || matches!(arg_type, DataType::Decimal32(_, _) | DataType::Decimal64(_, _) |DataType::Decimal128(_, _) | DataType::Decimal256(_, _)) - ), - } -} - -pub fn is_avg_support_arg_type(arg_type: &DataType) -> bool { - match arg_type { - DataType::Dictionary(_, dict_value_type) => { - is_avg_support_arg_type(dict_value_type.as_ref()) - } - _ => matches!( - arg_type, - arg_type if NUMERICS.contains(arg_type) - || matches!(arg_type, DataType::Decimal32(_, _) | DataType::Decimal64(_, _) |DataType::Decimal128(_, _) | DataType::Decimal256(_, _)) - ), - } -} - -pub fn is_variance_support_arg_type(arg_type: &DataType) -> bool { - matches!( - arg_type, - arg_type if NUMERICS.contains(arg_type) - ) -} - -pub fn is_covariance_support_arg_type(arg_type: &DataType) -> bool { - matches!( - arg_type, - arg_type if NUMERICS.contains(arg_type) - ) -} - -pub fn is_correlation_support_arg_type(arg_type: &DataType) -> bool { - matches!( - arg_type, - arg_type if NUMERICS.contains(arg_type) - ) -} - -pub fn is_integer_arg_type(arg_type: &DataType) -> bool { - arg_type.is_integer() -} - -pub fn coerce_avg_type(func_name: &str, arg_types: &[DataType]) -> Result> { - // Supported types smallint, int, bigint, real, double precision, decimal, or interval - // Refer to https://www.postgresql.org/docs/8.2/functions-aggregate.html doc - fn coerced_type(func_name: &str, data_type: &DataType) -> Result { - match &data_type { - DataType::Decimal32(p, s) => Ok(DataType::Decimal32(*p, *s)), - DataType::Decimal64(p, s) => Ok(DataType::Decimal64(*p, *s)), - DataType::Decimal128(p, s) => Ok(DataType::Decimal128(*p, *s)), - DataType::Decimal256(p, s) => Ok(DataType::Decimal256(*p, *s)), - d if d.is_numeric() => Ok(DataType::Float64), - DataType::Duration(time_unit) => Ok(DataType::Duration(*time_unit)), - DataType::Dictionary(_, v) => coerced_type(func_name, v.as_ref()), - _ => { - plan_err!( - "The function {:?} does not support inputs of type {}.", - func_name, - data_type - ) - } - } - } - Ok(vec![coerced_type(func_name, &arg_types[0])?]) -} -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn test_variance_return_data_type() -> Result<()> { - let data_type = DataType::Float64; - let result_type = variance_return_type(&data_type)?; - assert_eq!(DataType::Float64, result_type); - - let data_type = DataType::Decimal128(36, 10); - assert!(variance_return_type(&data_type).is_err()); - Ok(()) - } - - #[test] - fn test_sum_return_data_type() -> Result<()> { - let data_type = DataType::Decimal128(10, 5); - let result_type = sum_return_type(&data_type)?; - assert_eq!(DataType::Decimal128(20, 5), result_type); - - let data_type = DataType::Decimal128(36, 10); - let result_type = sum_return_type(&data_type)?; - assert_eq!(DataType::Decimal128(38, 10), result_type); - Ok(()) - } - - #[test] - fn test_covariance_return_data_type() -> Result<()> { - let data_type = DataType::Float64; - let result_type = covariance_return_type(&data_type)?; - assert_eq!(DataType::Float64, result_type); - - let data_type = DataType::Decimal128(36, 10); - assert!(covariance_return_type(&data_type).is_err()); - Ok(()) - } - - #[test] - fn test_correlation_return_data_type() -> Result<()> { - let data_type = DataType::Float64; - let result_type = correlation_return_type(&data_type)?; - assert_eq!(DataType::Float64, result_type); - - let data_type = DataType::Decimal128(36, 10); - assert!(correlation_return_type(&data_type).is_err()); - Ok(()) - } -} diff --git a/datafusion/expr-common/src/type_coercion/binary.rs b/datafusion/expr-common/src/type_coercion/binary.rs index 52bb211d9b99..122e0f987b6f 100644 --- a/datafusion/expr-common/src/type_coercion/binary.rs +++ b/datafusion/expr-common/src/type_coercion/binary.rs @@ -866,6 +866,7 @@ pub fn comparison_coercion_numeric( return Some(lhs_type.clone()); } binary_numeric_coercion(lhs_type, rhs_type) + .or_else(|| dictionary_comparison_coercion_numeric(lhs_type, rhs_type, true)) .or_else(|| string_coercion(lhs_type, rhs_type)) .or_else(|| null_coercion(lhs_type, rhs_type)) .or_else(|| string_numeric_coercion_as_numeric(lhs_type, rhs_type)) @@ -1353,38 +1354,75 @@ fn both_numeric_or_null_and_numeric(lhs_type: &DataType, rhs_type: &DataType) -> } } -/// Coercion rules for Dictionaries: the type that both lhs and rhs +/// Generic coercion rules for Dictionaries: the type that both lhs and rhs /// can be casted to for the purpose of a computation. /// /// Not all operators support dictionaries, if `preserve_dictionaries` is true -/// dictionaries will be preserved if possible -fn dictionary_comparison_coercion( +/// dictionaries will be preserved if possible. +/// +/// The `coerce_fn` parameter determines which comparison coercion function to use +/// for comparing the dictionary value types. +fn dictionary_comparison_coercion_generic( lhs_type: &DataType, rhs_type: &DataType, preserve_dictionaries: bool, + coerce_fn: fn(&DataType, &DataType) -> Option, ) -> Option { use arrow::datatypes::DataType::*; match (lhs_type, rhs_type) { ( Dictionary(_lhs_index_type, lhs_value_type), Dictionary(_rhs_index_type, rhs_value_type), - ) => comparison_coercion(lhs_value_type, rhs_value_type), + ) => coerce_fn(lhs_value_type, rhs_value_type), (d @ Dictionary(_, value_type), other_type) | (other_type, d @ Dictionary(_, value_type)) if preserve_dictionaries && value_type.as_ref() == other_type => { Some(d.clone()) } - (Dictionary(_index_type, value_type), _) => { - comparison_coercion(value_type, rhs_type) - } - (_, Dictionary(_index_type, value_type)) => { - comparison_coercion(lhs_type, value_type) - } + (Dictionary(_index_type, value_type), _) => coerce_fn(value_type, rhs_type), + (_, Dictionary(_index_type, value_type)) => coerce_fn(lhs_type, value_type), _ => None, } } +/// Coercion rules for Dictionaries: the type that both lhs and rhs +/// can be casted to for the purpose of a computation. +/// +/// Not all operators support dictionaries, if `preserve_dictionaries` is true +/// dictionaries will be preserved if possible +fn dictionary_comparison_coercion( + lhs_type: &DataType, + rhs_type: &DataType, + preserve_dictionaries: bool, +) -> Option { + dictionary_comparison_coercion_generic( + lhs_type, + rhs_type, + preserve_dictionaries, + comparison_coercion, + ) +} + +/// Coercion rules for Dictionaries with numeric preference: similar to +/// [`dictionary_comparison_coercion`] but uses [`comparison_coercion_numeric`] +/// which prefers numeric types over strings when both are present. +/// +/// This is used by [`comparison_coercion_numeric`] to maintain consistent +/// numeric-preferring semantics when dealing with dictionary types. +fn dictionary_comparison_coercion_numeric( + lhs_type: &DataType, + rhs_type: &DataType, + preserve_dictionaries: bool, +) -> Option { + dictionary_comparison_coercion_generic( + lhs_type, + rhs_type, + preserve_dictionaries, + comparison_coercion_numeric, + ) +} + /// Coercion rules for string concat. /// This is a union of string coercion rules and specified rules: /// 1. At least one side of lhs and rhs should be string type (Utf8 / LargeUtf8) diff --git a/datafusion/expr/src/expr.rs b/datafusion/expr/src/expr.rs index 282b3f6a0f55..6077b3c1e5bb 100644 --- a/datafusion/expr/src/expr.rs +++ b/datafusion/expr/src/expr.rs @@ -18,7 +18,7 @@ //! Logical Expressions: [`Expr`] use std::cmp::Ordering; -use std::collections::{BTreeMap, HashSet}; +use std::collections::HashSet; use std::fmt::{self, Display, Formatter, Write}; use std::hash::{Hash, Hasher}; use std::mem; @@ -45,6 +45,10 @@ use sqlparser::ast::{ RenameSelectItem, ReplaceSelectElement, }; +// Moved in 51.0.0 to datafusion_common +pub use datafusion_common::metadata::FieldMetadata; +use datafusion_common::metadata::ScalarAndMetadata; + // This mirrors sqlparser::ast::NullTreatment but we need our own variant // for when the sql feature is disabled. #[derive(Debug, Clone, Copy, Eq, PartialEq, Hash, Ord, PartialOrd)] @@ -421,6 +425,14 @@ impl From for Expr { } } +/// Create an [`Expr`] from an [`ScalarAndMetadata`] +impl From for Expr { + fn from(value: ScalarAndMetadata) -> Self { + let (value, metadata) = value.into_inner(); + Expr::Literal(value, metadata) + } +} + /// Create an [`Expr`] from an optional qualifier and a [`FieldRef`]. This is /// useful for creating [`Expr`] from a [`DFSchema`]. /// @@ -447,235 +459,6 @@ impl<'a> TreeNodeContainer<'a, Self> for Expr { } } -/// Literal metadata -/// -/// Stores metadata associated with a literal expressions -/// and is designed to be fast to `clone`. -/// -/// This structure is used to store metadata associated with a literal expression, and it -/// corresponds to the `metadata` field on [`Field`]. -/// -/// # Example: Create [`FieldMetadata`] from a [`Field`] -/// ``` -/// # use std::collections::HashMap; -/// # use datafusion_expr::expr::FieldMetadata; -/// # use arrow::datatypes::{Field, DataType}; -/// # let field = Field::new("c1", DataType::Int32, true) -/// # .with_metadata(HashMap::from([("foo".to_string(), "bar".to_string())])); -/// // Create a new `FieldMetadata` instance from a `Field` -/// let metadata = FieldMetadata::new_from_field(&field); -/// // There is also a `From` impl: -/// let metadata = FieldMetadata::from(&field); -/// ``` -/// -/// # Example: Update a [`Field`] with [`FieldMetadata`] -/// ``` -/// # use datafusion_expr::expr::FieldMetadata; -/// # use arrow::datatypes::{Field, DataType}; -/// # let field = Field::new("c1", DataType::Int32, true); -/// # let metadata = FieldMetadata::new_from_field(&field); -/// // Add any metadata from `FieldMetadata` to `Field` -/// let updated_field = metadata.add_to_field(field); -/// ``` -/// -#[derive(Clone, PartialEq, Eq, PartialOrd, Hash, Debug)] -pub struct FieldMetadata { - /// The inner metadata of a literal expression, which is a map of string - /// keys to string values. - /// - /// Note this is not a `HashMap` because `HashMap` does not provide - /// implementations for traits like `Debug` and `Hash`. - inner: Arc>, -} - -impl Default for FieldMetadata { - fn default() -> Self { - Self::new_empty() - } -} - -impl FieldMetadata { - /// Create a new empty metadata instance. - pub fn new_empty() -> Self { - Self { - inner: Arc::new(BTreeMap::new()), - } - } - - /// Merges two optional `FieldMetadata` instances, overwriting any existing - /// keys in `m` with keys from `n` if present. - /// - /// This function is commonly used in alias operations, particularly for literals - /// with metadata. When creating an alias expression, the metadata from the original - /// expression (such as a literal) is combined with any metadata specified on the alias. - /// - /// # Arguments - /// - /// * `m` - The first metadata (typically from the original expression like a literal) - /// * `n` - The second metadata (typically from the alias definition) - /// - /// # Merge Strategy - /// - /// - If both metadata instances exist, they are merged with `n` taking precedence - /// - Keys from `n` will overwrite keys from `m` if they have the same name - /// - If only one metadata instance exists, it is returned unchanged - /// - If neither exists, `None` is returned - /// - /// # Example usage - /// ```rust - /// use datafusion_expr::expr::FieldMetadata; - /// use std::collections::BTreeMap; - /// - /// // Create metadata for a literal expression - /// let literal_metadata = Some(FieldMetadata::from(BTreeMap::from([ - /// ("source".to_string(), "constant".to_string()), - /// ("type".to_string(), "int".to_string()), - /// ]))); - /// - /// // Create metadata for an alias - /// let alias_metadata = Some(FieldMetadata::from(BTreeMap::from([ - /// ("description".to_string(), "answer".to_string()), - /// ("source".to_string(), "user".to_string()), // This will override literal's "source" - /// ]))); - /// - /// // Merge the metadata - /// let merged = FieldMetadata::merge_options( - /// literal_metadata.as_ref(), - /// alias_metadata.as_ref(), - /// ); - /// - /// // Result contains: {"source": "user", "type": "int", "description": "answer"} - /// assert!(merged.is_some()); - /// ``` - pub fn merge_options( - m: Option<&FieldMetadata>, - n: Option<&FieldMetadata>, - ) -> Option { - match (m, n) { - (Some(m), Some(n)) => { - let mut merged = m.clone(); - merged.extend(n.clone()); - Some(merged) - } - (Some(m), None) => Some(m.clone()), - (None, Some(n)) => Some(n.clone()), - (None, None) => None, - } - } - - /// Create a new metadata instance from a `Field`'s metadata. - pub fn new_from_field(field: &Field) -> Self { - let inner = field - .metadata() - .iter() - .map(|(k, v)| (k.to_string(), v.to_string())) - .collect(); - Self { - inner: Arc::new(inner), - } - } - - /// Create a new metadata instance from a map of string keys to string values. - pub fn new(inner: BTreeMap) -> Self { - Self { - inner: Arc::new(inner), - } - } - - /// Get the inner metadata as a reference to a `BTreeMap`. - pub fn inner(&self) -> &BTreeMap { - &self.inner - } - - /// Return the inner metadata - pub fn into_inner(self) -> Arc> { - self.inner - } - - /// Adds metadata from `other` into `self`, overwriting any existing keys. - pub fn extend(&mut self, other: Self) { - if other.is_empty() { - return; - } - let other = Arc::unwrap_or_clone(other.into_inner()); - Arc::make_mut(&mut self.inner).extend(other); - } - - /// Returns true if the metadata is empty. - pub fn is_empty(&self) -> bool { - self.inner.is_empty() - } - - /// Returns the number of key-value pairs in the metadata. - pub fn len(&self) -> usize { - self.inner.len() - } - - /// Convert this `FieldMetadata` into a `HashMap` - pub fn to_hashmap(&self) -> std::collections::HashMap { - self.inner - .iter() - .map(|(k, v)| (k.to_string(), v.to_string())) - .collect() - } - - /// Updates the metadata on the Field with this metadata, if it is not empty. - pub fn add_to_field(&self, field: Field) -> Field { - if self.inner.is_empty() { - return field; - } - - field.with_metadata(self.to_hashmap()) - } -} - -impl From<&Field> for FieldMetadata { - fn from(field: &Field) -> Self { - Self::new_from_field(field) - } -} - -impl From> for FieldMetadata { - fn from(inner: BTreeMap) -> Self { - Self::new(inner) - } -} - -impl From> for FieldMetadata { - fn from(map: std::collections::HashMap) -> Self { - Self::new(map.into_iter().collect()) - } -} - -/// From reference -impl From<&std::collections::HashMap> for FieldMetadata { - fn from(map: &std::collections::HashMap) -> Self { - let inner = map - .iter() - .map(|(k, v)| (k.to_string(), v.to_string())) - .collect(); - Self::new(inner) - } -} - -/// From hashbrown map -impl From> for FieldMetadata { - fn from(map: HashMap) -> Self { - let inner = map.into_iter().collect(); - Self::new(inner) - } -} - -impl From<&HashMap> for FieldMetadata { - fn from(map: &HashMap) -> Self { - let inner = map - .into_iter() - .map(|(k, v)| (k.to_string(), v.to_string())) - .collect(); - Self::new(inner) - } -} - /// The metadata used in [`Field::metadata`]. /// /// This represents the metadata associated with an Arrow [`Field`]. The metadata consists of key-value pairs. @@ -1370,13 +1153,22 @@ pub struct Placeholder { /// The identifier of the parameter, including the leading `$` (e.g, `"$1"` or `"$foo"`) pub id: String, /// The type the parameter will be filled in with - pub data_type: Option, + pub field: Option, } impl Placeholder { /// Create a new Placeholder expression + #[deprecated(since = "51.0.0", note = "Use new_with_field instead")] pub fn new(id: String, data_type: Option) -> Self { - Self { id, data_type } + Self { + id, + field: data_type.map(|dt| Arc::new(Field::new("", dt, true))), + } + } + + /// Create a new Placeholder expression from a Field + pub fn new_with_field(id: String, field: Option) -> Self { + Self { id, field } } } @@ -1843,7 +1635,7 @@ impl Expr { /// ``` /// # use datafusion_expr::col; /// # use std::collections::HashMap; - /// # use datafusion_expr::expr::FieldMetadata; + /// # use datafusion_common::metadata::FieldMetadata; /// let metadata = HashMap::from([("key".to_string(), "value".to_string())]); /// let metadata = FieldMetadata::from(metadata); /// let expr = col("foo").alias_with_metadata("bar", Some(metadata)); @@ -1875,7 +1667,7 @@ impl Expr { /// ``` /// # use datafusion_expr::col; /// # use std::collections::HashMap; - /// # use datafusion_expr::expr::FieldMetadata; + /// # use datafusion_common::metadata::FieldMetadata; /// let metadata = HashMap::from([("key".to_string(), "value".to_string())]); /// let metadata = FieldMetadata::from(metadata); /// let expr = col("foo").alias_qualified_with_metadata(Some("tbl"), "bar", Some(metadata)); @@ -2886,19 +2678,23 @@ impl HashNode for Expr { } } -// Modifies expr if it is a placeholder with datatype of right +// Modifies expr to match the DataType, metadata, and nullability of other if it is +// a placeholder with previously unspecified type information (i.e., most placeholders) fn rewrite_placeholder(expr: &mut Expr, other: &Expr, schema: &DFSchema) -> Result<()> { - if let Expr::Placeholder(Placeholder { id: _, data_type }) = expr { - if data_type.is_none() { - let other_dt = other.get_type(schema); - match other_dt { + if let Expr::Placeholder(Placeholder { id: _, field }) = expr { + if field.is_none() { + let other_field = other.to_field(schema); + match other_field { Err(e) => { Err(e.context(format!( "Can not find type of {other} needed to infer type of {expr}" )))?; } - Ok(dt) => { - *data_type = Some(dt); + Ok((_, other_field)) => { + // We can't infer the nullability of the future parameter that might + // be bound, so ensure this is set to true + *field = + Some(other_field.as_ref().clone().with_nullable(true).into()); } } }; @@ -3715,8 +3511,8 @@ pub fn physical_name(expr: &Expr) -> Result { mod test { use crate::expr_fn::col; use crate::{ - case, lit, qualified_wildcard, wildcard, wildcard_with_options, ColumnarValue, - ScalarFunctionArgs, ScalarUDF, ScalarUDFImpl, Volatility, + case, lit, placeholder, qualified_wildcard, wildcard, wildcard_with_options, + ColumnarValue, ScalarFunctionArgs, ScalarUDF, ScalarUDFImpl, Volatility, }; use arrow::datatypes::{Field, Schema}; use sqlparser::ast; @@ -3730,15 +3526,15 @@ mod test { let param_placeholders = vec![ Expr::Placeholder(Placeholder { id: "$1".to_string(), - data_type: None, + field: None, }), Expr::Placeholder(Placeholder { id: "$2".to_string(), - data_type: None, + field: None, }), Expr::Placeholder(Placeholder { id: "$3".to_string(), - data_type: None, + field: None, }), ]; let in_list = Expr::InList(InList { @@ -3764,8 +3560,8 @@ mod test { match expr { Expr::Placeholder(placeholder) => { assert_eq!( - placeholder.data_type, - Some(DataType::Int32), + placeholder.field.unwrap().data_type(), + &DataType::Int32, "Placeholder {} should infer Int32", placeholder.id ); @@ -3789,7 +3585,7 @@ mod test { expr: Box::new(col("name")), pattern: Box::new(Expr::Placeholder(Placeholder { id: "$1".to_string(), - data_type: None, + field: None, })), negated: false, case_insensitive: false, @@ -3802,7 +3598,7 @@ mod test { match inferred_expr { Expr::Like(like) => match *like.pattern { Expr::Placeholder(placeholder) => { - assert_eq!(placeholder.data_type, Some(DataType::Utf8)); + assert_eq!(placeholder.field.unwrap().data_type(), &DataType::Utf8); } _ => panic!("Expected Placeholder"), }, @@ -3817,8 +3613,8 @@ mod test { Expr::SimilarTo(like) => match *like.pattern { Expr::Placeholder(placeholder) => { assert_eq!( - placeholder.data_type, - Some(DataType::Utf8), + placeholder.field.unwrap().data_type(), + &DataType::Utf8, "Placeholder {} should infer Utf8", placeholder.id ); @@ -3829,6 +3625,39 @@ mod test { } } + #[test] + fn infer_placeholder_with_metadata() { + // name == $1, where name is a non-nullable string + let schema = + Arc::new(Schema::new(vec![Field::new("name", DataType::Utf8, false) + .with_metadata( + [("some_key".to_string(), "some_value".to_string())].into(), + )])); + let df_schema = DFSchema::try_from(schema).unwrap(); + + let expr = binary_expr(col("name"), Operator::Eq, placeholder("$1")); + + let (inferred_expr, _) = expr.infer_placeholder_types(&df_schema).unwrap(); + match inferred_expr { + Expr::BinaryExpr(BinaryExpr { right, .. }) => match *right { + Expr::Placeholder(placeholder) => { + assert_eq!( + placeholder.field.as_ref().unwrap().data_type(), + &DataType::Utf8 + ); + assert_eq!( + placeholder.field.as_ref().unwrap().metadata(), + df_schema.field(0).metadata() + ); + // Inferred placeholder should still be nullable + assert!(placeholder.field.as_ref().unwrap().is_nullable()); + } + _ => panic!("Expected Placeholder"), + }, + _ => panic!("Expected BinaryExpr"), + } + } + #[test] fn format_case_when() -> Result<()> { let expr = case(col("a")) diff --git a/datafusion/expr/src/expr_fn.rs b/datafusion/expr/src/expr_fn.rs index 4666411dd540..c777c4978f99 100644 --- a/datafusion/expr/src/expr_fn.rs +++ b/datafusion/expr/src/expr_fn.rs @@ -119,13 +119,13 @@ pub fn ident(name: impl Into) -> Expr { /// /// ```rust /// # use datafusion_expr::{placeholder}; -/// let p = placeholder("$0"); // $0, refers to parameter 1 -/// assert_eq!(p.to_string(), "$0") +/// let p = placeholder("$1"); // $1, refers to parameter 1 +/// assert_eq!(p.to_string(), "$1") /// ``` pub fn placeholder(id: impl Into) -> Expr { Expr::Placeholder(Placeholder { id: id.into(), - data_type: None, + field: None, }) } diff --git a/datafusion/expr/src/expr_schema.rs b/datafusion/expr/src/expr_schema.rs index e803e3534130..8c557a5630f0 100644 --- a/datafusion/expr/src/expr_schema.rs +++ b/datafusion/expr/src/expr_schema.rs @@ -17,8 +17,8 @@ use super::{Between, Expr, Like}; use crate::expr::{ - AggregateFunction, AggregateFunctionParams, Alias, BinaryExpr, Cast, FieldMetadata, - InList, InSubquery, Placeholder, ScalarFunction, TryCast, Unnest, WindowFunction, + AggregateFunction, AggregateFunctionParams, Alias, BinaryExpr, Cast, InList, + InSubquery, Placeholder, ScalarFunction, TryCast, Unnest, WindowFunction, WindowFunctionParams, }; use crate::type_coercion::functions::{ @@ -28,6 +28,7 @@ use crate::udf::ReturnFieldArgs; use crate::{utils, LogicalPlan, Projection, Subquery, WindowFunctionDefinition}; use arrow::compute::can_cast_types; use arrow::datatypes::{DataType, Field, FieldRef}; +use datafusion_common::metadata::FieldMetadata; use datafusion_common::{ not_impl_err, plan_datafusion_err, plan_err, Column, DataFusionError, ExprSchema, Result, Spans, TableReference, @@ -104,9 +105,9 @@ impl ExprSchemable for Expr { fn get_type(&self, schema: &dyn ExprSchema) -> Result { match self { Expr::Alias(Alias { expr, name, .. }) => match &**expr { - Expr::Placeholder(Placeholder { data_type, .. }) => match &data_type { + Expr::Placeholder(Placeholder { field, .. }) => match &field { None => schema.data_type(&Column::from_name(name)).cloned(), - Some(dt) => Ok(dt.clone()), + Some(field) => Ok(field.data_type().clone()), }, _ => expr.get_type(schema), }, @@ -211,9 +212,9 @@ impl ExprSchemable for Expr { ) .get_result_type(), Expr::Like { .. } | Expr::SimilarTo { .. } => Ok(DataType::Boolean), - Expr::Placeholder(Placeholder { data_type, .. }) => { - if let Some(dtype) = data_type { - Ok(dtype.clone()) + Expr::Placeholder(Placeholder { field, .. }) => { + if let Some(field) = field { + Ok(field.data_type().clone()) } else { // If the placeholder's type hasn't been specified, treat it as // null (unspecified placeholders generate an error during planning) @@ -309,10 +310,12 @@ impl ExprSchemable for Expr { window_function, ) .map(|(_, nullable)| nullable), - Expr::ScalarVariable(_, _) - | Expr::TryCast { .. } - | Expr::Unnest(_) - | Expr::Placeholder(_) => Ok(true), + Expr::Placeholder(Placeholder { id: _, field }) => { + Ok(field.as_ref().map(|f| f.is_nullable()).unwrap_or(true)) + } + Expr::ScalarVariable(_, _) | Expr::TryCast { .. } | Expr::Unnest(_) => { + Ok(true) + } Expr::IsNull(_) | Expr::IsNotNull(_) | Expr::IsTrue(_) @@ -428,25 +431,11 @@ impl ExprSchemable for Expr { let field = match self { Expr::Alias(Alias { expr, - name, + name: _, metadata, .. }) => { - let field = match &**expr { - Expr::Placeholder(Placeholder { data_type, .. }) => { - match &data_type { - None => schema - .data_type_and_nullable(&Column::from_name(name)) - .map(|(d, n)| Field::new(&schema_name, d.clone(), n)), - Some(dt) => Ok(Field::new( - &schema_name, - dt.clone(), - expr.nullable(schema)?, - )), - } - } - _ => expr.to_field(schema).map(|(_, f)| f.as_ref().clone()), - }?; + let field = expr.to_field(schema).map(|(_, f)| f.as_ref().clone())?; let mut combined_metadata = expr.metadata(schema)?; if let Some(metadata) = metadata { @@ -594,6 +583,10 @@ impl ExprSchemable for Expr { .to_field(schema) .map(|(_, f)| f.as_ref().clone().with_data_type(data_type.clone())) .map(Arc::new), + Expr::Placeholder(Placeholder { + id: _, + field: Some(field), + }) => Ok(field.as_ref().clone().with_name(&schema_name).into()), Expr::Like(_) | Expr::SimilarTo(_) | Expr::Not(_) @@ -776,10 +769,12 @@ pub fn cast_subquery(subquery: Subquery, cast_to_type: &DataType) -> Result {{ @@ -905,7 +900,7 @@ mod tests { let schema = DFSchema::from_unqualified_fields( vec![meta.add_to_field(Field::new("foo", DataType::Int32, true))].into(), - std::collections::HashMap::new(), + HashMap::new(), ) .unwrap(); @@ -921,6 +916,52 @@ mod tests { assert_eq!(meta, outer_ref.metadata(&schema).unwrap()); } + #[test] + fn test_expr_placeholder() { + let schema = MockExprSchema::new(); + + let mut placeholder_meta = HashMap::new(); + placeholder_meta.insert("bar".to_string(), "buzz".to_string()); + let placeholder_meta = FieldMetadata::from(placeholder_meta); + + let expr = Expr::Placeholder(Placeholder::new_with_field( + "".to_string(), + Some( + Field::new("", DataType::Utf8, true) + .with_metadata(placeholder_meta.to_hashmap()) + .into(), + ), + )); + + assert_eq!( + expr.data_type_and_nullable(&schema).unwrap(), + (DataType::Utf8, true) + ); + assert_eq!(placeholder_meta, expr.metadata(&schema).unwrap()); + + let expr_alias = expr.alias("a placeholder by any other name"); + assert_eq!( + expr_alias.data_type_and_nullable(&schema).unwrap(), + (DataType::Utf8, true) + ); + assert_eq!(placeholder_meta, expr_alias.metadata(&schema).unwrap()); + + // Non-nullable placeholder field should remain non-nullable + let expr = Expr::Placeholder(Placeholder::new_with_field( + "".to_string(), + Some(Field::new("", DataType::Utf8, false).into()), + )); + assert_eq!( + expr.data_type_and_nullable(&schema).unwrap(), + (DataType::Utf8, false) + ); + let expr_alias = expr.alias("a placeholder by any other name"); + assert_eq!( + expr_alias.data_type_and_nullable(&schema).unwrap(), + (DataType::Utf8, false) + ); + } + #[derive(Debug)] struct MockExprSchema { field: Field, diff --git a/datafusion/expr/src/literal.rs b/datafusion/expr/src/literal.rs index c4bd43bc0a62..335d7b471f5f 100644 --- a/datafusion/expr/src/literal.rs +++ b/datafusion/expr/src/literal.rs @@ -17,9 +17,8 @@ //! Literal module contains foundational types that are used to represent literals in DataFusion. -use crate::expr::FieldMetadata; use crate::Expr; -use datafusion_common::ScalarValue; +use datafusion_common::{metadata::FieldMetadata, ScalarValue}; /// Create a literal expression pub fn lit(n: T) -> Expr { diff --git a/datafusion/expr/src/logical_plan/builder.rs b/datafusion/expr/src/logical_plan/builder.rs index 7a283b0420d3..a430add3f786 100644 --- a/datafusion/expr/src/logical_plan/builder.rs +++ b/datafusion/expr/src/logical_plan/builder.rs @@ -25,7 +25,7 @@ use std::iter::once; use std::sync::Arc; use crate::dml::CopyTo; -use crate::expr::{Alias, FieldMetadata, PlannedReplaceSelectItem, Sort as SortExpr}; +use crate::expr::{Alias, PlannedReplaceSelectItem, Sort as SortExpr}; use crate::expr_rewriter::{ coerce_plan_expr_for_schema, normalize_col, normalize_col_with_schemas_and_ambiguity_check, normalize_cols, normalize_sorts, @@ -50,9 +50,10 @@ use crate::{ use super::dml::InsertOp; use arrow::compute::can_cast_types; -use arrow::datatypes::{DataType, Field, Fields, Schema, SchemaRef}; +use arrow::datatypes::{DataType, Field, FieldRef, Fields, Schema, SchemaRef}; use datafusion_common::display::ToStringifiedPlan; use datafusion_common::file_options::file_type::FileType; +use datafusion_common::metadata::FieldMetadata; use datafusion_common::{ exec_err, get_target_functional_dependencies, internal_datafusion_err, not_impl_err, plan_datafusion_err, plan_err, Column, Constraints, DFSchema, DFSchemaRef, @@ -622,11 +623,11 @@ impl LogicalPlanBuilder { } /// Make a builder for a prepare logical plan from the builder's plan - pub fn prepare(self, name: String, data_types: Vec) -> Result { + pub fn prepare(self, name: String, fields: Vec) -> Result { Ok(Self::new(LogicalPlan::Statement(Statement::Prepare( Prepare { name, - data_types, + fields, input: self.plan, }, )))) diff --git a/datafusion/expr/src/logical_plan/plan.rs b/datafusion/expr/src/logical_plan/plan.rs index b8200ab8a48c..9541f35e3062 100644 --- a/datafusion/expr/src/logical_plan/plan.rs +++ b/datafusion/expr/src/logical_plan/plan.rs @@ -51,9 +51,10 @@ use crate::{ WindowFunctionDefinition, }; -use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; +use arrow::datatypes::{DataType, Field, FieldRef, Schema, SchemaRef}; use datafusion_common::cse::{NormalizeEq, Normalizeable}; use datafusion_common::format::ExplainFormat; +use datafusion_common::metadata::check_metadata_with_storage_equal; use datafusion_common::tree_node::{ Transformed, TreeNode, TreeNodeContainer, TreeNodeRecursion, }; @@ -1098,15 +1099,13 @@ impl LogicalPlan { })) } LogicalPlan::Statement(Statement::Prepare(Prepare { - name, - data_types, - .. + name, fields, .. })) => { self.assert_no_expressions(expr)?; let input = self.only_input(inputs)?; Ok(LogicalPlan::Statement(Statement::Prepare(Prepare { name: name.clone(), - data_types: data_types.clone(), + fields: fields.clone(), input: Arc::new(input), }))) } @@ -1282,7 +1281,7 @@ impl LogicalPlan { if let LogicalPlan::Statement(Statement::Prepare(prepare_lp)) = plan_with_values { - param_values.verify(&prepare_lp.data_types)?; + param_values.verify_fields(&prepare_lp.fields)?; // try and take ownership of the input if is not shared, clone otherwise Arc::unwrap_or_clone(prepare_lp.input) } else { @@ -1463,8 +1462,10 @@ impl LogicalPlan { let original_name = name_preserver.save(&e); let transformed_expr = e.transform_up(|e| { if let Expr::Placeholder(Placeholder { id, .. }) = e { - let value = param_values.get_placeholders_with_values(&id)?; - Ok(Transformed::yes(Expr::Literal(value, None))) + let (value, metadata) = param_values + .get_placeholders_with_values(&id)? + .into_inner(); + Ok(Transformed::yes(Expr::Literal(value, metadata))) } else { Ok(Transformed::no(e)) } @@ -1494,24 +1495,43 @@ impl LogicalPlan { } /// Walk the logical plan, find any `Placeholder` tokens, and return a map of their IDs and DataTypes + /// + /// Note that this will drop any extension or field metadata attached to parameters. Use + /// [`LogicalPlan::get_parameter_fields`] to keep extension metadata. pub fn get_parameter_types( &self, ) -> Result>, DataFusionError> { - let mut param_types: HashMap> = HashMap::new(); + let mut parameter_fields = self.get_parameter_fields()?; + Ok(parameter_fields + .drain() + .map(|(name, maybe_field)| { + (name, maybe_field.map(|field| field.data_type().clone())) + }) + .collect()) + } + + /// Walk the logical plan, find any `Placeholder` tokens, and return a map of their IDs and FieldRefs + pub fn get_parameter_fields( + &self, + ) -> Result>, DataFusionError> { + let mut param_types: HashMap> = HashMap::new(); self.apply_with_subqueries(|plan| { plan.apply_expressions(|expr| { expr.apply(|expr| { - if let Expr::Placeholder(Placeholder { id, data_type }) = expr { + if let Expr::Placeholder(Placeholder { id, field }) = expr { let prev = param_types.get(id); - match (prev, data_type) { - (Some(Some(prev)), Some(dt)) => { - if prev != dt { - plan_err!("Conflicting types for {id}")?; - } + match (prev, field) { + (Some(Some(prev)), Some(field)) => { + check_metadata_with_storage_equal( + (field.data_type(), Some(field.metadata())), + (prev.data_type(), Some(prev.metadata())), + "parameter", + &format!(": Conflicting types for id {id}"), + )?; } - (_, Some(dt)) => { - param_types.insert(id.clone(), Some(dt.clone())); + (_, Some(field)) => { + param_types.insert(id.clone(), Some(Arc::clone(field))); } _ => { param_types.insert(id.clone(), None); @@ -2753,7 +2773,8 @@ pub struct Union { impl Union { /// Constructs new Union instance deriving schema from inputs. - fn try_new(inputs: Vec>) -> Result { + /// Schema data types must match exactly. + pub fn try_new(inputs: Vec>) -> Result { let schema = Self::derive_schema_from_inputs(&inputs, false, false)?; Ok(Union { inputs, schema }) } @@ -4230,6 +4251,7 @@ mod tests { binary_expr, col, exists, in_subquery, lit, placeholder, scalar_subquery, GroupingSet, }; + use datafusion_common::metadata::ScalarAndMetadata; use datafusion_common::tree_node::{ TransformedResult, TreeNodeRewriter, TreeNodeVisitor, }; @@ -4770,6 +4792,38 @@ mod tests { .expect_err("unexpectedly succeeded to replace an invalid placeholder"); } + #[test] + fn test_replace_placeholder_mismatched_metadata() { + let schema = Schema::new(vec![Field::new("id", DataType::Int32, false)]); + + // Create a prepared statement with explicit fields that do not have metadata + let plan = table_scan(TableReference::none(), &schema, None) + .unwrap() + .filter(col("id").eq(placeholder("$1"))) + .unwrap() + .build() + .unwrap(); + let prepared_builder = LogicalPlanBuilder::new(plan) + .prepare( + "".to_string(), + vec![Field::new("", DataType::Int32, true).into()], + ) + .unwrap(); + + // Attempt to bind a parameter with metadata + let mut scalar_meta = HashMap::new(); + scalar_meta.insert("some_key".to_string(), "some_value".to_string()); + let param_values = ParamValues::List(vec![ScalarAndMetadata::new( + ScalarValue::Int32(Some(42)), + Some(scalar_meta.into()), + )]); + prepared_builder + .plan() + .clone() + .with_param_values(param_values) + .expect_err("prepared field metadata mismatch unexpectedly succeeded"); + } + #[test] fn test_nullable_schema_after_grouping_set() { let schema = Schema::new(vec![ @@ -5142,7 +5196,7 @@ mod tests { .unwrap(); // Check that the placeholder parameters have not received a DataType. - let params = plan.get_parameter_types().unwrap(); + let params = plan.get_parameter_fields().unwrap(); assert_eq!(params.len(), 1); let parameter_type = params.clone().get(placeholder_value).unwrap().clone(); diff --git a/datafusion/expr/src/logical_plan/statement.rs b/datafusion/expr/src/logical_plan/statement.rs index 6d3fe9fa75ac..bfc6b53d1136 100644 --- a/datafusion/expr/src/logical_plan/statement.rs +++ b/datafusion/expr/src/logical_plan/statement.rs @@ -15,7 +15,8 @@ // specific language governing permissions and limitations // under the License. -use arrow::datatypes::DataType; +use arrow::datatypes::FieldRef; +use datafusion_common::metadata::format_type_and_metadata; use datafusion_common::{DFSchema, DFSchemaRef}; use itertools::Itertools as _; use std::fmt::{self, Display}; @@ -108,10 +109,18 @@ impl Statement { }) => { write!(f, "SetVariable: set {variable:?} to {value:?}") } - Statement::Prepare(Prepare { - name, data_types, .. - }) => { - write!(f, "Prepare: {name:?} [{}]", data_types.iter().join(", ")) + Statement::Prepare(Prepare { name, fields, .. }) => { + write!( + f, + "Prepare: {name:?} [{}]", + fields + .iter() + .map(|f| format_type_and_metadata( + f.data_type(), + Some(f.metadata()) + )) + .join(", ") + ) } Statement::Execute(Execute { name, parameters, .. @@ -192,7 +201,7 @@ pub struct Prepare { /// The name of the statement pub name: String, /// Data types of the parameters ([`Expr::Placeholder`]) - pub data_types: Vec, + pub fields: Vec, /// The logical plan of the statements pub input: Arc, } diff --git a/datafusion/expr/src/test/function_stub.rs b/datafusion/expr/src/test/function_stub.rs index 41bc64505807..8609afeae601 100644 --- a/datafusion/expr/src/test/function_stub.rs +++ b/datafusion/expr/src/test/function_stub.rs @@ -22,13 +22,15 @@ use std::any::Any; use arrow::datatypes::{ - DataType, FieldRef, DECIMAL128_MAX_PRECISION, DECIMAL256_MAX_PRECISION, - DECIMAL32_MAX_PRECISION, DECIMAL64_MAX_PRECISION, + DataType, FieldRef, DECIMAL128_MAX_PRECISION, DECIMAL128_MAX_SCALE, + DECIMAL256_MAX_PRECISION, DECIMAL256_MAX_SCALE, DECIMAL32_MAX_PRECISION, + DECIMAL32_MAX_SCALE, DECIMAL64_MAX_PRECISION, DECIMAL64_MAX_SCALE, }; +use datafusion_common::plan_err; use datafusion_common::{exec_err, not_impl_err, utils::take_function_args, Result}; -use crate::type_coercion::aggregates::{avg_return_type, coerce_avg_type, NUMERICS}; +use crate::type_coercion::aggregates::NUMERICS; use crate::Volatility::Immutable; use crate::{ expr::AggregateFunction, @@ -488,8 +490,61 @@ impl AggregateUDFImpl for Avg { &self.signature } + fn coerce_types(&self, arg_types: &[DataType]) -> Result> { + let [args] = take_function_args(self.name(), arg_types)?; + + // Supported types smallint, int, bigint, real, double precision, decimal, or interval + // Refer to https://www.postgresql.org/docs/8.2/functions-aggregate.html doc + fn coerced_type(data_type: &DataType) -> Result { + match &data_type { + DataType::Decimal32(p, s) => Ok(DataType::Decimal32(*p, *s)), + DataType::Decimal64(p, s) => Ok(DataType::Decimal64(*p, *s)), + DataType::Decimal128(p, s) => Ok(DataType::Decimal128(*p, *s)), + DataType::Decimal256(p, s) => Ok(DataType::Decimal256(*p, *s)), + d if d.is_numeric() => Ok(DataType::Float64), + DataType::Duration(time_unit) => Ok(DataType::Duration(*time_unit)), + DataType::Dictionary(_, v) => coerced_type(v.as_ref()), + _ => { + plan_err!("Avg does not support inputs of type {data_type}.") + } + } + } + Ok(vec![coerced_type(args)?]) + } + fn return_type(&self, arg_types: &[DataType]) -> Result { - avg_return_type(self.name(), &arg_types[0]) + match &arg_types[0] { + DataType::Decimal32(precision, scale) => { + // In the spark, the result type is DECIMAL(min(38,precision+4), min(38,scale+4)). + // Ref: https://github.com/apache/spark/blob/fcf636d9eb8d645c24be3db2d599aba2d7e2955a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala#L66 + let new_precision = DECIMAL32_MAX_PRECISION.min(*precision + 4); + let new_scale = DECIMAL32_MAX_SCALE.min(*scale + 4); + Ok(DataType::Decimal32(new_precision, new_scale)) + } + DataType::Decimal64(precision, scale) => { + // In the spark, the result type is DECIMAL(min(38,precision+4), min(38,scale+4)). + // Ref: https://github.com/apache/spark/blob/fcf636d9eb8d645c24be3db2d599aba2d7e2955a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala#L66 + let new_precision = DECIMAL64_MAX_PRECISION.min(*precision + 4); + let new_scale = DECIMAL64_MAX_SCALE.min(*scale + 4); + Ok(DataType::Decimal64(new_precision, new_scale)) + } + DataType::Decimal128(precision, scale) => { + // In the spark, the result type is DECIMAL(min(38,precision+4), min(38,scale+4)). + // Ref: https://github.com/apache/spark/blob/fcf636d9eb8d645c24be3db2d599aba2d7e2955a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala#L66 + let new_precision = DECIMAL128_MAX_PRECISION.min(*precision + 4); + let new_scale = DECIMAL128_MAX_SCALE.min(*scale + 4); + Ok(DataType::Decimal128(new_precision, new_scale)) + } + DataType::Decimal256(precision, scale) => { + // In the spark, the result type is DECIMAL(min(38,precision+4), min(38,scale+4)). + // Ref: https://github.com/apache/spark/blob/fcf636d9eb8d645c24be3db2d599aba2d7e2955a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala#L66 + let new_precision = DECIMAL256_MAX_PRECISION.min(*precision + 4); + let new_scale = DECIMAL256_MAX_SCALE.min(*scale + 4); + Ok(DataType::Decimal256(new_precision, new_scale)) + } + DataType::Duration(time_unit) => Ok(DataType::Duration(*time_unit)), + _ => Ok(DataType::Float64), + } } fn accumulator(&self, _acc_args: AccumulatorArgs) -> Result> { @@ -503,8 +558,4 @@ impl AggregateUDFImpl for Avg { fn aliases(&self) -> &[String] { &self.aliases } - - fn coerce_types(&self, arg_types: &[DataType]) -> Result> { - coerce_avg_type(self.name(), arg_types) - } } diff --git a/datafusion/expr/src/udaf.rs b/datafusion/expr/src/udaf.rs index bfd699d81485..b593f8411d24 100644 --- a/datafusion/expr/src/udaf.rs +++ b/datafusion/expr/src/udaf.rs @@ -746,21 +746,52 @@ pub trait AggregateUDFImpl: Debug + DynEq + DynHash + Send + Sync { true } - /// If this function is ordered-set aggregate function, return true - /// otherwise, return false + /// If this function is an ordered-set aggregate function, return `true`. + /// Otherwise, return `false` (default). /// - /// Ordered-set aggregate functions require an explicit `ORDER BY` clause - /// because the calculation performed by these functions is dependent on the - /// specific sequence of the input rows, unlike other aggregate functions - /// like `SUM`, `AVG`, or `COUNT`. + /// Ordered-set aggregate functions allow specifying a sort order that affects + /// how the function calculates its result, unlike other aggregate functions + /// like `SUM` or `COUNT`. For example, `percentile_cont` is an ordered-set + /// aggregate function that calculates the exact percentile value from a list + /// of values; the output of calculating the `0.75` percentile depends on if + /// you're calculating on an ascending or descending list of values. /// - /// An example of an ordered-set aggregate function is `percentile_cont` - /// which computes a specific percentile value from a sorted list of values, and - /// is only meaningful when the input data is ordered. + /// Setting this to return `true` affects only SQL parsing & planning; it allows + /// use of the `WITHIN GROUP` clause to specify this order, for example: /// - /// In SQL syntax, ordered-set aggregate functions are used with the - /// `WITHIN GROUP (ORDER BY ...)` clause to specify the ordering of the input - /// data. + /// ```sql + /// -- Ascending + /// SELECT percentile_cont(0.75) WITHIN GROUP (ORDER BY c1 ASC) FROM table; + /// -- Default ordering is ascending if not explicitly specified + /// SELECT percentile_cont(0.75) WITHIN GROUP (ORDER BY c1) FROM table; + /// -- Descending + /// SELECT percentile_cont(0.75) WITHIN GROUP (ORDER BY c1 DESC) FROM table; + /// ``` + /// + /// This calculates the `0.75` percentile of the column `c1` from `table`, + /// according to the specific ordering. The column specified in the `WITHIN GROUP` + /// ordering clause is taken as the column to calculate values on; specifying + /// the `WITHIN GROUP` clause is optional so these queries are equivalent: + /// + /// ```sql + /// -- If no WITHIN GROUP is specified then default ordering is implementation + /// -- dependent; in this case ascending for percentile_cont + /// SELECT percentile_cont(c1, 0.75) FROM table; + /// SELECT percentile_cont(0.75) WITHIN GROUP (ORDER BY c1 ASC) FROM table; + /// ``` + /// + /// Aggregate UDFs can define their default ordering if the function is called + /// without the `WITHIN GROUP` clause, though a default of ascending is the + /// standard practice. + /// + /// Note that setting this to `true` does not guarantee input sort order to + /// the aggregate function; it expects the function to handle ordering the + /// input values themselves (e.g. `percentile_cont` must buffer and sort + /// the values internally). That is, DataFusion does not introduce any kind + /// of sort into the plan for these functions. + /// + /// Setting this to `false` disallows calling this function with the `WITHIN GROUP` + /// clause. fn is_ordered_set_aggregate(&self) -> bool { false } diff --git a/datafusion/expr/src/udf.rs b/datafusion/expr/src/udf.rs index d522158f7b6b..c4cd8c006d1f 100644 --- a/datafusion/expr/src/udf.rs +++ b/datafusion/expr/src/udf.rs @@ -252,7 +252,21 @@ impl ScalarUDF { Ok(result) } - /// Get the circuits of inner implementation + /// Determines which of the arguments passed to this function are evaluated eagerly + /// and which may be evaluated lazily. + /// + /// See [ScalarUDFImpl::conditional_arguments] for more information. + pub fn conditional_arguments<'a>( + &self, + args: &'a [Expr], + ) -> Option<(Vec<&'a Expr>, Vec<&'a Expr>)> { + self.inner.conditional_arguments(args) + } + + /// Returns true if some of this `exprs` subexpressions may not be evaluated + /// and thus any side effects (like divide by zero) may not be encountered. + /// + /// See [ScalarUDFImpl::short_circuits] for more information. pub fn short_circuits(&self) -> bool { self.inner.short_circuits() } @@ -532,6 +546,33 @@ pub trait ScalarUDFImpl: Debug + DynEq + DynHash + Send + Sync { /// [`DataFusionError::Internal`]: datafusion_common::DataFusionError::Internal fn return_type(&self, arg_types: &[DataType]) -> Result; + /// Create a new instance of this function with updated configuration. + /// + /// This method is called when configuration options change at runtime + /// (e.g., via `SET` statements) to allow functions that depend on + /// configuration to update themselves accordingly. + /// + /// Note the current [`ConfigOptions`] are also passed to [`Self::invoke_with_args`] so + /// this API is not needed for functions where the values may + /// depend on the current options. + /// + /// This API is useful for functions where the return + /// **type** depends on the configuration options, such as the `now()` function + /// which depends on the current timezone. + /// + /// # Arguments + /// + /// * `config` - The updated configuration options + /// + /// # Returns + /// + /// * `Some(ScalarUDF)` - A new instance of this function configured with the new settings + /// * `None` - If this function does not change with new configuration settings (the default) + /// + fn with_updated_config(&self, _config: &ConfigOptions) -> Option { + None + } + /// What type will be returned by this function, given the arguments? /// /// By default, this function calls [`Self::return_type`] with the @@ -656,10 +697,42 @@ pub trait ScalarUDFImpl: Debug + DynEq + DynHash + Send + Sync { /// /// Setting this to true prevents certain optimizations such as common /// subexpression elimination + /// + /// When overriding this function to return `true`, [ScalarUDFImpl::conditional_arguments] can also be + /// overridden to report more accurately which arguments are eagerly evaluated and which ones + /// lazily. fn short_circuits(&self) -> bool { false } + /// Determines which of the arguments passed to this function are evaluated eagerly + /// and which may be evaluated lazily. + /// + /// If this function returns `None`, all arguments are eagerly evaluated. + /// Returning `None` is a micro optimization that saves a needless `Vec` + /// allocation. + /// + /// If the function returns `Some`, returns (`eager`, `lazy`) where `eager` + /// are the arguments that are always evaluated, and `lazy` are the + /// arguments that may be evaluated lazily (i.e. may not be evaluated at all + /// in some cases). + /// + /// Implementations must ensure that the two returned `Vec`s are disjunct, + /// and that each argument from `args` is present in one the two `Vec`s. + /// + /// When overriding this function, [ScalarUDFImpl::short_circuits] must + /// be overridden to return `true`. + fn conditional_arguments<'a>( + &self, + args: &'a [Expr], + ) -> Option<(Vec<&'a Expr>, Vec<&'a Expr>)> { + if self.short_circuits() { + Some((vec![], args.iter().collect())) + } else { + None + } + } + /// Computes the output [`Interval`] for a [`ScalarUDFImpl`], given the input /// intervals. /// @@ -833,6 +906,10 @@ impl ScalarUDFImpl for AliasedScalarUDFImpl { self.inner.invoke_with_args(args) } + fn with_updated_config(&self, _config: &ConfigOptions) -> Option { + None + } + fn aliases(&self) -> &[String] { &self.aliases } @@ -845,6 +922,13 @@ impl ScalarUDFImpl for AliasedScalarUDFImpl { self.inner.simplify(args, info) } + fn conditional_arguments<'a>( + &self, + args: &'a [Expr], + ) -> Option<(Vec<&'a Expr>, Vec<&'a Expr>)> { + self.inner.conditional_arguments(args) + } + fn short_circuits(&self) -> bool { self.inner.short_circuits() } diff --git a/datafusion/functions-aggregate-common/src/utils.rs b/datafusion/functions-aggregate-common/src/utils.rs index b01f2c8629c9..7ce5f09373f5 100644 --- a/datafusion/functions-aggregate-common/src/utils.rs +++ b/datafusion/functions-aggregate-common/src/utils.rs @@ -95,6 +95,8 @@ pub struct DecimalAverager { target_mul: T::Native, /// the output precision target_precision: u8, + /// the output scale + target_scale: i8, } impl DecimalAverager { @@ -129,6 +131,7 @@ impl DecimalAverager { sum_mul, target_mul, target_precision, + target_scale, }) } else { // can't convert the lit decimal to the returned data type @@ -147,8 +150,11 @@ impl DecimalAverager { if let Ok(value) = sum.mul_checked(self.target_mul.div_wrapping(self.sum_mul)) { let new_value = value.div_wrapping(count); - let validate = - T::validate_decimal_precision(new_value, self.target_precision); + let validate = T::validate_decimal_precision( + new_value, + self.target_precision, + self.target_scale, + ); if validate.is_ok() { Ok(new_value) diff --git a/datafusion/functions-aggregate/src/approx_percentile_cont.rs b/datafusion/functions-aggregate/src/approx_percentile_cont.rs index 0deb09184b3f..668280314e8d 100644 --- a/datafusion/functions-aggregate/src/approx_percentile_cont.rs +++ b/datafusion/functions-aggregate/src/approx_percentile_cont.rs @@ -20,7 +20,7 @@ use std::fmt::{Debug, Formatter}; use std::mem::size_of_val; use std::sync::Arc; -use arrow::array::{Array, RecordBatch}; +use arrow::array::Array; use arrow::compute::{filter, is_not_null}; use arrow::datatypes::FieldRef; use arrow::{ @@ -28,19 +28,19 @@ use arrow::{ ArrayRef, Float32Array, Float64Array, Int16Array, Int32Array, Int64Array, Int8Array, UInt16Array, UInt32Array, UInt64Array, UInt8Array, }, - datatypes::{DataType, Field, Schema}, + datatypes::{DataType, Field}, }; use datafusion_common::{ - downcast_value, internal_err, not_impl_datafusion_err, not_impl_err, plan_err, - Result, ScalarValue, + downcast_value, internal_err, not_impl_err, plan_err, DataFusionError, Result, + ScalarValue, }; use datafusion_expr::expr::{AggregateFunction, Sort}; use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs}; use datafusion_expr::type_coercion::aggregates::{INTEGERS, NUMERICS}; use datafusion_expr::utils::format_state_name; use datafusion_expr::{ - Accumulator, AggregateUDFImpl, ColumnarValue, Documentation, Expr, Signature, - TypeSignature, Volatility, + Accumulator, AggregateUDFImpl, Documentation, Expr, Signature, TypeSignature, + Volatility, }; use datafusion_functions_aggregate_common::tdigest::{ TDigest, TryIntoF64, DEFAULT_MAX_SIZE, @@ -48,6 +48,8 @@ use datafusion_functions_aggregate_common::tdigest::{ use datafusion_macros::user_doc; use datafusion_physical_expr_common::physical_expr::PhysicalExpr; +use crate::utils::{get_scalar_value, validate_percentile_expr}; + create_func!(ApproxPercentileCont, approx_percentile_cont_udaf); /// Computes the approximate percentile continuous of a set of numbers @@ -164,7 +166,8 @@ impl ApproxPercentileCont { &self, args: AccumulatorArgs, ) -> Result { - let percentile = validate_input_percentile_expr(&args.exprs[1])?; + let percentile = + validate_percentile_expr(&args.exprs[1], "APPROX_PERCENTILE_CONT")?; let is_descending = args .order_bys @@ -214,45 +217,15 @@ impl ApproxPercentileCont { } } -fn get_scalar_value(expr: &Arc) -> Result { - let empty_schema = Arc::new(Schema::empty()); - let batch = RecordBatch::new_empty(Arc::clone(&empty_schema)); - if let ColumnarValue::Scalar(s) = expr.evaluate(&batch)? { - Ok(s) - } else { - internal_err!("Didn't expect ColumnarValue::Array") - } -} - -fn validate_input_percentile_expr(expr: &Arc) -> Result { - let percentile = match get_scalar_value(expr) - .map_err(|_| not_impl_datafusion_err!("Percentile value for 'APPROX_PERCENTILE_CONT' must be a literal, got: {expr}"))? { - ScalarValue::Float32(Some(value)) => { - value as f64 - } - ScalarValue::Float64(Some(value)) => { - value - } - sv => { - return not_impl_err!( - "Percentile value for 'APPROX_PERCENTILE_CONT' must be Float32 or Float64 literal (got data type {})", - sv.data_type() - ) - } - }; - - // Ensure the percentile is between 0 and 1. - if !(0.0..=1.0).contains(&percentile) { - return plan_err!( - "Percentile value must be between 0.0 and 1.0 inclusive, {percentile} is invalid" - ); - } - Ok(percentile) -} - fn validate_input_max_size_expr(expr: &Arc) -> Result { - let max_size = match get_scalar_value(expr) - .map_err(|_| not_impl_datafusion_err!("Tdigest max_size value for 'APPROX_PERCENTILE_CONT' must be a literal, got: {expr}"))? { + let scalar_value = get_scalar_value(expr).map_err(|_e| { + DataFusionError::Plan( + "Tdigest max_size value for 'APPROX_PERCENTILE_CONT' must be a literal" + .to_string(), + ) + })?; + + let max_size = match scalar_value { ScalarValue::UInt8(Some(q)) => q as usize, ScalarValue::UInt16(Some(q)) => q as usize, ScalarValue::UInt32(Some(q)) => q as usize, @@ -262,7 +235,7 @@ fn validate_input_max_size_expr(expr: &Arc) -> Result { ScalarValue::Int16(Some(q)) if q > 0 => q as usize, ScalarValue::Int8(Some(q)) if q > 0 => q as usize, sv => { - return not_impl_err!( + return plan_err!( "Tdigest max_size value for 'APPROX_PERCENTILE_CONT' must be UInt > 0 literal (got data type {}).", sv.data_type() ) diff --git a/datafusion/functions-aggregate/src/average.rs b/datafusion/functions-aggregate/src/average.rs index d007163e7c08..11960779ed18 100644 --- a/datafusion/functions-aggregate/src/average.rs +++ b/datafusion/functions-aggregate/src/average.rs @@ -27,14 +27,15 @@ use arrow::datatypes::{ i256, ArrowNativeType, DataType, Decimal128Type, Decimal256Type, Decimal32Type, Decimal64Type, DecimalType, DurationMicrosecondType, DurationMillisecondType, DurationNanosecondType, DurationSecondType, Field, FieldRef, Float64Type, TimeUnit, - UInt64Type, DECIMAL128_MAX_PRECISION, DECIMAL256_MAX_PRECISION, - DECIMAL32_MAX_PRECISION, DECIMAL64_MAX_PRECISION, + UInt64Type, DECIMAL128_MAX_PRECISION, DECIMAL128_MAX_SCALE, DECIMAL256_MAX_PRECISION, + DECIMAL256_MAX_SCALE, DECIMAL32_MAX_PRECISION, DECIMAL32_MAX_SCALE, + DECIMAL64_MAX_PRECISION, DECIMAL64_MAX_SCALE, }; +use datafusion_common::plan_err; use datafusion_common::{ exec_err, not_impl_err, utils::take_function_args, Result, ScalarValue, }; use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs}; -use datafusion_expr::type_coercion::aggregates::{avg_return_type, coerce_avg_type}; use datafusion_expr::utils::format_state_name; use datafusion_expr::Volatility::Immutable; use datafusion_expr::{ @@ -125,8 +126,61 @@ impl AggregateUDFImpl for Avg { &self.signature } + fn coerce_types(&self, arg_types: &[DataType]) -> Result> { + let [args] = take_function_args(self.name(), arg_types)?; + + // Supported types smallint, int, bigint, real, double precision, decimal, or interval + // Refer to https://www.postgresql.org/docs/8.2/functions-aggregate.html doc + fn coerced_type(data_type: &DataType) -> Result { + match &data_type { + DataType::Decimal32(p, s) => Ok(DataType::Decimal32(*p, *s)), + DataType::Decimal64(p, s) => Ok(DataType::Decimal64(*p, *s)), + DataType::Decimal128(p, s) => Ok(DataType::Decimal128(*p, *s)), + DataType::Decimal256(p, s) => Ok(DataType::Decimal256(*p, *s)), + d if d.is_numeric() => Ok(DataType::Float64), + DataType::Duration(time_unit) => Ok(DataType::Duration(*time_unit)), + DataType::Dictionary(_, v) => coerced_type(v.as_ref()), + _ => { + plan_err!("Avg does not support inputs of type {data_type}.") + } + } + } + Ok(vec![coerced_type(args)?]) + } + fn return_type(&self, arg_types: &[DataType]) -> Result { - avg_return_type(self.name(), &arg_types[0]) + match &arg_types[0] { + DataType::Decimal32(precision, scale) => { + // In the spark, the result type is DECIMAL(min(38,precision+4), min(38,scale+4)). + // Ref: https://github.com/apache/spark/blob/fcf636d9eb8d645c24be3db2d599aba2d7e2955a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala#L66 + let new_precision = DECIMAL32_MAX_PRECISION.min(*precision + 4); + let new_scale = DECIMAL32_MAX_SCALE.min(*scale + 4); + Ok(DataType::Decimal32(new_precision, new_scale)) + } + DataType::Decimal64(precision, scale) => { + // In the spark, the result type is DECIMAL(min(38,precision+4), min(38,scale+4)). + // Ref: https://github.com/apache/spark/blob/fcf636d9eb8d645c24be3db2d599aba2d7e2955a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala#L66 + let new_precision = DECIMAL64_MAX_PRECISION.min(*precision + 4); + let new_scale = DECIMAL64_MAX_SCALE.min(*scale + 4); + Ok(DataType::Decimal64(new_precision, new_scale)) + } + DataType::Decimal128(precision, scale) => { + // In the spark, the result type is DECIMAL(min(38,precision+4), min(38,scale+4)). + // Ref: https://github.com/apache/spark/blob/fcf636d9eb8d645c24be3db2d599aba2d7e2955a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala#L66 + let new_precision = DECIMAL128_MAX_PRECISION.min(*precision + 4); + let new_scale = DECIMAL128_MAX_SCALE.min(*scale + 4); + Ok(DataType::Decimal128(new_precision, new_scale)) + } + DataType::Decimal256(precision, scale) => { + // In the spark, the result type is DECIMAL(min(38,precision+4), min(38,scale+4)). + // Ref: https://github.com/apache/spark/blob/fcf636d9eb8d645c24be3db2d599aba2d7e2955a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala#L66 + let new_precision = DECIMAL256_MAX_PRECISION.min(*precision + 4); + let new_scale = DECIMAL256_MAX_SCALE.min(*scale + 4); + Ok(DataType::Decimal256(new_precision, new_scale)) + } + DataType::Duration(time_unit) => Ok(DataType::Duration(*time_unit)), + _ => Ok(DataType::Float64), + } } fn accumulator(&self, acc_args: AccumulatorArgs) -> Result> { @@ -452,11 +506,6 @@ impl AggregateUDFImpl for Avg { ReversedUDAF::Identical } - fn coerce_types(&self, arg_types: &[DataType]) -> Result> { - let [args] = take_function_args(self.name(), arg_types)?; - coerce_avg_type(self.name(), std::slice::from_ref(args)) - } - fn documentation(&self) -> Option<&Documentation> { self.doc() } diff --git a/datafusion/functions-aggregate/src/lib.rs b/datafusion/functions-aggregate/src/lib.rs index 4f282301ce5b..056cd45fa2c3 100644 --- a/datafusion/functions-aggregate/src/lib.rs +++ b/datafusion/functions-aggregate/src/lib.rs @@ -81,6 +81,7 @@ pub mod hyperloglog; pub mod median; pub mod min_max; pub mod nth_value; +pub mod percentile_cont; pub mod regr; pub mod stddev; pub mod string_agg; @@ -88,6 +89,7 @@ pub mod sum; pub mod variance; pub mod planner; +mod utils; use crate::approx_percentile_cont::approx_percentile_cont_udaf; use crate::approx_percentile_cont_with_weight::approx_percentile_cont_with_weight_udaf; @@ -123,6 +125,7 @@ pub mod expr_fn { pub use super::min_max::max; pub use super::min_max::min; pub use super::nth_value::nth_value; + pub use super::percentile_cont::percentile_cont; pub use super::regr::regr_avgx; pub use super::regr::regr_avgy; pub use super::regr::regr_count; @@ -171,6 +174,7 @@ pub fn all_default_aggregate_functions() -> Vec> { approx_distinct::approx_distinct_udaf(), approx_percentile_cont_udaf(), approx_percentile_cont_with_weight_udaf(), + percentile_cont::percentile_cont_udaf(), string_agg::string_agg_udaf(), bit_and_or_xor::bit_and_udaf(), bit_and_or_xor::bit_or_udaf(), @@ -207,13 +211,7 @@ mod tests { #[test] fn test_no_duplicate_name() -> Result<()> { let mut names = HashSet::new(); - let migrated_functions = ["array_agg", "count", "max", "min"]; for func in all_default_aggregate_functions() { - // TODO: remove this - // These functions are in intermediate migration state, skip them - if migrated_functions.contains(&func.name().to_lowercase().as_str()) { - continue; - } assert!( names.insert(func.name().to_string().to_lowercase()), "duplicate function name: {}", diff --git a/datafusion/functions-aggregate/src/min_max/min_max_bytes.rs b/datafusion/functions-aggregate/src/min_max/min_max_bytes.rs index 05321c2ff52d..30b2739c08ed 100644 --- a/datafusion/functions-aggregate/src/min_max/min_max_bytes.rs +++ b/datafusion/functions-aggregate/src/min_max/min_max_bytes.rs @@ -20,7 +20,8 @@ use arrow::array::{ LargeBinaryBuilder, LargeStringBuilder, StringBuilder, StringViewBuilder, }; use arrow::datatypes::DataType; -use datafusion_common::{internal_err, Result}; +use datafusion_common::hash_map::Entry; +use datafusion_common::{internal_err, HashMap, Result}; use datafusion_expr::{EmitTo, GroupsAccumulator}; use datafusion_functions_aggregate_common::aggregate::groups_accumulator::nulls::apply_filter_as_nulls; use std::mem::size_of; @@ -391,14 +392,6 @@ struct MinMaxBytesState { total_data_bytes: usize, } -#[derive(Debug, Clone, Copy)] -enum MinMaxLocation<'a> { - /// the min/max value is stored in the existing `min_max` array - ExistingMinMax, - /// the min/max value is stored in the input array at the given index - Input(&'a [u8]), -} - /// Implement the MinMaxBytesAccumulator with a comparison function /// for comparing strings impl MinMaxBytesState { @@ -450,7 +443,7 @@ impl MinMaxBytesState { // Minimize value copies by calculating the new min/maxes for each group // in this batch (either the existing min/max or the new input value) // and updating the owned values in `self.min_maxes` at most once - let mut locations = vec![MinMaxLocation::ExistingMinMax; total_num_groups]; + let mut locations = HashMap::::with_capacity(group_indices.len()); // Figure out the new min value for each group for (new_val, group_index) in iter.into_iter().zip(group_indices.iter()) { @@ -459,32 +452,29 @@ impl MinMaxBytesState { continue; // skip nulls }; - let existing_val = match locations[group_index] { - // previous input value was the min/max, so compare it - MinMaxLocation::Input(existing_val) => existing_val, - MinMaxLocation::ExistingMinMax => { - let Some(existing_val) = self.min_max[group_index].as_ref() else { - // no existing min/max, so this is the new min/max - locations[group_index] = MinMaxLocation::Input(new_val); - continue; - }; - existing_val.as_ref() + match locations.entry(group_index) { + Entry::Occupied(mut occupied_entry) => { + if cmp(new_val, occupied_entry.get()) { + occupied_entry.insert(new_val); + } + } + Entry::Vacant(vacant_entry) => { + if let Some(old_val) = self.min_max[group_index].as_ref() { + if cmp(new_val, old_val) { + vacant_entry.insert(new_val); + } + } else { + vacant_entry.insert(new_val); + } } }; - - // Compare the new value to the existing value, replacing if necessary - if cmp(new_val, existing_val) { - locations[group_index] = MinMaxLocation::Input(new_val); - } } // Update self.min_max with any new min/max values we found in the input - for (group_index, location) in locations.iter().enumerate() { - match location { - MinMaxLocation::ExistingMinMax => {} - MinMaxLocation::Input(new_val) => self.set_value(group_index, new_val), - } + for (group_index, location) in locations.iter() { + self.set_value(*group_index, location); } + Ok(()) } diff --git a/datafusion/functions-aggregate/src/percentile_cont.rs b/datafusion/functions-aggregate/src/percentile_cont.rs new file mode 100644 index 000000000000..8e9e9a3144d4 --- /dev/null +++ b/datafusion/functions-aggregate/src/percentile_cont.rs @@ -0,0 +1,814 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::fmt::{Debug, Formatter}; +use std::mem::{size_of, size_of_val}; +use std::sync::Arc; + +use arrow::array::{ + ArrowNumericType, BooleanArray, ListArray, PrimitiveArray, PrimitiveBuilder, +}; +use arrow::buffer::{OffsetBuffer, ScalarBuffer}; +use arrow::{ + array::{Array, ArrayRef, AsArray}, + datatypes::{ + ArrowNativeType, DataType, Decimal128Type, Decimal256Type, Decimal32Type, + Decimal64Type, Field, FieldRef, Float16Type, Float32Type, Float64Type, + }, +}; + +use arrow::array::ArrowNativeTypeOp; + +use datafusion_common::{ + internal_datafusion_err, internal_err, plan_err, DataFusionError, HashSet, Result, + ScalarValue, +}; +use datafusion_expr::expr::{AggregateFunction, Sort}; +use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs}; +use datafusion_expr::type_coercion::aggregates::NUMERICS; +use datafusion_expr::utils::format_state_name; +use datafusion_expr::{ + Accumulator, AggregateUDFImpl, Documentation, Expr, Signature, TypeSignature, + Volatility, +}; +use datafusion_expr::{EmitTo, GroupsAccumulator}; +use datafusion_functions_aggregate_common::aggregate::groups_accumulator::accumulate::accumulate; +use datafusion_functions_aggregate_common::aggregate::groups_accumulator::nulls::filtered_null_mask; +use datafusion_functions_aggregate_common::utils::Hashable; +use datafusion_macros::user_doc; + +use crate::utils::validate_percentile_expr; + +/// Precision multiplier for linear interpolation calculations. +/// +/// This value of 1,000,000 was chosen to balance precision with overflow safety: +/// - Provides 6 decimal places of precision for the fractional component +/// - Small enough to avoid overflow when multiplied with typical numeric values +/// - Sufficient precision for most statistical applications +/// +/// The interpolation formula: `lower + (upper - lower) * fraction` +/// is computed as: `lower + ((upper - lower) * (fraction * PRECISION)) / PRECISION` +/// to avoid floating-point operations on integer types while maintaining precision. +const INTERPOLATION_PRECISION: usize = 1_000_000; + +create_func!(PercentileCont, percentile_cont_udaf); + +/// Computes the exact percentile continuous of a set of numbers +pub fn percentile_cont(order_by: Sort, percentile: Expr) -> Expr { + let expr = order_by.expr.clone(); + let args = vec![expr, percentile]; + + Expr::AggregateFunction(AggregateFunction::new_udf( + percentile_cont_udaf(), + args, + false, + None, + vec![order_by], + None, + )) +} + +#[user_doc( + doc_section(label = "General Functions"), + description = "Returns the exact percentile of input values, interpolating between values if needed.", + syntax_example = "percentile_cont(percentile) WITHIN GROUP (ORDER BY expression)", + sql_example = r#"```sql +> SELECT percentile_cont(0.75) WITHIN GROUP (ORDER BY column_name) FROM table_name; ++----------------------------------------------------------+ +| percentile_cont(0.75) WITHIN GROUP (ORDER BY column_name) | ++----------------------------------------------------------+ +| 45.5 | ++----------------------------------------------------------+ +``` + +An alternate syntax is also supported: +```sql +> SELECT percentile_cont(column_name, 0.75) FROM table_name; ++---------------------------------------+ +| percentile_cont(column_name, 0.75) | ++---------------------------------------+ +| 45.5 | ++---------------------------------------+ +```"#, + standard_argument(name = "expression", prefix = "The"), + argument( + name = "percentile", + description = "Percentile to compute. Must be a float value between 0 and 1 (inclusive)." + ) +)] +/// PERCENTILE_CONT aggregate expression. This uses an exact calculation and stores all values +/// in memory before computing the result. If an approximation is sufficient then +/// APPROX_PERCENTILE_CONT provides a much more efficient solution. +/// +/// If using the distinct variation, the memory usage will be similarly high if the +/// cardinality is high as it stores all distinct values in memory before computing the +/// result, but if cardinality is low then memory usage will also be lower. +#[derive(PartialEq, Eq, Hash)] +pub struct PercentileCont { + signature: Signature, + aliases: Vec, +} + +impl Debug for PercentileCont { + fn fmt(&self, f: &mut Formatter) -> std::fmt::Result { + f.debug_struct("PercentileCont") + .field("name", &self.name()) + .field("signature", &self.signature) + .finish() + } +} + +impl Default for PercentileCont { + fn default() -> Self { + Self::new() + } +} + +impl PercentileCont { + pub fn new() -> Self { + let mut variants = Vec::with_capacity(NUMERICS.len()); + // Accept any numeric value paired with a float64 percentile + for num in NUMERICS { + variants.push(TypeSignature::Exact(vec![num.clone(), DataType::Float64])); + } + Self { + signature: Signature::one_of(variants, Volatility::Immutable), + aliases: vec![String::from("quantile_cont")], + } + } + + fn create_accumulator(&self, args: AccumulatorArgs) -> Result> { + let percentile = validate_percentile_expr(&args.exprs[1], "PERCENTILE_CONT")?; + + let is_descending = args + .order_bys + .first() + .map(|sort_expr| sort_expr.options.descending) + .unwrap_or(false); + + let percentile = if is_descending { + 1.0 - percentile + } else { + percentile + }; + + macro_rules! helper { + ($t:ty, $dt:expr) => { + if args.is_distinct { + Ok(Box::new(DistinctPercentileContAccumulator::<$t> { + data_type: $dt.clone(), + distinct_values: HashSet::new(), + percentile, + })) + } else { + Ok(Box::new(PercentileContAccumulator::<$t> { + data_type: $dt.clone(), + all_values: vec![], + percentile, + })) + } + }; + } + + let input_dt = args.exprs[0].data_type(args.schema)?; + match input_dt { + // For integer types, use Float64 internally since percentile_cont returns Float64 + DataType::Int8 + | DataType::Int16 + | DataType::Int32 + | DataType::Int64 + | DataType::UInt8 + | DataType::UInt16 + | DataType::UInt32 + | DataType::UInt64 => helper!(Float64Type, DataType::Float64), + DataType::Float16 => helper!(Float16Type, input_dt), + DataType::Float32 => helper!(Float32Type, input_dt), + DataType::Float64 => helper!(Float64Type, input_dt), + DataType::Decimal32(_, _) => helper!(Decimal32Type, input_dt), + DataType::Decimal64(_, _) => helper!(Decimal64Type, input_dt), + DataType::Decimal128(_, _) => helper!(Decimal128Type, input_dt), + DataType::Decimal256(_, _) => helper!(Decimal256Type, input_dt), + _ => Err(DataFusionError::NotImplemented(format!( + "PercentileContAccumulator not supported for {} with {}", + args.name, input_dt, + ))), + } + } +} + +impl AggregateUDFImpl for PercentileCont { + fn as_any(&self) -> &dyn std::any::Any { + self + } + + fn name(&self) -> &str { + "percentile_cont" + } + + fn aliases(&self) -> &[String] { + &self.aliases + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, arg_types: &[DataType]) -> Result { + if !arg_types[0].is_numeric() { + return plan_err!("percentile_cont requires numeric input types"); + } + // PERCENTILE_CONT performs linear interpolation and should return a float type + // For integer inputs, return Float64 (matching PostgreSQL/DuckDB behavior) + // For float inputs, preserve the float type + match &arg_types[0] { + DataType::Float16 | DataType::Float32 | DataType::Float64 => { + Ok(arg_types[0].clone()) + } + DataType::Decimal32(_, _) + | DataType::Decimal64(_, _) + | DataType::Decimal128(_, _) + | DataType::Decimal256(_, _) => Ok(arg_types[0].clone()), + DataType::UInt8 + | DataType::UInt16 + | DataType::UInt32 + | DataType::UInt64 + | DataType::Int8 + | DataType::Int16 + | DataType::Int32 + | DataType::Int64 => Ok(DataType::Float64), + // Shouldn't happen due to signature check, but just in case + dt => plan_err!( + "percentile_cont does not support input type {}, must be numeric", + dt + ), + } + } + + fn state_fields(&self, args: StateFieldsArgs) -> Result> { + //Intermediate state is a list of the elements we have collected so far + let input_type = args.input_fields[0].data_type().clone(); + // For integer types, we store as Float64 internally + let storage_type = match &input_type { + DataType::Int8 + | DataType::Int16 + | DataType::Int32 + | DataType::Int64 + | DataType::UInt8 + | DataType::UInt16 + | DataType::UInt32 + | DataType::UInt64 => DataType::Float64, + _ => input_type, + }; + + let field = Field::new_list_field(storage_type, true); + let state_name = if args.is_distinct { + "distinct_percentile_cont" + } else { + "percentile_cont" + }; + + Ok(vec![Field::new( + format_state_name(args.name, state_name), + DataType::List(Arc::new(field)), + true, + ) + .into()]) + } + + fn accumulator(&self, acc_args: AccumulatorArgs) -> Result> { + self.create_accumulator(acc_args) + } + + fn groups_accumulator_supported(&self, args: AccumulatorArgs) -> bool { + !args.is_distinct + } + + fn create_groups_accumulator( + &self, + args: AccumulatorArgs, + ) -> Result> { + let num_args = args.exprs.len(); + if num_args != 2 { + return internal_err!( + "percentile_cont should have 2 args, but found num args:{}", + args.exprs.len() + ); + } + + let percentile = validate_percentile_expr(&args.exprs[1], "PERCENTILE_CONT")?; + + let is_descending = args + .order_bys + .first() + .map(|sort_expr| sort_expr.options.descending) + .unwrap_or(false); + + let percentile = if is_descending { + 1.0 - percentile + } else { + percentile + }; + + macro_rules! helper { + ($t:ty, $dt:expr) => { + Ok(Box::new(PercentileContGroupsAccumulator::<$t>::new( + $dt, percentile, + ))) + }; + } + + let input_dt = args.exprs[0].data_type(args.schema)?; + match input_dt { + // For integer types, use Float64 internally since percentile_cont returns Float64 + DataType::Int8 + | DataType::Int16 + | DataType::Int32 + | DataType::Int64 + | DataType::UInt8 + | DataType::UInt16 + | DataType::UInt32 + | DataType::UInt64 => helper!(Float64Type, DataType::Float64), + DataType::Float16 => helper!(Float16Type, input_dt), + DataType::Float32 => helper!(Float32Type, input_dt), + DataType::Float64 => helper!(Float64Type, input_dt), + DataType::Decimal32(_, _) => helper!(Decimal32Type, input_dt), + DataType::Decimal64(_, _) => helper!(Decimal64Type, input_dt), + DataType::Decimal128(_, _) => helper!(Decimal128Type, input_dt), + DataType::Decimal256(_, _) => helper!(Decimal256Type, input_dt), + _ => Err(DataFusionError::NotImplemented(format!( + "PercentileContGroupsAccumulator not supported for {} with {}", + args.name, input_dt, + ))), + } + } + + fn supports_null_handling_clause(&self) -> bool { + false + } + + fn is_ordered_set_aggregate(&self) -> bool { + true + } + + fn documentation(&self) -> Option<&Documentation> { + self.doc() + } +} + +/// The percentile_cont accumulator accumulates the raw input values +/// as native types. +/// +/// The intermediate state is represented as a List of scalar values updated by +/// `merge_batch` and a `Vec` of native values that are converted to scalar values +/// in the final evaluation step so that we avoid expensive conversions and +/// allocations during `update_batch`. +struct PercentileContAccumulator { + data_type: DataType, + all_values: Vec, + percentile: f64, +} + +impl Debug for PercentileContAccumulator { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + write!( + f, + "PercentileContAccumulator({}, percentile={})", + self.data_type, self.percentile + ) + } +} + +impl Accumulator for PercentileContAccumulator { + fn state(&mut self) -> Result> { + // Convert `all_values` to `ListArray` and return a single List ScalarValue + + // Build offsets + let offsets = + OffsetBuffer::new(ScalarBuffer::from(vec![0, self.all_values.len() as i32])); + + // Build inner array + let values_array = PrimitiveArray::::new( + ScalarBuffer::from(std::mem::take(&mut self.all_values)), + None, + ) + .with_data_type(self.data_type.clone()); + + // Build the result list array + let list_array = ListArray::new( + Arc::new(Field::new_list_field(self.data_type.clone(), true)), + offsets, + Arc::new(values_array), + None, + ); + + Ok(vec![ScalarValue::List(Arc::new(list_array))]) + } + + fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { + // Cast to target type if needed (e.g., integer to Float64) + let values = if values[0].data_type() != &self.data_type { + arrow::compute::cast(&values[0], &self.data_type)? + } else { + Arc::clone(&values[0]) + }; + + let values = values.as_primitive::(); + self.all_values.reserve(values.len() - values.null_count()); + self.all_values.extend(values.iter().flatten()); + Ok(()) + } + + fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { + let array = states[0].as_list::(); + for v in array.iter().flatten() { + self.update_batch(&[v])? + } + Ok(()) + } + + fn evaluate(&mut self) -> Result { + let d = std::mem::take(&mut self.all_values); + let value = calculate_percentile::(d, self.percentile); + ScalarValue::new_primitive::(value, &self.data_type) + } + + fn size(&self) -> usize { + size_of_val(self) + self.all_values.capacity() * size_of::() + } +} + +/// The percentile_cont groups accumulator accumulates the raw input values +/// +/// For calculating the exact percentile of groups, we need to store all values +/// of groups before final evaluation. +/// So values in each group will be stored in a `Vec`, and the total group values +/// will be actually organized as a `Vec>`. +/// +#[derive(Debug)] +struct PercentileContGroupsAccumulator { + data_type: DataType, + group_values: Vec>, + percentile: f64, +} + +impl PercentileContGroupsAccumulator { + pub fn new(data_type: DataType, percentile: f64) -> Self { + Self { + data_type, + group_values: Vec::new(), + percentile, + } + } +} + +impl GroupsAccumulator + for PercentileContGroupsAccumulator +{ + fn update_batch( + &mut self, + values: &[ArrayRef], + group_indices: &[usize], + opt_filter: Option<&BooleanArray>, + total_num_groups: usize, + ) -> Result<()> { + // For ordered-set aggregates, we only care about the ORDER BY column (first element) + // The percentile parameter is already stored in self.percentile + + // Cast to target type if needed (e.g., integer to Float64) + let values_array = if values[0].data_type() != &self.data_type { + arrow::compute::cast(&values[0], &self.data_type)? + } else { + Arc::clone(&values[0]) + }; + + let values = values_array.as_primitive::(); + + // Push the `not nulls + not filtered` row into its group + self.group_values.resize(total_num_groups, Vec::new()); + accumulate( + group_indices, + values, + opt_filter, + |group_index, new_value| { + self.group_values[group_index].push(new_value); + }, + ); + + Ok(()) + } + + fn merge_batch( + &mut self, + values: &[ArrayRef], + group_indices: &[usize], + // Since aggregate filter should be applied in partial stage, in final stage there should be no filter + _opt_filter: Option<&BooleanArray>, + total_num_groups: usize, + ) -> Result<()> { + assert_eq!(values.len(), 1, "one argument to merge_batch"); + + let input_group_values = values[0].as_list::(); + + // Ensure group values big enough + self.group_values.resize(total_num_groups, Vec::new()); + + // Extend values to related groups + group_indices + .iter() + .zip(input_group_values.iter()) + .for_each(|(&group_index, values_opt)| { + if let Some(values) = values_opt { + let values = values.as_primitive::(); + self.group_values[group_index].extend(values.values().iter()); + } + }); + + Ok(()) + } + + fn state(&mut self, emit_to: EmitTo) -> Result> { + // Emit values + let emit_group_values = emit_to.take_needed(&mut self.group_values); + + // Build offsets + let mut offsets = Vec::with_capacity(self.group_values.len() + 1); + offsets.push(0); + let mut cur_len = 0_i32; + for group_value in &emit_group_values { + cur_len += group_value.len() as i32; + offsets.push(cur_len); + } + let offsets = OffsetBuffer::new(ScalarBuffer::from(offsets)); + + // Build inner array + let flatten_group_values = + emit_group_values.into_iter().flatten().collect::>(); + let group_values_array = + PrimitiveArray::::new(ScalarBuffer::from(flatten_group_values), None) + .with_data_type(self.data_type.clone()); + + // Build the result list array + let result_list_array = ListArray::new( + Arc::new(Field::new_list_field(self.data_type.clone(), true)), + offsets, + Arc::new(group_values_array), + None, + ); + + Ok(vec![Arc::new(result_list_array)]) + } + + fn evaluate(&mut self, emit_to: EmitTo) -> Result { + // Emit values + let emit_group_values = emit_to.take_needed(&mut self.group_values); + + // Calculate percentile for each group + let mut evaluate_result_builder = + PrimitiveBuilder::::new().with_data_type(self.data_type.clone()); + for values in emit_group_values { + let value = calculate_percentile::(values, self.percentile); + evaluate_result_builder.append_option(value); + } + + Ok(Arc::new(evaluate_result_builder.finish())) + } + + fn convert_to_state( + &self, + values: &[ArrayRef], + opt_filter: Option<&BooleanArray>, + ) -> Result> { + assert_eq!(values.len(), 1, "one argument to merge_batch"); + + // Cast to target type if needed (e.g., integer to Float64) + let values_array = if values[0].data_type() != &self.data_type { + arrow::compute::cast(&values[0], &self.data_type)? + } else { + Arc::clone(&values[0]) + }; + + let input_array = values_array.as_primitive::(); + + // Directly convert the input array to states, each row will be + // seen as a respective group. + // For detail, the `input_array` will be converted to a `ListArray`. + // And if row is `not null + not filtered`, it will be converted to a list + // with only one element; otherwise, this row in `ListArray` will be set + // to null. + + // Reuse values buffer in `input_array` to build `values` in `ListArray` + let values = PrimitiveArray::::new(input_array.values().clone(), None) + .with_data_type(self.data_type.clone()); + + // `offsets` in `ListArray`, each row as a list element + let offset_end = i32::try_from(input_array.len()).map_err(|e| { + internal_datafusion_err!( + "cast array_len to i32 failed in convert_to_state of group percentile_cont, err:{e:?}" + ) + })?; + let offsets = (0..=offset_end).collect::>(); + // Safety: The offsets vector is constructed as a sequential range from 0 to input_array.len(), + // which guarantees all OffsetBuffer invariants: + // 1. Offsets are monotonically increasing (each element is prev + 1) + // 2. No offset exceeds the values array length (max offset = input_array.len()) + // 3. First offset is 0 and last offset equals the total length + // Therefore new_unchecked is safe to use here. + let offsets = unsafe { OffsetBuffer::new_unchecked(ScalarBuffer::from(offsets)) }; + + // `nulls` for converted `ListArray` + let nulls = filtered_null_mask(opt_filter, input_array); + + let converted_list_array = ListArray::new( + Arc::new(Field::new_list_field(self.data_type.clone(), true)), + offsets, + Arc::new(values), + nulls, + ); + + Ok(vec![Arc::new(converted_list_array)]) + } + + fn supports_convert_to_state(&self) -> bool { + true + } + + fn size(&self) -> usize { + self.group_values + .iter() + .map(|values| values.capacity() * size_of::()) + .sum::() + // account for size of self.group_values too + + self.group_values.capacity() * size_of::>() + } +} + +/// The distinct percentile_cont accumulator accumulates the raw input values +/// using a HashSet to eliminate duplicates. +/// +/// The intermediate state is represented as a List of scalar values updated by +/// `merge_batch` and a `Vec` of `ArrayRef` that are converted to scalar values +/// in the final evaluation step so that we avoid expensive conversions and +/// allocations during `update_batch`. +struct DistinctPercentileContAccumulator { + data_type: DataType, + distinct_values: HashSet>, + percentile: f64, +} + +impl Debug for DistinctPercentileContAccumulator { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + write!( + f, + "DistinctPercentileContAccumulator({}, percentile={})", + self.data_type, self.percentile + ) + } +} + +impl Accumulator for DistinctPercentileContAccumulator { + fn state(&mut self) -> Result> { + let all_values = self + .distinct_values + .iter() + .map(|x| ScalarValue::new_primitive::(Some(x.0), &self.data_type)) + .collect::>>()?; + + let arr = ScalarValue::new_list_nullable(&all_values, &self.data_type); + Ok(vec![ScalarValue::List(arr)]) + } + + fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { + if values.is_empty() { + return Ok(()); + } + + // Cast to target type if needed (e.g., integer to Float64) + let values = if values[0].data_type() != &self.data_type { + arrow::compute::cast(&values[0], &self.data_type)? + } else { + Arc::clone(&values[0]) + }; + + let array = values.as_primitive::(); + match array.nulls().filter(|x| x.null_count() > 0) { + Some(n) => { + for idx in n.valid_indices() { + self.distinct_values.insert(Hashable(array.value(idx))); + } + } + None => array.values().iter().for_each(|x| { + self.distinct_values.insert(Hashable(*x)); + }), + } + Ok(()) + } + + fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { + let array = states[0].as_list::(); + for v in array.iter().flatten() { + self.update_batch(&[v])? + } + Ok(()) + } + + fn evaluate(&mut self) -> Result { + let d = std::mem::take(&mut self.distinct_values) + .into_iter() + .map(|v| v.0) + .collect::>(); + let value = calculate_percentile::(d, self.percentile); + ScalarValue::new_primitive::(value, &self.data_type) + } + + fn size(&self) -> usize { + size_of_val(self) + self.distinct_values.capacity() * size_of::() + } +} + +/// Calculate the percentile value for a given set of values. +/// This function performs an exact calculation by sorting all values. +/// +/// The percentile is calculated using linear interpolation between closest ranks. +/// For percentile p and n values: +/// - If p * (n-1) is an integer, return the value at that position +/// - Otherwise, interpolate between the two closest values +fn calculate_percentile( + mut values: Vec, + percentile: f64, +) -> Option { + let cmp = |x: &T::Native, y: &T::Native| x.compare(*y); + + let len = values.len(); + if len == 0 { + None + } else if len == 1 { + Some(values[0]) + } else if percentile == 0.0 { + // Get minimum value + Some( + *values + .iter() + .min_by(|a, b| cmp(a, b)) + .expect("we checked for len > 0 a few lines above"), + ) + } else if percentile == 1.0 { + // Get maximum value + Some( + *values + .iter() + .max_by(|a, b| cmp(a, b)) + .expect("we checked for len > 0 a few lines above"), + ) + } else { + // Calculate the index using the formula: p * (n - 1) + let index = percentile * ((len - 1) as f64); + let lower_index = index.floor() as usize; + let upper_index = index.ceil() as usize; + + if lower_index == upper_index { + // Exact index, return the value at that position + let (_, value, _) = values.select_nth_unstable_by(lower_index, cmp); + Some(*value) + } else { + // Need to interpolate between two values + // First, partition at lower_index to get the lower value + let (_, lower_value, _) = values.select_nth_unstable_by(lower_index, cmp); + let lower_value = *lower_value; + + // Then partition at upper_index to get the upper value + let (_, upper_value, _) = values.select_nth_unstable_by(upper_index, cmp); + let upper_value = *upper_value; + + // Linear interpolation using wrapping arithmetic + // We use wrapping operations here (matching the approach in median.rs) because: + // 1. Both values come from the input data, so diff is bounded by the value range + // 2. fraction is between 0 and 1, and INTERPOLATION_PRECISION is small enough + // to prevent overflow when combined with typical numeric ranges + // 3. The result is guaranteed to be between lower_value and upper_value + // 4. For floating-point types, wrapping ops behave the same as standard ops + let fraction = index - (lower_index as f64); + let diff = upper_value.sub_wrapping(lower_value); + let interpolated = lower_value.add_wrapping( + diff.mul_wrapping(T::Native::usize_as( + (fraction * INTERPOLATION_PRECISION as f64) as usize, + )) + .div_wrapping(T::Native::usize_as(INTERPOLATION_PRECISION)), + ); + Some(interpolated) + } + } +} diff --git a/datafusion/functions-aggregate/src/utils.rs b/datafusion/functions-aggregate/src/utils.rs new file mode 100644 index 000000000000..c058b64f9572 --- /dev/null +++ b/datafusion/functions-aggregate/src/utils.rs @@ -0,0 +1,72 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::sync::Arc; + +use arrow::array::RecordBatch; +use arrow::datatypes::Schema; +use datafusion_common::{internal_err, plan_err, DataFusionError, Result, ScalarValue}; +use datafusion_expr::ColumnarValue; +use datafusion_physical_expr_common::physical_expr::PhysicalExpr; + +/// Evaluates a physical expression to extract its scalar value. +/// +/// This is used to extract constant values from expressions (like percentile parameters) +/// by evaluating them against an empty record batch. +pub(crate) fn get_scalar_value(expr: &Arc) -> Result { + let empty_schema = Arc::new(Schema::empty()); + let batch = RecordBatch::new_empty(Arc::clone(&empty_schema)); + if let ColumnarValue::Scalar(s) = expr.evaluate(&batch)? { + Ok(s) + } else { + internal_err!("Didn't expect ColumnarValue::Array") + } +} + +/// Validates that a percentile expression is a literal float value between 0.0 and 1.0. +/// +/// Used by both `percentile_cont` and `approx_percentile_cont` to validate their +/// percentile parameters. +pub(crate) fn validate_percentile_expr( + expr: &Arc, + fn_name: &str, +) -> Result { + let scalar_value = get_scalar_value(expr).map_err(|_e| { + DataFusionError::Plan(format!( + "Percentile value for '{fn_name}' must be a literal" + )) + })?; + + let percentile = match scalar_value { + ScalarValue::Float32(Some(value)) => value as f64, + ScalarValue::Float64(Some(value)) => value, + sv => { + return plan_err!( + "Percentile value for '{fn_name}' must be Float32 or Float64 literal (got data type {})", + sv.data_type() + ) + } + }; + + // Ensure the percentile is between 0 and 1. + if !(0.0..=1.0).contains(&percentile) { + return plan_err!( + "Percentile value must be between 0.0 and 1.0 inclusive, {percentile} is invalid" + ); + } + Ok(percentile) +} diff --git a/datafusion/functions-nested/src/array_has.rs b/datafusion/functions-nested/src/array_has.rs index f34fea0c4ba0..080b2f16d92f 100644 --- a/datafusion/functions-nested/src/array_has.rs +++ b/datafusion/functions-nested/src/array_has.rs @@ -132,23 +132,26 @@ impl ScalarUDFImpl for ArrayHas { // if the haystack is a constant list, we can use an inlist expression which is more // efficient because the haystack is not varying per-row match haystack { + Expr::Literal(scalar, _) if scalar.is_null() => { + return Ok(ExprSimplifyResult::Simplified(Expr::Literal( + ScalarValue::Boolean(None), + None, + ))) + } Expr::Literal( // FixedSizeList gets coerced to List scalar @ ScalarValue::List(_) | scalar @ ScalarValue::LargeList(_), _, ) => { - let array = scalar.to_array().unwrap(); // guarantee of ScalarValue if let Ok(scalar_values) = - ScalarValue::convert_array_to_scalar_vec(&array) + ScalarValue::convert_array_to_scalar_vec(&scalar.to_array()?) { assert_eq!(scalar_values.len(), 1); let list = scalar_values .into_iter() - // If the vec is a singular null, `list` will be empty due to this flatten(). - // It would be more clear if we handled the None separately, but this is more performant. .flatten() .flatten() - .map(|v| Expr::Literal(v.clone(), None)) + .map(|v| Expr::Literal(v, None)) .collect(); return Ok(ExprSimplifyResult::Simplified(in_list( @@ -178,6 +181,12 @@ impl ScalarUDFImpl for ArrayHas { args: datafusion_expr::ScalarFunctionArgs, ) -> Result { let [first_arg, second_arg] = take_function_args(self.name(), &args.args)?; + if first_arg.data_type().is_null() { + // Always return null if the first argument is null + // i.e. array_has(null, element) -> null + return Ok(ColumnarValue::Scalar(ScalarValue::Boolean(None))); + } + match &second_arg { ColumnarValue::Array(array_needle) => { // the needle is already an array, convert the haystack to an array of the same length @@ -663,6 +672,7 @@ fn general_array_has_all_and_any_kernel( mod tests { use std::sync::Arc; + use arrow::datatypes::Int32Type; use arrow::{ array::{create_array, Array, ArrayRef, AsArray, Int32Array, ListArray}, buffer::OffsetBuffer, @@ -733,6 +743,40 @@ mod tests { ); } + #[test] + fn test_simplify_array_has_with_null_to_null() { + let haystack = Expr::Literal(ScalarValue::Null, None); + let needle = col("c"); + + let props = ExecutionProps::new(); + let context = datafusion_expr::simplify::SimplifyContext::new(&props); + let Ok(ExprSimplifyResult::Simplified(simplified)) = + ArrayHas::new().simplify(vec![haystack, needle], &context) + else { + panic!("Expected simplified expression"); + }; + + assert_eq!(simplified, Expr::Literal(ScalarValue::Boolean(None), None)); + } + + #[test] + fn test_simplify_array_has_with_null_list_to_null() { + let haystack = + ListArray::from_iter_primitive::; 0], _>([None]); + let haystack = Expr::Literal(ScalarValue::List(Arc::new(haystack)), None); + let needle = col("c"); + + let props = ExecutionProps::new(); + let context = datafusion_expr::simplify::SimplifyContext::new(&props); + let Ok(ExprSimplifyResult::Simplified(simplified)) = + ArrayHas::new().simplify(vec![haystack, needle], &context) + else { + panic!("Expected simplified expression"); + }; + + assert_eq!(simplified, Expr::Literal(ScalarValue::Boolean(None), None)); + } + #[test] fn test_array_has_complex_list_not_simplified() { let haystack = col("c1"); @@ -757,13 +801,9 @@ mod tests { Field::new_list("", Field::new("", DataType::Int32, true), true), true, )); - let needle_field = Arc::new(Field::new("needle", DataType::Int32, true)); - let return_field = Arc::new(Field::new_list( - "return", - Field::new("", DataType::Boolean, true), - true, - )); + let needle_field = Arc::new(Field::new("needle", DataType::Int32, true)); + let return_field = Arc::new(Field::new("return", DataType::Boolean, true)); let haystack = ListArray::new( Field::new_list_field(DataType::Int32, true).into(), OffsetBuffer::new(vec![0, 0].into()), @@ -773,7 +813,6 @@ mod tests { let haystack = ColumnarValue::Array(Arc::new(haystack)); let needle = ColumnarValue::Scalar(ScalarValue::Int32(Some(1))); - let result = ArrayHas::new().invoke_with_args(ScalarFunctionArgs { args: vec![haystack, needle], arg_fields: vec![haystack_field, needle_field], @@ -789,4 +828,34 @@ mod tests { Ok(()) } + + #[test] + fn test_array_has_list_null_haystack() -> Result<(), DataFusionError> { + let haystack_field = Arc::new(Field::new("haystack", DataType::Null, true)); + let needle_field = Arc::new(Field::new("needle", DataType::Int32, true)); + let return_field = Arc::new(Field::new("return", DataType::Boolean, true)); + let haystack = + ListArray::from_iter_primitive::; 0], _>([ + None, None, None, + ]); + + let haystack = ColumnarValue::Array(Arc::new(haystack)); + let needle = ColumnarValue::Scalar(ScalarValue::Int32(Some(1))); + let result = ArrayHas::new().invoke_with_args(ScalarFunctionArgs { + args: vec![haystack, needle], + arg_fields: vec![haystack_field, needle_field], + number_rows: 1, + return_field, + config_options: Arc::new(ConfigOptions::default()), + })?; + + let output = result.into_array(1)?; + let output = output.as_boolean(); + assert_eq!(output.len(), 3); + for i in 0..3 { + assert!(output.is_null(i)); + } + + Ok(()) + } } diff --git a/datafusion/functions-nested/src/set_ops.rs b/datafusion/functions-nested/src/set_ops.rs index 555767f8f070..53642bf1622b 100644 --- a/datafusion/functions-nested/src/set_ops.rs +++ b/datafusion/functions-nested/src/set_ops.rs @@ -29,9 +29,7 @@ use arrow::datatypes::{DataType, Field, FieldRef}; use arrow::row::{RowConverter, SortField}; use datafusion_common::cast::{as_large_list_array, as_list_array}; use datafusion_common::utils::ListCoercion; -use datafusion_common::{ - exec_err, internal_err, plan_err, utils::take_function_args, Result, -}; +use datafusion_common::{exec_err, internal_err, utils::take_function_args, Result}; use datafusion_expr::{ ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility, }; @@ -289,13 +287,7 @@ impl ScalarUDFImpl for ArrayDistinct { } fn return_type(&self, arg_types: &[DataType]) -> Result { - match &arg_types[0] { - List(field) => Ok(DataType::new_list(field.data_type().clone(), true)), - LargeList(field) => { - Ok(DataType::new_large_list(field.data_type().clone(), true)) - } - arg_type => plan_err!("{} does not support type {arg_type}", self.name()), - } + Ok(arg_types[0].clone()) } fn invoke_with_args( @@ -563,3 +555,54 @@ fn general_array_distinct( array.nulls().cloned(), )?)) } + +#[cfg(test)] +mod tests { + use std::sync::Arc; + + use arrow::{ + array::{Int32Array, ListArray}, + buffer::OffsetBuffer, + datatypes::{DataType, Field}, + }; + use datafusion_common::{config::ConfigOptions, DataFusionError}; + use datafusion_expr::{ColumnarValue, ScalarFunctionArgs}; + + use crate::set_ops::array_distinct_udf; + + #[test] + fn test_array_distinct_inner_nullability_result_type_match_return_type( + ) -> Result<(), DataFusionError> { + let udf = array_distinct_udf(); + + for inner_nullable in [true, false] { + let inner_field = Field::new_list_field(DataType::Int32, inner_nullable); + let input_field = + Field::new_list("input", Arc::new(inner_field.clone()), true); + + // [[1, 1, 2]] + let input_array = ListArray::new( + inner_field.into(), + OffsetBuffer::new(vec![0, 3].into()), + Arc::new(Int32Array::new(vec![1, 1, 2].into(), None)), + None, + ); + + let input_array = ColumnarValue::Array(Arc::new(input_array)); + + let result = udf.invoke_with_args(ScalarFunctionArgs { + args: vec![input_array], + arg_fields: vec![input_field.clone().into()], + number_rows: 1, + return_field: input_field.clone().into(), + config_options: Arc::new(ConfigOptions::default()), + })?; + + assert_eq!( + result.data_type(), + udf.return_type(&[input_field.data_type().clone()])? + ); + } + Ok(()) + } +} diff --git a/datafusion/functions-nested/src/string.rs b/datafusion/functions-nested/src/string.rs index 3373f7a9838e..61caa3ac7076 100644 --- a/datafusion/functions-nested/src/string.rs +++ b/datafusion/functions-nested/src/string.rs @@ -369,27 +369,38 @@ pub(super) fn array_to_string_inner(args: &[ArrayRef]) -> Result { List(..) => { let list_array = as_list_array(&arr)?; for i in 0..list_array.len() { - compute_array_to_string( - arg, - list_array.value(i), - delimiter.clone(), - null_string.clone(), - with_null_string, - )?; + if !list_array.is_null(i) { + compute_array_to_string( + arg, + list_array.value(i), + delimiter.clone(), + null_string.clone(), + with_null_string, + )?; + } else if with_null_string { + arg.push_str(&null_string); + arg.push_str(&delimiter); + } } Ok(arg) } FixedSizeList(..) => { let list_array = as_fixed_size_list_array(&arr)?; + for i in 0..list_array.len() { - compute_array_to_string( - arg, - list_array.value(i), - delimiter.clone(), - null_string.clone(), - with_null_string, - )?; + if !list_array.is_null(i) { + compute_array_to_string( + arg, + list_array.value(i), + delimiter.clone(), + null_string.clone(), + with_null_string, + )?; + } else if with_null_string { + arg.push_str(&null_string); + arg.push_str(&delimiter); + } } Ok(arg) @@ -397,13 +408,18 @@ pub(super) fn array_to_string_inner(args: &[ArrayRef]) -> Result { LargeList(..) => { let list_array = as_large_list_array(&arr)?; for i in 0..list_array.len() { - compute_array_to_string( - arg, - list_array.value(i), - delimiter.clone(), - null_string.clone(), - with_null_string, - )?; + if !list_array.is_null(i) { + compute_array_to_string( + arg, + list_array.value(i), + delimiter.clone(), + null_string.clone(), + with_null_string, + )?; + } else if with_null_string { + arg.push_str(&null_string); + arg.push_str(&delimiter); + } } Ok(arg) diff --git a/datafusion/functions-window-common/src/expr.rs b/datafusion/functions-window-common/src/expr.rs index 774cd5182b30..d72cd412f017 100644 --- a/datafusion/functions-window-common/src/expr.rs +++ b/datafusion/functions-window-common/src/expr.rs @@ -37,7 +37,7 @@ impl<'a> ExpressionArgs<'a> { /// /// * `input_exprs` - The expressions passed as arguments /// to the user-defined window function. - /// * `input_types` - The data types corresponding to the + /// * `input_fields` - The fields corresponding to the /// arguments to the user-defined window function. /// pub fn new( diff --git a/datafusion/functions-window-common/src/partition.rs b/datafusion/functions-window-common/src/partition.rs index 61125e596130..df0a81540117 100644 --- a/datafusion/functions-window-common/src/partition.rs +++ b/datafusion/functions-window-common/src/partition.rs @@ -42,7 +42,7 @@ impl<'a> PartitionEvaluatorArgs<'a> { /// /// * `input_exprs` - The expressions passed as arguments /// to the user-defined window function. - /// * `input_types` - The data types corresponding to the + /// * `input_fields` - The fields corresponding to the /// arguments to the user-defined window function. /// * `is_reversed` - Set to `true` if and only if the user-defined /// window function is reversible and is reversed. diff --git a/datafusion/functions-window/src/nth_value.rs b/datafusion/functions-window/src/nth_value.rs index 329d8aa5ab17..1ba6ad5ce0d4 100644 --- a/datafusion/functions-window/src/nth_value.rs +++ b/datafusion/functions-window/src/nth_value.rs @@ -40,39 +40,28 @@ use std::hash::Hash; use std::ops::Range; use std::sync::{Arc, LazyLock}; -get_or_init_udwf!( +define_udwf_and_expr!( First, first_value, - "returns the first value in the window frame", + [arg], + "Returns the first value in the window frame", NthValue::first ); -get_or_init_udwf!( +define_udwf_and_expr!( Last, last_value, - "returns the last value in the window frame", + [arg], + "Returns the last value in the window frame", NthValue::last ); get_or_init_udwf!( NthValue, nth_value, - "returns the nth value in the window frame", + "Returns the nth value in the window frame", NthValue::nth ); -/// Create an expression to represent the `first_value` window function -/// -pub fn first_value(arg: datafusion_expr::Expr) -> datafusion_expr::Expr { - first_value_udwf().call(vec![arg]) -} - -/// Create an expression to represent the `last_value` window function -/// -pub fn last_value(arg: datafusion_expr::Expr) -> datafusion_expr::Expr { - last_value_udwf().call(vec![arg]) -} - /// Create an expression to represent the `nth_value` window function -/// pub fn nth_value(arg: datafusion_expr::Expr, n: i64) -> datafusion_expr::Expr { nth_value_udwf().call(vec![arg, n.lit()]) } diff --git a/datafusion/functions-window/src/ntile.rs b/datafusion/functions-window/src/ntile.rs index d188db3bbf59..008caaa848aa 100644 --- a/datafusion/functions-window/src/ntile.rs +++ b/datafusion/functions-window/src/ntile.rs @@ -25,8 +25,7 @@ use datafusion_common::arrow::array::{ArrayRef, UInt64Array}; use datafusion_common::arrow::datatypes::{DataType, Field}; use datafusion_common::{exec_datafusion_err, exec_err, Result}; use datafusion_expr::{ - Documentation, Expr, LimitEffect, PartitionEvaluator, Signature, Volatility, - WindowUDFImpl, + Documentation, LimitEffect, PartitionEvaluator, Signature, Volatility, WindowUDFImpl, }; use datafusion_functions_window_common::field; use datafusion_functions_window_common::partition::PartitionEvaluatorArgs; @@ -37,16 +36,13 @@ use std::any::Any; use std::fmt::Debug; use std::sync::Arc; -get_or_init_udwf!( +define_udwf_and_expr!( Ntile, ntile, - "integer ranging from 1 to the argument value, dividing the partition as equally as possible" + [arg], + "Integer ranging from 1 to the argument value, dividing the partition as equally as possible." ); -pub fn ntile(arg: Expr) -> Expr { - ntile_udwf().call(vec![arg]) -} - #[user_doc( doc_section(label = "Ranking Functions"), description = "Integer ranging from 1 to the argument value, dividing the partition as equally as possible", diff --git a/datafusion/functions/src/core/arrow_cast.rs b/datafusion/functions/src/core/arrow_cast.rs index 94a41ba4bb25..c4e58601cd10 100644 --- a/datafusion/functions/src/core/arrow_cast.rs +++ b/datafusion/functions/src/core/arrow_cast.rs @@ -60,16 +60,26 @@ use datafusion_macros::user_doc; description = "Casts a value to a specific Arrow data type.", syntax_example = "arrow_cast(expression, datatype)", sql_example = r#"```sql -> select arrow_cast(-5, 'Int8') as a, +> select + arrow_cast(-5, 'Int8') as a, arrow_cast('foo', 'Dictionary(Int32, Utf8)') as b, - arrow_cast('bar', 'LargeUtf8') as c, - arrow_cast('2023-01-02T12:53:02', 'Timestamp(Microsecond, Some("+08:00"))') as d - ; -+----+-----+-----+---------------------------+ -| a | b | c | d | -+----+-----+-----+---------------------------+ -| -5 | foo | bar | 2023-01-02T12:53:02+08:00 | -+----+-----+-----+---------------------------+ + arrow_cast('bar', 'LargeUtf8') as c; + ++----+-----+-----+ +| a | b | c | ++----+-----+-----+ +| -5 | foo | bar | ++----+-----+-----+ + +> select + arrow_cast('2023-01-02T12:53:02', 'Timestamp(µs, "+08:00")') as d, + arrow_cast('2023-01-02T12:53:02', 'Timestamp(µs)') as e; + ++---------------------------+---------------------+ +| d | e | ++---------------------------+---------------------+ +| 2023-01-02T12:53:02+08:00 | 2023-01-02T12:53:02 | ++---------------------------+---------------------+ ```"#, argument( name = "expression", diff --git a/datafusion/functions/src/core/coalesce.rs b/datafusion/functions/src/core/coalesce.rs index 3fba539dd04b..aab1f445d559 100644 --- a/datafusion/functions/src/core/coalesce.rs +++ b/datafusion/functions/src/core/coalesce.rs @@ -47,7 +47,7 @@ use std::any::Any; )] #[derive(Debug, PartialEq, Eq, Hash)] pub struct CoalesceFunc { - signature: Signature, + pub(super) signature: Signature, } impl Default for CoalesceFunc { @@ -126,6 +126,15 @@ impl ScalarUDFImpl for CoalesceFunc { internal_err!("coalesce should have been simplified to case") } + fn conditional_arguments<'a>( + &self, + args: &'a [Expr], + ) -> Option<(Vec<&'a Expr>, Vec<&'a Expr>)> { + let eager = vec![&args[0]]; + let lazy = args[1..].iter().collect(); + Some((eager, lazy)) + } + fn short_circuits(&self) -> bool { true } diff --git a/datafusion/functions/src/core/nvl.rs b/datafusion/functions/src/core/nvl.rs index c8b34c4b1780..0b9968a88fc9 100644 --- a/datafusion/functions/src/core/nvl.rs +++ b/datafusion/functions/src/core/nvl.rs @@ -15,21 +15,19 @@ // specific language governing permissions and limitations // under the License. -use arrow::array::Array; -use arrow::compute::is_not_null; -use arrow::compute::kernels::zip::zip; -use arrow::datatypes::DataType; -use datafusion_common::{utils::take_function_args, Result}; +use crate::core::coalesce::CoalesceFunc; +use arrow::datatypes::{DataType, FieldRef}; +use datafusion_common::Result; +use datafusion_expr::simplify::{ExprSimplifyResult, SimplifyInfo}; use datafusion_expr::{ - ColumnarValue, Documentation, ScalarFunctionArgs, ScalarUDFImpl, Signature, - Volatility, + ColumnarValue, Documentation, Expr, ReturnFieldArgs, ScalarFunctionArgs, + ScalarUDFImpl, Signature, Volatility, }; use datafusion_macros::user_doc; -use std::sync::Arc; #[user_doc( doc_section(label = "Conditional Functions"), - description = "Returns _expression2_ if _expression1_ is NULL otherwise it returns _expression1_.", + description = "Returns _expression2_ if _expression1_ is NULL otherwise it returns _expression1_ and _expression2_ is not evaluated. This function can be used to substitute a default value for NULL values.", syntax_example = "nvl(expression1, expression2)", sql_example = r#"```sql > select nvl(null, 'a'); @@ -57,7 +55,7 @@ use std::sync::Arc; )] #[derive(Debug, PartialEq, Eq, Hash)] pub struct NVLFunc { - signature: Signature, + coalesce: CoalesceFunc, aliases: Vec, } @@ -90,11 +88,13 @@ impl Default for NVLFunc { impl NVLFunc { pub fn new() -> Self { Self { - signature: Signature::uniform( - 2, - SUPPORTED_NVL_TYPES.to_vec(), - Volatility::Immutable, - ), + coalesce: CoalesceFunc { + signature: Signature::uniform( + 2, + SUPPORTED_NVL_TYPES.to_vec(), + Volatility::Immutable, + ), + }, aliases: vec![String::from("ifnull")], } } @@ -110,209 +110,45 @@ impl ScalarUDFImpl for NVLFunc { } fn signature(&self) -> &Signature { - &self.signature + &self.coalesce.signature } fn return_type(&self, arg_types: &[DataType]) -> Result { - Ok(arg_types[0].clone()) + self.coalesce.return_type(arg_types) } - fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { - nvl_func(&args.args) - } - - fn aliases(&self) -> &[String] { - &self.aliases - } - - fn documentation(&self) -> Option<&Documentation> { - self.doc() + fn return_field_from_args(&self, args: ReturnFieldArgs) -> Result { + self.coalesce.return_field_from_args(args) } -} - -fn nvl_func(args: &[ColumnarValue]) -> Result { - let [lhs, rhs] = take_function_args("nvl/ifnull", args)?; - let (lhs_array, rhs_array) = match (lhs, rhs) { - (ColumnarValue::Array(lhs), ColumnarValue::Scalar(rhs)) => { - (Arc::clone(lhs), rhs.to_array_of_size(lhs.len())?) - } - (ColumnarValue::Array(lhs), ColumnarValue::Array(rhs)) => { - (Arc::clone(lhs), Arc::clone(rhs)) - } - (ColumnarValue::Scalar(lhs), ColumnarValue::Array(rhs)) => { - (lhs.to_array_of_size(rhs.len())?, Arc::clone(rhs)) - } - (ColumnarValue::Scalar(lhs), ColumnarValue::Scalar(rhs)) => { - let mut current_value = lhs; - if lhs.is_null() { - current_value = rhs; - } - return Ok(ColumnarValue::Scalar(current_value.clone())); - } - }; - let to_apply = is_not_null(&lhs_array)?; - let value = zip(&to_apply, &lhs_array, &rhs_array)?; - Ok(ColumnarValue::Array(value)) -} - -#[cfg(test)] -mod tests { - use std::sync::Arc; - - use arrow::array::*; - use super::*; - use datafusion_common::ScalarValue; - - #[test] - fn nvl_int32() -> Result<()> { - let a = Int32Array::from(vec![ - Some(1), - Some(2), - None, - None, - Some(3), - None, - None, - Some(4), - Some(5), - ]); - let a = ColumnarValue::Array(Arc::new(a)); - - let lit_array = ColumnarValue::Scalar(ScalarValue::Int32(Some(6i32))); - - let result = nvl_func(&[a, lit_array])?; - let result = result.into_array(0).expect("Failed to convert to array"); - - let expected = Arc::new(Int32Array::from(vec![ - Some(1), - Some(2), - Some(6), - Some(6), - Some(3), - Some(6), - Some(6), - Some(4), - Some(5), - ])) as ArrayRef; - assert_eq!(expected.as_ref(), result.as_ref()); - Ok(()) + fn simplify( + &self, + args: Vec, + info: &dyn SimplifyInfo, + ) -> Result { + self.coalesce.simplify(args, info) } - #[test] - // Ensure that arrays with no nulls can also invoke nvl() correctly - fn nvl_int32_non_nulls() -> Result<()> { - let a = Int32Array::from(vec![1, 3, 10, 7, 8, 1, 2, 4, 5]); - let a = ColumnarValue::Array(Arc::new(a)); - - let lit_array = ColumnarValue::Scalar(ScalarValue::Int32(Some(20i32))); - - let result = nvl_func(&[a, lit_array])?; - let result = result.into_array(0).expect("Failed to convert to array"); - - let expected = Arc::new(Int32Array::from(vec![ - Some(1), - Some(3), - Some(10), - Some(7), - Some(8), - Some(1), - Some(2), - Some(4), - Some(5), - ])) as ArrayRef; - assert_eq!(expected.as_ref(), result.as_ref()); - Ok(()) + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { + self.coalesce.invoke_with_args(args) } - #[test] - fn nvl_boolean() -> Result<()> { - let a = BooleanArray::from(vec![Some(true), Some(false), None]); - let a = ColumnarValue::Array(Arc::new(a)); - - let lit_array = ColumnarValue::Scalar(ScalarValue::Boolean(Some(false))); - - let result = nvl_func(&[a, lit_array])?; - let result = result.into_array(0).expect("Failed to convert to array"); - - let expected = Arc::new(BooleanArray::from(vec![ - Some(true), - Some(false), - Some(false), - ])) as ArrayRef; - - assert_eq!(expected.as_ref(), result.as_ref()); - Ok(()) + fn conditional_arguments<'a>( + &self, + args: &'a [Expr], + ) -> Option<(Vec<&'a Expr>, Vec<&'a Expr>)> { + self.coalesce.conditional_arguments(args) } - #[test] - fn nvl_string() -> Result<()> { - let a = StringArray::from(vec![Some("foo"), Some("bar"), None, Some("baz")]); - let a = ColumnarValue::Array(Arc::new(a)); - - let lit_array = ColumnarValue::Scalar(ScalarValue::from("bax")); - - let result = nvl_func(&[a, lit_array])?; - let result = result.into_array(0).expect("Failed to convert to array"); - - let expected = Arc::new(StringArray::from(vec![ - Some("foo"), - Some("bar"), - Some("bax"), - Some("baz"), - ])) as ArrayRef; - - assert_eq!(expected.as_ref(), result.as_ref()); - Ok(()) + fn short_circuits(&self) -> bool { + self.coalesce.short_circuits() } - #[test] - fn nvl_literal_first() -> Result<()> { - let a = Int32Array::from(vec![Some(1), Some(2), None, None, Some(3), Some(4)]); - let a = ColumnarValue::Array(Arc::new(a)); - - let lit_array = ColumnarValue::Scalar(ScalarValue::Int32(Some(2i32))); - - let result = nvl_func(&[lit_array, a])?; - let result = result.into_array(0).expect("Failed to convert to array"); - - let expected = Arc::new(Int32Array::from(vec![ - Some(2), - Some(2), - Some(2), - Some(2), - Some(2), - Some(2), - ])) as ArrayRef; - assert_eq!(expected.as_ref(), result.as_ref()); - Ok(()) + fn aliases(&self) -> &[String] { + &self.aliases } - #[test] - fn nvl_scalar() -> Result<()> { - let a_null = ColumnarValue::Scalar(ScalarValue::Int32(None)); - let b_null = ColumnarValue::Scalar(ScalarValue::Int32(Some(2i32))); - - let result_null = nvl_func(&[a_null, b_null])?; - let result_null = result_null - .into_array(1) - .expect("Failed to convert to array"); - - let expected_null = Arc::new(Int32Array::from(vec![Some(2i32)])) as ArrayRef; - - assert_eq!(expected_null.as_ref(), result_null.as_ref()); - - let a_nnull = ColumnarValue::Scalar(ScalarValue::Int32(Some(2i32))); - let b_nnull = ColumnarValue::Scalar(ScalarValue::Int32(Some(1i32))); - - let result_nnull = nvl_func(&[a_nnull, b_nnull])?; - let result_nnull = result_nnull - .into_array(1) - .expect("Failed to convert to array"); - - let expected_nnull = Arc::new(Int32Array::from(vec![Some(2i32)])) as ArrayRef; - assert_eq!(expected_nnull.as_ref(), result_nnull.as_ref()); - - Ok(()) + fn documentation(&self) -> Option<&Documentation> { + self.doc() } } diff --git a/datafusion/functions/src/core/nvl2.rs b/datafusion/functions/src/core/nvl2.rs index 82aa8d2a4cd5..45cb6760d062 100644 --- a/datafusion/functions/src/core/nvl2.rs +++ b/datafusion/functions/src/core/nvl2.rs @@ -15,17 +15,16 @@ // specific language governing permissions and limitations // under the License. -use arrow::array::Array; -use arrow::compute::is_not_null; -use arrow::compute::kernels::zip::zip; -use arrow::datatypes::DataType; +use arrow::datatypes::{DataType, Field, FieldRef}; use datafusion_common::{internal_err, utils::take_function_args, Result}; use datafusion_expr::{ - type_coercion::binary::comparison_coercion, ColumnarValue, Documentation, - ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility, + conditional_expressions::CaseBuilder, + simplify::{ExprSimplifyResult, SimplifyInfo}, + type_coercion::binary::comparison_coercion, + ColumnarValue, Documentation, Expr, ReturnFieldArgs, ScalarFunctionArgs, + ScalarUDFImpl, Signature, Volatility, }; use datafusion_macros::user_doc; -use std::sync::Arc; #[user_doc( doc_section(label = "Conditional Functions"), @@ -95,8 +94,37 @@ impl ScalarUDFImpl for NVL2Func { Ok(arg_types[1].clone()) } - fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { - nvl2_func(&args.args) + fn return_field_from_args(&self, args: ReturnFieldArgs) -> Result { + let nullable = + args.arg_fields[1].is_nullable() || args.arg_fields[2].is_nullable(); + let return_type = args.arg_fields[1].data_type().clone(); + Ok(Field::new(self.name(), return_type, nullable).into()) + } + + fn invoke_with_args(&self, _args: ScalarFunctionArgs) -> Result { + internal_err!("nvl2 should have been simplified to case") + } + + fn simplify( + &self, + args: Vec, + _info: &dyn SimplifyInfo, + ) -> Result { + let [test, if_non_null, if_null] = take_function_args(self.name(), args)?; + + let expr = CaseBuilder::new( + None, + vec![test.is_not_null()], + vec![if_non_null], + Some(Box::new(if_null)), + ) + .end()?; + + Ok(ExprSimplifyResult::Simplified(expr)) + } + + fn short_circuits(&self) -> bool { + true } fn coerce_types(&self, arg_types: &[DataType]) -> Result> { @@ -123,42 +151,3 @@ impl ScalarUDFImpl for NVL2Func { self.doc() } } - -fn nvl2_func(args: &[ColumnarValue]) -> Result { - let mut len = 1; - let mut is_array = false; - for arg in args { - if let ColumnarValue::Array(array) = arg { - len = array.len(); - is_array = true; - break; - } - } - if is_array { - let args = args - .iter() - .map(|arg| match arg { - ColumnarValue::Scalar(scalar) => scalar.to_array_of_size(len), - ColumnarValue::Array(array) => Ok(Arc::clone(array)), - }) - .collect::>>()?; - let [tested, if_non_null, if_null] = take_function_args("nvl2", args)?; - let to_apply = is_not_null(&tested)?; - let value = zip(&to_apply, &if_non_null, &if_null)?; - Ok(ColumnarValue::Array(value)) - } else { - let [tested, if_non_null, if_null] = take_function_args("nvl2", args)?; - match &tested { - ColumnarValue::Array(_) => { - internal_err!("except Scalar value, but got Array") - } - ColumnarValue::Scalar(scalar) => { - if scalar.is_null() { - Ok(if_null.clone()) - } else { - Ok(if_non_null.clone()) - } - } - } - } -} diff --git a/datafusion/functions/src/datetime/current_date.rs b/datafusion/functions/src/datetime/current_date.rs index 0ba3afd19bed..18b99bca8638 100644 --- a/datafusion/functions/src/datetime/current_date.rs +++ b/datafusion/functions/src/datetime/current_date.rs @@ -36,7 +36,9 @@ Returns the current date in the session time zone. The `current_date()` return value is determined at query time and will return the same date, no matter when in the query plan the function executes. "#, - syntax_example = "current_date()" + syntax_example = r#"current_date() + (optional) SET datafusion.execution.time_zone = '+00:00'; + SELECT current_date();"# )] #[derive(Debug, PartialEq, Eq, Hash)] pub struct CurrentDateFunc { diff --git a/datafusion/functions/src/datetime/current_time.rs b/datafusion/functions/src/datetime/current_time.rs index 79d5bfc1783c..4f5b199cce41 100644 --- a/datafusion/functions/src/datetime/current_time.rs +++ b/datafusion/functions/src/datetime/current_time.rs @@ -15,26 +15,32 @@ // specific language governing permissions and limitations // under the License. +use arrow::array::timezone::Tz; use arrow::datatypes::DataType; use arrow::datatypes::DataType::Time64; use arrow::datatypes::TimeUnit::Nanosecond; -use std::any::Any; - +use chrono::TimeZone; +use chrono::Timelike; use datafusion_common::{internal_err, Result, ScalarValue}; use datafusion_expr::simplify::{ExprSimplifyResult, SimplifyInfo}; use datafusion_expr::{ ColumnarValue, Documentation, Expr, ScalarUDFImpl, Signature, Volatility, }; use datafusion_macros::user_doc; +use std::any::Any; #[user_doc( doc_section(label = "Time and Date Functions"), description = r#" -Returns the current UTC time. +Returns the current time in the session time zone. The `current_time()` return value is determined at query time and will return the same time, no matter when in the query plan the function executes. + +The session time zone can be set using the statement 'SET datafusion.execution.time_zone = desired time zone'. The time zone can be a value like +00:00, 'Europe/London' etc. "#, - syntax_example = "current_time()" + syntax_example = r#"current_time() + (optional) SET datafusion.execution.time_zone = '+00:00'; + SELECT current_time();"# )] #[derive(Debug, PartialEq, Eq, Hash)] pub struct CurrentTimeFunc { @@ -93,7 +99,20 @@ impl ScalarUDFImpl for CurrentTimeFunc { info: &dyn SimplifyInfo, ) -> Result { let now_ts = info.execution_props().query_execution_start_time; - let nano = now_ts.timestamp_nanos_opt().map(|ts| ts % 86400000000000); + + // Try to get timezone from config and convert to local time + let nano = info + .execution_props() + .config_options() + .and_then(|config| config.execution.time_zone.parse::().ok()) + .map_or_else( + || datetime_to_time_nanos(&now_ts), + |tz| { + let local_now = tz.from_utc_datetime(&now_ts.naive_utc()); + datetime_to_time_nanos(&local_now) + }, + ); + Ok(ExprSimplifyResult::Simplified(Expr::Literal( ScalarValue::Time64Nanosecond(nano), None, @@ -104,3 +123,97 @@ impl ScalarUDFImpl for CurrentTimeFunc { self.doc() } } + +// Helper function for conversion of datetime to a timestamp. +fn datetime_to_time_nanos(dt: &chrono::DateTime) -> Option { + let hour = dt.hour() as i64; + let minute = dt.minute() as i64; + let second = dt.second() as i64; + let nanosecond = dt.nanosecond() as i64; + Some((hour * 3600 + minute * 60 + second) * 1_000_000_000 + nanosecond) +} + +#[cfg(test)] +mod tests { + use super::*; + use arrow::datatypes::{DataType, TimeUnit::Nanosecond}; + use chrono::{DateTime, Utc}; + use datafusion_common::{Result, ScalarValue}; + use datafusion_expr::execution_props::ExecutionProps; + use datafusion_expr::simplify::{ExprSimplifyResult, SimplifyInfo}; + use std::sync::Arc; + + struct MockSimplifyInfo { + execution_props: ExecutionProps, + } + + impl SimplifyInfo for MockSimplifyInfo { + fn is_boolean_type(&self, _expr: &Expr) -> Result { + Ok(false) + } + + fn nullable(&self, _expr: &Expr) -> Result { + Ok(true) + } + + fn execution_props(&self) -> &ExecutionProps { + &self.execution_props + } + + fn get_data_type(&self, _expr: &Expr) -> Result { + Ok(Time64(Nanosecond)) + } + } + + fn set_session_timezone_env(tz: &str, start_time: DateTime) -> MockSimplifyInfo { + let mut config = datafusion_common::config::ConfigOptions::default(); + config.execution.time_zone = tz.to_string(); + let mut execution_props = + ExecutionProps::new().with_query_execution_start_time(start_time); + execution_props.config_options = Some(Arc::new(config)); + MockSimplifyInfo { execution_props } + } + + #[test] + fn test_current_time_timezone_offset() { + // Use a fixed start time for consistent testing + let start_time = Utc.with_ymd_and_hms(2025, 1, 1, 12, 0, 0).unwrap(); + + // Test with UTC+05:00 + let info_plus_5 = set_session_timezone_env("+05:00", start_time); + let result_plus_5 = CurrentTimeFunc::new() + .simplify(vec![], &info_plus_5) + .unwrap(); + + // Test with UTC-05:00 + let info_minus_5 = set_session_timezone_env("-05:00", start_time); + let result_minus_5 = CurrentTimeFunc::new() + .simplify(vec![], &info_minus_5) + .unwrap(); + + // Extract nanoseconds from results + let nanos_plus_5 = match result_plus_5 { + ExprSimplifyResult::Simplified(Expr::Literal( + ScalarValue::Time64Nanosecond(Some(n)), + _, + )) => n, + _ => panic!("Expected Time64Nanosecond literal"), + }; + + let nanos_minus_5 = match result_minus_5 { + ExprSimplifyResult::Simplified(Expr::Literal( + ScalarValue::Time64Nanosecond(Some(n)), + _, + )) => n, + _ => panic!("Expected Time64Nanosecond literal"), + }; + + // Calculate the difference: UTC+05:00 should be 10 hours ahead of UTC-05:00 + let difference = nanos_plus_5 - nanos_minus_5; + + // 10 hours in nanoseconds + let expected_offset = 10i64 * 3600 * 1_000_000_000; + + assert_eq!(difference, expected_offset, "Expected 10-hour offset difference in nanoseconds between UTC+05:00 and UTC-05:00"); + } +} diff --git a/datafusion/functions/src/datetime/date_bin.rs b/datafusion/functions/src/datetime/date_bin.rs index 74e286de0f58..c4e89743bd55 100644 --- a/datafusion/functions/src/datetime/date_bin.rs +++ b/datafusion/functions/src/datetime/date_bin.rs @@ -687,7 +687,7 @@ mod tests { let res = invoke_date_bin_with_args(args, 1, return_field); assert_eq!( res.err().unwrap().strip_backtrace(), - "Execution error: DATE_BIN expects origin argument to be a TIMESTAMP with nanosecond precision but got Timestamp(Microsecond, None)" + "Execution error: DATE_BIN expects origin argument to be a TIMESTAMP with nanosecond precision but got Timestamp(µs)" ); args = vec![ diff --git a/datafusion/functions/src/datetime/mod.rs b/datafusion/functions/src/datetime/mod.rs index 5729b1edae95..d80f14facf82 100644 --- a/datafusion/functions/src/datetime/mod.rs +++ b/datafusion/functions/src/datetime/mod.rs @@ -45,7 +45,6 @@ make_udf_function!(date_part::DatePartFunc, date_part); make_udf_function!(date_trunc::DateTruncFunc, date_trunc); make_udf_function!(make_date::MakeDateFunc, make_date); make_udf_function!(from_unixtime::FromUnixtimeFunc, from_unixtime); -make_udf_function!(now::NowFunc, now); make_udf_function!(to_char::ToCharFunc, to_char); make_udf_function!(to_date::ToDateFunc, to_date); make_udf_function!(to_local_time::ToLocalTimeFunc, to_local_time); @@ -56,6 +55,9 @@ make_udf_function!(to_timestamp::ToTimestampMillisFunc, to_timestamp_millis); make_udf_function!(to_timestamp::ToTimestampMicrosFunc, to_timestamp_micros); make_udf_function!(to_timestamp::ToTimestampNanosFunc, to_timestamp_nanos); +// create UDF with config +make_udf_function_with_config!(now::NowFunc, now); + // we cannot currently use the export_functions macro since it doesn't handle // functions with varargs currently @@ -91,6 +93,7 @@ pub mod expr_fn { ),( now, "returns the current timestamp in nanoseconds, using the same value for all instances of now() in same statement", + @config ), ( to_local_time, @@ -255,6 +258,7 @@ pub mod expr_fn { /// Returns all DataFusion functions defined in this package pub fn functions() -> Vec> { + use datafusion_common::config::ConfigOptions; vec![ current_date(), current_time(), @@ -263,7 +267,7 @@ pub fn functions() -> Vec> { date_trunc(), from_unixtime(), make_date(), - now(), + now(&ConfigOptions::default()), to_char(), to_date(), to_local_time(), diff --git a/datafusion/functions/src/datetime/now.rs b/datafusion/functions/src/datetime/now.rs index 65dadb42a89e..96a35c241ff0 100644 --- a/datafusion/functions/src/datetime/now.rs +++ b/datafusion/functions/src/datetime/now.rs @@ -19,12 +19,14 @@ use arrow::datatypes::DataType::Timestamp; use arrow::datatypes::TimeUnit::Nanosecond; use arrow::datatypes::{DataType, Field, FieldRef}; use std::any::Any; +use std::sync::Arc; +use datafusion_common::config::ConfigOptions; use datafusion_common::{internal_err, Result, ScalarValue}; use datafusion_expr::simplify::{ExprSimplifyResult, SimplifyInfo}; use datafusion_expr::{ - ColumnarValue, Documentation, Expr, ReturnFieldArgs, ScalarUDFImpl, Signature, - Volatility, + ColumnarValue, Documentation, Expr, ReturnFieldArgs, ScalarUDF, ScalarUDFImpl, + Signature, Volatility, }; use datafusion_macros::user_doc; @@ -41,19 +43,30 @@ The `now()` return value is determined at query time and will return the same ti pub struct NowFunc { signature: Signature, aliases: Vec, + timezone: Option>, } impl Default for NowFunc { fn default() -> Self { - Self::new() + Self::new_with_config(&ConfigOptions::default()) } } impl NowFunc { + #[deprecated(since = "50.2.0", note = "use `new_with_config` instead")] pub fn new() -> Self { Self { signature: Signature::nullary(Volatility::Stable), aliases: vec!["current_timestamp".to_string()], + timezone: Some(Arc::from("+00")), + } + } + + pub fn new_with_config(config: &ConfigOptions) -> Self { + Self { + signature: Signature::nullary(Volatility::Stable), + aliases: vec!["current_timestamp".to_string()], + timezone: Some(Arc::from(config.execution.time_zone.as_str())), } } } @@ -77,10 +90,14 @@ impl ScalarUDFImpl for NowFunc { &self.signature } + fn with_updated_config(&self, config: &ConfigOptions) -> Option { + Some(Self::new_with_config(config).into()) + } + fn return_field_from_args(&self, _args: ReturnFieldArgs) -> Result { Ok(Field::new( self.name(), - Timestamp(Nanosecond, Some("+00:00".into())), + Timestamp(Nanosecond, self.timezone.clone()), false, ) .into()) @@ -106,8 +123,9 @@ impl ScalarUDFImpl for NowFunc { .execution_props() .query_execution_start_time .timestamp_nanos_opt(); + Ok(ExprSimplifyResult::Simplified(Expr::Literal( - ScalarValue::TimestampNanosecond(now_ts, Some("+00:00".into())), + ScalarValue::TimestampNanosecond(now_ts, self.timezone.clone()), None, ))) } diff --git a/datafusion/functions/src/encoding/inner.rs b/datafusion/functions/src/encoding/inner.rs index 5baa91936320..e5314ad220c8 100644 --- a/datafusion/functions/src/encoding/inner.rs +++ b/datafusion/functions/src/encoding/inner.rs @@ -24,7 +24,10 @@ use arrow::{ datatypes::{ByteArrayType, DataType}, }; use arrow_buffer::{Buffer, OffsetBufferBuilder}; -use base64::{engine::general_purpose, Engine as _}; +use base64::{ + engine::{DecodePaddingMode, GeneralPurpose, GeneralPurposeConfig}, + Engine as _, +}; use datafusion_common::{ cast::{as_generic_binary_array, as_generic_string_array}, not_impl_err, plan_err, @@ -40,6 +43,14 @@ use datafusion_expr::{ScalarUDFImpl, Signature, Volatility}; use datafusion_macros::user_doc; use std::any::Any; +// Allow padding characters, but don't require them, and don't generate them. +const BASE64_ENGINE: GeneralPurpose = GeneralPurpose::new( + &base64::alphabet::STANDARD, + GeneralPurposeConfig::new() + .with_encode_padding(false) + .with_decode_padding_mode(DecodePaddingMode::Indifferent), +); + #[user_doc( doc_section(label = "Binary String Functions"), description = "Encode binary data into a textual representation.", @@ -302,7 +313,7 @@ fn hex_encode(input: &[u8]) -> String { } fn base64_encode(input: &[u8]) -> String { - general_purpose::STANDARD_NO_PAD.encode(input) + BASE64_ENGINE.encode(input) } fn hex_decode(input: &[u8], buf: &mut [u8]) -> Result { @@ -315,7 +326,7 @@ fn hex_decode(input: &[u8], buf: &mut [u8]) -> Result { } fn base64_decode(input: &[u8], buf: &mut [u8]) -> Result { - general_purpose::STANDARD_NO_PAD + BASE64_ENGINE .decode_slice(input, buf) .map_err(|e| internal_datafusion_err!("Failed to decode from base64: {e}")) } @@ -364,18 +375,16 @@ where impl Encoding { fn encode_scalar(self, value: Option<&[u8]>) -> ColumnarValue { ColumnarValue::Scalar(match self { - Self::Base64 => ScalarValue::Utf8( - value.map(|v| general_purpose::STANDARD_NO_PAD.encode(v)), - ), + Self::Base64 => ScalarValue::Utf8(value.map(|v| BASE64_ENGINE.encode(v))), Self::Hex => ScalarValue::Utf8(value.map(hex::encode)), }) } fn encode_large_scalar(self, value: Option<&[u8]>) -> ColumnarValue { ColumnarValue::Scalar(match self { - Self::Base64 => ScalarValue::LargeUtf8( - value.map(|v| general_purpose::STANDARD_NO_PAD.encode(v)), - ), + Self::Base64 => { + ScalarValue::LargeUtf8(value.map(|v| BASE64_ENGINE.encode(v))) + } Self::Hex => ScalarValue::LargeUtf8(value.map(hex::encode)), }) } @@ -411,15 +420,9 @@ impl Encoding { }; let out = match self { - Self::Base64 => { - general_purpose::STANDARD_NO_PAD - .decode(value) - .map_err(|e| { - internal_datafusion_err!( - "Failed to decode value using base64: {e}" - ) - })? - } + Self::Base64 => BASE64_ENGINE.decode(value).map_err(|e| { + internal_datafusion_err!("Failed to decode value using base64: {e}") + })?, Self::Hex => hex::decode(value).map_err(|e| { internal_datafusion_err!("Failed to decode value using hex: {e}") })?, @@ -435,15 +438,9 @@ impl Encoding { }; let out = match self { - Self::Base64 => { - general_purpose::STANDARD_NO_PAD - .decode(value) - .map_err(|e| { - internal_datafusion_err!( - "Failed to decode value using base64: {e}" - ) - })? - } + Self::Base64 => BASE64_ENGINE.decode(value).map_err(|e| { + internal_datafusion_err!("Failed to decode value using base64: {e}") + })?, Self::Hex => hex::decode(value).map_err(|e| { internal_datafusion_err!("Failed to decode value using hex: {e}") })?, diff --git a/datafusion/functions/src/macros.rs b/datafusion/functions/src/macros.rs index 228d704e29cb..9e195f2d5291 100644 --- a/datafusion/functions/src/macros.rs +++ b/datafusion/functions/src/macros.rs @@ -40,6 +40,7 @@ /// Exported functions accept: /// - `Vec` argument (single argument followed by a comma) /// - Variable number of `Expr` arguments (zero or more arguments, must be without commas) +/// - Functions that require config (marked with `@config` prefix) #[macro_export] macro_rules! export_functions { ($(($FUNC:ident, $DOC:expr, $($arg:tt)*)),*) => { @@ -49,6 +50,15 @@ macro_rules! export_functions { )* }; + // function that requires config (marked with @config) + (single $FUNC:ident, $DOC:expr, @config) => { + #[doc = $DOC] + pub fn $FUNC() -> datafusion_expr::Expr { + use datafusion_common::config::ConfigOptions; + super::$FUNC(&ConfigOptions::default()).call(vec![]) + } + }; + // single vector argument (a single argument followed by a comma) (single $FUNC:ident, $DOC:expr, $arg:ident,) => { #[doc = $DOC] @@ -89,6 +99,22 @@ macro_rules! make_udf_function { }; } +/// Creates a singleton `ScalarUDF` of the `$UDF` function and a function +/// named `$NAME` which returns that singleton. The function takes a +/// configuration argument of type `$CONFIG_TYPE` to create the UDF. +#[macro_export] +macro_rules! make_udf_function_with_config { + ($UDF:ty, $NAME:ident) => { + #[allow(rustdoc::redundant_explicit_links)] + #[doc = concat!("Return a [`ScalarUDF`](datafusion_expr::ScalarUDF) implementation of ", stringify!($NAME))] + pub fn $NAME(config: &datafusion_common::config::ConfigOptions) -> std::sync::Arc { + std::sync::Arc::new(datafusion_expr::ScalarUDF::new_from_impl( + <$UDF>::new_with_config(&config), + )) + } + }; +} + /// Macro creates a sub module if the feature is not enabled /// /// The rationale for providing stub functions is to help users to configure datafusion diff --git a/datafusion/functions/src/planner.rs b/datafusion/functions/src/planner.rs index 7228cdc07e72..ccd167997003 100644 --- a/datafusion/functions/src/planner.rs +++ b/datafusion/functions/src/planner.rs @@ -25,7 +25,7 @@ use datafusion_expr::{ }; #[deprecated( - since = "0.50.0", + since = "50.0.0", note = "Use UnicodeFunctionPlanner and DateTimeFunctionPlanner instead" )] #[derive(Default, Debug)] diff --git a/datafusion/macros/Cargo.toml b/datafusion/macros/Cargo.toml index fe979720bc56..64781ddeaf42 100644 --- a/datafusion/macros/Cargo.toml +++ b/datafusion/macros/Cargo.toml @@ -43,4 +43,4 @@ proc-macro = true [dependencies] datafusion-doc = { workspace = true } quote = "1.0.41" -syn = { version = "2.0.106", features = ["full"] } +syn = { version = "2.0.108", features = ["full"] } diff --git a/datafusion/optimizer/src/analyzer/type_coercion.rs b/datafusion/optimizer/src/analyzer/type_coercion.rs index 3d5dee3a7255..4fb0f8553b4b 100644 --- a/datafusion/optimizer/src/analyzer/type_coercion.rs +++ b/datafusion/optimizer/src/analyzer/type_coercion.rs @@ -2117,7 +2117,7 @@ mod test { assert_analyzed_plan_eq!( plan, @r#" - Projection: CAST(Utf8("1998-03-18") AS Timestamp(Nanosecond, None)) = CAST(CAST(Utf8("1998-03-18") AS Date32) AS Timestamp(Nanosecond, None)) + Projection: CAST(Utf8("1998-03-18") AS Timestamp(ns)) = CAST(CAST(Utf8("1998-03-18") AS Date32) AS Timestamp(ns)) EmptyRelation: rows=0 "# ) @@ -2258,7 +2258,7 @@ mod test { let err = coerce_case_expression(case, &schema).unwrap_err(); assert_snapshot!( err.strip_backtrace(), - @"Error during planning: Failed to coerce then (Date32, Float32, Binary) and else (Timestamp(Nanosecond, None)) to common types in CASE WHEN expression" + @"Error during planning: Failed to coerce then (Date32, Float32, Binary) and else (Timestamp(ns)) to common types in CASE WHEN expression" ); Ok(()) @@ -2465,7 +2465,7 @@ mod test { assert_analyzed_plan_eq!( plan, @r#" - Projection: a = CAST(CAST(a AS Map(Field { name: "key_value", data_type: Struct([Field { name: "key", data_type: Utf8, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }, Field { name: "value", data_type: Float64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }]), nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }, false)) AS Map(Field { name: "entries", data_type: Struct([Field { name: "key", data_type: Utf8, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }, Field { name: "value", data_type: Float64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }]), nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }, false)) + Projection: a = CAST(CAST(a AS Map("key_value": Struct("key": Utf8, "value": nullable Float64), unsorted)) AS Map("entries": Struct("key": Utf8, "value": nullable Float64), unsorted)) EmptyRelation: rows=0 "# ) @@ -2488,7 +2488,7 @@ mod test { assert_analyzed_plan_eq!( plan, @r#" - Projection: IntervalYearMonth("12") + CAST(Utf8("2000-01-01T00:00:00") AS Timestamp(Nanosecond, None)) + Projection: IntervalYearMonth("12") + CAST(Utf8("2000-01-01T00:00:00") AS Timestamp(ns)) EmptyRelation: rows=0 "# ) @@ -2513,7 +2513,7 @@ mod test { assert_analyzed_plan_eq!( plan, @r#" - Projection: CAST(Utf8("1998-03-18") AS Timestamp(Nanosecond, None)) - CAST(Utf8("1998-03-18") AS Timestamp(Nanosecond, None)) + Projection: CAST(Utf8("1998-03-18") AS Timestamp(ns)) - CAST(Utf8("1998-03-18") AS Timestamp(ns)) EmptyRelation: rows=0 "# ) diff --git a/datafusion/optimizer/src/common_subexpr_eliminate.rs b/datafusion/optimizer/src/common_subexpr_eliminate.rs index ec1f8f991a8e..251006849459 100644 --- a/datafusion/optimizer/src/common_subexpr_eliminate.rs +++ b/datafusion/optimizer/src/common_subexpr_eliminate.rs @@ -652,10 +652,8 @@ impl CSEController for ExprCSEController<'_> { // In case of `ScalarFunction`s we don't know which children are surely // executed so start visiting all children conditionally and stop the // recursion with `TreeNodeRecursion::Jump`. - Expr::ScalarFunction(ScalarFunction { func, args }) - if func.short_circuits() => - { - Some((vec![], args.iter().collect())) + Expr::ScalarFunction(ScalarFunction { func, args }) => { + func.conditional_arguments(args) } // In case of `And` and `Or` the first child is surely executed, but we diff --git a/datafusion/optimizer/src/decorrelate_predicate_subquery.rs b/datafusion/optimizer/src/decorrelate_predicate_subquery.rs index c8be689fc5a4..ccf90f91e68f 100644 --- a/datafusion/optimizer/src/decorrelate_predicate_subquery.rs +++ b/datafusion/optimizer/src/decorrelate_predicate_subquery.rs @@ -1972,14 +1972,14 @@ mod tests { assert_optimized_plan_equal!( plan, - @r#" + @r" Projection: test.b [b:UInt32] LeftSemi Join: Filter: Boolean(true) [a:UInt32, b:UInt32, c:UInt32] TableScan: test [a:UInt32, b:UInt32, c:UInt32] SubqueryAlias: __correlated_sq_1 [arr:Int32;N] Unnest: lists[sq.arr|depth=1] structs[] [arr:Int32;N] - TableScan: sq [arr:List(Field { name: "item", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} });N] - "# + TableScan: sq [arr:List(Field { data_type: Int32, nullable: true });N] + " ) } @@ -2007,14 +2007,14 @@ mod tests { assert_optimized_plan_equal!( plan, - @r#" + @r" Projection: test.b [b:UInt32] LeftSemi Join: Filter: __correlated_sq_1.a = test.b [a:UInt32, b:UInt32, c:UInt32] TableScan: test [a:UInt32, b:UInt32, c:UInt32] SubqueryAlias: __correlated_sq_1 [a:UInt32;N] Unnest: lists[sq.a|depth=1] structs[] [a:UInt32;N] - TableScan: sq [a:List(Field { name: "item", data_type: UInt32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} });N] - "# + TableScan: sq [a:List(Field { data_type: UInt32, nullable: true });N] + " ) } diff --git a/datafusion/optimizer/src/replace_distinct_aggregate.rs b/datafusion/optimizer/src/replace_distinct_aggregate.rs index 2383787fa0e8..215f5e240d5d 100644 --- a/datafusion/optimizer/src/replace_distinct_aggregate.rs +++ b/datafusion/optimizer/src/replace_distinct_aggregate.rs @@ -25,7 +25,7 @@ use datafusion_common::tree_node::Transformed; use datafusion_common::{Column, Result}; use datafusion_expr::expr_rewriter::normalize_cols; use datafusion_expr::utils::expand_wildcard; -use datafusion_expr::{col, ExprFunctionExt, LogicalPlanBuilder}; +use datafusion_expr::{col, lit, ExprFunctionExt, Limit, LogicalPlanBuilder}; use datafusion_expr::{Aggregate, Distinct, DistinctOn, Expr, LogicalPlan}; /// Optimizer that replaces logical [[Distinct]] with a logical [[Aggregate]] @@ -54,6 +54,17 @@ use datafusion_expr::{Aggregate, Distinct, DistinctOn, Expr, LogicalPlan}; /// ) /// ORDER BY a DESC /// ``` +/// +/// In case there are no columns, the [[Distinct]] is replaced by a [[Limit]] +/// +/// ```text +/// SELECT DISTINCT * FROM empty_table +/// ``` +/// +/// Into +/// ```text +/// SELECT * FROM empty_table LIMIT 1 +/// ``` #[derive(Default, Debug)] pub struct ReplaceDistinctWithAggregate {} @@ -78,6 +89,16 @@ impl OptimizerRule for ReplaceDistinctWithAggregate { LogicalPlan::Distinct(Distinct::All(input)) => { let group_expr = expand_wildcard(input.schema(), &input, None)?; + if group_expr.is_empty() { + // Special case: there are no columns to group by, so we can't replace it by a group by + // however, we can replace it by LIMIT 1 because there is either no output or a single empty row + return Ok(Transformed::yes(LogicalPlan::Limit(Limit { + skip: None, + fetch: Some(Box::new(lit(1i64))), + input, + }))); + } + let field_count = input.schema().fields().len(); for dep in input.schema().functional_dependencies().iter() { // If distinct is exactly the same with a previous GROUP BY, we can @@ -184,15 +205,17 @@ impl OptimizerRule for ReplaceDistinctWithAggregate { #[cfg(test)] mod tests { - use std::sync::Arc; - use crate::assert_optimized_plan_eq_snapshot; use crate::replace_distinct_aggregate::ReplaceDistinctWithAggregate; use crate::test::*; + use arrow::datatypes::{Fields, Schema}; + use std::sync::Arc; use crate::OptimizerContext; use datafusion_common::Result; - use datafusion_expr::{col, logical_plan::builder::LogicalPlanBuilder, Expr}; + use datafusion_expr::{ + col, logical_plan::builder::LogicalPlanBuilder, table_scan, Expr, + }; use datafusion_functions_aggregate::sum::sum; macro_rules! assert_optimized_plan_equal { @@ -274,4 +297,16 @@ mod tests { TableScan: test ") } + + #[test] + fn use_limit_1_when_no_columns() -> Result<()> { + let plan = table_scan(Some("test"), &Schema::new(Fields::empty()), None)? + .distinct()? + .build()?; + + assert_optimized_plan_equal!(plan, @r" + Limit: skip=0, fetch=1 + TableScan: test + ") + } } diff --git a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs index c40906239073..204ce14e37d8 100644 --- a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs +++ b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs @@ -29,6 +29,7 @@ use std::sync::Arc; use datafusion_common::{ cast::{as_large_list_array, as_list_array}, + metadata::FieldMetadata, tree_node::{Transformed, TransformedResult, TreeNode, TreeNodeRewriter}, }; use datafusion_common::{ @@ -57,7 +58,6 @@ use crate::simplify_expressions::unwrap_cast::{ unwrap_cast_in_comparison_for_binary, }; use crate::simplify_expressions::SimplifyInfo; -use datafusion_expr::expr::FieldMetadata; use datafusion_expr_common::casts::try_cast_literal_to_type; use indexmap::IndexSet; use regex::Regex; diff --git a/datafusion/optimizer/src/simplify_expressions/simplify_predicates.rs b/datafusion/optimizer/src/simplify_expressions/simplify_predicates.rs index 131404e60706..e811ce731310 100644 --- a/datafusion/optimizer/src/simplify_expressions/simplify_predicates.rs +++ b/datafusion/optimizer/src/simplify_expressions/simplify_predicates.rs @@ -194,7 +194,7 @@ fn find_most_restrictive_predicate( let mut best_value: Option<&ScalarValue> = None; for (idx, pred) in predicates.iter().enumerate() { - if let Expr::BinaryExpr(BinaryExpr { left, op: _, right }) = pred { + if let Expr::BinaryExpr(BinaryExpr { left, op, right }) = pred { // Extract the literal value based on which side has it let scalar_value = match (right.as_literal(), left.as_literal()) { (Some(scalar), _) => Some(scalar), @@ -207,8 +207,12 @@ fn find_most_restrictive_predicate( let comparison = scalar.try_cmp(current_best)?; let is_better = if find_greater { comparison == std::cmp::Ordering::Greater + || (comparison == std::cmp::Ordering::Equal + && op == &Operator::Gt) } else { comparison == std::cmp::Ordering::Less + || (comparison == std::cmp::Ordering::Equal + && op == &Operator::Lt) }; if is_better { diff --git a/datafusion/physical-expr/benches/case_when.rs b/datafusion/physical-expr/benches/case_when.rs index 5a88604716d2..e52aeb1aee12 100644 --- a/datafusion/physical-expr/benches/case_when.rs +++ b/datafusion/physical-expr/benches/case_when.rs @@ -15,110 +15,506 @@ // specific language governing permissions and limitations // under the License. -use arrow::array::builder::{Int32Builder, StringBuilder}; -use arrow::datatypes::{DataType, Field, Schema}; +use arrow::array::{Array, ArrayRef, Int32Array, Int32Builder, StringArray}; +use arrow::datatypes::{ArrowNativeTypeOp, Field, Schema}; use arrow::record_batch::RecordBatch; -use criterion::{black_box, criterion_group, criterion_main, Criterion}; -use datafusion_common::ScalarValue; +use arrow::util::test_util::seedable_rng; +use criterion::{black_box, criterion_group, criterion_main, BenchmarkId, Criterion}; use datafusion_expr::Operator; -use datafusion_physical_expr::expressions::{BinaryExpr, CaseExpr, Column, Literal}; +use datafusion_physical_expr::expressions::{case, col, lit, BinaryExpr}; use datafusion_physical_expr_common::physical_expr::PhysicalExpr; +use itertools::Itertools; +use rand::distr::uniform::SampleUniform; +use rand::distr::Alphanumeric; +use rand::rngs::StdRng; +use rand::{Rng, RngCore}; +use std::fmt::{Display, Formatter}; +use std::ops::Range; use std::sync::Arc; -fn make_col(name: &str, index: usize) -> Arc { - Arc::new(Column::new(name, index)) +fn make_x_cmp_y( + x: &Arc, + op: Operator, + y: i32, +) -> Arc { + Arc::new(BinaryExpr::new(Arc::clone(x), op, lit(y))) } -fn make_lit_i32(n: i32) -> Arc { - Arc::new(Literal::new(ScalarValue::Int32(Some(n)))) -} +/// Create a record batch with the given number of rows and columns. +/// Columns are named `c` where `i` is the column index. +/// +/// The minimum value for `column_count` is `3`. +/// `c1` contains incrementing int32 values +/// `c2` contains int32 values in blocks of 1000 that increment by 1000 +/// `c3` contains int32 values with one null inserted every 9 rows +/// `c4` to `cn`, is present, contain unspecified int32 values +fn make_batch(row_count: usize, column_count: usize) -> RecordBatch { + assert!(column_count >= 3); + + let mut c2 = Int32Builder::new(); + let mut c3 = Int32Builder::new(); + for i in 0..row_count { + c2.append_value(i as i32 / 1000 * 1000); -fn criterion_benchmark(c: &mut Criterion) { - // create input data - let mut c1 = Int32Builder::new(); - let mut c2 = StringBuilder::new(); - let mut c3 = StringBuilder::new(); - for i in 0..1000 { - c1.append_value(i); - if i % 7 == 0 { - c2.append_null(); - } else { - c2.append_value(format!("string {i}")); - } if i % 9 == 0 { c3.append_null(); } else { - c3.append_value(format!("other string {i}")); + c3.append_value(i as i32); } } - let c1 = Arc::new(c1.finish()); + let c1 = Arc::new(Int32Array::from_iter_values(0..row_count as i32)); let c2 = Arc::new(c2.finish()); let c3 = Arc::new(c3.finish()); - let schema = Schema::new(vec![ - Field::new("c1", DataType::Int32, true), - Field::new("c2", DataType::Utf8, true), - Field::new("c3", DataType::Utf8, true), - ]); - let batch = RecordBatch::try_new(Arc::new(schema), vec![c1, c2, c3]).unwrap(); - - // use same predicate for all benchmarks - let predicate = Arc::new(BinaryExpr::new( - make_col("c1", 0), - Operator::LtEq, - make_lit_i32(500), - )); - - // CASE WHEN c1 <= 500 THEN 1 ELSE 0 END - c.bench_function("case_when: scalar or scalar", |b| { + let mut columns: Vec = vec![c1, c2, c3]; + for _ in 3..column_count { + columns.push(Arc::new(Int32Array::from_iter_values(0..row_count as i32))); + } + + let fields = columns + .iter() + .enumerate() + .map(|(i, c)| { + Field::new( + format!("c{}", i + 1), + c.data_type().clone(), + c.is_nullable(), + ) + }) + .collect::>(); + + let schema = Arc::new(Schema::new(fields)); + RecordBatch::try_new(Arc::clone(&schema), columns).unwrap() +} + +fn criterion_benchmark(c: &mut Criterion) { + run_benchmarks(c, &make_batch(8192, 3)); + run_benchmarks(c, &make_batch(8192, 50)); + run_benchmarks(c, &make_batch(8192, 100)); + + benchmark_lookup_table_case_when(c, 8192); +} + +fn run_benchmarks(c: &mut Criterion, batch: &RecordBatch) { + let c1 = col("c1", &batch.schema()).unwrap(); + let c2 = col("c2", &batch.schema()).unwrap(); + let c3 = col("c3", &batch.schema()).unwrap(); + + // No expression, when/then/else, literal values + c.bench_function( + format!( + "case_when {}x{}: CASE WHEN c1 <= 500 THEN 1 ELSE 0 END", + batch.num_rows(), + batch.num_columns() + ) + .as_str(), + |b| { + let expr = Arc::new( + case( + None, + vec![(make_x_cmp_y(&c1, Operator::LtEq, 500), lit(1))], + Some(lit(0)), + ) + .unwrap(), + ); + b.iter(|| black_box(expr.evaluate(black_box(batch)).unwrap())) + }, + ); + + // No expression, when/then/else, column reference values + c.bench_function( + format!( + "case_when {}x{}: CASE WHEN c1 <= 500 THEN c2 ELSE c3 END", + batch.num_rows(), + batch.num_columns() + ) + .as_str(), + |b| { + let expr = Arc::new( + case( + None, + vec![(make_x_cmp_y(&c1, Operator::LtEq, 500), Arc::clone(&c2))], + Some(Arc::clone(&c3)), + ) + .unwrap(), + ); + b.iter(|| black_box(expr.evaluate(black_box(batch)).unwrap())) + }, + ); + + // No expression, when/then, implicit else + c.bench_function( + format!( + "case_when {}x{}: CASE WHEN c1 <= 500 THEN c2 [ELSE NULL] END", + batch.num_rows(), + batch.num_columns() + ) + .as_str(), + |b| { + let expr = Arc::new( + case( + None, + vec![(make_x_cmp_y(&c1, Operator::LtEq, 500), Arc::clone(&c2))], + None, + ) + .unwrap(), + ); + b.iter(|| black_box(expr.evaluate(black_box(batch)).unwrap())) + }, + ); + + // With expression, two when/then branches + c.bench_function( + format!( + "case_when {}x{}: CASE c1 WHEN 1 THEN c2 WHEN 2 THEN c3 END", + batch.num_rows(), + batch.num_columns() + ) + .as_str(), + |b| { + let expr = Arc::new( + case( + Some(Arc::clone(&c1)), + vec![(lit(1), Arc::clone(&c2)), (lit(2), Arc::clone(&c3))], + None, + ) + .unwrap(), + ); + b.iter(|| black_box(expr.evaluate(black_box(batch)).unwrap())) + }, + ); + + // Many when/then branches where all are effectively reachable + c.bench_function(format!("case_when {}x{}: CASE WHEN c1 == 0 THEN 0 WHEN c1 == 1 THEN 1 ... WHEN c1 == n THEN n ELSE n + 1 END", batch.num_rows(), batch.num_columns()).as_str(), |b| { + let when_thens = (0..batch.num_rows() as i32).map(|i| (make_x_cmp_y(&c1, Operator::Eq, i), lit(i))).collect(); let expr = Arc::new( - CaseExpr::try_new( + case( None, - vec![(predicate.clone(), make_lit_i32(1))], - Some(make_lit_i32(0)), + when_thens, + Some(lit(batch.num_rows() as i32)) ) - .unwrap(), + .unwrap(), ); - b.iter(|| black_box(expr.evaluate(black_box(&batch)).unwrap())) + b.iter(|| black_box(expr.evaluate(black_box(batch)).unwrap())) }); - // CASE WHEN c1 <= 500 THEN c2 [ELSE NULL] END - c.bench_function("case_when: column or null", |b| { + // Many when/then branches where all but the first few are effectively unreachable + c.bench_function(format!("case_when {}x{}: CASE WHEN c1 < 0 THEN 0 WHEN c1 < 1000 THEN 1 ... WHEN c1 < n * 1000 THEN n ELSE n + 1 END", batch.num_rows(), batch.num_columns()).as_str(), |b| { + let when_thens = (0..batch.num_rows() as i32).map(|i| (make_x_cmp_y(&c1, Operator::Lt, i * 1000), lit(i))).collect(); let expr = Arc::new( - CaseExpr::try_new(None, vec![(predicate.clone(), make_col("c2", 1))], None) + case( + None, + when_thens, + Some(lit(batch.num_rows() as i32)) + ) .unwrap(), ); - b.iter(|| black_box(expr.evaluate(black_box(&batch)).unwrap())) + b.iter(|| black_box(expr.evaluate(black_box(batch)).unwrap())) }); - // CASE WHEN c1 <= 500 THEN c2 ELSE c3 END - c.bench_function("case_when: expr or expr", |b| { + // Many when/then branches where all are effectively reachable + c.bench_function(format!("case_when {}x{}: CASE c1 WHEN 0 THEN 0 WHEN 1 THEN 1 ... WHEN n THEN n ELSE n + 1 END", batch.num_rows(), batch.num_columns()).as_str(), |b| { + let when_thens = (0..batch.num_rows() as i32).map(|i| (lit(i), lit(i))).collect(); let expr = Arc::new( - CaseExpr::try_new( - None, - vec![(predicate.clone(), make_col("c2", 1))], - Some(make_col("c3", 2)), + case( + Some(Arc::clone(&c1)), + when_thens, + Some(lit(batch.num_rows() as i32)) ) - .unwrap(), + .unwrap(), ); - b.iter(|| black_box(expr.evaluate(black_box(&batch)).unwrap())) + b.iter(|| black_box(expr.evaluate(black_box(batch)).unwrap())) }); - // CASE c1 WHEN 1 THEN c2 WHEN 2 THEN c3 END - c.bench_function("case_when: CASE expr", |b| { + // Many when/then branches where all but the first few are effectively unreachable + c.bench_function(format!("case_when {}x{}: CASE c2 WHEN 0 THEN 0 WHEN 1000 THEN 1 ... WHEN n * 1000 THEN n ELSE n + 1 END", batch.num_rows(), batch.num_columns()).as_str(), |b| { + let when_thens = (0..batch.num_rows() as i32).map(|i| (lit(i * 1000), lit(i))).collect(); let expr = Arc::new( - CaseExpr::try_new( - Some(make_col("c1", 0)), - vec![ - (make_lit_i32(1), make_col("c2", 1)), - (make_lit_i32(2), make_col("c3", 2)), - ], - None, + case( + Some(Arc::clone(&c2)), + when_thens, + Some(lit(batch.num_rows() as i32)) ) - .unwrap(), + .unwrap(), ); - b.iter(|| black_box(expr.evaluate(black_box(&batch)).unwrap())) + b.iter(|| black_box(expr.evaluate(black_box(batch)).unwrap())) }); } +struct Options { + number_of_rows: usize, + range_of_values: Vec, + in_range_probability: f32, + null_probability: f32, +} + +fn generate_other_primitive_value( + rng: &mut impl RngCore, + exclude: &[T], +) -> T { + let mut value; + let retry_limit = 100; + for _ in 0..retry_limit { + value = rng.random_range(T::MIN_TOTAL_ORDER..=T::MAX_TOTAL_ORDER); + if !exclude.contains(&value) { + return value; + } + } + + panic!("Could not generate out of range value after {retry_limit} attempts"); +} + +fn create_random_string_generator( + length: Range, +) -> impl Fn(&mut dyn RngCore, &[String]) -> String { + assert!(length.end > length.start); + + move |rng, exclude| { + let retry_limit = 100; + for _ in 0..retry_limit { + let length = rng.random_range(length.clone()); + let value: String = rng + .sample_iter(Alphanumeric) + .take(length) + .map(char::from) + .collect(); + + if !exclude.contains(&value) { + return value; + } + } + + panic!("Could not generate out of range value after {retry_limit} attempts"); + } +} + +/// Create column with the provided number of rows +/// `in_range_percentage` is the percentage of values that should be inside the specified range +/// `null_percentage` is the percentage of null values +/// The rest of the values will be outside the specified range +fn generate_values_for_lookup( + options: Options, + generate_other_value: impl Fn(&mut StdRng, &[T]) -> T, +) -> A +where + T: Clone, + A: FromIterator>, +{ + // Create a value with specified range most of the time, but also some nulls and the rest is generic + + assert!( + options.in_range_probability + options.null_probability <= 1.0, + "Percentages must sum to 1.0 or less" + ); + + let rng = &mut seedable_rng(); + + let in_range_probability = 0.0..options.in_range_probability; + let null_range_probability = + in_range_probability.start..in_range_probability.start + options.null_probability; + let out_range_probability = null_range_probability.end..1.0; + + (0..options.number_of_rows) + .map(|_| { + let roll: f32 = rng.random(); + + match roll { + v if out_range_probability.contains(&v) => { + let index = rng.random_range(0..options.range_of_values.len()); + // Generate value in range + Some(options.range_of_values[index].clone()) + } + v if null_range_probability.contains(&v) => None, + _ => { + // Generate value out of range + Some(generate_other_value(rng, &options.range_of_values)) + } + } + }) + .collect::() +} + +fn benchmark_lookup_table_case_when(c: &mut Criterion, batch_size: usize) { + #[derive(Clone, Copy, Debug)] + struct CaseWhenLookupInput { + batch_size: usize, + + in_range_probability: f32, + null_probability: f32, + } + + impl Display for CaseWhenLookupInput { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + write!( + f, + "case_when {} rows: in_range: {}, nulls: {}", + self.batch_size, self.in_range_probability, self.null_probability, + ) + } + } + + let mut case_when_lookup = c.benchmark_group("lookup_table_case_when"); + + for in_range_probability in [0.1, 0.5, 0.9, 1.0] { + for null_probability in [0.0, 0.1, 0.5] { + if in_range_probability + null_probability > 1.0 { + continue; + } + + let input = CaseWhenLookupInput { + batch_size, + in_range_probability, + null_probability, + }; + + let when_thens_primitive_to_string = vec![ + (1, "something"), + (2, "very"), + (3, "interesting"), + (4, "is"), + (5, "going"), + (6, "to"), + (7, "happen"), + (30, "in"), + (31, "datafusion"), + (90, "when"), + (91, "you"), + (92, "find"), + (93, "it"), + (120, "let"), + (240, "me"), + (241, "know"), + (244, "please"), + (246, "thank"), + (250, "you"), + (252, "!"), + ]; + let when_thens_string_to_primitive = when_thens_primitive_to_string + .iter() + .map(|&(key, value)| (value, key)) + .collect_vec(); + + for num_entries in [5, 10, 20] { + for (name, values_range) in [ + ("all equally true", 0..num_entries), + // Test when early termination is beneficial + ("only first 2 are true", 0..2), + ] { + let when_thens_primitive_to_string = + when_thens_primitive_to_string[values_range.clone()].to_vec(); + + let when_thens_string_to_primitive = + when_thens_string_to_primitive[values_range].to_vec(); + + case_when_lookup.bench_with_input( + BenchmarkId::new( + format!( + "case when i32 -> utf8, {num_entries} entries, {name}" + ), + input, + ), + &input, + |b, input| { + let array: Int32Array = generate_values_for_lookup( + Options:: { + number_of_rows: batch_size, + range_of_values: when_thens_primitive_to_string + .iter() + .map(|(key, _)| *key) + .collect(), + in_range_probability: input.in_range_probability, + null_probability: input.null_probability, + }, + |rng, exclude| { + generate_other_primitive_value::(rng, exclude) + }, + ); + let batch = RecordBatch::try_new( + Arc::new(Schema::new(vec![Field::new( + "col1", + array.data_type().clone(), + true, + )])), + vec![Arc::new(array)], + ) + .unwrap(); + + let when_thens = when_thens_primitive_to_string + .iter() + .map(|&(key, value)| (lit(key), lit(value))) + .collect(); + + let expr = Arc::new( + case( + Some(col("col1", batch.schema_ref()).unwrap()), + when_thens, + Some(lit("whatever")), + ) + .unwrap(), + ); + + b.iter(|| { + black_box(expr.evaluate(black_box(&batch)).unwrap()) + }) + }, + ); + + case_when_lookup.bench_with_input( + BenchmarkId::new( + format!( + "case when utf8 -> i32, {num_entries} entries, {name}" + ), + input, + ), + &input, + |b, input| { + let array: StringArray = generate_values_for_lookup( + Options:: { + number_of_rows: batch_size, + range_of_values: when_thens_string_to_primitive + .iter() + .map(|(key, _)| (*key).to_string()) + .collect(), + in_range_probability: input.in_range_probability, + null_probability: input.null_probability, + }, + |rng, exclude| { + create_random_string_generator(3..10)(rng, exclude) + }, + ); + let batch = RecordBatch::try_new( + Arc::new(Schema::new(vec![Field::new( + "col1", + array.data_type().clone(), + true, + )])), + vec![Arc::new(array)], + ) + .unwrap(); + + let when_thens = when_thens_string_to_primitive + .iter() + .map(|&(key, value)| (lit(key), lit(value))) + .collect(); + + let expr = Arc::new( + case( + Some(col("col1", batch.schema_ref()).unwrap()), + when_thens, + Some(lit(1000)), + ) + .unwrap(), + ); + + b.iter(|| { + black_box(expr.evaluate(black_box(&batch)).unwrap()) + }) + }, + ); + } + } + } + } +} + criterion_group!(benches, criterion_benchmark); criterion_main!(benches); diff --git a/datafusion/physical-expr/src/equivalence/class.rs b/datafusion/physical-expr/src/equivalence/class.rs index 66ce77ef415e..5b64884f65bb 100644 --- a/datafusion/physical-expr/src/equivalence/class.rs +++ b/datafusion/physical-expr/src/equivalence/class.rs @@ -20,10 +20,10 @@ use std::ops::Deref; use std::sync::Arc; use std::vec::IntoIter; -use super::projection::ProjectionTargets; use super::ProjectionMapping; use crate::expressions::Literal; use crate::physical_expr::add_offset_to_expr; +use crate::projection::ProjectionTargets; use crate::{PhysicalExpr, PhysicalExprRef, PhysicalSortExpr, PhysicalSortRequirement}; use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; diff --git a/datafusion/physical-expr/src/equivalence/mod.rs b/datafusion/physical-expr/src/equivalence/mod.rs index bcc6835e2f6c..a7289103806b 100644 --- a/datafusion/physical-expr/src/equivalence/mod.rs +++ b/datafusion/physical-expr/src/equivalence/mod.rs @@ -25,12 +25,13 @@ use datafusion_physical_expr_common::sort_expr::{LexOrdering, PhysicalSortExpr}; mod class; mod ordering; -mod projection; mod properties; pub use class::{AcrossPartitions, ConstExpr, EquivalenceClass, EquivalenceGroup}; pub use ordering::OrderingEquivalenceClass; -pub use projection::{project_ordering, project_orderings, ProjectionMapping}; +// Re-export for backwards compatibility, we recommend importing from +// datafusion_physical_expr::projection instead +pub use crate::projection::{project_ordering, project_orderings, ProjectionMapping}; pub use properties::{ calculate_union, join_equivalence_properties, EquivalenceProperties, }; @@ -61,7 +62,7 @@ mod tests { use arrow::compute::SortOptions; use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; - use datafusion_common::{plan_err, Result}; + use datafusion_common::Result; use datafusion_physical_expr_common::sort_expr::PhysicalSortRequirement; /// Converts a string to a physical sort expression @@ -95,31 +96,6 @@ mod tests { sort_expr } - pub fn output_schema( - mapping: &ProjectionMapping, - input_schema: &Arc, - ) -> Result { - // Calculate output schema: - let mut fields = vec![]; - for (source, targets) in mapping.iter() { - let data_type = source.data_type(input_schema)?; - let nullable = source.nullable(input_schema)?; - for (target, _) in targets.iter() { - let Some(column) = target.as_any().downcast_ref::() else { - return plan_err!("Expects to have column"); - }; - fields.push(Field::new(column.name(), data_type.clone(), nullable)); - } - } - - let output_schema = Arc::new(Schema::new_with_metadata( - fields, - input_schema.metadata().clone(), - )); - - Ok(output_schema) - } - // Generate a schema which consists of 8 columns (a, b, c, d, e, f, g, h) pub fn create_test_schema() -> Result { let a = Field::new("a", DataType::Int32, true); diff --git a/datafusion/physical-expr/src/equivalence/properties/dependency.rs b/datafusion/physical-expr/src/equivalence/properties/dependency.rs index 26d5d32c6512..8945d18be430 100644 --- a/datafusion/physical-expr/src/equivalence/properties/dependency.rs +++ b/datafusion/physical-expr/src/equivalence/properties/dependency.rs @@ -387,11 +387,11 @@ mod tests { use super::*; use crate::equivalence::tests::{ - convert_to_sort_reqs, create_test_params, create_test_schema, output_schema, - parse_sort_expr, + convert_to_sort_reqs, create_test_params, create_test_schema, parse_sort_expr, }; use crate::equivalence::{convert_to_sort_exprs, ProjectionMapping}; use crate::expressions::{col, BinaryExpr, CastExpr, Column}; + use crate::projection::tests::output_schema; use crate::{ConstExpr, EquivalenceProperties, ScalarFunctionExpr}; use arrow::compute::SortOptions; diff --git a/datafusion/physical-expr/src/expressions/binary/kernels.rs b/datafusion/physical-expr/src/expressions/binary/kernels.rs index 71d1242eea85..36ecd1c81619 100644 --- a/datafusion/physical-expr/src/expressions/binary/kernels.rs +++ b/datafusion/physical-expr/src/expressions/binary/kernels.rs @@ -145,12 +145,14 @@ pub fn concat_elements_utf8view( left: &StringViewArray, right: &StringViewArray, ) -> std::result::Result { - let capacity = left - .data_buffers() - .iter() - .zip(right.data_buffers().iter()) - .map(|(b1, b2)| b1.len() + b2.len()) - .sum(); + if left.len() != right.len() { + return Err(ArrowError::ComputeError(format!( + "Arrays must have the same length: {} != {}", + left.len(), + right.len() + ))); + } + let capacity = left.len(); let mut result = StringViewBuilder::with_capacity(capacity); // Avoid reallocations by writing to a reused buffer (note we diff --git a/datafusion/physical-expr/src/expressions/case.rs b/datafusion/physical-expr/src/expressions/case.rs index d14146a20d8b..2db599047bcd 100644 --- a/datafusion/physical-expr/src/expressions/case.rs +++ b/datafusion/physical-expr/src/expressions/case.rs @@ -205,10 +205,15 @@ impl CaseExpr { let mut current_value = new_null_array(&return_type, batch.num_rows()); // We only consider non-null values while comparing with whens let mut remainder = not(&base_nulls)?; + let mut non_null_remainder_count = remainder.true_count(); for i in 0..self.when_then_expr.len() { - let when_value = self.when_then_expr[i] - .0 - .evaluate_selection(batch, &remainder)?; + // If there are no rows left to process, break out of the loop early + if non_null_remainder_count == 0 { + break; + } + + let when_predicate = &self.when_then_expr[i].0; + let when_value = when_predicate.evaluate_selection(batch, &remainder)?; let when_value = when_value.into_array(batch.num_rows())?; // build boolean array representing which rows match the "when" value let when_match = compare_with_eq( @@ -224,41 +229,46 @@ impl CaseExpr { _ => Cow::Owned(prep_null_mask_filter(&when_match)), }; // Make sure we only consider rows that have not been matched yet - let when_match = and(&when_match, &remainder)?; + let when_value = and(&when_match, &remainder)?; - // When no rows available for when clause, skip then clause - if when_match.true_count() == 0 { + // If the predicate did not match any rows, continue to the next branch immediately + let when_match_count = when_value.true_count(); + if when_match_count == 0 { continue; } - let then_value = self.when_then_expr[i] - .1 - .evaluate_selection(batch, &when_match)?; + let then_expression = &self.when_then_expr[i].1; + let then_value = then_expression.evaluate_selection(batch, &when_value)?; current_value = match then_value { ColumnarValue::Scalar(ScalarValue::Null) => { - nullif(current_value.as_ref(), &when_match)? + nullif(current_value.as_ref(), &when_value)? } ColumnarValue::Scalar(then_value) => { - zip(&when_match, &then_value.to_scalar()?, ¤t_value)? + zip(&when_value, &then_value.to_scalar()?, ¤t_value)? } ColumnarValue::Array(then_value) => { - zip(&when_match, &then_value, ¤t_value)? + zip(&when_value, &then_value, ¤t_value)? } }; - remainder = and_not(&remainder, &when_match)?; + remainder = and_not(&remainder, &when_value)?; + non_null_remainder_count -= when_match_count; } if let Some(e) = self.else_expr() { - // keep `else_expr`'s data type and return type consistent - let expr = try_cast(Arc::clone(e), &batch.schema(), return_type.clone())?; // null and unmatched tuples should be assigned else value remainder = or(&base_nulls, &remainder)?; - let else_ = expr - .evaluate_selection(batch, &remainder)? - .into_array(batch.num_rows())?; - current_value = zip(&remainder, &else_, ¤t_value)?; + + if remainder.true_count() > 0 { + // keep `else_expr`'s data type and return type consistent + let expr = try_cast(Arc::clone(e), &batch.schema(), return_type.clone())?; + + let else_ = expr + .evaluate_selection(batch, &remainder)? + .into_array(batch.num_rows())?; + current_value = zip(&remainder, &else_, ¤t_value)?; + } } Ok(ColumnarValue::Array(current_value)) @@ -277,10 +287,15 @@ impl CaseExpr { // start with nulls as default output let mut current_value = new_null_array(&return_type, batch.num_rows()); let mut remainder = BooleanArray::from(vec![true; batch.num_rows()]); + let mut remainder_count = batch.num_rows(); for i in 0..self.when_then_expr.len() { - let when_value = self.when_then_expr[i] - .0 - .evaluate_selection(batch, &remainder)?; + // If there are no rows left to process, break out of the loop early + if remainder_count == 0 { + break; + } + + let when_predicate = &self.when_then_expr[i].0; + let when_value = when_predicate.evaluate_selection(batch, &remainder)?; let when_value = when_value.into_array(batch.num_rows())?; let when_value = as_boolean_array(&when_value).map_err(|_| { internal_datafusion_err!("WHEN expression did not return a BooleanArray") @@ -293,14 +308,14 @@ impl CaseExpr { // Make sure we only consider rows that have not been matched yet let when_value = and(&when_value, &remainder)?; - // When no rows available for when clause, skip then clause - if when_value.true_count() == 0 { + // If the predicate did not match any rows, continue to the next branch immediately + let when_match_count = when_value.true_count(); + if when_match_count == 0 { continue; } - let then_value = self.when_then_expr[i] - .1 - .evaluate_selection(batch, &when_value)?; + let then_expression = &self.when_then_expr[i].1; + let then_value = then_expression.evaluate_selection(batch, &when_value)?; current_value = match then_value { ColumnarValue::Scalar(ScalarValue::Null) => { @@ -317,10 +332,11 @@ impl CaseExpr { // Succeed tuples should be filtered out for short-circuit evaluation, // null values for the current when expr should be kept remainder = and_not(&remainder, &when_value)?; + remainder_count -= when_match_count; } if let Some(e) = self.else_expr() { - if remainder.true_count() > 0 { + if remainder_count > 0 { // keep `else_expr`'s data type and return type consistent let expr = try_cast(Arc::clone(e), &batch.schema(), return_type.clone())?; let else_ = expr diff --git a/datafusion/physical-expr/src/expressions/cast.rs b/datafusion/physical-expr/src/expressions/cast.rs index 407e3e6a9d29..0419161b532c 100644 --- a/datafusion/physical-expr/src/expressions/cast.rs +++ b/datafusion/physical-expr/src/expressions/cast.rs @@ -439,8 +439,8 @@ mod tests { let expression = cast_with_options(col("a", &schema)?, &schema, Decimal128(6, 2), None)?; let e = expression.evaluate(&batch).unwrap_err().strip_backtrace(); // panics on OK - assert_snapshot!(e, @"Arrow error: Invalid argument error: 12345679 is too large to store in a Decimal128 of precision 6. Max is 999999"); - + assert_snapshot!(e, @"Arrow error: Invalid argument error: 123456.79 is too large to store in a Decimal128 of precision 6. Max is 9999.99"); + // safe cast should return null let expression_safe = cast_with_options( col("a", &schema)?, &schema, diff --git a/datafusion/physical-expr/src/expressions/dynamic_filters.rs b/datafusion/physical-expr/src/expressions/dynamic_filters.rs index a53b32c97689..964a193db833 100644 --- a/datafusion/physical-expr/src/expressions/dynamic_filters.rs +++ b/datafusion/physical-expr/src/expressions/dynamic_filters.rs @@ -381,14 +381,14 @@ mod test { ) .unwrap(); let snap = dynamic_filter_1.snapshot().unwrap().unwrap(); - insta::assert_snapshot!(format!("{snap:?}"), @r#"BinaryExpr { left: Column { name: "a", index: 0 }, op: Eq, right: Literal { value: Int32(42), field: Field { name: "lit", data_type: Int32, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} } }, fail_on_overflow: false }"#); + insta::assert_snapshot!(format!("{snap:?}"), @r#"BinaryExpr { left: Column { name: "a", index: 0 }, op: Eq, right: Literal { value: Int32(42), field: Field { name: "lit", data_type: Int32 } }, fail_on_overflow: false }"#); let dynamic_filter_2 = reassign_expr_columns( Arc::clone(&dynamic_filter) as Arc, &filter_schema_2, ) .unwrap(); let snap = dynamic_filter_2.snapshot().unwrap().unwrap(); - insta::assert_snapshot!(format!("{snap:?}"), @r#"BinaryExpr { left: Column { name: "a", index: 1 }, op: Eq, right: Literal { value: Int32(42), field: Field { name: "lit", data_type: Int32, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} } }, fail_on_overflow: false }"#); + insta::assert_snapshot!(format!("{snap:?}"), @r#"BinaryExpr { left: Column { name: "a", index: 1 }, op: Eq, right: Literal { value: Int32(42), field: Field { name: "lit", data_type: Int32 } }, fail_on_overflow: false }"#); // Both filters allow evaluating the same expression let batch_1 = RecordBatch::try_new( Arc::clone(&filter_schema_1), diff --git a/datafusion/physical-expr/src/expressions/literal.rs b/datafusion/physical-expr/src/expressions/literal.rs index 6e425ee439d6..94e91d43a1c4 100644 --- a/datafusion/physical-expr/src/expressions/literal.rs +++ b/datafusion/physical-expr/src/expressions/literal.rs @@ -28,8 +28,8 @@ use arrow::{ datatypes::{DataType, Schema}, record_batch::RecordBatch, }; +use datafusion_common::metadata::FieldMetadata; use datafusion_common::{Result, ScalarValue}; -use datafusion_expr::expr::FieldMetadata; use datafusion_expr::Expr; use datafusion_expr_common::columnar_value::ColumnarValue; use datafusion_expr_common::interval_arithmetic::Interval; diff --git a/datafusion/physical-expr/src/lib.rs b/datafusion/physical-expr/src/lib.rs index 468591d34d71..aa8c9e50fd71 100644 --- a/datafusion/physical-expr/src/lib.rs +++ b/datafusion/physical-expr/src/lib.rs @@ -37,6 +37,7 @@ pub mod intervals; mod partitioning; mod physical_expr; pub mod planner; +pub mod projection; mod scalar_function; pub mod simplifier; pub mod statistics; diff --git a/datafusion/physical-expr/src/planner.rs b/datafusion/physical-expr/src/planner.rs index 73df60c42e96..7790380dffd5 100644 --- a/datafusion/physical-expr/src/planner.rs +++ b/datafusion/physical-expr/src/planner.rs @@ -25,13 +25,12 @@ use crate::{ use arrow::datatypes::Schema; use datafusion_common::config::ConfigOptions; +use datafusion_common::metadata::FieldMetadata; use datafusion_common::{ exec_err, not_impl_err, plan_err, DFSchema, Result, ScalarValue, ToDFSchema, }; use datafusion_expr::execution_props::ExecutionProps; -use datafusion_expr::expr::{ - Alias, Cast, FieldMetadata, InList, Placeholder, ScalarFunction, -}; +use datafusion_expr::expr::{Alias, Cast, InList, Placeholder, ScalarFunction}; use datafusion_expr::var_provider::is_system_variables; use datafusion_expr::var_provider::VarType; use datafusion_expr::{ diff --git a/datafusion/physical-expr/src/equivalence/projection.rs b/datafusion/physical-expr/src/projection.rs similarity index 51% rename from datafusion/physical-expr/src/equivalence/projection.rs rename to datafusion/physical-expr/src/projection.rs index a4ed8187cfad..e35bfbb3a20d 100644 --- a/datafusion/physical-expr/src/equivalence/projection.rs +++ b/datafusion/physical-expr/src/projection.rs @@ -19,14 +19,426 @@ use std::ops::Deref; use std::sync::Arc; use crate::expressions::Column; +use crate::utils::collect_columns; use crate::PhysicalExpr; -use arrow::datatypes::SchemaRef; +use arrow::datatypes::{Field, Schema, SchemaRef}; +use datafusion_common::stats::{ColumnStatistics, Precision}; use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; -use datafusion_common::{internal_err, plan_err, Result}; +use datafusion_common::{internal_datafusion_err, internal_err, plan_err, Result}; use datafusion_physical_expr_common::sort_expr::{LexOrdering, PhysicalSortExpr}; use indexmap::IndexMap; +use itertools::Itertools; + +/// A projection expression as used by projection operations. +/// +/// The expression is evaluated and the result is stored in a column +/// with the name specified by `alias`. +/// +/// For example, the SQL expression `a + b AS sum_ab` would be represented +/// as a `ProjectionExpr` where `expr` is the expression `a + b` +/// and `alias` is the string `sum_ab`. +#[derive(Debug, Clone)] +pub struct ProjectionExpr { + /// The expression that will be evaluated. + pub expr: Arc, + /// The name of the output column for use an output schema. + pub alias: String, +} + +impl std::fmt::Display for ProjectionExpr { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + if self.expr.to_string() == self.alias { + write!(f, "{}", self.alias) + } else { + write!(f, "{} AS {}", self.expr, self.alias) + } + } +} + +impl ProjectionExpr { + /// Create a new projection expression + pub fn new(expr: Arc, alias: String) -> Self { + Self { expr, alias } + } + + /// Create a new projection expression from an expression and a schema using the expression's output field name as alias. + pub fn new_from_expression( + expr: Arc, + schema: &Schema, + ) -> Result { + let field = expr.return_field(schema)?; + Ok(Self { + expr, + alias: field.name().to_string(), + }) + } +} + +impl From<(Arc, String)> for ProjectionExpr { + fn from(value: (Arc, String)) -> Self { + Self::new(value.0, value.1) + } +} + +impl From<&(Arc, String)> for ProjectionExpr { + fn from(value: &(Arc, String)) -> Self { + Self::new(Arc::clone(&value.0), value.1.clone()) + } +} + +impl From for (Arc, String) { + fn from(value: ProjectionExpr) -> Self { + (value.expr, value.alias) + } +} + +/// A collection of projection expressions. +/// +/// This struct encapsulates multiple `ProjectionExpr` instances, +/// representing a complete projection operation and provides +/// methods to manipulate and analyze the projection as a whole. +#[derive(Debug, Clone)] +pub struct Projection { + exprs: Vec, +} + +impl std::fmt::Display for Projection { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + let exprs: Vec = self.exprs.iter().map(|e| e.to_string()).collect(); + write!(f, "Projection[{}]", exprs.join(", ")) + } +} + +impl From> for Projection { + fn from(value: Vec) -> Self { + Self { exprs: value } + } +} + +impl From<&[ProjectionExpr]> for Projection { + fn from(value: &[ProjectionExpr]) -> Self { + Self { + exprs: value.to_vec(), + } + } +} + +impl AsRef<[ProjectionExpr]> for Projection { + fn as_ref(&self) -> &[ProjectionExpr] { + &self.exprs + } +} + +impl Projection { + pub fn new(exprs: Vec) -> Self { + Self { exprs } + } + + /// Returns an iterator over the projection expressions + pub fn iter(&self) -> impl Iterator { + self.exprs.iter() + } + + /// Creates a ProjectionMapping from this projection + pub fn projection_mapping( + &self, + input_schema: &SchemaRef, + ) -> Result { + ProjectionMapping::try_new( + self.exprs + .iter() + .map(|p| (Arc::clone(&p.expr), p.alias.clone())), + input_schema, + ) + } + + /// Iterate over a clone of the projection expressions. + pub fn expr_iter(&self) -> impl Iterator> + '_ { + self.exprs.iter().map(|e| Arc::clone(&e.expr)) + } + + /// Apply another projection on top of this projection, returning the combined projection. + /// For example, if this projection is `SELECT c@2 AS x, b@1 AS y, a@0 as z` and the other projection is `SELECT x@0 + 1 AS c1, y@1 + z@2 as c2`, + /// we return a projection equivalent to `SELECT c@2 + 1 AS c1, b@1 + a@0 as c2`. + /// + /// # Example + /// + /// ```rust + /// use std::sync::Arc; + /// use datafusion_physical_expr::projection::{Projection, ProjectionExpr}; + /// use datafusion_physical_expr::expressions::{Column, BinaryExpr, Literal}; + /// use datafusion_common::{Result, ScalarValue}; + /// use datafusion_expr::Operator; + /// + /// fn main() -> Result<()> { + /// // Example from the docstring: + /// // Base projection: SELECT c@2 AS x, b@1 AS y, a@0 AS z + /// let base = Projection::new(vec![ + /// ProjectionExpr { + /// expr: Arc::new(Column::new("c", 2)), + /// alias: "x".to_string(), + /// }, + /// ProjectionExpr { + /// expr: Arc::new(Column::new("b", 1)), + /// alias: "y".to_string(), + /// }, + /// ProjectionExpr { + /// expr: Arc::new(Column::new("a", 0)), + /// alias: "z".to_string(), + /// }, + /// ]); + /// + /// // Top projection: SELECT x@0 + 1 AS c1, y@1 + z@2 AS c2 + /// let top = Projection::new(vec![ + /// ProjectionExpr { + /// expr: Arc::new(BinaryExpr::new( + /// Arc::new(Column::new("x", 0)), + /// Operator::Plus, + /// Arc::new(Literal::new(ScalarValue::Int32(Some(1)))), + /// )), + /// alias: "c1".to_string(), + /// }, + /// ProjectionExpr { + /// expr: Arc::new(BinaryExpr::new( + /// Arc::new(Column::new("y", 1)), + /// Operator::Plus, + /// Arc::new(Column::new("z", 2)), + /// )), + /// alias: "c2".to_string(), + /// }, + /// ]); + /// + /// // Expected result: SELECT c@2 + 1 AS c1, b@1 + a@0 AS c2 + /// let result = base.try_merge(&top)?; + /// + /// assert_eq!(result.as_ref().len(), 2); + /// assert_eq!(result.as_ref()[0].alias, "c1"); + /// assert_eq!(result.as_ref()[1].alias, "c2"); + /// + /// Ok(()) + /// } + /// ``` + /// + /// # Errors + /// This function returns an error if any expression in the `other` projection cannot be + /// applied on top of this projection. + pub fn try_merge(&self, other: &Projection) -> Result { + let mut new_exprs = Vec::with_capacity(other.exprs.len()); + for proj_expr in &other.exprs { + let new_expr = update_expr(&proj_expr.expr, &self.exprs, true)? + .ok_or_else(|| { + internal_datafusion_err!( + "Failed to combine projections: expression {} could not be applied on top of existing projections {}", + proj_expr.expr, + self.exprs.iter().map(|e| format!("{e}")).join(", ") + ) + })?; + new_exprs.push(ProjectionExpr { + expr: new_expr, + alias: proj_expr.alias.clone(), + }); + } + Ok(Projection::new(new_exprs)) + } + + /// Extract the column indices used in this projection. + /// For example, for a projection `SELECT a AS x, b + 1 AS y`, where `a` is at index 0 and `b` is at index 1, + /// this function would return `[0, 1]`. + /// Repeated indices are returned only once, and the order is ascending. + pub fn column_indices(&self) -> Vec { + self.exprs + .iter() + .flat_map(|e| collect_columns(&e.expr).into_iter().map(|col| col.index())) + .sorted_unstable() + .dedup() + .collect_vec() + } + + /// Project a schema according to this projection. + /// For example, for a projection `SELECT a AS x, b + 1 AS y`, where `a` is at index 0 and `b` is at index 1, + /// if the input schema is `[a: Int32, b: Int32, c: Int32]`, the output schema would be `[x: Int32, y: Int32]`. + /// Fields' metadata are preserved from the input schema. + pub fn project_schema(&self, input_schema: &Schema) -> Result { + let fields: Result> = self + .exprs + .iter() + .map(|proj_expr| { + let metadata = proj_expr + .expr + .return_field(input_schema)? + .metadata() + .clone(); + + let field = Field::new( + &proj_expr.alias, + proj_expr.expr.data_type(input_schema)?, + proj_expr.expr.nullable(input_schema)?, + ) + .with_metadata(metadata); + + Ok(field) + }) + .collect(); + + Ok(Schema::new_with_metadata( + fields?, + input_schema.metadata().clone(), + )) + } + + /// Project statistics according to this projection. + /// For example, for a projection `SELECT a AS x, b + 1 AS y`, where `a` is at index 0 and `b` is at index 1, + /// if the input statistics has column statistics for columns `a`, `b`, and `c`, the output statistics would have column statistics for columns `x` and `y`. + pub fn project_statistics( + &self, + mut stats: datafusion_common::Statistics, + input_schema: &Schema, + ) -> Result { + let mut primitive_row_size = 0; + let mut primitive_row_size_possible = true; + let mut column_statistics = vec![]; + + for proj_expr in &self.exprs { + let expr = &proj_expr.expr; + let col_stats = if let Some(col) = expr.as_any().downcast_ref::() { + stats.column_statistics[col.index()].clone() + } else { + // TODO stats: estimate more statistics from expressions + // (expressions should compute their statistics themselves) + ColumnStatistics::new_unknown() + }; + column_statistics.push(col_stats); + let data_type = expr.data_type(input_schema)?; + if let Some(value) = data_type.primitive_width() { + primitive_row_size += value; + continue; + } + primitive_row_size_possible = false; + } + + if primitive_row_size_possible { + stats.total_byte_size = + Precision::Exact(primitive_row_size).multiply(&stats.num_rows); + } + stats.column_statistics = column_statistics; + Ok(stats) + } +} + +impl<'a> IntoIterator for &'a Projection { + type Item = &'a ProjectionExpr; + type IntoIter = std::slice::Iter<'a, ProjectionExpr>; + + fn into_iter(self) -> Self::IntoIter { + self.exprs.iter() + } +} + +impl IntoIterator for Projection { + type Item = ProjectionExpr; + type IntoIter = std::vec::IntoIter; + + fn into_iter(self) -> Self::IntoIter { + self.exprs.into_iter() + } +} + +/// The function operates in two modes: +/// +/// 1) When `sync_with_child` is `true`: +/// +/// The function updates the indices of `expr` if the expression resides +/// in the input plan. For instance, given the expressions `a@1 + b@2` +/// and `c@0` with the input schema `c@2, a@0, b@1`, the expressions are +/// updated to `a@0 + b@1` and `c@2`. +/// +/// 2) When `sync_with_child` is `false`: +/// +/// The function determines how the expression would be updated if a projection +/// was placed before the plan associated with the expression. If the expression +/// cannot be rewritten after the projection, it returns `None`. For example, +/// given the expressions `c@0`, `a@1` and `b@2`, and the projection with +/// an output schema of `a, c_new`, then `c@0` becomes `c_new@1`, `a@1` becomes +/// `a@0`, but `b@2` results in `None` since the projection does not include `b`. +/// +/// # Errors +/// This function returns an error if `sync_with_child` is `true` and if any expression references +/// an index that is out of bounds for `projected_exprs`. +/// For example: +/// +/// - `expr` is `a@3` +/// - `projected_exprs` is \[`a@0`, `b@1`\] +/// +/// In this case, `a@3` references index 3, which is out of bounds for `projected_exprs` (which has length 2). +pub fn update_expr( + expr: &Arc, + projected_exprs: &[ProjectionExpr], + sync_with_child: bool, +) -> Result>> { + #[derive(Debug, PartialEq)] + enum RewriteState { + /// The expression is unchanged. + Unchanged, + /// Some part of the expression has been rewritten + RewrittenValid, + /// Some part of the expression has been rewritten, but some column + /// references could not be. + RewrittenInvalid, + } + + let mut state = RewriteState::Unchanged; + + let new_expr = Arc::clone(expr) + .transform_up(|expr| { + if state == RewriteState::RewrittenInvalid { + return Ok(Transformed::no(expr)); + } + + let Some(column) = expr.as_any().downcast_ref::() else { + return Ok(Transformed::no(expr)); + }; + if sync_with_child { + state = RewriteState::RewrittenValid; + // Update the index of `column`: + let projected_expr = projected_exprs.get(column.index()).ok_or_else(|| { + internal_datafusion_err!( + "Column index {} out of bounds for projected expressions of length {}", + column.index(), + projected_exprs.len() + ) + })?; + Ok(Transformed::yes(Arc::clone(&projected_expr.expr))) + } else { + // default to invalid, in case we can't find the relevant column + state = RewriteState::RewrittenInvalid; + // Determine how to update `column` to accommodate `projected_exprs` + projected_exprs + .iter() + .enumerate() + .find_map(|(index, proj_expr)| { + proj_expr.expr.as_any().downcast_ref::().and_then( + |projected_column| { + (column.name().eq(projected_column.name()) + && column.index() == projected_column.index()) + .then(|| { + state = RewriteState::RewrittenValid; + Arc::new(Column::new(&proj_expr.alias, index)) as _ + }) + }, + ) + }) + .map_or_else( + || Ok(Transformed::no(expr)), + |c| Ok(Transformed::yes(c)), + ) + } + }) + .data()?; + + Ok((state == RewriteState::RewrittenValid).then_some(new_expr)) +} /// Stores target expressions, along with their indices, that associate with a /// source expression in a projection mapping. @@ -249,18 +661,46 @@ pub fn project_ordering( } #[cfg(test)] -mod tests { +pub(crate) mod tests { + use std::collections::HashMap; + use super::*; - use crate::equivalence::tests::output_schema; use crate::equivalence::{convert_to_orderings, EquivalenceProperties}; - use crate::expressions::{col, BinaryExpr}; + use crate::expressions::{col, BinaryExpr, Literal}; use crate::utils::tests::TestScalarUDF; use crate::{PhysicalExprRef, ScalarFunctionExpr}; use arrow::compute::SortOptions; use arrow::datatypes::{DataType, Field, Schema, TimeUnit}; use datafusion_common::config::ConfigOptions; + use datafusion_common::{ScalarValue, Statistics}; use datafusion_expr::{Operator, ScalarUDF}; + use insta::assert_snapshot; + + pub(crate) fn output_schema( + mapping: &ProjectionMapping, + input_schema: &Arc, + ) -> Result { + // Calculate output schema: + let mut fields = vec![]; + for (source, targets) in mapping.iter() { + let data_type = source.data_type(input_schema)?; + let nullable = source.nullable(input_schema)?; + for (target, _) in targets.iter() { + let Some(column) = target.as_any().downcast_ref::() else { + return plan_err!("Expects to have column"); + }; + fields.push(Field::new(column.name(), data_type.clone(), nullable)); + } + } + + let output_schema = Arc::new(Schema::new_with_metadata( + fields, + input_schema.metadata().clone(), + )); + + Ok(output_schema) + } #[test] fn project_orderings() -> Result<()> { @@ -1087,4 +1527,628 @@ mod tests { Ok(()) } + + fn get_stats() -> Statistics { + Statistics { + num_rows: Precision::Exact(5), + total_byte_size: Precision::Exact(23), + column_statistics: vec![ + ColumnStatistics { + distinct_count: Precision::Exact(5), + max_value: Precision::Exact(ScalarValue::Int64(Some(21))), + min_value: Precision::Exact(ScalarValue::Int64(Some(-4))), + sum_value: Precision::Exact(ScalarValue::Int64(Some(42))), + null_count: Precision::Exact(0), + }, + ColumnStatistics { + distinct_count: Precision::Exact(1), + max_value: Precision::Exact(ScalarValue::from("x")), + min_value: Precision::Exact(ScalarValue::from("a")), + sum_value: Precision::Absent, + null_count: Precision::Exact(3), + }, + ColumnStatistics { + distinct_count: Precision::Absent, + max_value: Precision::Exact(ScalarValue::Float32(Some(1.1))), + min_value: Precision::Exact(ScalarValue::Float32(Some(0.1))), + sum_value: Precision::Exact(ScalarValue::Float32(Some(5.5))), + null_count: Precision::Absent, + }, + ], + } + } + + fn get_schema() -> Schema { + let field_0 = Field::new("col0", DataType::Int64, false); + let field_1 = Field::new("col1", DataType::Utf8, false); + let field_2 = Field::new("col2", DataType::Float32, false); + Schema::new(vec![field_0, field_1, field_2]) + } + + #[test] + fn test_stats_projection_columns_only() { + let source = get_stats(); + let schema = get_schema(); + + let projection = Projection::new(vec![ + ProjectionExpr { + expr: Arc::new(Column::new("col1", 1)), + alias: "col1".to_string(), + }, + ProjectionExpr { + expr: Arc::new(Column::new("col0", 0)), + alias: "col0".to_string(), + }, + ]); + + let result = projection.project_statistics(source, &schema).unwrap(); + + let expected = Statistics { + num_rows: Precision::Exact(5), + total_byte_size: Precision::Exact(23), + column_statistics: vec![ + ColumnStatistics { + distinct_count: Precision::Exact(1), + max_value: Precision::Exact(ScalarValue::from("x")), + min_value: Precision::Exact(ScalarValue::from("a")), + sum_value: Precision::Absent, + null_count: Precision::Exact(3), + }, + ColumnStatistics { + distinct_count: Precision::Exact(5), + max_value: Precision::Exact(ScalarValue::Int64(Some(21))), + min_value: Precision::Exact(ScalarValue::Int64(Some(-4))), + sum_value: Precision::Exact(ScalarValue::Int64(Some(42))), + null_count: Precision::Exact(0), + }, + ], + }; + + assert_eq!(result, expected); + } + + #[test] + fn test_stats_projection_column_with_primitive_width_only() { + let source = get_stats(); + let schema = get_schema(); + + let projection = Projection::new(vec![ + ProjectionExpr { + expr: Arc::new(Column::new("col2", 2)), + alias: "col2".to_string(), + }, + ProjectionExpr { + expr: Arc::new(Column::new("col0", 0)), + alias: "col0".to_string(), + }, + ]); + + let result = projection.project_statistics(source, &schema).unwrap(); + + let expected = Statistics { + num_rows: Precision::Exact(5), + total_byte_size: Precision::Exact(60), + column_statistics: vec![ + ColumnStatistics { + distinct_count: Precision::Absent, + max_value: Precision::Exact(ScalarValue::Float32(Some(1.1))), + min_value: Precision::Exact(ScalarValue::Float32(Some(0.1))), + sum_value: Precision::Exact(ScalarValue::Float32(Some(5.5))), + null_count: Precision::Absent, + }, + ColumnStatistics { + distinct_count: Precision::Exact(5), + max_value: Precision::Exact(ScalarValue::Int64(Some(21))), + min_value: Precision::Exact(ScalarValue::Int64(Some(-4))), + sum_value: Precision::Exact(ScalarValue::Int64(Some(42))), + null_count: Precision::Exact(0), + }, + ], + }; + + assert_eq!(result, expected); + } + + // Tests for Projection struct + + #[test] + fn test_projection_new() -> Result<()> { + let exprs = vec![ + ProjectionExpr { + expr: Arc::new(Column::new("a", 0)), + alias: "a".to_string(), + }, + ProjectionExpr { + expr: Arc::new(Column::new("b", 1)), + alias: "b".to_string(), + }, + ]; + let projection = Projection::new(exprs.clone()); + assert_eq!(projection.as_ref().len(), 2); + Ok(()) + } + + #[test] + fn test_projection_from_vec() -> Result<()> { + let exprs = vec![ProjectionExpr { + expr: Arc::new(Column::new("x", 0)), + alias: "x".to_string(), + }]; + let projection: Projection = exprs.clone().into(); + assert_eq!(projection.as_ref().len(), 1); + Ok(()) + } + + #[test] + fn test_projection_as_ref() -> Result<()> { + let exprs = vec![ + ProjectionExpr { + expr: Arc::new(Column::new("col1", 0)), + alias: "col1".to_string(), + }, + ProjectionExpr { + expr: Arc::new(Column::new("col2", 1)), + alias: "col2".to_string(), + }, + ]; + let projection = Projection::new(exprs); + let as_ref: &[ProjectionExpr] = projection.as_ref(); + assert_eq!(as_ref.len(), 2); + Ok(()) + } + + #[test] + fn test_column_indices_multiple_columns() -> Result<()> { + // Test with reversed column order to ensure proper reordering + let projection = Projection::new(vec![ + ProjectionExpr { + expr: Arc::new(Column::new("c", 5)), + alias: "c".to_string(), + }, + ProjectionExpr { + expr: Arc::new(Column::new("b", 2)), + alias: "b".to_string(), + }, + ProjectionExpr { + expr: Arc::new(Column::new("a", 0)), + alias: "a".to_string(), + }, + ]); + // Should return sorted indices regardless of projection order + assert_eq!(projection.column_indices(), vec![0, 2, 5]); + Ok(()) + } + + #[test] + fn test_column_indices_duplicates() -> Result<()> { + // Test that duplicate column indices appear only once + let projection = Projection::new(vec![ + ProjectionExpr { + expr: Arc::new(Column::new("a", 1)), + alias: "a".to_string(), + }, + ProjectionExpr { + expr: Arc::new(Column::new("b", 3)), + alias: "b".to_string(), + }, + ProjectionExpr { + expr: Arc::new(Column::new("a2", 1)), // duplicate index + alias: "a2".to_string(), + }, + ]); + assert_eq!(projection.column_indices(), vec![1, 3]); + Ok(()) + } + + #[test] + fn test_column_indices_unsorted() -> Result<()> { + // Test that column indices are sorted in the output + let projection = Projection::new(vec![ + ProjectionExpr { + expr: Arc::new(Column::new("c", 5)), + alias: "c".to_string(), + }, + ProjectionExpr { + expr: Arc::new(Column::new("a", 1)), + alias: "a".to_string(), + }, + ProjectionExpr { + expr: Arc::new(Column::new("b", 3)), + alias: "b".to_string(), + }, + ]); + assert_eq!(projection.column_indices(), vec![1, 3, 5]); + Ok(()) + } + + #[test] + fn test_column_indices_complex_expr() -> Result<()> { + // Test with complex expressions containing multiple columns + let expr = Arc::new(BinaryExpr::new( + Arc::new(Column::new("a", 1)), + Operator::Plus, + Arc::new(Column::new("b", 4)), + )); + let projection = Projection::new(vec![ + ProjectionExpr { + expr, + alias: "sum".to_string(), + }, + ProjectionExpr { + expr: Arc::new(Column::new("c", 2)), + alias: "c".to_string(), + }, + ]); + // Should return [1, 2, 4] - all columns used, sorted and deduplicated + assert_eq!(projection.column_indices(), vec![1, 2, 4]); + Ok(()) + } + + #[test] + fn test_column_indices_empty() -> Result<()> { + let projection = Projection::new(vec![]); + assert_eq!(projection.column_indices(), Vec::::new()); + Ok(()) + } + + #[test] + fn test_merge_simple_columns() -> Result<()> { + // First projection: SELECT c@2 AS x, b@1 AS y, a@0 AS z + let base_projection = Projection::new(vec![ + ProjectionExpr { + expr: Arc::new(Column::new("c", 2)), + alias: "x".to_string(), + }, + ProjectionExpr { + expr: Arc::new(Column::new("b", 1)), + alias: "y".to_string(), + }, + ProjectionExpr { + expr: Arc::new(Column::new("a", 0)), + alias: "z".to_string(), + }, + ]); + + // Second projection: SELECT y@1 AS col2, x@0 AS col1 + let top_projection = Projection::new(vec![ + ProjectionExpr { + expr: Arc::new(Column::new("y", 1)), + alias: "col2".to_string(), + }, + ProjectionExpr { + expr: Arc::new(Column::new("x", 0)), + alias: "col1".to_string(), + }, + ]); + + // Merge should produce: SELECT b@1 AS col2, c@2 AS col1 + let merged = base_projection.try_merge(&top_projection)?; + assert_snapshot!(format!("{merged}"), @"Projection[b@1 AS col2, c@2 AS col1]"); + + Ok(()) + } + + #[test] + fn test_merge_with_expressions() -> Result<()> { + // First projection: SELECT c@2 AS x, b@1 AS y, a@0 AS z + let base_projection = Projection::new(vec![ + ProjectionExpr { + expr: Arc::new(Column::new("c", 2)), + alias: "x".to_string(), + }, + ProjectionExpr { + expr: Arc::new(Column::new("b", 1)), + alias: "y".to_string(), + }, + ProjectionExpr { + expr: Arc::new(Column::new("a", 0)), + alias: "z".to_string(), + }, + ]); + + // Second projection: SELECT y@1 + z@2 AS c2, x@0 + 1 AS c1 + let top_projection = Projection::new(vec![ + ProjectionExpr { + expr: Arc::new(BinaryExpr::new( + Arc::new(Column::new("y", 1)), + Operator::Plus, + Arc::new(Column::new("z", 2)), + )), + alias: "c2".to_string(), + }, + ProjectionExpr { + expr: Arc::new(BinaryExpr::new( + Arc::new(Column::new("x", 0)), + Operator::Plus, + Arc::new(Literal::new(ScalarValue::Int32(Some(1)))), + )), + alias: "c1".to_string(), + }, + ]); + + // Merge should produce: SELECT b@1 + a@0 AS c2, c@2 + 1 AS c1 + let merged = base_projection.try_merge(&top_projection)?; + assert_snapshot!(format!("{merged}"), @"Projection[b@1 + a@0 AS c2, c@2 + 1 AS c1]"); + + Ok(()) + } + + #[test] + fn try_merge_error() { + // Create a base projection + let base = Projection::new(vec![ + ProjectionExpr { + expr: Arc::new(Column::new("a", 0)), + alias: "x".to_string(), + }, + ProjectionExpr { + expr: Arc::new(Column::new("b", 1)), + alias: "y".to_string(), + }, + ]); + + // Create a top projection that references a non-existent column index + let top = Projection::new(vec![ProjectionExpr { + expr: Arc::new(Column::new("z", 5)), // Invalid index + alias: "result".to_string(), + }]); + + // Attempt to merge and expect an error + let err_msg = base.try_merge(&top).unwrap_err().to_string(); + assert!( + err_msg.contains("Internal error: Column index 5 out of bounds for projected expressions of length 2"), + "Unexpected error message: {err_msg}", + ); + } + + #[test] + fn test_project_schema_simple_columns() -> Result<()> { + // Input schema: [col0: Int64, col1: Utf8, col2: Float32] + let input_schema = get_schema(); + + // Projection: SELECT col2 AS c, col0 AS a + let projection = Projection::new(vec![ + ProjectionExpr { + expr: Arc::new(Column::new("col2", 2)), + alias: "c".to_string(), + }, + ProjectionExpr { + expr: Arc::new(Column::new("col0", 0)), + alias: "a".to_string(), + }, + ]); + + let output_schema = projection.project_schema(&input_schema)?; + + // Should have 2 fields + assert_eq!(output_schema.fields().len(), 2); + + // First field should be "c" with Float32 type + assert_eq!(output_schema.field(0).name(), "c"); + assert_eq!(output_schema.field(0).data_type(), &DataType::Float32); + + // Second field should be "a" with Int64 type + assert_eq!(output_schema.field(1).name(), "a"); + assert_eq!(output_schema.field(1).data_type(), &DataType::Int64); + + Ok(()) + } + + #[test] + fn test_project_schema_with_expressions() -> Result<()> { + // Input schema: [col0: Int64, col1: Utf8, col2: Float32] + let input_schema = get_schema(); + + // Projection: SELECT col0 + 1 AS incremented + let projection = Projection::new(vec![ProjectionExpr { + expr: Arc::new(BinaryExpr::new( + Arc::new(Column::new("col0", 0)), + Operator::Plus, + Arc::new(Literal::new(ScalarValue::Int64(Some(1)))), + )), + alias: "incremented".to_string(), + }]); + + let output_schema = projection.project_schema(&input_schema)?; + + // Should have 1 field + assert_eq!(output_schema.fields().len(), 1); + + // Field should be "incremented" with Int64 type + assert_eq!(output_schema.field(0).name(), "incremented"); + assert_eq!(output_schema.field(0).data_type(), &DataType::Int64); + + Ok(()) + } + + #[test] + fn test_project_schema_preserves_metadata() -> Result<()> { + // Create schema with metadata + let mut metadata = HashMap::new(); + metadata.insert("key".to_string(), "value".to_string()); + let field_with_metadata = + Field::new("col0", DataType::Int64, false).with_metadata(metadata.clone()); + let input_schema = Schema::new(vec![ + field_with_metadata, + Field::new("col1", DataType::Utf8, false), + ]); + + // Projection: SELECT col0 AS renamed + let projection = Projection::new(vec![ProjectionExpr { + expr: Arc::new(Column::new("col0", 0)), + alias: "renamed".to_string(), + }]); + + let output_schema = projection.project_schema(&input_schema)?; + + // Should have 1 field + assert_eq!(output_schema.fields().len(), 1); + + // Field should be "renamed" with metadata preserved + assert_eq!(output_schema.field(0).name(), "renamed"); + assert_eq!(output_schema.field(0).metadata(), &metadata); + + Ok(()) + } + + #[test] + fn test_project_schema_empty() -> Result<()> { + let input_schema = get_schema(); + let projection = Projection::new(vec![]); + + let output_schema = projection.project_schema(&input_schema)?; + + assert_eq!(output_schema.fields().len(), 0); + + Ok(()) + } + + #[test] + fn test_project_statistics_columns_only() -> Result<()> { + let input_stats = get_stats(); + let input_schema = get_schema(); + + // Projection: SELECT col1 AS text, col0 AS num + let projection = Projection::new(vec![ + ProjectionExpr { + expr: Arc::new(Column::new("col1", 1)), + alias: "text".to_string(), + }, + ProjectionExpr { + expr: Arc::new(Column::new("col0", 0)), + alias: "num".to_string(), + }, + ]); + + let output_stats = projection.project_statistics(input_stats, &input_schema)?; + + // Row count should be preserved + assert_eq!(output_stats.num_rows, Precision::Exact(5)); + + // Should have 2 column statistics (reordered from input) + assert_eq!(output_stats.column_statistics.len(), 2); + + // First column (col1 from input) + assert_eq!( + output_stats.column_statistics[0].distinct_count, + Precision::Exact(1) + ); + assert_eq!( + output_stats.column_statistics[0].max_value, + Precision::Exact(ScalarValue::from("x")) + ); + + // Second column (col0 from input) + assert_eq!( + output_stats.column_statistics[1].distinct_count, + Precision::Exact(5) + ); + assert_eq!( + output_stats.column_statistics[1].max_value, + Precision::Exact(ScalarValue::Int64(Some(21))) + ); + + Ok(()) + } + + #[test] + fn test_project_statistics_with_expressions() -> Result<()> { + let input_stats = get_stats(); + let input_schema = get_schema(); + + // Projection with expression: SELECT col0 + 1 AS incremented, col1 AS text + let projection = Projection::new(vec![ + ProjectionExpr { + expr: Arc::new(BinaryExpr::new( + Arc::new(Column::new("col0", 0)), + Operator::Plus, + Arc::new(Literal::new(ScalarValue::Int64(Some(1)))), + )), + alias: "incremented".to_string(), + }, + ProjectionExpr { + expr: Arc::new(Column::new("col1", 1)), + alias: "text".to_string(), + }, + ]); + + let output_stats = projection.project_statistics(input_stats, &input_schema)?; + + // Row count should be preserved + assert_eq!(output_stats.num_rows, Precision::Exact(5)); + + // Should have 2 column statistics + assert_eq!(output_stats.column_statistics.len(), 2); + + // First column (expression) should have unknown statistics + assert_eq!( + output_stats.column_statistics[0].distinct_count, + Precision::Absent + ); + assert_eq!( + output_stats.column_statistics[0].max_value, + Precision::Absent + ); + + // Second column (col1) should preserve statistics + assert_eq!( + output_stats.column_statistics[1].distinct_count, + Precision::Exact(1) + ); + + Ok(()) + } + + #[test] + fn test_project_statistics_primitive_width_only() -> Result<()> { + let input_stats = get_stats(); + let input_schema = get_schema(); + + // Projection with only primitive width columns: SELECT col2 AS f, col0 AS i + let projection = Projection::new(vec![ + ProjectionExpr { + expr: Arc::new(Column::new("col2", 2)), + alias: "f".to_string(), + }, + ProjectionExpr { + expr: Arc::new(Column::new("col0", 0)), + alias: "i".to_string(), + }, + ]); + + let output_stats = projection.project_statistics(input_stats, &input_schema)?; + + // Row count should be preserved + assert_eq!(output_stats.num_rows, Precision::Exact(5)); + + // Total byte size should be recalculated for primitive types + // Float32 (4 bytes) + Int64 (8 bytes) = 12 bytes per row, 5 rows = 60 bytes + assert_eq!(output_stats.total_byte_size, Precision::Exact(60)); + + // Should have 2 column statistics + assert_eq!(output_stats.column_statistics.len(), 2); + + Ok(()) + } + + #[test] + fn test_project_statistics_empty() -> Result<()> { + let input_stats = get_stats(); + let input_schema = get_schema(); + + let projection = Projection::new(vec![]); + + let output_stats = projection.project_statistics(input_stats, &input_schema)?; + + // Row count should be preserved + assert_eq!(output_stats.num_rows, Precision::Exact(5)); + + // Should have no column statistics + assert_eq!(output_stats.column_statistics.len(), 0); + + // Total byte size should be 0 for empty projection + assert_eq!(output_stats.total_byte_size, Precision::Exact(0)); + + Ok(()) + } } diff --git a/datafusion/physical-optimizer/Cargo.toml b/datafusion/physical-optimizer/Cargo.toml index 15466cd86bb0..4df011fc0a05 100644 --- a/datafusion/physical-optimizer/Cargo.toml +++ b/datafusion/physical-optimizer/Cargo.toml @@ -52,5 +52,6 @@ recursive = { workspace = true, optional = true } [dev-dependencies] datafusion-expr = { workspace = true } +datafusion-functions = { workspace = true } insta = { workspace = true } tokio = { workspace = true } diff --git a/datafusion/physical-optimizer/src/projection_pushdown.rs b/datafusion/physical-optimizer/src/projection_pushdown.rs index 34affcbd4a19..987e3cb6f713 100644 --- a/datafusion/physical-optimizer/src/projection_pushdown.rs +++ b/datafusion/physical-optimizer/src/projection_pushdown.rs @@ -20,18 +20,32 @@ //! projections one by one if the operator below is amenable to this. If a //! projection reaches a source, it can even disappear from the plan entirely. -use std::sync::Arc; - use crate::PhysicalOptimizerRule; +use arrow::datatypes::{Fields, Schema, SchemaRef}; +use datafusion_common::alias::AliasGenerator; +use std::collections::HashSet; +use std::sync::Arc; use datafusion_common::config::ConfigOptions; -use datafusion_common::tree_node::{TransformedResult, TreeNode}; -use datafusion_common::Result; -use datafusion_physical_plan::projection::remove_unnecessary_projections; +use datafusion_common::tree_node::{ + Transformed, TransformedResult, TreeNode, TreeNodeRecursion, +}; +use datafusion_common::{JoinSide, JoinType, Result}; +use datafusion_physical_expr::expressions::Column; +use datafusion_physical_expr_common::physical_expr::PhysicalExpr; +use datafusion_physical_plan::joins::utils::{ColumnIndex, JoinFilter}; +use datafusion_physical_plan::joins::NestedLoopJoinExec; +use datafusion_physical_plan::projection::{ + remove_unnecessary_projections, ProjectionExec, +}; use datafusion_physical_plan::ExecutionPlan; /// This rule inspects `ProjectionExec`'s in the given physical plan and tries to /// remove or swap with its child. +/// +/// Furthermore, tries to push down projections from nested loop join filters that only depend on +/// one side of the join. By pushing these projections down, functions that only depend on one side +/// of the join must be evaluated for the cartesian product of the two sides. #[derive(Default, Debug)] pub struct ProjectionPushdown {} @@ -48,6 +62,20 @@ impl PhysicalOptimizerRule for ProjectionPushdown { plan: Arc, _config: &ConfigOptions, ) -> Result> { + let alias_generator = AliasGenerator::new(); + let plan = plan + .transform_up(|plan| { + match plan.as_any().downcast_ref::() { + None => Ok(Transformed::no(plan)), + Some(hash_join) => try_push_down_join_filter( + Arc::clone(&plan), + hash_join, + &alias_generator, + ), + } + }) + .map(|t| t.data)?; + plan.transform_down(remove_unnecessary_projections).data() } @@ -59,3 +87,713 @@ impl PhysicalOptimizerRule for ProjectionPushdown { true } } + +/// Tries to push down parts of the filter. +/// +/// See [JoinFilterRewriter] for details. +fn try_push_down_join_filter( + original_plan: Arc, + join: &NestedLoopJoinExec, + alias_generator: &AliasGenerator, +) -> Result>> { + // Mark joins are currently not supported. + if matches!(join.join_type(), JoinType::LeftMark | JoinType::RightMark) { + return Ok(Transformed::no(original_plan)); + } + + let projections = join.projection(); + let Some(filter) = join.filter() else { + return Ok(Transformed::no(original_plan)); + }; + + let original_lhs_length = join.left().schema().fields().len(); + let original_rhs_length = join.right().schema().fields().len(); + + let lhs_rewrite = try_push_down_projection( + Arc::clone(&join.right().schema()), + Arc::clone(join.left()), + JoinSide::Left, + filter.clone(), + alias_generator, + )?; + let rhs_rewrite = try_push_down_projection( + Arc::clone(&lhs_rewrite.data.0.schema()), + Arc::clone(join.right()), + JoinSide::Right, + lhs_rewrite.data.1, + alias_generator, + )?; + if !lhs_rewrite.transformed && !rhs_rewrite.transformed { + return Ok(Transformed::no(original_plan)); + } + + let join_filter = minimize_join_filter( + Arc::clone(rhs_rewrite.data.1.expression()), + rhs_rewrite.data.1.column_indices().to_vec(), + lhs_rewrite.data.0.schema().as_ref(), + rhs_rewrite.data.0.schema().as_ref(), + ); + + let new_lhs_length = lhs_rewrite.data.0.schema().fields.len(); + let projections = match projections { + None => match join.join_type() { + JoinType::Inner | JoinType::Left | JoinType::Right | JoinType::Full => { + // Build projections that ignore the newly projected columns. + let mut projections = Vec::new(); + projections.extend(0..original_lhs_length); + projections.extend(new_lhs_length..new_lhs_length + original_rhs_length); + projections + } + JoinType::LeftSemi | JoinType::LeftAnti => { + // Only return original left columns + let mut projections = Vec::new(); + projections.extend(0..original_lhs_length); + projections + } + JoinType::RightSemi | JoinType::RightAnti => { + // Only return original right columns + let mut projections = Vec::new(); + projections.extend(0..original_rhs_length); + projections + } + _ => unreachable!("Unsupported join type"), + }, + Some(projections) => { + let rhs_offset = new_lhs_length - original_lhs_length; + projections + .iter() + .map(|idx| { + if *idx >= original_lhs_length { + idx + rhs_offset + } else { + *idx + } + }) + .collect() + } + }; + + Ok(Transformed::yes(Arc::new(NestedLoopJoinExec::try_new( + lhs_rewrite.data.0, + rhs_rewrite.data.0, + Some(join_filter), + join.join_type(), + Some(projections), + )?))) +} + +/// Tries to push down parts of `expr` into the `join_side`. +fn try_push_down_projection( + other_schema: SchemaRef, + plan: Arc, + join_side: JoinSide, + join_filter: JoinFilter, + alias_generator: &AliasGenerator, +) -> Result, JoinFilter)>> { + let expr = Arc::clone(join_filter.expression()); + let original_plan_schema = plan.schema(); + let mut rewriter = JoinFilterRewriter::new( + join_side, + original_plan_schema.as_ref(), + join_filter.column_indices().to_vec(), + alias_generator, + ); + let new_expr = rewriter.rewrite(expr)?; + + if new_expr.transformed { + let new_join_side = + ProjectionExec::try_new(rewriter.join_side_projections, plan)?; + let new_schema = Arc::clone(&new_join_side.schema()); + + let (lhs_schema, rhs_schema) = match join_side { + JoinSide::Left => (new_schema, other_schema), + JoinSide::Right => (other_schema, new_schema), + JoinSide::None => unreachable!("Mark join not supported"), + }; + let intermediate_schema = rewriter + .intermediate_column_indices + .iter() + .map(|ci| match ci.side { + JoinSide::Left => Arc::clone(&lhs_schema.fields[ci.index]), + JoinSide::Right => Arc::clone(&rhs_schema.fields[ci.index]), + JoinSide::None => unreachable!("Mark join not supported"), + }) + .collect::(); + + let join_filter = JoinFilter::new( + new_expr.data, + rewriter.intermediate_column_indices, + Arc::new(Schema::new(intermediate_schema)), + ); + Ok(Transformed::yes((Arc::new(new_join_side), join_filter))) + } else { + Ok(Transformed::no((plan, join_filter))) + } +} + +/// Creates a new [JoinFilter] and tries to minimize the internal schema. +/// +/// This could eliminate some columns that were only part of a computation that has been pushed +/// down. As this computation is now materialized on one side of the join, the original input +/// columns are not needed anymore. +fn minimize_join_filter( + expr: Arc, + old_column_indices: Vec, + lhs_schema: &Schema, + rhs_schema: &Schema, +) -> JoinFilter { + let mut used_columns = HashSet::new(); + expr.apply(|expr| { + if let Some(col) = expr.as_any().downcast_ref::() { + used_columns.insert(col.index()); + } + Ok(TreeNodeRecursion::Continue) + }) + .expect("Closure cannot fail"); + + let new_column_indices = old_column_indices + .iter() + .enumerate() + .filter(|(idx, _)| used_columns.contains(idx)) + .map(|(_, ci)| ci.clone()) + .collect::>(); + let fields = new_column_indices + .iter() + .map(|ci| match ci.side { + JoinSide::Left => lhs_schema.field(ci.index).clone(), + JoinSide::Right => rhs_schema.field(ci.index).clone(), + JoinSide::None => unreachable!("Mark join not supported"), + }) + .collect::(); + + let final_expr = expr + .transform_up(|expr| match expr.as_any().downcast_ref::() { + None => Ok(Transformed::no(expr)), + Some(column) => { + let new_idx = used_columns + .iter() + .filter(|idx| **idx < column.index()) + .count(); + let new_column = Column::new(column.name(), new_idx); + Ok(Transformed::yes( + Arc::new(new_column) as Arc + )) + } + }) + .expect("Closure cannot fail"); + + JoinFilter::new( + final_expr.data, + new_column_indices, + Arc::new(Schema::new(fields)), + ) +} + +/// Implements the push-down machinery. +/// +/// The rewriter starts at the top of the filter expression and traverses the expression tree. For +/// each (sub-)expression, the rewriter checks whether it only refers to one side of the join. If +/// this is never the case, no subexpressions of the filter can be pushed down. If there is a +/// subexpression that can be computed using only one side of the join, the entire subexpression is +/// pushed down to the join side. +struct JoinFilterRewriter<'a> { + join_side: JoinSide, + join_side_schema: &'a Schema, + join_side_projections: Vec<(Arc, String)>, + intermediate_column_indices: Vec, + alias_generator: &'a AliasGenerator, +} + +impl<'a> JoinFilterRewriter<'a> { + /// Creates a new [JoinFilterRewriter]. + fn new( + join_side: JoinSide, + join_side_schema: &'a Schema, + column_indices: Vec, + alias_generator: &'a AliasGenerator, + ) -> Self { + let projections = join_side_schema + .fields() + .iter() + .enumerate() + .map(|(idx, field)| { + ( + Arc::new(Column::new(field.name(), idx)) as Arc, + field.name().to_string(), + ) + }) + .collect(); + + Self { + join_side, + join_side_schema, + join_side_projections: projections, + intermediate_column_indices: column_indices, + alias_generator, + } + } + + /// Executes the push-down machinery on `expr`. + /// + /// See the [JoinFilterRewriter] for further information. + fn rewrite( + &mut self, + expr: Arc, + ) -> Result>> { + let depends_on_this_side = self.depends_on_join_side(&expr, self.join_side)?; + // We don't push down things that do not depend on this side (other side or no side). + if !depends_on_this_side { + return Ok(Transformed::no(expr)); + } + + // Recurse if there is a dependency to both sides or if the entire expression is volatile. + let depends_on_other_side = + self.depends_on_join_side(&expr, self.join_side.negate())?; + let is_volatile = is_volatile_expression_tree(expr.as_ref()); + if depends_on_other_side || is_volatile { + return expr.map_children(|expr| self.rewrite(expr)); + } + + // There is only a dependency on this side. + + // If this expression has no children, we do not push down, as it should already be a column + // reference. + if expr.children().is_empty() { + return Ok(Transformed::no(expr)); + } + + // Otherwise, we push down a projection. + let alias = self.alias_generator.next("join_proj_push_down"); + let idx = self.create_new_column(alias.clone(), expr)?; + + Ok(Transformed::yes( + Arc::new(Column::new(&alias, idx)) as Arc + )) + } + + /// Creates a new column in the current join side. + fn create_new_column( + &mut self, + name: String, + expr: Arc, + ) -> Result { + // First, add a new projection. The expression must be rewritten, as it is no longer + // executed against the filter schema. + let new_idx = self.join_side_projections.len(); + let rewritten_expr = expr.transform_up(|expr| { + Ok(match expr.as_any().downcast_ref::() { + None => Transformed::no(expr), + Some(column) => { + let intermediate_column = + &self.intermediate_column_indices[column.index()]; + assert_eq!(intermediate_column.side, self.join_side); + + let join_side_index = intermediate_column.index; + let field = self.join_side_schema.field(join_side_index); + let new_column = Column::new(field.name(), join_side_index); + Transformed::yes(Arc::new(new_column) as Arc) + } + }) + })?; + self.join_side_projections.push((rewritten_expr.data, name)); + + // Then, update the column indices + let new_intermediate_idx = self.intermediate_column_indices.len(); + let idx = ColumnIndex { + index: new_idx, + side: self.join_side, + }; + self.intermediate_column_indices.push(idx); + + Ok(new_intermediate_idx) + } + + /// Checks whether the entire expression depends on the given `join_side`. + fn depends_on_join_side( + &mut self, + expr: &Arc, + join_side: JoinSide, + ) -> Result { + let mut result = false; + expr.apply(|expr| match expr.as_any().downcast_ref::() { + None => Ok(TreeNodeRecursion::Continue), + Some(c) => { + let column_index = &self.intermediate_column_indices[c.index()]; + if column_index.side == join_side { + result = true; + return Ok(TreeNodeRecursion::Stop); + } + Ok(TreeNodeRecursion::Continue) + } + })?; + + Ok(result) + } +} + +fn is_volatile_expression_tree(expr: &dyn PhysicalExpr) -> bool { + if expr.is_volatile_node() { + return true; + } + + expr.children() + .iter() + .map(|expr| is_volatile_expression_tree(expr.as_ref())) + .reduce(|lhs, rhs| lhs || rhs) + .unwrap_or(false) +} + +#[cfg(test)] +mod test { + use super::*; + use arrow::datatypes::{DataType, Field, FieldRef, Schema}; + use datafusion_expr_common::operator::Operator; + use datafusion_functions::math::random; + use datafusion_physical_expr::expressions::{binary, lit}; + use datafusion_physical_expr::ScalarFunctionExpr; + use datafusion_physical_expr_common::physical_expr::PhysicalExpr; + use datafusion_physical_plan::displayable; + use datafusion_physical_plan::empty::EmptyExec; + use insta::assert_snapshot; + use std::sync::Arc; + + #[tokio::test] + async fn no_computation_does_not_project() -> Result<()> { + let (left_schema, right_schema) = create_simple_schemas(); + let optimized_plan = run_test( + left_schema, + right_schema, + a_x(), + None, + a_greater_than_x, + JoinType::Inner, + )?; + + assert_snapshot!(optimized_plan, @r" + NestedLoopJoinExec: join_type=Inner, filter=a@0 > x@1 + EmptyExec + EmptyExec + "); + Ok(()) + } + + #[tokio::test] + async fn simple_push_down() -> Result<()> { + let (left_schema, right_schema) = create_simple_schemas(); + let optimized_plan = run_test( + left_schema, + right_schema, + a_x(), + None, + a_plus_one_greater_than_x_plus_one, + JoinType::Inner, + )?; + + assert_snapshot!(optimized_plan, @r" + NestedLoopJoinExec: join_type=Inner, filter=join_proj_push_down_1@0 > join_proj_push_down_2@1, projection=[a@0, x@2] + ProjectionExec: expr=[a@0 as a, a@0 + 1 as join_proj_push_down_1] + EmptyExec + ProjectionExec: expr=[x@0 as x, x@0 + 1 as join_proj_push_down_2] + EmptyExec + "); + Ok(()) + } + + #[tokio::test] + async fn does_not_push_down_short_circuiting_expressions() -> Result<()> { + let (left_schema, right_schema) = create_simple_schemas(); + let optimized_plan = run_test( + left_schema, + right_schema, + a_x(), + None, + |schema| { + binary( + lit(false), + Operator::And, + a_plus_one_greater_than_x_plus_one(schema)?, + schema, + ) + }, + JoinType::Inner, + )?; + + assert_snapshot!(optimized_plan, @r" + NestedLoopJoinExec: join_type=Inner, filter=false AND join_proj_push_down_1@0 > join_proj_push_down_2@1, projection=[a@0, x@2] + ProjectionExec: expr=[a@0 as a, a@0 + 1 as join_proj_push_down_1] + EmptyExec + ProjectionExec: expr=[x@0 as x, x@0 + 1 as join_proj_push_down_2] + EmptyExec + "); + Ok(()) + } + + #[tokio::test] + async fn does_not_push_down_volatile_functions() -> Result<()> { + let (left_schema, right_schema) = create_simple_schemas(); + let optimized_plan = run_test( + left_schema, + right_schema, + a_x(), + None, + a_plus_rand_greater_than_x, + JoinType::Inner, + )?; + + assert_snapshot!(optimized_plan, @r" + NestedLoopJoinExec: join_type=Inner, filter=a@0 + rand() > x@1 + EmptyExec + EmptyExec + "); + Ok(()) + } + + #[tokio::test] + async fn complex_schema_push_down() -> Result<()> { + let (left_schema, right_schema) = create_complex_schemas(); + + let optimized_plan = run_test( + left_schema, + right_schema, + a_b_x_z(), + None, + a_plus_b_greater_than_x_plus_z, + JoinType::Inner, + )?; + + assert_snapshot!(optimized_plan, @r" + NestedLoopJoinExec: join_type=Inner, filter=join_proj_push_down_1@0 > join_proj_push_down_2@1, projection=[a@0, b@1, c@2, x@4, y@5, z@6] + ProjectionExec: expr=[a@0 as a, b@1 as b, c@2 as c, a@0 + b@1 as join_proj_push_down_1] + EmptyExec + ProjectionExec: expr=[x@0 as x, y@1 as y, z@2 as z, x@0 + z@2 as join_proj_push_down_2] + EmptyExec + "); + Ok(()) + } + + #[tokio::test] + async fn push_down_with_existing_projections() -> Result<()> { + let (left_schema, right_schema) = create_complex_schemas(); + + let optimized_plan = run_test( + left_schema, + right_schema, + a_b_x_z(), + Some(vec![1, 3, 5]), // ("b", "x", "z") + a_plus_b_greater_than_x_plus_z, + JoinType::Inner, + )?; + + assert_snapshot!(optimized_plan, @r" + NestedLoopJoinExec: join_type=Inner, filter=join_proj_push_down_1@0 > join_proj_push_down_2@1, projection=[b@1, x@4, z@6] + ProjectionExec: expr=[a@0 as a, b@1 as b, c@2 as c, a@0 + b@1 as join_proj_push_down_1] + EmptyExec + ProjectionExec: expr=[x@0 as x, y@1 as y, z@2 as z, x@0 + z@2 as join_proj_push_down_2] + EmptyExec + "); + Ok(()) + } + + #[tokio::test] + async fn left_semi_join_projection() -> Result<()> { + let (left_schema, right_schema) = create_simple_schemas(); + + let left_semi_join_plan = run_test( + left_schema.clone(), + right_schema.clone(), + a_x(), + None, + a_plus_one_greater_than_x_plus_one, + JoinType::LeftSemi, + )?; + + assert_snapshot!(left_semi_join_plan, @r" + NestedLoopJoinExec: join_type=LeftSemi, filter=join_proj_push_down_1@0 > join_proj_push_down_2@1, projection=[a@0] + ProjectionExec: expr=[a@0 as a, a@0 + 1 as join_proj_push_down_1] + EmptyExec + ProjectionExec: expr=[x@0 as x, x@0 + 1 as join_proj_push_down_2] + EmptyExec + "); + Ok(()) + } + + #[tokio::test] + async fn right_semi_join_projection() -> Result<()> { + let (left_schema, right_schema) = create_simple_schemas(); + let right_semi_join_plan = run_test( + left_schema, + right_schema, + a_x(), + None, + a_plus_one_greater_than_x_plus_one, + JoinType::RightSemi, + )?; + assert_snapshot!(right_semi_join_plan, @r" + NestedLoopJoinExec: join_type=RightSemi, filter=join_proj_push_down_1@0 > join_proj_push_down_2@1, projection=[x@0] + ProjectionExec: expr=[a@0 as a, a@0 + 1 as join_proj_push_down_1] + EmptyExec + ProjectionExec: expr=[x@0 as x, x@0 + 1 as join_proj_push_down_2] + EmptyExec + "); + Ok(()) + } + + fn run_test( + left_schema: Schema, + right_schema: Schema, + column_indices: Vec, + existing_projections: Option>, + filter_expr_builder: impl FnOnce(&Schema) -> Result>, + join_type: JoinType, + ) -> Result { + let left = Arc::new(EmptyExec::new(Arc::new(left_schema.clone()))); + let right = Arc::new(EmptyExec::new(Arc::new(right_schema.clone()))); + + let join_fields: Vec<_> = column_indices + .iter() + .map(|ci| match ci.side { + JoinSide::Left => left_schema.field(ci.index).clone(), + JoinSide::Right => right_schema.field(ci.index).clone(), + JoinSide::None => unreachable!(), + }) + .collect(); + let join_schema = Arc::new(Schema::new(join_fields)); + + let filter_expr = filter_expr_builder(join_schema.as_ref())?; + + let join_filter = JoinFilter::new(filter_expr, column_indices, join_schema); + + let join = NestedLoopJoinExec::try_new( + left, + right, + Some(join_filter), + &join_type, + existing_projections, + )?; + + let optimizer = ProjectionPushdown::new(); + let optimized_plan = optimizer.optimize(Arc::new(join), &Default::default())?; + + let displayable_plan = displayable(optimized_plan.as_ref()).indent(false); + Ok(displayable_plan.to_string()) + } + + fn create_simple_schemas() -> (Schema, Schema) { + let left_schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]); + let right_schema = Schema::new(vec![Field::new("x", DataType::Int32, false)]); + + (left_schema, right_schema) + } + + fn create_complex_schemas() -> (Schema, Schema) { + let left_schema = Schema::new(vec![ + Field::new("a", DataType::Int32, false), + Field::new("b", DataType::Int32, false), + Field::new("c", DataType::Int32, false), + ]); + + let right_schema = Schema::new(vec![ + Field::new("x", DataType::Int32, false), + Field::new("y", DataType::Int32, false), + Field::new("z", DataType::Int32, false), + ]); + + (left_schema, right_schema) + } + + fn a_x() -> Vec { + vec![ + ColumnIndex { + index: 0, + side: JoinSide::Left, + }, + ColumnIndex { + index: 0, + side: JoinSide::Right, + }, + ] + } + + fn a_b_x_z() -> Vec { + vec![ + ColumnIndex { + index: 0, + side: JoinSide::Left, + }, + ColumnIndex { + index: 1, + side: JoinSide::Left, + }, + ColumnIndex { + index: 0, + side: JoinSide::Right, + }, + ColumnIndex { + index: 2, + side: JoinSide::Right, + }, + ] + } + + fn a_plus_one_greater_than_x_plus_one( + join_schema: &Schema, + ) -> Result> { + let left_expr = binary( + Arc::new(Column::new("a", 0)), + Operator::Plus, + lit(1), + join_schema, + )?; + let right_expr = binary( + Arc::new(Column::new("x", 1)), + Operator::Plus, + lit(1), + join_schema, + )?; + binary(left_expr, Operator::Gt, right_expr, join_schema) + } + + fn a_plus_rand_greater_than_x(join_schema: &Schema) -> Result> { + let left_expr = binary( + Arc::new(Column::new("a", 0)), + Operator::Plus, + Arc::new(ScalarFunctionExpr::new( + "rand", + random(), + vec![], + FieldRef::new(Field::new("out", DataType::Float64, false)), + Arc::new(ConfigOptions::default()), + )), + join_schema, + )?; + let right_expr = Arc::new(Column::new("x", 1)); + binary(left_expr, Operator::Gt, right_expr, join_schema) + } + + fn a_greater_than_x(join_schema: &Schema) -> Result> { + binary( + Arc::new(Column::new("a", 0)), + Operator::Gt, + Arc::new(Column::new("x", 1)), + join_schema, + ) + } + + fn a_plus_b_greater_than_x_plus_z( + join_schema: &Schema, + ) -> Result> { + let lhs = binary( + Arc::new(Column::new("a", 0)), + Operator::Plus, + Arc::new(Column::new("b", 1)), + join_schema, + )?; + let rhs = binary( + Arc::new(Column::new("x", 2)), + Operator::Plus, + Arc::new(Column::new("z", 3)), + join_schema, + )?; + binary(lhs, Operator::Gt, rhs, join_schema) + } +} diff --git a/datafusion/physical-plan/src/analyze.rs b/datafusion/physical-plan/src/analyze.rs index c095afe5e716..c696cf5aa5e6 100644 --- a/datafusion/physical-plan/src/analyze.rs +++ b/datafusion/physical-plan/src/analyze.rs @@ -26,6 +26,7 @@ use super::{ SendableRecordBatchStream, }; use crate::display::DisplayableExecutionPlan; +use crate::metrics::MetricType; use crate::{DisplayFormatType, ExecutionPlan, Partitioning}; use arrow::{array::StringBuilder, datatypes::SchemaRef, record_batch::RecordBatch}; @@ -44,6 +45,8 @@ pub struct AnalyzeExec { verbose: bool, /// If statistics should be displayed show_statistics: bool, + /// Which metric categories should be displayed + metric_types: Vec, /// The input plan (the plan being analyzed) pub(crate) input: Arc, /// The output schema for RecordBatches of this exec node @@ -56,6 +59,7 @@ impl AnalyzeExec { pub fn new( verbose: bool, show_statistics: bool, + metric_types: Vec, input: Arc, schema: SchemaRef, ) -> Self { @@ -63,6 +67,7 @@ impl AnalyzeExec { AnalyzeExec { verbose, show_statistics, + metric_types, input, schema, cache, @@ -145,6 +150,7 @@ impl ExecutionPlan for AnalyzeExec { Ok(Arc::new(Self::new( self.verbose, self.show_statistics, + self.metric_types.clone(), children.pop().unwrap(), Arc::clone(&self.schema), ))) @@ -182,6 +188,7 @@ impl ExecutionPlan for AnalyzeExec { let captured_schema = Arc::clone(&self.schema); let verbose = self.verbose; let show_statistics = self.show_statistics; + let metric_types = self.metric_types.clone(); // future that gathers the results from all the tasks in the // JoinSet that computes the overall row count and final @@ -201,6 +208,7 @@ impl ExecutionPlan for AnalyzeExec { duration, captured_input, captured_schema, + &metric_types, ) }; @@ -219,6 +227,7 @@ fn create_output_batch( duration: std::time::Duration, input: Arc, schema: SchemaRef, + metric_types: &[MetricType], ) -> Result { let mut type_builder = StringBuilder::with_capacity(1, 1024); let mut plan_builder = StringBuilder::with_capacity(1, 1024); @@ -227,6 +236,7 @@ fn create_output_batch( type_builder.append_value("Plan with Metrics"); let annotated_plan = DisplayableExecutionPlan::with_metrics(input.as_ref()) + .set_metric_types(metric_types.to_vec()) .set_show_statistics(show_statistics) .indent(verbose) .to_string(); @@ -238,6 +248,7 @@ fn create_output_batch( type_builder.append_value("Plan with Full Metrics"); let annotated_plan = DisplayableExecutionPlan::with_full_metrics(input.as_ref()) + .set_metric_types(metric_types.to_vec()) .set_show_statistics(show_statistics) .indent(verbose) .to_string(); @@ -282,7 +293,13 @@ mod tests { let blocking_exec = Arc::new(BlockingExec::new(Arc::clone(&schema), 1)); let refs = blocking_exec.refs(); - let analyze_exec = Arc::new(AnalyzeExec::new(true, false, blocking_exec, schema)); + let analyze_exec = Arc::new(AnalyzeExec::new( + true, + false, + vec![MetricType::SUMMARY, MetricType::DEV], + blocking_exec, + schema, + )); let fut = collect(analyze_exec, task_ctx); let mut fut = fut.boxed(); diff --git a/datafusion/physical-plan/src/coalesce_partitions.rs b/datafusion/physical-plan/src/coalesce_partitions.rs index 5869c51b26b8..2597dc6408de 100644 --- a/datafusion/physical-plan/src/coalesce_partitions.rs +++ b/datafusion/physical-plan/src/coalesce_partitions.rs @@ -170,8 +170,18 @@ impl ExecutionPlan for CoalescePartitionsExec { "CoalescePartitionsExec requires at least one input partition" ), 1 => { - // bypass any threading / metrics if there is a single partition - self.input.execute(0, context) + // single-partition path: execute child directly, but ensure fetch is respected + // (wrap with ObservedStream only if fetch is present so we don't add overhead otherwise) + let child_stream = self.input.execute(0, context)?; + if self.fetch.is_some() { + let baseline_metrics = BaselineMetrics::new(&self.metrics, partition); + return Ok(Box::pin(ObservedStream::new( + child_stream, + baseline_metrics, + self.fetch, + ))); + } + Ok(child_stream) } _ => { let baseline_metrics = BaselineMetrics::new(&self.metrics, partition); @@ -351,4 +361,110 @@ mod tests { collect(coalesce_partitions_exec, task_ctx).await.unwrap(); } + + #[tokio::test] + async fn test_single_partition_with_fetch() -> Result<()> { + let task_ctx = Arc::new(TaskContext::default()); + + // Use existing scan_partitioned with 1 partition (returns 100 rows per partition) + let input = test::scan_partitioned(1); + + // Test with fetch=3 + let coalesce = CoalescePartitionsExec::new(input).with_fetch(Some(3)); + + let stream = coalesce.execute(0, task_ctx)?; + let batches = common::collect(stream).await?; + + let row_count: usize = batches.iter().map(|batch| batch.num_rows()).sum(); + assert_eq!(row_count, 3, "Should only return 3 rows due to fetch=3"); + + Ok(()) + } + + #[tokio::test] + async fn test_multi_partition_with_fetch_one() -> Result<()> { + let task_ctx = Arc::new(TaskContext::default()); + + // Create 4 partitions, each with 100 rows + // This simulates the real-world scenario where each partition has data + let input = test::scan_partitioned(4); + + // Test with fetch=1 (the original bug: was returning multiple rows instead of 1) + let coalesce = CoalescePartitionsExec::new(input).with_fetch(Some(1)); + + let stream = coalesce.execute(0, task_ctx)?; + let batches = common::collect(stream).await?; + + let row_count: usize = batches.iter().map(|batch| batch.num_rows()).sum(); + assert_eq!( + row_count, 1, + "Should only return 1 row due to fetch=1, not one per partition" + ); + + Ok(()) + } + + #[tokio::test] + async fn test_single_partition_without_fetch() -> Result<()> { + let task_ctx = Arc::new(TaskContext::default()); + + // Use scan_partitioned with 1 partition + let input = test::scan_partitioned(1); + + // Test without fetch (should return all rows) + let coalesce = CoalescePartitionsExec::new(input); + + let stream = coalesce.execute(0, task_ctx)?; + let batches = common::collect(stream).await?; + + let row_count: usize = batches.iter().map(|batch| batch.num_rows()).sum(); + assert_eq!( + row_count, 100, + "Should return all 100 rows when fetch is None" + ); + + Ok(()) + } + + #[tokio::test] + async fn test_single_partition_fetch_larger_than_batch() -> Result<()> { + let task_ctx = Arc::new(TaskContext::default()); + + // Use scan_partitioned with 1 partition (returns 100 rows) + let input = test::scan_partitioned(1); + + // Test with fetch larger than available rows + let coalesce = CoalescePartitionsExec::new(input).with_fetch(Some(200)); + + let stream = coalesce.execute(0, task_ctx)?; + let batches = common::collect(stream).await?; + + let row_count: usize = batches.iter().map(|batch| batch.num_rows()).sum(); + assert_eq!( + row_count, 100, + "Should return all available rows (100) when fetch (200) is larger" + ); + + Ok(()) + } + + #[tokio::test] + async fn test_multi_partition_fetch_exact_match() -> Result<()> { + let task_ctx = Arc::new(TaskContext::default()); + + // Create 4 partitions, each with 100 rows + let num_partitions = 4; + let csv = test::scan_partitioned(num_partitions); + + // Test with fetch=400 (exactly all rows) + let coalesce = CoalescePartitionsExec::new(csv).with_fetch(Some(400)); + + let stream = coalesce.execute(0, task_ctx)?; + let batches = common::collect(stream).await?; + + let row_count: usize = batches.iter().map(|batch| batch.num_rows()).sum(); + assert_eq!(row_count, 400, "Should return exactly 400 rows"); + + Ok(()) + } } diff --git a/datafusion/physical-plan/src/display.rs b/datafusion/physical-plan/src/display.rs index 2420edfc743d..35ca0b65ae29 100644 --- a/datafusion/physical-plan/src/display.rs +++ b/datafusion/physical-plan/src/display.rs @@ -28,6 +28,7 @@ use datafusion_common::display::{GraphvizBuilder, PlanType, StringifiedPlan}; use datafusion_expr::display_schema; use datafusion_physical_expr::LexOrdering; +use crate::metrics::MetricType; use crate::render_tree::RenderTree; use super::{accept, ExecutionPlan, ExecutionPlanVisitor}; @@ -120,11 +121,17 @@ pub struct DisplayableExecutionPlan<'a> { show_statistics: bool, /// If schema should be displayed. See [`Self::set_show_schema`] show_schema: bool, + /// Which metric categories should be included when rendering + metric_types: Vec, // (TreeRender) Maximum total width of the rendered tree tree_maximum_render_width: usize, } impl<'a> DisplayableExecutionPlan<'a> { + fn default_metric_types() -> Vec { + vec![MetricType::SUMMARY, MetricType::DEV] + } + /// Create a wrapper around an [`ExecutionPlan`] which can be /// pretty printed in a variety of ways pub fn new(inner: &'a dyn ExecutionPlan) -> Self { @@ -133,6 +140,7 @@ impl<'a> DisplayableExecutionPlan<'a> { show_metrics: ShowMetrics::None, show_statistics: false, show_schema: false, + metric_types: Self::default_metric_types(), tree_maximum_render_width: 240, } } @@ -146,6 +154,7 @@ impl<'a> DisplayableExecutionPlan<'a> { show_metrics: ShowMetrics::Aggregated, show_statistics: false, show_schema: false, + metric_types: Self::default_metric_types(), tree_maximum_render_width: 240, } } @@ -159,6 +168,7 @@ impl<'a> DisplayableExecutionPlan<'a> { show_metrics: ShowMetrics::Full, show_statistics: false, show_schema: false, + metric_types: Self::default_metric_types(), tree_maximum_render_width: 240, } } @@ -178,6 +188,12 @@ impl<'a> DisplayableExecutionPlan<'a> { self } + /// Specify which metric types should be rendered alongside the plan + pub fn set_metric_types(mut self, metric_types: Vec) -> Self { + self.metric_types = metric_types; + self + } + /// Set the maximum render width for the tree format pub fn set_tree_maximum_render_width(mut self, width: usize) -> Self { self.tree_maximum_render_width = width; @@ -206,6 +222,7 @@ impl<'a> DisplayableExecutionPlan<'a> { show_metrics: ShowMetrics, show_statistics: bool, show_schema: bool, + metric_types: Vec, } impl fmt::Display for Wrapper<'_> { fn fmt(&self, f: &mut Formatter) -> fmt::Result { @@ -216,6 +233,7 @@ impl<'a> DisplayableExecutionPlan<'a> { show_metrics: self.show_metrics, show_statistics: self.show_statistics, show_schema: self.show_schema, + metric_types: &self.metric_types, }; accept(self.plan, &mut visitor) } @@ -226,6 +244,7 @@ impl<'a> DisplayableExecutionPlan<'a> { show_metrics: self.show_metrics, show_statistics: self.show_statistics, show_schema: self.show_schema, + metric_types: self.metric_types.clone(), } } @@ -245,6 +264,7 @@ impl<'a> DisplayableExecutionPlan<'a> { plan: &'a dyn ExecutionPlan, show_metrics: ShowMetrics, show_statistics: bool, + metric_types: Vec, } impl fmt::Display for Wrapper<'_> { fn fmt(&self, f: &mut Formatter) -> fmt::Result { @@ -255,6 +275,7 @@ impl<'a> DisplayableExecutionPlan<'a> { t, show_metrics: self.show_metrics, show_statistics: self.show_statistics, + metric_types: &self.metric_types, graphviz_builder: GraphvizBuilder::default(), parents: Vec::new(), }; @@ -272,6 +293,7 @@ impl<'a> DisplayableExecutionPlan<'a> { plan: self.inner, show_metrics: self.show_metrics, show_statistics: self.show_statistics, + metric_types: self.metric_types.clone(), } } @@ -306,6 +328,7 @@ impl<'a> DisplayableExecutionPlan<'a> { show_metrics: ShowMetrics, show_statistics: bool, show_schema: bool, + metric_types: Vec, } impl fmt::Display for Wrapper<'_> { @@ -317,6 +340,7 @@ impl<'a> DisplayableExecutionPlan<'a> { show_metrics: self.show_metrics, show_statistics: self.show_statistics, show_schema: self.show_schema, + metric_types: &self.metric_types, }; visitor.pre_visit(self.plan)?; Ok(()) @@ -328,6 +352,7 @@ impl<'a> DisplayableExecutionPlan<'a> { show_metrics: self.show_metrics, show_statistics: self.show_statistics, show_schema: self.show_schema, + metric_types: self.metric_types.clone(), } } @@ -382,6 +407,8 @@ struct IndentVisitor<'a, 'b> { show_statistics: bool, /// If schema should be displayed show_schema: bool, + /// Which metric types should be rendered + metric_types: &'a [MetricType], } impl ExecutionPlanVisitor for IndentVisitor<'_, '_> { @@ -394,6 +421,7 @@ impl ExecutionPlanVisitor for IndentVisitor<'_, '_> { ShowMetrics::Aggregated => { if let Some(metrics) = plan.metrics() { let metrics = metrics + .filter_by_metric_types(self.metric_types) .aggregate_by_name() .sorted_for_display() .timestamps_removed(); @@ -405,6 +433,7 @@ impl ExecutionPlanVisitor for IndentVisitor<'_, '_> { } ShowMetrics::Full => { if let Some(metrics) = plan.metrics() { + let metrics = metrics.filter_by_metric_types(self.metric_types); write!(self.f, ", metrics=[{metrics}]")?; } else { write!(self.f, ", metrics=[]")?; @@ -441,6 +470,8 @@ struct GraphvizVisitor<'a, 'b> { show_metrics: ShowMetrics, /// If statistics should be displayed show_statistics: bool, + /// Which metric types should be rendered + metric_types: &'a [MetricType], graphviz_builder: GraphvizBuilder, /// Used to record parent node ids when visiting a plan. @@ -478,6 +509,7 @@ impl ExecutionPlanVisitor for GraphvizVisitor<'_, '_> { ShowMetrics::Aggregated => { if let Some(metrics) = plan.metrics() { let metrics = metrics + .filter_by_metric_types(self.metric_types) .aggregate_by_name() .sorted_for_display() .timestamps_removed(); @@ -489,6 +521,7 @@ impl ExecutionPlanVisitor for GraphvizVisitor<'_, '_> { } ShowMetrics::Full => { if let Some(metrics) = plan.metrics() { + let metrics = metrics.filter_by_metric_types(self.metric_types); format!("metrics=[{metrics}]") } else { "metrics=[]".to_string() diff --git a/datafusion/physical-plan/src/joins/hash_join/exec.rs b/datafusion/physical-plan/src/joins/hash_join/exec.rs index 4c293b0498e7..b5fe5ee5cda1 100644 --- a/datafusion/physical-plan/src/joins/hash_join/exec.rs +++ b/datafusion/physical-plan/src/joins/hash_join/exec.rs @@ -1137,7 +1137,7 @@ impl ExecutionPlan for HashJoinExec { // Add dynamic filters in Post phase if enabled if matches!(phase, FilterPushdownPhase::Post) - && config.optimizer.enable_dynamic_filter_pushdown + && config.optimizer.enable_join_dynamic_filter_pushdown { // Add actual dynamic filter to right side (probe side) let dynamic_filter = Self::create_dynamic_filter(&self.on); diff --git a/datafusion/physical-plan/src/joins/hash_join/stream.rs b/datafusion/physical-plan/src/joins/hash_join/stream.rs index adc00d9fe75e..88c50c2eb2ce 100644 --- a/datafusion/physical-plan/src/joins/hash_join/stream.rs +++ b/datafusion/physical-plan/src/joins/hash_join/stream.rs @@ -637,6 +637,7 @@ impl HashJoinStream { let (left_side, right_side) = get_final_indices_from_shared_bitmap( build_side.left_data.visited_indices_bitmap(), self.join_type, + true, ); let empty_right_batch = RecordBatch::new_empty(self.right.schema()); // use the left and right indices to produce the batch result diff --git a/datafusion/physical-plan/src/joins/mod.rs b/datafusion/physical-plan/src/joins/mod.rs index 1d36db996434..b0c28cf994f7 100644 --- a/datafusion/physical-plan/src/joins/mod.rs +++ b/datafusion/physical-plan/src/joins/mod.rs @@ -24,11 +24,13 @@ pub use hash_join::HashJoinExec; pub use nested_loop_join::NestedLoopJoinExec; use parking_lot::Mutex; // Note: SortMergeJoin is not used in plans yet +pub use piecewise_merge_join::PiecewiseMergeJoinExec; pub use sort_merge_join::SortMergeJoinExec; pub use symmetric_hash_join::SymmetricHashJoinExec; mod cross_join; mod hash_join; mod nested_loop_join; +mod piecewise_merge_join; mod sort_merge_join; mod stream_join_utils; mod symmetric_hash_join; diff --git a/datafusion/physical-plan/src/joins/nested_loop_join.rs b/datafusion/physical-plan/src/joins/nested_loop_join.rs index 0974b3a9114e..7ae09a42de88 100644 --- a/datafusion/physical-plan/src/joins/nested_loop_join.rs +++ b/datafusion/physical-plan/src/joins/nested_loop_join.rs @@ -48,11 +48,15 @@ use crate::{ use arrow::array::{ new_null_array, Array, BooleanArray, BooleanBufferBuilder, RecordBatchOptions, + UInt64Array, }; use arrow::buffer::BooleanBuffer; -use arrow::compute::{concat_batches, filter, filter_record_batch, not, BatchCoalescer}; +use arrow::compute::{ + concat_batches, filter, filter_record_batch, not, take, BatchCoalescer, +}; use arrow::datatypes::{Schema, SchemaRef}; use arrow::record_batch::RecordBatch; +use arrow_schema::DataType; use datafusion_common::cast::as_boolean_array; use datafusion_common::{ arrow_err, internal_datafusion_err, internal_err, project_schema, @@ -1661,11 +1665,30 @@ fn build_row_join_batch( // Broadcast the single build-side row to match the filtered // probe-side batch length let original_left_array = build_side_batch.column(column_index.index); - let scalar_value = ScalarValue::try_from_array( - original_left_array.as_ref(), - build_side_index, - )?; - scalar_value.to_array_of_size(filtered_probe_batch.num_rows())? + // Avoid using `ScalarValue::to_array_of_size()` for `List(Utf8View)` to avoid + // deep copies for buffers inside `Utf8View` array. See below for details. + // https://github.com/apache/datafusion/issues/18159 + // + // In other cases, `to_array_of_size()` is faster. + match original_left_array.data_type() { + DataType::List(field) | DataType::LargeList(field) + if field.data_type() == &DataType::Utf8View => + { + let indices_iter = std::iter::repeat_n( + build_side_index as u64, + filtered_probe_batch.num_rows(), + ); + let indices_array = UInt64Array::from_iter_values(indices_iter); + take(original_left_array.as_ref(), &indices_array, None)? + } + _ => { + let scalar_value = ScalarValue::try_from_array( + original_left_array.as_ref(), + build_side_index, + )?; + scalar_value.to_array_of_size(filtered_probe_batch.num_rows())? + } + } } else { // Take the filtered probe-side column using compute::take Arc::clone(filtered_probe_batch.column(column_index.index)) diff --git a/datafusion/physical-plan/src/joins/piecewise_merge_join/classic_join.rs b/datafusion/physical-plan/src/joins/piecewise_merge_join/classic_join.rs new file mode 100644 index 000000000000..646905e0d787 --- /dev/null +++ b/datafusion/physical-plan/src/joins/piecewise_merge_join/classic_join.rs @@ -0,0 +1,1550 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Stream Implementation for PiecewiseMergeJoin's Classic Join (Left, Right, Full, Inner) + +use arrow::array::{new_null_array, Array, PrimitiveBuilder}; +use arrow::compute::{take, BatchCoalescer}; +use arrow::datatypes::UInt32Type; +use arrow::{ + array::{ArrayRef, RecordBatch, UInt32Array}, + compute::{sort_to_indices, take_record_batch}, +}; +use arrow_schema::{Schema, SchemaRef, SortOptions}; +use datafusion_common::NullEquality; +use datafusion_common::{internal_err, Result}; +use datafusion_execution::{RecordBatchStream, SendableRecordBatchStream}; +use datafusion_expr::{JoinType, Operator}; +use datafusion_physical_expr::PhysicalExprRef; +use futures::{Stream, StreamExt}; +use std::{cmp::Ordering, task::ready}; +use std::{sync::Arc, task::Poll}; + +use crate::handle_state; +use crate::joins::piecewise_merge_join::exec::{BufferedSide, BufferedSideReadyState}; +use crate::joins::piecewise_merge_join::utils::need_produce_result_in_final; +use crate::joins::utils::{compare_join_arrays, get_final_indices_from_shared_bitmap}; +use crate::joins::utils::{BuildProbeJoinMetrics, StatefulStreamResult}; + +pub(super) enum PiecewiseMergeJoinStreamState { + WaitBufferedSide, + FetchStreamBatch, + ProcessStreamBatch(SortedStreamBatch), + ProcessUnmatched, + Completed, +} + +impl PiecewiseMergeJoinStreamState { + // Grab mutable reference to the current stream batch + fn try_as_process_stream_batch_mut(&mut self) -> Result<&mut SortedStreamBatch> { + match self { + PiecewiseMergeJoinStreamState::ProcessStreamBatch(state) => Ok(state), + _ => internal_err!("Expected streamed batch in StreamBatch"), + } + } +} + +/// The stream side incoming batch with required sort order. +/// +/// Note the compare key in the join predicate might include expressions on the original +/// columns, so we store the evaluated compare key separately. +/// e.g. For join predicate `buffer.v1 < (stream.v1 + 1)`, the `compare_key_values` field stores +/// the evaluated `stream.v1 + 1` array. +pub(super) struct SortedStreamBatch { + pub batch: RecordBatch, + compare_key_values: Vec, +} + +impl SortedStreamBatch { + #[allow(dead_code)] + fn new(batch: RecordBatch, compare_key_values: Vec) -> Self { + Self { + batch, + compare_key_values, + } + } + + fn compare_key_values(&self) -> &Vec { + &self.compare_key_values + } +} + +pub(super) struct ClassicPWMJStream { + // Output schema of the `PiecewiseMergeJoin` + pub schema: Arc, + + // Physical expression that is evaluated on the streamed side + // We do not need on_buffered as this is already evaluated when + // creating the buffered side which happens before initializing + // `PiecewiseMergeJoinStream` + pub on_streamed: PhysicalExprRef, + // Type of join + pub join_type: JoinType, + // Comparison operator + pub operator: Operator, + // Streamed batch + pub streamed: SendableRecordBatchStream, + // Streamed schema + streamed_schema: SchemaRef, + // Buffered side data + buffered_side: BufferedSide, + // Tracks the state of the `PiecewiseMergeJoin` + state: PiecewiseMergeJoinStreamState, + // Sort option for streamed side (specifies whether + // the sort is ascending or descending) + sort_option: SortOptions, + // Metrics for build + probe joins + join_metrics: BuildProbeJoinMetrics, + // Tracking incremental state for emitting record batches + batch_process_state: BatchProcessState, +} + +impl RecordBatchStream for ClassicPWMJStream { + fn schema(&self) -> SchemaRef { + Arc::clone(&self.schema) + } +} + +// `PiecewiseMergeJoinStreamState` is separated into `WaitBufferedSide`, `FetchStreamBatch`, +// `ProcessStreamBatch`, `ProcessUnmatched` and `Completed`. +// +// Classic Joins +// 1. `WaitBufferedSide` - Load in the buffered side data into memory. +// 2. `FetchStreamBatch` - Fetch + sort incoming stream batches. We switch the state to +// `Completed` if there are are still remaining partitions to process. It is only switched to +// `ExhaustedStreamBatch` if all partitions have been processed. +// 3. `ProcessStreamBatch` - Compare stream batch row values against the buffered side data. +// 4. `ExhaustedStreamBatch` - If the join type is Left or Inner we will return state as +// `Completed` however for Full and Right we will need to process the unmatched buffered rows. +impl ClassicPWMJStream { + // Creates a new `PiecewiseMergeJoinStream` instance + #[allow(clippy::too_many_arguments)] + pub fn try_new( + schema: Arc, + on_streamed: PhysicalExprRef, + join_type: JoinType, + operator: Operator, + streamed: SendableRecordBatchStream, + buffered_side: BufferedSide, + state: PiecewiseMergeJoinStreamState, + sort_option: SortOptions, + join_metrics: BuildProbeJoinMetrics, + batch_size: usize, + ) -> Self { + Self { + schema: Arc::clone(&schema), + on_streamed, + join_type, + operator, + streamed_schema: streamed.schema(), + streamed, + buffered_side, + state, + sort_option, + join_metrics, + batch_process_state: BatchProcessState::new(schema, batch_size), + } + } + + fn poll_next_impl( + &mut self, + cx: &mut std::task::Context<'_>, + ) -> Poll>> { + loop { + return match self.state { + PiecewiseMergeJoinStreamState::WaitBufferedSide => { + handle_state!(ready!(self.collect_buffered_side(cx))) + } + PiecewiseMergeJoinStreamState::FetchStreamBatch => { + handle_state!(ready!(self.fetch_stream_batch(cx))) + } + PiecewiseMergeJoinStreamState::ProcessStreamBatch(_) => { + handle_state!(self.process_stream_batch()) + } + PiecewiseMergeJoinStreamState::ProcessUnmatched => { + handle_state!(self.process_unmatched_buffered_batch()) + } + PiecewiseMergeJoinStreamState::Completed => Poll::Ready(None), + }; + } + } + + // Collects buffered side data + fn collect_buffered_side( + &mut self, + cx: &mut std::task::Context<'_>, + ) -> Poll>>> { + let build_timer = self.join_metrics.build_time.timer(); + let buffered_data = ready!(self + .buffered_side + .try_as_initial_mut()? + .buffered_fut + .get_shared(cx))?; + build_timer.done(); + + // We will start fetching stream batches for classic joins + self.state = PiecewiseMergeJoinStreamState::FetchStreamBatch; + + self.buffered_side = + BufferedSide::Ready(BufferedSideReadyState { buffered_data }); + + Poll::Ready(Ok(StatefulStreamResult::Continue)) + } + + // Fetches incoming stream batches + fn fetch_stream_batch( + &mut self, + cx: &mut std::task::Context<'_>, + ) -> Poll>>> { + match ready!(self.streamed.poll_next_unpin(cx)) { + None => { + if self + .buffered_side + .try_as_ready_mut()? + .buffered_data + .remaining_partitions + .fetch_sub(1, std::sync::atomic::Ordering::SeqCst) + == 1 + { + self.batch_process_state.reset(); + self.state = PiecewiseMergeJoinStreamState::ProcessUnmatched; + } else { + self.state = PiecewiseMergeJoinStreamState::Completed; + } + } + Some(Ok(batch)) => { + // Evaluate the streamed physical expression on the stream batch + let stream_values: ArrayRef = self + .on_streamed + .evaluate(&batch)? + .into_array(batch.num_rows())?; + + self.join_metrics.input_batches.add(1); + self.join_metrics.input_rows.add(batch.num_rows()); + + // Sort stream values and change the streamed record batch accordingly + let indices = sort_to_indices( + stream_values.as_ref(), + Some(self.sort_option), + None, + )?; + let stream_batch = take_record_batch(&batch, &indices)?; + let stream_values = take(stream_values.as_ref(), &indices, None)?; + + // Reset BatchProcessState before processing a new stream batch + self.batch_process_state.reset(); + self.state = PiecewiseMergeJoinStreamState::ProcessStreamBatch( + SortedStreamBatch { + batch: stream_batch, + compare_key_values: vec![stream_values], + }, + ); + } + Some(Err(err)) => return Poll::Ready(Err(err)), + }; + + Poll::Ready(Ok(StatefulStreamResult::Continue)) + } + + // Only classic join will call. This function will process stream batches and evaluate against + // the buffered side data. + fn process_stream_batch( + &mut self, + ) -> Result>> { + let buffered_side = self.buffered_side.try_as_ready_mut()?; + let stream_batch = self.state.try_as_process_stream_batch_mut()?; + + if let Some(batch) = self + .batch_process_state + .output_batches + .next_completed_batch() + { + return Ok(StatefulStreamResult::Ready(Some(batch))); + } + + // Produce more work + let batch = resolve_classic_join( + buffered_side, + stream_batch, + Arc::clone(&self.schema), + self.operator, + self.sort_option, + self.join_type, + &mut self.batch_process_state, + )?; + + if !self.batch_process_state.continue_process { + // We finished scanning this stream batch. + self.batch_process_state + .output_batches + .finish_buffered_batch()?; + if let Some(b) = self + .batch_process_state + .output_batches + .next_completed_batch() + { + self.state = PiecewiseMergeJoinStreamState::FetchStreamBatch; + return Ok(StatefulStreamResult::Ready(Some(b))); + } + + // Nothing pending; hand back whatever `resolve` returned (often empty) and move on. + if self.batch_process_state.output_batches.is_empty() { + self.state = PiecewiseMergeJoinStreamState::FetchStreamBatch; + + return Ok(StatefulStreamResult::Ready(Some(batch))); + } + } + + Ok(StatefulStreamResult::Ready(Some(batch))) + } + + // Process remaining unmatched rows + fn process_unmatched_buffered_batch( + &mut self, + ) -> Result>> { + // Return early for `JoinType::Right` and `JoinType::Inner` + if matches!(self.join_type, JoinType::Right | JoinType::Inner) { + self.state = PiecewiseMergeJoinStreamState::Completed; + return Ok(StatefulStreamResult::Ready(None)); + } + + if !self.batch_process_state.continue_process { + if let Some(batch) = self + .batch_process_state + .output_batches + .next_completed_batch() + { + return Ok(StatefulStreamResult::Ready(Some(batch))); + } + + self.batch_process_state + .output_batches + .finish_buffered_batch()?; + if let Some(batch) = self + .batch_process_state + .output_batches + .next_completed_batch() + { + self.state = PiecewiseMergeJoinStreamState::Completed; + return Ok(StatefulStreamResult::Ready(Some(batch))); + } + } + + let buffered_data = + Arc::clone(&self.buffered_side.try_as_ready().unwrap().buffered_data); + + let (buffered_indices, _streamed_indices) = get_final_indices_from_shared_bitmap( + &buffered_data.visited_indices_bitmap, + self.join_type, + true, + ); + + let new_buffered_batch = + take_record_batch(buffered_data.batch(), &buffered_indices)?; + let mut buffered_columns = new_buffered_batch.columns().to_vec(); + + let streamed_columns: Vec = self + .streamed_schema + .fields() + .iter() + .map(|f| new_null_array(f.data_type(), new_buffered_batch.num_rows())) + .collect(); + + buffered_columns.extend(streamed_columns); + + let batch = RecordBatch::try_new(Arc::clone(&self.schema), buffered_columns)?; + + self.batch_process_state.output_batches.push_batch(batch)?; + + self.batch_process_state.continue_process = false; + if let Some(batch) = self + .batch_process_state + .output_batches + .next_completed_batch() + { + return Ok(StatefulStreamResult::Ready(Some(batch))); + } + + self.batch_process_state + .output_batches + .finish_buffered_batch()?; + if let Some(batch) = self + .batch_process_state + .output_batches + .next_completed_batch() + { + self.state = PiecewiseMergeJoinStreamState::Completed; + return Ok(StatefulStreamResult::Ready(Some(batch))); + } + + self.state = PiecewiseMergeJoinStreamState::Completed; + self.batch_process_state.reset(); + Ok(StatefulStreamResult::Ready(None)) + } +} + +struct BatchProcessState { + // Used to pick up from the last index on the stream side + output_batches: Box, + // Used to store the unmatched stream indices for `JoinType::Right` and `JoinType::Full` + unmatched_indices: PrimitiveBuilder, + // Used to store the start index on the buffered side; used to resume processing on the correct + // row + start_buffer_idx: usize, + // Used to store the start index on the stream side; used to resume processing on the correct + // row + start_stream_idx: usize, + // Signals if we found a match for the current stream row + found: bool, + // Signals to continue processing the current stream batch + continue_process: bool, + // Skip nulls + processed_null_count: bool, +} + +impl BatchProcessState { + pub(crate) fn new(schema: Arc, batch_size: usize) -> Self { + Self { + output_batches: Box::new(BatchCoalescer::new(schema, batch_size)), + unmatched_indices: PrimitiveBuilder::new(), + start_buffer_idx: 0, + start_stream_idx: 0, + found: false, + continue_process: true, + processed_null_count: false, + } + } + + pub(crate) fn reset(&mut self) { + self.unmatched_indices = PrimitiveBuilder::new(); + self.start_buffer_idx = 0; + self.start_stream_idx = 0; + self.found = false; + self.continue_process = true; + self.processed_null_count = false; + } +} + +impl Stream for ClassicPWMJStream { + type Item = Result; + + fn poll_next( + mut self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> Poll> { + self.poll_next_impl(cx) + } +} + +// For Left, Right, Full, and Inner joins, incoming stream batches will already be sorted. +#[allow(clippy::too_many_arguments)] +fn resolve_classic_join( + buffered_side: &mut BufferedSideReadyState, + stream_batch: &SortedStreamBatch, + join_schema: Arc, + operator: Operator, + sort_options: SortOptions, + join_type: JoinType, + batch_process_state: &mut BatchProcessState, +) -> Result { + let buffered_len = buffered_side.buffered_data.values().len(); + let stream_values = stream_batch.compare_key_values(); + + let mut buffer_idx = batch_process_state.start_buffer_idx; + let mut stream_idx = batch_process_state.start_stream_idx; + + if !batch_process_state.processed_null_count { + let buffered_null_idx = buffered_side.buffered_data.values().null_count(); + let stream_null_idx = stream_values[0].null_count(); + buffer_idx = buffered_null_idx; + stream_idx = stream_null_idx; + batch_process_state.processed_null_count = true; + } + + // Our buffer_idx variable allows us to start probing on the buffered side where we last matched + // in the previous stream row. + for row_idx in stream_idx..stream_batch.batch.num_rows() { + while buffer_idx < buffered_len { + let compare = { + let buffered_values = buffered_side.buffered_data.values(); + compare_join_arrays( + &[Arc::clone(&stream_values[0])], + row_idx, + &[Arc::clone(buffered_values)], + buffer_idx, + &[sort_options], + NullEquality::NullEqualsNothing, + )? + }; + + // If we find a match we append all indices and move to the next stream row index + match operator { + Operator::Gt | Operator::Lt => { + if matches!(compare, Ordering::Less) { + batch_process_state.found = true; + let count = buffered_len - buffer_idx; + + let batch = build_matched_indices_and_set_buffered_bitmap( + (buffer_idx, count), + (row_idx, count), + buffered_side, + stream_batch, + join_type, + Arc::clone(&join_schema), + )?; + + batch_process_state.output_batches.push_batch(batch)?; + + // Flush batch and update pointers if we have a completed batch + if let Some(batch) = + batch_process_state.output_batches.next_completed_batch() + { + batch_process_state.found = false; + batch_process_state.start_buffer_idx = buffer_idx; + batch_process_state.start_stream_idx = row_idx + 1; + return Ok(batch); + } + + break; + } + } + Operator::GtEq | Operator::LtEq => { + if matches!(compare, Ordering::Equal | Ordering::Less) { + batch_process_state.found = true; + let count = buffered_len - buffer_idx; + let batch = build_matched_indices_and_set_buffered_bitmap( + (buffer_idx, count), + (row_idx, count), + buffered_side, + stream_batch, + join_type, + Arc::clone(&join_schema), + )?; + + // Flush batch and update pointers if we have a completed batch + batch_process_state.output_batches.push_batch(batch)?; + if let Some(batch) = + batch_process_state.output_batches.next_completed_batch() + { + batch_process_state.found = false; + batch_process_state.start_buffer_idx = buffer_idx; + batch_process_state.start_stream_idx = row_idx + 1; + return Ok(batch); + } + + break; + } + } + _ => { + return internal_err!( + "PiecewiseMergeJoin should not contain operator, {}", + operator + ) + } + }; + + // Increment buffer_idx after every row + buffer_idx += 1; + } + + // If a match was not found for the current stream row index the stream indice is appended + // to the unmatched indices to be flushed later. + if matches!(join_type, JoinType::Right | JoinType::Full) + && !batch_process_state.found + { + batch_process_state + .unmatched_indices + .append_value(row_idx as u32); + } + + batch_process_state.found = false; + } + + // Flushed all unmatched indices on the streamed side + if matches!(join_type, JoinType::Right | JoinType::Full) { + let batch = create_unmatched_batch( + &mut batch_process_state.unmatched_indices, + stream_batch, + Arc::clone(&join_schema), + )?; + + batch_process_state.output_batches.push_batch(batch)?; + } + + batch_process_state.continue_process = false; + Ok(RecordBatch::new_empty(Arc::clone(&join_schema))) +} + +// Builds a record batch from indices ranges on the buffered and streamed side. +// +// The two ranges are: buffered_range: (start index, count) and streamed_range: (start index, count) due +// to batch.slice(start, count). +fn build_matched_indices_and_set_buffered_bitmap( + buffered_range: (usize, usize), + streamed_range: (usize, usize), + buffered_side: &mut BufferedSideReadyState, + stream_batch: &SortedStreamBatch, + join_type: JoinType, + join_schema: Arc, +) -> Result { + // Mark the buffered indices as visited + if need_produce_result_in_final(join_type) { + let mut bitmap = buffered_side.buffered_data.visited_indices_bitmap.lock(); + for i in buffered_range.0..buffered_range.0 + buffered_range.1 { + bitmap.set_bit(i, true); + } + } + + let new_buffered_batch = buffered_side + .buffered_data + .batch() + .slice(buffered_range.0, buffered_range.1); + let mut buffered_columns = new_buffered_batch.columns().to_vec(); + + let indices = UInt32Array::from_value(streamed_range.0 as u32, streamed_range.1); + let new_stream_batch = take_record_batch(&stream_batch.batch, &indices)?; + let streamed_columns = new_stream_batch.columns().to_vec(); + + buffered_columns.extend(streamed_columns); + + Ok(RecordBatch::try_new( + Arc::clone(&join_schema), + buffered_columns, + )?) +} + +// Creates a record batch from the unmatched indices on the streamed side +fn create_unmatched_batch( + streamed_indices: &mut PrimitiveBuilder, + stream_batch: &SortedStreamBatch, + join_schema: Arc, +) -> Result { + let streamed_indices = streamed_indices.finish(); + let new_stream_batch = take_record_batch(&stream_batch.batch, &streamed_indices)?; + let streamed_columns = new_stream_batch.columns().to_vec(); + let buffered_cols_len = join_schema.fields().len() - streamed_columns.len(); + + let num_rows = new_stream_batch.num_rows(); + let mut buffered_columns: Vec = join_schema + .fields() + .iter() + .take(buffered_cols_len) + .map(|field| new_null_array(field.data_type(), num_rows)) + .collect(); + + buffered_columns.extend(streamed_columns); + + Ok(RecordBatch::try_new( + Arc::clone(&join_schema), + buffered_columns, + )?) +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::{ + common, + joins::PiecewiseMergeJoinExec, + test::{build_table_i32, TestMemoryExec}, + ExecutionPlan, + }; + use arrow::array::{Date32Array, Date64Array}; + use arrow_schema::{DataType, Field}; + use datafusion_common::test_util::batches_to_string; + use datafusion_execution::TaskContext; + use datafusion_expr::JoinType; + use datafusion_physical_expr::{expressions::Column, PhysicalExpr}; + use insta::assert_snapshot; + use std::sync::Arc; + + fn columns(schema: &Schema) -> Vec { + schema.fields().iter().map(|f| f.name().clone()).collect() + } + + fn build_table( + a: (&str, &Vec), + b: (&str, &Vec), + c: (&str, &Vec), + ) -> Arc { + let batch = build_table_i32(a, b, c); + let schema = batch.schema(); + TestMemoryExec::try_new_exec(&[vec![batch]], schema, None).unwrap() + } + + fn build_date_table( + a: (&str, &Vec), + b: (&str, &Vec), + c: (&str, &Vec), + ) -> Arc { + let schema = Schema::new(vec![ + Field::new(a.0, DataType::Date32, false), + Field::new(b.0, DataType::Date32, false), + Field::new(c.0, DataType::Date32, false), + ]); + + let batch = RecordBatch::try_new( + Arc::new(schema), + vec![ + Arc::new(Date32Array::from(a.1.clone())), + Arc::new(Date32Array::from(b.1.clone())), + Arc::new(Date32Array::from(c.1.clone())), + ], + ) + .unwrap(); + + let schema = batch.schema(); + TestMemoryExec::try_new_exec(&[vec![batch]], schema, None).unwrap() + } + + fn build_date64_table( + a: (&str, &Vec), + b: (&str, &Vec), + c: (&str, &Vec), + ) -> Arc { + let schema = Schema::new(vec![ + Field::new(a.0, DataType::Date64, false), + Field::new(b.0, DataType::Date64, false), + Field::new(c.0, DataType::Date64, false), + ]); + + let batch = RecordBatch::try_new( + Arc::new(schema), + vec![ + Arc::new(Date64Array::from(a.1.clone())), + Arc::new(Date64Array::from(b.1.clone())), + Arc::new(Date64Array::from(c.1.clone())), + ], + ) + .unwrap(); + + let schema = batch.schema(); + TestMemoryExec::try_new_exec(&[vec![batch]], schema, None).unwrap() + } + + fn join( + left: Arc, + right: Arc, + on: (Arc, Arc), + operator: Operator, + join_type: JoinType, + ) -> Result { + PiecewiseMergeJoinExec::try_new(left, right, on, operator, join_type, 1) + } + + async fn join_collect( + left: Arc, + right: Arc, + on: (PhysicalExprRef, PhysicalExprRef), + operator: Operator, + join_type: JoinType, + ) -> Result<(Vec, Vec)> { + join_collect_with_options(left, right, on, operator, join_type).await + } + + async fn join_collect_with_options( + left: Arc, + right: Arc, + on: (PhysicalExprRef, PhysicalExprRef), + operator: Operator, + join_type: JoinType, + ) -> Result<(Vec, Vec)> { + let task_ctx = Arc::new(TaskContext::default()); + let join = join(left, right, on, operator, join_type)?; + let columns = columns(&join.schema()); + + let stream = join.execute(0, task_ctx)?; + let batches = common::collect(stream).await?; + Ok((columns, batches)) + } + + #[tokio::test] + async fn join_inner_less_than() -> Result<()> { + // +----+----+----+ + // | a1 | b1 | c1 | + // +----+----+----+ + // | 1 | 3 | 7 | + // | 2 | 2 | 8 | + // | 3 | 1 | 9 | + // +----+----+----+ + let left = build_table( + ("a1", &vec![1, 2, 3]), + ("b1", &vec![3, 2, 1]), // this has a repetition + ("c1", &vec![7, 8, 9]), + ); + + // +----+----+----+ + // | a2 | b1 | c2 | + // +----+----+----+ + // | 10 | 2 | 70 | + // | 20 | 3 | 80 | + // | 30 | 4 | 90 | + // +----+----+----+ + let right = build_table( + ("a2", &vec![10, 20, 30]), + ("b1", &vec![2, 3, 4]), + ("c2", &vec![70, 80, 90]), + ); + + let on = ( + Arc::new(Column::new_with_schema("b1", &left.schema())?) as _, + Arc::new(Column::new_with_schema("b1", &right.schema())?) as _, + ); + + let (_, batches) = + join_collect(left, right, on, Operator::Lt, JoinType::Inner).await?; + + assert_snapshot!(batches_to_string(&batches), @r#" + +----+----+----+----+----+----+ + | a1 | b1 | c1 | a2 | b1 | c2 | + +----+----+----+----+----+----+ + | 1 | 3 | 7 | 30 | 4 | 90 | + | 2 | 2 | 8 | 30 | 4 | 90 | + | 3 | 1 | 9 | 30 | 4 | 90 | + | 2 | 2 | 8 | 20 | 3 | 80 | + | 3 | 1 | 9 | 20 | 3 | 80 | + | 3 | 1 | 9 | 10 | 2 | 70 | + +----+----+----+----+----+----+ + "#); + Ok(()) + } + + #[tokio::test] + async fn join_inner_less_than_unsorted() -> Result<()> { + // +----+----+----+ + // | a1 | b1 | c1 | + // +----+----+----+ + // | 1 | 3 | 7 | + // | 2 | 2 | 8 | + // | 3 | 1 | 9 | + // +----+----+----+ + let left = build_table( + ("a1", &vec![1, 2, 3]), + ("b1", &vec![3, 2, 1]), // this has a repetition + ("c1", &vec![7, 8, 9]), + ); + + // +----+----+----+ + // | a2 | b1 | c2 | + // +----+----+----+ + // | 10 | 3 | 70 | + // | 20 | 2 | 80 | + // | 30 | 4 | 90 | + // +----+----+----+ + let right = build_table( + ("a2", &vec![10, 20, 30]), + ("b1", &vec![3, 2, 4]), + ("c2", &vec![70, 80, 90]), + ); + + let on = ( + Arc::new(Column::new_with_schema("b1", &left.schema())?) as _, + Arc::new(Column::new_with_schema("b1", &right.schema())?) as _, + ); + + let (_, batches) = + join_collect(left, right, on, Operator::Lt, JoinType::Inner).await?; + + assert_snapshot!(batches_to_string(&batches), @r#" + +----+----+----+----+----+----+ + | a1 | b1 | c1 | a2 | b1 | c2 | + +----+----+----+----+----+----+ + | 1 | 3 | 7 | 30 | 4 | 90 | + | 2 | 2 | 8 | 30 | 4 | 90 | + | 3 | 1 | 9 | 30 | 4 | 90 | + | 2 | 2 | 8 | 10 | 3 | 70 | + | 3 | 1 | 9 | 10 | 3 | 70 | + | 3 | 1 | 9 | 20 | 2 | 80 | + +----+----+----+----+----+----+ + "#); + Ok(()) + } + + #[tokio::test] + async fn join_inner_greater_than_equal_to() -> Result<()> { + // +----+----+----+ + // | a1 | b1 | c1 | + // +----+----+----+ + // | 1 | 2 | 7 | + // | 2 | 3 | 8 | + // | 3 | 4 | 9 | + // +----+----+----+ + let left = build_table( + ("a1", &vec![1, 2, 3]), + ("b1", &vec![2, 3, 4]), + ("c1", &vec![7, 8, 9]), + ); + + // +----+----+----+ + // | a2 | b1 | c2 | + // +----+----+----+ + // | 10 | 3 | 70 | + // | 20 | 2 | 80 | + // | 30 | 1 | 90 | + // +----+----+----+ + let right = build_table( + ("a2", &vec![10, 20, 30]), + ("b1", &vec![3, 2, 1]), + ("c2", &vec![70, 80, 90]), + ); + + let on = ( + Arc::new(Column::new_with_schema("b1", &left.schema())?) as _, + Arc::new(Column::new_with_schema("b1", &right.schema())?) as _, + ); + + let (_, batches) = + join_collect(left, right, on, Operator::GtEq, JoinType::Inner).await?; + + assert_snapshot!(batches_to_string(&batches), @r#" + +----+----+----+----+----+----+ + | a1 | b1 | c1 | a2 | b1 | c2 | + +----+----+----+----+----+----+ + | 1 | 2 | 7 | 30 | 1 | 90 | + | 2 | 3 | 8 | 30 | 1 | 90 | + | 3 | 4 | 9 | 30 | 1 | 90 | + | 1 | 2 | 7 | 20 | 2 | 80 | + | 2 | 3 | 8 | 20 | 2 | 80 | + | 3 | 4 | 9 | 20 | 2 | 80 | + | 2 | 3 | 8 | 10 | 3 | 70 | + | 3 | 4 | 9 | 10 | 3 | 70 | + +----+----+----+----+----+----+ + "#); + Ok(()) + } + + #[tokio::test] + async fn join_inner_empty_left() -> Result<()> { + // +----+----+----+ + // | a1 | b1 | c1 | + // +----+----+----+ + // (empty) + // +----+----+----+ + let left = build_table( + ("a1", &Vec::::new()), + ("b1", &Vec::::new()), + ("c1", &Vec::::new()), + ); + + // +----+----+----+ + // | a2 | b1 | c2 | + // +----+----+----+ + // | 1 | 1 | 1 | + // | 2 | 2 | 2 | + // +----+----+----+ + let right = build_table( + ("a2", &vec![1, 2]), + ("b1", &vec![1, 2]), + ("c2", &vec![1, 2]), + ); + + let on = ( + Arc::new(Column::new_with_schema("b1", &left.schema())?) as _, + Arc::new(Column::new_with_schema("b1", &right.schema())?) as _, + ); + let (_, batches) = + join_collect(left, right, on, Operator::LtEq, JoinType::Inner).await?; + assert_snapshot!(batches_to_string(&batches), @r#" + +----+----+----+----+----+----+ + | a1 | b1 | c1 | a2 | b1 | c2 | + +----+----+----+----+----+----+ + +----+----+----+----+----+----+ + "#); + Ok(()) + } + + #[tokio::test] + async fn join_full_greater_than_equal_to() -> Result<()> { + // +----+----+-----+ + // | a1 | b1 | c1 | + // +----+----+-----+ + // | 1 | 1 | 100 | + // | 2 | 2 | 200 | + // +----+----+-----+ + let left = build_table( + ("a1", &vec![1, 2]), + ("b1", &vec![1, 2]), + ("c1", &vec![100, 200]), + ); + + // +----+----+-----+ + // | a2 | b1 | c2 | + // +----+----+-----+ + // | 10 | 3 | 300 | + // | 20 | 2 | 400 | + // +----+----+-----+ + let right = build_table( + ("a2", &vec![10, 20]), + ("b1", &vec![3, 2]), + ("c2", &vec![300, 400]), + ); + + let on = ( + Arc::new(Column::new_with_schema("b1", &left.schema())?) as _, + Arc::new(Column::new_with_schema("b1", &right.schema())?) as _, + ); + + let (_, batches) = + join_collect(left, right, on, Operator::GtEq, JoinType::Full).await?; + + assert_snapshot!(batches_to_string(&batches), @r#" + +----+----+-----+----+----+-----+ + | a1 | b1 | c1 | a2 | b1 | c2 | + +----+----+-----+----+----+-----+ + | 2 | 2 | 200 | 20 | 2 | 400 | + | | | | 10 | 3 | 300 | + | 1 | 1 | 100 | | | | + +----+----+-----+----+----+-----+ + "#); + + Ok(()) + } + + #[tokio::test] + async fn join_left_greater_than() -> Result<()> { + // +----+----+----+ + // | a1 | b1 | c1 | + // +----+----+----+ + // | 1 | 1 | 7 | + // | 2 | 3 | 8 | + // | 3 | 4 | 9 | + // +----+----+----+ + let left = build_table( + ("a1", &vec![1, 2, 3]), + ("b1", &vec![1, 3, 4]), + ("c1", &vec![7, 8, 9]), + ); + + // +----+----+----+ + // | a2 | b1 | c2 | + // +----+----+----+ + // | 10 | 3 | 70 | + // | 20 | 2 | 80 | + // | 30 | 1 | 90 | + // +----+----+----+ + let right = build_table( + ("a2", &vec![10, 20, 30]), + ("b1", &vec![3, 2, 1]), + ("c2", &vec![70, 80, 90]), + ); + + let on = ( + Arc::new(Column::new_with_schema("b1", &left.schema())?) as _, + Arc::new(Column::new_with_schema("b1", &right.schema())?) as _, + ); + + let (_, batches) = + join_collect(left, right, on, Operator::Gt, JoinType::Left).await?; + + assert_snapshot!(batches_to_string(&batches), @r#" + +----+----+----+----+----+----+ + | a1 | b1 | c1 | a2 | b1 | c2 | + +----+----+----+----+----+----+ + | 2 | 3 | 8 | 30 | 1 | 90 | + | 3 | 4 | 9 | 30 | 1 | 90 | + | 2 | 3 | 8 | 20 | 2 | 80 | + | 3 | 4 | 9 | 20 | 2 | 80 | + | 3 | 4 | 9 | 10 | 3 | 70 | + | 1 | 1 | 7 | | | | + +----+----+----+----+----+----+ + "#); + Ok(()) + } + + #[tokio::test] + async fn join_right_greater_than() -> Result<()> { + // +----+----+----+ + // | a1 | b1 | c1 | + // +----+----+----+ + // | 1 | 1 | 7 | + // | 2 | 3 | 8 | + // | 3 | 4 | 9 | + // +----+----+----+ + let left = build_table( + ("a1", &vec![1, 2, 3]), + ("b1", &vec![1, 3, 4]), + ("c1", &vec![7, 8, 9]), + ); + + // +----+----+----+ + // | a2 | b1 | c2 | + // +----+----+----+ + // | 10 | 5 | 70 | + // | 20 | 3 | 80 | + // | 30 | 2 | 90 | + // +----+----+----+ + let right = build_table( + ("a2", &vec![10, 20, 30]), + ("b1", &vec![5, 3, 2]), + ("c2", &vec![70, 80, 90]), + ); + + let on = ( + Arc::new(Column::new_with_schema("b1", &left.schema())?) as _, + Arc::new(Column::new_with_schema("b1", &right.schema())?) as _, + ); + + let (_, batches) = + join_collect(left, right, on, Operator::Gt, JoinType::Right).await?; + + assert_snapshot!(batches_to_string(&batches), @r#" + +----+----+----+----+----+----+ + | a1 | b1 | c1 | a2 | b1 | c2 | + +----+----+----+----+----+----+ + | 2 | 3 | 8 | 30 | 2 | 90 | + | 3 | 4 | 9 | 30 | 2 | 90 | + | 3 | 4 | 9 | 20 | 3 | 80 | + | | | | 10 | 5 | 70 | + +----+----+----+----+----+----+ + "#); + Ok(()) + } + + #[tokio::test] + async fn join_right_less_than() -> Result<()> { + // +----+----+----+ + // | a1 | b1 | c1 | + // +----+----+----+ + // | 1 | 4 | 7 | + // | 2 | 3 | 8 | + // | 3 | 1 | 9 | + // +----+----+----+ + let left = build_table( + ("a1", &vec![1, 2, 3]), + ("b1", &vec![4, 3, 1]), + ("c1", &vec![7, 8, 9]), + ); + + // +----+----+----+ + // | a2 | b1 | c2 | + // +----+----+----+ + // | 10 | 2 | 70 | + // | 20 | 3 | 80 | + // | 30 | 5 | 90 | + // +----+----+----+ + let right = build_table( + ("a2", &vec![10, 20, 30]), + ("b1", &vec![2, 3, 5]), + ("c2", &vec![70, 80, 90]), + ); + + let on = ( + Arc::new(Column::new_with_schema("b1", &left.schema())?) as _, + Arc::new(Column::new_with_schema("b1", &right.schema())?) as _, + ); + + let (_, batches) = + join_collect(left, right, on, Operator::Lt, JoinType::Right).await?; + + assert_snapshot!(batches_to_string(&batches), @r#" + +----+----+----+----+----+----+ + | a1 | b1 | c1 | a2 | b1 | c2 | + +----+----+----+----+----+----+ + | 1 | 4 | 7 | 30 | 5 | 90 | + | 2 | 3 | 8 | 30 | 5 | 90 | + | 3 | 1 | 9 | 30 | 5 | 90 | + | 3 | 1 | 9 | 20 | 3 | 80 | + | 3 | 1 | 9 | 10 | 2 | 70 | + +----+----+----+----+----+----+ + "#); + Ok(()) + } + + #[tokio::test] + async fn join_inner_less_than_equal_with_dups() -> Result<()> { + // +----+----+----+ + // | a1 | b1 | c1 | + // +----+----+----+ + // | 1 | 4 | 7 | + // | 2 | 4 | 8 | + // | 3 | 2 | 9 | + // +----+----+----+ + let left = build_table( + ("a1", &vec![1, 2, 3]), + ("b1", &vec![4, 4, 2]), + ("c1", &vec![7, 8, 9]), + ); + + // +----+----+----+ + // | a2 | b1 | c2 | + // +----+----+----+ + // | 10 | 4 | 70 | + // | 20 | 3 | 80 | + // | 30 | 2 | 90 | + // +----+----+----+ + let right = build_table( + ("a2", &vec![10, 20, 30]), + ("b1", &vec![4, 3, 2]), + ("c2", &vec![70, 80, 90]), + ); + + let on = ( + Arc::new(Column::new_with_schema("b1", &left.schema())?) as _, + Arc::new(Column::new_with_schema("b1", &right.schema())?) as _, + ); + + let (_, batches) = + join_collect(left, right, on, Operator::LtEq, JoinType::Inner).await?; + + // Expected grouping follows right.b1 descending (4, 3, 2) + assert_snapshot!(batches_to_string(&batches), @r#" + +----+----+----+----+----+----+ + | a1 | b1 | c1 | a2 | b1 | c2 | + +----+----+----+----+----+----+ + | 1 | 4 | 7 | 10 | 4 | 70 | + | 2 | 4 | 8 | 10 | 4 | 70 | + | 3 | 2 | 9 | 10 | 4 | 70 | + | 3 | 2 | 9 | 20 | 3 | 80 | + | 3 | 2 | 9 | 30 | 2 | 90 | + +----+----+----+----+----+----+ + "#); + Ok(()) + } + + #[tokio::test] + async fn join_inner_greater_than_unsorted_right() -> Result<()> { + // +----+----+----+ + // | a1 | b1 | c1 | + // +----+----+----+ + // | 1 | 1 | 7 | + // | 2 | 2 | 8 | + // | 3 | 4 | 9 | + // +----+----+----+ + let left = build_table( + ("a1", &vec![1, 2, 3]), + ("b1", &vec![1, 2, 4]), + ("c1", &vec![7, 8, 9]), + ); + + // +----+----+----+ + // | a2 | b1 | c2 | + // +----+----+----+ + // | 10 | 3 | 70 | + // | 20 | 1 | 80 | + // | 30 | 2 | 90 | + // +----+----+----+ + let right = build_table( + ("a2", &vec![10, 20, 30]), + ("b1", &vec![3, 1, 2]), + ("c2", &vec![70, 80, 90]), + ); + + let on = ( + Arc::new(Column::new_with_schema("b1", &left.schema())?) as _, + Arc::new(Column::new_with_schema("b1", &right.schema())?) as _, + ); + + let (_, batches) = + join_collect(left, right, on, Operator::Gt, JoinType::Inner).await?; + + // Grouped by right in ascending evaluation for > (1,2,3) + assert_snapshot!(batches_to_string(&batches), @r#" + +----+----+----+----+----+----+ + | a1 | b1 | c1 | a2 | b1 | c2 | + +----+----+----+----+----+----+ + | 2 | 2 | 8 | 20 | 1 | 80 | + | 3 | 4 | 9 | 20 | 1 | 80 | + | 3 | 4 | 9 | 30 | 2 | 90 | + | 3 | 4 | 9 | 10 | 3 | 70 | + +----+----+----+----+----+----+ + "#); + Ok(()) + } + + #[tokio::test] + async fn join_left_less_than_equal_with_left_nulls_on_no_match() -> Result<()> { + // +----+----+----+ + // | a1 | b1 | c1 | + // +----+----+----+ + // | 1 | 5 | 7 | + // | 2 | 4 | 8 | + // | 3 | 1 | 9 | + // +----+----+----+ + let left = build_table( + ("a1", &vec![1, 2, 3]), + ("b1", &vec![5, 4, 1]), + ("c1", &vec![7, 8, 9]), + ); + + // +----+----+----+ + // | a2 | b1 | c2 | + // +----+----+----+ + // | 10 | 3 | 70 | + // +----+----+----+ + let right = build_table(("a2", &vec![10]), ("b1", &vec![3]), ("c2", &vec![70])); + + let on = ( + Arc::new(Column::new_with_schema("b1", &left.schema())?) as _, + Arc::new(Column::new_with_schema("b1", &right.schema())?) as _, + ); + + let (_, batches) = + join_collect(left, right, on, Operator::LtEq, JoinType::Left).await?; + + assert_snapshot!(batches_to_string(&batches), @r#" + +----+----+----+----+----+----+ + | a1 | b1 | c1 | a2 | b1 | c2 | + +----+----+----+----+----+----+ + | 3 | 1 | 9 | 10 | 3 | 70 | + | 1 | 5 | 7 | | | | + | 2 | 4 | 8 | | | | + +----+----+----+----+----+----+ + "#); + Ok(()) + } + + #[tokio::test] + async fn join_right_greater_than_equal_with_right_nulls_on_no_match() -> Result<()> { + // +----+----+----+ + // | a1 | b1 | c1 | + // +----+----+----+ + // | 1 | 1 | 7 | + // | 2 | 2 | 8 | + // +----+----+----+ + let left = build_table( + ("a1", &vec![1, 2]), + ("b1", &vec![1, 2]), + ("c1", &vec![7, 8]), + ); + + // +----+----+----+ + // | a2 | b1 | c2 | + // +----+----+----+ + // | 10 | 3 | 70 | + // | 20 | 5 | 80 | + // +----+----+----+ + let right = build_table( + ("a2", &vec![10, 20]), + ("b1", &vec![3, 5]), + ("c2", &vec![70, 80]), + ); + + let on = ( + Arc::new(Column::new_with_schema("b1", &left.schema())?) as _, + Arc::new(Column::new_with_schema("b1", &right.schema())?) as _, + ); + + let (_, batches) = + join_collect(left, right, on, Operator::GtEq, JoinType::Right).await?; + + assert_snapshot!(batches_to_string(&batches), @r#" + +----+----+----+----+----+----+ + | a1 | b1 | c1 | a2 | b1 | c2 | + +----+----+----+----+----+----+ + | | | | 10 | 3 | 70 | + | | | | 20 | 5 | 80 | + +----+----+----+----+----+----+ + "#); + Ok(()) + } + + #[tokio::test] + async fn join_inner_single_row_left_less_than() -> Result<()> { + let left = build_table(("a1", &vec![42]), ("b1", &vec![5]), ("c1", &vec![999])); + + let right = build_table( + ("a2", &vec![10, 20, 30]), + ("b1", &vec![1, 5, 7]), + ("c2", &vec![70, 80, 90]), + ); + + let on = ( + Arc::new(Column::new_with_schema("b1", &left.schema())?) as _, + Arc::new(Column::new_with_schema("b1", &right.schema())?) as _, + ); + + let (_, batches) = + join_collect(left, right, on, Operator::Lt, JoinType::Inner).await?; + + assert_snapshot!(batches_to_string(&batches), @r#" + +----+----+-----+----+----+----+ + | a1 | b1 | c1 | a2 | b1 | c2 | + +----+----+-----+----+----+----+ + | 42 | 5 | 999 | 30 | 7 | 90 | + +----+----+-----+----+----+----+ + "#); + Ok(()) + } + + #[tokio::test] + async fn join_inner_empty_right() -> Result<()> { + let left = build_table( + ("a1", &vec![1, 2, 3]), + ("b1", &vec![1, 2, 3]), + ("c1", &vec![7, 8, 9]), + ); + + let right = build_table( + ("a2", &Vec::::new()), + ("b1", &Vec::::new()), + ("c2", &Vec::::new()), + ); + + let on = ( + Arc::new(Column::new_with_schema("b1", &left.schema())?) as _, + Arc::new(Column::new_with_schema("b1", &right.schema())?) as _, + ); + + let (_, batches) = + join_collect(left, right, on, Operator::Gt, JoinType::Inner).await?; + + assert_snapshot!(batches_to_string(&batches), @r#" + +----+----+----+----+----+----+ + | a1 | b1 | c1 | a2 | b1 | c2 | + +----+----+----+----+----+----+ + +----+----+----+----+----+----+ + "#); + Ok(()) + } + + #[tokio::test] + async fn join_date32_inner_less_than() -> Result<()> { + // +----+-------+----+ + // | a1 | b1 | c1 | + // +----+-------+----+ + // | 1 | 19107 | 7 | + // | 2 | 19107 | 8 | + // | 3 | 19105 | 9 | + // +----+-------+----+ + let left = build_date_table( + ("a1", &vec![1, 2, 3]), + ("b1", &vec![19107, 19107, 19105]), + ("c1", &vec![7, 8, 9]), + ); + + // +----+-------+----+ + // | a2 | b1 | c2 | + // +----+-------+----+ + // | 10 | 19105 | 70 | + // | 20 | 19103 | 80 | + // | 30 | 19107 | 90 | + // +----+-------+----+ + let right = build_date_table( + ("a2", &vec![10, 20, 30]), + ("b1", &vec![19105, 19103, 19107]), + ("c2", &vec![70, 80, 90]), + ); + + let on = ( + Arc::new(Column::new_with_schema("b1", &left.schema())?) as _, + Arc::new(Column::new_with_schema("b1", &right.schema())?) as _, + ); + + let (_, batches) = + join_collect(left, right, on, Operator::Lt, JoinType::Inner).await?; + + assert_snapshot!(batches_to_string(&batches), @r#" + +------------+------------+------------+------------+------------+------------+ + | a1 | b1 | c1 | a2 | b1 | c2 | + +------------+------------+------------+------------+------------+------------+ + | 1970-01-04 | 2022-04-23 | 1970-01-10 | 1970-01-31 | 2022-04-25 | 1970-04-01 | + +------------+------------+------------+------------+------------+------------+ + "#); + Ok(()) + } + + #[tokio::test] + async fn join_date64_inner_less_than() -> Result<()> { + // +----+---------------+----+ + // | a1 | b1 | c1 | + // +----+---------------+----+ + // | 1 | 1650903441000 | 7 | + // | 2 | 1650903441000 | 8 | + // | 3 | 1650703441000 | 9 | + // +----+---------------+----+ + let left = build_date64_table( + ("a1", &vec![1, 2, 3]), + ("b1", &vec![1650903441000, 1650903441000, 1650703441000]), + ("c1", &vec![7, 8, 9]), + ); + + // +----+---------------+----+ + // | a2 | b1 | c2 | + // +----+---------------+----+ + // | 10 | 1650703441000 | 70 | + // | 20 | 1650503441000 | 80 | + // | 30 | 1650903441000 | 90 | + // +----+---------------+----+ + let right = build_date64_table( + ("a2", &vec![10, 20, 30]), + ("b1", &vec![1650703441000, 1650503441000, 1650903441000]), + ("c2", &vec![70, 80, 90]), + ); + + let on = ( + Arc::new(Column::new_with_schema("b1", &left.schema())?) as _, + Arc::new(Column::new_with_schema("b1", &right.schema())?) as _, + ); + + let (_, batches) = + join_collect(left, right, on, Operator::Lt, JoinType::Inner).await?; + + assert_snapshot!(batches_to_string(&batches), @r#" + +-------------------------+---------------------+-------------------------+-------------------------+---------------------+-------------------------+ + | a1 | b1 | c1 | a2 | b1 | c2 | + +-------------------------+---------------------+-------------------------+-------------------------+---------------------+-------------------------+ + | 1970-01-01T00:00:00.003 | 2022-04-23T08:44:01 | 1970-01-01T00:00:00.009 | 1970-01-01T00:00:00.030 | 2022-04-25T16:17:21 | 1970-01-01T00:00:00.090 | + +-------------------------+---------------------+-------------------------+-------------------------+---------------------+-------------------------+ + "#); + Ok(()) + } + + #[tokio::test] + async fn join_date64_right_less_than() -> Result<()> { + // +----+---------------+----+ + // | a1 | b1 | c1 | + // +----+---------------+----+ + // | 1 | 1650903441000 | 7 | + // | 2 | 1650703441000 | 8 | + // +----+---------------+----+ + let left = build_date64_table( + ("a1", &vec![1, 2]), + ("b1", &vec![1650903441000, 1650703441000]), + ("c1", &vec![7, 8]), + ); + + // +----+---------------+----+ + // | a2 | b1 | c2 | + // +----+---------------+----+ + // | 10 | 1650703441000 | 80 | + // | 20 | 1650903441000 | 90 | + // +----+---------------+----+ + let right = build_date64_table( + ("a2", &vec![10, 20]), + ("b1", &vec![1650703441000, 1650903441000]), + ("c2", &vec![80, 90]), + ); + + let on = ( + Arc::new(Column::new_with_schema("b1", &left.schema())?) as _, + Arc::new(Column::new_with_schema("b1", &right.schema())?) as _, + ); + + let (_, batches) = + join_collect(left, right, on, Operator::Lt, JoinType::Right).await?; + + assert_snapshot!(batches_to_string(&batches), @r#" + +-------------------------+---------------------+-------------------------+-------------------------+---------------------+-------------------------+ + | a1 | b1 | c1 | a2 | b1 | c2 | + +-------------------------+---------------------+-------------------------+-------------------------+---------------------+-------------------------+ + | 1970-01-01T00:00:00.002 | 2022-04-23T08:44:01 | 1970-01-01T00:00:00.008 | 1970-01-01T00:00:00.020 | 2022-04-25T16:17:21 | 1970-01-01T00:00:00.090 | + | | | | 1970-01-01T00:00:00.010 | 2022-04-23T08:44:01 | 1970-01-01T00:00:00.080 | + +-------------------------+---------------------+-------------------------+-------------------------+---------------------+-------------------------+ +"#); + Ok(()) + } +} diff --git a/datafusion/physical-plan/src/joins/piecewise_merge_join/exec.rs b/datafusion/physical-plan/src/joins/piecewise_merge_join/exec.rs new file mode 100644 index 000000000000..987f3e9df45a --- /dev/null +++ b/datafusion/physical-plan/src/joins/piecewise_merge_join/exec.rs @@ -0,0 +1,748 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::array::Array; +use arrow::{ + array::{ArrayRef, BooleanBufferBuilder, RecordBatch}, + compute::concat_batches, + util::bit_util, +}; +use arrow_schema::{SchemaRef, SortOptions}; +use datafusion_common::not_impl_err; +use datafusion_common::{internal_err, JoinSide, Result}; +use datafusion_execution::{ + memory_pool::{MemoryConsumer, MemoryReservation}, + SendableRecordBatchStream, +}; +use datafusion_expr::{JoinType, Operator}; +use datafusion_physical_expr::equivalence::join_equivalence_properties; +use datafusion_physical_expr::{ + Distribution, LexOrdering, OrderingRequirements, PhysicalExpr, PhysicalExprRef, + PhysicalSortExpr, +}; +use datafusion_physical_expr_common::physical_expr::fmt_sql; +use futures::TryStreamExt; +use parking_lot::Mutex; +use std::fmt::Formatter; +use std::sync::atomic::AtomicUsize; +use std::sync::Arc; + +use crate::execution_plan::{boundedness_from_children, EmissionType}; + +use crate::joins::piecewise_merge_join::classic_join::{ + ClassicPWMJStream, PiecewiseMergeJoinStreamState, +}; +use crate::joins::piecewise_merge_join::utils::{ + build_visited_indices_map, is_existence_join, is_right_existence_join, +}; +use crate::joins::utils::asymmetric_join_output_partitioning; +use crate::{ + joins::{ + utils::{build_join_schema, BuildProbeJoinMetrics, OnceAsync, OnceFut}, + SharedBitmapBuilder, + }, + metrics::ExecutionPlanMetricsSet, + spill::get_record_batch_memory_size, + ExecutionPlan, PlanProperties, +}; +use crate::{DisplayAs, DisplayFormatType, ExecutionPlanProperties}; + +/// `PiecewiseMergeJoinExec` is a join execution plan that only evaluates single range filter and show much +/// better performance for these workloads than `NestedLoopJoin` +/// +/// The physical planner will choose to evaluate this join when there is only one comparison filter. This +/// is a binary expression which contains [`Operator::Lt`], [`Operator::LtEq`], [`Operator::Gt`], and +/// [`Operator::GtEq`].: +/// Examples: +/// - `col0` < `colb`, `col0` <= `colb`, `col0` > `colb`, `col0` >= `colb` +/// +/// # Execution Plan Inputs +/// For `PiecewiseMergeJoin` we label all right inputs as the `streamed' side and the left outputs as the +/// 'buffered' side. +/// +/// `PiecewiseMergeJoin` takes a sorted input for the side to be buffered and is able to sort streamed record +/// batches during processing. Sorted input must specifically be ascending/descending based on the operator. +/// +/// # Algorithms +/// Classic joins are processed differently compared to existence joins. +/// +/// ## Classic Joins (Inner, Full, Left, Right) +/// For classic joins we buffer the build side and stream the probe side (the "probe" side). +/// Both sides are sorted so that we can iterate from index 0 to the end on each side. This ordering ensures +/// that when we find the first matching pair of rows, we can emit the current stream row joined with all remaining +/// probe rows from the match position onward, without rescanning earlier probe rows. +/// +/// For `<` and `<=` operators, both inputs are sorted in **descending** order, while for `>` and `>=` operators +/// they are sorted in **ascending** order. This choice ensures that the pointer on the buffered side can advance +/// monotonically as we stream new batches from the stream side. +/// +/// The streamed side may arrive unsorted, so this operator sorts each incoming batch in memory before +/// processing. The buffered side is required to be globally sorted; the plan declares this requirement +/// in `requires_input_order`, which allows the optimizer to automatically insert a `SortExec` on that side if needed. +/// By the time this operator runs, the buffered side is guaranteed to be in the proper order. +/// +/// The pseudocode for the algorithm looks like this: +/// +/// ```text +/// for stream_row in stream_batch: +/// for buffer_row in buffer_batch: +/// if compare(stream_row, probe_row): +/// output stream_row X buffer_batch[buffer_row:] +/// else: +/// continue +/// ``` +/// +/// The algorithm uses the streamed side (larger) to drive the loop. This is due to every row on the stream side iterating +/// the buffered side to find every first match. By doing this, each match can output more result so that output +/// handling can be better vectorized for performance. +/// +/// Here is an example: +/// +/// We perform a `JoinType::Left` with these two batches and the operator being `Operator::Lt`(<). For each +/// row on the streamed side we move a pointer on the buffered until it matches the condition. Once we reach +/// the row which matches (in this case with row 1 on streamed will have its first match on row 2 on +/// buffered; 100 < 200 is true), we can emit all rows after that match. We can emit the rows like this because +/// if the batch is sorted in ascending order, every subsequent row will also satisfy the condition as they will +/// all be larger values. +/// +/// ```text +/// SQL statement: +/// SELECT * +/// FROM (VALUES (100), (200), (500)) AS streamed(a) +/// LEFT JOIN (VALUES (100), (200), (200), (300), (400)) AS buffered(b) +/// ON streamed.a < buffered.b; +/// +/// Processing Row 1: +/// +/// Sorted Buffered Side Sorted Streamed Side +/// ┌──────────────────┐ ┌──────────────────┐ +/// 1 │ 100 │ 1 │ 100 │ +/// ├──────────────────┤ ├──────────────────┤ +/// 2 │ 200 │ ─┐ 2 │ 200 │ +/// ├──────────────────┤ │ For row 1 on streamed side with ├──────────────────┤ +/// 3 │ 200 │ │ value 100, we emit rows 2 - 5. 3 │ 500 │ +/// ├──────────────────┤ │ as matches when the operator is └──────────────────┘ +/// 4 │ 300 │ │ `Operator::Lt` (<) Emitting all +/// ├──────────────────┤ │ rows after the first match (row +/// 5 │ 400 │ ─┘ 2 buffered side; 100 < 200) +/// └──────────────────┘ +/// +/// Processing Row 2: +/// By sorting the streamed side we know +/// +/// Sorted Buffered Side Sorted Streamed Side +/// ┌──────────────────┐ ┌──────────────────┐ +/// 1 │ 100 │ 1 │ 100 │ +/// ├──────────────────┤ ├──────────────────┤ +/// 2 │ 200 │ <- Start here when probing for the 2 │ 200 │ +/// ├──────────────────┤ streamed side row 2. ├──────────────────┤ +/// 3 │ 200 │ 3 │ 500 │ +/// ├──────────────────┤ └──────────────────┘ +/// 4 │ 300 │ +/// ├──────────────────┤ +/// 5 │ 400 │ +/// └──────────────────┘ +/// +/// ``` +/// +/// ## Existence Joins (Semi, Anti, Mark) +/// Existence joins are made magnitudes of times faster with a `PiecewiseMergeJoin` as we only need to find +/// the min/max value of the streamed side to be able to emit all matches on the buffered side. By putting +/// the side we need to mark onto the sorted buffer side, we can emit all these matches at once. +/// +/// For less than operations (`<`) both inputs are to be sorted in descending order and vice versa for greater +/// than (`>`) operations. `SortExec` is used to enforce sorting on the buffered side and streamed side does not +/// need to be sorted due to only needing to find the min/max. +/// +/// For Left Semi, Anti, and Mark joins we swap the inputs so that the marked side is on the buffered side. +/// +/// The pseudocode for the algorithm looks like this: +/// +/// ```text +/// // Using the example of a less than `<` operation +/// let max = max_batch(streamed_batch) +/// +/// for buffer_row in buffer_batch: +/// if buffer_row < max: +/// output buffer_batch[buffer_row:] +/// ``` +/// +/// Only need to find the min/max value and iterate through the buffered side once. +/// +/// Here is an example: +/// We perform a `JoinType::LeftSemi` with these two batches and the operator being `Operator::Lt`(<). Because +/// the operator is `Operator::Lt` we can find the minimum value in the streamed side; in this case it is 200. +/// We can then advance a pointer from the start of the buffer side until we find the first value that satisfies +/// the predicate. All rows after that first matched value satisfy the condition 200 < x so we can mark all of +/// those rows as matched. +/// +/// ```text +/// SQL statement: +/// SELECT * +/// FROM (VALUES (500), (200), (300)) AS streamed(a) +/// LEFT SEMI JOIN (VALUES (100), (200), (200), (300), (400)) AS buffered(b) +/// ON streamed.a < buffered.b; +/// +/// Sorted Buffered Side Unsorted Streamed Side +/// ┌──────────────────┐ ┌──────────────────┐ +/// 1 │ 100 │ 1 │ 500 │ +/// ├──────────────────┤ ├──────────────────┤ +/// 2 │ 200 │ 2 │ 200 │ +/// ├──────────────────┤ ├──────────────────┤ +/// 3 │ 200 │ 3 │ 300 │ +/// ├──────────────────┤ └──────────────────┘ +/// 4 │ 300 │ ─┐ +/// ├──────────────────┤ | We emit matches for row 4 - 5 +/// 5 │ 400 │ ─┘ on the buffered side. +/// └──────────────────┘ +/// min value: 200 +/// ``` +/// +/// For both types of joins, the buffered side must be sorted ascending for `Operator::Lt` (<) or +/// `Operator::LtEq` (<=) and descending for `Operator::Gt` (>) or `Operator::GtEq` (>=). +/// +/// # Partitioning Logic +/// Piecewise Merge Join requires one buffered side partition + round robin partitioned stream side. A counter +/// is used in the buffered side to coordinate when all streamed partitions are finished execution. This allows +/// for processing the rest of the unmatched rows for Left and Full joins. The last partition that finishes +/// execution will be responsible for outputting the unmatched rows. +/// +/// # Performance Explanation (cost) +/// Piecewise Merge Join is used over Nested Loop Join due to its superior performance. Here is the breakdown: +/// +/// R: Buffered Side +/// S: Streamed Side +/// +/// ## Piecewise Merge Join (PWMJ) +/// +/// # Classic Join: +/// Requires sorting the probe side and, for each probe row, scanning the buffered side until the first match +/// is found. +/// Complexity: `O(sort(S) + num_of_batches(|S|) * scan(R))`. +/// +/// # Mark Join: +/// Sorts the probe side, then computes the min/max range of the probe keys and scans the buffered side only +/// within that range. +/// Complexity: `O(|S| + scan(R[range]))`. +/// +/// ## Nested Loop Join +/// Compares every row from `S` with every row from `R`. +/// Complexity: `O(|S| * |R|)`. +/// +/// ## Nested Loop Join +/// Always going to be probe (O(S) * O(R)). +/// +/// # Further Reference Material +/// DuckDB blog on Range Joins: [Range Joins in DuckDB](https://duckdb.org/2022/05/27/iejoin.html) +#[derive(Debug)] +pub struct PiecewiseMergeJoinExec { + /// Left buffered execution plan + pub buffered: Arc, + /// Right streamed execution plan + pub streamed: Arc, + /// The two expressions being compared + pub on: (Arc, Arc), + /// Comparison operator in the range predicate + pub operator: Operator, + /// How the join is performed + pub join_type: JoinType, + /// The schema once the join is applied + schema: SchemaRef, + /// Buffered data + buffered_fut: OnceAsync, + /// Execution metrics + metrics: ExecutionPlanMetricsSet, + + /// Sort expressions - See above for more details [`PiecewiseMergeJoinExec`] + /// + /// The left sort order, descending for `<`, `<=` operations + ascending for `>`, `>=` operations + left_child_plan_required_order: LexOrdering, + /// The right sort order, descending for `<`, `<=` operations + ascending for `>`, `>=` operations + /// Unsorted for mark joins + #[allow(unused)] + right_batch_required_orders: LexOrdering, + + /// This determines the sort order of all join columns used in sorting the stream and buffered execution plans. + sort_options: SortOptions, + /// Cache holding plan properties like equivalences, output partitioning etc. + cache: PlanProperties, + /// Number of partitions to process + num_partitions: usize, +} + +impl PiecewiseMergeJoinExec { + pub fn try_new( + buffered: Arc, + streamed: Arc, + on: (Arc, Arc), + operator: Operator, + join_type: JoinType, + num_partitions: usize, + ) -> Result { + // TODO: Implement existence joins for PiecewiseMergeJoin + if is_existence_join(join_type) { + return not_impl_err!( + "Existence Joins are currently not supported for PiecewiseMergeJoin" + ); + } + + // Take the operator and enforce a sort order on the streamed + buffered side based on + // the operator type. + let sort_options = match operator { + Operator::Lt | Operator::LtEq => { + // For left existence joins the inputs will be swapped so the sort + // options are switched + if is_right_existence_join(join_type) { + SortOptions::new(false, true) + } else { + SortOptions::new(true, true) + } + } + Operator::Gt | Operator::GtEq => { + if is_right_existence_join(join_type) { + SortOptions::new(true, true) + } else { + SortOptions::new(false, true) + } + } + _ => { + return internal_err!( + "Cannot contain non-range operator in PiecewiseMergeJoinExec" + ) + } + }; + + // Give the same `sort_option for comparison later` + let left_child_plan_required_order = + vec![PhysicalSortExpr::new(Arc::clone(&on.0), sort_options)]; + let right_batch_required_orders = + vec![PhysicalSortExpr::new(Arc::clone(&on.1), sort_options)]; + + let Some(left_child_plan_required_order) = + LexOrdering::new(left_child_plan_required_order) + else { + return internal_err!( + "PiecewiseMergeJoinExec requires valid sort expressions for its left side" + ); + }; + let Some(right_batch_required_orders) = + LexOrdering::new(right_batch_required_orders) + else { + return internal_err!( + "PiecewiseMergeJoinExec requires valid sort expressions for its right side" + ); + }; + + let buffered_schema = buffered.schema(); + let streamed_schema = streamed.schema(); + + // Create output schema for the join + let schema = + Arc::new(build_join_schema(&buffered_schema, &streamed_schema, &join_type).0); + let cache = Self::compute_properties( + &buffered, + &streamed, + Arc::clone(&schema), + join_type, + &on, + )?; + + Ok(Self { + streamed, + buffered, + on, + operator, + join_type, + schema, + buffered_fut: Default::default(), + metrics: ExecutionPlanMetricsSet::new(), + left_child_plan_required_order, + right_batch_required_orders, + sort_options, + cache, + num_partitions, + }) + } + + /// Reference to buffered side execution plan + pub fn buffered(&self) -> &Arc { + &self.buffered + } + + /// Reference to streamed side execution plan + pub fn streamed(&self) -> &Arc { + &self.streamed + } + + /// Join type + pub fn join_type(&self) -> JoinType { + self.join_type + } + + /// Reference to sort options + pub fn sort_options(&self) -> &SortOptions { + &self.sort_options + } + + /// Get probe side (streamed side) for the PiecewiseMergeJoin + /// In current implementation, probe side is determined according to join type. + pub fn probe_side(join_type: &JoinType) -> JoinSide { + match join_type { + JoinType::Right + | JoinType::Inner + | JoinType::Full + | JoinType::RightSemi + | JoinType::RightAnti + | JoinType::RightMark => JoinSide::Right, + JoinType::Left + | JoinType::LeftAnti + | JoinType::LeftSemi + | JoinType::LeftMark => JoinSide::Left, + } + } + + pub fn compute_properties( + buffered: &Arc, + streamed: &Arc, + schema: SchemaRef, + join_type: JoinType, + join_on: &(PhysicalExprRef, PhysicalExprRef), + ) -> Result { + let eq_properties = join_equivalence_properties( + buffered.equivalence_properties().clone(), + streamed.equivalence_properties().clone(), + &join_type, + schema, + &Self::maintains_input_order(join_type), + Some(Self::probe_side(&join_type)), + std::slice::from_ref(join_on), + )?; + + let output_partitioning = + asymmetric_join_output_partitioning(buffered, streamed, &join_type)?; + + Ok(PlanProperties::new( + eq_properties, + output_partitioning, + EmissionType::Incremental, + boundedness_from_children([buffered, streamed]), + )) + } + + // TODO: Add input order. Now they're all `false` indicating it will not maintain the input order. + // However, for certain join types the order is maintained. This can be updated in the future after + // more testing. + fn maintains_input_order(join_type: JoinType) -> Vec { + match join_type { + // The existence side is expected to come in sorted + JoinType::LeftSemi | JoinType::LeftAnti | JoinType::LeftMark => { + vec![false, false] + } + JoinType::RightSemi | JoinType::RightAnti | JoinType::RightMark => { + vec![false, false] + } + // Left, Right, Full, Inner Join is not guaranteed to maintain + // input order as the streamed side will be sorted during + // execution for `PiecewiseMergeJoin` + _ => vec![false, false], + } + } + + // TODO + pub fn swap_inputs(&self) -> Result> { + todo!() + } +} + +impl ExecutionPlan for PiecewiseMergeJoinExec { + fn name(&self) -> &str { + "PiecewiseMergeJoinExec" + } + + fn as_any(&self) -> &dyn std::any::Any { + self + } + + fn properties(&self) -> &PlanProperties { + &self.cache + } + + fn children(&self) -> Vec<&Arc> { + vec![&self.buffered, &self.streamed] + } + + fn required_input_distribution(&self) -> Vec { + vec![ + Distribution::SinglePartition, + Distribution::UnspecifiedDistribution, + ] + } + + fn required_input_ordering(&self) -> Vec> { + // Existence joins don't need to be sorted on one side. + if is_right_existence_join(self.join_type) { + unimplemented!() + } else { + // Sort the right side in memory, so we do not need to enforce any sorting + vec![ + Some(OrderingRequirements::from( + self.left_child_plan_required_order.clone(), + )), + None, + ] + } + } + + fn with_new_children( + self: Arc, + children: Vec>, + ) -> Result> { + match &children[..] { + [left, right] => Ok(Arc::new(PiecewiseMergeJoinExec::try_new( + Arc::clone(left), + Arc::clone(right), + self.on.clone(), + self.operator, + self.join_type, + self.num_partitions, + )?)), + _ => internal_err!( + "PiecewiseMergeJoin should have 2 children, found {}", + children.len() + ), + } + } + + fn execute( + &self, + partition: usize, + context: Arc, + ) -> Result { + let on_buffered = Arc::clone(&self.on.0); + let on_streamed = Arc::clone(&self.on.1); + + let metrics = BuildProbeJoinMetrics::new(partition, &self.metrics); + let buffered_fut = self.buffered_fut.try_once(|| { + let reservation = MemoryConsumer::new("PiecewiseMergeJoinInput") + .register(context.memory_pool()); + + let buffered_stream = self.buffered.execute(0, Arc::clone(&context))?; + Ok(build_buffered_data( + buffered_stream, + Arc::clone(&on_buffered), + metrics.clone(), + reservation, + build_visited_indices_map(self.join_type), + self.num_partitions, + )) + })?; + + let streamed = self.streamed.execute(partition, Arc::clone(&context))?; + + let batch_size = context.session_config().batch_size(); + + // TODO: Add existence joins + this is guarded at physical planner + if is_existence_join(self.join_type()) { + unreachable!() + } else { + Ok(Box::pin(ClassicPWMJStream::try_new( + Arc::clone(&self.schema), + on_streamed, + self.join_type, + self.operator, + streamed, + BufferedSide::Initial(BufferedSideInitialState { buffered_fut }), + PiecewiseMergeJoinStreamState::WaitBufferedSide, + self.sort_options, + metrics, + batch_size, + ))) + } + } +} + +impl DisplayAs for PiecewiseMergeJoinExec { + fn fmt_as(&self, t: DisplayFormatType, f: &mut Formatter) -> std::fmt::Result { + let on_str = format!( + "({} {} {})", + fmt_sql(self.on.0.as_ref()), + self.operator, + fmt_sql(self.on.1.as_ref()) + ); + + match t { + DisplayFormatType::Default | DisplayFormatType::Verbose => { + write!( + f, + "PiecewiseMergeJoin: operator={:?}, join_type={:?}, on={}", + self.operator, self.join_type, on_str + ) + } + + DisplayFormatType::TreeRender => { + writeln!(f, "operator={:?}", self.operator)?; + if self.join_type != JoinType::Inner { + writeln!(f, "join_type={:?}", self.join_type)?; + } + writeln!(f, "on={on_str}") + } + } + } +} + +async fn build_buffered_data( + buffered: SendableRecordBatchStream, + on_buffered: PhysicalExprRef, + metrics: BuildProbeJoinMetrics, + reservation: MemoryReservation, + build_map: bool, + remaining_partitions: usize, +) -> Result { + let schema = buffered.schema(); + + // Combine batches and record number of rows + let initial = (Vec::new(), 0, metrics, reservation); + let (batches, num_rows, metrics, mut reservation) = buffered + .try_fold(initial, |mut acc, batch| async { + let batch_size = get_record_batch_memory_size(&batch); + acc.3.try_grow(batch_size)?; + acc.2.build_mem_used.add(batch_size); + acc.2.build_input_batches.add(1); + acc.2.build_input_rows.add(batch.num_rows()); + // Update row count + acc.1 += batch.num_rows(); + // Push batch to output + acc.0.push(batch); + Ok(acc) + }) + .await?; + + let single_batch = concat_batches(&schema, batches.iter())?; + + // Evaluate physical expression on the buffered side. + let buffered_values = on_buffered + .evaluate(&single_batch)? + .into_array(single_batch.num_rows())?; + + // We add the single batch size + the memory of the join keys + // size of the size estimation + let size_estimation = get_record_batch_memory_size(&single_batch) + + buffered_values.get_array_memory_size(); + reservation.try_grow(size_estimation)?; + metrics.build_mem_used.add(size_estimation); + + // Created visited indices bitmap only if the join type requires it + let visited_indices_bitmap = if build_map { + let bitmap_size = bit_util::ceil(single_batch.num_rows(), 8); + reservation.try_grow(bitmap_size)?; + metrics.build_mem_used.add(bitmap_size); + + let mut bitmap_buffer = BooleanBufferBuilder::new(single_batch.num_rows()); + bitmap_buffer.append_n(num_rows, false); + bitmap_buffer + } else { + BooleanBufferBuilder::new(0) + }; + + let buffered_data = BufferedSideData::new( + single_batch, + buffered_values, + Mutex::new(visited_indices_bitmap), + remaining_partitions, + reservation, + ); + + Ok(buffered_data) +} + +pub(super) struct BufferedSideData { + pub(super) batch: RecordBatch, + values: ArrayRef, + pub(super) visited_indices_bitmap: SharedBitmapBuilder, + pub(super) remaining_partitions: AtomicUsize, + _reservation: MemoryReservation, +} + +impl BufferedSideData { + pub(super) fn new( + batch: RecordBatch, + values: ArrayRef, + visited_indices_bitmap: SharedBitmapBuilder, + remaining_partitions: usize, + reservation: MemoryReservation, + ) -> Self { + Self { + batch, + values, + visited_indices_bitmap, + remaining_partitions: AtomicUsize::new(remaining_partitions), + _reservation: reservation, + } + } + + pub(super) fn batch(&self) -> &RecordBatch { + &self.batch + } + + pub(super) fn values(&self) -> &ArrayRef { + &self.values + } +} + +pub(super) enum BufferedSide { + /// Indicates that build-side not collected yet + Initial(BufferedSideInitialState), + /// Indicates that build-side data has been collected + Ready(BufferedSideReadyState), +} + +impl BufferedSide { + // Takes a mutable state of the buffered row batches + pub(super) fn try_as_initial_mut(&mut self) -> Result<&mut BufferedSideInitialState> { + match self { + BufferedSide::Initial(state) => Ok(state), + _ => internal_err!("Expected build side in initial state"), + } + } + + pub(super) fn try_as_ready(&self) -> Result<&BufferedSideReadyState> { + match self { + BufferedSide::Ready(state) => Ok(state), + _ => { + internal_err!("Expected build side in ready state") + } + } + } + + /// Tries to extract BuildSideReadyState from BuildSide enum. + /// Returns an error if state is not Ready. + pub(super) fn try_as_ready_mut(&mut self) -> Result<&mut BufferedSideReadyState> { + match self { + BufferedSide::Ready(state) => Ok(state), + _ => internal_err!("Expected build side in ready state"), + } + } +} + +pub(super) struct BufferedSideInitialState { + pub(crate) buffered_fut: OnceFut, +} + +pub(super) struct BufferedSideReadyState { + /// Collected build-side data + pub(super) buffered_data: Arc, +} diff --git a/datafusion/physical-plan/src/joins/piecewise_merge_join/mod.rs b/datafusion/physical-plan/src/joins/piecewise_merge_join/mod.rs new file mode 100644 index 000000000000..c85a7cc16f65 --- /dev/null +++ b/datafusion/physical-plan/src/joins/piecewise_merge_join/mod.rs @@ -0,0 +1,24 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! PiecewiseMergeJoin is currently experimental + +pub use exec::PiecewiseMergeJoinExec; + +mod classic_join; +mod exec; +mod utils; diff --git a/datafusion/physical-plan/src/joins/piecewise_merge_join/utils.rs b/datafusion/physical-plan/src/joins/piecewise_merge_join/utils.rs new file mode 100644 index 000000000000..5bbb496322b5 --- /dev/null +++ b/datafusion/physical-plan/src/joins/piecewise_merge_join/utils.rs @@ -0,0 +1,61 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use datafusion_expr::JoinType; + +// Returns boolean for whether the join is a right existence join +pub(super) fn is_right_existence_join(join_type: JoinType) -> bool { + matches!( + join_type, + JoinType::RightAnti | JoinType::RightSemi | JoinType::RightMark + ) +} + +// Returns boolean for whether the join is an existence join +pub(super) fn is_existence_join(join_type: JoinType) -> bool { + matches!( + join_type, + JoinType::LeftAnti + | JoinType::RightAnti + | JoinType::LeftSemi + | JoinType::RightSemi + | JoinType::LeftMark + | JoinType::RightMark + ) +} + +// Returns boolean to check if the join type needs to record +// buffered side matches for classic joins +pub(super) fn need_produce_result_in_final(join_type: JoinType) -> bool { + matches!(join_type, JoinType::Full | JoinType::Left) +} + +// Returns boolean for whether or not we need to build the buffered side +// bitmap for marking matched rows on the buffered side. +pub(super) fn build_visited_indices_map(join_type: JoinType) -> bool { + matches!( + join_type, + JoinType::Full + | JoinType::Left + | JoinType::LeftAnti + | JoinType::RightAnti + | JoinType::LeftSemi + | JoinType::RightSemi + | JoinType::LeftMark + | JoinType::RightMark + ) +} diff --git a/datafusion/physical-plan/src/joins/sort_merge_join/stream.rs b/datafusion/physical-plan/src/joins/sort_merge_join/stream.rs index 879f47638d2c..5a2e3669ab5e 100644 --- a/datafusion/physical-plan/src/joins/sort_merge_join/stream.rs +++ b/datafusion/physical-plan/src/joins/sort_merge_join/stream.rs @@ -34,7 +34,7 @@ use std::sync::Arc; use std::task::{Context, Poll}; use crate::joins::sort_merge_join::metrics::SortMergeJoinMetrics; -use crate::joins::utils::JoinFilter; +use crate::joins::utils::{compare_join_arrays, JoinFilter}; use crate::spill::spill_manager::SpillManager; use crate::{PhysicalExpr, RecordBatchStream, SendableRecordBatchStream}; @@ -1865,101 +1865,6 @@ fn join_arrays(batch: &RecordBatch, on_column: &[PhysicalExprRef]) -> Vec Result { - let mut res = Ordering::Equal; - for ((left_array, right_array), sort_options) in - left_arrays.iter().zip(right_arrays).zip(sort_options) - { - macro_rules! compare_value { - ($T:ty) => {{ - let left_array = left_array.as_any().downcast_ref::<$T>().unwrap(); - let right_array = right_array.as_any().downcast_ref::<$T>().unwrap(); - match (left_array.is_null(left), right_array.is_null(right)) { - (false, false) => { - let left_value = &left_array.value(left); - let right_value = &right_array.value(right); - res = left_value.partial_cmp(right_value).unwrap(); - if sort_options.descending { - res = res.reverse(); - } - } - (true, false) => { - res = if sort_options.nulls_first { - Ordering::Less - } else { - Ordering::Greater - }; - } - (false, true) => { - res = if sort_options.nulls_first { - Ordering::Greater - } else { - Ordering::Less - }; - } - _ => { - res = match null_equality { - NullEquality::NullEqualsNothing => Ordering::Less, - NullEquality::NullEqualsNull => Ordering::Equal, - }; - } - } - }}; - } - - match left_array.data_type() { - DataType::Null => {} - DataType::Boolean => compare_value!(BooleanArray), - DataType::Int8 => compare_value!(Int8Array), - DataType::Int16 => compare_value!(Int16Array), - DataType::Int32 => compare_value!(Int32Array), - DataType::Int64 => compare_value!(Int64Array), - DataType::UInt8 => compare_value!(UInt8Array), - DataType::UInt16 => compare_value!(UInt16Array), - DataType::UInt32 => compare_value!(UInt32Array), - DataType::UInt64 => compare_value!(UInt64Array), - DataType::Float32 => compare_value!(Float32Array), - DataType::Float64 => compare_value!(Float64Array), - DataType::Utf8 => compare_value!(StringArray), - DataType::Utf8View => compare_value!(StringViewArray), - DataType::LargeUtf8 => compare_value!(LargeStringArray), - DataType::Binary => compare_value!(BinaryArray), - DataType::BinaryView => compare_value!(BinaryViewArray), - DataType::FixedSizeBinary(_) => compare_value!(FixedSizeBinaryArray), - DataType::LargeBinary => compare_value!(LargeBinaryArray), - DataType::Decimal32(..) => compare_value!(Decimal32Array), - DataType::Decimal64(..) => compare_value!(Decimal64Array), - DataType::Decimal128(..) => compare_value!(Decimal128Array), - DataType::Timestamp(time_unit, None) => match time_unit { - TimeUnit::Second => compare_value!(TimestampSecondArray), - TimeUnit::Millisecond => compare_value!(TimestampMillisecondArray), - TimeUnit::Microsecond => compare_value!(TimestampMicrosecondArray), - TimeUnit::Nanosecond => compare_value!(TimestampNanosecondArray), - }, - DataType::Date32 => compare_value!(Date32Array), - DataType::Date64 => compare_value!(Date64Array), - dt => { - return not_impl_err!( - "Unsupported data type in sort merge join comparator: {}", - dt - ); - } - } - if !res.is_eq() { - break; - } - } - Ok(res) -} - /// A faster version of compare_join_arrays() that only output whether /// the given two rows are equal fn is_join_arrays_equal( diff --git a/datafusion/physical-plan/src/joins/utils.rs b/datafusion/physical-plan/src/joins/utils.rs index c50bfce93a2d..78652d443d3c 100644 --- a/datafusion/physical-plan/src/joins/utils.rs +++ b/datafusion/physical-plan/src/joins/utils.rs @@ -17,7 +17,7 @@ //! Join related functionality used both on logical and physical plans -use std::cmp::min; +use std::cmp::{min, Ordering}; use std::collections::HashSet; use std::fmt::{self, Debug}; use std::future::Future; @@ -43,7 +43,13 @@ use arrow::array::{ BooleanBufferBuilder, NativeAdapter, PrimitiveArray, RecordBatch, RecordBatchOptions, UInt32Array, UInt32Builder, UInt64Array, }; -use arrow::array::{ArrayRef, BooleanArray}; +use arrow::array::{ + ArrayRef, BinaryArray, BinaryViewArray, BooleanArray, Date32Array, Date64Array, + Decimal128Array, FixedSizeBinaryArray, Float32Array, Float64Array, Int16Array, + Int32Array, Int64Array, Int8Array, LargeBinaryArray, LargeStringArray, StringArray, + StringViewArray, TimestampMicrosecondArray, TimestampMillisecondArray, + TimestampNanosecondArray, TimestampSecondArray, UInt16Array, UInt8Array, +}; use arrow::buffer::{BooleanBuffer, NullBuffer}; use arrow::compute::kernels::cmp::eq; use arrow::compute::{self, and, take, FilterBuilder}; @@ -51,12 +57,13 @@ use arrow::datatypes::{ ArrowNativeType, Field, Schema, SchemaBuilder, UInt32Type, UInt64Type, }; use arrow_ord::cmp::not_distinct; -use arrow_schema::ArrowError; +use arrow_schema::{ArrowError, DataType, SortOptions, TimeUnit}; use datafusion_common::cast::as_boolean_array; use datafusion_common::hash_utils::create_hashes; use datafusion_common::stats::Precision; use datafusion_common::{ - plan_err, DataFusionError, JoinSide, JoinType, NullEquality, Result, SharedResult, + not_impl_err, plan_err, DataFusionError, JoinSide, JoinType, NullEquality, Result, + SharedResult, }; use datafusion_expr::interval_arithmetic::Interval; use datafusion_expr::Operator; @@ -284,7 +291,7 @@ pub fn build_join_schema( JoinType::LeftSemi | JoinType::LeftAnti => left_fields().unzip(), JoinType::LeftMark => { let right_field = once(( - Field::new("mark", arrow::datatypes::DataType::Boolean, false), + Field::new("mark", DataType::Boolean, false), ColumnIndex { index: 0, side: JoinSide::None, @@ -295,7 +302,7 @@ pub fn build_join_schema( JoinType::RightSemi | JoinType::RightAnti => right_fields().unzip(), JoinType::RightMark => { let left_field = once(( - Field::new("mark", arrow_schema::DataType::Boolean, false), + Field::new("mark", DataType::Boolean, false), ColumnIndex { index: 0, side: JoinSide::None, @@ -812,9 +819,10 @@ pub(crate) fn need_produce_result_in_final(join_type: JoinType) -> bool { pub(crate) fn get_final_indices_from_shared_bitmap( shared_bitmap: &SharedBitmapBuilder, join_type: JoinType, + piecewise: bool, ) -> (UInt64Array, UInt32Array) { let bitmap = shared_bitmap.lock(); - get_final_indices_from_bit_map(&bitmap, join_type) + get_final_indices_from_bit_map(&bitmap, join_type, piecewise) } /// In the end of join execution, need to use bit map of the matched @@ -829,16 +837,22 @@ pub(crate) fn get_final_indices_from_shared_bitmap( pub(crate) fn get_final_indices_from_bit_map( left_bit_map: &BooleanBufferBuilder, join_type: JoinType, + // We add a flag for whether this is being passed from the `PiecewiseMergeJoin` + // because the bitmap can be for left + right `JoinType`s + piecewise: bool, ) -> (UInt64Array, UInt32Array) { let left_size = left_bit_map.len(); - if join_type == JoinType::LeftMark { + if join_type == JoinType::LeftMark || (join_type == JoinType::RightMark && piecewise) + { let left_indices = (0..left_size as u64).collect::(); let right_indices = (0..left_size) .map(|idx| left_bit_map.get_bit(idx).then_some(0)) .collect::(); return (left_indices, right_indices); } - let left_indices = if join_type == JoinType::LeftSemi { + let left_indices = if join_type == JoinType::LeftSemi + || (join_type == JoinType::RightSemi && piecewise) + { (0..left_size) .filter_map(|idx| (left_bit_map.get_bit(idx)).then_some(idx as u64)) .collect::() @@ -1749,6 +1763,99 @@ fn eq_dyn_null( } } +/// Get comparison result of two rows of join arrays +pub fn compare_join_arrays( + left_arrays: &[ArrayRef], + left: usize, + right_arrays: &[ArrayRef], + right: usize, + sort_options: &[SortOptions], + null_equality: NullEquality, +) -> Result { + let mut res = Ordering::Equal; + for ((left_array, right_array), sort_options) in + left_arrays.iter().zip(right_arrays).zip(sort_options) + { + macro_rules! compare_value { + ($T:ty) => {{ + let left_array = left_array.as_any().downcast_ref::<$T>().unwrap(); + let right_array = right_array.as_any().downcast_ref::<$T>().unwrap(); + match (left_array.is_null(left), right_array.is_null(right)) { + (false, false) => { + let left_value = &left_array.value(left); + let right_value = &right_array.value(right); + res = left_value.partial_cmp(right_value).unwrap(); + if sort_options.descending { + res = res.reverse(); + } + } + (true, false) => { + res = if sort_options.nulls_first { + Ordering::Less + } else { + Ordering::Greater + }; + } + (false, true) => { + res = if sort_options.nulls_first { + Ordering::Greater + } else { + Ordering::Less + }; + } + _ => { + res = match null_equality { + NullEquality::NullEqualsNothing => Ordering::Less, + NullEquality::NullEqualsNull => Ordering::Equal, + }; + } + } + }}; + } + + match left_array.data_type() { + DataType::Null => {} + DataType::Boolean => compare_value!(BooleanArray), + DataType::Int8 => compare_value!(Int8Array), + DataType::Int16 => compare_value!(Int16Array), + DataType::Int32 => compare_value!(Int32Array), + DataType::Int64 => compare_value!(Int64Array), + DataType::UInt8 => compare_value!(UInt8Array), + DataType::UInt16 => compare_value!(UInt16Array), + DataType::UInt32 => compare_value!(UInt32Array), + DataType::UInt64 => compare_value!(UInt64Array), + DataType::Float32 => compare_value!(Float32Array), + DataType::Float64 => compare_value!(Float64Array), + DataType::Binary => compare_value!(BinaryArray), + DataType::BinaryView => compare_value!(BinaryViewArray), + DataType::FixedSizeBinary(_) => compare_value!(FixedSizeBinaryArray), + DataType::LargeBinary => compare_value!(LargeBinaryArray), + DataType::Utf8 => compare_value!(StringArray), + DataType::Utf8View => compare_value!(StringViewArray), + DataType::LargeUtf8 => compare_value!(LargeStringArray), + DataType::Decimal128(..) => compare_value!(Decimal128Array), + DataType::Timestamp(time_unit, None) => match time_unit { + TimeUnit::Second => compare_value!(TimestampSecondArray), + TimeUnit::Millisecond => compare_value!(TimestampMillisecondArray), + TimeUnit::Microsecond => compare_value!(TimestampMicrosecondArray), + TimeUnit::Nanosecond => compare_value!(TimestampNanosecondArray), + }, + DataType::Date32 => compare_value!(Date32Array), + DataType::Date64 => compare_value!(Date64Array), + dt => { + return not_impl_err!( + "Unsupported data type in sort merge join comparator: {}", + dt + ); + } + } + if !res.is_eq() { + break; + } + } + Ok(res) +} + #[cfg(test)] mod tests { use std::collections::HashMap; diff --git a/datafusion/physical-plan/src/metrics/baseline.rs b/datafusion/physical-plan/src/metrics/baseline.rs index 15efb8f90aa2..858773b94664 100644 --- a/datafusion/physical-plan/src/metrics/baseline.rs +++ b/datafusion/physical-plan/src/metrics/baseline.rs @@ -21,6 +21,8 @@ use std::task::Poll; use arrow::record_batch::RecordBatch; +use crate::spill::get_record_batch_memory_size; + use super::{Count, ExecutionPlanMetricsSet, MetricBuilder, Time, Timestamp}; use datafusion_common::Result; @@ -53,6 +55,16 @@ pub struct BaselineMetrics { /// output rows: the total output rows output_rows: Count, + + /// Memory usage of all output batches. + /// + /// Note: This value may be overestimated. If multiple output `RecordBatch` + /// instances share underlying memory buffers, their sizes will be counted + /// multiple times. + /// Issue: + output_bytes: Count, + // Remember to update `docs/source/user-guide/metrics.md` when updating comments + // or adding new metrics } impl BaselineMetrics { @@ -62,9 +74,18 @@ impl BaselineMetrics { start_time.record(); Self { - end_time: MetricBuilder::new(metrics).end_timestamp(partition), - elapsed_compute: MetricBuilder::new(metrics).elapsed_compute(partition), - output_rows: MetricBuilder::new(metrics).output_rows(partition), + end_time: MetricBuilder::new(metrics) + .with_type(super::MetricType::SUMMARY) + .end_timestamp(partition), + elapsed_compute: MetricBuilder::new(metrics) + .with_type(super::MetricType::SUMMARY) + .elapsed_compute(partition), + output_rows: MetricBuilder::new(metrics) + .with_type(super::MetricType::SUMMARY) + .output_rows(partition), + output_bytes: MetricBuilder::new(metrics) + .with_type(super::MetricType::SUMMARY) + .output_bytes(partition), } } @@ -78,6 +99,7 @@ impl BaselineMetrics { end_time: Default::default(), elapsed_compute: self.elapsed_compute.clone(), output_rows: Default::default(), + output_bytes: Default::default(), } } @@ -205,6 +227,8 @@ impl RecordOutput for usize { impl RecordOutput for RecordBatch { fn record_output(self, bm: &BaselineMetrics) -> Self { bm.record_output(self.num_rows()); + let n_bytes = get_record_batch_memory_size(&self); + bm.output_bytes.add(n_bytes); self } } @@ -212,6 +236,8 @@ impl RecordOutput for RecordBatch { impl RecordOutput for &RecordBatch { fn record_output(self, bm: &BaselineMetrics) -> Self { bm.record_output(self.num_rows()); + let n_bytes = get_record_batch_memory_size(self); + bm.output_bytes.add(n_bytes); self } } diff --git a/datafusion/physical-plan/src/metrics/builder.rs b/datafusion/physical-plan/src/metrics/builder.rs index dbda0a310ce5..88ec1a3f67d1 100644 --- a/datafusion/physical-plan/src/metrics/builder.rs +++ b/datafusion/physical-plan/src/metrics/builder.rs @@ -19,6 +19,8 @@ use std::{borrow::Cow, sync::Arc}; +use crate::metrics::MetricType; + use super::{ Count, ExecutionPlanMetricsSet, Gauge, Label, Metric, MetricValue, Time, Timestamp, }; @@ -52,15 +54,23 @@ pub struct MetricBuilder<'a> { /// arbitrary name=value pairs identifying this metric labels: Vec

+ + + + +