diff --git a/.baseline/checkstyle/checkstyle.xml b/.baseline/checkstyle/checkstyle.xml index be92588c4552..9b22c83c24ae 100644 --- a/.baseline/checkstyle/checkstyle.xml +++ b/.baseline/checkstyle/checkstyle.xml @@ -132,6 +132,7 @@ org.apache.parquet.schema.PrimitiveType.PrimitiveTypeName.*, org.apache.spark.sql.functions.*, org.apache.spark.sql.connector.iceberg.write.RowLevelOperation.Command.*, + org.apache.spark.sql.connector.write.RowLevelOperation.Command.*, org.junit.Assert.*"/> diff --git a/.github/workflows/java-ci.yml b/.github/workflows/java-ci.yml index 1fb4fbe7f13b..cab7fd1d4291 100644 --- a/.github/workflows/java-ci.yml +++ b/.github/workflows/java-ci.yml @@ -84,7 +84,7 @@ jobs: with: distribution: zulu java-version: 8 - - run: ./gradlew -DflinkVersions=1.13,1.14,1.15 -DsparkVersions=2.4,3.0,3.1,3.2 -DhiveVersions=2,3 build -x test -x javadoc -x integrationTest + - run: ./gradlew -DflinkVersions=1.13,1.14,1.15 -DsparkVersions=2.4,3.0,3.1,3.2,3.3 -DhiveVersions=2,3 build -x test -x javadoc -x integrationTest build-javadoc: runs-on: ubuntu-20.04 diff --git a/.github/workflows/jmh-bechmarks.yml b/.github/workflows/jmh-bechmarks.yml index f746deeef6cf..a38e5a949e8c 100644 --- a/.github/workflows/jmh-bechmarks.yml +++ b/.github/workflows/jmh-bechmarks.yml @@ -28,8 +28,8 @@ on: description: 'The branch name' required: true spark_version: - description: 'The spark project version to use, such as iceberg-spark-3.2' - default: 'iceberg-spark-3.2' + description: 'The spark project version to use, such as iceberg-spark-3.3' + default: 'iceberg-spark-3.3' required: true benchmarks: description: 'A list of comma-separated double-quoted Benchmark names, such as "IcebergSourceFlatParquetDataReadBenchmark", "IcebergSourceFlatParquetDataFilterBenchmark"' diff --git a/.github/workflows/publish-snapshot.yml b/.github/workflows/publish-snapshot.yml index b03f52f42f92..ff183b8c6ebd 100644 --- a/.github/workflows/publish-snapshot.yml +++ b/.github/workflows/publish-snapshot.yml @@ -40,5 +40,5 @@ jobs: java-version: 8 - run: | ./gradlew printVersion - ./gradlew -DflinkVersions=1.13,1.14,1.15 -DsparkVersions=2.4,3.0,3.1,3.2 -DhiveVersions=2,3 publishApachePublicationToMavenRepository -PmavenUser=${{ secrets.NEXUS_USER }} -PmavenPassword=${{ secrets.NEXUS_PW }} + ./gradlew -DflinkVersions=1.13,1.14,1.15 -DsparkVersions=2.4,3.0,3.1,3.2,3.3 -DhiveVersions=2,3 publishApachePublicationToMavenRepository -PmavenUser=${{ secrets.NEXUS_USER }} -PmavenPassword=${{ secrets.NEXUS_PW }} ./gradlew -DflinkVersions= -DsparkVersions=3.2 -DscalaVersion=2.13 -DhiveVersions= publishApachePublicationToMavenRepository -PmavenUser=${{ secrets.NEXUS_USER }} -PmavenPassword=${{ secrets.NEXUS_PW }} diff --git a/.github/workflows/spark-ci.yml b/.github/workflows/spark-ci.yml index 79b6b7f9602e..7a19669b08e1 100644 --- a/.github/workflows/spark-ci.yml +++ b/.github/workflows/spark-ci.yml @@ -83,7 +83,7 @@ jobs: strategy: matrix: jvm: [8, 11] - spark: ['3.0', '3.1', '3.2'] + spark: ['3.0', '3.1', '3.2', '3.3'] env: SPARK_LOCAL_IP: localhost steps: @@ -111,7 +111,7 @@ jobs: strategy: matrix: jvm: [8, 11] - spark: ['3.2'] + spark: ['3.2','3.3'] env: SPARK_LOCAL_IP: localhost steps: diff --git a/.gitignore b/.gitignore index a7a414e7469f..4aafb18e52c7 100644 --- a/.gitignore +++ b/.gitignore @@ -32,6 +32,7 @@ spark/v2.4/spark/benchmark/* spark/v3.0/spark/benchmark/* spark/v3.1/spark/benchmark/* spark/v3.2/spark/benchmark/* +spark/v3.3/spark/benchmark/* data/benchmark/* __pycache__/ diff --git a/dev/stage-binaries.sh b/dev/stage-binaries.sh index ec5e1ac60682..c5c0778523ba 100755 --- a/dev/stage-binaries.sh +++ b/dev/stage-binaries.sh @@ -20,7 +20,7 @@ SCALA_VERSION=2.12 FLINK_VERSIONS=1.13,1.14,1.15 -SPARK_VERSIONS=2.4,3.0,3.1,3.2 +SPARK_VERSIONS=2.4,3.0,3.1,3.2,3.3 HIVE_VERSIONS=2,3 ./gradlew -Prelease -DscalaVersion=$SCALA_VERSION -DflinkVersions=$FLINK_VERSIONS -DsparkVersions=$SPARK_VERSIONS -DhiveVersions=$HIVE_VERSIONS publishApachePublicationToMavenRepository @@ -28,4 +28,5 @@ HIVE_VERSIONS=2,3 # Also publish Scala 2.13 Artifacts for versions that support it. # Flink does not yet support 2.13 (and is largely dropping a user-facing dependency on Scala). Hive doesn't need a Scala specification. ./gradlew -Prelease -DscalaVersion=2.13 :iceberg-spark:iceberg-spark-3.2_2.13:publishApachePublicationToMavenRepository :iceberg-spark:iceberg-spark-extensions-3.2_2.13:publishApachePublicationToMavenRepository :iceberg-spark:iceberg-spark-runtime-3.2_2.13:publishApachePublicationToMavenRepository +./gradlew -Prelease -DscalaVersion=2.13 :iceberg-spark:iceberg-spark-3.3_2.13:publishApachePublicationToMavenRepository :iceberg-spark:iceberg-spark-extensions-3.3_2.13:publishApachePublicationToMavenRepository :iceberg-spark:iceberg-spark-runtime-3.3_2.13:publishApachePublicationToMavenRepository diff --git a/gradle.properties b/gradle.properties index 1c1df528890e..29fa24eccc9f 100644 --- a/gradle.properties +++ b/gradle.properties @@ -19,8 +19,8 @@ systemProp.defaultFlinkVersions=1.15 systemProp.knownFlinkVersions=1.13,1.14,1.15 systemProp.defaultHiveVersions=2 systemProp.knownHiveVersions=2,3 -systemProp.defaultSparkVersions=3.2 -systemProp.knownSparkVersions=2.4,3.0,3.1,3.2 +systemProp.defaultSparkVersions=3.3 +systemProp.knownSparkVersions=2.4,3.0,3.1,3.2,3.3 systemProp.defaultScalaVersion=2.12 systemProp.knownScalaVersions=2.12,2.13 org.gradle.parallel=true diff --git a/jmh.gradle b/jmh.gradle index 538fd96af406..cbbd58d0fca2 100644 --- a/jmh.gradle +++ b/jmh.gradle @@ -41,6 +41,10 @@ if (sparkVersions.contains("3.2")) { jmhProjects.add(project(":iceberg-spark:iceberg-spark-3.2_${scalaVersion}")) } +if (sparkVersions.contains("3.3")) { + jmhProjects.add(project(":iceberg-spark:iceberg-spark-3.3_${scalaVersion}")) +} + jmhProjects.add(project(":iceberg-data")) configure(jmhProjects) { diff --git a/settings.gradle b/settings.gradle index 82e9c904674b..3f6e2cf036e9 100644 --- a/settings.gradle +++ b/settings.gradle @@ -151,6 +151,18 @@ if (sparkVersions.contains("3.2")) { project(":iceberg-spark:spark-runtime-3.2_${scalaVersion}").name = "iceberg-spark-runtime-3.2_${scalaVersion}" } +if (sparkVersions.contains("3.3")) { + include ":iceberg-spark:spark-3.3_${scalaVersion}" + include ":iceberg-spark:spark-extensions-3.3_${scalaVersion}" + include ":iceberg-spark:spark-runtime-3.3_${scalaVersion}" + project(":iceberg-spark:spark-3.3_${scalaVersion}").projectDir = file('spark/v3.3/spark') + project(":iceberg-spark:spark-3.3_${scalaVersion}").name = "iceberg-spark-3.3_${scalaVersion}" + project(":iceberg-spark:spark-extensions-3.3_${scalaVersion}").projectDir = file('spark/v3.3/spark-extensions') + project(":iceberg-spark:spark-extensions-3.3_${scalaVersion}").name = "iceberg-spark-extensions-3.3_${scalaVersion}" + project(":iceberg-spark:spark-runtime-3.3_${scalaVersion}").projectDir = file('spark/v3.3/spark-runtime') + project(":iceberg-spark:spark-runtime-3.3_${scalaVersion}").name = "iceberg-spark-runtime-3.3_${scalaVersion}" +} + // hive 3 depends on hive 2, so always add hive 2 if hive3 is enabled if (hiveVersions.contains("2") || hiveVersions.contains("3")) { include 'mr' diff --git a/spark/build.gradle b/spark/build.gradle index ca74e0570a75..1f0485889e3e 100644 --- a/spark/build.gradle +++ b/spark/build.gradle @@ -34,4 +34,8 @@ if (sparkVersions.contains("3.1")) { if (sparkVersions.contains("3.2")) { apply from: file("$projectDir/v3.2/build.gradle") +} + +if (sparkVersions.contains("3.3")) { + apply from: file("$projectDir/v3.3/build.gradle") } \ No newline at end of file diff --git a/spark/v3.3/build.gradle b/spark/v3.3/build.gradle new file mode 100644 index 000000000000..9d886447775e --- /dev/null +++ b/spark/v3.3/build.gradle @@ -0,0 +1,268 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +String sparkVersion = '3.3.0' +String sparkMajorVersion = '3.3' +String scalaVersion = System.getProperty("scalaVersion") != null ? System.getProperty("scalaVersion") : System.getProperty("defaultScalaVersion") + +def sparkProjects = [ + project(":iceberg-spark:iceberg-spark-${sparkMajorVersion}_${scalaVersion}"), + project(":iceberg-spark:iceberg-spark-extensions-${sparkMajorVersion}_${scalaVersion}"), + project(":iceberg-spark:iceberg-spark-runtime-${sparkMajorVersion}_${scalaVersion}"), +] + +project(":iceberg-spark:iceberg-spark-${sparkMajorVersion}_${scalaVersion}") { + apply plugin: 'scala' + apply plugin: 'com.github.alisiikh.scalastyle' + + sourceSets { + main { + scala.srcDirs = ['src/main/scala', 'src/main/java'] + java.srcDirs = [] + } + } + + dependencies { + implementation project(path: ':iceberg-bundled-guava', configuration: 'shadow') + api project(':iceberg-api') + implementation project(':iceberg-common') + implementation project(':iceberg-core') + implementation project(':iceberg-data') + implementation project(':iceberg-orc') + implementation project(':iceberg-parquet') + implementation project(':iceberg-arrow') + implementation("org.scala-lang.modules:scala-collection-compat_${scalaVersion}") + + compileOnly "com.google.errorprone:error_prone_annotations" + compileOnly "org.apache.avro:avro" + compileOnly("org.apache.spark:spark-hive_${scalaVersion}:${sparkVersion}") { + exclude group: 'org.apache.avro', module: 'avro' + exclude group: 'org.apache.arrow' + exclude group: 'org.apache.parquet' + // to make sure io.netty.buffer only comes from project(':iceberg-arrow') + exclude group: 'io.netty', module: 'netty-buffer' + exclude group: 'org.roaringbitmap' + } + + implementation("org.apache.parquet:parquet-column") + implementation("org.apache.parquet:parquet-hadoop") + + implementation("org.apache.orc:orc-core::nohive") { + exclude group: 'org.apache.hadoop' + exclude group: 'commons-lang' + // These artifacts are shaded and included in the orc-core fat jar + exclude group: 'com.google.protobuf', module: 'protobuf-java' + exclude group: 'org.apache.hive', module: 'hive-storage-api' + } + + implementation("org.apache.arrow:arrow-vector") { + exclude group: 'io.netty', module: 'netty-buffer' + exclude group: 'io.netty', module: 'netty-common' + exclude group: 'com.google.code.findbugs', module: 'jsr305' + } + + testImplementation("org.apache.hadoop:hadoop-minicluster") { + exclude group: 'org.apache.avro', module: 'avro' + // to make sure io.netty.buffer only comes from project(':iceberg-arrow') + exclude group: 'io.netty', module: 'netty-buffer' + } + testImplementation project(path: ':iceberg-hive-metastore') + testImplementation project(path: ':iceberg-hive-metastore', configuration: 'testArtifacts') + testImplementation project(path: ':iceberg-api', configuration: 'testArtifacts') + testImplementation project(path: ':iceberg-core', configuration: 'testArtifacts') + testImplementation project(path: ':iceberg-data', configuration: 'testArtifacts') + testImplementation "org.xerial:sqlite-jdbc" + } + + tasks.withType(Test) { + // For vectorized reads + // Allow unsafe memory access to avoid the costly check arrow does to check if index is within bounds + systemProperty("arrow.enable_unsafe_memory_access", "true") + // Disable expensive null check for every get(index) call. + // Iceberg manages nullability checks itself instead of relying on arrow. + systemProperty("arrow.enable_null_check_for_get", "false") + + // Vectorized reads need more memory + maxHeapSize '2560m' + } +} + +project(":iceberg-spark:iceberg-spark-extensions-${sparkMajorVersion}_${scalaVersion}") { + apply plugin: 'java-library' + apply plugin: 'scala' + apply plugin: 'com.github.alisiikh.scalastyle' + apply plugin: 'antlr' + + configurations { + /* + The Gradle Antlr plugin erroneously adds both antlr-build and runtime dependencies to the runtime path. This + bug https://github.com/gradle/gradle/issues/820 exists because older versions of Antlr do not have separate + runtime and implementation dependencies and they do not want to break backwards compatibility. So to only end up with + the runtime dependency on the runtime classpath we remove the dependencies added by the plugin here. Then add + the runtime dependency back to only the runtime configuration manually. + */ + implementation { + extendsFrom = extendsFrom.findAll { it != configurations.antlr } + } + } + + dependencies { + implementation("org.scala-lang.modules:scala-collection-compat_${scalaVersion}") + + compileOnly "org.scala-lang:scala-library:${scalaVersion}" + compileOnly project(path: ':iceberg-bundled-guava', configuration: 'shadow') + compileOnly project(':iceberg-api') + compileOnly project(':iceberg-core') + compileOnly project(':iceberg-common') + compileOnly project(":iceberg-spark:iceberg-spark-${sparkMajorVersion}_${scalaVersion}") + compileOnly("org.apache.spark:spark-hive_${scalaVersion}:${sparkVersion}") { + exclude group: 'org.apache.avro', module: 'avro' + exclude group: 'org.apache.arrow' + exclude group: 'org.apache.parquet' + // to make sure io.netty.buffer only comes from project(':iceberg-arrow') + exclude group: 'io.netty', module: 'netty-buffer' + exclude group: 'org.roaringbitmap' + } + + testImplementation project(path: ':iceberg-hive-metastore') + testImplementation project(path: ':iceberg-hive-metastore', configuration: 'testArtifacts') + + testImplementation project(path: ':iceberg-api', configuration: 'testArtifacts') + testImplementation project(path: ':iceberg-hive-metastore', configuration: 'testArtifacts') + testImplementation project(path: ":iceberg-spark:iceberg-spark-${sparkMajorVersion}_${scalaVersion}", configuration: 'testArtifacts') + + testImplementation "org.apache.avro:avro" + + // Required because we remove antlr plugin dependencies from the compile configuration, see note above + runtimeOnly "org.antlr:antlr4-runtime:4.8" + antlr "org.antlr:antlr4:4.8" + } + + generateGrammarSource { + maxHeapSize = "64m" + arguments += ['-visitor', '-package', 'org.apache.spark.sql.catalyst.parser.extensions'] + } +} + +project(":iceberg-spark:iceberg-spark-runtime-${sparkMajorVersion}_${scalaVersion}") { + apply plugin: 'com.github.johnrengelman.shadow' + + tasks.jar.dependsOn tasks.shadowJar + + sourceSets { + integration { + java.srcDir "$projectDir/src/integration/java" + resources.srcDir "$projectDir/src/integration/resources" + } + } + + configurations { + implementation { + exclude group: 'org.apache.spark' + // included in Spark + exclude group: 'org.slf4j' + exclude group: 'org.apache.commons' + exclude group: 'commons-pool' + exclude group: 'commons-codec' + exclude group: 'org.xerial.snappy' + exclude group: 'javax.xml.bind' + exclude group: 'javax.annotation' + exclude group: 'com.github.luben' + exclude group: 'com.ibm.icu' + exclude group: 'org.glassfish' + exclude group: 'org.abego.treelayout' + exclude group: 'org.antlr' + } + } + + dependencies { + api project(':iceberg-api') + implementation project(":iceberg-spark:iceberg-spark-${sparkMajorVersion}_${scalaVersion}") + implementation project(":iceberg-spark:iceberg-spark-extensions-${sparkMajorVersion}_${scalaVersion}") + implementation project(':iceberg-aws') + implementation(project(':iceberg-aliyun')) { + exclude group: 'edu.umd.cs.findbugs', module: 'findbugs' + exclude group: 'org.apache.httpcomponents', module: 'httpclient' + exclude group: 'commons-logging', module: 'commons-logging' + } + implementation project(':iceberg-hive-metastore') + implementation(project(':iceberg-nessie')) { + exclude group: 'com.google.code.findbugs', module: 'jsr305' + } + + integrationImplementation "org.apache.spark:spark-hive_${scalaVersion}:${sparkVersion}" + integrationImplementation 'org.junit.vintage:junit-vintage-engine' + integrationImplementation 'org.slf4j:slf4j-simple' + integrationImplementation project(path: ':iceberg-api', configuration: 'testArtifacts') + integrationImplementation project(path: ':iceberg-hive-metastore', configuration: 'testArtifacts') + integrationImplementation project(path: ":iceberg-spark:iceberg-spark-${sparkMajorVersion}_${scalaVersion}", configuration: 'testArtifacts') + integrationImplementation project(path: ":iceberg-spark:iceberg-spark-extensions-${sparkMajorVersion}_${scalaVersion}", configuration: 'testArtifacts') + // Not allowed on our classpath, only the runtime jar is allowed + integrationCompileOnly project(":iceberg-spark:iceberg-spark-extensions-${sparkMajorVersion}_${scalaVersion}") + integrationCompileOnly project(":iceberg-spark:iceberg-spark-${sparkMajorVersion}_${scalaVersion}") + integrationCompileOnly project(':iceberg-api') + } + + shadowJar { + configurations = [project.configurations.runtimeClasspath] + + zip64 true + + // include the LICENSE and NOTICE files for the shaded Jar + from(projectDir) { + include 'LICENSE' + include 'NOTICE' + } + + // Relocate dependencies to avoid conflicts + relocate 'com.google', 'org.apache.iceberg.shaded.com.google' + relocate 'com.fasterxml', 'org.apache.iceberg.shaded.com.fasterxml' + relocate 'com.github.benmanes', 'org.apache.iceberg.shaded.com.github.benmanes' + relocate 'org.checkerframework', 'org.apache.iceberg.shaded.org.checkerframework' + relocate 'org.apache.avro', 'org.apache.iceberg.shaded.org.apache.avro' + relocate 'avro.shaded', 'org.apache.iceberg.shaded.org.apache.avro.shaded' + relocate 'com.thoughtworks.paranamer', 'org.apache.iceberg.shaded.com.thoughtworks.paranamer' + relocate 'org.apache.parquet', 'org.apache.iceberg.shaded.org.apache.parquet' + relocate 'shaded.parquet', 'org.apache.iceberg.shaded.org.apache.parquet.shaded' + relocate 'org.apache.orc', 'org.apache.iceberg.shaded.org.apache.orc' + relocate 'io.airlift', 'org.apache.iceberg.shaded.io.airlift' + // relocate Arrow and related deps to shade Iceberg specific version + relocate 'io.netty.buffer', 'org.apache.iceberg.shaded.io.netty.buffer' + relocate 'org.apache.arrow', 'org.apache.iceberg.shaded.org.apache.arrow' + relocate 'com.carrotsearch', 'org.apache.iceberg.shaded.com.carrotsearch' + relocate 'org.threeten.extra', 'org.apache.iceberg.shaded.org.threeten.extra' + relocate 'org.roaringbitmap', 'org.apache.iceberg.shaded.org.roaringbitmap' + + classifier null + } + + task integrationTest(type: Test) { + description = "Test Spark3 Runtime Jar against Spark ${sparkMajorVersion}" + group = "verification" + testClassesDirs = sourceSets.integration.output.classesDirs + classpath = sourceSets.integration.runtimeClasspath + files(shadowJar.archiveFile.get().asFile.path) + inputs.file(shadowJar.archiveFile.get().asFile.path) + } + integrationTest.dependsOn shadowJar + + jar { + enabled = false + } +} + diff --git a/spark/v3.3/spark-extensions/src/main/antlr/org.apache.spark.sql.catalyst.parser.extensions/IcebergSqlExtensions.g4 b/spark/v3.3/spark-extensions/src/main/antlr/org.apache.spark.sql.catalyst.parser.extensions/IcebergSqlExtensions.g4 new file mode 100644 index 000000000000..d0b228df0a2f --- /dev/null +++ b/spark/v3.3/spark-extensions/src/main/antlr/org.apache.spark.sql.catalyst.parser.extensions/IcebergSqlExtensions.g4 @@ -0,0 +1,293 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + * + * This file is an adaptation of Presto's and Spark's grammar files. + */ + +grammar IcebergSqlExtensions; + +@lexer::members { + /** + * Verify whether current token is a valid decimal token (which contains dot). + * Returns true if the character that follows the token is not a digit or letter or underscore. + * + * For example: + * For char stream "2.3", "2." is not a valid decimal token, because it is followed by digit '3'. + * For char stream "2.3_", "2.3" is not a valid decimal token, because it is followed by '_'. + * For char stream "2.3W", "2.3" is not a valid decimal token, because it is followed by 'W'. + * For char stream "12.0D 34.E2+0.12 " 12.0D is a valid decimal token because it is followed + * by a space. 34.E2 is a valid decimal token because it is followed by symbol '+' + * which is not a digit or letter or underscore. + */ + public boolean isValidDecimal() { + int nextChar = _input.LA(1); + if (nextChar >= 'A' && nextChar <= 'Z' || nextChar >= '0' && nextChar <= '9' || + nextChar == '_') { + return false; + } else { + return true; + } + } + + /** + * This method will be called when we see '/*' and try to match it as a bracketed comment. + * If the next character is '+', it should be parsed as hint later, and we cannot match + * it as a bracketed comment. + * + * Returns true if the next character is '+'. + */ + public boolean isHint() { + int nextChar = _input.LA(1); + if (nextChar == '+') { + return true; + } else { + return false; + } + } +} + +singleStatement + : statement EOF + ; + +statement + : CALL multipartIdentifier '(' (callArgument (',' callArgument)*)? ')' #call + | ALTER TABLE multipartIdentifier ADD PARTITION FIELD transform (AS name=identifier)? #addPartitionField + | ALTER TABLE multipartIdentifier DROP PARTITION FIELD transform #dropPartitionField + | ALTER TABLE multipartIdentifier REPLACE PARTITION FIELD transform WITH transform (AS name=identifier)? #replacePartitionField + | ALTER TABLE multipartIdentifier WRITE writeSpec #setWriteDistributionAndOrdering + | ALTER TABLE multipartIdentifier SET IDENTIFIER_KW FIELDS fieldList #setIdentifierFields + | ALTER TABLE multipartIdentifier DROP IDENTIFIER_KW FIELDS fieldList #dropIdentifierFields + ; + +writeSpec + : (writeDistributionSpec | writeOrderingSpec)* + ; + +writeDistributionSpec + : DISTRIBUTED BY PARTITION + ; + +writeOrderingSpec + : LOCALLY? ORDERED BY order + | UNORDERED + ; + +callArgument + : expression #positionalArgument + | identifier '=>' expression #namedArgument + ; + +order + : fields+=orderField (',' fields+=orderField)* + | '(' fields+=orderField (',' fields+=orderField)* ')' + ; + +orderField + : transform direction=(ASC | DESC)? (NULLS nullOrder=(FIRST | LAST))? + ; + +transform + : multipartIdentifier #identityTransform + | transformName=identifier + '(' arguments+=transformArgument (',' arguments+=transformArgument)* ')' #applyTransform + ; + +transformArgument + : multipartIdentifier + | constant + ; + +expression + : constant + | stringMap + ; + +constant + : number #numericLiteral + | booleanValue #booleanLiteral + | STRING+ #stringLiteral + | identifier STRING #typeConstructor + ; + +stringMap + : MAP '(' constant (',' constant)* ')' + ; + +booleanValue + : TRUE | FALSE + ; + +number + : MINUS? EXPONENT_VALUE #exponentLiteral + | MINUS? DECIMAL_VALUE #decimalLiteral + | MINUS? INTEGER_VALUE #integerLiteral + | MINUS? BIGINT_LITERAL #bigIntLiteral + | MINUS? SMALLINT_LITERAL #smallIntLiteral + | MINUS? TINYINT_LITERAL #tinyIntLiteral + | MINUS? DOUBLE_LITERAL #doubleLiteral + | MINUS? FLOAT_LITERAL #floatLiteral + | MINUS? BIGDECIMAL_LITERAL #bigDecimalLiteral + ; + +multipartIdentifier + : parts+=identifier ('.' parts+=identifier)* + ; + +identifier + : IDENTIFIER #unquotedIdentifier + | quotedIdentifier #quotedIdentifierAlternative + | nonReserved #unquotedIdentifier + ; + +quotedIdentifier + : BACKQUOTED_IDENTIFIER + ; + +fieldList + : fields+=multipartIdentifier (',' fields+=multipartIdentifier)* + ; + +nonReserved + : ADD | ALTER | AS | ASC | BY | CALL | DESC | DROP | FIELD | FIRST | LAST | NULLS | ORDERED | PARTITION | TABLE | WRITE + | DISTRIBUTED | LOCALLY | UNORDERED | REPLACE | WITH | IDENTIFIER_KW | FIELDS | SET + | TRUE | FALSE + | MAP + ; + +ADD: 'ADD'; +ALTER: 'ALTER'; +AS: 'AS'; +ASC: 'ASC'; +BY: 'BY'; +CALL: 'CALL'; +DESC: 'DESC'; +DISTRIBUTED: 'DISTRIBUTED'; +DROP: 'DROP'; +FIELD: 'FIELD'; +FIELDS: 'FIELDS'; +FIRST: 'FIRST'; +LAST: 'LAST'; +LOCALLY: 'LOCALLY'; +NULLS: 'NULLS'; +ORDERED: 'ORDERED'; +PARTITION: 'PARTITION'; +REPLACE: 'REPLACE'; +IDENTIFIER_KW: 'IDENTIFIER'; +SET: 'SET'; +TABLE: 'TABLE'; +UNORDERED: 'UNORDERED'; +WITH: 'WITH'; +WRITE: 'WRITE'; + +TRUE: 'TRUE'; +FALSE: 'FALSE'; + +MAP: 'MAP'; + +PLUS: '+'; +MINUS: '-'; + +STRING + : '\'' ( ~('\''|'\\') | ('\\' .) )* '\'' + | '"' ( ~('"'|'\\') | ('\\' .) )* '"' + ; + +BIGINT_LITERAL + : DIGIT+ 'L' + ; + +SMALLINT_LITERAL + : DIGIT+ 'S' + ; + +TINYINT_LITERAL + : DIGIT+ 'Y' + ; + +INTEGER_VALUE + : DIGIT+ + ; + +EXPONENT_VALUE + : DIGIT+ EXPONENT + | DECIMAL_DIGITS EXPONENT {isValidDecimal()}? + ; + +DECIMAL_VALUE + : DECIMAL_DIGITS {isValidDecimal()}? + ; + +FLOAT_LITERAL + : DIGIT+ EXPONENT? 'F' + | DECIMAL_DIGITS EXPONENT? 'F' {isValidDecimal()}? + ; + +DOUBLE_LITERAL + : DIGIT+ EXPONENT? 'D' + | DECIMAL_DIGITS EXPONENT? 'D' {isValidDecimal()}? + ; + +BIGDECIMAL_LITERAL + : DIGIT+ EXPONENT? 'BD' + | DECIMAL_DIGITS EXPONENT? 'BD' {isValidDecimal()}? + ; + +IDENTIFIER + : (LETTER | DIGIT | '_')+ + ; + +BACKQUOTED_IDENTIFIER + : '`' ( ~'`' | '``' )* '`' + ; + +fragment DECIMAL_DIGITS + : DIGIT+ '.' DIGIT* + | '.' DIGIT+ + ; + +fragment EXPONENT + : 'E' [+-]? DIGIT+ + ; + +fragment DIGIT + : [0-9] + ; + +fragment LETTER + : [A-Z] + ; + +SIMPLE_COMMENT + : '--' ('\\\n' | ~[\r\n])* '\r'? '\n'? -> channel(HIDDEN) + ; + +BRACKETED_COMMENT + : '/*' {!isHint()}? (BRACKETED_COMMENT|.)*? '*/' -> channel(HIDDEN) + ; + +WS + : [ \r\n\t]+ -> channel(HIDDEN) + ; + +// Catch-all for anything we can't recognize. +// We use this to be able to ignore and recover all the text +// when splitting statements with DelimiterLexer +UNRECOGNIZED + : . + ; diff --git a/spark/v3.3/spark-extensions/src/main/scala/org/apache/iceberg/spark/extensions/IcebergSparkSessionExtensions.scala b/spark/v3.3/spark-extensions/src/main/scala/org/apache/iceberg/spark/extensions/IcebergSparkSessionExtensions.scala new file mode 100644 index 000000000000..4fb9a48a3e00 --- /dev/null +++ b/spark/v3.3/spark-extensions/src/main/scala/org/apache/iceberg/spark/extensions/IcebergSparkSessionExtensions.scala @@ -0,0 +1,77 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iceberg.spark.extensions + +import org.apache.spark.sql.SparkSessionExtensions +import org.apache.spark.sql.catalyst.analysis.AlignedRowLevelIcebergCommandCheck +import org.apache.spark.sql.catalyst.analysis.AlignRowLevelCommandAssignments +import org.apache.spark.sql.catalyst.analysis.CheckMergeIntoTableConditions +import org.apache.spark.sql.catalyst.analysis.MergeIntoIcebergTableResolutionCheck +import org.apache.spark.sql.catalyst.analysis.ProcedureArgumentCoercion +import org.apache.spark.sql.catalyst.analysis.ResolveMergeIntoTableReferences +import org.apache.spark.sql.catalyst.analysis.ResolveProcedures +import org.apache.spark.sql.catalyst.analysis.RewriteDeleteFromIcebergTable +import org.apache.spark.sql.catalyst.analysis.RewriteMergeIntoTable +import org.apache.spark.sql.catalyst.analysis.RewriteUpdateTable +import org.apache.spark.sql.catalyst.optimizer.ExtendedReplaceNullWithFalseInPredicate +import org.apache.spark.sql.catalyst.optimizer.ExtendedSimplifyConditionalsInPredicate +import org.apache.spark.sql.catalyst.parser.extensions.IcebergSparkSqlExtensionsParser +import org.apache.spark.sql.execution.datasources.v2.ExtendedDataSourceV2Strategy +import org.apache.spark.sql.execution.datasources.v2.ExtendedV2Writes +import org.apache.spark.sql.execution.datasources.v2.OptimizeMetadataOnlyDeleteFromIcebergTable +import org.apache.spark.sql.execution.datasources.v2.ReplaceRewrittenRowLevelCommand +import org.apache.spark.sql.execution.datasources.v2.RowLevelCommandScanRelationPushDown +import org.apache.spark.sql.execution.dynamicpruning.RowLevelCommandDynamicPruning + +class IcebergSparkSessionExtensions extends (SparkSessionExtensions => Unit) { + + override def apply(extensions: SparkSessionExtensions): Unit = { + // parser extensions + extensions.injectParser { case (_, parser) => new IcebergSparkSqlExtensionsParser(parser) } + + // analyzer extensions + extensions.injectResolutionRule { spark => ResolveProcedures(spark) } + extensions.injectResolutionRule { spark => ResolveMergeIntoTableReferences(spark) } + extensions.injectResolutionRule { _ => CheckMergeIntoTableConditions } + extensions.injectResolutionRule { _ => ProcedureArgumentCoercion } + extensions.injectResolutionRule { _ => AlignRowLevelCommandAssignments } + extensions.injectResolutionRule { _ => RewriteDeleteFromIcebergTable } + extensions.injectResolutionRule { _ => RewriteUpdateTable } + extensions.injectResolutionRule { _ => RewriteMergeIntoTable } + extensions.injectCheckRule { _ => MergeIntoIcebergTableResolutionCheck } + extensions.injectCheckRule { _ => AlignedRowLevelIcebergCommandCheck } + + // optimizer extensions + extensions.injectOptimizerRule { _ => ExtendedSimplifyConditionalsInPredicate } + extensions.injectOptimizerRule { _ => ExtendedReplaceNullWithFalseInPredicate } + // pre-CBO rules run only once and the order of the rules is important + // - metadata deletes have to be attempted immediately after the operator optimization + // - dynamic filters should be added before replacing commands with rewrite plans + // - scans must be planned before building writes + extensions.injectPreCBORule { _ => OptimizeMetadataOnlyDeleteFromIcebergTable } + extensions.injectPreCBORule { _ => RowLevelCommandScanRelationPushDown } + extensions.injectPreCBORule { _ => ExtendedV2Writes } + extensions.injectPreCBORule { spark => RowLevelCommandDynamicPruning(spark) } + extensions.injectPreCBORule { _ => ReplaceRewrittenRowLevelCommand } + + // planner extensions + extensions.injectPlannerStrategy { spark => ExtendedDataSourceV2Strategy(spark) } + } +} diff --git a/spark/v3.3/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/ProjectingInternalRow.scala b/spark/v3.3/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/ProjectingInternalRow.scala new file mode 100644 index 000000000000..fb654b646738 --- /dev/null +++ b/spark/v3.3/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/ProjectingInternalRow.scala @@ -0,0 +1,123 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.spark.sql.catalyst + +import org.apache.spark.sql.catalyst.util.ArrayData +import org.apache.spark.sql.catalyst.util.MapData +import org.apache.spark.sql.types.DataType +import org.apache.spark.sql.types.Decimal +import org.apache.spark.sql.types.StructType +import org.apache.spark.unsafe.types.CalendarInterval +import org.apache.spark.unsafe.types.UTF8String + +/** + * An InternalRow that projects particular columns from another InternalRow without copying + * the underlying data. + */ +case class ProjectingInternalRow(schema: StructType, colOrdinals: Seq[Int]) extends InternalRow { + assert(schema.size == colOrdinals.size) + + private var row: InternalRow = _ + + override def numFields: Int = colOrdinals.size + + def project(row: InternalRow): Unit = { + this.row = row + } + + override def setNullAt(i: Int): Unit = { + throw new UnsupportedOperationException("Cannot modify InternalRowProjection") + } + + override def update(i: Int, value: Any): Unit = { + throw new UnsupportedOperationException("Cannot modify InternalRowProjection") + } + + override def copy(): InternalRow = { + val newRow = if (row != null) row.copy() else null + val newProjection = ProjectingInternalRow(schema, colOrdinals) + newProjection.project(newRow) + newProjection + } + + override def isNullAt(ordinal: Int): Boolean = { + row.isNullAt(colOrdinals(ordinal)) + } + + override def getBoolean(ordinal: Int): Boolean = { + row.getBoolean(colOrdinals(ordinal)) + } + + override def getByte(ordinal: Int): Byte = { + row.getByte(colOrdinals(ordinal)) + } + + override def getShort(ordinal: Int): Short = { + row.getShort(colOrdinals(ordinal)) + } + + override def getInt(ordinal: Int): Int = { + row.getInt(colOrdinals(ordinal)) + } + + override def getLong(ordinal: Int): Long = { + row.getLong(colOrdinals(ordinal)) + } + + override def getFloat(ordinal: Int): Float = { + row.getFloat(colOrdinals(ordinal)) + } + + override def getDouble(ordinal: Int): Double = { + row.getDouble(colOrdinals(ordinal)) + } + + override def getDecimal(ordinal: Int, precision: Int, scale: Int): Decimal = { + row.getDecimal(colOrdinals(ordinal), precision, scale) + } + + override def getUTF8String(ordinal: Int): UTF8String = { + row.getUTF8String(colOrdinals(ordinal)) + } + + override def getBinary(ordinal: Int): Array[Byte] = { + row.getBinary(colOrdinals(ordinal)) + } + + override def getInterval(ordinal: Int): CalendarInterval = { + row.getInterval(colOrdinals(ordinal)) + } + + override def getStruct(ordinal: Int, numFields: Int): InternalRow = { + row.getStruct(colOrdinals(ordinal), numFields) + } + + override def getArray(ordinal: Int): ArrayData = { + row.getArray(colOrdinals(ordinal)) + } + + override def getMap(ordinal: Int): MapData = { + row.getMap(colOrdinals(ordinal)) + } + + override def get(ordinal: Int, dataType: DataType): AnyRef = { + row.get(colOrdinals(ordinal), dataType) + } +} diff --git a/spark/v3.3/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/analysis/AlignRowLevelCommandAssignments.scala b/spark/v3.3/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/analysis/AlignRowLevelCommandAssignments.scala new file mode 100644 index 000000000000..ad416f2a1c63 --- /dev/null +++ b/spark/v3.3/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/analysis/AlignRowLevelCommandAssignments.scala @@ -0,0 +1,113 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.spark.sql.catalyst.analysis + +import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.catalyst.expressions.AssignmentUtils +import org.apache.spark.sql.catalyst.plans.logical.Assignment +import org.apache.spark.sql.catalyst.plans.logical.DeleteAction +import org.apache.spark.sql.catalyst.plans.logical.InsertAction +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.catalyst.plans.logical.MergeIntoIcebergTable +import org.apache.spark.sql.catalyst.plans.logical.UpdateAction +import org.apache.spark.sql.catalyst.plans.logical.UpdateIcebergTable +import org.apache.spark.sql.catalyst.rules.Rule + +/** + * A rule that aligns assignments in UPDATE and MERGE operations. + * + * Note that this rule must be run before rewriting row-level commands. + */ +object AlignRowLevelCommandAssignments + extends Rule[LogicalPlan] with AssignmentAlignmentSupport { + + override def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { + case u: UpdateIcebergTable if u.resolved && !u.aligned => + u.copy(assignments = alignAssignments(u.table, u.assignments)) + + case m: MergeIntoIcebergTable if m.resolved && !m.aligned => + val alignedMatchedActions = m.matchedActions.map { + case u @ UpdateAction(_, assignments) => + u.copy(assignments = alignAssignments(m.targetTable, assignments)) + case d: DeleteAction => + d + case _ => + throw new AnalysisException("Matched actions can only contain UPDATE or DELETE") + } + + val alignedNotMatchedActions = m.notMatchedActions.map { + case i @ InsertAction(_, assignments) => + // check no nested columns are present + val refs = assignments.map(_.key).map(AssignmentUtils.toAssignmentRef) + refs.foreach { ref => + if (ref.size > 1) { + throw new AnalysisException( + "Nested fields are not supported inside INSERT clauses of MERGE operations: " + + s"${ref.mkString("`", "`.`", "`")}") + } + } + + val colNames = refs.map(_.head) + + // check there are no duplicates + val duplicateColNames = colNames.groupBy(identity).collect { + case (name, matchingNames) if matchingNames.size > 1 => name + } + + if (duplicateColNames.nonEmpty) { + throw new AnalysisException( + s"Duplicate column names inside INSERT clause: ${duplicateColNames.mkString(", ")}") + } + + // reorder assignments by the target table column order + val assignmentMap = colNames.zip(assignments).toMap + i.copy(assignments = alignInsertActionAssignments(m.targetTable, assignmentMap)) + + case _ => + throw new AnalysisException("Not matched actions can only contain INSERT") + } + + m.copy(matchedActions = alignedMatchedActions, notMatchedActions = alignedNotMatchedActions) + } + + private def alignInsertActionAssignments( + targetTable: LogicalPlan, + assignmentMap: Map[String, Assignment]): Seq[Assignment] = { + + val resolver = conf.resolver + + targetTable.output.map { targetAttr => + val assignment = assignmentMap + .find { case (name, _) => resolver(name, targetAttr.name) } + .map { case (_, assignment) => assignment } + + if (assignment.isEmpty) { + throw new AnalysisException( + s"Cannot find column '${targetAttr.name}' of the target table among " + + s"the INSERT columns: ${assignmentMap.keys.mkString(", ")}. " + + "INSERT clauses must provide values for all columns of the target table.") + } + + val key = assignment.get.key + val value = castIfNeeded(targetAttr, assignment.get.value, resolver) + AssignmentUtils.handleCharVarcharLimits(Assignment(key, value)) + } + } +} diff --git a/spark/v3.3/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/analysis/AlignedRowLevelIcebergCommandCheck.scala b/spark/v3.3/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/analysis/AlignedRowLevelIcebergCommandCheck.scala new file mode 100644 index 000000000000..d915e4f10949 --- /dev/null +++ b/spark/v3.3/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/analysis/AlignedRowLevelIcebergCommandCheck.scala @@ -0,0 +1,38 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.spark.sql.catalyst.analysis + +import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.catalyst.plans.logical.MergeIntoIcebergTable +import org.apache.spark.sql.catalyst.plans.logical.UpdateIcebergTable + +object AlignedRowLevelIcebergCommandCheck extends (LogicalPlan => Unit) { + + override def apply(plan: LogicalPlan): Unit = { + plan foreach { + case m: MergeIntoIcebergTable if !m.aligned => + throw new AnalysisException(s"Could not align Iceberg MERGE INTO: $m") + case u: UpdateIcebergTable if !u.aligned => + throw new AnalysisException(s"Could not align Iceberg UPDATE: $u") + case _ => // OK + } + } +} diff --git a/spark/v3.3/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/analysis/AssignmentAlignmentSupport.scala b/spark/v3.3/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/analysis/AssignmentAlignmentSupport.scala new file mode 100644 index 000000000000..14115bd3cbfe --- /dev/null +++ b/spark/v3.3/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/analysis/AssignmentAlignmentSupport.scala @@ -0,0 +1,202 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.spark.sql.catalyst.analysis + +import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.catalyst.SQLConfHelper +import org.apache.spark.sql.catalyst.expressions.Alias +import org.apache.spark.sql.catalyst.expressions.AnsiCast +import org.apache.spark.sql.catalyst.expressions.AssignmentUtils._ +import org.apache.spark.sql.catalyst.expressions.Cast +import org.apache.spark.sql.catalyst.expressions.CreateNamedStruct +import org.apache.spark.sql.catalyst.expressions.Expression +import org.apache.spark.sql.catalyst.expressions.GetStructField +import org.apache.spark.sql.catalyst.expressions.Literal +import org.apache.spark.sql.catalyst.expressions.NamedExpression +import org.apache.spark.sql.catalyst.plans.logical.Assignment +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.internal.SQLConf.StoreAssignmentPolicy +import org.apache.spark.sql.types.DataType +import org.apache.spark.sql.types.StructField +import org.apache.spark.sql.types.StructType +import scala.collection.compat.immutable.ArraySeq +import scala.collection.mutable + +trait AssignmentAlignmentSupport extends CastSupport { + + self: SQLConfHelper => + + private case class ColumnUpdate(ref: Seq[String], expr: Expression) + + /** + * Aligns assignments to match table columns. + *

+ * This method processes and reorders given assignments so that each target column gets + * an expression it should be set to. If a column does not have a matching assignment, + * it will be set to its current value. For example, if one passes a table with columns c1, c2 + * and an assignment c2 = 1, this method will return c1 = c1, c2 = 1. + *

+ * This method also handles updates to nested columns. If there is an assignment to a particular + * nested field, this method will construct a new struct with one field updated + * preserving other fields that have not been modified. For example, if one passes a table with + * columns c1, c2 where c2 is a struct with fields n1 and n2 and an assignment c2.n2 = 1, + * this method will return c1 = c1, c2 = struct(c2.n1, 1). + * + * @param table a target table + * @param assignments assignments to align + * @return aligned assignments that match table columns + */ + protected def alignAssignments( + table: LogicalPlan, + assignments: Seq[Assignment]): Seq[Assignment] = { + + val columnUpdates = assignments.map(a => ColumnUpdate(toAssignmentRef(a.key), a.value)) + val outputExprs = applyUpdates(table.output, columnUpdates) + outputExprs.zip(table.output).map { + case (expr, attr) => handleCharVarcharLimits(Assignment(attr, expr)) + } + } + + private def applyUpdates( + cols: Seq[NamedExpression], + updates: Seq[ColumnUpdate], + resolver: Resolver = conf.resolver, + namePrefix: Seq[String] = Nil): Seq[Expression] = { + + // iterate through columns at the current level and find which column updates match + cols.map { col => + // find matches for this column or any of its children + val prefixMatchedUpdates = updates.filter(a => resolver(a.ref.head, col.name)) + prefixMatchedUpdates match { + // if there is no exact match and no match for children, return the column as is + case updates if updates.isEmpty => + col + + // if there is an exact match, return the assigned expression + case Seq(update) if isExactMatch(update, col, resolver) => + castIfNeeded(col, update.expr, resolver) + + // if there are matches only for children + case updates if !hasExactMatch(updates, col, resolver) => + col.dataType match { + case StructType(fields) => + // build field expressions + val fieldExprs = fields.zipWithIndex.map { case (field, ordinal) => + Alias(GetStructField(col, ordinal, Some(field.name)), field.name)() + } + + // recursively apply this method on nested fields + val newUpdates = updates.map(u => u.copy(ref = u.ref.tail)) + val updatedFieldExprs = applyUpdates( + ArraySeq.unsafeWrapArray(fieldExprs), + newUpdates, + resolver, + namePrefix :+ col.name) + + // construct a new struct with updated field expressions + toNamedStruct(ArraySeq.unsafeWrapArray(fields), updatedFieldExprs) + + case otherType => + val colName = (namePrefix :+ col.name).mkString(".") + throw new AnalysisException( + "Updating nested fields is only supported for StructType " + + s"but $colName is of type $otherType" + ) + } + + // if there are conflicting updates, throw an exception + // there are two illegal scenarios: + // - multiple updates to the same column + // - updates to a top-level struct and its nested fields (e.g., a.b and a.b.c) + case updates if hasExactMatch(updates, col, resolver) => + val conflictingCols = updates.map(u => (namePrefix ++ u.ref).mkString(".")) + throw new AnalysisException( + "Updates are in conflict for these columns: " + + conflictingCols.distinct.mkString(", ")) + } + } + } + + private def toNamedStruct(fields: Seq[StructField], fieldExprs: Seq[Expression]): Expression = { + val namedStructExprs = fields.zip(fieldExprs).flatMap { case (field, expr) => + Seq(Literal(field.name), expr) + } + CreateNamedStruct(namedStructExprs) + } + + private def hasExactMatch( + updates: Seq[ColumnUpdate], + col: NamedExpression, + resolver: Resolver): Boolean = { + + updates.exists(assignment => isExactMatch(assignment, col, resolver)) + } + + private def isExactMatch( + update: ColumnUpdate, + col: NamedExpression, + resolver: Resolver): Boolean = { + + update.ref match { + case Seq(namePart) if resolver(namePart, col.name) => true + case _ => false + } + } + + protected def castIfNeeded( + tableAttr: NamedExpression, + expr: Expression, + resolver: Resolver): Expression = { + + val storeAssignmentPolicy = conf.storeAssignmentPolicy + + // run the type check and catch type errors + storeAssignmentPolicy match { + case StoreAssignmentPolicy.STRICT | StoreAssignmentPolicy.ANSI => + if (expr.nullable && !tableAttr.nullable) { + throw new AnalysisException( + s"Cannot write nullable values to non-null column '${tableAttr.name}'") + } + + // use byName = true to catch cases when struct field names don't match + // e.g. a struct with fields (a, b) is assigned as a struct with fields (a, c) or (b, a) + val errors = new mutable.ArrayBuffer[String]() + val canWrite = DataType.canWrite( + expr.dataType, tableAttr.dataType, byName = true, resolver, tableAttr.name, + storeAssignmentPolicy, err => errors += err) + + if (!canWrite) { + throw new AnalysisException( + s"Cannot write incompatible data:\n- ${errors.mkString("\n- ")}") + } + + case _ => // OK + } + + storeAssignmentPolicy match { + case _ if tableAttr.dataType.sameType(expr.dataType) => + expr + case StoreAssignmentPolicy.ANSI => + AnsiCast(expr, tableAttr.dataType, Option(conf.sessionLocalTimeZone)) + case _ => + Cast(expr, tableAttr.dataType, Option(conf.sessionLocalTimeZone)) + } + } +} diff --git a/spark/v3.3/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckMergeIntoTableConditions.scala b/spark/v3.3/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckMergeIntoTableConditions.scala new file mode 100644 index 000000000000..70f6694af60b --- /dev/null +++ b/spark/v3.3/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckMergeIntoTableConditions.scala @@ -0,0 +1,74 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.spark.sql.catalyst.analysis + +import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.catalyst.expressions.Expression +import org.apache.spark.sql.catalyst.expressions.SubqueryExpression +import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression +import org.apache.spark.sql.catalyst.plans.logical.DeleteAction +import org.apache.spark.sql.catalyst.plans.logical.InsertAction +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.catalyst.plans.logical.MergeIntoIcebergTable +import org.apache.spark.sql.catalyst.plans.logical.UpdateAction +import org.apache.spark.sql.catalyst.rules.Rule + +/** + * A rule that checks MERGE operations contain only supported conditions. + * + * Note that this rule must be run in the resolution batch before Spark executes CheckAnalysis. + * Otherwise, CheckAnalysis will throw a less descriptive error. + */ +object CheckMergeIntoTableConditions extends Rule[LogicalPlan] { + + override def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { + case m: MergeIntoIcebergTable if m.resolved => + checkMergeIntoCondition("SEARCH", m.mergeCondition) + + val actions = m.matchedActions ++ m.notMatchedActions + actions.foreach { + case DeleteAction(Some(cond)) => checkMergeIntoCondition("DELETE", cond) + case UpdateAction(Some(cond), _) => checkMergeIntoCondition("UPDATE", cond) + case InsertAction(Some(cond), _) => checkMergeIntoCondition("INSERT", cond) + case _ => // OK + } + + m + } + + private def checkMergeIntoCondition(condName: String, cond: Expression): Unit = { + if (!cond.deterministic) { + throw new AnalysisException( + s"Non-deterministic functions are not supported in $condName conditions of " + + s"MERGE operations: ${cond.sql}") + } + + if (SubqueryExpression.hasSubquery(cond)) { + throw new AnalysisException( + s"Subqueries are not supported in conditions of MERGE operations. " + + s"Found a subquery in the $condName condition: ${cond.sql}") + } + + if (cond.find(_.isInstanceOf[AggregateExpression]).isDefined) { + throw new AnalysisException( + s"Agg functions are not supported in $condName conditions of MERGE operations: " + {cond.sql}) + } + } +} diff --git a/spark/v3.3/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/analysis/MergeIntoIcebergTableResolutionCheck.scala b/spark/v3.3/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/analysis/MergeIntoIcebergTableResolutionCheck.scala new file mode 100644 index 000000000000..b3a9bda280d2 --- /dev/null +++ b/spark/v3.3/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/analysis/MergeIntoIcebergTableResolutionCheck.scala @@ -0,0 +1,35 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.spark.sql.catalyst.analysis + +import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.catalyst.plans.logical.UnresolvedMergeIntoIcebergTable + +object MergeIntoIcebergTableResolutionCheck extends (LogicalPlan => Unit) { + + override def apply(plan: LogicalPlan): Unit = { + plan foreach { + case m: UnresolvedMergeIntoIcebergTable => + throw new AnalysisException(s"Could not resolve Iceberg MERGE INTO statement: $m") + case _ => // OK + } + } +} diff --git a/spark/v3.3/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/analysis/ProcedureArgumentCoercion.scala b/spark/v3.3/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/analysis/ProcedureArgumentCoercion.scala new file mode 100644 index 000000000000..7f0ca8fadded --- /dev/null +++ b/spark/v3.3/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/analysis/ProcedureArgumentCoercion.scala @@ -0,0 +1,56 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.spark.sql.catalyst.analysis + +import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.catalyst.expressions.Cast +import org.apache.spark.sql.catalyst.plans.logical.Call +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.catalyst.rules.Rule + +object ProcedureArgumentCoercion extends Rule[LogicalPlan] { + override def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { + case c @ Call(procedure, args) if c.resolved => + val params = procedure.parameters + + val newArgs = args.zipWithIndex.map { case (arg, index) => + val param = params(index) + val paramType = param.dataType + val argType = arg.dataType + + if (paramType != argType && !Cast.canUpCast(argType, paramType)) { + throw new AnalysisException( + s"Wrong arg type for ${param.name}: cannot cast $argType to $paramType") + } + + if (paramType != argType) { + Cast(arg, paramType) + } else { + arg + } + } + + if (newArgs != args) { + c.copy(args = newArgs) + } else { + c + } + } +} diff --git a/spark/v3.3/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveMergeIntoTableReferences.scala b/spark/v3.3/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveMergeIntoTableReferences.scala new file mode 100644 index 000000000000..63ebdef95730 --- /dev/null +++ b/spark/v3.3/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveMergeIntoTableReferences.scala @@ -0,0 +1,153 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.spark.sql.catalyst.analysis + +import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.catalyst.expressions.Expression +import org.apache.spark.sql.catalyst.plans.logical.Assignment +import org.apache.spark.sql.catalyst.plans.logical.DeleteAction +import org.apache.spark.sql.catalyst.plans.logical.InsertAction +import org.apache.spark.sql.catalyst.plans.logical.InsertStarAction +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.catalyst.plans.logical.MergeIntoIcebergTable +import org.apache.spark.sql.catalyst.plans.logical.Project +import org.apache.spark.sql.catalyst.plans.logical.UnresolvedMergeIntoIcebergTable +import org.apache.spark.sql.catalyst.plans.logical.UpdateAction +import org.apache.spark.sql.catalyst.plans.logical.UpdateStarAction +import org.apache.spark.sql.catalyst.rules.Rule + +/** + * A resolution rule similar to ResolveReferences in Spark but handles Iceberg MERGE operations. + */ +case class ResolveMergeIntoTableReferences(spark: SparkSession) extends Rule[LogicalPlan] { + + private lazy val analyzer: Analyzer = spark.sessionState.analyzer + + override def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperatorsUp { + case m @ UnresolvedMergeIntoIcebergTable(targetTable, sourceTable, context) + if targetTable.resolved && sourceTable.resolved && m.duplicateResolved => + + val resolvedMatchedActions = context.matchedActions.map { + case DeleteAction(cond) => + val resolvedCond = cond.map(resolveCond("DELETE", _, m)) + DeleteAction(resolvedCond) + + case UpdateAction(cond, assignments) => + val resolvedCond = cond.map(resolveCond("UPDATE", _, m)) + // the update action can access columns from both target and source tables + val resolvedAssignments = resolveAssignments(assignments, m, resolveValuesWithSourceOnly = false) + UpdateAction(resolvedCond, resolvedAssignments) + + case UpdateStarAction(updateCondition) => + val resolvedUpdateCondition = updateCondition.map(resolveCond("UPDATE", _, m)) + val assignments = targetTable.output.map { attr => + Assignment(attr, UnresolvedAttribute(Seq(attr.name))) + } + // for UPDATE *, the value must be from the source table + val resolvedAssignments = resolveAssignments(assignments, m, resolveValuesWithSourceOnly = true) + UpdateAction(resolvedUpdateCondition, resolvedAssignments) + + case _ => + throw new AnalysisException("Matched actions can only contain UPDATE or DELETE") + } + + val resolvedNotMatchedActions = context.notMatchedActions.map { + case InsertAction(cond, assignments) => + // the insert action is used when not matched, so its condition and value can only + // access columns from the source table + val resolvedCond = cond.map(resolveCond("INSERT", _, Project(Nil, m.sourceTable))) + val resolvedAssignments = resolveAssignments(assignments, m, resolveValuesWithSourceOnly = true) + InsertAction(resolvedCond, resolvedAssignments) + + case InsertStarAction(cond) => + // the insert action is used when not matched, so its condition and value can only + // access columns from the source table + val resolvedCond = cond.map(resolveCond("INSERT", _, Project(Nil, m.sourceTable))) + val assignments = targetTable.output.map { attr => + Assignment(attr, UnresolvedAttribute(Seq(attr.name))) + } + val resolvedAssignments = resolveAssignments(assignments, m, resolveValuesWithSourceOnly = true) + InsertAction(resolvedCond, resolvedAssignments) + + case _ => + throw new AnalysisException("Not matched actions can only contain INSERT") + } + + val resolvedMergeCondition = resolveCond("SEARCH", context.mergeCondition, m) + + MergeIntoIcebergTable( + targetTable, + sourceTable, + mergeCondition = resolvedMergeCondition, + matchedActions = resolvedMatchedActions, + notMatchedActions = resolvedNotMatchedActions) + } + + private def resolveCond(condName: String, cond: Expression, plan: LogicalPlan): Expression = { + val resolvedCond = analyzer.resolveExpressionByPlanChildren(cond, plan) + + val unresolvedAttrs = resolvedCond.references.filter(!_.resolved) + if (unresolvedAttrs.nonEmpty) { + throw new AnalysisException( + s"Cannot resolve ${unresolvedAttrs.map(_.sql).mkString("[", ",", "]")} in $condName condition " + + s"of MERGE operation given input columns: ${plan.inputSet.toSeq.map(_.sql).mkString("[", ",", "]")}") + } + + resolvedCond + } + + // copied from ResolveReferences in Spark + private def resolveAssignments( + assignments: Seq[Assignment], + mergeInto: UnresolvedMergeIntoIcebergTable, + resolveValuesWithSourceOnly: Boolean): Seq[Assignment] = { + assignments.map { assign => + val resolvedKey = assign.key match { + case c if !c.resolved => + resolveMergeExprOrFail(c, Project(Nil, mergeInto.targetTable)) + case o => o + } + val resolvedValue = assign.value match { + // The update values may contain target and/or source references. + case c if !c.resolved => + if (resolveValuesWithSourceOnly) { + resolveMergeExprOrFail(c, Project(Nil, mergeInto.sourceTable)) + } else { + resolveMergeExprOrFail(c, mergeInto) + } + case o => o + } + Assignment(resolvedKey, resolvedValue) + } + } + + // copied from ResolveReferences in Spark + private def resolveMergeExprOrFail(e: Expression, p: LogicalPlan): Expression = { + val resolved = analyzer.resolveExpressionByPlanChildren(e, p) + resolved.references.filter(!_.resolved).foreach { a => + // Note: This will throw error only on unresolved attribute issues, + // not other resolution errors like mismatched data types. + val cols = p.inputSet.toSeq.map(_.sql).mkString(", ") + a.failAnalysis(s"cannot resolve ${a.sql} in MERGE command given columns [$cols]") + } + resolved + } +} diff --git a/spark/v3.3/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveProcedures.scala b/spark/v3.3/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveProcedures.scala new file mode 100644 index 000000000000..ee69b5e344f0 --- /dev/null +++ b/spark/v3.3/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveProcedures.scala @@ -0,0 +1,190 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.spark.sql.catalyst.analysis + +import java.util.Locale +import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.catalyst.expressions.Expression +import org.apache.spark.sql.catalyst.expressions.Literal +import org.apache.spark.sql.catalyst.plans.logical.Call +import org.apache.spark.sql.catalyst.plans.logical.CallArgument +import org.apache.spark.sql.catalyst.plans.logical.CallStatement +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.catalyst.plans.logical.NamedArgument +import org.apache.spark.sql.catalyst.plans.logical.PositionalArgument +import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.connector.catalog.CatalogManager +import org.apache.spark.sql.connector.catalog.CatalogPlugin +import org.apache.spark.sql.connector.catalog.LookupCatalog +import org.apache.spark.sql.connector.iceberg.catalog.ProcedureCatalog +import org.apache.spark.sql.connector.iceberg.catalog.ProcedureParameter +import scala.collection.Seq + +case class ResolveProcedures(spark: SparkSession) extends Rule[LogicalPlan] with LookupCatalog { + + protected lazy val catalogManager: CatalogManager = spark.sessionState.catalogManager + + override def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { + case CallStatement(CatalogAndIdentifier(catalog, ident), args) => + val procedure = catalog.asProcedureCatalog.loadProcedure(ident) + + val params = procedure.parameters + val normalizedParams = normalizeParams(params) + validateParams(normalizedParams) + + val normalizedArgs = normalizeArgs(args) + Call(procedure, args = buildArgExprs(normalizedParams, normalizedArgs).toSeq) + } + + private def validateParams(params: Seq[ProcedureParameter]): Unit = { + // should not be any duplicate param names + val duplicateParamNames = params.groupBy(_.name).collect { + case (name, matchingParams) if matchingParams.length > 1 => name + } + + if (duplicateParamNames.nonEmpty) { + throw new AnalysisException(s"Duplicate parameter names: ${duplicateParamNames.mkString("[", ",", "]")}") + } + + // optional params should be at the end + params.sliding(2).foreach { + case Seq(previousParam, currentParam) if !previousParam.required && currentParam.required => + throw new AnalysisException( + s"Optional parameters must be after required ones but $currentParam is after $previousParam") + case _ => + } + } + + private def buildArgExprs( + params: Seq[ProcedureParameter], + args: Seq[CallArgument]): Seq[Expression] = { + + // build a map of declared parameter names to their positions + val nameToPositionMap = params.map(_.name).zipWithIndex.toMap + + // build a map of parameter names to args + val nameToArgMap = buildNameToArgMap(params, args, nameToPositionMap) + + // verify all required parameters are provided + val missingParamNames = params.filter(_.required).collect { + case param if !nameToArgMap.contains(param.name) => param.name + } + + if (missingParamNames.nonEmpty) { + throw new AnalysisException(s"Missing required parameters: ${missingParamNames.mkString("[", ",", "]")}") + } + + val argExprs = new Array[Expression](params.size) + + nameToArgMap.foreach { case (name, arg) => + val position = nameToPositionMap(name) + argExprs(position) = arg.expr + } + + // assign nulls to optional params that were not set + params.foreach { + case p if !p.required && !nameToArgMap.contains(p.name) => + val position = nameToPositionMap(p.name) + argExprs(position) = Literal.create(null, p.dataType) + case _ => + } + + argExprs + } + + private def buildNameToArgMap( + params: Seq[ProcedureParameter], + args: Seq[CallArgument], + nameToPositionMap: Map[String, Int]): Map[String, CallArgument] = { + + val containsNamedArg = args.exists(_.isInstanceOf[NamedArgument]) + val containsPositionalArg = args.exists(_.isInstanceOf[PositionalArgument]) + + if (containsNamedArg && containsPositionalArg) { + throw new AnalysisException("Named and positional arguments cannot be mixed") + } + + if (containsNamedArg) { + buildNameToArgMapUsingNames(args, nameToPositionMap) + } else { + buildNameToArgMapUsingPositions(args, params) + } + } + + private def buildNameToArgMapUsingNames( + args: Seq[CallArgument], + nameToPositionMap: Map[String, Int]): Map[String, CallArgument] = { + + val namedArgs = args.asInstanceOf[Seq[NamedArgument]] + + val validationErrors = namedArgs.groupBy(_.name).collect { + case (name, matchingArgs) if matchingArgs.size > 1 => s"Duplicate procedure argument: $name" + case (name, _) if !nameToPositionMap.contains(name) => s"Unknown argument: $name" + } + + if (validationErrors.nonEmpty) { + throw new AnalysisException(s"Could not build name to arg map: ${validationErrors.mkString(", ")}") + } + + namedArgs.map(arg => arg.name -> arg).toMap + } + + private def buildNameToArgMapUsingPositions( + args: Seq[CallArgument], + params: Seq[ProcedureParameter]): Map[String, CallArgument] = { + + if (args.size > params.size) { + throw new AnalysisException("Too many arguments for procedure") + } + + args.zipWithIndex.map { case (arg, position) => + val param = params(position) + param.name -> arg + }.toMap + } + + private def normalizeParams(params: Seq[ProcedureParameter]): Seq[ProcedureParameter] = { + params.map { + case param if param.required => + val normalizedName = param.name.toLowerCase(Locale.ROOT) + ProcedureParameter.required(normalizedName, param.dataType) + case param => + val normalizedName = param.name.toLowerCase(Locale.ROOT) + ProcedureParameter.optional(normalizedName, param.dataType) + } + } + + private def normalizeArgs(args: Seq[CallArgument]): Seq[CallArgument] = { + args.map { + case a @ NamedArgument(name, _) => a.copy(name = name.toLowerCase(Locale.ROOT)) + case other => other + } + } + + implicit class CatalogHelper(plugin: CatalogPlugin) { + def asProcedureCatalog: ProcedureCatalog = plugin match { + case procedureCatalog: ProcedureCatalog => + procedureCatalog + case _ => + throw new AnalysisException(s"Cannot use catalog ${plugin.name}: not a ProcedureCatalog") + } + } +} diff --git a/spark/v3.3/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/analysis/RewriteDeleteFromIcebergTable.scala b/spark/v3.3/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/analysis/RewriteDeleteFromIcebergTable.scala new file mode 100644 index 000000000000..501ef8200cf5 --- /dev/null +++ b/spark/v3.3/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/analysis/RewriteDeleteFromIcebergTable.scala @@ -0,0 +1,118 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.spark.sql.catalyst.analysis + +import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.catalyst.expressions.Alias +import org.apache.spark.sql.catalyst.expressions.EqualNullSafe +import org.apache.spark.sql.catalyst.expressions.Expression +import org.apache.spark.sql.catalyst.expressions.Literal +import org.apache.spark.sql.catalyst.expressions.Not +import org.apache.spark.sql.catalyst.plans.logical.DeleteFromIcebergTable +import org.apache.spark.sql.catalyst.plans.logical.Filter +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.catalyst.plans.logical.Project +import org.apache.spark.sql.catalyst.plans.logical.ReplaceIcebergData +import org.apache.spark.sql.catalyst.plans.logical.WriteDelta +import org.apache.spark.sql.catalyst.util.RowDeltaUtils._ +import org.apache.spark.sql.connector.catalog.SupportsRowLevelOperations +import org.apache.spark.sql.connector.iceberg.write.SupportsDelta +import org.apache.spark.sql.connector.write.RowLevelOperation.Command.DELETE +import org.apache.spark.sql.connector.write.RowLevelOperationTable +import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation +import org.apache.spark.sql.util.CaseInsensitiveStringMap + +/** + * Assigns a rewrite plan for v2 tables that support rewriting data to handle DELETE statements. + * + * If a table implements SupportsDelete and SupportsRowLevelOperations, this rule assigns a rewrite + * plan but the optimizer will check whether this particular DELETE statement can be handled + * by simply passing delete filters to the connector. If yes, the optimizer will then discard + * the rewrite plan. + */ +object RewriteDeleteFromIcebergTable extends RewriteRowLevelIcebergCommand { + + override def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { + case d @ DeleteFromIcebergTable(aliasedTable, Some(cond), None) if d.resolved => + EliminateSubqueryAliases(aliasedTable) match { + case r @ DataSourceV2Relation(tbl: SupportsRowLevelOperations, _, _, _, _) => + val table = buildOperationTable(tbl, DELETE, CaseInsensitiveStringMap.empty()) + val rewritePlan = table.operation match { + case _: SupportsDelta => + buildWriteDeltaPlan(r, table, cond) + case _ => + buildReplaceDataPlan(r, table, cond) + } + // keep the original relation in DELETE to try deleting using filters + DeleteFromIcebergTable(r, Some(cond), Some(rewritePlan)) + + case p => + throw new AnalysisException(s"$p is not an Iceberg table") + } + } + + // build a rewrite plan for sources that support replacing groups of data (e.g. files, partitions) + private def buildReplaceDataPlan( + relation: DataSourceV2Relation, + operationTable: RowLevelOperationTable, + cond: Expression): ReplaceIcebergData = { + + // resolve all needed attrs (e.g. metadata attrs for grouping data on write) + val metadataAttrs = resolveRequiredMetadataAttrs(relation, operationTable.operation) + + // construct a read relation and include all required metadata columns + val readRelation = buildRelationWithAttrs(relation, operationTable, metadataAttrs) + + // construct a plan that contains unmatched rows in matched groups that must be carried over + // such rows do not match the condition but have to be copied over as the source can replace + // only groups of rows + val remainingRowsFilter = Not(EqualNullSafe(cond, Literal.TrueLiteral)) + val remainingRowsPlan = Filter(remainingRowsFilter, readRelation) + + // build a plan to replace read groups in the table + val writeRelation = relation.copy(table = operationTable) + ReplaceIcebergData(writeRelation, remainingRowsPlan, relation) + } + + // build a rewrite plan for sources that support row deltas + private def buildWriteDeltaPlan( + relation: DataSourceV2Relation, + operationTable: RowLevelOperationTable, + cond: Expression): WriteDelta = { + + // resolve all needed attrs (e.g. row ID and any required metadata attrs) + val rowIdAttrs = resolveRowIdAttrs(relation, operationTable.operation) + val metadataAttrs = resolveRequiredMetadataAttrs(relation, operationTable.operation) + + // construct a read relation and include all required metadata columns + val readRelation = buildRelationWithAttrs(relation, operationTable, rowIdAttrs ++ metadataAttrs) + + // construct a plan that only contains records to delete + val deletedRowsPlan = Filter(cond, readRelation) + val operationType = Alias(Literal(DELETE_OPERATION), OPERATION_COLUMN)() + val requiredWriteAttrs = dedupAttrs(rowIdAttrs ++ metadataAttrs) + val project = Project(operationType +: requiredWriteAttrs, deletedRowsPlan) + + // build a plan to write deletes to the table + val writeRelation = relation.copy(table = operationTable) + val projections = buildWriteDeltaProjections(project, Nil, rowIdAttrs, metadataAttrs) + WriteDelta(writeRelation, project, relation, projections) + } +} diff --git a/spark/v3.3/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/analysis/RewriteMergeIntoTable.scala b/spark/v3.3/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/analysis/RewriteMergeIntoTable.scala new file mode 100644 index 000000000000..7ea5c42c9d1c --- /dev/null +++ b/spark/v3.3/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/analysis/RewriteMergeIntoTable.scala @@ -0,0 +1,387 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.spark.sql.catalyst.analysis + +import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.catalyst.expressions.Alias +import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.catalyst.expressions.AttributeReference +import org.apache.spark.sql.catalyst.expressions.AttributeSet +import org.apache.spark.sql.catalyst.expressions.Expression +import org.apache.spark.sql.catalyst.expressions.ExtendedV2ExpressionUtils +import org.apache.spark.sql.catalyst.expressions.IsNotNull +import org.apache.spark.sql.catalyst.expressions.Literal +import org.apache.spark.sql.catalyst.expressions.Literal.FalseLiteral +import org.apache.spark.sql.catalyst.expressions.Literal.TrueLiteral +import org.apache.spark.sql.catalyst.expressions.MonotonicallyIncreasingID +import org.apache.spark.sql.catalyst.plans.FullOuter +import org.apache.spark.sql.catalyst.plans.Inner +import org.apache.spark.sql.catalyst.plans.LeftAnti +import org.apache.spark.sql.catalyst.plans.LeftOuter +import org.apache.spark.sql.catalyst.plans.RightOuter +import org.apache.spark.sql.catalyst.plans.logical.AppendData +import org.apache.spark.sql.catalyst.plans.logical.DeleteAction +import org.apache.spark.sql.catalyst.plans.logical.Filter +import org.apache.spark.sql.catalyst.plans.logical.HintInfo +import org.apache.spark.sql.catalyst.plans.logical.InsertAction +import org.apache.spark.sql.catalyst.plans.logical.Join +import org.apache.spark.sql.catalyst.plans.logical.JoinHint +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.catalyst.plans.logical.MergeAction +import org.apache.spark.sql.catalyst.plans.logical.MergeIntoIcebergTable +import org.apache.spark.sql.catalyst.plans.logical.MergeRows +import org.apache.spark.sql.catalyst.plans.logical.NO_BROADCAST_HASH +import org.apache.spark.sql.catalyst.plans.logical.NoStatsUnaryNode +import org.apache.spark.sql.catalyst.plans.logical.Project +import org.apache.spark.sql.catalyst.plans.logical.ReplaceIcebergData +import org.apache.spark.sql.catalyst.plans.logical.UpdateAction +import org.apache.spark.sql.catalyst.plans.logical.WriteDelta +import org.apache.spark.sql.catalyst.util.RowDeltaUtils._ +import org.apache.spark.sql.connector.catalog.SupportsRowLevelOperations +import org.apache.spark.sql.connector.expressions.FieldReference +import org.apache.spark.sql.connector.expressions.NamedReference +import org.apache.spark.sql.connector.iceberg.write.SupportsDelta +import org.apache.spark.sql.connector.write.RowLevelOperation.Command.MERGE +import org.apache.spark.sql.connector.write.RowLevelOperationTable +import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation +import org.apache.spark.sql.types.IntegerType +import org.apache.spark.sql.util.CaseInsensitiveStringMap + +/** + * Assigns a rewrite plan for v2 tables that support rewriting data to handle MERGE statements. + * + * This rule assumes the commands have been fully resolved and all assignments have been aligned. + * That's why it must be run after AlignRowLevelCommandAssignments. + */ +object RewriteMergeIntoTable extends RewriteRowLevelIcebergCommand { + + private final val ROW_FROM_SOURCE = "__row_from_source" + private final val ROW_FROM_TARGET = "__row_from_target" + private final val ROW_ID = "__row_id" + + private final val ROW_FROM_SOURCE_REF = FieldReference(ROW_FROM_SOURCE) + private final val ROW_FROM_TARGET_REF = FieldReference(ROW_FROM_TARGET) + private final val ROW_ID_REF = FieldReference(ROW_ID) + + override def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { + case m @ MergeIntoIcebergTable(aliasedTable, source, cond, matchedActions, notMatchedActions, None) + if m.resolved && m.aligned && matchedActions.isEmpty && notMatchedActions.size == 1 => + + EliminateSubqueryAliases(aliasedTable) match { + case r: DataSourceV2Relation => + // NOT MATCHED conditions may only refer to columns in source so they can be pushed down + val insertAction = notMatchedActions.head.asInstanceOf[InsertAction] + val filteredSource = insertAction.condition match { + case Some(insertCond) => Filter(insertCond, source) + case None => source + } + + // when there are no MATCHED actions, use a left anti join to remove any matching rows + // and switch to using a regular append instead of a row-level merge + // only unmatched source rows that match the condition are appended to the table + val joinPlan = Join(filteredSource, r, LeftAnti, Some(cond), JoinHint.NONE) + + val outputExprs = insertAction.assignments.map(_.value) + val outputColNames = r.output.map(_.name) + val outputCols = outputExprs.zip(outputColNames).map { case (expr, name) => + Alias(expr, name)() + } + val project = Project(outputCols, joinPlan) + + AppendData.byPosition(r, project) + + case p => + throw new AnalysisException(s"$p is not an Iceberg table") + } + + case m @ MergeIntoIcebergTable(aliasedTable, source, cond, matchedActions, notMatchedActions, None) + if m.resolved && m.aligned && matchedActions.isEmpty => + + EliminateSubqueryAliases(aliasedTable) match { + case r: DataSourceV2Relation => + // when there are no MATCHED actions, use a left anti join to remove any matching rows + // and switch to using a regular append instead of a row-level merge + // only unmatched source rows that match action conditions are appended to the table + val joinPlan = Join(source, r, LeftAnti, Some(cond), JoinHint.NONE) + + val notMatchedConditions = notMatchedActions.map(actionCondition) + val notMatchedOutputs = notMatchedActions.map(actionOutput(_, Nil)) + + // merge rows as there are multiple not matched actions + val mergeRows = MergeRows( + isSourceRowPresent = TrueLiteral, + isTargetRowPresent = FalseLiteral, + matchedConditions = Nil, + matchedOutputs = Nil, + notMatchedConditions = notMatchedConditions, + notMatchedOutputs = notMatchedOutputs, + targetOutput = Nil, + rowIdAttrs = Nil, + performCardinalityCheck = false, + emitNotMatchedTargetRows = false, + output = buildMergeRowsOutput(Nil, notMatchedOutputs, r.output), + joinPlan) + + AppendData.byPosition(r, mergeRows) + + case p => + throw new AnalysisException(s"$p is not an Iceberg table") + } + + case m @ MergeIntoIcebergTable(aliasedTable, source, cond, matchedActions, notMatchedActions, None) + if m.resolved && m.aligned => + + EliminateSubqueryAliases(aliasedTable) match { + case r @ DataSourceV2Relation(tbl: SupportsRowLevelOperations, _, _, _, _) => + val table = buildOperationTable(tbl, MERGE, CaseInsensitiveStringMap.empty()) + val rewritePlan = table.operation match { + case _: SupportsDelta => + buildWriteDeltaPlan(r, table, source, cond, matchedActions, notMatchedActions) + case _ => + buildReplaceDataPlan(r, table, source, cond, matchedActions, notMatchedActions) + } + + m.copy(rewritePlan = Some(rewritePlan)) + + case p => + throw new AnalysisException(s"$p is not an Iceberg table") + } + } + + // build a rewrite plan for sources that support replacing groups of data (e.g. files, partitions) + private def buildReplaceDataPlan( + relation: DataSourceV2Relation, + operationTable: RowLevelOperationTable, + source: LogicalPlan, + cond: Expression, + matchedActions: Seq[MergeAction], + notMatchedActions: Seq[MergeAction]): ReplaceIcebergData = { + + // resolve all needed attrs (e.g. metadata attrs for grouping data on write) + val metadataAttrs = resolveRequiredMetadataAttrs(relation, operationTable.operation) + + // construct a scan relation and include all required metadata columns + val readRelation = buildRelationWithAttrs(relation, operationTable, metadataAttrs) + val readAttrs = readRelation.output + + // project an extra column to check if a target row exists after the join + // project a synthetic row ID to perform the cardinality check + val rowFromTarget = Alias(TrueLiteral, ROW_FROM_TARGET)() + val rowId = Alias(MonotonicallyIncreasingID(), ROW_ID)() + val targetTableProjExprs = readAttrs ++ Seq(rowFromTarget, rowId) + val targetTableProj = Project(targetTableProjExprs, readRelation) + + // project an extra column to check if a source row exists after the join + val rowFromSource = Alias(TrueLiteral, ROW_FROM_SOURCE)() + val sourceTableProjExprs = source.output :+ rowFromSource + val sourceTableProj = Project(sourceTableProjExprs, source) + + // use left outer join if there is no NOT MATCHED action, unmatched source rows can be discarded + // use full outer join in all other cases, unmatched source rows may be needed + // disable broadcasts for the target table to perform the cardinality check + val joinType = if (notMatchedActions.isEmpty) LeftOuter else FullOuter + val joinHint = JoinHint(leftHint = Some(HintInfo(Some(NO_BROADCAST_HASH))), rightHint = None) + val joinPlan = Join(NoStatsUnaryNode(targetTableProj), sourceTableProj, joinType, Some(cond), joinHint) + + // add an extra matched action to output the original row if none of the actual actions matched + // this is needed to keep target rows that should be copied over + val matchedConditions = matchedActions.map(actionCondition) :+ TrueLiteral + val matchedOutputs = matchedActions.map(actionOutput(_, metadataAttrs)) :+ readAttrs + + val notMatchedConditions = notMatchedActions.map(actionCondition) + val notMatchedOutputs = notMatchedActions.map(actionOutput(_, metadataAttrs)) + + val rowIdAttr = resolveAttrRef(ROW_ID_REF, joinPlan) + val rowFromSourceAttr = resolveAttrRef(ROW_FROM_SOURCE_REF, joinPlan) + val rowFromTargetAttr = resolveAttrRef(ROW_FROM_TARGET_REF, joinPlan) + + val mergeRows = MergeRows( + isSourceRowPresent = IsNotNull(rowFromSourceAttr), + isTargetRowPresent = if (notMatchedActions.isEmpty) TrueLiteral else IsNotNull(rowFromTargetAttr), + matchedConditions = matchedConditions, + matchedOutputs = matchedOutputs, + notMatchedConditions = notMatchedConditions, + notMatchedOutputs = notMatchedOutputs, + targetOutput = readAttrs, + rowIdAttrs = Seq(rowIdAttr), + performCardinalityCheck = isCardinalityCheckNeeded(matchedActions), + emitNotMatchedTargetRows = true, + output = buildMergeRowsOutput(matchedOutputs, notMatchedOutputs, readAttrs), + joinPlan) + + // build a plan to replace read groups in the table + val writeRelation = relation.copy(table = operationTable) + ReplaceIcebergData(writeRelation, mergeRows, relation) + } + + // build a rewrite plan for sources that support row deltas + private def buildWriteDeltaPlan( + relation: DataSourceV2Relation, + operationTable: RowLevelOperationTable, + source: LogicalPlan, + cond: Expression, + matchedActions: Seq[MergeAction], + notMatchedActions: Seq[MergeAction]): WriteDelta = { + + // resolve all needed attrs (e.g. row ID and any required metadata attrs) + val rowAttrs = relation.output + val rowIdAttrs = resolveRowIdAttrs(relation, operationTable.operation) + val metadataAttrs = resolveRequiredMetadataAttrs(relation, operationTable.operation) + + // construct a scan relation and include all required metadata columns + val readRelation = buildRelationWithAttrs(relation, operationTable, rowIdAttrs ++ metadataAttrs) + val readAttrs = readRelation.output + + // project an extra column to check if a target row exists after the join + val targetTableProjExprs = readAttrs :+ Alias(TrueLiteral, ROW_FROM_TARGET)() + val targetTableProj = Project(targetTableProjExprs, readRelation) + + // project an extra column to check if a source row exists after the join + val sourceTableProjExprs = source.output :+ Alias(TrueLiteral, ROW_FROM_SOURCE)() + val sourceTableProj = Project(sourceTableProjExprs, source) + + // use inner join if there is no NOT MATCHED action, unmatched source rows can be discarded + // use right outer join in all other cases, unmatched source rows may be needed + // also disable broadcasts for the target table to perform the cardinality check + val joinType = if (notMatchedActions.isEmpty) Inner else RightOuter + val joinHint = JoinHint(leftHint = Some(HintInfo(Some(NO_BROADCAST_HASH))), rightHint = None) + val joinPlan = Join(NoStatsUnaryNode(targetTableProj), sourceTableProj, joinType, Some(cond), joinHint) + + val deleteRowValues = buildDeltaDeleteRowValues(rowAttrs, rowIdAttrs) + val metadataReadAttrs = readAttrs.filterNot(relation.outputSet.contains) + + val matchedConditions = matchedActions.map(actionCondition) + val matchedOutputs = matchedActions.map(deltaActionOutput(_, deleteRowValues, metadataReadAttrs)) + + val notMatchedConditions = notMatchedActions.map(actionCondition) + val notMatchedOutputs = notMatchedActions.map(deltaActionOutput(_, deleteRowValues, metadataReadAttrs)) + + val operationTypeAttr = AttributeReference(OPERATION_COLUMN, IntegerType, nullable = false)() + val rowFromSourceAttr = resolveAttrRef(ROW_FROM_SOURCE_REF, joinPlan) + val rowFromTargetAttr = resolveAttrRef(ROW_FROM_TARGET_REF, joinPlan) + + // merged rows must contain values for the operation type and all read attrs + val mergeRowsOutput = buildMergeRowsOutput(matchedOutputs, notMatchedOutputs, operationTypeAttr +: readAttrs) + + val mergeRows = MergeRows( + isSourceRowPresent = IsNotNull(rowFromSourceAttr), + isTargetRowPresent = if (notMatchedActions.isEmpty) TrueLiteral else IsNotNull(rowFromTargetAttr), + matchedConditions = matchedConditions, + matchedOutputs = matchedOutputs, + notMatchedConditions = notMatchedConditions, + notMatchedOutputs = notMatchedOutputs, + // only needed if emitting unmatched target rows + targetOutput = Nil, + rowIdAttrs = rowIdAttrs, + performCardinalityCheck = isCardinalityCheckNeeded(matchedActions), + emitNotMatchedTargetRows = false, + output = mergeRowsOutput, + joinPlan) + + // build a plan to write the row delta to the table + val writeRelation = relation.copy(table = operationTable) + val projections = buildWriteDeltaProjections(mergeRows, rowAttrs, rowIdAttrs, metadataAttrs) + WriteDelta(writeRelation, mergeRows, relation, projections) + } + + private def actionCondition(action: MergeAction): Expression = { + action.condition.getOrElse(TrueLiteral) + } + + private def actionOutput( + clause: MergeAction, + metadataAttrs: Seq[Attribute]): Seq[Expression] = { + + clause match { + case u: UpdateAction => + u.assignments.map(_.value) ++ metadataAttrs + + case _: DeleteAction => + Nil + + case i: InsertAction => + i.assignments.map(_.value) ++ metadataAttrs.map(attr => Literal(null, attr.dataType)) + + case other => + throw new AnalysisException(s"Unexpected action: $other") + } + } + + private def deltaActionOutput( + action: MergeAction, + deleteRowValues: Seq[Expression], + metadataAttrs: Seq[Attribute]): Seq[Expression] = { + + action match { + case u: UpdateAction => + Seq(Literal(UPDATE_OPERATION)) ++ u.assignments.map(_.value) ++ metadataAttrs + + case _: DeleteAction => + Seq(Literal(DELETE_OPERATION)) ++ deleteRowValues ++ metadataAttrs + + case i: InsertAction => + val metadataAttrValues = metadataAttrs.map(attr => Literal(null, attr.dataType)) + Seq(Literal(INSERT_OPERATION)) ++ i.assignments.map(_.value) ++ metadataAttrValues + + case other => + throw new AnalysisException(s"Unexpected action: $other") + } + } + + private def buildMergeRowsOutput( + matchedOutputs: Seq[Seq[Expression]], + notMatchedOutputs: Seq[Seq[Expression]], + attrs: Seq[Attribute]): Seq[Attribute] = { + + // collect all outputs from matched and not matched actions (ignoring DELETEs) + val outputs = matchedOutputs.filter(_.nonEmpty) ++ notMatchedOutputs.filter(_.nonEmpty) + + // build a correct nullability map for output attributes + // an attribute is nullable if at least one matched or not matched action may produce null + val nullabilityMap = attrs.indices.map { index => + index -> outputs.exists(output => output(index).nullable) + }.toMap + + attrs.zipWithIndex.map { case (attr, index) => + attr.withNullability(nullabilityMap(index)) + } + } + + private def isCardinalityCheckNeeded(actions: Seq[MergeAction]): Boolean = actions match { + case Seq(DeleteAction(None)) => false + case _ => true + } + + private def buildDeltaDeleteRowValues( + rowAttrs: Seq[Attribute], + rowIdAttrs: Seq[Attribute]): Seq[Expression] = { + + // nullify all row attrs that are not part of the row ID + val rowIdAttSet = AttributeSet(rowIdAttrs) + rowAttrs.map { + case attr if rowIdAttSet.contains(attr) => attr + case attr => Literal(null, attr.dataType) + } + } + + private def resolveAttrRef(ref: NamedReference, plan: LogicalPlan): AttributeReference = { + ExtendedV2ExpressionUtils.resolveRef[AttributeReference](ref, plan) + } +} diff --git a/spark/v3.3/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/analysis/RewriteRowLevelIcebergCommand.scala b/spark/v3.3/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/analysis/RewriteRowLevelIcebergCommand.scala new file mode 100644 index 000000000000..ec3b9576d005 --- /dev/null +++ b/spark/v3.3/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/analysis/RewriteRowLevelIcebergCommand.scala @@ -0,0 +1,101 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.spark.sql.catalyst.analysis + +import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.catalyst.ProjectingInternalRow +import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.catalyst.expressions.AttributeReference +import org.apache.spark.sql.catalyst.expressions.ExtendedV2ExpressionUtils +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.catalyst.util.WriteDeltaProjections +import org.apache.spark.sql.connector.iceberg.write.SupportsDelta +import org.apache.spark.sql.connector.write.RowLevelOperation +import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation +import org.apache.spark.sql.types.StructType + +trait RewriteRowLevelIcebergCommand extends RewriteRowLevelCommand { + + protected def buildWriteDeltaProjections( + plan: LogicalPlan, + rowAttrs: Seq[Attribute], + rowIdAttrs: Seq[Attribute], + metadataAttrs: Seq[Attribute]): WriteDeltaProjections = { + + val rowProjection = if (rowAttrs.nonEmpty) { + Some(newLazyProjection(plan, rowAttrs, usePlanTypes = true)) + } else { + None + } + + // in MERGE, the plan may contain both delete and insert records that may affect + // the nullability of metadata columns (e.g. metadata columns for new records are always null) + // since metadata columns are never passed with new records to insert, + // use the actual metadata column types instead of the ones present in the plan + + val rowIdProjection = newLazyProjection(plan, rowIdAttrs, usePlanTypes = false) + + val metadataProjection = if (metadataAttrs.nonEmpty) { + Some(newLazyProjection(plan, metadataAttrs, usePlanTypes = false)) + } else { + None + } + + WriteDeltaProjections(rowProjection, rowIdProjection, metadataProjection) + } + + // the projection is done by name, ignoring expr IDs + private def newLazyProjection( + plan: LogicalPlan, + attrs: Seq[Attribute], + usePlanTypes: Boolean): ProjectingInternalRow = { + + val colOrdinals = attrs.map(attr => plan.output.indexWhere(_.name == attr.name)) + val schema = if (usePlanTypes) { + val planAttrs = colOrdinals.map(plan.output(_)) + StructType.fromAttributes(planAttrs) + } else { + StructType.fromAttributes(attrs) + } + ProjectingInternalRow(schema, colOrdinals) + } + + protected def resolveRowIdAttrs( + relation: DataSourceV2Relation, + operation: RowLevelOperation): Seq[AttributeReference] = { + + operation match { + case supportsDelta: SupportsDelta => + val rowIdAttrs = ExtendedV2ExpressionUtils.resolveRefs[AttributeReference]( + supportsDelta.rowId.toSeq, + relation) + + val nullableRowIdAttrs = rowIdAttrs.filter(_.nullable) + if (nullableRowIdAttrs.nonEmpty) { + throw new AnalysisException(s"Row ID attrs cannot be nullable: $nullableRowIdAttrs") + } + + rowIdAttrs + + case other => + throw new AnalysisException(s"Operation $other does not support deltas") + } + } +} diff --git a/spark/v3.3/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/analysis/RewriteUpdateTable.scala b/spark/v3.3/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/analysis/RewriteUpdateTable.scala new file mode 100644 index 000000000000..f1c10c07da19 --- /dev/null +++ b/spark/v3.3/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/analysis/RewriteUpdateTable.scala @@ -0,0 +1,183 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.spark.sql.catalyst.analysis + +import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.catalyst.expressions.Alias +import org.apache.spark.sql.catalyst.expressions.EqualNullSafe +import org.apache.spark.sql.catalyst.expressions.Expression +import org.apache.spark.sql.catalyst.expressions.If +import org.apache.spark.sql.catalyst.expressions.Literal +import org.apache.spark.sql.catalyst.expressions.Not +import org.apache.spark.sql.catalyst.expressions.SubqueryExpression +import org.apache.spark.sql.catalyst.plans.logical.Assignment +import org.apache.spark.sql.catalyst.plans.logical.Filter +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.catalyst.plans.logical.Project +import org.apache.spark.sql.catalyst.plans.logical.ReplaceIcebergData +import org.apache.spark.sql.catalyst.plans.logical.Union +import org.apache.spark.sql.catalyst.plans.logical.UpdateIcebergTable +import org.apache.spark.sql.catalyst.plans.logical.WriteDelta +import org.apache.spark.sql.catalyst.util.RowDeltaUtils._ +import org.apache.spark.sql.connector.catalog.SupportsRowLevelOperations +import org.apache.spark.sql.connector.iceberg.write.SupportsDelta +import org.apache.spark.sql.connector.write.RowLevelOperation.Command.UPDATE +import org.apache.spark.sql.connector.write.RowLevelOperationTable +import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation +import org.apache.spark.sql.util.CaseInsensitiveStringMap + +/** + * Assigns a rewrite plan for v2 tables that support rewriting data to handle UPDATE statements. + * + * This rule assumes the commands have been fully resolved and all assignments have been aligned. + * That's why it must be run after AlignRowLevelCommandAssignments. + * + * This rule also must be run in the same batch with DeduplicateRelations in Spark. + */ +object RewriteUpdateTable extends RewriteRowLevelIcebergCommand { + + override def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { + case u @ UpdateIcebergTable(aliasedTable, assignments, cond, None) if u.resolved && u.aligned => + EliminateSubqueryAliases(aliasedTable) match { + case r @ DataSourceV2Relation(tbl: SupportsRowLevelOperations, _, _, _, _) => + val table = buildOperationTable(tbl, UPDATE, CaseInsensitiveStringMap.empty()) + val updateCond = cond.getOrElse(Literal.TrueLiteral) + val rewritePlan = table.operation match { + case _: SupportsDelta => + buildWriteDeltaPlan(r, table, assignments, updateCond) + case _ if SubqueryExpression.hasSubquery(updateCond) => + buildReplaceDataWithUnionPlan(r, table, assignments, updateCond) + case _ => + buildReplaceDataPlan(r, table, assignments, updateCond) + } + UpdateIcebergTable(r, assignments, cond, Some(rewritePlan)) + + case p => + throw new AnalysisException(s"$p is not an Iceberg table") + } + } + + // build a rewrite plan for sources that support replacing groups of data (e.g. files, partitions) + // if the condition does NOT contain a subquery + private def buildReplaceDataPlan( + relation: DataSourceV2Relation, + operationTable: RowLevelOperationTable, + assignments: Seq[Assignment], + cond: Expression): ReplaceIcebergData = { + + // resolve all needed attrs (e.g. metadata attrs for grouping data on write) + val metadataAttrs = resolveRequiredMetadataAttrs(relation, operationTable.operation) + + // construct a read relation and include all required metadata columns + val readRelation = buildRelationWithAttrs(relation, operationTable, metadataAttrs) + + // build a plan with updated and copied over records + val updatedAndRemainingRowsPlan = buildUpdateProjection(readRelation, assignments, cond) + + // build a plan to replace read groups in the table + val writeRelation = relation.copy(table = operationTable) + ReplaceIcebergData(writeRelation, updatedAndRemainingRowsPlan, relation) + } + + // build a rewrite plan for sources that support replacing groups of data (e.g. files, partitions) + // if the condition contains a subquery + private def buildReplaceDataWithUnionPlan( + relation: DataSourceV2Relation, + operationTable: RowLevelOperationTable, + assignments: Seq[Assignment], + cond: Expression): ReplaceIcebergData = { + + // resolve all needed attrs (e.g. metadata attrs for grouping data on write) + val metadataAttrs = resolveRequiredMetadataAttrs(relation, operationTable.operation) + + // construct a read relation and include all required metadata columns + // the same read relation will be used to read records that must be updated and be copied over + // DeduplicateRelations will take care of duplicated attr IDs + val readRelation = buildRelationWithAttrs(relation, operationTable, metadataAttrs) + + // build a plan for records that match the cond and should be updated + val matchedRowsPlan = Filter(cond, readRelation) + val updatedRowsPlan = buildUpdateProjection(matchedRowsPlan, assignments) + + // build a plan for records that did not match the cond but had to be copied over + val remainingRowFilter = Not(EqualNullSafe(cond, Literal.TrueLiteral)) + val remainingRowsPlan = Filter(remainingRowFilter, readRelation) + + // new state is a union of updated and copied over records + val updatedAndRemainingRowsPlan = Union(updatedRowsPlan, remainingRowsPlan) + + // build a plan to replace read groups in the table + val writeRelation = relation.copy(table = operationTable) + ReplaceIcebergData(writeRelation, updatedAndRemainingRowsPlan, relation) + } + + // build a rewrite plan for sources that support row deltas + private def buildWriteDeltaPlan( + relation: DataSourceV2Relation, + operationTable: RowLevelOperationTable, + assignments: Seq[Assignment], + cond: Expression): WriteDelta = { + + // resolve all needed attrs (e.g. row ID and any required metadata attrs) + val rowAttrs = relation.output + val rowIdAttrs = resolveRowIdAttrs(relation, operationTable.operation) + val metadataAttrs = resolveRequiredMetadataAttrs(relation, operationTable.operation) + + // construct a scan relation and include all required metadata columns + val readRelation = buildRelationWithAttrs(relation, operationTable, rowIdAttrs ++ metadataAttrs) + + // build a plan for updated records that match the cond + val matchedRowsPlan = Filter(cond, readRelation) + val updatedRowsPlan = buildUpdateProjection(matchedRowsPlan, assignments) + val operationType = Alias(Literal(UPDATE_OPERATION), OPERATION_COLUMN)() + val project = Project(operationType +: updatedRowsPlan.output, updatedRowsPlan) + + // build a plan to write the row delta to the table + val writeRelation = relation.copy(table = operationTable) + val projections = buildWriteDeltaProjections(project, rowAttrs, rowIdAttrs, metadataAttrs) + WriteDelta(writeRelation, project, relation, projections) + } + + // this method assumes the assignments have been already aligned before + // the condition passed to this method may be different from the UPDATE condition + private def buildUpdateProjection( + plan: LogicalPlan, + assignments: Seq[Assignment], + cond: Expression = Literal.TrueLiteral): LogicalPlan = { + + // TODO: avoid executing the condition for each column + + // the plan output may include metadata columns that are not modified + // that's why the number of assignments may not match the number of plan output columns + + val assignedValues = assignments.map(_.value) + val updatedValues = plan.output.zipWithIndex.map { case (attr, index) => + if (index < assignments.size) { + val assignedExpr = assignedValues(index) + val updatedValue = If(cond, assignedExpr, attr) + Alias(updatedValue, attr.name)() + } else { + attr + } + } + + Project(updatedValues, plan) + } +} diff --git a/spark/v3.3/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/expressions/AssignmentUtils.scala b/spark/v3.3/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/expressions/AssignmentUtils.scala new file mode 100644 index 000000000000..ce3818922c78 --- /dev/null +++ b/spark/v3.3/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/expressions/AssignmentUtils.scala @@ -0,0 +1,86 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.spark.sql.catalyst.expressions + +import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.catalyst.SQLConfHelper +import org.apache.spark.sql.catalyst.plans.logical.Assignment +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.catalyst.util.CharVarcharUtils +import org.apache.spark.sql.types.DataType + +object AssignmentUtils extends SQLConfHelper { + + /** + * Checks whether assignments are aligned and match table columns. + * + * @param table a target table + * @param assignments assignments to check + * @return true if the assignments are aligned + */ + def aligned(table: LogicalPlan, assignments: Seq[Assignment]): Boolean = { + val sameSize = table.output.size == assignments.size + sameSize && table.output.zip(assignments).forall { case (attr, assignment) => + val key = assignment.key + val value = assignment.value + val refsEqual = toAssignmentRef(attr).zip(toAssignmentRef(key)) + .forall{ case (attrRef, keyRef) => conf.resolver(attrRef, keyRef)} + + refsEqual && + DataType.equalsIgnoreCompatibleNullability(value.dataType, attr.dataType) && + (attr.nullable || !value.nullable) + } + } + + def toAssignmentRef(expr: Expression): Seq[String] = expr match { + case attr: AttributeReference => + Seq(attr.name) + case Alias(child, _) => + toAssignmentRef(child) + case GetStructField(child, _, Some(name)) => + toAssignmentRef(child) :+ name + case other: ExtractValue => + throw new AnalysisException(s"Updating nested fields is only supported for structs: $other") + case other => + throw new AnalysisException(s"Cannot convert to a reference, unsupported expression: $other") + } + + def handleCharVarcharLimits(assignment: Assignment): Assignment = { + val key = assignment.key + val value = assignment.value + + val rawKeyType = key.transform { + case attr: AttributeReference => + CharVarcharUtils.getRawType(attr.metadata) + .map(attr.withDataType) + .getOrElse(attr) + }.dataType + + if (CharVarcharUtils.hasCharVarchar(rawKeyType)) { + val newKey = key.transform { + case attr: AttributeReference => CharVarcharUtils.cleanAttrMetadata(attr) + } + val newValue = CharVarcharUtils.stringLengthCheck(value, rawKeyType) + Assignment(newKey, newValue) + } else { + assignment + } + } +} diff --git a/spark/v3.3/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExtendedV2ExpressionUtils.scala b/spark/v3.3/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExtendedV2ExpressionUtils.scala new file mode 100644 index 000000000000..16ff67a70522 --- /dev/null +++ b/spark/v3.3/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExtendedV2ExpressionUtils.scala @@ -0,0 +1,103 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.spark.sql.catalyst.expressions + +import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.catalyst.SQLConfHelper +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.connector.expressions.{Expression => V2Expression} +import org.apache.spark.sql.connector.expressions.{SortDirection => V2SortDirection} +import org.apache.spark.sql.connector.expressions.{NullOrdering => V2NullOrdering} +import org.apache.spark.sql.connector.expressions.BucketTransform +import org.apache.spark.sql.connector.expressions.DaysTransform +import org.apache.spark.sql.connector.expressions.FieldReference +import org.apache.spark.sql.connector.expressions.HoursTransform +import org.apache.spark.sql.connector.expressions.IdentityTransform +import org.apache.spark.sql.connector.expressions.MonthsTransform +import org.apache.spark.sql.connector.expressions.NamedReference +import org.apache.spark.sql.connector.expressions.SortValue +import org.apache.spark.sql.connector.expressions.Transform +import org.apache.spark.sql.connector.expressions.TruncateTransform +import org.apache.spark.sql.connector.expressions.YearsTransform +import org.apache.spark.sql.errors.QueryCompilationErrors + +/** + * A class that is inspired by V2ExpressionUtils in Spark but supports Iceberg transforms. + */ +object ExtendedV2ExpressionUtils extends SQLConfHelper { + import org.apache.spark.sql.connector.catalog.CatalogV2Implicits.MultipartIdentifierHelper + + def resolveRef[T <: NamedExpression](ref: NamedReference, plan: LogicalPlan): T = { + plan.resolve(ref.fieldNames.toSeq, conf.resolver) match { + case Some(namedExpr) => + namedExpr.asInstanceOf[T] + case None => + val name = ref.fieldNames.toSeq.quoted + val outputString = plan.output.map(_.name).mkString(",") + throw QueryCompilationErrors.cannotResolveAttributeError(name, outputString) + } + } + + def resolveRefs[T <: NamedExpression](refs: Seq[NamedReference], plan: LogicalPlan): Seq[T] = { + refs.map(ref => resolveRef[T](ref, plan)) + } + + def toCatalyst(expr: V2Expression, query: LogicalPlan): Expression = { + expr match { + case SortValue(child, direction, nullOrdering) => + val catalystChild = toCatalyst(child, query) + SortOrder(catalystChild, toCatalyst(direction), toCatalyst(nullOrdering), Seq.empty) + case IdentityTransform(ref) => + resolveRef[NamedExpression](ref, query) + case t: Transform if BucketTransform.unapply(t).isDefined => + t match { + // sort columns will be empty for bucket. + case BucketTransform(numBuckets, cols, _) => + IcebergBucketTransform(numBuckets, resolveRef[NamedExpression](cols.head, query)) + case _ => t.asInstanceOf[Expression] + // do nothing + } + case TruncateTransform(length, ref) => + IcebergTruncateTransform(resolveRef[NamedExpression](ref, query), length) + case YearsTransform(ref) => + IcebergYearTransform(resolveRef[NamedExpression](ref, query)) + case MonthsTransform(ref) => + IcebergMonthTransform(resolveRef[NamedExpression](ref, query)) + case DaysTransform(ref) => + IcebergDayTransform(resolveRef[NamedExpression](ref, query)) + case HoursTransform(ref) => + IcebergHourTransform(resolveRef[NamedExpression](ref, query)) + case ref: FieldReference => + resolveRef[NamedExpression](ref, query) + case _ => + throw new AnalysisException(s"$expr is not currently supported") + } + } + + private def toCatalyst(direction: V2SortDirection): SortDirection = direction match { + case V2SortDirection.ASCENDING => Ascending + case V2SortDirection.DESCENDING => Descending + } + + private def toCatalyst(nullOrdering: V2NullOrdering): NullOrdering = nullOrdering match { + case V2NullOrdering.NULLS_FIRST => NullsFirst + case V2NullOrdering.NULLS_LAST => NullsLast + } +} diff --git a/spark/v3.3/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ExtendedReplaceNullWithFalseInPredicate.scala b/spark/v3.3/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ExtendedReplaceNullWithFalseInPredicate.scala new file mode 100644 index 000000000000..d62cf6d83969 --- /dev/null +++ b/spark/v3.3/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ExtendedReplaceNullWithFalseInPredicate.scala @@ -0,0 +1,132 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.spark.sql.catalyst.optimizer + +import org.apache.spark.sql.catalyst.expressions.And +import org.apache.spark.sql.catalyst.expressions.CaseWhen +import org.apache.spark.sql.catalyst.expressions.Expression +import org.apache.spark.sql.catalyst.expressions.If +import org.apache.spark.sql.catalyst.expressions.In +import org.apache.spark.sql.catalyst.expressions.InSet +import org.apache.spark.sql.catalyst.expressions.Literal +import org.apache.spark.sql.catalyst.expressions.Literal.FalseLiteral +import org.apache.spark.sql.catalyst.expressions.Not +import org.apache.spark.sql.catalyst.expressions.Or +import org.apache.spark.sql.catalyst.plans.logical.DeleteAction +import org.apache.spark.sql.catalyst.plans.logical.DeleteFromIcebergTable +import org.apache.spark.sql.catalyst.plans.logical.InsertAction +import org.apache.spark.sql.catalyst.plans.logical.InsertStarAction +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.catalyst.plans.logical.MergeAction +import org.apache.spark.sql.catalyst.plans.logical.MergeIntoIcebergTable +import org.apache.spark.sql.catalyst.plans.logical.UpdateAction +import org.apache.spark.sql.catalyst.plans.logical.UpdateIcebergTable +import org.apache.spark.sql.catalyst.plans.logical.UpdateStarAction +import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.catalyst.trees.TreePattern.INSET +import org.apache.spark.sql.catalyst.trees.TreePattern.NULL_LITERAL +import org.apache.spark.sql.catalyst.trees.TreePattern.TRUE_OR_FALSE_LITERAL +import org.apache.spark.sql.types.BooleanType +import org.apache.spark.util.Utils + +/** + * A rule similar to ReplaceNullWithFalseInPredicate in Spark but applies to Iceberg row-level commands. + */ +object ExtendedReplaceNullWithFalseInPredicate extends Rule[LogicalPlan] { + + override def apply(plan: LogicalPlan): LogicalPlan = plan.transformWithPruning( + _.containsAnyPattern(NULL_LITERAL, TRUE_OR_FALSE_LITERAL, INSET)) { + + case d @ DeleteFromIcebergTable(_, Some(cond), _) => + d.copy(condition = Some(replaceNullWithFalse(cond))) + + case u @ UpdateIcebergTable(_, _, Some(cond), _) => + u.copy(condition = Some(replaceNullWithFalse(cond))) + + case m @ MergeIntoIcebergTable(_, _, mergeCond, matchedActions, notMatchedActions, _) => + m.copy( + mergeCondition = replaceNullWithFalse(mergeCond), + matchedActions = replaceNullWithFalse(matchedActions), + notMatchedActions = replaceNullWithFalse(notMatchedActions)) + } + + /** + * Recursively traverse the Boolean-type expression to replace + * `Literal(null, BooleanType)` with `FalseLiteral`, if possible. + * + * Note that `transformExpressionsDown` can not be used here as we must stop as soon as we hit + * an expression that is not [[CaseWhen]], [[If]], [[And]], [[Or]] or + * `Literal(null, BooleanType)`. + */ + private def replaceNullWithFalse(e: Expression): Expression = e match { + case Literal(null, BooleanType) => + FalseLiteral + // In SQL, the `Not(IN)` expression evaluates as follows: + // `NULL not in (1)` -> NULL + // `NULL not in (1, NULL)` -> NULL + // `1 not in (1, NULL)` -> false + // `1 not in (2, NULL)` -> NULL + // In predicate, NULL is equal to false, so we can simplify them to false directly. + case Not(In(value, list)) if (value +: list).exists(isNullLiteral) => + FalseLiteral + case Not(InSet(value, list)) if isNullLiteral(value) || list.contains(null) => + FalseLiteral + + case And(left, right) => + And(replaceNullWithFalse(left), replaceNullWithFalse(right)) + case Or(left, right) => + Or(replaceNullWithFalse(left), replaceNullWithFalse(right)) + case cw: CaseWhen if cw.dataType == BooleanType => + val newBranches = cw.branches.map { case (cond, value) => + replaceNullWithFalse(cond) -> replaceNullWithFalse(value) + } + val newElseValue = cw.elseValue.map(replaceNullWithFalse).getOrElse(FalseLiteral) + CaseWhen(newBranches, newElseValue) + case i @ If(pred, trueVal, falseVal) if i.dataType == BooleanType => + If(replaceNullWithFalse(pred), replaceNullWithFalse(trueVal), replaceNullWithFalse(falseVal)) + case e if e.dataType == BooleanType => + e + case e => + val message = "Expected a Boolean type expression in replaceNullWithFalse, " + + s"but got the type `${e.dataType.catalogString}` in `${e.sql}`." + if (Utils.isTesting) { + throw new IllegalArgumentException(message) + } else { + logWarning(message) + e + } + } + + private def isNullLiteral(e: Expression): Boolean = e match { + case Literal(null, _) => true + case _ => false + } + + private def replaceNullWithFalse(mergeActions: Seq[MergeAction]): Seq[MergeAction] = { + mergeActions.map { + case u @ UpdateAction(Some(cond), _) => u.copy(condition = Some(replaceNullWithFalse(cond))) + case u @ UpdateStarAction(Some(cond)) => u.copy(condition = Some(replaceNullWithFalse(cond))) + case d @ DeleteAction(Some(cond)) => d.copy(condition = Some(replaceNullWithFalse(cond))) + case i @ InsertAction(Some(cond), _) => i.copy(condition = Some(replaceNullWithFalse(cond))) + case i @ InsertStarAction(Some(cond)) => i.copy(condition = Some(replaceNullWithFalse(cond))) + case other => other + } + } +} diff --git a/spark/v3.3/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ExtendedSimplifyConditionalsInPredicate.scala b/spark/v3.3/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ExtendedSimplifyConditionalsInPredicate.scala new file mode 100644 index 000000000000..f4df565c44df --- /dev/null +++ b/spark/v3.3/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ExtendedSimplifyConditionalsInPredicate.scala @@ -0,0 +1,94 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.spark.sql.catalyst.optimizer + +import org.apache.spark.sql.catalyst.expressions.And +import org.apache.spark.sql.catalyst.expressions.CaseWhen +import org.apache.spark.sql.catalyst.expressions.Coalesce +import org.apache.spark.sql.catalyst.expressions.Expression +import org.apache.spark.sql.catalyst.expressions.If +import org.apache.spark.sql.catalyst.expressions.Literal +import org.apache.spark.sql.catalyst.expressions.Literal.FalseLiteral +import org.apache.spark.sql.catalyst.expressions.Literal.TrueLiteral +import org.apache.spark.sql.catalyst.expressions.Not +import org.apache.spark.sql.catalyst.expressions.Or +import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.catalyst.trees.TreePattern.CASE_WHEN +import org.apache.spark.sql.catalyst.trees.TreePattern.IF +import org.apache.spark.sql.types.BooleanType + +/** + * A rule similar to SimplifyConditionalsInPredicate in Spark but applies to Iceberg row-level commands. + */ +object ExtendedSimplifyConditionalsInPredicate extends Rule[LogicalPlan] { + + override def apply(plan: LogicalPlan): LogicalPlan = plan.transformWithPruning( + _.containsAnyPattern(CASE_WHEN, IF)) { + + case d @ DeleteFromIcebergTable(_, Some(cond), _) => + d.copy(condition = Some(simplifyConditional(cond))) + + case u @ UpdateIcebergTable(_, _, Some(cond), _) => + u.copy(condition = Some(simplifyConditional(cond))) + + case m @ MergeIntoIcebergTable(_, _, mergeCond, matchedActions, notMatchedActions, _) => + m.copy( + mergeCondition = simplifyConditional(mergeCond), + matchedActions = simplifyConditional(matchedActions), + notMatchedActions = simplifyConditional(notMatchedActions)) + } + + private def simplifyConditional(e: Expression): Expression = e match { + case And(left, right) => And(simplifyConditional(left), simplifyConditional(right)) + case Or(left, right) => Or(simplifyConditional(left), simplifyConditional(right)) + case If(cond, trueValue, FalseLiteral) => And(cond, trueValue) + case If(cond, trueValue, TrueLiteral) => Or(Not(Coalesce(Seq(cond, FalseLiteral))), trueValue) + case If(cond, FalseLiteral, falseValue) => + And(Not(Coalesce(Seq(cond, FalseLiteral))), falseValue) + case If(cond, TrueLiteral, falseValue) => Or(cond, falseValue) + case CaseWhen(Seq((cond, trueValue)), + Some(FalseLiteral) | Some(Literal(null, BooleanType)) | None) => + And(cond, trueValue) + case CaseWhen(Seq((cond, trueValue)), Some(TrueLiteral)) => + Or(Not(Coalesce(Seq(cond, FalseLiteral))), trueValue) + case CaseWhen(Seq((cond, FalseLiteral)), Some(elseValue)) => + And(Not(Coalesce(Seq(cond, FalseLiteral))), elseValue) + case CaseWhen(Seq((cond, TrueLiteral)), Some(elseValue)) => + Or(cond, elseValue) + case e if e.dataType == BooleanType => e + case e => + assert(e.dataType != BooleanType, + "Expected a Boolean type expression in ExtendedSimplifyConditionalsInPredicate, " + + s"but got the type `${e.dataType.catalogString}` in `${e.sql}`.") + e + } + + private def simplifyConditional(mergeActions: Seq[MergeAction]): Seq[MergeAction] = { + mergeActions.map { + case u @ UpdateAction(Some(cond), _) => u.copy(condition = Some(simplifyConditional(cond))) + case u @ UpdateStarAction(Some(cond)) => u.copy(condition = Some(simplifyConditional(cond))) + case d @ DeleteAction(Some(cond)) => d.copy(condition = Some(simplifyConditional(cond))) + case i @ InsertAction(Some(cond), _) => i.copy(condition = Some(simplifyConditional(cond))) + case i @ InsertStarAction(Some(cond)) => i.copy(condition = Some(simplifyConditional(cond))) + case other => other + } + } +} diff --git a/spark/v3.3/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/parser/extensions/IcebergSparkSqlExtensionsParser.scala b/spark/v3.3/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/parser/extensions/IcebergSparkSqlExtensionsParser.scala new file mode 100644 index 000000000000..00d292bcf86c --- /dev/null +++ b/spark/v3.3/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/parser/extensions/IcebergSparkSqlExtensionsParser.scala @@ -0,0 +1,376 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.spark.sql.catalyst.parser.extensions + +import java.util.Locale +import org.antlr.v4.runtime._ +import org.antlr.v4.runtime.atn.PredictionMode +import org.antlr.v4.runtime.misc.Interval +import org.antlr.v4.runtime.misc.ParseCancellationException +import org.antlr.v4.runtime.tree.TerminalNodeImpl +import org.apache.iceberg.common.DynConstructors +import org.apache.iceberg.spark.Spark3Util +import org.apache.iceberg.spark.source.SparkTable +import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.catalyst.FunctionIdentifier +import org.apache.spark.sql.catalyst.TableIdentifier +import org.apache.spark.sql.catalyst.analysis.EliminateSubqueryAliases +import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation +import org.apache.spark.sql.catalyst.expressions.Expression +import org.apache.spark.sql.catalyst.parser.ParserInterface +import org.apache.spark.sql.catalyst.parser.extensions.IcebergSqlExtensionsParser.NonReservedContext +import org.apache.spark.sql.catalyst.parser.extensions.IcebergSqlExtensionsParser.QuotedIdentifierContext +import org.apache.spark.sql.catalyst.plans.logical.DeleteFromIcebergTable +import org.apache.spark.sql.catalyst.plans.logical.DeleteFromTable +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.catalyst.plans.logical.MergeIntoContext +import org.apache.spark.sql.catalyst.plans.logical.MergeIntoTable +import org.apache.spark.sql.catalyst.plans.logical.UnresolvedMergeIntoIcebergTable +import org.apache.spark.sql.catalyst.plans.logical.UpdateIcebergTable +import org.apache.spark.sql.catalyst.plans.logical.UpdateTable +import org.apache.spark.sql.catalyst.trees.Origin +import org.apache.spark.sql.connector.catalog.Table +import org.apache.spark.sql.connector.catalog.TableCatalog +import org.apache.spark.sql.execution.command.ExplainCommand +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.internal.VariableSubstitution +import org.apache.spark.sql.types.DataType +import org.apache.spark.sql.types.StructType +import scala.jdk.CollectionConverters._ +import scala.util.Try + +class IcebergSparkSqlExtensionsParser(delegate: ParserInterface) extends ParserInterface { + + import IcebergSparkSqlExtensionsParser._ + + private lazy val substitutor = substitutorCtor.newInstance(SQLConf.get) + private lazy val astBuilder = new IcebergSqlExtensionsAstBuilder(delegate) + + /** + * Parse a string to a DataType. + */ + override def parseDataType(sqlText: String): DataType = { + delegate.parseDataType(sqlText) + } + + /** + * Parse a string to a raw DataType without CHAR/VARCHAR replacement. + */ + def parseRawDataType(sqlText: String): DataType = throw new UnsupportedOperationException() + + /** + * Parse a string to an Expression. + */ + override def parseExpression(sqlText: String): Expression = { + delegate.parseExpression(sqlText) + } + + /** + * Parse a string to a TableIdentifier. + */ + override def parseTableIdentifier(sqlText: String): TableIdentifier = { + delegate.parseTableIdentifier(sqlText) + } + + /** + * Parse a string to a FunctionIdentifier. + */ + override def parseFunctionIdentifier(sqlText: String): FunctionIdentifier = { + delegate.parseFunctionIdentifier(sqlText) + } + + /** + * Parse a string to a multi-part identifier. + */ + override def parseMultipartIdentifier(sqlText: String): Seq[String] = { + delegate.parseMultipartIdentifier(sqlText) + } + + /** + * Creates StructType for a given SQL string, which is a comma separated list of field + * definitions which will preserve the correct Hive metadata. + */ + override def parseTableSchema(sqlText: String): StructType = { + delegate.parseTableSchema(sqlText) + } + + /** + * Parse a string to a LogicalPlan. + */ + override def parsePlan(sqlText: String): LogicalPlan = { + val sqlTextAfterSubstitution = substitutor.substitute(sqlText) + if (isIcebergCommand(sqlTextAfterSubstitution)) { + parse(sqlTextAfterSubstitution) { parser => astBuilder.visit(parser.singleStatement()) }.asInstanceOf[LogicalPlan] + } else { + val parsedPlan = delegate.parsePlan(sqlText) + parsedPlan match { + case e: ExplainCommand => + e.copy(logicalPlan = replaceRowLevelCommands(e.logicalPlan)) + case p => + replaceRowLevelCommands(p) + } + } + } + + private def replaceRowLevelCommands(plan: LogicalPlan): LogicalPlan = plan resolveOperatorsDown { + case DeleteFromTable(UnresolvedIcebergTable(aliasedTable), condition) => + DeleteFromIcebergTable(aliasedTable, Some(condition)) + + case UpdateTable(UnresolvedIcebergTable(aliasedTable), assignments, condition) => + UpdateIcebergTable(aliasedTable, assignments, condition) + + case MergeIntoTable(UnresolvedIcebergTable(aliasedTable), source, cond, matchedActions, notMatchedActions) => + // cannot construct MergeIntoIcebergTable right away as MERGE operations require special resolution + // that's why the condition and actions must be hidden from the regular resolution rules in Spark + // see ResolveMergeIntoTableReferences for details + val context = MergeIntoContext(cond, matchedActions, notMatchedActions) + UnresolvedMergeIntoIcebergTable(aliasedTable, source, context) + } + + object UnresolvedIcebergTable { + + def unapply(plan: LogicalPlan): Option[LogicalPlan] = { + EliminateSubqueryAliases(plan) match { + case UnresolvedRelation(multipartIdentifier, _, _) if isIcebergTable(multipartIdentifier) => + Some(plan) + case _ => + None + } + } + + private def isIcebergTable(multipartIdent: Seq[String]): Boolean = { + val catalogAndIdentifier = Spark3Util.catalogAndIdentifier(SparkSession.active, multipartIdent.asJava) + catalogAndIdentifier.catalog match { + case tableCatalog: TableCatalog => + Try(tableCatalog.loadTable(catalogAndIdentifier.identifier)) + .map(isIcebergTable) + .getOrElse(false) + + case _ => + false + } + } + + private def isIcebergTable(table: Table): Boolean = table match { + case _: SparkTable => true + case _ => false + } + } + + private def isIcebergCommand(sqlText: String): Boolean = { + val normalized = sqlText.toLowerCase(Locale.ROOT).trim() + // Strip simple SQL comments that terminate a line, e.g. comments starting with `--` . + .replaceAll("--.*?\\n", " ") + // Strip newlines. + .replaceAll("\\s+", " ") + // Strip comments of the form /* ... */. This must come after stripping newlines so that + // comments that span multiple lines are caught. + .replaceAll("/\\*.*?\\*/", " ") + .trim() + normalized.startsWith("call") || ( + normalized.startsWith("alter table") && ( + normalized.contains("add partition field") || + normalized.contains("drop partition field") || + normalized.contains("replace partition field") || + normalized.contains("write ordered by") || + normalized.contains("write locally ordered by") || + normalized.contains("write distributed by") || + normalized.contains("write unordered") || + normalized.contains("set identifier fields") || + normalized.contains("drop identifier fields"))) + } + + protected def parse[T](command: String)(toResult: IcebergSqlExtensionsParser => T): T = { + val lexer = new IcebergSqlExtensionsLexer(new UpperCaseCharStream(CharStreams.fromString(command))) + lexer.removeErrorListeners() + lexer.addErrorListener(IcebergParseErrorListener) + + val tokenStream = new CommonTokenStream(lexer) + val parser = new IcebergSqlExtensionsParser(tokenStream) + parser.addParseListener(IcebergSqlExtensionsPostProcessor) + parser.removeErrorListeners() + parser.addErrorListener(IcebergParseErrorListener) + + try { + try { + // first, try parsing with potentially faster SLL mode + parser.getInterpreter.setPredictionMode(PredictionMode.SLL) + toResult(parser) + } + catch { + case _: ParseCancellationException => + // if we fail, parse with LL mode + tokenStream.seek(0) // rewind input stream + parser.reset() + + // Try Again. + parser.getInterpreter.setPredictionMode(PredictionMode.LL) + toResult(parser) + } + } + catch { + case e: IcebergParseException if e.command.isDefined => + throw e + case e: IcebergParseException => + throw e.withCommand(command) + case e: AnalysisException => + val position = Origin(e.line, e.startPosition) + throw new IcebergParseException(Option(command), e.message, position, position) + } + } + + override def parseQuery(sqlText: String): LogicalPlan = { + parsePlan(sqlText) + } +} + +object IcebergSparkSqlExtensionsParser { + private val substitutorCtor: DynConstructors.Ctor[VariableSubstitution] = + DynConstructors.builder() + .impl(classOf[VariableSubstitution]) + .impl(classOf[VariableSubstitution], classOf[SQLConf]) + .build() +} + +/* Copied from Apache Spark's to avoid dependency on Spark Internals */ +class UpperCaseCharStream(wrapped: CodePointCharStream) extends CharStream { + override def consume(): Unit = wrapped.consume + override def getSourceName(): String = wrapped.getSourceName + override def index(): Int = wrapped.index + override def mark(): Int = wrapped.mark + override def release(marker: Int): Unit = wrapped.release(marker) + override def seek(where: Int): Unit = wrapped.seek(where) + override def size(): Int = wrapped.size + + override def getText(interval: Interval): String = wrapped.getText(interval) + + // scalastyle:off + override def LA(i: Int): Int = { + val la = wrapped.LA(i) + if (la == 0 || la == IntStream.EOF) la + else Character.toUpperCase(la) + } + // scalastyle:on +} + +/** + * The post-processor validates & cleans-up the parse tree during the parse process. + */ +case object IcebergSqlExtensionsPostProcessor extends IcebergSqlExtensionsBaseListener { + + /** Remove the back ticks from an Identifier. */ + override def exitQuotedIdentifier(ctx: QuotedIdentifierContext): Unit = { + replaceTokenByIdentifier(ctx, 1) { token => + // Remove the double back ticks in the string. + token.setText(token.getText.replace("``", "`")) + token + } + } + + /** Treat non-reserved keywords as Identifiers. */ + override def exitNonReserved(ctx: NonReservedContext): Unit = { + replaceTokenByIdentifier(ctx, 0)(identity) + } + + private def replaceTokenByIdentifier( + ctx: ParserRuleContext, + stripMargins: Int)( + f: CommonToken => CommonToken = identity): Unit = { + val parent = ctx.getParent + parent.removeLastChild() + val token = ctx.getChild(0).getPayload.asInstanceOf[Token] + val newToken = new CommonToken( + new org.antlr.v4.runtime.misc.Pair(token.getTokenSource, token.getInputStream), + IcebergSqlExtensionsParser.IDENTIFIER, + token.getChannel, + token.getStartIndex + stripMargins, + token.getStopIndex - stripMargins) + parent.addChild(new TerminalNodeImpl(f(newToken))) + } +} + +/* Partially copied from Apache Spark's Parser to avoid dependency on Spark Internals */ +case object IcebergParseErrorListener extends BaseErrorListener { + override def syntaxError( + recognizer: Recognizer[_, _], + offendingSymbol: scala.Any, + line: Int, + charPositionInLine: Int, + msg: String, + e: RecognitionException): Unit = { + val (start, stop) = offendingSymbol match { + case token: CommonToken => + val start = Origin(Some(line), Some(token.getCharPositionInLine)) + val length = token.getStopIndex - token.getStartIndex + 1 + val stop = Origin(Some(line), Some(token.getCharPositionInLine + length)) + (start, stop) + case _ => + val start = Origin(Some(line), Some(charPositionInLine)) + (start, start) + } + throw new IcebergParseException(None, msg, start, stop) + } +} + +/** + * Copied from Apache Spark + * A [[ParseException]] is an [[AnalysisException]] that is thrown during the parse process. It + * contains fields and an extended error message that make reporting and diagnosing errors easier. + */ +class IcebergParseException( + val command: Option[String], + message: String, + val start: Origin, + val stop: Origin) extends AnalysisException(message, start.line, start.startPosition) { + + def this(message: String, ctx: ParserRuleContext) = { + this(Option(IcebergParserUtils.command(ctx)), + message, + IcebergParserUtils.position(ctx.getStart), + IcebergParserUtils.position(ctx.getStop)) + } + + override def getMessage: String = { + val builder = new StringBuilder + builder ++= "\n" ++= message + start match { + case Origin( + Some(l), Some(p), Some(startIndex), Some(stopIndex), Some(sqlText), Some(objectType), Some(objectName)) => + builder ++= s"(line $l, pos $p)\n" + command.foreach { cmd => + val (above, below) = cmd.split("\n").splitAt(l) + builder ++= "\n== SQL ==\n" + above.foreach(builder ++= _ += '\n') + builder ++= (0 until p).map(_ => "-").mkString("") ++= "^^^\n" + below.foreach(builder ++= _ += '\n') + } + case _ => + command.foreach { cmd => + builder ++= "\n== SQL ==\n" ++= cmd + } + } + builder.toString + } + + def withCommand(cmd: String): IcebergParseException = { + new IcebergParseException(Option(cmd), message, start, stop) + } +} \ No newline at end of file diff --git a/spark/v3.3/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/parser/extensions/IcebergSqlExtensionsAstBuilder.scala b/spark/v3.3/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/parser/extensions/IcebergSqlExtensionsAstBuilder.scala new file mode 100644 index 000000000000..8bb4754e5c84 --- /dev/null +++ b/spark/v3.3/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/parser/extensions/IcebergSqlExtensionsAstBuilder.scala @@ -0,0 +1,289 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.spark.sql.catalyst.parser.extensions + +import org.antlr.v4.runtime._ +import org.antlr.v4.runtime.misc.Interval +import org.antlr.v4.runtime.tree.ParseTree +import org.antlr.v4.runtime.tree.TerminalNode +import org.apache.iceberg.DistributionMode +import org.apache.iceberg.NullOrder +import org.apache.iceberg.SortDirection +import org.apache.iceberg.expressions.Term +import org.apache.iceberg.spark.Spark3Util +import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.catalyst.expressions.Expression +import org.apache.spark.sql.catalyst.expressions.Literal +import org.apache.spark.sql.catalyst.parser.ParserInterface +import org.apache.spark.sql.catalyst.parser.extensions.IcebergParserUtils.withOrigin +import org.apache.spark.sql.catalyst.parser.extensions.IcebergSqlExtensionsParser._ +import org.apache.spark.sql.catalyst.plans.logical.AddPartitionField +import org.apache.spark.sql.catalyst.plans.logical.CallArgument +import org.apache.spark.sql.catalyst.plans.logical.CallStatement +import org.apache.spark.sql.catalyst.plans.logical.DropIdentifierFields +import org.apache.spark.sql.catalyst.plans.logical.DropPartitionField +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.catalyst.plans.logical.NamedArgument +import org.apache.spark.sql.catalyst.plans.logical.PositionalArgument +import org.apache.spark.sql.catalyst.plans.logical.ReplacePartitionField +import org.apache.spark.sql.catalyst.plans.logical.SetIdentifierFields +import org.apache.spark.sql.catalyst.plans.logical.SetWriteDistributionAndOrdering +import org.apache.spark.sql.catalyst.trees.CurrentOrigin +import org.apache.spark.sql.catalyst.trees.Origin +import org.apache.spark.sql.connector.expressions +import org.apache.spark.sql.connector.expressions.ApplyTransform +import org.apache.spark.sql.connector.expressions.FieldReference +import org.apache.spark.sql.connector.expressions.IdentityTransform +import org.apache.spark.sql.connector.expressions.LiteralValue +import org.apache.spark.sql.connector.expressions.Transform +import scala.jdk.CollectionConverters._ + +class IcebergSqlExtensionsAstBuilder(delegate: ParserInterface) extends IcebergSqlExtensionsBaseVisitor[AnyRef] { + + private def toBuffer[T](list: java.util.List[T]): scala.collection.mutable.Buffer[T] = list.asScala + private def toSeq[T](list: java.util.List[T]): Seq[T] = toBuffer(list).toSeq + + /** + * Create a [[CallStatement]] for a stored procedure call. + */ + override def visitCall(ctx: CallContext): CallStatement = withOrigin(ctx) { + val name = toSeq(ctx.multipartIdentifier.parts).map(_.getText) + val args = toSeq(ctx.callArgument).map(typedVisit[CallArgument]) + CallStatement(name, args) + } + + /** + * Create an ADD PARTITION FIELD logical command. + */ + override def visitAddPartitionField(ctx: AddPartitionFieldContext): AddPartitionField = withOrigin(ctx) { + AddPartitionField( + typedVisit[Seq[String]](ctx.multipartIdentifier), + typedVisit[Transform](ctx.transform), + Option(ctx.name).map(_.getText)) + } + + /** + * Create a DROP PARTITION FIELD logical command. + */ + override def visitDropPartitionField(ctx: DropPartitionFieldContext): DropPartitionField = withOrigin(ctx) { + DropPartitionField( + typedVisit[Seq[String]](ctx.multipartIdentifier), + typedVisit[Transform](ctx.transform)) + } + + + /** + * Create an REPLACE PARTITION FIELD logical command. + */ + override def visitReplacePartitionField(ctx: ReplacePartitionFieldContext): ReplacePartitionField = withOrigin(ctx) { + ReplacePartitionField( + typedVisit[Seq[String]](ctx.multipartIdentifier), + typedVisit[Transform](ctx.transform(0)), + typedVisit[Transform](ctx.transform(1)), + Option(ctx.name).map(_.getText)) + } + + /** + * Create an SET IDENTIFIER FIELDS logical command. + */ + override def visitSetIdentifierFields(ctx: SetIdentifierFieldsContext): SetIdentifierFields = withOrigin(ctx) { + SetIdentifierFields( + typedVisit[Seq[String]](ctx.multipartIdentifier), + toSeq(ctx.fieldList.fields).map(_.getText)) + } + + /** + * Create an DROP IDENTIFIER FIELDS logical command. + */ + override def visitDropIdentifierFields(ctx: DropIdentifierFieldsContext): DropIdentifierFields = withOrigin(ctx) { + DropIdentifierFields( + typedVisit[Seq[String]](ctx.multipartIdentifier), + toSeq(ctx.fieldList.fields).map(_.getText)) + } + + /** + * Create a [[SetWriteDistributionAndOrdering]] for changing the write distribution and ordering. + */ + override def visitSetWriteDistributionAndOrdering( + ctx: SetWriteDistributionAndOrderingContext): SetWriteDistributionAndOrdering = { + + val tableName = typedVisit[Seq[String]](ctx.multipartIdentifier) + + val (distributionSpec, orderingSpec) = toDistributionAndOrderingSpec(ctx.writeSpec) + + if (distributionSpec == null && orderingSpec == null) { + throw new AnalysisException( + "ALTER TABLE has no changes: missing both distribution and ordering clauses") + } + + val distributionMode = if (distributionSpec != null) { + DistributionMode.HASH + } else if (orderingSpec.UNORDERED != null || orderingSpec.LOCALLY != null) { + DistributionMode.NONE + } else { + DistributionMode.RANGE + } + + val ordering = if (orderingSpec != null && orderingSpec.order != null) { + toSeq(orderingSpec.order.fields).map(typedVisit[(Term, SortDirection, NullOrder)]) + } else { + Seq.empty + } + + SetWriteDistributionAndOrdering(tableName, distributionMode, ordering) + } + + private def toDistributionAndOrderingSpec( + writeSpec: WriteSpecContext): (WriteDistributionSpecContext, WriteOrderingSpecContext) = { + + if (writeSpec.writeDistributionSpec.size > 1) { + throw new AnalysisException("ALTER TABLE contains multiple distribution clauses") + } + + if (writeSpec.writeOrderingSpec.size > 1) { + throw new AnalysisException("ALTER TABLE contains multiple ordering clauses") + } + + val distributionSpec = toBuffer(writeSpec.writeDistributionSpec).headOption.orNull + val orderingSpec = toBuffer(writeSpec.writeOrderingSpec).headOption.orNull + + (distributionSpec, orderingSpec) + } + + /** + * Create an order field. + */ + override def visitOrderField(ctx: OrderFieldContext): (Term, SortDirection, NullOrder) = { + val term = Spark3Util.toIcebergTerm(typedVisit[Transform](ctx.transform)) + val direction = Option(ctx.ASC).map(_ => SortDirection.ASC) + .orElse(Option(ctx.DESC).map(_ => SortDirection.DESC)) + .getOrElse(SortDirection.ASC) + val nullOrder = Option(ctx.FIRST).map(_ => NullOrder.NULLS_FIRST) + .orElse(Option(ctx.LAST).map(_ => NullOrder.NULLS_LAST)) + .getOrElse(if (direction == SortDirection.ASC) NullOrder.NULLS_FIRST else NullOrder.NULLS_LAST) + (term, direction, nullOrder) + } + + /** + * Create an IdentityTransform for a column reference. + */ + override def visitIdentityTransform(ctx: IdentityTransformContext): Transform = withOrigin(ctx) { + IdentityTransform(FieldReference(typedVisit[Seq[String]](ctx.multipartIdentifier()))) + } + + /** + * Create a named Transform from argument expressions. + */ + override def visitApplyTransform(ctx: ApplyTransformContext): Transform = withOrigin(ctx) { + val args = toSeq(ctx.arguments).map(typedVisit[expressions.Expression]) + ApplyTransform(ctx.transformName.getText, args) + } + + /** + * Create a transform argument from a column reference or a constant. + */ + override def visitTransformArgument(ctx: TransformArgumentContext): expressions.Expression = withOrigin(ctx) { + val reference = Option(ctx.multipartIdentifier()) + .map(typedVisit[Seq[String]]) + .map(FieldReference(_)) + val literal = Option(ctx.constant) + .map(visitConstant) + .map(lit => LiteralValue(lit.value, lit.dataType)) + reference.orElse(literal) + .getOrElse(throw new IcebergParseException(s"Invalid transform argument", ctx)) + } + + /** + * Return a multi-part identifier as Seq[String]. + */ + override def visitMultipartIdentifier(ctx: MultipartIdentifierContext): Seq[String] = withOrigin(ctx) { + toSeq(ctx.parts).map(_.getText) + } + + /** + * Create a positional argument in a stored procedure call. + */ + override def visitPositionalArgument(ctx: PositionalArgumentContext): CallArgument = withOrigin(ctx) { + val expr = typedVisit[Expression](ctx.expression) + PositionalArgument(expr) + } + + /** + * Create a named argument in a stored procedure call. + */ + override def visitNamedArgument(ctx: NamedArgumentContext): CallArgument = withOrigin(ctx) { + val name = ctx.identifier.getText + val expr = typedVisit[Expression](ctx.expression) + NamedArgument(name, expr) + } + + override def visitSingleStatement(ctx: SingleStatementContext): LogicalPlan = withOrigin(ctx) { + visit(ctx.statement).asInstanceOf[LogicalPlan] + } + + def visitConstant(ctx: ConstantContext): Literal = { + delegate.parseExpression(ctx.getText).asInstanceOf[Literal] + } + + override def visitExpression(ctx: ExpressionContext): Expression = { + // reconstruct the SQL string and parse it using the main Spark parser + // while we can avoid the logic to build Spark expressions, we still have to parse them + // we cannot call ctx.getText directly since it will not render spaces correctly + // that's why we need to recurse down the tree in reconstructSqlString + val sqlString = reconstructSqlString(ctx) + delegate.parseExpression(sqlString) + } + + private def reconstructSqlString(ctx: ParserRuleContext): String = { + toBuffer(ctx.children).map { + case c: ParserRuleContext => reconstructSqlString(c) + case t: TerminalNode => t.getText + }.mkString(" ") + } + + private def typedVisit[T](ctx: ParseTree): T = { + ctx.accept(this).asInstanceOf[T] + } +} + +/* Partially copied from Apache Spark's Parser to avoid dependency on Spark Internals */ +object IcebergParserUtils { + + private[sql] def withOrigin[T](ctx: ParserRuleContext)(f: => T): T = { + val current = CurrentOrigin.get + CurrentOrigin.set(position(ctx.getStart)) + try { + f + } finally { + CurrentOrigin.set(current) + } + } + + private[sql] def position(token: Token): Origin = { + val opt = Option(token) + Origin(opt.map(_.getLine), opt.map(_.getCharPositionInLine)) + } + + /** Get the command which created the token. */ + private[sql] def command(ctx: ParserRuleContext): String = { + val stream = ctx.getStart.getInputStream + stream.getText(Interval.of(0, stream.size() - 1)) + } +} diff --git a/spark/v3.3/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/planning/RewrittenRowLevelCommand.scala b/spark/v3.3/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/planning/RewrittenRowLevelCommand.scala new file mode 100644 index 000000000000..f45e44e2ce09 --- /dev/null +++ b/spark/v3.3/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/planning/RewrittenRowLevelCommand.scala @@ -0,0 +1,113 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.spark.sql.catalyst.planning + +import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.catalyst.plans.logical.ReplaceIcebergData +import org.apache.spark.sql.catalyst.plans.logical.RowLevelCommand +import org.apache.spark.sql.catalyst.plans.logical.UpdateIcebergTable +import org.apache.spark.sql.catalyst.plans.logical.WriteDelta +import org.apache.spark.sql.connector.catalog.Table +import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation +import org.apache.spark.sql.execution.datasources.v2.DataSourceV2ScanRelation + +/** + * An extractor for operations such as DELETE and MERGE that require rewriting data. + * + * This class extracts the following entities: + * - the row-level command (such as DeleteFromIcebergTable); + * - the read relation in the rewrite plan that can be either DataSourceV2Relation or + * DataSourceV2ScanRelation depending on whether the planning has already happened; + * - the current rewrite plan. + */ +object RewrittenRowLevelCommand { + type ReturnType = (RowLevelCommand, LogicalPlan, LogicalPlan) + + def unapply(plan: LogicalPlan): Option[ReturnType] = plan match { + case c: RowLevelCommand if c.rewritePlan.nonEmpty => + val rewritePlan = c.rewritePlan.get + + // both ReplaceData and WriteDelta reference a write relation + // but the corresponding read relation should be at the bottom of the write plan + // both the write and read relations will share the same RowLevelOperationTable object + // that's why it is safe to use reference equality to find the needed read relation + + val allowScanDuplication = c match { + // group-based updates that rely on the union approach may have multiple identical scans + case _: UpdateIcebergTable if rewritePlan.isInstanceOf[ReplaceIcebergData] => true + case _ => false + } + + rewritePlan match { + case rd @ ReplaceIcebergData(DataSourceV2Relation(table, _, _, _, _), query, _, _) => + val readRelation = findReadRelation(table, query, allowScanDuplication) + readRelation.map((c, _, rd)) + case wd @ WriteDelta(DataSourceV2Relation(table, _, _, _, _), query, _, _, _) => + val readRelation = findReadRelation(table, query, allowScanDuplication) + readRelation.map((c, _, wd)) + case _ => + None + } + + case _ => + None + } + + private def findReadRelation( + table: Table, + plan: LogicalPlan, + allowScanDuplication: Boolean): Option[LogicalPlan] = { + + val readRelations = plan.collect { + case r: DataSourceV2Relation if r.table eq table => r + case r: DataSourceV2ScanRelation if r.relation.table eq table => r + } + + // in some cases, the optimizer replaces the v2 read relation with a local relation + // for example, there is no reason to query the table if the condition is always false + // that's why it is valid not to find the corresponding v2 read relation + + readRelations match { + case relations if relations.isEmpty => + None + + case Seq(relation) => + Some(relation) + + case Seq(relation1: DataSourceV2Relation, relation2: DataSourceV2Relation) + if allowScanDuplication && (relation1.table eq relation2.table) => + Some(relation1) + + case Seq(relation1: DataSourceV2ScanRelation, relation2: DataSourceV2ScanRelation) + if allowScanDuplication && (relation1.scan eq relation2.scan) => + Some(relation1) + + case Seq(relation1, relation2) if allowScanDuplication => + throw new AnalysisException(s"Row-level read relations don't match: $relation1, $relation2") + + case relations if allowScanDuplication => + throw new AnalysisException(s"Expected up to two row-level read relations: $relations") + + case relations => + throw new AnalysisException(s"Expected only one row-level read relation: $relations") + } + } +} diff --git a/spark/v3.3/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/AddPartitionField.scala b/spark/v3.3/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/AddPartitionField.scala new file mode 100644 index 000000000000..e8b1b2941161 --- /dev/null +++ b/spark/v3.3/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/AddPartitionField.scala @@ -0,0 +1,33 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.spark.sql.catalyst.plans.logical + +import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.connector.expressions.Transform + +case class AddPartitionField(table: Seq[String], transform: Transform, name: Option[String]) extends LeafCommand { + import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._ + + override lazy val output: Seq[Attribute] = Nil + + override def simpleString(maxFields: Int): String = { + s"AddPartitionField ${table.quoted} ${name.map(n => s"$n=").getOrElse("")}${transform.describe}" + } +} diff --git a/spark/v3.3/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/Call.scala b/spark/v3.3/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/Call.scala new file mode 100644 index 000000000000..9616dae5a8d3 --- /dev/null +++ b/spark/v3.3/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/Call.scala @@ -0,0 +1,33 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.spark.sql.catalyst.plans.logical + +import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.catalyst.expressions.Expression +import org.apache.spark.sql.catalyst.util.truncatedString +import org.apache.spark.sql.connector.iceberg.catalog.Procedure + +case class Call(procedure: Procedure, args: Seq[Expression]) extends LeafCommand { + override lazy val output: Seq[Attribute] = procedure.outputType.toAttributes + + override def simpleString(maxFields: Int): String = { + s"Call${truncatedString(output.toSeq, "[", ", ", "]", maxFields)} ${procedure.description}" + } +} diff --git a/spark/v3.3/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/DeleteFromIcebergTable.scala b/spark/v3.3/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/DeleteFromIcebergTable.scala new file mode 100644 index 000000000000..d1268e416c50 --- /dev/null +++ b/spark/v3.3/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/DeleteFromIcebergTable.scala @@ -0,0 +1,49 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.spark.sql.catalyst.plans.logical + +import org.apache.spark.sql.catalyst.expressions.Expression + +case class DeleteFromIcebergTable( + table: LogicalPlan, + condition: Option[Expression], + rewritePlan: Option[LogicalPlan] = None) extends RowLevelCommand { + + override def children: Seq[LogicalPlan] = if (rewritePlan.isDefined) { + table :: rewritePlan.get :: Nil + } else { + table :: Nil + } + + override def withNewRewritePlan(newRewritePlan: LogicalPlan): RowLevelCommand = { + copy(rewritePlan = Some(newRewritePlan)) + } + + override protected def withNewChildrenInternal( + newChildren: IndexedSeq[LogicalPlan]): DeleteFromIcebergTable = { + if (newChildren.size == 1) { + copy(table = newChildren.head, rewritePlan = None) + } else { + require(newChildren.size == 2, "DeleteFromIcebergTable expects either one or two children") + val Seq(newTable, newRewritePlan) = newChildren.take(2) + copy(table = newTable, rewritePlan = Some(newRewritePlan)) + } + } +} diff --git a/spark/v3.3/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/DropIdentifierFields.scala b/spark/v3.3/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/DropIdentifierFields.scala new file mode 100644 index 000000000000..29dd686a0fba --- /dev/null +++ b/spark/v3.3/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/DropIdentifierFields.scala @@ -0,0 +1,34 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.spark.sql.catalyst.plans.logical + +import org.apache.spark.sql.catalyst.expressions.Attribute + +case class DropIdentifierFields( + table: Seq[String], + fields: Seq[String]) extends LeafCommand { + import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._ + + override lazy val output: Seq[Attribute] = Nil + + override def simpleString(maxFields: Int): String = { + s"DropIdentifierFields ${table.quoted} (${fields.quoted})" + } +} diff --git a/spark/v3.3/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/DropPartitionField.scala b/spark/v3.3/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/DropPartitionField.scala new file mode 100644 index 000000000000..fb1451324182 --- /dev/null +++ b/spark/v3.3/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/DropPartitionField.scala @@ -0,0 +1,33 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.spark.sql.catalyst.plans.logical + +import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.connector.expressions.Transform + +case class DropPartitionField(table: Seq[String], transform: Transform) extends LeafCommand { + import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._ + + override lazy val output: Seq[Attribute] = Nil + + override def simpleString(maxFields: Int): String = { + s"DropPartitionField ${table.quoted} ${transform.describe}" + } +} diff --git a/spark/v3.3/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/MergeIntoIcebergTable.scala b/spark/v3.3/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/MergeIntoIcebergTable.scala new file mode 100644 index 000000000000..8f84851dcda2 --- /dev/null +++ b/spark/v3.3/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/MergeIntoIcebergTable.scala @@ -0,0 +1,77 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.spark.sql.catalyst.plans.logical + +import org.apache.spark.sql.catalyst.expressions.AssignmentUtils +import org.apache.spark.sql.catalyst.expressions.Expression + +case class MergeIntoIcebergTable( + targetTable: LogicalPlan, + sourceTable: LogicalPlan, + mergeCondition: Expression, + matchedActions: Seq[MergeAction], + notMatchedActions: Seq[MergeAction], + rewritePlan: Option[LogicalPlan] = None) extends RowLevelCommand { + + lazy val aligned: Boolean = { + val matchedActionsAligned = matchedActions.forall { + case UpdateAction(_, assignments) => + AssignmentUtils.aligned(targetTable, assignments) + case _: DeleteAction => + true + case _ => + false + } + + val notMatchedActionsAligned = notMatchedActions.forall { + case InsertAction(_, assignments) => + AssignmentUtils.aligned(targetTable, assignments) + case _ => + false + } + + matchedActionsAligned && notMatchedActionsAligned + } + + def condition: Option[Expression] = Some(mergeCondition) + + override def children: Seq[LogicalPlan] = if (rewritePlan.isDefined) { + targetTable :: sourceTable :: rewritePlan.get :: Nil + } else { + targetTable :: sourceTable :: Nil + } + + override def withNewRewritePlan(newRewritePlan: LogicalPlan): RowLevelCommand = { + copy(rewritePlan = Some(newRewritePlan)) + } + + override protected def withNewChildrenInternal( + newChildren: IndexedSeq[LogicalPlan]): MergeIntoIcebergTable = { + + newChildren match { + case Seq(newTarget, newSource) => + copy(targetTable = newTarget, sourceTable = newSource, rewritePlan = None) + case Seq(newTarget, newSource, newRewritePlan) => + copy(targetTable = newTarget, sourceTable = newSource, rewritePlan = Some(newRewritePlan)) + case _ => + throw new IllegalArgumentException("MergeIntoIcebergTable expects either two or three children") + } + } +} diff --git a/spark/v3.3/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/MergeRows.scala b/spark/v3.3/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/MergeRows.scala new file mode 100644 index 000000000000..3607194fe8c8 --- /dev/null +++ b/spark/v3.3/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/MergeRows.scala @@ -0,0 +1,56 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.spark.sql.catalyst.plans.logical + +import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.catalyst.expressions.AttributeSet +import org.apache.spark.sql.catalyst.expressions.Expression +import org.apache.spark.sql.catalyst.util.truncatedString + +case class MergeRows( + isSourceRowPresent: Expression, + isTargetRowPresent: Expression, + matchedConditions: Seq[Expression], + matchedOutputs: Seq[Seq[Expression]], + notMatchedConditions: Seq[Expression], + notMatchedOutputs: Seq[Seq[Expression]], + targetOutput: Seq[Expression], + rowIdAttrs: Seq[Attribute], + performCardinalityCheck: Boolean, + emitNotMatchedTargetRows: Boolean, + output: Seq[Attribute], + child: LogicalPlan) extends UnaryNode { + + require(targetOutput.nonEmpty || !emitNotMatchedTargetRows) + + override lazy val producedAttributes: AttributeSet = { + AttributeSet(output.filterNot(attr => inputSet.contains(attr))) + } + + override lazy val references: AttributeSet = child.outputSet + + override def simpleString(maxFields: Int): String = { + s"MergeRows${truncatedString(output, "[", ", ", "]", maxFields)}" + } + + override protected def withNewChildInternal(newChild: LogicalPlan): LogicalPlan = { + copy(child = newChild) + } +} diff --git a/spark/v3.3/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/NoStatsUnaryNode.scala b/spark/v3.3/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/NoStatsUnaryNode.scala new file mode 100644 index 000000000000..c21df71f069d --- /dev/null +++ b/spark/v3.3/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/NoStatsUnaryNode.scala @@ -0,0 +1,31 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.spark.sql.catalyst.plans.logical + +import org.apache.spark.sql.catalyst.expressions.Attribute + +case class NoStatsUnaryNode(child: LogicalPlan) extends UnaryNode { + override def output: Seq[Attribute] = child.output + override def stats: Statistics = Statistics(Long.MaxValue) + + override protected def withNewChildInternal(newChild: LogicalPlan): LogicalPlan = { + copy(child = newChild) + } +} diff --git a/spark/v3.3/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/ReplaceIcebergData.scala b/spark/v3.3/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/ReplaceIcebergData.scala new file mode 100644 index 000000000000..2b741bef121a --- /dev/null +++ b/spark/v3.3/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/ReplaceIcebergData.scala @@ -0,0 +1,69 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.spark.sql.catalyst.plans.logical + +import org.apache.spark.sql.catalyst.analysis.NamedRelation +import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.catalyst.expressions.AttributeSet +import org.apache.spark.sql.catalyst.util.CharVarcharUtils +import org.apache.spark.sql.connector.write.Write +import org.apache.spark.sql.types.DataType + +/** + * Replace data in an existing table. + */ +case class ReplaceIcebergData( + table: NamedRelation, + query: LogicalPlan, + originalTable: NamedRelation, + write: Option[Write] = None) extends V2WriteCommandLike { + + override lazy val references: AttributeSet = query.outputSet + override lazy val stringArgs: Iterator[Any] = Iterator(table, query, write) + + // the incoming query may include metadata columns + lazy val dataInput: Seq[Attribute] = { + val tableAttrNames = table.output.map(_.name) + query.output.filter(attr => tableAttrNames.exists(conf.resolver(_, attr.name))) + } + + override def outputResolved: Boolean = { + assert(table.resolved && query.resolved, + "`outputResolved` can only be called when `table` and `query` are both resolved.") + + // take into account only incoming data columns and ignore metadata columns in the query + // they will be discarded after the logical write is built in the optimizer + // metadata columns may be needed to request a correct distribution or ordering + // but are not passed back to the data source during writes + + table.skipSchemaResolution || (dataInput.size == table.output.size && + dataInput.zip(table.output).forall { case (inAttr, outAttr) => + val outType = CharVarcharUtils.getRawType(outAttr.metadata).getOrElse(outAttr.dataType) + // names and types must match, nullability must be compatible + inAttr.name == outAttr.name && + DataType.equalsIgnoreCompatibleNullability(inAttr.dataType, outType) && + (outAttr.nullable || !inAttr.nullable) + }) + } + + override protected def withNewChildInternal(newChild: LogicalPlan): ReplaceIcebergData = { + copy(query = newChild) + } +} diff --git a/spark/v3.3/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/ReplacePartitionField.scala b/spark/v3.3/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/ReplacePartitionField.scala new file mode 100644 index 000000000000..8c660c6f37b1 --- /dev/null +++ b/spark/v3.3/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/ReplacePartitionField.scala @@ -0,0 +1,38 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.spark.sql.catalyst.plans.logical + +import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.connector.expressions.Transform + +case class ReplacePartitionField( + table: Seq[String], + transformFrom: Transform, + transformTo: Transform, + name: Option[String]) extends LeafCommand { + import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._ + + override lazy val output: Seq[Attribute] = Nil + + override def simpleString(maxFields: Int): String = { + s"ReplacePartitionField ${table.quoted} ${transformFrom.describe} " + + s"with ${name.map(n => s"$n=").getOrElse("")}${transformTo.describe}" + } +} diff --git a/spark/v3.3/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/RowLevelCommand.scala b/spark/v3.3/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/RowLevelCommand.scala new file mode 100644 index 000000000000..837ee963bcea --- /dev/null +++ b/spark/v3.3/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/RowLevelCommand.scala @@ -0,0 +1,28 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.spark.sql.catalyst.plans.logical + +import org.apache.spark.sql.catalyst.expressions.Expression + +trait RowLevelCommand extends Command with SupportsSubquery { + def condition: Option[Expression] + def rewritePlan: Option[LogicalPlan] + def withNewRewritePlan(newRewritePlan: LogicalPlan): RowLevelCommand +} diff --git a/spark/v3.3/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/SetIdentifierFields.scala b/spark/v3.3/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/SetIdentifierFields.scala new file mode 100644 index 000000000000..a5fa28a617e7 --- /dev/null +++ b/spark/v3.3/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/SetIdentifierFields.scala @@ -0,0 +1,35 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.spark.sql.catalyst.plans.logical + +import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.connector.expressions.Transform + +case class SetIdentifierFields( + table: Seq[String], + fields: Seq[String]) extends LeafCommand { + import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._ + + override lazy val output: Seq[Attribute] = Nil + + override def simpleString(maxFields: Int): String = { + s"SetIdentifierFields ${table.quoted} (${fields.quoted})" + } +} diff --git a/spark/v3.3/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/UnresolvedMergeIntoIcebergTable.scala b/spark/v3.3/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/UnresolvedMergeIntoIcebergTable.scala new file mode 100644 index 000000000000..895aa733ff20 --- /dev/null +++ b/spark/v3.3/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/UnresolvedMergeIntoIcebergTable.scala @@ -0,0 +1,45 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.spark.sql.catalyst.plans.logical + +import org.apache.spark.sql.catalyst.expressions.Expression + +/** + * A node that hides the MERGE condition and actions from regular Spark resolution. + */ +case class UnresolvedMergeIntoIcebergTable( + targetTable: LogicalPlan, + sourceTable: LogicalPlan, + context: MergeIntoContext) extends BinaryCommand { + + def duplicateResolved: Boolean = targetTable.outputSet.intersect(sourceTable.outputSet).isEmpty + + override def left: LogicalPlan = targetTable + override def right: LogicalPlan = sourceTable + + override protected def withNewChildrenInternal(newLeft: LogicalPlan, newRight: LogicalPlan): LogicalPlan = { + copy(targetTable = newLeft, sourceTable = newRight) + } +} + +case class MergeIntoContext( + mergeCondition: Expression, + matchedActions: Seq[MergeAction], + notMatchedActions: Seq[MergeAction]) diff --git a/spark/v3.3/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/UpdateIcebergTable.scala b/spark/v3.3/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/UpdateIcebergTable.scala new file mode 100644 index 000000000000..790eb9380e3d --- /dev/null +++ b/spark/v3.3/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/UpdateIcebergTable.scala @@ -0,0 +1,53 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.spark.sql.catalyst.plans.logical + +import org.apache.spark.sql.catalyst.expressions.AssignmentUtils +import org.apache.spark.sql.catalyst.expressions.Expression + +case class UpdateIcebergTable( + table: LogicalPlan, + assignments: Seq[Assignment], + condition: Option[Expression], + rewritePlan: Option[LogicalPlan] = None) extends RowLevelCommand { + + lazy val aligned: Boolean = AssignmentUtils.aligned(table, assignments) + + override def children: Seq[LogicalPlan] = if (rewritePlan.isDefined) { + table :: rewritePlan.get :: Nil + } else { + table :: Nil + } + + override def withNewRewritePlan(newRewritePlan: LogicalPlan): RowLevelCommand = { + copy(rewritePlan = Some(newRewritePlan)) + } + + override protected def withNewChildrenInternal( + newChildren: IndexedSeq[LogicalPlan]): UpdateIcebergTable = { + if (newChildren.size == 1) { + copy(table = newChildren.head, rewritePlan = None) + } else { + require(newChildren.size == 2, "UpdateTable expects either one or two children") + val Seq(newTable, newRewritePlan) = newChildren.take(2) + copy(table = newTable, rewritePlan = Some(newRewritePlan)) + } + } +} diff --git a/spark/v3.3/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/V2WriteCommandLike.scala b/spark/v3.3/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/V2WriteCommandLike.scala new file mode 100644 index 000000000000..9192d74b7caf --- /dev/null +++ b/spark/v3.3/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/V2WriteCommandLike.scala @@ -0,0 +1,42 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.spark.sql.catalyst.plans.logical + +import org.apache.spark.sql.catalyst.analysis.NamedRelation +import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.catalyst.expressions.AttributeSet + +// a node similar to V2WriteCommand in Spark but does not extend Command +// as ReplaceData and WriteDelta that extend this trait are nested within other commands +trait V2WriteCommandLike extends UnaryNode { + def table: NamedRelation + def query: LogicalPlan + def outputResolved: Boolean + + override lazy val resolved: Boolean = table.resolved && query.resolved && outputResolved + + override def child: LogicalPlan = query + override def output: Seq[Attribute] = Seq.empty + override def producedAttributes: AttributeSet = outputSet + // Commands are eagerly executed. They will be converted to LocalRelation after the DataFrame + // is created. That said, the statistics of a command is useless. Here we just return a dummy + // statistics to avoid unnecessary statistics calculation of command's children. + override def stats: Statistics = Statistics.DUMMY +} diff --git a/spark/v3.3/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/WriteDelta.scala b/spark/v3.3/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/WriteDelta.scala new file mode 100644 index 000000000000..534427cc0410 --- /dev/null +++ b/spark/v3.3/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/WriteDelta.scala @@ -0,0 +1,126 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.spark.sql.catalyst.plans.logical + +import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.catalyst.analysis.EliminateSubqueryAliases +import org.apache.spark.sql.catalyst.analysis.NamedRelation +import org.apache.spark.sql.catalyst.expressions.AttributeReference +import org.apache.spark.sql.catalyst.expressions.ExtendedV2ExpressionUtils +import org.apache.spark.sql.catalyst.expressions.NamedExpression +import org.apache.spark.sql.catalyst.util.CharVarcharUtils +import org.apache.spark.sql.catalyst.util.RowDeltaUtils.OPERATION_COLUMN +import org.apache.spark.sql.catalyst.util.WriteDeltaProjections +import org.apache.spark.sql.connector.iceberg.write.DeltaWrite +import org.apache.spark.sql.connector.iceberg.write.SupportsDelta +import org.apache.spark.sql.connector.write.RowLevelOperationTable +import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation +import org.apache.spark.sql.types.DataType +import org.apache.spark.sql.types.IntegerType +import org.apache.spark.sql.types.StructField + +/** + * Writes a delta of rows to an existing table. + */ +case class WriteDelta( + table: NamedRelation, + query: LogicalPlan, + originalTable: NamedRelation, + projections: WriteDeltaProjections, + write: Option[DeltaWrite] = None) extends V2WriteCommandLike { + + override protected lazy val stringArgs: Iterator[Any] = Iterator(table, query, write) + + private def operationResolved: Boolean = { + val attr = query.output.head + attr.name == OPERATION_COLUMN && attr.dataType == IntegerType && !attr.nullable + } + + private def operation: SupportsDelta = { + EliminateSubqueryAliases(table) match { + case DataSourceV2Relation(RowLevelOperationTable(_, operation), _, _, _, _) => + operation match { + case supportsDelta: SupportsDelta => + supportsDelta + case _ => + throw new AnalysisException(s"Operation $operation is not a delta operation") + } + case _ => + throw new AnalysisException(s"Cannot retrieve row-level operation from $table") + } + } + + private def rowAttrsResolved: Boolean = { + table.skipSchemaResolution || (projections.rowProjection match { + case Some(projection) => + table.output.size == projection.schema.size && + projection.schema.zip(table.output).forall { case (field, outAttr) => + isCompatible(field, outAttr) + } + case None => + true + }) + } + + private def rowIdAttrsResolved: Boolean = { + val rowIdAttrs = ExtendedV2ExpressionUtils.resolveRefs[AttributeReference]( + operation.rowId.toSeq, + originalTable) + + projections.rowIdProjection.schema.forall { field => + rowIdAttrs.exists(rowIdAttr => isCompatible(field, rowIdAttr)) + } + } + + private def metadataAttrsResolved: Boolean = { + projections.metadataProjection match { + case Some(projection) => + val metadataAttrs = ExtendedV2ExpressionUtils.resolveRefs[AttributeReference]( + operation.requiredMetadataAttributes.toSeq, + originalTable) + + projection.schema.forall { field => + metadataAttrs.exists(metadataAttr => isCompatible(field, metadataAttr)) + } + case None => + true + } + } + + private def isCompatible(projectionField: StructField, outAttr: NamedExpression): Boolean = { + val inType = CharVarcharUtils.getRawType(projectionField.metadata).getOrElse(outAttr.dataType) + val outType = CharVarcharUtils.getRawType(outAttr.metadata).getOrElse(outAttr.dataType) + // names and types must match, nullability must be compatible + projectionField.name == outAttr.name && + DataType.equalsIgnoreCompatibleNullability(inType, outType) && + (outAttr.nullable || !projectionField.nullable) + } + + override def outputResolved: Boolean = { + assert(table.resolved && query.resolved, + "`outputResolved` can only be called when `table` and `query` are both resolved.") + + operationResolved && rowAttrsResolved && rowIdAttrsResolved && metadataAttrsResolved + } + + override protected def withNewChildInternal(newChild: LogicalPlan): WriteDelta = { + copy(query = newChild) + } +} diff --git a/spark/v3.3/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statements.scala b/spark/v3.3/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statements.scala new file mode 100644 index 000000000000..be15f32bc1b8 --- /dev/null +++ b/spark/v3.3/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statements.scala @@ -0,0 +1,44 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.spark.sql.catalyst.plans.logical + +import org.apache.spark.sql.catalyst.expressions.Expression + +/** + * A CALL statement, as parsed from SQL. + */ +case class CallStatement(name: Seq[String], args: Seq[CallArgument]) extends LeafParsedStatement + +/** + * An argument in a CALL statement. + */ +sealed trait CallArgument { + def expr: Expression +} + +/** + * An argument in a CALL statement identified by name. + */ +case class NamedArgument(name: String, expr: Expression) extends CallArgument + +/** + * An argument in a CALL statement identified by position. + */ +case class PositionalArgument(expr: Expression) extends CallArgument diff --git a/spark/v3.3/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/util/RowDeltaUtils.scala b/spark/v3.3/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/util/RowDeltaUtils.scala new file mode 100644 index 000000000000..9f8f07b77db2 --- /dev/null +++ b/spark/v3.3/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/util/RowDeltaUtils.scala @@ -0,0 +1,27 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.spark.sql.catalyst.util + +object RowDeltaUtils { + final val OPERATION_COLUMN: String = "__row_operation" + final val DELETE_OPERATION: Int = 1 + final val UPDATE_OPERATION: Int = 2 + final val INSERT_OPERATION: Int = 3 +} diff --git a/spark/v3.3/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/util/WriteDeltaProjections.scala b/spark/v3.3/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/util/WriteDeltaProjections.scala new file mode 100644 index 000000000000..e4e8b147d139 --- /dev/null +++ b/spark/v3.3/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/util/WriteDeltaProjections.scala @@ -0,0 +1,27 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.spark.sql.catalyst.util + +import org.apache.spark.sql.catalyst.ProjectingInternalRow + +case class WriteDeltaProjections( + rowProjection: Option[ProjectingInternalRow], + rowIdProjection: ProjectingInternalRow, + metadataProjection: Option[ProjectingInternalRow]) diff --git a/spark/v3.3/spark-extensions/src/main/scala/org/apache/spark/sql/connector/expressions/TruncateTransform.scala b/spark/v3.3/spark-extensions/src/main/scala/org/apache/spark/sql/connector/expressions/TruncateTransform.scala new file mode 100644 index 000000000000..2a3269e2db1d --- /dev/null +++ b/spark/v3.3/spark-extensions/src/main/scala/org/apache/spark/sql/connector/expressions/TruncateTransform.scala @@ -0,0 +1,38 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.spark.sql.connector.expressions + +import org.apache.spark.sql.types.IntegerType + +private[sql] object TruncateTransform { + def unapply(expr: Expression): Option[(Int, FieldReference)] = expr match { + case transform: Transform => + transform match { + case NamedTransform("truncate", Seq(Ref(seq: Seq[String]), Lit(value: Int, IntegerType))) => + Some((value, FieldReference(seq))) + case NamedTransform("truncate", Seq(Lit(value: Int, IntegerType), Ref(seq: Seq[String]))) => + Some((value, FieldReference(seq))) + case _ => + None + } + case _ => + None + } +} diff --git a/spark/v3.3/spark-extensions/src/main/scala/org/apache/spark/sql/connector/write/ExtendedLogicalWriteInfoImpl.scala b/spark/v3.3/spark-extensions/src/main/scala/org/apache/spark/sql/connector/write/ExtendedLogicalWriteInfoImpl.scala new file mode 100644 index 000000000000..0ae10ada5a39 --- /dev/null +++ b/spark/v3.3/spark-extensions/src/main/scala/org/apache/spark/sql/connector/write/ExtendedLogicalWriteInfoImpl.scala @@ -0,0 +1,31 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.spark.sql.connector.write + +import org.apache.spark.sql.connector.iceberg.write.ExtendedLogicalWriteInfo +import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.util.CaseInsensitiveStringMap + +private[sql] case class ExtendedLogicalWriteInfoImpl( + queryId: String, + schema: StructType, + options: CaseInsensitiveStringMap, + rowIdSchema: StructType = null, + metadataSchema: StructType = null) extends ExtendedLogicalWriteInfo diff --git a/spark/v3.3/spark-extensions/src/main/scala/org/apache/spark/sql/execution/datasources/v2/AddPartitionFieldExec.scala b/spark/v3.3/spark-extensions/src/main/scala/org/apache/spark/sql/execution/datasources/v2/AddPartitionFieldExec.scala new file mode 100644 index 000000000000..55f327f7e45e --- /dev/null +++ b/spark/v3.3/spark-extensions/src/main/scala/org/apache/spark/sql/execution/datasources/v2/AddPartitionFieldExec.scala @@ -0,0 +1,56 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.spark.sql.execution.datasources.v2 + +import org.apache.iceberg.spark.Spark3Util +import org.apache.iceberg.spark.source.SparkTable +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.connector.catalog.Identifier +import org.apache.spark.sql.connector.catalog.TableCatalog +import org.apache.spark.sql.connector.expressions.Transform + +case class AddPartitionFieldExec( + catalog: TableCatalog, + ident: Identifier, + transform: Transform, + name: Option[String]) extends LeafV2CommandExec { + import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._ + + override lazy val output: Seq[Attribute] = Nil + + override protected def run(): Seq[InternalRow] = { + catalog.loadTable(ident) match { + case iceberg: SparkTable => + iceberg.table.updateSpec() + .addField(name.orNull, Spark3Util.toIcebergTerm(transform)) + .commit() + + case table => + throw new UnsupportedOperationException(s"Cannot add partition field to non-Iceberg table: $table") + } + + Nil + } + + override def simpleString(maxFields: Int): String = { + s"AddPartitionField ${catalog.name}.${ident.quoted} ${name.map(n => s"$n=").getOrElse("")}${transform.describe}" + } +} diff --git a/spark/v3.3/spark-extensions/src/main/scala/org/apache/spark/sql/execution/datasources/v2/CallExec.scala b/spark/v3.3/spark-extensions/src/main/scala/org/apache/spark/sql/execution/datasources/v2/CallExec.scala new file mode 100644 index 000000000000..f66962a8c453 --- /dev/null +++ b/spark/v3.3/spark-extensions/src/main/scala/org/apache/spark/sql/execution/datasources/v2/CallExec.scala @@ -0,0 +1,40 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.spark.sql.execution.datasources.v2 + +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.catalyst.util.truncatedString +import org.apache.spark.sql.connector.iceberg.catalog.Procedure +import scala.collection.compat.immutable.ArraySeq + +case class CallExec( + output: Seq[Attribute], + procedure: Procedure, + input: InternalRow) extends LeafV2CommandExec { + + override protected def run(): Seq[InternalRow] = { + ArraySeq.unsafeWrapArray(procedure.call(input)) + } + + override def simpleString(maxFields: Int): String = { + s"CallExec${truncatedString(output, "[", ", ", "]", maxFields)} ${procedure.description}" + } +} diff --git a/spark/v3.3/spark-extensions/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DropIdentifierFieldsExec.scala b/spark/v3.3/spark-extensions/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DropIdentifierFieldsExec.scala new file mode 100644 index 000000000000..dee778b474f9 --- /dev/null +++ b/spark/v3.3/spark-extensions/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DropIdentifierFieldsExec.scala @@ -0,0 +1,65 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.spark.sql.execution.datasources.v2 + +import org.apache.iceberg.relocated.com.google.common.base.Preconditions +import org.apache.iceberg.relocated.com.google.common.collect.Sets +import org.apache.iceberg.spark.source.SparkTable +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.connector.catalog.Identifier +import org.apache.spark.sql.connector.catalog.TableCatalog + +case class DropIdentifierFieldsExec( + catalog: TableCatalog, + ident: Identifier, + fields: Seq[String]) extends LeafV2CommandExec { + import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._ + + override lazy val output: Seq[Attribute] = Nil + + override protected def run(): Seq[InternalRow] = { + catalog.loadTable(ident) match { + case iceberg: SparkTable => + val schema = iceberg.table.schema + val identifierFieldNames = Sets.newHashSet(schema.identifierFieldNames) + + for (name <- fields) { + Preconditions.checkArgument(schema.findField(name) != null, + "Cannot complete drop identifier fields operation: field %s not found", name) + Preconditions.checkArgument(identifierFieldNames.contains(name), + "Cannot complete drop identifier fields operation: %s is not an identifier field", name) + identifierFieldNames.remove(name) + } + + iceberg.table.updateSchema() + .setIdentifierFields(identifierFieldNames) + .commit(); + case table => + throw new UnsupportedOperationException(s"Cannot drop identifier fields in non-Iceberg table: $table") + } + + Nil + } + + override def simpleString(maxFields: Int): String = { + s"DropIdentifierFields ${catalog.name}.${ident.quoted} (${fields.quoted})"; + } +} diff --git a/spark/v3.3/spark-extensions/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DropPartitionFieldExec.scala b/spark/v3.3/spark-extensions/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DropPartitionFieldExec.scala new file mode 100644 index 000000000000..9a153f0c004e --- /dev/null +++ b/spark/v3.3/spark-extensions/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DropPartitionFieldExec.scala @@ -0,0 +1,67 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.spark.sql.execution.datasources.v2 + +import org.apache.iceberg.spark.Spark3Util +import org.apache.iceberg.spark.source.SparkTable +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.connector.catalog.Identifier +import org.apache.spark.sql.connector.catalog.TableCatalog +import org.apache.spark.sql.connector.expressions.FieldReference +import org.apache.spark.sql.connector.expressions.IdentityTransform +import org.apache.spark.sql.connector.expressions.Transform + +case class DropPartitionFieldExec( + catalog: TableCatalog, + ident: Identifier, + transform: Transform) extends LeafV2CommandExec { + import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._ + + override lazy val output: Seq[Attribute] = Nil + + override protected def run(): Seq[InternalRow] = { + catalog.loadTable(ident) match { + case iceberg: SparkTable => + val schema = iceberg.table.schema + transform match { + case IdentityTransform(FieldReference(parts)) if parts.size == 1 && schema.findField(parts.head) == null => + // the name is not present in the Iceberg schema, so it must be a partition field name, not a column name + iceberg.table.updateSpec() + .removeField(parts.head) + .commit() + + case _ => + iceberg.table.updateSpec() + .removeField(Spark3Util.toIcebergTerm(transform)) + .commit() + } + + case table => + throw new UnsupportedOperationException(s"Cannot drop partition field in non-Iceberg table: $table") + } + + Nil + } + + override def simpleString(maxFields: Int): String = { + s"DropPartitionField ${catalog.name}.${ident.quoted} ${transform.describe}" + } +} diff --git a/spark/v3.3/spark-extensions/src/main/scala/org/apache/spark/sql/execution/datasources/v2/ExtendedDataSourceV2Implicits.scala b/spark/v3.3/spark-extensions/src/main/scala/org/apache/spark/sql/execution/datasources/v2/ExtendedDataSourceV2Implicits.scala new file mode 100644 index 000000000000..85bda0b08d46 --- /dev/null +++ b/spark/v3.3/spark-extensions/src/main/scala/org/apache/spark/sql/execution/datasources/v2/ExtendedDataSourceV2Implicits.scala @@ -0,0 +1,40 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.spark.sql.execution.datasources.v2 + +import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.connector.catalog.Table +import org.apache.spark.sql.connector.write.RowLevelOperationTable + +/** + * A class similar to DataSourceV2Implicits in Spark but contains custom implicit helpers. + */ +object ExtendedDataSourceV2Implicits { + implicit class TableHelper(table: Table) { + def asRowLevelOperationTable: RowLevelOperationTable = { + table match { + case rowLevelOperationTable: RowLevelOperationTable => + rowLevelOperationTable + case _ => + throw new AnalysisException(s"Table ${table.name} is not a row-level operation table") + } + } + } +} diff --git a/spark/v3.3/spark-extensions/src/main/scala/org/apache/spark/sql/execution/datasources/v2/ExtendedDataSourceV2Strategy.scala b/spark/v3.3/spark-extensions/src/main/scala/org/apache/spark/sql/execution/datasources/v2/ExtendedDataSourceV2Strategy.scala new file mode 100644 index 000000000000..b5eb3e47b7cc --- /dev/null +++ b/spark/v3.3/spark-extensions/src/main/scala/org/apache/spark/sql/execution/datasources/v2/ExtendedDataSourceV2Strategy.scala @@ -0,0 +1,141 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.spark.sql.execution.datasources.v2 + +import org.apache.iceberg.spark.Spark3Util +import org.apache.iceberg.spark.SparkCatalog +import org.apache.iceberg.spark.SparkSessionCatalog +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.Strategy +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.Expression +import org.apache.spark.sql.catalyst.expressions.GenericInternalRow +import org.apache.spark.sql.catalyst.expressions.Literal +import org.apache.spark.sql.catalyst.expressions.PredicateHelper +import org.apache.spark.sql.catalyst.plans.logical.AddPartitionField +import org.apache.spark.sql.catalyst.plans.logical.Call +import org.apache.spark.sql.catalyst.plans.logical.DeleteFromIcebergTable +import org.apache.spark.sql.catalyst.plans.logical.DropIdentifierFields +import org.apache.spark.sql.catalyst.plans.logical.DropPartitionField +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.catalyst.plans.logical.MergeRows +import org.apache.spark.sql.catalyst.plans.logical.NoStatsUnaryNode +import org.apache.spark.sql.catalyst.plans.logical.ReplaceIcebergData +import org.apache.spark.sql.catalyst.plans.logical.ReplacePartitionField +import org.apache.spark.sql.catalyst.plans.logical.SetIdentifierFields +import org.apache.spark.sql.catalyst.plans.logical.SetWriteDistributionAndOrdering +import org.apache.spark.sql.catalyst.plans.logical.WriteDelta +import org.apache.spark.sql.connector.catalog.Identifier +import org.apache.spark.sql.connector.catalog.TableCatalog +import org.apache.spark.sql.errors.QueryCompilationErrors +import org.apache.spark.sql.execution.SparkPlan +import org.apache.spark.sql.execution.datasources.DataSourceStrategy +import scala.jdk.CollectionConverters._ + +case class ExtendedDataSourceV2Strategy(spark: SparkSession) extends Strategy with PredicateHelper { + + import DataSourceV2Implicits._ + + override def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { + case c @ Call(procedure, args) => + val input = buildInternalRow(args) + CallExec(c.output, procedure, input) :: Nil + + case AddPartitionField(IcebergCatalogAndIdentifier(catalog, ident), transform, name) => + AddPartitionFieldExec(catalog, ident, transform, name) :: Nil + + case DropPartitionField(IcebergCatalogAndIdentifier(catalog, ident), transform) => + DropPartitionFieldExec(catalog, ident, transform) :: Nil + + case ReplacePartitionField(IcebergCatalogAndIdentifier(catalog, ident), transformFrom, transformTo, name) => + ReplacePartitionFieldExec(catalog, ident, transformFrom, transformTo, name) :: Nil + + case SetIdentifierFields(IcebergCatalogAndIdentifier(catalog, ident), fields) => + SetIdentifierFieldsExec(catalog, ident, fields) :: Nil + + case DropIdentifierFields(IcebergCatalogAndIdentifier(catalog, ident), fields) => + DropIdentifierFieldsExec(catalog, ident, fields) :: Nil + + case SetWriteDistributionAndOrdering( + IcebergCatalogAndIdentifier(catalog, ident), distributionMode, ordering) => + SetWriteDistributionAndOrderingExec(catalog, ident, distributionMode, ordering) :: Nil + + case ReplaceIcebergData(_: DataSourceV2Relation, query, r: DataSourceV2Relation, Some(write)) => + // refresh the cache using the original relation + ReplaceDataExec(planLater(query), refreshCache(r), write) :: Nil + + case WriteDelta(_: DataSourceV2Relation, query, r: DataSourceV2Relation, projs, Some(write)) => + // refresh the cache using the original relation + WriteDeltaExec(planLater(query), refreshCache(r), projs, write) :: Nil + + case MergeRows(isSourceRowPresent, isTargetRowPresent, matchedConditions, matchedOutputs, notMatchedConditions, + notMatchedOutputs, targetOutput, rowIdAttrs, performCardinalityCheck, emitNotMatchedTargetRows, + output, child) => + + MergeRowsExec(isSourceRowPresent, isTargetRowPresent, matchedConditions, matchedOutputs, notMatchedConditions, + notMatchedOutputs, targetOutput, rowIdAttrs, performCardinalityCheck, emitNotMatchedTargetRows, + output, planLater(child)) :: Nil + + case DeleteFromIcebergTable(DataSourceV2ScanRelation(r, _, output, _), condition, None) => + // the optimizer has already checked that this delete can be handled using a metadata operation + val deleteCond = condition.getOrElse(Literal.TrueLiteral) + val predicates = splitConjunctivePredicates(deleteCond) + val normalizedPredicates = DataSourceStrategy.normalizeExprs(predicates, output) + val filters = normalizedPredicates.flatMap { pred => + val filter = DataSourceStrategy.translateFilter(pred, supportNestedPredicatePushdown = true) + if (filter.isEmpty) { + throw QueryCompilationErrors.cannotTranslateExpressionToSourceFilterError(pred) + } + filter + }.toArray + DeleteFromTableExec(r.table.asDeletable, filters, refreshCache(r)) :: Nil + + case NoStatsUnaryNode(child) => + planLater(child) :: Nil + + case _ => Nil + } + + private def buildInternalRow(exprs: Seq[Expression]): InternalRow = { + val values = new Array[Any](exprs.size) + for (index <- exprs.indices) { + values(index) = exprs(index).eval() + } + new GenericInternalRow(values) + } + + private def refreshCache(r: DataSourceV2Relation)(): Unit = { + spark.sharedState.cacheManager.recacheByPlan(spark, r) + } + + private object IcebergCatalogAndIdentifier { + def unapply(identifier: Seq[String]): Option[(TableCatalog, Identifier)] = { + val catalogAndIdentifier = Spark3Util.catalogAndIdentifier(spark, identifier.asJava) + catalogAndIdentifier.catalog match { + case icebergCatalog: SparkCatalog => + Some((icebergCatalog, catalogAndIdentifier.identifier)) + case icebergCatalog: SparkSessionCatalog[_] => + Some((icebergCatalog, catalogAndIdentifier.identifier)) + case _ => + None + } + } + } +} diff --git a/spark/v3.3/spark-extensions/src/main/scala/org/apache/spark/sql/execution/datasources/v2/ExtendedDistributionAndOrderingUtils.scala b/spark/v3.3/spark-extensions/src/main/scala/org/apache/spark/sql/execution/datasources/v2/ExtendedDistributionAndOrderingUtils.scala new file mode 100644 index 000000000000..8c37b1b75924 --- /dev/null +++ b/spark/v3.3/spark-extensions/src/main/scala/org/apache/spark/sql/execution/datasources/v2/ExtendedDistributionAndOrderingUtils.scala @@ -0,0 +1,85 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.spark.sql.execution.datasources.v2 + +import org.apache.spark.sql.catalyst.expressions.Expression +import org.apache.spark.sql.catalyst.expressions.ExtendedV2ExpressionUtils.toCatalyst +import org.apache.spark.sql.catalyst.expressions.SortOrder +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.catalyst.plans.logical.RepartitionByExpression +import org.apache.spark.sql.catalyst.plans.logical.Sort +import org.apache.spark.sql.connector.distributions.ClusteredDistribution +import org.apache.spark.sql.connector.distributions.OrderedDistribution +import org.apache.spark.sql.connector.distributions.UnspecifiedDistribution +import org.apache.spark.sql.connector.write.RequiresDistributionAndOrdering +import org.apache.spark.sql.connector.write.Write +import org.apache.spark.sql.errors.QueryCompilationErrors +import org.apache.spark.sql.internal.SQLConf +import scala.collection.compat.immutable.ArraySeq + +/** + * A rule that is inspired by DistributionAndOrderingUtils in Spark but supports Iceberg transforms. + * + * Note that similarly to the original rule in Spark, it does not let AQE pick the number of shuffle + * partitions. See SPARK-34230 for context. + */ +object ExtendedDistributionAndOrderingUtils { + + def prepareQuery(write: Write, query: LogicalPlan, conf: SQLConf): LogicalPlan = write match { + case write: RequiresDistributionAndOrdering => + val numPartitions = write.requiredNumPartitions() + val distribution = write.requiredDistribution match { + case d: OrderedDistribution => d.ordering.map(e => toCatalyst(e, query)) + case d: ClusteredDistribution => d.clustering.map(e => toCatalyst(e, query)) + case _: UnspecifiedDistribution => Array.empty[Expression] + } + + val queryWithDistribution = if (distribution.nonEmpty) { + val finalNumPartitions = if (numPartitions > 0) { + numPartitions + } else { + conf.numShufflePartitions + } + // the conversion to catalyst expressions above produces SortOrder expressions + // for OrderedDistribution and generic expressions for ClusteredDistribution + // this allows RepartitionByExpression to pick either range or hash partitioning + RepartitionByExpression(ArraySeq.unsafeWrapArray(distribution), query, finalNumPartitions) + } else if (numPartitions > 0) { + throw QueryCompilationErrors.numberOfPartitionsNotAllowedWithUnspecifiedDistributionError() + } else { + query + } + + val ordering = write.requiredOrdering.toSeq + .map(e => toCatalyst(e, query)) + .asInstanceOf[Seq[SortOrder]] + + val queryWithDistributionAndOrdering = if (ordering.nonEmpty) { + Sort(ordering, global = false, queryWithDistribution) + } else { + queryWithDistribution + } + + queryWithDistributionAndOrdering + + case _ => + query + } +} diff --git a/spark/v3.3/spark-extensions/src/main/scala/org/apache/spark/sql/execution/datasources/v2/ExtendedV2Writes.scala b/spark/v3.3/spark-extensions/src/main/scala/org/apache/spark/sql/execution/datasources/v2/ExtendedV2Writes.scala new file mode 100644 index 000000000000..3e60d2494b05 --- /dev/null +++ b/spark/v3.3/spark-extensions/src/main/scala/org/apache/spark/sql/execution/datasources/v2/ExtendedV2Writes.scala @@ -0,0 +1,140 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.spark.sql.execution.datasources.v2 + +import java.util.UUID +import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.catalyst.expressions.PredicateHelper +import org.apache.spark.sql.catalyst.plans.logical.AppendData +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.catalyst.plans.logical.OverwriteByExpression +import org.apache.spark.sql.catalyst.plans.logical.OverwritePartitionsDynamic +import org.apache.spark.sql.catalyst.plans.logical.Project +import org.apache.spark.sql.catalyst.plans.logical.ReplaceIcebergData +import org.apache.spark.sql.catalyst.plans.logical.WriteDelta +import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.catalyst.utils.PlanUtils.isIcebergRelation +import org.apache.spark.sql.connector.catalog.Table +import org.apache.spark.sql.connector.iceberg.write.DeltaWriteBuilder +import org.apache.spark.sql.connector.write.ExtendedLogicalWriteInfoImpl +import org.apache.spark.sql.connector.write.SupportsDynamicOverwrite +import org.apache.spark.sql.connector.write.SupportsOverwrite +import org.apache.spark.sql.connector.write.SupportsTruncate +import org.apache.spark.sql.connector.write.WriteBuilder +import org.apache.spark.sql.errors.QueryCompilationErrors +import org.apache.spark.sql.errors.QueryExecutionErrors +import org.apache.spark.sql.execution.datasources.DataSourceStrategy +import org.apache.spark.sql.sources.AlwaysTrue +import org.apache.spark.sql.sources.Filter +import org.apache.spark.sql.types.StructType + +/** + * A rule that is inspired by V2Writes in Spark but supports Iceberg transforms. + */ +object ExtendedV2Writes extends Rule[LogicalPlan] with PredicateHelper { + + import DataSourceV2Implicits._ + + override def apply(plan: LogicalPlan): LogicalPlan = plan transformDown { + case a @ AppendData(r: DataSourceV2Relation, query, options, _, None) if isIcebergRelation(r) => + val writeBuilder = newWriteBuilder(r.table, query.schema, options) + val write = writeBuilder.build() + val newQuery = ExtendedDistributionAndOrderingUtils.prepareQuery(write, query, conf) + a.copy(write = Some(write), query = newQuery) + + case o @ OverwriteByExpression(r: DataSourceV2Relation, deleteExpr, query, options, _, None) + if isIcebergRelation(r) => + // fail if any filter cannot be converted. correctness depends on removing all matching data. + val filters = splitConjunctivePredicates(deleteExpr).flatMap { pred => + val filter = DataSourceStrategy.translateFilter(pred, supportNestedPredicatePushdown = true) + if (filter.isEmpty) { + throw QueryCompilationErrors.cannotTranslateExpressionToSourceFilterError(pred) + } + filter + }.toArray + + val table = r.table + val writeBuilder = newWriteBuilder(table, query.schema, options) + val write = writeBuilder match { + case builder: SupportsTruncate if isTruncate(filters) => + builder.truncate().build() + case builder: SupportsOverwrite => + builder.overwrite(filters).build() + case _ => + throw QueryExecutionErrors.overwriteTableByUnsupportedExpressionError(table) + } + + val newQuery = ExtendedDistributionAndOrderingUtils.prepareQuery(write, query, conf) + o.copy(write = Some(write), query = newQuery) + + case o @ OverwritePartitionsDynamic(r: DataSourceV2Relation, query, options, _, None) + if isIcebergRelation(r) => + val table = r.table + val writeBuilder = newWriteBuilder(table, query.schema, options) + val write = writeBuilder match { + case builder: SupportsDynamicOverwrite => + builder.overwriteDynamicPartitions().build() + case _ => + throw QueryExecutionErrors.dynamicPartitionOverwriteUnsupportedByTableError(table) + } + val newQuery = ExtendedDistributionAndOrderingUtils.prepareQuery(write, query, conf) + o.copy(write = Some(write), query = newQuery) + + case rd @ ReplaceIcebergData(r: DataSourceV2Relation, query, _, None) => + val rowSchema = StructType.fromAttributes(rd.dataInput) + val writeBuilder = newWriteBuilder(r.table, rowSchema, Map.empty) + val write = writeBuilder.build() + val newQuery = ExtendedDistributionAndOrderingUtils.prepareQuery(write, query, conf) + rd.copy(write = Some(write), query = Project(rd.dataInput, newQuery)) + + case wd @ WriteDelta(r: DataSourceV2Relation, query, _, projections, None) => + val rowSchema = projections.rowProjection.map(_.schema).orNull + val rowIdSchema = projections.rowIdProjection.schema + val metadataSchema = projections.metadataProjection.map(_.schema).orNull + val writeBuilder = newWriteBuilder(r.table, rowSchema, Map.empty, rowIdSchema, metadataSchema) + writeBuilder match { + case builder: DeltaWriteBuilder => + val deltaWrite = builder.build() + val newQuery = ExtendedDistributionAndOrderingUtils.prepareQuery(deltaWrite, query, conf) + wd.copy(write = Some(deltaWrite), query = newQuery) + case other => + throw new AnalysisException(s"$other is not DeltaWriteBuilder") + } + } + + private def isTruncate(filters: Array[Filter]): Boolean = { + filters.length == 1 && filters(0).isInstanceOf[AlwaysTrue] + } + + private def newWriteBuilder( + table: Table, + rowSchema: StructType, + writeOptions: Map[String, String], + rowIdSchema: StructType = null, + metadataSchema: StructType = null): WriteBuilder = { + val info = ExtendedLogicalWriteInfoImpl( + queryId = UUID.randomUUID().toString, + rowSchema, + writeOptions.asOptions, + rowIdSchema, + metadataSchema) + table.asWritable.newWriteBuilder(info) + } +} diff --git a/spark/v3.3/spark-extensions/src/main/scala/org/apache/spark/sql/execution/datasources/v2/MergeRowsExec.scala b/spark/v3.3/spark-extensions/src/main/scala/org/apache/spark/sql/execution/datasources/v2/MergeRowsExec.scala new file mode 100644 index 000000000000..4fbf8a523a54 --- /dev/null +++ b/spark/v3.3/spark-extensions/src/main/scala/org/apache/spark/sql/execution/datasources/v2/MergeRowsExec.scala @@ -0,0 +1,190 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.spark.sql.execution.datasources.v2 + +import org.apache.spark.SparkException +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.Ascending +import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.catalyst.expressions.AttributeSet +import org.apache.spark.sql.catalyst.expressions.BasePredicate +import org.apache.spark.sql.catalyst.expressions.Expression +import org.apache.spark.sql.catalyst.expressions.SortOrder +import org.apache.spark.sql.catalyst.expressions.UnsafeProjection +import org.apache.spark.sql.catalyst.expressions.codegen.GeneratePredicate +import org.apache.spark.sql.catalyst.util.truncatedString +import org.apache.spark.sql.execution.SparkPlan +import org.apache.spark.sql.execution.UnaryExecNode + +case class MergeRowsExec( + isSourceRowPresent: Expression, + isTargetRowPresent: Expression, + matchedConditions: Seq[Expression], + matchedOutputs: Seq[Seq[Expression]], + notMatchedConditions: Seq[Expression], + notMatchedOutputs: Seq[Seq[Expression]], + targetOutput: Seq[Expression], + rowIdAttrs: Seq[Attribute], + performCardinalityCheck: Boolean, + emitNotMatchedTargetRows: Boolean, + output: Seq[Attribute], + child: SparkPlan) extends UnaryExecNode { + + override def requiredChildOrdering: Seq[Seq[SortOrder]] = { + if (performCardinalityCheck) { + // request a local sort by the row ID attrs to co-locate matches for the same target row + Seq(rowIdAttrs.map(attr => SortOrder(attr, Ascending))) + } else { + Seq(Nil) + } + } + + @transient override lazy val producedAttributes: AttributeSet = { + AttributeSet(output.filterNot(attr => inputSet.contains(attr))) + } + + @transient override lazy val references: AttributeSet = child.outputSet + + override def simpleString(maxFields: Int): String = { + s"MergeRowsExec${truncatedString(output, "[", ", ", "]", maxFields)}" + } + + override protected def withNewChildInternal(newChild: SparkPlan): SparkPlan = { + copy(child = newChild) + } + + protected override def doExecute(): RDD[InternalRow] = { + child.execute().mapPartitions(processPartition) + } + + private def createProjection(exprs: Seq[Expression], attrs: Seq[Attribute]): UnsafeProjection = { + UnsafeProjection.create(exprs, attrs) + } + + private def createPredicate(expr: Expression, attrs: Seq[Attribute]): BasePredicate = { + GeneratePredicate.generate(expr, attrs) + } + + private def applyProjection( + actions: Seq[(BasePredicate, Option[UnsafeProjection])], + inputRow: InternalRow): InternalRow = { + + // find the first action where the predicate evaluates to true + // if there are overlapping conditions in actions, use the first matching action + // in the example below, when id = 5, both actions match but the first one is applied + // WHEN MATCHED AND id > 1 AND id < 10 UPDATE * + // WHEN MATCHED AND id = 5 OR id = 21 DELETE + + val pair = actions.find { + case (predicate, _) => predicate.eval(inputRow) + } + + // apply the projection to produce an output row, or return null to suppress this row + pair match { + case Some((_, Some(projection))) => + projection.apply(inputRow) + case _ => + null + } + } + + private def processPartition(rowIterator: Iterator[InternalRow]): Iterator[InternalRow] = { + val inputAttrs = child.output + + val isSourceRowPresentPred = createPredicate(isSourceRowPresent, inputAttrs) + val isTargetRowPresentPred = createPredicate(isTargetRowPresent, inputAttrs) + + val matchedPreds = matchedConditions.map(createPredicate(_, inputAttrs)) + val matchedProjs = matchedOutputs.map { + case output if output.nonEmpty => Some(createProjection(output, inputAttrs)) + case _ => None + } + val matchedPairs = matchedPreds zip matchedProjs + + val notMatchedPreds = notMatchedConditions.map(createPredicate(_, inputAttrs)) + val notMatchedProjs = notMatchedOutputs.map { + case output if output.nonEmpty => Some(createProjection(output, inputAttrs)) + case _ => None + } + val nonMatchedPairs = notMatchedPreds zip notMatchedProjs + + val projectTargetCols = createProjection(targetOutput, inputAttrs) + val rowIdProj = createProjection(rowIdAttrs, inputAttrs) + + // This method is responsible for processing a input row to emit the resultant row with an + // additional column that indicates whether the row is going to be included in the final + // output of merge or not. + // 1. Found a target row for which there is no corresponding source row (join condition not met) + // - Only project the target columns if we need to output unchanged rows + // 2. Found a source row for which there is no corresponding target row (join condition not met) + // - Apply the not matched actions (i.e INSERT actions) if non match conditions are met. + // 3. Found a source row for which there is a corresponding target row (join condition met) + // - Apply the matched actions (i.e DELETE or UPDATE actions) if match conditions are met. + def processRow(inputRow: InternalRow): InternalRow = { + if (emitNotMatchedTargetRows && !isSourceRowPresentPred.eval(inputRow)) { + projectTargetCols.apply(inputRow) + } else if (!isTargetRowPresentPred.eval(inputRow)) { + applyProjection(nonMatchedPairs, inputRow) + } else { + applyProjection(matchedPairs, inputRow) + } + } + + var lastMatchedRowId: InternalRow = null + + def processRowWithCardinalityCheck(inputRow: InternalRow): InternalRow = { + val isSourceRowPresent = isSourceRowPresentPred.eval(inputRow) + val isTargetRowPresent = isTargetRowPresentPred.eval(inputRow) + + if (isSourceRowPresent && isTargetRowPresent) { + val currentRowId = rowIdProj.apply(inputRow) + if (currentRowId == lastMatchedRowId) { + throw new SparkException( + "The ON search condition of the MERGE statement matched a single row from " + + "the target table with multiple rows of the source table. This could result " + + "in the target row being operated on more than once with an update or delete " + + "operation and is not allowed.") + } + lastMatchedRowId = currentRowId.copy() + } else { + lastMatchedRowId = null + } + + if (emitNotMatchedTargetRows && !isSourceRowPresent) { + projectTargetCols.apply(inputRow) + } else if (!isTargetRowPresent) { + applyProjection(nonMatchedPairs, inputRow) + } else { + applyProjection(matchedPairs, inputRow) + } + } + + val processFunc: InternalRow => InternalRow = if (performCardinalityCheck) { + processRowWithCardinalityCheck + } else { + processRow + } + + rowIterator + .map(processFunc) + .filter(row => row != null) + } +} diff --git a/spark/v3.3/spark-extensions/src/main/scala/org/apache/spark/sql/execution/datasources/v2/OptimizeMetadataOnlyDeleteFromIcebergTable.scala b/spark/v3.3/spark-extensions/src/main/scala/org/apache/spark/sql/execution/datasources/v2/OptimizeMetadataOnlyDeleteFromIcebergTable.scala new file mode 100644 index 000000000000..e81a567eb2a8 --- /dev/null +++ b/spark/v3.3/spark-extensions/src/main/scala/org/apache/spark/sql/execution/datasources/v2/OptimizeMetadataOnlyDeleteFromIcebergTable.scala @@ -0,0 +1,73 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.spark.sql.execution.datasources.v2 + +import org.apache.spark.sql.catalyst.expressions.Expression +import org.apache.spark.sql.catalyst.expressions.Literal +import org.apache.spark.sql.catalyst.expressions.PredicateHelper +import org.apache.spark.sql.catalyst.expressions.SubqueryExpression +import org.apache.spark.sql.catalyst.plans.logical.DeleteFromIcebergTable +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.connector.catalog.SupportsDelete +import org.apache.spark.sql.execution.datasources.DataSourceStrategy +import org.apache.spark.sql.sources +import org.slf4j.LoggerFactory + +/** + * Checks whether a metadata delete is possible and nullifies the rewrite plan if the source can + * handle this delete without executing the rewrite plan. + * + * Note this rule must be run after expression optimization. + */ +object OptimizeMetadataOnlyDeleteFromIcebergTable extends Rule[LogicalPlan] with PredicateHelper { + + val logger = LoggerFactory.getLogger(OptimizeMetadataOnlyDeleteFromIcebergTable.getClass) + + override def apply(plan: LogicalPlan): LogicalPlan = plan transform { + case d @ DeleteFromIcebergTable(relation: DataSourceV2Relation, cond, Some(_)) => + val deleteCond = cond.getOrElse(Literal.TrueLiteral) + relation.table match { + case table: SupportsDelete if !SubqueryExpression.hasSubquery(deleteCond) => + val predicates = splitConjunctivePredicates(deleteCond) + val normalizedPredicates = DataSourceStrategy.normalizeExprs(predicates, relation.output) + val dataSourceFilters = toDataSourceFilters(normalizedPredicates) + val allPredicatesTranslated = normalizedPredicates.size == dataSourceFilters.length + if (allPredicatesTranslated && table.canDeleteWhere(dataSourceFilters)) { + logger.info(s"Optimizing delete expression: ${dataSourceFilters.mkString(",")} as metadata delete") + d.copy(rewritePlan = None) + } else { + d + } + case _ => + d + } + } + + protected def toDataSourceFilters(predicates: Seq[Expression]): Array[sources.Filter] = { + predicates.flatMap { p => + val filter = DataSourceStrategy.translateFilter(p, supportNestedPredicatePushdown = true) + if (filter.isEmpty) { + logWarning(s"Cannot translate expression to source filter: $p") + } + filter + }.toArray + } +} diff --git a/spark/v3.3/spark-extensions/src/main/scala/org/apache/spark/sql/execution/datasources/v2/ReplaceDataExec.scala b/spark/v3.3/spark-extensions/src/main/scala/org/apache/spark/sql/execution/datasources/v2/ReplaceDataExec.scala new file mode 100644 index 000000000000..26c652469ac4 --- /dev/null +++ b/spark/v3.3/spark-extensions/src/main/scala/org/apache/spark/sql/execution/datasources/v2/ReplaceDataExec.scala @@ -0,0 +1,40 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.spark.sql.execution.datasources.v2 + +import org.apache.spark.sql.catalyst.expressions.AttributeSet +import org.apache.spark.sql.connector.write.Write +import org.apache.spark.sql.execution.SparkPlan + +/** + * Physical plan node to replace data in existing tables. + */ +case class ReplaceDataExec( + query: SparkPlan, + refreshCache: () => Unit, + write: Write) extends V2ExistingTableWriteExec { + + override lazy val references: AttributeSet = query.outputSet + override lazy val stringArgs: Iterator[Any] = Iterator(query, write) + + override protected def withNewChildInternal(newChild: SparkPlan): ReplaceDataExec = { + copy(query = newChild) + } +} diff --git a/spark/v3.3/spark-extensions/src/main/scala/org/apache/spark/sql/execution/datasources/v2/ReplacePartitionFieldExec.scala b/spark/v3.3/spark-extensions/src/main/scala/org/apache/spark/sql/execution/datasources/v2/ReplacePartitionFieldExec.scala new file mode 100644 index 000000000000..fcae0a5defc4 --- /dev/null +++ b/spark/v3.3/spark-extensions/src/main/scala/org/apache/spark/sql/execution/datasources/v2/ReplacePartitionFieldExec.scala @@ -0,0 +1,72 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.spark.sql.execution.datasources.v2 + +import org.apache.iceberg.spark.Spark3Util +import org.apache.iceberg.spark.source.SparkTable +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.connector.catalog.Identifier +import org.apache.spark.sql.connector.catalog.TableCatalog +import org.apache.spark.sql.connector.expressions.FieldReference +import org.apache.spark.sql.connector.expressions.IdentityTransform +import org.apache.spark.sql.connector.expressions.Transform + +case class ReplacePartitionFieldExec( + catalog: TableCatalog, + ident: Identifier, + transformFrom: Transform, + transformTo: Transform, + name: Option[String]) extends LeafV2CommandExec { + import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._ + + override lazy val output: Seq[Attribute] = Nil + + override protected def run(): Seq[InternalRow] = { + catalog.loadTable(ident) match { + case iceberg: SparkTable => + val schema = iceberg.table.schema + transformFrom match { + case IdentityTransform(FieldReference(parts)) if parts.size == 1 && schema.findField(parts.head) == null => + // the name is not present in the Iceberg schema, so it must be a partition field name, not a column name + iceberg.table.updateSpec() + .removeField(parts.head) + .addField(name.orNull, Spark3Util.toIcebergTerm(transformTo)) + .commit() + + case _ => + iceberg.table.updateSpec() + .removeField(Spark3Util.toIcebergTerm(transformFrom)) + .addField(name.orNull, Spark3Util.toIcebergTerm(transformTo)) + .commit() + } + + case table => + throw new UnsupportedOperationException(s"Cannot replace partition field in non-Iceberg table: $table") + } + + Nil + } + + override def simpleString(maxFields: Int): String = { + s"ReplacePartitionField ${catalog.name}.${ident.quoted} ${transformFrom.describe} " + + s"with ${name.map(n => s"$n=").getOrElse("")}${transformTo.describe}" + } +} diff --git a/spark/v3.3/spark-extensions/src/main/scala/org/apache/spark/sql/execution/datasources/v2/ReplaceRewrittenRowLevelCommand.scala b/spark/v3.3/spark-extensions/src/main/scala/org/apache/spark/sql/execution/datasources/v2/ReplaceRewrittenRowLevelCommand.scala new file mode 100644 index 000000000000..414d4c0ec305 --- /dev/null +++ b/spark/v3.3/spark-extensions/src/main/scala/org/apache/spark/sql/execution/datasources/v2/ReplaceRewrittenRowLevelCommand.scala @@ -0,0 +1,34 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.spark.sql.execution.datasources.v2 + +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.catalyst.plans.logical.RowLevelCommand +import org.apache.spark.sql.catalyst.rules.Rule + +/** + * Replaces operations such as DELETE and MERGE with the corresponding rewrite plans. + */ +object ReplaceRewrittenRowLevelCommand extends Rule[LogicalPlan] { + override def apply(plan: LogicalPlan): LogicalPlan = plan transformDown { + case c: RowLevelCommand if c.rewritePlan.isDefined => + c.rewritePlan.get + } +} diff --git a/spark/v3.3/spark-extensions/src/main/scala/org/apache/spark/sql/execution/datasources/v2/RowLevelCommandScanRelationPushDown.scala b/spark/v3.3/spark-extensions/src/main/scala/org/apache/spark/sql/execution/datasources/v2/RowLevelCommandScanRelationPushDown.scala new file mode 100644 index 000000000000..2be73cb6eeac --- /dev/null +++ b/spark/v3.3/spark-extensions/src/main/scala/org/apache/spark/sql/execution/datasources/v2/RowLevelCommandScanRelationPushDown.scala @@ -0,0 +1,97 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.spark.sql.execution.datasources.v2 + +import org.apache.spark.sql.catalyst.expressions.AttributeReference +import org.apache.spark.sql.catalyst.expressions.AttributeSet +import org.apache.spark.sql.catalyst.expressions.Expression +import org.apache.spark.sql.catalyst.expressions.PredicateHelper +import org.apache.spark.sql.catalyst.expressions.SubqueryExpression +import org.apache.spark.sql.catalyst.planning.RewrittenRowLevelCommand +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.catalyst.util.CharVarcharUtils +import org.apache.spark.sql.connector.expressions.filter.Predicate +import org.apache.spark.sql.connector.read.ScanBuilder +import org.apache.spark.sql.execution.datasources.DataSourceStrategy +import org.apache.spark.sql.sources.Filter +import org.apache.spark.sql.types.StructType + +object RowLevelCommandScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper { + import ExtendedDataSourceV2Implicits._ + + override def apply(plan: LogicalPlan): LogicalPlan = plan transformDown { + // push down the filter from the command condition instead of the filter in the rewrite plan, + // which may be negated for copy-on-write operations + case RewrittenRowLevelCommand(command, relation: DataSourceV2Relation, rewritePlan) => + val table = relation.table.asRowLevelOperationTable + val scanBuilder = table.newScanBuilder(relation.options) + + val (pushedFilters, remainingFilters) = command.condition match { + case Some(cond) => pushFilters(cond, scanBuilder, relation.output) + case None => (Nil, Nil) + } + + val (scan, output) = PushDownUtils.pruneColumns(scanBuilder, relation, relation.output, Nil) + + logInfo( + s""" + |Pushing operators to ${relation.name} + |Pushed filters: ${pushedFilters.mkString(", ")} + |Filters that were not pushed: ${remainingFilters.mkString(",")} + |Output: ${output.mkString(", ")} + """.stripMargin) + + // replace DataSourceV2Relation with DataSourceV2ScanRelation for the row operation table + // there may be multiple read relations for UPDATEs that rely on the UNION approach + val newRewritePlan = rewritePlan transform { + case r: DataSourceV2Relation if r.table eq table => + DataSourceV2ScanRelation(r, scan, toOutputAttrs(scan.readSchema(), r)) + } + + command.withNewRewritePlan(newRewritePlan) + } + + private def pushFilters( + cond: Expression, + scanBuilder: ScanBuilder, + tableAttrs: Seq[AttributeReference]): (Seq[Filter], Seq[Predicate]) = { + + val tableAttrSet = AttributeSet(tableAttrs) + val filters = splitConjunctivePredicates(cond).filter(_.references.subsetOf(tableAttrSet)) + val normalizedFilters = DataSourceStrategy.normalizeExprs(filters, tableAttrs) + val (_, normalizedFiltersWithoutSubquery) = + normalizedFilters.partition(SubqueryExpression.hasSubquery) + + val (pushedFilters, _) = PushDownUtils.pushFilters(scanBuilder, normalizedFiltersWithoutSubquery) + (pushedFilters.left.getOrElse(Seq.empty), pushedFilters.right.getOrElse(Seq.empty)) + } + + private def toOutputAttrs( + schema: StructType, + relation: DataSourceV2Relation): Seq[AttributeReference] = { + val nameToAttr = relation.output.map(_.name).zip(relation.output).toMap + val cleaned = CharVarcharUtils.replaceCharVarcharWithStringInSchema(schema) + cleaned.toAttributes.map { + // keep the attribute id during transformation + a => a.withExprId(nameToAttr(a.name).exprId) + } + } +} diff --git a/spark/v3.3/spark-extensions/src/main/scala/org/apache/spark/sql/execution/datasources/v2/SetIdentifierFieldsExec.scala b/spark/v3.3/spark-extensions/src/main/scala/org/apache/spark/sql/execution/datasources/v2/SetIdentifierFieldsExec.scala new file mode 100644 index 000000000000..b50550ad38ef --- /dev/null +++ b/spark/v3.3/spark-extensions/src/main/scala/org/apache/spark/sql/execution/datasources/v2/SetIdentifierFieldsExec.scala @@ -0,0 +1,53 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.spark.sql.execution.datasources.v2 + +import org.apache.iceberg.spark.source.SparkTable +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.connector.catalog.Identifier +import org.apache.spark.sql.connector.catalog.TableCatalog +import scala.jdk.CollectionConverters._ + +case class SetIdentifierFieldsExec( + catalog: TableCatalog, + ident: Identifier, + fields: Seq[String]) extends LeafV2CommandExec { + import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._ + + override lazy val output: Seq[Attribute] = Nil + + override protected def run(): Seq[InternalRow] = { + catalog.loadTable(ident) match { + case iceberg: SparkTable => + iceberg.table.updateSchema() + .setIdentifierFields(fields.asJava) + .commit(); + case table => + throw new UnsupportedOperationException(s"Cannot set identifier fields in non-Iceberg table: $table") + } + + Nil + } + + override def simpleString(maxFields: Int): String = { + s"SetIdentifierFields ${catalog.name}.${ident.quoted} (${fields.quoted})"; + } +} diff --git a/spark/v3.3/spark-extensions/src/main/scala/org/apache/spark/sql/execution/datasources/v2/SetWriteDistributionAndOrderingExec.scala b/spark/v3.3/spark-extensions/src/main/scala/org/apache/spark/sql/execution/datasources/v2/SetWriteDistributionAndOrderingExec.scala new file mode 100644 index 000000000000..386485b10b05 --- /dev/null +++ b/spark/v3.3/spark-extensions/src/main/scala/org/apache/spark/sql/execution/datasources/v2/SetWriteDistributionAndOrderingExec.scala @@ -0,0 +1,78 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.spark.sql.execution.datasources.v2 + +import org.apache.iceberg.DistributionMode +import org.apache.iceberg.NullOrder +import org.apache.iceberg.SortDirection +import org.apache.iceberg.TableProperties.WRITE_DISTRIBUTION_MODE +import org.apache.iceberg.expressions.Term +import org.apache.iceberg.spark.source.SparkTable +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.connector.catalog.CatalogV2Implicits +import org.apache.spark.sql.connector.catalog.Identifier +import org.apache.spark.sql.connector.catalog.TableCatalog + +case class SetWriteDistributionAndOrderingExec( + catalog: TableCatalog, + ident: Identifier, + distributionMode: DistributionMode, + sortOrder: Seq[(Term, SortDirection, NullOrder)]) extends LeafV2CommandExec { + + import CatalogV2Implicits._ + + override lazy val output: Seq[Attribute] = Nil + + override protected def run(): Seq[InternalRow] = { + catalog.loadTable(ident) match { + case iceberg: SparkTable => + val txn = iceberg.table.newTransaction() + + val orderBuilder = txn.replaceSortOrder() + sortOrder.foreach { + case (term, SortDirection.ASC, nullOrder) => + orderBuilder.asc(term, nullOrder) + case (term, SortDirection.DESC, nullOrder) => + orderBuilder.desc(term, nullOrder) + } + orderBuilder.commit() + + txn.updateProperties() + .set(WRITE_DISTRIBUTION_MODE, distributionMode.modeName()) + .commit() + + txn.commitTransaction() + + case table => + throw new UnsupportedOperationException(s"Cannot set write order of non-Iceberg table: $table") + } + + Nil + } + + override def simpleString(maxFields: Int): String = { + val tableIdent = s"${catalog.name}.${ident.quoted}" + val order = sortOrder.map { + case (term, direction, nullOrder) => s"$term $direction $nullOrder" + }.mkString(", ") + s"SetWriteDistributionAndOrdering $tableIdent $distributionMode $order" + } +} diff --git a/spark/v3.3/spark-extensions/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteDeltaExec.scala b/spark/v3.3/spark-extensions/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteDeltaExec.scala new file mode 100644 index 000000000000..fa4e4f648313 --- /dev/null +++ b/spark/v3.3/spark-extensions/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteDeltaExec.scala @@ -0,0 +1,238 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.spark.sql.execution.datasources.v2 + +import org.apache.spark.SparkEnv +import org.apache.spark.SparkException +import org.apache.spark.TaskContext +import org.apache.spark.internal.Logging +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.AttributeSet +import org.apache.spark.sql.catalyst.util.RowDeltaUtils._ +import org.apache.spark.sql.catalyst.util.WriteDeltaProjections +import org.apache.spark.sql.connector.iceberg.write.DeltaWrite +import org.apache.spark.sql.connector.iceberg.write.DeltaWriter +import org.apache.spark.sql.connector.write.BatchWrite +import org.apache.spark.sql.connector.write.DataWriter +import org.apache.spark.sql.connector.write.DataWriterFactory +import org.apache.spark.sql.connector.write.PhysicalWriteInfoImpl +import org.apache.spark.sql.connector.write.WriterCommitMessage +import org.apache.spark.sql.errors.QueryExecutionErrors +import org.apache.spark.sql.execution.SparkPlan +import org.apache.spark.sql.execution.metric.CustomMetrics +import org.apache.spark.sql.execution.metric.SQLMetric +import org.apache.spark.util.LongAccumulator +import org.apache.spark.util.Utils +import scala.collection.compat.immutable.ArraySeq +import scala.util.control.NonFatal + +/** + * Physical plan node to write a delta of rows to an existing table. + */ +case class WriteDeltaExec( + query: SparkPlan, + refreshCache: () => Unit, + projections: WriteDeltaProjections, + write: DeltaWrite) extends ExtendedV2ExistingTableWriteExec[DeltaWriter[InternalRow]] { + + override lazy val references: AttributeSet = query.outputSet + override lazy val stringArgs: Iterator[Any] = Iterator(query, write) + + override lazy val writingTask: WritingSparkTask[DeltaWriter[InternalRow]] = { + DeltaWithMetadataWritingSparkTask(projections) + } + + override protected def withNewChildInternal(newChild: SparkPlan): WriteDeltaExec = { + copy(query = newChild) + } +} + +// a trait similar to V2ExistingTableWriteExec but supports custom write tasks +trait ExtendedV2ExistingTableWriteExec[W <: DataWriter[InternalRow]] extends V2ExistingTableWriteExec { + def writingTask: WritingSparkTask[W] + + protected override def writeWithV2(batchWrite: BatchWrite): Seq[InternalRow] = { + val rdd: RDD[InternalRow] = { + val tempRdd = query.execute() + // SPARK-23271 If we are attempting to write a zero partition rdd, create a dummy single + // partition rdd to make sure we at least set up one write task to write the metadata. + if (tempRdd.partitions.length == 0) { + sparkContext.parallelize(Seq.empty[InternalRow], 1) + } else { + tempRdd + } + } + // introduce a local var to avoid serializing the whole class + val task = writingTask + val writerFactory = batchWrite.createBatchWriterFactory( + PhysicalWriteInfoImpl(rdd.getNumPartitions)) + val useCommitCoordinator = batchWrite.useCommitCoordinator + val messages = new Array[WriterCommitMessage](rdd.partitions.length) + val totalNumRowsAccumulator = new LongAccumulator() + + logInfo(s"Start processing data source write support: $batchWrite. " + + s"The input RDD has ${messages.length} partitions.") + + // Avoid object not serializable issue. + val writeMetrics: Map[String, SQLMetric] = customMetrics + + try { + sparkContext.runJob( + rdd, + (context: TaskContext, iter: Iterator[InternalRow]) => + task.run(writerFactory, context, iter, useCommitCoordinator, writeMetrics), + rdd.partitions.indices, + (index, result: DataWritingSparkTaskResult) => { + val commitMessage = result.writerCommitMessage + messages(index) = commitMessage + totalNumRowsAccumulator.add(result.numRows) + batchWrite.onDataWriterCommit(commitMessage) + } + ) + + logInfo(s"Data source write support $batchWrite is committing.") + batchWrite.commit(messages) + logInfo(s"Data source write support $batchWrite committed.") + commitProgress = Some(StreamWriterCommitProgress(totalNumRowsAccumulator.value)) + } catch { + case cause: Throwable => + logError(s"Data source write support $batchWrite is aborting.") + try { + batchWrite.abort(messages) + } catch { + case t: Throwable => + logError(s"Data source write support $batchWrite failed to abort.") + cause.addSuppressed(t) + throw QueryExecutionErrors.writingJobFailedError(cause) + } + logError(s"Data source write support $batchWrite aborted.") + cause match { + // Only wrap non fatal exceptions. + case NonFatal(e) => throw QueryExecutionErrors.writingJobAbortedError(e) + case _ => throw cause + } + } + + Nil + } +} + +trait WritingSparkTask[W <: DataWriter[InternalRow]] extends Logging with Serializable { + + protected def writeFunc(writer: W, row: InternalRow): Unit + + def run( + writerFactory: DataWriterFactory, + context: TaskContext, + iter: Iterator[InternalRow], + useCommitCoordinator: Boolean, + customMetrics: Map[String, SQLMetric]): DataWritingSparkTaskResult = { + val stageId = context.stageId() + val stageAttempt = context.stageAttemptNumber() + val partId = context.partitionId() + val taskId = context.taskAttemptId() + val attemptId = context.attemptNumber() + val dataWriter = writerFactory.createWriter(partId, taskId).asInstanceOf[W] + + var count = 0L + // write the data and commit this writer. + Utils.tryWithSafeFinallyAndFailureCallbacks(block = { + while (iter.hasNext) { + if (count % CustomMetrics.NUM_ROWS_PER_UPDATE == 0) { + CustomMetrics.updateMetrics(ArraySeq.unsafeWrapArray(dataWriter.currentMetricsValues), customMetrics) + } + + // Count is here. + count += 1 + writeFunc(dataWriter, iter.next()) + } + + CustomMetrics.updateMetrics(ArraySeq.unsafeWrapArray(dataWriter.currentMetricsValues), customMetrics) + + val msg = if (useCommitCoordinator) { + val coordinator = SparkEnv.get.outputCommitCoordinator + val commitAuthorized = coordinator.canCommit(stageId, stageAttempt, partId, attemptId) + if (commitAuthorized) { + logInfo(s"Commit authorized for partition $partId (task $taskId, attempt $attemptId, " + + s"stage $stageId.$stageAttempt)") + dataWriter.commit() + } else { + val commitDeniedException = QueryExecutionErrors.commitDeniedError( + partId, taskId, attemptId, stageId, stageAttempt) + logInfo(commitDeniedException.getMessage) + // throwing CommitDeniedException will trigger the catch block for abort + throw commitDeniedException + } + + } else { + logInfo(s"Writer for partition ${context.partitionId()} is committing.") + dataWriter.commit() + } + + logInfo(s"Committed partition $partId (task $taskId, attempt $attemptId, " + + s"stage $stageId.$stageAttempt)") + + DataWritingSparkTaskResult(count, msg) + + })(catchBlock = { + // If there is an error, abort this writer + logError(s"Aborting commit for partition $partId (task $taskId, attempt $attemptId, " + + s"stage $stageId.$stageAttempt)") + dataWriter.abort() + logError(s"Aborted commit for partition $partId (task $taskId, attempt $attemptId, " + + s"stage $stageId.$stageAttempt)") + }, finallyBlock = { + dataWriter.close() + }) + } +} + +case class DeltaWithMetadataWritingSparkTask( + projs: WriteDeltaProjections) extends WritingSparkTask[DeltaWriter[InternalRow]] { + + private lazy val rowProjection = projs.rowProjection.orNull + private lazy val rowIdProjection = projs.rowIdProjection + private lazy val metadataProjection = projs.metadataProjection.orNull + + override protected def writeFunc(writer: DeltaWriter[InternalRow], row: InternalRow): Unit = { + val operation = row.getInt(0) + + operation match { + case DELETE_OPERATION => + rowIdProjection.project(row) + metadataProjection.project(row) + writer.delete(metadataProjection, rowIdProjection) + + case UPDATE_OPERATION => + rowProjection.project(row) + rowIdProjection.project(row) + metadataProjection.project(row) + writer.update(metadataProjection, rowIdProjection, rowProjection) + + case INSERT_OPERATION => + rowProjection.project(row) + writer.insert(rowProjection) + + case other => + throw new SparkException(s"Unexpected operation ID: $other") + } + } +} diff --git a/spark/v3.3/spark-extensions/src/main/scala/org/apache/spark/sql/execution/dynamicpruning/RowLevelCommandDynamicPruning.scala b/spark/v3.3/spark-extensions/src/main/scala/org/apache/spark/sql/execution/dynamicpruning/RowLevelCommandDynamicPruning.scala new file mode 100644 index 000000000000..1a9dec3b1727 --- /dev/null +++ b/spark/v3.3/spark-extensions/src/main/scala/org/apache/spark/sql/execution/dynamicpruning/RowLevelCommandDynamicPruning.scala @@ -0,0 +1,171 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.spark.sql.execution.dynamicpruning + +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.catalyst.expressions.And +import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.catalyst.expressions.AttributeMap +import org.apache.spark.sql.catalyst.expressions.AttributeReference +import org.apache.spark.sql.catalyst.expressions.DynamicPruningSubquery +import org.apache.spark.sql.catalyst.expressions.Expression +import org.apache.spark.sql.catalyst.expressions.ExtendedV2ExpressionUtils +import org.apache.spark.sql.catalyst.expressions.Literal +import org.apache.spark.sql.catalyst.expressions.PredicateHelper +import org.apache.spark.sql.catalyst.expressions.SubqueryExpression +import org.apache.spark.sql.catalyst.planning.RewrittenRowLevelCommand +import org.apache.spark.sql.catalyst.plans.LeftSemi +import org.apache.spark.sql.catalyst.plans.logical.DeleteFromIcebergTable +import org.apache.spark.sql.catalyst.plans.logical.Filter +import org.apache.spark.sql.catalyst.plans.logical.Join +import org.apache.spark.sql.catalyst.plans.logical.JoinHint +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.catalyst.plans.logical.MergeIntoIcebergTable +import org.apache.spark.sql.catalyst.plans.logical.Project +import org.apache.spark.sql.catalyst.plans.logical.ReplaceIcebergData +import org.apache.spark.sql.catalyst.plans.logical.RowLevelCommand +import org.apache.spark.sql.catalyst.plans.logical.Sort +import org.apache.spark.sql.catalyst.plans.logical.Subquery +import org.apache.spark.sql.catalyst.plans.logical.UpdateIcebergTable +import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.catalyst.trees.TreePattern.PLAN_EXPRESSION +import org.apache.spark.sql.catalyst.trees.TreePattern.SORT +import org.apache.spark.sql.connector.read.SupportsRuntimeFiltering +import org.apache.spark.sql.execution.datasources.v2.DataSourceV2ScanRelation +import scala.collection.compat.immutable.ArraySeq + +/** + * A rule that adds a runtime filter for row-level commands. + * + * Note that only group-based rewrite plans (i.e. ReplaceData) are taken into account. + * Row-based rewrite plans are subject to usual runtime filtering. + */ +case class RowLevelCommandDynamicPruning(spark: SparkSession) extends Rule[LogicalPlan] with PredicateHelper { + + override def apply(plan: LogicalPlan): LogicalPlan = plan transformDown { + // apply special dynamic filtering only for plans that don't support deltas + case RewrittenRowLevelCommand( + command: RowLevelCommand, + DataSourceV2ScanRelation(_, scan: SupportsRuntimeFiltering, _, _), + rewritePlan: ReplaceIcebergData) if conf.dynamicPartitionPruningEnabled && isCandidate(command) => + + // use reference equality to find exactly the required scan relations + val newRewritePlan = rewritePlan transformUp { + case r: DataSourceV2ScanRelation if r.scan eq scan => + val pruningKeys = ExtendedV2ExpressionUtils.resolveRefs[Attribute]( + ArraySeq.unsafeWrapArray(scan.filterAttributes), r) + val dynamicPruningCond = buildDynamicPruningCondition(r, command, pruningKeys) + val filter = Filter(dynamicPruningCond, r) + // always optimize dynamic filtering subqueries for row-level commands as it is important + // to rewrite introduced predicates as joins because Spark recently stopped optimizing + // dynamic subqueries to facilitate broadcast reuse + optimizeSubquery(filter) + } + command.withNewRewritePlan(newRewritePlan) + } + + private def isCandidate(command: RowLevelCommand): Boolean = command.condition match { + case Some(cond) if cond != Literal.TrueLiteral => true + case _ => false + } + + private def buildDynamicPruningCondition( + relation: DataSourceV2ScanRelation, + command: RowLevelCommand, + pruningKeys: Seq[Attribute]): Expression = { + + // construct a filtering plan with the original scan relation + val matchingRowsPlan = command match { + case d: DeleteFromIcebergTable => + Filter(d.condition.get, relation) + + case u: UpdateIcebergTable => + // UPDATEs with subqueries are rewritten using a UNION with two identical scan relations + // the analyzer clones of them and assigns fresh expr IDs so that attributes don't collide + // this rule assigns dynamic filters to both scan relations based on the update condition + // the condition always refers to the original expr IDs and must be transformed + // see RewriteUpdateTable for more details + val attrMap = buildAttrMap(u.table.output, relation.output) + val transformedCond = u.condition.get transform { + case attr: AttributeReference if attrMap.contains(attr) => attrMap(attr) + } + Filter(transformedCond, relation) + + case m: MergeIntoIcebergTable => + Join(relation, m.sourceTable, LeftSemi, Some(m.mergeCondition), JoinHint.NONE) + } + + // clone the original relation in the filtering plan and assign new expr IDs to avoid conflicts + val transformedMatchingRowsPlan = matchingRowsPlan transformUpWithNewOutput { + case r: DataSourceV2ScanRelation if r eq relation => + val oldOutput = r.output + val newOutput = oldOutput.map(_.newInstance()) + r.copy(output = newOutput) -> oldOutput.zip(newOutput) + } + + val filterableScan = relation.scan.asInstanceOf[SupportsRuntimeFiltering] + val buildKeys = ExtendedV2ExpressionUtils.resolveRefs[Attribute]( + ArraySeq.unsafeWrapArray(filterableScan.filterAttributes), + transformedMatchingRowsPlan) + val buildQuery = Project(buildKeys, transformedMatchingRowsPlan) + val dynamicPruningSubqueries = pruningKeys.zipWithIndex.map { case (key, index) => + DynamicPruningSubquery(key, buildQuery, buildKeys, index, onlyInBroadcast = false) + } + + // combine all dynamic subqueries to produce the final condition + dynamicPruningSubqueries.reduce(And) + } + + private def buildAttrMap( + tableAttrs: Seq[Attribute], + scanAttrs: Seq[Attribute]): AttributeMap[Attribute] = { + + val resolver = conf.resolver + val attrMapping = tableAttrs.flatMap { tableAttr => + scanAttrs + .find(scanAttr => resolver(scanAttr.name, tableAttr.name)) + .map(scanAttr => tableAttr -> scanAttr) + } + AttributeMap(attrMapping) + } + + // borrowed from OptimizeSubqueries in Spark + private def optimizeSubquery(plan: LogicalPlan): LogicalPlan = plan.transformAllExpressionsWithPruning( + _.containsPattern(PLAN_EXPRESSION)) { + case s: SubqueryExpression => + val Subquery(newPlan, _) = spark.sessionState.optimizer.execute(Subquery.fromExpression(s)) + // At this point we have an optimized subquery plan that we are going to attach + // to this subquery expression. Here we can safely remove any top level sort + // in the plan as tuples produced by a subquery are un-ordered. + s.withNewPlan(removeTopLevelSort(newPlan)) + } + + // borrowed from OptimizeSubqueries in Spark + private def removeTopLevelSort(plan: LogicalPlan): LogicalPlan = { + if (!plan.containsPattern(SORT)) { + return plan + } + plan match { + case Sort(_, _, child) => child + case Project(fields, child) => Project(fields, removeTopLevelSort(child)) + case other => other + } + } +} diff --git a/spark/v3.3/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/Employee.java b/spark/v3.3/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/Employee.java new file mode 100644 index 000000000000..51ac57855bbe --- /dev/null +++ b/spark/v3.3/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/Employee.java @@ -0,0 +1,68 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iceberg.spark.extensions; + +import java.util.Objects; + +public class Employee { + private Integer id; + private String dep; + + public Employee() { + } + + public Employee(Integer id, String dep) { + this.id = id; + this.dep = dep; + } + + public Integer getId() { + return id; + } + + public void setId(Integer id) { + this.id = id; + } + + public String getDep() { + return dep; + } + + public void setDep(String dep) { + this.dep = dep; + } + + @Override + public boolean equals(Object other) { + if (this == other) { + return true; + } else if (other == null || getClass() != other.getClass()) { + return false; + } + + Employee employee = (Employee) other; + return Objects.equals(id, employee.id) && Objects.equals(dep, employee.dep); + } + + @Override + public int hashCode() { + return Objects.hash(id, dep); + } +} diff --git a/spark/v3.3/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/SparkExtensionsTestBase.java b/spark/v3.3/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/SparkExtensionsTestBase.java new file mode 100644 index 000000000000..585ea3cd81bd --- /dev/null +++ b/spark/v3.3/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/SparkExtensionsTestBase.java @@ -0,0 +1,67 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iceberg.spark.extensions; + +import java.util.Map; +import java.util.Random; +import java.util.concurrent.ThreadLocalRandom; +import org.apache.iceberg.CatalogUtil; +import org.apache.iceberg.hive.HiveCatalog; +import org.apache.iceberg.hive.TestHiveMetastore; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap; +import org.apache.iceberg.spark.SparkCatalogTestBase; +import org.apache.iceberg.spark.SparkTestBase; +import org.apache.spark.sql.SparkSession; +import org.apache.spark.sql.internal.SQLConf; +import org.junit.BeforeClass; + +import static org.apache.hadoop.hive.conf.HiveConf.ConfVars.METASTOREURIS; + +public abstract class SparkExtensionsTestBase extends SparkCatalogTestBase { + + private static final Random RANDOM = ThreadLocalRandom.current(); + + public SparkExtensionsTestBase(String catalogName, String implementation, Map config) { + super(catalogName, implementation, config); + } + + @BeforeClass + public static void startMetastoreAndSpark() { + SparkTestBase.metastore = new TestHiveMetastore(); + metastore.start(); + SparkTestBase.hiveConf = metastore.hiveConf(); + + SparkTestBase.spark = SparkSession.builder() + .master("local[2]") + .config("spark.testing", "true") + .config(SQLConf.PARTITION_OVERWRITE_MODE().key(), "dynamic") + .config("spark.sql.extensions", IcebergSparkSessionExtensions.class.getName()) + .config("spark.hadoop." + METASTOREURIS.varname, hiveConf.get(METASTOREURIS.varname)) + .config("spark.sql.shuffle.partitions", "4") + .config("spark.sql.hive.metastorePartitionPruningFallbackOnException", "true") + .config("spark.sql.legacy.respectNullabilityInTextDatasetConversion", "true") + .config(SQLConf.ADAPTIVE_EXECUTION_ENABLED().key(), String.valueOf(RANDOM.nextBoolean())) + .enableHiveSupport() + .getOrCreate(); + + SparkTestBase.catalog = (HiveCatalog) + CatalogUtil.loadCatalog(HiveCatalog.class.getName(), "hive", ImmutableMap.of(), hiveConf); + } +} diff --git a/spark/v3.3/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/SparkRowLevelOperationsTestBase.java b/spark/v3.3/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/SparkRowLevelOperationsTestBase.java new file mode 100644 index 000000000000..278559a5601d --- /dev/null +++ b/spark/v3.3/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/SparkRowLevelOperationsTestBase.java @@ -0,0 +1,243 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iceberg.spark.extensions; + +import java.util.Arrays; +import java.util.List; +import java.util.Map; +import java.util.Random; +import java.util.Set; +import java.util.concurrent.ThreadLocalRandom; +import java.util.stream.Collectors; +import org.apache.iceberg.Snapshot; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap; +import org.apache.iceberg.spark.SparkCatalog; +import org.apache.iceberg.spark.SparkSessionCatalog; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Encoder; +import org.apache.spark.sql.Encoders; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.catalyst.analysis.NoSuchTableException; +import org.junit.Assert; +import org.junit.runner.RunWith; +import org.junit.runners.Parameterized; +import org.junit.runners.Parameterized.Parameters; + +import static org.apache.iceberg.DataOperations.DELETE; +import static org.apache.iceberg.DataOperations.OVERWRITE; +import static org.apache.iceberg.SnapshotSummary.ADDED_DELETE_FILES_PROP; +import static org.apache.iceberg.SnapshotSummary.ADDED_FILES_PROP; +import static org.apache.iceberg.SnapshotSummary.CHANGED_PARTITION_COUNT_PROP; +import static org.apache.iceberg.SnapshotSummary.DELETED_FILES_PROP; +import static org.apache.iceberg.TableProperties.DEFAULT_FILE_FORMAT; +import static org.apache.iceberg.TableProperties.PARQUET_VECTORIZATION_ENABLED; +import static org.apache.iceberg.TableProperties.WRITE_DISTRIBUTION_MODE; +import static org.apache.iceberg.TableProperties.WRITE_DISTRIBUTION_MODE_HASH; +import static org.apache.iceberg.TableProperties.WRITE_DISTRIBUTION_MODE_NONE; +import static org.apache.iceberg.TableProperties.WRITE_DISTRIBUTION_MODE_RANGE; + +@RunWith(Parameterized.class) +public abstract class SparkRowLevelOperationsTestBase extends SparkExtensionsTestBase { + + private static final Random RANDOM = ThreadLocalRandom.current(); + + protected final String fileFormat; + protected final boolean vectorized; + protected final String distributionMode; + + public SparkRowLevelOperationsTestBase(String catalogName, String implementation, + Map config, String fileFormat, + boolean vectorized, + String distributionMode) { + super(catalogName, implementation, config); + this.fileFormat = fileFormat; + this.vectorized = vectorized; + this.distributionMode = distributionMode; + } + + @Parameters(name = "catalogName = {0}, implementation = {1}, config = {2}," + + " format = {3}, vectorized = {4}, distributionMode = {5}") + public static Object[][] parameters() { + return new Object[][] { + { "testhive", SparkCatalog.class.getName(), + ImmutableMap.of( + "type", "hive", + "default-namespace", "default" + ), + "orc", + true, + WRITE_DISTRIBUTION_MODE_NONE + }, + { "testhive", SparkCatalog.class.getName(), + ImmutableMap.of( + "type", "hive", + "default-namespace", "default" + ), + "parquet", + true, + WRITE_DISTRIBUTION_MODE_NONE + }, + { "testhadoop", SparkCatalog.class.getName(), + ImmutableMap.of( + "type", "hadoop" + ), + "parquet", + RANDOM.nextBoolean(), + WRITE_DISTRIBUTION_MODE_HASH + }, + { "spark_catalog", SparkSessionCatalog.class.getName(), + ImmutableMap.of( + "type", "hive", + "default-namespace", "default", + "clients", "1", + "parquet-enabled", "false", + "cache-enabled", "false" // Spark will delete tables using v1, leaving the cache out of sync + ), + "avro", + false, + WRITE_DISTRIBUTION_MODE_RANGE + } + }; + } + + protected abstract Map extraTableProperties(); + + protected void initTable() { + sql("ALTER TABLE %s SET TBLPROPERTIES('%s' '%s')", tableName, DEFAULT_FILE_FORMAT, fileFormat); + sql("ALTER TABLE %s SET TBLPROPERTIES('%s' '%s')", tableName, WRITE_DISTRIBUTION_MODE, distributionMode); + + switch (fileFormat) { + case "parquet": + sql("ALTER TABLE %s SET TBLPROPERTIES('%s' '%b')", tableName, PARQUET_VECTORIZATION_ENABLED, vectorized); + break; + case "orc": + Assert.assertTrue(vectorized); + break; + case "avro": + Assert.assertFalse(vectorized); + break; + } + + Map props = extraTableProperties(); + props.forEach((prop, value) -> { + sql("ALTER TABLE %s SET TBLPROPERTIES('%s' '%s')", tableName, prop, value); + }); + } + + protected void createAndInitTable(String schema) { + createAndInitTable(schema, null); + } + + protected void createAndInitTable(String schema, String jsonData) { + sql("CREATE TABLE %s (%s) USING iceberg", tableName, schema); + initTable(); + + if (jsonData != null) { + try { + Dataset ds = toDS(schema, jsonData); + ds.writeTo(tableName).append(); + } catch (NoSuchTableException e) { + throw new RuntimeException("Failed to write data", e); + } + } + } + + protected void append(String table, String jsonData) { + append(table, null, jsonData); + } + + protected void append(String table, String schema, String jsonData) { + try { + Dataset ds = toDS(schema, jsonData); + ds.coalesce(1).writeTo(table).append(); + } catch (NoSuchTableException e) { + throw new RuntimeException("Failed to write data", e); + } + } + + protected void createOrReplaceView(String name, String jsonData) { + createOrReplaceView(name, null, jsonData); + } + + protected void createOrReplaceView(String name, String schema, String jsonData) { + Dataset ds = toDS(schema, jsonData); + ds.createOrReplaceTempView(name); + } + + protected void createOrReplaceView(String name, List data, Encoder encoder) { + spark.createDataset(data, encoder).createOrReplaceTempView(name); + } + + private Dataset toDS(String schema, String jsonData) { + List jsonRows = Arrays.stream(jsonData.split("\n")) + .filter(str -> str.trim().length() > 0) + .collect(Collectors.toList()); + Dataset jsonDS = spark.createDataset(jsonRows, Encoders.STRING()); + + if (schema != null) { + return spark.read().schema(schema).json(jsonDS); + } else { + return spark.read().json(jsonDS); + } + } + + protected void validateDelete(Snapshot snapshot, String changedPartitionCount, String deletedDataFiles) { + validateSnapshot(snapshot, DELETE, changedPartitionCount, deletedDataFiles, null, null); + } + + protected void validateCopyOnWrite(Snapshot snapshot, String changedPartitionCount, + String deletedDataFiles, String addedDataFiles) { + validateSnapshot(snapshot, OVERWRITE, changedPartitionCount, deletedDataFiles, null, addedDataFiles); + } + + protected void validateMergeOnRead(Snapshot snapshot, String changedPartitionCount, + String addedDeleteFiles, String addedDataFiles) { + validateSnapshot(snapshot, OVERWRITE, changedPartitionCount, null, addedDeleteFiles, addedDataFiles); + } + + protected void validateSnapshot(Snapshot snapshot, String operation, String changedPartitionCount, + String deletedDataFiles, String addedDeleteFiles, String addedDataFiles) { + Assert.assertEquals("Operation must match", operation, snapshot.operation()); + validateProperty(snapshot, CHANGED_PARTITION_COUNT_PROP, changedPartitionCount); + validateProperty(snapshot, DELETED_FILES_PROP, deletedDataFiles); + validateProperty(snapshot, ADDED_DELETE_FILES_PROP, addedDeleteFiles); + validateProperty(snapshot, ADDED_FILES_PROP, addedDataFiles); + } + + protected void validateProperty(Snapshot snapshot, String property, Set expectedValues) { + String actual = snapshot.summary().get(property); + Assert.assertTrue("Snapshot property " + property + " has unexpected value, actual = " + + actual + ", expected one of : " + String.join(",", expectedValues), + expectedValues.contains(actual)); + } + + protected void validateProperty(Snapshot snapshot, String property, String expectedValue) { + String actual = snapshot.summary().get(property); + Assert.assertEquals("Snapshot property " + property + " has unexpected value.", expectedValue, actual); + } + + protected void sleep(long millis) { + try { + Thread.sleep(millis); + } catch (InterruptedException e) { + throw new RuntimeException(e); + } + } +} diff --git a/spark/v3.3/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestAddFilesProcedure.java b/spark/v3.3/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestAddFilesProcedure.java new file mode 100644 index 000000000000..e274ad857875 --- /dev/null +++ b/spark/v3.3/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestAddFilesProcedure.java @@ -0,0 +1,903 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iceberg.spark.extensions; + +import java.io.File; +import java.io.IOException; +import java.util.List; +import java.util.Map; +import java.util.stream.Collectors; +import org.apache.avro.Schema; +import org.apache.avro.SchemaBuilder; +import org.apache.avro.file.DataFileWriter; +import org.apache.avro.generic.GenericData; +import org.apache.avro.generic.GenericDatumWriter; +import org.apache.avro.generic.GenericRecord; +import org.apache.avro.io.DatumWriter; +import org.apache.iceberg.AssertHelpers; +import org.apache.iceberg.DataFile; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableList; +import org.apache.iceberg.relocated.com.google.common.collect.Lists; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.RowFactory; +import org.apache.spark.sql.types.DataTypes; +import org.apache.spark.sql.types.Metadata; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; +import org.joda.time.DateTime; +import org.junit.After; +import org.junit.Assert; +import org.junit.Assume; +import org.junit.Before; +import org.junit.Ignore; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.TemporaryFolder; + +public class TestAddFilesProcedure extends SparkExtensionsTestBase { + + private final String sourceTableName = "source_table"; + private File fileTableDir; + + public TestAddFilesProcedure(String catalogName, String implementation, Map config) { + super(catalogName, implementation, config); + } + + @Rule + public TemporaryFolder temp = new TemporaryFolder(); + + @Before + public void setupTempDirs() { + try { + fileTableDir = temp.newFolder(); + } catch (IOException e) { + throw new RuntimeException(e); + } + } + + @After + public void dropTables() { + sql("DROP TABLE IF EXISTS %s", sourceTableName); + sql("DROP TABLE IF EXISTS %s", tableName); + } + + @Test + public void addDataUnpartitioned() { + createUnpartitionedFileTable("parquet"); + + String createIceberg = + "CREATE TABLE %s (id Integer, name String, dept String, subdept String) USING iceberg"; + + sql(createIceberg, tableName); + + Object result = scalarSql("CALL %s.system.add_files('%s', '`parquet`.`%s`')", + catalogName, tableName, fileTableDir.getAbsolutePath()); + + Assert.assertEquals(2L, result); + + assertEquals("Iceberg table contains correct data", + sql("SELECT * FROM %s ORDER BY id", sourceTableName), + sql("SELECT * FROM %s ORDER BY id", tableName)); + } + + @Ignore // TODO Classpath issues prevent us from actually writing to a Spark ORC table + public void addDataUnpartitionedOrc() { + createUnpartitionedFileTable("orc"); + + String createIceberg = + "CREATE TABLE %s (id Integer, name String, dept String, subdept String) USING iceberg"; + + sql(createIceberg, tableName); + + Object result = scalarSql("CALL %s.system.add_files('%s', '`orc`.`%s`')", + catalogName, tableName, fileTableDir.getAbsolutePath()); + + Assert.assertEquals(2L, result); + + assertEquals("Iceberg table contains correct data", + sql("SELECT * FROM %s ORDER BY id", sourceTableName), + sql("SELECT * FROM %s ORDER BY id", tableName)); + } + + @Test + public void addAvroFile() throws Exception { + // Spark Session Catalog cannot load metadata tables + // with "The namespace in session catalog must have exactly one name part" + Assume.assumeFalse(catalogName.equals("spark_catalog")); + + // Create an Avro file + + Schema schema = SchemaBuilder.record("record").fields() + .requiredInt("id") + .requiredString("data") + .endRecord(); + GenericRecord record1 = new GenericData.Record(schema); + record1.put("id", 1L); + record1.put("data", "a"); + GenericRecord record2 = new GenericData.Record(schema); + record2.put("id", 2L); + record2.put("data", "b"); + File outputFile = temp.newFile("test.avro"); + + DatumWriter datumWriter = new GenericDatumWriter(schema); + DataFileWriter dataFileWriter = new DataFileWriter(datumWriter); + dataFileWriter.create(schema, outputFile); + dataFileWriter.append(record1); + dataFileWriter.append(record2); + dataFileWriter.close(); + + String createIceberg = + "CREATE TABLE %s (id Long, data String) USING iceberg"; + sql(createIceberg, tableName); + + Object result = scalarSql("CALL %s.system.add_files('%s', '`avro`.`%s`')", + catalogName, tableName, outputFile.getPath()); + Assert.assertEquals(1L, result); + + List expected = Lists.newArrayList( + new Object[]{1L, "a"}, + new Object[]{2L, "b"} + ); + + assertEquals("Iceberg table contains correct data", + expected, + sql("SELECT * FROM %s ORDER BY id", tableName)); + + List actualRecordCount = sql("select %s from %s.files", + DataFile.RECORD_COUNT.name(), + tableName); + List expectedRecordCount = Lists.newArrayList(); + expectedRecordCount.add(new Object[]{2L}); + assertEquals("Iceberg file metadata should have correct metadata count", + expectedRecordCount, actualRecordCount); + } + + // TODO Adding spark-avro doesn't work in tests + @Ignore + public void addDataUnpartitionedAvro() { + createUnpartitionedFileTable("avro"); + + String createIceberg = + "CREATE TABLE %s (id Integer, name String, dept String, subdept String) USING iceberg"; + + sql(createIceberg, tableName); + + Object result = scalarSql("CALL %s.system.add_files('%s', '`avro`.`%s`')", + catalogName, tableName, fileTableDir.getAbsolutePath()); + + Assert.assertEquals(2L, result); + + assertEquals("Iceberg table contains correct data", + sql("SELECT * FROM %s ORDER BY id", sourceTableName), + sql("SELECT * FROM %s ORDER BY id", tableName)); + } + + @Test + public void addDataUnpartitionedHive() { + createUnpartitionedHiveTable(); + + String createIceberg = + "CREATE TABLE %s (id Integer, name String, dept String, subdept String) USING iceberg"; + + sql(createIceberg, tableName); + + Object result = scalarSql("CALL %s.system.add_files('%s', '%s')", + catalogName, tableName, sourceTableName); + + Assert.assertEquals(2L, result); + + assertEquals("Iceberg table contains correct data", + sql("SELECT * FROM %s ORDER BY id", sourceTableName), + sql("SELECT * FROM %s ORDER BY id", tableName)); + } + + @Test + public void addDataUnpartitionedExtraCol() { + createUnpartitionedFileTable("parquet"); + + String createIceberg = + "CREATE TABLE %s (id Integer, name String, dept String, subdept String, foo string) USING iceberg"; + + sql(createIceberg, tableName); + + Object result = scalarSql("CALL %s.system.add_files('%s', '`parquet`.`%s`')", + catalogName, tableName, fileTableDir.getAbsolutePath()); + + Assert.assertEquals(2L, result); + + assertEquals("Iceberg table contains correct data", + sql("SELECT * FROM %s ORDER BY id", sourceTableName), + sql("SELECT id, name, dept, subdept FROM %s ORDER BY id", tableName)); + } + + @Test + public void addDataUnpartitionedMissingCol() { + createUnpartitionedFileTable("parquet"); + + String createIceberg = + "CREATE TABLE %s (id Integer, name String, dept String) USING iceberg"; + + sql(createIceberg, tableName); + + Object result = scalarSql("CALL %s.system.add_files('%s', '`parquet`.`%s`')", + catalogName, tableName, fileTableDir.getAbsolutePath()); + + Assert.assertEquals(2L, result); + + assertEquals("Iceberg table contains correct data", + sql("SELECT id, name, dept FROM %s ORDER BY id", sourceTableName), + sql("SELECT * FROM %s ORDER BY id", tableName)); + } + + @Test + public void addDataPartitionedMissingCol() { + createPartitionedFileTable("parquet"); + + String createIceberg = + "CREATE TABLE %s (id Integer, name String, dept String) USING iceberg PARTITIONED BY (id)"; + + sql(createIceberg, tableName); + + Object result = scalarSql("CALL %s.system.add_files('%s', '`parquet`.`%s`')", + catalogName, tableName, fileTableDir.getAbsolutePath()); + + Assert.assertEquals(8L, result); + + assertEquals("Iceberg table contains correct data", + sql("SELECT id, name, dept FROM %s ORDER BY id", sourceTableName), + sql("SELECT * FROM %s ORDER BY id", tableName)); + } + + @Test + public void addDataPartitioned() { + createPartitionedFileTable("parquet"); + + String createIceberg = + "CREATE TABLE %s (id Integer, name String, dept String, subdept String) USING iceberg PARTITIONED BY (id)"; + + sql(createIceberg, tableName); + + Object result = scalarSql("CALL %s.system.add_files('%s', '`parquet`.`%s`')", + catalogName, tableName, fileTableDir.getAbsolutePath()); + + Assert.assertEquals(8L, result); + + assertEquals("Iceberg table contains correct data", + sql("SELECT id, name, dept, subdept FROM %s ORDER BY id", sourceTableName), + sql("SELECT id, name, dept, subdept FROM %s ORDER BY id", tableName)); + } + + @Ignore // TODO Classpath issues prevent us from actually writing to a Spark ORC table + public void addDataPartitionedOrc() { + createPartitionedFileTable("orc"); + + String createIceberg = + "CREATE TABLE %s (id Integer, name String, dept String, subdept String) USING iceberg PARTITIONED BY (id)"; + + sql(createIceberg, tableName); + + Object result = scalarSql("CALL %s.system.add_files('%s', '`parquet`.`%s`')", + catalogName, tableName, fileTableDir.getAbsolutePath()); + + Assert.assertEquals(8L, result); + + assertEquals("Iceberg table contains correct data", + sql("SELECT id, name, dept, subdept FROM %s ORDER BY id", sourceTableName), + sql("SELECT id, name, dept, subdept FROM %s ORDER BY id", tableName)); + } + + // TODO Adding spark-avro doesn't work in tests + @Ignore + public void addDataPartitionedAvro() { + createPartitionedFileTable("avro"); + + String createIceberg = + "CREATE TABLE %s (id Integer, name String, dept String, subdept String) USING iceberg PARTITIONED BY (id)"; + + sql(createIceberg, tableName); + + Object result = scalarSql("CALL %s.system.add_files('%s', '`avro`.`%s`')", + catalogName, tableName, fileTableDir.getAbsolutePath()); + + Assert.assertEquals(8L, result); + + assertEquals("Iceberg table contains correct data", + sql("SELECT id, name, dept, subdept FROM %s ORDER BY id", sourceTableName), + sql("SELECT id, name, dept, subdept FROM %s ORDER BY id", tableName)); + } + + @Test + public void addDataPartitionedHive() { + createPartitionedHiveTable(); + + String createIceberg = + "CREATE TABLE %s (id Integer, name String, dept String, subdept String) USING iceberg PARTITIONED BY (id)"; + + sql(createIceberg, tableName); + + Object result = scalarSql("CALL %s.system.add_files('%s', '%s')", + catalogName, tableName, sourceTableName); + + Assert.assertEquals(8L, result); + + assertEquals("Iceberg table contains correct data", + sql("SELECT id, name, dept, subdept FROM %s ORDER BY id", sourceTableName), + sql("SELECT id, name, dept, subdept FROM %s ORDER BY id", tableName)); + } + + @Test + public void addPartitionToPartitioned() { + createPartitionedFileTable("parquet"); + + String createIceberg = + "CREATE TABLE %s (id Integer, name String, dept String, subdept String) USING iceberg PARTITIONED BY (id)"; + + sql(createIceberg, tableName); + + Object result = scalarSql("CALL %s.system.add_files('%s', '`parquet`.`%s`', map('id', 1))", + catalogName, tableName, fileTableDir.getAbsolutePath()); + + Assert.assertEquals(2L, result); + + assertEquals("Iceberg table contains correct data", + sql("SELECT id, name, dept, subdept FROM %s WHERE id = 1 ORDER BY id", sourceTableName), + sql("SELECT id, name, dept, subdept FROM %s ORDER BY id", tableName)); + } + + @Test + public void addDataPartitionedByDateToPartitioned() { + createDatePartitionedFileTable("parquet"); + + String createIceberg = + "CREATE TABLE %s (id Integer, name String, dept String, date Date) USING iceberg PARTITIONED BY (date)"; + + sql(createIceberg, tableName); + + Object result = scalarSql("CALL %s.system.add_files('%s', '`parquet`.`%s`', map('date', '2021-01-01'))", + catalogName, tableName, fileTableDir.getAbsolutePath()); + + Assert.assertEquals(2L, result); + + assertEquals("Iceberg table contains correct data", + sql("SELECT id, name, dept, date FROM %s WHERE date = '2021-01-01' ORDER BY id", sourceTableName), + sql("SELECT id, name, dept, date FROM %s ORDER BY id", tableName)); + } + + @Test + public void addFilteredPartitionsToPartitioned() { + createCompositePartitionedTable("parquet"); + + String createIceberg = + "CREATE TABLE %s (id Integer, name String, dept String, subdept String) USING iceberg " + + "PARTITIONED BY (id, dept)"; + + sql(createIceberg, tableName); + + Object result = scalarSql("CALL %s.system.add_files('%s', '`parquet`.`%s`', map('id', 1))", + catalogName, tableName, fileTableDir.getAbsolutePath()); + + Assert.assertEquals(2L, result); + + assertEquals("Iceberg table contains correct data", + sql("SELECT id, name, dept, subdept FROM %s WHERE id = 1 ORDER BY id", sourceTableName), + sql("SELECT id, name, dept, subdept FROM %s ORDER BY id", tableName)); + } + + @Test + public void addFilteredPartitionsToPartitioned2() { + createCompositePartitionedTable("parquet"); + + String createIceberg = + "CREATE TABLE %s (id Integer, name String, dept String, subdept String) USING iceberg " + + "PARTITIONED BY (id, dept)"; + + sql(createIceberg, tableName); + + Object result = scalarSql("CALL %s.system.add_files('%s', '`parquet`.`%s`', map('dept', 'hr'))", + catalogName, tableName, fileTableDir.getAbsolutePath()); + + Assert.assertEquals(6L, result); + + assertEquals("Iceberg table contains correct data", + sql("SELECT id, name, dept, subdept FROM %s WHERE dept = 'hr' ORDER BY id", sourceTableName), + sql("SELECT id, name, dept, subdept FROM %s ORDER BY id", tableName)); + } + + @Test + public void addFilteredPartitionsToPartitionedWithNullValueFilteringOnId() { + createCompositePartitionedTableWithNullValueInPartitionColumn("parquet"); + + String createIceberg = + "CREATE TABLE %s (id Integer, name String, dept String, subdept String) USING iceberg " + + "PARTITIONED BY (id, dept)"; + + sql(createIceberg, tableName); + + Object result = scalarSql("CALL %s.system.add_files('%s', '`parquet`.`%s`', map('id', 1))", + catalogName, tableName, fileTableDir.getAbsolutePath()); + + Assert.assertEquals(2L, result); + + assertEquals("Iceberg table contains correct data", + sql("SELECT id, name, dept, subdept FROM %s WHERE id = 1 ORDER BY id", sourceTableName), + sql("SELECT id, name, dept, subdept FROM %s ORDER BY id", tableName)); + } + + @Test + public void addFilteredPartitionsToPartitionedWithNullValueFilteringOnDept() { + createCompositePartitionedTableWithNullValueInPartitionColumn("parquet"); + + String createIceberg = + "CREATE TABLE %s (id Integer, name String, dept String, subdept String) USING iceberg " + + "PARTITIONED BY (id, dept)"; + + sql(createIceberg, tableName); + + Object result = scalarSql("CALL %s.system.add_files('%s', '`parquet`.`%s`', map('dept', 'hr'))", + catalogName, tableName, fileTableDir.getAbsolutePath()); + + Assert.assertEquals(6L, result); + + assertEquals("Iceberg table contains correct data", + sql("SELECT id, name, dept, subdept FROM %s WHERE dept = 'hr' ORDER BY id", sourceTableName), + sql("SELECT id, name, dept, subdept FROM %s ORDER BY id", tableName)); + } + + @Test + public void addWeirdCaseHiveTable() { + createWeirdCaseTable(); + + String createIceberg = + "CREATE TABLE %s (id Integer, `naMe` String, dept String, subdept String) USING iceberg " + + "PARTITIONED BY (`naMe`)"; + + sql(createIceberg, tableName); + + Object result = scalarSql("CALL %s.system.add_files('%s', '%s', map('naMe', 'John Doe'))", + catalogName, tableName, sourceTableName); + + Assert.assertEquals(2L, result); + + /* + While we would like to use + SELECT id, `naMe`, dept, subdept FROM %s WHERE `naMe` = 'John Doe' ORDER BY id + Spark does not actually handle this pushdown correctly for hive based tables and it returns 0 records + */ + List expected = + sql("SELECT id, `naMe`, dept, subdept from %s ORDER BY id", sourceTableName) + .stream() + .filter(r -> r[1].equals("John Doe")) + .collect(Collectors.toList()); + + // TODO when this assert breaks Spark fixed the pushdown issue + Assert.assertEquals("If this assert breaks it means that Spark has fixed the pushdown issue", 0, + sql("SELECT id, `naMe`, dept, subdept from %s WHERE `naMe` = 'John Doe' ORDER BY id", sourceTableName) + .size()); + + // Pushdown works for iceberg + Assert.assertEquals("We should be able to pushdown mixed case partition keys", 2, + sql("SELECT id, `naMe`, dept, subdept FROM %s WHERE `naMe` = 'John Doe' ORDER BY id", tableName) + .size()); + + assertEquals("Iceberg table contains correct data", + expected, + sql("SELECT id, `naMe`, dept, subdept FROM %s ORDER BY id", tableName)); + } + + @Test + public void addPartitionToPartitionedHive() { + createPartitionedHiveTable(); + + String createIceberg = + "CREATE TABLE %s (id Integer, name String, dept String, subdept String) USING iceberg PARTITIONED BY (id)"; + + sql(createIceberg, tableName); + + Object result = scalarSql("CALL %s.system.add_files('%s', '%s', map('id', 1))", + catalogName, tableName, sourceTableName); + + Assert.assertEquals(2L, result); + + assertEquals("Iceberg table contains correct data", + sql("SELECT id, name, dept, subdept FROM %s WHERE id = 1 ORDER BY id", sourceTableName), + sql("SELECT id, name, dept, subdept FROM %s ORDER BY id", tableName)); + } + + @Test + public void invalidDataImport() { + createPartitionedFileTable("parquet"); + + String createIceberg = + "CREATE TABLE %s (id Integer, name String, dept String, subdept String) USING iceberg"; + + sql(createIceberg, tableName); + + AssertHelpers.assertThrows("Should forbid adding of partitioned data to unpartitioned table", + IllegalArgumentException.class, + "Cannot use partition filter with an unpartitioned table", + () -> scalarSql("CALL %s.system.add_files('%s', '`parquet`.`%s`', map('id', 1))", + catalogName, tableName, fileTableDir.getAbsolutePath()) + ); + + AssertHelpers.assertThrows("Should forbid adding of partitioned data to unpartitioned table", + IllegalArgumentException.class, + "Cannot add partitioned files to an unpartitioned table", + () -> scalarSql("CALL %s.system.add_files('%s', '`parquet`.`%s`')", + catalogName, tableName, fileTableDir.getAbsolutePath()) + ); + } + + @Test + public void invalidDataImportPartitioned() { + createUnpartitionedFileTable("parquet"); + + String createIceberg = + "CREATE TABLE %s (id Integer, name String, dept String, subdept String) USING iceberg PARTITIONED BY (id)"; + + sql(createIceberg, tableName); + + AssertHelpers.assertThrows("Should forbid adding with a mismatching partition spec", + IllegalArgumentException.class, + "is greater than the number of partitioned columns", + () -> scalarSql("CALL %s.system.add_files('%s', '`parquet`.`%s`', map('x', '1', 'y', '2'))", + catalogName, tableName, fileTableDir.getAbsolutePath())); + + AssertHelpers.assertThrows("Should forbid adding with partition spec with incorrect columns", + IllegalArgumentException.class, + "specified partition filter refers to columns that are not partitioned", + () -> scalarSql("CALL %s.system.add_files('%s', '`parquet`.`%s`', map('dept', '2'))", + catalogName, tableName, fileTableDir.getAbsolutePath())); + } + + + @Test + public void addTwice() { + createPartitionedHiveTable(); + + String createIceberg = + "CREATE TABLE %s (id Integer, name String, dept String, subdept String) USING iceberg PARTITIONED BY (id)"; + + sql(createIceberg, tableName); + + Object result1 = scalarSql("CALL %s.system.add_files(" + + "table => '%s', " + + "source_table => '%s', " + + "partition_filter => map('id', 1))", + catalogName, tableName, sourceTableName); + Assert.assertEquals(2L, result1); + + Object result2 = scalarSql("CALL %s.system.add_files(" + + "table => '%s', " + + "source_table => '%s', " + + "partition_filter => map('id', 2))", + catalogName, tableName, sourceTableName); + Assert.assertEquals(2L, result2); + + assertEquals("Iceberg table contains correct data", + sql("SELECT id, name, dept, subdept FROM %s WHERE id = 1 ORDER BY id", sourceTableName), + sql("SELECT id, name, dept, subdept FROM %s WHERE id = 1 ORDER BY id", tableName)); + assertEquals("Iceberg table contains correct data", + sql("SELECT id, name, dept, subdept FROM %s WHERE id = 2 ORDER BY id", sourceTableName), + sql("SELECT id, name, dept, subdept FROM %s WHERE id = 2 ORDER BY id", tableName)); + } + + @Test + public void duplicateDataPartitioned() { + createPartitionedHiveTable(); + + String createIceberg = + "CREATE TABLE %s (id Integer, name String, dept String, subdept String) USING iceberg PARTITIONED BY (id)"; + + sql(createIceberg, tableName); + + scalarSql("CALL %s.system.add_files(" + + "table => '%s', " + + "source_table => '%s', " + + "partition_filter => map('id', 1))", + catalogName, tableName, sourceTableName); + + AssertHelpers.assertThrows("Should not allow adding duplicate files", + IllegalStateException.class, + "Cannot complete import because data files to be imported already" + + " exist within the target table", + () -> scalarSql("CALL %s.system.add_files(" + + "table => '%s', " + + "source_table => '%s', " + + "partition_filter => map('id', 1))", + catalogName, tableName, sourceTableName)); + } + + @Test + public void duplicateDataPartitionedAllowed() { + createPartitionedHiveTable(); + + String createIceberg = + "CREATE TABLE %s (id Integer, name String, dept String, subdept String) USING iceberg PARTITIONED BY (id)"; + + sql(createIceberg, tableName); + + Object result1 = scalarSql("CALL %s.system.add_files(" + + "table => '%s', " + + "source_table => '%s', " + + "partition_filter => map('id', 1))", + catalogName, tableName, sourceTableName); + + Assert.assertEquals(2L, result1); + + Object result2 = scalarSql("CALL %s.system.add_files(" + + "table => '%s', " + + "source_table => '%s', " + + "partition_filter => map('id', 1)," + + "check_duplicate_files => false)", + catalogName, tableName, sourceTableName); + + Assert.assertEquals(2L, result2); + + + assertEquals("Iceberg table contains correct data", + sql("SELECT id, name, dept, subdept FROM %s WHERE id = 1 UNION ALL " + + "SELECT id, name, dept, subdept FROM %s WHERE id = 1", sourceTableName, sourceTableName), + sql("SELECT id, name, dept, subdept FROM %s", tableName, tableName)); + } + + @Test + public void duplicateDataUnpartitioned() { + createUnpartitionedHiveTable(); + + String createIceberg = + "CREATE TABLE %s (id Integer, name String, dept String, subdept String) USING iceberg"; + + sql(createIceberg, tableName); + + scalarSql("CALL %s.system.add_files('%s', '%s')", + catalogName, tableName, sourceTableName); + + AssertHelpers.assertThrows("Should not allow adding duplicate files", + IllegalStateException.class, + "Cannot complete import because data files to be imported already" + + " exist within the target table", + () -> scalarSql("CALL %s.system.add_files('%s', '%s')", + catalogName, tableName, sourceTableName)); + } + + @Test + public void duplicateDataUnpartitionedAllowed() { + createUnpartitionedHiveTable(); + + String createIceberg = + "CREATE TABLE %s (id Integer, name String, dept String, subdept String) USING iceberg"; + + sql(createIceberg, tableName); + + Object result1 = scalarSql("CALL %s.system.add_files('%s', '%s')", + catalogName, tableName, sourceTableName); + Assert.assertEquals(2L, result1); + + Object result2 = scalarSql("CALL %s.system.add_files(" + + "table => '%s', " + + "source_table => '%s'," + + "check_duplicate_files => false)", + catalogName, tableName, sourceTableName); + Assert.assertEquals(2L, result2); + + assertEquals("Iceberg table contains correct data", + sql("SELECT * FROM (SELECT * FROM %s UNION ALL " + + "SELECT * from %s) ORDER BY id", sourceTableName, sourceTableName), + sql("SELECT * FROM %s ORDER BY id", tableName)); + + + } + + @Test + public void testEmptyImportDoesNotThrow() { + + String createIceberg = + "CREATE TABLE %s (id Integer, name String, dept String, subdept String) USING iceberg"; + sql(createIceberg, tableName); + + // Empty path based import + Object pathResult = scalarSql("CALL %s.system.add_files('%s', '`parquet`.`%s`')", + catalogName, tableName, fileTableDir.getAbsolutePath()); + Assert.assertEquals(0L, pathResult); + assertEquals("Iceberg table contains no added data when importing from an empty path", + emptyQueryResult, + sql("SELECT * FROM %s ORDER BY id", tableName)); + + // Empty table based import + String createHive = "CREATE TABLE %s (id Integer, name String, dept String, subdept String) STORED AS parquet"; + sql(createHive, sourceTableName); + + Object tableResult = scalarSql("CALL %s.system.add_files('%s', '%s')", + catalogName, tableName, sourceTableName); + Assert.assertEquals(0L, tableResult); + assertEquals("Iceberg table contains no added data when importing from an empty table", + emptyQueryResult, + sql("SELECT * FROM %s ORDER BY id", tableName)); + } + + @Test + public void testPartitionedImportFromEmptyPartitionDoesNotThrow() { + createPartitionedHiveTable(); + + final int emptyPartitionId = 999; + // Add an empty partition to the hive table + sql("ALTER TABLE %s ADD PARTITION (id = '%d') LOCATION '%d'", sourceTableName, + emptyPartitionId, emptyPartitionId); + + String createIceberg = + "CREATE TABLE %s (id Integer, name String, dept String, subdept String) USING iceberg PARTITIONED BY (id)"; + + sql(createIceberg, tableName); + + Object tableResult = scalarSql("CALL %s.system.add_files(" + + "table => '%s', " + + "source_table => '%s', " + + "partition_filter => map('id', %d))", + catalogName, tableName, sourceTableName, emptyPartitionId); + + Assert.assertEquals(0L, tableResult); + assertEquals("Iceberg table contains no added data when importing from an empty table", + emptyQueryResult, + sql("SELECT * FROM %s ORDER BY id", tableName)); + } + + private static final List emptyQueryResult = Lists.newArrayList(); + + private static final StructField[] struct = { + new StructField("id", DataTypes.IntegerType, true, Metadata.empty()), + new StructField("name", DataTypes.StringType, true, Metadata.empty()), + new StructField("dept", DataTypes.StringType, true, Metadata.empty()), + new StructField("subdept", DataTypes.StringType, true, Metadata.empty()) + }; + + private static final Dataset unpartitionedDF = + spark.createDataFrame( + ImmutableList.of( + RowFactory.create(1, "John Doe", "hr", "communications"), + RowFactory.create(2, "Jane Doe", "hr", "salary"), + RowFactory.create(3, "Matt Doe", "hr", "communications"), + RowFactory.create(4, "Will Doe", "facilities", "all")), + new StructType(struct)).repartition(1); + + private static final Dataset singleNullRecordDF = + spark.createDataFrame( + ImmutableList.of( + RowFactory.create(null, null, null, null)), + new StructType(struct)).repartition(1); + + private static final Dataset partitionedDF = + unpartitionedDF.select("name", "dept", "subdept", "id"); + + private static final Dataset compositePartitionedDF = + unpartitionedDF.select("name", "subdept", "id", "dept"); + + private static final Dataset compositePartitionedNullRecordDF = + singleNullRecordDF.select("name", "subdept", "id", "dept"); + + private static final Dataset weirdColumnNamesDF = + unpartitionedDF.select( + unpartitionedDF.col("id"), + unpartitionedDF.col("subdept"), + unpartitionedDF.col("dept"), + unpartitionedDF.col("name").as("naMe")); + + private static final StructField[] dateStruct = { + new StructField("id", DataTypes.IntegerType, true, Metadata.empty()), + new StructField("name", DataTypes.StringType, true, Metadata.empty()), + new StructField("dept", DataTypes.StringType, true, Metadata.empty()), + new StructField("ts", DataTypes.DateType, true, Metadata.empty()) + }; + + private static java.sql.Date toDate(String value) { + return new java.sql.Date(DateTime.parse(value).getMillis()); + } + + private static final Dataset dateDF = + spark.createDataFrame( + ImmutableList.of( + RowFactory.create(1, "John Doe", "hr", toDate("2021-01-01")), + RowFactory.create(2, "Jane Doe", "hr", toDate("2021-01-01")), + RowFactory.create(3, "Matt Doe", "hr", toDate("2021-01-02")), + RowFactory.create(4, "Will Doe", "facilities", toDate("2021-01-02"))), + new StructType(dateStruct)).repartition(2); + + private void createUnpartitionedFileTable(String format) { + String createParquet = + "CREATE TABLE %s (id Integer, name String, dept String, subdept String) USING %s LOCATION '%s'"; + + sql(createParquet, sourceTableName, format, fileTableDir.getAbsolutePath()); + unpartitionedDF.write().insertInto(sourceTableName); + unpartitionedDF.write().insertInto(sourceTableName); + } + + private void createPartitionedFileTable(String format) { + String createParquet = + "CREATE TABLE %s (id Integer, name String, dept String, subdept String) USING %s PARTITIONED BY (id) " + + "LOCATION '%s'"; + + sql(createParquet, sourceTableName, format, fileTableDir.getAbsolutePath()); + + partitionedDF.write().insertInto(sourceTableName); + partitionedDF.write().insertInto(sourceTableName); + } + + private void createCompositePartitionedTable(String format) { + String createParquet = "CREATE TABLE %s (id Integer, name String, dept String, subdept String) USING %s " + + "PARTITIONED BY (id, dept) LOCATION '%s'"; + sql(createParquet, sourceTableName, format, fileTableDir.getAbsolutePath()); + + compositePartitionedDF.write().insertInto(sourceTableName); + compositePartitionedDF.write().insertInto(sourceTableName); + } + + private void createCompositePartitionedTableWithNullValueInPartitionColumn(String format) { + String createParquet = "CREATE TABLE %s (id Integer, name String, dept String, subdept String) USING %s " + + "PARTITIONED BY (id, dept) LOCATION '%s'"; + sql(createParquet, sourceTableName, format, fileTableDir.getAbsolutePath()); + + Dataset unionedDF = compositePartitionedDF.unionAll(compositePartitionedNullRecordDF) + .select("name", "subdept", "id", "dept") + .repartition(1); + + unionedDF.write().insertInto(sourceTableName); + unionedDF.write().insertInto(sourceTableName); + } + + private void createWeirdCaseTable() { + String createParquet = + "CREATE TABLE %s (id Integer, subdept String, dept String) " + + "PARTITIONED BY (`naMe` String) STORED AS parquet"; + + sql(createParquet, sourceTableName); + + weirdColumnNamesDF.write().insertInto(sourceTableName); + weirdColumnNamesDF.write().insertInto(sourceTableName); + + } + + private void createUnpartitionedHiveTable() { + String createHive = "CREATE TABLE %s (id Integer, name String, dept String, subdept String) STORED AS parquet"; + + sql(createHive, sourceTableName); + + unpartitionedDF.write().insertInto(sourceTableName); + unpartitionedDF.write().insertInto(sourceTableName); + } + + private void createPartitionedHiveTable() { + String createHive = "CREATE TABLE %s (name String, dept String, subdept String) " + + "PARTITIONED BY (id Integer) STORED AS parquet"; + + sql(createHive, sourceTableName); + + partitionedDF.write().insertInto(sourceTableName); + partitionedDF.write().insertInto(sourceTableName); + } + + private void createDatePartitionedFileTable(String format) { + String createParquet = "CREATE TABLE %s (id Integer, name String, dept String, date Date) USING %s " + + "PARTITIONED BY (date) LOCATION '%s'"; + + sql(createParquet, sourceTableName, format, fileTableDir.getAbsolutePath()); + + dateDF.write().insertInto(sourceTableName); + } +} diff --git a/spark/v3.3/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestAlterTablePartitionFields.java b/spark/v3.3/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestAlterTablePartitionFields.java new file mode 100644 index 000000000000..9d630508b6e4 --- /dev/null +++ b/spark/v3.3/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestAlterTablePartitionFields.java @@ -0,0 +1,415 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iceberg.spark.extensions; + +import java.util.Map; +import org.apache.iceberg.PartitionSpec; +import org.apache.iceberg.Table; +import org.apache.iceberg.spark.source.SparkTable; +import org.apache.spark.sql.connector.catalog.CatalogManager; +import org.apache.spark.sql.connector.catalog.Identifier; +import org.apache.spark.sql.connector.catalog.TableCatalog; +import org.junit.After; +import org.junit.Assert; +import org.junit.Test; + +public class TestAlterTablePartitionFields extends SparkExtensionsTestBase { + public TestAlterTablePartitionFields(String catalogName, String implementation, Map config) { + super(catalogName, implementation, config); + } + + @After + public void removeTable() { + sql("DROP TABLE IF EXISTS %s", tableName); + } + + @Test + public void testAddIdentityPartition() { + sql("CREATE TABLE %s (id bigint NOT NULL, category string, ts timestamp, data string) USING iceberg", tableName); + Table table = validationCatalog.loadTable(tableIdent); + + Assert.assertTrue("Table should start unpartitioned", table.spec().isUnpartitioned()); + + sql("ALTER TABLE %s ADD PARTITION FIELD category", tableName); + + table.refresh(); + + PartitionSpec expected = PartitionSpec.builderFor(table.schema()) + .withSpecId(1) + .identity("category") + .build(); + + Assert.assertEquals("Should have new spec field", expected, table.spec()); + } + + @Test + public void testAddBucketPartition() { + sql("CREATE TABLE %s (id bigint NOT NULL, category string, ts timestamp, data string) USING iceberg", tableName); + Table table = validationCatalog.loadTable(tableIdent); + + Assert.assertTrue("Table should start unpartitioned", table.spec().isUnpartitioned()); + + sql("ALTER TABLE %s ADD PARTITION FIELD bucket(16, id)", tableName); + + table.refresh(); + + PartitionSpec expected = PartitionSpec.builderFor(table.schema()) + .withSpecId(1) + .bucket("id", 16, "id_bucket_16") + .build(); + + Assert.assertEquals("Should have new spec field", expected, table.spec()); + } + + @Test + public void testAddTruncatePartition() { + sql("CREATE TABLE %s (id bigint NOT NULL, category string, ts timestamp, data string) USING iceberg", tableName); + Table table = validationCatalog.loadTable(tableIdent); + + Assert.assertTrue("Table should start unpartitioned", table.spec().isUnpartitioned()); + + sql("ALTER TABLE %s ADD PARTITION FIELD truncate(data, 4)", tableName); + + table.refresh(); + + PartitionSpec expected = PartitionSpec.builderFor(table.schema()) + .withSpecId(1) + .truncate("data", 4, "data_trunc_4") + .build(); + + Assert.assertEquals("Should have new spec field", expected, table.spec()); + } + + @Test + public void testAddYearsPartition() { + sql("CREATE TABLE %s (id bigint NOT NULL, category string, ts timestamp, data string) USING iceberg", tableName); + Table table = validationCatalog.loadTable(tableIdent); + + Assert.assertTrue("Table should start unpartitioned", table.spec().isUnpartitioned()); + + sql("ALTER TABLE %s ADD PARTITION FIELD years(ts)", tableName); + + table.refresh(); + + PartitionSpec expected = PartitionSpec.builderFor(table.schema()) + .withSpecId(1) + .year("ts") + .build(); + + Assert.assertEquals("Should have new spec field", expected, table.spec()); + } + + @Test + public void testAddMonthsPartition() { + sql("CREATE TABLE %s (id bigint NOT NULL, category string, ts timestamp, data string) USING iceberg", tableName); + Table table = validationCatalog.loadTable(tableIdent); + + Assert.assertTrue("Table should start unpartitioned", table.spec().isUnpartitioned()); + + sql("ALTER TABLE %s ADD PARTITION FIELD months(ts)", tableName); + + table.refresh(); + + PartitionSpec expected = PartitionSpec.builderFor(table.schema()) + .withSpecId(1) + .month("ts") + .build(); + + Assert.assertEquals("Should have new spec field", expected, table.spec()); + } + + @Test + public void testAddDaysPartition() { + sql("CREATE TABLE %s (id bigint NOT NULL, category string, ts timestamp, data string) USING iceberg", tableName); + Table table = validationCatalog.loadTable(tableIdent); + + Assert.assertTrue("Table should start unpartitioned", table.spec().isUnpartitioned()); + + sql("ALTER TABLE %s ADD PARTITION FIELD days(ts)", tableName); + + table.refresh(); + + PartitionSpec expected = PartitionSpec.builderFor(table.schema()) + .withSpecId(1) + .day("ts") + .build(); + + Assert.assertEquals("Should have new spec field", expected, table.spec()); + } + + @Test + public void testAddHoursPartition() { + sql("CREATE TABLE %s (id bigint NOT NULL, category string, ts timestamp, data string) USING iceberg", tableName); + Table table = validationCatalog.loadTable(tableIdent); + + Assert.assertTrue("Table should start unpartitioned", table.spec().isUnpartitioned()); + + sql("ALTER TABLE %s ADD PARTITION FIELD hours(ts)", tableName); + + table.refresh(); + + PartitionSpec expected = PartitionSpec.builderFor(table.schema()) + .withSpecId(1) + .hour("ts") + .build(); + + Assert.assertEquals("Should have new spec field", expected, table.spec()); + } + + @Test + public void testAddNamedPartition() { + sql("CREATE TABLE %s (id bigint NOT NULL, category string, ts timestamp, data string) USING iceberg", tableName); + Table table = validationCatalog.loadTable(tableIdent); + + Assert.assertTrue("Table should start unpartitioned", table.spec().isUnpartitioned()); + + sql("ALTER TABLE %s ADD PARTITION FIELD bucket(16, id) AS shard", tableName); + + table.refresh(); + + PartitionSpec expected = PartitionSpec.builderFor(table.schema()) + .withSpecId(1) + .bucket("id", 16, "shard") + .build(); + + Assert.assertEquals("Should have new spec field", expected, table.spec()); + } + + @Test + public void testDropIdentityPartition() { + sql("CREATE TABLE %s (id bigint NOT NULL, category string, data string) USING iceberg PARTITIONED BY (category)", + tableName); + Table table = validationCatalog.loadTable(tableIdent); + + Assert.assertEquals("Table should start with 1 partition field", 1, table.spec().fields().size()); + + sql("ALTER TABLE %s DROP PARTITION FIELD category", tableName); + + table.refresh(); + + PartitionSpec expected = PartitionSpec.builderFor(table.schema()) + .withSpecId(1) + .alwaysNull("category", "category") + .build(); + + Assert.assertEquals("Should have new spec field", expected, table.spec()); + } + + @Test + public void testDropDaysPartition() { + sql("CREATE TABLE %s (id bigint NOT NULL, ts timestamp, data string) USING iceberg PARTITIONED BY (days(ts))", + tableName); + Table table = validationCatalog.loadTable(tableIdent); + + Assert.assertEquals("Table should start with 1 partition field", 1, table.spec().fields().size()); + + sql("ALTER TABLE %s DROP PARTITION FIELD days(ts)", tableName); + + table.refresh(); + + PartitionSpec expected = PartitionSpec.builderFor(table.schema()) + .withSpecId(1) + .alwaysNull("ts", "ts_day") + .build(); + + Assert.assertEquals("Should have new spec field", expected, table.spec()); + } + + @Test + public void testDropBucketPartition() { + sql("CREATE TABLE %s (id bigint NOT NULL, data string) USING iceberg PARTITIONED BY (bucket(16, id))", + tableName); + Table table = validationCatalog.loadTable(tableIdent); + + Assert.assertEquals("Table should start with 1 partition field", 1, table.spec().fields().size()); + + sql("ALTER TABLE %s DROP PARTITION FIELD bucket(16, id)", tableName); + + table.refresh(); + + PartitionSpec expected = PartitionSpec.builderFor(table.schema()) + .withSpecId(1) + .alwaysNull("id", "id_bucket") + .build(); + + Assert.assertEquals("Should have new spec field", expected, table.spec()); + } + + @Test + public void testDropPartitionByName() { + sql("CREATE TABLE %s (id bigint NOT NULL, category string, ts timestamp, data string) USING iceberg", tableName); + Table table = validationCatalog.loadTable(tableIdent); + + Assert.assertTrue("Table should start unpartitioned", table.spec().isUnpartitioned()); + + sql("ALTER TABLE %s ADD PARTITION FIELD bucket(16, id) AS shard", tableName); + + table.refresh(); + + Assert.assertEquals("Table should have 1 partition field", 1, table.spec().fields().size()); + + // Should be recognized as iceberg command even with extra white spaces + sql("ALTER TABLE %s DROP PARTITION \n FIELD shard", tableName); + + table.refresh(); + + PartitionSpec expected = PartitionSpec.builderFor(table.schema()) + .withSpecId(2) + .alwaysNull("id", "shard") + .build(); + + Assert.assertEquals("Should have new spec field", expected, table.spec()); + } + + @Test + public void testReplacePartition() { + sql("CREATE TABLE %s (id bigint NOT NULL, category string, ts timestamp, data string) USING iceberg", tableName); + Table table = validationCatalog.loadTable(tableIdent); + Assert.assertTrue("Table should start unpartitioned", table.spec().isUnpartitioned()); + + sql("ALTER TABLE %s ADD PARTITION FIELD days(ts)", tableName); + table.refresh(); + PartitionSpec expected = PartitionSpec.builderFor(table.schema()) + .withSpecId(1) + .day("ts") + .build(); + Assert.assertEquals("Should have new spec field", expected, table.spec()); + + sql("ALTER TABLE %s REPLACE PARTITION FIELD days(ts) WITH hours(ts)", tableName); + table.refresh(); + expected = PartitionSpec.builderFor(table.schema()) + .withSpecId(2) + .alwaysNull("ts", "ts_day") + .hour("ts") + .build(); + Assert.assertEquals("Should changed from daily to hourly partitioned field", expected, table.spec()); + } + + @Test + public void testReplacePartitionAndRename() { + sql("CREATE TABLE %s (id bigint NOT NULL, category string, ts timestamp, data string) USING iceberg", tableName); + Table table = validationCatalog.loadTable(tableIdent); + Assert.assertTrue("Table should start unpartitioned", table.spec().isUnpartitioned()); + + sql("ALTER TABLE %s ADD PARTITION FIELD days(ts)", tableName); + table.refresh(); + PartitionSpec expected = PartitionSpec.builderFor(table.schema()) + .withSpecId(1) + .day("ts") + .build(); + Assert.assertEquals("Should have new spec field", expected, table.spec()); + + sql("ALTER TABLE %s REPLACE PARTITION FIELD days(ts) WITH hours(ts) AS hour_col", tableName); + table.refresh(); + expected = PartitionSpec.builderFor(table.schema()) + .withSpecId(2) + .alwaysNull("ts", "ts_day") + .hour("ts", "hour_col") + .build(); + Assert.assertEquals("Should changed from daily to hourly partitioned field", expected, table.spec()); + } + + @Test + public void testReplaceNamedPartition() { + sql("CREATE TABLE %s (id bigint NOT NULL, category string, ts timestamp, data string) USING iceberg", tableName); + Table table = validationCatalog.loadTable(tableIdent); + Assert.assertTrue("Table should start unpartitioned", table.spec().isUnpartitioned()); + + sql("ALTER TABLE %s ADD PARTITION FIELD days(ts) AS day_col", tableName); + table.refresh(); + PartitionSpec expected = PartitionSpec.builderFor(table.schema()) + .withSpecId(1) + .day("ts", "day_col") + .build(); + Assert.assertEquals("Should have new spec field", expected, table.spec()); + + sql("ALTER TABLE %s REPLACE PARTITION FIELD day_col WITH hours(ts)", tableName); + table.refresh(); + expected = PartitionSpec.builderFor(table.schema()) + .withSpecId(2) + .alwaysNull("ts", "day_col") + .hour("ts") + .build(); + Assert.assertEquals("Should changed from daily to hourly partitioned field", expected, table.spec()); + } + + @Test + public void testReplaceNamedPartitionAndRenameDifferently() { + sql("CREATE TABLE %s (id bigint NOT NULL, category string, ts timestamp, data string) USING iceberg", tableName); + Table table = validationCatalog.loadTable(tableIdent); + Assert.assertTrue("Table should start unpartitioned", table.spec().isUnpartitioned()); + + sql("ALTER TABLE %s ADD PARTITION FIELD days(ts) AS day_col", tableName); + table.refresh(); + PartitionSpec expected = PartitionSpec.builderFor(table.schema()) + .withSpecId(1) + .day("ts", "day_col") + .build(); + Assert.assertEquals("Should have new spec field", expected, table.spec()); + + sql("ALTER TABLE %s REPLACE PARTITION FIELD day_col WITH hours(ts) AS hour_col", tableName); + table.refresh(); + expected = PartitionSpec.builderFor(table.schema()) + .withSpecId(2) + .alwaysNull("ts", "day_col") + .hour("ts", "hour_col") + .build(); + Assert.assertEquals("Should changed from daily to hourly partitioned field", expected, table.spec()); + } + + @Test + public void testSparkTableAddDropPartitions() throws Exception { + sql("CREATE TABLE %s (id bigint NOT NULL, ts timestamp, data string) USING iceberg", tableName); + Assert.assertEquals("spark table partition should be empty", 0, sparkTable().partitioning().length); + + sql("ALTER TABLE %s ADD PARTITION FIELD bucket(16, id) AS shard", tableName); + assertPartitioningEquals(sparkTable(), 1, "bucket(16, id)"); + + sql("ALTER TABLE %s ADD PARTITION FIELD truncate(data, 4)", tableName); + assertPartitioningEquals(sparkTable(), 2, "truncate(data, 4)"); + + sql("ALTER TABLE %s ADD PARTITION FIELD years(ts)", tableName); + assertPartitioningEquals(sparkTable(), 3, "years(ts)"); + + sql("ALTER TABLE %s DROP PARTITION FIELD years(ts)", tableName); + assertPartitioningEquals(sparkTable(), 2, "truncate(data, 4)"); + + sql("ALTER TABLE %s DROP PARTITION FIELD truncate(data, 4)", tableName); + assertPartitioningEquals(sparkTable(), 1, "bucket(16, id)"); + + sql("ALTER TABLE %s DROP PARTITION FIELD shard", tableName); + sql("DESCRIBE %s", tableName); + Assert.assertEquals("spark table partition should be empty", 0, sparkTable().partitioning().length); + } + + private void assertPartitioningEquals(SparkTable table, int len, String transform) { + Assert.assertEquals("spark table partition should be " + len, len, table.partitioning().length); + Assert.assertEquals("latest spark table partition transform should match", + transform, table.partitioning()[len - 1].toString()); + } + + private SparkTable sparkTable() throws Exception { + validationCatalog.loadTable(tableIdent).refresh(); + CatalogManager catalogManager = spark.sessionState().catalogManager(); + TableCatalog catalog = (TableCatalog) catalogManager.catalog(catalogName); + Identifier identifier = Identifier.of(tableIdent.namespace().levels(), tableIdent.name()); + return (SparkTable) catalog.loadTable(identifier); + } +} diff --git a/spark/v3.3/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestAlterTableSchema.java b/spark/v3.3/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestAlterTableSchema.java new file mode 100644 index 000000000000..ac12953d0a7e --- /dev/null +++ b/spark/v3.3/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestAlterTableSchema.java @@ -0,0 +1,142 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iceberg.spark.extensions; + +import java.util.Map; +import org.apache.iceberg.AssertHelpers; +import org.apache.iceberg.Table; +import org.apache.iceberg.relocated.com.google.common.collect.Sets; +import org.junit.After; +import org.junit.Assert; +import org.junit.Test; + +public class TestAlterTableSchema extends SparkExtensionsTestBase { + public TestAlterTableSchema(String catalogName, String implementation, Map config) { + super(catalogName, implementation, config); + } + + @After + public void removeTable() { + sql("DROP TABLE IF EXISTS %s", tableName); + } + + @Test + public void testSetIdentifierFields() { + sql("CREATE TABLE %s (id bigint NOT NULL, " + + "location struct NOT NULL) USING iceberg", tableName); + Table table = validationCatalog.loadTable(tableIdent); + Assert.assertTrue("Table should start without identifier", table.schema().identifierFieldIds().isEmpty()); + + sql("ALTER TABLE %s SET IDENTIFIER FIELDS id", tableName); + table.refresh(); + Assert.assertEquals("Should have new identifier field", + Sets.newHashSet(table.schema().findField("id").fieldId()), + table.schema().identifierFieldIds()); + + sql("ALTER TABLE %s SET IDENTIFIER FIELDS id, location.lon", tableName); + table.refresh(); + Assert.assertEquals("Should have new identifier field", + Sets.newHashSet( + table.schema().findField("id").fieldId(), + table.schema().findField("location.lon").fieldId()), + table.schema().identifierFieldIds()); + + sql("ALTER TABLE %s SET IDENTIFIER FIELDS location.lon", tableName); + table.refresh(); + Assert.assertEquals("Should have new identifier field", + Sets.newHashSet(table.schema().findField("location.lon").fieldId()), + table.schema().identifierFieldIds()); + } + + @Test + public void testSetInvalidIdentifierFields() { + sql("CREATE TABLE %s (id bigint NOT NULL, id2 bigint) USING iceberg", tableName); + Table table = validationCatalog.loadTable(tableIdent); + Assert.assertTrue("Table should start without identifier", table.schema().identifierFieldIds().isEmpty()); + AssertHelpers.assertThrows("should not allow setting unknown fields", + IllegalArgumentException.class, + "not found in current schema or added columns", + () -> sql("ALTER TABLE %s SET IDENTIFIER FIELDS unknown", tableName)); + + AssertHelpers.assertThrows("should not allow setting optional fields", + IllegalArgumentException.class, + "not a required field", + () -> sql("ALTER TABLE %s SET IDENTIFIER FIELDS id2", tableName)); + } + + @Test + public void testDropIdentifierFields() { + sql("CREATE TABLE %s (id bigint NOT NULL, " + + "location struct NOT NULL) USING iceberg", tableName); + Table table = validationCatalog.loadTable(tableIdent); + Assert.assertTrue("Table should start without identifier", table.schema().identifierFieldIds().isEmpty()); + + sql("ALTER TABLE %s SET IDENTIFIER FIELDS id, location.lon", tableName); + table.refresh(); + Assert.assertEquals("Should have new identifier fields", + Sets.newHashSet( + table.schema().findField("id").fieldId(), + table.schema().findField("location.lon").fieldId()), + table.schema().identifierFieldIds()); + + sql("ALTER TABLE %s DROP IDENTIFIER FIELDS id", tableName); + table.refresh(); + Assert.assertEquals("Should removed identifier field", + Sets.newHashSet(table.schema().findField("location.lon").fieldId()), + table.schema().identifierFieldIds()); + + sql("ALTER TABLE %s SET IDENTIFIER FIELDS id, location.lon", tableName); + table.refresh(); + Assert.assertEquals("Should have new identifier fields", + Sets.newHashSet( + table.schema().findField("id").fieldId(), + table.schema().findField("location.lon").fieldId()), + table.schema().identifierFieldIds()); + + sql("ALTER TABLE %s DROP IDENTIFIER FIELDS id, location.lon", tableName); + table.refresh(); + Assert.assertEquals("Should have no identifier field", + Sets.newHashSet(), + table.schema().identifierFieldIds()); + } + + @Test + public void testDropInvalidIdentifierFields() { + sql("CREATE TABLE %s (id bigint NOT NULL, data string NOT NULL, " + + "location struct NOT NULL) USING iceberg", tableName); + Table table = validationCatalog.loadTable(tableIdent); + Assert.assertTrue("Table should start without identifier", table.schema().identifierFieldIds().isEmpty()); + AssertHelpers.assertThrows("should not allow dropping unknown fields", + IllegalArgumentException.class, + "field unknown not found", + () -> sql("ALTER TABLE %s DROP IDENTIFIER FIELDS unknown", tableName)); + + sql("ALTER TABLE %s SET IDENTIFIER FIELDS id", tableName); + AssertHelpers.assertThrows("should not allow dropping a field that is not an identifier", + IllegalArgumentException.class, + "data is not an identifier field", + () -> sql("ALTER TABLE %s DROP IDENTIFIER FIELDS data", tableName)); + + AssertHelpers.assertThrows("should not allow dropping a nested field that is not an identifier", + IllegalArgumentException.class, + "location.lon is not an identifier field", + () -> sql("ALTER TABLE %s DROP IDENTIFIER FIELDS location.lon", tableName)); + } +} diff --git a/spark/v3.3/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestAncestorsOfProcedure.java b/spark/v3.3/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestAncestorsOfProcedure.java new file mode 100644 index 000000000000..baf464d94ad0 --- /dev/null +++ b/spark/v3.3/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestAncestorsOfProcedure.java @@ -0,0 +1,163 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iceberg.spark.extensions; + +import java.util.List; +import java.util.Map; +import org.apache.iceberg.AssertHelpers; +import org.apache.iceberg.Table; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableList; +import org.apache.spark.sql.AnalysisException; +import org.junit.After; +import org.junit.Test; + +public class TestAncestorsOfProcedure extends SparkExtensionsTestBase { + + public TestAncestorsOfProcedure(String catalogName, String implementation, Map config) { + super(catalogName, implementation, config); + } + + @After + public void removeTables() { + sql("DROP TABLE IF EXISTS %s", tableName); + } + + @Test + public void testAncestorOfUsingEmptyArgs() { + sql("CREATE TABLE %s (id bigint NOT NULL, data string) USING iceberg", tableName); + sql("INSERT INTO TABLE %s VALUES (1, 'a')", tableName); + sql("INSERT INTO TABLE %s VALUES (2, 'b')", tableName); + + Table table = validationCatalog.loadTable(tableIdent); + Long currentSnapshotId = table.currentSnapshot().snapshotId(); + Long currentTimestamp = table.currentSnapshot().timestampMillis(); + Long preSnapshotId = table.currentSnapshot().parentId(); + Long preTimeStamp = table.snapshot(table.currentSnapshot().parentId()).timestampMillis(); + + List output = sql("CALL %s.system.ancestors_of('%s')", + catalogName, tableIdent); + + assertEquals( + "Procedure output must match", + ImmutableList.of( + row(currentSnapshotId, currentTimestamp), + row(preSnapshotId, preTimeStamp)), + output); + } + + @Test + public void testAncestorOfUsingSnapshotId() { + sql("CREATE TABLE %s (id bigint NOT NULL, data string) USING iceberg", tableName); + sql("INSERT INTO TABLE %s VALUES (1, 'a')", tableName); + sql("INSERT INTO TABLE %s VALUES (2, 'b')", tableName); + + Table table = validationCatalog.loadTable(tableIdent); + Long currentSnapshotId = table.currentSnapshot().snapshotId(); + Long currentTimestamp = table.currentSnapshot().timestampMillis(); + Long preSnapshotId = table.currentSnapshot().parentId(); + Long preTimeStamp = table.snapshot(table.currentSnapshot().parentId()).timestampMillis(); + + assertEquals( + "Procedure output must match", + ImmutableList.of( + row(currentSnapshotId, currentTimestamp), + row(preSnapshotId, preTimeStamp)), + sql("CALL %s.system.ancestors_of('%s', %dL)", catalogName, tableIdent, currentSnapshotId)); + + assertEquals( + "Procedure output must match", + ImmutableList.of(row(preSnapshotId, preTimeStamp)), + sql("CALL %s.system.ancestors_of('%s', %dL)", catalogName, tableIdent, preSnapshotId)); + } + + @Test + public void testAncestorOfWithRollBack() { + sql("CREATE TABLE %s (id bigint NOT NULL, data string) USING iceberg", tableName); + Table table = validationCatalog.loadTable(tableIdent); + sql("INSERT INTO TABLE %s VALUES (1, 'a')", tableName); + table.refresh(); + Long firstSnapshotId = table.currentSnapshot().snapshotId(); + Long firstTimestamp = table.currentSnapshot().timestampMillis(); + sql("INSERT INTO TABLE %s VALUES (2, 'b')", tableName); + table.refresh(); + Long secondSnapshotId = table.currentSnapshot().snapshotId(); + Long secondTimestamp = table.currentSnapshot().timestampMillis(); + sql("INSERT INTO TABLE %s VALUES (3, 'c')", tableName); + table.refresh(); + Long thirdSnapshotId = table.currentSnapshot().snapshotId(); + Long thirdTimestamp = table.currentSnapshot().timestampMillis(); + + // roll back + sql("CALL %s.system.rollback_to_snapshot('%s', %dL)", + catalogName, tableIdent, secondSnapshotId); + + sql("INSERT INTO TABLE %s VALUES (4, 'd')", tableName); + table.refresh(); + Long fourthSnapshotId = table.currentSnapshot().snapshotId(); + Long fourthTimestamp = table.currentSnapshot().timestampMillis(); + + assertEquals( + "Procedure output must match", + ImmutableList.of( + row(fourthSnapshotId, fourthTimestamp), + row(secondSnapshotId, secondTimestamp), + row(firstSnapshotId, firstTimestamp)), + sql("CALL %s.system.ancestors_of('%s', %dL)", catalogName, tableIdent, fourthSnapshotId)); + + assertEquals( + "Procedure output must match", + ImmutableList.of( + row(thirdSnapshotId, thirdTimestamp), + row(secondSnapshotId, secondTimestamp), + row(firstSnapshotId, firstTimestamp)), + sql("CALL %s.system.ancestors_of('%s', %dL)", catalogName, tableIdent, thirdSnapshotId)); + } + + @Test + public void testAncestorOfUsingNamedArgs() { + sql("CREATE TABLE %s (id bigint NOT NULL, data string) USING iceberg", tableName); + sql("INSERT INTO TABLE %s VALUES (1, 'a')", tableName); + + Table table = validationCatalog.loadTable(tableIdent); + Long firstSnapshotId = table.currentSnapshot().snapshotId(); + Long firstTimestamp = table.currentSnapshot().timestampMillis(); + + assertEquals( + "Procedure output must match", + ImmutableList.of(row(firstSnapshotId, firstTimestamp)), + sql("CALL %s.system.ancestors_of(snapshot_id => %dL, table => '%s')", + catalogName, firstSnapshotId, tableIdent)); + } + + @Test + public void testInvalidAncestorOfCases() { + AssertHelpers.assertThrows("Should reject calls without all required args", + AnalysisException.class, "Missing required parameters", + () -> sql("CALL %s.system.ancestors_of()", catalogName)); + + AssertHelpers.assertThrows("Should reject calls with empty table identifier", + IllegalArgumentException.class, "Cannot handle an empty identifier for argument table", + () -> sql("CALL %s.system.ancestors_of('')", catalogName)); + + AssertHelpers.assertThrows("Should reject calls with invalid arg types", + AnalysisException.class, "Wrong arg type for snapshot_id: cannot cast", + () -> sql("CALL %s.system.ancestors_of('%s', 1.1)", catalogName, tableIdent)); + } +} diff --git a/spark/v3.3/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestCallStatementParser.java b/spark/v3.3/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestCallStatementParser.java new file mode 100644 index 000000000000..a6815d21a7b3 --- /dev/null +++ b/spark/v3.3/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestCallStatementParser.java @@ -0,0 +1,196 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iceberg.spark.extensions; + +import java.math.BigDecimal; +import java.sql.Timestamp; +import java.time.Instant; +import java.util.List; +import org.apache.iceberg.AssertHelpers; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableList; +import org.apache.iceberg.relocated.com.google.common.collect.Lists; +import org.apache.spark.sql.SparkSession; +import org.apache.spark.sql.catalyst.expressions.Expression; +import org.apache.spark.sql.catalyst.expressions.Literal; +import org.apache.spark.sql.catalyst.expressions.Literal$; +import org.apache.spark.sql.catalyst.parser.ParseException; +import org.apache.spark.sql.catalyst.parser.ParserInterface; +import org.apache.spark.sql.catalyst.parser.extensions.IcebergParseException; +import org.apache.spark.sql.catalyst.plans.logical.CallArgument; +import org.apache.spark.sql.catalyst.plans.logical.CallStatement; +import org.apache.spark.sql.catalyst.plans.logical.NamedArgument; +import org.apache.spark.sql.catalyst.plans.logical.PositionalArgument; +import org.apache.spark.sql.types.DataType; +import org.apache.spark.sql.types.DataTypes; +import org.junit.AfterClass; +import org.junit.Assert; +import org.junit.BeforeClass; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.TemporaryFolder; +import scala.collection.JavaConverters; + +public class TestCallStatementParser { + + @Rule + public TemporaryFolder temp = new TemporaryFolder(); + + private static SparkSession spark = null; + private static ParserInterface parser = null; + + @BeforeClass + public static void startSpark() { + TestCallStatementParser.spark = SparkSession.builder() + .master("local[2]") + .config("spark.sql.extensions", IcebergSparkSessionExtensions.class.getName()) + .config("spark.extra.prop", "value") + .getOrCreate(); + TestCallStatementParser.parser = spark.sessionState().sqlParser(); + } + + @AfterClass + public static void stopSpark() { + SparkSession currentSpark = TestCallStatementParser.spark; + TestCallStatementParser.spark = null; + TestCallStatementParser.parser = null; + currentSpark.stop(); + } + + @Test + public void testCallWithPositionalArgs() throws ParseException { + CallStatement call = (CallStatement) parser.parsePlan("CALL c.n.func(1, '2', 3L, true, 1.0D, 9.0e1, 900e-1BD)"); + Assert.assertEquals(ImmutableList.of("c", "n", "func"), JavaConverters.seqAsJavaList(call.name())); + + Assert.assertEquals(7, call.args().size()); + + checkArg(call, 0, 1, DataTypes.IntegerType); + checkArg(call, 1, "2", DataTypes.StringType); + checkArg(call, 2, 3L, DataTypes.LongType); + checkArg(call, 3, true, DataTypes.BooleanType); + checkArg(call, 4, 1.0D, DataTypes.DoubleType); + checkArg(call, 5, 9.0e1, DataTypes.DoubleType); + checkArg(call, 6, new BigDecimal("900e-1"), DataTypes.createDecimalType(3, 1)); + } + + @Test + public void testCallWithNamedArgs() throws ParseException { + CallStatement call = (CallStatement) parser.parsePlan("CALL cat.system.func(c1 => 1, c2 => '2', c3 => true)"); + Assert.assertEquals(ImmutableList.of("cat", "system", "func"), JavaConverters.seqAsJavaList(call.name())); + + Assert.assertEquals(3, call.args().size()); + + checkArg(call, 0, "c1", 1, DataTypes.IntegerType); + checkArg(call, 1, "c2", "2", DataTypes.StringType); + checkArg(call, 2, "c3", true, DataTypes.BooleanType); + } + + @Test + public void testCallWithMixedArgs() throws ParseException { + CallStatement call = (CallStatement) parser.parsePlan("CALL cat.system.func(c1 => 1, '2')"); + Assert.assertEquals(ImmutableList.of("cat", "system", "func"), JavaConverters.seqAsJavaList(call.name())); + + Assert.assertEquals(2, call.args().size()); + + checkArg(call, 0, "c1", 1, DataTypes.IntegerType); + checkArg(call, 1, "2", DataTypes.StringType); + } + + @Test + public void testCallWithTimestampArg() throws ParseException { + CallStatement call = (CallStatement) parser.parsePlan("CALL cat.system.func(TIMESTAMP '2017-02-03T10:37:30.00Z')"); + Assert.assertEquals(ImmutableList.of("cat", "system", "func"), JavaConverters.seqAsJavaList(call.name())); + + Assert.assertEquals(1, call.args().size()); + + checkArg(call, 0, Timestamp.from(Instant.parse("2017-02-03T10:37:30.00Z")), DataTypes.TimestampType); + } + + @Test + public void testCallWithVarSubstitution() throws ParseException { + CallStatement call = (CallStatement) parser.parsePlan("CALL cat.system.func('${spark.extra.prop}')"); + Assert.assertEquals(ImmutableList.of("cat", "system", "func"), JavaConverters.seqAsJavaList(call.name())); + + Assert.assertEquals(1, call.args().size()); + + checkArg(call, 0, "value", DataTypes.StringType); + } + + @Test + public void testCallParseError() { + AssertHelpers.assertThrows("Should fail with a sensible parse error", IcebergParseException.class, + "missing '(' at 'radish'", + () -> parser.parsePlan("CALL cat.system radish kebab")); + } + + @Test + public void testCallStripsComments() throws ParseException { + List callStatementsWithComments = Lists.newArrayList( + "/* bracketed comment */ CALL cat.system.func('${spark.extra.prop}')", + "/**/ CALL cat.system.func('${spark.extra.prop}')", + "-- single line comment \n CALL cat.system.func('${spark.extra.prop}')", + "-- multiple \n-- single line \n-- comments \n CALL cat.system.func('${spark.extra.prop}')", + "/* select * from multiline_comment \n where x like '%sql%'; */ CALL cat.system.func('${spark.extra.prop}')", + "/* {\"app\": \"dbt\", \"dbt_version\": \"1.0.1\", \"profile_name\": \"profile1\", \"target_name\": \"dev\", " + + "\"node_id\": \"model.profile1.stg_users\"} \n*/ CALL cat.system.func('${spark.extra.prop}')", + "/* Some multi-line comment \n" + + "*/ CALL /* inline comment */ cat.system.func('${spark.extra.prop}') -- ending comment", + "CALL -- a line ending comment\n" + + "cat.system.func('${spark.extra.prop}')" + ); + for (String sqlText : callStatementsWithComments) { + CallStatement call = (CallStatement) parser.parsePlan(sqlText); + Assert.assertEquals(ImmutableList.of("cat", "system", "func"), JavaConverters.seqAsJavaList(call.name())); + + Assert.assertEquals(1, call.args().size()); + + checkArg(call, 0, "value", DataTypes.StringType); + } + } + + private void checkArg(CallStatement call, int index, Object expectedValue, DataType expectedType) { + checkArg(call, index, null, expectedValue, expectedType); + } + + private void checkArg(CallStatement call, int index, String expectedName, + Object expectedValue, DataType expectedType) { + + if (expectedName != null) { + NamedArgument arg = checkCast(call.args().apply(index), NamedArgument.class); + Assert.assertEquals(expectedName, arg.name()); + } else { + CallArgument arg = call.args().apply(index); + checkCast(arg, PositionalArgument.class); + } + + Expression expectedExpr = toSparkLiteral(expectedValue, expectedType); + Expression actualExpr = call.args().apply(index).expr(); + Assert.assertEquals("Arg types must match", expectedExpr.dataType(), actualExpr.dataType()); + Assert.assertEquals("Arg must match", expectedExpr, actualExpr); + } + + private Literal toSparkLiteral(Object value, DataType dataType) { + return Literal$.MODULE$.create(value, dataType); + } + + private T checkCast(Object value, Class expectedClass) { + Assert.assertTrue("Expected instance of " + expectedClass.getName(), expectedClass.isInstance(value)); + return expectedClass.cast(value); + } +} diff --git a/spark/v3.3/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestCherrypickSnapshotProcedure.java b/spark/v3.3/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestCherrypickSnapshotProcedure.java new file mode 100644 index 000000000000..c69964693189 --- /dev/null +++ b/spark/v3.3/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestCherrypickSnapshotProcedure.java @@ -0,0 +1,180 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iceberg.spark.extensions; + +import java.util.List; +import java.util.Map; +import org.apache.iceberg.AssertHelpers; +import org.apache.iceberg.Snapshot; +import org.apache.iceberg.Table; +import org.apache.iceberg.exceptions.ValidationException; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableList; +import org.apache.iceberg.relocated.com.google.common.collect.Iterables; +import org.apache.spark.sql.AnalysisException; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.catalyst.analysis.NoSuchProcedureException; +import org.junit.After; +import org.junit.Test; + +import static org.apache.iceberg.TableProperties.WRITE_AUDIT_PUBLISH_ENABLED; + +public class TestCherrypickSnapshotProcedure extends SparkExtensionsTestBase { + + public TestCherrypickSnapshotProcedure(String catalogName, String implementation, Map config) { + super(catalogName, implementation, config); + } + + @After + public void removeTables() { + sql("DROP TABLE IF EXISTS %s", tableName); + } + + @Test + public void testCherrypickSnapshotUsingPositionalArgs() { + sql("CREATE TABLE %s (id bigint NOT NULL, data string) USING iceberg", tableName); + sql("ALTER TABLE %s SET TBLPROPERTIES ('%s' 'true')", tableName, WRITE_AUDIT_PUBLISH_ENABLED); + + spark.conf().set("spark.wap.id", "1"); + + sql("INSERT INTO TABLE %s VALUES (1, 'a')", tableName); + + assertEquals("Should not see rows from staged snapshot", + ImmutableList.of(), + sql("SELECT * FROM %s", tableName)); + + Table table = validationCatalog.loadTable(tableIdent); + Snapshot wapSnapshot = Iterables.getOnlyElement(table.snapshots()); + + List output = sql( + "CALL %s.system.cherrypick_snapshot('%s', %dL)", + catalogName, tableIdent, wapSnapshot.snapshotId()); + + table.refresh(); + + Snapshot currentSnapshot = table.currentSnapshot(); + + assertEquals("Procedure output must match", + ImmutableList.of(row(wapSnapshot.snapshotId(), currentSnapshot.snapshotId())), + output); + + assertEquals("Cherrypick must be successful", + ImmutableList.of(row(1L, "a")), + sql("SELECT * FROM %s", tableName)); + } + + @Test + public void testCherrypickSnapshotUsingNamedArgs() { + sql("CREATE TABLE %s (id bigint NOT NULL, data string) USING iceberg", tableName); + sql("ALTER TABLE %s SET TBLPROPERTIES ('%s' 'true')", tableName, WRITE_AUDIT_PUBLISH_ENABLED); + + spark.conf().set("spark.wap.id", "1"); + + sql("INSERT INTO TABLE %s VALUES (1, 'a')", tableName); + + assertEquals("Should not see rows from staged snapshot", + ImmutableList.of(), + sql("SELECT * FROM %s", tableName)); + + Table table = validationCatalog.loadTable(tableIdent); + Snapshot wapSnapshot = Iterables.getOnlyElement(table.snapshots()); + + List output = sql( + "CALL %s.system.cherrypick_snapshot(snapshot_id => %dL, table => '%s')", + catalogName, wapSnapshot.snapshotId(), tableIdent); + + table.refresh(); + + Snapshot currentSnapshot = table.currentSnapshot(); + + assertEquals("Procedure output must match", + ImmutableList.of(row(wapSnapshot.snapshotId(), currentSnapshot.snapshotId())), + output); + + assertEquals("Cherrypick must be successful", + ImmutableList.of(row(1L, "a")), + sql("SELECT * FROM %s", tableName)); + } + + @Test + public void testCherrypickSnapshotRefreshesRelationCache() { + sql("CREATE TABLE %s (id bigint NOT NULL, data string) USING iceberg", tableName); + sql("ALTER TABLE %s SET TBLPROPERTIES ('%s' 'true')", tableName, WRITE_AUDIT_PUBLISH_ENABLED); + + Dataset query = spark.sql("SELECT * FROM " + tableName + " WHERE id = 1"); + query.createOrReplaceTempView("tmp"); + + spark.sql("CACHE TABLE tmp"); + + assertEquals("View should not produce rows", ImmutableList.of(), sql("SELECT * FROM tmp")); + + spark.conf().set("spark.wap.id", "1"); + + sql("INSERT INTO TABLE %s VALUES (1, 'a')", tableName); + + assertEquals("Should not see rows from staged snapshot", + ImmutableList.of(), + sql("SELECT * FROM %s", tableName)); + + Table table = validationCatalog.loadTable(tableIdent); + Snapshot wapSnapshot = Iterables.getOnlyElement(table.snapshots()); + + sql("CALL %s.system.cherrypick_snapshot('%s', %dL)", + catalogName, tableIdent, wapSnapshot.snapshotId()); + + assertEquals("Cherrypick snapshot should be visible", + ImmutableList.of(row(1L, "a")), + sql("SELECT * FROM tmp")); + + sql("UNCACHE TABLE tmp"); + } + + @Test + public void testCherrypickInvalidSnapshot() { + sql("CREATE TABLE %s (id bigint NOT NULL, data string) USING iceberg", tableName); + + AssertHelpers.assertThrows("Should reject invalid snapshot id", + ValidationException.class, "Cannot cherry-pick unknown snapshot ID", + () -> sql("CALL %s.system.cherrypick_snapshot('%s', -1L)", catalogName, tableIdent)); + } + + @Test + public void testInvalidCherrypickSnapshotCases() { + AssertHelpers.assertThrows("Should not allow mixed args", + AnalysisException.class, "Named and positional arguments cannot be mixed", + () -> sql("CALL %s.system.cherrypick_snapshot('n', table => 't', 1L)", catalogName)); + + AssertHelpers.assertThrows("Should not resolve procedures in arbitrary namespaces", + NoSuchProcedureException.class, "not found", + () -> sql("CALL %s.custom.cherrypick_snapshot('n', 't', 1L)", catalogName)); + + AssertHelpers.assertThrows("Should reject calls without all required args", + AnalysisException.class, "Missing required parameters", + () -> sql("CALL %s.system.cherrypick_snapshot('t')", catalogName)); + + AssertHelpers.assertThrows("Should reject calls with empty table identifier", + IllegalArgumentException.class, "Cannot handle an empty identifier", + () -> sql("CALL %s.system.cherrypick_snapshot('', 1L)", catalogName)); + + AssertHelpers.assertThrows("Should reject calls with invalid arg types", + AnalysisException.class, "Wrong arg type for snapshot_id: cannot cast", + () -> sql("CALL %s.system.cherrypick_snapshot('t', 2.2)", catalogName)); + } +} diff --git a/spark/v3.3/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestConflictValidation.java b/spark/v3.3/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestConflictValidation.java new file mode 100644 index 000000000000..4ce44818ab46 --- /dev/null +++ b/spark/v3.3/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestConflictValidation.java @@ -0,0 +1,411 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iceberg.spark.extensions; + +import java.util.List; +import java.util.Map; +import org.apache.iceberg.AssertHelpers; +import org.apache.iceberg.IsolationLevel; +import org.apache.iceberg.Table; +import org.apache.iceberg.exceptions.ValidationException; +import org.apache.iceberg.relocated.com.google.common.collect.Lists; +import org.apache.iceberg.spark.SparkWriteOptions; +import org.apache.iceberg.spark.source.SimpleRecord; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.catalyst.analysis.NoSuchTableException; +import org.apache.spark.sql.functions; +import org.junit.After; +import org.junit.Before; +import org.junit.Test; + +public class TestConflictValidation extends SparkExtensionsTestBase { + + public TestConflictValidation(String catalogName, String implementation, Map config) { + super(catalogName, implementation, config); + } + + @Before + public void createTables() { + sql("CREATE TABLE %s (id int, data string) USING iceberg " + + "PARTITIONED BY (id)" + + "TBLPROPERTIES" + + "('format-version'='2'," + + "'write.delete.mode'='merge-on-read')", tableName); + sql("INSERT INTO %s VALUES (1, 'a'), (2, 'b'), (3, 'c')", tableName); + } + + @After + public void removeTables() { + sql("DROP TABLE IF EXISTS %s", tableName); + } + + @Test + public void testOverwriteFilterSerializableIsolation() throws Exception { + Table table = validationCatalog.loadTable(tableIdent); + long snapshotId = table.currentSnapshot().snapshotId(); + + List records = Lists.newArrayList(new SimpleRecord(1, "a")); + spark.createDataFrame(records, SimpleRecord.class).writeTo(tableName).append(); + + // Validating from previous snapshot finds conflicts + Dataset conflictingDf = spark.createDataFrame(records, SimpleRecord.class); + AssertHelpers.assertThrowsCause("Conflicting new data files should throw exception", + ValidationException.class, + "Found conflicting files that can contain records matching ref(name=\"id\") == 1:", + () -> { + try { + conflictingDf.writeTo(tableName) + .option(SparkWriteOptions.VALIDATE_FROM_SNAPSHOT_ID, String.valueOf(snapshotId)) + .option(SparkWriteOptions.ISOLATION_LEVEL, IsolationLevel.SERIALIZABLE.toString()) + .overwrite(functions.col("id").equalTo(1)); + } catch (NoSuchTableException e) { + throw new RuntimeException(e); + } + }); + + // Validating from latest snapshot should succeed + table.refresh(); + long newSnapshotId = table.currentSnapshot().snapshotId(); + conflictingDf.writeTo(tableName) + .option(SparkWriteOptions.VALIDATE_FROM_SNAPSHOT_ID, String.valueOf(newSnapshotId)) + .option(SparkWriteOptions.ISOLATION_LEVEL, IsolationLevel.SERIALIZABLE.toString()) + .overwrite(functions.col("id").equalTo(1)); + } + + @Test + public void testOverwriteFilterSerializableIsolation2() throws Exception { + List records = Lists.newArrayList( + new SimpleRecord(1, "a"), + new SimpleRecord(1, "b")); + spark.createDataFrame(records, SimpleRecord.class).coalesce(1).writeTo(tableName).append(); + + Table table = validationCatalog.loadTable(tableIdent); + long snapshotId = table.currentSnapshot().snapshotId(); + + // This should add a delete file + sql("DELETE FROM %s WHERE id='1' and data='b'", tableName); + table.refresh(); + + // Validating from previous snapshot finds conflicts + List conflictingRecords = Lists.newArrayList(new SimpleRecord(1, "a")); + Dataset conflictingDf = spark.createDataFrame(conflictingRecords, SimpleRecord.class); + AssertHelpers.assertThrowsCause("Conflicting new delete files should throw exception", + ValidationException.class, + "Found new conflicting delete files that can apply to records matching ref(name=\"id\") == 1:", + () -> { + try { + conflictingDf.writeTo(tableName) + .option(SparkWriteOptions.VALIDATE_FROM_SNAPSHOT_ID, String.valueOf(snapshotId)) + .option(SparkWriteOptions.ISOLATION_LEVEL, IsolationLevel.SNAPSHOT.toString()) + .overwrite(functions.col("id").equalTo(1)); + } catch (NoSuchTableException e) { + throw new RuntimeException(e); + } + }); + + // Validating from latest snapshot should succeed + table.refresh(); + long newSnapshotId = table.currentSnapshot().snapshotId(); + conflictingDf.writeTo(tableName) + .option(SparkWriteOptions.VALIDATE_FROM_SNAPSHOT_ID, String.valueOf(newSnapshotId)) + .option(SparkWriteOptions.ISOLATION_LEVEL, IsolationLevel.SERIALIZABLE.toString()) + .overwrite(functions.col("id").equalTo(1)); + } + + @Test + public void testOverwriteFilterSerializableIsolation3() throws Exception { + Table table = validationCatalog.loadTable(tableIdent); + long snapshotId = table.currentSnapshot().snapshotId(); + + // This should delete a data file + sql("DELETE FROM %s WHERE id='1'", tableName); + table.refresh(); + + // Validating from previous snapshot finds conflicts + List conflictingRecords = Lists.newArrayList(new SimpleRecord(1, "a")); + Dataset conflictingDf = spark.createDataFrame(conflictingRecords, SimpleRecord.class); + AssertHelpers.assertThrowsCause("Conflicting deleted data files should throw exception", + ValidationException.class, + "Found conflicting deleted files that can contain records matching ref(name=\"id\") == 1:", + () -> { + try { + conflictingDf.writeTo(tableName) + .option(SparkWriteOptions.VALIDATE_FROM_SNAPSHOT_ID, String.valueOf(snapshotId)) + .option(SparkWriteOptions.ISOLATION_LEVEL, IsolationLevel.SERIALIZABLE.toString()) + .overwrite(functions.col("id").equalTo(1)); + } catch (NoSuchTableException e) { + throw new RuntimeException(e); + } + }); + + // Validating from latest snapshot should succeed + table.refresh(); + long newSnapshotId = table.currentSnapshot().snapshotId(); + conflictingDf.writeTo(tableName) + .option(SparkWriteOptions.VALIDATE_FROM_SNAPSHOT_ID, String.valueOf(newSnapshotId)) + .option(SparkWriteOptions.ISOLATION_LEVEL, IsolationLevel.SERIALIZABLE.toString()) + .overwrite(functions.col("id").equalTo(1)); + } + + @Test + public void testOverwriteFilterNoSnapshotIdValidation() throws Exception { + Table table = validationCatalog.loadTable(tableIdent); + + List records = Lists.newArrayList(new SimpleRecord(1, "a")); + spark.createDataFrame(records, SimpleRecord.class).writeTo(tableName).append(); + + // Validating from no snapshot id defaults to beginning snapshot id and finds conflicts + Dataset conflictingDf = spark.createDataFrame(records, SimpleRecord.class); + AssertHelpers.assertThrowsCause("Conflicting new data files should throw exception", + ValidationException.class, + "Found conflicting files that can contain records matching ref(name=\"id\") == 1:", + () -> { + try { + conflictingDf.writeTo(tableName) + .option(SparkWriteOptions.ISOLATION_LEVEL, IsolationLevel.SERIALIZABLE.toString()) + .overwrite(functions.col("id").equalTo(1)); + } catch (NoSuchTableException e) { + throw new RuntimeException(e); + } + }); + + // Validating from latest snapshot should succeed + table.refresh(); + long newSnapshotId = table.currentSnapshot().snapshotId(); + conflictingDf.writeTo(tableName) + .option(SparkWriteOptions.VALIDATE_FROM_SNAPSHOT_ID, String.valueOf(newSnapshotId)) + .option(SparkWriteOptions.ISOLATION_LEVEL, IsolationLevel.SERIALIZABLE.toString()) + .overwrite(functions.col("id").equalTo(1)); + } + + @Test + public void testOverwriteFilterSnapshotIsolation() throws Exception { + List records = Lists.newArrayList( + new SimpleRecord(1, "a"), + new SimpleRecord(1, "b")); + spark.createDataFrame(records, SimpleRecord.class).coalesce(1).writeTo(tableName).append(); + + Table table = validationCatalog.loadTable(tableIdent); + long snapshotId = table.currentSnapshot().snapshotId(); + + // This should add a delete file + sql("DELETE FROM %s WHERE id='1' and data='b'", tableName); + table.refresh(); + + // Validating from previous snapshot finds conflicts + List conflictingRecords = Lists.newArrayList(new SimpleRecord(1, "a")); + Dataset conflictingDf = spark.createDataFrame(conflictingRecords, SimpleRecord.class); + AssertHelpers.assertThrowsCause("Conflicting new delete files should throw exception", + ValidationException.class, + "Found new conflicting delete files that can apply to records matching ref(name=\"id\") == 1:", + () -> { + try { + conflictingDf.writeTo(tableName) + .option(SparkWriteOptions.VALIDATE_FROM_SNAPSHOT_ID, String.valueOf(snapshotId)) + .option(SparkWriteOptions.ISOLATION_LEVEL, IsolationLevel.SNAPSHOT.toString()) + .overwrite(functions.col("id").equalTo(1)); + } catch (NoSuchTableException e) { + throw new RuntimeException(e); + } + }); + + // Validating from latest snapshot should succeed + table.refresh(); + long newSnapshotId = table.currentSnapshot().snapshotId(); + conflictingDf.writeTo(tableName) + .option(SparkWriteOptions.VALIDATE_FROM_SNAPSHOT_ID, String.valueOf(newSnapshotId)) + .option(SparkWriteOptions.ISOLATION_LEVEL, IsolationLevel.SNAPSHOT.toString()) + .overwrite(functions.col("id").equalTo(1)); + } + + @Test + public void testOverwriteFilterSnapshotIsolation2() throws Exception { + Table table = validationCatalog.loadTable(tableIdent); + long snapshotId = table.currentSnapshot().snapshotId(); + + List records = Lists.newArrayList( + new SimpleRecord(1, "a")); + spark.createDataFrame(records, SimpleRecord.class).writeTo(tableName).append(); + + // Validation should not fail due to conflicting data file in snapshot isolation mode + Dataset conflictingDf = spark.createDataFrame(records, SimpleRecord.class); + conflictingDf.writeTo(tableName) + .option(SparkWriteOptions.VALIDATE_FROM_SNAPSHOT_ID, String.valueOf(snapshotId)) + .option(SparkWriteOptions.ISOLATION_LEVEL, IsolationLevel.SNAPSHOT.toString()) + .overwrite(functions.col("id").equalTo(1)); + } + + @Test + public void testOverwritePartitionSerializableIsolation() throws Exception { + Table table = validationCatalog.loadTable(tableIdent); + final long snapshotId = table.currentSnapshot().snapshotId(); + + List records = Lists.newArrayList(new SimpleRecord(1, "a")); + spark.createDataFrame(records, SimpleRecord.class).writeTo(tableName).append(); + + // Validating from previous snapshot finds conflicts + Dataset conflictingDf = spark.createDataFrame(records, SimpleRecord.class); + AssertHelpers.assertThrowsCause("Conflicting deleted data files should throw exception", + ValidationException.class, + "Found conflicting files that can contain records matching partitions [id=1]", + () -> { + try { + conflictingDf.writeTo(tableName) + .option(SparkWriteOptions.VALIDATE_FROM_SNAPSHOT_ID, String.valueOf(snapshotId)) + .option(SparkWriteOptions.ISOLATION_LEVEL, IsolationLevel.SERIALIZABLE.toString()) + .overwritePartitions(); + } catch (NoSuchTableException e) { + throw new RuntimeException(e); + } + }); + + // Validating from latest snapshot should succeed + table.refresh(); + long newSnapshotId = table.currentSnapshot().snapshotId(); + conflictingDf.writeTo(tableName) + .option(SparkWriteOptions.VALIDATE_FROM_SNAPSHOT_ID, String.valueOf(newSnapshotId)) + .option(SparkWriteOptions.ISOLATION_LEVEL, IsolationLevel.SERIALIZABLE.toString()) + .overwritePartitions(); + } + + @Test + public void testOverwritePartitionSnapshotIsolation() throws Exception { + List records = Lists.newArrayList( + new SimpleRecord(1, "a"), + new SimpleRecord(1, "b")); + spark.createDataFrame(records, SimpleRecord.class).coalesce(1).writeTo(tableName).append(); + + Table table = validationCatalog.loadTable(tableIdent); + final long snapshotId = table.currentSnapshot().snapshotId(); + + // This should generate a delete file + sql("DELETE FROM %s WHERE data='a'", tableName); + + // Validating from previous snapshot finds conflicts + Dataset conflictingDf = spark.createDataFrame(records, SimpleRecord.class); + AssertHelpers.assertThrowsCause("Conflicting deleted data files should throw exception", + ValidationException.class, + "Found new conflicting delete files that can apply to records matching [id=1]", + () -> { + try { + conflictingDf.writeTo(tableName) + .option(SparkWriteOptions.VALIDATE_FROM_SNAPSHOT_ID, String.valueOf(snapshotId)) + .option(SparkWriteOptions.ISOLATION_LEVEL, IsolationLevel.SNAPSHOT.toString()) + .overwritePartitions(); + } catch (NoSuchTableException e) { + throw new RuntimeException(e); + } + }); + + // Validating from latest snapshot should succeed + table.refresh(); + long newSnapshotId = table.currentSnapshot().snapshotId(); + conflictingDf.writeTo(tableName) + .option(SparkWriteOptions.VALIDATE_FROM_SNAPSHOT_ID, String.valueOf(newSnapshotId)) + .option(SparkWriteOptions.ISOLATION_LEVEL, IsolationLevel.SNAPSHOT.toString()) + .overwritePartitions(); + } + + @Test + public void testOverwritePartitionSnapshotIsolation2() throws Exception { + Table table = validationCatalog.loadTable(tableIdent); + final long snapshotId = table.currentSnapshot().snapshotId(); + + // This should delete a data file + sql("DELETE FROM %s WHERE id='1'", tableName); + + // Validating from previous snapshot finds conflicts + List records = Lists.newArrayList( + new SimpleRecord(1, "a")); + spark.createDataFrame(records, SimpleRecord.class).coalesce(1).writeTo(tableName).append(); + Dataset conflictingDf = spark.createDataFrame(records, SimpleRecord.class); + + AssertHelpers.assertThrowsCause("Conflicting deleted data files should throw exception", + ValidationException.class, + "Found conflicting deleted files that can apply to records matching [id=1]", + () -> { + try { + conflictingDf.writeTo(tableName) + .option(SparkWriteOptions.VALIDATE_FROM_SNAPSHOT_ID, String.valueOf(snapshotId)) + .option(SparkWriteOptions.ISOLATION_LEVEL, IsolationLevel.SNAPSHOT.toString()) + .overwritePartitions(); + } catch (NoSuchTableException e) { + throw new RuntimeException(e); + } + }); + + // Validating from latest snapshot should succeed + table.refresh(); + long newSnapshotId = table.currentSnapshot().snapshotId(); + conflictingDf.writeTo(tableName) + .option(SparkWriteOptions.VALIDATE_FROM_SNAPSHOT_ID, String.valueOf(newSnapshotId)) + .option(SparkWriteOptions.ISOLATION_LEVEL, IsolationLevel.SNAPSHOT.toString()) + .overwritePartitions(); + } + + @Test + public void testOverwritePartitionSnapshotIsolation3() throws Exception { + Table table = validationCatalog.loadTable(tableIdent); + final long snapshotId = table.currentSnapshot().snapshotId(); + + List records = Lists.newArrayList(new SimpleRecord(1, "a")); + spark.createDataFrame(records, SimpleRecord.class).writeTo(tableName).append(); + + // Validation should not find conflicting data file in snapshot isolation mode + Dataset conflictingDf = spark.createDataFrame(records, SimpleRecord.class); + conflictingDf.writeTo(tableName) + .option(SparkWriteOptions.VALIDATE_FROM_SNAPSHOT_ID, String.valueOf(snapshotId)) + .option(SparkWriteOptions.ISOLATION_LEVEL, IsolationLevel.SNAPSHOT.toString()) + .overwritePartitions(); + } + + @Test + public void testOverwritePartitionNoSnapshotIdValidation() throws Exception { + Table table = validationCatalog.loadTable(tableIdent); + + List records = Lists.newArrayList( + new SimpleRecord(1, "a")); + spark.createDataFrame(records, SimpleRecord.class).writeTo(tableName).append(); + + // Validating from null snapshot is equivalent to validating from beginning + Dataset conflictingDf = spark.createDataFrame(records, SimpleRecord.class); + AssertHelpers.assertThrowsCause("Conflicting deleted data files should throw exception", + ValidationException.class, + "Found conflicting files that can contain records matching partitions [id=1]", + () -> { + try { + conflictingDf.writeTo(tableName) + .option(SparkWriteOptions.ISOLATION_LEVEL, IsolationLevel.SERIALIZABLE.toString()) + .overwritePartitions(); + } catch (NoSuchTableException e) { + throw new RuntimeException(e); + } + }); + + // Validating from latest snapshot should succeed + table.refresh(); + long snapshotId = table.currentSnapshot().snapshotId(); + conflictingDf.writeTo(tableName) + .option(SparkWriteOptions.VALIDATE_FROM_SNAPSHOT_ID, String.valueOf(snapshotId)) + .option(SparkWriteOptions.ISOLATION_LEVEL, IsolationLevel.SERIALIZABLE.toString()) + .overwritePartitions(); + } +} diff --git a/spark/v3.3/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestCopyOnWriteDelete.java b/spark/v3.3/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestCopyOnWriteDelete.java new file mode 100644 index 000000000000..c9d15906251f --- /dev/null +++ b/spark/v3.3/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestCopyOnWriteDelete.java @@ -0,0 +1,37 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iceberg.spark.extensions; + +import java.util.Map; +import org.apache.iceberg.TableProperties; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap; + +public class TestCopyOnWriteDelete extends TestDelete { + + public TestCopyOnWriteDelete(String catalogName, String implementation, Map config, + String fileFormat, Boolean vectorized, String distributionMode) { + super(catalogName, implementation, config, fileFormat, vectorized, distributionMode); + } + + @Override + protected Map extraTableProperties() { + return ImmutableMap.of(TableProperties.DELETE_MODE, "copy-on-write"); + } +} diff --git a/spark/v3.3/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestCopyOnWriteMerge.java b/spark/v3.3/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestCopyOnWriteMerge.java new file mode 100644 index 000000000000..60aba632646f --- /dev/null +++ b/spark/v3.3/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestCopyOnWriteMerge.java @@ -0,0 +1,37 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iceberg.spark.extensions; + +import java.util.Map; +import org.apache.iceberg.TableProperties; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap; + +public class TestCopyOnWriteMerge extends TestMerge { + + public TestCopyOnWriteMerge(String catalogName, String implementation, Map config, + String fileFormat, boolean vectorized, String distributionMode) { + super(catalogName, implementation, config, fileFormat, vectorized, distributionMode); + } + + @Override + protected Map extraTableProperties() { + return ImmutableMap.of(TableProperties.MERGE_MODE, "copy-on-write"); + } +} diff --git a/spark/v3.3/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestCopyOnWriteUpdate.java b/spark/v3.3/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestCopyOnWriteUpdate.java new file mode 100644 index 000000000000..cc73ecba9ddf --- /dev/null +++ b/spark/v3.3/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestCopyOnWriteUpdate.java @@ -0,0 +1,37 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iceberg.spark.extensions; + +import java.util.Map; +import org.apache.iceberg.TableProperties; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap; + +public class TestCopyOnWriteUpdate extends TestUpdate { + + public TestCopyOnWriteUpdate(String catalogName, String implementation, Map config, + String fileFormat, boolean vectorized, String distributionMode) { + super(catalogName, implementation, config, fileFormat, vectorized, distributionMode); + } + + @Override + protected Map extraTableProperties() { + return ImmutableMap.of(TableProperties.UPDATE_MODE, "copy-on-write"); + } +} diff --git a/spark/v3.3/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestDelete.java b/spark/v3.3/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestDelete.java new file mode 100644 index 000000000000..d492a30eb827 --- /dev/null +++ b/spark/v3.3/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestDelete.java @@ -0,0 +1,876 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iceberg.spark.extensions; + +import java.util.Arrays; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.Future; +import java.util.concurrent.ThreadPoolExecutor; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicInteger; +import org.apache.iceberg.AssertHelpers; +import org.apache.iceberg.DataFile; +import org.apache.iceberg.RowLevelOperationMode; +import org.apache.iceberg.Snapshot; +import org.apache.iceberg.Table; +import org.apache.iceberg.exceptions.ValidationException; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableList; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap; +import org.apache.iceberg.relocated.com.google.common.collect.Iterables; +import org.apache.iceberg.relocated.com.google.common.collect.Lists; +import org.apache.iceberg.relocated.com.google.common.util.concurrent.MoreExecutors; +import org.apache.iceberg.spark.Spark3Util; +import org.apache.iceberg.spark.data.TestHelpers; +import org.apache.spark.SparkException; +import org.apache.spark.sql.AnalysisException; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Encoders; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.catalyst.analysis.NoSuchTableException; +import org.apache.spark.sql.internal.SQLConf; +import org.hamcrest.CoreMatchers; +import org.junit.After; +import org.junit.Assert; +import org.junit.Assume; +import org.junit.BeforeClass; +import org.junit.Test; + +import static org.apache.iceberg.RowLevelOperationMode.COPY_ON_WRITE; +import static org.apache.iceberg.TableProperties.DELETE_ISOLATION_LEVEL; +import static org.apache.iceberg.TableProperties.DELETE_MODE; +import static org.apache.iceberg.TableProperties.DELETE_MODE_DEFAULT; +import static org.apache.iceberg.TableProperties.PARQUET_ROW_GROUP_SIZE_BYTES; +import static org.apache.iceberg.TableProperties.SPLIT_SIZE; +import static org.apache.spark.sql.functions.lit; + +public abstract class TestDelete extends SparkRowLevelOperationsTestBase { + + public TestDelete(String catalogName, String implementation, Map config, + String fileFormat, Boolean vectorized, String distributionMode) { + super(catalogName, implementation, config, fileFormat, vectorized, distributionMode); + } + + @BeforeClass + public static void setupSparkConf() { + spark.conf().set("spark.sql.shuffle.partitions", "4"); + } + + @After + public void removeTables() { + sql("DROP TABLE IF EXISTS %s", tableName); + sql("DROP TABLE IF EXISTS deleted_id"); + sql("DROP TABLE IF EXISTS deleted_dep"); + } + + @Test + public void testDeleteFileThenMetadataDelete() throws Exception { + Assume.assumeFalse("Avro does not support metadata delete", fileFormat.equals("avro")); + createAndInitUnpartitionedTable(); + + sql("INSERT INTO TABLE %s VALUES (1, 'hr'), (2, 'hardware'), (null, 'hr')", tableName); + + // MOR mode: writes a delete file as null cannot be deleted by metadata + sql("DELETE FROM %s AS t WHERE t.id IS NULL", tableName); + + // Metadata Delete + Table table = Spark3Util.loadIcebergTable(spark, tableName); + Set dataFilesBefore = TestHelpers.dataFiles(table); + + sql("DELETE FROM %s AS t WHERE t.id = 1", tableName); + + Set dataFilesAfter = TestHelpers.dataFiles(table); + Assert.assertTrue("Data file should have been removed", dataFilesBefore.size() > dataFilesAfter.size()); + + assertEquals("Should have expected rows", + ImmutableList.of(row(2, "hardware")), + sql("SELECT * FROM %s ORDER BY id", tableName)); + } + + @Test + public void testDeleteWithFalseCondition() { + createAndInitUnpartitionedTable(); + + sql("INSERT INTO TABLE %s VALUES (1, 'hr'), (2, 'hardware')", tableName); + + sql("DELETE FROM %s WHERE id = 1 AND id > 20", tableName); + + Table table = validationCatalog.loadTable(tableIdent); + Assert.assertEquals("Should have 2 snapshots", 2, Iterables.size(table.snapshots())); + + assertEquals("Should have expected rows", + ImmutableList.of(row(1, "hr"), row(2, "hardware")), + sql("SELECT * FROM %s ORDER BY id", tableName)); + } + + @Test + public void testDeleteFromEmptyTable() { + createAndInitUnpartitionedTable(); + + sql("DELETE FROM %s WHERE id IN (1)", tableName); + sql("DELETE FROM %s WHERE dep = 'hr'", tableName); + + Table table = validationCatalog.loadTable(tableIdent); + Assert.assertEquals("Should have 2 snapshots", 2, Iterables.size(table.snapshots())); + + assertEquals("Should have expected rows", + ImmutableList.of(), + sql("SELECT * FROM %s ORDER BY id", tableName)); + } + + @Test + public void testExplain() { + createAndInitUnpartitionedTable(); + + sql("INSERT INTO TABLE %s VALUES (1, 'hr'), (2, 'hardware'), (null, 'hr')", tableName); + + sql("EXPLAIN DELETE FROM %s WHERE id <=> 1", tableName); + + sql("EXPLAIN DELETE FROM %s WHERE true", tableName); + + Table table = validationCatalog.loadTable(tableIdent); + Assert.assertEquals("Should have 1 snapshot", 1, Iterables.size(table.snapshots())); + + assertEquals("Should have expected rows", + ImmutableList.of(row(1, "hr"), row(2, "hardware"), row(null, "hr")), + sql("SELECT * FROM %s ORDER BY id ASC NULLS LAST", tableName)); + } + + @Test + public void testDeleteWithAlias() { + createAndInitUnpartitionedTable(); + + sql("INSERT INTO TABLE %s VALUES (1, 'hr'), (2, 'hardware'), (null, 'hr')", tableName); + + sql("DELETE FROM %s AS t WHERE t.id IS NULL", tableName); + + assertEquals("Should have expected rows", + ImmutableList.of(row(1, "hr"), row(2, "hardware")), + sql("SELECT * FROM %s ORDER BY id", tableName)); + } + + @Test + public void testDeleteWithDynamicFileFiltering() throws NoSuchTableException { + createAndInitPartitionedTable(); + + append(new Employee(1, "hr"), new Employee(3, "hr")); + append(new Employee(1, "hardware"), new Employee(2, "hardware")); + + sql("DELETE FROM %s WHERE id = 2", tableName); + + Table table = validationCatalog.loadTable(tableIdent); + Assert.assertEquals("Should have 3 snapshots", 3, Iterables.size(table.snapshots())); + + Snapshot currentSnapshot = table.currentSnapshot(); + if (mode(table) == COPY_ON_WRITE) { + validateCopyOnWrite(currentSnapshot, "1", "1", "1"); + } else { + validateMergeOnRead(currentSnapshot, "1", "1", null); + } + + assertEquals("Should have expected rows", + ImmutableList.of(row(1, "hardware"), row(1, "hr"), row(3, "hr")), + sql("SELECT * FROM %s ORDER BY id, dep", tableName)); + } + + @Test + public void testDeleteNonExistingRecords() { + createAndInitPartitionedTable(); + + sql("INSERT INTO TABLE %s VALUES (1, 'hr'), (2, 'hardware'), (null, 'hr')", tableName); + + sql("DELETE FROM %s AS t WHERE t.id > 10", tableName); + + Table table = validationCatalog.loadTable(tableIdent); + Assert.assertEquals("Should have 2 snapshots", 2, Iterables.size(table.snapshots())); + + Snapshot currentSnapshot = table.currentSnapshot(); + + if (fileFormat.equals("orc") || fileFormat.equals("parquet")) { + validateDelete(currentSnapshot, "0", null); + } else { + if (mode(table) == COPY_ON_WRITE) { + validateCopyOnWrite(currentSnapshot, "0", null, null); + } else { + validateMergeOnRead(currentSnapshot, "0", null, null); + } + } + + assertEquals("Should have expected rows", + ImmutableList.of(row(1, "hr"), row(2, "hardware"), row(null, "hr")), + sql("SELECT * FROM %s ORDER BY id ASC NULLS LAST", tableName)); + } + + @Test + public void testDeleteWithoutCondition() { + createAndInitPartitionedTable(); + + sql("INSERT INTO TABLE %s VALUES (1, 'hr')", tableName); + sql("INSERT INTO TABLE %s VALUES (2, 'hardware')", tableName); + sql("INSERT INTO TABLE %s VALUES (null, 'hr')", tableName); + + sql("DELETE FROM %s", tableName); + + Table table = validationCatalog.loadTable(tableIdent); + Assert.assertEquals("Should have 4 snapshots", 4, Iterables.size(table.snapshots())); + + // should be a delete instead of an overwrite as it is done through a metadata operation + Snapshot currentSnapshot = table.currentSnapshot(); + validateDelete(currentSnapshot, "2", "3"); + + assertEquals("Should have expected rows", + ImmutableList.of(), + sql("SELECT * FROM %s", tableName)); + } + + @Test + public void testDeleteUsingMetadataWithComplexCondition() { + createAndInitPartitionedTable(); + + sql("INSERT INTO TABLE %s VALUES (1, 'dep1')", tableName); + sql("INSERT INTO TABLE %s VALUES (2, 'dep2')", tableName); + sql("INSERT INTO TABLE %s VALUES (null, 'dep3')", tableName); + + sql("DELETE FROM %s WHERE dep > 'dep2' OR dep = CAST(4 AS STRING) OR dep = 'dep2'", tableName); + + Table table = validationCatalog.loadTable(tableIdent); + Assert.assertEquals("Should have 4 snapshots", 4, Iterables.size(table.snapshots())); + + // should be a delete instead of an overwrite as it is done through a metadata operation + Snapshot currentSnapshot = table.currentSnapshot(); + validateDelete(currentSnapshot, "2", "2"); + + assertEquals("Should have expected rows", + ImmutableList.of(row(1, "dep1")), + sql("SELECT * FROM %s", tableName)); + } + + @Test + public void testDeleteWithArbitraryPartitionPredicates() { + createAndInitPartitionedTable(); + + sql("INSERT INTO TABLE %s VALUES (1, 'hr')", tableName); + sql("INSERT INTO TABLE %s VALUES (2, 'hardware')", tableName); + sql("INSERT INTO TABLE %s VALUES (null, 'hr')", tableName); + + // %% is an escaped version of % + sql("DELETE FROM %s WHERE id = 10 OR dep LIKE '%%ware'", tableName); + + Table table = validationCatalog.loadTable(tableIdent); + Assert.assertEquals("Should have 4 snapshots", 4, Iterables.size(table.snapshots())); + + // should be an overwrite since cannot be executed using a metadata operation + Snapshot currentSnapshot = table.currentSnapshot(); + if (mode(table) == COPY_ON_WRITE) { + validateCopyOnWrite(currentSnapshot, "1", "1", null); + } else { + validateMergeOnRead(currentSnapshot, "1", "1", null); + } + + assertEquals("Should have expected rows", + ImmutableList.of(row(1, "hr"), row(null, "hr")), + sql("SELECT * FROM %s ORDER BY id ASC NULLS LAST", tableName)); + } + + @Test + public void testDeleteWithNonDeterministicCondition() { + createAndInitPartitionedTable(); + + sql("INSERT INTO TABLE %s VALUES (1, 'hr'), (2, 'hardware')", tableName); + + AssertHelpers.assertThrows("Should complain about non-deterministic expressions", + AnalysisException.class, "nondeterministic expressions are only allowed", + () -> sql("DELETE FROM %s WHERE id = 1 AND rand() > 0.5", tableName)); + } + + @Test + public void testDeleteWithFoldableConditions() { + createAndInitPartitionedTable(); + + sql("INSERT INTO TABLE %s VALUES (1, 'hr'), (2, 'hardware')", tableName); + + // should keep all rows and don't trigger execution + sql("DELETE FROM %s WHERE false", tableName); + assertEquals("Should have expected rows", + ImmutableList.of(row(1, "hr"), row(2, "hardware")), + sql("SELECT * FROM %s ORDER BY id", tableName)); + + // should keep all rows and don't trigger execution + sql("DELETE FROM %s WHERE 50 <> 50", tableName); + assertEquals("Should have expected rows", + ImmutableList.of(row(1, "hr"), row(2, "hardware")), + sql("SELECT * FROM %s ORDER BY id", tableName)); + + // should keep all rows and don't trigger execution + sql("DELETE FROM %s WHERE 1 > null", tableName); + assertEquals("Should have expected rows", + ImmutableList.of(row(1, "hr"), row(2, "hardware")), + sql("SELECT * FROM %s ORDER BY id", tableName)); + + // should remove all rows + sql("DELETE FROM %s WHERE 21 = 21", tableName); + assertEquals("Should have expected rows", + ImmutableList.of(), + sql("SELECT * FROM %s ORDER BY id", tableName)); + + Table table = validationCatalog.loadTable(tableIdent); + Assert.assertEquals("Should have 2 snapshots", 2, Iterables.size(table.snapshots())); + } + + @Test + public void testDeleteWithNullConditions() { + createAndInitPartitionedTable(); + + sql("INSERT INTO TABLE %s VALUES (0, null), (1, 'hr'), (2, 'hardware'), (null, 'hr')", tableName); + + // should keep all rows as null is never equal to null + sql("DELETE FROM %s WHERE dep = null", tableName); + assertEquals("Should have expected rows", + ImmutableList.of(row(0, null), row(1, "hr"), row(2, "hardware"), row(null, "hr")), + sql("SELECT * FROM %s ORDER BY id ASC NULLS LAST", tableName)); + + // null = 'software' -> null + // should delete using metadata operation only + sql("DELETE FROM %s WHERE dep = 'software'", tableName); + assertEquals("Should have expected rows", + ImmutableList.of(row(0, null), row(1, "hr"), row(2, "hardware"), row(null, "hr")), + sql("SELECT * FROM %s ORDER BY id ASC NULLS LAST", tableName)); + + // should delete using metadata operation only + sql("DELETE FROM %s WHERE dep <=> NULL", tableName); + assertEquals("Should have expected rows", + ImmutableList.of(row(1, "hr"), row(2, "hardware"), row(null, "hr")), + sql("SELECT * FROM %s ORDER BY id ASC NULLS LAST", tableName)); + + Table table = validationCatalog.loadTable(tableIdent); + Assert.assertEquals("Should have 3 snapshots", 3, Iterables.size(table.snapshots())); + + Snapshot currentSnapshot = table.currentSnapshot(); + validateDelete(currentSnapshot, "1", "1"); + } + + @Test + public void testDeleteWithInAndNotInConditions() { + createAndInitUnpartitionedTable(); + + sql("INSERT INTO TABLE %s VALUES (1, 'hr'), (2, 'hardware'), (null, 'hr')", tableName); + + sql("DELETE FROM %s WHERE id IN (1, null)", tableName); + assertEquals("Should have expected rows", + ImmutableList.of(row(2, "hardware"), row(null, "hr")), + sql("SELECT * FROM %s ORDER BY id ASC NULLS LAST", tableName)); + + sql("DELETE FROM %s WHERE id NOT IN (null, 1)", tableName); + assertEquals("Should have expected rows", + ImmutableList.of(row(2, "hardware"), row(null, "hr")), + sql("SELECT * FROM %s ORDER BY id ASC NULLS LAST", tableName)); + + sql("DELETE FROM %s WHERE id NOT IN (1, 10)", tableName); + assertEquals("Should have expected rows", + ImmutableList.of(row(null, "hr")), + sql("SELECT * FROM %s ORDER BY id ASC NULLS LAST", tableName)); + } + + @Test + public void testDeleteWithMultipleRowGroupsParquet() throws NoSuchTableException { + Assume.assumeTrue(fileFormat.equalsIgnoreCase("parquet")); + + createAndInitPartitionedTable(); + + sql("ALTER TABLE %s SET TBLPROPERTIES('%s' '%d')", tableName, PARQUET_ROW_GROUP_SIZE_BYTES, 100); + sql("ALTER TABLE %s SET TBLPROPERTIES('%s' '%d')", tableName, SPLIT_SIZE, 100); + + List ids = Lists.newArrayListWithCapacity(200); + for (int id = 1; id <= 200; id++) { + ids.add(id); + } + Dataset df = spark.createDataset(ids, Encoders.INT()) + .withColumnRenamed("value", "id") + .withColumn("dep", lit("hr")); + df.coalesce(1).writeTo(tableName).append(); + + Assert.assertEquals(200, spark.table(tableName).count()); + + // delete a record from one of two row groups and copy over the second one + sql("DELETE FROM %s WHERE id IN (200, 201)", tableName); + + Assert.assertEquals(199, spark.table(tableName).count()); + } + + @Test + public void testDeleteWithConditionOnNestedColumn() { + createAndInitNestedColumnsTable(); + + sql("INSERT INTO TABLE %s VALUES (1, named_struct(\"c1\", 3, \"c2\", \"v1\"))", tableName); + sql("INSERT INTO TABLE %s VALUES (2, named_struct(\"c1\", 2, \"c2\", \"v2\"))", tableName); + + sql("DELETE FROM %s WHERE complex.c1 = id + 2", tableName); + assertEquals("Should have expected rows", + ImmutableList.of(row(2)), + sql("SELECT id FROM %s", tableName)); + + sql("DELETE FROM %s t WHERE t.complex.c1 = id", tableName); + assertEquals("Should have expected rows", + ImmutableList.of(), + sql("SELECT id FROM %s", tableName)); + } + + @Test + public void testDeleteWithInSubquery() throws NoSuchTableException { + createAndInitUnpartitionedTable(); + + sql("INSERT INTO TABLE %s VALUES (1, 'hr'), (2, 'hardware'), (null, 'hr')", tableName); + + createOrReplaceView("deleted_id", Arrays.asList(0, 1, null), Encoders.INT()); + createOrReplaceView("deleted_dep", Arrays.asList("software", "hr"), Encoders.STRING()); + + sql("DELETE FROM %s WHERE id IN (SELECT * FROM deleted_id) AND dep IN (SELECT * from deleted_dep)", tableName); + assertEquals("Should have expected rows", + ImmutableList.of(row(2, "hardware"), row(null, "hr")), + sql("SELECT * FROM %s ORDER BY id ASC NULLS LAST", tableName)); + + append(new Employee(1, "hr"), new Employee(-1, "hr")); + assertEquals("Should have expected rows", + ImmutableList.of(row(-1, "hr"), row(1, "hr"), row(2, "hardware"), row(null, "hr")), + sql("SELECT * FROM %s ORDER BY id ASC NULLS LAST", tableName)); + + sql("DELETE FROM %s WHERE id IS NULL OR id IN (SELECT value + 2 FROM deleted_id)", tableName); + assertEquals("Should have expected rows", + ImmutableList.of(row(-1, "hr"), row(1, "hr")), + sql("SELECT * FROM %s ORDER BY id", tableName)); + + append(new Employee(null, "hr"), new Employee(2, "hr")); + assertEquals("Should have expected rows", + ImmutableList.of(row(-1, "hr"), row(1, "hr"), row(2, "hr"), row(null, "hr")), + sql("SELECT * FROM %s ORDER BY id ASC NULLS LAST", tableName)); + + sql("DELETE FROM %s WHERE id IN (SELECT value + 2 FROM deleted_id) AND dep = 'hr'", tableName); + assertEquals("Should have expected rows", + ImmutableList.of(row(-1, "hr"), row(1, "hr"), row(null, "hr")), + sql("SELECT * FROM %s ORDER BY id ASC NULLS LAST", tableName)); + } + + @Test + public void testDeleteWithMultiColumnInSubquery() throws NoSuchTableException { + createAndInitUnpartitionedTable(); + + append(new Employee(1, "hr"), new Employee(2, "hardware"), new Employee(null, "hr")); + + List deletedEmployees = Arrays.asList(new Employee(null, "hr"), new Employee(1, "hr")); + createOrReplaceView("deleted_employee", deletedEmployees, Encoders.bean(Employee.class)); + + sql("DELETE FROM %s WHERE (id, dep) IN (SELECT id, dep FROM deleted_employee)", tableName); + assertEquals("Should have expected rows", + ImmutableList.of(row(2, "hardware"), row(null, "hr")), + sql("SELECT * FROM %s ORDER BY id ASC NULLS LAST", tableName)); + } + + @Test + public void testDeleteWithNotInSubquery() throws NoSuchTableException { + createAndInitUnpartitionedTable(); + + append(new Employee(1, "hr"), new Employee(2, "hardware"), new Employee(null, "hr")); + + createOrReplaceView("deleted_id", Arrays.asList(-1, -2, null), Encoders.INT()); + createOrReplaceView("deleted_dep", Arrays.asList("software", "hr"), Encoders.STRING()); + + // the file filter subquery (nested loop lef-anti join) returns 0 records + sql("DELETE FROM %s WHERE id NOT IN (SELECT * FROM deleted_id)", tableName); + assertEquals("Should have expected rows", + ImmutableList.of(row(1, "hr"), row(2, "hardware"), row(null, "hr")), + sql("SELECT * FROM %s ORDER BY id ASC NULLS LAST", tableName)); + + sql("DELETE FROM %s WHERE id NOT IN (SELECT * FROM deleted_id WHERE value IS NOT NULL)", tableName); + assertEquals("Should have expected rows", + ImmutableList.of(row(null, "hr")), + sql("SELECT * FROM %s ORDER BY id ASC NULLS LAST", tableName)); + + sql("INSERT INTO TABLE %s VALUES (1, 'hr'), (2, 'hardware'), (null, 'hr')", tableName); + assertEquals("Should have expected rows", + ImmutableList.of(row(1, "hr"), row(2, "hardware"), row(null, "hr"), row(null, "hr")), + sql("SELECT * FROM %s ORDER BY id ASC NULLS LAST", tableName)); + + sql("DELETE FROM %s WHERE id NOT IN (SELECT * FROM deleted_id) OR dep IN ('software', 'hr')", tableName); + assertEquals("Should have expected rows", + ImmutableList.of(row(2, "hardware")), + sql("SELECT * FROM %s ORDER BY id ASC NULLS LAST", tableName)); + + sql("DELETE FROM %s t WHERE " + + "id NOT IN (SELECT * FROM deleted_id WHERE value IS NOT NULL) AND " + + "EXISTS (SELECT 1 FROM FROM deleted_dep WHERE t.dep = deleted_dep.value)", tableName); + assertEquals("Should have expected rows", + ImmutableList.of(row(2, "hardware")), + sql("SELECT * FROM %s ORDER BY id ASC NULLS LAST", tableName)); + + sql("DELETE FROM %s t WHERE " + + "id NOT IN (SELECT * FROM deleted_id WHERE value IS NOT NULL) OR " + + "EXISTS (SELECT 1 FROM FROM deleted_dep WHERE t.dep = deleted_dep.value)", tableName); + assertEquals("Should have expected rows", + ImmutableList.of(), + sql("SELECT * FROM %s ORDER BY id ASC NULLS LAST", tableName)); + } + + @Test + public void testDeleteOnNonIcebergTableNotSupported() { + createOrReplaceView("testtable", "{ \"c1\": -100, \"c2\": -200 }"); + + AssertHelpers.assertThrows("Delete is supported only for Iceberg tables", + AnalysisException.class, "DELETE is only supported with v2 tables.", + () -> sql("DELETE FROM %s WHERE c1 = -100", "testtable")); + } + + @Test + public void testDeleteWithExistSubquery() throws NoSuchTableException { + createAndInitUnpartitionedTable(); + + append(new Employee(1, "hr"), new Employee(2, "hardware"), new Employee(null, "hr")); + + createOrReplaceView("deleted_id", Arrays.asList(-1, -2, null), Encoders.INT()); + createOrReplaceView("deleted_dep", Arrays.asList("software", "hr"), Encoders.STRING()); + + sql("DELETE FROM %s t WHERE EXISTS (SELECT 1 FROM deleted_id d WHERE t.id = d.value)", tableName); + assertEquals("Should have expected rows", + ImmutableList.of(row(1, "hr"), row(2, "hardware"), row(null, "hr")), + sql("SELECT * FROM %s ORDER BY id ASC NULLS LAST", tableName)); + + sql("DELETE FROM %s t WHERE EXISTS (SELECT 1 FROM deleted_id d WHERE t.id = d.value + 2)", tableName); + assertEquals("Should have expected rows", + ImmutableList.of(row(2, "hardware"), row(null, "hr")), + sql("SELECT * FROM %s ORDER BY id ASC NULLS LAST", tableName)); + + sql("DELETE FROM %s t WHERE EXISTS (SELECT 1 FROM deleted_id d WHERE t.id = d.value) OR t.id IS NULL", tableName); + assertEquals("Should have expected rows", + ImmutableList.of(row(2, "hardware")), + sql("SELECT * FROM %s", tableName)); + + sql("DELETE FROM %s t WHERE " + + "EXISTS (SELECT 1 FROM deleted_id di WHERE t.id = di.value) AND " + + "EXISTS (SELECT 1 FROM deleted_dep dd WHERE t.dep = dd.value)", tableName); + assertEquals("Should have expected rows", + ImmutableList.of(row(2, "hardware")), + sql("SELECT * FROM %s", tableName)); + } + + @Test + public void testDeleteWithNotExistsSubquery() throws NoSuchTableException { + createAndInitUnpartitionedTable(); + + append(new Employee(1, "hr"), new Employee(2, "hardware"), new Employee(null, "hr")); + + createOrReplaceView("deleted_id", Arrays.asList(-1, -2, null), Encoders.INT()); + createOrReplaceView("deleted_dep", Arrays.asList("software", "hr"), Encoders.STRING()); + + sql("DELETE FROM %s t WHERE " + + "NOT EXISTS (SELECT 1 FROM deleted_id di WHERE t.id = di.value + 2) AND " + + "NOT EXISTS (SELECT 1 FROM deleted_dep dd WHERE t.dep = dd.value)", tableName); + assertEquals("Should have expected rows", + ImmutableList.of(row(1, "hr"), row(null, "hr")), + sql("SELECT * FROM %s ORDER BY id ASC NULLS LAST", tableName)); + + sql("DELETE FROM %s t WHERE NOT EXISTS (SELECT 1 FROM deleted_id d WHERE t.id = d.value + 2)", tableName); + assertEquals("Should have expected rows", + ImmutableList.of(row(1, "hr")), + sql("SELECT * FROM %s ORDER BY id ASC NULLS LAST", tableName)); + + String subquery = "SELECT 1 FROM deleted_id d WHERE t.id = d.value + 2"; + sql("DELETE FROM %s t WHERE NOT EXISTS (%s) OR t.id = 1", tableName, subquery); + assertEquals("Should have expected rows", + ImmutableList.of(), + sql("SELECT * FROM %s ORDER BY id ASC NULLS LAST", tableName)); + } + + @Test + public void testDeleteWithScalarSubquery() throws NoSuchTableException { + createAndInitUnpartitionedTable(); + + append(new Employee(1, "hr"), new Employee(2, "hardware"), new Employee(null, "hr")); + + createOrReplaceView("deleted_id", Arrays.asList(1, 100, null), Encoders.INT()); + + // TODO: Spark does not support AQE and DPP with aggregates at the moment + withSQLConf(ImmutableMap.of(SQLConf.ADAPTIVE_EXECUTION_ENABLED().key(), "false"), () -> { + sql("DELETE FROM %s t WHERE id <= (SELECT min(value) FROM deleted_id)", tableName); + assertEquals("Should have expected rows", + ImmutableList.of(row(2, "hardware"), row(null, "hr")), + sql("SELECT * FROM %s ORDER BY id ASC NULLS LAST", tableName)); + }); + } + + @Test + public void testDeleteThatRequiresGroupingBeforeWrite() throws NoSuchTableException { + createAndInitPartitionedTable(); + + append(new Employee(0, "hr"), new Employee(1, "hr"), new Employee(2, "hr")); + append(new Employee(0, "ops"), new Employee(1, "ops"), new Employee(2, "ops")); + append(new Employee(0, "hr"), new Employee(1, "hr"), new Employee(2, "hr")); + append(new Employee(0, "ops"), new Employee(1, "ops"), new Employee(2, "ops")); + + createOrReplaceView("deleted_id", Arrays.asList(1, 100), Encoders.INT()); + + String originalNumOfShufflePartitions = spark.conf().get("spark.sql.shuffle.partitions"); + try { + // set the num of shuffle partitions to 1 to ensure we have only 1 writing task + spark.conf().set("spark.sql.shuffle.partitions", "1"); + + sql("DELETE FROM %s t WHERE id IN (SELECT * FROM deleted_id)", tableName); + Assert.assertEquals("Should have expected num of rows", 8L, spark.table(tableName).count()); + } finally { + spark.conf().set("spark.sql.shuffle.partitions", originalNumOfShufflePartitions); + } + } + + @Test + public synchronized void testDeleteWithSerializableIsolation() throws InterruptedException { + // cannot run tests with concurrency for Hadoop tables without atomic renames + Assume.assumeFalse(catalogName.equalsIgnoreCase("testhadoop")); + + createAndInitUnpartitionedTable(); + + sql("ALTER TABLE %s SET TBLPROPERTIES('%s' '%s')", tableName, DELETE_ISOLATION_LEVEL, "serializable"); + + // Pre-populate the table to force it to use the Spark Writers instead of Metadata-Only Delete + // for more consistent exception stack + List ids = ImmutableList.of(1, 2); + Dataset inputDF = spark.createDataset(ids, Encoders.INT()) + .withColumnRenamed("value", "id") + .withColumn("dep", lit("hr")); + try { + inputDF.coalesce(1).writeTo(tableName).append(); + } catch (NoSuchTableException e) { + throw new RuntimeException(e); + } + + ExecutorService executorService = MoreExecutors.getExitingExecutorService( + (ThreadPoolExecutor) Executors.newFixedThreadPool(2)); + + AtomicInteger barrier = new AtomicInteger(0); + + // delete thread + Future deleteFuture = executorService.submit(() -> { + for (int numOperations = 0; numOperations < Integer.MAX_VALUE; numOperations++) { + while (barrier.get() < numOperations * 2) { + sleep(10); + } + sql("DELETE FROM %s WHERE id = 1", tableName); + barrier.incrementAndGet(); + } + }); + + // append thread + Future appendFuture = executorService.submit(() -> { + for (int numOperations = 0; numOperations < Integer.MAX_VALUE; numOperations++) { + while (barrier.get() < numOperations * 2) { + sleep(10); + } + + try { + inputDF.coalesce(1).writeTo(tableName).append(); + } catch (NoSuchTableException e) { + throw new RuntimeException(e); + } + + barrier.incrementAndGet(); + } + }); + + try { + deleteFuture.get(); + Assert.fail("Expected a validation exception"); + } catch (ExecutionException e) { + Throwable sparkException = e.getCause(); + Assert.assertThat(sparkException, CoreMatchers.instanceOf(SparkException.class)); + Throwable validationException = sparkException.getCause(); + Assert.assertThat(validationException, CoreMatchers.instanceOf(ValidationException.class)); + String errMsg = validationException.getMessage(); + Assert.assertThat(errMsg, CoreMatchers.containsString("Found conflicting files that can contain")); + } finally { + appendFuture.cancel(true); + } + + executorService.shutdown(); + Assert.assertTrue("Timeout", executorService.awaitTermination(2, TimeUnit.MINUTES)); + } + + @Test + public synchronized void testDeleteWithSnapshotIsolation() throws InterruptedException, ExecutionException { + // cannot run tests with concurrency for Hadoop tables without atomic renames + Assume.assumeFalse(catalogName.equalsIgnoreCase("testhadoop")); + + createAndInitUnpartitionedTable(); + + sql("ALTER TABLE %s SET TBLPROPERTIES('%s' '%s')", tableName, DELETE_ISOLATION_LEVEL, "snapshot"); + + ExecutorService executorService = MoreExecutors.getExitingExecutorService( + (ThreadPoolExecutor) Executors.newFixedThreadPool(2)); + + AtomicInteger barrier = new AtomicInteger(0); + + // delete thread + Future deleteFuture = executorService.submit(() -> { + for (int numOperations = 0; numOperations < 20; numOperations++) { + while (barrier.get() < numOperations * 2) { + sleep(10); + } + sql("DELETE FROM %s WHERE id = 1", tableName); + barrier.incrementAndGet(); + } + }); + + // append thread + Future appendFuture = executorService.submit(() -> { + List ids = ImmutableList.of(1, 2); + Dataset inputDF = spark.createDataset(ids, Encoders.INT()) + .withColumnRenamed("value", "id") + .withColumn("dep", lit("hr")); + + for (int numOperations = 0; numOperations < 20; numOperations++) { + while (barrier.get() < numOperations * 2) { + sleep(10); + } + + try { + inputDF.coalesce(1).writeTo(tableName).append(); + } catch (NoSuchTableException e) { + throw new RuntimeException(e); + } + + barrier.incrementAndGet(); + } + }); + + try { + deleteFuture.get(); + } finally { + appendFuture.cancel(true); + } + + executorService.shutdown(); + Assert.assertTrue("Timeout", executorService.awaitTermination(2, TimeUnit.MINUTES)); + } + + @Test + public void testDeleteRefreshesRelationCache() throws NoSuchTableException { + createAndInitPartitionedTable(); + + append(new Employee(1, "hr"), new Employee(3, "hr")); + append(new Employee(1, "hardware"), new Employee(2, "hardware")); + + Dataset query = spark.sql("SELECT * FROM " + tableName + " WHERE id = 1"); + query.createOrReplaceTempView("tmp"); + + spark.sql("CACHE TABLE tmp"); + + assertEquals("View should have correct data", + ImmutableList.of(row(1, "hardware"), row(1, "hr")), + sql("SELECT * FROM tmp ORDER BY id, dep")); + + sql("DELETE FROM %s WHERE id = 1", tableName); + + Table table = validationCatalog.loadTable(tableIdent); + Assert.assertEquals("Should have 3 snapshots", 3, Iterables.size(table.snapshots())); + + Snapshot currentSnapshot = table.currentSnapshot(); + if (mode(table) == COPY_ON_WRITE) { + validateCopyOnWrite(currentSnapshot, "2", "2", "2"); + } else { + validateMergeOnRead(currentSnapshot, "2", "2", null); + } + assertEquals("Should have expected rows", + ImmutableList.of(row(2, "hardware"), row(3, "hr")), + sql("SELECT * FROM %s ORDER BY id, dep", tableName)); + + assertEquals("Should refresh the relation cache", + ImmutableList.of(), + sql("SELECT * FROM tmp ORDER BY id, dep")); + + spark.sql("UNCACHE TABLE tmp"); + } + + @Test + public void testDeleteWithMultipleSpecs() { + createAndInitTable("id INT, dep STRING, category STRING"); + + // write an unpartitioned file + append(tableName, "{ \"id\": 1, \"dep\": \"hr\", \"category\": \"c1\"}"); + + // write a file partitioned by dep + sql("ALTER TABLE %s ADD PARTITION FIELD dep", tableName); + append(tableName, "{ \"id\": 2, \"dep\": \"hr\", \"category\": \"c1\" }\n" + + "{ \"id\": 3, \"dep\": \"hr\", \"category\": \"c1\" }"); + + // write a file partitioned by dep and category + sql("ALTER TABLE %s ADD PARTITION FIELD category", tableName); + append(tableName, "{ \"id\": 5, \"dep\": \"hr\", \"category\": \"c1\"}"); + + // write another file partitioned by dep + sql("ALTER TABLE %s DROP PARTITION FIELD category", tableName); + append(tableName, "{ \"id\": 7, \"dep\": \"hr\", \"category\": \"c1\"}"); + + sql("DELETE FROM %s WHERE id IN (1, 3, 5, 7)", tableName); + + Table table = validationCatalog.loadTable(tableIdent); + Assert.assertEquals("Should have 5 snapshots", 5, Iterables.size(table.snapshots())); + + Snapshot currentSnapshot = table.currentSnapshot(); + if (mode(table) == COPY_ON_WRITE) { + // copy-on-write is tested against v1 and such tables have different partition evolution behavior + // that's why the number of changed partitions is 4 for copy-on-write + validateCopyOnWrite(currentSnapshot, "4", "4", "1"); + } else { + validateMergeOnRead(currentSnapshot, "3", "3", null); + } + + assertEquals("Should have expected rows", + ImmutableList.of(row(2, "hr", "c1")), + sql("SELECT * FROM %s ORDER BY id", tableName)); + } + + // TODO: multiple stripes for ORC + + protected void createAndInitPartitionedTable() { + sql("CREATE TABLE %s (id INT, dep STRING) USING iceberg PARTITIONED BY (dep)", tableName); + initTable(); + } + + protected void createAndInitUnpartitionedTable() { + sql("CREATE TABLE %s (id INT, dep STRING) USING iceberg", tableName); + initTable(); + } + + protected void createAndInitNestedColumnsTable() { + sql("CREATE TABLE %s (id INT, complex STRUCT) USING iceberg", tableName); + initTable(); + } + + protected void append(Employee... employees) throws NoSuchTableException { + List input = Arrays.asList(employees); + Dataset inputDF = spark.createDataFrame(input, Employee.class); + inputDF.coalesce(1).writeTo(tableName).append(); + } + + private RowLevelOperationMode mode(Table table) { + String modeName = table.properties().getOrDefault(DELETE_MODE, DELETE_MODE_DEFAULT); + return RowLevelOperationMode.fromName(modeName); + } +} diff --git a/spark/v3.3/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestExpireSnapshotsProcedure.java b/spark/v3.3/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestExpireSnapshotsProcedure.java new file mode 100644 index 000000000000..a77dafe5af05 --- /dev/null +++ b/spark/v3.3/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestExpireSnapshotsProcedure.java @@ -0,0 +1,350 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iceberg.spark.extensions; + +import java.io.IOException; +import java.sql.Timestamp; +import java.time.Instant; +import java.util.List; +import java.util.Map; +import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.fs.FileSystem; +import org.apache.hadoop.fs.Path; +import org.apache.iceberg.AssertHelpers; +import org.apache.iceberg.Snapshot; +import org.apache.iceberg.Table; +import org.apache.iceberg.exceptions.ValidationException; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableList; +import org.apache.iceberg.relocated.com.google.common.collect.Iterables; +import org.apache.iceberg.relocated.com.google.common.collect.Lists; +import org.apache.iceberg.spark.Spark3Util; +import org.apache.iceberg.spark.SparkCatalog; +import org.apache.iceberg.spark.data.TestHelpers; +import org.apache.iceberg.spark.source.SimpleRecord; +import org.apache.spark.sql.AnalysisException; +import org.apache.spark.sql.Encoders; +import org.apache.spark.sql.catalyst.analysis.NoSuchProcedureException; +import org.junit.After; +import org.junit.Assert; +import org.junit.Test; + +import static org.apache.iceberg.TableProperties.GC_ENABLED; + +public class TestExpireSnapshotsProcedure extends SparkExtensionsTestBase { + + public TestExpireSnapshotsProcedure(String catalogName, String implementation, Map config) { + super(catalogName, implementation, config); + } + + @After + public void removeTables() { + sql("DROP TABLE IF EXISTS %s", tableName); + } + + @Test + public void testExpireSnapshotsInEmptyTable() { + sql("CREATE TABLE %s (id bigint NOT NULL, data string) USING iceberg", tableName); + + List output = sql( + "CALL %s.system.expire_snapshots('%s')", + catalogName, tableIdent); + assertEquals("Should not delete any files", ImmutableList.of(row(0L, 0L, 0L, 0L, 0L)), output); + } + + @Test + public void testExpireSnapshotsUsingPositionalArgs() { + sql("CREATE TABLE %s (id bigint NOT NULL, data string) USING iceberg", tableName); + sql("INSERT INTO TABLE %s VALUES (1, 'a')", tableName); + + Table table = validationCatalog.loadTable(tableIdent); + Snapshot firstSnapshot = table.currentSnapshot(); + + waitUntilAfter(firstSnapshot.timestampMillis()); + + sql("INSERT INTO TABLE %s VALUES (2, 'b')", tableName); + + table.refresh(); + + Snapshot secondSnapshot = table.currentSnapshot(); + Timestamp secondSnapshotTimestamp = Timestamp.from(Instant.ofEpochMilli(secondSnapshot.timestampMillis())); + + Assert.assertEquals("Should be 2 snapshots", 2, Iterables.size(table.snapshots())); + + // expire without retainLast param + List output1 = sql( + "CALL %s.system.expire_snapshots('%s', TIMESTAMP '%s')", + catalogName, tableIdent, secondSnapshotTimestamp); + assertEquals("Procedure output must match", + ImmutableList.of(row(0L, 0L, 0L, 0L, 1L)), + output1); + + table.refresh(); + + Assert.assertEquals("Should expire one snapshot", 1, Iterables.size(table.snapshots())); + + sql("INSERT OVERWRITE %s VALUES (3, 'c')", tableName); + sql("INSERT INTO TABLE %s VALUES (4, 'd')", tableName); + assertEquals("Should have expected rows", + ImmutableList.of(row(3L, "c"), row(4L, "d")), + sql("SELECT * FROM %s ORDER BY id", tableName)); + + table.refresh(); + + waitUntilAfter(table.currentSnapshot().timestampMillis()); + + Timestamp currentTimestamp = Timestamp.from(Instant.ofEpochMilli(System.currentTimeMillis())); + + Assert.assertEquals("Should be 3 snapshots", 3, Iterables.size(table.snapshots())); + + // expire with retainLast param + List output = sql( + "CALL %s.system.expire_snapshots('%s', TIMESTAMP '%s', 2)", + catalogName, tableIdent, currentTimestamp); + assertEquals("Procedure output must match", + ImmutableList.of(row(2L, 0L, 0L, 2L, 1L)), + output); + } + + @Test + public void testExpireSnapshotUsingNamedArgs() { + sql("CREATE TABLE %s (id bigint NOT NULL, data string) USING iceberg", tableName); + + sql("INSERT INTO TABLE %s VALUES (1, 'a')", tableName); + sql("INSERT INTO TABLE %s VALUES (2, 'b')", tableName); + + Table table = validationCatalog.loadTable(tableIdent); + + Assert.assertEquals("Should be 2 snapshots", 2, Iterables.size(table.snapshots())); + + waitUntilAfter(table.currentSnapshot().timestampMillis()); + + Timestamp currentTimestamp = Timestamp.from(Instant.ofEpochMilli(System.currentTimeMillis())); + + List output = sql( + "CALL %s.system.expire_snapshots(" + + "older_than => TIMESTAMP '%s'," + + "table => '%s'," + + "retain_last => 1)", + catalogName, currentTimestamp, tableIdent); + assertEquals("Procedure output must match", + ImmutableList.of(row(0L, 0L, 0L, 0L, 1L)), + output); + } + + @Test + public void testExpireSnapshotsGCDisabled() { + sql("CREATE TABLE %s (id bigint NOT NULL, data string) USING iceberg", tableName); + + sql("ALTER TABLE %s SET TBLPROPERTIES ('%s' 'false')", tableName, GC_ENABLED); + + AssertHelpers.assertThrows("Should reject call", + ValidationException.class, "Cannot expire snapshots: GC is disabled", + () -> sql("CALL %s.system.expire_snapshots('%s')", catalogName, tableIdent)); + } + + @Test + public void testInvalidExpireSnapshotsCases() { + AssertHelpers.assertThrows("Should not allow mixed args", + AnalysisException.class, "Named and positional arguments cannot be mixed", + () -> sql("CALL %s.system.expire_snapshots('n', table => 't')", catalogName)); + + AssertHelpers.assertThrows("Should not resolve procedures in arbitrary namespaces", + NoSuchProcedureException.class, "not found", + () -> sql("CALL %s.custom.expire_snapshots('n', 't')", catalogName)); + + AssertHelpers.assertThrows("Should reject calls without all required args", + AnalysisException.class, "Missing required parameters", + () -> sql("CALL %s.system.expire_snapshots()", catalogName)); + + AssertHelpers.assertThrows("Should reject calls with invalid arg types", + AnalysisException.class, "Wrong arg type", + () -> sql("CALL %s.system.expire_snapshots('n', 2.2)", catalogName)); + + AssertHelpers.assertThrows("Should reject calls with empty table identifier", + IllegalArgumentException.class, "Cannot handle an empty identifier", + () -> sql("CALL %s.system.expire_snapshots('')", catalogName)); + } + + @Test + public void testResolvingTableInAnotherCatalog() throws IOException { + String anotherCatalog = "another_" + catalogName; + spark.conf().set("spark.sql.catalog." + anotherCatalog, SparkCatalog.class.getName()); + spark.conf().set("spark.sql.catalog." + anotherCatalog + ".type", "hadoop"); + spark.conf().set("spark.sql.catalog." + anotherCatalog + ".warehouse", "file:" + temp.newFolder().toString()); + + sql("CREATE TABLE %s.%s (id bigint NOT NULL, data string) USING iceberg", anotherCatalog, tableIdent); + + AssertHelpers.assertThrows("Should reject calls for a table in another catalog", + IllegalArgumentException.class, "Cannot run procedure in catalog", + () -> sql("CALL %s.system.expire_snapshots('%s')", catalogName, anotherCatalog + "." + tableName)); + } + + @Test + public void testConcurrentExpireSnapshots() { + sql("CREATE TABLE %s (id bigint NOT NULL, data string) USING iceberg", tableName); + + sql("INSERT INTO TABLE %s VALUES (1, 'a')", tableName); + sql("INSERT INTO TABLE %s VALUES (2, 'b')", tableName); + sql("INSERT INTO TABLE %s VALUES (3, 'c')", tableName); + sql("INSERT INTO TABLE %s VALUES (4, 'd')", tableName); + + Timestamp currentTimestamp = Timestamp.from(Instant.ofEpochMilli(System.currentTimeMillis())); + List output = sql( + "CALL %s.system.expire_snapshots(" + + "older_than => TIMESTAMP '%s'," + + "table => '%s'," + + "max_concurrent_deletes => %s," + + "retain_last => 1)", + catalogName, currentTimestamp, tableIdent, 4); + assertEquals("Expiring snapshots concurrently should succeed", + ImmutableList.of(row(0L, 0L, 0L, 0L, 3L)), output); + } + + @Test + public void testConcurrentExpireSnapshotsWithInvalidInput() { + sql("CREATE TABLE %s (id bigint NOT NULL, data string) USING iceberg", tableName); + + AssertHelpers.assertThrows("Should throw an error when max_concurrent_deletes = 0", + IllegalArgumentException.class, "max_concurrent_deletes should have value > 0", + () -> sql("CALL %s.system.expire_snapshots(table => '%s', max_concurrent_deletes => %s)", + catalogName, tableIdent, 0)); + + AssertHelpers.assertThrows("Should throw an error when max_concurrent_deletes < 0 ", + IllegalArgumentException.class, "max_concurrent_deletes should have value > 0", + () -> sql( + "CALL %s.system.expire_snapshots(table => '%s', max_concurrent_deletes => %s)", + catalogName, tableIdent, -1)); + + } + + @Test + public void testExpireDeleteFiles() throws Exception { + sql("CREATE TABLE %s (id bigint, data string) USING iceberg TBLPROPERTIES" + + "('format-version'='2', 'write.delete.mode'='merge-on-read')", tableName); + + List records = Lists.newArrayList( + new SimpleRecord(1, "a"), + new SimpleRecord(2, "b"), + new SimpleRecord(3, "c"), + new SimpleRecord(4, "d") + ); + spark.createDataset(records, Encoders.bean(SimpleRecord.class)).coalesce(1).writeTo(tableName).append(); + sql("DELETE FROM %s WHERE id=1", tableName); + + Table table = Spark3Util.loadIcebergTable(spark, tableName); + + Assert.assertEquals("Should have 1 delete manifest", 1, TestHelpers.deleteManifests(table).size()); + Assert.assertEquals("Should have 1 delete file", 1, TestHelpers.deleteFiles(table).size()); + Path deleteManifestPath = new Path(TestHelpers.deleteManifests(table).iterator().next().path()); + Path deleteFilePath = new Path(String.valueOf(TestHelpers.deleteFiles(table).iterator().next().path())); + + sql("CALL %s.system.rewrite_data_files(" + + "table => '%s'," + + "options => map(" + + "'delete-file-threshold','1'," + + "'use-starting-sequence-number', 'false'))", + catalogName, tableIdent); + table.refresh(); + + sql("INSERT INTO TABLE %s VALUES (5, 'e')", tableName); // this txn moves the file to the DELETED state + sql("INSERT INTO TABLE %s VALUES (6, 'f')", tableName); // this txn removes the file reference + table.refresh(); + + Assert.assertEquals("Should have no delete manifests", 0, TestHelpers.deleteManifests(table).size()); + Assert.assertEquals("Should have no delete files", 0, TestHelpers.deleteFiles(table).size()); + + FileSystem localFs = FileSystem.getLocal(new Configuration()); + Assert.assertTrue("Delete manifest should still exist", localFs.exists(deleteManifestPath)); + Assert.assertTrue("Delete file should still exist", localFs.exists(deleteFilePath)); + + Timestamp currentTimestamp = Timestamp.from(Instant.ofEpochMilli(System.currentTimeMillis())); + List output = sql("CALL %s.system.expire_snapshots(" + + "older_than => TIMESTAMP '%s'," + + "table => '%s'," + + "retain_last => 1)", + catalogName, currentTimestamp, tableIdent); + + assertEquals("Should deleted 1 data and pos delete file and 4 manifests and lists (one for each txn)", + ImmutableList.of(row(1L, 1L, 0L, 4L, 4L)), output); + Assert.assertFalse("Delete manifest should be removed", localFs.exists(deleteManifestPath)); + Assert.assertFalse("Delete file should be removed", localFs.exists(deleteFilePath)); + } + + @Test + public void testExpireSnapshotWithStreamResultsEnabled() { + sql("CREATE TABLE %s (id bigint NOT NULL, data string) USING iceberg", tableName); + + sql("INSERT INTO TABLE %s VALUES (1, 'a')", tableName); + sql("INSERT INTO TABLE %s VALUES (2, 'b')", tableName); + + Table table = validationCatalog.loadTable(tableIdent); + + Assert.assertEquals("Should be 2 snapshots", 2, Iterables.size(table.snapshots())); + + waitUntilAfter(table.currentSnapshot().timestampMillis()); + + Timestamp currentTimestamp = Timestamp.from(Instant.ofEpochMilli(System.currentTimeMillis())); + + List output = sql( + "CALL %s.system.expire_snapshots(" + + "older_than => TIMESTAMP '%s'," + + "table => '%s'," + + "retain_last => 1, " + + "stream_results => true)", + catalogName, currentTimestamp, tableIdent); + assertEquals("Procedure output must match", ImmutableList.of(row(0L, 0L, 0L, 0L, 1L)), output); + } + + @Test + public void testExpireSnapshotsProcedureWorksWithSqlComments() { + // Ensure that systems such as dbt, that inject comments into the generated SQL files, will + // work with Iceberg-specific DDL + sql("CREATE TABLE %s (id bigint NOT NULL, data string) USING iceberg", tableName); + + sql("INSERT INTO TABLE %s VALUES (1, 'a')", tableName); + sql("INSERT INTO TABLE %s VALUES (2, 'b')", tableName); + + Table table = validationCatalog.loadTable(tableIdent); + + Assert.assertEquals("Should be 2 snapshots", 2, Iterables.size(table.snapshots())); + + waitUntilAfter(table.currentSnapshot().timestampMillis()); + + Timestamp currentTimestamp = Timestamp.from(Instant.ofEpochMilli(System.currentTimeMillis())); + + String callStatement = + "/* CALL statement is used to expire snapshots */\n" + + "-- And we have single line comments as well \n" + + "/* And comments that span *multiple* \n" + + " lines */ CALL /* this is the actual CALL */ %s.system.expire_snapshots(" + + " older_than => TIMESTAMP '%s'," + + " table => '%s'," + + " retain_last => 1)"; + List output = sql( + callStatement, catalogName, currentTimestamp, tableIdent); + assertEquals("Procedure output must match", + ImmutableList.of(row(0L, 0L, 0L, 0L, 1L)), + output); + + table.refresh(); + + Assert.assertEquals("Should be 1 snapshot remaining", 1, Iterables.size(table.snapshots())); + } +} diff --git a/spark/v3.3/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestIcebergExpressions.java b/spark/v3.3/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestIcebergExpressions.java new file mode 100644 index 000000000000..ce88814ce937 --- /dev/null +++ b/spark/v3.3/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestIcebergExpressions.java @@ -0,0 +1,70 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iceberg.spark.extensions; + +import java.math.BigDecimal; +import java.util.Map; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableList; +import org.apache.spark.sql.Column; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.catalyst.expressions.IcebergTruncateTransform; +import org.junit.After; +import org.junit.Test; + +public class TestIcebergExpressions extends SparkExtensionsTestBase { + + public TestIcebergExpressions(String catalogName, String implementation, Map config) { + super(catalogName, implementation, config); + } + + @After + public void removeTables() { + sql("DROP TABLE IF EXISTS %s", tableName); + sql("DROP VIEW IF EXISTS emp"); + sql("DROP VIEW IF EXISTS v"); + } + + @Test + public void testTruncateExpressions() { + sql("CREATE TABLE %s ( " + + " int_c INT, long_c LONG, dec_c DECIMAL(4, 2), str_c STRING, binary_c BINARY " + + ") USING iceberg", tableName); + + sql("CREATE TEMPORARY VIEW emp " + + "AS SELECT * FROM VALUES (101, 10001, 10.65, '101-Employee', CAST('1234' AS BINARY)) " + + "AS EMP(int_c, long_c, dec_c, str_c, binary_c)"); + + sql("INSERT INTO %s SELECT * FROM emp", tableName); + + Dataset df = spark.sql("SELECT * FROM " + tableName); + df.select( + new Column(new IcebergTruncateTransform(df.col("int_c").expr(), 2)).as("int_c"), + new Column(new IcebergTruncateTransform(df.col("long_c").expr(), 2)).as("long_c"), + new Column(new IcebergTruncateTransform(df.col("dec_c").expr(), 50)).as("dec_c"), + new Column(new IcebergTruncateTransform(df.col("str_c").expr(), 2)).as("str_c"), + new Column(new IcebergTruncateTransform(df.col("binary_c").expr(), 2)).as("binary_c") + ).createOrReplaceTempView("v"); + + assertEquals("Should have expected rows", + ImmutableList.of(row(100, 10000L, new BigDecimal("10.50"), "10", "12")), + sql("SELECT int_c, long_c, dec_c, str_c, CAST(binary_c AS STRING) FROM v")); + } +} diff --git a/spark/v3.3/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestMerge.java b/spark/v3.3/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestMerge.java new file mode 100644 index 000000000000..0173f9797d0b --- /dev/null +++ b/spark/v3.3/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestMerge.java @@ -0,0 +1,1813 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iceberg.spark.extensions; + +import java.util.Arrays; +import java.util.Collections; +import java.util.List; +import java.util.Map; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.Future; +import java.util.concurrent.ThreadPoolExecutor; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicInteger; +import org.apache.iceberg.AssertHelpers; +import org.apache.iceberg.DistributionMode; +import org.apache.iceberg.Snapshot; +import org.apache.iceberg.SnapshotSummary; +import org.apache.iceberg.Table; +import org.apache.iceberg.exceptions.ValidationException; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableList; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap; +import org.apache.iceberg.relocated.com.google.common.collect.Lists; +import org.apache.iceberg.relocated.com.google.common.util.concurrent.MoreExecutors; +import org.apache.spark.SparkException; +import org.apache.spark.sql.AnalysisException; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Encoders; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.catalyst.analysis.NoSuchTableException; +import org.apache.spark.sql.internal.SQLConf; +import org.hamcrest.CoreMatchers; +import org.junit.After; +import org.junit.Assert; +import org.junit.Assume; +import org.junit.BeforeClass; +import org.junit.Test; + +import static org.apache.iceberg.TableProperties.MERGE_ISOLATION_LEVEL; +import static org.apache.iceberg.TableProperties.PARQUET_ROW_GROUP_SIZE_BYTES; +import static org.apache.iceberg.TableProperties.SPLIT_SIZE; +import static org.apache.iceberg.TableProperties.WRITE_DISTRIBUTION_MODE; +import static org.apache.spark.sql.functions.lit; + +public abstract class TestMerge extends SparkRowLevelOperationsTestBase { + + public TestMerge(String catalogName, String implementation, Map config, + String fileFormat, boolean vectorized, String distributionMode) { + super(catalogName, implementation, config, fileFormat, vectorized, distributionMode); + } + + @BeforeClass + public static void setupSparkConf() { + spark.conf().set("spark.sql.shuffle.partitions", "4"); + } + + @After + public void removeTables() { + sql("DROP TABLE IF EXISTS %s", tableName); + sql("DROP TABLE IF EXISTS source"); + } + + @Test + public void testMergeWithStaticPredicatePushDown() { + createAndInitTable("id BIGINT, dep STRING"); + + sql("ALTER TABLE %s ADD PARTITION FIELD dep", tableName); + + // add a data file to the 'software' partition + append(tableName, "{ \"id\": 1, \"dep\": \"software\" }"); + + // add a data file to the 'hr' partition + append(tableName, "{ \"id\": 1, \"dep\": \"hr\" }"); + + Table table = validationCatalog.loadTable(tableIdent); + + Snapshot snapshot = table.currentSnapshot(); + String dataFilesCount = snapshot.summary().get(SnapshotSummary.TOTAL_DATA_FILES_PROP); + Assert.assertEquals("Must have 2 files before MERGE", "2", dataFilesCount); + + createOrReplaceView("source", + "{ \"id\": 1, \"dep\": \"finance\" }\n" + + "{ \"id\": 2, \"dep\": \"hardware\" }"); + + // remove the data file from the 'hr' partition to ensure it is not scanned + withUnavailableFiles(snapshot.addedFiles(table.io()), () -> { + // disable dynamic pruning and rely only on static predicate pushdown + withSQLConf(ImmutableMap.of(SQLConf.DYNAMIC_PARTITION_PRUNING_ENABLED().key(), "false"), () -> { + sql("MERGE INTO %s t USING source " + + "ON t.id == source.id AND t.dep IN ('software') AND source.id < 10 " + + "WHEN MATCHED AND source.id = 1 THEN " + + " UPDATE SET dep = source.dep " + + "WHEN NOT MATCHED THEN " + + " INSERT (dep, id) VALUES (source.dep, source.id)", tableName); + }); + }); + + ImmutableList expectedRows = ImmutableList.of( + row(1L, "finance"), // updated + row(1L, "hr"), // kept + row(2L, "hardware") // new + ); + assertEquals("Output should match", expectedRows, sql("SELECT * FROM %s ORDER BY id, dep", tableName)); + } + + @Test + public void testMergeIntoEmptyTargetInsertAllNonMatchingRows() { + createAndInitTable("id INT, dep STRING"); + + createOrReplaceView("source", "id INT, dep STRING", + "{ \"id\": 1, \"dep\": \"emp-id-1\" }\n" + + "{ \"id\": 2, \"dep\": \"emp-id-2\" }\n" + + "{ \"id\": 3, \"dep\": \"emp-id-3\" }"); + + sql("MERGE INTO %s AS t USING source AS s " + + "ON t.id == s.id " + + "WHEN NOT MATCHED THEN " + + " INSERT *", tableName); + + ImmutableList expectedRows = ImmutableList.of( + row(1, "emp-id-1"), // new + row(2, "emp-id-2"), // new + row(3, "emp-id-3") // new + ); + assertEquals("Should have expected rows", expectedRows, sql("SELECT * FROM %s ORDER BY id", tableName)); + } + + @Test + public void testMergeIntoEmptyTargetInsertOnlyMatchingRows() { + createAndInitTable("id INT, dep STRING"); + + createOrReplaceView("source", "id INT, dep STRING", + "{ \"id\": 1, \"dep\": \"emp-id-1\" }\n" + + "{ \"id\": 2, \"dep\": \"emp-id-2\" }\n" + + "{ \"id\": 3, \"dep\": \"emp-id-3\" }"); + + sql("MERGE INTO %s AS t USING source AS s " + + "ON t.id == s.id " + + "WHEN NOT MATCHED AND (s.id >=2) THEN " + + " INSERT *", tableName); + + ImmutableList expectedRows = ImmutableList.of( + row(2, "emp-id-2"), // new + row(3, "emp-id-3") // new + ); + assertEquals("Should have expected rows", expectedRows, sql("SELECT * FROM %s ORDER BY id", tableName)); + } + + @Test + public void testMergeWithOnlyUpdateClause() { + createAndInitTable("id INT, dep STRING", + "{ \"id\": 1, \"dep\": \"emp-id-one\" }\n" + + "{ \"id\": 6, \"dep\": \"emp-id-six\" }"); + + createOrReplaceView("source", "id INT, dep STRING", + "{ \"id\": 2, \"dep\": \"emp-id-2\" }\n" + + "{ \"id\": 1, \"dep\": \"emp-id-1\" }\n" + + "{ \"id\": 6, \"dep\": \"emp-id-6\" }"); + + sql("MERGE INTO %s AS t USING source AS s " + + "ON t.id == s.id " + + "WHEN MATCHED AND t.id = 1 THEN " + + " UPDATE SET *", tableName); + + ImmutableList expectedRows = ImmutableList.of( + row(1, "emp-id-1"), // updated + row(6, "emp-id-six") // kept + ); + assertEquals("Should have expected rows", expectedRows, sql("SELECT * FROM %s ORDER BY id", tableName)); + } + + @Test + public void testMergeWithOnlyDeleteClause() { + createAndInitTable("id INT, dep STRING", + "{ \"id\": 1, \"dep\": \"emp-id-one\" }\n" + + "{ \"id\": 6, \"dep\": \"emp-id-6\" }"); + + createOrReplaceView("source", "id INT, dep STRING", + "{ \"id\": 2, \"dep\": \"emp-id-2\" }\n" + + "{ \"id\": 1, \"dep\": \"emp-id-1\" }\n" + + "{ \"id\": 6, \"dep\": \"emp-id-6\" }"); + + sql("MERGE INTO %s AS t USING source AS s " + + "ON t.id == s.id " + + "WHEN MATCHED AND t.id = 6 THEN " + + " DELETE", tableName); + + ImmutableList expectedRows = ImmutableList.of( + row(1, "emp-id-one") // kept + ); + assertEquals("Should have expected rows", expectedRows, sql("SELECT * FROM %s ORDER BY id", tableName)); + } + + @Test + public void testMergeWithAllCauses() { + createAndInitTable("id INT, dep STRING", + "{ \"id\": 1, \"dep\": \"emp-id-one\" }\n" + + "{ \"id\": 6, \"dep\": \"emp-id-6\" }"); + + createOrReplaceView("source", "id INT, dep STRING", + "{ \"id\": 2, \"dep\": \"emp-id-2\" }\n" + + "{ \"id\": 1, \"dep\": \"emp-id-1\" }\n" + + "{ \"id\": 6, \"dep\": \"emp-id-6\" }"); + + sql("MERGE INTO %s AS t USING source AS s " + + "ON t.id == s.id " + + "WHEN MATCHED AND t.id = 1 THEN " + + " UPDATE SET * " + + "WHEN MATCHED AND t.id = 6 THEN " + + " DELETE " + + "WHEN NOT MATCHED AND s.id = 2 THEN " + + " INSERT *", tableName); + + ImmutableList expectedRows = ImmutableList.of( + row(1, "emp-id-1"), // updated + row(2, "emp-id-2") // new + ); + assertEquals("Should have expected rows", expectedRows, sql("SELECT * FROM %s ORDER BY id", tableName)); + } + + @Test + public void testMergeWithAllCausesWithExplicitColumnSpecification() { + createAndInitTable("id INT, dep STRING", + "{ \"id\": 1, \"dep\": \"emp-id-one\" }\n" + + "{ \"id\": 6, \"dep\": \"emp-id-6\" }"); + + createOrReplaceView("source", "id INT, dep STRING", + "{ \"id\": 2, \"dep\": \"emp-id-2\" }\n" + + "{ \"id\": 1, \"dep\": \"emp-id-1\" }\n" + + "{ \"id\": 6, \"dep\": \"emp-id-6\" }"); + + sql("MERGE INTO %s AS t USING source AS s " + + "ON t.id == s.id " + + "WHEN MATCHED AND t.id = 1 THEN " + + " UPDATE SET t.id = s.id, t.dep = s.dep " + + "WHEN MATCHED AND t.id = 6 THEN " + + " DELETE " + + "WHEN NOT MATCHED AND s.id = 2 THEN " + + " INSERT (t.id, t.dep) VALUES (s.id, s.dep)", tableName); + + ImmutableList expectedRows = ImmutableList.of( + row(1, "emp-id-1"), // updated + row(2, "emp-id-2") // new + ); + assertEquals("Should have expected rows", expectedRows, sql("SELECT * FROM %s ORDER BY id", tableName)); + } + + @Test + public void testMergeWithSourceCTE() { + createAndInitTable("id INT, dep STRING", + "{ \"id\": 2, \"dep\": \"emp-id-two\" }\n" + + "{ \"id\": 6, \"dep\": \"emp-id-6\" }"); + + createOrReplaceView("source", "id INT, dep STRING", + "{ \"id\": 2, \"dep\": \"emp-id-3\" }\n" + + "{ \"id\": 1, \"dep\": \"emp-id-2\" }\n" + + "{ \"id\": 5, \"dep\": \"emp-id-6\" }"); + + sql("WITH cte1 AS (SELECT id + 1 AS id, dep FROM source) " + + "MERGE INTO %s AS t USING cte1 AS s " + + "ON t.id == s.id " + + "WHEN MATCHED AND t.id = 2 THEN " + + " UPDATE SET * " + + "WHEN MATCHED AND t.id = 6 THEN " + + " DELETE " + + "WHEN NOT MATCHED AND s.id = 3 THEN " + + " INSERT *", tableName); + + ImmutableList expectedRows = ImmutableList.of( + row(2, "emp-id-2"), // updated + row(3, "emp-id-3") // new + ); + assertEquals("Should have expected rows", expectedRows, sql("SELECT * FROM %s ORDER BY id", tableName)); + } + + @Test + public void testMergeWithSourceFromSetOps() { + createAndInitTable("id INT, dep STRING", + "{ \"id\": 1, \"dep\": \"emp-id-one\" }\n" + + "{ \"id\": 6, \"dep\": \"emp-id-6\" }"); + + createOrReplaceView("source", "id INT, dep STRING", + "{ \"id\": 2, \"dep\": \"emp-id-2\" }\n" + + "{ \"id\": 1, \"dep\": \"emp-id-1\" }\n" + + "{ \"id\": 6, \"dep\": \"emp-id-6\" }"); + + String derivedSource = + "SELECT * FROM source WHERE id = 2 " + + "UNION ALL " + + "SELECT * FROM source WHERE id = 1 OR id = 6"; + + sql("MERGE INTO %s AS t USING (%s) AS s " + + "ON t.id == s.id " + + "WHEN MATCHED AND t.id = 1 THEN " + + " UPDATE SET * " + + "WHEN MATCHED AND t.id = 6 THEN " + + " DELETE " + + "WHEN NOT MATCHED AND s.id = 2 THEN " + + " INSERT *", tableName, derivedSource); + + ImmutableList expectedRows = ImmutableList.of( + row(1, "emp-id-1"), // updated + row(2, "emp-id-2") // new + ); + assertEquals("Should have expected rows", expectedRows, sql("SELECT * FROM %s ORDER BY id", tableName)); + } + + @Test + public void testMergeWithMultipleUpdatesForTargetRowSmallTargetLargeSource() { + createAndInitTable("id INT, dep STRING", + "{ \"id\": 1, \"dep\": \"emp-id-one\" }\n" + + "{ \"id\": 6, \"dep\": \"emp-id-6\" }"); + + List sourceIds = Lists.newArrayList(); + for (int i = 0; i < 10_000; i++) { + sourceIds.add(i); + } + Dataset ds = spark.createDataset(sourceIds, Encoders.INT()); + ds.union(ds).createOrReplaceTempView("source"); + + String errorMsg = "a single row from the target table with multiple rows of the source table"; + AssertHelpers.assertThrowsCause("Should complain about multiple matches", + SparkException.class, errorMsg, + () -> { + sql("MERGE INTO %s AS t USING source AS s " + + "ON t.id == s.value " + + "WHEN MATCHED AND t.id = 1 THEN " + + " UPDATE SET id = 10 " + + "WHEN MATCHED AND t.id = 6 THEN " + + " DELETE " + + "WHEN NOT MATCHED AND s.value = 2 THEN " + + " INSERT (id, dep) VALUES (s.value, null)", tableName); + }); + + assertEquals("Target should be unchanged", + ImmutableList.of(row(1, "emp-id-one"), row(6, "emp-id-6")), + sql("SELECT * FROM %s ORDER BY id ASC NULLS LAST", tableName)); + } + + @Test + public void testMergeWithMultipleUpdatesForTargetRowSmallTargetLargeSourceEnabledHashShuffleJoin() { + createAndInitTable("id INT, dep STRING", + "{ \"id\": 1, \"dep\": \"emp-id-one\" }\n" + + "{ \"id\": 6, \"dep\": \"emp-id-6\" }"); + + List sourceIds = Lists.newArrayList(); + for (int i = 0; i < 10_000; i++) { + sourceIds.add(i); + } + Dataset ds = spark.createDataset(sourceIds, Encoders.INT()); + ds.union(ds).createOrReplaceTempView("source"); + + withSQLConf(ImmutableMap.of(SQLConf.PREFER_SORTMERGEJOIN().key(), "false"), () -> { + String errorMsg = "a single row from the target table with multiple rows of the source table"; + AssertHelpers.assertThrowsCause("Should complain about multiple matches", + SparkException.class, errorMsg, + () -> { + sql("MERGE INTO %s AS t USING source AS s " + + "ON t.id == s.value " + + "WHEN MATCHED AND t.id = 1 THEN " + + " UPDATE SET id = 10 " + + "WHEN MATCHED AND t.id = 6 THEN " + + " DELETE " + + "WHEN NOT MATCHED AND s.value = 2 THEN " + + " INSERT (id, dep) VALUES (s.value, null)", tableName); + }); + }); + + assertEquals("Target should be unchanged", + ImmutableList.of(row(1, "emp-id-one"), row(6, "emp-id-6")), + sql("SELECT * FROM %s ORDER BY id ASC NULLS LAST", tableName)); + } + + @Test + public void testMergeWithMultipleUpdatesForTargetRowSmallTargetLargeSourceNoEqualityCondition() { + createAndInitTable("id INT, dep STRING", "{ \"id\": 1, \"dep\": \"emp-id-one\" }"); + + List sourceIds = Lists.newArrayList(); + for (int i = 0; i < 10_000; i++) { + sourceIds.add(i); + } + Dataset ds = spark.createDataset(sourceIds, Encoders.INT()); + ds.union(ds).createOrReplaceTempView("source"); + + withSQLConf(ImmutableMap.of(SQLConf.PREFER_SORTMERGEJOIN().key(), "false"), () -> { + String errorMsg = "a single row from the target table with multiple rows of the source table"; + AssertHelpers.assertThrowsCause("Should complain about multiple matches", + SparkException.class, errorMsg, + () -> { + sql("MERGE INTO %s AS t USING source AS s " + + "ON t.id > s.value " + + "WHEN MATCHED AND t.id = 1 THEN " + + " UPDATE SET id = 10 " + + "WHEN MATCHED AND t.id = 6 THEN " + + " DELETE " + + "WHEN NOT MATCHED AND s.value = 2 THEN " + + " INSERT (id, dep) VALUES (s.value, null)", tableName); + }); + }); + + assertEquals("Target should be unchanged", + ImmutableList.of(row(1, "emp-id-one")), + sql("SELECT * FROM %s ORDER BY id ASC NULLS LAST", tableName)); + } + + @Test + public void testMergeWithMultipleUpdatesForTargetRowSmallTargetLargeSourceNoNotMatchedActions() { + createAndInitTable("id INT, dep STRING", + "{ \"id\": 1, \"dep\": \"emp-id-one\" }\n" + + "{ \"id\": 6, \"dep\": \"emp-id-6\" }"); + + List sourceIds = Lists.newArrayList(); + for (int i = 0; i < 10_000; i++) { + sourceIds.add(i); + } + Dataset ds = spark.createDataset(sourceIds, Encoders.INT()); + ds.union(ds).createOrReplaceTempView("source"); + + String errorMsg = "a single row from the target table with multiple rows of the source table"; + AssertHelpers.assertThrowsCause("Should complain about multiple matches", + SparkException.class, errorMsg, + () -> { + sql("MERGE INTO %s AS t USING source AS s " + + "ON t.id == s.value " + + "WHEN MATCHED AND t.id = 1 THEN " + + " UPDATE SET id = 10 " + + "WHEN MATCHED AND t.id = 6 THEN " + + " DELETE", tableName); + }); + + assertEquals("Target should be unchanged", + ImmutableList.of(row(1, "emp-id-one"), row(6, "emp-id-6")), + sql("SELECT * FROM %s ORDER BY id ASC NULLS LAST", tableName)); + } + + @Test + public void testMergeWithMultipleUpdatesForTargetRowSmallTargetLargeSourceNoNotMatchedActionsNoEqualityCondition() { + createAndInitTable("id INT, dep STRING", "{ \"id\": 1, \"dep\": \"emp-id-one\" }"); + + List sourceIds = Lists.newArrayList(); + for (int i = 0; i < 10_000; i++) { + sourceIds.add(i); + } + Dataset ds = spark.createDataset(sourceIds, Encoders.INT()); + ds.union(ds).createOrReplaceTempView("source"); + + String errorMsg = "a single row from the target table with multiple rows of the source table"; + AssertHelpers.assertThrowsCause("Should complain about multiple matches", + SparkException.class, errorMsg, + () -> { + sql("MERGE INTO %s AS t USING source AS s " + + "ON t.id > s.value " + + "WHEN MATCHED AND t.id = 1 THEN " + + " UPDATE SET id = 10 " + + "WHEN MATCHED AND t.id = 6 THEN " + + " DELETE", tableName); + }); + + assertEquals("Target should be unchanged", + ImmutableList.of(row(1, "emp-id-one")), + sql("SELECT * FROM %s ORDER BY id ASC NULLS LAST", tableName)); + } + + @Test + public void testMergeWithMultipleUpdatesForTargetRow() { + createAndInitTable("id INT, dep STRING", + "{ \"id\": 1, \"dep\": \"emp-id-one\" }\n" + + "{ \"id\": 6, \"dep\": \"emp-id-6\" }"); + + createOrReplaceView("source", "id INT, dep STRING", + "{ \"id\": 1, \"dep\": \"emp-id-1\" }\n" + + "{ \"id\": 1, \"dep\": \"emp-id-1\" }\n" + + "{ \"id\": 2, \"dep\": \"emp-id-2\" }\n" + + "{ \"id\": 6, \"dep\": \"emp-id-6\" }"); + + String errorMsg = "a single row from the target table with multiple rows of the source table"; + AssertHelpers.assertThrowsCause("Should complain about multiple matches", + SparkException.class, errorMsg, + () -> { + sql("MERGE INTO %s AS t USING source AS s " + + "ON t.id == s.id " + + "WHEN MATCHED AND t.id = 1 THEN " + + " UPDATE SET * " + + "WHEN MATCHED AND t.id = 6 THEN " + + " DELETE " + + "WHEN NOT MATCHED AND s.id = 2 THEN " + + " INSERT *", tableName); + }); + + assertEquals("Target should be unchanged", + ImmutableList.of(row(1, "emp-id-one"), row(6, "emp-id-6")), + sql("SELECT * FROM %s ORDER BY id ASC NULLS LAST", tableName)); + } + + @Test + public void testMergeWithUnconditionalDelete() { + createAndInitTable("id INT, dep STRING", + "{ \"id\": 1, \"dep\": \"emp-id-one\" }\n" + + "{ \"id\": 6, \"dep\": \"emp-id-6\" }"); + + createOrReplaceView("source", "id INT, dep STRING", + "{ \"id\": 1, \"dep\": \"emp-id-1\" }\n" + + "{ \"id\": 1, \"dep\": \"emp-id-1\" }\n" + + "{ \"id\": 2, \"dep\": \"emp-id-2\" }\n" + + "{ \"id\": 6, \"dep\": \"emp-id-6\" }"); + + sql("MERGE INTO %s AS t USING source AS s " + + "ON t.id == s.id " + + "WHEN MATCHED THEN " + + " DELETE " + + "WHEN NOT MATCHED AND s.id = 2 THEN " + + " INSERT *", tableName); + + ImmutableList expectedRows = ImmutableList.of( + row(2, "emp-id-2") // new + ); + assertEquals("Should have expected rows", expectedRows, sql("SELECT * FROM %s ORDER BY id", tableName)); + } + + @Test + public void testMergeWithSingleConditionalDelete() { + createAndInitTable("id INT, dep STRING", + "{ \"id\": 1, \"dep\": \"emp-id-one\" }\n" + + "{ \"id\": 6, \"dep\": \"emp-id-6\" }"); + + createOrReplaceView("source", "id INT, dep STRING", + "{ \"id\": 1, \"dep\": \"emp-id-1\" }\n" + + "{ \"id\": 1, \"dep\": \"emp-id-1\" }\n" + + "{ \"id\": 2, \"dep\": \"emp-id-2\" }\n" + + "{ \"id\": 6, \"dep\": \"emp-id-6\" }"); + + String errorMsg = "a single row from the target table with multiple rows of the source table"; + AssertHelpers.assertThrowsCause("Should complain about multiple matches", + SparkException.class, errorMsg, + () -> { + sql("MERGE INTO %s AS t USING source AS s " + + "ON t.id == s.id " + + "WHEN MATCHED AND t.id = 1 THEN " + + " DELETE " + + "WHEN NOT MATCHED AND s.id = 2 THEN " + + " INSERT *", tableName); + }); + + assertEquals("Target should be unchanged", + ImmutableList.of(row(1, "emp-id-one"), row(6, "emp-id-6")), + sql("SELECT * FROM %s ORDER BY id ASC NULLS LAST", tableName)); + } + + @Test + public void testMergeWithIdentityTransform() { + for (DistributionMode mode : DistributionMode.values()) { + createAndInitTable("id INT, dep STRING"); + sql("ALTER TABLE %s ADD PARTITION FIELD identity(dep)", tableName); + sql("ALTER TABLE %s SET TBLPROPERTIES('%s' '%s')", tableName, WRITE_DISTRIBUTION_MODE, mode.modeName()); + + append(tableName, + "{ \"id\": 1, \"dep\": \"emp-id-one\" }\n" + + "{ \"id\": 6, \"dep\": \"emp-id-6\" }"); + + createOrReplaceView("source", "id INT, dep STRING", + "{ \"id\": 2, \"dep\": \"emp-id-2\" }\n" + + "{ \"id\": 1, \"dep\": \"emp-id-1\" }\n" + + "{ \"id\": 6, \"dep\": \"emp-id-6\" }"); + + sql("MERGE INTO %s AS t USING source AS s " + + "ON t.id == s.id " + + "WHEN MATCHED AND t.id = 1 THEN " + + " UPDATE SET * " + + "WHEN MATCHED AND t.id = 6 THEN " + + " DELETE " + + "WHEN NOT MATCHED AND s.id = 2 THEN " + + " INSERT *", tableName); + + ImmutableList expectedRows = ImmutableList.of( + row(1, "emp-id-1"), // updated + row(2, "emp-id-2") // new + ); + assertEquals("Should have expected rows", expectedRows, sql("SELECT * FROM %s ORDER BY id", tableName)); + + removeTables(); + } + } + + @Test + public void testMergeWithDaysTransform() { + for (DistributionMode mode : DistributionMode.values()) { + createAndInitTable("id INT, ts TIMESTAMP"); + sql("ALTER TABLE %s ADD PARTITION FIELD days(ts)", tableName); + sql("ALTER TABLE %s SET TBLPROPERTIES('%s' '%s')", tableName, WRITE_DISTRIBUTION_MODE, mode.modeName()); + + append(tableName, "id INT, ts TIMESTAMP", + "{ \"id\": 1, \"ts\": \"2000-01-01 00:00:00\" }\n" + + "{ \"id\": 6, \"ts\": \"2000-01-06 00:00:00\" }"); + + createOrReplaceView("source", "id INT, ts TIMESTAMP", + "{ \"id\": 2, \"ts\": \"2001-01-02 00:00:00\" }\n" + + "{ \"id\": 1, \"ts\": \"2001-01-01 00:00:00\" }\n" + + "{ \"id\": 6, \"ts\": \"2001-01-06 00:00:00\" }"); + + sql("MERGE INTO %s AS t USING source AS s " + + "ON t.id == s.id " + + "WHEN MATCHED AND t.id = 1 THEN " + + " UPDATE SET * " + + "WHEN MATCHED AND t.id = 6 THEN " + + " DELETE " + + "WHEN NOT MATCHED AND s.id = 2 THEN " + + " INSERT *", tableName); + + ImmutableList expectedRows = ImmutableList.of( + row(1, "2001-01-01 00:00:00"), // updated + row(2, "2001-01-02 00:00:00") // new + ); + assertEquals("Should have expected rows", + expectedRows, + sql("SELECT id, CAST(ts AS STRING) FROM %s ORDER BY id", tableName)); + + removeTables(); + } + } + + @Test + public void testMergeWithBucketTransform() { + for (DistributionMode mode : DistributionMode.values()) { + createAndInitTable("id INT, dep STRING"); + sql("ALTER TABLE %s ADD PARTITION FIELD bucket(2, dep)", tableName); + sql("ALTER TABLE %s SET TBLPROPERTIES('%s' '%s')", tableName, WRITE_DISTRIBUTION_MODE, mode.modeName()); + + append(tableName, + "{ \"id\": 1, \"dep\": \"emp-id-one\" }\n" + + "{ \"id\": 6, \"dep\": \"emp-id-6\" }"); + + createOrReplaceView("source", "id INT, dep STRING", + "{ \"id\": 2, \"dep\": \"emp-id-2\" }\n" + + "{ \"id\": 1, \"dep\": \"emp-id-1\" }\n" + + "{ \"id\": 6, \"dep\": \"emp-id-6\" }"); + + sql("MERGE INTO %s AS t USING source AS s " + + "ON t.id == s.id " + + "WHEN MATCHED AND t.id = 1 THEN " + + " UPDATE SET * " + + "WHEN MATCHED AND t.id = 6 THEN " + + " DELETE " + + "WHEN NOT MATCHED AND s.id = 2 THEN " + + " INSERT *", tableName); + + ImmutableList expectedRows = ImmutableList.of( + row(1, "emp-id-1"), // updated + row(2, "emp-id-2") // new + ); + assertEquals("Should have expected rows", expectedRows, sql("SELECT * FROM %s ORDER BY id", tableName)); + + removeTables(); + } + } + + @Test + public void testMergeWithTruncateTransform() { + for (DistributionMode mode : DistributionMode.values()) { + createAndInitTable("id INT, dep STRING"); + sql("ALTER TABLE %s ADD PARTITION FIELD truncate(dep, 2)", tableName); + sql("ALTER TABLE %s SET TBLPROPERTIES('%s' '%s')", tableName, WRITE_DISTRIBUTION_MODE, mode.modeName()); + + append(tableName, + "{ \"id\": 1, \"dep\": \"emp-id-one\" }\n" + + "{ \"id\": 6, \"dep\": \"emp-id-6\" }"); + + createOrReplaceView("source", "id INT, dep STRING", + "{ \"id\": 2, \"dep\": \"emp-id-2\" }\n" + + "{ \"id\": 1, \"dep\": \"emp-id-1\" }\n" + + "{ \"id\": 6, \"dep\": \"emp-id-6\" }"); + + sql("MERGE INTO %s AS t USING source AS s " + + "ON t.id == s.id " + + "WHEN MATCHED AND t.id = 1 THEN " + + " UPDATE SET * " + + "WHEN MATCHED AND t.id = 6 THEN " + + " DELETE " + + "WHEN NOT MATCHED AND s.id = 2 THEN " + + " INSERT *", tableName); + + ImmutableList expectedRows = ImmutableList.of( + row(1, "emp-id-1"), // updated + row(2, "emp-id-2") // new + ); + assertEquals("Should have expected rows", expectedRows, sql("SELECT * FROM %s ORDER BY id", tableName)); + + removeTables(); + } + } + + @Test + public void testMergeIntoPartitionedAndOrderedTable() { + for (DistributionMode mode : DistributionMode.values()) { + createAndInitTable("id INT, dep STRING"); + sql("ALTER TABLE %s ADD PARTITION FIELD dep", tableName); + sql("ALTER TABLE %s WRITE ORDERED BY (id)", tableName); + sql("ALTER TABLE %s SET TBLPROPERTIES('%s' '%s')", tableName, WRITE_DISTRIBUTION_MODE, mode.modeName()); + + append(tableName, + "{ \"id\": 1, \"dep\": \"emp-id-one\" }\n" + + "{ \"id\": 6, \"dep\": \"emp-id-6\" }"); + + createOrReplaceView("source", "id INT, dep STRING", + "{ \"id\": 2, \"dep\": \"emp-id-2\" }\n" + + "{ \"id\": 1, \"dep\": \"emp-id-1\" }\n" + + "{ \"id\": 6, \"dep\": \"emp-id-6\" }"); + + sql("MERGE INTO %s AS t USING source AS s " + + "ON t.id == s.id " + + "WHEN MATCHED AND t.id = 1 THEN " + + " UPDATE SET * " + + "WHEN MATCHED AND t.id = 6 THEN " + + " DELETE " + + "WHEN NOT MATCHED AND s.id = 2 THEN " + + " INSERT *", tableName); + + ImmutableList expectedRows = ImmutableList.of( + row(1, "emp-id-1"), // updated + row(2, "emp-id-2") // new + ); + assertEquals("Should have expected rows", expectedRows, sql("SELECT * FROM %s ORDER BY id", tableName)); + + removeTables(); + } + } + + @Test + public void testSelfMerge() { + createAndInitTable("id INT, v STRING", + "{ \"id\": 1, \"v\": \"v1\" }\n" + + "{ \"id\": 2, \"v\": \"v2\" }"); + + sql("MERGE INTO %s t USING %s s " + + "ON t.id == s.id " + + "WHEN MATCHED AND t.id = 1 THEN " + + " UPDATE SET v = 'x' " + + "WHEN NOT MATCHED THEN " + + " INSERT *", tableName, tableName); + + ImmutableList expectedRows = ImmutableList.of( + row(1, "x"), // updated + row(2, "v2") // kept + ); + assertEquals("Output should match", expectedRows, sql("SELECT * FROM %s ORDER BY id", tableName)); + } + + @Test + public void testSelfMergeWithCaching() { + createAndInitTable("id INT, v STRING", + "{ \"id\": 1, \"v\": \"v1\" }\n" + + "{ \"id\": 2, \"v\": \"v2\" }"); + + sql("CACHE TABLE %s", tableName); + + sql("MERGE INTO %s t USING %s s " + + "ON t.id == s.id " + + "WHEN MATCHED AND t.id = 1 THEN " + + " UPDATE SET v = 'x' " + + "WHEN NOT MATCHED THEN " + + " INSERT *", tableName, tableName); + + ImmutableList expectedRows = ImmutableList.of( + row(1, "x"), // updated + row(2, "v2") // kept + ); + assertEquals("Output should match", expectedRows, sql("SELECT * FROM %s ORDER BY id", tableName)); + } + + @Test + public void testMergeWithSourceAsSelfSubquery() { + createAndInitTable("id INT, v STRING", + "{ \"id\": 1, \"v\": \"v1\" }\n" + + "{ \"id\": 2, \"v\": \"v2\" }"); + + createOrReplaceView("source", Arrays.asList(1, null), Encoders.INT()); + + sql("MERGE INTO %s t USING (SELECT id AS value FROM %s r JOIN source ON r.id = source.value) s " + + "ON t.id == s.value " + + "WHEN MATCHED AND t.id = 1 THEN " + + " UPDATE SET v = 'x' " + + "WHEN NOT MATCHED THEN " + + " INSERT (v, id) VALUES ('invalid', -1) ", tableName, tableName); + + ImmutableList expectedRows = ImmutableList.of( + row(1, "x"), // updated + row(2, "v2") // kept + ); + assertEquals("Output should match", expectedRows, sql("SELECT * FROM %s ORDER BY id", tableName)); + } + + @Test + public synchronized void testMergeWithSerializableIsolation() throws InterruptedException { + // cannot run tests with concurrency for Hadoop tables without atomic renames + Assume.assumeFalse(catalogName.equalsIgnoreCase("testhadoop")); + + createAndInitTable("id INT, dep STRING"); + createOrReplaceView("source", Collections.singletonList(1), Encoders.INT()); + + sql("ALTER TABLE %s SET TBLPROPERTIES('%s' '%s')", tableName, MERGE_ISOLATION_LEVEL, "serializable"); + + ExecutorService executorService = MoreExecutors.getExitingExecutorService( + (ThreadPoolExecutor) Executors.newFixedThreadPool(2)); + + AtomicInteger barrier = new AtomicInteger(0); + + // merge thread + Future mergeFuture = executorService.submit(() -> { + for (int numOperations = 0; numOperations < Integer.MAX_VALUE; numOperations++) { + while (barrier.get() < numOperations * 2) { + sleep(10); + } + sql("MERGE INTO %s t USING source s " + + "ON t.id == s.value " + + "WHEN MATCHED THEN " + + " UPDATE SET dep = 'x'", tableName); + barrier.incrementAndGet(); + } + }); + + // append thread + Future appendFuture = executorService.submit(() -> { + for (int numOperations = 0; numOperations < Integer.MAX_VALUE; numOperations++) { + while (barrier.get() < numOperations * 2) { + sleep(10); + } + sql("INSERT INTO TABLE %s VALUES (1, 'hr')", tableName); + barrier.incrementAndGet(); + } + }); + + try { + mergeFuture.get(); + Assert.fail("Expected a validation exception"); + } catch (ExecutionException e) { + Throwable sparkException = e.getCause(); + Assert.assertThat(sparkException, CoreMatchers.instanceOf(SparkException.class)); + Throwable validationException = sparkException.getCause(); + Assert.assertThat(validationException, CoreMatchers.instanceOf(ValidationException.class)); + String errMsg = validationException.getMessage(); + Assert.assertThat(errMsg, CoreMatchers.containsString("Found conflicting files that can contain")); + } finally { + appendFuture.cancel(true); + } + + executorService.shutdown(); + Assert.assertTrue("Timeout", executorService.awaitTermination(2, TimeUnit.MINUTES)); + } + + @Test + public synchronized void testMergeWithSnapshotIsolation() throws InterruptedException, ExecutionException { + // cannot run tests with concurrency for Hadoop tables without atomic renames + Assume.assumeFalse(catalogName.equalsIgnoreCase("testhadoop")); + + createAndInitTable("id INT, dep STRING"); + createOrReplaceView("source", Collections.singletonList(1), Encoders.INT()); + + sql("ALTER TABLE %s SET TBLPROPERTIES('%s' '%s')", tableName, MERGE_ISOLATION_LEVEL, "snapshot"); + + ExecutorService executorService = MoreExecutors.getExitingExecutorService( + (ThreadPoolExecutor) Executors.newFixedThreadPool(2)); + + AtomicInteger barrier = new AtomicInteger(0); + + // merge thread + Future mergeFuture = executorService.submit(() -> { + for (int numOperations = 0; numOperations < 20; numOperations++) { + while (barrier.get() < numOperations * 2) { + sleep(10); + } + sql("MERGE INTO %s t USING source s " + + "ON t.id == s.value " + + "WHEN MATCHED THEN " + + " UPDATE SET dep = 'x'", tableName); + barrier.incrementAndGet(); + } + }); + + // append thread + Future appendFuture = executorService.submit(() -> { + for (int numOperations = 0; numOperations < 20; numOperations++) { + while (barrier.get() < numOperations * 2) { + sleep(10); + } + sql("INSERT INTO TABLE %s VALUES (1, 'hr')", tableName); + barrier.incrementAndGet(); + } + }); + + try { + mergeFuture.get(); + } finally { + appendFuture.cancel(true); + } + + executorService.shutdown(); + Assert.assertTrue("Timeout", executorService.awaitTermination(2, TimeUnit.MINUTES)); + } + + @Test + public void testMergeWithExtraColumnsInSource() { + createAndInitTable("id INT, v STRING", + "{ \"id\": 1, \"v\": \"v1\" }\n" + + "{ \"id\": 2, \"v\": \"v2\" }"); + createOrReplaceView("source", + "{ \"id\": 1, \"extra_col\": -1, \"v\": \"v1_1\" }\n" + + "{ \"id\": 3, \"extra_col\": -1, \"v\": \"v3\" }\n" + + "{ \"id\": 4, \"extra_col\": -1, \"v\": \"v4\" }"); + + sql("MERGE INTO %s t USING source " + + "ON t.id == source.id " + + "WHEN MATCHED THEN " + + " UPDATE SET v = source.v " + + "WHEN NOT MATCHED THEN " + + " INSERT (v, id) VALUES (source.v, source.id)", tableName); + + ImmutableList expectedRows = ImmutableList.of( + row(1, "v1_1"), // new + row(2, "v2"), // kept + row(3, "v3"), // new + row(4, "v4") // new + ); + assertEquals("Output should match", expectedRows, sql("SELECT * FROM %s ORDER BY id", tableName)); + } + + @Test + public void testMergeWithNullsInTargetAndSource() { + createAndInitTable("id INT, v STRING", + "{ \"id\": null, \"v\": \"v1\" }\n" + + "{ \"id\": 2, \"v\": \"v2\" }"); + + createOrReplaceView("source", + "{ \"id\": null, \"v\": \"v1_1\" }\n" + + "{ \"id\": 4, \"v\": \"v4\" }"); + + sql("MERGE INTO %s t USING source " + + "ON t.id == source.id " + + "WHEN MATCHED THEN " + + " UPDATE SET v = source.v " + + "WHEN NOT MATCHED THEN " + + " INSERT (v, id) VALUES (source.v, source.id)", tableName); + + ImmutableList expectedRows = ImmutableList.of( + row(null, "v1"), // kept + row(null, "v1_1"), // new + row(2, "v2"), // kept + row(4, "v4") // new + ); + assertEquals("Output should match", expectedRows, sql("SELECT * FROM %s ORDER BY v", tableName)); + } + + @Test + public void testMergeWithNullSafeEquals() { + createAndInitTable("id INT, v STRING", + "{ \"id\": null, \"v\": \"v1\" }\n" + + "{ \"id\": 2, \"v\": \"v2\" }"); + + createOrReplaceView("source", + "{ \"id\": null, \"v\": \"v1_1\" }\n" + + "{ \"id\": 4, \"v\": \"v4\" }"); + + sql("MERGE INTO %s t USING source " + + "ON t.id <=> source.id " + + "WHEN MATCHED THEN " + + " UPDATE SET v = source.v " + + "WHEN NOT MATCHED THEN " + + " INSERT (v, id) VALUES (source.v, source.id)", tableName); + + ImmutableList expectedRows = ImmutableList.of( + row(null, "v1_1"), // updated + row(2, "v2"), // kept + row(4, "v4") // new + ); + assertEquals("Output should match", expectedRows, sql("SELECT * FROM %s ORDER BY v", tableName)); + } + + @Test + public void testMergeWithNullCondition() { + createAndInitTable("id INT, v STRING", + "{ \"id\": null, \"v\": \"v1\" }\n" + + "{ \"id\": 2, \"v\": \"v2\" }"); + + createOrReplaceView("source", + "{ \"id\": null, \"v\": \"v1_1\" }\n" + + "{ \"id\": 2, \"v\": \"v2_2\" }"); + + sql("MERGE INTO %s t USING source " + + "ON t.id == source.id AND NULL " + + "WHEN MATCHED THEN " + + " UPDATE SET v = source.v " + + "WHEN NOT MATCHED THEN " + + " INSERT (v, id) VALUES (source.v, source.id)", tableName); + + ImmutableList expectedRows = ImmutableList.of( + row(null, "v1"), // kept + row(null, "v1_1"), // new + row(2, "v2"), // kept + row(2, "v2_2") // new + ); + assertEquals("Output should match", expectedRows, sql("SELECT * FROM %s ORDER BY v", tableName)); + } + + @Test + public void testMergeWithNullActionConditions() { + createAndInitTable("id INT, v STRING", + "{ \"id\": 1, \"v\": \"v1\" }\n" + + "{ \"id\": 2, \"v\": \"v2\" }"); + + createOrReplaceView("source", + "{ \"id\": 1, \"v\": \"v1_1\" }\n" + + "{ \"id\": 2, \"v\": \"v2_2\" }\n" + + "{ \"id\": 3, \"v\": \"v3_3\" }"); + + // all conditions are NULL and will never match any rows + sql("MERGE INTO %s t USING source " + + "ON t.id == source.id " + + "WHEN MATCHED AND source.id = 1 AND NULL THEN " + + " UPDATE SET v = source.v " + + "WHEN MATCHED AND source.v = 'v1_1' AND NULL THEN " + + " DELETE " + + "WHEN NOT MATCHED AND source.id = 3 AND NULL THEN " + + " INSERT (v, id) VALUES (source.v, source.id)", tableName); + + ImmutableList expectedRows1 = ImmutableList.of( + row(1, "v1"), // kept + row(2, "v2") // kept + ); + assertEquals("Output should match", expectedRows1, sql("SELECT * FROM %s ORDER BY v", tableName)); + + // only the update and insert conditions are NULL + sql("MERGE INTO %s t USING source " + + "ON t.id == source.id " + + "WHEN MATCHED AND source.id = 1 AND NULL THEN " + + " UPDATE SET v = source.v " + + "WHEN MATCHED AND source.v = 'v1_1' THEN " + + " DELETE " + + "WHEN NOT MATCHED AND source.id = 3 AND NULL THEN " + + " INSERT (v, id) VALUES (source.v, source.id)", tableName); + + ImmutableList expectedRows2 = ImmutableList.of( + row(2, "v2") // kept + ); + assertEquals("Output should match", expectedRows2, sql("SELECT * FROM %s ORDER BY v", tableName)); + } + + @Test + public void testMergeWithMultipleMatchingActions() { + createAndInitTable("id INT, v STRING", + "{ \"id\": 1, \"v\": \"v1\" }\n" + + "{ \"id\": 2, \"v\": \"v2\" }"); + + createOrReplaceView("source", + "{ \"id\": 1, \"v\": \"v1_1\" }\n" + + "{ \"id\": 2, \"v\": \"v2_2\" }"); + + // the order of match actions is important in this case + sql("MERGE INTO %s t USING source " + + "ON t.id == source.id " + + "WHEN MATCHED AND source.id = 1 THEN " + + " UPDATE SET v = source.v " + + "WHEN MATCHED AND source.v = 'v1_1' THEN " + + " DELETE " + + "WHEN NOT MATCHED THEN " + + " INSERT (v, id) VALUES (source.v, source.id)", tableName); + + ImmutableList expectedRows = ImmutableList.of( + row(1, "v1_1"), // updated (also matches the delete cond but update is first) + row(2, "v2") // kept (matches neither the update nor the delete cond) + ); + assertEquals("Output should match", expectedRows, sql("SELECT * FROM %s ORDER BY v", tableName)); + } + + @Test + public void testMergeWithMultipleRowGroupsParquet() throws NoSuchTableException { + Assume.assumeTrue(fileFormat.equalsIgnoreCase("parquet")); + + createAndInitTable("id INT, dep STRING"); + sql("ALTER TABLE %s ADD PARTITION FIELD dep", tableName); + + sql("ALTER TABLE %s SET TBLPROPERTIES('%s' '%d')", tableName, PARQUET_ROW_GROUP_SIZE_BYTES, 100); + sql("ALTER TABLE %s SET TBLPROPERTIES('%s' '%d')", tableName, SPLIT_SIZE, 100); + + createOrReplaceView("source", Collections.singletonList(1), Encoders.INT()); + + List ids = Lists.newArrayListWithCapacity(200); + for (int id = 1; id <= 200; id++) { + ids.add(id); + } + Dataset df = spark.createDataset(ids, Encoders.INT()) + .withColumnRenamed("value", "id") + .withColumn("dep", lit("hr")); + df.coalesce(1).writeTo(tableName).append(); + + Assert.assertEquals(200, spark.table(tableName).count()); + + // update a record from one of two row groups and copy over the second one + sql("MERGE INTO %s t USING source " + + "ON t.id == source.value " + + "WHEN MATCHED THEN " + + " UPDATE SET dep = 'x'", tableName); + + Assert.assertEquals(200, spark.table(tableName).count()); + } + + @Test + public void testMergeInsertOnly() { + createAndInitTable("id STRING, v STRING", + "{ \"id\": \"a\", \"v\": \"v1\" }\n" + + "{ \"id\": \"b\", \"v\": \"v2\" }"); + createOrReplaceView("source", + "{ \"id\": \"a\", \"v\": \"v1_1\" }\n" + + "{ \"id\": \"a\", \"v\": \"v1_2\" }\n" + + "{ \"id\": \"c\", \"v\": \"v3\" }\n" + + "{ \"id\": \"d\", \"v\": \"v4_1\" }\n" + + "{ \"id\": \"d\", \"v\": \"v4_2\" }"); + + sql("MERGE INTO %s t USING source " + + "ON t.id == source.id " + + "WHEN NOT MATCHED THEN " + + " INSERT *", tableName); + + ImmutableList expectedRows = ImmutableList.of( + row("a", "v1"), // kept + row("b", "v2"), // kept + row("c", "v3"), // new + row("d", "v4_1"), // new + row("d", "v4_2") // new + ); + assertEquals("Output should match", expectedRows, sql("SELECT * FROM %s ORDER BY id", tableName)); + } + + @Test + public void testMergeInsertOnlyWithCondition() { + createAndInitTable("id INTEGER, v INTEGER", "{ \"id\": 1, \"v\": 1 }"); + createOrReplaceView("source", + "{ \"id\": 1, \"v\": 11, \"is_new\": true }\n" + + "{ \"id\": 2, \"v\": 21, \"is_new\": true }\n" + + "{ \"id\": 2, \"v\": 22, \"is_new\": false }"); + + // validate assignments are reordered to match the table attrs + sql("MERGE INTO %s t USING source s " + + "ON t.id == s.id " + + "WHEN NOT MATCHED AND is_new = TRUE THEN " + + " INSERT (v, id) VALUES (s.v + 100, s.id)", tableName); + + ImmutableList expectedRows = ImmutableList.of( + row(1, 1), // kept + row(2, 121) // new + ); + assertEquals("Output should match", expectedRows, sql("SELECT * FROM %s ORDER BY id", tableName)); + } + + @Test + public void testMergeAlignsUpdateAndInsertActions() { + createAndInitTable("id INT, a INT, b STRING", "{ \"id\": 1, \"a\": 2, \"b\": \"str\" }"); + createOrReplaceView("source", + "{ \"id\": 1, \"c1\": -2, \"c2\": \"new_str_1\" }\n" + + "{ \"id\": 2, \"c1\": -20, \"c2\": \"new_str_2\" }"); + + sql("MERGE INTO %s t USING source " + + "ON t.id == source.id " + + "WHEN MATCHED THEN " + + " UPDATE SET b = c2, a = c1, t.id = source.id " + + "WHEN NOT MATCHED THEN " + + " INSERT (b, a, id) VALUES (c2, c1, id)", tableName); + + assertEquals("Output should match", + ImmutableList.of(row(1, -2, "new_str_1"), row(2, -20, "new_str_2")), + sql("SELECT * FROM %s ORDER BY id", tableName)); + } + + @Test + public void testMergeMixedCaseAlignsUpdateAndInsertActions() { + createAndInitTable("id INT, a INT, b STRING", "{ \"id\": 1, \"a\": 2, \"b\": \"str\" }"); + createOrReplaceView( + "source", + "{ \"id\": 1, \"c1\": -2, \"c2\": \"new_str_1\" }\n" + + "{ \"id\": 2, \"c1\": -20, \"c2\": \"new_str_2\" }"); + + sql("MERGE INTO %s t USING source " + + "ON t.iD == source.Id " + + "WHEN MATCHED THEN " + + " UPDATE SET B = c2, A = c1, t.Id = source.ID " + + "WHEN NOT MATCHED THEN " + + " INSERT (b, A, iD) VALUES (c2, c1, id)", tableName); + + assertEquals( + "Output should match", + ImmutableList.of(row(1, -2, "new_str_1"), row(2, -20, "new_str_2")), + sql("SELECT * FROM %s ORDER BY id", tableName)); + + assertEquals( + "Output should match", + ImmutableList.of(row(1, -2, "new_str_1")), + sql("SELECT * FROM %s WHERE id = 1 ORDER BY id", tableName)); + assertEquals( + "Output should match", + ImmutableList.of(row(2, -20, "new_str_2")), + sql("SELECT * FROM %s WHERE b = 'new_str_2'ORDER BY id", tableName)); + } + + @Test + public void testMergeUpdatesNestedStructFields() { + createAndInitTable("id INT, s STRUCT,m:MAP>>", + "{ \"id\": 1, \"s\": { \"c1\": 2, \"c2\": { \"a\": [1,2], \"m\": { \"a\": \"b\"} } } } }"); + createOrReplaceView("source", "{ \"id\": 1, \"c1\": -2 }"); + + // update primitive, array, map columns inside a struct + sql("MERGE INTO %s t USING source " + + "ON t.id == source.id " + + "WHEN MATCHED THEN " + + " UPDATE SET t.s.c1 = source.c1, t.s.c2.a = array(-1, -2), t.s.c2.m = map('k', 'v')", tableName); + + assertEquals("Output should match", + ImmutableList.of(row(1, row(-2, row(ImmutableList.of(-1, -2), ImmutableMap.of("k", "v"))))), + sql("SELECT * FROM %s ORDER BY id", tableName)); + + // set primitive, array, map columns to NULL (proper casts should be in place) + sql("MERGE INTO %s t USING source " + + "ON t.id == source.id " + + "WHEN MATCHED THEN " + + " UPDATE SET t.s.c1 = NULL, t.s.c2 = NULL", tableName); + + assertEquals("Output should match", + ImmutableList.of(row(1, row(null, null))), + sql("SELECT * FROM %s ORDER BY id", tableName)); + + // update all fields in a struct + sql("MERGE INTO %s t USING source " + + "ON t.id == source.id " + + "WHEN MATCHED THEN " + + " UPDATE SET t.s = named_struct('c1', 100, 'c2', named_struct('a', array(1), 'm', map('x', 'y')))", tableName); + + assertEquals("Output should match", + ImmutableList.of(row(1, row(100, row(ImmutableList.of(1), ImmutableMap.of("x", "y"))))), + sql("SELECT * FROM %s ORDER BY id", tableName)); + } + + @Test + public void testMergeWithInferredCasts() { + createAndInitTable("id INT, s STRING", "{ \"id\": 1, \"s\": \"value\" }"); + createOrReplaceView("source", "{ \"id\": 1, \"c1\": -2}"); + + // -2 in source should be casted to "-2" in target + sql("MERGE INTO %s t USING source " + + "ON t.id == source.id " + + "WHEN MATCHED THEN " + + " UPDATE SET t.s = source.c1", tableName); + + assertEquals("Output should match", + ImmutableList.of(row(1, "-2")), + sql("SELECT * FROM %s ORDER BY id", tableName)); + } + + @Test + public void testMergeModifiesNullStruct() { + createAndInitTable("id INT, s STRUCT", "{ \"id\": 1, \"s\": null }"); + createOrReplaceView("source", "{ \"id\": 1, \"n1\": -10 }"); + + sql("MERGE INTO %s t USING source s " + + "ON t.id == s.id " + + "WHEN MATCHED THEN " + + " UPDATE SET t.s.n1 = s.n1", tableName); + + assertEquals("Output should match", + ImmutableList.of(row(1, row(-10, null))), + sql("SELECT * FROM %s", tableName)); + } + + @Test + public void testMergeRefreshesRelationCache() { + createAndInitTable("id INT, name STRING", "{ \"id\": 1, \"name\": \"n1\" }"); + createOrReplaceView("source", "{ \"id\": 1, \"name\": \"n2\" }"); + + Dataset query = spark.sql("SELECT name FROM " + tableName); + query.createOrReplaceTempView("tmp"); + + spark.sql("CACHE TABLE tmp"); + + assertEquals("View should have correct data", + ImmutableList.of(row("n1")), + sql("SELECT * FROM tmp")); + + sql("MERGE INTO %s t USING source s " + + "ON t.id == s.id " + + "WHEN MATCHED THEN " + + " UPDATE SET t.name = s.name", tableName); + + assertEquals("View should have correct data", + ImmutableList.of(row("n2")), + sql("SELECT * FROM tmp")); + + spark.sql("UNCACHE TABLE tmp"); + } + + @Test + public void testMergeWithMultipleNotMatchedActions() { + createAndInitTable("id INT, dep STRING", "{ \"id\": 0, \"dep\": \"emp-id-0\" }"); + + createOrReplaceView("source", "id INT, dep STRING", + "{ \"id\": 1, \"dep\": \"emp-id-1\" }\n" + + "{ \"id\": 2, \"dep\": \"emp-id-2\" }\n" + + "{ \"id\": 3, \"dep\": \"emp-id-3\" }"); + + sql("MERGE INTO %s AS t USING source AS s " + + "ON t.id == s.id " + + "WHEN NOT MATCHED AND s.id = 1 THEN " + + " INSERT (dep, id) VALUES (s.dep, -1)" + + "WHEN NOT MATCHED THEN " + + " INSERT *", tableName); + + ImmutableList expectedRows = ImmutableList.of( + row(-1, "emp-id-1"), // new + row(0, "emp-id-0"), // kept + row(2, "emp-id-2"), // new + row(3, "emp-id-3") // new + ); + assertEquals("Should have expected rows", expectedRows, sql("SELECT * FROM %s ORDER BY id", tableName)); + } + + @Test + public void testMergeWithMultipleConditionalNotMatchedActions() { + createAndInitTable("id INT, dep STRING", "{ \"id\": 0, \"dep\": \"emp-id-0\" }"); + + createOrReplaceView("source", "id INT, dep STRING", + "{ \"id\": 1, \"dep\": \"emp-id-1\" }\n" + + "{ \"id\": 2, \"dep\": \"emp-id-2\" }\n" + + "{ \"id\": 3, \"dep\": \"emp-id-3\" }"); + + sql("MERGE INTO %s AS t USING source AS s " + + "ON t.id == s.id " + + "WHEN NOT MATCHED AND s.id = 1 THEN " + + " INSERT (dep, id) VALUES (s.dep, -1)" + + "WHEN NOT MATCHED AND s.id = 2 THEN " + + " INSERT *", tableName); + + ImmutableList expectedRows = ImmutableList.of( + row(-1, "emp-id-1"), // new + row(0, "emp-id-0"), // kept + row(2, "emp-id-2") // new + ); + assertEquals("Should have expected rows", expectedRows, sql("SELECT * FROM %s ORDER BY id", tableName)); + } + + @Test + public void testMergeResolvesColumnsByName() { + createAndInitTable("id INT, badge INT, dep STRING", + "{ \"id\": 1, \"badge\": 1000, \"dep\": \"emp-id-one\" }\n" + + "{ \"id\": 6, \"badge\": 6000, \"dep\": \"emp-id-6\" }"); + + createOrReplaceView("source", "badge INT, id INT, dep STRING", + "{ \"badge\": 1001, \"id\": 1, \"dep\": \"emp-id-1\" }\n" + + "{ \"badge\": 6006, \"id\": 6, \"dep\": \"emp-id-6\" }\n" + + "{ \"badge\": 7007, \"id\": 7, \"dep\": \"emp-id-7\" }"); + + sql("MERGE INTO %s AS t USING source AS s " + + "ON t.id == s.id " + + "WHEN MATCHED THEN " + + " UPDATE SET * " + + "WHEN NOT MATCHED THEN " + + " INSERT * ", tableName); + + ImmutableList expectedRows = ImmutableList.of( + row(1, 1001, "emp-id-1"), // updated + row(6, 6006, "emp-id-6"), // updated + row(7, 7007, "emp-id-7") // new + ); + assertEquals("Should have expected rows", expectedRows, + sql("SELECT id, badge, dep FROM %s ORDER BY id", tableName)); + } + + @Test + public void testMergeShouldResolveWhenThereAreNoUnresolvedExpressionsOrColumns() { + // ensures that MERGE INTO will resolve into the correct action even if no columns + // or otherwise unresolved expressions exist in the query (testing SPARK-34962) + createAndInitTable("id INT, dep STRING"); + + createOrReplaceView("source", "id INT, dep STRING", + "{ \"id\": 1, \"dep\": \"emp-id-1\" }\n" + + "{ \"id\": 2, \"dep\": \"emp-id-2\" }\n" + + "{ \"id\": 3, \"dep\": \"emp-id-3\" }"); + + sql("MERGE INTO %s AS t USING source AS s " + + "ON 1 != 1 " + + "WHEN MATCHED THEN " + + " UPDATE SET * " + + "WHEN NOT MATCHED THEN " + + " INSERT *", tableName); + + ImmutableList expectedRows = ImmutableList.of( + row(1, "emp-id-1"), // new + row(2, "emp-id-2"), // new + row(3, "emp-id-3") // new + ); + assertEquals("Should have expected rows", expectedRows, sql("SELECT * FROM %s ORDER BY id", tableName)); + } + + @Test + public void testMergeWithNonExistingColumns() { + createAndInitTable("id INT, c STRUCT>"); + createOrReplaceView("source", "{ \"c1\": -100, \"c2\": -200 }"); + + AssertHelpers.assertThrows("Should complain about the invalid top-level column", + AnalysisException.class, "cannot resolve t.invalid_col", + () -> { + sql("MERGE INTO %s t USING source s " + + "ON t.id == s.c1 " + + "WHEN MATCHED THEN " + + " UPDATE SET t.invalid_col = s.c2", tableName); + }); + + AssertHelpers.assertThrows("Should complain about the invalid nested column", + AnalysisException.class, "No such struct field invalid_col", + () -> { + sql("MERGE INTO %s t USING source s " + + "ON t.id == s.c1 " + + "WHEN MATCHED THEN " + + " UPDATE SET t.c.n2.invalid_col = s.c2", tableName); + }); + + AssertHelpers.assertThrows("Should complain about the invalid top-level column", + AnalysisException.class, "cannot resolve invalid_col", + () -> { + sql("MERGE INTO %s t USING source s " + + "ON t.id == s.c1 " + + "WHEN MATCHED THEN " + + " UPDATE SET t.c.n2.dn1 = s.c2 " + + "WHEN NOT MATCHED THEN " + + " INSERT (id, invalid_col) VALUES (s.c1, null)", tableName); + }); + } + + @Test + public void testMergeWithInvalidColumnsInInsert() { + createAndInitTable("id INT, c STRUCT>"); + createOrReplaceView("source", "{ \"c1\": -100, \"c2\": -200 }"); + + AssertHelpers.assertThrows("Should complain about the nested column", + AnalysisException.class, "Nested fields are not supported inside INSERT clauses", + () -> { + sql("MERGE INTO %s t USING source s " + + "ON t.id == s.c1 " + + "WHEN MATCHED THEN " + + " UPDATE SET t.c.n2.dn1 = s.c2 " + + "WHEN NOT MATCHED THEN " + + " INSERT (id, c.n2) VALUES (s.c1, null)", tableName); + }); + + AssertHelpers.assertThrows("Should complain about duplicate columns", + AnalysisException.class, "Duplicate column names inside INSERT clause", + () -> { + sql("MERGE INTO %s t USING source s " + + "ON t.id == s.c1 " + + "WHEN MATCHED THEN " + + " UPDATE SET t.c.n2.dn1 = s.c2 " + + "WHEN NOT MATCHED THEN " + + " INSERT (id, id) VALUES (s.c1, null)", tableName); + }); + + AssertHelpers.assertThrows("Should complain about missing columns", + AnalysisException.class, "must provide values for all columns of the target table", + () -> { + sql("MERGE INTO %s t USING source s " + + "ON t.id == s.c1 " + + "WHEN NOT MATCHED THEN " + + " INSERT (id) VALUES (s.c1)", tableName); + }); + } + + @Test + public void testMergeWithInvalidUpdates() { + createAndInitTable("id INT, a ARRAY>, m MAP"); + createOrReplaceView("source", "{ \"c1\": -100, \"c2\": -200 }"); + + AssertHelpers.assertThrows("Should complain about updating an array column", + AnalysisException.class, "Updating nested fields is only supported for structs", + () -> { + sql("MERGE INTO %s t USING source s " + + "ON t.id == s.c1 " + + "WHEN MATCHED THEN " + + " UPDATE SET t.a.c1 = s.c2", tableName); + }); + + AssertHelpers.assertThrows("Should complain about updating a map column", + AnalysisException.class, "Updating nested fields is only supported for structs", + () -> { + sql("MERGE INTO %s t USING source s " + + "ON t.id == s.c1 " + + "WHEN MATCHED THEN " + + " UPDATE SET t.m.key = 'new_key'", tableName); + }); + } + + @Test + public void testMergeWithConflictingUpdates() { + createAndInitTable("id INT, c STRUCT>"); + createOrReplaceView("source", "{ \"c1\": -100, \"c2\": -200 }"); + + AssertHelpers.assertThrows("Should complain about conflicting updates to a top-level column", + AnalysisException.class, "Updates are in conflict", + () -> { + sql("MERGE INTO %s t USING source s " + + "ON t.id == s.c1 " + + "WHEN MATCHED THEN " + + " UPDATE SET t.id = 1, t.c.n1 = 2, t.id = 2", tableName); + }); + + AssertHelpers.assertThrows("Should complain about conflicting updates to a nested column", + AnalysisException.class, "Updates are in conflict for these columns", + () -> { + sql("MERGE INTO %s t USING source s " + + "ON t.id == s.c1 " + + "WHEN MATCHED THEN " + + " UPDATE SET t.c.n1 = 1, t.id = 2, t.c.n1 = 2", tableName); + }); + + AssertHelpers.assertThrows("Should complain about conflicting updates to a nested column", + AnalysisException.class, "Updates are in conflict", + () -> { + sql("MERGE INTO %s t USING source s " + + "ON t.id == s.c1 " + + "WHEN MATCHED THEN " + + " UPDATE SET c.n1 = 1, c = named_struct('n1', 1, 'n2', named_struct('dn1', 1, 'dn2', 2))", tableName); + }); + } + + @Test + public void testMergeWithInvalidAssignments() { + createAndInitTable("id INT NOT NULL, s STRUCT> NOT NULL"); + createOrReplaceView( + "source", + "c1 INT, c2 STRUCT NOT NULL, c3 STRING NOT NULL, c4 STRUCT", + "{ \"c1\": -100, \"c2\": { \"n1\" : 1 }, \"c3\" : 'str', \"c4\": { \"dn2\": 1, \"dn2\": 2 } }"); + + for (String policy : new String[]{"ansi", "strict"}) { + withSQLConf(ImmutableMap.of("spark.sql.storeAssignmentPolicy", policy), () -> { + + AssertHelpers.assertThrows("Should complain about writing nulls to a top-level column", + AnalysisException.class, "Cannot write nullable values to non-null column", + () -> { + sql("MERGE INTO %s t USING source s " + + "ON t.id == s.c1 " + + "WHEN MATCHED THEN " + + " UPDATE SET t.id = NULL", tableName); + }); + + AssertHelpers.assertThrows("Should complain about writing nulls to a nested column", + AnalysisException.class, "Cannot write nullable values to non-null column", + () -> { + sql("MERGE INTO %s t USING source s " + + "ON t.id == s.c1 " + + "WHEN MATCHED THEN " + + " UPDATE SET t.s.n1 = NULL", tableName); + }); + + AssertHelpers.assertThrows("Should complain about writing missing fields in structs", + AnalysisException.class, "missing fields", + () -> { + sql("MERGE INTO %s t USING source s " + + "ON t.id == s.c1 " + + "WHEN MATCHED THEN " + + " UPDATE SET t.s = s.c2", tableName); + }); + + AssertHelpers.assertThrows("Should complain about writing invalid data types", + AnalysisException.class, "Cannot safely cast", + () -> { + sql("MERGE INTO %s t USING source s " + + "ON t.id == s.c1 " + + "WHEN MATCHED THEN " + + " UPDATE SET t.s.n1 = s.c3", tableName); + }); + + AssertHelpers.assertThrows("Should complain about writing incompatible structs", + AnalysisException.class, "field name does not match", + () -> { + sql("MERGE INTO %s t USING source s " + + "ON t.id == s.c1 " + + "WHEN MATCHED THEN " + + " UPDATE SET t.s.n2 = s.c4", tableName); + }); + }); + } + } + + @Test + public void testMergeWithNonDeterministicConditions() { + createAndInitTable("id INT, c STRUCT>"); + createOrReplaceView("source", "{ \"c1\": -100, \"c2\": -200 }"); + + AssertHelpers.assertThrows("Should complain about non-deterministic search conditions", + AnalysisException.class, "Non-deterministic functions are not supported", + () -> { + sql("MERGE INTO %s t USING source s " + + "ON t.id == s.c1 AND rand() > t.id " + + "WHEN MATCHED THEN " + + " UPDATE SET t.c.n1 = -1", tableName); + }); + + AssertHelpers.assertThrows("Should complain about non-deterministic update conditions", + AnalysisException.class, "Non-deterministic functions are not supported", + () -> { + sql("MERGE INTO %s t USING source s " + + "ON t.id == s.c1 " + + "WHEN MATCHED AND rand() > t.id THEN " + + " UPDATE SET t.c.n1 = -1", tableName); + }); + + AssertHelpers.assertThrows("Should complain about non-deterministic delete conditions", + AnalysisException.class, "Non-deterministic functions are not supported", + () -> { + sql("MERGE INTO %s t USING source s " + + "ON t.id == s.c1 " + + "WHEN MATCHED AND rand() > t.id THEN " + + " DELETE", tableName); + }); + + AssertHelpers.assertThrows("Should complain about non-deterministic insert conditions", + AnalysisException.class, "Non-deterministic functions are not supported", + () -> { + sql("MERGE INTO %s t USING source s " + + "ON t.id == s.c1 " + + "WHEN NOT MATCHED AND rand() > c1 THEN " + + " INSERT (id, c) VALUES (1, null)", tableName); + }); + } + + @Test + public void testMergeWithAggregateExpressions() { + createAndInitTable("id INT, c STRUCT>"); + createOrReplaceView("source", "{ \"c1\": -100, \"c2\": -200 }"); + + AssertHelpers.assertThrows("Should complain about agg expressions in search conditions", + AnalysisException.class, "Agg functions are not supported", + () -> { + sql("MERGE INTO %s t USING source s " + + "ON t.id == s.c1 AND max(t.id) == 1 " + + "WHEN MATCHED THEN " + + " UPDATE SET t.c.n1 = -1", tableName); + }); + + AssertHelpers.assertThrows("Should complain about agg expressions in update conditions", + AnalysisException.class, "Agg functions are not supported", + () -> { + sql("MERGE INTO %s t USING source s " + + "ON t.id == s.c1 " + + "WHEN MATCHED AND sum(t.id) < 1 THEN " + + " UPDATE SET t.c.n1 = -1", tableName); + }); + + AssertHelpers.assertThrows("Should complain about agg expressions in delete conditions", + AnalysisException.class, "Agg functions are not supported", + () -> { + sql("MERGE INTO %s t USING source s " + + "ON t.id == s.c1 " + + "WHEN MATCHED AND sum(t.id) THEN " + + " DELETE", tableName); + }); + + AssertHelpers.assertThrows("Should complain about agg expressions in insert conditions", + AnalysisException.class, "Agg functions are not supported", + () -> { + sql("MERGE INTO %s t USING source s " + + "ON t.id == s.c1 " + + "WHEN NOT MATCHED AND sum(c1) < 1 THEN " + + " INSERT (id, c) VALUES (1, null)", tableName); + }); + } + + @Test + public void testMergeWithSubqueriesInConditions() { + createAndInitTable("id INT, c STRUCT>"); + createOrReplaceView("source", "{ \"c1\": -100, \"c2\": -200 }"); + + AssertHelpers.assertThrows("Should complain about subquery expressions", + AnalysisException.class, "Subqueries are not supported in conditions", + () -> { + sql("MERGE INTO %s t USING source s " + + "ON t.id == s.c1 AND t.id < (SELECT max(c2) FROM source) " + + "WHEN MATCHED THEN " + + " UPDATE SET t.c.n1 = s.c2", tableName); + }); + + AssertHelpers.assertThrows("Should complain about subquery expressions", + AnalysisException.class, "Subqueries are not supported in conditions", + () -> { + sql("MERGE INTO %s t USING source s " + + "ON t.id == s.c1 " + + "WHEN MATCHED AND t.id < (SELECT max(c2) FROM source) THEN " + + " UPDATE SET t.c.n1 = s.c2", tableName); + }); + + AssertHelpers.assertThrows("Should complain about subquery expressions", + AnalysisException.class, "Subqueries are not supported in conditions", + () -> { + sql("MERGE INTO %s t USING source s " + + "ON t.id == s.c1 " + + "WHEN MATCHED AND t.id NOT IN (SELECT c2 FROM source) THEN " + + " DELETE", tableName); + }); + + AssertHelpers.assertThrows("Should complain about subquery expressions", + AnalysisException.class, "Subqueries are not supported in conditions", + () -> { + sql("MERGE INTO %s t USING source s " + + "ON t.id == s.c1 " + + "WHEN NOT MATCHED AND s.c1 IN (SELECT c2 FROM source) THEN " + + " INSERT (id, c) VALUES (1, null)", tableName); + }); + } + + @Test + public void testMergeWithTargetColumnsInInsertConditions() { + createAndInitTable("id INT, c2 INT"); + createOrReplaceView("source", "{ \"id\": 1, \"value\": 11 }"); + + AssertHelpers.assertThrows("Should complain about the target column", + AnalysisException.class, "Cannot resolve [c2]", + () -> { + sql("MERGE INTO %s t USING source s " + + "ON t.id == s.id " + + "WHEN NOT MATCHED AND c2 = 1 THEN " + + " INSERT (id, c2) VALUES (s.id, null)", tableName); + }); + } + + @Test + public void testMergeWithNonIcebergTargetTableNotSupported() { + createOrReplaceView("target", "{ \"c1\": -100, \"c2\": -200 }"); + createOrReplaceView("source", "{ \"c1\": -100, \"c2\": -200 }"); + + AssertHelpers.assertThrows("Should complain non iceberg target table", + UnsupportedOperationException.class, "MERGE INTO TABLE is not supported temporarily.", + () -> { + sql("MERGE INTO target t USING source s " + + "ON t.c1 == s.c1 " + + "WHEN MATCHED THEN " + + " UPDATE SET *"); + }); + } + + /** + * Tests a merge where both the source and target are evaluated to be partitioned by SingePartition at planning time + * but DynamicFileFilterExec will return an empty target. + */ + @Test + public void testMergeSinglePartitionPartitioning() { + // This table will only have a single file and a single partition + createAndInitTable("id INT", "{\"id\": -1}"); + + // Coalesce forces our source into a SinglePartition distribution + spark.range(0, 5).coalesce(1).createOrReplaceTempView("source"); + + sql("MERGE INTO %s t USING source s ON t.id = s.id " + + "WHEN MATCHED THEN UPDATE SET *" + + "WHEN NOT MATCHED THEN INSERT *", + tableName); + + ImmutableList expectedRows = ImmutableList.of( + row(-1), + row(0), + row(1), + row(2), + row(3), + row(4) + ); + + List result = sql("SELECT * FROM %s ORDER BY id", tableName); + assertEquals("Should correctly add the non-matching rows", expectedRows, result); + } + + @Test + public void testMergeEmptyTable() { + // This table will only have a single file and a single partition + createAndInitTable("id INT", null); + + // Coalesce forces our source into a SinglePartition distribution + spark.range(0, 5).coalesce(1).createOrReplaceTempView("source"); + + sql("MERGE INTO %s t USING source s ON t.id = s.id " + + "WHEN MATCHED THEN UPDATE SET *" + + "WHEN NOT MATCHED THEN INSERT *", + tableName); + + ImmutableList expectedRows = ImmutableList.of( + row(0), + row(1), + row(2), + row(3), + row(4) + ); + + List result = sql("SELECT * FROM %s ORDER BY id", tableName); + assertEquals("Should correctly add the non-matching rows", expectedRows, result); + } +} diff --git a/spark/v3.3/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestMergeOnReadDelete.java b/spark/v3.3/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestMergeOnReadDelete.java new file mode 100644 index 000000000000..d5454e036017 --- /dev/null +++ b/spark/v3.3/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestMergeOnReadDelete.java @@ -0,0 +1,105 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iceberg.spark.extensions; + +import java.util.Map; +import org.apache.iceberg.AssertHelpers; +import org.apache.iceberg.RowDelta; +import org.apache.iceberg.Table; +import org.apache.iceberg.TableProperties; +import org.apache.iceberg.exceptions.CommitStateUnknownException; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableList; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap; +import org.apache.iceberg.spark.source.SparkTable; +import org.apache.iceberg.spark.source.TestSparkCatalog; +import org.apache.spark.SparkException; +import org.apache.spark.sql.connector.catalog.Identifier; +import org.junit.Test; +import org.junit.runners.Parameterized; + +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.spy; +import static org.mockito.Mockito.when; + +public class TestMergeOnReadDelete extends TestDelete { + + public TestMergeOnReadDelete(String catalogName, String implementation, Map config, + String fileFormat, Boolean vectorized, String distributionMode) { + super(catalogName, implementation, config, fileFormat, vectorized, distributionMode); + } + + @Override + protected Map extraTableProperties() { + return ImmutableMap.of( + TableProperties.FORMAT_VERSION, "2", + TableProperties.DELETE_MODE, "merge-on-read" + ); + } + + @Parameterized.AfterParam + public static void clearTestSparkCatalogCache() { + TestSparkCatalog.clearTables(); + } + + @Test + public void testCommitUnknownException() { + createAndInitTable("id INT, dep STRING, category STRING"); + + // write unpartitioned files + append(tableName, "{ \"id\": 1, \"dep\": \"hr\", \"category\": \"c1\"}"); + append(tableName, "{ \"id\": 2, \"dep\": \"hr\", \"category\": \"c1\" }\n" + + "{ \"id\": 3, \"dep\": \"hr\", \"category\": \"c1\" }"); + + Table table = validationCatalog.loadTable(tableIdent); + + RowDelta newRowDelta = table.newRowDelta(); + RowDelta spyNewRowDelta = spy(newRowDelta); + doAnswer(invocation -> { + newRowDelta.commit(); + throw new CommitStateUnknownException(new RuntimeException("Datacenter on Fire")); + }).when(spyNewRowDelta).commit(); + + Table spyTable = spy(table); + when(spyTable.newRowDelta()).thenReturn(spyNewRowDelta); + SparkTable sparkTable = new SparkTable(spyTable, false); + + ImmutableMap config = ImmutableMap.of( + "type", "hive", + "default-namespace", "default" + ); + spark.conf().set("spark.sql.catalog.dummy_catalog", "org.apache.iceberg.spark.source.TestSparkCatalog"); + config.forEach((key, value) -> spark.conf().set("spark.sql.catalog.dummy_catalog." + key, value)); + Identifier ident = Identifier.of(new String[]{"default"}, "table"); + TestSparkCatalog.setTable(ident, sparkTable); + + // Although an exception is thrown here, write and commit have succeeded + AssertHelpers.assertThrowsWithCause("Should throw a Commit State Unknown Exception", + SparkException.class, + "Writing job aborted", + CommitStateUnknownException.class, + "Datacenter on Fire", + () -> sql("DELETE FROM %s WHERE id = 2", "dummy_catalog.default.table")); + + // Since write and commit succeeded, the rows should be readable + assertEquals("Should have expected rows", + ImmutableList.of(row(1, "hr", "c1"), row(3, "hr", "c1")), + sql("SELECT * FROM %s ORDER BY id", "dummy_catalog.default.table")); + } +} diff --git a/spark/v3.3/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestMergeOnReadMerge.java b/spark/v3.3/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestMergeOnReadMerge.java new file mode 100644 index 000000000000..d32f3464c3d0 --- /dev/null +++ b/spark/v3.3/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestMergeOnReadMerge.java @@ -0,0 +1,40 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iceberg.spark.extensions; + +import java.util.Map; +import org.apache.iceberg.TableProperties; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap; + +public class TestMergeOnReadMerge extends TestMerge { + + public TestMergeOnReadMerge(String catalogName, String implementation, Map config, + String fileFormat, boolean vectorized, String distributionMode) { + super(catalogName, implementation, config, fileFormat, vectorized, distributionMode); + } + + @Override + protected Map extraTableProperties() { + return ImmutableMap.of( + TableProperties.FORMAT_VERSION, "2", + TableProperties.MERGE_MODE, "merge-on-read" + ); + } +} diff --git a/spark/v3.3/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestMergeOnReadUpdate.java b/spark/v3.3/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestMergeOnReadUpdate.java new file mode 100644 index 000000000000..97c34f155894 --- /dev/null +++ b/spark/v3.3/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestMergeOnReadUpdate.java @@ -0,0 +1,40 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iceberg.spark.extensions; + +import java.util.Map; +import org.apache.iceberg.TableProperties; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap; + +public class TestMergeOnReadUpdate extends TestUpdate { + + public TestMergeOnReadUpdate(String catalogName, String implementation, Map config, + String fileFormat, boolean vectorized, String distributionMode) { + super(catalogName, implementation, config, fileFormat, vectorized, distributionMode); + } + + @Override + protected Map extraTableProperties() { + return ImmutableMap.of( + TableProperties.FORMAT_VERSION, "2", + TableProperties.UPDATE_MODE, "merge-on-read" + ); + } +} diff --git a/spark/v3.3/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestMetadataTables.java b/spark/v3.3/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestMetadataTables.java new file mode 100644 index 000000000000..60c57e12a3ff --- /dev/null +++ b/spark/v3.3/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestMetadataTables.java @@ -0,0 +1,363 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iceberg.spark.extensions; + +import java.io.IOException; +import java.util.Comparator; +import java.util.List; +import java.util.Map; +import java.util.stream.Collectors; +import java.util.stream.Stream; +import org.apache.avro.generic.GenericData.Record; +import org.apache.commons.collections.ListUtils; +import org.apache.iceberg.FileContent; +import org.apache.iceberg.ManifestFile; +import org.apache.iceberg.Schema; +import org.apache.iceberg.Table; +import org.apache.iceberg.avro.Avro; +import org.apache.iceberg.io.CloseableIterable; +import org.apache.iceberg.io.InputFile; +import org.apache.iceberg.relocated.com.google.common.collect.Lists; +import org.apache.iceberg.spark.Spark3Util; +import org.apache.iceberg.spark.data.TestHelpers; +import org.apache.iceberg.spark.source.SimpleRecord; +import org.apache.spark.sql.Encoders; +import org.apache.spark.sql.Row; +import org.junit.After; +import org.junit.Assert; +import org.junit.Test; + + +public class TestMetadataTables extends SparkExtensionsTestBase { + + public TestMetadataTables(String catalogName, String implementation, Map config) { + super(catalogName, implementation, config); + } + + @After + public void removeTables() { + sql("DROP TABLE IF EXISTS %s", tableName); + } + + @Test + public void testUnpartitionedTable() throws Exception { + sql("CREATE TABLE %s (id bigint, data string) USING iceberg TBLPROPERTIES" + + "('format-version'='2', 'write.delete.mode'='merge-on-read')", tableName); + + List records = Lists.newArrayList( + new SimpleRecord(1, "a"), + new SimpleRecord(2, "b"), + new SimpleRecord(3, "c"), + new SimpleRecord(4, "d") + ); + spark.createDataset(records, Encoders.bean(SimpleRecord.class)) + .coalesce(1) + .writeTo(tableName) + .append(); + + sql("DELETE FROM %s WHERE id=1", tableName); + + Table table = Spark3Util.loadIcebergTable(spark, tableName); + List expectedDataManifests = TestHelpers.dataManifests(table); + List expectedDeleteManifests = TestHelpers.deleteManifests(table); + Assert.assertEquals("Should have 1 data manifest", 1, expectedDataManifests.size()); + Assert.assertEquals("Should have 1 delete manifest", 1, expectedDeleteManifests.size()); + + Schema entriesTableSchema = Spark3Util.loadIcebergTable(spark, tableName + ".entries").schema(); + Schema filesTableSchema = Spark3Util.loadIcebergTable(spark, tableName + ".files").schema(); + + // check delete files table + List actualDeleteFiles = spark.sql("SELECT * FROM " + tableName + ".delete_files").collectAsList(); + Assert.assertEquals("Metadata table should return one delete file", 1, actualDeleteFiles.size()); + + List expectedDeleteFiles = expectedEntries(table, FileContent.POSITION_DELETES, + entriesTableSchema, expectedDeleteManifests, null); + Assert.assertEquals("Should be one delete file manifest entry", 1, expectedDeleteFiles.size()); + TestHelpers.assertEqualsSafe(filesTableSchema.asStruct(), expectedDeleteFiles.get(0), actualDeleteFiles.get(0)); + + // check data files table + List actualDataFiles = spark.sql("SELECT * FROM " + tableName + ".data_files").collectAsList(); + Assert.assertEquals("Metadata table should return one data file", 1, actualDataFiles.size()); + + List expectedDataFiles = expectedEntries(table, FileContent.DATA, + entriesTableSchema, expectedDataManifests, null); + Assert.assertEquals("Should be one data file manifest entry", 1, expectedDataFiles.size()); + TestHelpers.assertEqualsSafe(filesTableSchema.asStruct(), expectedDataFiles.get(0), actualDataFiles.get(0)); + + // check all files table + List actualFiles = spark.sql("SELECT * FROM " + tableName + ".files ORDER BY content").collectAsList(); + Assert.assertEquals("Metadata table should return two files", 2, actualFiles.size()); + + List expectedFiles = Stream.concat(expectedDataFiles.stream(), expectedDeleteFiles.stream()) + .collect(Collectors.toList()); + Assert.assertEquals("Should have two files manifest entries", 2, expectedFiles.size()); + TestHelpers.assertEqualsSafe(filesTableSchema.asStruct(), expectedFiles.get(0), actualFiles.get(0)); + TestHelpers.assertEqualsSafe(filesTableSchema.asStruct(), expectedFiles.get(1), actualFiles.get(1)); + } + + @Test + public void testPartitionedTable() throws Exception { + sql("CREATE TABLE %s (id bigint, data string) " + + "USING iceberg " + + "PARTITIONED BY (data) " + + "TBLPROPERTIES" + + "('format-version'='2', 'write.delete.mode'='merge-on-read')", tableName); + + List recordsA = Lists.newArrayList( + new SimpleRecord(1, "a"), + new SimpleRecord(2, "a") + ); + spark.createDataset(recordsA, Encoders.bean(SimpleRecord.class)) + .coalesce(1) + .writeTo(tableName) + .append(); + + List recordsB = Lists.newArrayList( + new SimpleRecord(1, "b"), + new SimpleRecord(2, "b") + ); + spark.createDataset(recordsB, Encoders.bean(SimpleRecord.class)) + .coalesce(1) + .writeTo(tableName) + .append(); + + sql("DELETE FROM %s WHERE id=1 AND data='a'", tableName); + sql("DELETE FROM %s WHERE id=1 AND data='b'", tableName); + + Table table = Spark3Util.loadIcebergTable(spark, tableName); + Schema entriesTableSchema = Spark3Util.loadIcebergTable(spark, tableName + ".entries").schema(); + + List expectedDataManifests = TestHelpers.dataManifests(table); + List expectedDeleteManifests = TestHelpers.deleteManifests(table); + Assert.assertEquals("Should have 2 data manifests", 2, expectedDataManifests.size()); + Assert.assertEquals("Should have 2 delete manifests", 2, expectedDeleteManifests.size()); + + Schema filesTableSchema = Spark3Util.loadIcebergTable(spark, tableName + ".delete_files").schema(); + + // Check delete files table + List expectedDeleteFiles = expectedEntries(table, FileContent.POSITION_DELETES, + entriesTableSchema, expectedDeleteManifests, "a"); + Assert.assertEquals("Should have one delete file manifest entry", 1, expectedDeleteFiles.size()); + + List actualDeleteFiles = spark.sql("SELECT * FROM " + tableName + ".delete_files " + + "WHERE partition.data='a'").collectAsList(); + Assert.assertEquals("Metadata table should return one delete file", 1, actualDeleteFiles.size()); + TestHelpers.assertEqualsSafe(filesTableSchema.asStruct(), expectedDeleteFiles.get(0), actualDeleteFiles.get(0)); + + // Check data files table + List expectedDataFiles = expectedEntries(table, FileContent.DATA, + entriesTableSchema, expectedDataManifests, "a"); + Assert.assertEquals("Should have one data file manifest entry", 1, expectedDataFiles.size()); + + List actualDataFiles = spark.sql("SELECT * FROM " + tableName + ".data_files " + + "WHERE partition.data='a'").collectAsList(); + Assert.assertEquals("Metadata table should return one data file", 1, actualDataFiles.size()); + TestHelpers.assertEqualsSafe(filesTableSchema.asStruct(), expectedDataFiles.get(0), actualDataFiles.get(0)); + + List actualPartitionsWithProjection = + spark.sql("SELECT file_count FROM " + tableName + ".partitions ").collectAsList(); + Assert.assertEquals("Metadata table should return two partitions record", 2, actualPartitionsWithProjection.size()); + for (int i = 0; i < 2; ++i) { + Assert.assertEquals(1, actualPartitionsWithProjection.get(i).get(0)); + } + + // Check files table + List expectedFiles = Stream.concat(expectedDataFiles.stream(), expectedDeleteFiles.stream()) + .collect(Collectors.toList()); + Assert.assertEquals("Should have two file manifest entries", 2, expectedFiles.size()); + + List actualFiles = spark.sql("SELECT * FROM " + tableName + ".files " + + "WHERE partition.data='a' ORDER BY content").collectAsList(); + Assert.assertEquals("Metadata table should return two files", 2, actualFiles.size()); + TestHelpers.assertEqualsSafe(filesTableSchema.asStruct(), expectedFiles.get(0), actualFiles.get(0)); + TestHelpers.assertEqualsSafe(filesTableSchema.asStruct(), expectedFiles.get(1), actualFiles.get(1)); + } + + @Test + public void testAllFilesUnpartitioned() throws Exception { + sql("CREATE TABLE %s (id bigint, data string) USING iceberg TBLPROPERTIES" + + "('format-version'='2', 'write.delete.mode'='merge-on-read')", tableName); + + List records = Lists.newArrayList( + new SimpleRecord(1, "a"), + new SimpleRecord(2, "b"), + new SimpleRecord(3, "c"), + new SimpleRecord(4, "d") + ); + spark.createDataset(records, Encoders.bean(SimpleRecord.class)) + .coalesce(1) + .writeTo(tableName) + .append(); + + // Create delete file + sql("DELETE FROM %s WHERE id=1", tableName); + + Table table = Spark3Util.loadIcebergTable(spark, tableName); + List expectedDataManifests = TestHelpers.dataManifests(table); + Assert.assertEquals("Should have 1 data manifest", 1, expectedDataManifests.size()); + List expectedDeleteManifests = TestHelpers.deleteManifests(table); + Assert.assertEquals("Should have 1 delete manifest", 1, expectedDeleteManifests.size()); + + // Clear table to test whether 'all_files' can read past files + List results = sql("DELETE FROM %s", tableName); + Assert.assertEquals("Table should be cleared", 0, results.size()); + + Schema entriesTableSchema = Spark3Util.loadIcebergTable(spark, tableName + ".entries").schema(); + Schema filesTableSchema = Spark3Util.loadIcebergTable(spark, tableName + ".all_data_files").schema(); + + // Check all data files table + List actualDataFiles = spark.sql("SELECT * FROM " + tableName + ".all_data_files").collectAsList(); + + List expectedDataFiles = expectedEntries(table, FileContent.DATA, + entriesTableSchema, expectedDataManifests, null); + Assert.assertEquals("Should be one data file manifest entry", 1, expectedDataFiles.size()); + Assert.assertEquals("Metadata table should return one data file", 1, actualDataFiles.size()); + TestHelpers.assertEqualsSafe(filesTableSchema.asStruct(), expectedDataFiles.get(0), actualDataFiles.get(0)); + + // Check all delete files table + List actualDeleteFiles = spark.sql("SELECT * FROM " + tableName + ".all_delete_files").collectAsList(); + List expectedDeleteFiles = expectedEntries(table, FileContent.POSITION_DELETES, + entriesTableSchema, expectedDeleteManifests, null); + Assert.assertEquals("Should be one delete file manifest entry", 1, expectedDeleteFiles.size()); + Assert.assertEquals("Metadata table should return one delete file", 1, actualDeleteFiles.size()); + TestHelpers.assertEqualsSafe(filesTableSchema.asStruct(), expectedDeleteFiles.get(0), actualDeleteFiles.get(0)); + + // Check all files table + List actualFiles = spark.sql("SELECT * FROM " + tableName + ".all_files ORDER BY content") + .collectAsList(); + List expectedFiles = ListUtils.union(expectedDataFiles, expectedDeleteFiles); + expectedFiles.sort(Comparator.comparing(r -> ((Integer) r.get("content")))); + Assert.assertEquals("Metadata table should return two files", 2, actualFiles.size()); + TestHelpers.assertEqualsSafe(filesTableSchema.asStruct(), expectedFiles, actualFiles); + } + + @Test + public void testAllFilesPartitioned() throws Exception { + // Create table and insert data + sql("CREATE TABLE %s (id bigint, data string) " + + "USING iceberg " + + "PARTITIONED BY (data) " + + "TBLPROPERTIES" + + "('format-version'='2', 'write.delete.mode'='merge-on-read')", tableName); + + List recordsA = Lists.newArrayList( + new SimpleRecord(1, "a"), + new SimpleRecord(2, "a") + ); + spark.createDataset(recordsA, Encoders.bean(SimpleRecord.class)) + .coalesce(1) + .writeTo(tableName) + .append(); + + List recordsB = Lists.newArrayList( + new SimpleRecord(1, "b"), + new SimpleRecord(2, "b") + ); + spark.createDataset(recordsB, Encoders.bean(SimpleRecord.class)) + .coalesce(1) + .writeTo(tableName) + .append(); + + // Create delete file + sql("DELETE FROM %s WHERE id=1", tableName); + + Table table = Spark3Util.loadIcebergTable(spark, tableName); + List expectedDataManifests = TestHelpers.dataManifests(table); + Assert.assertEquals("Should have 2 data manifests", 2, expectedDataManifests.size()); + List expectedDeleteManifests = TestHelpers.deleteManifests(table); + Assert.assertEquals("Should have 1 delete manifest", 1, expectedDeleteManifests.size()); + + // Clear table to test whether 'all_files' can read past files + List results = sql("DELETE FROM %s", tableName); + Assert.assertEquals("Table should be cleared", 0, results.size()); + + Schema entriesTableSchema = Spark3Util.loadIcebergTable(spark, tableName + ".entries").schema(); + Schema filesTableSchema = Spark3Util.loadIcebergTable(spark, tableName + ".all_data_files").schema(); + + // Check all data files table + List actualDataFiles = spark.sql("SELECT * FROM " + tableName + ".all_data_files " + + "WHERE partition.data='a'").collectAsList(); + List expectedDataFiles = expectedEntries(table, FileContent.DATA, + entriesTableSchema, expectedDataManifests, "a"); + Assert.assertEquals("Should be one data file manifest entry", 1, expectedDataFiles.size()); + Assert.assertEquals("Metadata table should return one data file", 1, actualDataFiles.size()); + TestHelpers.assertEqualsSafe(filesTableSchema.asStruct(), expectedDataFiles.get(0), actualDataFiles.get(0)); + + // Check all delete files table + List actualDeleteFiles = spark.sql("SELECT * FROM " + tableName + ".all_delete_files " + + "WHERE partition.data='a'").collectAsList(); + List expectedDeleteFiles = expectedEntries(table, FileContent.POSITION_DELETES, + entriesTableSchema, expectedDeleteManifests, "a"); + Assert.assertEquals("Should be one data file manifest entry", 1, expectedDeleteFiles.size()); + Assert.assertEquals("Metadata table should return one data file", 1, actualDeleteFiles.size()); + + TestHelpers.assertEqualsSafe(filesTableSchema.asStruct(), expectedDeleteFiles.get(0), actualDeleteFiles.get(0)); + + // Check all files table + List actualFiles = spark.sql("SELECT * FROM " + tableName + ".all_files WHERE partition.data='a' " + + "ORDER BY content").collectAsList(); + List expectedFiles = ListUtils.union(expectedDataFiles, expectedDeleteFiles); + expectedFiles.sort(Comparator.comparing(r -> ((Integer) r.get("content")))); + Assert.assertEquals("Metadata table should return two files", 2, actualFiles.size()); + TestHelpers.assertEqualsSafe(filesTableSchema.asStruct(), expectedFiles, actualFiles); + } + + /** + * Find matching manifest entries of an Iceberg table + * @param table iceberg table + * @param expectedContent file content to populate on entries + * @param entriesTableSchema schema of Manifest entries + * @param manifestsToExplore manifests to explore of the table + * @param partValue partition value that manifest entries must match, or null to skip filtering + */ + private List expectedEntries(Table table, FileContent expectedContent, Schema entriesTableSchema, + List manifestsToExplore, String partValue) throws IOException { + List expected = Lists.newArrayList(); + for (ManifestFile manifest : manifestsToExplore) { + InputFile in = table.io().newInputFile(manifest.path()); + try (CloseableIterable rows = Avro.read(in).project(entriesTableSchema).build()) { + for (Record record : rows) { + if ((Integer) record.get("status") < 2 /* added or existing */) { + Record file = (Record) record.get("data_file"); + if (partitionMatch(file, partValue)) { + asMetadataRecord(file, expectedContent); + expected.add(file); + } + } + } + } + } + return expected; + } + + // Populate certain fields derived in the metadata tables + private void asMetadataRecord(Record file, FileContent content) { + file.put(0, content.id()); + file.put(3, 0); // specId + } + + private boolean partitionMatch(Record file, String partValue) { + if (partValue == null) { + return true; + } + Record partition = (Record) file.get(4); + return partValue.equals(partition.get(0).toString()); + } +} diff --git a/spark/v3.3/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestMigrateTableProcedure.java b/spark/v3.3/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestMigrateTableProcedure.java new file mode 100644 index 000000000000..d66e75add16f --- /dev/null +++ b/spark/v3.3/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestMigrateTableProcedure.java @@ -0,0 +1,151 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iceberg.spark.extensions; + +import java.io.IOException; +import java.util.Map; +import org.apache.iceberg.AssertHelpers; +import org.apache.iceberg.Table; +import org.apache.iceberg.exceptions.ValidationException; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableList; +import org.apache.spark.sql.AnalysisException; +import org.junit.After; +import org.junit.Assert; +import org.junit.Assume; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.TemporaryFolder; + +public class TestMigrateTableProcedure extends SparkExtensionsTestBase { + + public TestMigrateTableProcedure(String catalogName, String implementation, Map config) { + super(catalogName, implementation, config); + } + + @Rule + public TemporaryFolder temp = new TemporaryFolder(); + + @After + public void removeTables() { + sql("DROP TABLE IF EXISTS %s", tableName); + sql("DROP TABLE IF EXISTS %s_BACKUP_", tableName); + } + + @Test + public void testMigrate() throws IOException { + Assume.assumeTrue(catalogName.equals("spark_catalog")); + String location = temp.newFolder().toString(); + sql("CREATE TABLE %s (id bigint NOT NULL, data string) USING parquet LOCATION '%s'", tableName, location); + sql("INSERT INTO TABLE %s VALUES (1, 'a')", tableName); + Object result = scalarSql("CALL %s.system.migrate('%s')", catalogName, tableName); + + Assert.assertEquals("Should have added one file", 1L, result); + + Table createdTable = validationCatalog.loadTable(tableIdent); + + String tableLocation = createdTable.location().replace("file:", ""); + Assert.assertEquals("Table should have original location", location, tableLocation); + + sql("INSERT INTO TABLE %s VALUES (1, 'a')", tableName); + + assertEquals("Should have expected rows", + ImmutableList.of(row(1L, "a"), row(1L, "a")), + sql("SELECT * FROM %s ORDER BY id", tableName)); + + sql("DROP TABLE %s", tableName + "_BACKUP_"); + } + + @Test + public void testMigrateWithOptions() throws IOException { + Assume.assumeTrue(catalogName.equals("spark_catalog")); + String location = temp.newFolder().toString(); + sql("CREATE TABLE %s (id bigint NOT NULL, data string) USING parquet LOCATION '%s'", tableName, location); + sql("INSERT INTO TABLE %s VALUES (1, 'a')", tableName); + + Object result = scalarSql("CALL %s.system.migrate('%s', map('foo', 'bar'))", catalogName, tableName); + + Assert.assertEquals("Should have added one file", 1L, result); + + Table createdTable = validationCatalog.loadTable(tableIdent); + + Map props = createdTable.properties(); + Assert.assertEquals("Should have extra property set", "bar", props.get("foo")); + + String tableLocation = createdTable.location().replace("file:", ""); + Assert.assertEquals("Table should have original location", location, tableLocation); + + sql("INSERT INTO TABLE %s VALUES (1, 'a')", tableName); + + assertEquals("Should have expected rows", + ImmutableList.of(row(1L, "a"), row(1L, "a")), + sql("SELECT * FROM %s ORDER BY id", tableName)); + + sql("DROP TABLE %s", tableName + "_BACKUP_"); + } + + @Test + public void testMigrateWithInvalidMetricsConfig() throws IOException { + Assume.assumeTrue(catalogName.equals("spark_catalog")); + + String location = temp.newFolder().toString(); + sql("CREATE TABLE %s (id bigint NOT NULL, data string) USING parquet LOCATION '%s'", tableName, location); + + AssertHelpers.assertThrows("Should reject invalid metrics config", + ValidationException.class, "Invalid metrics config", + () -> { + String props = "map('write.metadata.metrics.column.x', 'X')"; + sql("CALL %s.system.migrate('%s', %s)", catalogName, tableName, props); + }); + } + + @Test + public void testMigrateWithConflictingProps() throws IOException { + Assume.assumeTrue(catalogName.equals("spark_catalog")); + + String location = temp.newFolder().toString(); + sql("CREATE TABLE %s (id bigint NOT NULL, data string) USING parquet LOCATION '%s'", tableName, location); + sql("INSERT INTO TABLE %s VALUES (1, 'a')", tableName); + + Object result = scalarSql("CALL %s.system.migrate('%s', map('migrated', 'false'))", catalogName, tableName); + Assert.assertEquals("Should have added one file", 1L, result); + + assertEquals("Should have expected rows", + ImmutableList.of(row(1L, "a")), + sql("SELECT * FROM %s", tableName)); + + Table table = validationCatalog.loadTable(tableIdent); + Assert.assertEquals("Should override user value", "true", table.properties().get("migrated")); + } + + @Test + public void testInvalidMigrateCases() { + AssertHelpers.assertThrows("Should reject calls without all required args", + AnalysisException.class, "Missing required parameters", + () -> sql("CALL %s.system.migrate()", catalogName)); + + AssertHelpers.assertThrows("Should reject calls with invalid arg types", + AnalysisException.class, "Wrong arg type", + () -> sql("CALL %s.system.migrate(map('foo','bar'))", catalogName)); + + AssertHelpers.assertThrows("Should reject calls with empty table identifier", + IllegalArgumentException.class, "Cannot handle an empty identifier", + () -> sql("CALL %s.system.migrate('')", catalogName)); + } +} diff --git a/spark/v3.3/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestRegisterTableProcedure.java b/spark/v3.3/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestRegisterTableProcedure.java new file mode 100644 index 000000000000..d1f1905e098c --- /dev/null +++ b/spark/v3.3/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestRegisterTableProcedure.java @@ -0,0 +1,89 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iceberg.spark.extensions; + +import java.util.List; +import java.util.Map; +import org.apache.iceberg.HasTableOperations; +import org.apache.iceberg.Table; +import org.apache.iceberg.hive.HiveTableOperations; +import org.apache.iceberg.spark.Spark3Util; +import org.apache.spark.sql.catalyst.analysis.NoSuchTableException; +import org.apache.spark.sql.catalyst.parser.ParseException; +import org.apache.spark.sql.functions; +import org.apache.spark.sql.types.DataTypes; +import org.junit.After; +import org.junit.Assert; +import org.junit.Assume; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.TemporaryFolder; + +public class TestRegisterTableProcedure extends SparkExtensionsTestBase { + + private final String targetName; + + public TestRegisterTableProcedure( + String catalogName, + String implementation, + Map config) { + super(catalogName, implementation, config); + targetName = tableName("register_table"); + } + + @Rule + public TemporaryFolder temp = new TemporaryFolder(); + + @After + public void dropTables() { + sql("DROP TABLE IF EXISTS %s", tableName); + sql("DROP TABLE IF EXISTS %s", targetName); + } + + @Test + public void testRegisterTable() throws NoSuchTableException, ParseException { + Assume.assumeTrue("Register only implemented on Hive Catalogs", + spark.conf().get("spark.sql.catalog." + catalogName + ".type").equals("hive")); + + long numRows = 1000; + + sql("CREATE TABLE %s (id int, data string) using ICEBERG", tableName); + spark.range(0, numRows) + .withColumn("data", functions.col("id").cast(DataTypes.StringType)) + .writeTo(tableName) + .append(); + + Table table = Spark3Util.loadIcebergTable(spark, tableName); + long originalFileCount = (long) scalarSql("SELECT COUNT(*) from %s.files", tableName); + long currentSnapshotId = table.currentSnapshot().snapshotId(); + String metadataJson = ((HiveTableOperations) (((HasTableOperations) table).operations())).currentMetadataLocation(); + + List result = sql("CALL %s.system.register_table('%s', '%s')", catalogName, targetName, metadataJson); + Assert.assertEquals("Current Snapshot is not correct", currentSnapshotId, result.get(0)[0]); + + List original = sql("SELECT * FROM %s", tableName); + List registered = sql("SELECT * FROM %s", targetName); + assertEquals("Registered table rows should match original table rows", original, registered); + Assert.assertEquals("Should have the right row count in the procedure result", + numRows, result.get(0)[1]); + Assert.assertEquals("Should have the right datafile count in the procedure result", + originalFileCount, result.get(0)[2]); + } +} diff --git a/spark/v3.3/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestRemoveOrphanFilesProcedure.java b/spark/v3.3/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestRemoveOrphanFilesProcedure.java new file mode 100644 index 000000000000..e3a3bbf64b87 --- /dev/null +++ b/spark/v3.3/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestRemoveOrphanFilesProcedure.java @@ -0,0 +1,403 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iceberg.spark.extensions; + +import java.io.IOException; +import java.sql.Timestamp; +import java.time.Instant; +import java.util.List; +import java.util.Map; +import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.fs.FileSystem; +import org.apache.hadoop.fs.Path; +import org.apache.iceberg.AssertHelpers; +import org.apache.iceberg.Table; +import org.apache.iceberg.exceptions.ValidationException; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableList; +import org.apache.iceberg.relocated.com.google.common.collect.Lists; +import org.apache.iceberg.spark.Spark3Util; +import org.apache.iceberg.spark.data.TestHelpers; +import org.apache.iceberg.spark.source.SimpleRecord; +import org.apache.spark.sql.AnalysisException; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Encoders; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.catalyst.analysis.NoSuchProcedureException; +import org.junit.After; +import org.junit.Assert; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.TemporaryFolder; + +import static org.apache.iceberg.TableProperties.GC_ENABLED; +import static org.apache.iceberg.TableProperties.WRITE_AUDIT_PUBLISH_ENABLED; + +public class TestRemoveOrphanFilesProcedure extends SparkExtensionsTestBase { + + @Rule + public TemporaryFolder temp = new TemporaryFolder(); + + public TestRemoveOrphanFilesProcedure(String catalogName, String implementation, Map config) { + super(catalogName, implementation, config); + } + + @After + public void removeTable() { + sql("DROP TABLE IF EXISTS %s PURGE", tableName); + sql("DROP TABLE IF EXISTS p PURGE"); + } + + @Test + public void testRemoveOrphanFilesInEmptyTable() { + sql("CREATE TABLE %s (id bigint NOT NULL, data string) USING iceberg", tableName); + + List output = sql( + "CALL %s.system.remove_orphan_files('%s')", + catalogName, tableIdent); + assertEquals("Should be no orphan files", ImmutableList.of(), output); + + assertEquals("Should have no rows", + ImmutableList.of(), + sql("SELECT * FROM %s", tableName)); + } + + @Test + public void testRemoveOrphanFilesInDataFolder() throws IOException { + if (catalogName.equals("testhadoop")) { + sql("CREATE TABLE %s (id bigint NOT NULL, data string) USING iceberg", tableName); + } else { + // give a fresh location to Hive tables as Spark will not clean up the table location + // correctly while dropping tables through spark_catalog + sql("CREATE TABLE %s (id bigint NOT NULL, data string) USING iceberg LOCATION '%s'", + tableName, temp.newFolder()); + } + + sql("INSERT INTO TABLE %s VALUES (1, 'a')", tableName); + sql("INSERT INTO TABLE %s VALUES (2, 'b')", tableName); + + Table table = validationCatalog.loadTable(tableIdent); + + String metadataLocation = table.location() + "/metadata"; + String dataLocation = table.location() + "/data"; + + // produce orphan files in the data location using parquet + sql("CREATE TABLE p (id bigint) USING parquet LOCATION '%s'", dataLocation); + sql("INSERT INTO TABLE p VALUES (1)"); + + // wait to ensure files are old enough + waitUntilAfter(System.currentTimeMillis()); + + Timestamp currentTimestamp = Timestamp.from(Instant.ofEpochMilli(System.currentTimeMillis())); + + // check for orphans in the metadata folder + List output1 = sql( + "CALL %s.system.remove_orphan_files(" + + "table => '%s'," + + "older_than => TIMESTAMP '%s'," + + "location => '%s')", + catalogName, tableIdent, currentTimestamp, metadataLocation); + assertEquals("Should be no orphan files in the metadata folder", ImmutableList.of(), output1); + + // check for orphans in the table location + List output2 = sql( + "CALL %s.system.remove_orphan_files(" + + "table => '%s'," + + "older_than => TIMESTAMP '%s')", + catalogName, tableIdent, currentTimestamp); + Assert.assertEquals("Should be orphan files in the data folder", 1, output2.size()); + + // the previous call should have deleted all orphan files + List output3 = sql( + "CALL %s.system.remove_orphan_files(" + + "table => '%s'," + + "older_than => TIMESTAMP '%s')", + catalogName, tableIdent, currentTimestamp); + Assert.assertEquals("Should be no more orphan files in the data folder", 0, output3.size()); + + assertEquals("Should have expected rows", + ImmutableList.of(row(1L, "a"), row(2L, "b")), + sql("SELECT * FROM %s ORDER BY id", tableName)); + } + + @Test + public void testRemoveOrphanFilesDryRun() throws IOException { + if (catalogName.equals("testhadoop")) { + sql("CREATE TABLE %s (id bigint NOT NULL, data string) USING iceberg", tableName); + } else { + // give a fresh location to Hive tables as Spark will not clean up the table location + // correctly while dropping tables through spark_catalog + sql("CREATE TABLE %s (id bigint NOT NULL, data string) USING iceberg LOCATION '%s'", + tableName, temp.newFolder()); + } + + sql("INSERT INTO TABLE %s VALUES (1, 'a')", tableName); + sql("INSERT INTO TABLE %s VALUES (2, 'b')", tableName); + + Table table = validationCatalog.loadTable(tableIdent); + + // produce orphan files in the table location using parquet + sql("CREATE TABLE p (id bigint) USING parquet LOCATION '%s'", table.location()); + sql("INSERT INTO TABLE p VALUES (1)"); + + // wait to ensure files are old enough + waitUntilAfter(System.currentTimeMillis()); + + Timestamp currentTimestamp = Timestamp.from(Instant.ofEpochMilli(System.currentTimeMillis())); + + // check for orphans without deleting + List output1 = sql( + "CALL %s.system.remove_orphan_files(" + + "table => '%s'," + + "older_than => TIMESTAMP '%s'," + + "dry_run => true)", + catalogName, tableIdent, currentTimestamp); + Assert.assertEquals("Should be one orphan files", 1, output1.size()); + + // actually delete orphans + List output2 = sql( + "CALL %s.system.remove_orphan_files(" + + "table => '%s'," + + "older_than => TIMESTAMP '%s')", + catalogName, tableIdent, currentTimestamp); + Assert.assertEquals("Should be one orphan files", 1, output2.size()); + + // the previous call should have deleted all orphan files + List output3 = sql( + "CALL %s.system.remove_orphan_files(" + + "table => '%s'," + + "older_than => TIMESTAMP '%s')", + catalogName, tableIdent, currentTimestamp); + Assert.assertEquals("Should be no more orphan files", 0, output3.size()); + + assertEquals("Should have expected rows", + ImmutableList.of(row(1L, "a"), row(2L, "b")), + sql("SELECT * FROM %s ORDER BY id", tableName)); + } + + @Test + public void testRemoveOrphanFilesGCDisabled() { + sql("CREATE TABLE %s (id bigint NOT NULL, data string) USING iceberg", tableName); + + sql("ALTER TABLE %s SET TBLPROPERTIES ('%s' 'false')", tableName, GC_ENABLED); + + AssertHelpers.assertThrows("Should reject call", + ValidationException.class, "Cannot delete orphan files: GC is disabled", + () -> sql("CALL %s.system.remove_orphan_files('%s')", catalogName, tableIdent)); + + // reset the property to enable the table purging in removeTable. + sql("ALTER TABLE %s SET TBLPROPERTIES ('%s' 'true')", tableName, GC_ENABLED); + } + + @Test + public void testRemoveOrphanFilesWap() { + sql("CREATE TABLE %s (id bigint NOT NULL, data string) USING iceberg", tableName); + sql("ALTER TABLE %s SET TBLPROPERTIES ('%s' 'true')", tableName, WRITE_AUDIT_PUBLISH_ENABLED); + + spark.conf().set("spark.wap.id", "1"); + + sql("INSERT INTO TABLE %s VALUES (1, 'a')", tableName); + + assertEquals("Should not see rows from staged snapshot", + ImmutableList.of(), + sql("SELECT * FROM %s", tableName)); + + List output = sql( + "CALL %s.system.remove_orphan_files('%s')", catalogName, tableIdent); + assertEquals("Should be no orphan files", ImmutableList.of(), output); + } + + @Test + public void testInvalidRemoveOrphanFilesCases() { + AssertHelpers.assertThrows("Should not allow mixed args", + AnalysisException.class, "Named and positional arguments cannot be mixed", + () -> sql("CALL %s.system.remove_orphan_files('n', table => 't')", catalogName)); + + AssertHelpers.assertThrows("Should not resolve procedures in arbitrary namespaces", + NoSuchProcedureException.class, "not found", + () -> sql("CALL %s.custom.remove_orphan_files('n', 't')", catalogName)); + + AssertHelpers.assertThrows("Should reject calls without all required args", + AnalysisException.class, "Missing required parameters", + () -> sql("CALL %s.system.remove_orphan_files()", catalogName)); + + AssertHelpers.assertThrows("Should reject calls with invalid arg types", + AnalysisException.class, "Wrong arg type", + () -> sql("CALL %s.system.remove_orphan_files('n', 2.2)", catalogName)); + + AssertHelpers.assertThrows("Should reject calls with empty table identifier", + IllegalArgumentException.class, "Cannot handle an empty identifier", + () -> sql("CALL %s.system.remove_orphan_files('')", catalogName)); + } + + @Test + public void testConcurrentRemoveOrphanFiles() throws IOException { + if (catalogName.equals("testhadoop")) { + sql("CREATE TABLE %s (id bigint NOT NULL, data string) USING iceberg", tableName); + } else { + // give a fresh location to Hive tables as Spark will not clean up the table location + // correctly while dropping tables through spark_catalog + sql("CREATE TABLE %s (id bigint NOT NULL, data string) USING iceberg LOCATION '%s'", + tableName, temp.newFolder()); + } + + sql("INSERT INTO TABLE %s VALUES (1, 'a')", tableName); + sql("INSERT INTO TABLE %s VALUES (2, 'b')", tableName); + + Table table = validationCatalog.loadTable(tableIdent); + + String metadataLocation = table.location() + "/metadata"; + String dataLocation = table.location() + "/data"; + + // produce orphan files in the data location using parquet + sql("CREATE TABLE p (id bigint) USING parquet LOCATION '%s'", dataLocation); + sql("INSERT INTO TABLE p VALUES (1)"); + sql("INSERT INTO TABLE p VALUES (10)"); + sql("INSERT INTO TABLE p VALUES (100)"); + sql("INSERT INTO TABLE p VALUES (1000)"); + + // wait to ensure files are old enough + waitUntilAfter(System.currentTimeMillis()); + + Timestamp currentTimestamp = Timestamp.from(Instant.ofEpochMilli(System.currentTimeMillis())); + + // check for orphans in the table location + List output = sql( + "CALL %s.system.remove_orphan_files(" + + "table => '%s'," + + "max_concurrent_deletes => %s," + + "older_than => TIMESTAMP '%s')", + catalogName, tableIdent, 4, currentTimestamp); + Assert.assertEquals("Should be orphan files in the data folder", 4, output.size()); + + // the previous call should have deleted all orphan files + List output3 = sql( + "CALL %s.system.remove_orphan_files(" + + "table => '%s'," + + "max_concurrent_deletes => %s," + + "older_than => TIMESTAMP '%s')", + catalogName, tableIdent, 4, currentTimestamp); + Assert.assertEquals("Should be no more orphan files in the data folder", 0, output3.size()); + + assertEquals( + "Should have expected rows", + ImmutableList.of(row(1L, "a"), row(2L, "b")), + sql("SELECT * FROM %s ORDER BY id", tableName)); + } + + @Test + public void testConcurrentRemoveOrphanFilesWithInvalidInput() { + sql("CREATE TABLE %s (id bigint NOT NULL, data string) USING iceberg", tableName); + + AssertHelpers.assertThrows("Should throw an error when max_concurrent_deletes = 0", + IllegalArgumentException.class, "max_concurrent_deletes should have value > 0", + () -> sql("CALL %s.system.remove_orphan_files(table => '%s', max_concurrent_deletes => %s)", + catalogName, tableIdent, 0)); + + AssertHelpers.assertThrows("Should throw an error when max_concurrent_deletes < 0 ", + IllegalArgumentException.class, "max_concurrent_deletes should have value > 0", + () -> sql( + "CALL %s.system.remove_orphan_files(table => '%s', max_concurrent_deletes => %s)", + catalogName, tableIdent, -1)); + + String tempViewName = "file_list_test"; + spark.emptyDataFrame().createOrReplaceTempView(tempViewName); + + AssertHelpers.assertThrows( + "Should throw an error if file_list_view is missing required columns", + IllegalArgumentException.class, + "does not exist. Available:", + () -> + sql( + "CALL %s.system.remove_orphan_files(table => '%s', file_list_view => '%s')", + catalogName, tableIdent, tempViewName)); + + spark + .createDataset(Lists.newArrayList(), Encoders.tuple(Encoders.INT(), Encoders.TIMESTAMP())) + .toDF("file_path", "last_modified") + .createOrReplaceTempView(tempViewName); + + AssertHelpers.assertThrows( + "Should throw an error if file_path has wrong type", + IllegalArgumentException.class, + "Invalid file_path column", + () -> + sql( + "CALL %s.system.remove_orphan_files(table => '%s', file_list_view => '%s')", + catalogName, tableIdent, tempViewName)); + + spark + .createDataset(Lists.newArrayList(), Encoders.tuple(Encoders.STRING(), Encoders.STRING())) + .toDF("file_path", "last_modified") + .createOrReplaceTempView(tempViewName); + + AssertHelpers.assertThrows( + "Should throw an error if last_modified has wrong type", + IllegalArgumentException.class, + "Invalid last_modified column", + () -> + sql( + "CALL %s.system.remove_orphan_files(table => '%s', file_list_view => '%s')", + catalogName, tableIdent, tempViewName)); + } + + @Test + public void testRemoveOrphanFilesWithDeleteFiles() throws Exception { + sql("CREATE TABLE %s (id int, data string) USING iceberg TBLPROPERTIES" + + "('format-version'='2', 'write.delete.mode'='merge-on-read')", tableName); + + List records = Lists.newArrayList( + new SimpleRecord(1, "a"), + new SimpleRecord(2, "b"), + new SimpleRecord(3, "c"), + new SimpleRecord(4, "d") + ); + spark.createDataset(records, Encoders.bean(SimpleRecord.class)).coalesce(1).writeTo(tableName).append(); + sql("DELETE FROM %s WHERE id=1", tableName); + + Table table = Spark3Util.loadIcebergTable(spark, tableName); + Assert.assertEquals("Should have 1 delete manifest", 1, TestHelpers.deleteManifests(table).size()); + Assert.assertEquals("Should have 1 delete file", 1, TestHelpers.deleteFiles(table).size()); + Path deleteManifestPath = new Path(TestHelpers.deleteManifests(table).iterator().next().path()); + Path deleteFilePath = new Path(String.valueOf(TestHelpers.deleteFiles(table).iterator().next().path())); + + // wait to ensure files are old enough + waitUntilAfter(System.currentTimeMillis()); + Timestamp currentTimestamp = Timestamp.from(Instant.ofEpochMilli(System.currentTimeMillis())); + + // delete orphans + List output = sql( + "CALL %s.system.remove_orphan_files(" + + "table => '%s'," + + "older_than => TIMESTAMP '%s')", + catalogName, tableIdent, currentTimestamp); + Assert.assertEquals("Should be no orphan files", 0, output.size()); + + FileSystem localFs = FileSystem.getLocal(new Configuration()); + Assert.assertTrue("Delete manifest should still exist", localFs.exists(deleteManifestPath)); + Assert.assertTrue("Delete file should still exist", localFs.exists(deleteFilePath)); + + records.remove(new SimpleRecord(1, "a")); + Dataset resultDF = spark.read().format("iceberg").load(tableName); + List actualRecords = resultDF + .as(Encoders.bean(SimpleRecord.class)) + .collectAsList(); + Assert.assertEquals("Rows must match", records, actualRecords); + } +} diff --git a/spark/v3.3/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestRequiredDistributionAndOrdering.java b/spark/v3.3/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestRequiredDistributionAndOrdering.java new file mode 100644 index 000000000000..1d3ccc766288 --- /dev/null +++ b/spark/v3.3/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestRequiredDistributionAndOrdering.java @@ -0,0 +1,280 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iceberg.spark.extensions; + +import java.math.BigDecimal; +import java.util.List; +import java.util.Map; +import org.apache.iceberg.AssertHelpers; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableList; +import org.apache.iceberg.spark.SparkWriteOptions; +import org.apache.iceberg.spark.source.ThreeColumnRecord; +import org.apache.spark.SparkException; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.catalyst.analysis.NoSuchTableException; +import org.junit.After; +import org.junit.Test; + +public class TestRequiredDistributionAndOrdering extends SparkExtensionsTestBase { + + public TestRequiredDistributionAndOrdering(String catalogName, String implementation, Map config) { + super(catalogName, implementation, config); + } + + @After + public void dropTestTable() { + sql("DROP TABLE IF EXISTS %s", tableName); + } + + @Test + public void testDefaultLocalSortWithBucketTransforms() throws NoSuchTableException { + sql("CREATE TABLE %s (c1 INT, c2 STRING, c3 STRING) " + + "USING iceberg " + + "PARTITIONED BY (bucket(2, c1))", tableName); + + List data = ImmutableList.of( + new ThreeColumnRecord(1, null, "A"), + new ThreeColumnRecord(2, "BBBBBBBBBB", "B"), + new ThreeColumnRecord(3, "BBBBBBBBBB", "A"), + new ThreeColumnRecord(4, "BBBBBBBBBB", "B"), + new ThreeColumnRecord(5, "BBBBBBBBBB", "A"), + new ThreeColumnRecord(6, "BBBBBBBBBB", "B"), + new ThreeColumnRecord(7, "BBBBBBBBBB", "A") + ); + Dataset ds = spark.createDataFrame(data, ThreeColumnRecord.class); + Dataset inputDF = ds.coalesce(1).sortWithinPartitions("c1"); + + // should insert a local sort by partition columns by default + inputDF.writeTo(tableName).append(); + + assertEquals("Row count must match", + ImmutableList.of(row(7L)), + sql("SELECT count(*) FROM %s", tableName)); + } + + @Test + public void testPartitionColumnsArePrependedForRangeDistribution() throws NoSuchTableException { + sql("CREATE TABLE %s (c1 INT, c2 STRING, c3 STRING) " + + "USING iceberg " + + "PARTITIONED BY (bucket(2, c1))", tableName); + + List data = ImmutableList.of( + new ThreeColumnRecord(1, null, "A"), + new ThreeColumnRecord(2, "BBBBBBBBBB", "B"), + new ThreeColumnRecord(3, "BBBBBBBBBB", "A"), + new ThreeColumnRecord(4, "BBBBBBBBBB", "B"), + new ThreeColumnRecord(5, "BBBBBBBBBB", "A"), + new ThreeColumnRecord(6, "BBBBBBBBBB", "B"), + new ThreeColumnRecord(7, "BBBBBBBBBB", "A") + ); + Dataset ds = spark.createDataFrame(data, ThreeColumnRecord.class); + Dataset inputDF = ds.coalesce(1).sortWithinPartitions("c1"); + + // should automatically prepend partition columns to the ordering + sql("ALTER TABLE %s WRITE ORDERED BY c1, c2", tableName); + + inputDF.writeTo(tableName).append(); + + assertEquals("Row count must match", + ImmutableList.of(row(7L)), + sql("SELECT count(*) FROM %s", tableName)); + } + + @Test + public void testSortOrderIncludesPartitionColumns() throws NoSuchTableException { + sql("CREATE TABLE %s (c1 INT, c2 STRING, c3 STRING) " + + "USING iceberg " + + "PARTITIONED BY (bucket(2, c1))", tableName); + + List data = ImmutableList.of( + new ThreeColumnRecord(1, null, "A"), + new ThreeColumnRecord(2, "BBBBBBBBBB", "B"), + new ThreeColumnRecord(3, "BBBBBBBBBB", "A"), + new ThreeColumnRecord(4, "BBBBBBBBBB", "B"), + new ThreeColumnRecord(5, "BBBBBBBBBB", "A"), + new ThreeColumnRecord(6, "BBBBBBBBBB", "B"), + new ThreeColumnRecord(7, "BBBBBBBBBB", "A") + ); + Dataset ds = spark.createDataFrame(data, ThreeColumnRecord.class); + Dataset inputDF = ds.coalesce(1).sortWithinPartitions("c1"); + + // should succeed with a correct sort order + sql("ALTER TABLE %s WRITE ORDERED BY bucket(2, c3), c1, c2", tableName); + + inputDF.writeTo(tableName).append(); + + assertEquals("Row count must match", + ImmutableList.of(row(7L)), + sql("SELECT count(*) FROM %s", tableName)); + } + + @Test + public void testHashDistributionOnBucketedColumn() throws NoSuchTableException { + sql("CREATE TABLE %s (c1 INT, c2 STRING, c3 STRING) " + + "USING iceberg " + + "PARTITIONED BY (bucket(2, c1))", tableName); + + List data = ImmutableList.of( + new ThreeColumnRecord(1, null, "A"), + new ThreeColumnRecord(2, "BBBBBBBBBB", "B"), + new ThreeColumnRecord(3, "BBBBBBBBBB", "A"), + new ThreeColumnRecord(4, "BBBBBBBBBB", "B"), + new ThreeColumnRecord(5, "BBBBBBBBBB", "A"), + new ThreeColumnRecord(6, "BBBBBBBBBB", "B"), + new ThreeColumnRecord(7, "BBBBBBBBBB", "A") + ); + Dataset ds = spark.createDataFrame(data, ThreeColumnRecord.class); + Dataset inputDF = ds.coalesce(1).sortWithinPartitions("c1"); + + // should automatically prepend partition columns to the local ordering after hash distribution + sql("ALTER TABLE %s WRITE DISTRIBUTED BY PARTITION ORDERED BY c1, c2", tableName); + + inputDF.writeTo(tableName).append(); + + assertEquals("Row count must match", + ImmutableList.of(row(7L)), + sql("SELECT count(*) FROM %s", tableName)); + } + + @Test + public void testDisabledDistributionAndOrdering() { + sql("CREATE TABLE %s (c1 INT, c2 STRING, c3 STRING) " + + "USING iceberg " + + "PARTITIONED BY (bucket(2, c1))", tableName); + + List data = ImmutableList.of( + new ThreeColumnRecord(1, null, "A"), + new ThreeColumnRecord(2, "BBBBBBBBBB", "B"), + new ThreeColumnRecord(3, "BBBBBBBBBB", "A"), + new ThreeColumnRecord(4, "BBBBBBBBBB", "B"), + new ThreeColumnRecord(5, "BBBBBBBBBB", "A"), + new ThreeColumnRecord(6, "BBBBBBBBBB", "B"), + new ThreeColumnRecord(7, "BBBBBBBBBB", "A") + ); + Dataset ds = spark.createDataFrame(data, ThreeColumnRecord.class); + Dataset inputDF = ds.coalesce(1).sortWithinPartitions("c1"); + + // should fail if ordering is disabled + AssertHelpers.assertThrows("Should reject writes without ordering", + SparkException.class, "Writing job aborted", + () -> { + try { + inputDF.writeTo(tableName) + .option(SparkWriteOptions.USE_TABLE_DISTRIBUTION_AND_ORDERING, "false") + .append(); + } catch (NoSuchTableException e) { + throw new RuntimeException(e); + } + }); + } + + @Test + public void testDefaultSortOnDecimalBucketedColumn() { + sql("CREATE TABLE %s (c1 INT, c2 DECIMAL(20, 2)) " + + "USING iceberg " + + "PARTITIONED BY (bucket(2, c2))", tableName); + + sql("INSERT INTO %s VALUES (1, 20.2), (2, 40.2), (3, 60.2)", tableName); + + List expected = ImmutableList.of( + row(1, new BigDecimal("20.20")), + row(2, new BigDecimal("40.20")), + row(3, new BigDecimal("60.20")) + ); + + assertEquals("Rows must match", expected, sql("SELECT * FROM %s ORDER BY c1", tableName)); + } + + @Test + public void testDefaultSortOnStringBucketedColumn() { + sql("CREATE TABLE %s (c1 INT, c2 STRING) " + + "USING iceberg " + + "PARTITIONED BY (bucket(2, c2))", tableName); + + sql("INSERT INTO %s VALUES (1, 'A'), (2, 'B')", tableName); + + List expected = ImmutableList.of( + row(1, "A"), + row(2, "B") + ); + + assertEquals("Rows must match", expected, sql("SELECT * FROM %s ORDER BY c1", tableName)); + } + + @Test + public void testDefaultSortOnDecimalTruncatedColumn() { + sql("CREATE TABLE %s (c1 INT, c2 DECIMAL(20, 2)) " + + "USING iceberg " + + "PARTITIONED BY (truncate(2, c2))", tableName); + + sql("INSERT INTO %s VALUES (1, 20.2), (2, 40.2)", tableName); + + List expected = ImmutableList.of( + row(1, new BigDecimal("20.20")), + row(2, new BigDecimal("40.20")) + ); + + assertEquals("Rows must match", expected, sql("SELECT * FROM %s ORDER BY c1", tableName)); + } + + @Test + public void testDefaultSortOnLongTruncatedColumn() { + sql("CREATE TABLE %s (c1 INT, c2 BIGINT) " + + "USING iceberg " + + "PARTITIONED BY (truncate(2, c2))", tableName); + + sql("INSERT INTO %s VALUES (1, 22222222222222), (2, 444444444444)", tableName); + + List expected = ImmutableList.of( + row(1, 22222222222222L), + row(2, 444444444444L) + ); + + assertEquals("Rows must match", expected, sql("SELECT * FROM %s ORDER BY c1", tableName)); + } + + @Test + public void testRangeDistributionWithQuotedColumnNames() throws NoSuchTableException { + sql("CREATE TABLE %s (`c.1` INT, c2 STRING, c3 STRING) " + + "USING iceberg " + + "PARTITIONED BY (bucket(2, `c.1`))", tableName); + + List data = ImmutableList.of( + new ThreeColumnRecord(1, null, "A"), + new ThreeColumnRecord(2, "BBBBBBBBBB", "B"), + new ThreeColumnRecord(3, "BBBBBBBBBB", "A"), + new ThreeColumnRecord(4, "BBBBBBBBBB", "B"), + new ThreeColumnRecord(5, "BBBBBBBBBB", "A"), + new ThreeColumnRecord(6, "BBBBBBBBBB", "B"), + new ThreeColumnRecord(7, "BBBBBBBBBB", "A") + ); + Dataset ds = spark.createDataFrame(data, ThreeColumnRecord.class); + Dataset inputDF = ds.selectExpr("c1 as `c.1`", "c2", "c3").coalesce(1).sortWithinPartitions("`c.1`"); + + sql("ALTER TABLE %s WRITE ORDERED BY `c.1`, c2", tableName); + + inputDF.writeTo(tableName).append(); + + assertEquals("Row count must match", + ImmutableList.of(row(7L)), + sql("SELECT count(*) FROM %s", tableName)); + } +} diff --git a/spark/v3.3/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestRewriteDataFilesProcedure.java b/spark/v3.3/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestRewriteDataFilesProcedure.java new file mode 100644 index 000000000000..a6fa5e45eb7e --- /dev/null +++ b/spark/v3.3/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestRewriteDataFilesProcedure.java @@ -0,0 +1,349 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iceberg.spark.extensions; + +import java.util.List; +import java.util.Map; +import java.util.stream.IntStream; +import org.apache.iceberg.AssertHelpers; +import org.apache.iceberg.exceptions.ValidationException; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableList; +import org.apache.iceberg.relocated.com.google.common.collect.Lists; +import org.apache.iceberg.spark.source.ThreeColumnRecord; +import org.apache.spark.sql.AnalysisException; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.catalyst.analysis.NoSuchProcedureException; +import org.junit.After; +import org.junit.Test; + + +public class TestRewriteDataFilesProcedure extends SparkExtensionsTestBase { + + public TestRewriteDataFilesProcedure(String catalogName, String implementation, Map config) { + super(catalogName, implementation, config); + } + + @After + public void removeTable() { + sql("DROP TABLE IF EXISTS %s", tableName); + } + + @Test + public void testRewriteDataFilesInEmptyTable() { + createTable(); + List output = sql( + "CALL %s.system.rewrite_data_files('%s')", catalogName, tableIdent); + assertEquals("Procedure output must match", + ImmutableList.of(row(0, 0)), + output); + } + + @Test + public void testRewriteDataFilesOnPartitionTable() { + createPartitionTable(); + // create 5 files for each partition (c2 = 'foo' and c2 = 'bar') + insertData(10); + List expectedRecords = currentData(); + + List output = sql( + "CALL %s.system.rewrite_data_files(table => '%s')", catalogName, tableIdent); + + assertEquals("Action should rewrite 10 data files and add 2 data files (one per partition) ", + ImmutableList.of(row(10, 2)), + output); + + List actualRecords = currentData(); + assertEquals("Data after compaction should not change", expectedRecords, actualRecords); + } + + @Test + public void testRewriteDataFilesOnNonPartitionTable() { + createTable(); + // create 10 files under non-partitioned table + insertData(10); + List expectedRecords = currentData(); + + List output = sql( + "CALL %s.system.rewrite_data_files(table => '%s')", catalogName, tableIdent); + + assertEquals("Action should rewrite 10 data files and add 1 data files", + ImmutableList.of(row(10, 1)), + output); + + List actualRecords = currentData(); + assertEquals("Data after compaction should not change", expectedRecords, actualRecords); + } + + @Test + public void testRewriteDataFilesWithOptions() { + createTable(); + // create 10 files under non-partitioned table + insertData(10); + List expectedRecords = currentData(); + + // set the min-input-files = 12, instead of default 5 to skip compacting the files. + List output = sql( + "CALL %s.system.rewrite_data_files(table => '%s', options => map('min-input-files','12'))", + catalogName, tableIdent); + + assertEquals("Action should rewrite 0 data files and add 0 data files", + ImmutableList.of(row(0, 0)), + output); + + List actualRecords = currentData(); + assertEquals("Data should not change", expectedRecords, actualRecords); + } + + @Test + public void testRewriteDataFilesWithSortStrategy() { + createTable(); + // create 10 files under non-partitioned table + insertData(10); + List expectedRecords = currentData(); + + // set sort_order = c1 DESC LAST + List output = sql( + "CALL %s.system.rewrite_data_files(table => '%s', " + + "strategy => 'sort', sort_order => 'c1 DESC NULLS LAST')", + catalogName, tableIdent); + + assertEquals("Action should rewrite 10 data files and add 1 data files", + ImmutableList.of(row(10, 1)), + output); + + List actualRecords = currentData(); + assertEquals("Data after compaction should not change", expectedRecords, actualRecords); + } + + @Test + public void testRewriteDataFilesWithFilter() { + createTable(); + // create 10 files under non-partitioned table + insertData(10); + List expectedRecords = currentData(); + + // select only 5 files for compaction (files that may have c1 = 1) + List output = sql( + "CALL %s.system.rewrite_data_files(table => '%s'," + + " where => 'c1 = 1 and c2 is not null')", catalogName, tableIdent); + + assertEquals("Action should rewrite 5 data files (containing c1 = 1) and add 1 data files", + ImmutableList.of(row(5, 1)), + output); + + List actualRecords = currentData(); + assertEquals("Data after compaction should not change", expectedRecords, actualRecords); + } + + @Test + public void testRewriteDataFilesWithFilterOnPartitionTable() { + createPartitionTable(); + // create 5 files for each partition (c2 = 'foo' and c2 = 'bar') + insertData(10); + List expectedRecords = currentData(); + + // select only 5 files for compaction (files in the partition c2 = 'bar') + List output = sql( + "CALL %s.system.rewrite_data_files(table => '%s'," + + " where => 'c2 = \"bar\"')", catalogName, tableIdent); + + assertEquals("Action should rewrite 5 data files from single matching partition" + + "(containing c2 = bar) and add 1 data files", + ImmutableList.of(row(5, 1)), + output); + + List actualRecords = currentData(); + assertEquals("Data after compaction should not change", expectedRecords, actualRecords); + } + + @Test + public void testRewriteDataFilesWithInFilterOnPartitionTable() { + createPartitionTable(); + // create 5 files for each partition (c2 = 'foo' and c2 = 'bar') + insertData(10); + List expectedRecords = currentData(); + + // select only 5 files for compaction (files in the partition c2 in ('bar')) + List output = sql( + "CALL %s.system.rewrite_data_files(table => '%s'," + + " where => 'c2 in (\"bar\")')", catalogName, tableIdent); + + assertEquals("Action should rewrite 5 data files from single matching partition" + + "(containing c2 = bar) and add 1 data files", + ImmutableList.of(row(5, 1)), + output); + + List actualRecords = currentData(); + assertEquals("Data after compaction should not change", expectedRecords, actualRecords); + } + + @Test + public void testRewriteDataFilesWithAllPossibleFilters() { + createPartitionTable(); + // create 5 files for each partition (c2 = 'foo' and c2 = 'bar') + insertData(10); + + // Pass the literal value which is not present in the data files. + // So that parsing can be tested on a same dataset without actually compacting the files. + + // EqualTo + sql("CALL %s.system.rewrite_data_files(table => '%s'," + + " where => 'c1 = 3')", catalogName, tableIdent); + // GreaterThan + sql("CALL %s.system.rewrite_data_files(table => '%s'," + + " where => 'c1 > 3')", catalogName, tableIdent); + // GreaterThanOrEqual + sql("CALL %s.system.rewrite_data_files(table => '%s'," + + " where => 'c1 >= 3')", catalogName, tableIdent); + // LessThan + sql("CALL %s.system.rewrite_data_files(table => '%s'," + + " where => 'c1 < 0')", catalogName, tableIdent); + // LessThanOrEqual + sql("CALL %s.system.rewrite_data_files(table => '%s'," + + " where => 'c1 <= 0')", catalogName, tableIdent); + // In + sql("CALL %s.system.rewrite_data_files(table => '%s'," + + " where => 'c1 in (3,4,5)')", catalogName, tableIdent); + // IsNull + sql("CALL %s.system.rewrite_data_files(table => '%s'," + + " where => 'c1 is null')", catalogName, tableIdent); + // IsNotNull + sql("CALL %s.system.rewrite_data_files(table => '%s'," + + " where => 'c3 is not null')", catalogName, tableIdent); + // And + sql("CALL %s.system.rewrite_data_files(table => '%s'," + + " where => 'c1 = 3 and c2 = \"bar\"')", catalogName, tableIdent); + // Or + sql("CALL %s.system.rewrite_data_files(table => '%s'," + + " where => 'c1 = 3 or c1 = 5')", catalogName, tableIdent); + // Not + sql("CALL %s.system.rewrite_data_files(table => '%s'," + + " where => 'c1 not in (1,2)')", catalogName, tableIdent); + // StringStartsWith + sql("CALL %s.system.rewrite_data_files(table => '%s'," + + " where => 'c2 like \"%s\"')", catalogName, tableIdent, "car%"); + + // TODO: Enable when org.apache.iceberg.spark.SparkFilters have implementations for StringEndsWith & StringContains + // StringEndsWith + // sql("CALL %s.system.rewrite_data_files(table => '%s'," + + // " where => 'c2 like \"%s\"')", catalogName, tableIdent, "%car"); + // StringContains + // sql("CALL %s.system.rewrite_data_files(table => '%s'," + + // " where => 'c2 like \"%s\"')", catalogName, tableIdent, "%car%"); + } + + @Test + public void testRewriteDataFilesWithInvalidInputs() { + createTable(); + // create 2 files under non-partitioned table + insertData(2); + + // Test for invalid strategy + AssertHelpers.assertThrows("Should reject calls with unsupported strategy error message", + IllegalArgumentException.class, "unsupported strategy: temp. Only binpack,sort is supported", + () -> sql("CALL %s.system.rewrite_data_files(table => '%s', options => map('min-input-files','2'), " + + "strategy => 'temp')", catalogName, tableIdent)); + + // Test for sort_order with binpack strategy + AssertHelpers.assertThrows("Should reject calls with error message", + IllegalArgumentException.class, "Cannot set strategy to sort, it has already been set", + () -> sql("CALL %s.system.rewrite_data_files(table => '%s', strategy => 'binpack', " + + "sort_order => 'c1 ASC NULLS FIRST')", catalogName, tableIdent)); + + // Test for sort_order with invalid null order + AssertHelpers.assertThrows("Should reject calls with error message", + IllegalArgumentException.class, "Unable to parse sortOrder:", + () -> sql("CALL %s.system.rewrite_data_files(table => '%s', strategy => 'sort', " + + "sort_order => 'c1 ASC none')", catalogName, tableIdent)); + + // Test for sort_order with invalid sort direction + AssertHelpers.assertThrows("Should reject calls with error message", + IllegalArgumentException.class, "Unable to parse sortOrder:", + () -> sql("CALL %s.system.rewrite_data_files(table => '%s', strategy => 'sort', " + + "sort_order => 'c1 none NULLS FIRST')", catalogName, tableIdent)); + + // Test for sort_order with invalid column name + AssertHelpers.assertThrows("Should reject calls with error message", + ValidationException.class, "Cannot find field 'col1' in struct:" + + " struct<1: c1: optional int, 2: c2: optional string, 3: c3: optional string>", + () -> sql("CALL %s.system.rewrite_data_files(table => '%s', strategy => 'sort', " + + "sort_order => 'col1 DESC NULLS FIRST')", catalogName, tableIdent)); + + // Test for sort_order with invalid filter column col1 + AssertHelpers.assertThrows("Should reject calls with error message", + IllegalArgumentException.class, "Cannot parse predicates in where option: col1 = 3", + () -> sql("CALL %s.system.rewrite_data_files(table => '%s', " + + "where => 'col1 = 3')", catalogName, tableIdent)); + } + + @Test + public void testInvalidCasesForRewriteDataFiles() { + AssertHelpers.assertThrows("Should not allow mixed args", + AnalysisException.class, "Named and positional arguments cannot be mixed", + () -> sql("CALL %s.system.rewrite_data_files('n', table => 't')", catalogName)); + + AssertHelpers.assertThrows("Should not resolve procedures in arbitrary namespaces", + NoSuchProcedureException.class, "not found", + () -> sql("CALL %s.custom.rewrite_data_files('n', 't')", catalogName)); + + AssertHelpers.assertThrows("Should reject calls without all required args", + AnalysisException.class, "Missing required parameters", + () -> sql("CALL %s.system.rewrite_data_files()", catalogName)); + + AssertHelpers.assertThrows("Should reject duplicate arg names name", + AnalysisException.class, "Duplicate procedure argument: table", + () -> sql("CALL %s.system.rewrite_data_files(table => 't', table => 't')", catalogName)); + + AssertHelpers.assertThrows("Should reject calls with empty table identifier", + IllegalArgumentException.class, "Cannot handle an empty identifier", + () -> sql("CALL %s.system.rewrite_data_files('')", catalogName)); + } + + private void createTable() { + sql("CREATE TABLE %s (c1 int, c2 string, c3 string) USING iceberg", tableName); + } + + private void createPartitionTable() { + sql("CREATE TABLE %s (c1 int, c2 string, c3 string) USING iceberg PARTITIONED BY (c2)", tableName); + } + + private void insertData(int filesCount) { + ThreeColumnRecord record1 = new ThreeColumnRecord(1, "foo", null); + ThreeColumnRecord record2 = new ThreeColumnRecord(2, "bar", null); + + List records = Lists.newArrayList(); + IntStream.range(0, filesCount / 2).forEach(i -> { + records.add(record1); + records.add(record2); + }); + + Dataset df = spark.createDataFrame(records, ThreeColumnRecord.class).repartition(filesCount); + try { + df.writeTo(tableName).append(); + } catch (org.apache.spark.sql.catalyst.analysis.NoSuchTableException e) { + throw new RuntimeException(e); + } + } + + private List currentData() { + return rowsToJava(spark.sql("SELECT * FROM " + tableName + " order by c1, c2, c3").collectAsList()); + } +} diff --git a/spark/v3.3/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestRewriteManifestsProcedure.java b/spark/v3.3/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestRewriteManifestsProcedure.java new file mode 100644 index 000000000000..dcf0a2d91e3e --- /dev/null +++ b/spark/v3.3/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestRewriteManifestsProcedure.java @@ -0,0 +1,174 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iceberg.spark.extensions; + +import java.util.List; +import java.util.Map; +import org.apache.iceberg.AssertHelpers; +import org.apache.iceberg.Table; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableList; +import org.apache.spark.sql.AnalysisException; +import org.apache.spark.sql.catalyst.analysis.NoSuchProcedureException; +import org.junit.After; +import org.junit.Assert; +import org.junit.Test; + +import static org.apache.iceberg.TableProperties.SNAPSHOT_ID_INHERITANCE_ENABLED; + +public class TestRewriteManifestsProcedure extends SparkExtensionsTestBase { + + public TestRewriteManifestsProcedure(String catalogName, String implementation, Map config) { + super(catalogName, implementation, config); + } + + @After + public void removeTable() { + sql("DROP TABLE IF EXISTS %s", tableName); + } + + @Test + public void testRewriteManifestsInEmptyTable() { + sql("CREATE TABLE %s (id bigint NOT NULL, data string) USING iceberg", tableName); + List output = sql( + "CALL %s.system.rewrite_manifests('%s')", catalogName, tableIdent); + assertEquals("Procedure output must match", + ImmutableList.of(row(0, 0)), + output); + } + + @Test + public void testRewriteLargeManifests() { + sql("CREATE TABLE %s (id bigint NOT NULL, data string) USING iceberg PARTITIONED BY (data)", tableName); + sql("INSERT INTO TABLE %s VALUES (1, 'a'), (2, 'b'), (3, 'c'), (4, 'd')", tableName); + + Table table = validationCatalog.loadTable(tableIdent); + + Assert.assertEquals("Must have 1 manifest", 1, table.currentSnapshot().allManifests(table.io()).size()); + + sql("ALTER TABLE %s SET TBLPROPERTIES ('commit.manifest.target-size-bytes' '1')", tableName); + + List output = sql( + "CALL %s.system.rewrite_manifests('%s')", catalogName, tableIdent); + assertEquals("Procedure output must match", + ImmutableList.of(row(1, 4)), + output); + + table.refresh(); + + Assert.assertEquals("Must have 4 manifests", 4, table.currentSnapshot().allManifests(table.io()).size()); + } + + @Test + public void testRewriteSmallManifestsWithSnapshotIdInheritance() { + sql("CREATE TABLE %s (id bigint NOT NULL, data string) USING iceberg PARTITIONED BY (data)", tableName); + + sql("ALTER TABLE %s SET TBLPROPERTIES ('%s' '%s')", tableName, SNAPSHOT_ID_INHERITANCE_ENABLED, "true"); + + sql("INSERT INTO TABLE %s VALUES (1, 'a')", tableName); + sql("INSERT INTO TABLE %s VALUES (2, 'b')", tableName); + sql("INSERT INTO TABLE %s VALUES (3, 'c')", tableName); + sql("INSERT INTO TABLE %s VALUES (4, 'd')", tableName); + + Table table = validationCatalog.loadTable(tableIdent); + + Assert.assertEquals("Must have 4 manifest", 4, table.currentSnapshot().allManifests(table.io()).size()); + + List output = sql( + "CALL %s.system.rewrite_manifests(table => '%s')", catalogName, tableIdent); + assertEquals("Procedure output must match", + ImmutableList.of(row(4, 1)), + output); + + table.refresh(); + + Assert.assertEquals("Must have 1 manifests", 1, table.currentSnapshot().allManifests(table.io()).size()); + } + + @Test + public void testRewriteSmallManifestsWithoutCaching() { + sql("CREATE TABLE %s (id bigint NOT NULL, data string) USING iceberg PARTITIONED BY (data)", tableName); + + sql("INSERT INTO TABLE %s VALUES (1, 'a')", tableName); + sql("INSERT INTO TABLE %s VALUES (2, 'b')", tableName); + + Table table = validationCatalog.loadTable(tableIdent); + + Assert.assertEquals("Must have 2 manifest", 2, table.currentSnapshot().allManifests(table.io()).size()); + + List output = sql( + "CALL %s.system.rewrite_manifests(use_caching => false, table => '%s')", catalogName, tableIdent); + assertEquals("Procedure output must match", + ImmutableList.of(row(2, 1)), + output); + + table.refresh(); + + Assert.assertEquals("Must have 1 manifests", 1, table.currentSnapshot().allManifests(table.io()).size()); + } + + @Test + public void testRewriteManifestsCaseInsensitiveArgs() { + sql("CREATE TABLE %s (id bigint NOT NULL, data string) USING iceberg PARTITIONED BY (data)", tableName); + + sql("INSERT INTO TABLE %s VALUES (1, 'a')", tableName); + sql("INSERT INTO TABLE %s VALUES (2, 'b')", tableName); + + Table table = validationCatalog.loadTable(tableIdent); + + Assert.assertEquals("Must have 2 manifest", 2, table.currentSnapshot().allManifests(table.io()).size()); + + List output = sql( + "CALL %s.system.rewrite_manifests(usE_cAcHiNg => false, tAbLe => '%s')", catalogName, tableIdent); + assertEquals("Procedure output must match", + ImmutableList.of(row(2, 1)), + output); + + table.refresh(); + + Assert.assertEquals("Must have 1 manifests", 1, table.currentSnapshot().allManifests(table.io()).size()); + } + + @Test + public void testInvalidRewriteManifestsCases() { + AssertHelpers.assertThrows("Should not allow mixed args", + AnalysisException.class, "Named and positional arguments cannot be mixed", + () -> sql("CALL %s.system.rewrite_manifests('n', table => 't')", catalogName)); + + AssertHelpers.assertThrows("Should not resolve procedures in arbitrary namespaces", + NoSuchProcedureException.class, "not found", + () -> sql("CALL %s.custom.rewrite_manifests('n', 't')", catalogName)); + + AssertHelpers.assertThrows("Should reject calls without all required args", + AnalysisException.class, "Missing required parameters", + () -> sql("CALL %s.system.rewrite_manifests()", catalogName)); + + AssertHelpers.assertThrows("Should reject calls with invalid arg types", + AnalysisException.class, "Wrong arg type", + () -> sql("CALL %s.system.rewrite_manifests('n', 2.2)", catalogName)); + + AssertHelpers.assertThrows("Should reject duplicate arg names name", + AnalysisException.class, "Duplicate procedure argument: table", + () -> sql("CALL %s.system.rewrite_manifests(table => 't', tAbLe => 't')", catalogName)); + + AssertHelpers.assertThrows("Should reject calls with empty table identifier", + IllegalArgumentException.class, "Cannot handle an empty identifier", + () -> sql("CALL %s.system.rewrite_manifests('')", catalogName)); + } +} diff --git a/spark/v3.3/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestRollbackToSnapshotProcedure.java b/spark/v3.3/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestRollbackToSnapshotProcedure.java new file mode 100644 index 000000000000..d3e6bdcbc285 --- /dev/null +++ b/spark/v3.3/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestRollbackToSnapshotProcedure.java @@ -0,0 +1,260 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iceberg.spark.extensions; + +import java.util.List; +import java.util.Map; +import org.apache.iceberg.AssertHelpers; +import org.apache.iceberg.Snapshot; +import org.apache.iceberg.Table; +import org.apache.iceberg.exceptions.ValidationException; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableList; +import org.apache.spark.sql.AnalysisException; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.catalyst.analysis.NoSuchProcedureException; +import org.junit.After; +import org.junit.Assume; +import org.junit.Test; + +public class TestRollbackToSnapshotProcedure extends SparkExtensionsTestBase { + + public TestRollbackToSnapshotProcedure(String catalogName, String implementation, Map config) { + super(catalogName, implementation, config); + } + + @After + public void removeTables() { + sql("DROP TABLE IF EXISTS %s", tableName); + } + + @Test + public void testRollbackToSnapshotUsingPositionalArgs() { + sql("CREATE TABLE %s (id bigint NOT NULL, data string) USING iceberg", tableName); + sql("INSERT INTO TABLE %s VALUES (1, 'a')", tableName); + + Table table = validationCatalog.loadTable(tableIdent); + Snapshot firstSnapshot = table.currentSnapshot(); + + sql("INSERT INTO TABLE %s VALUES (1, 'a')", tableName); + + assertEquals("Should have expected rows", + ImmutableList.of(row(1L, "a"), row(1L, "a")), + sql("SELECT * FROM %s ORDER BY id", tableName)); + + table.refresh(); + + Snapshot secondSnapshot = table.currentSnapshot(); + + List output = sql( + "CALL %s.system.rollback_to_snapshot('%s', %dL)", + catalogName, tableIdent, firstSnapshot.snapshotId()); + + assertEquals("Procedure output must match", + ImmutableList.of(row(secondSnapshot.snapshotId(), firstSnapshot.snapshotId())), + output); + + assertEquals("Rollback must be successful", + ImmutableList.of(row(1L, "a")), + sql("SELECT * FROM %s ORDER BY id", tableName)); + } + + @Test + public void testRollbackToSnapshotUsingNamedArgs() { + sql("CREATE TABLE %s (id bigint NOT NULL, data string) USING iceberg", tableName); + sql("INSERT INTO TABLE %s VALUES (1, 'a')", tableName); + + Table table = validationCatalog.loadTable(tableIdent); + Snapshot firstSnapshot = table.currentSnapshot(); + + sql("INSERT INTO TABLE %s VALUES (1, 'a')", tableName); + + assertEquals("Should have expected rows", + ImmutableList.of(row(1L, "a"), row(1L, "a")), + sql("SELECT * FROM %s ORDER BY id", tableName)); + + table.refresh(); + + Snapshot secondSnapshot = table.currentSnapshot(); + + List output = sql( + "CALL %s.system.rollback_to_snapshot(snapshot_id => %dL, table => '%s')", + catalogName, firstSnapshot.snapshotId(), tableIdent); + + assertEquals("Procedure output must match", + ImmutableList.of(row(secondSnapshot.snapshotId(), firstSnapshot.snapshotId())), + output); + + assertEquals("Rollback must be successful", + ImmutableList.of(row(1L, "a")), + sql("SELECT * FROM %s ORDER BY id", tableName)); + } + + @Test + public void testRollbackToSnapshotRefreshesRelationCache() { + sql("CREATE TABLE %s (id bigint NOT NULL, data string) USING iceberg", tableName); + sql("INSERT INTO TABLE %s VALUES (1, 'a')", tableName); + + Table table = validationCatalog.loadTable(tableIdent); + Snapshot firstSnapshot = table.currentSnapshot(); + + sql("INSERT INTO TABLE %s VALUES (1, 'a')", tableName); + + table.refresh(); + + Snapshot secondSnapshot = table.currentSnapshot(); + + Dataset query = spark.sql("SELECT * FROM " + tableName + " WHERE id = 1"); + query.createOrReplaceTempView("tmp"); + + spark.sql("CACHE TABLE tmp"); + + assertEquals("View should have expected rows", + ImmutableList.of(row(1L, "a"), row(1L, "a")), + sql("SELECT * FROM tmp")); + + List output = sql( + "CALL %s.system.rollback_to_snapshot(table => '%s', snapshot_id => %dL)", + catalogName, tableIdent, firstSnapshot.snapshotId()); + + assertEquals("Procedure output must match", + ImmutableList.of(row(secondSnapshot.snapshotId(), firstSnapshot.snapshotId())), + output); + + assertEquals("View cache must be invalidated", + ImmutableList.of(row(1L, "a")), + sql("SELECT * FROM tmp")); + + sql("UNCACHE TABLE tmp"); + } + + @Test + public void testRollbackToSnapshotWithQuotedIdentifiers() { + sql("CREATE TABLE %s (id bigint NOT NULL, data string) USING iceberg", tableName); + sql("INSERT INTO TABLE %s VALUES (1, 'a')", tableName); + + Table table = validationCatalog.loadTable(tableIdent); + Snapshot firstSnapshot = table.currentSnapshot(); + + sql("INSERT INTO TABLE %s VALUES (1, 'a')", tableName); + + assertEquals("Should have expected rows", + ImmutableList.of(row(1L, "a"), row(1L, "a")), + sql("SELECT * FROM %s ORDER BY id", tableName)); + + table.refresh(); + + Snapshot secondSnapshot = table.currentSnapshot(); + + StringBuilder quotedNamespaceBuilder = new StringBuilder(); + for (String level : tableIdent.namespace().levels()) { + quotedNamespaceBuilder.append("`"); + quotedNamespaceBuilder.append(level); + quotedNamespaceBuilder.append("`"); + } + String quotedNamespace = quotedNamespaceBuilder.toString(); + + List output = sql( + "CALL %s.system.rollback_to_snapshot('%s', %d)", + catalogName, quotedNamespace + ".`" + tableIdent.name() + "`", firstSnapshot.snapshotId()); + + assertEquals("Procedure output must match", + ImmutableList.of(row(secondSnapshot.snapshotId(), firstSnapshot.snapshotId())), + output); + + assertEquals("Rollback must be successful", + ImmutableList.of(row(1L, "a")), + sql("SELECT * FROM %s ORDER BY id", tableName)); + } + + @Test + public void testRollbackToSnapshotWithoutExplicitCatalog() { + Assume.assumeTrue("Working only with the session catalog", "spark_catalog".equals(catalogName)); + + sql("CREATE TABLE %s (id bigint NOT NULL, data string) USING iceberg", tableName); + sql("INSERT INTO TABLE %s VALUES (1, 'a')", tableName); + + Table table = validationCatalog.loadTable(tableIdent); + Snapshot firstSnapshot = table.currentSnapshot(); + + sql("INSERT INTO TABLE %s VALUES (1, 'a')", tableName); + + assertEquals("Should have expected rows", + ImmutableList.of(row(1L, "a"), row(1L, "a")), + sql("SELECT * FROM %s ORDER BY id", tableName)); + + table.refresh(); + + Snapshot secondSnapshot = table.currentSnapshot(); + + // use camel case intentionally to test case sensitivity + List output = sql( + "CALL SyStEm.rOLlBaCk_to_SnApShOt('%s', %dL)", + tableIdent, firstSnapshot.snapshotId()); + + assertEquals("Procedure output must match", + ImmutableList.of(row(secondSnapshot.snapshotId(), firstSnapshot.snapshotId())), + output); + + assertEquals("Rollback must be successful", + ImmutableList.of(row(1L, "a")), + sql("SELECT * FROM %s ORDER BY id", tableName)); + } + + @Test + public void testRollbackToInvalidSnapshot() { + sql("CREATE TABLE %s (id bigint NOT NULL, data string) USING iceberg", tableName); + + AssertHelpers.assertThrows("Should reject invalid snapshot id", + ValidationException.class, "Cannot roll back to unknown snapshot id", + () -> sql("CALL %s.system.rollback_to_snapshot('%s', -1L)", catalogName, tableIdent)); + } + + @Test + public void testInvalidRollbackToSnapshotCases() { + AssertHelpers.assertThrows("Should not allow mixed args", + AnalysisException.class, "Named and positional arguments cannot be mixed", + () -> sql("CALL %s.system.rollback_to_snapshot(namespace => 'n1', table => 't', 1L)", catalogName)); + + AssertHelpers.assertThrows("Should not resolve procedures in arbitrary namespaces", + NoSuchProcedureException.class, "not found", + () -> sql("CALL %s.custom.rollback_to_snapshot('n', 't', 1L)", catalogName)); + + AssertHelpers.assertThrows("Should reject calls without all required args", + AnalysisException.class, "Missing required parameters", + () -> sql("CALL %s.system.rollback_to_snapshot('t')", catalogName)); + + AssertHelpers.assertThrows("Should reject calls without all required args", + AnalysisException.class, "Missing required parameters", + () -> sql("CALL %s.system.rollback_to_snapshot(1L)", catalogName)); + + AssertHelpers.assertThrows("Should reject calls without all required args", + AnalysisException.class, "Missing required parameters", + () -> sql("CALL %s.system.rollback_to_snapshot(table => 't')", catalogName)); + + AssertHelpers.assertThrows("Should reject calls with invalid arg types", + AnalysisException.class, "Wrong arg type for snapshot_id: cannot cast", + () -> sql("CALL %s.system.rollback_to_snapshot('t', 2.2)", catalogName)); + + AssertHelpers.assertThrows("Should reject calls with empty table identifier", + IllegalArgumentException.class, "Cannot handle an empty identifier", + () -> sql("CALL %s.system.rollback_to_snapshot('', 1L)", catalogName)); + } +} diff --git a/spark/v3.3/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestRollbackToTimestampProcedure.java b/spark/v3.3/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestRollbackToTimestampProcedure.java new file mode 100644 index 000000000000..52fc12c7d01e --- /dev/null +++ b/spark/v3.3/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestRollbackToTimestampProcedure.java @@ -0,0 +1,268 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iceberg.spark.extensions; + +import java.time.LocalDateTime; +import java.util.List; +import java.util.Map; +import org.apache.iceberg.AssertHelpers; +import org.apache.iceberg.Snapshot; +import org.apache.iceberg.Table; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableList; +import org.apache.spark.sql.AnalysisException; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.catalyst.analysis.NoSuchProcedureException; +import org.junit.After; +import org.junit.Assume; +import org.junit.Test; + +public class TestRollbackToTimestampProcedure extends SparkExtensionsTestBase { + + public TestRollbackToTimestampProcedure(String catalogName, String implementation, Map config) { + super(catalogName, implementation, config); + } + + @After + public void removeTables() { + sql("DROP TABLE IF EXISTS %s", tableName); + } + + @Test + public void testRollbackToTimestampUsingPositionalArgs() { + sql("CREATE TABLE %s (id bigint NOT NULL, data string) USING iceberg", tableName); + sql("INSERT INTO TABLE %s VALUES (1, 'a')", tableName); + + Table table = validationCatalog.loadTable(tableIdent); + Snapshot firstSnapshot = table.currentSnapshot(); + String firstSnapshotTimestamp = LocalDateTime.now().toString(); + + waitUntilAfter(firstSnapshot.timestampMillis()); + + sql("INSERT INTO TABLE %s VALUES (1, 'a')", tableName); + + assertEquals("Should have expected rows", + ImmutableList.of(row(1L, "a"), row(1L, "a")), + sql("SELECT * FROM %s ORDER BY id", tableName)); + + table.refresh(); + + Snapshot secondSnapshot = table.currentSnapshot(); + + List output = sql( + "CALL %s.system.rollback_to_timestamp('%s',TIMESTAMP '%s')", + catalogName, tableIdent, firstSnapshotTimestamp); + + assertEquals("Procedure output must match", + ImmutableList.of(row(secondSnapshot.snapshotId(), firstSnapshot.snapshotId())), + output); + + assertEquals("Rollback must be successful", + ImmutableList.of(row(1L, "a")), + sql("SELECT * FROM %s ORDER BY id", tableName)); + } + + @Test + public void testRollbackToTimestampUsingNamedArgs() { + sql("CREATE TABLE %s (id bigint NOT NULL, data string) USING iceberg", tableName); + sql("INSERT INTO TABLE %s VALUES (1, 'a')", tableName); + + Table table = validationCatalog.loadTable(tableIdent); + Snapshot firstSnapshot = table.currentSnapshot(); + String firstSnapshotTimestamp = LocalDateTime.now().toString(); + + waitUntilAfter(firstSnapshot.timestampMillis()); + + sql("INSERT INTO TABLE %s VALUES (1, 'a')", tableName); + + assertEquals("Should have expected rows", + ImmutableList.of(row(1L, "a"), row(1L, "a")), + sql("SELECT * FROM %s ORDER BY id", tableName)); + + table.refresh(); + + Snapshot secondSnapshot = table.currentSnapshot(); + + List output = sql( + "CALL %s.system.rollback_to_timestamp(timestamp => TIMESTAMP '%s', table => '%s')", + catalogName, firstSnapshotTimestamp, tableIdent); + + assertEquals("Procedure output must match", + ImmutableList.of(row(secondSnapshot.snapshotId(), firstSnapshot.snapshotId())), + output); + + assertEquals("Rollback must be successful", + ImmutableList.of(row(1L, "a")), + sql("SELECT * FROM %s ORDER BY id", tableName)); + } + + @Test + public void testRollbackToTimestampRefreshesRelationCache() { + sql("CREATE TABLE %s (id bigint NOT NULL, data string) USING iceberg", tableName); + sql("INSERT INTO TABLE %s VALUES (1, 'a')", tableName); + + Table table = validationCatalog.loadTable(tableIdent); + Snapshot firstSnapshot = table.currentSnapshot(); + String firstSnapshotTimestamp = LocalDateTime.now().toString(); + + waitUntilAfter(firstSnapshot.timestampMillis()); + + sql("INSERT INTO TABLE %s VALUES (1, 'a')", tableName); + + table.refresh(); + + Snapshot secondSnapshot = table.currentSnapshot(); + + Dataset query = spark.sql("SELECT * FROM " + tableName + " WHERE id = 1"); + query.createOrReplaceTempView("tmp"); + + spark.sql("CACHE TABLE tmp"); + + assertEquals("View should have expected rows", + ImmutableList.of(row(1L, "a"), row(1L, "a")), + sql("SELECT * FROM tmp")); + + List output = sql( + "CALL %s.system.rollback_to_timestamp(table => '%s', timestamp => TIMESTAMP '%s')", + catalogName, tableIdent, firstSnapshotTimestamp); + + assertEquals("Procedure output must match", + ImmutableList.of(row(secondSnapshot.snapshotId(), firstSnapshot.snapshotId())), + output); + + assertEquals("View cache must be invalidated", + ImmutableList.of(row(1L, "a")), + sql("SELECT * FROM tmp")); + + sql("UNCACHE TABLE tmp"); + } + + @Test + public void testRollbackToTimestampWithQuotedIdentifiers() { + sql("CREATE TABLE %s (id bigint NOT NULL, data string) USING iceberg", tableName); + sql("INSERT INTO TABLE %s VALUES (1, 'a')", tableName); + + Table table = validationCatalog.loadTable(tableIdent); + Snapshot firstSnapshot = table.currentSnapshot(); + String firstSnapshotTimestamp = LocalDateTime.now().toString(); + + waitUntilAfter(firstSnapshot.timestampMillis()); + + sql("INSERT INTO TABLE %s VALUES (1, 'a')", tableName); + + assertEquals("Should have expected rows", + ImmutableList.of(row(1L, "a"), row(1L, "a")), + sql("SELECT * FROM %s ORDER BY id", tableName)); + + table.refresh(); + + Snapshot secondSnapshot = table.currentSnapshot(); + + StringBuilder quotedNamespaceBuilder = new StringBuilder(); + for (String level : tableIdent.namespace().levels()) { + quotedNamespaceBuilder.append("`"); + quotedNamespaceBuilder.append(level); + quotedNamespaceBuilder.append("`"); + } + String quotedNamespace = quotedNamespaceBuilder.toString(); + + List output = sql( + "CALL %s.system.rollback_to_timestamp('%s', TIMESTAMP '%s')", + catalogName, quotedNamespace + ".`" + tableIdent.name() + "`", firstSnapshotTimestamp); + + assertEquals("Procedure output must match", + ImmutableList.of(row(secondSnapshot.snapshotId(), firstSnapshot.snapshotId())), + output); + + assertEquals("Rollback must be successful", + ImmutableList.of(row(1L, "a")), + sql("SELECT * FROM %s ORDER BY id", tableName)); + } + + @Test + public void testRollbackToTimestampWithoutExplicitCatalog() { + Assume.assumeTrue("Working only with the session catalog", "spark_catalog".equals(catalogName)); + + sql("CREATE TABLE %s (id bigint NOT NULL, data string) USING iceberg", tableName); + sql("INSERT INTO TABLE %s VALUES (1, 'a')", tableName); + + Table table = validationCatalog.loadTable(tableIdent); + Snapshot firstSnapshot = table.currentSnapshot(); + String firstSnapshotTimestamp = LocalDateTime.now().toString(); + + waitUntilAfter(firstSnapshot.timestampMillis()); + + sql("INSERT INTO TABLE %s VALUES (1, 'a')", tableName); + + assertEquals("Should have expected rows", + ImmutableList.of(row(1L, "a"), row(1L, "a")), + sql("SELECT * FROM %s ORDER BY id", tableName)); + + table.refresh(); + + Snapshot secondSnapshot = table.currentSnapshot(); + + // use camel case intentionally to test case sensitivity + List output = sql( + "CALL SyStEm.rOLlBaCk_to_TiMeStaMp('%s', TIMESTAMP '%s')", + tableIdent, firstSnapshotTimestamp); + + assertEquals("Procedure output must match", + ImmutableList.of(row(secondSnapshot.snapshotId(), firstSnapshot.snapshotId())), + output); + + assertEquals("Rollback must be successful", + ImmutableList.of(row(1L, "a")), + sql("SELECT * FROM %s ORDER BY id", tableName)); + } + + @Test + public void testInvalidRollbackToTimestampCases() { + String timestamp = "TIMESTAMP '2007-12-03T10:15:30'"; + + AssertHelpers.assertThrows("Should not allow mixed args", + AnalysisException.class, "Named and positional arguments cannot be mixed", + () -> sql("CALL %s.system.rollback_to_timestamp(namespace => 'n1', 't', %s)", catalogName, timestamp)); + + AssertHelpers.assertThrows("Should not resolve procedures in arbitrary namespaces", + NoSuchProcedureException.class, "not found", + () -> sql("CALL %s.custom.rollback_to_timestamp('n', 't', %s)", catalogName, timestamp)); + + AssertHelpers.assertThrows("Should reject calls without all required args", + AnalysisException.class, "Missing required parameters", + () -> sql("CALL %s.system.rollback_to_timestamp('t')", catalogName)); + + AssertHelpers.assertThrows("Should reject calls without all required args", + AnalysisException.class, "Missing required parameters", + () -> sql("CALL %s.system.rollback_to_timestamp(timestamp => %s)", catalogName, timestamp)); + + AssertHelpers.assertThrows("Should reject calls without all required args", + AnalysisException.class, "Missing required parameters", + () -> sql("CALL %s.system.rollback_to_timestamp(table => 't')", catalogName)); + + AssertHelpers.assertThrows("Should reject calls with extra args", + AnalysisException.class, "Too many arguments", + () -> sql("CALL %s.system.rollback_to_timestamp('n', 't', %s, 1L)", catalogName, timestamp)); + + AssertHelpers.assertThrows("Should reject calls with invalid arg types", + AnalysisException.class, "Wrong arg type for timestamp: cannot cast", + () -> sql("CALL %s.system.rollback_to_timestamp('t', 2.2)", catalogName)); + } +} diff --git a/spark/v3.3/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestSetCurrentSnapshotProcedure.java b/spark/v3.3/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestSetCurrentSnapshotProcedure.java new file mode 100644 index 000000000000..0ea8c4861e8c --- /dev/null +++ b/spark/v3.3/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestSetCurrentSnapshotProcedure.java @@ -0,0 +1,221 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iceberg.spark.extensions; + +import java.util.List; +import java.util.Map; +import org.apache.iceberg.AssertHelpers; +import org.apache.iceberg.Snapshot; +import org.apache.iceberg.Table; +import org.apache.iceberg.catalog.Namespace; +import org.apache.iceberg.exceptions.ValidationException; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableList; +import org.apache.iceberg.relocated.com.google.common.collect.Iterables; +import org.apache.spark.sql.AnalysisException; +import org.apache.spark.sql.catalyst.analysis.NoSuchProcedureException; +import org.junit.After; +import org.junit.Assume; +import org.junit.Test; + +import static org.apache.iceberg.TableProperties.WRITE_AUDIT_PUBLISH_ENABLED; + +public class TestSetCurrentSnapshotProcedure extends SparkExtensionsTestBase { + + public TestSetCurrentSnapshotProcedure(String catalogName, String implementation, Map config) { + super(catalogName, implementation, config); + } + + @After + public void removeTables() { + sql("DROP TABLE IF EXISTS %s", tableName); + } + + @Test + public void testSetCurrentSnapshotUsingPositionalArgs() { + sql("CREATE TABLE %s (id bigint NOT NULL, data string) USING iceberg", tableName); + sql("INSERT INTO TABLE %s VALUES (1, 'a')", tableName); + + Table table = validationCatalog.loadTable(tableIdent); + Snapshot firstSnapshot = table.currentSnapshot(); + + sql("INSERT INTO TABLE %s VALUES (1, 'a')", tableName); + + assertEquals("Should have expected rows", + ImmutableList.of(row(1L, "a"), row(1L, "a")), + sql("SELECT * FROM %s ORDER BY id", tableName)); + + table.refresh(); + + Snapshot secondSnapshot = table.currentSnapshot(); + + List output = sql( + "CALL %s.system.set_current_snapshot('%s', %dL)", + catalogName, tableIdent, firstSnapshot.snapshotId()); + + assertEquals("Procedure output must match", + ImmutableList.of(row(secondSnapshot.snapshotId(), firstSnapshot.snapshotId())), + output); + + assertEquals("Set must be successful", + ImmutableList.of(row(1L, "a")), + sql("SELECT * FROM %s ORDER BY id", tableName)); + } + + @Test + public void testSetCurrentSnapshotUsingNamedArgs() { + sql("CREATE TABLE %s (id bigint NOT NULL, data string) USING iceberg", tableName); + sql("INSERT INTO TABLE %s VALUES (1, 'a')", tableName); + + Table table = validationCatalog.loadTable(tableIdent); + Snapshot firstSnapshot = table.currentSnapshot(); + + sql("INSERT INTO TABLE %s VALUES (1, 'a')", tableName); + + assertEquals("Should have expected rows", + ImmutableList.of(row(1L, "a"), row(1L, "a")), + sql("SELECT * FROM %s ORDER BY id", tableName)); + + table.refresh(); + + Snapshot secondSnapshot = table.currentSnapshot(); + + List output = sql( + "CALL %s.system.set_current_snapshot(snapshot_id => %dL, table => '%s')", + catalogName, firstSnapshot.snapshotId(), tableIdent); + + assertEquals("Procedure output must match", + ImmutableList.of(row(secondSnapshot.snapshotId(), firstSnapshot.snapshotId())), + output); + + assertEquals("Set must be successful", + ImmutableList.of(row(1L, "a")), + sql("SELECT * FROM %s ORDER BY id", tableName)); + } + + @Test + public void testSetCurrentSnapshotWap() { + sql("CREATE TABLE %s (id bigint NOT NULL, data string) USING iceberg", tableName); + sql("ALTER TABLE %s SET TBLPROPERTIES ('%s' 'true')", tableName, WRITE_AUDIT_PUBLISH_ENABLED); + + spark.conf().set("spark.wap.id", "1"); + + sql("INSERT INTO TABLE %s VALUES (1, 'a')", tableName); + + assertEquals("Should not see rows from staged snapshot", + ImmutableList.of(), + sql("SELECT * FROM %s", tableName)); + + Table table = validationCatalog.loadTable(tableIdent); + Snapshot wapSnapshot = Iterables.getOnlyElement(table.snapshots()); + + List output = sql( + "CALL %s.system.set_current_snapshot(table => '%s', snapshot_id => %dL)", + catalogName, tableIdent, wapSnapshot.snapshotId()); + + assertEquals("Procedure output must match", + ImmutableList.of(row(null, wapSnapshot.snapshotId())), + output); + + assertEquals("Current snapshot must be set correctly", + ImmutableList.of(row(1L, "a")), + sql("SELECT * FROM %s", tableName)); + } + + @Test + public void tesSetCurrentSnapshotWithoutExplicitCatalog() { + Assume.assumeTrue("Working only with the session catalog", "spark_catalog".equals(catalogName)); + + sql("CREATE TABLE %s (id bigint NOT NULL, data string) USING iceberg", tableName); + sql("INSERT INTO TABLE %s VALUES (1, 'a')", tableName); + + Table table = validationCatalog.loadTable(tableIdent); + Snapshot firstSnapshot = table.currentSnapshot(); + + sql("INSERT INTO TABLE %s VALUES (1, 'a')", tableName); + + assertEquals("Should have expected rows", + ImmutableList.of(row(1L, "a"), row(1L, "a")), + sql("SELECT * FROM %s ORDER BY id", tableName)); + + table.refresh(); + + Snapshot secondSnapshot = table.currentSnapshot(); + + // use camel case intentionally to test case sensitivity + List output = sql( + "CALL SyStEm.sEt_cuRrEnT_sNaPsHot('%s', %dL)", + tableIdent, firstSnapshot.snapshotId()); + + assertEquals("Procedure output must match", + ImmutableList.of(row(secondSnapshot.snapshotId(), firstSnapshot.snapshotId())), + output); + + assertEquals("Set must be successful", + ImmutableList.of(row(1L, "a")), + sql("SELECT * FROM %s ORDER BY id", tableName)); + } + + @Test + public void testSetCurrentSnapshotToInvalidSnapshot() { + sql("CREATE TABLE %s (id bigint NOT NULL, data string) USING iceberg", tableName); + + Namespace namespace = tableIdent.namespace(); + String tableName = tableIdent.name(); + + AssertHelpers.assertThrows("Should reject invalid snapshot id", + ValidationException.class, "Cannot roll back to unknown snapshot id", + () -> sql("CALL %s.system.set_current_snapshot('%s', -1L)", catalogName, tableIdent)); + } + + @Test + public void testInvalidRollbackToSnapshotCases() { + AssertHelpers.assertThrows("Should not allow mixed args", + AnalysisException.class, "Named and positional arguments cannot be mixed", + () -> sql("CALL %s.system.set_current_snapshot(namespace => 'n1', table => 't', 1L)", catalogName)); + + AssertHelpers.assertThrows("Should not resolve procedures in arbitrary namespaces", + NoSuchProcedureException.class, "not found", + () -> sql("CALL %s.custom.set_current_snapshot('n', 't', 1L)", catalogName)); + + AssertHelpers.assertThrows("Should reject calls without all required args", + AnalysisException.class, "Missing required parameters", + () -> sql("CALL %s.system.set_current_snapshot('t')", catalogName)); + + AssertHelpers.assertThrows("Should reject calls without all required args", + AnalysisException.class, "Missing required parameters", + () -> sql("CALL %s.system.set_current_snapshot(1L)", catalogName)); + + AssertHelpers.assertThrows("Should reject calls without all required args", + AnalysisException.class, "Missing required parameters", + () -> sql("CALL %s.system.set_current_snapshot(snapshot_id => 1L)", catalogName)); + + AssertHelpers.assertThrows("Should reject calls without all required args", + AnalysisException.class, "Missing required parameters", + () -> sql("CALL %s.system.set_current_snapshot(table => 't')", catalogName)); + + AssertHelpers.assertThrows("Should reject calls with invalid arg types", + AnalysisException.class, "Wrong arg type for snapshot_id: cannot cast", + () -> sql("CALL %s.system.set_current_snapshot('t', 2.2)", catalogName)); + + AssertHelpers.assertThrows("Should reject calls with empty table identifier", + IllegalArgumentException.class, "Cannot handle an empty identifier", + () -> sql("CALL %s.system.set_current_snapshot('', 1L)", catalogName)); + } +} diff --git a/spark/v3.3/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestSetWriteDistributionAndOrdering.java b/spark/v3.3/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestSetWriteDistributionAndOrdering.java new file mode 100644 index 000000000000..473278d25068 --- /dev/null +++ b/spark/v3.3/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestSetWriteDistributionAndOrdering.java @@ -0,0 +1,282 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iceberg.spark.extensions; + +import java.util.Map; +import org.apache.iceberg.NullOrder; +import org.apache.iceberg.SortOrder; +import org.apache.iceberg.Table; +import org.apache.iceberg.TableProperties; +import org.junit.After; +import org.junit.Assert; +import org.junit.Test; + +import static org.apache.iceberg.expressions.Expressions.bucket; + +public class TestSetWriteDistributionAndOrdering extends SparkExtensionsTestBase { + public TestSetWriteDistributionAndOrdering(String catalogName, String implementation, Map config) { + super(catalogName, implementation, config); + } + + @After + public void removeTable() { + sql("DROP TABLE IF EXISTS %s", tableName); + } + + @Test + public void testSetWriteOrderByColumn() { + sql("CREATE TABLE %s (id bigint NOT NULL, category string, ts timestamp, data string) USING iceberg", tableName); + Table table = validationCatalog.loadTable(tableIdent); + Assert.assertTrue("Table should start unsorted", table.sortOrder().isUnsorted()); + + sql("ALTER TABLE %s WRITE ORDERED BY category, id", tableName); + + table.refresh(); + + String distributionMode = table.properties().get(TableProperties.WRITE_DISTRIBUTION_MODE); + Assert.assertEquals("Distribution mode must match", "range", distributionMode); + + SortOrder expected = SortOrder.builderFor(table.schema()) + .withOrderId(1) + .asc("category", NullOrder.NULLS_FIRST) + .asc("id", NullOrder.NULLS_FIRST) + .build(); + Assert.assertEquals("Should have expected order", expected, table.sortOrder()); + } + + @Test + public void testSetWriteOrderByColumnWithDirection() { + sql("CREATE TABLE %s (id bigint NOT NULL, category string, ts timestamp, data string) USING iceberg", tableName); + Table table = validationCatalog.loadTable(tableIdent); + Assert.assertTrue("Table should start unsorted", table.sortOrder().isUnsorted()); + + sql("ALTER TABLE %s WRITE ORDERED BY category ASC, id DESC", tableName); + + table.refresh(); + + String distributionMode = table.properties().get(TableProperties.WRITE_DISTRIBUTION_MODE); + Assert.assertEquals("Distribution mode must match", "range", distributionMode); + + SortOrder expected = SortOrder.builderFor(table.schema()) + .withOrderId(1) + .asc("category", NullOrder.NULLS_FIRST) + .desc("id", NullOrder.NULLS_LAST) + .build(); + Assert.assertEquals("Should have expected order", expected, table.sortOrder()); + } + + @Test + public void testSetWriteOrderByColumnWithDirectionAndNullOrder() { + sql("CREATE TABLE %s (id bigint NOT NULL, category string, ts timestamp, data string) USING iceberg", tableName); + Table table = validationCatalog.loadTable(tableIdent); + Assert.assertTrue("Table should start unsorted", table.sortOrder().isUnsorted()); + + sql("ALTER TABLE %s WRITE ORDERED BY category ASC NULLS LAST, id DESC NULLS FIRST", tableName); + + table.refresh(); + + String distributionMode = table.properties().get(TableProperties.WRITE_DISTRIBUTION_MODE); + Assert.assertEquals("Distribution mode must match", "range", distributionMode); + + SortOrder expected = SortOrder.builderFor(table.schema()) + .withOrderId(1) + .asc("category", NullOrder.NULLS_LAST) + .desc("id", NullOrder.NULLS_FIRST) + .build(); + Assert.assertEquals("Should have expected order", expected, table.sortOrder()); + } + + @Test + public void testSetWriteOrderByTransform() { + sql("CREATE TABLE %s (id bigint NOT NULL, category string, ts timestamp, data string) USING iceberg", tableName); + Table table = validationCatalog.loadTable(tableIdent); + Assert.assertTrue("Table should start unsorted", table.sortOrder().isUnsorted()); + + sql("ALTER TABLE %s WRITE ORDERED BY category DESC, bucket(16, id), id", tableName); + + table.refresh(); + + String distributionMode = table.properties().get(TableProperties.WRITE_DISTRIBUTION_MODE); + Assert.assertEquals("Distribution mode must match", "range", distributionMode); + + SortOrder expected = SortOrder.builderFor(table.schema()) + .withOrderId(1) + .desc("category") + .asc(bucket("id", 16)) + .asc("id") + .build(); + Assert.assertEquals("Should have expected order", expected, table.sortOrder()); + } + + @Test + public void testSetWriteUnordered() { + sql("CREATE TABLE %s (id bigint NOT NULL, category string, ts timestamp, data string) USING iceberg", tableName); + Table table = validationCatalog.loadTable(tableIdent); + Assert.assertTrue("Table should start unsorted", table.sortOrder().isUnsorted()); + + sql("ALTER TABLE %s WRITE ORDERED BY category DESC, bucket(16, id), id", tableName); + + table.refresh(); + + String distributionMode = table.properties().get(TableProperties.WRITE_DISTRIBUTION_MODE); + Assert.assertEquals("Distribution mode must match", "range", distributionMode); + + Assert.assertNotEquals("Table must be sorted", SortOrder.unsorted(), table.sortOrder()); + + sql("ALTER TABLE %s WRITE UNORDERED", tableName); + + table.refresh(); + + String newDistributionMode = table.properties().get(TableProperties.WRITE_DISTRIBUTION_MODE); + Assert.assertEquals("New distribution mode must match", "none", newDistributionMode); + + Assert.assertEquals("New sort order must match", SortOrder.unsorted(), table.sortOrder()); + } + + @Test + public void testSetWriteLocallyOrdered() { + sql("CREATE TABLE %s (id bigint NOT NULL, category string, ts timestamp, data string) USING iceberg", tableName); + Table table = validationCatalog.loadTable(tableIdent); + Assert.assertTrue("Table should start unsorted", table.sortOrder().isUnsorted()); + + sql("ALTER TABLE %s WRITE LOCALLY ORDERED BY category DESC, bucket(16, id), id", tableName); + + table.refresh(); + + String distributionMode = table.properties().get(TableProperties.WRITE_DISTRIBUTION_MODE); + Assert.assertEquals("Distribution mode must match", "none", distributionMode); + + SortOrder expected = SortOrder.builderFor(table.schema()) + .withOrderId(1) + .desc("category") + .asc(bucket("id", 16)) + .asc("id") + .build(); + Assert.assertEquals("Sort order must match", expected, table.sortOrder()); + } + + @Test + public void testSetWriteDistributedByWithSort() { + sql("CREATE TABLE %s (id bigint NOT NULL, category string) USING iceberg PARTITIONED BY (category)", tableName); + Table table = validationCatalog.loadTable(tableIdent); + Assert.assertTrue("Table should start unsorted", table.sortOrder().isUnsorted()); + + sql("ALTER TABLE %s WRITE DISTRIBUTED BY PARTITION ORDERED BY id", tableName); + + table.refresh(); + + String distributionMode = table.properties().get(TableProperties.WRITE_DISTRIBUTION_MODE); + Assert.assertEquals("Distribution mode must match", "hash", distributionMode); + + SortOrder expected = SortOrder.builderFor(table.schema()) + .withOrderId(1) + .asc("id") + .build(); + Assert.assertEquals("Sort order must match", expected, table.sortOrder()); + } + + @Test + public void testSetWriteDistributedByWithLocalSort() { + sql("CREATE TABLE %s (id bigint NOT NULL, category string) USING iceberg PARTITIONED BY (category)", tableName); + Table table = validationCatalog.loadTable(tableIdent); + Assert.assertTrue("Table should start unsorted", table.sortOrder().isUnsorted()); + + sql("ALTER TABLE %s WRITE DISTRIBUTED BY PARTITION LOCALLY ORDERED BY id", tableName); + + table.refresh(); + + String distributionMode = table.properties().get(TableProperties.WRITE_DISTRIBUTION_MODE); + Assert.assertEquals("Distribution mode must match", "hash", distributionMode); + + SortOrder expected = SortOrder.builderFor(table.schema()) + .withOrderId(1) + .asc("id") + .build(); + Assert.assertEquals("Sort order must match", expected, table.sortOrder()); + } + + @Test + public void testSetWriteDistributedByAndUnordered() { + sql("CREATE TABLE %s (id bigint NOT NULL, category string) USING iceberg PARTITIONED BY (category)", tableName); + Table table = validationCatalog.loadTable(tableIdent); + Assert.assertTrue("Table should start unsorted", table.sortOrder().isUnsorted()); + + sql("ALTER TABLE %s WRITE DISTRIBUTED BY PARTITION UNORDERED", tableName); + + table.refresh(); + + String distributionMode = table.properties().get(TableProperties.WRITE_DISTRIBUTION_MODE); + Assert.assertEquals("Distribution mode must match", "hash", distributionMode); + + Assert.assertEquals("Sort order must match", SortOrder.unsorted(), table.sortOrder()); + } + + @Test + public void testSetWriteDistributedByOnly() { + sql("CREATE TABLE %s (id bigint NOT NULL, category string) USING iceberg PARTITIONED BY (category)", tableName); + Table table = validationCatalog.loadTable(tableIdent); + Assert.assertTrue("Table should start unsorted", table.sortOrder().isUnsorted()); + + sql("ALTER TABLE %s WRITE DISTRIBUTED BY PARTITION UNORDERED", tableName); + + table.refresh(); + + String distributionMode = table.properties().get(TableProperties.WRITE_DISTRIBUTION_MODE); + Assert.assertEquals("Distribution mode must match", "hash", distributionMode); + + Assert.assertEquals("Sort order must match", SortOrder.unsorted(), table.sortOrder()); + } + + @Test + public void testSetWriteDistributedAndUnorderedInverted() { + sql("CREATE TABLE %s (id bigint NOT NULL, category string) USING iceberg PARTITIONED BY (category)", tableName); + Table table = validationCatalog.loadTable(tableIdent); + Assert.assertTrue("Table should start unsorted", table.sortOrder().isUnsorted()); + + sql("ALTER TABLE %s WRITE UNORDERED DISTRIBUTED BY PARTITION", tableName); + + table.refresh(); + + String distributionMode = table.properties().get(TableProperties.WRITE_DISTRIBUTION_MODE); + Assert.assertEquals("Distribution mode must match", "hash", distributionMode); + + Assert.assertEquals("Sort order must match", SortOrder.unsorted(), table.sortOrder()); + } + + @Test + public void testSetWriteDistributedAndLocallyOrderedInverted() { + sql("CREATE TABLE %s (id bigint NOT NULL, category string) USING iceberg PARTITIONED BY (category)", tableName); + Table table = validationCatalog.loadTable(tableIdent); + Assert.assertTrue("Table should start unsorted", table.sortOrder().isUnsorted()); + + sql("ALTER TABLE %s WRITE ORDERED BY id DISTRIBUTED BY PARTITION", tableName); + + table.refresh(); + + String distributionMode = table.properties().get(TableProperties.WRITE_DISTRIBUTION_MODE); + Assert.assertEquals("Distribution mode must match", "hash", distributionMode); + + SortOrder expected = SortOrder.builderFor(table.schema()) + .withOrderId(1) + .asc("id") + .build(); + Assert.assertEquals("Sort order must match", expected, table.sortOrder()); + } +} diff --git a/spark/v3.3/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestSnapshotTableProcedure.java b/spark/v3.3/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestSnapshotTableProcedure.java new file mode 100644 index 000000000000..66fa8e80c515 --- /dev/null +++ b/spark/v3.3/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestSnapshotTableProcedure.java @@ -0,0 +1,191 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iceberg.spark.extensions; + +import java.io.IOException; +import java.util.Map; +import org.apache.iceberg.AssertHelpers; +import org.apache.iceberg.Table; +import org.apache.iceberg.TableProperties; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableList; +import org.apache.spark.sql.AnalysisException; +import org.junit.After; +import org.junit.Assert; +import org.junit.Assume; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.TemporaryFolder; + +public class TestSnapshotTableProcedure extends SparkExtensionsTestBase { + private static final String sourceName = "spark_catalog.default.source"; + // Currently we can only Snapshot only out of the Spark Session Catalog + + public TestSnapshotTableProcedure(String catalogName, String implementation, Map config) { + super(catalogName, implementation, config); + } + + @Rule + public TemporaryFolder temp = new TemporaryFolder(); + + @After + public void removeTables() { + sql("DROP TABLE IF EXISTS %s", tableName); + sql("DROP TABLE IF EXISTS %S", sourceName); + } + + @Test + public void testSnapshot() throws IOException { + String location = temp.newFolder().toString(); + sql("CREATE TABLE %s (id bigint NOT NULL, data string) USING parquet LOCATION '%s'", sourceName, location); + sql("INSERT INTO TABLE %s VALUES (1, 'a')", sourceName); + Object result = scalarSql("CALL %s.system.snapshot('%s', '%s')", catalogName, sourceName, tableName); + + Assert.assertEquals("Should have added one file", 1L, result); + + Table createdTable = validationCatalog.loadTable(tableIdent); + String tableLocation = createdTable.location(); + Assert.assertNotEquals("Table should not have the original location", location, tableLocation); + + sql("INSERT INTO TABLE %s VALUES (1, 'a')", tableName); + + assertEquals("Should have expected rows", + ImmutableList.of(row(1L, "a"), row(1L, "a")), + sql("SELECT * FROM %s ORDER BY id", tableName)); + } + + @Test + public void testSnapshotWithProperties() throws IOException { + String location = temp.newFolder().toString(); + sql("CREATE TABLE %s (id bigint NOT NULL, data string) USING parquet LOCATION '%s'", sourceName, location); + sql("INSERT INTO TABLE %s VALUES (1, 'a')", sourceName); + Object result = scalarSql( + "CALL %s.system.snapshot(source_table => '%s', table => '%s', properties => map('foo','bar'))", + catalogName, sourceName, tableName); + + Assert.assertEquals("Should have added one file", 1L, result); + + Table createdTable = validationCatalog.loadTable(tableIdent); + + String tableLocation = createdTable.location(); + Assert.assertNotEquals("Table should not have the original location", location, tableLocation); + + Map props = createdTable.properties(); + Assert.assertEquals("Should have extra property set", "bar", props.get("foo")); + + sql("INSERT INTO TABLE %s VALUES (1, 'a')", tableName); + + assertEquals("Should have expected rows", + ImmutableList.of(row(1L, "a"), row(1L, "a")), + sql("SELECT * FROM %s ORDER BY id", tableName)); + } + + @Test + public void testSnapshotWithAlternateLocation() throws IOException { + Assume.assumeTrue("No Snapshoting with Alternate locations with Hadoop Catalogs", !catalogName.contains("hadoop")); + String location = temp.newFolder().toString(); + String snapshotLocation = temp.newFolder().toString(); + sql("CREATE TABLE %s (id bigint NOT NULL, data string) USING parquet LOCATION '%s'", sourceName, location); + sql("INSERT INTO TABLE %s VALUES (1, 'a')", sourceName); + Object[] result = sql( + "CALL %s.system.snapshot(source_table => '%s', table => '%s', location => '%s')", + catalogName, sourceName, tableName, snapshotLocation).get(0); + + Assert.assertEquals("Should have added one file", 1L, result[0]); + + String storageLocation = validationCatalog.loadTable(tableIdent).location(); + Assert.assertEquals("Snapshot should be made at specified location", snapshotLocation, storageLocation); + + sql("INSERT INTO TABLE %s VALUES (1, 'a')", tableName); + + assertEquals("Should have expected rows", + ImmutableList.of(row(1L, "a"), row(1L, "a")), + sql("SELECT * FROM %s ORDER BY id", tableName)); + } + + @Test + public void testDropTable() throws IOException { + String location = temp.newFolder().toString(); + sql("CREATE TABLE %s (id bigint NOT NULL, data string) USING parquet LOCATION '%s'", sourceName, location); + sql("INSERT INTO TABLE %s VALUES (1, 'a')", sourceName); + + Object result = scalarSql("CALL %s.system.snapshot('%s', '%s')", catalogName, sourceName, tableName); + Assert.assertEquals("Should have added one file", 1L, result); + + assertEquals("Should have expected rows", + ImmutableList.of(row(1L, "a")), + sql("SELECT * FROM %s", tableName)); + + sql("DROP TABLE %s", tableName); + + assertEquals("Source table should be intact", + ImmutableList.of(row(1L, "a")), + sql("SELECT * FROM %s", sourceName)); + } + + @Test + public void testSnapshotWithConflictingProps() throws IOException { + String location = temp.newFolder().toString(); + sql("CREATE TABLE %s (id bigint NOT NULL, data string) USING parquet LOCATION '%s'", sourceName, location); + sql("INSERT INTO TABLE %s VALUES (1, 'a')", sourceName); + + Object result = scalarSql( + "CALL %s.system.snapshot(" + + "source_table => '%s'," + + "table => '%s'," + + "properties => map('%s', 'true', 'snapshot', 'false'))", + catalogName, sourceName, tableName, TableProperties.GC_ENABLED); + Assert.assertEquals("Should have added one file", 1L, result); + + assertEquals("Should have expected rows", + ImmutableList.of(row(1L, "a")), + sql("SELECT * FROM %s", tableName)); + + Table table = validationCatalog.loadTable(tableIdent); + Map props = table.properties(); + Assert.assertEquals("Should override user value", "true", props.get("snapshot")); + Assert.assertEquals("Should override user value", "false", props.get(TableProperties.GC_ENABLED)); + } + + @Test + public void testInvalidSnapshotsCases() throws IOException { + String location = temp.newFolder().toString(); + sql("CREATE TABLE %s (id bigint NOT NULL, data string) USING parquet LOCATION '%s'", sourceName, location); + + AssertHelpers.assertThrows("Should reject calls without all required args", + AnalysisException.class, "Missing required parameters", + () -> sql("CALL %s.system.snapshot('foo')", catalogName)); + + AssertHelpers.assertThrows("Should reject calls with invalid arg types", + AnalysisException.class, "Wrong arg type", + () -> sql("CALL %s.system.snapshot('n', 't', map('foo', 'bar'))", catalogName)); + + AssertHelpers.assertThrows("Should reject calls with invalid map args", + AnalysisException.class, "cannot resolve 'map", + () -> sql("CALL %s.system.snapshot('%s', 'fable', 'loc', map(2, 1, 1))", catalogName, sourceName)); + + AssertHelpers.assertThrows("Should reject calls with empty table identifier", + IllegalArgumentException.class, "Cannot handle an empty identifier", + () -> sql("CALL %s.system.snapshot('', 'dest')", catalogName)); + + AssertHelpers.assertThrows("Should reject calls with empty table identifier", + IllegalArgumentException.class, "Cannot handle an empty identifier", + () -> sql("CALL %s.system.snapshot('src', '')", catalogName)); + } +} diff --git a/spark/v3.3/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestUpdate.java b/spark/v3.3/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestUpdate.java new file mode 100644 index 000000000000..edc3944e69fe --- /dev/null +++ b/spark/v3.3/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestUpdate.java @@ -0,0 +1,1001 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iceberg.spark.extensions; + +import java.util.Arrays; +import java.util.List; +import java.util.Map; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.Future; +import java.util.concurrent.ThreadPoolExecutor; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicInteger; +import org.apache.iceberg.AssertHelpers; +import org.apache.iceberg.DataFile; +import org.apache.iceberg.RowLevelOperationMode; +import org.apache.iceberg.Snapshot; +import org.apache.iceberg.SnapshotSummary; +import org.apache.iceberg.Table; +import org.apache.iceberg.exceptions.ValidationException; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableList; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableSet; +import org.apache.iceberg.relocated.com.google.common.collect.Iterables; +import org.apache.iceberg.relocated.com.google.common.collect.Lists; +import org.apache.iceberg.relocated.com.google.common.util.concurrent.MoreExecutors; +import org.apache.iceberg.spark.SparkSQLProperties; +import org.apache.spark.SparkException; +import org.apache.spark.sql.AnalysisException; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Encoders; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.catalyst.analysis.NoSuchTableException; +import org.apache.spark.sql.internal.SQLConf; +import org.hamcrest.CoreMatchers; +import org.junit.After; +import org.junit.Assert; +import org.junit.Assume; +import org.junit.BeforeClass; +import org.junit.Test; + +import static org.apache.iceberg.DataOperations.OVERWRITE; +import static org.apache.iceberg.RowLevelOperationMode.COPY_ON_WRITE; +import static org.apache.iceberg.SnapshotSummary.ADDED_FILES_PROP; +import static org.apache.iceberg.SnapshotSummary.CHANGED_PARTITION_COUNT_PROP; +import static org.apache.iceberg.SnapshotSummary.DELETED_FILES_PROP; +import static org.apache.iceberg.TableProperties.PARQUET_ROW_GROUP_SIZE_BYTES; +import static org.apache.iceberg.TableProperties.SPLIT_SIZE; +import static org.apache.iceberg.TableProperties.UPDATE_ISOLATION_LEVEL; +import static org.apache.iceberg.TableProperties.UPDATE_MODE; +import static org.apache.iceberg.TableProperties.UPDATE_MODE_DEFAULT; +import static org.apache.spark.sql.functions.lit; + +public abstract class TestUpdate extends SparkRowLevelOperationsTestBase { + + public TestUpdate(String catalogName, String implementation, Map config, + String fileFormat, boolean vectorized, String distributionMode) { + super(catalogName, implementation, config, fileFormat, vectorized, distributionMode); + } + + @BeforeClass + public static void setupSparkConf() { + spark.conf().set("spark.sql.shuffle.partitions", "4"); + } + + @After + public void removeTables() { + sql("DROP TABLE IF EXISTS %s", tableName); + sql("DROP TABLE IF EXISTS updated_id"); + sql("DROP TABLE IF EXISTS updated_dep"); + sql("DROP TABLE IF EXISTS deleted_employee"); + } + + @Test + public void testExplain() { + createAndInitTable("id INT, dep STRING"); + + sql("INSERT INTO TABLE %s VALUES (1, 'hr'), (2, 'hardware'), (null, 'hr')", tableName); + + sql("EXPLAIN UPDATE %s SET dep = 'invalid' WHERE id <=> 1", tableName); + + sql("EXPLAIN UPDATE %s SET dep = 'invalid' WHERE true", tableName); + + Table table = validationCatalog.loadTable(tableIdent); + Assert.assertEquals("Should have 1 snapshot", 1, Iterables.size(table.snapshots())); + + assertEquals("Should have expected rows", + ImmutableList.of(row(1, "hr"), row(2, "hardware"), row(null, "hr")), + sql("SELECT * FROM %s ORDER BY id ASC NULLS LAST", tableName)); + } + + @Test + public void testUpdateEmptyTable() { + createAndInitTable("id INT, dep STRING"); + + sql("UPDATE %s SET dep = 'invalid' WHERE id IN (1)", tableName); + sql("UPDATE %s SET id = -1 WHERE dep = 'hr'", tableName); + + Table table = validationCatalog.loadTable(tableIdent); + Assert.assertEquals("Should have 2 snapshots", 2, Iterables.size(table.snapshots())); + + assertEquals("Should have expected rows", + ImmutableList.of(), + sql("SELECT * FROM %s ORDER BY id", tableName)); + } + + @Test + public void testUpdateWithAlias() { + createAndInitTable("id INT, dep STRING", "{ \"id\": 1, \"dep\": \"a\" }"); + sql("ALTER TABLE %s ADD PARTITION FIELD dep", tableName); + + sql("UPDATE %s AS t SET t.dep = 'invalid'", tableName); + + Table table = validationCatalog.loadTable(tableIdent); + Assert.assertEquals("Should have 2 snapshots", 2, Iterables.size(table.snapshots())); + + assertEquals("Should have expected rows", + ImmutableList.of(row(1, "invalid")), + sql("SELECT * FROM %s", tableName)); + } + + @Test + public void testUpdateAlignsAssignments() { + createAndInitTable("id INT, c1 INT, c2 INT"); + + sql("INSERT INTO TABLE %s VALUES (1, 11, 111), (2, 22, 222)", tableName); + + sql("UPDATE %s SET `c2` = c2 - 2, c1 = `c1` - 1 WHERE id <=> 1", tableName); + + assertEquals("Should have expected rows", + ImmutableList.of(row(1, 10, 109), row(2, 22, 222)), + sql("SELECT * FROM %s ORDER BY id", tableName)); + } + + @Test + public void testUpdateWithUnsupportedPartitionPredicate() { + createAndInitTable("id INT, dep STRING"); + sql("ALTER TABLE %s ADD PARTITION FIELD dep", tableName); + + sql("INSERT INTO TABLE %s VALUES (1, 'software'), (2, 'hr')", tableName); + + sql("UPDATE %s t SET `t`.`id` = -1 WHERE t.dep LIKE '%%r' ", tableName); + + assertEquals("Should have expected rows", + ImmutableList.of(row(-1, "hr"), row(1, "software")), + sql("SELECT * FROM %s ORDER BY id", tableName)); + } + + @Test + public void testUpdateWithDynamicFileFiltering() { + createAndInitTable("id INT, dep STRING"); + sql("ALTER TABLE %s ADD PARTITION FIELD dep", tableName); + + append(tableName, + "{ \"id\": 1, \"dep\": \"hr\" }\n" + + "{ \"id\": 3, \"dep\": \"hr\" }"); + append(tableName, + "{ \"id\": 1, \"dep\": \"hardware\" }\n" + + "{ \"id\": 2, \"dep\": \"hardware\" }"); + + sql("UPDATE %s SET id = cast('-1' AS INT) WHERE id = 2", tableName); + + Table table = validationCatalog.loadTable(tableIdent); + Assert.assertEquals("Should have 3 snapshots", 3, Iterables.size(table.snapshots())); + + Snapshot currentSnapshot = table.currentSnapshot(); + if (mode(table) == COPY_ON_WRITE) { + validateCopyOnWrite(currentSnapshot, "1", "1", "1"); + } else { + validateMergeOnRead(currentSnapshot, "1", "1", "1"); + } + + assertEquals("Should have expected rows", + ImmutableList.of(row(-1, "hardware"), row(1, "hardware"), row(1, "hr"), row(3, "hr")), + sql("SELECT * FROM %s ORDER BY id, dep", tableName)); + } + + @Test + public void testUpdateNonExistingRecords() { + createAndInitTable("id INT, dep STRING"); + sql("ALTER TABLE %s ADD PARTITION FIELD dep", tableName); + + sql("INSERT INTO TABLE %s VALUES (1, 'hr'), (2, 'hardware'), (null, 'hr')", tableName); + + sql("UPDATE %s SET id = -1 WHERE id > 10", tableName); + + Table table = validationCatalog.loadTable(tableIdent); + Assert.assertEquals("Should have 2 snapshots", 2, Iterables.size(table.snapshots())); + + Snapshot currentSnapshot = table.currentSnapshot(); + if (mode(table) == COPY_ON_WRITE) { + validateCopyOnWrite(currentSnapshot, "0", null, null); + } else { + validateMergeOnRead(currentSnapshot, "0", null, null); + } + + assertEquals("Should have expected rows", + ImmutableList.of(row(1, "hr"), row(2, "hardware"), row(null, "hr")), + sql("SELECT * FROM %s ORDER BY id ASC NULLS LAST", tableName)); + } + + @Test + public void testUpdateWithoutCondition() { + createAndInitTable("id INT, dep STRING"); + sql("ALTER TABLE %s ADD PARTITION FIELD dep", tableName); + + sql("ALTER TABLE %s WRITE DISTRIBUTED BY PARTITION", tableName); + + sql("INSERT INTO TABLE %s VALUES (1, 'hr')", tableName); + sql("INSERT INTO TABLE %s VALUES (2, 'hardware')", tableName); + sql("INSERT INTO TABLE %s VALUES (null, 'hr')", tableName); + + // set the num of shuffle partitions to 200 instead of default 4 to reduce the chance of hashing + // records for multiple source files to one writing task (needed for a predictable num of output files) + withSQLConf(ImmutableMap.of(SQLConf.SHUFFLE_PARTITIONS().key(), "200"), () -> { + sql("UPDATE %s SET id = -1", tableName); + }); + + Table table = validationCatalog.loadTable(tableIdent); + Assert.assertEquals("Should have 4 snapshots", 4, Iterables.size(table.snapshots())); + + Snapshot currentSnapshot = table.currentSnapshot(); + + Assert.assertEquals("Operation must match", OVERWRITE, currentSnapshot.operation()); + if (mode(table) == COPY_ON_WRITE) { + Assert.assertEquals("Operation must match", OVERWRITE, currentSnapshot.operation()); + validateProperty(currentSnapshot, CHANGED_PARTITION_COUNT_PROP, "2"); + validateProperty(currentSnapshot, DELETED_FILES_PROP, "3"); + validateProperty(currentSnapshot, ADDED_FILES_PROP, ImmutableSet.of("2", "3")); + } else { + validateMergeOnRead(currentSnapshot, "2", "2", "2"); + } + + assertEquals("Should have expected rows", + ImmutableList.of(row(-1, "hardware"), row(-1, "hr"), row(-1, "hr")), + sql("SELECT * FROM %s ORDER BY dep ASC", tableName)); + } + + @Test + public void testUpdateWithNullConditions() { + createAndInitTable("id INT, dep STRING"); + + append(tableName, + "{ \"id\": 0, \"dep\": null }\n" + + "{ \"id\": 1, \"dep\": \"hr\" }\n" + + "{ \"id\": 2, \"dep\": \"hardware\" }"); + + // should not update any rows as null is never equal to null + sql("UPDATE %s SET id = -1 WHERE dep = NULL", tableName); + assertEquals("Should have expected rows", + ImmutableList.of(row(0, null), row(1, "hr"), row(2, "hardware")), + sql("SELECT * FROM %s ORDER BY id", tableName)); + + // should not update any rows the condition does not match any records + sql("UPDATE %s SET id = -1 WHERE dep = 'software'", tableName); + assertEquals("Should have expected rows", + ImmutableList.of(row(0, null), row(1, "hr"), row(2, "hardware")), + sql("SELECT * FROM %s ORDER BY id", tableName)); + + // should update one matching row with a null-safe condition + sql("UPDATE %s SET dep = 'invalid', id = -1 WHERE dep <=> NULL", tableName); + assertEquals("Should have expected rows", + ImmutableList.of(row(-1, "invalid"), row(1, "hr"), row(2, "hardware")), + sql("SELECT * FROM %s ORDER BY id", tableName)); + } + + @Test + public void testUpdateWithInAndNotInConditions() { + createAndInitTable("id INT, dep STRING"); + + append(tableName, + "{ \"id\": 1, \"dep\": \"hr\" }\n" + + "{ \"id\": 2, \"dep\": \"hardware\" }\n" + + "{ \"id\": null, \"dep\": \"hr\" }"); + + sql("UPDATE %s SET id = -1 WHERE id IN (1, null)", tableName); + assertEquals("Should have expected rows", + ImmutableList.of(row(-1, "hr"), row(2, "hardware"), row(null, "hr")), + sql("SELECT * FROM %s ORDER BY id ASC NULLS LAST", tableName)); + + sql("UPDATE %s SET id = 100 WHERE id NOT IN (null, 1)", tableName); + assertEquals("Should have expected rows", + ImmutableList.of(row(-1, "hr"), row(2, "hardware"), row(null, "hr")), + sql("SELECT * FROM %s ORDER BY id ASC NULLS LAST", tableName)); + + sql("UPDATE %s SET id = 100 WHERE id NOT IN (1, 10)", tableName); + assertEquals("Should have expected rows", + ImmutableList.of(row(100, "hardware"), row(100, "hr"), row(null, "hr")), + sql("SELECT * FROM %s ORDER BY id ASC NULLS LAST, dep", tableName)); + } + + @Test + public void testUpdateWithMultipleRowGroupsParquet() throws NoSuchTableException { + Assume.assumeTrue(fileFormat.equalsIgnoreCase("parquet")); + + createAndInitTable("id INT, dep STRING"); + sql("ALTER TABLE %s ADD PARTITION FIELD dep", tableName); + + sql("ALTER TABLE %s SET TBLPROPERTIES('%s' '%d')", tableName, PARQUET_ROW_GROUP_SIZE_BYTES, 100); + sql("ALTER TABLE %s SET TBLPROPERTIES('%s' '%d')", tableName, SPLIT_SIZE, 100); + + List ids = Lists.newArrayListWithCapacity(200); + for (int id = 1; id <= 200; id++) { + ids.add(id); + } + Dataset df = spark.createDataset(ids, Encoders.INT()) + .withColumnRenamed("value", "id") + .withColumn("dep", lit("hr")); + df.coalesce(1).writeTo(tableName).append(); + + Assert.assertEquals(200, spark.table(tableName).count()); + + // update a record from one of two row groups and copy over the second one + sql("UPDATE %s SET id = -1 WHERE id IN (200, 201)", tableName); + + Assert.assertEquals(200, spark.table(tableName).count()); + } + + @Test + public void testUpdateNestedStructFields() { + createAndInitTable("id INT, s STRUCT,m:MAP>>", + "{ \"id\": 1, \"s\": { \"c1\": 2, \"c2\": { \"a\": [1,2], \"m\": { \"a\": \"b\"} } } } }"); + + // update primitive, array, map columns inside a struct + sql("UPDATE %s SET s.c1 = -1, s.c2.m = map('k', 'v'), s.c2.a = array(-1)", tableName); + + assertEquals("Output should match", + ImmutableList.of(row(1, row(-1, row(ImmutableList.of(-1), ImmutableMap.of("k", "v"))))), + sql("SELECT * FROM %s", tableName)); + + // set primitive, array, map columns to NULL (proper casts should be in place) + sql("UPDATE %s SET s.c1 = NULL, s.c2 = NULL WHERE id IN (1)", tableName); + + assertEquals("Output should match", + ImmutableList.of(row(1, row(null, null))), + sql("SELECT * FROM %s", tableName)); + + // update all fields in a struct + sql("UPDATE %s SET s = named_struct('c1', 1, 'c2', named_struct('a', array(1), 'm', null))", tableName); + + assertEquals("Output should match", + ImmutableList.of(row(1, row(1, row(ImmutableList.of(1), null)))), + sql("SELECT * FROM %s", tableName)); + } + + @Test + public void testUpdateWithUserDefinedDistribution() { + createAndInitTable("id INT, c2 INT, c3 INT"); + sql("ALTER TABLE %s ADD PARTITION FIELD bucket(8, c3)", tableName); + + append(tableName, + "{ \"id\": 1, \"c2\": 11, \"c3\": 1 }\n" + + "{ \"id\": 2, \"c2\": 22, \"c3\": 1 }\n" + + "{ \"id\": 3, \"c2\": 33, \"c3\": 1 }"); + + // request a global sort + sql("ALTER TABLE %s WRITE ORDERED BY c2", tableName); + sql("UPDATE %s SET c2 = -22 WHERE id NOT IN (1, 3)", tableName); + assertEquals("Should have expected rows", + ImmutableList.of(row(1, 11, 1), row(2, -22, 1), row(3, 33, 1)), + sql("SELECT * FROM %s ORDER BY id ASC NULLS LAST", tableName)); + + // request a local sort + sql("ALTER TABLE %s WRITE LOCALLY ORDERED BY id", tableName); + sql("UPDATE %s SET c2 = -33 WHERE id = 3", tableName); + assertEquals("Should have expected rows", + ImmutableList.of(row(1, 11, 1), row(2, -22, 1), row(3, -33, 1)), + sql("SELECT * FROM %s ORDER BY id ASC NULLS LAST", tableName)); + + // request a hash distribution + local sort + sql("ALTER TABLE %s WRITE DISTRIBUTED BY PARTITION ORDERED BY id", tableName); + sql("UPDATE %s SET c2 = -11 WHERE id = 1", tableName); + assertEquals("Should have expected rows", + ImmutableList.of(row(1, -11, 1), row(2, -22, 1), row(3, -33, 1)), + sql("SELECT * FROM %s ORDER BY id ASC NULLS LAST", tableName)); + } + + @Test + public synchronized void testUpdateWithSerializableIsolation() throws InterruptedException { + // cannot run tests with concurrency for Hadoop tables without atomic renames + Assume.assumeFalse(catalogName.equalsIgnoreCase("testhadoop")); + + createAndInitTable("id INT, dep STRING"); + + sql("ALTER TABLE %s SET TBLPROPERTIES('%s' '%s')", tableName, UPDATE_ISOLATION_LEVEL, "serializable"); + + ExecutorService executorService = MoreExecutors.getExitingExecutorService( + (ThreadPoolExecutor) Executors.newFixedThreadPool(2)); + + AtomicInteger barrier = new AtomicInteger(0); + + // update thread + Future updateFuture = executorService.submit(() -> { + for (int numOperations = 0; numOperations < Integer.MAX_VALUE; numOperations++) { + while (barrier.get() < numOperations * 2) { + sleep(10); + } + sql("UPDATE %s SET id = -1 WHERE id = 1", tableName); + barrier.incrementAndGet(); + } + }); + + // append thread + Future appendFuture = executorService.submit(() -> { + for (int numOperations = 0; numOperations < Integer.MAX_VALUE; numOperations++) { + while (barrier.get() < numOperations * 2) { + sleep(10); + } + sql("INSERT INTO TABLE %s VALUES (1, 'hr')", tableName); + barrier.incrementAndGet(); + } + }); + + try { + updateFuture.get(); + Assert.fail("Expected a validation exception"); + } catch (ExecutionException e) { + Throwable sparkException = e.getCause(); + Assert.assertThat(sparkException, CoreMatchers.instanceOf(SparkException.class)); + Throwable validationException = sparkException.getCause(); + Assert.assertThat(validationException, CoreMatchers.instanceOf(ValidationException.class)); + String errMsg = validationException.getMessage(); + Assert.assertThat(errMsg, CoreMatchers.containsString("Found conflicting files that can contain")); + } finally { + appendFuture.cancel(true); + } + + executorService.shutdown(); + Assert.assertTrue("Timeout", executorService.awaitTermination(2, TimeUnit.MINUTES)); + } + + @Test + public synchronized void testUpdateWithSnapshotIsolation() throws InterruptedException, ExecutionException { + // cannot run tests with concurrency for Hadoop tables without atomic renames + Assume.assumeFalse(catalogName.equalsIgnoreCase("testhadoop")); + + createAndInitTable("id INT, dep STRING"); + + sql("ALTER TABLE %s SET TBLPROPERTIES('%s' '%s')", tableName, UPDATE_ISOLATION_LEVEL, "snapshot"); + + ExecutorService executorService = MoreExecutors.getExitingExecutorService( + (ThreadPoolExecutor) Executors.newFixedThreadPool(2)); + + AtomicInteger barrier = new AtomicInteger(0); + + // update thread + Future updateFuture = executorService.submit(() -> { + for (int numOperations = 0; numOperations < 20; numOperations++) { + while (barrier.get() < numOperations * 2) { + sleep(10); + } + sql("UPDATE %s SET id = -1 WHERE id = 1", tableName); + barrier.incrementAndGet(); + } + }); + + // append thread + Future appendFuture = executorService.submit(() -> { + for (int numOperations = 0; numOperations < 20; numOperations++) { + while (barrier.get() < numOperations * 2) { + sleep(10); + } + sql("INSERT INTO TABLE %s VALUES (1, 'hr')", tableName); + barrier.incrementAndGet(); + } + }); + + try { + updateFuture.get(); + } finally { + appendFuture.cancel(true); + } + + executorService.shutdown(); + Assert.assertTrue("Timeout", executorService.awaitTermination(2, TimeUnit.MINUTES)); + } + + @Test + public void testUpdateWithInferredCasts() { + createAndInitTable("id INT, s STRING", "{ \"id\": 1, \"s\": \"value\" }"); + + sql("UPDATE %s SET s = -1 WHERE id = 1", tableName); + + assertEquals("Should have expected rows", + ImmutableList.of(row(1, "-1")), + sql("SELECT * FROM %s", tableName)); + } + + @Test + public void testUpdateModifiesNullStruct() { + createAndInitTable("id INT, s STRUCT", "{ \"id\": 1, \"s\": null }"); + + sql("UPDATE %s SET s.n1 = -1 WHERE id = 1", tableName); + + assertEquals("Should have expected rows", + ImmutableList.of(row(1, row(-1, null))), + sql("SELECT * FROM %s", tableName)); + } + + @Test + public void testUpdateRefreshesRelationCache() { + createAndInitTable("id INT, dep STRING"); + sql("ALTER TABLE %s ADD PARTITION FIELD dep", tableName); + + append(tableName, + "{ \"id\": 1, \"dep\": \"hr\" }\n" + + "{ \"id\": 3, \"dep\": \"hr\" }"); + + append(tableName, + "{ \"id\": 1, \"dep\": \"hardware\" }\n" + + "{ \"id\": 2, \"dep\": \"hardware\" }"); + + Dataset query = spark.sql("SELECT * FROM " + tableName + " WHERE id = 1"); + query.createOrReplaceTempView("tmp"); + + spark.sql("CACHE TABLE tmp"); + + assertEquals("View should have correct data", + ImmutableList.of(row(1, "hardware"), row(1, "hr")), + sql("SELECT * FROM tmp ORDER BY id, dep")); + + sql("UPDATE %s SET id = -1 WHERE id = 1", tableName); + + Table table = validationCatalog.loadTable(tableIdent); + Assert.assertEquals("Should have 3 snapshots", 3, Iterables.size(table.snapshots())); + + Snapshot currentSnapshot = table.currentSnapshot(); + if (mode(table) == COPY_ON_WRITE) { + validateCopyOnWrite(currentSnapshot, "2", "2", "2"); + } else { + validateMergeOnRead(currentSnapshot, "2", "2", "2"); + } + + assertEquals("Should have expected rows", + ImmutableList.of(row(-1, "hardware"), row(-1, "hr"), row(2, "hardware"), row(3, "hr")), + sql("SELECT * FROM %s ORDER BY id, dep", tableName)); + + assertEquals("Should refresh the relation cache", + ImmutableList.of(), + sql("SELECT * FROM tmp ORDER BY id, dep")); + + spark.sql("UNCACHE TABLE tmp"); + } + + @Test + public void testUpdateWithInSubquery() { + createAndInitTable("id INT, dep STRING"); + + append(tableName, + "{ \"id\": 1, \"dep\": \"hr\" }\n" + + "{ \"id\": 2, \"dep\": \"hardware\" }\n" + + "{ \"id\": null, \"dep\": \"hr\" }"); + + createOrReplaceView("updated_id", Arrays.asList(0, 1, null), Encoders.INT()); + createOrReplaceView("updated_dep", Arrays.asList("software", "hr"), Encoders.STRING()); + + sql("UPDATE %s SET id = -1 WHERE " + + "id IN (SELECT * FROM updated_id) AND " + + "dep IN (SELECT * from updated_dep)", tableName); + assertEquals("Should have expected rows", + ImmutableList.of(row(-1, "hr"), row(2, "hardware"), row(null, "hr")), + sql("SELECT * FROM %s ORDER BY id ASC NULLS LAST", tableName)); + + sql("UPDATE %s SET id = 5 WHERE id IS NULL OR id IN (SELECT value + 1 FROM updated_id)", tableName); + assertEquals("Should have expected rows", + ImmutableList.of(row(-1, "hr"), row(5, "hardware"), row(5, "hr")), + sql("SELECT * FROM %s ORDER BY id, dep", tableName)); + + append(tableName, + "{ \"id\": null, \"dep\": \"hr\" }\n" + + "{ \"id\": 2, \"dep\": \"hr\" }"); + assertEquals("Should have expected rows", + ImmutableList.of(row(-1, "hr"), row(2, "hr"), row(5, "hardware"), row(5, "hr"), row(null, "hr")), + sql("SELECT * FROM %s ORDER BY id ASC NULLS LAST, dep", tableName)); + + sql("UPDATE %s SET id = 10 WHERE id IN (SELECT value + 2 FROM updated_id) AND dep = 'hr'", tableName); + assertEquals("Should have expected rows", + ImmutableList.of(row(-1, "hr"), row(5, "hardware"), row(5, "hr"), row(10, "hr"), row(null, "hr")), + sql("SELECT * FROM %s ORDER BY id ASC NULLS LAST, dep", tableName)); + } + + @Test + public void testUpdateWithInSubqueryAndDynamicFileFiltering() { + createAndInitTable("id INT, dep STRING"); + sql("ALTER TABLE %s ADD PARTITION FIELD dep", tableName); + + sql("ALTER TABLE %s WRITE DISTRIBUTED BY PARTITION", tableName); + + append(tableName, + "{ \"id\": 1, \"dep\": \"hr\" }\n" + + "{ \"id\": 3, \"dep\": \"hr\" }"); + append(tableName, + "{ \"id\": 1, \"dep\": \"hardware\" }\n" + + "{ \"id\": 2, \"dep\": \"hardware\" }"); + + createOrReplaceView("updated_id", Arrays.asList(-1, 2), Encoders.INT()); + + sql("UPDATE %s SET id = -1 WHERE id IN (SELECT * FROM updated_id)", tableName); + + Table table = validationCatalog.loadTable(tableIdent); + Assert.assertEquals("Should have 3 snapshots", 3, Iterables.size(table.snapshots())); + + Snapshot currentSnapshot = table.currentSnapshot(); + if (mode(table) == COPY_ON_WRITE) { + validateCopyOnWrite(currentSnapshot, "1", "1", "1"); + } else { + validateMergeOnRead(currentSnapshot, "1", "1", "1"); + } + + assertEquals("Should have expected rows", + ImmutableList.of(row(-1, "hardware"), row(1, "hardware"), row(1, "hr"), row(3, "hr")), + sql("SELECT * FROM %s ORDER BY id, dep", tableName)); + } + + @Test + public void testUpdateWithSelfSubquery() { + createAndInitTable("id INT, dep STRING"); + + append(tableName, + "{ \"id\": 1, \"dep\": \"hr\" }\n" + + "{ \"id\": 2, \"dep\": \"hr\" }"); + + sql("UPDATE %s SET dep = 'x' WHERE id IN (SELECT id + 1 FROM %s)", tableName, tableName); + assertEquals("Should have expected rows", + ImmutableList.of(row(1, "hr"), row(2, "x")), + sql("SELECT * FROM %s ORDER BY id", tableName)); + + // TODO: Spark does not support AQE and DPP with aggregates at the moment + withSQLConf(ImmutableMap.of(SQLConf.ADAPTIVE_EXECUTION_ENABLED().key(), "false"), () -> { + sql("UPDATE %s SET dep = 'y' WHERE " + + "id = (SELECT count(*) FROM (SELECT DISTINCT id FROM %s) AS t)", tableName, tableName); + assertEquals("Should have expected rows", + ImmutableList.of(row(1, "hr"), row(2, "y")), + sql("SELECT * FROM %s ORDER BY id", tableName)); + }); + + sql("UPDATE %s SET id = (SELECT id - 2 FROM %s WHERE id = 1)", tableName, tableName); + assertEquals("Should have expected rows", + ImmutableList.of(row(-1, "hr"), row(-1, "y")), + sql("SELECT * FROM %s ORDER BY id, dep", tableName)); + } + + @Test + public void testUpdateWithMultiColumnInSubquery() { + createAndInitTable("id INT, dep STRING"); + + append(tableName, + "{ \"id\": 1, \"dep\": \"hr\" }\n" + + "{ \"id\": 2, \"dep\": \"hardware\" }\n" + + "{ \"id\": null, \"dep\": \"hr\" }"); + + List deletedEmployees = Arrays.asList(new Employee(null, "hr"), new Employee(1, "hr")); + createOrReplaceView("deleted_employee", deletedEmployees, Encoders.bean(Employee.class)); + + sql("UPDATE %s SET dep = 'x', id = -1 WHERE (id, dep) IN (SELECT id, dep FROM deleted_employee)", tableName); + assertEquals("Should have expected rows", + ImmutableList.of(row(-1, "x"), row(2, "hardware"), row(null, "hr")), + sql("SELECT * FROM %s ORDER BY id ASC NULLS LAST", tableName)); + } + + @Test + public void testUpdateWithNotInSubquery() { + createAndInitTable("id INT, dep STRING"); + + append(tableName, + "{ \"id\": 1, \"dep\": \"hr\" }\n" + + "{ \"id\": 2, \"dep\": \"hardware\" }\n" + + "{ \"id\": null, \"dep\": \"hr\" }"); + + createOrReplaceView("updated_id", Arrays.asList(-1, -2, null), Encoders.INT()); + createOrReplaceView("updated_dep", Arrays.asList("software", "hr"), Encoders.STRING()); + + // the file filter subquery (nested loop lef-anti join) returns 0 records + sql("UPDATE %s SET id = -1 WHERE id NOT IN (SELECT * FROM updated_id)", tableName); + assertEquals("Should have expected rows", + ImmutableList.of(row(1, "hr"), row(2, "hardware"), row(null, "hr")), + sql("SELECT * FROM %s ORDER BY id ASC NULLS LAST", tableName)); + + sql("UPDATE %s SET id = -1 WHERE id NOT IN (SELECT * FROM updated_id WHERE value IS NOT NULL)", tableName); + assertEquals("Should have expected rows", + ImmutableList.of(row(-1, "hardware"), row(-1, "hr"), row(null, "hr")), + sql("SELECT * FROM %s ORDER BY id ASC NULLS LAST, dep", tableName)); + + sql("UPDATE %s SET id = 5 WHERE id NOT IN (SELECT * FROM updated_id) OR dep IN ('software', 'hr')", tableName); + assertEquals("Should have expected rows", + ImmutableList.of(row(-1, "hardware"), row(5, "hr"), row(5, "hr")), + sql("SELECT * FROM %s ORDER BY id ASC NULLS LAST, dep", tableName)); + } + + @Test + public void testUpdateWithExistSubquery() { + createAndInitTable("id INT, dep STRING"); + + append(tableName, + "{ \"id\": 1, \"dep\": \"hr\" }\n" + + "{ \"id\": 2, \"dep\": \"hardware\" }\n" + + "{ \"id\": null, \"dep\": \"hr\" }"); + + createOrReplaceView("updated_id", Arrays.asList(-1, -2, null), Encoders.INT()); + createOrReplaceView("updated_dep", Arrays.asList("hr", null), Encoders.STRING()); + + sql("UPDATE %s t SET id = -1 WHERE EXISTS (SELECT 1 FROM updated_id u WHERE t.id = u.value)", tableName); + assertEquals("Should have expected rows", + ImmutableList.of(row(1, "hr"), row(2, "hardware"), row(null, "hr")), + sql("SELECT * FROM %s ORDER BY id ASC NULLS LAST", tableName)); + + sql("UPDATE %s t SET dep = 'x', id = -1 WHERE " + + "EXISTS (SELECT 1 FROM updated_id u WHERE t.id = u.value + 2)", tableName); + assertEquals("Should have expected rows", + ImmutableList.of(row(-1, "x"), row(2, "hardware"), row(null, "hr")), + sql("SELECT * FROM %s ORDER BY id ASC NULLS LAST", tableName)); + + sql("UPDATE %s t SET id = -2 WHERE " + + "EXISTS (SELECT 1 FROM updated_id u WHERE t.id = u.value) OR " + + "t.id IS NULL", tableName); + assertEquals("Should have expected rows", + ImmutableList.of(row(-2, "hr"), row(-2, "x"), row(2, "hardware")), + sql("SELECT * FROM %s ORDER BY id, dep", tableName)); + + sql("UPDATE %s t SET id = 1 WHERE " + + "EXISTS (SELECT 1 FROM updated_id ui WHERE t.id = ui.value) AND " + + "EXISTS (SELECT 1 FROM updated_dep ud WHERE t.dep = ud.value)", tableName); + assertEquals("Should have expected rows", + ImmutableList.of(row(-2, "x"), row(1, "hr"), row(2, "hardware")), + sql("SELECT * FROM %s ORDER BY id, dep", tableName)); + } + + @Test + public void testUpdateWithNotExistsSubquery() { + createAndInitTable("id INT, dep STRING"); + + append(tableName, + "{ \"id\": 1, \"dep\": \"hr\" }\n" + + "{ \"id\": 2, \"dep\": \"hardware\" }\n" + + "{ \"id\": null, \"dep\": \"hr\" }"); + + createOrReplaceView("updated_id", Arrays.asList(-1, -2, null), Encoders.INT()); + createOrReplaceView("updated_dep", Arrays.asList("hr", "software"), Encoders.STRING()); + + sql("UPDATE %s t SET id = -1 WHERE NOT EXISTS (SELECT 1 FROM updated_id u WHERE t.id = u.value + 2)", tableName); + assertEquals("Should have expected rows", + ImmutableList.of(row(-1, "hardware"), row(-1, "hr"), row(1, "hr")), + sql("SELECT * FROM %s ORDER BY id, dep", tableName)); + + sql("UPDATE %s t SET id = 5 WHERE " + + "NOT EXISTS (SELECT 1 FROM updated_id u WHERE t.id = u.value) OR " + + "t.id = 1", tableName); + assertEquals("Should have expected rows", + ImmutableList.of(row(-1, "hardware"), row(-1, "hr"), row(5, "hr")), + sql("SELECT * FROM %s ORDER BY id, dep", tableName)); + + sql("UPDATE %s t SET id = 10 WHERE " + + "NOT EXISTS (SELECT 1 FROM updated_id ui WHERE t.id = ui.value) AND " + + "EXISTS (SELECT 1 FROM updated_dep ud WHERE t.dep = ud.value)", tableName); + assertEquals("Should have expected rows", + ImmutableList.of(row(-1, "hardware"), row(-1, "hr"), row(10, "hr")), + sql("SELECT * FROM %s ORDER BY id, dep", tableName)); + } + + @Test + public void testUpdateWithScalarSubquery() { + createAndInitTable("id INT, dep STRING"); + + append(tableName, + "{ \"id\": 1, \"dep\": \"hr\" }\n" + + "{ \"id\": 2, \"dep\": \"hardware\" }\n" + + "{ \"id\": null, \"dep\": \"hr\" }"); + + createOrReplaceView("updated_id", Arrays.asList(1, 100, null), Encoders.INT()); + + // TODO: Spark does not support AQE and DPP with aggregates at the moment + withSQLConf(ImmutableMap.of(SQLConf.ADAPTIVE_EXECUTION_ENABLED().key(), "false"), () -> { + sql("UPDATE %s SET id = -1 WHERE id <= (SELECT min(value) FROM updated_id)", tableName); + assertEquals("Should have expected rows", + ImmutableList.of(row(-1, "hr"), row(2, "hardware"), row(null, "hr")), + sql("SELECT * FROM %s ORDER BY id ASC NULLS LAST", tableName)); + }); + } + + @Test + public void testUpdateThatRequiresGroupingBeforeWrite() { + createAndInitTable("id INT, dep STRING"); + sql("ALTER TABLE %s ADD PARTITION FIELD dep", tableName); + + append(tableName, + "{ \"id\": 0, \"dep\": \"hr\" }\n" + + "{ \"id\": 1, \"dep\": \"hr\" }\n" + + "{ \"id\": 2, \"dep\": \"hr\" }"); + + append(tableName, + "{ \"id\": 0, \"dep\": \"ops\" }\n" + + "{ \"id\": 1, \"dep\": \"ops\" }\n" + + "{ \"id\": 2, \"dep\": \"ops\" }"); + + append(tableName, + "{ \"id\": 0, \"dep\": \"hr\" }\n" + + "{ \"id\": 1, \"dep\": \"hr\" }\n" + + "{ \"id\": 2, \"dep\": \"hr\" }"); + + append(tableName, + "{ \"id\": 0, \"dep\": \"ops\" }\n" + + "{ \"id\": 1, \"dep\": \"ops\" }\n" + + "{ \"id\": 2, \"dep\": \"ops\" }"); + + createOrReplaceView("updated_id", Arrays.asList(1, 100), Encoders.INT()); + + String originalNumOfShufflePartitions = spark.conf().get("spark.sql.shuffle.partitions"); + try { + // set the num of shuffle partitions to 1 to ensure we have only 1 writing task + spark.conf().set("spark.sql.shuffle.partitions", "1"); + + sql("UPDATE %s t SET id = -1 WHERE id IN (SELECT * FROM updated_id)", tableName); + Assert.assertEquals("Should have expected num of rows", 12L, spark.table(tableName).count()); + } finally { + spark.conf().set("spark.sql.shuffle.partitions", originalNumOfShufflePartitions); + } + } + + @Test + public void testUpdateWithVectorization() { + createAndInitTable("id INT, dep STRING"); + + append(tableName, + "{ \"id\": 0, \"dep\": \"hr\" }\n" + + "{ \"id\": 1, \"dep\": \"hr\" }\n" + + "{ \"id\": 2, \"dep\": \"hr\" }"); + + withSQLConf(ImmutableMap.of(SparkSQLProperties.VECTORIZATION_ENABLED, "true"), () -> { + sql("UPDATE %s t SET id = -1", tableName); + + assertEquals("Should have expected rows", + ImmutableList.of(row(-1, "hr"), row(-1, "hr"), row(-1, "hr")), + sql("SELECT * FROM %s ORDER BY id, dep", tableName)); + }); + } + + @Test + public void testUpdateModifyPartitionSourceField() throws NoSuchTableException { + createAndInitTable("id INT, dep STRING, country STRING"); + + sql("ALTER TABLE %s ADD PARTITION FIELD bucket(4, id)", tableName); + sql("ALTER TABLE %s ADD PARTITION FIELD dep", tableName); + + List ids = Lists.newArrayListWithCapacity(100); + for (int id = 1; id <= 100; id++) { + ids.add(id); + } + + Dataset df1 = spark.createDataset(ids, Encoders.INT()) + .withColumnRenamed("value", "id") + .withColumn("dep", lit("hr")) + .withColumn("country", lit("usa")); + df1.coalesce(1).writeTo(tableName).append(); + + Dataset df2 = spark.createDataset(ids, Encoders.INT()) + .withColumnRenamed("value", "id") + .withColumn("dep", lit("software")) + .withColumn("country", lit("usa")); + df2.coalesce(1).writeTo(tableName).append(); + + Dataset df3 = spark.createDataset(ids, Encoders.INT()) + .withColumnRenamed("value", "id") + .withColumn("dep", lit("hardware")) + .withColumn("country", lit("usa")); + df3.coalesce(1).writeTo(tableName).append(); + + sql("UPDATE %s SET id = -1 WHERE id IN (10, 11, 12, 13, 14, 15, 16, 17, 18, 19)", tableName); + Assert.assertEquals(30L, scalarSql("SELECT count(*) FROM %s WHERE id = -1", tableName)); + } + + @Test + public void testUpdateWithStaticPredicatePushdown() { + createAndInitTable("id INT, dep STRING"); + + sql("ALTER TABLE %s ADD PARTITION FIELD dep", tableName); + + // add a data file to the 'software' partition + append(tableName, "{ \"id\": 1, \"dep\": \"software\" }"); + + // add a data file to the 'hr' partition + append(tableName, "{ \"id\": 1, \"dep\": \"hr\" }"); + + Table table = validationCatalog.loadTable(tableIdent); + + Snapshot snapshot = table.currentSnapshot(); + String dataFilesCount = snapshot.summary().get(SnapshotSummary.TOTAL_DATA_FILES_PROP); + Assert.assertEquals("Must have 2 files before UPDATE", "2", dataFilesCount); + + // remove the data file from the 'hr' partition to ensure it is not scanned + DataFile dataFile = Iterables.getOnlyElement(snapshot.addedFiles(table.io())); + table.io().deleteFile(dataFile.path().toString()); + + // disable dynamic pruning and rely only on static predicate pushdown + withSQLConf(ImmutableMap.of(SQLConf.DYNAMIC_PARTITION_PRUNING_ENABLED().key(), "false"), () -> { + sql("UPDATE %s SET id = -1 WHERE dep IN ('software') AND id == 1", tableName); + }); + } + + @Test + public void testUpdateWithInvalidUpdates() { + createAndInitTable("id INT, a ARRAY>, m MAP"); + + AssertHelpers.assertThrows("Should complain about updating an array column", + AnalysisException.class, "Updating nested fields is only supported for structs", + () -> sql("UPDATE %s SET a.c1 = 1", tableName)); + + AssertHelpers.assertThrows("Should complain about updating a map column", + AnalysisException.class, "Updating nested fields is only supported for structs", + () -> sql("UPDATE %s SET m.key = 'new_key'", tableName)); + } + + @Test + public void testUpdateWithConflictingAssignments() { + createAndInitTable("id INT, c STRUCT>"); + + AssertHelpers.assertThrows("Should complain about conflicting updates to a top-level column", + AnalysisException.class, "Updates are in conflict", + () -> sql("UPDATE %s t SET t.id = 1, t.c.n1 = 2, t.id = 2", tableName)); + + AssertHelpers.assertThrows("Should complain about conflicting updates to a nested column", + AnalysisException.class, "Updates are in conflict for these columns", + () -> sql("UPDATE %s t SET t.c.n1 = 1, t.id = 2, t.c.n1 = 2", tableName)); + + AssertHelpers.assertThrows("Should complain about conflicting updates to a nested column", + AnalysisException.class, "Updates are in conflict", + () -> { + sql("UPDATE %s SET c.n1 = 1, c = named_struct('n1', 1, 'n2', named_struct('dn1', 1, 'dn2', 2))", tableName); + }); + } + + @Test + public void testUpdateWithInvalidAssignments() { + createAndInitTable("id INT NOT NULL, s STRUCT> NOT NULL"); + + for (String policy : new String[]{"ansi", "strict"}) { + withSQLConf(ImmutableMap.of("spark.sql.storeAssignmentPolicy", policy), () -> { + + AssertHelpers.assertThrows("Should complain about writing nulls to a top-level column", + AnalysisException.class, "Cannot write nullable values to non-null column", + () -> sql("UPDATE %s t SET t.id = NULL", tableName)); + + AssertHelpers.assertThrows("Should complain about writing nulls to a nested column", + AnalysisException.class, "Cannot write nullable values to non-null column", + () -> sql("UPDATE %s t SET t.s.n1 = NULL", tableName)); + + AssertHelpers.assertThrows("Should complain about writing missing fields in structs", + AnalysisException.class, "missing fields", + () -> sql("UPDATE %s t SET t.s = named_struct('n1', 1)", tableName)); + + AssertHelpers.assertThrows("Should complain about writing invalid data types", + AnalysisException.class, "Cannot safely cast", + () -> sql("UPDATE %s t SET t.s.n1 = 'str'", tableName)); + + AssertHelpers.assertThrows("Should complain about writing incompatible structs", + AnalysisException.class, "field name does not match", + () -> sql("UPDATE %s t SET t.s.n2 = named_struct('dn2', 1, 'dn1', 2)", tableName)); + }); + } + } + + @Test + public void testUpdateWithNonDeterministicCondition() { + createAndInitTable("id INT, dep STRING"); + + AssertHelpers.assertThrows("Should complain about non-deterministic expressions", + AnalysisException.class, "nondeterministic expressions are only allowed", + () -> sql("UPDATE %s SET id = -1 WHERE id = 1 AND rand() > 0.5", tableName)); + } + + @Test + public void testUpdateOnNonIcebergTableNotSupported() { + createOrReplaceView("testtable", "{ \"c1\": -100, \"c2\": -200 }"); + + AssertHelpers.assertThrows("UPDATE is not supported for non iceberg table", + UnsupportedOperationException.class, "not supported temporarily", + () -> sql("UPDATE %s SET c1 = -1 WHERE c2 = 1", "testtable")); + } + + private RowLevelOperationMode mode(Table table) { + String modeName = table.properties().getOrDefault(UPDATE_MODE, UPDATE_MODE_DEFAULT); + return RowLevelOperationMode.fromName(modeName); + } +} diff --git a/spark/v3.3/spark-runtime/LICENSE b/spark/v3.3/spark-runtime/LICENSE new file mode 100644 index 000000000000..9d1522431696 --- /dev/null +++ b/spark/v3.3/spark-runtime/LICENSE @@ -0,0 +1,629 @@ + + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + +-------------------------------------------------------------------------------- + +This binary artifact contains Apache Avro. + +Copyright: 2014-2017 The Apache Software Foundation. +Home page: https://parquet.apache.org/ +License: http://www.apache.org/licenses/LICENSE-2.0 + +-------------------------------------------------------------------------------- + +This binary artifact contains the Jackson JSON processor. + +Copyright: 2007-2019 Tatu Saloranta and other contributors +Home page: http://jackson.codehaus.org/ +License: http://www.apache.org/licenses/LICENSE-2.0.txt + +-------------------------------------------------------------------------------- + +This binary artifact contains Paranamer. + +Copyright: 2000-2007 INRIA, France Telecom, 2006-2018 Paul Hammant & ThoughtWorks Inc +Home page: https://github.com/paul-hammant/paranamer +License: https://github.com/paul-hammant/paranamer/blob/master/LICENSE.txt (BSD) + +License text: +| Portions copyright (c) 2006-2018 Paul Hammant & ThoughtWorks Inc +| Portions copyright (c) 2000-2007 INRIA, France Telecom +| All rights reserved. +| +| Redistribution and use in source and binary forms, with or without +| modification, are permitted provided that the following conditions +| are met: +| 1. Redistributions of source code must retain the above copyright +| notice, this list of conditions and the following disclaimer. +| 2. Redistributions in binary form must reproduce the above copyright +| notice, this list of conditions and the following disclaimer in the +| documentation and/or other materials provided with the distribution. +| 3. Neither the name of the copyright holders nor the names of its +| contributors may be used to endorse or promote products derived from +| this software without specific prior written permission. +| +| THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +| AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +| IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +| ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE +| LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +| CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +| SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +| INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +| CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +| ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF +| THE POSSIBILITY OF SUCH DAMAGE. + +-------------------------------------------------------------------------------- + +This binary artifact contains Apache Parquet. + +Copyright: 2014-2017 The Apache Software Foundation. +Home page: https://parquet.apache.org/ +License: http://www.apache.org/licenses/LICENSE-2.0 + +-------------------------------------------------------------------------------- + +This binary artifact contains Apache Thrift. + +Copyright: 2006-2010 The Apache Software Foundation. +Home page: https://thrift.apache.org/ +License: http://www.apache.org/licenses/LICENSE-2.0 + +-------------------------------------------------------------------------------- + +This binary artifact contains code from Daniel Lemire's JavaFastPFOR project. + +Copyright: 2013 Daniel Lemire +Home page: https://github.com/lemire/JavaFastPFOR +License: Apache License Version 2.0 http://www.apache.org/licenses/LICENSE-2.0 + +-------------------------------------------------------------------------------- + +This binary artifact contains fastutil. + +Copyright: 2002-2014 Sebastiano Vigna +Home page: http://fastutil.di.unimi.it/ +License: http://www.apache.org/licenses/LICENSE-2.0.html + +-------------------------------------------------------------------------------- + +This binary artifact contains Apache ORC. + +Copyright: 2013-2019 The Apache Software Foundation. +Home page: https://orc.apache.org/ +License: http://www.apache.org/licenses/LICENSE-2.0 + +-------------------------------------------------------------------------------- + +This binary artifact contains Apache Hive's storage API via ORC. + +Copyright: 2013-2019 The Apache Software Foundation. +Home page: https://hive.apache.org/ +License: http://www.apache.org/licenses/LICENSE-2.0 + +-------------------------------------------------------------------------------- + +This binary artifact contains Google protobuf via ORC. + +Copyright: 2008 Google Inc. +Home page: https://developers.google.com/protocol-buffers +License: https://github.com/protocolbuffers/protobuf/blob/master/LICENSE (BSD) + +License text: + +| Copyright 2008 Google Inc. All rights reserved. +| +| Redistribution and use in source and binary forms, with or without +| modification, are permitted provided that the following conditions are +| met: +| +| * Redistributions of source code must retain the above copyright +| notice, this list of conditions and the following disclaimer. +| * Redistributions in binary form must reproduce the above +| copyright notice, this list of conditions and the following disclaimer +| in the documentation and/or other materials provided with the +| distribution. +| * Neither the name of Google Inc. nor the names of its +| contributors may be used to endorse or promote products derived from +| this software without specific prior written permission. +| +| THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +| "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +| LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +| A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +| OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +| SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +| LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +| DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +| THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +| (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +| OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +| +| Code generated by the Protocol Buffer compiler is owned by the owner +| of the input file used when generating it. This code is not +| standalone and requires a support library to be linked with it. This +| support library is itself covered by the above license. + +-------------------------------------------------------------------------------- + +This binary artifact contains Airlift Aircompressor. + +Copyright: 2011-2019 Aircompressor authors. +Home page: https://github.com/airlift/aircompressor +License: http://www.apache.org/licenses/LICENSE-2.0 + +-------------------------------------------------------------------------------- + +This binary artifact contains Airlift Slice. + +Copyright: 2013-2019 Slice authors. +Home page: https://github.com/airlift/slice +License: http://www.apache.org/licenses/LICENSE-2.0 + +-------------------------------------------------------------------------------- + +This binary artifact contains JetBrains annotations. + +Copyright: 2000-2020 JetBrains s.r.o. +Home page: https://github.com/JetBrains/java-annotations +License: http://www.apache.org/licenses/LICENSE-2.0 + +-------------------------------------------------------------------------------- + +This binary artifact contains code from Cloudera Kite. + +Copyright: 2013-2017 Cloudera Inc. +Home page: https://kitesdk.org/ +License: http://www.apache.org/licenses/LICENSE-2.0 + +-------------------------------------------------------------------------------- + +This binary artifact contains code from Presto. + +Copyright: 2016 Facebook and contributors +Home page: https://prestodb.io/ +License: http://www.apache.org/licenses/LICENSE-2.0 + +-------------------------------------------------------------------------------- + +This binary artifact contains Google Guava. + +Copyright: 2006-2019 The Guava Authors +Home page: https://github.com/google/guava +License: http://www.apache.org/licenses/LICENSE-2.0 + +-------------------------------------------------------------------------------- + +This binary artifact contains Google Error Prone Annotations. + +Copyright: Copyright 2011-2019 The Error Prone Authors +Home page: https://github.com/google/error-prone +License: http://www.apache.org/licenses/LICENSE-2.0 + +-------------------------------------------------------------------------------- + +This binary artifact contains findbugs-annotations by Stephen Connolly. + +Copyright: 2011-2016 Stephen Connolly, Greg Lucas +Home page: https://github.com/stephenc/findbugs-annotations +License: http://www.apache.org/licenses/LICENSE-2.0 + +-------------------------------------------------------------------------------- + +This binary artifact contains Google j2objc Annotations. + +Copyright: Copyright 2012-2018 Google Inc. +Home page: https://github.com/google/j2objc/tree/master/annotations +License: http://www.apache.org/licenses/LICENSE-2.0 + +-------------------------------------------------------------------------------- + +This binary artifact contains checkerframework checker-qual Annotations. + +Copyright: 2004-2019 the Checker Framework developers +Home page: https://github.com/typetools/checker-framework +License: https://github.com/typetools/checker-framework/blob/master/LICENSE.txt (MIT license) + +License text: +| The annotations are licensed under the MIT License. (The text of this +| license appears below.) More specifically, all the parts of the Checker +| Framework that you might want to include with your own program use the +| MIT License. This is the checker-qual.jar file and all the files that +| appear in it: every file in a qual/ directory, plus utility files such +| as NullnessUtil.java, RegexUtil.java, SignednessUtil.java, etc. +| In addition, the cleanroom implementations of third-party annotations, +| which the Checker Framework recognizes as aliases for its own +| annotations, are licensed under the MIT License. +| +| Permission is hereby granted, free of charge, to any person obtaining a copy +| of this software and associated documentation files (the "Software"), to deal +| in the Software without restriction, including without limitation the rights +| to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +| copies of the Software, and to permit persons to whom the Software is +| furnished to do so, subject to the following conditions: +| +| The above copyright notice and this permission notice shall be included in +| all copies or substantial portions of the Software. +| +| THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +| IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +| FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +| AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +| LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +| OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +| THE SOFTWARE. + +-------------------------------------------------------------------------------- + +This binary artifact contains Animal Sniffer Annotations. + +Copyright: 2009-2018 codehaus.org +Home page: https://www.mojohaus.org/animal-sniffer/animal-sniffer-annotations/ +License: https://www.mojohaus.org/animal-sniffer/animal-sniffer-annotations/license.html (MIT license) + +License text: +| The MIT License +| +| Copyright (c) 2009 codehaus.org. +| +| Permission is hereby granted, free of charge, to any person obtaining a copy +| of this software and associated documentation files (the "Software"), to deal +| in the Software without restriction, including without limitation the rights +| to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +| copies of the Software, and to permit persons to whom the Software is +| furnished to do so, subject to the following conditions: +| +| The above copyright notice and this permission notice shall be included in +| all copies or substantial portions of the Software. +| +| THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +| IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +| FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +| AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +| LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +| OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +| THE SOFTWARE. + +-------------------------------------------------------------------------------- + +This binary artifact contains Caffeine by Ben Manes. + +Copyright: 2014-2019 Ben Manes and contributors +Home page: https://github.com/ben-manes/caffeine +License: http://www.apache.org/licenses/LICENSE-2.0 + +-------------------------------------------------------------------------------- + +This binary artifact contains Apache Arrow. + +Copyright: 2016-2019 The Apache Software Foundation. +Home page: https://arrow.apache.org/ +License: http://www.apache.org/licenses/LICENSE-2.0 + +-------------------------------------------------------------------------------- + +This binary artifact contains Netty's buffer library. + +Copyright: 2014-2020 The Netty Project +Home page: https://netty.io/ +License: http://www.apache.org/licenses/LICENSE-2.0 + +-------------------------------------------------------------------------------- + +This binary artifact contains Google FlatBuffers. + +Copyright: 2013-2020 Google Inc. +Home page: https://google.github.io/flatbuffers/ +License: http://www.apache.org/licenses/LICENSE-2.0 + +-------------------------------------------------------------------------------- + +This binary artifact contains Carrot Search Labs HPPC. + +Copyright: 2002-2019 Carrot Search s.c. +Home page: http://labs.carrotsearch.com/hppc.html +License: http://www.apache.org/licenses/LICENSE-2.0 + +-------------------------------------------------------------------------------- + +This binary artifact contains code from Apache Lucene via Carrot Search HPPC. + +Copyright: 2011-2020 The Apache Software Foundation. +Home page: https://lucene.apache.org/ +License: http://www.apache.org/licenses/LICENSE-2.0 + +-------------------------------------------------------------------------------- + +This binary artifact contains Apache Yetus audience annotations. + +Copyright: 2008-2020 The Apache Software Foundation. +Home page: https://yetus.apache.org/ +License: http://www.apache.org/licenses/LICENSE-2.0 + +-------------------------------------------------------------------------------- + +This binary artifact contains ThreeTen. + +Copyright: 2007-present, Stephen Colebourne & Michael Nascimento Santos. +Home page: https://www.threeten.org/threeten-extra/ +License: https://github.com/ThreeTen/threeten-extra/blob/master/LICENSE.txt (BSD 3-clause) + +License text: + +| All rights reserved. +| +| * Redistribution and use in source and binary forms, with or without +| modification, are permitted provided that the following conditions are met: +| +| * Redistributions of source code must retain the above copyright notice, +| this list of conditions and the following disclaimer. +| +| * Redistributions in binary form must reproduce the above copyright notice, +| this list of conditions and the following disclaimer in the documentation +| and/or other materials provided with the distribution. +| +| * Neither the name of JSR-310 nor the names of its contributors +| may be used to endorse or promote products derived from this software +| without specific prior written permission. +| +| THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +| "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +| LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +| A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +| CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +| EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +| PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +| PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF +| LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING +| NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS +| SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +-------------------------------------------------------------------------------- + +This binary artifact contains code from Project Nessie. + +Copyright: 2020 Dremio Corporation. +Home page: https://projectnessie.org/ +License: http://www.apache.org/licenses/LICENSE-2.0 + +-------------------------------------------------------------------------------- + +This product includes code from Apache Spark. + +* vectorized reading of definition levels in BaseVectorizedParquetValuesReader.java +* portions of the extensions parser +* casting logic in AssignmentAlignmentSupport +* implementation of SetAccumulator. + +Copyright: 2011-2018 The Apache Software Foundation +Home page: https://spark.apache.org/ +License: https://www.apache.org/licenses/LICENSE-2.0 + +-------------------------------------------------------------------------------- + +This product includes code from Delta Lake. + +* AssignmentAlignmentSupport is an independent development but UpdateExpressionsSupport in Delta was used as a reference. + +Copyright: 2020 The Delta Lake Project Authors. +Home page: https://delta.io/ +License: https://www.apache.org/licenses/LICENSE-2.0 + +-------------------------------------------------------------------------------- + +This binary includes code from Apache Commons. + +* Core ArrayUtil. + +Copyright: 2020 The Apache Software Foundation +Home page: https://commons.apache.org/ +License: https://www.apache.org/licenses/LICENSE-2.0 + +-------------------------------------------------------------------------------- + +This binary artifact contains Apache HttpComponents Client. + +Copyright: 1999-2022 The Apache Software Foundation. +Home page: https://hc.apache.org/ +License: http://www.apache.org/licenses/LICENSE-2.0 diff --git a/spark/v3.3/spark-runtime/NOTICE b/spark/v3.3/spark-runtime/NOTICE new file mode 100644 index 000000000000..4a1f4dfde1cc --- /dev/null +++ b/spark/v3.3/spark-runtime/NOTICE @@ -0,0 +1,508 @@ + +Apache Iceberg +Copyright 2017-2022 The Apache Software Foundation + +This product includes software developed at +The Apache Software Foundation (http://www.apache.org/). + +-------------------------------------------------------------------------------- + +This binary artifact contains code from Kite, developed at Cloudera, Inc. with +the following copyright notice: + +| Copyright 2013 Cloudera Inc. +| +| Licensed under the Apache License, Version 2.0 (the "License"); +| you may not use this file except in compliance with the License. +| You may obtain a copy of the License at +| +| http://www.apache.org/licenses/LICENSE-2.0 +| +| Unless required by applicable law or agreed to in writing, software +| distributed under the License is distributed on an "AS IS" BASIS, +| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +| See the License for the specific language governing permissions and +| limitations under the License. + +-------------------------------------------------------------------------------- + +This binary artifact includes Apache ORC with the following in its NOTICE file: + +| Apache ORC +| Copyright 2013-2019 The Apache Software Foundation +| +| This product includes software developed by The Apache Software +| Foundation (http://www.apache.org/). +| +| This product includes software developed by Hewlett-Packard: +| (c) Copyright [2014-2015] Hewlett-Packard Development Company, L.P + +-------------------------------------------------------------------------------- + +This binary artifact includes Airlift Aircompressor with the following in its +NOTICE file: + +| Snappy Copyright Notices +| ========================= +| +| * Copyright 2011 Dain Sundstrom +| * Copyright 2011, Google Inc. +| +| +| Snappy License +| =============== +| Copyright 2011, Google Inc. +| All rights reserved. +| +| Redistribution and use in source and binary forms, with or without +| modification, are permitted provided that the following conditions are +| met: +| +| * Redistributions of source code must retain the above copyright +| notice, this list of conditions and the following disclaimer. +| * Redistributions in binary form must reproduce the above +| copyright notice, this list of conditions and the following disclaimer +| in the documentation and/or other materials provided with the +| distribution. +| * Neither the name of Google Inc. nor the names of its +| contributors may be used to endorse or promote products derived from +| this software without specific prior written permission. +| +| THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +| "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +| LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +| A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +| OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +| SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +| LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +| DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +| THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +| (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +| OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +-------------------------------------------------------------------------------- + +This binary artifact includes Carrot Search Labs HPPC with the following in its +NOTICE file: + +| ACKNOWLEDGEMENT +| =============== +| +| HPPC borrowed code, ideas or both from: +| +| * Apache Lucene, http://lucene.apache.org/ +| (Apache license) +| * Fastutil, http://fastutil.di.unimi.it/ +| (Apache license) +| * Koloboke, https://github.com/OpenHFT/Koloboke +| (Apache license) + +-------------------------------------------------------------------------------- + +This binary artifact includes Apache Yetus with the following in its NOTICE +file: + +| Apache Yetus +| Copyright 2008-2020 The Apache Software Foundation +| +| This product includes software developed at +| The Apache Software Foundation (https://www.apache.org/). +| +| --- +| Additional licenses for the Apache Yetus Source/Website: +| --- +| +| +| See LICENSE for terms. + +-------------------------------------------------------------------------------- + +This binary artifact includes Google Protobuf with the following copyright +notice: + +| Copyright 2008 Google Inc. All rights reserved. +| +| Redistribution and use in source and binary forms, with or without +| modification, are permitted provided that the following conditions are +| met: +| +| * Redistributions of source code must retain the above copyright +| notice, this list of conditions and the following disclaimer. +| * Redistributions in binary form must reproduce the above +| copyright notice, this list of conditions and the following disclaimer +| in the documentation and/or other materials provided with the +| distribution. +| * Neither the name of Google Inc. nor the names of its +| contributors may be used to endorse or promote products derived from +| this software without specific prior written permission. +| +| THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +| "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +| LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +| A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +| OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +| SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +| LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +| DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +| THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +| (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +| OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +| +| Code generated by the Protocol Buffer compiler is owned by the owner +| of the input file used when generating it. This code is not +| standalone and requires a support library to be linked with it. This +| support library is itself covered by the above license. + +-------------------------------------------------------------------------------- + +This binary artifact includes Apache Arrow with the following in its NOTICE file: + +| Apache Arrow +| Copyright 2016-2019 The Apache Software Foundation +| +| This product includes software developed at +| The Apache Software Foundation (http://www.apache.org/). +| +| This product includes software from the SFrame project (BSD, 3-clause). +| * Copyright (C) 2015 Dato, Inc. +| * Copyright (c) 2009 Carnegie Mellon University. +| +| This product includes software from the Feather project (Apache 2.0) +| https://github.com/wesm/feather +| +| This product includes software from the DyND project (BSD 2-clause) +| https://github.com/libdynd +| +| This product includes software from the LLVM project +| * distributed under the University of Illinois Open Source +| +| This product includes software from the google-lint project +| * Copyright (c) 2009 Google Inc. All rights reserved. +| +| This product includes software from the mman-win32 project +| * Copyright https://code.google.com/p/mman-win32/ +| * Licensed under the MIT License; +| +| This product includes software from the LevelDB project +| * Copyright (c) 2011 The LevelDB Authors. All rights reserved. +| * Use of this source code is governed by a BSD-style license that can be +| * Moved from Kudu http://github.com/cloudera/kudu +| +| This product includes software from the CMake project +| * Copyright 2001-2009 Kitware, Inc. +| * Copyright 2012-2014 Continuum Analytics, Inc. +| * All rights reserved. +| +| This product includes software from https://github.com/matthew-brett/multibuild (BSD 2-clause) +| * Copyright (c) 2013-2016, Matt Terry and Matthew Brett; all rights reserved. +| +| This product includes software from the Ibis project (Apache 2.0) +| * Copyright (c) 2015 Cloudera, Inc. +| * https://github.com/cloudera/ibis +| +| This product includes software from Dremio (Apache 2.0) +| * Copyright (C) 2017-2018 Dremio Corporation +| * https://github.com/dremio/dremio-oss +| +| This product includes software from Google Guava (Apache 2.0) +| * Copyright (C) 2007 The Guava Authors +| * https://github.com/google/guava +| +| This product include software from CMake (BSD 3-Clause) +| * CMake - Cross Platform Makefile Generator +| * Copyright 2000-2019 Kitware, Inc. and Contributors +| +| The web site includes files generated by Jekyll. +| +| -------------------------------------------------------------------------------- +| +| This product includes code from Apache Kudu, which includes the following in +| its NOTICE file: +| +| Apache Kudu +| Copyright 2016 The Apache Software Foundation +| +| This product includes software developed at +| The Apache Software Foundation (http://www.apache.org/). +| +| Portions of this software were developed at +| Cloudera, Inc (http://www.cloudera.com/). +| +| -------------------------------------------------------------------------------- +| +| This product includes code from Apache ORC, which includes the following in +| its NOTICE file: +| +| Apache ORC +| Copyright 2013-2019 The Apache Software Foundation +| +| This product includes software developed by The Apache Software +| Foundation (http://www.apache.org/). +| +| This product includes software developed by Hewlett-Packard: +| (c) Copyright [2014-2015] Hewlett-Packard Development Company, L.P + +-------------------------------------------------------------------------------- + +This binary artifact includes Netty buffers with the following in its NOTICE +file: + +| The Netty Project +| ================= +| +| Please visit the Netty web site for more information: +| +| * https://netty.io/ +| +| Copyright 2014 The Netty Project +| +| The Netty Project licenses this file to you under the Apache License, +| version 2.0 (the "License"); you may not use this file except in compliance +| with the License. You may obtain a copy of the License at: +| +| http://www.apache.org/licenses/LICENSE-2.0 +| +| Unless required by applicable law or agreed to in writing, software +| distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +| WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +| License for the specific language governing permissions and limitations +| under the License. +| +| Also, please refer to each LICENSE..txt file, which is located in +| the 'license' directory of the distribution file, for the license terms of the +| components that this product depends on. +| +| ------------------------------------------------------------------------------- +| This product contains the extensions to Java Collections Framework which has +| been derived from the works by JSR-166 EG, Doug Lea, and Jason T. Greene: +| +| * LICENSE: +| * license/LICENSE.jsr166y.txt (Public Domain) +| * HOMEPAGE: +| * http://gee.cs.oswego.edu/cgi-bin/viewcvs.cgi/jsr166/ +| * http://viewvc.jboss.org/cgi-bin/viewvc.cgi/jbosscache/experimental/jsr166/ +| +| This product contains a modified version of Robert Harder's Public Domain +| Base64 Encoder and Decoder, which can be obtained at: +| +| * LICENSE: +| * license/LICENSE.base64.txt (Public Domain) +| * HOMEPAGE: +| * http://iharder.sourceforge.net/current/java/base64/ +| +| This product contains a modified portion of 'Webbit', an event based +| WebSocket and HTTP server, which can be obtained at: +| +| * LICENSE: +| * license/LICENSE.webbit.txt (BSD License) +| * HOMEPAGE: +| * https://github.com/joewalnes/webbit +| +| This product contains a modified portion of 'SLF4J', a simple logging +| facade for Java, which can be obtained at: +| +| * LICENSE: +| * license/LICENSE.slf4j.txt (MIT License) +| * HOMEPAGE: +| * http://www.slf4j.org/ +| +| This product contains a modified portion of 'Apache Harmony', an open source +| Java SE, which can be obtained at: +| +| * NOTICE: +| * license/NOTICE.harmony.txt +| * LICENSE: +| * license/LICENSE.harmony.txt (Apache License 2.0) +| * HOMEPAGE: +| * http://archive.apache.org/dist/harmony/ +| +| This product contains a modified portion of 'jbzip2', a Java bzip2 compression +| and decompression library written by Matthew J. Francis. It can be obtained at: +| +| * LICENSE: +| * license/LICENSE.jbzip2.txt (MIT License) +| * HOMEPAGE: +| * https://code.google.com/p/jbzip2/ +| +| This product contains a modified portion of 'libdivsufsort', a C API library to construct +| the suffix array and the Burrows-Wheeler transformed string for any input string of +| a constant-size alphabet written by Yuta Mori. It can be obtained at: +| +| * LICENSE: +| * license/LICENSE.libdivsufsort.txt (MIT License) +| * HOMEPAGE: +| * https://github.com/y-256/libdivsufsort +| +| This product contains a modified portion of Nitsan Wakart's 'JCTools', Java Concurrency Tools for the JVM, +| which can be obtained at: +| +| * LICENSE: +| * license/LICENSE.jctools.txt (ASL2 License) +| * HOMEPAGE: +| * https://github.com/JCTools/JCTools +| +| This product optionally depends on 'JZlib', a re-implementation of zlib in +| pure Java, which can be obtained at: +| +| * LICENSE: +| * license/LICENSE.jzlib.txt (BSD style License) +| * HOMEPAGE: +| * http://www.jcraft.com/jzlib/ +| +| This product optionally depends on 'Compress-LZF', a Java library for encoding and +| decoding data in LZF format, written by Tatu Saloranta. It can be obtained at: +| +| * LICENSE: +| * license/LICENSE.compress-lzf.txt (Apache License 2.0) +| * HOMEPAGE: +| * https://github.com/ning/compress +| +| This product optionally depends on 'lz4', a LZ4 Java compression +| and decompression library written by Adrien Grand. It can be obtained at: +| +| * LICENSE: +| * license/LICENSE.lz4.txt (Apache License 2.0) +| * HOMEPAGE: +| * https://github.com/jpountz/lz4-java +| +| This product optionally depends on 'lzma-java', a LZMA Java compression +| and decompression library, which can be obtained at: +| +| * LICENSE: +| * license/LICENSE.lzma-java.txt (Apache License 2.0) +| * HOMEPAGE: +| * https://github.com/jponge/lzma-java +| +| This product contains a modified portion of 'jfastlz', a Java port of FastLZ compression +| and decompression library written by William Kinney. It can be obtained at: +| +| * LICENSE: +| * license/LICENSE.jfastlz.txt (MIT License) +| * HOMEPAGE: +| * https://code.google.com/p/jfastlz/ +| +| This product contains a modified portion of and optionally depends on 'Protocol Buffers', Google's data +| interchange format, which can be obtained at: +| +| * LICENSE: +| * license/LICENSE.protobuf.txt (New BSD License) +| * HOMEPAGE: +| * https://github.com/google/protobuf +| +| This product optionally depends on 'Bouncy Castle Crypto APIs' to generate +| a temporary self-signed X.509 certificate when the JVM does not provide the +| equivalent functionality. It can be obtained at: +| +| * LICENSE: +| * license/LICENSE.bouncycastle.txt (MIT License) +| * HOMEPAGE: +| * http://www.bouncycastle.org/ +| +| This product optionally depends on 'Snappy', a compression library produced +| by Google Inc, which can be obtained at: +| +| * LICENSE: +| * license/LICENSE.snappy.txt (New BSD License) +| * HOMEPAGE: +| * https://github.com/google/snappy +| +| This product optionally depends on 'JBoss Marshalling', an alternative Java +| serialization API, which can be obtained at: +| +| * LICENSE: +| * license/LICENSE.jboss-marshalling.txt (Apache License 2.0) +| * HOMEPAGE: +| * https://github.com/jboss-remoting/jboss-marshalling +| +| This product optionally depends on 'Caliper', Google's micro- +| benchmarking framework, which can be obtained at: +| +| * LICENSE: +| * license/LICENSE.caliper.txt (Apache License 2.0) +| * HOMEPAGE: +| * https://github.com/google/caliper +| +| This product optionally depends on 'Apache Commons Logging', a logging +| framework, which can be obtained at: +| +| * LICENSE: +| * license/LICENSE.commons-logging.txt (Apache License 2.0) +| * HOMEPAGE: +| * http://commons.apache.org/logging/ +| +| This product optionally depends on 'Apache Log4J', a logging framework, which +| can be obtained at: +| +| * LICENSE: +| * license/LICENSE.log4j.txt (Apache License 2.0) +| * HOMEPAGE: +| * http://logging.apache.org/log4j/ +| +| This product optionally depends on 'Aalto XML', an ultra-high performance +| non-blocking XML processor, which can be obtained at: +| +| * LICENSE: +| * license/LICENSE.aalto-xml.txt (Apache License 2.0) +| * HOMEPAGE: +| * http://wiki.fasterxml.com/AaltoHome +| +| This product contains a modified version of 'HPACK', a Java implementation of +| the HTTP/2 HPACK algorithm written by Twitter. It can be obtained at: +| +| * LICENSE: +| * license/LICENSE.hpack.txt (Apache License 2.0) +| * HOMEPAGE: +| * https://github.com/twitter/hpack +| +| This product contains a modified version of 'HPACK', a Java implementation of +| the HTTP/2 HPACK algorithm written by Cory Benfield. It can be obtained at: +| +| * LICENSE: +| * license/LICENSE.hyper-hpack.txt (MIT License) +| * HOMEPAGE: +| * https://github.com/python-hyper/hpack/ +| +| This product contains a modified version of 'HPACK', a Java implementation of +| the HTTP/2 HPACK algorithm written by Tatsuhiro Tsujikawa. It can be obtained at: +| +| * LICENSE: +| * license/LICENSE.nghttp2-hpack.txt (MIT License) +| * HOMEPAGE: +| * https://github.com/nghttp2/nghttp2/ +| +| This product contains a modified portion of 'Apache Commons Lang', a Java library +| provides utilities for the java.lang API, which can be obtained at: +| +| * LICENSE: +| * license/LICENSE.commons-lang.txt (Apache License 2.0) +| * HOMEPAGE: +| * https://commons.apache.org/proper/commons-lang/ +| +| +| This product contains the Maven wrapper scripts from 'Maven Wrapper', that provides an easy way to ensure a user has everything necessary to run the Maven build. +| +| * LICENSE: +| * license/LICENSE.mvn-wrapper.txt (Apache License 2.0) +| * HOMEPAGE: +| * https://github.com/takari/maven-wrapper +| +| This product contains the dnsinfo.h header file, that provides a way to retrieve the system DNS configuration on MacOS. +| This private header is also used by Apple's open source +| mDNSResponder (https://opensource.apple.com/tarballs/mDNSResponder/). +| +| * LICENSE: +| * license/LICENSE.dnsinfo.txt (Apache License 2.0) +| * HOMEPAGE: +| * http://www.opensource.apple.com/source/configd/configd-453.19/dnsinfo/dnsinfo.h + +-------------------------------------------------------------------------------- + +This binary artifact includes Project Nessie with the following in its NOTICE +file: + +| Dremio +| Copyright 2015-2017 Dremio Corporation +| +| This product includes software developed at +| The Apache Software Foundation (http://www.apache.org/). + diff --git a/spark/v3.3/spark-runtime/src/integration/java/org/apache/iceberg/spark/SmokeTest.java b/spark/v3.3/spark-runtime/src/integration/java/org/apache/iceberg/spark/SmokeTest.java new file mode 100644 index 000000000000..40c8b27f447e --- /dev/null +++ b/spark/v3.3/spark-runtime/src/integration/java/org/apache/iceberg/spark/SmokeTest.java @@ -0,0 +1,160 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iceberg.spark; + +import java.io.IOException; +import java.util.Map; +import org.apache.iceberg.Table; +import org.apache.iceberg.catalog.TableIdentifier; +import org.apache.iceberg.spark.extensions.SparkExtensionsTestBase; +import org.apache.spark.sql.catalyst.analysis.NoSuchTableException; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; + +public class SmokeTest extends SparkExtensionsTestBase { + + public SmokeTest(String catalogName, String implementation, Map config) { + super(catalogName, implementation, config); + } + + @Before + public void dropTable() { + sql("DROP TABLE IF EXISTS %s", tableName); + } + + // Run through our Doc's Getting Started Example + // TODO Update doc example so that it can actually be run, modifications were required for this test suite to run + @Test + public void testGettingStarted() throws IOException { + // Creating a table + sql("CREATE TABLE %s (id bigint, data string) USING iceberg", tableName); + + // Writing + sql("INSERT INTO %s VALUES (1, 'a'), (2, 'b'), (3, 'c')", tableName); + Assert.assertEquals("Should have inserted 3 rows", + 3L, scalarSql("SELECT COUNT(*) FROM %s", tableName)); + + sql("DROP TABLE IF EXISTS source"); + sql("CREATE TABLE source (id bigint, data string) USING parquet LOCATION '%s'", temp.newFolder()); + sql("INSERT INTO source VALUES (10, 'd'), (11, 'ee')"); + + sql("INSERT INTO %s SELECT id, data FROM source WHERE length(data) = 1", tableName); + Assert.assertEquals("Table should now have 4 rows", + 4L, scalarSql("SELECT COUNT(*) FROM %s", tableName)); + + sql("DROP TABLE IF EXISTS updates"); + sql("CREATE TABLE updates (id bigint, data string) USING parquet LOCATION '%s'", temp.newFolder()); + sql("INSERT INTO updates VALUES (1, 'x'), (2, 'x'), (4, 'z')"); + + sql("MERGE INTO %s t USING (SELECT * FROM updates) u ON t.id = u.id\n" + + "WHEN MATCHED THEN UPDATE SET t.data = u.data\n" + + "WHEN NOT MATCHED THEN INSERT *", tableName); + Assert.assertEquals("Table should now have 5 rows", + 5L, scalarSql("SELECT COUNT(*) FROM %s", tableName)); + Assert.assertEquals("Record 1 should now have data x", + "x", scalarSql("SELECT data FROM %s WHERE id = 1", tableName)); + + // Reading + Assert.assertEquals("There should be 2 records with data x", + 2L, scalarSql( + "SELECT count(1) as count FROM %s WHERE data = 'x' GROUP BY data ", tableName)); + + // Not supported because of Spark limitation + if (!catalogName.equals("spark_catalog")) { + Assert.assertEquals("There should be 3 snapshots", + 3L, scalarSql("SELECT COUNT(*) FROM %s.snapshots", tableName)); + } + } + + // From Spark DDL Docs section + @Test + public void testAlterTable() throws NoSuchTableException { + sql("CREATE TABLE %s (category int, id bigint, data string, ts timestamp) USING iceberg", tableName); + Table table = getTable(); + // Add examples + sql("ALTER TABLE %s ADD PARTITION FIELD bucket(16, id)", tableName); + sql("ALTER TABLE %s ADD PARTITION FIELD truncate(data, 4)", tableName); + sql("ALTER TABLE %s ADD PARTITION FIELD years(ts)", tableName); + sql("ALTER TABLE %s ADD PARTITION FIELD bucket(16, category) AS shard", tableName); + table = getTable(); + Assert.assertEquals("Table should have 4 partition fields", 4, table.spec().fields().size()); + + // Drop Examples + sql("ALTER TABLE %s DROP PARTITION FIELD bucket(16, id)", tableName); + sql("ALTER TABLE %s DROP PARTITION FIELD truncate(data, 4)", tableName); + sql("ALTER TABLE %s DROP PARTITION FIELD years(ts)", tableName); + sql("ALTER TABLE %s DROP PARTITION FIELD shard", tableName); + + table = getTable(); + Assert.assertEquals("Table should have 4 partition fields", 4, table.spec().fields().size()); + // VoidTransform is package private so we can't reach it here, just checking name + Assert.assertTrue("All transforms should be void", + table.spec().fields().stream().allMatch(pf -> pf.transform().toString().equals("void"))); + + // Sort order examples + sql("ALTER TABLE %s WRITE ORDERED BY category, id", tableName); + sql("ALTER TABLE %s WRITE ORDERED BY category ASC, id DESC", tableName); + sql("ALTER TABLE %s WRITE ORDERED BY category ASC NULLS LAST, id DESC NULLS FIRST", tableName); + table = getTable(); + Assert.assertEquals("Table should be sorted on 2 fields", 2, table.sortOrder().fields().size()); + } + + @Test + public void testCreateTable() { + sql("DROP TABLE IF EXISTS %s", tableName("first")); + sql("DROP TABLE IF EXISTS %s", tableName("second")); + sql("DROP TABLE IF EXISTS %s", tableName("third")); + + sql("CREATE TABLE %s (\n" + + " id bigint COMMENT 'unique id',\n" + + " data string)\n" + + "USING iceberg", tableName("first")); + getTable("first"); // Table should exist + + sql("CREATE TABLE %s (\n" + + " id bigint,\n" + + " data string,\n" + + " category string)\n" + + "USING iceberg\n" + + "PARTITIONED BY (category)", tableName("second")); + Table second = getTable("second"); + Assert.assertEquals("Should be partitioned on 1 column", 1, second.spec().fields().size()); + + sql("CREATE TABLE %s (\n" + + " id bigint,\n" + + " data string,\n" + + " category string,\n" + + " ts timestamp)\n" + + "USING iceberg\n" + + "PARTITIONED BY (bucket(16, id), days(ts), category)", tableName("third")); + Table third = getTable("third"); + Assert.assertEquals("Should be partitioned on 3 columns", 3, third.spec().fields().size()); + } + + private Table getTable(String name) { + return validationCatalog.loadTable(TableIdentifier.of("default", name)); + } + + private Table getTable() { + return validationCatalog.loadTable(tableIdent); + } + +} diff --git a/spark/v3.3/spark/src/jmh/java/org/apache/iceberg/spark/SparkBenchmarkUtil.java b/spark/v3.3/spark/src/jmh/java/org/apache/iceberg/spark/SparkBenchmarkUtil.java new file mode 100644 index 000000000000..17df7d2cf9d7 --- /dev/null +++ b/spark/v3.3/spark/src/jmh/java/org/apache/iceberg/spark/SparkBenchmarkUtil.java @@ -0,0 +1,58 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iceberg.spark; + +import java.util.List; +import org.apache.iceberg.Schema; +import org.apache.iceberg.relocated.com.google.common.collect.Lists; +import org.apache.iceberg.types.Types; +import org.apache.spark.sql.catalyst.expressions.Attribute; +import org.apache.spark.sql.catalyst.expressions.AttributeReference; +import org.apache.spark.sql.catalyst.expressions.Expression; +import org.apache.spark.sql.catalyst.expressions.UnsafeProjection; +import org.apache.spark.sql.types.StructType; +import scala.collection.JavaConverters; + +public class SparkBenchmarkUtil { + + private SparkBenchmarkUtil() { + } + + public static UnsafeProjection projection(Schema expectedSchema, Schema actualSchema) { + StructType struct = SparkSchemaUtil.convert(actualSchema); + + List refs = JavaConverters.seqAsJavaListConverter(struct.toAttributes()).asJava(); + List attrs = Lists.newArrayListWithExpectedSize(struct.fields().length); + List exprs = Lists.newArrayListWithExpectedSize(struct.fields().length); + + for (AttributeReference ref : refs) { + attrs.add(ref.toAttribute()); + } + + for (Types.NestedField field : expectedSchema.columns()) { + int indexInIterSchema = struct.fieldIndex(field.name()); + exprs.add(refs.get(indexInIterSchema)); + } + + return UnsafeProjection.create( + JavaConverters.asScalaBufferConverter(exprs).asScala().toSeq(), + JavaConverters.asScalaBufferConverter(attrs).asScala().toSeq()); + } +} diff --git a/spark/v3.3/spark/src/jmh/java/org/apache/iceberg/spark/action/IcebergSortCompactionBenchmark.java b/spark/v3.3/spark/src/jmh/java/org/apache/iceberg/spark/action/IcebergSortCompactionBenchmark.java new file mode 100644 index 000000000000..17f9e23862be --- /dev/null +++ b/spark/v3.3/spark/src/jmh/java/org/apache/iceberg/spark/action/IcebergSortCompactionBenchmark.java @@ -0,0 +1,378 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + + +package org.apache.iceberg.spark.action; + +import java.io.IOException; +import java.util.Collections; +import java.util.UUID; +import java.util.concurrent.TimeUnit; +import org.apache.hadoop.conf.Configuration; +import org.apache.iceberg.NullOrder; +import org.apache.iceberg.Schema; +import org.apache.iceberg.SortDirection; +import org.apache.iceberg.SortOrder; +import org.apache.iceberg.Table; +import org.apache.iceberg.actions.BinPackStrategy; +import org.apache.iceberg.relocated.com.google.common.io.Files; +import org.apache.iceberg.spark.Spark3Util; +import org.apache.iceberg.spark.SparkSchemaUtil; +import org.apache.iceberg.spark.SparkSessionCatalog; +import org.apache.iceberg.spark.actions.SparkActions; +import org.apache.iceberg.types.Types; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.SaveMode; +import org.apache.spark.sql.SparkSession; +import org.apache.spark.sql.connector.catalog.Identifier; +import org.apache.spark.sql.connector.expressions.Transform; +import org.apache.spark.sql.types.DataTypes; +import org.openjdk.jmh.annotations.Benchmark; +import org.openjdk.jmh.annotations.BenchmarkMode; +import org.openjdk.jmh.annotations.Fork; +import org.openjdk.jmh.annotations.Level; +import org.openjdk.jmh.annotations.Measurement; +import org.openjdk.jmh.annotations.Mode; +import org.openjdk.jmh.annotations.Scope; +import org.openjdk.jmh.annotations.Setup; +import org.openjdk.jmh.annotations.State; +import org.openjdk.jmh.annotations.TearDown; +import org.openjdk.jmh.annotations.Threads; +import org.openjdk.jmh.annotations.Timeout; + +import static org.apache.iceberg.types.Types.NestedField.optional; +import static org.apache.iceberg.types.Types.NestedField.required; +import static org.apache.spark.sql.functions.col; +import static org.apache.spark.sql.functions.current_date; +import static org.apache.spark.sql.functions.date_add; +import static org.apache.spark.sql.functions.expr; + +@Fork(1) +@State(Scope.Benchmark) +@Measurement(iterations = 10) +@BenchmarkMode(Mode.SingleShotTime) +@Timeout(time = 1000, timeUnit = TimeUnit.HOURS) +public class IcebergSortCompactionBenchmark { + + private static final String[] NAMESPACE = new String[] {"default"}; + private static final String NAME = "sortbench"; + private static final Identifier IDENT = Identifier.of(NAMESPACE, NAME); + private static final String FULL_IDENT = Spark3Util.quotedFullIdentifier("spark_catalog", IDENT); + private static final int NUM_FILES = 8; + private static final long NUM_ROWS = 7500000L; + private static final long UNIQUE_VALUES = NUM_ROWS / 4; + + private final Configuration hadoopConf = initHadoopConf(); + private SparkSession spark; + + @Setup + public void setupBench() { + setupSpark(); + } + + @TearDown + public void teardownBench() { + tearDownSpark(); + } + + @Setup(Level.Iteration) + public void setupIteration() { + initTable(); + appendData(); + } + + @TearDown(Level.Iteration) + public void cleanUpIteration() throws IOException { + cleanupFiles(); + } + + @Benchmark + @Threads(1) + public void sortInt() { + SparkActions.get() + .rewriteDataFiles(table(), FULL_IDENT) + .option(BinPackStrategy.REWRITE_ALL, "true") + .sort(SortOrder + .builderFor(table().schema()) + .sortBy("intCol", SortDirection.ASC, NullOrder.NULLS_FIRST) + .build()) + .execute(); + } + + @Benchmark + @Threads(1) + public void sortInt2() { + SparkActions.get() + .rewriteDataFiles(table(), FULL_IDENT) + .option(BinPackStrategy.REWRITE_ALL, "true") + .sort(SortOrder + .builderFor(table().schema()) + .sortBy("intCol", SortDirection.ASC, NullOrder.NULLS_FIRST) + .sortBy("intCol2", SortDirection.ASC, NullOrder.NULLS_FIRST) + .build()) + .execute(); + } + + @Benchmark + @Threads(1) + public void sortInt3() { + SparkActions.get() + .rewriteDataFiles(table(), FULL_IDENT) + .option(BinPackStrategy.REWRITE_ALL, "true") + .sort(SortOrder + .builderFor(table().schema()) + .sortBy("intCol", SortDirection.ASC, NullOrder.NULLS_FIRST) + .sortBy("intCol2", SortDirection.ASC, NullOrder.NULLS_FIRST) + .sortBy("intCol3", SortDirection.ASC, NullOrder.NULLS_FIRST) + .sortBy("intCol4", SortDirection.ASC, NullOrder.NULLS_FIRST) + .build()) + .execute(); + } + + @Benchmark + @Threads(1) + public void sortInt4() { + SparkActions.get() + .rewriteDataFiles(table(), FULL_IDENT) + .option(BinPackStrategy.REWRITE_ALL, "true") + .sort(SortOrder + .builderFor(table().schema()) + .sortBy("intCol", SortDirection.ASC, NullOrder.NULLS_FIRST) + .sortBy("intCol2", SortDirection.ASC, NullOrder.NULLS_FIRST) + .sortBy("intCol3", SortDirection.ASC, NullOrder.NULLS_FIRST) + .sortBy("intCol4", SortDirection.ASC, NullOrder.NULLS_FIRST) + .build()) + .execute(); + } + + @Benchmark + @Threads(1) + public void sortString() { + SparkActions.get() + .rewriteDataFiles(table(), FULL_IDENT) + .option(BinPackStrategy.REWRITE_ALL, "true") + .sort(SortOrder + .builderFor(table().schema()) + .sortBy("stringCol", SortDirection.ASC, NullOrder.NULLS_FIRST) + .build()) + .execute(); + } + + @Benchmark + @Threads(1) + public void sortFourColumns() { + SparkActions.get() + .rewriteDataFiles(table(), FULL_IDENT) + .option(BinPackStrategy.REWRITE_ALL, "true") + .sort(SortOrder + .builderFor(table().schema()) + .sortBy("stringCol", SortDirection.ASC, NullOrder.NULLS_FIRST) + .sortBy("intCol", SortDirection.ASC, NullOrder.NULLS_FIRST) + .sortBy("dateCol", SortDirection.DESC, NullOrder.NULLS_FIRST) + .sortBy("doubleCol", SortDirection.DESC, NullOrder.NULLS_FIRST) + .build()) + .execute(); + } + + @Benchmark + @Threads(1) + public void sortSixColumns() { + SparkActions.get() + .rewriteDataFiles(table(), FULL_IDENT) + .option(BinPackStrategy.REWRITE_ALL, "true") + .sort(SortOrder + .builderFor(table().schema()) + .sortBy("stringCol", SortDirection.ASC, NullOrder.NULLS_FIRST) + .sortBy("intCol", SortDirection.ASC, NullOrder.NULLS_FIRST) + .sortBy("dateCol", SortDirection.DESC, NullOrder.NULLS_FIRST) + .sortBy("timestampCol", SortDirection.DESC, NullOrder.NULLS_FIRST) + .sortBy("doubleCol", SortDirection.DESC, NullOrder.NULLS_FIRST) + .sortBy("longCol", SortDirection.DESC, NullOrder.NULLS_FIRST) + .build()) + .execute(); + } + + @Benchmark + @Threads(1) + public void zSortInt() { + SparkActions.get() + .rewriteDataFiles(table(), FULL_IDENT) + .option(BinPackStrategy.REWRITE_ALL, "true") + .zOrder("intCol") + .execute(); + } + + @Benchmark + @Threads(1) + public void zSortInt2() { + SparkActions.get() + .rewriteDataFiles(table(), FULL_IDENT) + .option(BinPackStrategy.REWRITE_ALL, "true") + .zOrder("intCol", "intCol2") + .execute(); + } + + @Benchmark + @Threads(1) + public void zSortInt3() { + SparkActions.get() + .rewriteDataFiles(table(), FULL_IDENT) + .option(BinPackStrategy.REWRITE_ALL, "true") + .zOrder("intCol", "intCol2", "intCol3") + .execute(); + } + + @Benchmark + @Threads(1) + public void zSortInt4() { + SparkActions.get() + .rewriteDataFiles(table(), FULL_IDENT) + .option(BinPackStrategy.REWRITE_ALL, "true") + .zOrder("intCol", "intCol2", "intCol3", "intCol4") + .execute(); + } + + @Benchmark + @Threads(1) + public void zSortString() { + SparkActions.get() + .rewriteDataFiles(table(), FULL_IDENT) + .option(BinPackStrategy.REWRITE_ALL, "true") + .zOrder("stringCol") + .execute(); + } + + @Benchmark + @Threads(1) + public void zSortFourColumns() { + SparkActions.get() + .rewriteDataFiles(table(), FULL_IDENT) + .option(BinPackStrategy.REWRITE_ALL, "true") + .zOrder("stringCol", "intCol", "dateCol", "doubleCol") + .execute(); + } + + @Benchmark + @Threads(1) + public void zSortSixColumns() { + SparkActions.get() + .rewriteDataFiles(table(), FULL_IDENT) + .option(BinPackStrategy.REWRITE_ALL, "true") + .zOrder("stringCol", "intCol", "dateCol", "timestampCol", "doubleCol", "longCol") + .execute(); + } + + protected Configuration initHadoopConf() { + return new Configuration(); + } + + protected final void initTable() { + Schema schema = new Schema( + required(1, "longCol", Types.LongType.get()), + required(2, "intCol", Types.IntegerType.get()), + required(3, "intCol2", Types.IntegerType.get()), + required(4, "intCol3", Types.IntegerType.get()), + required(5, "intCol4", Types.IntegerType.get()), + required(6, "floatCol", Types.FloatType.get()), + optional(7, "doubleCol", Types.DoubleType.get()), + optional(8, "dateCol", Types.DateType.get()), + optional(9, "timestampCol", Types.TimestampType.withZone()), + optional(10, "stringCol", Types.StringType.get())); + + SparkSessionCatalog catalog; + try { + catalog = (SparkSessionCatalog) + Spark3Util.catalogAndIdentifier(spark(), "spark_catalog").catalog(); + catalog.dropTable(IDENT); + catalog.createTable(IDENT, SparkSchemaUtil.convert(schema), new Transform[0], Collections.emptyMap()); + } catch (Exception e) { + throw new RuntimeException(e); + } + } + + private void appendData() { + Dataset df = spark().range(0, NUM_ROWS * NUM_FILES, 1, NUM_FILES) + .drop("id") + .withColumn("longCol", new RandomGeneratingUDF(UNIQUE_VALUES).randomLongUDF().apply()) + .withColumn( + "intCol", + new RandomGeneratingUDF(UNIQUE_VALUES).randomLongUDF().apply().cast(DataTypes.IntegerType)) + .withColumn( + "intCol2", + new RandomGeneratingUDF(UNIQUE_VALUES).randomLongUDF().apply().cast(DataTypes.IntegerType)) + .withColumn( + "intCol3", + new RandomGeneratingUDF(UNIQUE_VALUES).randomLongUDF().apply().cast(DataTypes.IntegerType)) + .withColumn( + "intCol4", + new RandomGeneratingUDF(UNIQUE_VALUES).randomLongUDF().apply().cast(DataTypes.IntegerType)) + .withColumn( + "floatCol", + new RandomGeneratingUDF(UNIQUE_VALUES).randomLongUDF().apply().cast(DataTypes.FloatType)) + .withColumn( + "doubleCol", + new RandomGeneratingUDF(UNIQUE_VALUES).randomLongUDF().apply().cast(DataTypes.DoubleType)) + .withColumn("dateCol", date_add(current_date(), col("intCol").mod(NUM_FILES))) + .withColumn("timestampCol", expr("TO_TIMESTAMP(dateCol)")) + .withColumn("stringCol", new RandomGeneratingUDF(UNIQUE_VALUES).randomString().apply()); + writeData(df); + } + + private void writeData(Dataset df) { + df.write().format("iceberg").mode(SaveMode.Append).save(NAME); + } + + protected final Table table() { + try { + return Spark3Util.loadIcebergTable(spark(), NAME); + } catch (Exception e) { + throw new RuntimeException(e); + } + } + + protected final SparkSession spark() { + return spark; + } + + protected String getCatalogWarehouse() { + String location = Files.createTempDir().getAbsolutePath() + "/" + UUID.randomUUID() + "/"; + return location; + } + + protected void cleanupFiles() throws IOException { + spark.sql("DROP TABLE IF EXISTS " + NAME); + } + + protected void setupSpark() { + SparkSession.Builder builder = + SparkSession.builder() + .config("spark.sql.catalog.spark_catalog", "org.apache.iceberg.spark.SparkSessionCatalog") + .config("spark.sql.catalog.spark_catalog.type", "hadoop") + .config("spark.sql.catalog.spark_catalog.warehouse", getCatalogWarehouse()) + .master("local[*]"); + spark = builder.getOrCreate(); + Configuration sparkHadoopConf = spark.sessionState().newHadoopConf(); + hadoopConf.forEach(entry -> sparkHadoopConf.set(entry.getKey(), entry.getValue())); + } + + protected void tearDownSpark() { + spark.stop(); + } +} diff --git a/spark/v3.3/spark/src/jmh/java/org/apache/iceberg/spark/action/RandomGeneratingUDF.java b/spark/v3.3/spark/src/jmh/java/org/apache/iceberg/spark/action/RandomGeneratingUDF.java new file mode 100644 index 000000000000..cfbd9d4fb3f6 --- /dev/null +++ b/spark/v3.3/spark/src/jmh/java/org/apache/iceberg/spark/action/RandomGeneratingUDF.java @@ -0,0 +1,47 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iceberg.spark.action; + +import java.io.Serializable; +import java.util.Random; +import org.apache.iceberg.types.Types; +import org.apache.iceberg.util.RandomUtil; +import org.apache.spark.sql.expressions.UserDefinedFunction; +import org.apache.spark.sql.types.DataTypes; + +import static org.apache.spark.sql.functions.udf; + +class RandomGeneratingUDF implements Serializable { + private final long uniqueValues; + private Random rand = new Random(); + + RandomGeneratingUDF(long uniqueValues) { + this.uniqueValues = uniqueValues; + } + + UserDefinedFunction randomLongUDF() { + return udf(() -> rand.nextLong() % (uniqueValues / 2), DataTypes.LongType).asNondeterministic().asNonNullable(); + } + + UserDefinedFunction randomString() { + return udf(() -> (String) RandomUtil.generatePrimitive(Types.StringType.get(), rand), DataTypes.StringType) + .asNondeterministic().asNonNullable(); + } +} diff --git a/spark/v3.3/spark/src/jmh/java/org/apache/iceberg/spark/data/parquet/SparkParquetReadersFlatDataBenchmark.java b/spark/v3.3/spark/src/jmh/java/org/apache/iceberg/spark/data/parquet/SparkParquetReadersFlatDataBenchmark.java new file mode 100644 index 000000000000..3c58f5278ca3 --- /dev/null +++ b/spark/v3.3/spark/src/jmh/java/org/apache/iceberg/spark/data/parquet/SparkParquetReadersFlatDataBenchmark.java @@ -0,0 +1,216 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iceberg.spark.data.parquet; + +import java.io.File; +import java.io.IOException; +import java.util.List; +import org.apache.avro.generic.GenericData; +import org.apache.iceberg.Files; +import org.apache.iceberg.Schema; +import org.apache.iceberg.common.DynMethods; +import org.apache.iceberg.io.CloseableIterable; +import org.apache.iceberg.io.FileAppender; +import org.apache.iceberg.parquet.Parquet; +import org.apache.iceberg.relocated.com.google.common.collect.Iterables; +import org.apache.iceberg.spark.SparkBenchmarkUtil; +import org.apache.iceberg.spark.SparkSchemaUtil; +import org.apache.iceberg.spark.data.RandomData; +import org.apache.iceberg.spark.data.SparkParquetReaders; +import org.apache.iceberg.types.Types; +import org.apache.spark.sql.catalyst.InternalRow; +import org.apache.spark.sql.catalyst.expressions.UnsafeProjection; +import org.apache.spark.sql.execution.datasources.parquet.ParquetReadSupport; +import org.apache.spark.sql.types.StructType; +import org.openjdk.jmh.annotations.Benchmark; +import org.openjdk.jmh.annotations.BenchmarkMode; +import org.openjdk.jmh.annotations.Fork; +import org.openjdk.jmh.annotations.Measurement; +import org.openjdk.jmh.annotations.Mode; +import org.openjdk.jmh.annotations.Scope; +import org.openjdk.jmh.annotations.Setup; +import org.openjdk.jmh.annotations.State; +import org.openjdk.jmh.annotations.TearDown; +import org.openjdk.jmh.annotations.Threads; +import org.openjdk.jmh.annotations.Warmup; +import org.openjdk.jmh.infra.Blackhole; + +import static org.apache.iceberg.types.Types.NestedField.optional; +import static org.apache.iceberg.types.Types.NestedField.required; + +/** + * A benchmark that evaluates the performance of reading Parquet data with a flat schema using + * Iceberg and Spark Parquet readers. + * + * To run this benchmark for spark-3.3: + * + * ./gradlew -DsparkVersions=3.3 :iceberg-spark:iceberg-spark-3.3_2.12:jmh + * -PjmhIncludeRegex=SparkParquetReadersFlatDataBenchmark + * -PjmhOutputPath=benchmark/spark-parquet-readers-flat-data-benchmark-result.txt + * + */ +@Fork(1) +@State(Scope.Benchmark) +@Warmup(iterations = 3) +@Measurement(iterations = 5) +@BenchmarkMode(Mode.SingleShotTime) +public class SparkParquetReadersFlatDataBenchmark { + + private static final DynMethods.UnboundMethod APPLY_PROJECTION = DynMethods.builder("apply") + .impl(UnsafeProjection.class, InternalRow.class) + .build(); + private static final Schema SCHEMA = new Schema( + required(1, "longCol", Types.LongType.get()), + required(2, "intCol", Types.IntegerType.get()), + required(3, "floatCol", Types.FloatType.get()), + optional(4, "doubleCol", Types.DoubleType.get()), + optional(5, "decimalCol", Types.DecimalType.of(20, 5)), + optional(6, "dateCol", Types.DateType.get()), + optional(7, "timestampCol", Types.TimestampType.withZone()), + optional(8, "stringCol", Types.StringType.get())); + private static final Schema PROJECTED_SCHEMA = new Schema( + required(1, "longCol", Types.LongType.get()), + optional(5, "decimalCol", Types.DecimalType.of(20, 5)), + optional(8, "stringCol", Types.StringType.get())); + private static final int NUM_RECORDS = 10000000; + private File dataFile; + + @Setup + public void setupBenchmark() throws IOException { + dataFile = File.createTempFile("parquet-flat-data-benchmark", ".parquet"); + dataFile.delete(); + List records = RandomData.generateList(SCHEMA, NUM_RECORDS, 0L); + try (FileAppender writer = Parquet.write(Files.localOutput(dataFile)) + .schema(SCHEMA) + .named("benchmark") + .build()) { + writer.addAll(records); + } + } + + @TearDown + public void tearDownBenchmark() { + if (dataFile != null) { + dataFile.delete(); + } + } + + @Benchmark + @Threads(1) + public void readUsingIcebergReader(Blackhole blackHole) throws IOException { + try (CloseableIterable rows = Parquet.read(Files.localInput(dataFile)) + .project(SCHEMA) + .createReaderFunc(type -> SparkParquetReaders.buildReader(SCHEMA, type)) + .build()) { + + for (InternalRow row : rows) { + blackHole.consume(row); + } + } + } + + @Benchmark + @Threads(1) + public void readUsingIcebergReaderUnsafe(Blackhole blackhole) throws IOException { + try (CloseableIterable rows = Parquet.read(Files.localInput(dataFile)) + .project(SCHEMA) + .createReaderFunc(type -> SparkParquetReaders.buildReader(SCHEMA, type)) + .build()) { + + Iterable unsafeRows = Iterables.transform( + rows, + APPLY_PROJECTION.bind(SparkBenchmarkUtil.projection(SCHEMA, SCHEMA))::invoke); + + for (InternalRow row : unsafeRows) { + blackhole.consume(row); + } + } + } + + @Benchmark + @Threads(1) + public void readUsingSparkReader(Blackhole blackhole) throws IOException { + StructType sparkSchema = SparkSchemaUtil.convert(SCHEMA); + try (CloseableIterable rows = Parquet.read(Files.localInput(dataFile)) + .project(SCHEMA) + .readSupport(new ParquetReadSupport()) + .set("org.apache.spark.sql.parquet.row.requested_schema", sparkSchema.json()) + .set("spark.sql.parquet.binaryAsString", "false") + .set("spark.sql.parquet.int96AsTimestamp", "false") + .callInit() + .build()) { + + for (InternalRow row : rows) { + blackhole.consume(row); + } + } + } + + @Benchmark + @Threads(1) + public void readWithProjectionUsingIcebergReader(Blackhole blackhole) throws IOException { + try (CloseableIterable rows = Parquet.read(Files.localInput(dataFile)) + .project(PROJECTED_SCHEMA) + .createReaderFunc(type -> SparkParquetReaders.buildReader(PROJECTED_SCHEMA, type)) + .build()) { + + for (InternalRow row : rows) { + blackhole.consume(row); + } + } + } + + @Benchmark + @Threads(1) + public void readWithProjectionUsingIcebergReaderUnsafe(Blackhole blackhole) throws IOException { + try (CloseableIterable rows = Parquet.read(Files.localInput(dataFile)) + .project(PROJECTED_SCHEMA) + .createReaderFunc(type -> SparkParquetReaders.buildReader(PROJECTED_SCHEMA, type)) + .build()) { + + Iterable unsafeRows = Iterables.transform( + rows, + APPLY_PROJECTION.bind(SparkBenchmarkUtil.projection(PROJECTED_SCHEMA, PROJECTED_SCHEMA))::invoke); + + for (InternalRow row : unsafeRows) { + blackhole.consume(row); + } + } + } + + @Benchmark + @Threads(1) + public void readWithProjectionUsingSparkReader(Blackhole blackhole) throws IOException { + StructType sparkSchema = SparkSchemaUtil.convert(PROJECTED_SCHEMA); + try (CloseableIterable rows = Parquet.read(Files.localInput(dataFile)) + .project(PROJECTED_SCHEMA) + .readSupport(new ParquetReadSupport()) + .set("org.apache.spark.sql.parquet.row.requested_schema", sparkSchema.json()) + .set("spark.sql.parquet.binaryAsString", "false") + .set("spark.sql.parquet.int96AsTimestamp", "false") + .callInit() + .build()) { + + for (InternalRow row : rows) { + blackhole.consume(row); + } + } + } +} diff --git a/spark/v3.3/spark/src/jmh/java/org/apache/iceberg/spark/data/parquet/SparkParquetReadersNestedDataBenchmark.java b/spark/v3.3/spark/src/jmh/java/org/apache/iceberg/spark/data/parquet/SparkParquetReadersNestedDataBenchmark.java new file mode 100644 index 000000000000..e894f50ebfbc --- /dev/null +++ b/spark/v3.3/spark/src/jmh/java/org/apache/iceberg/spark/data/parquet/SparkParquetReadersNestedDataBenchmark.java @@ -0,0 +1,216 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iceberg.spark.data.parquet; + +import java.io.File; +import java.io.IOException; +import java.util.List; +import org.apache.avro.generic.GenericData; +import org.apache.iceberg.Files; +import org.apache.iceberg.Schema; +import org.apache.iceberg.common.DynMethods; +import org.apache.iceberg.io.CloseableIterable; +import org.apache.iceberg.io.FileAppender; +import org.apache.iceberg.parquet.Parquet; +import org.apache.iceberg.relocated.com.google.common.collect.Iterables; +import org.apache.iceberg.spark.SparkBenchmarkUtil; +import org.apache.iceberg.spark.SparkSchemaUtil; +import org.apache.iceberg.spark.data.RandomData; +import org.apache.iceberg.spark.data.SparkParquetReaders; +import org.apache.iceberg.types.Types; +import org.apache.spark.sql.catalyst.InternalRow; +import org.apache.spark.sql.catalyst.expressions.UnsafeProjection; +import org.apache.spark.sql.execution.datasources.parquet.ParquetReadSupport; +import org.apache.spark.sql.types.StructType; +import org.openjdk.jmh.annotations.Benchmark; +import org.openjdk.jmh.annotations.BenchmarkMode; +import org.openjdk.jmh.annotations.Fork; +import org.openjdk.jmh.annotations.Measurement; +import org.openjdk.jmh.annotations.Mode; +import org.openjdk.jmh.annotations.Scope; +import org.openjdk.jmh.annotations.Setup; +import org.openjdk.jmh.annotations.State; +import org.openjdk.jmh.annotations.TearDown; +import org.openjdk.jmh.annotations.Threads; +import org.openjdk.jmh.annotations.Warmup; +import org.openjdk.jmh.infra.Blackhole; + +import static org.apache.iceberg.types.Types.NestedField.optional; +import static org.apache.iceberg.types.Types.NestedField.required; + +/** + * A benchmark that evaluates the performance of reading nested Parquet data using + * Iceberg and Spark Parquet readers. + * + * To run this benchmark for spark-3.3: + * + * ./gradlew -DsparkVersions=3.3 :iceberg-spark:iceberg-spark-3.3_2.12:jmh + * -PjmhIncludeRegex=SparkParquetReadersNestedDataBenchmark + * -PjmhOutputPath=benchmark/spark-parquet-readers-nested-data-benchmark-result.txt + * + */ +@Fork(1) +@State(Scope.Benchmark) +@Warmup(iterations = 3) +@Measurement(iterations = 5) +@BenchmarkMode(Mode.SingleShotTime) +public class SparkParquetReadersNestedDataBenchmark { + + private static final DynMethods.UnboundMethod APPLY_PROJECTION = DynMethods.builder("apply") + .impl(UnsafeProjection.class, InternalRow.class) + .build(); + private static final Schema SCHEMA = new Schema( + required(0, "id", Types.LongType.get()), + optional(4, "nested", Types.StructType.of( + required(1, "col1", Types.StringType.get()), + required(2, "col2", Types.DoubleType.get()), + required(3, "col3", Types.LongType.get()) + )) + ); + private static final Schema PROJECTED_SCHEMA = new Schema( + optional(4, "nested", Types.StructType.of( + required(1, "col1", Types.StringType.get()) + )) + ); + private static final int NUM_RECORDS = 10000000; + private File dataFile; + + @Setup + public void setupBenchmark() throws IOException { + dataFile = File.createTempFile("parquet-nested-data-benchmark", ".parquet"); + dataFile.delete(); + List records = RandomData.generateList(SCHEMA, NUM_RECORDS, 0L); + try (FileAppender writer = Parquet.write(Files.localOutput(dataFile)) + .schema(SCHEMA) + .named("benchmark") + .build()) { + writer.addAll(records); + } + } + + @TearDown + public void tearDownBenchmark() { + if (dataFile != null) { + dataFile.delete(); + } + } + + @Benchmark + @Threads(1) + public void readUsingIcebergReader(Blackhole blackhole) throws IOException { + try (CloseableIterable rows = Parquet.read(Files.localInput(dataFile)) + .project(SCHEMA) + .createReaderFunc(type -> SparkParquetReaders.buildReader(SCHEMA, type)) + .build()) { + + for (InternalRow row : rows) { + blackhole.consume(row); + } + } + } + + @Benchmark + @Threads(1) + public void readUsingIcebergReaderUnsafe(Blackhole blackhole) throws IOException { + try (CloseableIterable rows = Parquet.read(Files.localInput(dataFile)) + .project(SCHEMA) + .createReaderFunc(type -> SparkParquetReaders.buildReader(SCHEMA, type)) + .build()) { + + Iterable unsafeRows = Iterables.transform( + rows, + APPLY_PROJECTION.bind(SparkBenchmarkUtil.projection(SCHEMA, SCHEMA))::invoke); + + for (InternalRow row : unsafeRows) { + blackhole.consume(row); + } + } + } + + @Benchmark + @Threads(1) + public void readUsingSparkReader(Blackhole blackhole) throws IOException { + StructType sparkSchema = SparkSchemaUtil.convert(SCHEMA); + try (CloseableIterable rows = Parquet.read(Files.localInput(dataFile)) + .project(SCHEMA) + .readSupport(new ParquetReadSupport()) + .set("org.apache.spark.sql.parquet.row.requested_schema", sparkSchema.json()) + .set("spark.sql.parquet.binaryAsString", "false") + .set("spark.sql.parquet.int96AsTimestamp", "false") + .callInit() + .build()) { + + for (InternalRow row : rows) { + blackhole.consume(row); + } + } + } + + @Benchmark + @Threads(1) + public void readWithProjectionUsingIcebergReader(Blackhole blackhole) throws IOException { + try (CloseableIterable rows = Parquet.read(Files.localInput(dataFile)) + .project(PROJECTED_SCHEMA) + .createReaderFunc(type -> SparkParquetReaders.buildReader(PROJECTED_SCHEMA, type)) + .build()) { + + for (InternalRow row : rows) { + blackhole.consume(row); + } + } + } + + @Benchmark + @Threads(1) + public void readWithProjectionUsingIcebergReaderUnsafe(Blackhole blackhole) throws IOException { + try (CloseableIterable rows = Parquet.read(Files.localInput(dataFile)) + .project(PROJECTED_SCHEMA) + .createReaderFunc(type -> SparkParquetReaders.buildReader(PROJECTED_SCHEMA, type)) + .build()) { + + Iterable unsafeRows = Iterables.transform( + rows, + APPLY_PROJECTION.bind(SparkBenchmarkUtil.projection(PROJECTED_SCHEMA, PROJECTED_SCHEMA))::invoke); + + for (InternalRow row : unsafeRows) { + blackhole.consume(row); + } + } + } + + @Benchmark + @Threads(1) + public void readWithProjectionUsingSparkReader(Blackhole blackhole) throws IOException { + StructType sparkSchema = SparkSchemaUtil.convert(PROJECTED_SCHEMA); + try (CloseableIterable rows = Parquet.read(Files.localInput(dataFile)) + .project(PROJECTED_SCHEMA) + .readSupport(new ParquetReadSupport()) + .set("org.apache.spark.sql.parquet.row.requested_schema", sparkSchema.json()) + .set("spark.sql.parquet.binaryAsString", "false") + .set("spark.sql.parquet.int96AsTimestamp", "false") + .callInit() + .build()) { + + for (InternalRow row : rows) { + blackhole.consume(row); + } + } + } +} diff --git a/spark/v3.3/spark/src/jmh/java/org/apache/iceberg/spark/data/parquet/SparkParquetWritersFlatDataBenchmark.java b/spark/v3.3/spark/src/jmh/java/org/apache/iceberg/spark/data/parquet/SparkParquetWritersFlatDataBenchmark.java new file mode 100644 index 000000000000..5fe33de3d69e --- /dev/null +++ b/spark/v3.3/spark/src/jmh/java/org/apache/iceberg/spark/data/parquet/SparkParquetWritersFlatDataBenchmark.java @@ -0,0 +1,125 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iceberg.spark.data.parquet; + +import java.io.File; +import java.io.IOException; +import org.apache.iceberg.Files; +import org.apache.iceberg.Schema; +import org.apache.iceberg.io.FileAppender; +import org.apache.iceberg.parquet.Parquet; +import org.apache.iceberg.spark.SparkSchemaUtil; +import org.apache.iceberg.spark.data.RandomData; +import org.apache.iceberg.spark.data.SparkParquetWriters; +import org.apache.iceberg.types.Types; +import org.apache.spark.sql.catalyst.InternalRow; +import org.apache.spark.sql.execution.datasources.parquet.ParquetWriteSupport; +import org.apache.spark.sql.types.StructType; +import org.openjdk.jmh.annotations.Benchmark; +import org.openjdk.jmh.annotations.BenchmarkMode; +import org.openjdk.jmh.annotations.Fork; +import org.openjdk.jmh.annotations.Level; +import org.openjdk.jmh.annotations.Measurement; +import org.openjdk.jmh.annotations.Mode; +import org.openjdk.jmh.annotations.Scope; +import org.openjdk.jmh.annotations.Setup; +import org.openjdk.jmh.annotations.State; +import org.openjdk.jmh.annotations.TearDown; +import org.openjdk.jmh.annotations.Threads; +import org.openjdk.jmh.annotations.Warmup; + +import static org.apache.iceberg.types.Types.NestedField.optional; +import static org.apache.iceberg.types.Types.NestedField.required; + +/** + * A benchmark that evaluates the performance of writing Parquet data with a flat schema using + * Iceberg and Spark Parquet writers. + * + * To run this benchmark for spark-3.3: + * + * ./gradlew -DsparkVersions=3.3 :iceberg-spark:iceberg-spark-3.3_2.12:jmh + * -PjmhIncludeRegex=SparkParquetWritersFlatDataBenchmark + * -PjmhOutputPath=benchmark/spark-parquet-writers-flat-data-benchmark-result.txt + * + */ +@Fork(1) +@State(Scope.Benchmark) +@Warmup(iterations = 3) +@Measurement(iterations = 5) +@BenchmarkMode(Mode.SingleShotTime) +public class SparkParquetWritersFlatDataBenchmark { + + private static final Schema SCHEMA = new Schema( + required(1, "longCol", Types.LongType.get()), + required(2, "intCol", Types.IntegerType.get()), + required(3, "floatCol", Types.FloatType.get()), + optional(4, "doubleCol", Types.DoubleType.get()), + optional(5, "decimalCol", Types.DecimalType.of(20, 5)), + optional(6, "dateCol", Types.DateType.get()), + optional(7, "timestampCol", Types.TimestampType.withZone()), + optional(8, "stringCol", Types.StringType.get())); + private static final int NUM_RECORDS = 1000000; + private Iterable rows; + private File dataFile; + + @Setup + public void setupBenchmark() throws IOException { + rows = RandomData.generateSpark(SCHEMA, NUM_RECORDS, 0L); + dataFile = File.createTempFile("parquet-flat-data-benchmark", ".parquet"); + dataFile.delete(); + } + + @TearDown(Level.Iteration) + public void tearDownBenchmark() { + if (dataFile != null) { + dataFile.delete(); + } + } + + @Benchmark + @Threads(1) + public void writeUsingIcebergWriter() throws IOException { + try (FileAppender writer = Parquet.write(Files.localOutput(dataFile)) + .createWriterFunc(msgType -> SparkParquetWriters.buildWriter(SparkSchemaUtil.convert(SCHEMA), msgType)) + .schema(SCHEMA) + .build()) { + + writer.addAll(rows); + } + } + + @Benchmark + @Threads(1) + public void writeUsingSparkWriter() throws IOException { + StructType sparkSchema = SparkSchemaUtil.convert(SCHEMA); + try (FileAppender writer = Parquet.write(Files.localOutput(dataFile)) + .writeSupport(new ParquetWriteSupport()) + .set("org.apache.spark.sql.parquet.row.attributes", sparkSchema.json()) + .set("spark.sql.parquet.writeLegacyFormat", "false") + .set("spark.sql.parquet.binaryAsString", "false") + .set("spark.sql.parquet.int96AsTimestamp", "false") + .set("spark.sql.parquet.outputTimestampType", "TIMESTAMP_MICROS") + .schema(SCHEMA) + .build()) { + + writer.addAll(rows); + } + } +} diff --git a/spark/v3.3/spark/src/jmh/java/org/apache/iceberg/spark/data/parquet/SparkParquetWritersNestedDataBenchmark.java b/spark/v3.3/spark/src/jmh/java/org/apache/iceberg/spark/data/parquet/SparkParquetWritersNestedDataBenchmark.java new file mode 100644 index 000000000000..0b591c2e2cf5 --- /dev/null +++ b/spark/v3.3/spark/src/jmh/java/org/apache/iceberg/spark/data/parquet/SparkParquetWritersNestedDataBenchmark.java @@ -0,0 +1,124 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iceberg.spark.data.parquet; + +import java.io.File; +import java.io.IOException; +import org.apache.iceberg.Files; +import org.apache.iceberg.Schema; +import org.apache.iceberg.io.FileAppender; +import org.apache.iceberg.parquet.Parquet; +import org.apache.iceberg.spark.SparkSchemaUtil; +import org.apache.iceberg.spark.data.RandomData; +import org.apache.iceberg.spark.data.SparkParquetWriters; +import org.apache.iceberg.types.Types; +import org.apache.spark.sql.catalyst.InternalRow; +import org.apache.spark.sql.execution.datasources.parquet.ParquetWriteSupport; +import org.apache.spark.sql.types.StructType; +import org.openjdk.jmh.annotations.Benchmark; +import org.openjdk.jmh.annotations.BenchmarkMode; +import org.openjdk.jmh.annotations.Fork; +import org.openjdk.jmh.annotations.Level; +import org.openjdk.jmh.annotations.Measurement; +import org.openjdk.jmh.annotations.Mode; +import org.openjdk.jmh.annotations.Scope; +import org.openjdk.jmh.annotations.Setup; +import org.openjdk.jmh.annotations.State; +import org.openjdk.jmh.annotations.TearDown; +import org.openjdk.jmh.annotations.Threads; +import org.openjdk.jmh.annotations.Warmup; + +import static org.apache.iceberg.types.Types.NestedField.optional; +import static org.apache.iceberg.types.Types.NestedField.required; + +/** + * A benchmark that evaluates the performance of writing nested Parquet data using + * Iceberg and Spark Parquet writers. + * + * To run this benchmark for spark-3.3: + * + * ./gradlew -DsparkVersions=3.3 :iceberg-spark:iceberg-spark-3.3_2.12:jmh + * -PjmhIncludeRegex=SparkParquetWritersNestedDataBenchmark + * -PjmhOutputPath=benchmark/spark-parquet-writers-nested-data-benchmark-result.txt + * + */ +@Fork(1) +@State(Scope.Benchmark) +@Warmup(iterations = 3) +@Measurement(iterations = 5) +@BenchmarkMode(Mode.SingleShotTime) +public class SparkParquetWritersNestedDataBenchmark { + + private static final Schema SCHEMA = new Schema( + required(0, "id", Types.LongType.get()), + optional(4, "nested", Types.StructType.of( + required(1, "col1", Types.StringType.get()), + required(2, "col2", Types.DoubleType.get()), + required(3, "col3", Types.LongType.get()) + )) + ); + private static final int NUM_RECORDS = 1000000; + private Iterable rows; + private File dataFile; + + @Setup + public void setupBenchmark() throws IOException { + rows = RandomData.generateSpark(SCHEMA, NUM_RECORDS, 0L); + dataFile = File.createTempFile("parquet-nested-data-benchmark", ".parquet"); + dataFile.delete(); + } + + @TearDown(Level.Iteration) + public void tearDownBenchmark() { + if (dataFile != null) { + dataFile.delete(); + } + } + + @Benchmark + @Threads(1) + public void writeUsingIcebergWriter() throws IOException { + try (FileAppender writer = Parquet.write(Files.localOutput(dataFile)) + .createWriterFunc(msgType -> SparkParquetWriters.buildWriter(SparkSchemaUtil.convert(SCHEMA), msgType)) + .schema(SCHEMA) + .build()) { + + writer.addAll(rows); + } + } + + @Benchmark + @Threads(1) + public void writeUsingSparkWriter() throws IOException { + StructType sparkSchema = SparkSchemaUtil.convert(SCHEMA); + try (FileAppender writer = Parquet.write(Files.localOutput(dataFile)) + .writeSupport(new ParquetWriteSupport()) + .set("org.apache.spark.sql.parquet.row.attributes", sparkSchema.json()) + .set("spark.sql.parquet.writeLegacyFormat", "false") + .set("spark.sql.parquet.binaryAsString", "false") + .set("spark.sql.parquet.int96AsTimestamp", "false") + .set("spark.sql.parquet.outputTimestampType", "TIMESTAMP_MICROS") + .schema(SCHEMA) + .build()) { + + writer.addAll(rows); + } + } +} diff --git a/spark/v3.3/spark/src/jmh/java/org/apache/iceberg/spark/source/Action.java b/spark/v3.3/spark/src/jmh/java/org/apache/iceberg/spark/source/Action.java new file mode 100644 index 000000000000..1820a801b2fb --- /dev/null +++ b/spark/v3.3/spark/src/jmh/java/org/apache/iceberg/spark/source/Action.java @@ -0,0 +1,25 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iceberg.spark.source; + +@FunctionalInterface +public interface Action { + void invoke(); +} diff --git a/spark/v3.3/spark/src/jmh/java/org/apache/iceberg/spark/source/IcebergSourceBenchmark.java b/spark/v3.3/spark/src/jmh/java/org/apache/iceberg/spark/source/IcebergSourceBenchmark.java new file mode 100644 index 000000000000..00addb8a327b --- /dev/null +++ b/spark/v3.3/spark/src/jmh/java/org/apache/iceberg/spark/source/IcebergSourceBenchmark.java @@ -0,0 +1,195 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iceberg.spark.source; + +import java.io.IOException; +import java.util.Map; +import java.util.UUID; +import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.fs.FileSystem; +import org.apache.hadoop.fs.Path; +import org.apache.iceberg.FileFormat; +import org.apache.iceberg.Table; +import org.apache.iceberg.TableProperties; +import org.apache.iceberg.UpdateProperties; +import org.apache.iceberg.relocated.com.google.common.collect.Maps; +import org.apache.iceberg.spark.SparkSchemaUtil; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.SaveMode; +import org.apache.spark.sql.SparkSession; +import org.apache.spark.sql.internal.SQLConf; +import org.apache.spark.sql.types.StructType; +import org.openjdk.jmh.annotations.BenchmarkMode; +import org.openjdk.jmh.annotations.Fork; +import org.openjdk.jmh.annotations.Measurement; +import org.openjdk.jmh.annotations.Mode; +import org.openjdk.jmh.annotations.Scope; +import org.openjdk.jmh.annotations.State; +import org.openjdk.jmh.annotations.Warmup; + +@Fork(1) +@State(Scope.Benchmark) +@Warmup(iterations = 3) +@Measurement(iterations = 5) +@BenchmarkMode(Mode.SingleShotTime) +public abstract class IcebergSourceBenchmark { + + private final Configuration hadoopConf = initHadoopConf(); + private final Table table = initTable(); + private SparkSession spark; + + protected abstract Configuration initHadoopConf(); + + protected final Configuration hadoopConf() { + return hadoopConf; + } + + protected abstract Table initTable(); + + protected final Table table() { + return table; + } + + protected final SparkSession spark() { + return spark; + } + + protected String newTableLocation() { + String tmpDir = hadoopConf.get("hadoop.tmp.dir"); + Path tablePath = new Path(tmpDir, "spark-iceberg-table-" + UUID.randomUUID()); + return tablePath.toString(); + } + + protected String dataLocation() { + Map properties = table.properties(); + return properties.getOrDefault(TableProperties.WRITE_DATA_LOCATION, String.format("%s/data", table.location())); + } + + protected void cleanupFiles() throws IOException { + try (FileSystem fileSystem = FileSystem.get(hadoopConf)) { + Path dataPath = new Path(dataLocation()); + fileSystem.delete(dataPath, true); + Path tablePath = new Path(table.location()); + fileSystem.delete(tablePath, true); + } + } + + protected void setupSpark(boolean enableDictionaryEncoding) { + SparkSession.Builder builder = SparkSession.builder() + .config("spark.ui.enabled", false); + if (!enableDictionaryEncoding) { + builder.config("parquet.dictionary.page.size", "1") + .config("parquet.enable.dictionary", false) + .config(TableProperties.PARQUET_DICT_SIZE_BYTES, "1"); + } + builder.master("local"); + spark = builder.getOrCreate(); + Configuration sparkHadoopConf = spark.sessionState().newHadoopConf(); + hadoopConf.forEach(entry -> sparkHadoopConf.set(entry.getKey(), entry.getValue())); + } + + protected void setupSpark() { + setupSpark(false); + } + + protected void tearDownSpark() { + spark.stop(); + } + + protected void materialize(Dataset ds) { + ds.queryExecution().toRdd().toJavaRDD().foreach(record -> { }); + } + + protected void appendAsFile(Dataset ds) { + // ensure the schema is precise (including nullability) + StructType sparkSchema = SparkSchemaUtil.convert(table.schema()); + spark.createDataFrame(ds.rdd(), sparkSchema) + .coalesce(1) + .write() + .format("iceberg") + .mode(SaveMode.Append) + .save(table.location()); + } + + protected void withSQLConf(Map conf, Action action) { + SQLConf sqlConf = SQLConf.get(); + + Map currentConfValues = Maps.newHashMap(); + conf.keySet().forEach(confKey -> { + if (sqlConf.contains(confKey)) { + String currentConfValue = sqlConf.getConfString(confKey); + currentConfValues.put(confKey, currentConfValue); + } + }); + + conf.forEach((confKey, confValue) -> { + if (SQLConf.isStaticConfigKey(confKey)) { + throw new RuntimeException("Cannot modify the value of a static config: " + confKey); + } + sqlConf.setConfString(confKey, confValue); + }); + + try { + action.invoke(); + } finally { + conf.forEach((confKey, confValue) -> { + if (currentConfValues.containsKey(confKey)) { + sqlConf.setConfString(confKey, currentConfValues.get(confKey)); + } else { + sqlConf.unsetConf(confKey); + } + }); + } + } + + protected void withTableProperties(Map props, Action action) { + Map tableProps = table.properties(); + Map currentPropValues = Maps.newHashMap(); + props.keySet().forEach(propKey -> { + if (tableProps.containsKey(propKey)) { + String currentPropValue = tableProps.get(propKey); + currentPropValues.put(propKey, currentPropValue); + } + }); + + UpdateProperties updateProperties = table.updateProperties(); + props.forEach(updateProperties::set); + updateProperties.commit(); + + try { + action.invoke(); + } finally { + UpdateProperties restoreProperties = table.updateProperties(); + props.forEach((propKey, propValue) -> { + if (currentPropValues.containsKey(propKey)) { + restoreProperties.set(propKey, currentPropValues.get(propKey)); + } else { + restoreProperties.remove(propKey); + } + }); + restoreProperties.commit(); + } + } + + protected FileFormat fileFormat() { + throw new UnsupportedOperationException("Unsupported file format"); + } +} diff --git a/spark/v3.3/spark/src/jmh/java/org/apache/iceberg/spark/source/IcebergSourceDeleteBenchmark.java b/spark/v3.3/spark/src/jmh/java/org/apache/iceberg/spark/source/IcebergSourceDeleteBenchmark.java new file mode 100644 index 000000000000..0cdf209889b7 --- /dev/null +++ b/spark/v3.3/spark/src/jmh/java/org/apache/iceberg/spark/source/IcebergSourceDeleteBenchmark.java @@ -0,0 +1,287 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iceberg.spark.source; + +import java.io.IOException; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.UUID; +import java.util.concurrent.ThreadLocalRandom; +import org.apache.hadoop.conf.Configuration; +import org.apache.iceberg.PartitionSpec; +import org.apache.iceberg.RowDelta; +import org.apache.iceberg.Schema; +import org.apache.iceberg.Table; +import org.apache.iceberg.TableProperties; +import org.apache.iceberg.deletes.PositionDelete; +import org.apache.iceberg.hadoop.HadoopTables; +import org.apache.iceberg.io.ClusteredEqualityDeleteWriter; +import org.apache.iceberg.io.ClusteredPositionDeleteWriter; +import org.apache.iceberg.io.OutputFileFactory; +import org.apache.iceberg.relocated.com.google.common.collect.Iterables; +import org.apache.iceberg.relocated.com.google.common.collect.Lists; +import org.apache.iceberg.relocated.com.google.common.collect.Maps; +import org.apache.iceberg.relocated.com.google.common.collect.Sets; +import org.apache.iceberg.types.Types; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.catalyst.InternalRow; +import org.apache.spark.sql.catalyst.expressions.GenericInternalRow; +import org.openjdk.jmh.annotations.Benchmark; +import org.openjdk.jmh.annotations.Setup; +import org.openjdk.jmh.annotations.TearDown; +import org.openjdk.jmh.annotations.Threads; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import static org.apache.iceberg.TableProperties.SPLIT_OPEN_FILE_COST; +import static org.apache.iceberg.types.Types.NestedField.optional; +import static org.apache.iceberg.types.Types.NestedField.required; +import static org.apache.spark.sql.functions.current_date; +import static org.apache.spark.sql.functions.date_add; +import static org.apache.spark.sql.functions.expr; + +public abstract class IcebergSourceDeleteBenchmark extends IcebergSourceBenchmark { + private static final Logger LOG = LoggerFactory.getLogger(IcebergSourceDeleteBenchmark.class); + private static final long TARGET_FILE_SIZE_IN_BYTES = 512L * 1024 * 1024; + + protected static final int NUM_FILES = 1; + protected static final int NUM_ROWS = 10 * 1000 * 1000; + + @Setup + public void setupBenchmark() throws IOException { + setupSpark(); + appendData(); + } + + @TearDown + public void tearDownBenchmark() throws IOException { + tearDownSpark(); + cleanupFiles(); + } + + @Benchmark + @Threads(1) + public void readIceberg() { + Map tableProperties = Maps.newHashMap(); + tableProperties.put(SPLIT_OPEN_FILE_COST, Integer.toString(128 * 1024 * 1024)); + withTableProperties(tableProperties, () -> { + String tableLocation = table().location(); + Dataset df = spark().read().format("iceberg").load(tableLocation); + materialize(df); + }); + } + + @Benchmark + @Threads(1) + public void readIcebergVectorized() { + Map tableProperties = Maps.newHashMap(); + tableProperties.put(SPLIT_OPEN_FILE_COST, Integer.toString(128 * 1024 * 1024)); + tableProperties.put(TableProperties.PARQUET_VECTORIZATION_ENABLED, "true"); + withTableProperties(tableProperties, () -> { + String tableLocation = table().location(); + Dataset df = spark().read().format("iceberg").load(tableLocation); + materialize(df); + }); + } + + @Benchmark + @Threads(1) + public void readIcebergWithIsDeletedColumn() { + Map tableProperties = Maps.newHashMap(); + tableProperties.put(SPLIT_OPEN_FILE_COST, Integer.toString(128 * 1024 * 1024)); + withTableProperties(tableProperties, () -> { + String tableLocation = table().location(); + Dataset df = spark().read().format("iceberg").load(tableLocation).filter("_deleted = false"); + materialize(df); + }); + } + + @Benchmark + @Threads(1) + public void readDeletedRows() { + Map tableProperties = Maps.newHashMap(); + tableProperties.put(SPLIT_OPEN_FILE_COST, Integer.toString(128 * 1024 * 1024)); + withTableProperties(tableProperties, () -> { + String tableLocation = table().location(); + Dataset df = spark().read().format("iceberg").load(tableLocation).filter("_deleted = true"); + materialize(df); + }); + } + + protected abstract void appendData() throws IOException; + + protected void writeData(int fileNum) { + Dataset df = spark().range(NUM_ROWS) + .withColumnRenamed("id", "longCol") + .withColumn("intCol", expr("CAST(MOD(longCol, 2147483647) AS INT)")) + .withColumn("floatCol", expr("CAST(longCol AS FLOAT)")) + .withColumn("doubleCol", expr("CAST(longCol AS DOUBLE)")) + .withColumn("dateCol", date_add(current_date(), fileNum)) + .withColumn("timestampCol", expr("TO_TIMESTAMP(dateCol)")) + .withColumn("stringCol", expr("CAST(dateCol AS STRING)")); + appendAsFile(df); + } + + @Override + protected Table initTable() { + Schema schema = new Schema( + required(1, "longCol", Types.LongType.get()), + required(2, "intCol", Types.IntegerType.get()), + required(3, "floatCol", Types.FloatType.get()), + optional(4, "doubleCol", Types.DoubleType.get()), + optional(6, "dateCol", Types.DateType.get()), + optional(7, "timestampCol", Types.TimestampType.withZone()), + optional(8, "stringCol", Types.StringType.get())); + PartitionSpec partitionSpec = PartitionSpec.unpartitioned(); + HadoopTables tables = new HadoopTables(hadoopConf()); + Map properties = Maps.newHashMap(); + properties.put(TableProperties.METADATA_COMPRESSION, "gzip"); + properties.put(TableProperties.FORMAT_VERSION, "2"); + return tables.create(schema, partitionSpec, properties, newTableLocation()); + } + + @Override + protected Configuration initHadoopConf() { + return new Configuration(); + } + + protected void writePosDeletes(CharSequence path, long numRows, double percentage) throws IOException { + writePosDeletes(path, numRows, percentage, 1); + } + + protected void writePosDeletes(CharSequence path, long numRows, double percentage, + int numDeleteFile) throws IOException { + writePosDeletesWithNoise(path, numRows, percentage, 0, numDeleteFile); + } + + protected void writePosDeletesWithNoise(CharSequence path, long numRows, double percentage, int numNoise, + int numDeleteFile) throws IOException { + Set deletedPos = Sets.newHashSet(); + while (deletedPos.size() < numRows * percentage) { + deletedPos.add(ThreadLocalRandom.current().nextLong(numRows)); + } + LOG.info("pos delete row count: {}, num of delete files: {}", deletedPos.size(), numDeleteFile); + + int partitionSize = (int) (numRows * percentage) / numDeleteFile; + Iterable> sets = Iterables.partition(deletedPos, partitionSize); + for (List item : sets) { + writePosDeletes(path, item, numNoise); + } + } + + protected void writePosDeletes(CharSequence path, List deletedPos, int numNoise) throws IOException { + OutputFileFactory fileFactory = newFileFactory(); + SparkFileWriterFactory writerFactory = SparkFileWriterFactory.builderFor(table()) + .dataFileFormat(fileFormat()) + .build(); + + ClusteredPositionDeleteWriter writer = new ClusteredPositionDeleteWriter<>( + writerFactory, fileFactory, table().io(), TARGET_FILE_SIZE_IN_BYTES); + + PartitionSpec unpartitionedSpec = table().specs().get(0); + + PositionDelete positionDelete = PositionDelete.create(); + try (ClusteredPositionDeleteWriter closeableWriter = writer) { + for (Long pos : deletedPos) { + positionDelete.set(path, pos, null); + closeableWriter.write(positionDelete, unpartitionedSpec, null); + for (int i = 0; i < numNoise; i++) { + positionDelete.set(noisePath(path), pos, null); + closeableWriter.write(positionDelete, unpartitionedSpec, null); + } + } + } + + RowDelta rowDelta = table().newRowDelta(); + writer.result().deleteFiles().forEach(rowDelta::addDeletes); + rowDelta.validateDeletedFiles().commit(); + } + + protected void writeEqDeletes(long numRows, double percentage) throws IOException { + Set deletedValues = Sets.newHashSet(); + while (deletedValues.size() < numRows * percentage) { + deletedValues.add(ThreadLocalRandom.current().nextLong(numRows)); + } + + List rows = Lists.newArrayList(); + for (Long value : deletedValues) { + GenericInternalRow genericInternalRow = new GenericInternalRow(7); + genericInternalRow.setLong(0, value); + genericInternalRow.setInt(1, (int) (value % Integer.MAX_VALUE)); + genericInternalRow.setFloat(2, (float) value); + genericInternalRow.setNullAt(3); + genericInternalRow.setNullAt(4); + genericInternalRow.setNullAt(5); + genericInternalRow.setNullAt(6); + rows.add(genericInternalRow); + } + LOG.info("Num of equality deleted rows: {}", rows.size()); + + writeEqDeletes(rows); + } + + private void writeEqDeletes(List rows) throws IOException { + int equalityFieldId = table().schema().findField("longCol").fieldId(); + + OutputFileFactory fileFactory = newFileFactory(); + SparkFileWriterFactory writerFactory = SparkFileWriterFactory + .builderFor(table()) + .dataFileFormat(fileFormat()) + .equalityDeleteRowSchema(table().schema()) + .equalityFieldIds(new int[]{equalityFieldId}) + .build(); + + ClusteredEqualityDeleteWriter writer = new ClusteredEqualityDeleteWriter<>( + writerFactory, fileFactory, table().io(), TARGET_FILE_SIZE_IN_BYTES); + + PartitionSpec unpartitionedSpec = table().specs().get(0); + try (ClusteredEqualityDeleteWriter closeableWriter = writer) { + for (InternalRow row : rows) { + closeableWriter.write(row, unpartitionedSpec, null); + } + } + + RowDelta rowDelta = table().newRowDelta(); + LOG.info("Num of Delete File: {}", writer.result().deleteFiles().size()); + writer.result().deleteFiles().forEach(rowDelta::addDeletes); + rowDelta.validateDeletedFiles().commit(); + } + + private OutputFileFactory newFileFactory() { + return OutputFileFactory.builderFor(table(), 1, 1) + .format(fileFormat()) + .build(); + } + + private CharSequence noisePath(CharSequence path) { + // assume the data file name would be something like "00000-0-30da64e0-56b5-4743-a11b-3188a1695bf7-00001.parquet" + // so the dataFileSuffixLen is the UUID string length + length of "-00001.parquet", which is 36 + 14 = 60. It's OK + // to be not accurate here. + int dataFileSuffixLen = 60; + UUID uuid = UUID.randomUUID(); + if (path.length() > dataFileSuffixLen) { + return path.subSequence(0, dataFileSuffixLen) + uuid.toString(); + } else { + return uuid.toString(); + } + } +} diff --git a/spark/v3.3/spark/src/jmh/java/org/apache/iceberg/spark/source/IcebergSourceFlatDataBenchmark.java b/spark/v3.3/spark/src/jmh/java/org/apache/iceberg/spark/source/IcebergSourceFlatDataBenchmark.java new file mode 100644 index 000000000000..9e206321a540 --- /dev/null +++ b/spark/v3.3/spark/src/jmh/java/org/apache/iceberg/spark/source/IcebergSourceFlatDataBenchmark.java @@ -0,0 +1,59 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iceberg.spark.source; + +import java.util.Map; +import org.apache.hadoop.conf.Configuration; +import org.apache.iceberg.PartitionSpec; +import org.apache.iceberg.Schema; +import org.apache.iceberg.Table; +import org.apache.iceberg.TableProperties; +import org.apache.iceberg.hadoop.HadoopTables; +import org.apache.iceberg.relocated.com.google.common.collect.Maps; +import org.apache.iceberg.types.Types; + +import static org.apache.iceberg.types.Types.NestedField.optional; +import static org.apache.iceberg.types.Types.NestedField.required; + +public abstract class IcebergSourceFlatDataBenchmark extends IcebergSourceBenchmark { + + @Override + protected Configuration initHadoopConf() { + return new Configuration(); + } + + @Override + protected final Table initTable() { + Schema schema = new Schema( + required(1, "longCol", Types.LongType.get()), + required(2, "intCol", Types.IntegerType.get()), + required(3, "floatCol", Types.FloatType.get()), + optional(4, "doubleCol", Types.DoubleType.get()), + optional(5, "decimalCol", Types.DecimalType.of(20, 5)), + optional(6, "dateCol", Types.DateType.get()), + optional(7, "timestampCol", Types.TimestampType.withZone()), + optional(8, "stringCol", Types.StringType.get())); + PartitionSpec partitionSpec = PartitionSpec.unpartitioned(); + HadoopTables tables = new HadoopTables(hadoopConf()); + Map properties = Maps.newHashMap(); + properties.put(TableProperties.METADATA_COMPRESSION, "gzip"); + return tables.create(schema, partitionSpec, properties, newTableLocation()); + } +} diff --git a/spark/v3.3/spark/src/jmh/java/org/apache/iceberg/spark/source/IcebergSourceNestedDataBenchmark.java b/spark/v3.3/spark/src/jmh/java/org/apache/iceberg/spark/source/IcebergSourceNestedDataBenchmark.java new file mode 100644 index 000000000000..5a0d9359ec6b --- /dev/null +++ b/spark/v3.3/spark/src/jmh/java/org/apache/iceberg/spark/source/IcebergSourceNestedDataBenchmark.java @@ -0,0 +1,58 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iceberg.spark.source; + +import java.util.Map; +import org.apache.hadoop.conf.Configuration; +import org.apache.iceberg.PartitionSpec; +import org.apache.iceberg.Schema; +import org.apache.iceberg.Table; +import org.apache.iceberg.TableProperties; +import org.apache.iceberg.hadoop.HadoopTables; +import org.apache.iceberg.relocated.com.google.common.collect.Maps; +import org.apache.iceberg.types.Types; + +import static org.apache.iceberg.types.Types.NestedField.optional; +import static org.apache.iceberg.types.Types.NestedField.required; + +public abstract class IcebergSourceNestedDataBenchmark extends IcebergSourceBenchmark { + + @Override + protected Configuration initHadoopConf() { + return new Configuration(); + } + + @Override + protected final Table initTable() { + Schema schema = new Schema( + required(0, "id", Types.LongType.get()), + optional(4, "nested", Types.StructType.of( + required(1, "col1", Types.StringType.get()), + required(2, "col2", Types.DoubleType.get()), + required(3, "col3", Types.LongType.get()) + )) + ); + PartitionSpec partitionSpec = PartitionSpec.unpartitioned(); + HadoopTables tables = new HadoopTables(hadoopConf()); + Map properties = Maps.newHashMap(); + properties.put(TableProperties.METADATA_COMPRESSION, "gzip"); + return tables.create(schema, partitionSpec, properties, newTableLocation()); + } +} diff --git a/spark/v3.3/spark/src/jmh/java/org/apache/iceberg/spark/source/IcebergSourceNestedListDataBenchmark.java b/spark/v3.3/spark/src/jmh/java/org/apache/iceberg/spark/source/IcebergSourceNestedListDataBenchmark.java new file mode 100644 index 000000000000..369a1507b648 --- /dev/null +++ b/spark/v3.3/spark/src/jmh/java/org/apache/iceberg/spark/source/IcebergSourceNestedListDataBenchmark.java @@ -0,0 +1,56 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iceberg.spark.source; + +import java.util.Map; +import org.apache.hadoop.conf.Configuration; +import org.apache.iceberg.PartitionSpec; +import org.apache.iceberg.Schema; +import org.apache.iceberg.Table; +import org.apache.iceberg.TableProperties; +import org.apache.iceberg.hadoop.HadoopTables; +import org.apache.iceberg.relocated.com.google.common.collect.Maps; +import org.apache.iceberg.types.Types; + +import static org.apache.iceberg.types.Types.NestedField.optional; +import static org.apache.iceberg.types.Types.NestedField.required; + +public abstract class IcebergSourceNestedListDataBenchmark extends IcebergSourceBenchmark { + + @Override + protected Configuration initHadoopConf() { + return new Configuration(); + } + + @Override + protected final Table initTable() { + Schema schema = new Schema( + required(0, "id", Types.LongType.get()), + optional(1, "outerlist", Types.ListType.ofOptional(2, + Types.StructType.of(required(3, "innerlist", Types.ListType.ofRequired(4, Types.StringType.get()))) + )) + ); + PartitionSpec partitionSpec = PartitionSpec.unpartitioned(); + HadoopTables tables = new HadoopTables(hadoopConf()); + Map properties = Maps.newHashMap(); + properties.put(TableProperties.METADATA_COMPRESSION, "gzip"); + return tables.create(schema, partitionSpec, properties, newTableLocation()); + } +} diff --git a/spark/v3.3/spark/src/jmh/java/org/apache/iceberg/spark/source/WritersBenchmark.java b/spark/v3.3/spark/src/jmh/java/org/apache/iceberg/spark/source/WritersBenchmark.java new file mode 100644 index 000000000000..acec471bdfd1 --- /dev/null +++ b/spark/v3.3/spark/src/jmh/java/org/apache/iceberg/spark/source/WritersBenchmark.java @@ -0,0 +1,349 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iceberg.spark.source; + +import java.io.IOException; +import java.util.Comparator; +import java.util.List; +import java.util.Map; +import org.apache.hadoop.conf.Configuration; +import org.apache.iceberg.FileFormat; +import org.apache.iceberg.PartitionKey; +import org.apache.iceberg.PartitionSpec; +import org.apache.iceberg.Schema; +import org.apache.iceberg.Table; +import org.apache.iceberg.deletes.PositionDelete; +import org.apache.iceberg.expressions.Expressions; +import org.apache.iceberg.hadoop.HadoopTables; +import org.apache.iceberg.io.ClusteredDataWriter; +import org.apache.iceberg.io.ClusteredEqualityDeleteWriter; +import org.apache.iceberg.io.ClusteredPositionDeleteWriter; +import org.apache.iceberg.io.DeleteSchemaUtil; +import org.apache.iceberg.io.FanoutDataWriter; +import org.apache.iceberg.io.FileIO; +import org.apache.iceberg.io.OutputFileFactory; +import org.apache.iceberg.io.TaskWriter; +import org.apache.iceberg.io.UnpartitionedWriter; +import org.apache.iceberg.relocated.com.google.common.base.Preconditions; +import org.apache.iceberg.relocated.com.google.common.collect.Lists; +import org.apache.iceberg.relocated.com.google.common.collect.Maps; +import org.apache.iceberg.spark.SparkSchemaUtil; +import org.apache.iceberg.spark.data.RandomData; +import org.apache.iceberg.transforms.Transform; +import org.apache.iceberg.transforms.Transforms; +import org.apache.iceberg.types.Types; +import org.apache.spark.sql.catalyst.InternalRow; +import org.apache.spark.sql.types.StructType; +import org.openjdk.jmh.annotations.Benchmark; +import org.openjdk.jmh.annotations.Setup; +import org.openjdk.jmh.annotations.TearDown; +import org.openjdk.jmh.annotations.Threads; +import org.openjdk.jmh.infra.Blackhole; + +import static org.apache.iceberg.types.Types.NestedField.optional; +import static org.apache.iceberg.types.Types.NestedField.required; + +public abstract class WritersBenchmark extends IcebergSourceBenchmark { + + private static final int NUM_ROWS = 2500000; + private static final long TARGET_FILE_SIZE_IN_BYTES = 50L * 1024 * 1024; + + private static final Schema SCHEMA = new Schema( + required(1, "longCol", Types.LongType.get()), + required(2, "intCol", Types.IntegerType.get()), + required(3, "floatCol", Types.FloatType.get()), + optional(4, "doubleCol", Types.DoubleType.get()), + optional(5, "decimalCol", Types.DecimalType.of(20, 5)), + optional(6, "timestampCol", Types.TimestampType.withZone()), + optional(7, "stringCol", Types.StringType.get()) + ); + + private Iterable rows; + private Iterable positionDeleteRows; + private PartitionSpec unpartitionedSpec; + private PartitionSpec partitionedSpec; + + @Override + protected abstract FileFormat fileFormat(); + + @Setup + public void setupBenchmark() { + setupSpark(); + + List data = Lists.newArrayList(RandomData.generateSpark(SCHEMA, NUM_ROWS, 0L)); + Transform transform = Transforms.bucket(Types.IntegerType.get(), 32); + data.sort(Comparator.comparingInt(row -> transform.apply(row.getInt(1)))); + this.rows = data; + + this.positionDeleteRows = RandomData.generateSpark(DeleteSchemaUtil.pathPosSchema(), NUM_ROWS, 0L); + + this.unpartitionedSpec = table().specs().get(0); + Preconditions.checkArgument(unpartitionedSpec.isUnpartitioned()); + this.partitionedSpec = table().specs().get(1); + } + + @TearDown + public void tearDownBenchmark() throws IOException { + tearDownSpark(); + cleanupFiles(); + } + + @Override + protected Configuration initHadoopConf() { + return new Configuration(); + } + + @Override + protected final Table initTable() { + HadoopTables tables = new HadoopTables(hadoopConf()); + PartitionSpec spec = PartitionSpec.unpartitioned(); + Map properties = Maps.newHashMap(); + Table table = tables.create(SCHEMA, spec, properties, newTableLocation()); + + // add a partitioned spec to the table + table.updateSpec() + .addField(Expressions.bucket("intCol", 32)) + .commit(); + + return table; + } + + @Benchmark + @Threads(1) + public void writeUnpartitionedClusteredDataWriter(Blackhole blackhole) throws IOException { + FileIO io = table().io(); + + OutputFileFactory fileFactory = newFileFactory(); + SparkFileWriterFactory writerFactory = SparkFileWriterFactory.builderFor(table()) + .dataFileFormat(fileFormat()) + .dataSchema(table().schema()) + .build(); + + ClusteredDataWriter writer = new ClusteredDataWriter<>( + writerFactory, fileFactory, io, TARGET_FILE_SIZE_IN_BYTES); + + try (ClusteredDataWriter closeableWriter = writer) { + for (InternalRow row : rows) { + closeableWriter.write(row, unpartitionedSpec, null); + } + } + + blackhole.consume(writer); + } + + @Benchmark + @Threads(1) + public void writeUnpartitionedLegacyDataWriter(Blackhole blackhole) throws IOException { + FileIO io = table().io(); + + OutputFileFactory fileFactory = newFileFactory(); + + Schema writeSchema = table().schema(); + StructType sparkWriteType = SparkSchemaUtil.convert(writeSchema); + SparkAppenderFactory appenders = SparkAppenderFactory.builderFor(table(), writeSchema, sparkWriteType) + .spec(unpartitionedSpec) + .build(); + + TaskWriter writer = new UnpartitionedWriter<>( + unpartitionedSpec, fileFormat(), appenders, + fileFactory, io, TARGET_FILE_SIZE_IN_BYTES); + + try (TaskWriter closableWriter = writer) { + for (InternalRow row : rows) { + closableWriter.write(row); + } + } + + blackhole.consume(writer.complete()); + } + + @Benchmark + @Threads(1) + public void writePartitionedClusteredDataWriter(Blackhole blackhole) throws IOException { + FileIO io = table().io(); + + OutputFileFactory fileFactory = newFileFactory(); + SparkFileWriterFactory writerFactory = SparkFileWriterFactory.builderFor(table()) + .dataFileFormat(fileFormat()) + .dataSchema(table().schema()) + .build(); + + ClusteredDataWriter writer = new ClusteredDataWriter<>( + writerFactory, fileFactory, io, TARGET_FILE_SIZE_IN_BYTES); + + PartitionKey partitionKey = new PartitionKey(partitionedSpec, table().schema()); + StructType dataSparkType = SparkSchemaUtil.convert(table().schema()); + InternalRowWrapper internalRowWrapper = new InternalRowWrapper(dataSparkType); + + try (ClusteredDataWriter closeableWriter = writer) { + for (InternalRow row : rows) { + partitionKey.partition(internalRowWrapper.wrap(row)); + closeableWriter.write(row, partitionedSpec, partitionKey); + } + } + + blackhole.consume(writer); + } + + @Benchmark + @Threads(1) + public void writePartitionedLegacyDataWriter(Blackhole blackhole) throws IOException { + FileIO io = table().io(); + + OutputFileFactory fileFactory = newFileFactory(); + + Schema writeSchema = table().schema(); + StructType sparkWriteType = SparkSchemaUtil.convert(writeSchema); + SparkAppenderFactory appenders = SparkAppenderFactory.builderFor(table(), writeSchema, sparkWriteType) + .spec(partitionedSpec) + .build(); + + TaskWriter writer = new SparkPartitionedWriter( + partitionedSpec, fileFormat(), appenders, + fileFactory, io, TARGET_FILE_SIZE_IN_BYTES, + writeSchema, sparkWriteType); + + try (TaskWriter closableWriter = writer) { + for (InternalRow row : rows) { + closableWriter.write(row); + } + } + + blackhole.consume(writer.complete()); + } + + @Benchmark + @Threads(1) + public void writePartitionedFanoutDataWriter(Blackhole blackhole) throws IOException { + FileIO io = table().io(); + + OutputFileFactory fileFactory = newFileFactory(); + SparkFileWriterFactory writerFactory = SparkFileWriterFactory.builderFor(table()) + .dataFileFormat(fileFormat()) + .dataSchema(table().schema()) + .build(); + + FanoutDataWriter writer = new FanoutDataWriter<>( + writerFactory, fileFactory, io, TARGET_FILE_SIZE_IN_BYTES); + + PartitionKey partitionKey = new PartitionKey(partitionedSpec, table().schema()); + StructType dataSparkType = SparkSchemaUtil.convert(table().schema()); + InternalRowWrapper internalRowWrapper = new InternalRowWrapper(dataSparkType); + + try (FanoutDataWriter closeableWriter = writer) { + for (InternalRow row : rows) { + partitionKey.partition(internalRowWrapper.wrap(row)); + closeableWriter.write(row, partitionedSpec, partitionKey); + } + } + + blackhole.consume(writer); + } + + @Benchmark + @Threads(1) + public void writePartitionedLegacyFanoutDataWriter(Blackhole blackhole) throws IOException { + FileIO io = table().io(); + + OutputFileFactory fileFactory = newFileFactory(); + + Schema writeSchema = table().schema(); + StructType sparkWriteType = SparkSchemaUtil.convert(writeSchema); + SparkAppenderFactory appenders = SparkAppenderFactory.builderFor(table(), writeSchema, sparkWriteType) + .spec(partitionedSpec) + .build(); + + TaskWriter writer = new SparkPartitionedFanoutWriter( + partitionedSpec, fileFormat(), appenders, + fileFactory, io, TARGET_FILE_SIZE_IN_BYTES, + writeSchema, sparkWriteType); + + try (TaskWriter closableWriter = writer) { + for (InternalRow row : rows) { + closableWriter.write(row); + } + } + + blackhole.consume(writer.complete()); + } + + @Benchmark + @Threads(1) + public void writePartitionedClusteredEqualityDeleteWriter(Blackhole blackhole) throws IOException { + FileIO io = table().io(); + + int equalityFieldId = table().schema().findField("longCol").fieldId(); + + OutputFileFactory fileFactory = newFileFactory(); + SparkFileWriterFactory writerFactory = SparkFileWriterFactory.builderFor(table()) + .dataFileFormat(fileFormat()) + .equalityDeleteRowSchema(table().schema()) + .equalityFieldIds(new int[]{equalityFieldId}) + .build(); + + ClusteredEqualityDeleteWriter writer = new ClusteredEqualityDeleteWriter<>( + writerFactory, fileFactory, io, TARGET_FILE_SIZE_IN_BYTES); + + PartitionKey partitionKey = new PartitionKey(partitionedSpec, table().schema()); + StructType deleteSparkType = SparkSchemaUtil.convert(table().schema()); + InternalRowWrapper internalRowWrapper = new InternalRowWrapper(deleteSparkType); + + try (ClusteredEqualityDeleteWriter closeableWriter = writer) { + for (InternalRow row : rows) { + partitionKey.partition(internalRowWrapper.wrap(row)); + closeableWriter.write(row, partitionedSpec, partitionKey); + } + } + + blackhole.consume(writer); + } + + @Benchmark + @Threads(1) + public void writeUnpartitionedClusteredPositionDeleteWriter(Blackhole blackhole) throws IOException { + FileIO io = table().io(); + + OutputFileFactory fileFactory = newFileFactory(); + SparkFileWriterFactory writerFactory = SparkFileWriterFactory.builderFor(table()) + .dataFileFormat(fileFormat()) + .build(); + + ClusteredPositionDeleteWriter writer = new ClusteredPositionDeleteWriter<>( + writerFactory, fileFactory, io, TARGET_FILE_SIZE_IN_BYTES); + + PositionDelete positionDelete = PositionDelete.create(); + try (ClusteredPositionDeleteWriter closeableWriter = writer) { + for (InternalRow row : positionDeleteRows) { + String path = row.getString(0); + long pos = row.getLong(1); + positionDelete.set(path, pos, null); + closeableWriter.write(positionDelete, unpartitionedSpec, null); + } + } + + blackhole.consume(writer); + } + + private OutputFileFactory newFileFactory() { + return OutputFileFactory.builderFor(table(), 1, 1) + .format(fileFormat()) + .build(); + } +} diff --git a/spark/v3.3/spark/src/jmh/java/org/apache/iceberg/spark/source/avro/AvroWritersBenchmark.java b/spark/v3.3/spark/src/jmh/java/org/apache/iceberg/spark/source/avro/AvroWritersBenchmark.java new file mode 100644 index 000000000000..868edcc90517 --- /dev/null +++ b/spark/v3.3/spark/src/jmh/java/org/apache/iceberg/spark/source/avro/AvroWritersBenchmark.java @@ -0,0 +1,41 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iceberg.spark.source.avro; + +import org.apache.iceberg.FileFormat; +import org.apache.iceberg.spark.source.WritersBenchmark; + +/** + * A benchmark that evaluates the performance of various Iceberg writers for Avro data. + *

+ * To run this benchmark for spark-3.3: + * + * ./gradlew -DsparkVersions=3.3 :iceberg-spark:iceberg-spark-3.3_2.12:jmh + * -PjmhIncludeRegex=AvroWritersBenchmark + * -PjmhOutputPath=benchmark/avro-writers-benchmark-result.txt + * + */ +public class AvroWritersBenchmark extends WritersBenchmark { + + @Override + protected FileFormat fileFormat() { + return FileFormat.AVRO; + } +} diff --git a/spark/v3.3/spark/src/jmh/java/org/apache/iceberg/spark/source/avro/IcebergSourceFlatAvroDataReadBenchmark.java b/spark/v3.3/spark/src/jmh/java/org/apache/iceberg/spark/source/avro/IcebergSourceFlatAvroDataReadBenchmark.java new file mode 100644 index 000000000000..14d39fcad750 --- /dev/null +++ b/spark/v3.3/spark/src/jmh/java/org/apache/iceberg/spark/source/avro/IcebergSourceFlatAvroDataReadBenchmark.java @@ -0,0 +1,132 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iceberg.spark.source.avro; + +import java.io.IOException; +import java.util.Map; +import org.apache.iceberg.relocated.com.google.common.collect.Maps; +import org.apache.iceberg.spark.source.IcebergSourceFlatDataBenchmark; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.internal.SQLConf; +import org.openjdk.jmh.annotations.Benchmark; +import org.openjdk.jmh.annotations.Setup; +import org.openjdk.jmh.annotations.TearDown; +import org.openjdk.jmh.annotations.Threads; + +import static org.apache.iceberg.TableProperties.DEFAULT_FILE_FORMAT; +import static org.apache.iceberg.TableProperties.SPLIT_OPEN_FILE_COST; +import static org.apache.spark.sql.functions.current_date; +import static org.apache.spark.sql.functions.date_add; +import static org.apache.spark.sql.functions.expr; + +/** + * A benchmark that evaluates the performance of reading Avro data with a flat schema + * using Iceberg and the built-in file source in Spark. + *

+ * To run this benchmark for spark-3.3: + * + * ./gradlew -DsparkVersions=3.3 :iceberg-spark:iceberg-spark-3.3_2.12:jmh + * -PjmhIncludeRegex=IcebergSourceFlatAvroDataReadBenchmark + * -PjmhOutputPath=benchmark/iceberg-source-flat-avro-data-read-benchmark-result.txt + * + */ +public class IcebergSourceFlatAvroDataReadBenchmark extends IcebergSourceFlatDataBenchmark { + + private static final int NUM_FILES = 10; + private static final int NUM_ROWS = 1000000; + + @Setup + public void setupBenchmark() { + setupSpark(); + appendData(); + } + + @TearDown + public void tearDownBenchmark() throws IOException { + tearDownSpark(); + cleanupFiles(); + } + + @Benchmark + @Threads(1) + public void readIceberg() { + Map tableProperties = Maps.newHashMap(); + tableProperties.put(SPLIT_OPEN_FILE_COST, Integer.toString(128 * 1024 * 1024)); + withTableProperties(tableProperties, () -> { + String tableLocation = table().location(); + Dataset df = spark().read().format("iceberg").load(tableLocation); + materialize(df); + }); + } + + @Benchmark + @Threads(1) + public void readFileSource() { + Map conf = Maps.newHashMap(); + conf.put(SQLConf.FILES_OPEN_COST_IN_BYTES().key(), Integer.toString(128 * 1024 * 1024)); + withSQLConf(conf, () -> { + Dataset df = spark().read().format("avro").load(dataLocation()); + materialize(df); + }); + } + + @Benchmark + @Threads(1) + public void readWithProjectionIceberg() { + Map tableProperties = Maps.newHashMap(); + tableProperties.put(SPLIT_OPEN_FILE_COST, Integer.toString(128 * 1024 * 1024)); + withTableProperties(tableProperties, () -> { + String tableLocation = table().location(); + Dataset df = spark().read().format("iceberg").load(tableLocation).select("longCol"); + materialize(df); + }); + } + + @Benchmark + @Threads(1) + public void readWithProjectionFileSource() { + Map conf = Maps.newHashMap(); + conf.put(SQLConf.FILES_OPEN_COST_IN_BYTES().key(), Integer.toString(128 * 1024 * 1024)); + withSQLConf(conf, () -> { + Dataset df = spark().read().format("avro").load(dataLocation()).select("longCol"); + materialize(df); + }); + } + + private void appendData() { + Map tableProperties = Maps.newHashMap(); + tableProperties.put(DEFAULT_FILE_FORMAT, "avro"); + withTableProperties(tableProperties, () -> { + for (int fileNum = 1; fileNum <= NUM_FILES; fileNum++) { + Dataset df = spark().range(NUM_ROWS) + .withColumnRenamed("id", "longCol") + .withColumn("intCol", expr("CAST(longCol AS INT)")) + .withColumn("floatCol", expr("CAST(longCol AS FLOAT)")) + .withColumn("doubleCol", expr("CAST(longCol AS DOUBLE)")) + .withColumn("decimalCol", expr("CAST(longCol AS DECIMAL(20, 5))")) + .withColumn("dateCol", date_add(current_date(), fileNum)) + .withColumn("timestampCol", expr("TO_TIMESTAMP(dateCol)")) + .withColumn("stringCol", expr("CAST(dateCol AS STRING)")); + appendAsFile(df); + } + }); + } +} diff --git a/spark/v3.3/spark/src/jmh/java/org/apache/iceberg/spark/source/avro/IcebergSourceNestedAvroDataReadBenchmark.java b/spark/v3.3/spark/src/jmh/java/org/apache/iceberg/spark/source/avro/IcebergSourceNestedAvroDataReadBenchmark.java new file mode 100644 index 000000000000..e9e492717cc3 --- /dev/null +++ b/spark/v3.3/spark/src/jmh/java/org/apache/iceberg/spark/source/avro/IcebergSourceNestedAvroDataReadBenchmark.java @@ -0,0 +1,132 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iceberg.spark.source.avro; + +import java.io.IOException; +import java.util.Map; +import org.apache.iceberg.relocated.com.google.common.collect.Maps; +import org.apache.iceberg.spark.source.IcebergSourceNestedDataBenchmark; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.internal.SQLConf; +import org.openjdk.jmh.annotations.Benchmark; +import org.openjdk.jmh.annotations.Setup; +import org.openjdk.jmh.annotations.TearDown; +import org.openjdk.jmh.annotations.Threads; + +import static org.apache.iceberg.TableProperties.DEFAULT_FILE_FORMAT; +import static org.apache.iceberg.TableProperties.SPLIT_OPEN_FILE_COST; +import static org.apache.spark.sql.functions.expr; +import static org.apache.spark.sql.functions.lit; +import static org.apache.spark.sql.functions.struct; + + +/** + * A benchmark that evaluates the performance of reading Avro data with a flat schema + * using Iceberg and the built-in file source in Spark. + *

+ * To run this benchmark for spark-3.3: + * + * ./gradlew -DsparkVersions=3.3 :iceberg-spark:iceberg-spark-3.3_2.12:jmh + * -PjmhIncludeRegex=IcebergSourceNestedAvroDataReadBenchmark + * -PjmhOutputPath=benchmark/iceberg-source-nested-avro-data-read-benchmark-result.txt + * + */ +public class IcebergSourceNestedAvroDataReadBenchmark extends IcebergSourceNestedDataBenchmark { + + private static final int NUM_FILES = 10; + private static final int NUM_ROWS = 1000000; + + @Setup + public void setupBenchmark() { + setupSpark(); + appendData(); + } + + @TearDown + public void tearDownBenchmark() throws IOException { + tearDownSpark(); + cleanupFiles(); + } + + @Benchmark + @Threads(1) + public void readIceberg() { + Map tableProperties = Maps.newHashMap(); + tableProperties.put(SPLIT_OPEN_FILE_COST, Integer.toString(128 * 1024 * 1024)); + withTableProperties(tableProperties, () -> { + String tableLocation = table().location(); + Dataset df = spark().read().format("iceberg").load(tableLocation); + materialize(df); + }); + } + + @Benchmark + @Threads(1) + public void readFileSource() { + Map conf = Maps.newHashMap(); + conf.put(SQLConf.FILES_OPEN_COST_IN_BYTES().key(), Integer.toString(128 * 1024 * 1024)); + withSQLConf(conf, () -> { + Dataset df = spark().read().format("avro").load(dataLocation()); + materialize(df); + }); + } + + @Benchmark + @Threads(1) + public void readWithProjectionIceberg() { + Map tableProperties = Maps.newHashMap(); + tableProperties.put(SPLIT_OPEN_FILE_COST, Integer.toString(128 * 1024 * 1024)); + withTableProperties(tableProperties, () -> { + String tableLocation = table().location(); + Dataset df = spark().read().format("iceberg").load(tableLocation).select("nested.col3"); + materialize(df); + }); + } + + @Benchmark + @Threads(1) + public void readWithProjectionFileSource() { + Map conf = Maps.newHashMap(); + conf.put(SQLConf.FILES_OPEN_COST_IN_BYTES().key(), Integer.toString(128 * 1024 * 1024)); + withSQLConf(conf, () -> { + Dataset df = spark().read().format("avro").load(dataLocation()).select("nested.col3"); + materialize(df); + }); + } + + private void appendData() { + Map tableProperties = Maps.newHashMap(); + tableProperties.put(DEFAULT_FILE_FORMAT, "avro"); + withTableProperties(tableProperties, () -> { + for (int fileNum = 1; fileNum <= NUM_FILES; fileNum++) { + Dataset df = spark().range(NUM_ROWS) + .withColumn( + "nested", + struct( + expr("CAST(id AS string) AS col1"), + expr("CAST(id AS double) AS col2"), + lit(fileNum).cast("long").as("col3") + )); + appendAsFile(df); + } + }); + } +} diff --git a/spark/v3.3/spark/src/jmh/java/org/apache/iceberg/spark/source/orc/IcebergSourceFlatORCDataBenchmark.java b/spark/v3.3/spark/src/jmh/java/org/apache/iceberg/spark/source/orc/IcebergSourceFlatORCDataBenchmark.java new file mode 100644 index 000000000000..329c9ffe7738 --- /dev/null +++ b/spark/v3.3/spark/src/jmh/java/org/apache/iceberg/spark/source/orc/IcebergSourceFlatORCDataBenchmark.java @@ -0,0 +1,67 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iceberg.spark.source.orc; + +import java.util.Map; +import org.apache.hadoop.conf.Configuration; +import org.apache.iceberg.PartitionSpec; +import org.apache.iceberg.Schema; +import org.apache.iceberg.Table; +import org.apache.iceberg.TableProperties; +import org.apache.iceberg.hadoop.HadoopTables; +import org.apache.iceberg.relocated.com.google.common.collect.Maps; +import org.apache.iceberg.spark.source.IcebergSourceBenchmark; +import org.apache.iceberg.types.Types; + +import static org.apache.iceberg.types.Types.NestedField.optional; +import static org.apache.iceberg.types.Types.NestedField.required; + + +/** + * Same as {@link org.apache.iceberg.spark.source.IcebergSourceFlatDataBenchmark} but we disable the Timestamp with + * zone type for ORC performance tests as Spark native reader does not support ORC's TIMESTAMP_INSTANT type + */ +public abstract class IcebergSourceFlatORCDataBenchmark extends IcebergSourceBenchmark { + + @Override + protected Configuration initHadoopConf() { + return new Configuration(); + } + + @Override + protected final Table initTable() { + Schema schema = new Schema( + required(1, "longCol", Types.LongType.get()), + required(2, "intCol", Types.IntegerType.get()), + required(3, "floatCol", Types.FloatType.get()), + optional(4, "doubleCol", Types.DoubleType.get()), + optional(5, "decimalCol", Types.DecimalType.of(20, 5)), + optional(6, "dateCol", Types.DateType.get()), + // Disable timestamp column for ORC performance tests as Spark native reader does not support ORC's + // TIMESTAMP_INSTANT type + // optional(7, "timestampCol", Types.TimestampType.withZone()), + optional(8, "stringCol", Types.StringType.get())); + PartitionSpec partitionSpec = PartitionSpec.unpartitioned(); + HadoopTables tables = new HadoopTables(hadoopConf()); + Map properties = Maps.newHashMap(); + properties.put(TableProperties.METADATA_COMPRESSION, "gzip"); + return tables.create(schema, partitionSpec, properties, newTableLocation()); + } +} diff --git a/spark/v3.3/spark/src/jmh/java/org/apache/iceberg/spark/source/orc/IcebergSourceFlatORCDataReadBenchmark.java b/spark/v3.3/spark/src/jmh/java/org/apache/iceberg/spark/source/orc/IcebergSourceFlatORCDataReadBenchmark.java new file mode 100644 index 000000000000..46aa8f83999d --- /dev/null +++ b/spark/v3.3/spark/src/jmh/java/org/apache/iceberg/spark/source/orc/IcebergSourceFlatORCDataReadBenchmark.java @@ -0,0 +1,184 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iceberg.spark.source.orc; + +import java.io.IOException; +import java.util.Map; +import org.apache.iceberg.relocated.com.google.common.collect.Maps; +import org.apache.iceberg.spark.SparkReadOptions; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.internal.SQLConf; +import org.openjdk.jmh.annotations.Benchmark; +import org.openjdk.jmh.annotations.Setup; +import org.openjdk.jmh.annotations.TearDown; +import org.openjdk.jmh.annotations.Threads; + +import static org.apache.iceberg.TableProperties.DEFAULT_FILE_FORMAT; +import static org.apache.iceberg.TableProperties.SPLIT_OPEN_FILE_COST; +import static org.apache.spark.sql.functions.current_date; +import static org.apache.spark.sql.functions.date_add; +import static org.apache.spark.sql.functions.expr; + +/** + * A benchmark that evaluates the performance of reading ORC data with a flat schema + * using Iceberg and the built-in file source in Spark. + *

+ * To run this benchmark for spark-3.3: + * + * ./gradlew -DsparkVersions=3.3 :iceberg-spark:iceberg-spark-3.3_2.12:jmh + * -PjmhIncludeRegex=IcebergSourceFlatORCDataReadBenchmark + * -PjmhOutputPath=benchmark/iceberg-source-flat-orc-data-read-benchmark-result.txt + * + */ +public class IcebergSourceFlatORCDataReadBenchmark extends IcebergSourceFlatORCDataBenchmark { + + private static final int NUM_FILES = 10; + private static final int NUM_ROWS = 1000000; + + @Setup + public void setupBenchmark() { + setupSpark(); + appendData(); + } + + @TearDown + public void tearDownBenchmark() throws IOException { + tearDownSpark(); + cleanupFiles(); + } + + @Benchmark + @Threads(1) + public void readIcebergNonVectorized() { + Map tableProperties = Maps.newHashMap(); + tableProperties.put(SPLIT_OPEN_FILE_COST, Integer.toString(128 * 1024 * 1024)); + withTableProperties(tableProperties, () -> { + String tableLocation = table().location(); + Dataset df = spark().read().format("iceberg").load(tableLocation); + materialize(df); + }); + } + + @Benchmark + @Threads(1) + public void readIcebergVectorized() { + Map tableProperties = Maps.newHashMap(); + tableProperties.put(SPLIT_OPEN_FILE_COST, Integer.toString(128 * 1024 * 1024)); + withTableProperties(tableProperties, () -> { + String tableLocation = table().location(); + Dataset df = spark().read().option(SparkReadOptions.VECTORIZATION_ENABLED, "true") + .format("iceberg").load(tableLocation); + materialize(df); + }); + } + + @Benchmark + @Threads(1) + public void readFileSourceVectorized() { + Map conf = Maps.newHashMap(); + conf.put(SQLConf.ORC_VECTORIZED_READER_ENABLED().key(), "true"); + conf.put(SQLConf.FILES_OPEN_COST_IN_BYTES().key(), Integer.toString(128 * 1024 * 1024)); + withSQLConf(conf, () -> { + Dataset df = spark().read().orc(dataLocation()); + materialize(df); + }); + } + + @Benchmark + @Threads(1) + public void readFileSourceNonVectorized() { + Map conf = Maps.newHashMap(); + conf.put(SQLConf.ORC_VECTORIZED_READER_ENABLED().key(), "false"); + conf.put(SQLConf.FILES_OPEN_COST_IN_BYTES().key(), Integer.toString(128 * 1024 * 1024)); + withSQLConf(conf, () -> { + Dataset df = spark().read().orc(dataLocation()); + materialize(df); + }); + } + + @Benchmark + @Threads(1) + public void readWithProjectionIcebergNonVectorized() { + Map tableProperties = Maps.newHashMap(); + tableProperties.put(SPLIT_OPEN_FILE_COST, Integer.toString(128 * 1024 * 1024)); + withTableProperties(tableProperties, () -> { + String tableLocation = table().location(); + Dataset df = spark().read().format("iceberg").load(tableLocation).select("longCol"); + materialize(df); + }); + } + + @Benchmark + @Threads(1) + public void readWithProjectionIcebergVectorized() { + Map tableProperties = Maps.newHashMap(); + tableProperties.put(SPLIT_OPEN_FILE_COST, Integer.toString(128 * 1024 * 1024)); + withTableProperties(tableProperties, () -> { + String tableLocation = table().location(); + Dataset df = spark().read().option(SparkReadOptions.VECTORIZATION_ENABLED, "true") + .format("iceberg").load(tableLocation).select("longCol"); + materialize(df); + }); + } + + + @Benchmark + @Threads(1) + public void readWithProjectionFileSourceVectorized() { + Map conf = Maps.newHashMap(); + conf.put(SQLConf.ORC_VECTORIZED_READER_ENABLED().key(), "true"); + conf.put(SQLConf.FILES_OPEN_COST_IN_BYTES().key(), Integer.toString(128 * 1024 * 1024)); + withSQLConf(conf, () -> { + Dataset df = spark().read().orc(dataLocation()).select("longCol"); + materialize(df); + }); + } + + @Benchmark + @Threads(1) + public void readWithProjectionFileSourceNonVectorized() { + Map conf = Maps.newHashMap(); + conf.put(SQLConf.ORC_VECTORIZED_READER_ENABLED().key(), "false"); + conf.put(SQLConf.FILES_OPEN_COST_IN_BYTES().key(), Integer.toString(128 * 1024 * 1024)); + withSQLConf(conf, () -> { + Dataset df = spark().read().orc(dataLocation()).select("longCol"); + materialize(df); + }); + } + + private void appendData() { + Map tableProperties = Maps.newHashMap(); + tableProperties.put(DEFAULT_FILE_FORMAT, "orc"); + withTableProperties(tableProperties, () -> { + for (int fileNum = 1; fileNum <= NUM_FILES; fileNum++) { + Dataset df = spark().range(NUM_ROWS) + .withColumnRenamed("id", "longCol") + .withColumn("intCol", expr("CAST(longCol AS INT)")) + .withColumn("floatCol", expr("CAST(longCol AS FLOAT)")) + .withColumn("doubleCol", expr("CAST(longCol AS DOUBLE)")) + .withColumn("decimalCol", expr("CAST(longCol AS DECIMAL(20, 5))")) + .withColumn("dateCol", date_add(current_date(), fileNum)) + .withColumn("stringCol", expr("CAST(dateCol AS STRING)")); + appendAsFile(df); + } + }); + } +} diff --git a/spark/v3.3/spark/src/jmh/java/org/apache/iceberg/spark/source/orc/IcebergSourceNestedListORCDataWriteBenchmark.java b/spark/v3.3/spark/src/jmh/java/org/apache/iceberg/spark/source/orc/IcebergSourceNestedListORCDataWriteBenchmark.java new file mode 100644 index 000000000000..f4edce20ec99 --- /dev/null +++ b/spark/v3.3/spark/src/jmh/java/org/apache/iceberg/spark/source/orc/IcebergSourceNestedListORCDataWriteBenchmark.java @@ -0,0 +1,99 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iceberg.spark.source.orc; + +import java.io.IOException; +import java.util.Map; +import org.apache.iceberg.relocated.com.google.common.collect.Maps; +import org.apache.iceberg.spark.source.IcebergSourceNestedListDataBenchmark; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.SaveMode; +import org.openjdk.jmh.annotations.Benchmark; +import org.openjdk.jmh.annotations.Param; +import org.openjdk.jmh.annotations.Setup; +import org.openjdk.jmh.annotations.TearDown; +import org.openjdk.jmh.annotations.Threads; + +import static org.apache.spark.sql.functions.array_repeat; +import static org.apache.spark.sql.functions.expr; +import static org.apache.spark.sql.functions.struct; + +/** + * A benchmark that evaluates the performance of writing nested Parquet data using Iceberg + * and the built-in file source in Spark. + *

+ * To run this benchmark for spark-3.3: + * + * ./gradlew -DsparkVersions=3.3 :iceberg-spark:iceberg-spark-3.3_2.12:jmh + * -PjmhIncludeRegex=IcebergSourceNestedListORCDataWriteBenchmark + * -PjmhOutputPath=benchmark/iceberg-source-nested-list-orc-data-write-benchmark-result.txt + * + */ +public class IcebergSourceNestedListORCDataWriteBenchmark extends IcebergSourceNestedListDataBenchmark { + + @Setup + public void setupBenchmark() { + setupSpark(); + } + + @TearDown + public void tearDownBenchmark() throws IOException { + tearDownSpark(); + cleanupFiles(); + } + + @Param({"2000", "20000"}) + private int numRows; + + @Benchmark + @Threads(1) + public void writeIceberg() { + String tableLocation = table().location(); + benchmarkData().write().format("iceberg").option("write-format", "orc") + .mode(SaveMode.Append).save(tableLocation); + } + + @Benchmark + @Threads(1) + public void writeIcebergDictionaryOff() { + Map tableProperties = Maps.newHashMap(); + tableProperties.put("orc.dictionary.key.threshold", "0"); + withTableProperties(tableProperties, () -> { + String tableLocation = table().location(); + benchmarkData().write().format("iceberg").option("write-format", "orc") + .mode(SaveMode.Append).save(tableLocation); + }); + } + + @Benchmark + @Threads(1) + public void writeFileSource() { + benchmarkData().write().mode(SaveMode.Append).orc(dataLocation()); + } + + private Dataset benchmarkData() { + return spark().range(numRows) + .withColumn("outerlist", array_repeat(struct( + expr("array_repeat(CAST(id AS string), 1000) AS innerlist")), + 10)) + .coalesce(1); + } +} diff --git a/spark/v3.3/spark/src/jmh/java/org/apache/iceberg/spark/source/orc/IcebergSourceNestedORCDataReadBenchmark.java b/spark/v3.3/spark/src/jmh/java/org/apache/iceberg/spark/source/orc/IcebergSourceNestedORCDataReadBenchmark.java new file mode 100644 index 000000000000..dd1122c3abd6 --- /dev/null +++ b/spark/v3.3/spark/src/jmh/java/org/apache/iceberg/spark/source/orc/IcebergSourceNestedORCDataReadBenchmark.java @@ -0,0 +1,161 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iceberg.spark.source.orc; + +import java.io.IOException; +import java.util.Map; +import org.apache.iceberg.relocated.com.google.common.collect.Maps; +import org.apache.iceberg.spark.SparkReadOptions; +import org.apache.iceberg.spark.source.IcebergSourceNestedDataBenchmark; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.internal.SQLConf; +import org.openjdk.jmh.annotations.Benchmark; +import org.openjdk.jmh.annotations.Setup; +import org.openjdk.jmh.annotations.TearDown; +import org.openjdk.jmh.annotations.Threads; + +import static org.apache.iceberg.TableProperties.DEFAULT_FILE_FORMAT; +import static org.apache.iceberg.TableProperties.SPLIT_OPEN_FILE_COST; +import static org.apache.spark.sql.functions.expr; +import static org.apache.spark.sql.functions.lit; +import static org.apache.spark.sql.functions.struct; + + +/** + * A benchmark that evaluates the performance of reading ORC data with a flat schema + * using Iceberg and the built-in file source in Spark. + *

+ * To run this benchmark for spark-3.3: + * + * ./gradlew -DsparkVersions=3.3 :iceberg-spark:iceberg-spark-3.3_2.12:jmh + * -PjmhIncludeRegex=IcebergSourceNestedORCDataReadBenchmark + * -PjmhOutputPath=benchmark/iceberg-source-nested-orc-data-read-benchmark-result.txt + * + */ +public class IcebergSourceNestedORCDataReadBenchmark extends IcebergSourceNestedDataBenchmark { + + private static final int NUM_FILES = 10; + private static final int NUM_ROWS = 1000000; + + @Setup + public void setupBenchmark() { + setupSpark(); + appendData(); + } + + @TearDown + public void tearDownBenchmark() throws IOException { + tearDownSpark(); + cleanupFiles(); + } + + @Benchmark + @Threads(1) + public void readIcebergNonVectorized() { + Map tableProperties = Maps.newHashMap(); + tableProperties.put(SPLIT_OPEN_FILE_COST, Integer.toString(128 * 1024 * 1024)); + withTableProperties(tableProperties, () -> { + String tableLocation = table().location(); + Dataset df = spark().read().format("iceberg").load(tableLocation); + materialize(df); + }); + } + + @Benchmark + @Threads(1) + public void readIcebergVectorized() { + Map tableProperties = Maps.newHashMap(); + tableProperties.put(SPLIT_OPEN_FILE_COST, Integer.toString(128 * 1024 * 1024)); + withTableProperties(tableProperties, () -> { + String tableLocation = table().location(); + Dataset df = spark().read().option(SparkReadOptions.VECTORIZATION_ENABLED, "true") + .format("iceberg").load(tableLocation); + materialize(df); + }); + } + + @Benchmark + @Threads(1) + public void readFileSourceNonVectorized() { + Map conf = Maps.newHashMap(); + conf.put(SQLConf.ORC_VECTORIZED_READER_ENABLED().key(), "false"); + conf.put(SQLConf.FILES_OPEN_COST_IN_BYTES().key(), Integer.toString(128 * 1024 * 1024)); + withSQLConf(conf, () -> { + Dataset df = spark().read().orc(dataLocation()); + materialize(df); + }); + } + + @Benchmark + @Threads(1) + public void readWithProjectionIcebergNonVectorized() { + Map tableProperties = Maps.newHashMap(); + tableProperties.put(SPLIT_OPEN_FILE_COST, Integer.toString(128 * 1024 * 1024)); + withTableProperties(tableProperties, () -> { + String tableLocation = table().location(); + Dataset df = spark().read().format("iceberg").load(tableLocation).selectExpr("nested.col3"); + materialize(df); + }); + } + + @Benchmark + @Threads(1) + public void readWithProjectionIcebergVectorized() { + Map tableProperties = Maps.newHashMap(); + tableProperties.put(SPLIT_OPEN_FILE_COST, Integer.toString(128 * 1024 * 1024)); + withTableProperties(tableProperties, () -> { + String tableLocation = table().location(); + Dataset df = spark().read().option(SparkReadOptions.VECTORIZATION_ENABLED, "true") + .format("iceberg").load(tableLocation).selectExpr("nested.col3"); + materialize(df); + }); + } + + @Benchmark + @Threads(1) + public void readWithProjectionFileSourceNonVectorized() { + Map conf = Maps.newHashMap(); + conf.put(SQLConf.ORC_VECTORIZED_READER_ENABLED().key(), "false"); + conf.put(SQLConf.FILES_OPEN_COST_IN_BYTES().key(), Integer.toString(128 * 1024 * 1024)); + withSQLConf(conf, () -> { + Dataset df = spark().read().orc(dataLocation()).selectExpr("nested.col3"); + materialize(df); + }); + } + + private void appendData() { + Map tableProperties = Maps.newHashMap(); + tableProperties.put(DEFAULT_FILE_FORMAT, "orc"); + withTableProperties(tableProperties, () -> { + for (int fileNum = 0; fileNum < NUM_FILES; fileNum++) { + Dataset df = spark().range(NUM_ROWS) + .withColumn( + "nested", + struct( + expr("CAST(id AS string) AS col1"), + expr("CAST(id AS double) AS col2"), + lit(fileNum).cast("long").as("col3") + )); + appendAsFile(df); + } + }); + } +} diff --git a/spark/v3.3/spark/src/jmh/java/org/apache/iceberg/spark/source/parquet/IcebergSourceFlatParquetDataFilterBenchmark.java b/spark/v3.3/spark/src/jmh/java/org/apache/iceberg/spark/source/parquet/IcebergSourceFlatParquetDataFilterBenchmark.java new file mode 100644 index 000000000000..885f9b4f0aa6 --- /dev/null +++ b/spark/v3.3/spark/src/jmh/java/org/apache/iceberg/spark/source/parquet/IcebergSourceFlatParquetDataFilterBenchmark.java @@ -0,0 +1,122 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iceberg.spark.source.parquet; + +import java.io.IOException; +import java.util.Map; +import org.apache.iceberg.relocated.com.google.common.collect.Maps; +import org.apache.iceberg.spark.source.IcebergSourceFlatDataBenchmark; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.internal.SQLConf; +import org.openjdk.jmh.annotations.Benchmark; +import org.openjdk.jmh.annotations.Setup; +import org.openjdk.jmh.annotations.TearDown; +import org.openjdk.jmh.annotations.Threads; + +import static org.apache.iceberg.TableProperties.SPLIT_OPEN_FILE_COST; +import static org.apache.spark.sql.functions.current_date; +import static org.apache.spark.sql.functions.date_add; +import static org.apache.spark.sql.functions.expr; + +/** + * A benchmark that evaluates the file skipping capabilities in the Spark data source for Iceberg. + * + * This class uses a dataset with a flat schema, where the records are clustered according to the + * column used in the filter predicate. + * + * The performance is compared to the built-in file source in Spark. + * + * To run this benchmark for spark-3.3: + * + * ./gradlew -DsparkVersions=3.3 :iceberg-spark:iceberg-spark-3.3_2.12:jmh + * -PjmhIncludeRegex=IcebergSourceFlatParquetDataFilterBenchmark + * -PjmhOutputPath=benchmark/iceberg-source-flat-parquet-data-filter-benchmark-result.txt + * + */ +public class IcebergSourceFlatParquetDataFilterBenchmark extends IcebergSourceFlatDataBenchmark { + + private static final String FILTER_COND = "dateCol == date_add(current_date(), 1)"; + private static final int NUM_FILES = 500; + private static final int NUM_ROWS = 10000; + + @Setup + public void setupBenchmark() { + setupSpark(); + appendData(); + } + + @TearDown + public void tearDownBenchmark() throws IOException { + tearDownSpark(); + cleanupFiles(); + } + + @Benchmark + @Threads(1) + public void readWithFilterIceberg() { + Map tableProperties = Maps.newHashMap(); + tableProperties.put(SPLIT_OPEN_FILE_COST, Integer.toString(128 * 1024 * 1024)); + withTableProperties(tableProperties, () -> { + String tableLocation = table().location(); + Dataset df = spark().read().format("iceberg").load(tableLocation).filter(FILTER_COND); + materialize(df); + }); + } + + @Benchmark + @Threads(1) + public void readWithFilterFileSourceVectorized() { + Map conf = Maps.newHashMap(); + conf.put(SQLConf.PARQUET_VECTORIZED_READER_ENABLED().key(), "true"); + conf.put(SQLConf.FILES_OPEN_COST_IN_BYTES().key(), Integer.toString(128 * 1024 * 1024)); + withSQLConf(conf, () -> { + Dataset df = spark().read().parquet(dataLocation()).filter(FILTER_COND); + materialize(df); + }); + } + + @Benchmark + @Threads(1) + public void readWithFilterFileSourceNonVectorized() { + Map conf = Maps.newHashMap(); + conf.put(SQLConf.PARQUET_VECTORIZED_READER_ENABLED().key(), "false"); + conf.put(SQLConf.FILES_OPEN_COST_IN_BYTES().key(), Integer.toString(128 * 1024 * 1024)); + withSQLConf(conf, () -> { + Dataset df = spark().read().parquet(dataLocation()).filter(FILTER_COND); + materialize(df); + }); + } + + private void appendData() { + for (int fileNum = 1; fileNum < NUM_FILES; fileNum++) { + Dataset df = spark().range(NUM_ROWS) + .withColumnRenamed("id", "longCol") + .withColumn("intCol", expr("CAST(longCol AS INT)")) + .withColumn("floatCol", expr("CAST(longCol AS FLOAT)")) + .withColumn("doubleCol", expr("CAST(longCol AS DOUBLE)")) + .withColumn("decimalCol", expr("CAST(longCol AS DECIMAL(20, 5))")) + .withColumn("dateCol", date_add(current_date(), fileNum)) + .withColumn("timestampCol", expr("TO_TIMESTAMP(dateCol)")) + .withColumn("stringCol", expr("CAST(dateCol AS STRING)")); + appendAsFile(df); + } + } +} diff --git a/spark/v3.3/spark/src/jmh/java/org/apache/iceberg/spark/source/parquet/IcebergSourceFlatParquetDataReadBenchmark.java b/spark/v3.3/spark/src/jmh/java/org/apache/iceberg/spark/source/parquet/IcebergSourceFlatParquetDataReadBenchmark.java new file mode 100644 index 000000000000..bbbdd288d193 --- /dev/null +++ b/spark/v3.3/spark/src/jmh/java/org/apache/iceberg/spark/source/parquet/IcebergSourceFlatParquetDataReadBenchmark.java @@ -0,0 +1,153 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iceberg.spark.source.parquet; + +import java.io.IOException; +import java.util.Map; +import org.apache.iceberg.relocated.com.google.common.collect.Maps; +import org.apache.iceberg.spark.source.IcebergSourceFlatDataBenchmark; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.internal.SQLConf; +import org.openjdk.jmh.annotations.Benchmark; +import org.openjdk.jmh.annotations.Setup; +import org.openjdk.jmh.annotations.TearDown; +import org.openjdk.jmh.annotations.Threads; + +import static org.apache.iceberg.TableProperties.SPLIT_OPEN_FILE_COST; +import static org.apache.spark.sql.functions.current_date; +import static org.apache.spark.sql.functions.date_add; +import static org.apache.spark.sql.functions.expr; + +/** + * A benchmark that evaluates the performance of reading Parquet data with a flat schema + * using Iceberg and the built-in file source in Spark. + * + * To run this benchmark for spark-3.3: + * + * ./gradlew -DsparkVersions=3.3 :iceberg-spark:iceberg-spark-3.3_2.12:jmh + * -PjmhIncludeRegex=IcebergSourceFlatParquetDataReadBenchmark + * -PjmhOutputPath=benchmark/iceberg-source-flat-parquet-data-read-benchmark-result.txt + * + */ +public class IcebergSourceFlatParquetDataReadBenchmark extends IcebergSourceFlatDataBenchmark { + + private static final int NUM_FILES = 10; + private static final int NUM_ROWS = 1000000; + + @Setup + public void setupBenchmark() { + setupSpark(); + appendData(); + } + + @TearDown + public void tearDownBenchmark() throws IOException { + tearDownSpark(); + cleanupFiles(); + } + + @Benchmark + @Threads(1) + public void readIceberg() { + Map tableProperties = Maps.newHashMap(); + tableProperties.put(SPLIT_OPEN_FILE_COST, Integer.toString(128 * 1024 * 1024)); + withTableProperties(tableProperties, () -> { + String tableLocation = table().location(); + Dataset df = spark().read().format("iceberg").load(tableLocation); + materialize(df); + }); + } + + @Benchmark + @Threads(1) + public void readFileSourceVectorized() { + Map conf = Maps.newHashMap(); + conf.put(SQLConf.PARQUET_VECTORIZED_READER_ENABLED().key(), "true"); + conf.put(SQLConf.FILES_OPEN_COST_IN_BYTES().key(), Integer.toString(128 * 1024 * 1024)); + withSQLConf(conf, () -> { + Dataset df = spark().read().parquet(dataLocation()); + materialize(df); + }); + } + + @Benchmark + @Threads(1) + public void readFileSourceNonVectorized() { + Map conf = Maps.newHashMap(); + conf.put(SQLConf.PARQUET_VECTORIZED_READER_ENABLED().key(), "false"); + conf.put(SQLConf.FILES_OPEN_COST_IN_BYTES().key(), Integer.toString(128 * 1024 * 1024)); + withSQLConf(conf, () -> { + Dataset df = spark().read().parquet(dataLocation()); + materialize(df); + }); + } + + @Benchmark + @Threads(1) + public void readWithProjectionIceberg() { + Map tableProperties = Maps.newHashMap(); + tableProperties.put(SPLIT_OPEN_FILE_COST, Integer.toString(128 * 1024 * 1024)); + withTableProperties(tableProperties, () -> { + String tableLocation = table().location(); + Dataset df = spark().read().format("iceberg").load(tableLocation).select("longCol"); + materialize(df); + }); + } + + @Benchmark + @Threads(1) + public void readWithProjectionFileSourceVectorized() { + Map conf = Maps.newHashMap(); + conf.put(SQLConf.PARQUET_VECTORIZED_READER_ENABLED().key(), "true"); + conf.put(SQLConf.FILES_OPEN_COST_IN_BYTES().key(), Integer.toString(128 * 1024 * 1024)); + withSQLConf(conf, () -> { + Dataset df = spark().read().parquet(dataLocation()).select("longCol"); + materialize(df); + }); + } + + @Benchmark + @Threads(1) + public void readWithProjectionFileSourceNonVectorized() { + Map conf = Maps.newHashMap(); + conf.put(SQLConf.PARQUET_VECTORIZED_READER_ENABLED().key(), "false"); + conf.put(SQLConf.FILES_OPEN_COST_IN_BYTES().key(), Integer.toString(128 * 1024 * 1024)); + withSQLConf(conf, () -> { + Dataset df = spark().read().parquet(dataLocation()).select("longCol"); + materialize(df); + }); + } + + private void appendData() { + for (int fileNum = 1; fileNum <= NUM_FILES; fileNum++) { + Dataset df = spark().range(NUM_ROWS) + .withColumnRenamed("id", "longCol") + .withColumn("intCol", expr("CAST(longCol AS INT)")) + .withColumn("floatCol", expr("CAST(longCol AS FLOAT)")) + .withColumn("doubleCol", expr("CAST(longCol AS DOUBLE)")) + .withColumn("decimalCol", expr("CAST(longCol AS DECIMAL(20, 5))")) + .withColumn("dateCol", date_add(current_date(), fileNum)) + .withColumn("timestampCol", expr("TO_TIMESTAMP(dateCol)")) + .withColumn("stringCol", expr("CAST(dateCol AS STRING)")); + appendAsFile(df); + } + } +} diff --git a/spark/v3.3/spark/src/jmh/java/org/apache/iceberg/spark/source/parquet/IcebergSourceFlatParquetDataWriteBenchmark.java b/spark/v3.3/spark/src/jmh/java/org/apache/iceberg/spark/source/parquet/IcebergSourceFlatParquetDataWriteBenchmark.java new file mode 100644 index 000000000000..dd3f5080f5cd --- /dev/null +++ b/spark/v3.3/spark/src/jmh/java/org/apache/iceberg/spark/source/parquet/IcebergSourceFlatParquetDataWriteBenchmark.java @@ -0,0 +1,90 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iceberg.spark.source.parquet; + +import java.io.IOException; +import java.util.Map; +import org.apache.iceberg.relocated.com.google.common.collect.Maps; +import org.apache.iceberg.spark.source.IcebergSourceFlatDataBenchmark; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.SaveMode; +import org.apache.spark.sql.internal.SQLConf; +import org.openjdk.jmh.annotations.Benchmark; +import org.openjdk.jmh.annotations.Setup; +import org.openjdk.jmh.annotations.TearDown; +import org.openjdk.jmh.annotations.Threads; + +import static org.apache.spark.sql.functions.expr; + +/** + * A benchmark that evaluates the performance of writing Parquet data with a flat schema + * using Iceberg and the built-in file source in Spark. + * + * To run this benchmark for spark-3.3: + * + * ./gradlew -DsparkVersions=3.3 :iceberg-spark:iceberg-spark-3.3_2.12:jmh + * -PjmhIncludeRegex=IcebergSourceFlatParquetDataWriteBenchmark + * -PjmhOutputPath=benchmark/iceberg-source-flat-parquet-data-write-benchmark-result.txt + * + */ +public class IcebergSourceFlatParquetDataWriteBenchmark extends IcebergSourceFlatDataBenchmark { + + private static final int NUM_ROWS = 5000000; + + @Setup + public void setupBenchmark() { + setupSpark(); + } + + @TearDown + public void tearDownBenchmark() throws IOException { + tearDownSpark(); + cleanupFiles(); + } + + @Benchmark + @Threads(1) + public void writeIceberg() { + String tableLocation = table().location(); + benchmarkData().write().format("iceberg").mode(SaveMode.Append).save(tableLocation); + } + + @Benchmark + @Threads(1) + public void writeFileSource() { + Map conf = Maps.newHashMap(); + conf.put(SQLConf.PARQUET_COMPRESSION().key(), "gzip"); + withSQLConf(conf, () -> benchmarkData().write().mode(SaveMode.Append).parquet(dataLocation())); + } + + private Dataset benchmarkData() { + return spark().range(NUM_ROWS) + .withColumnRenamed("id", "longCol") + .withColumn("intCol", expr("CAST(longCol AS INT)")) + .withColumn("floatCol", expr("CAST(longCol AS FLOAT)")) + .withColumn("doubleCol", expr("CAST(longCol AS DOUBLE)")) + .withColumn("decimalCol", expr("CAST(longCol AS DECIMAL(20, 5))")) + .withColumn("dateCol", expr("DATE_ADD(CURRENT_DATE(), (longCol % 20))")) + .withColumn("timestampCol", expr("TO_TIMESTAMP(dateCol)")) + .withColumn("stringCol", expr("CAST(dateCol AS STRING)")) + .coalesce(1); + } +} diff --git a/spark/v3.3/spark/src/jmh/java/org/apache/iceberg/spark/source/parquet/IcebergSourceNestedListParquetDataWriteBenchmark.java b/spark/v3.3/spark/src/jmh/java/org/apache/iceberg/spark/source/parquet/IcebergSourceNestedListParquetDataWriteBenchmark.java new file mode 100644 index 000000000000..a06033944eda --- /dev/null +++ b/spark/v3.3/spark/src/jmh/java/org/apache/iceberg/spark/source/parquet/IcebergSourceNestedListParquetDataWriteBenchmark.java @@ -0,0 +1,89 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iceberg.spark.source.parquet; + +import java.io.IOException; +import java.util.Map; +import org.apache.iceberg.relocated.com.google.common.collect.Maps; +import org.apache.iceberg.spark.source.IcebergSourceNestedListDataBenchmark; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.SaveMode; +import org.apache.spark.sql.internal.SQLConf; +import org.openjdk.jmh.annotations.Benchmark; +import org.openjdk.jmh.annotations.Param; +import org.openjdk.jmh.annotations.Setup; +import org.openjdk.jmh.annotations.TearDown; +import org.openjdk.jmh.annotations.Threads; + +import static org.apache.spark.sql.functions.array_repeat; +import static org.apache.spark.sql.functions.expr; +import static org.apache.spark.sql.functions.struct; + +/** + * A benchmark that evaluates the performance of writing nested Parquet data using Iceberg + * and the built-in file source in Spark. + * + * To run this benchmark for spark-3.3: + * + * ./gradlew -DsparkVersions=3.3 :iceberg-spark:iceberg-spark-3.3_2.12:jmh + * -PjmhIncludeRegex=IcebergSourceNestedListParquetDataWriteBenchmark + * -PjmhOutputPath=benchmark/iceberg-source-nested-list-parquet-data-write-benchmark-result.txt + * + */ +public class IcebergSourceNestedListParquetDataWriteBenchmark extends IcebergSourceNestedListDataBenchmark { + + @Setup + public void setupBenchmark() { + setupSpark(); + } + + @TearDown + public void tearDownBenchmark() throws IOException { + tearDownSpark(); + cleanupFiles(); + } + + @Param({"2000", "20000"}) + private int numRows; + + @Benchmark + @Threads(1) + public void writeIceberg() { + String tableLocation = table().location(); + benchmarkData().write().format("iceberg").mode(SaveMode.Append).save(tableLocation); + } + + @Benchmark + @Threads(1) + public void writeFileSource() { + Map conf = Maps.newHashMap(); + conf.put(SQLConf.PARQUET_COMPRESSION().key(), "gzip"); + withSQLConf(conf, () -> benchmarkData().write().mode(SaveMode.Append).parquet(dataLocation())); + } + + private Dataset benchmarkData() { + return spark().range(numRows) + .withColumn("outerlist", array_repeat(struct( + expr("array_repeat(CAST(id AS string), 1000) AS innerlist")), + 10)) + .coalesce(1); + } +} diff --git a/spark/v3.3/spark/src/jmh/java/org/apache/iceberg/spark/source/parquet/IcebergSourceNestedParquetDataFilterBenchmark.java b/spark/v3.3/spark/src/jmh/java/org/apache/iceberg/spark/source/parquet/IcebergSourceNestedParquetDataFilterBenchmark.java new file mode 100644 index 000000000000..1036bc5f1b99 --- /dev/null +++ b/spark/v3.3/spark/src/jmh/java/org/apache/iceberg/spark/source/parquet/IcebergSourceNestedParquetDataFilterBenchmark.java @@ -0,0 +1,121 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iceberg.spark.source.parquet; + +import java.io.IOException; +import java.util.Map; +import org.apache.iceberg.relocated.com.google.common.collect.Maps; +import org.apache.iceberg.spark.source.IcebergSourceNestedDataBenchmark; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.internal.SQLConf; +import org.openjdk.jmh.annotations.Benchmark; +import org.openjdk.jmh.annotations.Setup; +import org.openjdk.jmh.annotations.TearDown; +import org.openjdk.jmh.annotations.Threads; + +import static org.apache.iceberg.TableProperties.SPLIT_OPEN_FILE_COST; +import static org.apache.spark.sql.functions.expr; +import static org.apache.spark.sql.functions.lit; +import static org.apache.spark.sql.functions.struct; + +/** + * A benchmark that evaluates the file skipping capabilities in the Spark data source for Iceberg. + * + * This class uses a dataset with nested data, where the records are clustered according to the + * column used in the filter predicate. + * + * The performance is compared to the built-in file source in Spark. + * + * To run this benchmark for spark-3.3: + * + * ./gradlew -DsparkVersions=3.3 :iceberg-spark:iceberg-spark-3.3_2.12:jmh + * -PjmhIncludeRegex=IcebergSourceNestedParquetDataFilterBenchmark + * -PjmhOutputPath=benchmark/iceberg-source-nested-parquet-data-filter-benchmark-result.txt + * + */ +public class IcebergSourceNestedParquetDataFilterBenchmark extends IcebergSourceNestedDataBenchmark { + + private static final String FILTER_COND = "nested.col3 == 0"; + private static final int NUM_FILES = 500; + private static final int NUM_ROWS = 10000; + + @Setup + public void setupBenchmark() { + setupSpark(); + appendData(); + } + + @TearDown + public void tearDownBenchmark() throws IOException { + tearDownSpark(); + cleanupFiles(); + } + + @Benchmark + @Threads(1) + public void readWithFilterIceberg() { + Map tableProperties = Maps.newHashMap(); + tableProperties.put(SPLIT_OPEN_FILE_COST, Integer.toString(128 * 1024 * 1024)); + withTableProperties(tableProperties, () -> { + String tableLocation = table().location(); + Dataset df = spark().read().format("iceberg").load(tableLocation).filter(FILTER_COND); + materialize(df); + }); + } + + @Benchmark + @Threads(1) + public void readWithFilterFileSourceVectorized() { + Map conf = Maps.newHashMap(); + conf.put(SQLConf.PARQUET_VECTORIZED_READER_ENABLED().key(), "true"); + conf.put(SQLConf.FILES_OPEN_COST_IN_BYTES().key(), Integer.toString(128 * 1024 * 1024)); + withSQLConf(conf, () -> { + Dataset df = spark().read().parquet(dataLocation()).filter(FILTER_COND); + materialize(df); + }); + } + + @Benchmark + @Threads(1) + public void readWithFilterFileSourceNonVectorized() { + Map conf = Maps.newHashMap(); + conf.put(SQLConf.PARQUET_VECTORIZED_READER_ENABLED().key(), "false"); + conf.put(SQLConf.FILES_OPEN_COST_IN_BYTES().key(), Integer.toString(128 * 1024 * 1024)); + withSQLConf(conf, () -> { + Dataset df = spark().read().parquet(dataLocation()).filter(FILTER_COND); + materialize(df); + }); + } + + private void appendData() { + for (int fileNum = 1; fileNum <= NUM_FILES; fileNum++) { + Dataset df = spark().range(NUM_ROWS) + .withColumn( + "nested", + struct( + expr("CAST(id AS string) AS col1"), + expr("CAST(id AS double) AS col2"), + lit(fileNum).cast("long").as("col3") + )); + appendAsFile(df); + } + } +} diff --git a/spark/v3.3/spark/src/jmh/java/org/apache/iceberg/spark/source/parquet/IcebergSourceNestedParquetDataReadBenchmark.java b/spark/v3.3/spark/src/jmh/java/org/apache/iceberg/spark/source/parquet/IcebergSourceNestedParquetDataReadBenchmark.java new file mode 100644 index 000000000000..fdd1c217004b --- /dev/null +++ b/spark/v3.3/spark/src/jmh/java/org/apache/iceberg/spark/source/parquet/IcebergSourceNestedParquetDataReadBenchmark.java @@ -0,0 +1,154 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iceberg.spark.source.parquet; + +import java.io.IOException; +import java.util.Map; +import org.apache.iceberg.relocated.com.google.common.collect.Maps; +import org.apache.iceberg.spark.source.IcebergSourceNestedDataBenchmark; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.internal.SQLConf; +import org.openjdk.jmh.annotations.Benchmark; +import org.openjdk.jmh.annotations.Setup; +import org.openjdk.jmh.annotations.TearDown; +import org.openjdk.jmh.annotations.Threads; + +import static org.apache.iceberg.TableProperties.SPLIT_OPEN_FILE_COST; +import static org.apache.spark.sql.functions.expr; +import static org.apache.spark.sql.functions.lit; +import static org.apache.spark.sql.functions.struct; + +/** + * A benchmark that evaluates the performance of reading nested Parquet data using Iceberg + * and the built-in file source in Spark. + * + * To run this benchmark for spark-3.3: + * + * ./gradlew -DsparkVersions=3.3 :iceberg-spark:iceberg-spark-3.3_2.12:jmh + * -PjmhIncludeRegex=IcebergSourceNestedParquetDataReadBenchmark + * -PjmhOutputPath=benchmark/iceberg-source-nested-parquet-data-read-benchmark-result.txt + * + */ +public class IcebergSourceNestedParquetDataReadBenchmark extends IcebergSourceNestedDataBenchmark { + + private static final int NUM_FILES = 10; + private static final int NUM_ROWS = 1000000; + + @Setup + public void setupBenchmark() { + setupSpark(); + appendData(); + } + + @TearDown + public void tearDownBenchmark() throws IOException { + tearDownSpark(); + cleanupFiles(); + } + + @Benchmark + @Threads(1) + public void readIceberg() { + Map tableProperties = Maps.newHashMap(); + tableProperties.put(SPLIT_OPEN_FILE_COST, Integer.toString(128 * 1024 * 1024)); + withTableProperties(tableProperties, () -> { + String tableLocation = table().location(); + Dataset df = spark().read().format("iceberg").load(tableLocation); + materialize(df); + }); + } + + @Benchmark + @Threads(1) + public void readFileSourceVectorized() { + Map conf = Maps.newHashMap(); + conf.put(SQLConf.PARQUET_VECTORIZED_READER_ENABLED().key(), "true"); + conf.put(SQLConf.FILES_OPEN_COST_IN_BYTES().key(), Integer.toString(128 * 1024 * 1024)); + withSQLConf(conf, () -> { + Dataset df = spark().read().parquet(dataLocation()); + materialize(df); + }); + } + + @Benchmark + @Threads(1) + public void readFileSourceNonVectorized() { + Map conf = Maps.newHashMap(); + conf.put(SQLConf.PARQUET_VECTORIZED_READER_ENABLED().key(), "false"); + conf.put(SQLConf.FILES_OPEN_COST_IN_BYTES().key(), Integer.toString(128 * 1024 * 1024)); + withSQLConf(conf, () -> { + Dataset df = spark().read().parquet(dataLocation()); + materialize(df); + }); + } + + @Benchmark + @Threads(1) + public void readWithProjectionIceberg() { + Map tableProperties = Maps.newHashMap(); + tableProperties.put(SPLIT_OPEN_FILE_COST, Integer.toString(128 * 1024 * 1024)); + withTableProperties(tableProperties, () -> { + String tableLocation = table().location(); + Dataset df = spark().read().format("iceberg").load(tableLocation).selectExpr("nested.col3"); + materialize(df); + }); + } + + @Benchmark + @Threads(1) + public void readWithProjectionFileSourceVectorized() { + Map conf = Maps.newHashMap(); + conf.put(SQLConf.PARQUET_VECTORIZED_READER_ENABLED().key(), "true"); + conf.put(SQLConf.FILES_OPEN_COST_IN_BYTES().key(), Integer.toString(128 * 1024 * 1024)); + conf.put(SQLConf.NESTED_SCHEMA_PRUNING_ENABLED().key(), "true"); + withSQLConf(conf, () -> { + Dataset df = spark().read().parquet(dataLocation()).selectExpr("nested.col3"); + materialize(df); + }); + } + + @Benchmark + @Threads(1) + public void readWithProjectionFileSourceNonVectorized() { + Map conf = Maps.newHashMap(); + conf.put(SQLConf.PARQUET_VECTORIZED_READER_ENABLED().key(), "false"); + conf.put(SQLConf.FILES_OPEN_COST_IN_BYTES().key(), Integer.toString(128 * 1024 * 1024)); + conf.put(SQLConf.NESTED_SCHEMA_PRUNING_ENABLED().key(), "true"); + withSQLConf(conf, () -> { + Dataset df = spark().read().parquet(dataLocation()).selectExpr("nested.col3"); + materialize(df); + }); + } + + private void appendData() { + for (int fileNum = 0; fileNum < NUM_FILES; fileNum++) { + Dataset df = spark().range(NUM_ROWS) + .withColumn( + "nested", + struct( + expr("CAST(id AS string) AS col1"), + expr("CAST(id AS double) AS col2"), + lit(fileNum).cast("long").as("col3") + )); + appendAsFile(df); + } + } +} diff --git a/spark/v3.3/spark/src/jmh/java/org/apache/iceberg/spark/source/parquet/IcebergSourceNestedParquetDataWriteBenchmark.java b/spark/v3.3/spark/src/jmh/java/org/apache/iceberg/spark/source/parquet/IcebergSourceNestedParquetDataWriteBenchmark.java new file mode 100644 index 000000000000..65265426f9f8 --- /dev/null +++ b/spark/v3.3/spark/src/jmh/java/org/apache/iceberg/spark/source/parquet/IcebergSourceNestedParquetDataWriteBenchmark.java @@ -0,0 +1,90 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iceberg.spark.source.parquet; + +import java.io.IOException; +import java.util.Map; +import org.apache.iceberg.relocated.com.google.common.collect.Maps; +import org.apache.iceberg.spark.source.IcebergSourceNestedDataBenchmark; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.SaveMode; +import org.apache.spark.sql.internal.SQLConf; +import org.openjdk.jmh.annotations.Benchmark; +import org.openjdk.jmh.annotations.Setup; +import org.openjdk.jmh.annotations.TearDown; +import org.openjdk.jmh.annotations.Threads; + +import static org.apache.spark.sql.functions.expr; +import static org.apache.spark.sql.functions.struct; + +/** + * A benchmark that evaluates the performance of writing nested Parquet data using Iceberg + * and the built-in file source in Spark. + * + * To run this benchmark for spark-3.3: + * + * ./gradlew -DsparkVersions=3.3 :iceberg-spark:iceberg-spark-3.3_2.12:jmh + * -PjmhIncludeRegex=IcebergSourceNestedParquetDataWriteBenchmark + * -PjmhOutputPath=benchmark/iceberg-source-nested-parquet-data-write-benchmark-result.txt + * + */ +public class IcebergSourceNestedParquetDataWriteBenchmark extends IcebergSourceNestedDataBenchmark { + + private static final int NUM_ROWS = 5000000; + + @Setup + public void setupBenchmark() { + setupSpark(); + } + + @TearDown + public void tearDownBenchmark() throws IOException { + tearDownSpark(); + cleanupFiles(); + } + + @Benchmark + @Threads(1) + public void writeIceberg() { + String tableLocation = table().location(); + benchmarkData().write().format("iceberg").mode(SaveMode.Append).save(tableLocation); + } + + @Benchmark + @Threads(1) + public void writeFileSource() { + Map conf = Maps.newHashMap(); + conf.put(SQLConf.PARQUET_COMPRESSION().key(), "gzip"); + withSQLConf(conf, () -> benchmarkData().write().mode(SaveMode.Append).parquet(dataLocation())); + } + + private Dataset benchmarkData() { + return spark().range(NUM_ROWS) + .withColumn( + "nested", + struct( + expr("CAST(id AS string) AS col1"), + expr("CAST(id AS double) AS col2"), + expr("id AS col3") + )) + .coalesce(1); + } +} diff --git a/spark/v3.3/spark/src/jmh/java/org/apache/iceberg/spark/source/parquet/IcebergSourceParquetEqDeleteBenchmark.java b/spark/v3.3/spark/src/jmh/java/org/apache/iceberg/spark/source/parquet/IcebergSourceParquetEqDeleteBenchmark.java new file mode 100644 index 000000000000..4884e8fe5db1 --- /dev/null +++ b/spark/v3.3/spark/src/jmh/java/org/apache/iceberg/spark/source/parquet/IcebergSourceParquetEqDeleteBenchmark.java @@ -0,0 +1,60 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iceberg.spark.source.parquet; + +import java.io.IOException; +import org.apache.iceberg.FileFormat; +import org.apache.iceberg.spark.source.IcebergSourceDeleteBenchmark; +import org.openjdk.jmh.annotations.Param; + +/** + * A benchmark that evaluates the non-vectorized read and vectorized read with equality delete in the Spark data source + * for Iceberg. + *

+ * This class uses a dataset with a flat schema. + * To run this benchmark for spark-3.3: + * + * ./gradlew -DsparkVersions=3.3 :iceberg-spark:iceberg-spark-3.3:jmh + * -PjmhIncludeRegex=IcebergSourceParquetEqDeleteBenchmark + * -PjmhOutputPath=benchmark/iceberg-source-parquet-eq-delete-benchmark-result.txt + * + */ +public class IcebergSourceParquetEqDeleteBenchmark extends IcebergSourceDeleteBenchmark { + @Param({"0", "0.000001", "0.05", "0.25", "0.5", "1"}) + private double percentDeleteRow; + + @Override + protected void appendData() throws IOException { + for (int fileNum = 1; fileNum <= NUM_FILES; fileNum++) { + writeData(fileNum); + + if (percentDeleteRow > 0) { + // add equality deletes + table().refresh(); + writeEqDeletes(NUM_ROWS, percentDeleteRow); + } + } + } + + @Override + protected FileFormat fileFormat() { + return FileFormat.PARQUET; + } +} diff --git a/spark/v3.3/spark/src/jmh/java/org/apache/iceberg/spark/source/parquet/IcebergSourceParquetMultiDeleteFileBenchmark.java b/spark/v3.3/spark/src/jmh/java/org/apache/iceberg/spark/source/parquet/IcebergSourceParquetMultiDeleteFileBenchmark.java new file mode 100644 index 000000000000..69b5be1d9171 --- /dev/null +++ b/spark/v3.3/spark/src/jmh/java/org/apache/iceberg/spark/source/parquet/IcebergSourceParquetMultiDeleteFileBenchmark.java @@ -0,0 +1,60 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iceberg.spark.source.parquet; + +import java.io.IOException; +import org.apache.iceberg.DataFile; +import org.apache.iceberg.FileFormat; +import org.apache.iceberg.spark.source.IcebergSourceDeleteBenchmark; +import org.openjdk.jmh.annotations.Param; + +/** + * A benchmark that evaluates the non-vectorized read and vectorized read with pos-delete in the Spark data source for + * Iceberg. + *

+ * This class uses a dataset with a flat schema. + * To run this benchmark for spark-3.3: + * + * ./gradlew -DsparkVersions=3.3 :iceberg-spark:iceberg-spark-3.3:jmh \ + * -PjmhIncludeRegex=IcebergSourceParquetMultiDeleteFileBenchmark \ + * -PjmhOutputPath=benchmark/iceberg-source-parquet-multi-delete-file-benchmark-result.txt + * + */ +public class IcebergSourceParquetMultiDeleteFileBenchmark extends IcebergSourceDeleteBenchmark { + @Param({"1", "2", "5", "10"}) + private int numDeleteFile; + + @Override + protected void appendData() throws IOException { + for (int fileNum = 1; fileNum <= NUM_FILES; fileNum++) { + writeData(fileNum); + + table().refresh(); + for (DataFile file : table().currentSnapshot().addedFiles(table().io())) { + writePosDeletes(file.path(), NUM_ROWS, 0.25, numDeleteFile); + } + } + } + + @Override + protected FileFormat fileFormat() { + return FileFormat.PARQUET; + } +} diff --git a/spark/v3.3/spark/src/jmh/java/org/apache/iceberg/spark/source/parquet/IcebergSourceParquetPosDeleteBenchmark.java b/spark/v3.3/spark/src/jmh/java/org/apache/iceberg/spark/source/parquet/IcebergSourceParquetPosDeleteBenchmark.java new file mode 100644 index 000000000000..326ecd5fa28b --- /dev/null +++ b/spark/v3.3/spark/src/jmh/java/org/apache/iceberg/spark/source/parquet/IcebergSourceParquetPosDeleteBenchmark.java @@ -0,0 +1,63 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iceberg.spark.source.parquet; + +import java.io.IOException; +import org.apache.iceberg.DataFile; +import org.apache.iceberg.FileFormat; +import org.apache.iceberg.spark.source.IcebergSourceDeleteBenchmark; +import org.openjdk.jmh.annotations.Param; + +/** + * A benchmark that evaluates the non-vectorized read and vectorized read with pos-delete in the Spark data source for + * Iceberg. + *

+ * This class uses a dataset with a flat schema. + * To run this benchmark for spark-3.3: + * + * ./gradlew -DsparkVersions=3.3 :iceberg-spark:iceberg-spark-3.3:jmh + * -PjmhIncludeRegex=IcebergSourceParquetPosDeleteBenchmark + * -PjmhOutputPath=benchmark/iceberg-source-parquet-pos-delete-benchmark-result.txt + * + */ +public class IcebergSourceParquetPosDeleteBenchmark extends IcebergSourceDeleteBenchmark { + @Param({"0", "0.000001", "0.05", "0.25", "0.5", "1"}) + private double percentDeleteRow; + + @Override + protected void appendData() throws IOException { + for (int fileNum = 1; fileNum <= NUM_FILES; fileNum++) { + writeData(fileNum); + + if (percentDeleteRow > 0) { + // add pos-deletes + table().refresh(); + for (DataFile file : table().currentSnapshot().addedFiles(table().io())) { + writePosDeletes(file.path(), NUM_ROWS, percentDeleteRow); + } + } + } + } + + @Override + protected FileFormat fileFormat() { + return FileFormat.PARQUET; + } +} diff --git a/spark/v3.3/spark/src/jmh/java/org/apache/iceberg/spark/source/parquet/IcebergSourceParquetWithUnrelatedDeleteBenchmark.java b/spark/v3.3/spark/src/jmh/java/org/apache/iceberg/spark/source/parquet/IcebergSourceParquetWithUnrelatedDeleteBenchmark.java new file mode 100644 index 000000000000..548c391746b2 --- /dev/null +++ b/spark/v3.3/spark/src/jmh/java/org/apache/iceberg/spark/source/parquet/IcebergSourceParquetWithUnrelatedDeleteBenchmark.java @@ -0,0 +1,62 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iceberg.spark.source.parquet; + +import java.io.IOException; +import org.apache.iceberg.DataFile; +import org.apache.iceberg.FileFormat; +import org.apache.iceberg.spark.source.IcebergSourceDeleteBenchmark; +import org.openjdk.jmh.annotations.Param; + +/** + * A benchmark that evaluates the non-vectorized read and vectorized read with pos-delete in the Spark data source for + * Iceberg. + *

+ * This class uses a dataset with a flat schema. + * To run this benchmark for spark-3.3: + * + * ./gradlew -DsparkVersions=3.3 :iceberg-spark:iceberg-spark-3.3:jmh + * -PjmhIncludeRegex=IcebergSourceParquetWithUnrelatedDeleteBenchmark + * -PjmhOutputPath=benchmark/iceberg-source-parquet-with-unrelated-delete-benchmark-result.txt + * + */ +public class IcebergSourceParquetWithUnrelatedDeleteBenchmark extends IcebergSourceDeleteBenchmark { + private static final double PERCENT_DELETE_ROW = 0.05; + @Param({"0", "0.05", "0.25", "0.5"}) + private double percentUnrelatedDeletes; + + @Override + protected void appendData() throws IOException { + for (int fileNum = 1; fileNum <= NUM_FILES; fileNum++) { + writeData(fileNum); + + table().refresh(); + for (DataFile file : table().currentSnapshot().addedFiles(table().io())) { + writePosDeletesWithNoise(file.path(), NUM_ROWS, PERCENT_DELETE_ROW, + (int) (percentUnrelatedDeletes / PERCENT_DELETE_ROW), 1); + } + } + } + + @Override + protected FileFormat fileFormat() { + return FileFormat.PARQUET; + } +} diff --git a/spark/v3.3/spark/src/jmh/java/org/apache/iceberg/spark/source/parquet/ParquetWritersBenchmark.java b/spark/v3.3/spark/src/jmh/java/org/apache/iceberg/spark/source/parquet/ParquetWritersBenchmark.java new file mode 100644 index 000000000000..0d347222e0e5 --- /dev/null +++ b/spark/v3.3/spark/src/jmh/java/org/apache/iceberg/spark/source/parquet/ParquetWritersBenchmark.java @@ -0,0 +1,41 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iceberg.spark.source.parquet; + +import org.apache.iceberg.FileFormat; +import org.apache.iceberg.spark.source.WritersBenchmark; + +/** + * A benchmark that evaluates the performance of various Iceberg writers for Parquet data. + * + * To run this benchmark for spark-3.3: + * + * ./gradlew -DsparkVersions=3.3 :iceberg-spark:iceberg-spark-3.3_2.12:jmh \ + * -PjmhIncludeRegex=ParquetWritersBenchmark \ + * -PjmhOutputPath=benchmark/parquet-writers-benchmark-result.txt + * + */ +public class ParquetWritersBenchmark extends WritersBenchmark { + + @Override + protected FileFormat fileFormat() { + return FileFormat.PARQUET; + } +} diff --git a/spark/v3.3/spark/src/jmh/java/org/apache/iceberg/spark/source/parquet/vectorized/VectorizedReadDictionaryEncodedFlatParquetDataBenchmark.java b/spark/v3.3/spark/src/jmh/java/org/apache/iceberg/spark/source/parquet/vectorized/VectorizedReadDictionaryEncodedFlatParquetDataBenchmark.java new file mode 100644 index 000000000000..0bb0a1192312 --- /dev/null +++ b/spark/v3.3/spark/src/jmh/java/org/apache/iceberg/spark/source/parquet/vectorized/VectorizedReadDictionaryEncodedFlatParquetDataBenchmark.java @@ -0,0 +1,139 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iceberg.spark.source.parquet.vectorized; + +import java.math.BigDecimal; +import java.math.BigInteger; +import java.util.Map; +import org.apache.iceberg.TableProperties; +import org.apache.iceberg.relocated.com.google.common.collect.Maps; +import org.apache.iceberg.types.Types; +import org.apache.spark.sql.Column; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.SaveMode; +import org.apache.spark.sql.types.DataTypes; +import org.openjdk.jmh.annotations.Setup; + +import static org.apache.spark.sql.functions.col; +import static org.apache.spark.sql.functions.date_add; +import static org.apache.spark.sql.functions.lit; +import static org.apache.spark.sql.functions.pmod; +import static org.apache.spark.sql.functions.to_date; +import static org.apache.spark.sql.functions.to_timestamp; + +/** + * Benchmark to compare performance of reading Parquet dictionary encoded data with a flat schema using vectorized + * Iceberg read path and the built-in file source in Spark. + *

+ * To run this benchmark for spark-3.3: + * + * ./gradlew -DsparkVersions=3.3 :iceberg-spark:iceberg-spark-3.3_2.12:jmh + * -PjmhIncludeRegex=VectorizedReadDictionaryEncodedFlatParquetDataBenchmark + * -PjmhOutputPath=benchmark/results.txt + * + */ +public class VectorizedReadDictionaryEncodedFlatParquetDataBenchmark extends VectorizedReadFlatParquetDataBenchmark { + + @Setup + @Override + public void setupBenchmark() { + setupSpark(true); + appendData(); + // Allow unsafe memory access to avoid the costly check arrow does to check if index is within bounds + System.setProperty("arrow.enable_unsafe_memory_access", "true"); + // Disable expensive null check for every get(index) call. + // Iceberg manages nullability checks itself instead of relying on arrow. + System.setProperty("arrow.enable_null_check_for_get", "false"); + } + + @Override + Map parquetWriteProps() { + Map properties = Maps.newHashMap(); + properties.put(TableProperties.METADATA_COMPRESSION, "gzip"); + return properties; + } + + @Override + void appendData() { + Dataset df = idDF(); + df = withLongColumnDictEncoded(df); + df = withIntColumnDictEncoded(df); + df = withFloatColumnDictEncoded(df); + df = withDoubleColumnDictEncoded(df); + df = withDecimalColumnDictEncoded(df); + df = withDateColumnDictEncoded(df); + df = withTimestampColumnDictEncoded(df); + df = withStringColumnDictEncoded(df); + df = df.drop("id"); + df.write().format("iceberg") + .mode(SaveMode.Append) + .save(table().location()); + } + + private static Column modColumn() { + return pmod(col("id"), lit(9)); + } + + private Dataset idDF() { + return spark().range(0, NUM_ROWS_PER_FILE * NUM_FILES, 1, NUM_FILES).toDF(); + } + + private static Dataset withLongColumnDictEncoded(Dataset df) { + return df.withColumn("longCol", modColumn().cast(DataTypes.LongType)); + } + + private static Dataset withIntColumnDictEncoded(Dataset df) { + return df.withColumn("intCol", modColumn().cast(DataTypes.IntegerType)); + } + + private static Dataset withFloatColumnDictEncoded(Dataset df) { + return df.withColumn("floatCol", modColumn().cast(DataTypes.FloatType)); + + } + + private static Dataset withDoubleColumnDictEncoded(Dataset df) { + return df.withColumn("doubleCol", modColumn().cast(DataTypes.DoubleType)); + } + + private static Dataset withDecimalColumnDictEncoded(Dataset df) { + Types.DecimalType type = Types.DecimalType.of(20, 5); + return df.withColumn("decimalCol", lit(bigDecimal(type, 0)).plus(modColumn())); + } + + private static Dataset withDateColumnDictEncoded(Dataset df) { + Column days = modColumn().cast(DataTypes.ShortType); + return df.withColumn("dateCol", date_add(to_date(lit("04/12/2019"), "MM/dd/yyyy"), days)); + } + + private static Dataset withTimestampColumnDictEncoded(Dataset df) { + Column days = modColumn().cast(DataTypes.ShortType); + return df.withColumn("timestampCol", to_timestamp(date_add(to_date(lit("04/12/2019"), "MM/dd/yyyy"), days))); + } + + private static Dataset withStringColumnDictEncoded(Dataset df) { + return df.withColumn("stringCol", modColumn().cast(DataTypes.StringType)); + } + + private static BigDecimal bigDecimal(Types.DecimalType type, int value) { + BigInteger unscaled = new BigInteger(String.valueOf(value + 1)); + return new BigDecimal(unscaled, type.scale()); + } +} diff --git a/spark/v3.3/spark/src/jmh/java/org/apache/iceberg/spark/source/parquet/vectorized/VectorizedReadFlatParquetDataBenchmark.java b/spark/v3.3/spark/src/jmh/java/org/apache/iceberg/spark/source/parquet/vectorized/VectorizedReadFlatParquetDataBenchmark.java new file mode 100644 index 000000000000..3e2a566ead55 --- /dev/null +++ b/spark/v3.3/spark/src/jmh/java/org/apache/iceberg/spark/source/parquet/vectorized/VectorizedReadFlatParquetDataBenchmark.java @@ -0,0 +1,304 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iceberg.spark.source.parquet.vectorized; + +import java.io.IOException; +import java.util.Map; +import org.apache.hadoop.conf.Configuration; +import org.apache.iceberg.PartitionSpec; +import org.apache.iceberg.Schema; +import org.apache.iceberg.Table; +import org.apache.iceberg.TableProperties; +import org.apache.iceberg.hadoop.HadoopTables; +import org.apache.iceberg.relocated.com.google.common.collect.Maps; +import org.apache.iceberg.spark.source.IcebergSourceBenchmark; +import org.apache.iceberg.types.Types; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.internal.SQLConf; +import org.openjdk.jmh.annotations.Benchmark; +import org.openjdk.jmh.annotations.Setup; +import org.openjdk.jmh.annotations.TearDown; +import org.openjdk.jmh.annotations.Threads; + +import static org.apache.iceberg.types.Types.NestedField.optional; +import static org.apache.spark.sql.functions.col; +import static org.apache.spark.sql.functions.current_date; +import static org.apache.spark.sql.functions.date_add; +import static org.apache.spark.sql.functions.expr; +import static org.apache.spark.sql.functions.lit; +import static org.apache.spark.sql.functions.pmod; +import static org.apache.spark.sql.functions.when; + +/** + * Benchmark to compare performance of reading Parquet data with a flat schema using vectorized Iceberg read path and + * the built-in file source in Spark. + *

+ * To run this benchmark for spark-3.3: + * + * ./gradlew -DsparkVersions=3.3 :iceberg-spark:iceberg-spark-3.3_2.12:jmh + * -PjmhIncludeRegex=VectorizedReadFlatParquetDataBenchmark + * -PjmhOutputPath=benchmark/results.txt + * + */ +public class VectorizedReadFlatParquetDataBenchmark extends IcebergSourceBenchmark { + + static final int NUM_FILES = 5; + static final int NUM_ROWS_PER_FILE = 10_000_000; + + @Setup + public void setupBenchmark() { + setupSpark(); + appendData(); + // Allow unsafe memory access to avoid the costly check arrow does to check if index is within bounds + System.setProperty("arrow.enable_unsafe_memory_access", "true"); + // Disable expensive null check for every get(index) call. + // Iceberg manages nullability checks itself instead of relying on arrow. + System.setProperty("arrow.enable_null_check_for_get", "false"); + } + + @TearDown + public void tearDownBenchmark() throws IOException { + tearDownSpark(); + cleanupFiles(); + } + + @Override + protected Configuration initHadoopConf() { + return new Configuration(); + } + + @Override + protected Table initTable() { + Schema schema = new Schema( + optional(1, "longCol", Types.LongType.get()), + optional(2, "intCol", Types.IntegerType.get()), + optional(3, "floatCol", Types.FloatType.get()), + optional(4, "doubleCol", Types.DoubleType.get()), + optional(5, "decimalCol", Types.DecimalType.of(20, 5)), + optional(6, "dateCol", Types.DateType.get()), + optional(7, "timestampCol", Types.TimestampType.withZone()), + optional(8, "stringCol", Types.StringType.get())); + PartitionSpec partitionSpec = PartitionSpec.unpartitioned(); + HadoopTables tables = new HadoopTables(hadoopConf()); + Map properties = parquetWriteProps(); + return tables.create(schema, partitionSpec, properties, newTableLocation()); + } + + Map parquetWriteProps() { + Map properties = Maps.newHashMap(); + properties.put(TableProperties.METADATA_COMPRESSION, "gzip"); + properties.put(TableProperties.PARQUET_DICT_SIZE_BYTES, "1"); + return properties; + } + + void appendData() { + for (int fileNum = 1; fileNum <= NUM_FILES; fileNum++) { + Dataset df = spark().range(NUM_ROWS_PER_FILE) + .withColumn( + "longCol", + when(pmod(col("id"), lit(10)).equalTo(lit(0)), lit(null)) + .otherwise(col("id"))) + .drop("id") + .withColumn("intCol", expr("CAST(longCol AS INT)")) + .withColumn("floatCol", expr("CAST(longCol AS FLOAT)")) + .withColumn("doubleCol", expr("CAST(longCol AS DOUBLE)")) + .withColumn("decimalCol", expr("CAST(longCol AS DECIMAL(20, 5))")) + .withColumn("dateCol", date_add(current_date(), fileNum)) + .withColumn("timestampCol", expr("TO_TIMESTAMP(dateCol)")) + .withColumn("stringCol", expr("CAST(longCol AS STRING)")); + appendAsFile(df); + } + } + + @Benchmark + @Threads(1) + public void readIntegersIcebergVectorized5k() { + withTableProperties(tablePropsWithVectorizationEnabled(5000), () -> { + String tableLocation = table().location(); + Dataset df = spark().read().format("iceberg") + .load(tableLocation).select("intCol"); + materialize(df); + }); + } + + @Benchmark + @Threads(1) + public void readIntegersSparkVectorized5k() { + withSQLConf(sparkConfWithVectorizationEnabled(5000), () -> { + Dataset df = spark().read().parquet(dataLocation()).select("intCol"); + materialize(df); + }); + } + + @Benchmark + @Threads(1) + public void readLongsIcebergVectorized5k() { + withTableProperties(tablePropsWithVectorizationEnabled(5000), () -> { + String tableLocation = table().location(); + Dataset df = spark().read().format("iceberg") + .load(tableLocation).select("longCol"); + materialize(df); + }); + } + + @Benchmark + @Threads(1) + public void readLongsSparkVectorized5k() { + withSQLConf(sparkConfWithVectorizationEnabled(5000), () -> { + Dataset df = spark().read().parquet(dataLocation()).select("longCol"); + materialize(df); + }); + } + + @Benchmark + @Threads(1) + public void readFloatsIcebergVectorized5k() { + withTableProperties(tablePropsWithVectorizationEnabled(5000), () -> { + String tableLocation = table().location(); + Dataset df = spark().read().format("iceberg") + .load(tableLocation).select("floatCol"); + materialize(df); + }); + } + + @Benchmark + @Threads(1) + public void readFloatsSparkVectorized5k() { + withSQLConf(sparkConfWithVectorizationEnabled(5000), () -> { + Dataset df = spark().read().parquet(dataLocation()).select("floatCol"); + materialize(df); + }); + } + + @Benchmark + @Threads(1) + public void readDoublesIcebergVectorized5k() { + withTableProperties(tablePropsWithVectorizationEnabled(5000), () -> { + String tableLocation = table().location(); + Dataset df = spark().read().format("iceberg") + .load(tableLocation).select("doubleCol"); + materialize(df); + }); + } + + @Benchmark + @Threads(1) + public void readDoublesSparkVectorized5k() { + withSQLConf(sparkConfWithVectorizationEnabled(5000), () -> { + Dataset df = spark().read().parquet(dataLocation()).select("doubleCol"); + materialize(df); + }); + } + + @Benchmark + @Threads(1) + public void readDecimalsIcebergVectorized5k() { + withTableProperties(tablePropsWithVectorizationEnabled(5000), () -> { + String tableLocation = table().location(); + Dataset df = spark().read().format("iceberg") + .load(tableLocation).select("decimalCol"); + materialize(df); + }); + } + + @Benchmark + @Threads(1) + public void readDecimalsSparkVectorized5k() { + withSQLConf(sparkConfWithVectorizationEnabled(5000), () -> { + Dataset df = spark().read().parquet(dataLocation()).select("decimalCol"); + materialize(df); + }); + } + + @Benchmark + @Threads(1) + public void readDatesIcebergVectorized5k() { + withTableProperties(tablePropsWithVectorizationEnabled(5000), () -> { + String tableLocation = table().location(); + Dataset df = spark().read().format("iceberg") + .load(tableLocation).select("dateCol"); + materialize(df); + }); + } + + @Benchmark + @Threads(1) + public void readDatesSparkVectorized5k() { + withSQLConf(sparkConfWithVectorizationEnabled(5000), () -> { + Dataset df = spark().read().parquet(dataLocation()).select("dateCol"); + materialize(df); + }); + } + + @Benchmark + @Threads(1) + public void readTimestampsIcebergVectorized5k() { + withTableProperties(tablePropsWithVectorizationEnabled(5000), () -> { + String tableLocation = table().location(); + Dataset df = spark().read().format("iceberg") + .load(tableLocation).select("timestampCol"); + materialize(df); + }); + } + + @Benchmark + @Threads(1) + public void readTimestampsSparkVectorized5k() { + withSQLConf(sparkConfWithVectorizationEnabled(5000), () -> { + Dataset df = spark().read().parquet(dataLocation()).select("timestampCol"); + materialize(df); + }); + } + + @Benchmark + @Threads(1) + public void readStringsIcebergVectorized5k() { + withTableProperties(tablePropsWithVectorizationEnabled(5000), () -> { + String tableLocation = table().location(); + Dataset df = spark().read().format("iceberg") + .load(tableLocation).select("stringCol"); + materialize(df); + }); + } + + @Benchmark + @Threads(1) + public void readStringsSparkVectorized5k() { + withSQLConf(sparkConfWithVectorizationEnabled(5000), () -> { + Dataset df = spark().read().parquet(dataLocation()).select("stringCol"); + materialize(df); + }); + } + + private static Map tablePropsWithVectorizationEnabled(int batchSize) { + Map tableProperties = Maps.newHashMap(); + tableProperties.put(TableProperties.PARQUET_VECTORIZATION_ENABLED, "true"); + tableProperties.put(TableProperties.PARQUET_BATCH_SIZE, String.valueOf(batchSize)); + return tableProperties; + } + + private static Map sparkConfWithVectorizationEnabled(int batchSize) { + Map conf = Maps.newHashMap(); + conf.put(SQLConf.PARQUET_VECTORIZED_READER_ENABLED().key(), "true"); + conf.put(SQLConf.PARQUET_VECTORIZED_READER_BATCH_SIZE().key(), String.valueOf(batchSize)); + return conf; + } +} diff --git a/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/BaseCatalog.java b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/BaseCatalog.java new file mode 100644 index 000000000000..43447e346648 --- /dev/null +++ b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/BaseCatalog.java @@ -0,0 +1,49 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iceberg.spark; + +import org.apache.iceberg.spark.procedures.SparkProcedures; +import org.apache.iceberg.spark.procedures.SparkProcedures.ProcedureBuilder; +import org.apache.iceberg.spark.source.HasIcebergCatalog; +import org.apache.spark.sql.catalyst.analysis.NoSuchProcedureException; +import org.apache.spark.sql.connector.catalog.Identifier; +import org.apache.spark.sql.connector.catalog.StagingTableCatalog; +import org.apache.spark.sql.connector.catalog.SupportsNamespaces; +import org.apache.spark.sql.connector.iceberg.catalog.Procedure; +import org.apache.spark.sql.connector.iceberg.catalog.ProcedureCatalog; + +abstract class BaseCatalog implements StagingTableCatalog, ProcedureCatalog, SupportsNamespaces, HasIcebergCatalog { + + @Override + public Procedure loadProcedure(Identifier ident) throws NoSuchProcedureException { + String[] namespace = ident.namespace(); + String name = ident.name(); + + // namespace resolution is case insensitive until we have a way to configure case sensitivity in catalogs + if (namespace.length == 1 && namespace[0].equalsIgnoreCase("system")) { + ProcedureBuilder builder = SparkProcedures.newBuilder(name); + if (builder != null) { + return builder.withTableCatalog(this).build(); + } + } + + throw new NoSuchProcedureException(ident); + } +} diff --git a/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/CommitMetadata.java b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/CommitMetadata.java new file mode 100644 index 000000000000..58137250003a --- /dev/null +++ b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/CommitMetadata.java @@ -0,0 +1,61 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iceberg.spark; + +import java.util.Map; +import java.util.concurrent.Callable; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap; +import org.apache.iceberg.util.ExceptionUtil; + +/** + * utility class to accept thread local commit properties + */ +public class CommitMetadata { + + private CommitMetadata() { + + } + + private static final ThreadLocal> COMMIT_PROPERTIES = ThreadLocal.withInitial(ImmutableMap::of); + + /** + * running the code wrapped as a caller, and any snapshot committed within the callable object will be attached with + * the metadata defined in properties + * @param properties extra commit metadata to attach to the snapshot committed within callable + * @param callable the code to be executed + * @param exClass the expected type of exception which would be thrown from callable + */ + public static R withCommitProperties( + Map properties, Callable callable, Class exClass) throws E { + COMMIT_PROPERTIES.set(properties); + try { + return callable.call(); + } catch (Throwable e) { + ExceptionUtil.castAndThrow(e, exClass); + return null; + } finally { + COMMIT_PROPERTIES.set(ImmutableMap.of()); + } + } + + public static Map commitProperties() { + return COMMIT_PROPERTIES.get(); + } +} diff --git a/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/FileRewriteCoordinator.java b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/FileRewriteCoordinator.java new file mode 100644 index 000000000000..acd5f64d7ed6 --- /dev/null +++ b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/FileRewriteCoordinator.java @@ -0,0 +1,94 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iceberg.spark; + +import java.util.Map; +import java.util.Set; +import java.util.stream.Collectors; +import org.apache.iceberg.DataFile; +import org.apache.iceberg.HasTableOperations; +import org.apache.iceberg.Table; +import org.apache.iceberg.TableOperations; +import org.apache.iceberg.exceptions.ValidationException; +import org.apache.iceberg.relocated.com.google.common.collect.Maps; +import org.apache.iceberg.util.Pair; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +public class FileRewriteCoordinator { + + private static final Logger LOG = LoggerFactory.getLogger(FileRewriteCoordinator.class); + private static final FileRewriteCoordinator INSTANCE = new FileRewriteCoordinator(); + + private final Map, Set> resultMap = Maps.newConcurrentMap(); + + private FileRewriteCoordinator() { + } + + public static FileRewriteCoordinator get() { + return INSTANCE; + } + + /** + * Called to persist the output of a rewrite action for a specific group. Since the write is done via a + * Spark Datasource, we have to propagate the result through this side-effect call. + * @param table table where the rewrite is occurring + * @param fileSetID the id used to identify the source set of files being rewritten + * @param newDataFiles the new files which have been written + */ + public void stageRewrite(Table table, String fileSetID, Set newDataFiles) { + LOG.debug("Staging the output for {} - fileset {} with {} files", table.name(), fileSetID, newDataFiles.size()); + Pair id = toID(table, fileSetID); + resultMap.put(id, newDataFiles); + } + + public Set fetchNewDataFiles(Table table, String fileSetID) { + Pair id = toID(table, fileSetID); + Set result = resultMap.get(id); + ValidationException.check(result != null, + "No results for rewrite of file set %s in table %s", + fileSetID, table); + + return result; + } + + public void clearRewrite(Table table, String fileSetID) { + LOG.debug("Removing entry from RewriteCoordinator for {} - id {}", table.name(), fileSetID); + Pair id = toID(table, fileSetID); + resultMap.remove(id); + } + + public Set fetchSetIDs(Table table) { + return resultMap.keySet().stream() + .filter(e -> e.first().equals(tableUUID(table))) + .map(Pair::second) + .collect(Collectors.toSet()); + } + + private Pair toID(Table table, String setID) { + String tableUUID = tableUUID(table); + return Pair.of(tableUUID, setID); + } + + private String tableUUID(Table table) { + TableOperations ops = ((HasTableOperations) table).operations(); + return ops.current().uuid(); + } +} diff --git a/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/FileScanTaskSetManager.java b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/FileScanTaskSetManager.java new file mode 100644 index 000000000000..827b674ca16d --- /dev/null +++ b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/FileScanTaskSetManager.java @@ -0,0 +1,78 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iceberg.spark; + +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.stream.Collectors; +import org.apache.iceberg.FileScanTask; +import org.apache.iceberg.HasTableOperations; +import org.apache.iceberg.Table; +import org.apache.iceberg.TableOperations; +import org.apache.iceberg.relocated.com.google.common.base.Preconditions; +import org.apache.iceberg.relocated.com.google.common.collect.Maps; +import org.apache.iceberg.util.Pair; + +public class FileScanTaskSetManager { + + private static final FileScanTaskSetManager INSTANCE = new FileScanTaskSetManager(); + + private final Map, List> tasksMap = Maps.newConcurrentMap(); + + private FileScanTaskSetManager() { + } + + public static FileScanTaskSetManager get() { + return INSTANCE; + } + + public void stageTasks(Table table, String setID, List tasks) { + Preconditions.checkArgument(tasks != null && tasks.size() > 0, "Cannot stage null or empty tasks"); + Pair id = toID(table, setID); + tasksMap.put(id, tasks); + } + + public List fetchTasks(Table table, String setID) { + Pair id = toID(table, setID); + return tasksMap.get(id); + } + + public List removeTasks(Table table, String setID) { + Pair id = toID(table, setID); + return tasksMap.remove(id); + } + + public Set fetchSetIDs(Table table) { + return tasksMap.keySet().stream() + .filter(e -> e.first().equals(tableUUID(table))) + .map(Pair::second) + .collect(Collectors.toSet()); + } + + private String tableUUID(Table table) { + TableOperations ops = ((HasTableOperations) table).operations(); + return ops.current().uuid(); + } + + private Pair toID(Table table, String setID) { + return Pair.of(tableUUID(table), setID); + } +} diff --git a/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/IcebergSpark.java b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/IcebergSpark.java new file mode 100644 index 000000000000..e8889bc1fd01 --- /dev/null +++ b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/IcebergSpark.java @@ -0,0 +1,48 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iceberg.spark; + +import org.apache.iceberg.transforms.Transform; +import org.apache.iceberg.transforms.Transforms; +import org.apache.iceberg.types.Type; +import org.apache.spark.sql.SparkSession; +import org.apache.spark.sql.types.DataType; +import org.apache.spark.sql.types.DataTypes; + +public class IcebergSpark { + private IcebergSpark() { + } + + public static void registerBucketUDF(SparkSession session, String funcName, DataType sourceType, int numBuckets) { + SparkTypeToType typeConverter = new SparkTypeToType(); + Type sourceIcebergType = typeConverter.atomic(sourceType); + Transform bucket = Transforms.bucket(sourceIcebergType, numBuckets); + session.udf().register(funcName, + value -> bucket.apply(SparkValueConverter.convert(sourceIcebergType, value)), DataTypes.IntegerType); + } + + public static void registerTruncateUDF(SparkSession session, String funcName, DataType sourceType, int width) { + SparkTypeToType typeConverter = new SparkTypeToType(); + Type sourceIcebergType = typeConverter.atomic(sourceType); + Transform truncate = Transforms.truncate(sourceIcebergType, width); + session.udf().register(funcName, + value -> truncate.apply(SparkValueConverter.convert(sourceIcebergType, value)), sourceType); + } +} diff --git a/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/JobGroupInfo.java b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/JobGroupInfo.java new file mode 100644 index 000000000000..a35808fd8ce6 --- /dev/null +++ b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/JobGroupInfo.java @@ -0,0 +1,48 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iceberg.spark; + +/** + * Captures information about the current job + * which is used for displaying on the UI + */ +public class JobGroupInfo { + private String groupId; + private String description; + private boolean interruptOnCancel; + + public JobGroupInfo(String groupId, String desc, boolean interruptOnCancel) { + this.groupId = groupId; + this.description = desc; + this.interruptOnCancel = interruptOnCancel; + } + + public String groupId() { + return groupId; + } + + public String description() { + return description; + } + + public boolean interruptOnCancel() { + return interruptOnCancel; + } +} diff --git a/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/JobGroupUtils.java b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/JobGroupUtils.java new file mode 100644 index 000000000000..155dce707701 --- /dev/null +++ b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/JobGroupUtils.java @@ -0,0 +1,46 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iceberg.spark; + +import org.apache.spark.SparkContext; +import org.apache.spark.SparkContext$; + +public class JobGroupUtils { + + private static final String JOB_GROUP_ID = SparkContext$.MODULE$.SPARK_JOB_GROUP_ID(); + private static final String JOB_GROUP_DESC = SparkContext$.MODULE$.SPARK_JOB_DESCRIPTION(); + private static final String JOB_INTERRUPT_ON_CANCEL = SparkContext$.MODULE$.SPARK_JOB_INTERRUPT_ON_CANCEL(); + + private JobGroupUtils() { + } + + public static JobGroupInfo getJobGroupInfo(SparkContext sparkContext) { + String groupId = sparkContext.getLocalProperty(JOB_GROUP_ID); + String description = sparkContext.getLocalProperty(JOB_GROUP_DESC); + String interruptOnCancel = sparkContext.getLocalProperty(JOB_INTERRUPT_ON_CANCEL); + return new JobGroupInfo(groupId, description, Boolean.parseBoolean(interruptOnCancel)); + } + + public static void setJobGroupInfo(SparkContext sparkContext, JobGroupInfo info) { + sparkContext.setLocalProperty(JOB_GROUP_ID, info.groupId()); + sparkContext.setLocalProperty(JOB_GROUP_DESC, info.description()); + sparkContext.setLocalProperty(JOB_INTERRUPT_ON_CANCEL, String.valueOf(info.interruptOnCancel())); + } +} diff --git a/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/PathIdentifier.java b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/PathIdentifier.java new file mode 100644 index 000000000000..235097ea46cc --- /dev/null +++ b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/PathIdentifier.java @@ -0,0 +1,57 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iceberg.spark; + +import java.util.List; +import org.apache.iceberg.relocated.com.google.common.base.Joiner; +import org.apache.iceberg.relocated.com.google.common.base.Splitter; +import org.apache.iceberg.relocated.com.google.common.collect.Iterables; +import org.apache.spark.sql.connector.catalog.Identifier; + +public class PathIdentifier implements Identifier { + private static final Splitter SPLIT = Splitter.on("/"); + private static final Joiner JOIN = Joiner.on("/"); + private final String[] namespace; + private final String location; + private final String name; + + public PathIdentifier(String location) { + this.location = location; + List pathParts = SPLIT.splitToList(location); + name = Iterables.getLast(pathParts); + namespace = pathParts.size() > 1 ? + new String[]{JOIN.join(pathParts.subList(0, pathParts.size() - 1))} : + new String[0]; + } + + @Override + public String[] namespace() { + return namespace; + } + + @Override + public String name() { + return name; + } + + public String location() { + return location; + } +} diff --git a/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/PruneColumnsWithReordering.java b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/PruneColumnsWithReordering.java new file mode 100644 index 000000000000..3bdf984ed219 --- /dev/null +++ b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/PruneColumnsWithReordering.java @@ -0,0 +1,258 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iceberg.spark; + +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.function.Supplier; +import org.apache.iceberg.Schema; +import org.apache.iceberg.relocated.com.google.common.base.Preconditions; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap; +import org.apache.iceberg.relocated.com.google.common.collect.Lists; +import org.apache.iceberg.relocated.com.google.common.collect.Maps; +import org.apache.iceberg.types.Type; +import org.apache.iceberg.types.Type.TypeID; +import org.apache.iceberg.types.TypeUtil; +import org.apache.iceberg.types.Types; +import org.apache.spark.sql.types.ArrayType; +import org.apache.spark.sql.types.BinaryType; +import org.apache.spark.sql.types.BooleanType; +import org.apache.spark.sql.types.DataType; +import org.apache.spark.sql.types.DateType; +import org.apache.spark.sql.types.DecimalType; +import org.apache.spark.sql.types.DoubleType; +import org.apache.spark.sql.types.FloatType; +import org.apache.spark.sql.types.IntegerType; +import org.apache.spark.sql.types.LongType; +import org.apache.spark.sql.types.MapType; +import org.apache.spark.sql.types.StringType; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; +import org.apache.spark.sql.types.TimestampType; + +public class PruneColumnsWithReordering extends TypeUtil.CustomOrderSchemaVisitor { + private final StructType requestedType; + private final Set filterRefs; + private DataType current = null; + + PruneColumnsWithReordering(StructType requestedType, Set filterRefs) { + this.requestedType = requestedType; + this.filterRefs = filterRefs; + } + + @Override + public Type schema(Schema schema, Supplier structResult) { + this.current = requestedType; + try { + return structResult.get(); + } finally { + this.current = null; + } + } + + @Override + public Type struct(Types.StructType struct, Iterable fieldResults) { + Preconditions.checkNotNull(struct, "Cannot prune null struct. Pruning must start with a schema."); + Preconditions.checkArgument(current instanceof StructType, "Not a struct: %s", current); + StructType requestedStruct = (StructType) current; + + List fields = struct.fields(); + List types = Lists.newArrayList(fieldResults); + + boolean changed = false; + // use a LinkedHashMap to preserve the original order of filter fields that are not projected + Map projectedFields = Maps.newLinkedHashMap(); + for (int i = 0; i < fields.size(); i += 1) { + Types.NestedField field = fields.get(i); + Type type = types.get(i); + + if (type == null) { + changed = true; + + } else if (field.type() == type) { + projectedFields.put(field.name(), field); + + } else if (field.isOptional()) { + changed = true; + projectedFields.put(field.name(), + Types.NestedField.optional(field.fieldId(), field.name(), type)); + + } else { + changed = true; + projectedFields.put(field.name(), + Types.NestedField.required(field.fieldId(), field.name(), type)); + } + } + + // Construct a new struct with the projected struct's order + boolean reordered = false; + StructField[] requestedFields = requestedStruct.fields(); + List newFields = Lists.newArrayListWithExpectedSize(requestedFields.length); + for (int i = 0; i < requestedFields.length; i += 1) { + // fields are resolved by name because Spark only sees the current table schema. + String name = requestedFields[i].name(); + if (!fields.get(i).name().equals(name)) { + reordered = true; + } + newFields.add(projectedFields.remove(name)); + } + + // Add remaining filter fields that were not explicitly projected + if (!projectedFields.isEmpty()) { + newFields.addAll(projectedFields.values()); + changed = true; // order probably changed + } + + if (reordered || changed) { + return Types.StructType.of(newFields); + } + + return struct; + } + + @Override + public Type field(Types.NestedField field, Supplier fieldResult) { + Preconditions.checkArgument(current instanceof StructType, "Not a struct: %s", current); + StructType requestedStruct = (StructType) current; + + // fields are resolved by name because Spark only sees the current table schema. + if (requestedStruct.getFieldIndex(field.name()).isEmpty()) { + // make sure that filter fields are projected even if they aren't in the requested schema. + if (filterRefs.contains(field.fieldId())) { + return field.type(); + } + return null; + } + + int fieldIndex = requestedStruct.fieldIndex(field.name()); + StructField requestedField = requestedStruct.fields()[fieldIndex]; + + Preconditions.checkArgument(requestedField.nullable() || field.isRequired(), + "Cannot project an optional field as non-null: %s", field.name()); + + this.current = requestedField.dataType(); + try { + return fieldResult.get(); + } catch (IllegalArgumentException e) { + throw new IllegalArgumentException( + "Invalid projection for field " + field.name() + ": " + e.getMessage(), e); + } finally { + this.current = requestedStruct; + } + } + + @Override + public Type list(Types.ListType list, Supplier elementResult) { + Preconditions.checkArgument(current instanceof ArrayType, "Not an array: %s", current); + ArrayType requestedArray = (ArrayType) current; + + Preconditions.checkArgument(requestedArray.containsNull() || !list.isElementOptional(), + "Cannot project an array of optional elements as required elements: %s", requestedArray); + + this.current = requestedArray.elementType(); + try { + Type elementType = elementResult.get(); + if (list.elementType() == elementType) { + return list; + } + + // must be a projected element type, create a new list + if (list.isElementOptional()) { + return Types.ListType.ofOptional(list.elementId(), elementType); + } else { + return Types.ListType.ofRequired(list.elementId(), elementType); + } + } finally { + this.current = requestedArray; + } + } + + @Override + public Type map(Types.MapType map, Supplier keyResult, Supplier valueResult) { + Preconditions.checkArgument(current instanceof MapType, "Not a map: %s", current); + MapType requestedMap = (MapType) current; + + Preconditions.checkArgument(requestedMap.valueContainsNull() || !map.isValueOptional(), + "Cannot project a map of optional values as required values: %s", map); + Preconditions.checkArgument(StringType.class.isInstance(requestedMap.keyType()), + "Invalid map key type (not string): %s", requestedMap.keyType()); + + this.current = requestedMap.valueType(); + try { + Type valueType = valueResult.get(); + if (map.valueType() == valueType) { + return map; + } + + if (map.isValueOptional()) { + return Types.MapType.ofOptional(map.keyId(), map.valueId(), map.keyType(), valueType); + } else { + return Types.MapType.ofRequired(map.keyId(), map.valueId(), map.keyType(), valueType); + } + } finally { + this.current = requestedMap; + } + } + + @Override + public Type primitive(Type.PrimitiveType primitive) { + Class expectedType = TYPES.get(primitive.typeId()); + Preconditions.checkArgument(expectedType != null && expectedType.isInstance(current), + "Cannot project %s to incompatible type: %s", primitive, current); + + // additional checks based on type + switch (primitive.typeId()) { + case DECIMAL: + Types.DecimalType decimal = (Types.DecimalType) primitive; + DecimalType requestedDecimal = (DecimalType) current; + Preconditions.checkArgument(requestedDecimal.scale() == decimal.scale(), + "Cannot project decimal with incompatible scale: %s != %s", requestedDecimal.scale(), decimal.scale()); + Preconditions.checkArgument(requestedDecimal.precision() >= decimal.precision(), + "Cannot project decimal with incompatible precision: %s < %s", + requestedDecimal.precision(), decimal.precision()); + break; + case TIMESTAMP: + Types.TimestampType timestamp = (Types.TimestampType) primitive; + Preconditions.checkArgument(timestamp.shouldAdjustToUTC(), + "Cannot project timestamp (without time zone) as timestamptz (with time zone)"); + break; + default: + } + + return primitive; + } + + private static final ImmutableMap> TYPES = ImmutableMap + .>builder() + .put(TypeID.BOOLEAN, BooleanType.class) + .put(TypeID.INTEGER, IntegerType.class) + .put(TypeID.LONG, LongType.class) + .put(TypeID.FLOAT, FloatType.class) + .put(TypeID.DOUBLE, DoubleType.class) + .put(TypeID.DATE, DateType.class) + .put(TypeID.TIMESTAMP, TimestampType.class) + .put(TypeID.DECIMAL, DecimalType.class) + .put(TypeID.UUID, StringType.class) + .put(TypeID.STRING, StringType.class) + .put(TypeID.FIXED, BinaryType.class) + .put(TypeID.BINARY, BinaryType.class) + .build(); +} diff --git a/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/PruneColumnsWithoutReordering.java b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/PruneColumnsWithoutReordering.java new file mode 100644 index 000000000000..c6984e2fe8cd --- /dev/null +++ b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/PruneColumnsWithoutReordering.java @@ -0,0 +1,226 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iceberg.spark; + +import java.util.List; +import java.util.Set; +import java.util.function.Supplier; +import org.apache.iceberg.Schema; +import org.apache.iceberg.relocated.com.google.common.base.Preconditions; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap; +import org.apache.iceberg.relocated.com.google.common.collect.Lists; +import org.apache.iceberg.types.Type; +import org.apache.iceberg.types.Type.TypeID; +import org.apache.iceberg.types.TypeUtil; +import org.apache.iceberg.types.Types; +import org.apache.spark.sql.types.ArrayType; +import org.apache.spark.sql.types.BinaryType; +import org.apache.spark.sql.types.BooleanType; +import org.apache.spark.sql.types.DataType; +import org.apache.spark.sql.types.DateType; +import org.apache.spark.sql.types.DecimalType; +import org.apache.spark.sql.types.DoubleType; +import org.apache.spark.sql.types.FloatType; +import org.apache.spark.sql.types.IntegerType; +import org.apache.spark.sql.types.LongType; +import org.apache.spark.sql.types.MapType; +import org.apache.spark.sql.types.StringType; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; +import org.apache.spark.sql.types.TimestampType; + +public class PruneColumnsWithoutReordering extends TypeUtil.CustomOrderSchemaVisitor { + private final StructType requestedType; + private final Set filterRefs; + private DataType current = null; + + PruneColumnsWithoutReordering(StructType requestedType, Set filterRefs) { + this.requestedType = requestedType; + this.filterRefs = filterRefs; + } + + @Override + public Type schema(Schema schema, Supplier structResult) { + this.current = requestedType; + try { + return structResult.get(); + } finally { + this.current = null; + } + } + + @Override + public Type struct(Types.StructType struct, Iterable fieldResults) { + Preconditions.checkNotNull(struct, "Cannot prune null struct. Pruning must start with a schema."); + Preconditions.checkArgument(current instanceof StructType, "Not a struct: %s", current); + + List fields = struct.fields(); + List types = Lists.newArrayList(fieldResults); + + boolean changed = false; + List newFields = Lists.newArrayListWithExpectedSize(types.size()); + for (int i = 0; i < fields.size(); i += 1) { + Types.NestedField field = fields.get(i); + Type type = types.get(i); + + if (type == null) { + changed = true; + + } else if (field.type() == type) { + newFields.add(field); + + } else if (field.isOptional()) { + changed = true; + newFields.add(Types.NestedField.optional(field.fieldId(), field.name(), type)); + + } else { + changed = true; + newFields.add(Types.NestedField.required(field.fieldId(), field.name(), type)); + } + } + + if (changed) { + return Types.StructType.of(newFields); + } + + return struct; + } + + @Override + public Type field(Types.NestedField field, Supplier fieldResult) { + Preconditions.checkArgument(current instanceof StructType, "Not a struct: %s", current); + StructType requestedStruct = (StructType) current; + + // fields are resolved by name because Spark only sees the current table schema. + if (requestedStruct.getFieldIndex(field.name()).isEmpty()) { + // make sure that filter fields are projected even if they aren't in the requested schema. + if (filterRefs.contains(field.fieldId())) { + return field.type(); + } + return null; + } + + int fieldIndex = requestedStruct.fieldIndex(field.name()); + StructField requestedField = requestedStruct.fields()[fieldIndex]; + + Preconditions.checkArgument(requestedField.nullable() || field.isRequired(), + "Cannot project an optional field as non-null: %s", field.name()); + + this.current = requestedField.dataType(); + try { + return fieldResult.get(); + } catch (IllegalArgumentException e) { + throw new IllegalArgumentException( + "Invalid projection for field " + field.name() + ": " + e.getMessage(), e); + } finally { + this.current = requestedStruct; + } + } + + @Override + public Type list(Types.ListType list, Supplier elementResult) { + Preconditions.checkArgument(current instanceof ArrayType, "Not an array: %s", current); + ArrayType requestedArray = (ArrayType) current; + + Preconditions.checkArgument(requestedArray.containsNull() || !list.isElementOptional(), + "Cannot project an array of optional elements as required elements: %s", requestedArray); + + this.current = requestedArray.elementType(); + try { + Type elementType = elementResult.get(); + if (list.elementType() == elementType) { + return list; + } + + // must be a projected element type, create a new list + if (list.isElementOptional()) { + return Types.ListType.ofOptional(list.elementId(), elementType); + } else { + return Types.ListType.ofRequired(list.elementId(), elementType); + } + } finally { + this.current = requestedArray; + } + } + + @Override + public Type map(Types.MapType map, Supplier keyResult, Supplier valueResult) { + Preconditions.checkArgument(current instanceof MapType, "Not a map: %s", current); + MapType requestedMap = (MapType) current; + + Preconditions.checkArgument(requestedMap.valueContainsNull() || !map.isValueOptional(), + "Cannot project a map of optional values as required values: %s", map); + + this.current = requestedMap.valueType(); + try { + Type valueType = valueResult.get(); + if (map.valueType() == valueType) { + return map; + } + + if (map.isValueOptional()) { + return Types.MapType.ofOptional(map.keyId(), map.valueId(), map.keyType(), valueType); + } else { + return Types.MapType.ofRequired(map.keyId(), map.valueId(), map.keyType(), valueType); + } + } finally { + this.current = requestedMap; + } + } + + @Override + public Type primitive(Type.PrimitiveType primitive) { + Class expectedType = TYPES.get(primitive.typeId()); + Preconditions.checkArgument(expectedType != null && expectedType.isInstance(current), + "Cannot project %s to incompatible type: %s", primitive, current); + + // additional checks based on type + switch (primitive.typeId()) { + case DECIMAL: + Types.DecimalType decimal = (Types.DecimalType) primitive; + DecimalType requestedDecimal = (DecimalType) current; + Preconditions.checkArgument(requestedDecimal.scale() == decimal.scale(), + "Cannot project decimal with incompatible scale: %s != %s", requestedDecimal.scale(), decimal.scale()); + Preconditions.checkArgument(requestedDecimal.precision() >= decimal.precision(), + "Cannot project decimal with incompatible precision: %s < %s", + requestedDecimal.precision(), decimal.precision()); + break; + default: + } + + return primitive; + } + + private static final ImmutableMap> TYPES = ImmutableMap + .>builder() + .put(TypeID.BOOLEAN, BooleanType.class) + .put(TypeID.INTEGER, IntegerType.class) + .put(TypeID.LONG, LongType.class) + .put(TypeID.FLOAT, FloatType.class) + .put(TypeID.DOUBLE, DoubleType.class) + .put(TypeID.DATE, DateType.class) + .put(TypeID.TIMESTAMP, TimestampType.class) + .put(TypeID.DECIMAL, DecimalType.class) + .put(TypeID.UUID, StringType.class) + .put(TypeID.STRING, StringType.class) + .put(TypeID.FIXED, BinaryType.class) + .put(TypeID.BINARY, BinaryType.class) + .build(); +} diff --git a/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/RollbackStagedTable.java b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/RollbackStagedTable.java new file mode 100644 index 000000000000..a27d06e7a1d7 --- /dev/null +++ b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/RollbackStagedTable.java @@ -0,0 +1,137 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iceberg.spark; + +import java.util.Map; +import java.util.Set; +import java.util.function.Consumer; +import java.util.function.Function; +import org.apache.spark.sql.connector.catalog.Identifier; +import org.apache.spark.sql.connector.catalog.StagedTable; +import org.apache.spark.sql.connector.catalog.SupportsDelete; +import org.apache.spark.sql.connector.catalog.SupportsRead; +import org.apache.spark.sql.connector.catalog.SupportsWrite; +import org.apache.spark.sql.connector.catalog.Table; +import org.apache.spark.sql.connector.catalog.TableCapability; +import org.apache.spark.sql.connector.catalog.TableCatalog; +import org.apache.spark.sql.connector.expressions.Transform; +import org.apache.spark.sql.connector.read.ScanBuilder; +import org.apache.spark.sql.connector.write.LogicalWriteInfo; +import org.apache.spark.sql.connector.write.WriteBuilder; +import org.apache.spark.sql.sources.Filter; +import org.apache.spark.sql.types.StructType; +import org.apache.spark.sql.util.CaseInsensitiveStringMap; + +/** + * An implementation of StagedTable that mimics the behavior of Spark's non-atomic CTAS and RTAS. + *

+ * A Spark catalog can implement StagingTableCatalog to support atomic operations by producing StagedTable. But if a + * catalog implements StagingTableCatalog, Spark expects the catalog to be able to produce a StagedTable for any table + * loaded by the catalog. This assumption doesn't always work, as in the case of {@link SparkSessionCatalog}, which + * supports atomic operations can produce a StagedTable for Iceberg tables, but wraps the session catalog and cannot + * necessarily produce a working StagedTable implementation for tables that it loads. + *

+ * The work-around is this class, which implements the StagedTable interface but does not have atomic behavior. Instead, + * the StagedTable interface is used to implement the behavior of the non-atomic SQL plans that will create a table, + * write, and will drop the table to roll back. + *

+ * This StagedTable implements SupportsRead, SupportsWrite, and SupportsDelete by passing the calls to the real table. + * Implementing those interfaces is safe because Spark will not use them unless the table supports them and returns the + * corresponding capabilities from {@link #capabilities()}. + */ +public class RollbackStagedTable implements StagedTable, SupportsRead, SupportsWrite, SupportsDelete { + private final TableCatalog catalog; + private final Identifier ident; + private final Table table; + + public RollbackStagedTable(TableCatalog catalog, Identifier ident, Table table) { + this.catalog = catalog; + this.ident = ident; + this.table = table; + } + + @Override + public void commitStagedChanges() { + // the changes have already been committed to the table at the end of the write + } + + @Override + public void abortStagedChanges() { + // roll back changes by dropping the table + catalog.dropTable(ident); + } + + @Override + public String name() { + return table.name(); + } + + @Override + public StructType schema() { + return table.schema(); + } + + @Override + public Transform[] partitioning() { + return table.partitioning(); + } + + @Override + public Map properties() { + return table.properties(); + } + + @Override + public Set capabilities() { + return table.capabilities(); + } + + @Override + public void deleteWhere(Filter[] filters) { + call(SupportsDelete.class, t -> t.deleteWhere(filters)); + } + + @Override + public ScanBuilder newScanBuilder(CaseInsensitiveStringMap options) { + return callReturning(SupportsRead.class, t -> t.newScanBuilder(options)); + } + + @Override + public WriteBuilder newWriteBuilder(LogicalWriteInfo info) { + return callReturning(SupportsWrite.class, t -> t.newWriteBuilder(info)); + } + + private void call(Class requiredClass, Consumer task) { + callReturning(requiredClass, inst -> { + task.accept(inst); + return null; + }); + } + + private R callReturning(Class requiredClass, Function task) { + if (requiredClass.isInstance(table)) { + return task.apply(requiredClass.cast(table)); + } else { + throw new UnsupportedOperationException(String.format( + "Table does not implement %s: %s (%s)", + requiredClass.getSimpleName(), table.name(), table.getClass().getName())); + } + } +} diff --git a/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/SortOrderToSpark.java b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/SortOrderToSpark.java new file mode 100644 index 000000000000..5677f83a95ee --- /dev/null +++ b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/SortOrderToSpark.java @@ -0,0 +1,92 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iceberg.spark; + +import java.util.Map; +import org.apache.iceberg.NullOrder; +import org.apache.iceberg.Schema; +import org.apache.iceberg.SortDirection; +import org.apache.iceberg.transforms.SortOrderVisitor; +import org.apache.spark.sql.connector.expressions.Expressions; +import org.apache.spark.sql.connector.expressions.NullOrdering; +import org.apache.spark.sql.connector.expressions.SortOrder; + +class SortOrderToSpark implements SortOrderVisitor { + + private final Map quotedNameById; + + SortOrderToSpark(Schema schema) { + this.quotedNameById = SparkSchemaUtil.indexQuotedNameById(schema); + } + + @Override + public SortOrder field(String sourceName, int id, SortDirection direction, NullOrder nullOrder) { + return Expressions.sort(Expressions.column(quotedName(id)), toSpark(direction), toSpark(nullOrder)); + } + + @Override + public SortOrder bucket(String sourceName, int id, int width, SortDirection direction, NullOrder nullOrder) { + return Expressions.sort(Expressions.bucket(width, quotedName(id)), toSpark(direction), toSpark(nullOrder)); + } + + @Override + public SortOrder truncate(String sourceName, int id, int width, SortDirection direction, NullOrder nullOrder) { + return Expressions.sort(Expressions.apply( + "truncate", Expressions.column(quotedName(id)), Expressions.literal(width)), + toSpark(direction), toSpark(nullOrder)); + } + + @Override + public SortOrder year(String sourceName, int id, SortDirection direction, NullOrder nullOrder) { + return Expressions.sort(Expressions.years(quotedName(id)), toSpark(direction), toSpark(nullOrder)); + } + + @Override + public SortOrder month(String sourceName, int id, SortDirection direction, NullOrder nullOrder) { + return Expressions.sort(Expressions.months(quotedName(id)), toSpark(direction), toSpark(nullOrder)); + } + + @Override + public SortOrder day(String sourceName, int id, SortDirection direction, NullOrder nullOrder) { + return Expressions.sort(Expressions.days(quotedName(id)), toSpark(direction), toSpark(nullOrder)); + } + + @Override + public SortOrder hour(String sourceName, int id, SortDirection direction, NullOrder nullOrder) { + return Expressions.sort(Expressions.hours(quotedName(id)), toSpark(direction), toSpark(nullOrder)); + } + + private String quotedName(int id) { + return quotedNameById.get(id); + } + + private org.apache.spark.sql.connector.expressions.SortDirection toSpark(SortDirection direction) { + if (direction == SortDirection.ASC) { + return org.apache.spark.sql.connector.expressions.SortDirection.ASCENDING; + } else { + return org.apache.spark.sql.connector.expressions.SortDirection.DESCENDING; + } + } + + private NullOrdering toSpark(NullOrder nullOrder) { + return nullOrder == NullOrder.NULLS_FIRST ? NullOrdering.NULLS_FIRST : NullOrdering.NULLS_LAST; + } +} + diff --git a/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/Spark3Util.java b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/Spark3Util.java new file mode 100644 index 000000000000..b7d29367548f --- /dev/null +++ b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/Spark3Util.java @@ -0,0 +1,894 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iceberg.spark; + +import java.nio.ByteBuffer; +import java.util.Arrays; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.Set; +import java.util.stream.Collectors; +import org.apache.hadoop.fs.FileStatus; +import org.apache.hadoop.fs.Path; +import org.apache.iceberg.NullOrder; +import org.apache.iceberg.PartitionSpec; +import org.apache.iceberg.Schema; +import org.apache.iceberg.TableOperations; +import org.apache.iceberg.TableProperties; +import org.apache.iceberg.UpdateProperties; +import org.apache.iceberg.UpdateSchema; +import org.apache.iceberg.catalog.Catalog; +import org.apache.iceberg.catalog.Namespace; +import org.apache.iceberg.catalog.TableIdentifier; +import org.apache.iceberg.expressions.BoundPredicate; +import org.apache.iceberg.expressions.ExpressionVisitors; +import org.apache.iceberg.expressions.Term; +import org.apache.iceberg.expressions.UnboundPredicate; +import org.apache.iceberg.relocated.com.google.common.base.Joiner; +import org.apache.iceberg.relocated.com.google.common.base.Preconditions; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableList; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableSet; +import org.apache.iceberg.relocated.com.google.common.collect.Lists; +import org.apache.iceberg.relocated.com.google.common.collect.Maps; +import org.apache.iceberg.relocated.com.google.common.io.BaseEncoding; +import org.apache.iceberg.spark.SparkTableUtil.SparkPartition; +import org.apache.iceberg.spark.source.HasIcebergCatalog; +import org.apache.iceberg.spark.source.SparkTable; +import org.apache.iceberg.transforms.PartitionSpecVisitor; +import org.apache.iceberg.transforms.SortOrderVisitor; +import org.apache.iceberg.types.Type; +import org.apache.iceberg.types.TypeUtil; +import org.apache.iceberg.types.Types; +import org.apache.iceberg.util.ByteBuffers; +import org.apache.iceberg.util.Pair; +import org.apache.spark.sql.SparkSession; +import org.apache.spark.sql.catalyst.CatalystTypeConverters; +import org.apache.spark.sql.catalyst.analysis.NoSuchTableException; +import org.apache.spark.sql.catalyst.parser.ParseException; +import org.apache.spark.sql.catalyst.parser.ParserInterface; +import org.apache.spark.sql.connector.catalog.CatalogManager; +import org.apache.spark.sql.connector.catalog.CatalogPlugin; +import org.apache.spark.sql.connector.catalog.CatalogV2Implicits; +import org.apache.spark.sql.connector.catalog.Identifier; +import org.apache.spark.sql.connector.catalog.Table; +import org.apache.spark.sql.connector.catalog.TableCatalog; +import org.apache.spark.sql.connector.catalog.TableChange; +import org.apache.spark.sql.connector.expressions.Expression; +import org.apache.spark.sql.connector.expressions.Expressions; +import org.apache.spark.sql.connector.expressions.Literal; +import org.apache.spark.sql.connector.expressions.NamedReference; +import org.apache.spark.sql.connector.expressions.Transform; +import org.apache.spark.sql.execution.datasources.FileStatusCache; +import org.apache.spark.sql.execution.datasources.InMemoryFileIndex; +import org.apache.spark.sql.execution.datasources.PartitionDirectory; +import org.apache.spark.sql.types.IntegerType; +import org.apache.spark.sql.types.LongType; +import org.apache.spark.sql.types.StructType; +import org.apache.spark.sql.util.CaseInsensitiveStringMap; +import scala.Option; +import scala.collection.JavaConverters; +import scala.collection.immutable.Seq; + +public class Spark3Util { + + private static final Set RESERVED_PROPERTIES = ImmutableSet.of( + TableCatalog.PROP_LOCATION, TableCatalog.PROP_PROVIDER); + private static final Joiner DOT = Joiner.on("."); + + private Spark3Util() { + } + + public static CaseInsensitiveStringMap setOption(String key, String value, CaseInsensitiveStringMap options) { + Map newOptions = Maps.newHashMap(); + newOptions.putAll(options); + newOptions.put(key, value); + return new CaseInsensitiveStringMap(newOptions); + } + + public static Map rebuildCreateProperties(Map createProperties) { + ImmutableMap.Builder tableProperties = ImmutableMap.builder(); + createProperties.entrySet().stream() + .filter(entry -> !RESERVED_PROPERTIES.contains(entry.getKey())) + .forEach(tableProperties::put); + + String provider = createProperties.get(TableCatalog.PROP_PROVIDER); + if ("parquet".equalsIgnoreCase(provider)) { + tableProperties.put(TableProperties.DEFAULT_FILE_FORMAT, "parquet"); + } else if ("avro".equalsIgnoreCase(provider)) { + tableProperties.put(TableProperties.DEFAULT_FILE_FORMAT, "avro"); + } else if ("orc".equalsIgnoreCase(provider)) { + tableProperties.put(TableProperties.DEFAULT_FILE_FORMAT, "orc"); + } else if (provider != null && !"iceberg".equalsIgnoreCase(provider)) { + throw new IllegalArgumentException("Unsupported format in USING: " + provider); + } + + return tableProperties.build(); + } + + /** + * Applies a list of Spark table changes to an {@link UpdateProperties} operation. + * + * @param pendingUpdate an uncommitted UpdateProperties operation to configure + * @param changes a list of Spark table changes + * @return the UpdateProperties operation configured with the changes + */ + public static UpdateProperties applyPropertyChanges(UpdateProperties pendingUpdate, List changes) { + for (TableChange change : changes) { + if (change instanceof TableChange.SetProperty) { + TableChange.SetProperty set = (TableChange.SetProperty) change; + pendingUpdate.set(set.property(), set.value()); + + } else if (change instanceof TableChange.RemoveProperty) { + TableChange.RemoveProperty remove = (TableChange.RemoveProperty) change; + pendingUpdate.remove(remove.property()); + + } else { + throw new UnsupportedOperationException("Cannot apply unknown table change: " + change); + } + } + + return pendingUpdate; + } + + /** + * Applies a list of Spark table changes to an {@link UpdateSchema} operation. + * + * @param pendingUpdate an uncommitted UpdateSchema operation to configure + * @param changes a list of Spark table changes + * @return the UpdateSchema operation configured with the changes + */ + public static UpdateSchema applySchemaChanges(UpdateSchema pendingUpdate, List changes) { + for (TableChange change : changes) { + if (change instanceof TableChange.AddColumn) { + apply(pendingUpdate, (TableChange.AddColumn) change); + + } else if (change instanceof TableChange.UpdateColumnType) { + TableChange.UpdateColumnType update = (TableChange.UpdateColumnType) change; + Type newType = SparkSchemaUtil.convert(update.newDataType()); + Preconditions.checkArgument(newType.isPrimitiveType(), + "Cannot update '%s', not a primitive type: %s", DOT.join(update.fieldNames()), update.newDataType()); + pendingUpdate.updateColumn(DOT.join(update.fieldNames()), newType.asPrimitiveType()); + + } else if (change instanceof TableChange.UpdateColumnComment) { + TableChange.UpdateColumnComment update = (TableChange.UpdateColumnComment) change; + pendingUpdate.updateColumnDoc(DOT.join(update.fieldNames()), update.newComment()); + + } else if (change instanceof TableChange.RenameColumn) { + TableChange.RenameColumn rename = (TableChange.RenameColumn) change; + pendingUpdate.renameColumn(DOT.join(rename.fieldNames()), rename.newName()); + + } else if (change instanceof TableChange.DeleteColumn) { + TableChange.DeleteColumn delete = (TableChange.DeleteColumn) change; + pendingUpdate.deleteColumn(DOT.join(delete.fieldNames())); + + } else if (change instanceof TableChange.UpdateColumnNullability) { + TableChange.UpdateColumnNullability update = (TableChange.UpdateColumnNullability) change; + if (update.nullable()) { + pendingUpdate.makeColumnOptional(DOT.join(update.fieldNames())); + } else { + pendingUpdate.requireColumn(DOT.join(update.fieldNames())); + } + + } else if (change instanceof TableChange.UpdateColumnPosition) { + apply(pendingUpdate, (TableChange.UpdateColumnPosition) change); + + } else { + throw new UnsupportedOperationException("Cannot apply unknown table change: " + change); + } + } + + return pendingUpdate; + } + + private static void apply(UpdateSchema pendingUpdate, TableChange.UpdateColumnPosition update) { + Preconditions.checkArgument(update.position() != null, "Invalid position: null"); + + if (update.position() instanceof TableChange.After) { + TableChange.After after = (TableChange.After) update.position(); + String referenceField = peerName(update.fieldNames(), after.column()); + pendingUpdate.moveAfter(DOT.join(update.fieldNames()), referenceField); + + } else if (update.position() instanceof TableChange.First) { + pendingUpdate.moveFirst(DOT.join(update.fieldNames())); + + } else { + throw new IllegalArgumentException("Unknown position for reorder: " + update.position()); + } + } + + private static void apply(UpdateSchema pendingUpdate, TableChange.AddColumn add) { + Preconditions.checkArgument(add.isNullable(), + "Incompatible change: cannot add required column: %s", leafName(add.fieldNames())); + Type type = SparkSchemaUtil.convert(add.dataType()); + pendingUpdate.addColumn(parentName(add.fieldNames()), leafName(add.fieldNames()), type, add.comment()); + + if (add.position() instanceof TableChange.After) { + TableChange.After after = (TableChange.After) add.position(); + String referenceField = peerName(add.fieldNames(), after.column()); + pendingUpdate.moveAfter(DOT.join(add.fieldNames()), referenceField); + + } else if (add.position() instanceof TableChange.First) { + pendingUpdate.moveFirst(DOT.join(add.fieldNames())); + + } else { + Preconditions.checkArgument(add.position() == null, + "Cannot add '%s' at unknown position: %s", DOT.join(add.fieldNames()), add.position()); + } + } + + public static org.apache.iceberg.Table toIcebergTable(Table table) { + Preconditions.checkArgument(table instanceof SparkTable, "Table %s is not an Iceberg table", table); + SparkTable sparkTable = (SparkTable) table; + return sparkTable.table(); + } + + /** + * Converts a PartitionSpec to Spark transforms. + * + * @param spec a PartitionSpec + * @return an array of Transforms + */ + public static Transform[] toTransforms(PartitionSpec spec) { + Map quotedNameById = SparkSchemaUtil.indexQuotedNameById(spec.schema()); + List transforms = PartitionSpecVisitor.visit(spec, + new PartitionSpecVisitor() { + @Override + public Transform identity(String sourceName, int sourceId) { + return Expressions.identity(quotedName(sourceId)); + } + + @Override + public Transform bucket(String sourceName, int sourceId, int numBuckets) { + return Expressions.bucket(numBuckets, quotedName(sourceId)); + } + + @Override + public Transform truncate(String sourceName, int sourceId, int width) { + return Expressions.apply("truncate", Expressions.column(quotedName(sourceId)), Expressions.literal(width)); + } + + @Override + public Transform year(String sourceName, int sourceId) { + return Expressions.years(quotedName(sourceId)); + } + + @Override + public Transform month(String sourceName, int sourceId) { + return Expressions.months(quotedName(sourceId)); + } + + @Override + public Transform day(String sourceName, int sourceId) { + return Expressions.days(quotedName(sourceId)); + } + + @Override + public Transform hour(String sourceName, int sourceId) { + return Expressions.hours(quotedName(sourceId)); + } + + @Override + public Transform alwaysNull(int fieldId, String sourceName, int sourceId) { + // do nothing for alwaysNull, it doesn't need to be converted to a transform + return null; + } + + @Override + public Transform unknown(int fieldId, String sourceName, int sourceId, String transform) { + return Expressions.apply(transform, Expressions.column(quotedName(sourceId))); + } + + private String quotedName(int id) { + return quotedNameById.get(id); + } + }); + + return transforms.stream().filter(Objects::nonNull).toArray(Transform[]::new); + } + + public static NamedReference toNamedReference(String name) { + return Expressions.column(name); + } + + public static Term toIcebergTerm(Expression expr) { + if (expr instanceof Transform) { + Transform transform = (Transform) expr; + Preconditions.checkArgument(transform.references().length == 1, + "Cannot convert transform with more than one column reference: %s", transform); + String colName = DOT.join(transform.references()[0].fieldNames()); + switch (transform.name()) { + case "identity": + return org.apache.iceberg.expressions.Expressions.ref(colName); + case "bucket": + return org.apache.iceberg.expressions.Expressions.bucket(colName, findWidth(transform)); + case "years": + return org.apache.iceberg.expressions.Expressions.year(colName); + case "months": + return org.apache.iceberg.expressions.Expressions.month(colName); + case "date": + case "days": + return org.apache.iceberg.expressions.Expressions.day(colName); + case "date_hour": + case "hours": + return org.apache.iceberg.expressions.Expressions.hour(colName); + case "truncate": + return org.apache.iceberg.expressions.Expressions.truncate(colName, findWidth(transform)); + default: + throw new UnsupportedOperationException("Transform is not supported: " + transform); + } + + } else if (expr instanceof NamedReference) { + NamedReference ref = (NamedReference) expr; + return org.apache.iceberg.expressions.Expressions.ref(DOT.join(ref.fieldNames())); + + } else { + throw new UnsupportedOperationException("Cannot convert unknown expression: " + expr); + } + } + + /** + * Converts Spark transforms into a {@link PartitionSpec}. + * + * @param schema the table schema + * @param partitioning Spark Transforms + * @return a PartitionSpec + */ + public static PartitionSpec toPartitionSpec(Schema schema, Transform[] partitioning) { + if (partitioning == null || partitioning.length == 0) { + return PartitionSpec.unpartitioned(); + } + + PartitionSpec.Builder builder = PartitionSpec.builderFor(schema); + for (Transform transform : partitioning) { + Preconditions.checkArgument(transform.references().length == 1, + "Cannot convert transform with more than one column reference: %s", transform); + String colName = DOT.join(transform.references()[0].fieldNames()); + switch (transform.name()) { + case "identity": + builder.identity(colName); + break; + case "bucket": + builder.bucket(colName, findWidth(transform)); + break; + case "years": + builder.year(colName); + break; + case "months": + builder.month(colName); + break; + case "date": + case "days": + builder.day(colName); + break; + case "date_hour": + case "hours": + builder.hour(colName); + break; + case "truncate": + builder.truncate(colName, findWidth(transform)); + break; + default: + throw new UnsupportedOperationException("Transform is not supported: " + transform); + } + } + + return builder.build(); + } + + @SuppressWarnings("unchecked") + private static int findWidth(Transform transform) { + for (Expression expr : transform.arguments()) { + if (expr instanceof Literal) { + if (((Literal) expr).dataType() instanceof IntegerType) { + Literal lit = (Literal) expr; + Preconditions.checkArgument(lit.value() > 0, + "Unsupported width for transform: %s", transform.describe()); + return lit.value(); + + } else if (((Literal) expr).dataType() instanceof LongType) { + Literal lit = (Literal) expr; + Preconditions.checkArgument(lit.value() > 0 && lit.value() < Integer.MAX_VALUE, + "Unsupported width for transform: %s", transform.describe()); + if (lit.value() > Integer.MAX_VALUE) { + throw new IllegalArgumentException(); + } + return lit.value().intValue(); + } + } + } + + throw new IllegalArgumentException("Cannot find width for transform: " + transform.describe()); + } + + private static String leafName(String[] fieldNames) { + Preconditions.checkArgument(fieldNames.length > 0, "Invalid field name: at least one name is required"); + return fieldNames[fieldNames.length - 1]; + } + + private static String peerName(String[] fieldNames, String fieldName) { + if (fieldNames.length > 1) { + String[] peerNames = Arrays.copyOf(fieldNames, fieldNames.length); + peerNames[fieldNames.length - 1] = fieldName; + return DOT.join(peerNames); + } + return fieldName; + } + + private static String parentName(String[] fieldNames) { + if (fieldNames.length > 1) { + return DOT.join(Arrays.copyOfRange(fieldNames, 0, fieldNames.length - 1)); + } + return null; + } + + public static String describe(org.apache.iceberg.expressions.Expression expr) { + return ExpressionVisitors.visit(expr, DescribeExpressionVisitor.INSTANCE); + } + + public static String describe(Schema schema) { + return TypeUtil.visit(schema, DescribeSchemaVisitor.INSTANCE); + } + + public static String describe(Type type) { + return TypeUtil.visit(type, DescribeSchemaVisitor.INSTANCE); + } + + public static String describe(org.apache.iceberg.SortOrder order) { + return Joiner.on(", ").join(SortOrderVisitor.visit(order, DescribeSortOrderVisitor.INSTANCE)); + } + + public static boolean extensionsEnabled(SparkSession spark) { + String extensions = spark.conf().get("spark.sql.extensions", ""); + return extensions.contains("IcebergSparkSessionExtensions"); + } + + public static class DescribeSchemaVisitor extends TypeUtil.SchemaVisitor { + private static final Joiner COMMA = Joiner.on(','); + private static final DescribeSchemaVisitor INSTANCE = new DescribeSchemaVisitor(); + + private DescribeSchemaVisitor() { + } + + @Override + public String schema(Schema schema, String structResult) { + return structResult; + } + + @Override + public String struct(Types.StructType struct, List fieldResults) { + return "struct<" + COMMA.join(fieldResults) + ">"; + } + + @Override + public String field(Types.NestedField field, String fieldResult) { + return field.name() + ": " + fieldResult + (field.isRequired() ? " not null" : ""); + } + + @Override + public String list(Types.ListType list, String elementResult) { + return "list<" + elementResult + ">"; + } + + @Override + public String map(Types.MapType map, String keyResult, String valueResult) { + return "map<" + keyResult + ", " + valueResult + ">"; + } + + @Override + public String primitive(Type.PrimitiveType primitive) { + switch (primitive.typeId()) { + case BOOLEAN: + return "boolean"; + case INTEGER: + return "int"; + case LONG: + return "bigint"; + case FLOAT: + return "float"; + case DOUBLE: + return "double"; + case DATE: + return "date"; + case TIME: + return "time"; + case TIMESTAMP: + return "timestamp"; + case STRING: + case UUID: + return "string"; + case FIXED: + case BINARY: + return "binary"; + case DECIMAL: + Types.DecimalType decimal = (Types.DecimalType) primitive; + return "decimal(" + decimal.precision() + "," + decimal.scale() + ")"; + } + throw new UnsupportedOperationException("Cannot convert type to SQL: " + primitive); + } + } + + private static class DescribeExpressionVisitor extends ExpressionVisitors.ExpressionVisitor { + private static final DescribeExpressionVisitor INSTANCE = new DescribeExpressionVisitor(); + + private DescribeExpressionVisitor() { + } + + @Override + public String alwaysTrue() { + return "true"; + } + + @Override + public String alwaysFalse() { + return "false"; + } + + @Override + public String not(String result) { + return "NOT (" + result + ")"; + } + + @Override + public String and(String leftResult, String rightResult) { + return "(" + leftResult + " AND " + rightResult + ")"; + } + + @Override + public String or(String leftResult, String rightResult) { + return "(" + leftResult + " OR " + rightResult + ")"; + } + + @Override + public String predicate(BoundPredicate pred) { + throw new UnsupportedOperationException("Cannot convert bound predicates to SQL"); + } + + @Override + public String predicate(UnboundPredicate pred) { + switch (pred.op()) { + case IS_NULL: + return pred.ref().name() + " IS NULL"; + case NOT_NULL: + return pred.ref().name() + " IS NOT NULL"; + case IS_NAN: + return "is_nan(" + pred.ref().name() + ")"; + case NOT_NAN: + return "not_nan(" + pred.ref().name() + ")"; + case LT: + return pred.ref().name() + " < " + sqlString(pred.literal()); + case LT_EQ: + return pred.ref().name() + " <= " + sqlString(pred.literal()); + case GT: + return pred.ref().name() + " > " + sqlString(pred.literal()); + case GT_EQ: + return pred.ref().name() + " >= " + sqlString(pred.literal()); + case EQ: + return pred.ref().name() + " = " + sqlString(pred.literal()); + case NOT_EQ: + return pred.ref().name() + " != " + sqlString(pred.literal()); + case STARTS_WITH: + return pred.ref().name() + " LIKE '" + pred.literal() + "%'"; + case NOT_STARTS_WITH: + return pred.ref().name() + " NOT LIKE '" + pred.literal() + "%'"; + case IN: + return pred.ref().name() + " IN (" + sqlString(pred.literals()) + ")"; + case NOT_IN: + return pred.ref().name() + " NOT IN (" + sqlString(pred.literals()) + ")"; + default: + throw new UnsupportedOperationException("Cannot convert predicate to SQL: " + pred); + } + } + + private static String sqlString(List> literals) { + return literals.stream().map(DescribeExpressionVisitor::sqlString).collect(Collectors.joining(", ")); + } + + private static String sqlString(org.apache.iceberg.expressions.Literal lit) { + if (lit.value() instanceof String) { + return "'" + lit.value() + "'"; + } else if (lit.value() instanceof ByteBuffer) { + byte[] bytes = ByteBuffers.toByteArray((ByteBuffer) lit.value()); + return "X'" + BaseEncoding.base16().encode(bytes) + "'"; + } else { + return lit.value().toString(); + } + } + } + + /** + * Returns an Iceberg Table by its name from a Spark V2 Catalog. If cache is enabled in {@link SparkCatalog}, + * the {@link TableOperations} of the table may be stale, please refresh the table to get the latest one. + * + * @param spark SparkSession used for looking up catalog references and tables + * @param name The multipart identifier of the Iceberg table + * @return an Iceberg table + */ + public static org.apache.iceberg.Table loadIcebergTable(SparkSession spark, String name) + throws ParseException, NoSuchTableException { + CatalogAndIdentifier catalogAndIdentifier = catalogAndIdentifier(spark, name); + + TableCatalog catalog = asTableCatalog(catalogAndIdentifier.catalog); + Table sparkTable = catalog.loadTable(catalogAndIdentifier.identifier); + return toIcebergTable(sparkTable); + } + + /** + * Returns the underlying Iceberg Catalog object represented by a Spark Catalog + * @param spark SparkSession used for looking up catalog reference + * @param catalogName The name of the Spark Catalog being referenced + * @return the Iceberg catalog class being wrapped by the Spark Catalog + */ + public static Catalog loadIcebergCatalog(SparkSession spark, String catalogName) { + CatalogPlugin catalogPlugin = spark.sessionState().catalogManager().catalog(catalogName); + Preconditions.checkArgument(catalogPlugin instanceof HasIcebergCatalog, + String.format("Cannot load Iceberg catalog from catalog %s because it does not contain an Iceberg Catalog. " + + "Actual Class: %s", + catalogName, catalogPlugin.getClass().getName())); + return ((HasIcebergCatalog) catalogPlugin).icebergCatalog(); + } + + + public static CatalogAndIdentifier catalogAndIdentifier(SparkSession spark, String name) throws ParseException { + return catalogAndIdentifier(spark, name, spark.sessionState().catalogManager().currentCatalog()); + } + + public static CatalogAndIdentifier catalogAndIdentifier(SparkSession spark, String name, + CatalogPlugin defaultCatalog) throws ParseException { + ParserInterface parser = spark.sessionState().sqlParser(); + Seq multiPartIdentifier = parser.parseMultipartIdentifier(name).toIndexedSeq(); + List javaMultiPartIdentifier = JavaConverters.seqAsJavaList(multiPartIdentifier); + return catalogAndIdentifier(spark, javaMultiPartIdentifier, defaultCatalog); + } + + public static CatalogAndIdentifier catalogAndIdentifier(String description, SparkSession spark, String name) { + return catalogAndIdentifier(description, spark, name, spark.sessionState().catalogManager().currentCatalog()); + } + + public static CatalogAndIdentifier catalogAndIdentifier(String description, SparkSession spark, + String name, CatalogPlugin defaultCatalog) { + try { + return catalogAndIdentifier(spark, name, defaultCatalog); + } catch (ParseException e) { + throw new IllegalArgumentException("Cannot parse " + description + ": " + name, e); + } + } + + public static CatalogAndIdentifier catalogAndIdentifier(SparkSession spark, List nameParts) { + return catalogAndIdentifier(spark, nameParts, spark.sessionState().catalogManager().currentCatalog()); + } + + /** + * A modified version of Spark's LookupCatalog.CatalogAndIdentifier.unapply + * Attempts to find the catalog and identifier a multipart identifier represents + * @param spark Spark session to use for resolution + * @param nameParts Multipart identifier representing a table + * @param defaultCatalog Catalog to use if none is specified + * @return The CatalogPlugin and Identifier for the table + */ + public static CatalogAndIdentifier catalogAndIdentifier(SparkSession spark, List nameParts, + CatalogPlugin defaultCatalog) { + CatalogManager catalogManager = spark.sessionState().catalogManager(); + + String[] currentNamespace; + if (defaultCatalog.equals(catalogManager.currentCatalog())) { + currentNamespace = catalogManager.currentNamespace(); + } else { + currentNamespace = defaultCatalog.defaultNamespace(); + } + + Pair catalogIdentifier = SparkUtil.catalogAndIdentifier(nameParts, + catalogName -> { + try { + return catalogManager.catalog(catalogName); + } catch (Exception e) { + return null; + } + }, + Identifier::of, + defaultCatalog, + currentNamespace + ); + return new CatalogAndIdentifier(catalogIdentifier); + } + + private static TableCatalog asTableCatalog(CatalogPlugin catalog) { + if (catalog instanceof TableCatalog) { + return (TableCatalog) catalog; + } + + throw new IllegalArgumentException(String.format( + "Cannot use catalog %s(%s): not a TableCatalog", catalog.name(), catalog.getClass().getName())); + } + + /** + * This mimics a class inside of Spark which is private inside of LookupCatalog. + */ + public static class CatalogAndIdentifier { + private final CatalogPlugin catalog; + private final Identifier identifier; + + + public CatalogAndIdentifier(CatalogPlugin catalog, Identifier identifier) { + this.catalog = catalog; + this.identifier = identifier; + } + + public CatalogAndIdentifier(Pair identifier) { + this.catalog = identifier.first(); + this.identifier = identifier.second(); + } + + public CatalogPlugin catalog() { + return catalog; + } + + public Identifier identifier() { + return identifier; + } + } + + public static TableIdentifier identifierToTableIdentifier(Identifier identifier) { + return TableIdentifier.of(Namespace.of(identifier.namespace()), identifier.name()); + } + + public static String quotedFullIdentifier(String catalogName, Identifier identifier) { + List parts = + ImmutableList.builder() + .add(catalogName) + .addAll(Arrays.asList(identifier.namespace())) + .add(identifier.name()) + .build(); + + return CatalogV2Implicits.MultipartIdentifierHelper( + JavaConverters.asScalaIteratorConverter(parts.iterator()).asScala().toSeq() + ).quoted(); + } + + /** + * Use Spark to list all partitions in the table. + * + * @param spark a Spark session + * @param rootPath a table identifier + * @param format format of the file + * @param partitionFilter partitionFilter of the file + * @return all table's partitions + */ + public static List getPartitions(SparkSession spark, Path rootPath, String format, + Map partitionFilter) { + FileStatusCache fileStatusCache = FileStatusCache.getOrCreate(spark); + + InMemoryFileIndex fileIndex = new InMemoryFileIndex( + spark, + JavaConverters + .collectionAsScalaIterableConverter(ImmutableList.of(rootPath)) + .asScala() + .toSeq(), + scala.collection.immutable.Map$.MODULE$.empty(), + Option.empty(), + fileStatusCache, + Option.empty(), + Option.empty()); + + org.apache.spark.sql.execution.datasources.PartitionSpec spec = fileIndex.partitionSpec(); + StructType schema = spec.partitionColumns(); + if (schema.isEmpty()) { + return Lists.newArrayList(); + } + + List filterExpressions = + SparkUtil.partitionMapToExpression(schema, partitionFilter); + Seq scalaPartitionFilters = + JavaConverters.asScalaBufferConverter(filterExpressions).asScala().toIndexedSeq(); + + List dataFilters = Lists.newArrayList(); + Seq scalaDataFilters = + JavaConverters.asScalaBufferConverter(dataFilters).asScala().toIndexedSeq(); + + Seq filteredPartitions = + fileIndex.listFiles(scalaPartitionFilters, scalaDataFilters).toIndexedSeq(); + + return JavaConverters + .seqAsJavaListConverter(filteredPartitions) + .asJava() + .stream() + .map(partition -> { + Map values = Maps.newHashMap(); + JavaConverters.asJavaIterableConverter(schema).asJava().forEach(field -> { + int fieldIndex = schema.fieldIndex(field.name()); + Object catalystValue = partition.values().get(fieldIndex, field.dataType()); + Object value = CatalystTypeConverters.convertToScala(catalystValue, field.dataType()); + values.put(field.name(), String.valueOf(value)); + }); + + FileStatus fileStatus = + JavaConverters.seqAsJavaListConverter(partition.files()).asJava().get(0); + + return new SparkPartition(values, fileStatus.getPath().getParent().toString(), format); + }).collect(Collectors.toList()); + } + + public static org.apache.spark.sql.catalyst.TableIdentifier toV1TableIdentifier(Identifier identifier) { + String[] namespace = identifier.namespace(); + + Preconditions.checkArgument(namespace.length <= 1, + "Cannot convert %s to a Spark v1 identifier, namespace contains more than 1 part", identifier); + + String table = identifier.name(); + Option database = namespace.length == 1 ? Option.apply(namespace[0]) : Option.empty(); + return org.apache.spark.sql.catalyst.TableIdentifier.apply(table, database); + } + + private static class DescribeSortOrderVisitor implements SortOrderVisitor { + private static final DescribeSortOrderVisitor INSTANCE = new DescribeSortOrderVisitor(); + + private DescribeSortOrderVisitor() { + } + + @Override + public String field(String sourceName, int sourceId, + org.apache.iceberg.SortDirection direction, NullOrder nullOrder) { + return String.format("%s %s %s", sourceName, direction, nullOrder); + } + + @Override + public String bucket(String sourceName, int sourceId, int numBuckets, + org.apache.iceberg.SortDirection direction, NullOrder nullOrder) { + return String.format("bucket(%s, %s) %s %s", numBuckets, sourceName, direction, nullOrder); + } + + @Override + public String truncate(String sourceName, int sourceId, int width, + org.apache.iceberg.SortDirection direction, NullOrder nullOrder) { + return String.format("truncate(%s, %s) %s %s", sourceName, width, direction, nullOrder); + } + + @Override + public String year(String sourceName, int sourceId, + org.apache.iceberg.SortDirection direction, NullOrder nullOrder) { + return String.format("years(%s) %s %s", sourceName, direction, nullOrder); + } + + @Override + public String month(String sourceName, int sourceId, + org.apache.iceberg.SortDirection direction, NullOrder nullOrder) { + return String.format("months(%s) %s %s", sourceName, direction, nullOrder); + } + + @Override + public String day(String sourceName, int sourceId, + org.apache.iceberg.SortDirection direction, NullOrder nullOrder) { + return String.format("days(%s) %s %s", sourceName, direction, nullOrder); + } + + @Override + public String hour(String sourceName, int sourceId, + org.apache.iceberg.SortDirection direction, NullOrder nullOrder) { + return String.format("hours(%s) %s %s", sourceName, direction, nullOrder); + } + + @Override + public String unknown(String sourceName, int sourceId, String transform, + org.apache.iceberg.SortDirection direction, NullOrder nullOrder) { + return String.format("%s(%s) %s %s", transform, sourceName, direction, nullOrder); + } + } +} diff --git a/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/SparkCatalog.java b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/SparkCatalog.java new file mode 100644 index 000000000000..a1ab3cced3f0 --- /dev/null +++ b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/SparkCatalog.java @@ -0,0 +1,654 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iceberg.spark; + +import java.util.Arrays; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.TreeMap; +import java.util.regex.Matcher; +import java.util.regex.Pattern; +import org.apache.hadoop.conf.Configuration; +import org.apache.iceberg.CachingCatalog; +import org.apache.iceberg.CatalogProperties; +import org.apache.iceberg.CatalogUtil; +import org.apache.iceberg.HasTableOperations; +import org.apache.iceberg.MetadataTableType; +import org.apache.iceberg.Schema; +import org.apache.iceberg.Table; +import org.apache.iceberg.Transaction; +import org.apache.iceberg.catalog.Catalog; +import org.apache.iceberg.catalog.Namespace; +import org.apache.iceberg.catalog.SupportsNamespaces; +import org.apache.iceberg.catalog.TableIdentifier; +import org.apache.iceberg.exceptions.AlreadyExistsException; +import org.apache.iceberg.exceptions.ValidationException; +import org.apache.iceberg.hadoop.HadoopCatalog; +import org.apache.iceberg.hadoop.HadoopTables; +import org.apache.iceberg.relocated.com.google.common.base.Preconditions; +import org.apache.iceberg.relocated.com.google.common.base.Splitter; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableList; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableSet; +import org.apache.iceberg.relocated.com.google.common.collect.Lists; +import org.apache.iceberg.relocated.com.google.common.collect.Maps; +import org.apache.iceberg.relocated.com.google.common.collect.Sets; +import org.apache.iceberg.spark.actions.SparkActions; +import org.apache.iceberg.spark.source.SparkTable; +import org.apache.iceberg.spark.source.StagedSparkTable; +import org.apache.iceberg.util.Pair; +import org.apache.iceberg.util.PropertyUtil; +import org.apache.iceberg.util.SnapshotUtil; +import org.apache.spark.sql.SparkSession; +import org.apache.spark.sql.catalyst.analysis.NamespaceAlreadyExistsException; +import org.apache.spark.sql.catalyst.analysis.NoSuchNamespaceException; +import org.apache.spark.sql.catalyst.analysis.NoSuchTableException; +import org.apache.spark.sql.catalyst.analysis.TableAlreadyExistsException; +import org.apache.spark.sql.connector.catalog.Identifier; +import org.apache.spark.sql.connector.catalog.NamespaceChange; +import org.apache.spark.sql.connector.catalog.StagedTable; +import org.apache.spark.sql.connector.catalog.TableCatalog; +import org.apache.spark.sql.connector.catalog.TableChange; +import org.apache.spark.sql.connector.catalog.TableChange.ColumnChange; +import org.apache.spark.sql.connector.catalog.TableChange.RemoveProperty; +import org.apache.spark.sql.connector.catalog.TableChange.SetProperty; +import org.apache.spark.sql.connector.expressions.Transform; +import org.apache.spark.sql.types.StructType; +import org.apache.spark.sql.util.CaseInsensitiveStringMap; + +import static org.apache.iceberg.TableProperties.GC_ENABLED; +import static org.apache.iceberg.TableProperties.GC_ENABLED_DEFAULT; + +/** + * A Spark TableCatalog implementation that wraps an Iceberg {@link Catalog}. + *

+ * This supports the following catalog configuration options: + *

+ *

+ * To use a custom catalog that is not a Hive or Hadoop catalog, extend this class and override + * {@link #buildIcebergCatalog(String, CaseInsensitiveStringMap)}. + */ +public class SparkCatalog extends BaseCatalog { + private static final Set DEFAULT_NS_KEYS = ImmutableSet.of(TableCatalog.PROP_OWNER); + private static final Splitter COMMA = Splitter.on(","); + private static final Pattern AT_TIMESTAMP = Pattern.compile("at_timestamp_(\\d+)"); + private static final Pattern SNAPSHOT_ID = Pattern.compile("snapshot_id_(\\d+)"); + + private String catalogName = null; + private Catalog icebergCatalog = null; + private boolean cacheEnabled = CatalogProperties.CACHE_ENABLED_DEFAULT; + private SupportsNamespaces asNamespaceCatalog = null; + private String[] defaultNamespace = null; + private HadoopTables tables; + private boolean useTimestampsWithoutZone; + + /** + * Build an Iceberg {@link Catalog} to be used by this Spark catalog adapter. + * + * @param name Spark's catalog name + * @param options Spark's catalog options + * @return an Iceberg catalog + */ + protected Catalog buildIcebergCatalog(String name, CaseInsensitiveStringMap options) { + Configuration conf = SparkUtil.hadoopConfCatalogOverrides(SparkSession.active(), name); + Map optionsMap = new TreeMap<>(String.CASE_INSENSITIVE_ORDER); + optionsMap.putAll(options); + optionsMap.put(CatalogProperties.APP_ID, SparkSession.active().sparkContext().applicationId()); + optionsMap.put(CatalogProperties.USER, SparkSession.active().sparkContext().sparkUser()); + return CatalogUtil.buildIcebergCatalog(name, optionsMap, conf); + } + + /** + * Build an Iceberg {@link TableIdentifier} for the given Spark identifier. + * + * @param identifier Spark's identifier + * @return an Iceberg identifier + */ + protected TableIdentifier buildIdentifier(Identifier identifier) { + return Spark3Util.identifierToTableIdentifier(identifier); + } + + @Override + public SparkTable loadTable(Identifier ident) throws NoSuchTableException { + try { + Pair icebergTable = load(ident); + return new SparkTable(icebergTable.first(), icebergTable.second(), !cacheEnabled); + } catch (org.apache.iceberg.exceptions.NoSuchTableException e) { + throw new NoSuchTableException(ident); + } + } + + @Override + public SparkTable loadTable(Identifier ident, String version) throws NoSuchTableException { + try { + Pair icebergTable = load(ident); + Preconditions.checkArgument(icebergTable.second() == null, + "Cannot do time-travel based on both table identifier and AS OF"); + return new SparkTable(icebergTable.first(), Long.parseLong(version), !cacheEnabled); + } catch (org.apache.iceberg.exceptions.NoSuchTableException e) { + throw new NoSuchTableException(ident); + } + } + + @Override + public SparkTable loadTable(Identifier ident, long timestamp) throws NoSuchTableException { + try { + Pair icebergTable = load(ident); + Preconditions.checkArgument(icebergTable.second() == null, + "Cannot do time-travel based on both table identifier and AS OF"); + long snapshotIdAsOfTime = SnapshotUtil.snapshotIdAsOfTime(icebergTable.first(), timestamp); + return new SparkTable(icebergTable.first(), snapshotIdAsOfTime, !cacheEnabled); + } catch (org.apache.iceberg.exceptions.NoSuchTableException e) { + throw new NoSuchTableException(ident); + } + } + + @Override + public SparkTable createTable(Identifier ident, StructType schema, + Transform[] transforms, + Map properties) throws TableAlreadyExistsException { + Schema icebergSchema = SparkSchemaUtil.convert(schema, useTimestampsWithoutZone); + try { + Catalog.TableBuilder builder = newBuilder(ident, icebergSchema); + Table icebergTable = builder + .withPartitionSpec(Spark3Util.toPartitionSpec(icebergSchema, transforms)) + .withLocation(properties.get("location")) + .withProperties(Spark3Util.rebuildCreateProperties(properties)) + .create(); + return new SparkTable(icebergTable, !cacheEnabled); + } catch (AlreadyExistsException e) { + throw new TableAlreadyExistsException(ident); + } + } + + @Override + public StagedTable stageCreate(Identifier ident, StructType schema, Transform[] transforms, + Map properties) throws TableAlreadyExistsException { + Schema icebergSchema = SparkSchemaUtil.convert(schema, useTimestampsWithoutZone); + try { + Catalog.TableBuilder builder = newBuilder(ident, icebergSchema); + Transaction transaction = builder.withPartitionSpec(Spark3Util.toPartitionSpec(icebergSchema, transforms)) + .withLocation(properties.get("location")) + .withProperties(Spark3Util.rebuildCreateProperties(properties)) + .createTransaction(); + return new StagedSparkTable(transaction); + } catch (AlreadyExistsException e) { + throw new TableAlreadyExistsException(ident); + } + } + + @Override + public StagedTable stageReplace(Identifier ident, StructType schema, Transform[] transforms, + Map properties) throws NoSuchTableException { + Schema icebergSchema = SparkSchemaUtil.convert(schema, useTimestampsWithoutZone); + try { + Catalog.TableBuilder builder = newBuilder(ident, icebergSchema); + Transaction transaction = builder.withPartitionSpec(Spark3Util.toPartitionSpec(icebergSchema, transforms)) + .withLocation(properties.get("location")) + .withProperties(Spark3Util.rebuildCreateProperties(properties)) + .replaceTransaction(); + return new StagedSparkTable(transaction); + } catch (org.apache.iceberg.exceptions.NoSuchTableException e) { + throw new NoSuchTableException(ident); + } + } + + @Override + public StagedTable stageCreateOrReplace(Identifier ident, StructType schema, Transform[] transforms, + Map properties) { + Schema icebergSchema = SparkSchemaUtil.convert(schema, useTimestampsWithoutZone); + Catalog.TableBuilder builder = newBuilder(ident, icebergSchema); + Transaction transaction = builder.withPartitionSpec(Spark3Util.toPartitionSpec(icebergSchema, transforms)) + .withLocation(properties.get("location")) + .withProperties(Spark3Util.rebuildCreateProperties(properties)) + .createOrReplaceTransaction(); + return new StagedSparkTable(transaction); + } + + @Override + public SparkTable alterTable(Identifier ident, TableChange... changes) throws NoSuchTableException { + SetProperty setLocation = null; + SetProperty setSnapshotId = null; + SetProperty pickSnapshotId = null; + List propertyChanges = Lists.newArrayList(); + List schemaChanges = Lists.newArrayList(); + + for (TableChange change : changes) { + if (change instanceof SetProperty) { + SetProperty set = (SetProperty) change; + if (TableCatalog.PROP_LOCATION.equalsIgnoreCase(set.property())) { + setLocation = set; + } else if ("current-snapshot-id".equalsIgnoreCase(set.property())) { + setSnapshotId = set; + } else if ("cherry-pick-snapshot-id".equalsIgnoreCase(set.property())) { + pickSnapshotId = set; + } else if ("sort-order".equalsIgnoreCase(set.property())) { + throw new UnsupportedOperationException("Cannot specify the 'sort-order' because it's a reserved table " + + "property. Please use the command 'ALTER TABLE ... WRITE ORDERED BY' to specify write sort-orders."); + } else { + propertyChanges.add(set); + } + } else if (change instanceof RemoveProperty) { + propertyChanges.add(change); + } else if (change instanceof ColumnChange) { + schemaChanges.add(change); + } else { + throw new UnsupportedOperationException("Cannot apply unknown table change: " + change); + } + } + + try { + Table table = load(ident).first(); + commitChanges(table, setLocation, setSnapshotId, pickSnapshotId, propertyChanges, schemaChanges); + return new SparkTable(table, true /* refreshEagerly */); + } catch (org.apache.iceberg.exceptions.NoSuchTableException e) { + throw new NoSuchTableException(ident); + } + } + + @Override + public boolean dropTable(Identifier ident) { + return dropTableWithoutPurging(ident); + } + + @Override + public boolean purgeTable(Identifier ident) { + try { + Table table = load(ident).first(); + ValidationException.check( + PropertyUtil.propertyAsBoolean(table.properties(), GC_ENABLED, GC_ENABLED_DEFAULT), + "Cannot purge table: GC is disabled (deleting files may corrupt other tables)"); + String metadataFileLocation = ((HasTableOperations) table).operations().current().metadataFileLocation(); + + boolean dropped = dropTableWithoutPurging(ident); + + if (dropped) { + // We should check whether the metadata file exists. Because the HadoopCatalog/HadoopTables will drop the + // warehouse directly and ignore the `purge` argument. + boolean metadataFileExists = table.io().newInputFile(metadataFileLocation).exists(); + + if (metadataFileExists) { + SparkActions.get() + .deleteReachableFiles(metadataFileLocation) + .io(table.io()) + .execute(); + } + } + + return dropped; + } catch (org.apache.iceberg.exceptions.NoSuchTableException e) { + return false; + } + } + + private boolean dropTableWithoutPurging(Identifier ident) { + if (isPathIdentifier(ident)) { + return tables.dropTable(((PathIdentifier) ident).location(), false /* don't purge data */); + } else { + return icebergCatalog.dropTable(buildIdentifier(ident), false /* don't purge data */); + } + } + + @Override + public void renameTable(Identifier from, Identifier to) throws NoSuchTableException, TableAlreadyExistsException { + try { + checkNotPathIdentifier(from, "renameTable"); + checkNotPathIdentifier(to, "renameTable"); + icebergCatalog.renameTable(buildIdentifier(from), buildIdentifier(to)); + } catch (org.apache.iceberg.exceptions.NoSuchTableException e) { + throw new NoSuchTableException(from); + } catch (AlreadyExistsException e) { + throw new TableAlreadyExistsException(to); + } + } + + @Override + public void invalidateTable(Identifier ident) { + if (!isPathIdentifier(ident)) { + icebergCatalog.invalidateTable(buildIdentifier(ident)); + } + } + + @Override + public Identifier[] listTables(String[] namespace) { + return icebergCatalog.listTables(Namespace.of(namespace)).stream() + .map(ident -> Identifier.of(ident.namespace().levels(), ident.name())) + .toArray(Identifier[]::new); + } + + @Override + public String[] defaultNamespace() { + if (defaultNamespace != null) { + return defaultNamespace; + } + + return new String[0]; + } + + @Override + public String[][] listNamespaces() { + if (asNamespaceCatalog != null) { + return asNamespaceCatalog.listNamespaces().stream() + .map(Namespace::levels) + .toArray(String[][]::new); + } + + return new String[0][]; + } + + @Override + public String[][] listNamespaces(String[] namespace) throws NoSuchNamespaceException { + if (asNamespaceCatalog != null) { + try { + return asNamespaceCatalog.listNamespaces(Namespace.of(namespace)).stream() + .map(Namespace::levels) + .toArray(String[][]::new); + } catch (org.apache.iceberg.exceptions.NoSuchNamespaceException e) { + throw new NoSuchNamespaceException(namespace); + } + } + + throw new NoSuchNamespaceException(namespace); + } + + @Override + public Map loadNamespaceMetadata(String[] namespace) throws NoSuchNamespaceException { + if (asNamespaceCatalog != null) { + try { + return asNamespaceCatalog.loadNamespaceMetadata(Namespace.of(namespace)); + } catch (org.apache.iceberg.exceptions.NoSuchNamespaceException e) { + throw new NoSuchNamespaceException(namespace); + } + } + + throw new NoSuchNamespaceException(namespace); + } + + @Override + public void createNamespace(String[] namespace, Map metadata) throws NamespaceAlreadyExistsException { + if (asNamespaceCatalog != null) { + try { + if (asNamespaceCatalog instanceof HadoopCatalog && DEFAULT_NS_KEYS.equals(metadata.keySet())) { + // Hadoop catalog will reject metadata properties, but Spark automatically adds "owner". + // If only the automatic properties are present, replace metadata with an empty map. + asNamespaceCatalog.createNamespace(Namespace.of(namespace), ImmutableMap.of()); + } else { + asNamespaceCatalog.createNamespace(Namespace.of(namespace), metadata); + } + } catch (AlreadyExistsException e) { + throw new NamespaceAlreadyExistsException(namespace); + } + } else { + throw new UnsupportedOperationException("Namespaces are not supported by catalog: " + catalogName); + } + } + + @Override + public void alterNamespace(String[] namespace, NamespaceChange... changes) throws NoSuchNamespaceException { + if (asNamespaceCatalog != null) { + Map updates = Maps.newHashMap(); + Set removals = Sets.newHashSet(); + for (NamespaceChange change : changes) { + if (change instanceof NamespaceChange.SetProperty) { + NamespaceChange.SetProperty set = (NamespaceChange.SetProperty) change; + updates.put(set.property(), set.value()); + } else if (change instanceof NamespaceChange.RemoveProperty) { + removals.add(((NamespaceChange.RemoveProperty) change).property()); + } else { + throw new UnsupportedOperationException("Cannot apply unknown namespace change: " + change); + } + } + + try { + if (!updates.isEmpty()) { + asNamespaceCatalog.setProperties(Namespace.of(namespace), updates); + } + + if (!removals.isEmpty()) { + asNamespaceCatalog.removeProperties(Namespace.of(namespace), removals); + } + + } catch (org.apache.iceberg.exceptions.NoSuchNamespaceException e) { + throw new NoSuchNamespaceException(namespace); + } + } else { + throw new NoSuchNamespaceException(namespace); + } + } + + @Override + public boolean dropNamespace(String[] namespace, boolean cascade) throws NoSuchNamespaceException { + if (asNamespaceCatalog != null) { + try { + return asNamespaceCatalog.dropNamespace(Namespace.of(namespace)); + } catch (org.apache.iceberg.exceptions.NoSuchNamespaceException e) { + throw new NoSuchNamespaceException(namespace); + } + } + + return false; + } + + @Override + public final void initialize(String name, CaseInsensitiveStringMap options) { + this.cacheEnabled = PropertyUtil.propertyAsBoolean(options, + CatalogProperties.CACHE_ENABLED, CatalogProperties.CACHE_ENABLED_DEFAULT); + + long cacheExpirationIntervalMs = PropertyUtil.propertyAsLong(options, + CatalogProperties.CACHE_EXPIRATION_INTERVAL_MS, + CatalogProperties.CACHE_EXPIRATION_INTERVAL_MS_DEFAULT); + + // An expiration interval of 0ms effectively disables caching. + // Do not wrap with CachingCatalog. + if (cacheExpirationIntervalMs == 0) { + this.cacheEnabled = false; + } + + Catalog catalog = buildIcebergCatalog(name, options); + + this.catalogName = name; + SparkSession sparkSession = SparkSession.active(); + this.useTimestampsWithoutZone = SparkUtil.useTimestampWithoutZoneInNewTables(sparkSession.conf()); + this.tables = new HadoopTables(SparkUtil.hadoopConfCatalogOverrides(SparkSession.active(), name)); + this.icebergCatalog = cacheEnabled ? CachingCatalog.wrap(catalog, cacheExpirationIntervalMs) : catalog; + if (catalog instanceof SupportsNamespaces) { + this.asNamespaceCatalog = (SupportsNamespaces) catalog; + if (options.containsKey("default-namespace")) { + this.defaultNamespace = Splitter.on('.') + .splitToList(options.get("default-namespace")) + .toArray(new String[0]); + } + } + } + + @Override + public String name() { + return catalogName; + } + + private static void commitChanges(Table table, SetProperty setLocation, SetProperty setSnapshotId, + SetProperty pickSnapshotId, List propertyChanges, + List schemaChanges) { + // don't allow setting the snapshot and picking a commit at the same time because order is ambiguous and choosing + // one order leads to different results + Preconditions.checkArgument(setSnapshotId == null || pickSnapshotId == null, + "Cannot set the current the current snapshot ID and cherry-pick snapshot changes"); + + if (setSnapshotId != null) { + long newSnapshotId = Long.parseLong(setSnapshotId.value()); + table.manageSnapshots().setCurrentSnapshot(newSnapshotId).commit(); + } + + // if updating the table snapshot, perform that update first in case it fails + if (pickSnapshotId != null) { + long newSnapshotId = Long.parseLong(pickSnapshotId.value()); + table.manageSnapshots().cherrypick(newSnapshotId).commit(); + } + + Transaction transaction = table.newTransaction(); + + if (setLocation != null) { + transaction.updateLocation() + .setLocation(setLocation.value()) + .commit(); + } + + if (!propertyChanges.isEmpty()) { + Spark3Util.applyPropertyChanges(transaction.updateProperties(), propertyChanges).commit(); + } + + if (!schemaChanges.isEmpty()) { + Spark3Util.applySchemaChanges(transaction.updateSchema(), schemaChanges).commit(); + } + + transaction.commitTransaction(); + } + + private static boolean isPathIdentifier(Identifier ident) { + return ident instanceof PathIdentifier; + } + + private static void checkNotPathIdentifier(Identifier identifier, String method) { + if (identifier instanceof PathIdentifier) { + throw new IllegalArgumentException(String.format("Cannot pass path based identifier to %s method. %s is a path.", + method, identifier)); + } + } + + private Pair load(Identifier ident) { + if (isPathIdentifier(ident)) { + return loadFromPathIdentifier((PathIdentifier) ident); + } + + try { + return Pair.of(icebergCatalog.loadTable(buildIdentifier(ident)), null); + + } catch (org.apache.iceberg.exceptions.NoSuchTableException e) { + if (ident.namespace().length == 0) { + throw e; + } + + // if the original load didn't work, the identifier may be extended and include a snapshot selector + TableIdentifier namespaceAsIdent = buildIdentifier(namespaceToIdentifier(ident.namespace())); + Table table; + try { + table = icebergCatalog.loadTable(namespaceAsIdent); + } catch (Exception ignored) { + // the namespace does not identify a table, so it cannot be a table with a snapshot selector + // throw the original exception + throw e; + } + + // loading the namespace as a table worked, check the name to see if it is a valid selector + Matcher at = AT_TIMESTAMP.matcher(ident.name()); + if (at.matches()) { + long asOfTimestamp = Long.parseLong(at.group(1)); + return Pair.of(table, SnapshotUtil.snapshotIdAsOfTime(table, asOfTimestamp)); + } + + Matcher id = SNAPSHOT_ID.matcher(ident.name()); + if (id.matches()) { + long snapshotId = Long.parseLong(id.group(1)); + return Pair.of(table, snapshotId); + } + + // the name wasn't a valid snapshot selector. throw the original exception + throw e; + } + } + + private Pair> parseLocationString(String location) { + int hashIndex = location.lastIndexOf('#'); + if (hashIndex != -1 && !location.endsWith("#")) { + String baseLocation = location.substring(0, hashIndex); + List metadata = COMMA.splitToList(location.substring(hashIndex + 1)); + return Pair.of(baseLocation, metadata); + } else { + return Pair.of(location, ImmutableList.of()); + } + } + + private Pair loadFromPathIdentifier(PathIdentifier ident) { + Pair> parsed = parseLocationString(ident.location()); + + String metadataTableName = null; + Long asOfTimestamp = null; + Long snapshotId = null; + for (String meta : parsed.second()) { + if (MetadataTableType.from(meta) != null) { + metadataTableName = meta; + continue; + } + + Matcher at = AT_TIMESTAMP.matcher(meta); + if (at.matches()) { + asOfTimestamp = Long.parseLong(at.group(1)); + continue; + } + + Matcher id = SNAPSHOT_ID.matcher(meta); + if (id.matches()) { + snapshotId = Long.parseLong(id.group(1)); + } + } + + Preconditions.checkArgument(asOfTimestamp == null || snapshotId == null, + "Cannot specify both snapshot-id and as-of-timestamp: %s", ident.location()); + + Table table = tables.load(parsed.first() + (metadataTableName != null ? "#" + metadataTableName : "")); + + if (snapshotId != null) { + return Pair.of(table, snapshotId); + } else if (asOfTimestamp != null) { + return Pair.of(table, SnapshotUtil.snapshotIdAsOfTime(table, asOfTimestamp)); + } else { + return Pair.of(table, null); + } + } + + private Identifier namespaceToIdentifier(String[] namespace) { + Preconditions.checkArgument(namespace.length > 0, + "Cannot convert empty namespace to identifier"); + String[] ns = Arrays.copyOf(namespace, namespace.length - 1); + String name = namespace[ns.length]; + return Identifier.of(ns, name); + } + + private Catalog.TableBuilder newBuilder(Identifier ident, Schema schema) { + return isPathIdentifier(ident) ? + tables.buildTable(((PathIdentifier) ident).location(), schema) : + icebergCatalog.buildTable(buildIdentifier(ident), schema); + } + + @Override + public Catalog icebergCatalog() { + return icebergCatalog; + } +} diff --git a/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/SparkConfParser.java b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/SparkConfParser.java new file mode 100644 index 000000000000..f3d89467fcf7 --- /dev/null +++ b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/SparkConfParser.java @@ -0,0 +1,203 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iceberg.spark; + +import java.util.List; +import java.util.Locale; +import java.util.Map; +import java.util.function.Function; +import org.apache.iceberg.Table; +import org.apache.iceberg.relocated.com.google.common.base.Preconditions; +import org.apache.iceberg.relocated.com.google.common.collect.Lists; +import org.apache.spark.sql.RuntimeConfig; +import org.apache.spark.sql.SparkSession; + +class SparkConfParser { + + private final Map properties; + private final RuntimeConfig sessionConf; + private final Map options; + + SparkConfParser(SparkSession spark, Table table, Map options) { + this.properties = table.properties(); + this.sessionConf = spark.conf(); + this.options = options; + } + + public BooleanConfParser booleanConf() { + return new BooleanConfParser(); + } + + public IntConfParser intConf() { + return new IntConfParser(); + } + + public LongConfParser longConf() { + return new LongConfParser(); + } + + public StringConfParser stringConf() { + return new StringConfParser(); + } + + class BooleanConfParser extends ConfParser { + private Boolean defaultValue; + + @Override + protected BooleanConfParser self() { + return this; + } + + public BooleanConfParser defaultValue(boolean value) { + this.defaultValue = value; + return self(); + } + + public BooleanConfParser defaultValue(String value) { + this.defaultValue = Boolean.parseBoolean(value); + return self(); + } + + public boolean parse() { + Preconditions.checkArgument(defaultValue != null, "Default value cannot be null"); + return parse(Boolean::parseBoolean, defaultValue); + } + } + + class IntConfParser extends ConfParser { + private Integer defaultValue; + + @Override + protected IntConfParser self() { + return this; + } + + public IntConfParser defaultValue(int value) { + this.defaultValue = value; + return self(); + } + + public int parse() { + Preconditions.checkArgument(defaultValue != null, "Default value cannot be null"); + return parse(Integer::parseInt, defaultValue); + } + + public Integer parseOptional() { + return parse(Integer::parseInt, null); + } + } + + class LongConfParser extends ConfParser { + private Long defaultValue; + + @Override + protected LongConfParser self() { + return this; + } + + public LongConfParser defaultValue(long value) { + this.defaultValue = value; + return self(); + } + + public long parse() { + Preconditions.checkArgument(defaultValue != null, "Default value cannot be null"); + return parse(Long::parseLong, defaultValue); + } + + public Long parseOptional() { + return parse(Long::parseLong, null); + } + } + + class StringConfParser extends ConfParser { + private String defaultValue; + + @Override + protected StringConfParser self() { + return this; + } + + public StringConfParser defaultValue(String value) { + this.defaultValue = value; + return self(); + } + + public String parse() { + Preconditions.checkArgument(defaultValue != null, "Default value cannot be null"); + return parse(Function.identity(), defaultValue); + } + + public String parseOptional() { + return parse(Function.identity(), null); + } + } + + abstract class ConfParser { + private final List optionNames = Lists.newArrayList(); + private String sessionConfName; + private String tablePropertyName; + + protected abstract ThisT self(); + + public ThisT option(String name) { + this.optionNames.add(name); + return self(); + } + + public ThisT sessionConf(String name) { + this.sessionConfName = name; + return self(); + } + + public ThisT tableProperty(String name) { + this.tablePropertyName = name; + return self(); + } + + protected T parse(Function conversion, T defaultValue) { + if (!optionNames.isEmpty()) { + for (String optionName : optionNames) { + // use lower case comparison as DataSourceOptions.asMap() in Spark 2 returns a lower case map + String optionValue = options.get(optionName.toLowerCase(Locale.ROOT)); + if (optionValue != null) { + return conversion.apply(optionValue); + } + } + } + + if (sessionConfName != null) { + String sessionConfValue = sessionConf.get(sessionConfName, null); + if (sessionConfValue != null) { + return conversion.apply(sessionConfValue); + } + } + + if (tablePropertyName != null) { + String propertyValue = properties.get(tablePropertyName); + if (propertyValue != null) { + return conversion.apply(propertyValue); + } + } + + return defaultValue; + } + } +} diff --git a/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/SparkDataFile.java b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/SparkDataFile.java new file mode 100644 index 000000000000..a6390d39c575 --- /dev/null +++ b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/SparkDataFile.java @@ -0,0 +1,203 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iceberg.spark; + +import java.nio.ByteBuffer; +import java.util.List; +import java.util.Locale; +import java.util.Map; +import org.apache.iceberg.DataFile; +import org.apache.iceberg.FileFormat; +import org.apache.iceberg.StructLike; +import org.apache.iceberg.relocated.com.google.common.collect.Maps; +import org.apache.iceberg.types.Type; +import org.apache.iceberg.types.Types; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.types.StructType; + +public class SparkDataFile implements DataFile { + + private final int filePathPosition; + private final int fileFormatPosition; + private final int partitionPosition; + private final int recordCountPosition; + private final int fileSizeInBytesPosition; + private final int columnSizesPosition; + private final int valueCountsPosition; + private final int nullValueCountsPosition; + private final int nanValueCountsPosition; + private final int lowerBoundsPosition; + private final int upperBoundsPosition; + private final int keyMetadataPosition; + private final int splitOffsetsPosition; + private final int sortOrderIdPosition; + private final Type lowerBoundsType; + private final Type upperBoundsType; + private final Type keyMetadataType; + + private final SparkStructLike wrappedPartition; + private Row wrapped; + + public SparkDataFile(Types.StructType type, StructType sparkType) { + this.lowerBoundsType = type.fieldType("lower_bounds"); + this.upperBoundsType = type.fieldType("upper_bounds"); + this.keyMetadataType = type.fieldType("key_metadata"); + this.wrappedPartition = new SparkStructLike(type.fieldType("partition").asStructType()); + + Map positions = Maps.newHashMap(); + type.fields().forEach(field -> { + String fieldName = field.name(); + positions.put(fieldName, fieldPosition(fieldName, sparkType)); + }); + + filePathPosition = positions.get("file_path"); + fileFormatPosition = positions.get("file_format"); + partitionPosition = positions.get("partition"); + recordCountPosition = positions.get("record_count"); + fileSizeInBytesPosition = positions.get("file_size_in_bytes"); + columnSizesPosition = positions.get("column_sizes"); + valueCountsPosition = positions.get("value_counts"); + nullValueCountsPosition = positions.get("null_value_counts"); + nanValueCountsPosition = positions.get("nan_value_counts"); + lowerBoundsPosition = positions.get("lower_bounds"); + upperBoundsPosition = positions.get("upper_bounds"); + keyMetadataPosition = positions.get("key_metadata"); + splitOffsetsPosition = positions.get("split_offsets"); + sortOrderIdPosition = positions.get("sort_order_id"); + } + + public SparkDataFile wrap(Row row) { + this.wrapped = row; + if (wrappedPartition.size() > 0) { + this.wrappedPartition.wrap(row.getAs(partitionPosition)); + } + return this; + } + + @Override + public Long pos() { + return null; + } + + @Override + public int specId() { + return -1; + } + + @Override + public CharSequence path() { + return wrapped.getAs(filePathPosition); + } + + @Override + public FileFormat format() { + String formatAsString = wrapped.getString(fileFormatPosition).toUpperCase(Locale.ROOT); + return FileFormat.valueOf(formatAsString); + } + + @Override + public StructLike partition() { + return wrappedPartition; + } + + @Override + public long recordCount() { + return wrapped.getAs(recordCountPosition); + } + + @Override + public long fileSizeInBytes() { + return wrapped.getAs(fileSizeInBytesPosition); + } + + @Override + public Map columnSizes() { + return wrapped.isNullAt(columnSizesPosition) ? null : wrapped.getJavaMap(columnSizesPosition); + } + + @Override + public Map valueCounts() { + return wrapped.isNullAt(valueCountsPosition) ? null : wrapped.getJavaMap(valueCountsPosition); + } + + @Override + public Map nullValueCounts() { + return wrapped.isNullAt(nullValueCountsPosition) ? null : wrapped.getJavaMap(nullValueCountsPosition); + } + + @Override + public Map nanValueCounts() { + return wrapped.isNullAt(nanValueCountsPosition) ? null : wrapped.getJavaMap(nanValueCountsPosition); + } + + @Override + public Map lowerBounds() { + Map lowerBounds = wrapped.isNullAt(lowerBoundsPosition) ? null : wrapped.getJavaMap(lowerBoundsPosition); + return convert(lowerBoundsType, lowerBounds); + } + + @Override + public Map upperBounds() { + Map upperBounds = wrapped.isNullAt(upperBoundsPosition) ? null : wrapped.getJavaMap(upperBoundsPosition); + return convert(upperBoundsType, upperBounds); + } + + @Override + public ByteBuffer keyMetadata() { + return convert(keyMetadataType, wrapped.get(keyMetadataPosition)); + } + + @Override + public DataFile copy() { + throw new UnsupportedOperationException("Not implemented: copy"); + } + + @Override + public DataFile copyWithoutStats() { + throw new UnsupportedOperationException("Not implemented: copyWithoutStats"); + } + + @Override + public List splitOffsets() { + return wrapped.isNullAt(splitOffsetsPosition) ? null : wrapped.getList(splitOffsetsPosition); + } + + @Override + public Integer sortOrderId() { + return wrapped.getAs(sortOrderIdPosition); + } + + private int fieldPosition(String name, StructType sparkType) { + try { + return sparkType.fieldIndex(name); + } catch (IllegalArgumentException e) { + // the partition field is absent for unpartitioned tables + if (name.equals("partition") && wrappedPartition.size() == 0) { + return -1; + } + throw e; + } + } + + @SuppressWarnings("unchecked") + private T convert(Type valueType, Object value) { + return (T) SparkValueConverter.convert(valueType, value); + } +} diff --git a/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/SparkDistributionAndOrderingUtil.java b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/SparkDistributionAndOrderingUtil.java new file mode 100644 index 000000000000..e481e1260432 --- /dev/null +++ b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/SparkDistributionAndOrderingUtil.java @@ -0,0 +1,224 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iceberg.spark; + +import java.util.List; +import org.apache.iceberg.DistributionMode; +import org.apache.iceberg.MetadataColumns; +import org.apache.iceberg.Table; +import org.apache.iceberg.relocated.com.google.common.collect.ObjectArrays; +import org.apache.iceberg.transforms.SortOrderVisitor; +import org.apache.iceberg.util.SortOrderUtil; +import org.apache.spark.sql.connector.distributions.ClusteredDistribution; +import org.apache.spark.sql.connector.distributions.Distribution; +import org.apache.spark.sql.connector.distributions.Distributions; +import org.apache.spark.sql.connector.distributions.OrderedDistribution; +import org.apache.spark.sql.connector.distributions.UnspecifiedDistribution; +import org.apache.spark.sql.connector.expressions.Expression; +import org.apache.spark.sql.connector.expressions.Expressions; +import org.apache.spark.sql.connector.expressions.NamedReference; +import org.apache.spark.sql.connector.expressions.SortDirection; +import org.apache.spark.sql.connector.expressions.SortOrder; +import org.apache.spark.sql.connector.write.RowLevelOperation.Command; + +import static org.apache.spark.sql.connector.write.RowLevelOperation.Command.DELETE; +import static org.apache.spark.sql.connector.write.RowLevelOperation.Command.UPDATE; + +public class SparkDistributionAndOrderingUtil { + + private static final NamedReference SPEC_ID = Expressions.column(MetadataColumns.SPEC_ID.name()); + private static final NamedReference PARTITION = Expressions.column(MetadataColumns.PARTITION_COLUMN_NAME); + private static final NamedReference FILE_PATH = Expressions.column(MetadataColumns.FILE_PATH.name()); + private static final NamedReference ROW_POSITION = Expressions.column(MetadataColumns.ROW_POSITION.name()); + + private static final SortOrder SPEC_ID_ORDER = Expressions.sort(SPEC_ID, SortDirection.ASCENDING); + private static final SortOrder PARTITION_ORDER = Expressions.sort(PARTITION, SortDirection.ASCENDING); + private static final SortOrder FILE_PATH_ORDER = Expressions.sort(FILE_PATH, SortDirection.ASCENDING); + private static final SortOrder ROW_POSITION_ORDER = Expressions.sort(ROW_POSITION, SortDirection.ASCENDING); + + private static final SortOrder[] EXISTING_FILE_ORDERING = new SortOrder[]{FILE_PATH_ORDER, ROW_POSITION_ORDER}; + private static final SortOrder[] POSITION_DELETE_ORDERING = new SortOrder[]{ + SPEC_ID_ORDER, PARTITION_ORDER, FILE_PATH_ORDER, ROW_POSITION_ORDER + }; + + private SparkDistributionAndOrderingUtil() { + } + + public static Distribution buildRequiredDistribution(Table table, DistributionMode distributionMode) { + switch (distributionMode) { + case NONE: + return Distributions.unspecified(); + + case HASH: + return Distributions.clustered(Spark3Util.toTransforms(table.spec())); + + case RANGE: + return Distributions.ordered(buildTableOrdering(table)); + + default: + throw new IllegalArgumentException("Unsupported distribution mode: " + distributionMode); + } + } + + public static SortOrder[] buildRequiredOrdering(Table table, Distribution distribution) { + if (distribution instanceof OrderedDistribution) { + OrderedDistribution orderedDistribution = (OrderedDistribution) distribution; + return orderedDistribution.ordering(); + } else { + return buildTableOrdering(table); + } + } + + public static Distribution buildCopyOnWriteDistribution(Table table, Command command, + DistributionMode distributionMode) { + if (command == DELETE || command == UPDATE) { + return buildCopyOnWriteDeleteUpdateDistribution(table, distributionMode); + } else { + return buildRequiredDistribution(table, distributionMode); + } + } + + private static Distribution buildCopyOnWriteDeleteUpdateDistribution(Table table, DistributionMode distributionMode) { + switch (distributionMode) { + case NONE: + return Distributions.unspecified(); + + case HASH: + Expression[] clustering = new Expression[]{FILE_PATH}; + return Distributions.clustered(clustering); + + case RANGE: + SortOrder[] tableOrdering = buildTableOrdering(table); + if (table.sortOrder().isSorted()) { + return Distributions.ordered(tableOrdering); + } else { + SortOrder[] ordering = ObjectArrays.concat(tableOrdering, EXISTING_FILE_ORDERING, SortOrder.class); + return Distributions.ordered(ordering); + } + + default: + throw new IllegalArgumentException("Unexpected distribution mode: " + distributionMode); + } + } + + public static SortOrder[] buildCopyOnWriteOrdering(Table table, Command command, Distribution distribution) { + if (command == DELETE || command == UPDATE) { + return buildCopyOnWriteDeleteUpdateOrdering(table, distribution); + } else { + return buildRequiredOrdering(table, distribution); + } + } + + private static SortOrder[] buildCopyOnWriteDeleteUpdateOrdering(Table table, Distribution distribution) { + if (distribution instanceof UnspecifiedDistribution) { + return buildTableOrdering(table); + + } else if (distribution instanceof ClusteredDistribution) { + SortOrder[] tableOrdering = buildTableOrdering(table); + if (table.sortOrder().isSorted()) { + return tableOrdering; + } else { + return ObjectArrays.concat(tableOrdering, EXISTING_FILE_ORDERING, SortOrder.class); + } + + } else if (distribution instanceof OrderedDistribution) { + OrderedDistribution orderedDistribution = (OrderedDistribution) distribution; + return orderedDistribution.ordering(); + + } else { + throw new IllegalArgumentException("Unexpected distribution type: " + distribution.getClass().getName()); + } + } + + public static Distribution buildPositionDeltaDistribution(Table table, Command command, + DistributionMode distributionMode) { + if (command == DELETE || command == UPDATE) { + return buildPositionDeleteUpdateDistribution(distributionMode); + } else { + return buildPositionMergeDistribution(table, distributionMode); + } + } + + private static Distribution buildPositionMergeDistribution(Table table, DistributionMode distributionMode) { + switch (distributionMode) { + case NONE: + return Distributions.unspecified(); + + case HASH: + if (table.spec().isUnpartitioned()) { + Expression[] clustering = new Expression[]{SPEC_ID, PARTITION, FILE_PATH}; + return Distributions.clustered(clustering); + } else { + Distribution dataDistribution = buildRequiredDistribution(table, distributionMode); + Expression[] dataClustering = ((ClusteredDistribution) dataDistribution).clustering(); + Expression[] deleteClustering = new Expression[]{SPEC_ID, PARTITION}; + Expression[] clustering = ObjectArrays.concat(deleteClustering, dataClustering, Expression.class); + return Distributions.clustered(clustering); + } + + case RANGE: + Distribution dataDistribution = buildRequiredDistribution(table, distributionMode); + SortOrder[] dataOrdering = ((OrderedDistribution) dataDistribution).ordering(); + SortOrder[] deleteOrdering = new SortOrder[]{SPEC_ID_ORDER, PARTITION_ORDER, FILE_PATH_ORDER}; + SortOrder[] ordering = ObjectArrays.concat(deleteOrdering, dataOrdering, SortOrder.class); + return Distributions.ordered(ordering); + + default: + throw new IllegalArgumentException("Unexpected distribution mode: " + distributionMode); + } + } + + private static Distribution buildPositionDeleteUpdateDistribution(DistributionMode distributionMode) { + switch (distributionMode) { + case NONE: + return Distributions.unspecified(); + + case HASH: + Expression[] clustering = new Expression[]{SPEC_ID, PARTITION}; + return Distributions.clustered(clustering); + + case RANGE: + SortOrder[] ordering = new SortOrder[]{SPEC_ID_ORDER, PARTITION_ORDER, FILE_PATH_ORDER}; + return Distributions.ordered(ordering); + + default: + throw new IllegalArgumentException("Unsupported distribution mode: " + distributionMode); + } + } + + public static SortOrder[] buildPositionDeltaOrdering(Table table, Command command) { + if (command == DELETE || command == UPDATE) { + return POSITION_DELETE_ORDERING; + } else { + // all metadata columns like _spec_id, _file, _pos will be null for new data records + SortOrder[] dataOrdering = buildTableOrdering(table); + return ObjectArrays.concat(POSITION_DELETE_ORDERING, dataOrdering, SortOrder.class); + } + } + + public static SortOrder[] convert(org.apache.iceberg.SortOrder sortOrder) { + List converted = SortOrderVisitor.visit(sortOrder, new SortOrderToSpark(sortOrder.schema())); + return converted.toArray(new SortOrder[0]); + } + + private static SortOrder[] buildTableOrdering(Table table) { + return convert(SortOrderUtil.buildSortOrder(table)); + } +} diff --git a/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/SparkExceptionUtil.java b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/SparkExceptionUtil.java new file mode 100644 index 000000000000..2eb53baa688e --- /dev/null +++ b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/SparkExceptionUtil.java @@ -0,0 +1,66 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iceberg.spark; + +import com.google.errorprone.annotations.FormatMethod; +import java.io.IOException; +import org.apache.iceberg.exceptions.NoSuchNamespaceException; +import org.apache.iceberg.exceptions.NoSuchTableException; +import org.apache.iceberg.exceptions.RuntimeIOException; +import org.apache.iceberg.exceptions.ValidationException; +import org.apache.spark.sql.AnalysisException; + +public class SparkExceptionUtil { + + private SparkExceptionUtil() { + } + + /** + * Converts checked exceptions to unchecked exceptions. + * + * @param cause a checked exception object which is to be converted to its unchecked equivalent. + * @param message exception message as a format string + * @param args format specifiers + * @return unchecked exception. + */ + @FormatMethod + public static RuntimeException toUncheckedException(final Throwable cause, final String message, + final Object... args) { + // Parameters are required to be final to help @FormatMethod do static analysis + if (cause instanceof RuntimeException) { + return (RuntimeException) cause; + + } else if (cause instanceof org.apache.spark.sql.catalyst.analysis.NoSuchDatabaseException) { + return new NoSuchNamespaceException(cause, message, args); + + } else if (cause instanceof org.apache.spark.sql.catalyst.analysis.NoSuchTableException) { + return new NoSuchTableException(cause, message, args); + + } else if (cause instanceof AnalysisException) { + return new ValidationException(cause, message, args); + + } else if (cause instanceof IOException) { + return new RuntimeIOException((IOException) cause, message, args); + + } else { + return new RuntimeException(String.format(message, args), cause); + } + } +} diff --git a/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/SparkFilters.java b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/SparkFilters.java new file mode 100644 index 000000000000..50c108c0b01b --- /dev/null +++ b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/SparkFilters.java @@ -0,0 +1,271 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iceberg.spark; + +import java.sql.Date; +import java.sql.Timestamp; +import java.time.Instant; +import java.time.LocalDate; +import java.util.Map; +import java.util.Objects; +import java.util.regex.Matcher; +import java.util.regex.Pattern; +import java.util.stream.Collectors; +import java.util.stream.Stream; +import org.apache.iceberg.expressions.Expression; +import org.apache.iceberg.expressions.Expression.Operation; +import org.apache.iceberg.expressions.Expressions; +import org.apache.iceberg.relocated.com.google.common.base.Preconditions; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap; +import org.apache.iceberg.util.NaNUtil; +import org.apache.spark.sql.catalyst.util.DateTimeUtils; +import org.apache.spark.sql.sources.AlwaysFalse; +import org.apache.spark.sql.sources.AlwaysFalse$; +import org.apache.spark.sql.sources.AlwaysTrue; +import org.apache.spark.sql.sources.AlwaysTrue$; +import org.apache.spark.sql.sources.And; +import org.apache.spark.sql.sources.EqualNullSafe; +import org.apache.spark.sql.sources.EqualTo; +import org.apache.spark.sql.sources.Filter; +import org.apache.spark.sql.sources.GreaterThan; +import org.apache.spark.sql.sources.GreaterThanOrEqual; +import org.apache.spark.sql.sources.In; +import org.apache.spark.sql.sources.IsNotNull; +import org.apache.spark.sql.sources.IsNull; +import org.apache.spark.sql.sources.LessThan; +import org.apache.spark.sql.sources.LessThanOrEqual; +import org.apache.spark.sql.sources.Not; +import org.apache.spark.sql.sources.Or; +import org.apache.spark.sql.sources.StringStartsWith; + +import static org.apache.iceberg.expressions.Expressions.and; +import static org.apache.iceberg.expressions.Expressions.equal; +import static org.apache.iceberg.expressions.Expressions.greaterThan; +import static org.apache.iceberg.expressions.Expressions.greaterThanOrEqual; +import static org.apache.iceberg.expressions.Expressions.in; +import static org.apache.iceberg.expressions.Expressions.isNaN; +import static org.apache.iceberg.expressions.Expressions.isNull; +import static org.apache.iceberg.expressions.Expressions.lessThan; +import static org.apache.iceberg.expressions.Expressions.lessThanOrEqual; +import static org.apache.iceberg.expressions.Expressions.not; +import static org.apache.iceberg.expressions.Expressions.notIn; +import static org.apache.iceberg.expressions.Expressions.notNull; +import static org.apache.iceberg.expressions.Expressions.or; +import static org.apache.iceberg.expressions.Expressions.startsWith; + +public class SparkFilters { + + private static final Pattern BACKTICKS_PATTERN = Pattern.compile("([`])(.|$)"); + + private SparkFilters() { + } + + private static final Map, Operation> FILTERS = ImmutableMap + ., Operation>builder() + .put(AlwaysTrue.class, Operation.TRUE) + .put(AlwaysTrue$.class, Operation.TRUE) + .put(AlwaysFalse$.class, Operation.FALSE) + .put(AlwaysFalse.class, Operation.FALSE) + .put(EqualTo.class, Operation.EQ) + .put(EqualNullSafe.class, Operation.EQ) + .put(GreaterThan.class, Operation.GT) + .put(GreaterThanOrEqual.class, Operation.GT_EQ) + .put(LessThan.class, Operation.LT) + .put(LessThanOrEqual.class, Operation.LT_EQ) + .put(In.class, Operation.IN) + .put(IsNull.class, Operation.IS_NULL) + .put(IsNotNull.class, Operation.NOT_NULL) + .put(And.class, Operation.AND) + .put(Or.class, Operation.OR) + .put(Not.class, Operation.NOT) + .put(StringStartsWith.class, Operation.STARTS_WITH) + .build(); + + public static Expression convert(Filter[] filters) { + Expression expression = Expressions.alwaysTrue(); + for (Filter filter : filters) { + Expression converted = convert(filter); + Preconditions.checkArgument(converted != null, "Cannot convert filter to Iceberg: %s", filter); + expression = Expressions.and(expression, converted); + } + return expression; + } + + public static Expression convert(Filter filter) { + // avoid using a chain of if instanceof statements by mapping to the expression enum. + Operation op = FILTERS.get(filter.getClass()); + if (op != null) { + switch (op) { + case TRUE: + return Expressions.alwaysTrue(); + + case FALSE: + return Expressions.alwaysFalse(); + + case IS_NULL: + IsNull isNullFilter = (IsNull) filter; + return isNull(unquote(isNullFilter.attribute())); + + case NOT_NULL: + IsNotNull notNullFilter = (IsNotNull) filter; + return notNull(unquote(notNullFilter.attribute())); + + case LT: + LessThan lt = (LessThan) filter; + return lessThan(unquote(lt.attribute()), convertLiteral(lt.value())); + + case LT_EQ: + LessThanOrEqual ltEq = (LessThanOrEqual) filter; + return lessThanOrEqual(unquote(ltEq.attribute()), convertLiteral(ltEq.value())); + + case GT: + GreaterThan gt = (GreaterThan) filter; + return greaterThan(unquote(gt.attribute()), convertLiteral(gt.value())); + + case GT_EQ: + GreaterThanOrEqual gtEq = (GreaterThanOrEqual) filter; + return greaterThanOrEqual(unquote(gtEq.attribute()), convertLiteral(gtEq.value())); + + case EQ: // used for both eq and null-safe-eq + if (filter instanceof EqualTo) { + EqualTo eq = (EqualTo) filter; + // comparison with null in normal equality is always null. this is probably a mistake. + Preconditions.checkNotNull(eq.value(), + "Expression is always false (eq is not null-safe): %s", filter); + return handleEqual(unquote(eq.attribute()), eq.value()); + } else { + EqualNullSafe eq = (EqualNullSafe) filter; + if (eq.value() == null) { + return isNull(unquote(eq.attribute())); + } else { + return handleEqual(unquote(eq.attribute()), eq.value()); + } + } + + case IN: + In inFilter = (In) filter; + return in(unquote(inFilter.attribute()), + Stream.of(inFilter.values()) + .filter(Objects::nonNull) + .map(SparkFilters::convertLiteral) + .collect(Collectors.toList())); + + case NOT: + Not notFilter = (Not) filter; + Filter childFilter = notFilter.child(); + Operation childOp = FILTERS.get(childFilter.getClass()); + if (childOp == Operation.IN) { + // infer an extra notNull predicate for Spark NOT IN filters + // as Iceberg expressions don't follow the 3-value SQL boolean logic + // col NOT IN (1, 2) in Spark is equivalent to notNull(col) && notIn(col, 1, 2) in Iceberg + In childInFilter = (In) childFilter; + Expression notIn = notIn(unquote(childInFilter.attribute()), + Stream.of(childInFilter.values()) + .map(SparkFilters::convertLiteral) + .collect(Collectors.toList())); + return and(notNull(childInFilter.attribute()), notIn); + } else if (hasNoInFilter(childFilter)) { + Expression child = convert(childFilter); + if (child != null) { + return not(child); + } + } + return null; + + case AND: { + And andFilter = (And) filter; + Expression left = convert(andFilter.left()); + Expression right = convert(andFilter.right()); + if (left != null && right != null) { + return and(left, right); + } + return null; + } + + case OR: { + Or orFilter = (Or) filter; + Expression left = convert(orFilter.left()); + Expression right = convert(orFilter.right()); + if (left != null && right != null) { + return or(left, right); + } + return null; + } + + case STARTS_WITH: { + StringStartsWith stringStartsWith = (StringStartsWith) filter; + return startsWith(unquote(stringStartsWith.attribute()), stringStartsWith.value()); + } + } + } + + return null; + } + + private static Object convertLiteral(Object value) { + if (value instanceof Timestamp) { + return DateTimeUtils.fromJavaTimestamp((Timestamp) value); + } else if (value instanceof Date) { + return DateTimeUtils.fromJavaDate((Date) value); + } else if (value instanceof Instant) { + return DateTimeUtils.instantToMicros((Instant) value); + } else if (value instanceof LocalDate) { + return DateTimeUtils.localDateToDays((LocalDate) value); + } + return value; + } + + private static Expression handleEqual(String attribute, Object value) { + if (NaNUtil.isNaN(value)) { + return isNaN(attribute); + } else { + return equal(attribute, convertLiteral(value)); + } + } + + private static String unquote(String attributeName) { + Matcher matcher = BACKTICKS_PATTERN.matcher(attributeName); + return matcher.replaceAll("$2"); + } + + private static boolean hasNoInFilter(Filter filter) { + Operation op = FILTERS.get(filter.getClass()); + + if (op != null) { + switch (op) { + case AND: + And andFilter = (And) filter; + return hasNoInFilter(andFilter.left()) && hasNoInFilter(andFilter.right()); + case OR: + Or orFilter = (Or) filter; + return hasNoInFilter(orFilter.left()) && hasNoInFilter(orFilter.right()); + case NOT: + Not notFilter = (Not) filter; + return hasNoInFilter(notFilter.child()); + case IN: + return false; + default: + return true; + } + } + + return false; + } +} diff --git a/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/SparkFixupTimestampType.java b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/SparkFixupTimestampType.java new file mode 100644 index 000000000000..d4dd53d34a97 --- /dev/null +++ b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/SparkFixupTimestampType.java @@ -0,0 +1,57 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iceberg.spark; + +import org.apache.iceberg.Schema; +import org.apache.iceberg.types.FixupTypes; +import org.apache.iceberg.types.Type; +import org.apache.iceberg.types.TypeUtil; +import org.apache.iceberg.types.Types; + +/** + * By default Spark type {@link org.apache.iceberg.types.Types.TimestampType} should be converted to + * {@link Types.TimestampType#withZone()} iceberg type. But we also can convert + * {@link org.apache.iceberg.types.Types.TimestampType} to {@link Types.TimestampType#withoutZone()} iceberg type + * by setting {@link SparkSQLProperties#USE_TIMESTAMP_WITHOUT_TIME_ZONE_IN_NEW_TABLES} to 'true' + */ +class SparkFixupTimestampType extends FixupTypes { + + private SparkFixupTimestampType(Schema referenceSchema) { + super(referenceSchema); + } + + static Schema fixup(Schema schema) { + return new Schema(TypeUtil.visit(schema, + new SparkFixupTimestampType(schema)).asStructType().fields()); + } + + @Override + public Type primitive(Type.PrimitiveType primitive) { + if (primitive.typeId() == Type.TypeID.TIMESTAMP) { + return Types.TimestampType.withoutZone(); + } + return primitive; + } + + @Override + protected boolean fixupPrimitive(Type.PrimitiveType type, Type source) { + return Type.TypeID.TIMESTAMP.equals(type.typeId()); + } +} diff --git a/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/SparkFixupTypes.java b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/SparkFixupTypes.java new file mode 100644 index 000000000000..5508965af249 --- /dev/null +++ b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/SparkFixupTypes.java @@ -0,0 +1,64 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iceberg.spark; + +import org.apache.iceberg.Schema; +import org.apache.iceberg.types.FixupTypes; +import org.apache.iceberg.types.Type; +import org.apache.iceberg.types.TypeUtil; + +/** + * Some types, like binary and fixed, are converted to the same Spark type. Conversion back + * can produce only one, which may not be correct. + */ +class SparkFixupTypes extends FixupTypes { + + private SparkFixupTypes(Schema referenceSchema) { + super(referenceSchema); + } + + static Schema fixup(Schema schema, Schema referenceSchema) { + return new Schema(TypeUtil.visit(schema, + new SparkFixupTypes(referenceSchema)).asStructType().fields()); + } + + @Override + protected boolean fixupPrimitive(Type.PrimitiveType type, Type source) { + switch (type.typeId()) { + case STRING: + if (source.typeId() == Type.TypeID.UUID) { + return true; + } + break; + case BINARY: + if (source.typeId() == Type.TypeID.FIXED) { + return true; + } + break; + case TIMESTAMP: + if (source.typeId() == Type.TypeID.TIMESTAMP) { + return true; + } + break; + default: + } + return false; + } +} diff --git a/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/SparkReadConf.java b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/SparkReadConf.java new file mode 100644 index 000000000000..c7c01758c3ee --- /dev/null +++ b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/SparkReadConf.java @@ -0,0 +1,228 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iceberg.spark; + +import java.util.Map; +import java.util.Set; +import org.apache.iceberg.Table; +import org.apache.iceberg.TableProperties; +import org.apache.iceberg.hadoop.HadoopInputFile; +import org.apache.iceberg.io.InputFile; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableSet; +import org.apache.iceberg.util.PropertyUtil; +import org.apache.spark.sql.SparkSession; + +/** + * A class for common Iceberg configs for Spark reads. + *

+ * If a config is set at multiple levels, the following order of precedence is used (top to bottom): + *

    + *
  1. Read options
  2. + *
  3. Session configuration
  4. + *
  5. Table metadata
  6. + *
+ * The most specific value is set in read options and takes precedence over all other configs. + * If no read option is provided, this class checks the session configuration for any overrides. + * If no applicable value is found in the session configuration, this class uses the table metadata. + *

+ * Note this class is NOT meant to be serialized and sent to executors. + */ +public class SparkReadConf { + + private static final Set LOCALITY_WHITELIST_FS = ImmutableSet.of("hdfs"); + + private final SparkSession spark; + private final Table table; + private final Map readOptions; + private final SparkConfParser confParser; + + public SparkReadConf(SparkSession spark, Table table, Map readOptions) { + this.spark = spark; + this.table = table; + this.readOptions = readOptions; + this.confParser = new SparkConfParser(spark, table, readOptions); + } + + public boolean caseSensitive() { + return Boolean.parseBoolean(spark.conf().get("spark.sql.caseSensitive")); + } + + public boolean localityEnabled() { + InputFile file = table.io().newInputFile(table.location()); + + if (file instanceof HadoopInputFile) { + String scheme = ((HadoopInputFile) file).getFileSystem().getScheme(); + boolean defaultValue = LOCALITY_WHITELIST_FS.contains(scheme); + return PropertyUtil.propertyAsBoolean( + readOptions, + SparkReadOptions.LOCALITY, + defaultValue); + } + + return false; + } + + public Long snapshotId() { + return confParser.longConf() + .option(SparkReadOptions.SNAPSHOT_ID) + .parseOptional(); + } + + public Long asOfTimestamp() { + return confParser.longConf() + .option(SparkReadOptions.AS_OF_TIMESTAMP) + .parseOptional(); + } + + public Long startSnapshotId() { + return confParser.longConf() + .option(SparkReadOptions.START_SNAPSHOT_ID) + .parseOptional(); + } + + public Long endSnapshotId() { + return confParser.longConf() + .option(SparkReadOptions.END_SNAPSHOT_ID) + .parseOptional(); + } + + public String fileScanTaskSetId() { + return confParser.stringConf() + .option(SparkReadOptions.FILE_SCAN_TASK_SET_ID) + .parseOptional(); + } + + public boolean streamingSkipDeleteSnapshots() { + return confParser.booleanConf() + .option(SparkReadOptions.STREAMING_SKIP_DELETE_SNAPSHOTS) + .defaultValue(SparkReadOptions.STREAMING_SKIP_DELETE_SNAPSHOTS_DEFAULT) + .parse(); + } + + public boolean streamingSkipOverwriteSnapshots() { + return confParser.booleanConf() + .option(SparkReadOptions.STREAMING_SKIP_OVERWRITE_SNAPSHOTS) + .defaultValue(SparkReadOptions.STREAMING_SKIP_OVERWRITE_SNAPSHOTS_DEFAULT) + .parse(); + } + + public boolean parquetVectorizationEnabled() { + return confParser.booleanConf() + .option(SparkReadOptions.VECTORIZATION_ENABLED) + .sessionConf(SparkSQLProperties.VECTORIZATION_ENABLED) + .tableProperty(TableProperties.PARQUET_VECTORIZATION_ENABLED) + .defaultValue(TableProperties.PARQUET_VECTORIZATION_ENABLED_DEFAULT) + .parse(); + } + + public int parquetBatchSize() { + return confParser.intConf() + .option(SparkReadOptions.VECTORIZATION_BATCH_SIZE) + .tableProperty(TableProperties.PARQUET_BATCH_SIZE) + .defaultValue(TableProperties.PARQUET_BATCH_SIZE_DEFAULT) + .parse(); + } + + public boolean orcVectorizationEnabled() { + return confParser.booleanConf() + .option(SparkReadOptions.VECTORIZATION_ENABLED) + .sessionConf(SparkSQLProperties.VECTORIZATION_ENABLED) + .tableProperty(TableProperties.ORC_VECTORIZATION_ENABLED) + .defaultValue(TableProperties.ORC_VECTORIZATION_ENABLED_DEFAULT) + .parse(); + } + + public int orcBatchSize() { + return confParser.intConf() + .option(SparkReadOptions.VECTORIZATION_BATCH_SIZE) + .tableProperty(TableProperties.ORC_BATCH_SIZE) + .defaultValue(TableProperties.ORC_BATCH_SIZE_DEFAULT) + .parse(); + } + + public Long splitSizeOption() { + return confParser.longConf() + .option(SparkReadOptions.SPLIT_SIZE) + .parseOptional(); + } + + public long splitSize() { + return confParser.longConf() + .option(SparkReadOptions.SPLIT_SIZE) + .tableProperty(TableProperties.SPLIT_SIZE) + .defaultValue(TableProperties.SPLIT_SIZE_DEFAULT) + .parse(); + } + + public Integer splitLookbackOption() { + return confParser.intConf() + .option(SparkReadOptions.LOOKBACK) + .parseOptional(); + } + + public int splitLookback() { + return confParser.intConf() + .option(SparkReadOptions.LOOKBACK) + .tableProperty(TableProperties.SPLIT_LOOKBACK) + .defaultValue(TableProperties.SPLIT_LOOKBACK_DEFAULT) + .parse(); + } + + public Long splitOpenFileCostOption() { + return confParser.longConf() + .option(SparkReadOptions.FILE_OPEN_COST) + .parseOptional(); + } + + public long splitOpenFileCost() { + return confParser.longConf() + .option(SparkReadOptions.FILE_OPEN_COST) + .tableProperty(TableProperties.SPLIT_OPEN_FILE_COST) + .defaultValue(TableProperties.SPLIT_OPEN_FILE_COST_DEFAULT) + .parse(); + } + + /** + * Enables reading a timestamp without time zone as a timestamp with time zone. + *

+ * Generally, this is not safe as a timestamp without time zone is supposed to represent the wall-clock time, + * i.e. no matter the reader/writer timezone 3PM should always be read as 3PM, + * but a timestamp with time zone represents instant semantics, i.e. the timestamp + * is adjusted so that the corresponding time in the reader timezone is displayed. + *

+ * When set to false (default), an exception must be thrown while reading a timestamp without time zone. + * + * @return boolean indicating if reading timestamps without timezone is allowed + */ + public boolean handleTimestampWithoutZone() { + return confParser.booleanConf() + .option(SparkReadOptions.HANDLE_TIMESTAMP_WITHOUT_TIMEZONE) + .sessionConf(SparkSQLProperties.HANDLE_TIMESTAMP_WITHOUT_TIMEZONE) + .defaultValue(SparkSQLProperties.HANDLE_TIMESTAMP_WITHOUT_TIMEZONE_DEFAULT) + .parse(); + } + + public Long streamFromTimestamp() { + return confParser.longConf() + .option(SparkReadOptions.STREAM_FROM_TIMESTAMP) + .defaultValue(Long.MIN_VALUE) + .parse(); + } +} diff --git a/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/SparkReadOptions.java b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/SparkReadOptions.java new file mode 100644 index 000000000000..edcc2300344a --- /dev/null +++ b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/SparkReadOptions.java @@ -0,0 +1,76 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iceberg.spark; + +/** + * Spark DF read options + */ +public class SparkReadOptions { + + private SparkReadOptions() { + } + + // Snapshot ID of the table snapshot to read + public static final String SNAPSHOT_ID = "snapshot-id"; + + // Start snapshot ID used in incremental scans (exclusive) + public static final String START_SNAPSHOT_ID = "start-snapshot-id"; + + // End snapshot ID used in incremental scans (inclusive) + public static final String END_SNAPSHOT_ID = "end-snapshot-id"; + + // A timestamp in milliseconds; the snapshot used will be the snapshot current at this time. + public static final String AS_OF_TIMESTAMP = "as-of-timestamp"; + + // Overrides the table's read.split.target-size and read.split.metadata-target-size + public static final String SPLIT_SIZE = "split-size"; + + // Overrides the table's read.split.planning-lookback + public static final String LOOKBACK = "lookback"; + + // Overrides the table's read.split.open-file-cost + public static final String FILE_OPEN_COST = "file-open-cost"; + + // Overrides the table's read.split.open-file-cost + public static final String VECTORIZATION_ENABLED = "vectorization-enabled"; + + // Overrides the table's read.parquet.vectorization.batch-size + public static final String VECTORIZATION_BATCH_SIZE = "batch-size"; + + // Set ID that is used to fetch file scan tasks + public static final String FILE_SCAN_TASK_SET_ID = "file-scan-task-set-id"; + + // skip snapshots of type delete while reading stream out of iceberg table + public static final String STREAMING_SKIP_DELETE_SNAPSHOTS = "streaming-skip-delete-snapshots"; + public static final boolean STREAMING_SKIP_DELETE_SNAPSHOTS_DEFAULT = false; + + // skip snapshots of type overwrite while reading stream out of iceberg table + public static final String STREAMING_SKIP_OVERWRITE_SNAPSHOTS = "streaming-skip-overwrite-snapshots"; + public static final boolean STREAMING_SKIP_OVERWRITE_SNAPSHOTS_DEFAULT = false; + + // Controls whether to allow reading timestamps without zone info + public static final String HANDLE_TIMESTAMP_WITHOUT_TIMEZONE = "handle-timestamp-without-timezone"; + + // Controls whether to report locality information to Spark while allocating input partitions + public static final String LOCALITY = "locality"; + + // Timestamp in milliseconds; start a stream from the snapshot that occurs after this timestamp + public static final String STREAM_FROM_TIMESTAMP = "stream-from-timestamp"; +} diff --git a/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/SparkSQLProperties.java b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/SparkSQLProperties.java new file mode 100644 index 000000000000..f2dcc13bece0 --- /dev/null +++ b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/SparkSQLProperties.java @@ -0,0 +1,46 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iceberg.spark; + +public class SparkSQLProperties { + + private SparkSQLProperties() { + } + + // Controls whether vectorized reads are enabled + public static final String VECTORIZATION_ENABLED = "spark.sql.iceberg.vectorization.enabled"; + + // Controls whether reading/writing timestamps without timezones is allowed + public static final String HANDLE_TIMESTAMP_WITHOUT_TIMEZONE = "spark.sql.iceberg.handle-timestamp-without-timezone"; + public static final boolean HANDLE_TIMESTAMP_WITHOUT_TIMEZONE_DEFAULT = false; + + // Controls whether timestamp types for new tables should be stored with timezone info + public static final String USE_TIMESTAMP_WITHOUT_TIME_ZONE_IN_NEW_TABLES = + "spark.sql.iceberg.use-timestamp-without-timezone-in-new-tables"; + public static final boolean USE_TIMESTAMP_WITHOUT_TIME_ZONE_IN_NEW_TABLES_DEFAULT = false; + + // Controls whether to perform the nullability check during writes + public static final String CHECK_NULLABILITY = "spark.sql.iceberg.check-nullability"; + public static final boolean CHECK_NULLABILITY_DEFAULT = true; + + // Controls whether to check the order of fields during writes + public static final String CHECK_ORDERING = "spark.sql.iceberg.check-ordering"; + public static final boolean CHECK_ORDERING_DEFAULT = true; +} diff --git a/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/SparkSchemaUtil.java b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/SparkSchemaUtil.java new file mode 100644 index 000000000000..c242f8535206 --- /dev/null +++ b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/SparkSchemaUtil.java @@ -0,0 +1,347 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iceberg.spark; + +import java.util.Collection; +import java.util.Collections; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.function.Function; +import java.util.stream.Collectors; +import org.apache.iceberg.MetadataColumns; +import org.apache.iceberg.PartitionSpec; +import org.apache.iceberg.Schema; +import org.apache.iceberg.exceptions.ValidationException; +import org.apache.iceberg.expressions.Binder; +import org.apache.iceberg.expressions.Expression; +import org.apache.iceberg.relocated.com.google.common.base.Splitter; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableSet; +import org.apache.iceberg.relocated.com.google.common.collect.Lists; +import org.apache.iceberg.relocated.com.google.common.math.LongMath; +import org.apache.iceberg.types.Type; +import org.apache.iceberg.types.TypeUtil; +import org.apache.iceberg.types.Types; +import org.apache.spark.sql.AnalysisException; +import org.apache.spark.sql.SparkSession; +import org.apache.spark.sql.catalog.Column; +import org.apache.spark.sql.types.DataType; +import org.apache.spark.sql.types.StructType; + +/** + * Helper methods for working with Spark/Hive metadata. + */ +public class SparkSchemaUtil { + private SparkSchemaUtil() { + } + + /** + * Returns a {@link Schema} for the given table with fresh field ids. + *

+ * This creates a Schema for an existing table by looking up the table's schema with Spark and + * converting that schema. Spark/Hive partition columns are included in the schema. + * + * @param spark a Spark session + * @param name a table name and (optional) database + * @return a Schema for the table, if found + */ + public static Schema schemaForTable(SparkSession spark, String name) { + StructType sparkType = spark.table(name).schema(); + Type converted = SparkTypeVisitor.visit(sparkType, new SparkTypeToType(sparkType)); + return new Schema(converted.asNestedType().asStructType().fields()); + } + + /** + * Returns a {@link PartitionSpec} for the given table. + *

+ * This creates a partition spec for an existing table by looking up the table's schema and + * creating a spec with identity partitions for each partition column. + * + * @param spark a Spark session + * @param name a table name and (optional) database + * @return a PartitionSpec for the table + * @throws AnalysisException if thrown by the Spark catalog + */ + public static PartitionSpec specForTable(SparkSession spark, String name) throws AnalysisException { + List parts = Lists.newArrayList(Splitter.on('.').limit(2).split(name)); + String db = parts.size() == 1 ? "default" : parts.get(0); + String table = parts.get(parts.size() == 1 ? 0 : 1); + + PartitionSpec spec = identitySpec( + schemaForTable(spark, name), + spark.catalog().listColumns(db, table).collectAsList()); + return spec == null ? PartitionSpec.unpartitioned() : spec; + } + + /** + * Convert a {@link Schema} to a {@link DataType Spark type}. + * + * @param schema a Schema + * @return the equivalent Spark type + * @throws IllegalArgumentException if the type cannot be converted to Spark + */ + public static StructType convert(Schema schema) { + return (StructType) TypeUtil.visit(schema, new TypeToSparkType()); + } + + /** + * Convert a {@link Type} to a {@link DataType Spark type}. + * + * @param type a Type + * @return the equivalent Spark type + * @throws IllegalArgumentException if the type cannot be converted to Spark + */ + public static DataType convert(Type type) { + return TypeUtil.visit(type, new TypeToSparkType()); + } + + /** + * Convert a Spark {@link StructType struct} to a {@link Schema} with new field ids. + *

+ * This conversion assigns fresh ids. + *

+ * Some data types are represented as the same Spark type. These are converted to a default type. + *

+ * To convert using a reference schema for field ids and ambiguous types, use + * {@link #convert(Schema, StructType)}. + * + * @param sparkType a Spark StructType + * @return the equivalent Schema + * @throws IllegalArgumentException if the type cannot be converted + */ + public static Schema convert(StructType sparkType) { + return convert(sparkType, false); + } + + /** + * Convert a Spark {@link StructType struct} to a {@link Schema} with new field ids. + *

+ * This conversion assigns fresh ids. + *

+ * Some data types are represented as the same Spark type. These are converted to a default type. + *

+ * To convert using a reference schema for field ids and ambiguous types, use + * {@link #convert(Schema, StructType)}. + * + * @param sparkType a Spark StructType + * @param useTimestampWithoutZone boolean flag indicates that timestamp should be stored without timezone + * @return the equivalent Schema + * @throws IllegalArgumentException if the type cannot be converted + */ + public static Schema convert(StructType sparkType, boolean useTimestampWithoutZone) { + Type converted = SparkTypeVisitor.visit(sparkType, new SparkTypeToType(sparkType)); + Schema schema = new Schema(converted.asNestedType().asStructType().fields()); + if (useTimestampWithoutZone) { + schema = SparkFixupTimestampType.fixup(schema); + } + return schema; + } + + /** + * Convert a Spark {@link DataType struct} to a {@link Type} with new field ids. + *

+ * This conversion assigns fresh ids. + *

+ * Some data types are represented as the same Spark type. These are converted to a default type. + *

+ * To convert using a reference schema for field ids and ambiguous types, use + * {@link #convert(Schema, StructType)}. + * + * @param sparkType a Spark DataType + * @return the equivalent Type + * @throws IllegalArgumentException if the type cannot be converted + */ + public static Type convert(DataType sparkType) { + return SparkTypeVisitor.visit(sparkType, new SparkTypeToType()); + } + + /** + * Convert a Spark {@link StructType struct} to a {@link Schema} based on the given schema. + *

+ * This conversion does not assign new ids; it uses ids from the base schema. + *

+ * Data types, field order, and nullability will match the spark type. This conversion may return + * a schema that is not compatible with base schema. + * + * @param baseSchema a Schema on which conversion is based + * @param sparkType a Spark StructType + * @return the equivalent Schema + * @throws IllegalArgumentException if the type cannot be converted or there are missing ids + */ + public static Schema convert(Schema baseSchema, StructType sparkType) { + // convert to a type with fresh ids + Types.StructType struct = SparkTypeVisitor.visit(sparkType, new SparkTypeToType(sparkType)).asStructType(); + // reassign ids to match the base schema + Schema schema = TypeUtil.reassignIds(new Schema(struct.fields()), baseSchema); + // fix types that can't be represented in Spark (UUID and Fixed) + return SparkFixupTypes.fixup(schema, baseSchema); + } + + /** + * Convert a Spark {@link StructType struct} to a {@link Schema} based on the given schema. + *

+ * This conversion will assign new ids for fields that are not found in the base schema. + *

+ * Data types, field order, and nullability will match the spark type. This conversion may return + * a schema that is not compatible with base schema. + * + * @param baseSchema a Schema on which conversion is based + * @param sparkType a Spark StructType + * @return the equivalent Schema + * @throws IllegalArgumentException if the type cannot be converted or there are missing ids + */ + public static Schema convertWithFreshIds(Schema baseSchema, StructType sparkType) { + // convert to a type with fresh ids + Types.StructType struct = SparkTypeVisitor.visit(sparkType, new SparkTypeToType(sparkType)).asStructType(); + // reassign ids to match the base schema + Schema schema = TypeUtil.reassignOrRefreshIds(new Schema(struct.fields()), baseSchema); + // fix types that can't be represented in Spark (UUID and Fixed) + return SparkFixupTypes.fixup(schema, baseSchema); + } + + /** + * Prune columns from a {@link Schema} using a {@link StructType Spark type} projection. + *

+ * This requires that the Spark type is a projection of the Schema. Nullability and types must + * match. + * + * @param schema a Schema + * @param requestedType a projection of the Spark representation of the Schema + * @return a Schema corresponding to the Spark projection + * @throws IllegalArgumentException if the Spark type does not match the Schema + */ + public static Schema prune(Schema schema, StructType requestedType) { + return new Schema(TypeUtil.visit(schema, new PruneColumnsWithoutReordering(requestedType, ImmutableSet.of())) + .asNestedType() + .asStructType() + .fields()); + } + + /** + * Prune columns from a {@link Schema} using a {@link StructType Spark type} projection. + *

+ * This requires that the Spark type is a projection of the Schema. Nullability and types must + * match. + *

+ * The filters list of {@link Expression} is used to ensure that columns referenced by filters + * are projected. + * + * @param schema a Schema + * @param requestedType a projection of the Spark representation of the Schema + * @param filters a list of filters + * @return a Schema corresponding to the Spark projection + * @throws IllegalArgumentException if the Spark type does not match the Schema + */ + public static Schema prune(Schema schema, StructType requestedType, List filters) { + Set filterRefs = Binder.boundReferences(schema.asStruct(), filters, true); + return new Schema(TypeUtil.visit(schema, new PruneColumnsWithoutReordering(requestedType, filterRefs)) + .asNestedType() + .asStructType() + .fields()); + } + + /** + * Prune columns from a {@link Schema} using a {@link StructType Spark type} projection. + *

+ * This requires that the Spark type is a projection of the Schema. Nullability and types must + * match. + *

+ * The filters list of {@link Expression} is used to ensure that columns referenced by filters + * are projected. + * + * @param schema a Schema + * @param requestedType a projection of the Spark representation of the Schema + * @param filter a filters + * @return a Schema corresponding to the Spark projection + * @throws IllegalArgumentException if the Spark type does not match the Schema + */ + public static Schema prune(Schema schema, StructType requestedType, Expression filter, boolean caseSensitive) { + Set filterRefs = + Binder.boundReferences(schema.asStruct(), Collections.singletonList(filter), caseSensitive); + + return new Schema(TypeUtil.visit(schema, new PruneColumnsWithoutReordering(requestedType, filterRefs)) + .asNestedType() + .asStructType() + .fields()); + } + + private static PartitionSpec identitySpec(Schema schema, Collection columns) { + List names = Lists.newArrayList(); + for (Column column : columns) { + if (column.isPartition()) { + names.add(column.name()); + } + } + + return identitySpec(schema, names); + } + + private static PartitionSpec identitySpec(Schema schema, List partitionNames) { + if (partitionNames == null || partitionNames.isEmpty()) { + return null; + } + + PartitionSpec.Builder builder = PartitionSpec.builderFor(schema); + for (String partitionName : partitionNames) { + builder.identity(partitionName); + } + + return builder.build(); + } + + /** + * Estimate approximate table size based on Spark schema and total records. + * + * @param tableSchema Spark schema + * @param totalRecords total records in the table + * @return approximate size based on table schema + */ + public static long estimateSize(StructType tableSchema, long totalRecords) { + if (totalRecords == Long.MAX_VALUE) { + return totalRecords; + } + + long result; + try { + result = LongMath.checkedMultiply(tableSchema.defaultSize(), totalRecords); + } catch (ArithmeticException e) { + result = Long.MAX_VALUE; + } + return result; + } + + public static void validateMetadataColumnReferences(Schema tableSchema, Schema readSchema) { + List conflictingColumnNames = readSchema.columns().stream() + .map(Types.NestedField::name) + .filter(name -> MetadataColumns.isMetadataColumn(name) && tableSchema.findField(name) != null) + .collect(Collectors.toList()); + + ValidationException.check( + conflictingColumnNames.isEmpty(), + "Table column names conflict with names reserved for Iceberg metadata columns: %s.\n" + + "Please, use ALTER TABLE statements to rename the conflicting table columns.", + conflictingColumnNames); + } + + public static Map indexQuotedNameById(Schema schema) { + Function quotingFunc = name -> String.format("`%s`", name.replace("`", "``")); + return TypeUtil.indexQuotedNameById(schema.asStruct(), quotingFunc); + } +} diff --git a/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/SparkSessionCatalog.java b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/SparkSessionCatalog.java new file mode 100644 index 000000000000..4f8e90585bb1 --- /dev/null +++ b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/SparkSessionCatalog.java @@ -0,0 +1,357 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iceberg.spark; + +import java.util.Map; +import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.hive.conf.HiveConf; +import org.apache.iceberg.CatalogProperties; +import org.apache.iceberg.catalog.Catalog; +import org.apache.iceberg.relocated.com.google.common.base.Preconditions; +import org.apache.iceberg.spark.source.HasIcebergCatalog; +import org.apache.spark.sql.SparkSession; +import org.apache.spark.sql.catalyst.analysis.NamespaceAlreadyExistsException; +import org.apache.spark.sql.catalyst.analysis.NoSuchFunctionException; +import org.apache.spark.sql.catalyst.analysis.NoSuchNamespaceException; +import org.apache.spark.sql.catalyst.analysis.NoSuchTableException; +import org.apache.spark.sql.catalyst.analysis.NonEmptyNamespaceException; +import org.apache.spark.sql.catalyst.analysis.TableAlreadyExistsException; +import org.apache.spark.sql.connector.catalog.CatalogExtension; +import org.apache.spark.sql.connector.catalog.CatalogPlugin; +import org.apache.spark.sql.connector.catalog.Identifier; +import org.apache.spark.sql.connector.catalog.NamespaceChange; +import org.apache.spark.sql.connector.catalog.StagedTable; +import org.apache.spark.sql.connector.catalog.StagingTableCatalog; +import org.apache.spark.sql.connector.catalog.SupportsNamespaces; +import org.apache.spark.sql.connector.catalog.Table; +import org.apache.spark.sql.connector.catalog.TableCatalog; +import org.apache.spark.sql.connector.catalog.TableChange; +import org.apache.spark.sql.connector.catalog.functions.UnboundFunction; +import org.apache.spark.sql.connector.expressions.Transform; +import org.apache.spark.sql.types.StructType; +import org.apache.spark.sql.util.CaseInsensitiveStringMap; + +/** + * A Spark catalog that can also load non-Iceberg tables. + * + * @param CatalogPlugin class to avoid casting to TableCatalog and SupportsNamespaces. + */ +public class SparkSessionCatalog + extends BaseCatalog implements CatalogExtension { + private static final String[] DEFAULT_NAMESPACE = new String[] {"default"}; + + private String catalogName = null; + private TableCatalog icebergCatalog = null; + private StagingTableCatalog asStagingCatalog = null; + private T sessionCatalog = null; + private boolean createParquetAsIceberg = false; + private boolean createAvroAsIceberg = false; + private boolean createOrcAsIceberg = false; + + /** + * Build a {@link SparkCatalog} to be used for Iceberg operations. + *

+ * The default implementation creates a new SparkCatalog with the session catalog's name and options. + * + * @param name catalog name + * @param options catalog options + * @return a SparkCatalog to be used for Iceberg tables + */ + protected TableCatalog buildSparkCatalog(String name, CaseInsensitiveStringMap options) { + SparkCatalog newCatalog = new SparkCatalog(); + newCatalog.initialize(name, options); + return newCatalog; + } + + @Override + public String[] defaultNamespace() { + return DEFAULT_NAMESPACE; + } + + @Override + public String[][] listNamespaces() throws NoSuchNamespaceException { + return getSessionCatalog().listNamespaces(); + } + + @Override + public String[][] listNamespaces(String[] namespace) throws NoSuchNamespaceException { + return getSessionCatalog().listNamespaces(namespace); + } + + @Override + public boolean namespaceExists(String[] namespace) { + return getSessionCatalog().namespaceExists(namespace); + } + + @Override + public Map loadNamespaceMetadata(String[] namespace) throws NoSuchNamespaceException { + return getSessionCatalog().loadNamespaceMetadata(namespace); + } + + @Override + public void createNamespace(String[] namespace, Map metadata) throws NamespaceAlreadyExistsException { + getSessionCatalog().createNamespace(namespace, metadata); + } + + @Override + public void alterNamespace(String[] namespace, NamespaceChange... changes) throws NoSuchNamespaceException { + getSessionCatalog().alterNamespace(namespace, changes); + } + + @Override + public boolean dropNamespace(String[] namespace, boolean cascade) + throws NoSuchNamespaceException, NonEmptyNamespaceException { + return getSessionCatalog().dropNamespace(namespace, cascade); + } + + @Override + public Identifier[] listTables(String[] namespace) throws NoSuchNamespaceException { + // delegate to the session catalog because all tables share the same namespace + return getSessionCatalog().listTables(namespace); + } + + @Override + public Table loadTable(Identifier ident) throws NoSuchTableException { + try { + return icebergCatalog.loadTable(ident); + } catch (NoSuchTableException e) { + return getSessionCatalog().loadTable(ident); + } + } + + @Override + public void invalidateTable(Identifier ident) { + // We do not need to check whether the table exists and whether + // it is an Iceberg table to reduce remote service requests. + icebergCatalog.invalidateTable(ident); + getSessionCatalog().invalidateTable(ident); + } + + @Override + public Table createTable(Identifier ident, StructType schema, Transform[] partitions, + Map properties) + throws TableAlreadyExistsException, NoSuchNamespaceException { + String provider = properties.get("provider"); + if (useIceberg(provider)) { + return icebergCatalog.createTable(ident, schema, partitions, properties); + } else { + // delegate to the session catalog + return getSessionCatalog().createTable(ident, schema, partitions, properties); + } + } + + @Override + public StagedTable stageCreate(Identifier ident, StructType schema, Transform[] partitions, + Map properties) + throws TableAlreadyExistsException, NoSuchNamespaceException { + String provider = properties.get("provider"); + TableCatalog catalog; + if (useIceberg(provider)) { + if (asStagingCatalog != null) { + return asStagingCatalog.stageCreate(ident, schema, partitions, properties); + } + catalog = icebergCatalog; + } else { + catalog = getSessionCatalog(); + } + + // create the table with the session catalog, then wrap it in a staged table that will delete to roll back + Table table = catalog.createTable(ident, schema, partitions, properties); + return new RollbackStagedTable(catalog, ident, table); + } + + @Override + public StagedTable stageReplace(Identifier ident, StructType schema, Transform[] partitions, + Map properties) + throws NoSuchNamespaceException, NoSuchTableException { + String provider = properties.get("provider"); + TableCatalog catalog; + if (useIceberg(provider)) { + if (asStagingCatalog != null) { + return asStagingCatalog.stageReplace(ident, schema, partitions, properties); + } + catalog = icebergCatalog; + } else { + catalog = getSessionCatalog(); + } + + // attempt to drop the table and fail if it doesn't exist + if (!catalog.dropTable(ident)) { + throw new NoSuchTableException(ident); + } + + try { + // create the table with the session catalog, then wrap it in a staged table that will delete to roll back + Table table = catalog.createTable(ident, schema, partitions, properties); + return new RollbackStagedTable(catalog, ident, table); + + } catch (TableAlreadyExistsException e) { + // the table was deleted, but now already exists again. retry the replace. + return stageReplace(ident, schema, partitions, properties); + } + } + + @Override + public StagedTable stageCreateOrReplace(Identifier ident, StructType schema, Transform[] partitions, + Map properties) throws NoSuchNamespaceException { + String provider = properties.get("provider"); + TableCatalog catalog; + if (useIceberg(provider)) { + if (asStagingCatalog != null) { + return asStagingCatalog.stageCreateOrReplace(ident, schema, partitions, properties); + } + catalog = icebergCatalog; + } else { + catalog = getSessionCatalog(); + } + + // drop the table if it exists + catalog.dropTable(ident); + + try { + // create the table with the session catalog, then wrap it in a staged table that will delete to roll back + Table sessionCatalogTable = catalog.createTable(ident, schema, partitions, properties); + return new RollbackStagedTable(catalog, ident, sessionCatalogTable); + + } catch (TableAlreadyExistsException e) { + // the table was deleted, but now already exists again. retry the replace. + return stageCreateOrReplace(ident, schema, partitions, properties); + } + } + + @Override + public Table alterTable(Identifier ident, TableChange... changes) throws NoSuchTableException { + if (icebergCatalog.tableExists(ident)) { + return icebergCatalog.alterTable(ident, changes); + } else { + return getSessionCatalog().alterTable(ident, changes); + } + } + + @Override + public boolean dropTable(Identifier ident) { + // no need to check table existence to determine which catalog to use. if a table doesn't exist then both are + // required to return false. + return icebergCatalog.dropTable(ident) || getSessionCatalog().dropTable(ident); + } + + @Override + public boolean purgeTable(Identifier ident) { + // no need to check table existence to determine which catalog to use. if a table doesn't exist then both are + // required to return false. + return icebergCatalog.purgeTable(ident) || getSessionCatalog().purgeTable(ident); + } + + @Override + public void renameTable(Identifier from, Identifier to) throws NoSuchTableException, TableAlreadyExistsException { + // rename is not supported by HadoopCatalog. to avoid UnsupportedOperationException for session catalog tables, + // check table existence first to ensure that the table belongs to the Iceberg catalog. + if (icebergCatalog.tableExists(from)) { + icebergCatalog.renameTable(from, to); + } else { + getSessionCatalog().renameTable(from, to); + } + } + + @Override + public final void initialize(String name, CaseInsensitiveStringMap options) { + if (options.containsKey("type") && options.get("type").equalsIgnoreCase("hive")) { + validateHmsUri(options.get(CatalogProperties.URI)); + } + + this.catalogName = name; + this.icebergCatalog = buildSparkCatalog(name, options); + if (icebergCatalog instanceof StagingTableCatalog) { + this.asStagingCatalog = (StagingTableCatalog) icebergCatalog; + } + + this.createParquetAsIceberg = options.getBoolean("parquet-enabled", createParquetAsIceberg); + this.createAvroAsIceberg = options.getBoolean("avro-enabled", createAvroAsIceberg); + this.createOrcAsIceberg = options.getBoolean("orc-enabled", createOrcAsIceberg); + } + + private void validateHmsUri(String catalogHmsUri) { + if (catalogHmsUri == null) { + return; + } + + Configuration conf = SparkSession.active().sessionState().newHadoopConf(); + String envHmsUri = conf.get(HiveConf.ConfVars.METASTOREURIS.varname, null); + if (envHmsUri == null) { + return; + } + + Preconditions.checkArgument(catalogHmsUri.equals(envHmsUri), + "Inconsistent Hive metastore URIs: %s (Spark session) != %s (spark_catalog)", + envHmsUri, catalogHmsUri); + } + + @Override + @SuppressWarnings("unchecked") + public void setDelegateCatalog(CatalogPlugin sparkSessionCatalog) { + if (sparkSessionCatalog instanceof TableCatalog && sparkSessionCatalog instanceof SupportsNamespaces) { + this.sessionCatalog = (T) sparkSessionCatalog; + } else { + throw new IllegalArgumentException("Invalid session catalog: " + sparkSessionCatalog); + } + } + + @Override + public String name() { + return catalogName; + } + + private boolean useIceberg(String provider) { + if (provider == null || "iceberg".equalsIgnoreCase(provider)) { + return true; + } else if (createParquetAsIceberg && "parquet".equalsIgnoreCase(provider)) { + return true; + } else if (createAvroAsIceberg && "avro".equalsIgnoreCase(provider)) { + return true; + } else if (createOrcAsIceberg && "orc".equalsIgnoreCase(provider)) { + return true; + } + + return false; + } + + private T getSessionCatalog() { + Preconditions.checkNotNull(sessionCatalog, "Delegated SessionCatalog is missing. " + + "Please make sure your are replacing Spark's default catalog, named 'spark_catalog'."); + return sessionCatalog; + } + + @Override + public Catalog icebergCatalog() { + Preconditions.checkArgument(icebergCatalog instanceof HasIcebergCatalog, + "Cannot return underlying Iceberg Catalog, wrapped catalog does not contain an Iceberg Catalog"); + return ((HasIcebergCatalog) icebergCatalog).icebergCatalog(); + } + + @Override + public Identifier[] listFunctions(String[] namespace) { + return new Identifier[0]; + } + + @Override + public UnboundFunction loadFunction(Identifier ident) throws NoSuchFunctionException { + throw new NoSuchFunctionException(ident); + } +} diff --git a/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/SparkStructLike.java b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/SparkStructLike.java new file mode 100644 index 000000000000..30509e3381dc --- /dev/null +++ b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/SparkStructLike.java @@ -0,0 +1,55 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iceberg.spark; + +import org.apache.iceberg.StructLike; +import org.apache.iceberg.types.Types; +import org.apache.spark.sql.Row; + +public class SparkStructLike implements StructLike { + + private final Types.StructType type; + private Row wrapped; + + public SparkStructLike(Types.StructType type) { + this.type = type; + } + + public SparkStructLike wrap(Row row) { + this.wrapped = row; + return this; + } + + @Override + public int size() { + return type.fields().size(); + } + + @Override + public T get(int pos, Class javaClass) { + Types.NestedField field = type.fields().get(pos); + return javaClass.cast(SparkValueConverter.convert(field.type(), wrapped.get(pos))); + } + + @Override + public void set(int pos, T value) { + throw new UnsupportedOperationException("Not implemented: set"); + } +} diff --git a/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/SparkTableUtil.java b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/SparkTableUtil.java new file mode 100644 index 000000000000..2fedb4c74a70 --- /dev/null +++ b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/SparkTableUtil.java @@ -0,0 +1,673 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iceberg.spark; + +import java.io.IOException; +import java.io.Serializable; +import java.net.URI; +import java.util.Collections; +import java.util.Iterator; +import java.util.List; +import java.util.Map; +import java.util.stream.Collectors; +import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.fs.Path; +import org.apache.iceberg.AppendFiles; +import org.apache.iceberg.DataFile; +import org.apache.iceberg.FileFormat; +import org.apache.iceberg.ManifestFile; +import org.apache.iceberg.ManifestFiles; +import org.apache.iceberg.ManifestWriter; +import org.apache.iceberg.MetadataTableType; +import org.apache.iceberg.MetadataTableUtils; +import org.apache.iceberg.MetricsConfig; +import org.apache.iceberg.PartitionSpec; +import org.apache.iceberg.Table; +import org.apache.iceberg.TableProperties; +import org.apache.iceberg.data.TableMigrationUtil; +import org.apache.iceberg.hadoop.HadoopFileIO; +import org.apache.iceberg.hadoop.SerializableConfiguration; +import org.apache.iceberg.hadoop.Util; +import org.apache.iceberg.io.FileIO; +import org.apache.iceberg.io.OutputFile; +import org.apache.iceberg.mapping.NameMapping; +import org.apache.iceberg.mapping.NameMappingParser; +import org.apache.iceberg.relocated.com.google.common.base.Joiner; +import org.apache.iceberg.relocated.com.google.common.base.MoreObjects; +import org.apache.iceberg.relocated.com.google.common.base.Objects; +import org.apache.iceberg.relocated.com.google.common.base.Preconditions; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableList; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap; +import org.apache.iceberg.relocated.com.google.common.collect.Lists; +import org.apache.iceberg.relocated.com.google.common.collect.Maps; +import org.apache.iceberg.spark.source.SparkTable; +import org.apache.iceberg.util.PropertyUtil; +import org.apache.iceberg.util.Tasks; +import org.apache.iceberg.util.ThreadPools; +import org.apache.spark.TaskContext; +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.api.java.function.FlatMapFunction; +import org.apache.spark.api.java.function.MapFunction; +import org.apache.spark.api.java.function.MapPartitionsFunction; +import org.apache.spark.sql.AnalysisException; +import org.apache.spark.sql.Column; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Encoders; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.SparkSession; +import org.apache.spark.sql.catalyst.TableIdentifier; +import org.apache.spark.sql.catalyst.analysis.NoSuchDatabaseException; +import org.apache.spark.sql.catalyst.analysis.NoSuchTableException; +import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute; +import org.apache.spark.sql.catalyst.catalog.CatalogTable; +import org.apache.spark.sql.catalyst.catalog.CatalogTablePartition; +import org.apache.spark.sql.catalyst.catalog.SessionCatalog; +import org.apache.spark.sql.catalyst.expressions.Expression; +import org.apache.spark.sql.catalyst.expressions.NamedExpression; +import org.apache.spark.sql.catalyst.parser.ParseException; +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan; +import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation; +import org.apache.spark.sql.util.CaseInsensitiveStringMap; +import scala.Function2; +import scala.Option; +import scala.Some; +import scala.Tuple2; +import scala.collection.JavaConverters; +import scala.collection.immutable.Map$; +import scala.collection.immutable.Seq; +import scala.collection.mutable.Builder; +import scala.runtime.AbstractPartialFunction; + +import static org.apache.spark.sql.functions.col; + +/** + * Java version of the original SparkTableUtil.scala + * https://github.com/apache/iceberg/blob/apache-iceberg-0.8.0-incubating/spark/src/main/scala/org/apache/iceberg/spark/SparkTableUtil.scala + */ +public class SparkTableUtil { + + private static final String DUPLICATE_FILE_MESSAGE = "Cannot complete import because data files " + + "to be imported already exist within the target table: %s. " + + "This is disabled by default as Iceberg is not designed for mulitple references to the same file" + + " within the same table. If you are sure, you may set 'check_duplicate_files' to false to force the import."; + + + private SparkTableUtil() { + } + + /** + * Returns a DataFrame with a row for each partition in the table. + * + * The DataFrame has 3 columns, partition key (a=1/b=2), partition location, and format + * (avro or parquet). + * + * @param spark a Spark session + * @param table a table name and (optional) database + * @return a DataFrame of the table's partitions + */ + public static Dataset partitionDF(SparkSession spark, String table) { + List partitions = getPartitions(spark, table); + return spark.createDataFrame(partitions, SparkPartition.class).toDF("partition", "uri", "format"); + } + + /** + * Returns a DataFrame with a row for each partition that matches the specified 'expression'. + * + * @param spark a Spark session. + * @param table name of the table. + * @param expression The expression whose matching partitions are returned. + * @return a DataFrame of the table partitions. + */ + public static Dataset partitionDFByFilter(SparkSession spark, String table, String expression) { + List partitions = getPartitionsByFilter(spark, table, expression); + return spark.createDataFrame(partitions, SparkPartition.class).toDF("partition", "uri", "format"); + } + + /** + * Returns all partitions in the table. + * + * @param spark a Spark session + * @param table a table name and (optional) database + * @return all table's partitions + */ + public static List getPartitions(SparkSession spark, String table) { + try { + TableIdentifier tableIdent = spark.sessionState().sqlParser().parseTableIdentifier(table); + return getPartitions(spark, tableIdent, null); + } catch (ParseException e) { + throw SparkExceptionUtil.toUncheckedException(e, "Unable to parse table identifier: %s", table); + } + } + + /** + * Returns all partitions in the table. + * + * @param spark a Spark session + * @param tableIdent a table identifier + * @param partitionFilter partition filter, or null if no filter + * @return all table's partitions + */ + public static List getPartitions(SparkSession spark, TableIdentifier tableIdent, + Map partitionFilter) { + try { + SessionCatalog catalog = spark.sessionState().catalog(); + CatalogTable catalogTable = catalog.getTableMetadata(tableIdent); + + Option> scalaPartitionFilter; + if (partitionFilter != null && !partitionFilter.isEmpty()) { + Builder, scala.collection.immutable.Map> builder = + Map$.MODULE$.newBuilder(); + partitionFilter.forEach((key, value) -> builder.$plus$eq(Tuple2.apply(key, value))); + scalaPartitionFilter = Option.apply(builder.result()); + } else { + scalaPartitionFilter = Option.empty(); + } + Seq partitions = catalog.listPartitions(tableIdent, scalaPartitionFilter).toIndexedSeq(); + return JavaConverters + .seqAsJavaListConverter(partitions) + .asJava() + .stream() + .map(catalogPartition -> toSparkPartition(catalogPartition, catalogTable)) + .collect(Collectors.toList()); + } catch (NoSuchDatabaseException e) { + throw SparkExceptionUtil.toUncheckedException(e, "Unknown table: %s. Database not found in catalog.", tableIdent); + } catch (NoSuchTableException e) { + throw SparkExceptionUtil.toUncheckedException(e, "Unknown table: %s. Table not found in catalog.", tableIdent); + } + } + + /** + * Returns partitions that match the specified 'predicate'. + * + * @param spark a Spark session + * @param table a table name and (optional) database + * @param predicate a predicate on partition columns + * @return matching table's partitions + */ + public static List getPartitionsByFilter(SparkSession spark, String table, String predicate) { + TableIdentifier tableIdent; + try { + tableIdent = spark.sessionState().sqlParser().parseTableIdentifier(table); + } catch (ParseException e) { + throw SparkExceptionUtil.toUncheckedException(e, "Unable to parse the table identifier: %s", table); + } + + Expression unresolvedPredicateExpr; + try { + unresolvedPredicateExpr = spark.sessionState().sqlParser().parseExpression(predicate); + } catch (ParseException e) { + throw SparkExceptionUtil.toUncheckedException(e, "Unable to parse the predicate expression: %s", predicate); + } + + Expression resolvedPredicateExpr = resolveAttrs(spark, table, unresolvedPredicateExpr); + return getPartitionsByFilter(spark, tableIdent, resolvedPredicateExpr); + } + + /** + * Returns partitions that match the specified 'predicate'. + * + * @param spark a Spark session + * @param tableIdent a table identifier + * @param predicateExpr a predicate expression on partition columns + * @return matching table's partitions + */ + public static List getPartitionsByFilter(SparkSession spark, TableIdentifier tableIdent, + Expression predicateExpr) { + try { + SessionCatalog catalog = spark.sessionState().catalog(); + CatalogTable catalogTable = catalog.getTableMetadata(tableIdent); + + Expression resolvedPredicateExpr; + if (!predicateExpr.resolved()) { + resolvedPredicateExpr = resolveAttrs(spark, tableIdent.quotedString(), predicateExpr); + } else { + resolvedPredicateExpr = predicateExpr; + } + Seq predicates = JavaConverters + .collectionAsScalaIterableConverter(ImmutableList.of(resolvedPredicateExpr)) + .asScala().toIndexedSeq(); + + Seq partitions = catalog.listPartitionsByFilter(tableIdent, predicates).toIndexedSeq(); + + return JavaConverters + .seqAsJavaListConverter(partitions) + .asJava() + .stream() + .map(catalogPartition -> toSparkPartition(catalogPartition, catalogTable)) + .collect(Collectors.toList()); + } catch (NoSuchDatabaseException e) { + throw SparkExceptionUtil.toUncheckedException(e, "Unknown table: %s. Database not found in catalog.", tableIdent); + } catch (NoSuchTableException e) { + throw SparkExceptionUtil.toUncheckedException(e, "Unknown table: %s. Table not found in catalog.", tableIdent); + } + } + + /** + * Returns the data files in a partition by listing the partition location. + * + * For Parquet and ORC partitions, this will read metrics from the file footer. For Avro partitions, + * metrics are set to null. + * + * @param partition a partition + * @param conf a serializable Hadoop conf + * @param metricsConfig a metrics conf + * @return a List of DataFile + * @deprecated use {@link TableMigrationUtil#listPartition(Map, String, String, PartitionSpec, Configuration, + * MetricsConfig, NameMapping)} + */ + @Deprecated + public static List listPartition(SparkPartition partition, PartitionSpec spec, + SerializableConfiguration conf, MetricsConfig metricsConfig) { + return listPartition(partition, spec, conf, metricsConfig, null); + } + + /** + * Returns the data files in a partition by listing the partition location. + * + * For Parquet and ORC partitions, this will read metrics from the file footer. For Avro partitions, + * metrics are set to null. + * + * @param partition a partition + * @param conf a serializable Hadoop conf + * @param metricsConfig a metrics conf + * @param mapping a name mapping + * @return a List of DataFile + * @deprecated use {@link TableMigrationUtil#listPartition(Map, String, String, PartitionSpec, Configuration, + * MetricsConfig, NameMapping)} + */ + @Deprecated + public static List listPartition(SparkPartition partition, PartitionSpec spec, + SerializableConfiguration conf, MetricsConfig metricsConfig, + NameMapping mapping) { + return TableMigrationUtil.listPartition(partition.values, partition.uri, partition.format, spec, conf.get(), + metricsConfig, mapping); + } + + + private static SparkPartition toSparkPartition(CatalogTablePartition partition, CatalogTable table) { + Option locationUri = partition.storage().locationUri(); + Option serde = partition.storage().serde(); + + Preconditions.checkArgument(locationUri.nonEmpty(), "Partition URI should be defined"); + Preconditions.checkArgument(serde.nonEmpty() || table.provider().nonEmpty(), + "Partition format should be defined"); + + String uri = Util.uriToString(locationUri.get()); + String format = serde.nonEmpty() ? serde.get() : table.provider().get(); + + Map partitionSpec = JavaConverters.mapAsJavaMapConverter(partition.spec()).asJava(); + return new SparkPartition(partitionSpec, uri, format); + } + + private static Expression resolveAttrs(SparkSession spark, String table, Expression expr) { + Function2 resolver = spark.sessionState().analyzer().resolver(); + LogicalPlan plan = spark.table(table).queryExecution().analyzed(); + return expr.transform(new AbstractPartialFunction() { + @Override + public Expression apply(Expression attr) { + UnresolvedAttribute unresolvedAttribute = (UnresolvedAttribute) attr; + Option namedExpressionOption = plan.resolve(unresolvedAttribute.nameParts(), resolver); + if (namedExpressionOption.isDefined()) { + return (Expression) namedExpressionOption.get(); + } else { + throw new IllegalArgumentException( + String.format("Could not resolve %s using columns: %s", attr, plan.output())); + } + } + + @Override + public boolean isDefinedAt(Expression attr) { + return attr instanceof UnresolvedAttribute; + } + }); + } + + private static Iterator buildManifest(SerializableConfiguration conf, PartitionSpec spec, + String basePath, Iterator> fileTuples) { + if (fileTuples.hasNext()) { + FileIO io = new HadoopFileIO(conf.get()); + TaskContext ctx = TaskContext.get(); + String suffix = String.format("stage-%d-task-%d-manifest", ctx.stageId(), ctx.taskAttemptId()); + Path location = new Path(basePath, suffix); + String outputPath = FileFormat.AVRO.addExtension(location.toString()); + OutputFile outputFile = io.newOutputFile(outputPath); + ManifestWriter writer = ManifestFiles.write(spec, outputFile); + + try (ManifestWriter writerRef = writer) { + fileTuples.forEachRemaining(fileTuple -> writerRef.add(fileTuple._2)); + } catch (IOException e) { + throw SparkExceptionUtil.toUncheckedException(e, "Unable to close the manifest writer: %s", outputPath); + } + + ManifestFile manifestFile = writer.toManifestFile(); + return ImmutableList.of(manifestFile).iterator(); + } else { + return Collections.emptyIterator(); + } + } + + /** + * Import files from an existing Spark table to an Iceberg table. + * + * The import uses the Spark session to get table metadata. It assumes no + * operation is going on the original and target table and thus is not + * thread-safe. + * + * @param spark a Spark session + * @param sourceTableIdent an identifier of the source Spark table + * @param targetTable an Iceberg table where to import the data + * @param stagingDir a staging directory to store temporary manifest files + * @param partitionFilter only import partitions whose values match those in the map, can be partially defined + * @param checkDuplicateFiles if true, throw exception if import results in a duplicate data file + */ + public static void importSparkTable(SparkSession spark, TableIdentifier sourceTableIdent, Table targetTable, + String stagingDir, Map partitionFilter, + boolean checkDuplicateFiles) { + SessionCatalog catalog = spark.sessionState().catalog(); + + String db = sourceTableIdent.database().nonEmpty() ? + sourceTableIdent.database().get() : + catalog.getCurrentDatabase(); + TableIdentifier sourceTableIdentWithDB = new TableIdentifier(sourceTableIdent.table(), Some.apply(db)); + + if (!catalog.tableExists(sourceTableIdentWithDB)) { + throw new org.apache.iceberg.exceptions.NoSuchTableException("Table %s does not exist", sourceTableIdentWithDB); + } + + try { + PartitionSpec spec = SparkSchemaUtil.specForTable(spark, sourceTableIdentWithDB.unquotedString()); + + if (Objects.equal(spec, PartitionSpec.unpartitioned())) { + importUnpartitionedSparkTable(spark, sourceTableIdentWithDB, targetTable, checkDuplicateFiles); + } else { + List sourceTablePartitions = getPartitions(spark, sourceTableIdent, + partitionFilter); + Preconditions.checkArgument(!sourceTablePartitions.isEmpty(), + "Cannot find any partitions in table %s", sourceTableIdent); + importSparkPartitions(spark, sourceTablePartitions, targetTable, spec, stagingDir, checkDuplicateFiles); + } + } catch (AnalysisException e) { + throw SparkExceptionUtil.toUncheckedException( + e, "Unable to get partition spec for table: %s", sourceTableIdentWithDB); + } + } + + /** + * Import files from an existing Spark table to an Iceberg table. + * + * The import uses the Spark session to get table metadata. It assumes no + * operation is going on the original and target table and thus is not + * thread-safe. + * + * @param spark a Spark session + * @param sourceTableIdent an identifier of the source Spark table + * @param targetTable an Iceberg table where to import the data + * @param stagingDir a staging directory to store temporary manifest files + * @param checkDuplicateFiles if true, throw exception if import results in a duplicate data file + */ + public static void importSparkTable(SparkSession spark, TableIdentifier sourceTableIdent, Table targetTable, + String stagingDir, boolean checkDuplicateFiles) { + importSparkTable(spark, sourceTableIdent, targetTable, stagingDir, Collections.emptyMap(), checkDuplicateFiles); + } + + /** + * Import files from an existing Spark table to an Iceberg table. + * + * The import uses the Spark session to get table metadata. It assumes no + * operation is going on the original and target table and thus is not + * thread-safe. + * @param spark a Spark session + * @param sourceTableIdent an identifier of the source Spark table + * @param targetTable an Iceberg table where to import the data + * @param stagingDir a staging directory to store temporary manifest files + */ + public static void importSparkTable(SparkSession spark, TableIdentifier sourceTableIdent, Table targetTable, + String stagingDir) { + importSparkTable(spark, sourceTableIdent, targetTable, stagingDir, Collections.emptyMap(), false); + } + + private static void importUnpartitionedSparkTable(SparkSession spark, TableIdentifier sourceTableIdent, + Table targetTable, boolean checkDuplicateFiles) { + try { + CatalogTable sourceTable = spark.sessionState().catalog().getTableMetadata(sourceTableIdent); + Option format = + sourceTable.storage().serde().nonEmpty() ? sourceTable.storage().serde() : sourceTable.provider(); + Preconditions.checkArgument(format.nonEmpty(), "Could not determine table format"); + + Map partition = Collections.emptyMap(); + PartitionSpec spec = PartitionSpec.unpartitioned(); + Configuration conf = spark.sessionState().newHadoopConf(); + MetricsConfig metricsConfig = MetricsConfig.forTable(targetTable); + String nameMappingString = targetTable.properties().get(TableProperties.DEFAULT_NAME_MAPPING); + NameMapping nameMapping = nameMappingString != null ? NameMappingParser.fromJson(nameMappingString) : null; + + List files = TableMigrationUtil.listPartition( + partition, Util.uriToString(sourceTable.location()), format.get(), spec, conf, metricsConfig, nameMapping); + + if (checkDuplicateFiles) { + Dataset importedFiles = spark.createDataset( + Lists.transform(files, f -> f.path().toString()), Encoders.STRING()).toDF("file_path"); + Dataset existingFiles = loadMetadataTable(spark, targetTable, MetadataTableType.ENTRIES); + Column joinCond = existingFiles.col("data_file.file_path").equalTo(importedFiles.col("file_path")); + Dataset duplicates = importedFiles.join(existingFiles, joinCond) + .select("file_path").as(Encoders.STRING()); + Preconditions.checkState(duplicates.isEmpty(), + String.format(DUPLICATE_FILE_MESSAGE, Joiner.on(",").join((String[]) duplicates.take(10)))); + } + + AppendFiles append = targetTable.newAppend(); + files.forEach(append::appendFile); + append.commit(); + } catch (NoSuchDatabaseException e) { + throw SparkExceptionUtil.toUncheckedException( + e, "Unknown table: %s. Database not found in catalog.", sourceTableIdent); + } catch (NoSuchTableException e) { + throw SparkExceptionUtil.toUncheckedException( + e, "Unknown table: %s. Table not found in catalog.", sourceTableIdent); + } + } + + /** + * Import files from given partitions to an Iceberg table. + * + * @param spark a Spark session + * @param partitions partitions to import + * @param targetTable an Iceberg table where to import the data + * @param spec a partition spec + * @param stagingDir a staging directory to store temporary manifest files + * @param checkDuplicateFiles if true, throw exception if import results in a duplicate data file + */ + public static void importSparkPartitions(SparkSession spark, List partitions, Table targetTable, + PartitionSpec spec, String stagingDir, boolean checkDuplicateFiles) { + Configuration conf = spark.sessionState().newHadoopConf(); + SerializableConfiguration serializableConf = new SerializableConfiguration(conf); + int parallelism = Math.min(partitions.size(), spark.sessionState().conf().parallelPartitionDiscoveryParallelism()); + int numShufflePartitions = spark.sessionState().conf().numShufflePartitions(); + MetricsConfig metricsConfig = MetricsConfig.fromProperties(targetTable.properties()); + String nameMappingString = targetTable.properties().get(TableProperties.DEFAULT_NAME_MAPPING); + NameMapping nameMapping = nameMappingString != null ? NameMappingParser.fromJson(nameMappingString) : null; + + JavaSparkContext sparkContext = JavaSparkContext.fromSparkContext(spark.sparkContext()); + JavaRDD partitionRDD = sparkContext.parallelize(partitions, parallelism); + + Dataset partitionDS = spark.createDataset( + partitionRDD.rdd(), + Encoders.javaSerialization(SparkPartition.class)); + + Dataset filesToImport = partitionDS + .flatMap((FlatMapFunction) sparkPartition -> + listPartition(sparkPartition, spec, serializableConf, metricsConfig, nameMapping).iterator(), + Encoders.javaSerialization(DataFile.class)); + + if (checkDuplicateFiles) { + Dataset importedFiles = filesToImport + .map((MapFunction) f -> f.path().toString(), Encoders.STRING()) + .toDF("file_path"); + Dataset existingFiles = loadMetadataTable(spark, targetTable, MetadataTableType.ENTRIES); + Column joinCond = existingFiles.col("data_file.file_path").equalTo(importedFiles.col("file_path")); + Dataset duplicates = importedFiles.join(existingFiles, joinCond) + .select("file_path").as(Encoders.STRING()); + Preconditions.checkState(duplicates.isEmpty(), + String.format(DUPLICATE_FILE_MESSAGE, Joiner.on(",").join((String[]) duplicates.take(10)))); + } + + List manifests = filesToImport + .repartition(numShufflePartitions) + .map((MapFunction>) file -> + Tuple2.apply(file.path().toString(), file), + Encoders.tuple(Encoders.STRING(), Encoders.javaSerialization(DataFile.class))) + .orderBy(col("_1")) + .mapPartitions( + (MapPartitionsFunction, ManifestFile>) fileTuple -> + buildManifest(serializableConf, spec, stagingDir, fileTuple), + Encoders.javaSerialization(ManifestFile.class)) + .collectAsList(); + + try { + boolean snapshotIdInheritanceEnabled = PropertyUtil.propertyAsBoolean( + targetTable.properties(), + TableProperties.SNAPSHOT_ID_INHERITANCE_ENABLED, + TableProperties.SNAPSHOT_ID_INHERITANCE_ENABLED_DEFAULT); + + AppendFiles append = targetTable.newAppend(); + manifests.forEach(append::appendManifest); + append.commit(); + + if (!snapshotIdInheritanceEnabled) { + // delete original manifests as they were rewritten before the commit + deleteManifests(targetTable.io(), manifests); + } + } catch (Throwable e) { + deleteManifests(targetTable.io(), manifests); + throw e; + } + } + + /** + * Import files from given partitions to an Iceberg table. + * + * @param spark a Spark session + * @param partitions partitions to import + * @param targetTable an Iceberg table where to import the data + * @param spec a partition spec + * @param stagingDir a staging directory to store temporary manifest files + */ + public static void importSparkPartitions(SparkSession spark, List partitions, Table targetTable, + PartitionSpec spec, String stagingDir) { + importSparkPartitions(spark, partitions, targetTable, spec, stagingDir, false); + } + + public static List filterPartitions(List partitions, + Map partitionFilter) { + if (partitionFilter.isEmpty()) { + return partitions; + } else { + return partitions.stream() + .filter(p -> p.getValues().entrySet().containsAll(partitionFilter.entrySet())) + .collect(Collectors.toList()); + } + } + + private static void deleteManifests(FileIO io, List manifests) { + Tasks.foreach(manifests) + .executeWith(ThreadPools.getWorkerPool()) + .noRetry() + .suppressFailureWhenFinished() + .run(item -> io.deleteFile(item.path())); + } + + /** + * Loads a metadata table. + * + * @deprecated since 0.14.0, will be removed in 0.15.0; + * use {@link #loadMetadataTable(SparkSession, Table, MetadataTableType)}. + */ + @Deprecated + public static Dataset loadCatalogMetadataTable(SparkSession spark, Table table, MetadataTableType type) { + return loadMetadataTable(spark, table, type); + } + + public static Dataset loadMetadataTable(SparkSession spark, Table table, MetadataTableType type) { + return loadMetadataTable(spark, table, type, ImmutableMap.of()); + } + + public static Dataset loadMetadataTable(SparkSession spark, Table table, MetadataTableType type, + Map extraOptions) { + SparkTable metadataTable = new SparkTable(MetadataTableUtils.createMetadataTableInstance(table, type), false); + CaseInsensitiveStringMap options = new CaseInsensitiveStringMap(extraOptions); + return Dataset.ofRows(spark, DataSourceV2Relation.create(metadataTable, Some.empty(), Some.empty(), options)); + } + + /** + * Class representing a table partition. + */ + public static class SparkPartition implements Serializable { + private final Map values; + private final String uri; + private final String format; + + public SparkPartition(Map values, String uri, String format) { + this.values = Maps.newHashMap(values); + this.uri = uri; + this.format = format; + } + + public Map getValues() { + return values; + } + + public String getUri() { + return uri; + } + + public String getFormat() { + return format; + } + + @Override + public String toString() { + return MoreObjects.toStringHelper(this) + .add("values", values) + .add("uri", uri) + .add("format", format) + .toString(); + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + SparkPartition that = (SparkPartition) o; + return Objects.equal(values, that.values) && + Objects.equal(uri, that.uri) && + Objects.equal(format, that.format); + } + + @Override + public int hashCode() { + return Objects.hashCode(values, uri, format); + } + } +} diff --git a/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/SparkTypeToType.java b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/SparkTypeToType.java new file mode 100644 index 000000000000..f0b8b2a9762b --- /dev/null +++ b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/SparkTypeToType.java @@ -0,0 +1,163 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iceberg.spark; + +import java.util.List; +import org.apache.iceberg.relocated.com.google.common.collect.Lists; +import org.apache.iceberg.types.Type; +import org.apache.iceberg.types.Types; +import org.apache.spark.sql.types.ArrayType; +import org.apache.spark.sql.types.BinaryType; +import org.apache.spark.sql.types.BooleanType; +import org.apache.spark.sql.types.ByteType; +import org.apache.spark.sql.types.CharType; +import org.apache.spark.sql.types.DataType; +import org.apache.spark.sql.types.DateType; +import org.apache.spark.sql.types.DecimalType; +import org.apache.spark.sql.types.DoubleType; +import org.apache.spark.sql.types.FloatType; +import org.apache.spark.sql.types.IntegerType; +import org.apache.spark.sql.types.LongType; +import org.apache.spark.sql.types.MapType; +import org.apache.spark.sql.types.ShortType; +import org.apache.spark.sql.types.StringType; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; +import org.apache.spark.sql.types.TimestampType; +import org.apache.spark.sql.types.VarcharType; + +class SparkTypeToType extends SparkTypeVisitor { + private final StructType root; + private int nextId = 0; + + SparkTypeToType() { + this.root = null; + } + + SparkTypeToType(StructType root) { + this.root = root; + // the root struct's fields use the first ids + this.nextId = root.fields().length; + } + + private int getNextId() { + int next = nextId; + nextId += 1; + return next; + } + + @Override + @SuppressWarnings("ReferenceEquality") + public Type struct(StructType struct, List types) { + StructField[] fields = struct.fields(); + List newFields = Lists.newArrayListWithExpectedSize(fields.length); + boolean isRoot = root == struct; + for (int i = 0; i < fields.length; i += 1) { + StructField field = fields[i]; + Type type = types.get(i); + + int id; + if (isRoot) { + // for new conversions, use ordinals for ids in the root struct + id = i; + } else { + id = getNextId(); + } + + String doc = field.getComment().isDefined() ? field.getComment().get() : null; + + if (field.nullable()) { + newFields.add(Types.NestedField.optional(id, field.name(), type, doc)); + } else { + newFields.add(Types.NestedField.required(id, field.name(), type, doc)); + } + } + + return Types.StructType.of(newFields); + } + + @Override + public Type field(StructField field, Type typeResult) { + return typeResult; + } + + @Override + public Type array(ArrayType array, Type elementType) { + if (array.containsNull()) { + return Types.ListType.ofOptional(getNextId(), elementType); + } else { + return Types.ListType.ofRequired(getNextId(), elementType); + } + } + + @Override + public Type map(MapType map, Type keyType, Type valueType) { + if (map.valueContainsNull()) { + return Types.MapType.ofOptional(getNextId(), getNextId(), keyType, valueType); + } else { + return Types.MapType.ofRequired(getNextId(), getNextId(), keyType, valueType); + } + } + + @SuppressWarnings("checkstyle:CyclomaticComplexity") + @Override + public Type atomic(DataType atomic) { + if (atomic instanceof BooleanType) { + return Types.BooleanType.get(); + + } else if ( + atomic instanceof IntegerType || + atomic instanceof ShortType || + atomic instanceof ByteType) { + return Types.IntegerType.get(); + + } else if (atomic instanceof LongType) { + return Types.LongType.get(); + + } else if (atomic instanceof FloatType) { + return Types.FloatType.get(); + + } else if (atomic instanceof DoubleType) { + return Types.DoubleType.get(); + + } else if ( + atomic instanceof StringType || + atomic instanceof CharType || + atomic instanceof VarcharType) { + return Types.StringType.get(); + + } else if (atomic instanceof DateType) { + return Types.DateType.get(); + + } else if (atomic instanceof TimestampType) { + return Types.TimestampType.withZone(); + + } else if (atomic instanceof DecimalType) { + return Types.DecimalType.of( + ((DecimalType) atomic).precision(), + ((DecimalType) atomic).scale()); + } else if (atomic instanceof BinaryType) { + return Types.BinaryType.get(); + } + + throw new UnsupportedOperationException( + "Not a supported type: " + atomic.catalogString()); + } +} diff --git a/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/SparkTypeVisitor.java b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/SparkTypeVisitor.java new file mode 100644 index 000000000000..83b31940711e --- /dev/null +++ b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/SparkTypeVisitor.java @@ -0,0 +1,83 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iceberg.spark; + +import java.util.List; +import org.apache.iceberg.relocated.com.google.common.collect.Lists; +import org.apache.spark.sql.types.ArrayType; +import org.apache.spark.sql.types.DataType; +import org.apache.spark.sql.types.MapType; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; +import org.apache.spark.sql.types.UserDefinedType; + +class SparkTypeVisitor { + static T visit(DataType type, SparkTypeVisitor visitor) { + if (type instanceof StructType) { + StructField[] fields = ((StructType) type).fields(); + List fieldResults = Lists.newArrayListWithExpectedSize(fields.length); + + for (StructField field : fields) { + fieldResults.add(visitor.field( + field, + visit(field.dataType(), visitor))); + } + + return visitor.struct((StructType) type, fieldResults); + + } else if (type instanceof MapType) { + return visitor.map((MapType) type, + visit(((MapType) type).keyType(), visitor), + visit(((MapType) type).valueType(), visitor)); + + } else if (type instanceof ArrayType) { + return visitor.array( + (ArrayType) type, + visit(((ArrayType) type).elementType(), visitor)); + + } else if (type instanceof UserDefinedType) { + throw new UnsupportedOperationException( + "User-defined types are not supported"); + + } else { + return visitor.atomic(type); + } + } + + public T struct(StructType struct, List fieldResults) { + return null; + } + + public T field(StructField field, T typeResult) { + return null; + } + + public T array(ArrayType array, T elementResult) { + return null; + } + + public T map(MapType map, T keyResult, T valueResult) { + return null; + } + + public T atomic(DataType atomic) { + return null; + } +} diff --git a/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/SparkUtil.java b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/SparkUtil.java new file mode 100644 index 000000000000..7754aa406123 --- /dev/null +++ b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/SparkUtil.java @@ -0,0 +1,253 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iceberg.spark; + +import java.sql.Date; +import java.sql.Timestamp; +import java.util.List; +import java.util.Map; +import java.util.function.BiFunction; +import java.util.function.Function; +import java.util.stream.Collectors; +import org.apache.hadoop.conf.Configuration; +import org.apache.iceberg.PartitionField; +import org.apache.iceberg.PartitionSpec; +import org.apache.iceberg.Schema; +import org.apache.iceberg.Table; +import org.apache.iceberg.hadoop.HadoopConfigurable; +import org.apache.iceberg.io.FileIO; +import org.apache.iceberg.relocated.com.google.common.base.Preconditions; +import org.apache.iceberg.relocated.com.google.common.collect.Lists; +import org.apache.iceberg.transforms.Transform; +import org.apache.iceberg.transforms.UnknownTransform; +import org.apache.iceberg.types.TypeUtil; +import org.apache.iceberg.types.Types; +import org.apache.iceberg.util.Pair; +import org.apache.spark.sql.RuntimeConfig; +import org.apache.spark.sql.SparkSession; +import org.apache.spark.sql.catalyst.expressions.BoundReference; +import org.apache.spark.sql.catalyst.expressions.EqualTo; +import org.apache.spark.sql.catalyst.expressions.Expression; +import org.apache.spark.sql.catalyst.expressions.Literal; +import org.apache.spark.sql.types.DataType; +import org.apache.spark.sql.types.DataTypes; +import org.apache.spark.sql.types.StructType; +import org.apache.spark.util.SerializableConfiguration; +import org.joda.time.DateTime; + +public class SparkUtil { + + public static final String TIMESTAMP_WITHOUT_TIMEZONE_ERROR = String.format("Cannot handle timestamp without" + + " timezone fields in Spark. Spark does not natively support this type but if you would like to handle all" + + " timestamps as timestamp with timezone set '%s' to true. This will not change the underlying values stored" + + " but will change their displayed values in Spark. For more information please see" + + " https://docs.databricks.com/spark/latest/dataframes-datasets/dates-timestamps.html#ansi-sql-and" + + "-spark-sql-timestamps", SparkSQLProperties.HANDLE_TIMESTAMP_WITHOUT_TIMEZONE); + + private static final String SPARK_CATALOG_CONF_PREFIX = "spark.sql.catalog"; + // Format string used as the prefix for spark configuration keys to override hadoop configuration values + // for Iceberg tables from a given catalog. These keys can be specified as `spark.sql.catalog.$catalogName.hadoop.*`, + // similar to using `spark.hadoop.*` to override hadoop configurations globally for a given spark session. + private static final String SPARK_CATALOG_HADOOP_CONF_OVERRIDE_FMT_STR = SPARK_CATALOG_CONF_PREFIX + ".%s.hadoop."; + + private SparkUtil() { + } + + public static FileIO serializableFileIO(Table table) { + if (table.io() instanceof HadoopConfigurable) { + // we need to use Spark's SerializableConfiguration to avoid issues with Kryo serialization + ((HadoopConfigurable) table.io()).serializeConfWith(conf -> new SerializableConfiguration(conf)::value); + } + + return table.io(); + } + + /** + * Check whether the partition transforms in a spec can be used to write data. + * + * @param spec a PartitionSpec + * @throws UnsupportedOperationException if the spec contains unknown partition transforms + */ + public static void validatePartitionTransforms(PartitionSpec spec) { + if (spec.fields().stream().anyMatch(field -> field.transform() instanceof UnknownTransform)) { + String unsupported = spec.fields().stream() + .map(PartitionField::transform) + .filter(transform -> transform instanceof UnknownTransform) + .map(Transform::toString) + .collect(Collectors.joining(", ")); + + throw new UnsupportedOperationException( + String.format("Cannot write using unsupported transforms: %s", unsupported)); + } + } + + /** + * A modified version of Spark's LookupCatalog.CatalogAndIdentifier.unapply + * Attempts to find the catalog and identifier a multipart identifier represents + * @param nameParts Multipart identifier representing a table + * @return The CatalogPlugin and Identifier for the table + */ + public static Pair catalogAndIdentifier(List nameParts, + Function catalogProvider, + BiFunction identiferProvider, + C currentCatalog, + String[] currentNamespace) { + Preconditions.checkArgument(!nameParts.isEmpty(), + "Cannot determine catalog and identifier from empty name"); + + int lastElementIndex = nameParts.size() - 1; + String name = nameParts.get(lastElementIndex); + + if (nameParts.size() == 1) { + // Only a single element, use current catalog and namespace + return Pair.of(currentCatalog, identiferProvider.apply(currentNamespace, name)); + } else { + C catalog = catalogProvider.apply(nameParts.get(0)); + if (catalog == null) { + // The first element was not a valid catalog, treat it like part of the namespace + String[] namespace = nameParts.subList(0, lastElementIndex).toArray(new String[0]); + return Pair.of(currentCatalog, identiferProvider.apply(namespace, name)); + } else { + // Assume the first element is a valid catalog + String[] namespace = nameParts.subList(1, lastElementIndex).toArray(new String[0]); + return Pair.of(catalog, identiferProvider.apply(namespace, name)); + } + } + } + + /** + * Responsible for checking if the table schema has a timestamp without timezone column + * @param schema table schema to check if it contains a timestamp without timezone column + * @return boolean indicating if the schema passed in has a timestamp field without a timezone + */ + public static boolean hasTimestampWithoutZone(Schema schema) { + return TypeUtil.find(schema, t -> Types.TimestampType.withoutZone().equals(t)) != null; + } + + /** + * Checks whether timestamp types for new tables should be stored with timezone info. + *

+ * The default value is false and all timestamp fields are stored as {@link Types.TimestampType#withZone()}. + * If enabled, all timestamp fields in new tables will be stored as {@link Types.TimestampType#withoutZone()}. + * + * @param sessionConf a Spark runtime config + * @return true if timestamp types for new tables should be stored with timezone info + */ + public static boolean useTimestampWithoutZoneInNewTables(RuntimeConfig sessionConf) { + String sessionConfValue = sessionConf.get(SparkSQLProperties.USE_TIMESTAMP_WITHOUT_TIME_ZONE_IN_NEW_TABLES, null); + if (sessionConfValue != null) { + return Boolean.parseBoolean(sessionConfValue); + } + return SparkSQLProperties.USE_TIMESTAMP_WITHOUT_TIME_ZONE_IN_NEW_TABLES_DEFAULT; + } + + /** + * Pulls any Catalog specific overrides for the Hadoop conf from the current SparkSession, which can be + * set via `spark.sql.catalog.$catalogName.hadoop.*` + * + * Mirrors the override of hadoop configurations for a given spark session using `spark.hadoop.*`. + * + * The SparkCatalog allows for hadoop configurations to be overridden per catalog, by setting + * them on the SQLConf, where the following will add the property "fs.default.name" with value + * "hdfs://hanksnamenode:8020" to the catalog's hadoop configuration. + * SparkSession.builder() + * .config(s"spark.sql.catalog.$catalogName.hadoop.fs.default.name", "hdfs://hanksnamenode:8020") + * .getOrCreate() + * @param spark The current Spark session + * @param catalogName Name of the catalog to find overrides for. + * @return the Hadoop Configuration that should be used for this catalog, with catalog specific overrides applied. + */ + public static Configuration hadoopConfCatalogOverrides(SparkSession spark, String catalogName) { + // Find keys for the catalog intended to be hadoop configurations + final String hadoopConfCatalogPrefix = hadoopConfPrefixForCatalog(catalogName); + final Configuration conf = spark.sessionState().newHadoopConf(); + spark.sqlContext().conf().settings().forEach((k, v) -> { + // These checks are copied from `spark.sessionState().newHadoopConfWithOptions()`, which we + // avoid using to not have to convert back and forth between scala / java map types. + if (v != null && k != null && k.startsWith(hadoopConfCatalogPrefix)) { + conf.set(k.substring(hadoopConfCatalogPrefix.length()), v); + } + }); + return conf; + } + + private static String hadoopConfPrefixForCatalog(String catalogName) { + return String.format(SPARK_CATALOG_HADOOP_CONF_OVERRIDE_FMT_STR, catalogName); + } + + /** + * Get a List of Spark filter Expression. + * + * @param schema table schema + * @param filters filters in the format of a Map, where key is one of the table column name, + * and value is the specific value to be filtered on the column. + * @return a List of filters in the format of Spark Expression. + */ + public static List partitionMapToExpression(StructType schema, + Map filters) { + List filterExpressions = Lists.newArrayList(); + for (Map.Entry entry : filters.entrySet()) { + try { + int index = schema.fieldIndex(entry.getKey()); + DataType dataType = schema.fields()[index].dataType(); + BoundReference ref = new BoundReference(index, dataType, true); + switch (dataType.typeName()) { + case "integer": + filterExpressions.add(new EqualTo(ref, + Literal.create(Integer.parseInt(entry.getValue()), DataTypes.IntegerType))); + break; + case "string": + filterExpressions.add(new EqualTo(ref, Literal.create(entry.getValue(), DataTypes.StringType))); + break; + case "short": + filterExpressions.add(new EqualTo(ref, + Literal.create(Short.parseShort(entry.getValue()), DataTypes.ShortType))); + break; + case "long": + filterExpressions.add(new EqualTo(ref, + Literal.create(Long.parseLong(entry.getValue()), DataTypes.LongType))); + break; + case "float": + filterExpressions.add(new EqualTo(ref, + Literal.create(Float.parseFloat(entry.getValue()), DataTypes.FloatType))); + break; + case "double": + filterExpressions.add(new EqualTo(ref, + Literal.create(Double.parseDouble(entry.getValue()), DataTypes.DoubleType))); + break; + case "date": + filterExpressions.add(new EqualTo(ref, + Literal.create(new Date(DateTime.parse(entry.getValue()).getMillis()), DataTypes.DateType))); + break; + case "timestamp": + filterExpressions.add(new EqualTo(ref, + Literal.create(new Timestamp(DateTime.parse(entry.getValue()).getMillis()), DataTypes.TimestampType))); + break; + default: + throw new IllegalStateException("Unexpected data type in partition filters: " + dataType); + } + } catch (IllegalArgumentException e) { + // ignore if filter is not on table columns + } + } + + return filterExpressions; + } +} diff --git a/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/SparkValueConverter.java b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/SparkValueConverter.java new file mode 100644 index 000000000000..b3e6b2f48887 --- /dev/null +++ b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/SparkValueConverter.java @@ -0,0 +1,125 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iceberg.spark; + +import java.nio.ByteBuffer; +import java.sql.Date; +import java.sql.Timestamp; +import java.util.List; +import java.util.Map; +import org.apache.iceberg.Schema; +import org.apache.iceberg.data.GenericRecord; +import org.apache.iceberg.data.Record; +import org.apache.iceberg.relocated.com.google.common.collect.Lists; +import org.apache.iceberg.relocated.com.google.common.collect.Maps; +import org.apache.iceberg.types.Type; +import org.apache.iceberg.types.Types; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.catalyst.util.DateTimeUtils; + +/** + * A utility class that converts Spark values to Iceberg's internal representation. + */ +public class SparkValueConverter { + + private SparkValueConverter() { + } + + public static Record convert(Schema schema, Row row) { + return convert(schema.asStruct(), row); + } + + public static Object convert(Type type, Object object) { + if (object == null) { + return null; + } + + switch (type.typeId()) { + case STRUCT: + return convert(type.asStructType(), (Row) object); + + case LIST: + List convertedList = Lists.newArrayList(); + List list = (List) object; + for (Object element : list) { + convertedList.add(convert(type.asListType().elementType(), element)); + } + return convertedList; + + case MAP: + Map convertedMap = Maps.newLinkedHashMap(); + Map map = (Map) object; + for (Map.Entry entry : map.entrySet()) { + convertedMap.put( + convert(type.asMapType().keyType(), entry.getKey()), + convert(type.asMapType().valueType(), entry.getValue())); + } + return convertedMap; + + case DATE: + return DateTimeUtils.fromJavaDate((Date) object); + case TIMESTAMP: + return DateTimeUtils.fromJavaTimestamp((Timestamp) object); + case BINARY: + return ByteBuffer.wrap((byte[]) object); + case INTEGER: + return ((Number) object).intValue(); + case BOOLEAN: + case LONG: + case FLOAT: + case DOUBLE: + case DECIMAL: + case STRING: + case FIXED: + return object; + default: + throw new UnsupportedOperationException("Not a supported type: " + type); + } + } + + private static Record convert(Types.StructType struct, Row row) { + if (row == null) { + return null; + } + + Record record = GenericRecord.create(struct); + List fields = struct.fields(); + for (int i = 0; i < fields.size(); i += 1) { + Types.NestedField field = fields.get(i); + + Type fieldType = field.type(); + + switch (fieldType.typeId()) { + case STRUCT: + record.set(i, convert(fieldType.asStructType(), row.getStruct(i))); + break; + case LIST: + record.set(i, convert(fieldType.asListType(), row.getList(i))); + break; + case MAP: + record.set(i, convert(fieldType.asMapType(), row.getJavaMap(i))); + break; + default: + record.set(i, convert(fieldType, row.get(i))); + } + } + return record; + } +} diff --git a/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/SparkWriteConf.java b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/SparkWriteConf.java new file mode 100644 index 000000000000..047a0b8169ed --- /dev/null +++ b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/SparkWriteConf.java @@ -0,0 +1,269 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iceberg.spark; + +import java.util.Locale; +import java.util.Map; +import org.apache.iceberg.DistributionMode; +import org.apache.iceberg.FileFormat; +import org.apache.iceberg.IsolationLevel; +import org.apache.iceberg.SnapshotSummary; +import org.apache.iceberg.Table; +import org.apache.iceberg.TableProperties; +import org.apache.iceberg.relocated.com.google.common.collect.Maps; +import org.apache.spark.sql.RuntimeConfig; +import org.apache.spark.sql.SparkSession; + +import static org.apache.iceberg.DistributionMode.HASH; +import static org.apache.iceberg.DistributionMode.NONE; +import static org.apache.iceberg.DistributionMode.RANGE; + +/** + * A class for common Iceberg configs for Spark writes. + *

+ * If a config is set at multiple levels, the following order of precedence is used (top to bottom): + *

    + *
  1. Write options
  2. + *
  3. Session configuration
  4. + *
  5. Table metadata
  6. + *
+ * The most specific value is set in write options and takes precedence over all other configs. + * If no write option is provided, this class checks the session configuration for any overrides. + * If no applicable value is found in the session configuration, this class uses the table metadata. + *

+ * Note this class is NOT meant to be serialized and sent to executors. + */ +public class SparkWriteConf { + + private final Table table; + private final RuntimeConfig sessionConf; + private final Map writeOptions; + private final SparkConfParser confParser; + + public SparkWriteConf(SparkSession spark, Table table, Map writeOptions) { + this.table = table; + this.sessionConf = spark.conf(); + this.writeOptions = writeOptions; + this.confParser = new SparkConfParser(spark, table, writeOptions); + } + + public boolean checkNullability() { + return confParser.booleanConf() + .option(SparkWriteOptions.CHECK_NULLABILITY) + .sessionConf(SparkSQLProperties.CHECK_NULLABILITY) + .defaultValue(SparkSQLProperties.CHECK_NULLABILITY_DEFAULT) + .parse(); + } + + public boolean checkOrdering() { + return confParser.booleanConf() + .option(SparkWriteOptions.CHECK_ORDERING) + .sessionConf(SparkSQLProperties.CHECK_ORDERING) + .defaultValue(SparkSQLProperties.CHECK_ORDERING_DEFAULT) + .parse(); + } + + /** + * Enables writing a timestamp with time zone as a timestamp without time zone. + *

+ * Generally, this is not safe as a timestamp without time zone is supposed to represent the wall-clock time, + * i.e. no matter the reader/writer timezone 3PM should always be read as 3PM, + * but a timestamp with time zone represents instant semantics, i.e. the timestamp + * is adjusted so that the corresponding time in the reader timezone is displayed. + *

+ * When set to false (default), an exception must be thrown if the table contains a timestamp without time zone. + * + * @return boolean indicating if writing timestamps without timezone is allowed + */ + public boolean handleTimestampWithoutZone() { + return confParser.booleanConf() + .option(SparkWriteOptions.HANDLE_TIMESTAMP_WITHOUT_TIMEZONE) + .sessionConf(SparkSQLProperties.HANDLE_TIMESTAMP_WITHOUT_TIMEZONE) + .defaultValue(SparkSQLProperties.HANDLE_TIMESTAMP_WITHOUT_TIMEZONE_DEFAULT) + .parse(); + } + + public String overwriteMode() { + String overwriteMode = writeOptions.get(SparkWriteOptions.OVERWRITE_MODE); + return overwriteMode != null ? overwriteMode.toLowerCase(Locale.ROOT) : null; + } + + public boolean wapEnabled() { + return confParser.booleanConf() + .tableProperty(TableProperties.WRITE_AUDIT_PUBLISH_ENABLED) + .defaultValue(TableProperties.WRITE_AUDIT_PUBLISH_ENABLED_DEFAULT) + .parse(); + } + + public String wapId() { + return sessionConf.get("spark.wap.id", null); + } + + public boolean mergeSchema() { + return confParser.booleanConf() + .option(SparkWriteOptions.MERGE_SCHEMA) + .option(SparkWriteOptions.SPARK_MERGE_SCHEMA) + .defaultValue(SparkWriteOptions.MERGE_SCHEMA_DEFAULT) + .parse(); + } + + public FileFormat dataFileFormat() { + String valueAsString = confParser.stringConf() + .option(SparkWriteOptions.WRITE_FORMAT) + .tableProperty(TableProperties.DEFAULT_FILE_FORMAT) + .defaultValue(TableProperties.DEFAULT_FILE_FORMAT_DEFAULT) + .parse(); + return FileFormat.valueOf(valueAsString.toUpperCase(Locale.ENGLISH)); + } + + public long targetDataFileSize() { + return confParser.longConf() + .option(SparkWriteOptions.TARGET_FILE_SIZE_BYTES) + .tableProperty(TableProperties.WRITE_TARGET_FILE_SIZE_BYTES) + .defaultValue(TableProperties.WRITE_TARGET_FILE_SIZE_BYTES_DEFAULT) + .parse(); + } + + public boolean fanoutWriterEnabled() { + return confParser.booleanConf() + .option(SparkWriteOptions.FANOUT_ENABLED) + .tableProperty(TableProperties.SPARK_WRITE_PARTITIONED_FANOUT_ENABLED) + .defaultValue(TableProperties.SPARK_WRITE_PARTITIONED_FANOUT_ENABLED_DEFAULT) + .parse(); + } + + public FileFormat deleteFileFormat() { + String valueAsString = confParser.stringConf() + .option(SparkWriteOptions.DELETE_FORMAT) + .tableProperty(TableProperties.DELETE_DEFAULT_FILE_FORMAT) + .parseOptional(); + return valueAsString != null ? FileFormat.valueOf(valueAsString.toUpperCase(Locale.ENGLISH)) : dataFileFormat(); + } + + public long targetDeleteFileSize() { + return confParser.longConf() + .option(SparkWriteOptions.TARGET_DELETE_FILE_SIZE_BYTES) + .tableProperty(TableProperties.DELETE_TARGET_FILE_SIZE_BYTES) + .defaultValue(TableProperties.DELETE_TARGET_FILE_SIZE_BYTES_DEFAULT) + .parse(); + } + + public Map extraSnapshotMetadata() { + Map extraSnapshotMetadata = Maps.newHashMap(); + + writeOptions.forEach((key, value) -> { + if (key.startsWith(SnapshotSummary.EXTRA_METADATA_PREFIX)) { + extraSnapshotMetadata.put(key.substring(SnapshotSummary.EXTRA_METADATA_PREFIX.length()), value); + } + }); + + return extraSnapshotMetadata; + } + + public String rewrittenFileSetId() { + return confParser.stringConf() + .option(SparkWriteOptions.REWRITTEN_FILE_SCAN_TASK_SET_ID) + .parseOptional(); + } + + public DistributionMode distributionMode() { + String modeName = confParser.stringConf() + .option(SparkWriteOptions.DISTRIBUTION_MODE) + .tableProperty(TableProperties.WRITE_DISTRIBUTION_MODE) + .parseOptional(); + + if (modeName != null) { + DistributionMode mode = DistributionMode.fromName(modeName); + return adjustWriteDistributionMode(mode); + } else { + return table.sortOrder().isSorted() ? RANGE : NONE; + } + } + + private DistributionMode adjustWriteDistributionMode(DistributionMode mode) { + if (mode == RANGE && table.spec().isUnpartitioned() && table.sortOrder().isUnsorted()) { + return NONE; + } else if (mode == HASH && table.spec().isUnpartitioned()) { + return NONE; + } else { + return mode; + } + } + + public DistributionMode deleteDistributionMode() { + String deleteModeName = confParser.stringConf() + .option(SparkWriteOptions.DISTRIBUTION_MODE) + .tableProperty(TableProperties.DELETE_DISTRIBUTION_MODE) + .defaultValue(TableProperties.WRITE_DISTRIBUTION_MODE_HASH) + .parse(); + return DistributionMode.fromName(deleteModeName); + } + + public DistributionMode updateDistributionMode() { + String updateModeName = confParser.stringConf() + .option(SparkWriteOptions.DISTRIBUTION_MODE) + .tableProperty(TableProperties.UPDATE_DISTRIBUTION_MODE) + .defaultValue(TableProperties.WRITE_DISTRIBUTION_MODE_HASH) + .parse(); + return DistributionMode.fromName(updateModeName); + } + + public DistributionMode copyOnWriteMergeDistributionMode() { + String mergeModeName = confParser.stringConf() + .option(SparkWriteOptions.DISTRIBUTION_MODE) + .tableProperty(TableProperties.MERGE_DISTRIBUTION_MODE) + .parseOptional(); + + if (mergeModeName != null) { + DistributionMode mergeMode = DistributionMode.fromName(mergeModeName); + return adjustWriteDistributionMode(mergeMode); + } else { + return distributionMode(); + } + } + + public DistributionMode positionDeltaMergeDistributionMode() { + String mergeModeName = confParser.stringConf() + .option(SparkWriteOptions.DISTRIBUTION_MODE) + .tableProperty(TableProperties.MERGE_DISTRIBUTION_MODE) + .parseOptional(); + return mergeModeName != null ? DistributionMode.fromName(mergeModeName) : distributionMode(); + } + + public boolean useTableDistributionAndOrdering() { + return confParser.booleanConf() + .option(SparkWriteOptions.USE_TABLE_DISTRIBUTION_AND_ORDERING) + .defaultValue(SparkWriteOptions.USE_TABLE_DISTRIBUTION_AND_ORDERING_DEFAULT) + .parse(); + } + + public Long validateFromSnapshotId() { + return confParser.longConf() + .option(SparkWriteOptions.VALIDATE_FROM_SNAPSHOT_ID) + .parseOptional(); + } + + public IsolationLevel isolationLevel() { + String isolationLevelName = confParser.stringConf() + .option(SparkWriteOptions.ISOLATION_LEVEL) + .parseOptional(); + return isolationLevelName != null ? IsolationLevel.fromName(isolationLevelName) : null; + } +} diff --git a/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/SparkWriteOptions.java b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/SparkWriteOptions.java new file mode 100644 index 000000000000..72de545f1298 --- /dev/null +++ b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/SparkWriteOptions.java @@ -0,0 +1,81 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iceberg.spark; + +/** + * Spark DF write options + */ +public class SparkWriteOptions { + + private SparkWriteOptions() { + } + + // Fileformat for write operations(default: Table write.format.default ) + public static final String WRITE_FORMAT = "write-format"; + + // Overrides this table's write.target-file-size-bytes + public static final String TARGET_FILE_SIZE_BYTES = "target-file-size-bytes"; + + // Overrides the default file format for delete files + public static final String DELETE_FORMAT = "delete-format"; + + // Overrides the default size for delete files + public static final String TARGET_DELETE_FILE_SIZE_BYTES = "target-delete-file-size-bytes"; + + // Sets the nullable check on fields(default: true) + public static final String CHECK_NULLABILITY = "check-nullability"; + + // Adds an entry with custom-key and corresponding value in the snapshot summary + // ex: df.write().format(iceberg) + // .option(SparkWriteOptions.SNAPSHOT_PROPERTY_PREFIX."key1", "value1") + // .save(location) + public static final String SNAPSHOT_PROPERTY_PREFIX = "snapshot-property"; + + // Overrides table property write.spark.fanout.enabled(default: false) + public static final String FANOUT_ENABLED = "fanout-enabled"; + + // Checks if input schema and table schema are same(default: true) + public static final String CHECK_ORDERING = "check-ordering"; + + // File scan task set ID that indicates which files must be replaced + public static final String REWRITTEN_FILE_SCAN_TASK_SET_ID = "rewritten-file-scan-task-set-id"; + + // Controls whether to allow writing timestamps without zone info + public static final String HANDLE_TIMESTAMP_WITHOUT_TIMEZONE = "handle-timestamp-without-timezone"; + + public static final String OVERWRITE_MODE = "overwrite-mode"; + + // Overrides the default distribution mode for a write operation + public static final String DISTRIBUTION_MODE = "distribution-mode"; + + // Controls whether to take into account the table distribution and sort order during a write operation + public static final String USE_TABLE_DISTRIBUTION_AND_ORDERING = "use-table-distribution-and-ordering"; + public static final boolean USE_TABLE_DISTRIBUTION_AND_ORDERING_DEFAULT = true; + + public static final String MERGE_SCHEMA = "merge-schema"; + public static final String SPARK_MERGE_SCHEMA = "mergeSchema"; + public static final boolean MERGE_SCHEMA_DEFAULT = false; + + // Identifies snapshot from which to start validating conflicting changes + public static final String VALIDATE_FROM_SNAPSHOT_ID = "validate-from-snapshot-id"; + + // Isolation Level for DataFrame calls. Currently supported by overwritePartitions + public static final String ISOLATION_LEVEL = "isolation-level"; +} diff --git a/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/TypeToSparkType.java b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/TypeToSparkType.java new file mode 100644 index 000000000000..b12d3e5030b7 --- /dev/null +++ b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/TypeToSparkType.java @@ -0,0 +1,139 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iceberg.spark; + +import java.util.List; +import org.apache.iceberg.MetadataColumns; +import org.apache.iceberg.Schema; +import org.apache.iceberg.relocated.com.google.common.collect.Lists; +import org.apache.iceberg.types.Type; +import org.apache.iceberg.types.TypeUtil; +import org.apache.iceberg.types.Types; +import org.apache.spark.sql.types.ArrayType$; +import org.apache.spark.sql.types.BinaryType$; +import org.apache.spark.sql.types.BooleanType$; +import org.apache.spark.sql.types.DataType; +import org.apache.spark.sql.types.DateType$; +import org.apache.spark.sql.types.DecimalType$; +import org.apache.spark.sql.types.DoubleType$; +import org.apache.spark.sql.types.FloatType$; +import org.apache.spark.sql.types.IntegerType$; +import org.apache.spark.sql.types.LongType$; +import org.apache.spark.sql.types.MapType$; +import org.apache.spark.sql.types.Metadata; +import org.apache.spark.sql.types.MetadataBuilder; +import org.apache.spark.sql.types.StringType$; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType$; +import org.apache.spark.sql.types.TimestampType$; + +class TypeToSparkType extends TypeUtil.SchemaVisitor { + TypeToSparkType() { + } + + public static final String METADATA_COL_ATTR_KEY = "__metadata_col"; + + @Override + public DataType schema(Schema schema, DataType structType) { + return structType; + } + + @Override + public DataType struct(Types.StructType struct, List fieldResults) { + List fields = struct.fields(); + + List sparkFields = Lists.newArrayListWithExpectedSize(fieldResults.size()); + for (int i = 0; i < fields.size(); i += 1) { + Types.NestedField field = fields.get(i); + DataType type = fieldResults.get(i); + Metadata metadata = fieldMetadata(field.fieldId()); + StructField sparkField = StructField.apply(field.name(), type, field.isOptional(), metadata); + if (field.doc() != null) { + sparkField = sparkField.withComment(field.doc()); + } + sparkFields.add(sparkField); + } + + return StructType$.MODULE$.apply(sparkFields); + } + + @Override + public DataType field(Types.NestedField field, DataType fieldResult) { + return fieldResult; + } + + @Override + public DataType list(Types.ListType list, DataType elementResult) { + return ArrayType$.MODULE$.apply(elementResult, list.isElementOptional()); + } + + @Override + public DataType map(Types.MapType map, DataType keyResult, DataType valueResult) { + return MapType$.MODULE$.apply(keyResult, valueResult, map.isValueOptional()); + } + + @Override + public DataType primitive(Type.PrimitiveType primitive) { + switch (primitive.typeId()) { + case BOOLEAN: + return BooleanType$.MODULE$; + case INTEGER: + return IntegerType$.MODULE$; + case LONG: + return LongType$.MODULE$; + case FLOAT: + return FloatType$.MODULE$; + case DOUBLE: + return DoubleType$.MODULE$; + case DATE: + return DateType$.MODULE$; + case TIME: + throw new UnsupportedOperationException( + "Spark does not support time fields"); + case TIMESTAMP: + return TimestampType$.MODULE$; + case STRING: + return StringType$.MODULE$; + case UUID: + // use String + return StringType$.MODULE$; + case FIXED: + return BinaryType$.MODULE$; + case BINARY: + return BinaryType$.MODULE$; + case DECIMAL: + Types.DecimalType decimal = (Types.DecimalType) primitive; + return DecimalType$.MODULE$.apply(decimal.precision(), decimal.scale()); + default: + throw new UnsupportedOperationException( + "Cannot convert unknown type to Spark: " + primitive); + } + } + + private Metadata fieldMetadata(int fieldId) { + if (MetadataColumns.metadataFieldIds().contains(fieldId)) { + return new MetadataBuilder() + .putBoolean(METADATA_COL_ATTR_KEY, true) + .build(); + } + + return Metadata.empty(); + } +} diff --git a/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/actions/BaseDeleteOrphanFilesSparkAction.java b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/actions/BaseDeleteOrphanFilesSparkAction.java new file mode 100644 index 000000000000..dc58a05d4d0d --- /dev/null +++ b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/actions/BaseDeleteOrphanFilesSparkAction.java @@ -0,0 +1,364 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iceberg.spark.actions; + +import java.io.File; +import java.io.IOException; +import java.io.Serializable; +import java.sql.Timestamp; +import java.util.Iterator; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.TimeUnit; +import java.util.function.Consumer; +import java.util.function.Predicate; +import java.util.stream.Collectors; +import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.fs.FileStatus; +import org.apache.hadoop.fs.FileSystem; +import org.apache.hadoop.fs.Path; +import org.apache.hadoop.fs.PathFilter; +import org.apache.iceberg.PartitionSpec; +import org.apache.iceberg.Table; +import org.apache.iceberg.actions.BaseDeleteOrphanFilesActionResult; +import org.apache.iceberg.actions.DeleteOrphanFiles; +import org.apache.iceberg.exceptions.RuntimeIOException; +import org.apache.iceberg.exceptions.ValidationException; +import org.apache.iceberg.hadoop.HiddenPathFilter; +import org.apache.iceberg.relocated.com.google.common.annotations.VisibleForTesting; +import org.apache.iceberg.relocated.com.google.common.base.Joiner; +import org.apache.iceberg.relocated.com.google.common.base.Preconditions; +import org.apache.iceberg.relocated.com.google.common.collect.Lists; +import org.apache.iceberg.spark.JobGroupInfo; +import org.apache.iceberg.util.PropertyUtil; +import org.apache.iceberg.util.Tasks; +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.function.FlatMapFunction; +import org.apache.spark.broadcast.Broadcast; +import org.apache.spark.sql.Column; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Encoders; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.SparkSession; +import org.apache.spark.sql.expressions.UserDefinedFunction; +import org.apache.spark.sql.functions; +import org.apache.spark.sql.types.DataTypes; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; +import org.apache.spark.util.SerializableConfiguration; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import static org.apache.iceberg.TableProperties.GC_ENABLED; +import static org.apache.iceberg.TableProperties.GC_ENABLED_DEFAULT; + +/** + * An action that removes orphan metadata, data and delete files by listing a given location and comparing + * the actual files in that location with content and metadata files referenced by all valid snapshots. + * The location must be accessible for listing via the Hadoop {@link FileSystem}. + *

+ * By default, this action cleans up the table location returned by {@link Table#location()} and + * removes unreachable files that are older than 3 days using {@link Table#io()}. The behavior can be modified + * by passing a custom location to {@link #location} and a custom timestamp to {@link #olderThan(long)}. + * For example, someone might point this action to the data folder to clean up only orphan data files. + *

+ * Configure an alternative delete method using {@link #deleteWith(Consumer)}. + *

+ * For full control of the set of files being evaluated, use the {@link #compareToFileList(Dataset)} argument. This + * skips the directory listing - any files in the dataset provided which are not found in table metadata will + * be deleted, using the same {@link Table#location()} and {@link #olderThan(long)} filtering as above. + *

+ * Note: It is dangerous to call this action with a short retention interval as it might corrupt + * the state of the table if another operation is writing at the same time. + */ +public class BaseDeleteOrphanFilesSparkAction + extends BaseSparkAction implements DeleteOrphanFiles { + + private static final Logger LOG = LoggerFactory.getLogger(BaseDeleteOrphanFilesSparkAction.class); + private static final UserDefinedFunction filenameUDF = functions.udf((String path) -> { + int lastIndex = path.lastIndexOf(File.separator); + if (lastIndex == -1) { + return path; + } else { + return path.substring(lastIndex + 1); + } + }, DataTypes.StringType); + + private final SerializableConfiguration hadoopConf; + private final int partitionDiscoveryParallelism; + private final Table table; + private final Consumer defaultDelete = new Consumer() { + @Override + public void accept(String file) { + table.io().deleteFile(file); + } + }; + + private String location = null; + private long olderThanTimestamp = System.currentTimeMillis() - TimeUnit.DAYS.toMillis(3); + private Dataset compareToFileList; + private Consumer deleteFunc = defaultDelete; + private ExecutorService deleteExecutorService = null; + + public BaseDeleteOrphanFilesSparkAction(SparkSession spark, Table table) { + super(spark); + + this.hadoopConf = new SerializableConfiguration(spark.sessionState().newHadoopConf()); + this.partitionDiscoveryParallelism = spark.sessionState().conf().parallelPartitionDiscoveryParallelism(); + this.table = table; + this.location = table.location(); + + ValidationException.check( + PropertyUtil.propertyAsBoolean(table.properties(), GC_ENABLED, GC_ENABLED_DEFAULT), + "Cannot delete orphan files: GC is disabled (deleting files may corrupt other tables)"); + } + + @Override + protected DeleteOrphanFiles self() { + return this; + } + + @Override + public BaseDeleteOrphanFilesSparkAction executeDeleteWith(ExecutorService executorService) { + this.deleteExecutorService = executorService; + return this; + } + + @Override + public BaseDeleteOrphanFilesSparkAction location(String newLocation) { + this.location = newLocation; + return this; + } + + @Override + public BaseDeleteOrphanFilesSparkAction olderThan(long newOlderThanTimestamp) { + this.olderThanTimestamp = newOlderThanTimestamp; + return this; + } + + @Override + public BaseDeleteOrphanFilesSparkAction deleteWith(Consumer newDeleteFunc) { + this.deleteFunc = newDeleteFunc; + return this; + } + + public BaseDeleteOrphanFilesSparkAction compareToFileList(Dataset files) { + StructType schema = files.schema(); + + StructField filePathField = schema.apply(FILE_PATH); + Preconditions.checkArgument( + filePathField.dataType() == DataTypes.StringType, + "Invalid %s column: %s is not a string", + FILE_PATH, + filePathField.dataType()); + + StructField lastModifiedField = schema.apply(LAST_MODIFIED); + Preconditions.checkArgument( + lastModifiedField.dataType() == DataTypes.TimestampType, + "Invalid %s column: %s is not a timestamp", + LAST_MODIFIED, + lastModifiedField.dataType()); + + this.compareToFileList = files; + return this; + } + + private Dataset filteredCompareToFileList() { + Dataset files = compareToFileList; + if (location != null) { + files = files.filter(files.col(FILE_PATH).startsWith(location)); + } + return files + .filter(files.col(LAST_MODIFIED).lt(new Timestamp(olderThanTimestamp))) + .select(files.col(FILE_PATH)); + } + + @Override + public DeleteOrphanFiles.Result execute() { + JobGroupInfo info = newJobGroupInfo("DELETE-ORPHAN-FILES", jobDesc()); + return withJobGroupInfo(info, this::doExecute); + } + + private String jobDesc() { + List options = Lists.newArrayList(); + options.add("older_than=" + olderThanTimestamp); + if (location != null) { + options.add("location=" + location); + } + return String.format("Deleting orphan files (%s) from %s", Joiner.on(',').join(options), table.name()); + } + + private DeleteOrphanFiles.Result doExecute() { + Dataset validContentFileDF = buildValidContentFileDF(table); + Dataset validMetadataFileDF = buildValidMetadataFileDF(table); + Dataset validFileDF = validContentFileDF.union(validMetadataFileDF); + Dataset actualFileDF = compareToFileList == null ? buildActualFileDF() : filteredCompareToFileList(); + + Column actualFileName = filenameUDF.apply(actualFileDF.col(FILE_PATH)); + Column validFileName = filenameUDF.apply(validFileDF.col(FILE_PATH)); + Column nameEqual = actualFileName.equalTo(validFileName); + Column actualContains = actualFileDF.col(FILE_PATH).contains(validFileDF.col(FILE_PATH)); + Column joinCond = nameEqual.and(actualContains); + List orphanFiles = actualFileDF.join(validFileDF, joinCond, "leftanti") + .as(Encoders.STRING()) + .collectAsList(); + + Tasks.foreach(orphanFiles) + .noRetry() + .executeWith(deleteExecutorService) + .suppressFailureWhenFinished() + .onFailure((file, exc) -> LOG.warn("Failed to delete file: {}", file, exc)) + .run(deleteFunc::accept); + + return new BaseDeleteOrphanFilesActionResult(orphanFiles); + } + + private Dataset buildActualFileDF() { + List subDirs = Lists.newArrayList(); + List matchingFiles = Lists.newArrayList(); + + Predicate predicate = file -> file.getModificationTime() < olderThanTimestamp; + PathFilter pathFilter = PartitionAwareHiddenPathFilter.forSpecs(table.specs()); + + // list at most 3 levels and only dirs that have less than 10 direct sub dirs on the driver + listDirRecursively(location, predicate, hadoopConf.value(), 3, 10, subDirs, pathFilter, matchingFiles); + + JavaRDD matchingFileRDD = sparkContext().parallelize(matchingFiles, 1); + + if (subDirs.isEmpty()) { + return spark().createDataset(matchingFileRDD.rdd(), Encoders.STRING()).toDF(FILE_PATH); + } + + int parallelism = Math.min(subDirs.size(), partitionDiscoveryParallelism); + JavaRDD subDirRDD = sparkContext().parallelize(subDirs, parallelism); + + Broadcast conf = sparkContext().broadcast(hadoopConf); + JavaRDD matchingLeafFileRDD = subDirRDD.mapPartitions( + listDirsRecursively(conf, olderThanTimestamp, pathFilter) + ); + + JavaRDD completeMatchingFileRDD = matchingFileRDD.union(matchingLeafFileRDD); + return spark().createDataset(completeMatchingFileRDD.rdd(), Encoders.STRING()).toDF(FILE_PATH); + } + + private static void listDirRecursively( + String dir, Predicate predicate, Configuration conf, int maxDepth, + int maxDirectSubDirs, List remainingSubDirs, PathFilter pathFilter, List matchingFiles) { + + // stop listing whenever we reach the max depth + if (maxDepth <= 0) { + remainingSubDirs.add(dir); + return; + } + + try { + Path path = new Path(dir); + FileSystem fs = path.getFileSystem(conf); + + List subDirs = Lists.newArrayList(); + + for (FileStatus file : fs.listStatus(path, pathFilter)) { + if (file.isDirectory()) { + subDirs.add(file.getPath().toString()); + } else if (file.isFile() && predicate.test(file)) { + matchingFiles.add(file.getPath().toString()); + } + } + + // stop listing if the number of direct sub dirs is bigger than maxDirectSubDirs + if (subDirs.size() > maxDirectSubDirs) { + remainingSubDirs.addAll(subDirs); + return; + } + + for (String subDir : subDirs) { + listDirRecursively( + subDir, predicate, conf, maxDepth - 1, maxDirectSubDirs, remainingSubDirs, pathFilter, matchingFiles); + } + } catch (IOException e) { + throw new RuntimeIOException(e); + } + } + + private static FlatMapFunction, String> listDirsRecursively( + Broadcast conf, + long olderThanTimestamp, + PathFilter pathFilter) { + + return dirs -> { + List subDirs = Lists.newArrayList(); + List files = Lists.newArrayList(); + + Predicate predicate = file -> file.getModificationTime() < olderThanTimestamp; + + int maxDepth = 2000; + int maxDirectSubDirs = Integer.MAX_VALUE; + + dirs.forEachRemaining(dir -> { + listDirRecursively( + dir, predicate, conf.value().value(), maxDepth, maxDirectSubDirs, subDirs, pathFilter, files); + }); + + if (!subDirs.isEmpty()) { + throw new RuntimeException("Could not list subdirectories, reached maximum subdirectory depth: " + maxDepth); + } + + return files.iterator(); + }; + } + + /** + * A {@link PathFilter} that filters out hidden path, but does not filter out paths that would be marked + * as hidden by {@link HiddenPathFilter} due to a partition field that starts with one of the characters that + * indicate a hidden path. + */ + @VisibleForTesting + static class PartitionAwareHiddenPathFilter implements PathFilter, Serializable { + + private final Set hiddenPathPartitionNames; + + PartitionAwareHiddenPathFilter(Set hiddenPathPartitionNames) { + this.hiddenPathPartitionNames = hiddenPathPartitionNames; + } + + @Override + public boolean accept(Path path) { + boolean isHiddenPartitionPath = hiddenPathPartitionNames.stream().anyMatch(path.getName()::startsWith); + return isHiddenPartitionPath || HiddenPathFilter.get().accept(path); + } + + static PathFilter forSpecs(Map specs) { + if (specs == null) { + return HiddenPathFilter.get(); + } + + Set partitionNames = specs.values().stream() + .map(PartitionSpec::fields) + .flatMap(List::stream) + .filter(partitionField -> partitionField.name().startsWith("_") || partitionField.name().startsWith(".")) + .map(partitionField -> partitionField.name() + "=") + .collect(Collectors.toSet()); + + return partitionNames.isEmpty() ? HiddenPathFilter.get() : new PartitionAwareHiddenPathFilter(partitionNames); + } + } +} diff --git a/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/actions/BaseDeleteReachableFilesSparkAction.java b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/actions/BaseDeleteReachableFilesSparkAction.java new file mode 100644 index 000000000000..fc2a5fa5cf00 --- /dev/null +++ b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/actions/BaseDeleteReachableFilesSparkAction.java @@ -0,0 +1,183 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iceberg.spark.actions; + +import java.util.Iterator; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.atomic.AtomicLong; +import java.util.function.Consumer; +import org.apache.iceberg.Table; +import org.apache.iceberg.TableMetadata; +import org.apache.iceberg.TableMetadataParser; +import org.apache.iceberg.actions.BaseDeleteReachableFilesActionResult; +import org.apache.iceberg.actions.DeleteReachableFiles; +import org.apache.iceberg.exceptions.NotFoundException; +import org.apache.iceberg.exceptions.ValidationException; +import org.apache.iceberg.hadoop.HadoopFileIO; +import org.apache.iceberg.io.FileIO; +import org.apache.iceberg.relocated.com.google.common.base.Preconditions; +import org.apache.iceberg.spark.JobGroupInfo; +import org.apache.iceberg.util.PropertyUtil; +import org.apache.iceberg.util.Tasks; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.SparkSession; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import static org.apache.iceberg.TableProperties.GC_ENABLED; +import static org.apache.iceberg.TableProperties.GC_ENABLED_DEFAULT; + +/** + * An implementation of {@link DeleteReachableFiles} that uses metadata tables in Spark + * to determine which files should be deleted. + */ +@SuppressWarnings("UnnecessaryAnonymousClass") +public class BaseDeleteReachableFilesSparkAction + extends BaseSparkAction implements DeleteReachableFiles { + + public static final String STREAM_RESULTS = "stream-results"; + public static final boolean STREAM_RESULTS_DEFAULT = false; + + private static final Logger LOG = LoggerFactory.getLogger(BaseDeleteReachableFilesSparkAction.class); + + private final String metadataFileLocation; + private final Consumer defaultDelete = new Consumer() { + @Override + public void accept(String file) { + io.deleteFile(file); + } + }; + + private Consumer deleteFunc = defaultDelete; + private ExecutorService deleteExecutorService = null; + private FileIO io = new HadoopFileIO(spark().sessionState().newHadoopConf()); + + public BaseDeleteReachableFilesSparkAction(SparkSession spark, String metadataFileLocation) { + super(spark); + this.metadataFileLocation = metadataFileLocation; + } + + @Override + protected DeleteReachableFiles self() { + return this; + } + + @Override + public DeleteReachableFiles io(FileIO fileIO) { + this.io = fileIO; + return this; + } + + @Override + public DeleteReachableFiles deleteWith(Consumer newDeleteFunc) { + this.deleteFunc = newDeleteFunc; + return this; + } + + @Override + public DeleteReachableFiles executeDeleteWith(ExecutorService executorService) { + this.deleteExecutorService = executorService; + return this; + } + + @Override + public Result execute() { + Preconditions.checkArgument(io != null, "File IO cannot be null"); + String jobDesc = String.format("Deleting files reachable from %s", metadataFileLocation); + JobGroupInfo info = newJobGroupInfo("DELETE-REACHABLE-FILES", jobDesc); + return withJobGroupInfo(info, this::doExecute); + } + + private Result doExecute() { + TableMetadata metadata = TableMetadataParser.read(io, metadataFileLocation); + + ValidationException.check( + PropertyUtil.propertyAsBoolean(metadata.properties(), GC_ENABLED, GC_ENABLED_DEFAULT), + "Cannot delete files: GC is disabled (deleting files may corrupt other tables)"); + + Dataset reachableFileDF = buildReachableFileDF(metadata).distinct(); + + boolean streamResults = PropertyUtil.propertyAsBoolean(options(), STREAM_RESULTS, STREAM_RESULTS_DEFAULT); + if (streamResults) { + return deleteFiles(reachableFileDF.toLocalIterator()); + } else { + return deleteFiles(reachableFileDF.collectAsList().iterator()); + } + } + + private Dataset buildReachableFileDF(TableMetadata metadata) { + Table staticTable = newStaticTable(metadata, io); + return withFileType(buildValidContentFileDF(staticTable), CONTENT_FILE) + .union(withFileType(buildManifestFileDF(staticTable), MANIFEST)) + .union(withFileType(buildManifestListDF(staticTable), MANIFEST_LIST)) + .union(withFileType(buildAllReachableOtherMetadataFileDF(staticTable), OTHERS)); + } + + /** + * Deletes files passed to it. + * + * @param deleted an Iterator of Spark Rows of the structure (path: String, type: String) + * @return Statistics on which files were deleted + */ + private BaseDeleteReachableFilesActionResult deleteFiles(Iterator deleted) { + AtomicLong dataFileCount = new AtomicLong(0L); + AtomicLong manifestCount = new AtomicLong(0L); + AtomicLong manifestListCount = new AtomicLong(0L); + AtomicLong otherFilesCount = new AtomicLong(0L); + + Tasks.foreach(deleted) + .retry(3).stopRetryOn(NotFoundException.class).suppressFailureWhenFinished() + .executeWith(deleteExecutorService) + .onFailure((fileInfo, exc) -> { + String file = fileInfo.getString(0); + String type = fileInfo.getString(1); + LOG.warn("Delete failed for {}: {}", type, file, exc); + }) + .run(fileInfo -> { + String file = fileInfo.getString(0); + String type = fileInfo.getString(1); + deleteFunc.accept(file); + switch (type) { + case CONTENT_FILE: + dataFileCount.incrementAndGet(); + LOG.trace("Deleted Content File: {}", file); + break; + case MANIFEST: + manifestCount.incrementAndGet(); + LOG.debug("Deleted Manifest: {}", file); + break; + case MANIFEST_LIST: + manifestListCount.incrementAndGet(); + LOG.debug("Deleted Manifest List: {}", file); + break; + case OTHERS: + otherFilesCount.incrementAndGet(); + LOG.debug("Others: {}", file); + break; + } + }); + + long filesCount = dataFileCount.get() + manifestCount.get() + manifestListCount.get() + otherFilesCount.get(); + LOG.info("Total files removed: {}", filesCount); + return new BaseDeleteReachableFilesActionResult(dataFileCount.get(), manifestCount.get(), manifestListCount.get(), + otherFilesCount.get()); + } +} diff --git a/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/actions/BaseExpireSnapshotsSparkAction.java b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/actions/BaseExpireSnapshotsSparkAction.java new file mode 100644 index 000000000000..2216383337ef --- /dev/null +++ b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/actions/BaseExpireSnapshotsSparkAction.java @@ -0,0 +1,276 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iceberg.spark.actions; + +import java.util.Iterator; +import java.util.List; +import java.util.Set; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.atomic.AtomicLong; +import java.util.function.Consumer; +import org.apache.iceberg.FileContent; +import org.apache.iceberg.HasTableOperations; +import org.apache.iceberg.Table; +import org.apache.iceberg.TableMetadata; +import org.apache.iceberg.TableOperations; +import org.apache.iceberg.actions.BaseExpireSnapshotsActionResult; +import org.apache.iceberg.actions.ExpireSnapshots; +import org.apache.iceberg.exceptions.NotFoundException; +import org.apache.iceberg.exceptions.ValidationException; +import org.apache.iceberg.relocated.com.google.common.base.Joiner; +import org.apache.iceberg.relocated.com.google.common.base.Preconditions; +import org.apache.iceberg.relocated.com.google.common.collect.Lists; +import org.apache.iceberg.relocated.com.google.common.collect.Sets; +import org.apache.iceberg.spark.JobGroupInfo; +import org.apache.iceberg.util.PropertyUtil; +import org.apache.iceberg.util.Tasks; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.SparkSession; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import static org.apache.iceberg.TableProperties.GC_ENABLED; +import static org.apache.iceberg.TableProperties.GC_ENABLED_DEFAULT; + +/** + * An action that performs the same operation as {@link org.apache.iceberg.ExpireSnapshots} but uses Spark + * to determine the delta in files between the pre and post-expiration table metadata. All of the same + * restrictions of {@link org.apache.iceberg.ExpireSnapshots} also apply to this action. + *

+ * This action first leverages {@link org.apache.iceberg.ExpireSnapshots} to expire snapshots and then + * uses metadata tables to find files that can be safely deleted. This is done by anti-joining two Datasets + * that contain all manifest and content files before and after the expiration. The snapshot expiration + * will be fully committed before any deletes are issued. + *

+ * This operation performs a shuffle so the parallelism can be controlled through 'spark.sql.shuffle.partitions'. + *

+ * Deletes are still performed locally after retrieving the results from the Spark executors. + */ +@SuppressWarnings("UnnecessaryAnonymousClass") +public class BaseExpireSnapshotsSparkAction + extends BaseSparkAction implements ExpireSnapshots { + + public static final String STREAM_RESULTS = "stream-results"; + public static final boolean STREAM_RESULTS_DEFAULT = false; + + private static final Logger LOG = LoggerFactory.getLogger(BaseExpireSnapshotsSparkAction.class); + + private final Table table; + private final TableOperations ops; + private final Consumer defaultDelete = new Consumer() { + @Override + public void accept(String file) { + ops.io().deleteFile(file); + } + }; + + private final Set expiredSnapshotIds = Sets.newHashSet(); + private Long expireOlderThanValue = null; + private Integer retainLastValue = null; + private Consumer deleteFunc = defaultDelete; + private ExecutorService deleteExecutorService = null; + private Dataset expiredFiles = null; + + public BaseExpireSnapshotsSparkAction(SparkSession spark, Table table) { + super(spark); + this.table = table; + this.ops = ((HasTableOperations) table).operations(); + + ValidationException.check( + PropertyUtil.propertyAsBoolean(table.properties(), GC_ENABLED, GC_ENABLED_DEFAULT), + "Cannot expire snapshots: GC is disabled (deleting files may corrupt other tables)"); + } + + @Override + protected ExpireSnapshots self() { + return this; + } + + @Override + public BaseExpireSnapshotsSparkAction executeDeleteWith(ExecutorService executorService) { + this.deleteExecutorService = executorService; + return this; + } + + @Override + public BaseExpireSnapshotsSparkAction expireSnapshotId(long snapshotId) { + expiredSnapshotIds.add(snapshotId); + return this; + } + + @Override + public BaseExpireSnapshotsSparkAction expireOlderThan(long timestampMillis) { + this.expireOlderThanValue = timestampMillis; + return this; + } + + @Override + public BaseExpireSnapshotsSparkAction retainLast(int numSnapshots) { + Preconditions.checkArgument(1 <= numSnapshots, + "Number of snapshots to retain must be at least 1, cannot be: %s", numSnapshots); + this.retainLastValue = numSnapshots; + return this; + } + + @Override + public BaseExpireSnapshotsSparkAction deleteWith(Consumer newDeleteFunc) { + this.deleteFunc = newDeleteFunc; + return this; + } + + /** + * Expires snapshots and commits the changes to the table, returning a Dataset of files to delete. + *

+ * This does not delete data files. To delete data files, run {@link #execute()}. + *

+ * This may be called before or after {@link #execute()} is called to return the expired file list. + * + * @return a Dataset of files that are no longer referenced by the table + */ + public Dataset expire() { + if (expiredFiles == null) { + // fetch metadata before expiration + Dataset originalFiles = buildValidFileDF(ops.current()); + + // perform expiration + org.apache.iceberg.ExpireSnapshots expireSnapshots = table.expireSnapshots().cleanExpiredFiles(false); + for (long id : expiredSnapshotIds) { + expireSnapshots = expireSnapshots.expireSnapshotId(id); + } + + if (expireOlderThanValue != null) { + expireSnapshots = expireSnapshots.expireOlderThan(expireOlderThanValue); + } + + if (retainLastValue != null) { + expireSnapshots = expireSnapshots.retainLast(retainLastValue); + } + + expireSnapshots.commit(); + + // fetch metadata after expiration + Dataset validFiles = buildValidFileDF(ops.refresh()); + + // determine expired files + this.expiredFiles = originalFiles.except(validFiles); + } + + return expiredFiles; + } + + @Override + public ExpireSnapshots.Result execute() { + JobGroupInfo info = newJobGroupInfo("EXPIRE-SNAPSHOTS", jobDesc()); + return withJobGroupInfo(info, this::doExecute); + } + + private String jobDesc() { + List options = Lists.newArrayList(); + + if (expireOlderThanValue != null) { + options.add("older_than=" + expireOlderThanValue); + } + + if (retainLastValue != null) { + options.add("retain_last=" + retainLastValue); + } + + if (!expiredSnapshotIds.isEmpty()) { + Long first = expiredSnapshotIds.stream().findFirst().get(); + if (expiredSnapshotIds.size() > 1) { + options.add(String.format("snapshot_ids: %s (%s more...)", first, expiredSnapshotIds.size() - 1)); + } else { + options.add(String.format("snapshot_id: %s", first)); + } + } + + return String.format("Expiring snapshots (%s) in %s", Joiner.on(',').join(options), table.name()); + } + + private ExpireSnapshots.Result doExecute() { + boolean streamResults = PropertyUtil.propertyAsBoolean(options(), STREAM_RESULTS, STREAM_RESULTS_DEFAULT); + if (streamResults) { + return deleteFiles(expire().toLocalIterator()); + } else { + return deleteFiles(expire().collectAsList().iterator()); + } + } + + private Dataset buildValidFileDF(TableMetadata metadata) { + Table staticTable = newStaticTable(metadata, table.io()); + return buildValidContentFileWithTypeDF(staticTable) + .union(withFileType(buildManifestFileDF(staticTable), MANIFEST)) + .union(withFileType(buildManifestListDF(staticTable), MANIFEST_LIST)); + } + + /** + * Deletes files passed to it based on their type. + * + * @param expired an Iterator of Spark Rows of the structure (path: String, type: String) + * @return Statistics on which files were deleted + */ + private BaseExpireSnapshotsActionResult deleteFiles(Iterator expired) { + AtomicLong dataFileCount = new AtomicLong(0L); + AtomicLong posDeleteFileCount = new AtomicLong(0L); + AtomicLong eqDeleteFileCount = new AtomicLong(0L); + AtomicLong manifestCount = new AtomicLong(0L); + AtomicLong manifestListCount = new AtomicLong(0L); + + Tasks.foreach(expired) + .retry(3).stopRetryOn(NotFoundException.class).suppressFailureWhenFinished() + .executeWith(deleteExecutorService) + .onFailure((fileInfo, exc) -> { + String file = fileInfo.getString(0); + String type = fileInfo.getString(1); + LOG.warn("Delete failed for {}: {}", type, file, exc); + }) + .run(fileInfo -> { + String file = fileInfo.getString(0); + String type = fileInfo.getString(1); + deleteFunc.accept(file); + + if (FileContent.DATA.name().equalsIgnoreCase(type)) { + dataFileCount.incrementAndGet(); + LOG.trace("Deleted Data File: {}", file); + } else if (FileContent.POSITION_DELETES.name().equalsIgnoreCase(type)) { + posDeleteFileCount.incrementAndGet(); + LOG.trace("Deleted Positional Delete File: {}", file); + } else if (FileContent.EQUALITY_DELETES.name().equalsIgnoreCase(type)) { + eqDeleteFileCount.incrementAndGet(); + LOG.trace("Deleted Equality Delete File: {}", file); + } else if (MANIFEST.equals(type)) { + manifestCount.incrementAndGet(); + LOG.debug("Deleted Manifest: {}", file); + } else if (MANIFEST_LIST.equalsIgnoreCase(type)) { + manifestListCount.incrementAndGet(); + LOG.debug("Deleted Manifest List: {}", file); + } else { + throw new ValidationException("Illegal file type: %s", type); + } + }); + + long contentFileCount = dataFileCount.get() + posDeleteFileCount.get() + eqDeleteFileCount.get(); + LOG.info("Deleted {} total files", contentFileCount + manifestCount.get() + manifestListCount.get()); + + return new BaseExpireSnapshotsActionResult(dataFileCount.get(), posDeleteFileCount.get(), + eqDeleteFileCount.get(), manifestCount.get(), manifestListCount.get()); + } +} diff --git a/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/actions/BaseMigrateTableSparkAction.java b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/actions/BaseMigrateTableSparkAction.java new file mode 100644 index 000000000000..ef9f0d3e2583 --- /dev/null +++ b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/actions/BaseMigrateTableSparkAction.java @@ -0,0 +1,214 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iceberg.spark.actions; + +import java.util.Map; +import org.apache.iceberg.Snapshot; +import org.apache.iceberg.SnapshotSummary; +import org.apache.iceberg.Table; +import org.apache.iceberg.actions.BaseMigrateTableActionResult; +import org.apache.iceberg.actions.MigrateTable; +import org.apache.iceberg.exceptions.AlreadyExistsException; +import org.apache.iceberg.exceptions.NoSuchTableException; +import org.apache.iceberg.relocated.com.google.common.base.Preconditions; +import org.apache.iceberg.relocated.com.google.common.collect.Maps; +import org.apache.iceberg.spark.JobGroupInfo; +import org.apache.iceberg.spark.SparkSessionCatalog; +import org.apache.iceberg.spark.SparkTableUtil; +import org.apache.iceberg.spark.source.StagedSparkTable; +import org.apache.spark.sql.SparkSession; +import org.apache.spark.sql.catalyst.TableIdentifier; +import org.apache.spark.sql.connector.catalog.CatalogPlugin; +import org.apache.spark.sql.connector.catalog.Identifier; +import org.apache.spark.sql.connector.catalog.StagingTableCatalog; +import org.apache.spark.sql.connector.catalog.TableCatalog; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import scala.Some; +import scala.collection.JavaConverters; + +/** + * Takes a Spark table in the source catalog and attempts to transform it into an Iceberg + * table in the same location with the same identifier. Once complete the identifier which + * previously referred to a non-Iceberg table will refer to the newly migrated Iceberg + * table. + */ +public class BaseMigrateTableSparkAction + extends BaseTableCreationSparkAction + implements MigrateTable { + + private static final Logger LOG = LoggerFactory.getLogger(BaseMigrateTableSparkAction.class); + private static final String BACKUP_SUFFIX = "_BACKUP_"; + + private final StagingTableCatalog destCatalog; + private final Identifier destTableIdent; + private final Identifier backupIdent; + + public BaseMigrateTableSparkAction(SparkSession spark, CatalogPlugin sourceCatalog, Identifier sourceTableIdent) { + super(spark, sourceCatalog, sourceTableIdent); + this.destCatalog = checkDestinationCatalog(sourceCatalog); + this.destTableIdent = sourceTableIdent; + String backupName = sourceTableIdent.name() + BACKUP_SUFFIX; + this.backupIdent = Identifier.of(sourceTableIdent.namespace(), backupName); + } + + @Override + protected MigrateTable self() { + return this; + } + + @Override + protected StagingTableCatalog destCatalog() { + return destCatalog; + } + + @Override + protected Identifier destTableIdent() { + return destTableIdent; + } + + @Override + public MigrateTable tableProperties(Map properties) { + setProperties(properties); + return this; + } + + @Override + public MigrateTable tableProperty(String property, String value) { + setProperty(property, value); + return this; + } + + @Override + public MigrateTable.Result execute() { + String desc = String.format("Migrating table %s", destTableIdent().toString()); + JobGroupInfo info = newJobGroupInfo("MIGRATE-TABLE", desc); + return withJobGroupInfo(info, this::doExecute); + } + + private MigrateTable.Result doExecute() { + LOG.info("Starting the migration of {} to Iceberg", sourceTableIdent()); + + // move the source table to a new name, halting all modifications and allowing us to stage + // the creation of a new Iceberg table in its place + renameAndBackupSourceTable(); + + StagedSparkTable stagedTable = null; + Table icebergTable; + boolean threw = true; + try { + LOG.info("Staging a new Iceberg table {}", destTableIdent()); + stagedTable = stageDestTable(); + icebergTable = stagedTable.table(); + + LOG.info("Ensuring {} has a valid name mapping", destTableIdent()); + ensureNameMappingPresent(icebergTable); + + Some backupNamespace = Some.apply(backupIdent.namespace()[0]); + TableIdentifier v1BackupIdent = new TableIdentifier(backupIdent.name(), backupNamespace); + String stagingLocation = getMetadataLocation(icebergTable); + LOG.info("Generating Iceberg metadata for {} in {}", destTableIdent(), stagingLocation); + SparkTableUtil.importSparkTable(spark(), v1BackupIdent, icebergTable, stagingLocation); + + LOG.info("Committing staged changes to {}", destTableIdent()); + stagedTable.commitStagedChanges(); + threw = false; + } finally { + if (threw) { + LOG.error("Failed to perform the migration, aborting table creation and restoring the original table"); + + restoreSourceTable(); + + if (stagedTable != null) { + try { + stagedTable.abortStagedChanges(); + } catch (Exception abortException) { + LOG.error("Cannot abort staged changes", abortException); + } + } + } + } + + Snapshot snapshot = icebergTable.currentSnapshot(); + long migratedDataFilesCount = Long.parseLong(snapshot.summary().get(SnapshotSummary.TOTAL_DATA_FILES_PROP)); + LOG.info("Successfully loaded Iceberg metadata for {} files to {}", migratedDataFilesCount, destTableIdent()); + return new BaseMigrateTableActionResult(migratedDataFilesCount); + } + + @Override + protected Map destTableProps() { + Map properties = Maps.newHashMap(); + + // copy over relevant source table props + properties.putAll(JavaConverters.mapAsJavaMapConverter(v1SourceTable().properties()).asJava()); + EXCLUDED_PROPERTIES.forEach(properties::remove); + + // set default and user-provided props + properties.put(TableCatalog.PROP_PROVIDER, "iceberg"); + properties.putAll(additionalProperties()); + + // make sure we mark this table as migrated + properties.put("migrated", "true"); + + // inherit the source table location + properties.putIfAbsent(LOCATION, sourceTableLocation()); + + return properties; + } + + @Override + protected TableCatalog checkSourceCatalog(CatalogPlugin catalog) { + // currently the import code relies on being able to look up the table in the session catalog + Preconditions.checkArgument(catalog instanceof SparkSessionCatalog, + "Cannot migrate a table from a non-Iceberg Spark Session Catalog. Found %s of class %s as the source catalog.", + catalog.name(), catalog.getClass().getName()); + + return (TableCatalog) catalog; + } + + private void renameAndBackupSourceTable() { + try { + LOG.info("Renaming {} as {} for backup", sourceTableIdent(), backupIdent); + destCatalog().renameTable(sourceTableIdent(), backupIdent); + + } catch (org.apache.spark.sql.catalyst.analysis.NoSuchTableException e) { + throw new NoSuchTableException("Cannot find source table %s", sourceTableIdent()); + + } catch (org.apache.spark.sql.catalyst.analysis.TableAlreadyExistsException e) { + throw new AlreadyExistsException( + "Cannot rename %s as %s for backup. The backup table already exists.", + sourceTableIdent(), backupIdent); + } + } + + private void restoreSourceTable() { + try { + LOG.info("Restoring {} from {}", sourceTableIdent(), backupIdent); + destCatalog().renameTable(backupIdent, sourceTableIdent()); + + } catch (org.apache.spark.sql.catalyst.analysis.NoSuchTableException e) { + LOG.error("Cannot restore the original table, the backup table {} cannot be found", backupIdent, e); + + } catch (org.apache.spark.sql.catalyst.analysis.TableAlreadyExistsException e) { + LOG.error("Cannot restore the original table, a table with the original name exists. " + + "Use the backup table {} to restore the original table manually.", backupIdent, e); + } + } +} diff --git a/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/actions/BaseRewriteDataFilesSparkAction.java b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/actions/BaseRewriteDataFilesSparkAction.java new file mode 100644 index 000000000000..7572cf6689cf --- /dev/null +++ b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/actions/BaseRewriteDataFilesSparkAction.java @@ -0,0 +1,484 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iceberg.spark.actions; + +import java.io.IOException; +import java.math.RoundingMode; +import java.util.Collections; +import java.util.Comparator; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.concurrent.ConcurrentLinkedQueue; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.ThreadPoolExecutor; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.stream.Collectors; +import java.util.stream.Stream; +import org.apache.iceberg.DataFile; +import org.apache.iceberg.FileScanTask; +import org.apache.iceberg.RewriteJobOrder; +import org.apache.iceberg.SortOrder; +import org.apache.iceberg.StructLike; +import org.apache.iceberg.Table; +import org.apache.iceberg.actions.BaseRewriteDataFilesFileGroupInfo; +import org.apache.iceberg.actions.BaseRewriteDataFilesResult; +import org.apache.iceberg.actions.BinPackStrategy; +import org.apache.iceberg.actions.RewriteDataFiles; +import org.apache.iceberg.actions.RewriteDataFilesCommitManager; +import org.apache.iceberg.actions.RewriteFileGroup; +import org.apache.iceberg.actions.RewriteStrategy; +import org.apache.iceberg.actions.SortStrategy; +import org.apache.iceberg.data.GenericRecord; +import org.apache.iceberg.exceptions.CommitFailedException; +import org.apache.iceberg.exceptions.ValidationException; +import org.apache.iceberg.expressions.Expression; +import org.apache.iceberg.expressions.Expressions; +import org.apache.iceberg.io.CloseableIterable; +import org.apache.iceberg.relocated.com.google.common.annotations.VisibleForTesting; +import org.apache.iceberg.relocated.com.google.common.base.Preconditions; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableList; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableSet; +import org.apache.iceberg.relocated.com.google.common.collect.Lists; +import org.apache.iceberg.relocated.com.google.common.collect.Maps; +import org.apache.iceberg.relocated.com.google.common.collect.Queues; +import org.apache.iceberg.relocated.com.google.common.collect.Sets; +import org.apache.iceberg.relocated.com.google.common.math.IntMath; +import org.apache.iceberg.relocated.com.google.common.util.concurrent.MoreExecutors; +import org.apache.iceberg.relocated.com.google.common.util.concurrent.ThreadFactoryBuilder; +import org.apache.iceberg.types.Types.StructType; +import org.apache.iceberg.util.PropertyUtil; +import org.apache.iceberg.util.StructLikeMap; +import org.apache.iceberg.util.Tasks; +import org.apache.spark.sql.SparkSession; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +public class BaseRewriteDataFilesSparkAction + extends BaseSnapshotUpdateSparkAction implements RewriteDataFiles { + + private static final Logger LOG = LoggerFactory.getLogger(BaseRewriteDataFilesSparkAction.class); + private static final Set VALID_OPTIONS = ImmutableSet.of( + MAX_CONCURRENT_FILE_GROUP_REWRITES, + MAX_FILE_GROUP_SIZE_BYTES, + PARTIAL_PROGRESS_ENABLED, + PARTIAL_PROGRESS_MAX_COMMITS, + TARGET_FILE_SIZE_BYTES, + USE_STARTING_SEQUENCE_NUMBER, + REWRITE_JOB_ORDER + ); + + private final Table table; + private final String fullIdentifier; + + private Expression filter = Expressions.alwaysTrue(); + private int maxConcurrentFileGroupRewrites; + private int maxCommits; + private boolean partialProgressEnabled; + private boolean useStartingSequenceNumber; + private RewriteJobOrder rewriteJobOrder; + private RewriteStrategy strategy = null; + + @Deprecated + protected BaseRewriteDataFilesSparkAction(SparkSession spark, Table table) { + super(spark); + this.table = table; + this.fullIdentifier = null; + } + + protected BaseRewriteDataFilesSparkAction(SparkSession spark, Table table, String fullIdentifier) { + super(spark); + this.table = table; + this.fullIdentifier = fullIdentifier; + } + + @Override + protected RewriteDataFiles self() { + return this; + } + + @Override + public RewriteDataFiles binPack() { + Preconditions.checkArgument(this.strategy == null, + "Cannot set strategy to binpack, it has already been set", this.strategy); + this.strategy = binPackStrategy(); + return this; + } + + @Override + public RewriteDataFiles sort(SortOrder sortOrder) { + Preconditions.checkArgument(this.strategy == null, + "Cannot set strategy to sort, it has already been set to %s", this.strategy); + this.strategy = sortStrategy().sortOrder(sortOrder); + return this; + } + + @Override + public RewriteDataFiles sort() { + Preconditions.checkArgument(this.strategy == null, + "Cannot set strategy to sort, it has already been set to %s", this.strategy); + this.strategy = sortStrategy(); + return this; + } + + @Override + public RewriteDataFiles zOrder(String... columnNames) { + this.strategy = zOrderStrategy(columnNames); + return this; + } + + @Override + public RewriteDataFiles filter(Expression expression) { + filter = Expressions.and(filter, expression); + return this; + } + + @Override + public RewriteDataFiles.Result execute() { + if (table.currentSnapshot() == null) { + return new BaseRewriteDataFilesResult(ImmutableList.of()); + } + + long startingSnapshotId = table.currentSnapshot().snapshotId(); + + // Default to BinPack if no strategy selected + if (this.strategy == null) { + this.strategy = binPackStrategy(); + } + + validateAndInitOptions(); + + Map>> fileGroupsByPartition = planFileGroups(startingSnapshotId); + RewriteExecutionContext ctx = new RewriteExecutionContext(fileGroupsByPartition); + + if (ctx.totalGroupCount() == 0) { + LOG.info("Nothing found to rewrite in {}", table.name()); + return new BaseRewriteDataFilesResult(Collections.emptyList()); + } + + Stream groupStream = toGroupStream(ctx, fileGroupsByPartition); + + RewriteDataFilesCommitManager commitManager = commitManager(startingSnapshotId); + if (partialProgressEnabled) { + return doExecuteWithPartialProgress(ctx, groupStream, commitManager); + } else { + return doExecute(ctx, groupStream, commitManager); + } + } + + Map>> planFileGroups(long startingSnapshotId) { + CloseableIterable fileScanTasks = table.newScan() + .useSnapshot(startingSnapshotId) + .filter(filter) + .ignoreResiduals() + .planFiles(); + + try { + StructType partitionType = table.spec().partitionType(); + StructLikeMap> filesByPartition = StructLikeMap.create(partitionType); + StructLike emptyStruct = GenericRecord.create(partitionType); + + fileScanTasks.forEach(task -> { + // If a task uses an incompatible partition spec the data inside could contain values which + // belong to multiple partitions in the current spec. Treating all such files as un-partitioned and + // grouping them together helps to minimize new files made. + StructLike taskPartition = task.file().specId() == table.spec().specId() ? + task.file().partition() : emptyStruct; + + List files = filesByPartition.get(taskPartition); + if (files == null) { + files = Lists.newArrayList(); + } + + files.add(task); + filesByPartition.put(taskPartition, files); + }); + + StructLikeMap>> fileGroupsByPartition = StructLikeMap.create(partitionType); + + filesByPartition.forEach((partition, tasks) -> { + Iterable filtered = strategy.selectFilesToRewrite(tasks); + Iterable> groupedTasks = strategy.planFileGroups(filtered); + List> fileGroups = ImmutableList.copyOf(groupedTasks); + if (fileGroups.size() > 0) { + fileGroupsByPartition.put(partition, fileGroups); + } + }); + + return fileGroupsByPartition; + } finally { + try { + fileScanTasks.close(); + } catch (IOException io) { + LOG.error("Cannot properly close file iterable while planning for rewrite", io); + } + } + } + + @VisibleForTesting + RewriteFileGroup rewriteFiles(RewriteExecutionContext ctx, RewriteFileGroup fileGroup) { + String desc = jobDesc(fileGroup, ctx); + Set addedFiles = withJobGroupInfo( + newJobGroupInfo("REWRITE-DATA-FILES", desc), + () -> strategy.rewriteFiles(fileGroup.fileScans())); + + fileGroup.setOutputFiles(addedFiles); + LOG.info("Rewrite Files Ready to be Committed - {}", desc); + return fileGroup; + } + + private ExecutorService rewriteService() { + return MoreExecutors.getExitingExecutorService( + (ThreadPoolExecutor) Executors.newFixedThreadPool( + maxConcurrentFileGroupRewrites, + new ThreadFactoryBuilder() + .setNameFormat("Rewrite-Service-%d") + .build())); + } + + @VisibleForTesting + RewriteDataFilesCommitManager commitManager(long startingSnapshotId) { + return new RewriteDataFilesCommitManager(table, startingSnapshotId, useStartingSequenceNumber); + } + + private Result doExecute(RewriteExecutionContext ctx, Stream groupStream, + RewriteDataFilesCommitManager commitManager) { + ExecutorService rewriteService = rewriteService(); + + ConcurrentLinkedQueue rewrittenGroups = Queues.newConcurrentLinkedQueue(); + + Tasks.Builder rewriteTaskBuilder = Tasks.foreach(groupStream) + .executeWith(rewriteService) + .stopOnFailure() + .noRetry() + .onFailure((fileGroup, exception) -> { + LOG.warn("Failure during rewrite process for group {}", fileGroup.info(), exception); + }); + + try { + rewriteTaskBuilder.run(fileGroup -> { + rewrittenGroups.add(rewriteFiles(ctx, fileGroup)); + }); + } catch (Exception e) { + // At least one rewrite group failed, clean up all completed rewrites + LOG.error("Cannot complete rewrite, {} is not enabled and one of the file set groups failed to " + + "be rewritten. This error occurred during the writing of new files, not during the commit process. This " + + "indicates something is wrong that doesn't involve conflicts with other Iceberg operations. Enabling " + + "{} may help in this case but the root cause should be investigated. Cleaning up {} groups which finished " + + "being written.", PARTIAL_PROGRESS_ENABLED, PARTIAL_PROGRESS_ENABLED, rewrittenGroups.size(), e); + + Tasks.foreach(rewrittenGroups) + .suppressFailureWhenFinished() + .run(group -> commitManager.abortFileGroup(group)); + throw e; + } finally { + rewriteService.shutdown(); + } + + try { + commitManager.commitOrClean(Sets.newHashSet(rewrittenGroups)); + } catch (ValidationException | CommitFailedException e) { + String errorMessage = String.format( + "Cannot commit rewrite because of a ValidationException or CommitFailedException. This usually means that " + + "this rewrite has conflicted with another concurrent Iceberg operation. To reduce the likelihood of " + + "conflicts, set %s which will break up the rewrite into multiple smaller commits controlled by %s. " + + "Separate smaller rewrite commits can succeed independently while any commits that conflict with " + + "another Iceberg operation will be ignored. This mode will create additional snapshots in the table " + + "history, one for each commit.", + PARTIAL_PROGRESS_ENABLED, PARTIAL_PROGRESS_MAX_COMMITS); + throw new RuntimeException(errorMessage, e); + } + + List rewriteResults = rewrittenGroups.stream() + .map(RewriteFileGroup::asResult) + .collect(Collectors.toList()); + return new BaseRewriteDataFilesResult(rewriteResults); + } + + private Result doExecuteWithPartialProgress(RewriteExecutionContext ctx, Stream groupStream, + RewriteDataFilesCommitManager commitManager) { + ExecutorService rewriteService = rewriteService(); + + // Start Commit Service + int groupsPerCommit = IntMath.divide(ctx.totalGroupCount(), maxCommits, RoundingMode.CEILING); + RewriteDataFilesCommitManager.CommitService commitService = commitManager.service(groupsPerCommit); + commitService.start(); + + // Start rewrite tasks + Tasks.foreach(groupStream) + .suppressFailureWhenFinished() + .executeWith(rewriteService) + .noRetry() + .onFailure((fileGroup, exception) -> LOG.error("Failure during rewrite group {}", fileGroup.info(), exception)) + .run(fileGroup -> commitService.offer(rewriteFiles(ctx, fileGroup))); + rewriteService.shutdown(); + + // Stop Commit service + commitService.close(); + List commitResults = commitService.results(); + if (commitResults.size() == 0) { + LOG.error("{} is true but no rewrite commits succeeded. Check the logs to determine why the individual " + + "commits failed. If this is persistent it may help to increase {} which will break the rewrite operation " + + "into smaller commits.", PARTIAL_PROGRESS_ENABLED, PARTIAL_PROGRESS_MAX_COMMITS); + } + + List rewriteResults = commitResults.stream() + .map(RewriteFileGroup::asResult) + .collect(Collectors.toList()); + return new BaseRewriteDataFilesResult(rewriteResults); + } + + Stream toGroupStream(RewriteExecutionContext ctx, + Map>> fileGroupsByPartition) { + Stream rewriteFileGroupStream = fileGroupsByPartition.entrySet().stream() + .flatMap(e -> { + StructLike partition = e.getKey(); + List> fileGroups = e.getValue(); + return fileGroups.stream().map(tasks -> { + int globalIndex = ctx.currentGlobalIndex(); + int partitionIndex = ctx.currentPartitionIndex(partition); + FileGroupInfo info = new BaseRewriteDataFilesFileGroupInfo(globalIndex, partitionIndex, partition); + return new RewriteFileGroup(info, tasks); + }); + }); + + return rewriteFileGroupStream.sorted(rewriteGroupComparator()); + } + + private Comparator rewriteGroupComparator() { + switch (rewriteJobOrder) { + case BYTES_ASC: + return Comparator.comparing(RewriteFileGroup::sizeInBytes); + case BYTES_DESC: + return Comparator.comparing(RewriteFileGroup::sizeInBytes, Comparator.reverseOrder()); + case FILES_ASC: + return Comparator.comparing(RewriteFileGroup::numFiles); + case FILES_DESC: + return Comparator.comparing(RewriteFileGroup::numFiles, Comparator.reverseOrder()); + default: + return (fileGroupOne, fileGroupTwo) -> 0; + } + } + + void validateAndInitOptions() { + Set validOptions = Sets.newHashSet(strategy.validOptions()); + validOptions.addAll(VALID_OPTIONS); + + Set invalidKeys = Sets.newHashSet(options().keySet()); + invalidKeys.removeAll(validOptions); + + Preconditions.checkArgument(invalidKeys.isEmpty(), + "Cannot use options %s, they are not supported by the action or the strategy %s", + invalidKeys, strategy.name()); + + strategy = strategy.options(options()); + + maxConcurrentFileGroupRewrites = PropertyUtil.propertyAsInt(options(), + MAX_CONCURRENT_FILE_GROUP_REWRITES, + MAX_CONCURRENT_FILE_GROUP_REWRITES_DEFAULT); + + maxCommits = PropertyUtil.propertyAsInt(options(), + PARTIAL_PROGRESS_MAX_COMMITS, + PARTIAL_PROGRESS_MAX_COMMITS_DEFAULT); + + partialProgressEnabled = PropertyUtil.propertyAsBoolean(options(), + PARTIAL_PROGRESS_ENABLED, + PARTIAL_PROGRESS_ENABLED_DEFAULT); + + useStartingSequenceNumber = PropertyUtil.propertyAsBoolean(options(), + USE_STARTING_SEQUENCE_NUMBER, + USE_STARTING_SEQUENCE_NUMBER_DEFAULT); + + rewriteJobOrder = RewriteJobOrder.fromName(PropertyUtil.propertyAsString(options(), + REWRITE_JOB_ORDER, + REWRITE_JOB_ORDER_DEFAULT)); + + Preconditions.checkArgument(maxConcurrentFileGroupRewrites >= 1, + "Cannot set %s to %s, the value must be positive.", + MAX_CONCURRENT_FILE_GROUP_REWRITES, maxConcurrentFileGroupRewrites); + + Preconditions.checkArgument(!partialProgressEnabled || maxCommits > 0, + "Cannot set %s to %s, the value must be positive when %s is true", + PARTIAL_PROGRESS_MAX_COMMITS, maxCommits, PARTIAL_PROGRESS_ENABLED); + } + + private String jobDesc(RewriteFileGroup group, RewriteExecutionContext ctx) { + StructLike partition = group.info().partition(); + if (partition.size() > 0) { + return String.format("Rewriting %d files (%s, file group %d/%d, %s (%d/%d)) in %s", + group.rewrittenFiles().size(), + strategy.name(), group.info().globalIndex(), + ctx.totalGroupCount(), partition, group.info().partitionIndex(), ctx.groupsInPartition(partition), + table.name()); + } else { + return String.format("Rewriting %d files (%s, file group %d/%d) in %s", + group.rewrittenFiles().size(), + strategy.name(), group.info().globalIndex(), ctx.totalGroupCount(), + table.name()); + } + } + + private BinPackStrategy binPackStrategy() { + return new SparkBinPackStrategy(table, fullIdentifier, spark()); + } + + private SortStrategy sortStrategy() { + return new SparkSortStrategy(table, spark()); + } + + private SortStrategy zOrderStrategy(String... columnNames) { + return new SparkZOrderStrategy(table, spark(), Lists.newArrayList(columnNames)); + } + + @VisibleForTesting + static class RewriteExecutionContext { + private final Map numGroupsByPartition; + private final int totalGroupCount; + private final Map partitionIndexMap; + private final AtomicInteger groupIndex; + + RewriteExecutionContext(Map>> fileGroupsByPartition) { + this.numGroupsByPartition = fileGroupsByPartition.entrySet().stream() + .collect(Collectors.toMap(Map.Entry::getKey, e -> e.getValue().size())); + this.totalGroupCount = numGroupsByPartition.values().stream() + .reduce(Integer::sum) + .orElse(0); + this.partitionIndexMap = Maps.newConcurrentMap(); + this.groupIndex = new AtomicInteger(1); + } + + public int currentGlobalIndex() { + return groupIndex.getAndIncrement(); + } + + public int currentPartitionIndex(StructLike partition) { + return partitionIndexMap.merge(partition, 1, Integer::sum); + } + + public int groupsInPartition(StructLike partition) { + return numGroupsByPartition.get(partition); + } + + public int totalGroupCount() { + return totalGroupCount; + } + } +} diff --git a/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/actions/BaseRewriteManifestsSparkAction.java b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/actions/BaseRewriteManifestsSparkAction.java new file mode 100644 index 000000000000..1648dd26f85a --- /dev/null +++ b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/actions/BaseRewriteManifestsSparkAction.java @@ -0,0 +1,370 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iceberg.spark.actions; + +import java.io.IOException; +import java.util.Collections; +import java.util.List; +import java.util.UUID; +import java.util.function.Function; +import java.util.function.Predicate; +import java.util.stream.Collectors; +import org.apache.hadoop.fs.Path; +import org.apache.iceberg.DataFile; +import org.apache.iceberg.FileFormat; +import org.apache.iceberg.HasTableOperations; +import org.apache.iceberg.ManifestFile; +import org.apache.iceberg.ManifestFiles; +import org.apache.iceberg.ManifestWriter; +import org.apache.iceberg.PartitionSpec; +import org.apache.iceberg.Snapshot; +import org.apache.iceberg.Table; +import org.apache.iceberg.TableOperations; +import org.apache.iceberg.TableProperties; +import org.apache.iceberg.actions.BaseRewriteManifestsActionResult; +import org.apache.iceberg.actions.RewriteManifests; +import org.apache.iceberg.exceptions.CommitStateUnknownException; +import org.apache.iceberg.exceptions.ValidationException; +import org.apache.iceberg.io.FileIO; +import org.apache.iceberg.io.OutputFile; +import org.apache.iceberg.relocated.com.google.common.base.Preconditions; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableList; +import org.apache.iceberg.relocated.com.google.common.collect.Iterables; +import org.apache.iceberg.relocated.com.google.common.collect.Lists; +import org.apache.iceberg.spark.JobGroupInfo; +import org.apache.iceberg.spark.SparkDataFile; +import org.apache.iceberg.spark.SparkUtil; +import org.apache.iceberg.types.Types; +import org.apache.iceberg.util.PropertyUtil; +import org.apache.iceberg.util.Tasks; +import org.apache.iceberg.util.ThreadPools; +import org.apache.spark.api.java.function.MapFunction; +import org.apache.spark.api.java.function.MapPartitionsFunction; +import org.apache.spark.broadcast.Broadcast; +import org.apache.spark.sql.Column; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Encoder; +import org.apache.spark.sql.Encoders; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.SparkSession; +import org.apache.spark.sql.internal.SQLConf; +import org.apache.spark.sql.types.StructType; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import static org.apache.iceberg.MetadataTableType.ENTRIES; + +/** + * An action that rewrites manifests in a distributed manner and co-locates metadata for partitions. + *

+ * By default, this action rewrites all manifests for the current partition spec and writes the result + * to the metadata folder. The behavior can be modified by passing a custom predicate to {@link #rewriteIf(Predicate)} + * and a custom spec id to {@link #specId(int)}. In addition, there is a way to configure a custom location + * for new manifests via {@link #stagingLocation}. + */ +public class BaseRewriteManifestsSparkAction + extends BaseSnapshotUpdateSparkAction + implements RewriteManifests { + + private static final Logger LOG = LoggerFactory.getLogger(BaseRewriteManifestsSparkAction.class); + + private static final String USE_CACHING = "use-caching"; + private static final boolean USE_CACHING_DEFAULT = true; + + private final Encoder manifestEncoder; + private final Table table; + private final int formatVersion; + private final FileIO fileIO; + private final long targetManifestSizeBytes; + + private PartitionSpec spec = null; + private Predicate predicate = manifest -> true; + private String stagingLocation = null; + + public BaseRewriteManifestsSparkAction(SparkSession spark, Table table) { + super(spark); + this.manifestEncoder = Encoders.javaSerialization(ManifestFile.class); + this.table = table; + this.spec = table.spec(); + this.targetManifestSizeBytes = PropertyUtil.propertyAsLong( + table.properties(), + TableProperties.MANIFEST_TARGET_SIZE_BYTES, + TableProperties.MANIFEST_TARGET_SIZE_BYTES_DEFAULT); + this.fileIO = SparkUtil.serializableFileIO(table); + + // default the staging location to the metadata location + TableOperations ops = ((HasTableOperations) table).operations(); + Path metadataFilePath = new Path(ops.metadataFileLocation("file")); + this.stagingLocation = metadataFilePath.getParent().toString(); + + // use the current table format version for new manifests + this.formatVersion = ops.current().formatVersion(); + } + + @Override + protected RewriteManifests self() { + return this; + } + + @Override + public RewriteManifests specId(int specId) { + Preconditions.checkArgument(table.specs().containsKey(specId), "Invalid spec id %s", specId); + this.spec = table.specs().get(specId); + return this; + } + + @Override + public RewriteManifests rewriteIf(Predicate newPredicate) { + this.predicate = newPredicate; + return this; + } + + @Override + public RewriteManifests stagingLocation(String newStagingLocation) { + this.stagingLocation = newStagingLocation; + return this; + } + + @Override + public RewriteManifests.Result execute() { + String desc = String.format("Rewriting manifests (staging location=%s) of %s", stagingLocation, table.name()); + JobGroupInfo info = newJobGroupInfo("REWRITE-MANIFESTS", desc); + return withJobGroupInfo(info, this::doExecute); + } + + private RewriteManifests.Result doExecute() { + List matchingManifests = findMatchingManifests(); + if (matchingManifests.isEmpty()) { + return BaseRewriteManifestsActionResult.empty(); + } + + long totalSizeBytes = 0L; + int numEntries = 0; + + for (ManifestFile manifest : matchingManifests) { + ValidationException.check(hasFileCounts(manifest), "No file counts in manifest: %s", manifest.path()); + + totalSizeBytes += manifest.length(); + numEntries += manifest.addedFilesCount() + manifest.existingFilesCount() + manifest.deletedFilesCount(); + } + + int targetNumManifests = targetNumManifests(totalSizeBytes); + int targetNumManifestEntries = targetNumManifestEntries(numEntries, targetNumManifests); + + Dataset manifestEntryDF = buildManifestEntryDF(matchingManifests); + + List newManifests; + if (spec.fields().size() < 1) { + newManifests = writeManifestsForUnpartitionedTable(manifestEntryDF, targetNumManifests); + } else { + newManifests = writeManifestsForPartitionedTable(manifestEntryDF, targetNumManifests, targetNumManifestEntries); + } + + replaceManifests(matchingManifests, newManifests); + + return new BaseRewriteManifestsActionResult(matchingManifests, newManifests); + } + + private Dataset buildManifestEntryDF(List manifests) { + Dataset manifestDF = spark() + .createDataset(Lists.transform(manifests, ManifestFile::path), Encoders.STRING()) + .toDF("manifest"); + + Dataset manifestEntryDF = loadMetadataTable(table, ENTRIES) + .filter("status < 2") // select only live entries + .selectExpr("input_file_name() as manifest", "snapshot_id", "sequence_number", "data_file"); + + Column joinCond = manifestDF.col("manifest").equalTo(manifestEntryDF.col("manifest")); + return manifestEntryDF + .join(manifestDF, joinCond, "left_semi") + .select("snapshot_id", "sequence_number", "data_file"); + } + + private List writeManifestsForUnpartitionedTable(Dataset manifestEntryDF, int numManifests) { + Broadcast io = sparkContext().broadcast(fileIO); + StructType sparkType = (StructType) manifestEntryDF.schema().apply("data_file").dataType(); + + // we rely only on the target number of manifests for unpartitioned tables + // as we should not worry about having too much metadata per partition + long maxNumManifestEntries = Long.MAX_VALUE; + + return manifestEntryDF + .repartition(numManifests) + .mapPartitions( + toManifests(io, maxNumManifestEntries, stagingLocation, formatVersion, spec, sparkType), + manifestEncoder + ) + .collectAsList(); + } + + private List writeManifestsForPartitionedTable( + Dataset manifestEntryDF, int numManifests, + int targetNumManifestEntries) { + + Broadcast io = sparkContext().broadcast(fileIO); + StructType sparkType = (StructType) manifestEntryDF.schema().apply("data_file").dataType(); + + // we allow the actual size of manifests to be 10% higher if the estimation is not precise enough + long maxNumManifestEntries = (long) (1.1 * targetNumManifestEntries); + + return withReusableDS(manifestEntryDF, df -> { + Column partitionColumn = df.col("data_file.partition"); + return df.repartitionByRange(numManifests, partitionColumn) + .sortWithinPartitions(partitionColumn) + .mapPartitions( + toManifests(io, maxNumManifestEntries, stagingLocation, formatVersion, spec, sparkType), + manifestEncoder + ) + .collectAsList(); + }); + } + + private U withReusableDS(Dataset ds, Function, U> func) { + Dataset reusableDS; + boolean useCaching = PropertyUtil.propertyAsBoolean(options(), USE_CACHING, USE_CACHING_DEFAULT); + if (useCaching) { + reusableDS = ds.cache(); + } else { + int parallelism = SQLConf.get().numShufflePartitions(); + reusableDS = ds.repartition(parallelism).map((MapFunction) value -> value, ds.exprEnc()); + } + + try { + return func.apply(reusableDS); + } finally { + if (useCaching) { + reusableDS.unpersist(false); + } + } + } + + private List findMatchingManifests() { + Snapshot currentSnapshot = table.currentSnapshot(); + + if (currentSnapshot == null) { + return ImmutableList.of(); + } + + return currentSnapshot.dataManifests(fileIO).stream() + .filter(manifest -> manifest.partitionSpecId() == spec.specId() && predicate.test(manifest)) + .collect(Collectors.toList()); + } + + private int targetNumManifests(long totalSizeBytes) { + return (int) ((totalSizeBytes + targetManifestSizeBytes - 1) / targetManifestSizeBytes); + } + + private int targetNumManifestEntries(int numEntries, int numManifests) { + return (numEntries + numManifests - 1) / numManifests; + } + + private boolean hasFileCounts(ManifestFile manifest) { + return manifest.addedFilesCount() != null && + manifest.existingFilesCount() != null && + manifest.deletedFilesCount() != null; + } + + private void replaceManifests(Iterable deletedManifests, Iterable addedManifests) { + try { + boolean snapshotIdInheritanceEnabled = PropertyUtil.propertyAsBoolean( + table.properties(), + TableProperties.SNAPSHOT_ID_INHERITANCE_ENABLED, + TableProperties.SNAPSHOT_ID_INHERITANCE_ENABLED_DEFAULT); + + org.apache.iceberg.RewriteManifests rewriteManifests = table.rewriteManifests(); + deletedManifests.forEach(rewriteManifests::deleteManifest); + addedManifests.forEach(rewriteManifests::addManifest); + commit(rewriteManifests); + + if (!snapshotIdInheritanceEnabled) { + // delete new manifests as they were rewritten before the commit + deleteFiles(Iterables.transform(addedManifests, ManifestFile::path)); + } + } catch (CommitStateUnknownException commitStateUnknownException) { + // don't clean up added manifest files, because they may have been successfully committed. + throw commitStateUnknownException; + } catch (Exception e) { + // delete all new manifests because the rewrite failed + deleteFiles(Iterables.transform(addedManifests, ManifestFile::path)); + throw e; + } + } + + private void deleteFiles(Iterable locations) { + Tasks.foreach(locations) + .executeWith(ThreadPools.getWorkerPool()) + .noRetry() + .suppressFailureWhenFinished() + .onFailure((location, exc) -> LOG.warn("Failed to delete: {}", location, exc)) + .run(fileIO::deleteFile); + } + + private static ManifestFile writeManifest( + List rows, int startIndex, int endIndex, Broadcast io, + String location, int format, PartitionSpec spec, StructType sparkType) throws IOException { + + String manifestName = "optimized-m-" + UUID.randomUUID(); + Path manifestPath = new Path(location, manifestName); + OutputFile outputFile = io.value().newOutputFile(FileFormat.AVRO.addExtension(manifestPath.toString())); + + Types.StructType dataFileType = DataFile.getType(spec.partitionType()); + SparkDataFile wrapper = new SparkDataFile(dataFileType, sparkType); + + ManifestWriter writer = ManifestFiles.write(format, spec, outputFile, null); + + try { + for (int index = startIndex; index < endIndex; index++) { + Row row = rows.get(index); + long snapshotId = row.getLong(0); + long sequenceNumber = row.getLong(1); + Row file = row.getStruct(2); + writer.existing(wrapper.wrap(file), snapshotId, sequenceNumber); + } + } finally { + writer.close(); + } + + return writer.toManifestFile(); + } + + private static MapPartitionsFunction toManifests( + Broadcast io, long maxNumManifestEntries, String location, + int format, PartitionSpec spec, StructType sparkType) { + + return rows -> { + List rowsAsList = Lists.newArrayList(rows); + + if (rowsAsList.isEmpty()) { + return Collections.emptyIterator(); + } + + List manifests = Lists.newArrayList(); + if (rowsAsList.size() <= maxNumManifestEntries) { + manifests.add(writeManifest(rowsAsList, 0, rowsAsList.size(), io, location, format, spec, sparkType)); + } else { + int midIndex = rowsAsList.size() / 2; + manifests.add(writeManifest(rowsAsList, 0, midIndex, io, location, format, spec, sparkType)); + manifests.add(writeManifest(rowsAsList, midIndex, rowsAsList.size(), io, location, format, spec, sparkType)); + } + + return manifests.iterator(); + }; + } +} diff --git a/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/actions/BaseSnapshotTableSparkAction.java b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/actions/BaseSnapshotTableSparkAction.java new file mode 100644 index 000000000000..e584d33a3f95 --- /dev/null +++ b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/actions/BaseSnapshotTableSparkAction.java @@ -0,0 +1,204 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iceberg.spark.actions; + +import java.util.Map; +import org.apache.iceberg.Snapshot; +import org.apache.iceberg.SnapshotSummary; +import org.apache.iceberg.Table; +import org.apache.iceberg.TableProperties; +import org.apache.iceberg.actions.BaseSnapshotTableActionResult; +import org.apache.iceberg.actions.SnapshotTable; +import org.apache.iceberg.relocated.com.google.common.base.Preconditions; +import org.apache.iceberg.relocated.com.google.common.collect.Maps; +import org.apache.iceberg.spark.JobGroupInfo; +import org.apache.iceberg.spark.Spark3Util; +import org.apache.iceberg.spark.Spark3Util.CatalogAndIdentifier; +import org.apache.iceberg.spark.SparkTableUtil; +import org.apache.iceberg.spark.source.StagedSparkTable; +import org.apache.spark.sql.SparkSession; +import org.apache.spark.sql.catalyst.TableIdentifier; +import org.apache.spark.sql.connector.catalog.CatalogPlugin; +import org.apache.spark.sql.connector.catalog.Identifier; +import org.apache.spark.sql.connector.catalog.StagingTableCatalog; +import org.apache.spark.sql.connector.catalog.TableCatalog; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import scala.collection.JavaConverters; + +/** + * Creates a new Iceberg table based on a source Spark table. The new Iceberg table will + * have a different data and metadata directory allowing it to exist independently of the + * source table. + */ +public class BaseSnapshotTableSparkAction + extends BaseTableCreationSparkAction + implements SnapshotTable { + + private static final Logger LOG = LoggerFactory.getLogger(BaseSnapshotTableSparkAction.class); + + private StagingTableCatalog destCatalog; + private Identifier destTableIdent; + private String destTableLocation = null; + + BaseSnapshotTableSparkAction(SparkSession spark, CatalogPlugin sourceCatalog, Identifier sourceTableIdent) { + super(spark, sourceCatalog, sourceTableIdent); + } + + @Override + protected SnapshotTable self() { + return this; + } + + @Override + protected StagingTableCatalog destCatalog() { + return destCatalog; + } + + @Override + protected Identifier destTableIdent() { + return destTableIdent; + } + + @Override + public SnapshotTable as(String ident) { + String ctx = "snapshot destination"; + CatalogPlugin defaultCatalog = spark().sessionState().catalogManager().currentCatalog(); + CatalogAndIdentifier catalogAndIdent = Spark3Util.catalogAndIdentifier(ctx, spark(), ident, defaultCatalog); + this.destCatalog = checkDestinationCatalog(catalogAndIdent.catalog()); + this.destTableIdent = catalogAndIdent.identifier(); + return this; + } + + @Override + public SnapshotTable tableProperties(Map properties) { + setProperties(properties); + return this; + } + + @Override + public SnapshotTable tableProperty(String property, String value) { + setProperty(property, value); + return this; + } + + @Override + public SnapshotTable.Result execute() { + String desc = String.format("Snapshotting table %s as %s", sourceTableIdent(), destTableIdent); + JobGroupInfo info = newJobGroupInfo("SNAPSHOT-TABLE", desc); + return withJobGroupInfo(info, this::doExecute); + } + + private SnapshotTable.Result doExecute() { + Preconditions.checkArgument(destCatalog() != null && destTableIdent() != null, + "The destination catalog and identifier cannot be null. " + + "Make sure to configure the action with a valid destination table identifier via the `as` method."); + + LOG.info("Staging a new Iceberg table {} as a snapshot of {}", destTableIdent(), sourceTableIdent()); + StagedSparkTable stagedTable = stageDestTable(); + Table icebergTable = stagedTable.table(); + + // TODO: Check the dest table location does not overlap with the source table location + + boolean threw = true; + try { + LOG.info("Ensuring {} has a valid name mapping", destTableIdent()); + ensureNameMappingPresent(icebergTable); + + TableIdentifier v1TableIdent = v1SourceTable().identifier(); + String stagingLocation = getMetadataLocation(icebergTable); + LOG.info("Generating Iceberg metadata for {} in {}", destTableIdent(), stagingLocation); + SparkTableUtil.importSparkTable(spark(), v1TableIdent, icebergTable, stagingLocation); + + LOG.info("Committing staged changes to {}", destTableIdent()); + stagedTable.commitStagedChanges(); + threw = false; + } finally { + if (threw) { + LOG.error("Error when populating the staged table with metadata, aborting changes"); + + try { + stagedTable.abortStagedChanges(); + } catch (Exception abortException) { + LOG.error("Cannot abort staged changes", abortException); + } + } + } + + Snapshot snapshot = icebergTable.currentSnapshot(); + long importedDataFilesCount = Long.parseLong(snapshot.summary().get(SnapshotSummary.TOTAL_DATA_FILES_PROP)); + LOG.info("Successfully loaded Iceberg metadata for {} files to {}", importedDataFilesCount, destTableIdent()); + return new BaseSnapshotTableActionResult(importedDataFilesCount); + } + + @Override + protected Map destTableProps() { + Map properties = Maps.newHashMap(); + + // copy over relevant source table props + properties.putAll(JavaConverters.mapAsJavaMapConverter(v1SourceTable().properties()).asJava()); + EXCLUDED_PROPERTIES.forEach(properties::remove); + + // remove any possible location properties from origin properties + properties.remove(LOCATION); + properties.remove(TableProperties.WRITE_METADATA_LOCATION); + properties.remove(TableProperties.WRITE_FOLDER_STORAGE_LOCATION); + properties.remove(TableProperties.OBJECT_STORE_PATH); + properties.remove(TableProperties.WRITE_DATA_LOCATION); + + // set default and user-provided props + properties.put(TableCatalog.PROP_PROVIDER, "iceberg"); + properties.putAll(additionalProperties()); + + // make sure we mark this table as a snapshot table + properties.put(TableProperties.GC_ENABLED, "false"); + properties.put("snapshot", "true"); + + // set the destination table location if provided + if (destTableLocation != null) { + properties.put(LOCATION, destTableLocation); + } + + return properties; + } + + @Override + protected TableCatalog checkSourceCatalog(CatalogPlugin catalog) { + // currently the import code relies on being able to look up the table in the session catalog + Preconditions.checkArgument(catalog.name().equalsIgnoreCase("spark_catalog"), + "Cannot snapshot a table that isn't in the session catalog (i.e. spark_catalog). " + + "Found source catalog: %s.", catalog.name()); + + Preconditions.checkArgument(catalog instanceof TableCatalog, + "Cannot snapshot as catalog %s of class %s in not a table catalog", + catalog.name(), catalog.getClass().getName()); + + return (TableCatalog) catalog; + } + + @Override + public SnapshotTable tableLocation(String location) { + Preconditions.checkArgument(!sourceTableLocation().equals(location), + "The snapshot table location cannot be same as the source table location. " + + "This would mix snapshot table files with original table files."); + this.destTableLocation = location; + return this; + } +} diff --git a/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/actions/BaseSnapshotUpdateSparkAction.java b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/actions/BaseSnapshotUpdateSparkAction.java new file mode 100644 index 000000000000..53fa06bbb5dc --- /dev/null +++ b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/actions/BaseSnapshotUpdateSparkAction.java @@ -0,0 +1,46 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iceberg.spark.actions; + +import java.util.Map; +import org.apache.iceberg.actions.SnapshotUpdate; +import org.apache.iceberg.relocated.com.google.common.collect.Maps; +import org.apache.spark.sql.SparkSession; + +abstract class BaseSnapshotUpdateSparkAction + extends BaseSparkAction implements SnapshotUpdate { + + private final Map summary = Maps.newHashMap(); + + protected BaseSnapshotUpdateSparkAction(SparkSession spark) { + super(spark); + } + + @Override + public ThisT snapshotProperty(String property, String value) { + summary.put(property, value); + return self(); + } + + protected void commit(org.apache.iceberg.SnapshotUpdate update) { + summary.forEach(update::set); + update.commit(); + } +} diff --git a/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/actions/BaseSparkAction.java b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/actions/BaseSparkAction.java new file mode 100644 index 000000000000..99a026abf5f0 --- /dev/null +++ b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/actions/BaseSparkAction.java @@ -0,0 +1,231 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iceberg.spark.actions; + +import java.util.Iterator; +import java.util.List; +import java.util.Map; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.function.Supplier; +import org.apache.iceberg.BaseTable; +import org.apache.iceberg.ContentFile; +import org.apache.iceberg.DataFile; +import org.apache.iceberg.ManifestFiles; +import org.apache.iceberg.MetadataTableType; +import org.apache.iceberg.PartitionSpec; +import org.apache.iceberg.ReachableFileUtil; +import org.apache.iceberg.SerializableTable; +import org.apache.iceberg.StaticTableOperations; +import org.apache.iceberg.Table; +import org.apache.iceberg.TableMetadata; +import org.apache.iceberg.actions.Action; +import org.apache.iceberg.io.CloseableIterator; +import org.apache.iceberg.io.ClosingIterator; +import org.apache.iceberg.io.FileIO; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableList; +import org.apache.iceberg.relocated.com.google.common.collect.Lists; +import org.apache.iceberg.relocated.com.google.common.collect.Maps; +import org.apache.iceberg.spark.JobGroupInfo; +import org.apache.iceberg.spark.JobGroupUtils; +import org.apache.iceberg.spark.SparkTableUtil; +import org.apache.spark.SparkContext; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.api.java.function.FlatMapFunction; +import org.apache.spark.broadcast.Broadcast; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Encoders; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.SparkSession; +import scala.Tuple2; + +import static org.apache.iceberg.MetadataTableType.ALL_MANIFESTS; +import static org.apache.spark.sql.functions.col; +import static org.apache.spark.sql.functions.lit; + +abstract class BaseSparkAction implements Action { + + protected static final String CONTENT_FILE = "Content File"; + protected static final String MANIFEST = "Manifest"; + protected static final String MANIFEST_LIST = "Manifest List"; + protected static final String OTHERS = "Others"; + + protected static final String FILE_PATH = "file_path"; + protected static final String FILE_TYPE = "file_type"; + protected static final String LAST_MODIFIED = "last_modified"; + + private static final AtomicInteger JOB_COUNTER = new AtomicInteger(); + + private final SparkSession spark; + private final JavaSparkContext sparkContext; + private final Map options = Maps.newHashMap(); + + protected BaseSparkAction(SparkSession spark) { + this.spark = spark; + this.sparkContext = JavaSparkContext.fromSparkContext(spark.sparkContext()); + } + + protected SparkSession spark() { + return spark; + } + + protected JavaSparkContext sparkContext() { + return sparkContext; + } + + protected abstract ThisT self(); + + @Override + public ThisT option(String name, String value) { + options.put(name, value); + return self(); + } + + @Override + public ThisT options(Map newOptions) { + options.putAll(newOptions); + return self(); + } + + protected Map options() { + return options; + } + + protected T withJobGroupInfo(JobGroupInfo info, Supplier supplier) { + SparkContext context = spark().sparkContext(); + JobGroupInfo previousInfo = JobGroupUtils.getJobGroupInfo(context); + try { + JobGroupUtils.setJobGroupInfo(context, info); + return supplier.get(); + } finally { + JobGroupUtils.setJobGroupInfo(context, previousInfo); + } + } + + protected JobGroupInfo newJobGroupInfo(String groupId, String desc) { + return new JobGroupInfo(groupId + "-" + JOB_COUNTER.incrementAndGet(), desc, false); + } + + protected Table newStaticTable(TableMetadata metadata, FileIO io) { + String metadataFileLocation = metadata.metadataFileLocation(); + StaticTableOperations ops = new StaticTableOperations(metadataFileLocation, io); + return new BaseTable(ops, metadataFileLocation); + } + + // builds a DF of delete and data file path and type by reading all manifests + protected Dataset buildValidContentFileWithTypeDF(Table table) { + Broadcast tableBroadcast = sparkContext.broadcast(SerializableTable.copyOf(table)); + + Dataset allManifests = loadMetadataTable(table, ALL_MANIFESTS) + .selectExpr( + "content", + "path", + "length", + "partition_spec_id as partitionSpecId", + "added_snapshot_id as addedSnapshotId") + .dropDuplicates("path") + .repartition(spark.sessionState().conf().numShufflePartitions()) // avoid adaptive execution combining tasks + .as(Encoders.bean(ManifestFileBean.class)); + + return allManifests + .flatMap(new ReadManifest(tableBroadcast), Encoders.tuple(Encoders.STRING(), Encoders.STRING())) + .toDF(FILE_PATH, FILE_TYPE); + } + + // builds a DF of delete and data file paths by reading all manifests + protected Dataset buildValidContentFileDF(Table table) { + return buildValidContentFileWithTypeDF(table).select(FILE_PATH); + } + + protected Dataset buildManifestFileDF(Table table) { + return loadMetadataTable(table, ALL_MANIFESTS).select(col("path").as(FILE_PATH)); + } + + protected Dataset buildManifestListDF(Table table) { + List manifestLists = ReachableFileUtil.manifestListLocations(table); + return spark.createDataset(manifestLists, Encoders.STRING()).toDF(FILE_PATH); + } + + protected Dataset buildOtherMetadataFileDF(Table table) { + return buildOtherMetadataFileDF(table, false /* include all reachable previous metadata locations */); + } + + protected Dataset buildAllReachableOtherMetadataFileDF(Table table) { + return buildOtherMetadataFileDF(table, true /* include all reachable previous metadata locations */); + } + + private Dataset buildOtherMetadataFileDF(Table table, boolean includePreviousMetadataLocations) { + List otherMetadataFiles = Lists.newArrayList(); + otherMetadataFiles.addAll(ReachableFileUtil.metadataFileLocations(table, includePreviousMetadataLocations)); + otherMetadataFiles.add(ReachableFileUtil.versionHintLocation(table)); + return spark.createDataset(otherMetadataFiles, Encoders.STRING()).toDF(FILE_PATH); + } + + protected Dataset buildValidMetadataFileDF(Table table) { + Dataset manifestDF = buildManifestFileDF(table); + Dataset manifestListDF = buildManifestListDF(table); + Dataset otherMetadataFileDF = buildOtherMetadataFileDF(table); + + return manifestDF.union(otherMetadataFileDF).union(manifestListDF); + } + + protected Dataset withFileType(Dataset ds, String type) { + return ds.withColumn(FILE_TYPE, lit(type)); + } + + protected Dataset loadMetadataTable(Table table, MetadataTableType type) { + return SparkTableUtil.loadMetadataTable(spark, table, type); + } + + private static class ReadManifest implements FlatMapFunction> { + private final Broadcast
table; + + ReadManifest(Broadcast
table) { + this.table = table; + } + + @Override + public Iterator> call(ManifestFileBean manifest) { + return new ClosingIterator<>(entries(manifest)); + } + + public CloseableIterator> entries(ManifestFileBean manifest) { + FileIO io = table.getValue().io(); + Map specs = table.getValue().specs(); + ImmutableList projection = ImmutableList.of(DataFile.FILE_PATH.name(), DataFile.CONTENT.name()); + + switch (manifest.content()) { + case DATA: + return CloseableIterator.transform( + ManifestFiles.read(manifest, io, specs).select(projection).iterator(), + ReadManifest::contentFileWithType); + case DELETES: + return CloseableIterator.transform( + ManifestFiles.readDeleteManifest(manifest, io, specs).select(projection).iterator(), + ReadManifest::contentFileWithType); + default: + throw new IllegalArgumentException("Unsupported manifest content type:" + manifest.content()); + } + } + + static Tuple2 contentFileWithType(ContentFile file) { + return new Tuple2<>(file.path().toString(), file.content().toString()); + } + } +} diff --git a/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/actions/BaseTableCreationSparkAction.java b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/actions/BaseTableCreationSparkAction.java new file mode 100644 index 000000000000..6eadece65cb6 --- /dev/null +++ b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/actions/BaseTableCreationSparkAction.java @@ -0,0 +1,168 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iceberg.spark.actions; + +import java.util.List; +import java.util.Locale; +import java.util.Map; +import java.util.Set; +import org.apache.iceberg.Table; +import org.apache.iceberg.TableProperties; +import org.apache.iceberg.exceptions.AlreadyExistsException; +import org.apache.iceberg.exceptions.NoSuchNamespaceException; +import org.apache.iceberg.exceptions.NoSuchTableException; +import org.apache.iceberg.mapping.MappingUtil; +import org.apache.iceberg.mapping.NameMapping; +import org.apache.iceberg.mapping.NameMappingParser; +import org.apache.iceberg.relocated.com.google.common.base.Preconditions; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableList; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableSet; +import org.apache.iceberg.relocated.com.google.common.collect.Maps; +import org.apache.iceberg.spark.SparkCatalog; +import org.apache.iceberg.spark.SparkSessionCatalog; +import org.apache.iceberg.spark.source.StagedSparkTable; +import org.apache.spark.sql.SparkSession; +import org.apache.spark.sql.catalyst.catalog.CatalogTable; +import org.apache.spark.sql.catalyst.catalog.CatalogUtils; +import org.apache.spark.sql.connector.catalog.CatalogPlugin; +import org.apache.spark.sql.connector.catalog.Identifier; +import org.apache.spark.sql.connector.catalog.StagingTableCatalog; +import org.apache.spark.sql.connector.catalog.TableCatalog; +import org.apache.spark.sql.connector.catalog.V1Table; +import org.apache.spark.sql.connector.expressions.Transform; +import org.apache.spark.sql.types.StructType; + +abstract class BaseTableCreationSparkAction extends BaseSparkAction { + private static final Set ALLOWED_SOURCES = ImmutableSet.of("parquet", "avro", "orc", "hive"); + protected static final String LOCATION = "location"; + protected static final String ICEBERG_METADATA_FOLDER = "metadata"; + protected static final List EXCLUDED_PROPERTIES = + ImmutableList.of("path", "transient_lastDdlTime", "serialization.format"); + + // Source Fields + private final V1Table sourceTable; + private final CatalogTable sourceCatalogTable; + private final String sourceTableLocation; + private final TableCatalog sourceCatalog; + private final Identifier sourceTableIdent; + + // Optional Parameters for destination + private final Map additionalProperties = Maps.newHashMap(); + + BaseTableCreationSparkAction(SparkSession spark, CatalogPlugin sourceCatalog, Identifier sourceTableIdent) { + super(spark); + + this.sourceCatalog = checkSourceCatalog(sourceCatalog); + this.sourceTableIdent = sourceTableIdent; + + try { + this.sourceTable = (V1Table) this.sourceCatalog.loadTable(sourceTableIdent); + this.sourceCatalogTable = sourceTable.v1Table(); + } catch (org.apache.spark.sql.catalyst.analysis.NoSuchTableException e) { + throw new NoSuchTableException("Cannot not find source table '%s'", sourceTableIdent); + } catch (ClassCastException e) { + throw new IllegalArgumentException(String.format("Cannot use non-v1 table '%s' as a source", sourceTableIdent), + e); + } + validateSourceTable(); + + this.sourceTableLocation = CatalogUtils.URIToString(sourceCatalogTable.storage().locationUri().get()); + } + + protected abstract TableCatalog checkSourceCatalog(CatalogPlugin catalog); + + protected abstract StagingTableCatalog destCatalog(); + + protected abstract Identifier destTableIdent(); + + protected abstract Map destTableProps(); + + protected String sourceTableLocation() { + return sourceTableLocation; + } + + protected CatalogTable v1SourceTable() { + return sourceCatalogTable; + } + + protected TableCatalog sourceCatalog() { + return sourceCatalog; + } + + protected Identifier sourceTableIdent() { + return sourceTableIdent; + } + + protected void setProperties(Map properties) { + additionalProperties.putAll(properties); + } + + protected void setProperty(String key, String value) { + additionalProperties.put(key, value); + } + + protected Map additionalProperties() { + return additionalProperties; + } + + private void validateSourceTable() { + String sourceTableProvider = sourceCatalogTable.provider().get().toLowerCase(Locale.ROOT); + Preconditions.checkArgument(ALLOWED_SOURCES.contains(sourceTableProvider), + "Cannot create an Iceberg table from source provider: '%s'", sourceTableProvider); + Preconditions.checkArgument(!sourceCatalogTable.storage().locationUri().isEmpty(), + "Cannot create an Iceberg table from a source without an explicit location"); + } + + protected StagingTableCatalog checkDestinationCatalog(CatalogPlugin catalog) { + Preconditions.checkArgument(catalog instanceof SparkSessionCatalog || catalog instanceof SparkCatalog, + "Cannot create Iceberg table in non-Iceberg Catalog. " + + "Catalog '%s' was of class '%s' but '%s' or '%s' are required", + catalog.name(), catalog.getClass().getName(), SparkSessionCatalog.class.getName(), + SparkCatalog.class.getName()); + + return (StagingTableCatalog) catalog; + } + + protected StagedSparkTable stageDestTable() { + try { + Map props = destTableProps(); + StructType schema = sourceTable.schema(); + Transform[] partitioning = sourceTable.partitioning(); + return (StagedSparkTable) destCatalog().stageCreate(destTableIdent(), schema, partitioning, props); + } catch (org.apache.spark.sql.catalyst.analysis.NoSuchNamespaceException e) { + throw new NoSuchNamespaceException("Cannot create table %s as the namespace does not exist", destTableIdent()); + } catch (org.apache.spark.sql.catalyst.analysis.TableAlreadyExistsException e) { + throw new AlreadyExistsException("Cannot create table %s as it already exists", destTableIdent()); + } + } + + protected void ensureNameMappingPresent(Table table) { + if (!table.properties().containsKey(TableProperties.DEFAULT_NAME_MAPPING)) { + NameMapping nameMapping = MappingUtil.create(table.schema()); + String nameMappingJson = NameMappingParser.toJson(nameMapping); + table.updateProperties().set(TableProperties.DEFAULT_NAME_MAPPING, nameMappingJson).commit(); + } + } + + protected String getMetadataLocation(Table table) { + return table.properties().getOrDefault(TableProperties.WRITE_METADATA_LOCATION, + table.location() + "/" + ICEBERG_METADATA_FOLDER); + } +} diff --git a/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/actions/ManifestFileBean.java b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/actions/ManifestFileBean.java new file mode 100644 index 000000000000..269130496dc9 --- /dev/null +++ b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/actions/ManifestFileBean.java @@ -0,0 +1,153 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iceberg.spark.actions; + +import java.nio.ByteBuffer; +import java.util.List; +import org.apache.iceberg.ManifestContent; +import org.apache.iceberg.ManifestFile; + +public class ManifestFileBean implements ManifestFile { + private String path = null; + private Long length = null; + private Integer partitionSpecId = null; + private Long addedSnapshotId = null; + private Integer content = null; + + public String getPath() { + return path; + } + + public void setPath(String path) { + this.path = path; + } + + public Long getLength() { + return length; + } + + public void setLength(Long length) { + this.length = length; + } + + public Integer getPartitionSpecId() { + return partitionSpecId; + } + + public void setPartitionSpecId(Integer partitionSpecId) { + this.partitionSpecId = partitionSpecId; + } + + public Long getAddedSnapshotId() { + return addedSnapshotId; + } + + public void setAddedSnapshotId(Long addedSnapshotId) { + this.addedSnapshotId = addedSnapshotId; + } + + public Integer getContent() { + return content; + } + + public void setContent(Integer content) { + this.content = content; + } + + @Override + public String path() { + return path; + } + + @Override + public long length() { + return length; + } + + @Override + public int partitionSpecId() { + return partitionSpecId; + } + + @Override + public ManifestContent content() { + return ManifestContent.fromId(content); + } + + @Override + public long sequenceNumber() { + return 0; + } + + @Override + public long minSequenceNumber() { + return 0; + } + + @Override + public Long snapshotId() { + return addedSnapshotId; + } + + @Override + public Integer addedFilesCount() { + return null; + } + + @Override + public Long addedRowsCount() { + return null; + } + + @Override + public Integer existingFilesCount() { + return null; + } + + @Override + public Long existingRowsCount() { + return null; + } + + @Override + public Integer deletedFilesCount() { + return null; + } + + @Override + public Long deletedRowsCount() { + return null; + } + + @Override + public List partitions() { + return null; + } + + @Override + public ByteBuffer keyMetadata() { + return null; + } + + @Override + public ManifestFile copy() { + throw new UnsupportedOperationException("Cannot copy"); + } +} diff --git a/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/actions/SparkActions.java b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/actions/SparkActions.java new file mode 100644 index 000000000000..4f6bb3f00d9f --- /dev/null +++ b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/actions/SparkActions.java @@ -0,0 +1,104 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iceberg.spark.actions; + +import org.apache.iceberg.Table; +import org.apache.iceberg.actions.ActionsProvider; +import org.apache.iceberg.actions.DeleteOrphanFiles; +import org.apache.iceberg.actions.DeleteReachableFiles; +import org.apache.iceberg.actions.ExpireSnapshots; +import org.apache.iceberg.actions.MigrateTable; +import org.apache.iceberg.actions.RewriteDataFiles; +import org.apache.iceberg.actions.RewriteManifests; +import org.apache.iceberg.actions.SnapshotTable; +import org.apache.iceberg.spark.Spark3Util; +import org.apache.iceberg.spark.Spark3Util.CatalogAndIdentifier; +import org.apache.spark.sql.SparkSession; +import org.apache.spark.sql.connector.catalog.CatalogPlugin; + +/** + * An implementation of {@link ActionsProvider} for Spark. + *

+ * This class is the primary API for interacting with actions in Spark that users should use + * to instantiate particular actions. + */ +public class SparkActions implements ActionsProvider { + + private final SparkSession spark; + + private SparkActions(SparkSession spark) { + this.spark = spark; + } + + public static SparkActions get(SparkSession spark) { + return new SparkActions(spark); + } + + public static SparkActions get() { + return new SparkActions(SparkSession.active()); + } + + @Override + public SnapshotTable snapshotTable(String tableIdent) { + String ctx = "snapshot source"; + CatalogPlugin defaultCatalog = spark.sessionState().catalogManager().currentCatalog(); + CatalogAndIdentifier catalogAndIdent = Spark3Util.catalogAndIdentifier(ctx, spark, tableIdent, defaultCatalog); + return new BaseSnapshotTableSparkAction(spark, catalogAndIdent.catalog(), catalogAndIdent.identifier()); + } + + @Override + public MigrateTable migrateTable(String tableIdent) { + String ctx = "migrate target"; + CatalogPlugin defaultCatalog = spark.sessionState().catalogManager().currentCatalog(); + CatalogAndIdentifier catalogAndIdent = Spark3Util.catalogAndIdentifier(ctx, spark, tableIdent, defaultCatalog); + return new BaseMigrateTableSparkAction(spark, catalogAndIdent.catalog(), catalogAndIdent.identifier()); + } + + @Override + @Deprecated + public RewriteDataFiles rewriteDataFiles(Table table) { + return new BaseRewriteDataFilesSparkAction(spark, table); + } + + @Override + public RewriteDataFiles rewriteDataFiles(Table table, String fullIdentifier) { + return new BaseRewriteDataFilesSparkAction(spark, table, fullIdentifier); + } + + @Override + public DeleteOrphanFiles deleteOrphanFiles(Table table) { + return new BaseDeleteOrphanFilesSparkAction(spark, table); + } + + @Override + public RewriteManifests rewriteManifests(Table table) { + return new BaseRewriteManifestsSparkAction(spark, table); + } + + @Override + public ExpireSnapshots expireSnapshots(Table table) { + return new BaseExpireSnapshotsSparkAction(spark, table); + } + + @Override + public DeleteReachableFiles deleteReachableFiles(String metadataLocation) { + return new BaseDeleteReachableFilesSparkAction(spark, metadataLocation); + } +} diff --git a/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/actions/SparkBinPackStrategy.java b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/actions/SparkBinPackStrategy.java new file mode 100644 index 000000000000..fcfaeaa7511f --- /dev/null +++ b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/actions/SparkBinPackStrategy.java @@ -0,0 +1,104 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iceberg.spark.actions; + +import java.util.List; +import java.util.Set; +import java.util.UUID; +import org.apache.iceberg.DataFile; +import org.apache.iceberg.DistributionMode; +import org.apache.iceberg.FileScanTask; +import org.apache.iceberg.Table; +import org.apache.iceberg.actions.BinPackStrategy; +import org.apache.iceberg.spark.FileRewriteCoordinator; +import org.apache.iceberg.spark.FileScanTaskSetManager; +import org.apache.iceberg.spark.SparkReadOptions; +import org.apache.iceberg.spark.SparkWriteOptions; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.SparkSession; +import org.apache.spark.sql.internal.SQLConf; + +public class SparkBinPackStrategy extends BinPackStrategy { + private final Table table; + private final String fullIdentifier; + private final SparkSession spark; + private final FileScanTaskSetManager manager = FileScanTaskSetManager.get(); + private final FileRewriteCoordinator rewriteCoordinator = FileRewriteCoordinator.get(); + + public SparkBinPackStrategy(Table table, String fullIdentifier, SparkSession spark) { + this.table = table; + this.spark = spark; + // Fallback if a quoted identifier is not supplied + this.fullIdentifier = fullIdentifier == null ? table.name() : fullIdentifier; + } + + @Deprecated + public SparkBinPackStrategy(Table table, SparkSession spark) { + this.table = table; + this.spark = spark; + // Fallback if a quoted identifier is not supplied + this.fullIdentifier = table.name(); + } + + @Override + public Table table() { + return table; + } + + @Override + public Set rewriteFiles(List filesToRewrite) { + String groupID = UUID.randomUUID().toString(); + try { + manager.stageTasks(table, groupID, filesToRewrite); + + // Disable Adaptive Query Execution as this may change the output partitioning of our write + SparkSession cloneSession = spark.cloneSession(); + cloneSession.conf().set(SQLConf.ADAPTIVE_EXECUTION_ENABLED().key(), false); + + Dataset scanDF = cloneSession.read().format("iceberg") + .option(SparkReadOptions.FILE_SCAN_TASK_SET_ID, groupID) + .option(SparkReadOptions.SPLIT_SIZE, splitSize(inputFileSize(filesToRewrite))) + .option(SparkReadOptions.FILE_OPEN_COST, "0") + .load(fullIdentifier); + + // All files within a file group are written with the same spec, so check the first + boolean requiresRepartition = !filesToRewrite.get(0).spec().equals(table.spec()); + + // Invoke a shuffle if the partition spec of the incoming partition does not match the table + String distributionMode = requiresRepartition ? DistributionMode.RANGE.modeName() : + DistributionMode.NONE.modeName(); + + // write the packed data into new files where each split becomes a new file + scanDF.write() + .format("iceberg") + .option(SparkWriteOptions.REWRITTEN_FILE_SCAN_TASK_SET_ID, groupID) + .option(SparkWriteOptions.TARGET_FILE_SIZE_BYTES, writeMaxFileSize()) + .option(SparkWriteOptions.DISTRIBUTION_MODE, distributionMode) + .mode("append") + .save(fullIdentifier); + + return rewriteCoordinator.fetchNewDataFiles(table, groupID); + } finally { + manager.removeTasks(table, groupID); + rewriteCoordinator.clearRewrite(table, groupID); + } + } +} diff --git a/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/actions/SparkSortStrategy.java b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/actions/SparkSortStrategy.java new file mode 100644 index 000000000000..6c8f8c027dba --- /dev/null +++ b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/actions/SparkSortStrategy.java @@ -0,0 +1,170 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iceberg.spark.actions; + +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.UUID; +import org.apache.iceberg.DataFile; +import org.apache.iceberg.FileScanTask; +import org.apache.iceberg.Table; +import org.apache.iceberg.actions.RewriteStrategy; +import org.apache.iceberg.actions.SortStrategy; +import org.apache.iceberg.relocated.com.google.common.base.Preconditions; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableSet; +import org.apache.iceberg.spark.FileRewriteCoordinator; +import org.apache.iceberg.spark.FileScanTaskSetManager; +import org.apache.iceberg.spark.SparkDistributionAndOrderingUtil; +import org.apache.iceberg.spark.SparkReadOptions; +import org.apache.iceberg.spark.SparkWriteOptions; +import org.apache.iceberg.util.PropertyUtil; +import org.apache.iceberg.util.SortOrderUtil; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.SparkSession; +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan; +import org.apache.spark.sql.catalyst.utils.DistributionAndOrderingUtils$; +import org.apache.spark.sql.connector.distributions.Distribution; +import org.apache.spark.sql.connector.distributions.Distributions; +import org.apache.spark.sql.connector.expressions.SortOrder; +import org.apache.spark.sql.internal.SQLConf; + +public class SparkSortStrategy extends SortStrategy { + + /** + * The number of shuffle partitions and consequently the number of output files + * created by the Spark Sort is based on the size of the input data files used + * in this rewrite operation. Due to compression, the disk file sizes may not + * accurately represent the size of files in the output. This parameter lets + * the user adjust the file size used for estimating actual output data size. A + * factor greater than 1.0 would generate more files than we would expect based + * on the on-disk file size. A value less than 1.0 would create fewer files than + * we would expect due to the on-disk size. + */ + public static final String COMPRESSION_FACTOR = "compression-factor"; + + private final Table table; + private final SparkSession spark; + private final FileScanTaskSetManager manager = FileScanTaskSetManager.get(); + private final FileRewriteCoordinator rewriteCoordinator = FileRewriteCoordinator.get(); + + private double sizeEstimateMultiple; + + public SparkSortStrategy(Table table, SparkSession spark) { + this.table = table; + this.spark = spark; + } + + @Override + public Table table() { + return table; + } + + @Override + public Set validOptions() { + return ImmutableSet.builder() + .addAll(super.validOptions()) + .add(COMPRESSION_FACTOR) + .build(); + } + + @Override + public RewriteStrategy options(Map options) { + sizeEstimateMultiple = PropertyUtil.propertyAsDouble(options, + COMPRESSION_FACTOR, + 1.0); + + Preconditions.checkArgument(sizeEstimateMultiple > 0, + "Invalid compression factor: %s (not positive)", sizeEstimateMultiple); + + return super.options(options); + } + + @Override + public Set rewriteFiles(List filesToRewrite) { + String groupID = UUID.randomUUID().toString(); + boolean requiresRepartition = !filesToRewrite.get(0).spec().equals(table.spec()); + + SortOrder[] ordering; + if (requiresRepartition) { + // Build in the requirement for Partition Sorting into our sort order + ordering = SparkDistributionAndOrderingUtil.convert(SortOrderUtil.buildSortOrder(table, sortOrder())); + } else { + ordering = SparkDistributionAndOrderingUtil.convert(sortOrder()); + } + + Distribution distribution = Distributions.ordered(ordering); + + try { + manager.stageTasks(table, groupID, filesToRewrite); + + // Disable Adaptive Query Execution as this may change the output partitioning of our write + SparkSession cloneSession = spark.cloneSession(); + cloneSession.conf().set(SQLConf.ADAPTIVE_EXECUTION_ENABLED().key(), false); + + // Reset Shuffle Partitions for our sort + long numOutputFiles = numOutputFiles((long) (inputFileSize(filesToRewrite) * sizeEstimateMultiple)); + cloneSession.conf().set(SQLConf.SHUFFLE_PARTITIONS().key(), Math.max(1, numOutputFiles)); + + Dataset scanDF = cloneSession.read().format("iceberg") + .option(SparkReadOptions.FILE_SCAN_TASK_SET_ID, groupID) + .load(table.name()); + + // write the packed data into new files where each split becomes a new file + SQLConf sqlConf = cloneSession.sessionState().conf(); + LogicalPlan sortPlan = sortPlan(distribution, ordering, scanDF.logicalPlan(), sqlConf); + Dataset sortedDf = new Dataset<>(cloneSession, sortPlan, scanDF.encoder()); + + sortedDf.write() + .format("iceberg") + .option(SparkWriteOptions.REWRITTEN_FILE_SCAN_TASK_SET_ID, groupID) + .option(SparkWriteOptions.TARGET_FILE_SIZE_BYTES, writeMaxFileSize()) + .option(SparkWriteOptions.USE_TABLE_DISTRIBUTION_AND_ORDERING, "false") + .mode("append") // This will only write files without modifying the table, see SparkWrite.RewriteFiles + .save(table.name()); + + return rewriteCoordinator.fetchNewDataFiles(table, groupID); + } finally { + manager.removeTasks(table, groupID); + rewriteCoordinator.clearRewrite(table, groupID); + } + } + + protected SparkSession spark() { + return this.spark; + } + + protected LogicalPlan sortPlan(Distribution distribution, SortOrder[] ordering, LogicalPlan plan, SQLConf conf) { + return DistributionAndOrderingUtils$.MODULE$.prepareQuery(distribution, ordering, plan, conf); + } + + protected double sizeEstimateMultiple() { + return sizeEstimateMultiple; + } + + protected FileScanTaskSetManager manager() { + return manager; + } + + protected FileRewriteCoordinator rewriteCoordinator() { + return rewriteCoordinator; + } +} diff --git a/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/actions/SparkZOrderStrategy.java b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/actions/SparkZOrderStrategy.java new file mode 100644 index 000000000000..cdd47fe31372 --- /dev/null +++ b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/actions/SparkZOrderStrategy.java @@ -0,0 +1,219 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iceberg.spark.actions; + +import java.util.Arrays; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.UUID; +import java.util.stream.Collectors; +import java.util.stream.Stream; +import org.apache.iceberg.DataFile; +import org.apache.iceberg.FileScanTask; +import org.apache.iceberg.NullOrder; +import org.apache.iceberg.PartitionField; +import org.apache.iceberg.Schema; +import org.apache.iceberg.SortDirection; +import org.apache.iceberg.Table; +import org.apache.iceberg.actions.RewriteStrategy; +import org.apache.iceberg.relocated.com.google.common.base.Preconditions; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableSet; +import org.apache.iceberg.spark.SparkDistributionAndOrderingUtil; +import org.apache.iceberg.spark.SparkReadOptions; +import org.apache.iceberg.spark.SparkWriteOptions; +import org.apache.iceberg.types.Types; +import org.apache.iceberg.types.Types.NestedField; +import org.apache.iceberg.util.PropertyUtil; +import org.apache.iceberg.util.SortOrderUtil; +import org.apache.iceberg.util.ZOrderByteUtils; +import org.apache.spark.sql.Column; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.SparkSession; +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan; +import org.apache.spark.sql.connector.distributions.Distribution; +import org.apache.spark.sql.connector.distributions.Distributions; +import org.apache.spark.sql.connector.expressions.SortOrder; +import org.apache.spark.sql.functions; +import org.apache.spark.sql.internal.SQLConf; +import org.apache.spark.sql.types.StructField; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +public class SparkZOrderStrategy extends SparkSortStrategy { + private static final Logger LOG = LoggerFactory.getLogger(SparkZOrderStrategy.class); + + private static final String Z_COLUMN = "ICEZVALUE"; + private static final Schema Z_SCHEMA = new Schema(NestedField.required(0, Z_COLUMN, Types.BinaryType.get())); + private static final org.apache.iceberg.SortOrder Z_SORT_ORDER = org.apache.iceberg.SortOrder.builderFor(Z_SCHEMA) + .sortBy(Z_COLUMN, SortDirection.ASC, NullOrder.NULLS_LAST) + .build(); + + /** + * Controls the amount of bytes interleaved in the ZOrder Algorithm. Default is all bytes being interleaved. + */ + private static final String MAX_OUTPUT_SIZE_KEY = "max-output-size"; + private static final int DEFAULT_MAX_OUTPUT_SIZE = Integer.MAX_VALUE; + + /** + * Controls the number of bytes considered from an input column of a type with variable length (String, Binary). + * Default is to use the same size as primitives {@link ZOrderByteUtils#PRIMITIVE_BUFFER_SIZE} + */ + private static final String VAR_LENGTH_CONTRIBUTION_KEY = "var-length-contribution"; + private static final int DEFAULT_VAR_LENGTH_CONTRIBUTION = ZOrderByteUtils.PRIMITIVE_BUFFER_SIZE; + + private final List zOrderColNames; + + private int maxOutputSize; + private int varLengthContribution; + + @Override + public Set validOptions() { + return ImmutableSet.builder() + .addAll(super.validOptions()) + .add(VAR_LENGTH_CONTRIBUTION_KEY) + .add(MAX_OUTPUT_SIZE_KEY) + .build(); + } + + @Override + public RewriteStrategy options(Map options) { + super.options(options); + + varLengthContribution = PropertyUtil.propertyAsInt(options, VAR_LENGTH_CONTRIBUTION_KEY, + DEFAULT_VAR_LENGTH_CONTRIBUTION); + Preconditions.checkArgument(varLengthContribution > 0, + "Cannot use less than 1 byte for variable length types with zOrder, %s was set to %s", + VAR_LENGTH_CONTRIBUTION_KEY, varLengthContribution); + + maxOutputSize = PropertyUtil.propertyAsInt(options, MAX_OUTPUT_SIZE_KEY, DEFAULT_MAX_OUTPUT_SIZE); + Preconditions.checkArgument(maxOutputSize > 0, + "Cannot have the interleaved ZOrder value use less than 1 byte, %s was set to %s", + MAX_OUTPUT_SIZE_KEY, maxOutputSize); + + return this; + } + + public SparkZOrderStrategy(Table table, SparkSession spark, List zOrderColNames) { + super(table, spark); + + Preconditions.checkArgument(zOrderColNames != null && !zOrderColNames.isEmpty(), + "Cannot ZOrder when no columns are specified"); + + Stream identityPartitionColumns = table.spec().fields().stream() + .filter(f -> f.transform().isIdentity()) + .map(PartitionField::name); + List partZOrderCols = identityPartitionColumns + .filter(zOrderColNames::contains) + .collect(Collectors.toList()); + + if (!partZOrderCols.isEmpty()) { + LOG.warn("Cannot ZOrder on an Identity partition column as these values are constant within a partition " + + "and will be removed from the ZOrder expression: {}", partZOrderCols); + zOrderColNames.removeAll(partZOrderCols); + Preconditions.checkArgument(!zOrderColNames.isEmpty(), + "Cannot perform ZOrdering, all columns provided were identity partition columns and cannot be used."); + } + + this.zOrderColNames = zOrderColNames; + } + + @Override + public String name() { + return "Z-ORDER"; + } + + @Override + protected void validateOptions() { + // Ignore SortStrategy validation + return; + } + + @Override + public Set rewriteFiles(List filesToRewrite) { + SparkZOrderUDF zOrderUDF = new SparkZOrderUDF(zOrderColNames.size(), varLengthContribution, maxOutputSize); + + String groupID = UUID.randomUUID().toString(); + boolean requiresRepartition = !filesToRewrite.get(0).spec().equals(table().spec()); + + SortOrder[] ordering; + if (requiresRepartition) { + ordering = SparkDistributionAndOrderingUtil.convert(SortOrderUtil.buildSortOrder(table(), sortOrder())); + } else { + ordering = SparkDistributionAndOrderingUtil.convert(sortOrder()); + } + + Distribution distribution = Distributions.ordered(ordering); + + try { + manager().stageTasks(table(), groupID, filesToRewrite); + + // Disable Adaptive Query Execution as this may change the output partitioning of our write + SparkSession cloneSession = spark().cloneSession(); + cloneSession.conf().set(SQLConf.ADAPTIVE_EXECUTION_ENABLED().key(), false); + + // Reset Shuffle Partitions for our sort + long numOutputFiles = numOutputFiles((long) (inputFileSize(filesToRewrite) * sizeEstimateMultiple())); + cloneSession.conf().set(SQLConf.SHUFFLE_PARTITIONS().key(), Math.max(1, numOutputFiles)); + + Dataset scanDF = cloneSession.read().format("iceberg") + .option(SparkReadOptions.FILE_SCAN_TASK_SET_ID, groupID) + .load(table().name()); + + Column[] originalColumns = Arrays.stream(scanDF.schema().names()) + .map(n -> functions.col(n)) + .toArray(Column[]::new); + + List zOrderColumns = zOrderColNames.stream() + .map(scanDF.schema()::apply) + .collect(Collectors.toList()); + + Column zvalueArray = functions.array(zOrderColumns.stream().map(colStruct -> + zOrderUDF.sortedLexicographically(functions.col(colStruct.name()), colStruct.dataType()) + ).toArray(Column[]::new)); + + Dataset zvalueDF = scanDF.withColumn(Z_COLUMN, zOrderUDF.interleaveBytes(zvalueArray)); + + SQLConf sqlConf = cloneSession.sessionState().conf(); + LogicalPlan sortPlan = sortPlan(distribution, ordering, zvalueDF.logicalPlan(), sqlConf); + Dataset sortedDf = new Dataset<>(cloneSession, sortPlan, zvalueDF.encoder()); + sortedDf + .select(originalColumns) + .write() + .format("iceberg") + .option(SparkWriteOptions.REWRITTEN_FILE_SCAN_TASK_SET_ID, groupID) + .option(SparkWriteOptions.TARGET_FILE_SIZE_BYTES, writeMaxFileSize()) + .option(SparkWriteOptions.USE_TABLE_DISTRIBUTION_AND_ORDERING, "false") + .mode("append") + .save(table().name()); + + return rewriteCoordinator().fetchNewDataFiles(table(), groupID); + } finally { + manager().removeTasks(table(), groupID); + rewriteCoordinator().clearRewrite(table(), groupID); + } + } + + @Override + protected org.apache.iceberg.SortOrder sortOrder() { + return Z_SORT_ORDER; + } +} diff --git a/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/actions/SparkZOrderUDF.java b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/actions/SparkZOrderUDF.java new file mode 100644 index 000000000000..eea3689211e2 --- /dev/null +++ b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/actions/SparkZOrderUDF.java @@ -0,0 +1,274 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iceberg.spark.actions; + +import java.io.IOException; +import java.io.ObjectInputStream; +import java.io.Serializable; +import java.nio.ByteBuffer; +import java.nio.charset.CharsetEncoder; +import java.nio.charset.StandardCharsets; +import org.apache.iceberg.util.ZOrderByteUtils; +import org.apache.spark.sql.Column; +import org.apache.spark.sql.expressions.UserDefinedFunction; +import org.apache.spark.sql.functions; +import org.apache.spark.sql.types.BinaryType; +import org.apache.spark.sql.types.BooleanType; +import org.apache.spark.sql.types.ByteType; +import org.apache.spark.sql.types.DataType; +import org.apache.spark.sql.types.DataTypes; +import org.apache.spark.sql.types.DateType; +import org.apache.spark.sql.types.DoubleType; +import org.apache.spark.sql.types.FloatType; +import org.apache.spark.sql.types.IntegerType; +import org.apache.spark.sql.types.LongType; +import org.apache.spark.sql.types.ShortType; +import org.apache.spark.sql.types.StringType; +import org.apache.spark.sql.types.TimestampType; +import scala.collection.JavaConverters; +import scala.collection.Seq; + +class SparkZOrderUDF implements Serializable { + private static final byte[] PRIMITIVE_EMPTY = new byte[ZOrderByteUtils.PRIMITIVE_BUFFER_SIZE]; + + /** + * Every Spark task runs iteratively on a rows in a single thread so ThreadLocal should protect from + * concurrent access to any of these structures. + */ + private transient ThreadLocal outputBuffer; + private transient ThreadLocal inputHolder; + private transient ThreadLocal inputBuffers; + private transient ThreadLocal encoder; + + private final int numCols; + + private int inputCol = 0; + private int totalOutputBytes = 0; + private final int varTypeSize; + private final int maxOutputSize; + + SparkZOrderUDF(int numCols, int varTypeSize, int maxOutputSize) { + this.numCols = numCols; + this.varTypeSize = varTypeSize; + this.maxOutputSize = maxOutputSize; + } + + private void readObject(ObjectInputStream in) throws IOException, ClassNotFoundException { + in.defaultReadObject(); + inputBuffers = ThreadLocal.withInitial(() -> new ByteBuffer[numCols]); + inputHolder = ThreadLocal.withInitial(() -> new byte[numCols][]); + outputBuffer = ThreadLocal.withInitial(() -> ByteBuffer.allocate(totalOutputBytes)); + encoder = ThreadLocal.withInitial(() -> StandardCharsets.UTF_8.newEncoder()); + } + + private ByteBuffer inputBuffer(int position, int size) { + ByteBuffer buffer = inputBuffers.get()[position]; + if (buffer == null) { + buffer = ByteBuffer.allocate(size); + inputBuffers.get()[position] = buffer; + } + return buffer; + } + + byte[] interleaveBits(Seq scalaBinary) { + byte[][] columnsBinary = JavaConverters.seqAsJavaList(scalaBinary).toArray(inputHolder.get()); + return ZOrderByteUtils.interleaveBits(columnsBinary, totalOutputBytes, outputBuffer.get()); + } + + private UserDefinedFunction tinyToOrderedBytesUDF() { + int position = inputCol; + UserDefinedFunction udf = functions.udf((Byte value) -> { + if (value == null) { + return PRIMITIVE_EMPTY; + } + return ZOrderByteUtils.tinyintToOrderedBytes(value, inputBuffer(position, ZOrderByteUtils.PRIMITIVE_BUFFER_SIZE)) + .array(); + }, DataTypes.BinaryType).withName("TINY_ORDERED_BYTES"); + + this.inputCol++; + increaseOutputSize(ZOrderByteUtils.PRIMITIVE_BUFFER_SIZE); + + return udf; + } + + private UserDefinedFunction shortToOrderedBytesUDF() { + int position = inputCol; + UserDefinedFunction udf = functions.udf((Short value) -> { + if (value == null) { + return PRIMITIVE_EMPTY; + } + return ZOrderByteUtils.shortToOrderedBytes(value, inputBuffer(position, ZOrderByteUtils.PRIMITIVE_BUFFER_SIZE)) + .array(); + }, DataTypes.BinaryType).withName("SHORT_ORDERED_BYTES"); + + this.inputCol++; + increaseOutputSize(ZOrderByteUtils.PRIMITIVE_BUFFER_SIZE); + + return udf; + } + + private UserDefinedFunction intToOrderedBytesUDF() { + int position = inputCol; + UserDefinedFunction udf = functions.udf((Integer value) -> { + if (value == null) { + return PRIMITIVE_EMPTY; + } + return ZOrderByteUtils.intToOrderedBytes(value, inputBuffer(position, ZOrderByteUtils.PRIMITIVE_BUFFER_SIZE)) + .array(); + }, DataTypes.BinaryType).withName("INT_ORDERED_BYTES"); + + this.inputCol++; + increaseOutputSize(ZOrderByteUtils.PRIMITIVE_BUFFER_SIZE); + + return udf; + } + + private UserDefinedFunction longToOrderedBytesUDF() { + int position = inputCol; + UserDefinedFunction udf = functions.udf((Long value) -> { + if (value == null) { + return PRIMITIVE_EMPTY; + } + return ZOrderByteUtils.longToOrderedBytes(value, inputBuffer(position, ZOrderByteUtils.PRIMITIVE_BUFFER_SIZE)) + .array(); + }, DataTypes.BinaryType).withName("LONG_ORDERED_BYTES"); + + this.inputCol++; + increaseOutputSize(ZOrderByteUtils.PRIMITIVE_BUFFER_SIZE); + + return udf; + } + + private UserDefinedFunction floatToOrderedBytesUDF() { + int position = inputCol; + UserDefinedFunction udf = functions.udf((Float value) -> { + if (value == null) { + return PRIMITIVE_EMPTY; + } + return ZOrderByteUtils.floatToOrderedBytes(value, inputBuffer(position, ZOrderByteUtils.PRIMITIVE_BUFFER_SIZE)) + .array(); + }, DataTypes.BinaryType).withName("FLOAT_ORDERED_BYTES"); + + this.inputCol++; + increaseOutputSize(ZOrderByteUtils.PRIMITIVE_BUFFER_SIZE); + + return udf; + } + + private UserDefinedFunction doubleToOrderedBytesUDF() { + int position = inputCol; + UserDefinedFunction udf = functions.udf((Double value) -> { + if (value == null) { + return PRIMITIVE_EMPTY; + } + return ZOrderByteUtils.doubleToOrderedBytes(value, inputBuffer(position, ZOrderByteUtils.PRIMITIVE_BUFFER_SIZE)) + .array(); + }, DataTypes.BinaryType).withName("DOUBLE_ORDERED_BYTES"); + + this.inputCol++; + increaseOutputSize(ZOrderByteUtils.PRIMITIVE_BUFFER_SIZE); + + return udf; + } + + private UserDefinedFunction booleanToOrderedBytesUDF() { + int position = inputCol; + UserDefinedFunction udf = functions.udf((Boolean value) -> { + ByteBuffer buffer = inputBuffer(position, ZOrderByteUtils.PRIMITIVE_BUFFER_SIZE); + buffer.put(0, (byte) (value ? -127 : 0)); + return buffer.array(); + }, DataTypes.BinaryType).withName("BOOLEAN-LEXICAL-BYTES"); + + this.inputCol++; + increaseOutputSize(ZOrderByteUtils.PRIMITIVE_BUFFER_SIZE); + return udf; + } + + private UserDefinedFunction stringToOrderedBytesUDF() { + int position = inputCol; + UserDefinedFunction udf = functions.udf((String value) -> + ZOrderByteUtils.stringToOrderedBytes( + value, + varTypeSize, + inputBuffer(position, varTypeSize), + encoder.get()).array(), DataTypes.BinaryType) + .withName("STRING-LEXICAL-BYTES"); + + this.inputCol++; + increaseOutputSize(varTypeSize); + + return udf; + } + + private UserDefinedFunction bytesTruncateUDF() { + int position = inputCol; + UserDefinedFunction udf = functions.udf((byte[] value) -> + ZOrderByteUtils.byteTruncateOrFill(value, varTypeSize, inputBuffer(position, varTypeSize)).array(), + DataTypes.BinaryType) + .withName("BYTE-TRUNCATE"); + + this.inputCol++; + increaseOutputSize(varTypeSize); + + return udf; + } + + private final UserDefinedFunction interleaveUDF = + functions.udf((Seq arrayBinary) -> interleaveBits(arrayBinary), DataTypes.BinaryType) + .withName("INTERLEAVE_BYTES"); + + Column interleaveBytes(Column arrayBinary) { + return interleaveUDF.apply(arrayBinary); + } + + @SuppressWarnings("checkstyle:CyclomaticComplexity") + Column sortedLexicographically(Column column, DataType type) { + if (type instanceof ByteType) { + return tinyToOrderedBytesUDF().apply(column); + } else if (type instanceof ShortType) { + return shortToOrderedBytesUDF().apply(column); + } else if (type instanceof IntegerType) { + return intToOrderedBytesUDF().apply(column); + } else if (type instanceof LongType) { + return longToOrderedBytesUDF().apply(column); + } else if (type instanceof FloatType) { + return floatToOrderedBytesUDF().apply(column); + } else if (type instanceof DoubleType) { + return doubleToOrderedBytesUDF().apply(column); + } else if (type instanceof StringType) { + return stringToOrderedBytesUDF().apply(column); + } else if (type instanceof BinaryType) { + return bytesTruncateUDF().apply(column); + } else if (type instanceof BooleanType) { + return booleanToOrderedBytesUDF().apply(column); + } else if (type instanceof TimestampType) { + return longToOrderedBytesUDF().apply(column.cast(DataTypes.LongType)); + } else if (type instanceof DateType) { + return longToOrderedBytesUDF().apply(column.cast(DataTypes.LongType)); + } else { + throw new IllegalArgumentException( + String.format("Cannot use column %s of type %s in ZOrdering, the type is unsupported", column, type)); + } + } + + private void increaseOutputSize(int bytes) { + totalOutputBytes = Math.min(totalOutputBytes + bytes, maxOutputSize); + } +} diff --git a/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/data/AvroWithSparkSchemaVisitor.java b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/data/AvroWithSparkSchemaVisitor.java new file mode 100644 index 000000000000..40ed05b4ce65 --- /dev/null +++ b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/data/AvroWithSparkSchemaVisitor.java @@ -0,0 +1,74 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iceberg.spark.data; + +import org.apache.iceberg.avro.AvroWithPartnerByStructureVisitor; +import org.apache.iceberg.relocated.com.google.common.base.Preconditions; +import org.apache.iceberg.util.Pair; +import org.apache.spark.sql.types.ArrayType; +import org.apache.spark.sql.types.DataType; +import org.apache.spark.sql.types.DataTypes; +import org.apache.spark.sql.types.MapType; +import org.apache.spark.sql.types.StringType; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; + +public abstract class AvroWithSparkSchemaVisitor extends AvroWithPartnerByStructureVisitor { + + @Override + protected boolean isStringType(DataType dataType) { + return dataType instanceof StringType; + } + + @Override + protected boolean isMapType(DataType dataType) { + return dataType instanceof MapType; + } + + @Override + protected DataType arrayElementType(DataType arrayType) { + Preconditions.checkArgument(arrayType instanceof ArrayType, "Invalid array: %s is not an array", arrayType); + return ((ArrayType) arrayType).elementType(); + } + + @Override + protected DataType mapKeyType(DataType mapType) { + Preconditions.checkArgument(isMapType(mapType), "Invalid map: %s is not a map", mapType); + return ((MapType) mapType).keyType(); + } + + @Override + protected DataType mapValueType(DataType mapType) { + Preconditions.checkArgument(isMapType(mapType), "Invalid map: %s is not a map", mapType); + return ((MapType) mapType).valueType(); + } + + @Override + protected Pair fieldNameAndType(DataType structType, int pos) { + Preconditions.checkArgument(structType instanceof StructType, "Invalid struct: %s is not a struct", structType); + StructField field = ((StructType) structType).apply(pos); + return Pair.of(field.name(), field.dataType()); + } + + @Override + protected DataType nullType() { + return DataTypes.NullType; + } +} diff --git a/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/data/ParquetWithSparkSchemaVisitor.java b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/data/ParquetWithSparkSchemaVisitor.java new file mode 100644 index 000000000000..924cc3e2325a --- /dev/null +++ b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/data/ParquetWithSparkSchemaVisitor.java @@ -0,0 +1,208 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iceberg.spark.data; + +import java.util.Deque; +import java.util.List; +import org.apache.iceberg.avro.AvroSchemaUtil; +import org.apache.iceberg.relocated.com.google.common.base.Preconditions; +import org.apache.iceberg.relocated.com.google.common.collect.Lists; +import org.apache.parquet.schema.GroupType; +import org.apache.parquet.schema.MessageType; +import org.apache.parquet.schema.OriginalType; +import org.apache.parquet.schema.PrimitiveType; +import org.apache.parquet.schema.Type; +import org.apache.parquet.schema.Type.Repetition; +import org.apache.spark.sql.types.ArrayType; +import org.apache.spark.sql.types.DataType; +import org.apache.spark.sql.types.MapType; +import org.apache.spark.sql.types.Metadata; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; + +/** + * Visitor for traversing a Parquet type with a companion Spark type. + * + * @param the Java class returned by the visitor + */ +public class ParquetWithSparkSchemaVisitor { + private final Deque fieldNames = Lists.newLinkedList(); + + public static T visit(DataType sType, Type type, ParquetWithSparkSchemaVisitor visitor) { + Preconditions.checkArgument(sType != null, "Invalid DataType: null"); + if (type instanceof MessageType) { + Preconditions.checkArgument(sType instanceof StructType, "Invalid struct: %s is not a struct", sType); + StructType struct = (StructType) sType; + return visitor.message(struct, (MessageType) type, visitFields(struct, type.asGroupType(), visitor)); + + } else if (type.isPrimitive()) { + return visitor.primitive(sType, type.asPrimitiveType()); + + } else { + // if not a primitive, the typeId must be a group + GroupType group = type.asGroupType(); + OriginalType annotation = group.getOriginalType(); + if (annotation != null) { + switch (annotation) { + case LIST: + Preconditions.checkArgument(!group.isRepetition(Repetition.REPEATED), + "Invalid list: top-level group is repeated: %s", group); + Preconditions.checkArgument(group.getFieldCount() == 1, + "Invalid list: does not contain single repeated field: %s", group); + + GroupType repeatedElement = group.getFields().get(0).asGroupType(); + Preconditions.checkArgument(repeatedElement.isRepetition(Repetition.REPEATED), + "Invalid list: inner group is not repeated"); + Preconditions.checkArgument(repeatedElement.getFieldCount() <= 1, + "Invalid list: repeated group is not a single field: %s", group); + + Preconditions.checkArgument(sType instanceof ArrayType, "Invalid list: %s is not an array", sType); + ArrayType array = (ArrayType) sType; + StructField element = new StructField( + "element", array.elementType(), array.containsNull(), Metadata.empty()); + + visitor.fieldNames.push(repeatedElement.getName()); + try { + T elementResult = null; + if (repeatedElement.getFieldCount() > 0) { + elementResult = visitField(element, repeatedElement.getType(0), visitor); + } + + return visitor.list(array, group, elementResult); + + } finally { + visitor.fieldNames.pop(); + } + + case MAP: + Preconditions.checkArgument(!group.isRepetition(Repetition.REPEATED), + "Invalid map: top-level group is repeated: %s", group); + Preconditions.checkArgument(group.getFieldCount() == 1, + "Invalid map: does not contain single repeated field: %s", group); + + GroupType repeatedKeyValue = group.getType(0).asGroupType(); + Preconditions.checkArgument(repeatedKeyValue.isRepetition(Repetition.REPEATED), + "Invalid map: inner group is not repeated"); + Preconditions.checkArgument(repeatedKeyValue.getFieldCount() <= 2, + "Invalid map: repeated group does not have 2 fields"); + + Preconditions.checkArgument(sType instanceof MapType, "Invalid map: %s is not a map", sType); + MapType map = (MapType) sType; + StructField keyField = new StructField("key", map.keyType(), false, Metadata.empty()); + StructField valueField = new StructField( + "value", map.valueType(), map.valueContainsNull(), Metadata.empty()); + + visitor.fieldNames.push(repeatedKeyValue.getName()); + try { + T keyResult = null; + T valueResult = null; + switch (repeatedKeyValue.getFieldCount()) { + case 2: + // if there are 2 fields, both key and value are projected + keyResult = visitField(keyField, repeatedKeyValue.getType(0), visitor); + valueResult = visitField(valueField, repeatedKeyValue.getType(1), visitor); + break; + case 1: + // if there is just one, use the name to determine what it is + Type keyOrValue = repeatedKeyValue.getType(0); + if (keyOrValue.getName().equalsIgnoreCase("key")) { + keyResult = visitField(keyField, keyOrValue, visitor); + // value result remains null + } else { + valueResult = visitField(valueField, keyOrValue, visitor); + // key result remains null + } + break; + default: + // both results will remain null + } + + return visitor.map(map, group, keyResult, valueResult); + + } finally { + visitor.fieldNames.pop(); + } + + default: + } + } + + Preconditions.checkArgument(sType instanceof StructType, "Invalid struct: %s is not a struct", sType); + StructType struct = (StructType) sType; + return visitor.struct(struct, group, visitFields(struct, group, visitor)); + } + } + + private static T visitField(StructField sField, Type field, ParquetWithSparkSchemaVisitor visitor) { + visitor.fieldNames.push(field.getName()); + try { + return visit(sField.dataType(), field, visitor); + } finally { + visitor.fieldNames.pop(); + } + } + + private static List visitFields(StructType struct, GroupType group, + ParquetWithSparkSchemaVisitor visitor) { + StructField[] sFields = struct.fields(); + Preconditions.checkArgument(sFields.length == group.getFieldCount(), + "Structs do not match: %s and %s", struct, group); + List results = Lists.newArrayListWithExpectedSize(group.getFieldCount()); + for (int i = 0; i < sFields.length; i += 1) { + Type field = group.getFields().get(i); + StructField sField = sFields[i]; + Preconditions.checkArgument(field.getName().equals(AvroSchemaUtil.makeCompatibleName(sField.name())), + "Structs do not match: field %s != %s", field.getName(), sField.name()); + results.add(visitField(sField, field, visitor)); + } + + return results; + } + + public T message(StructType sStruct, MessageType message, List fields) { + return null; + } + + public T struct(StructType sStruct, GroupType struct, List fields) { + return null; + } + + public T list(ArrayType sArray, GroupType array, T element) { + return null; + } + + public T map(MapType sMap, GroupType map, T key, T value) { + return null; + } + + public T primitive(DataType sPrimitive, PrimitiveType primitive) { + return null; + } + + protected String[] currentPath() { + return Lists.newArrayList(fieldNames.descendingIterator()).toArray(new String[0]); + } + + protected String[] path(String name) { + List list = Lists.newArrayList(fieldNames.descendingIterator()); + list.add(name); + return list.toArray(new String[0]); + } +} diff --git a/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/data/SparkAvroReader.java b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/data/SparkAvroReader.java new file mode 100644 index 000000000000..c693e2e2c057 --- /dev/null +++ b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/data/SparkAvroReader.java @@ -0,0 +1,167 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iceberg.spark.data; + +import java.io.IOException; +import java.util.List; +import java.util.Map; +import java.util.function.Supplier; +import org.apache.avro.LogicalType; +import org.apache.avro.LogicalTypes; +import org.apache.avro.Schema; +import org.apache.avro.io.DatumReader; +import org.apache.avro.io.Decoder; +import org.apache.iceberg.avro.AvroSchemaWithTypeVisitor; +import org.apache.iceberg.avro.SupportsRowPosition; +import org.apache.iceberg.avro.ValueReader; +import org.apache.iceberg.avro.ValueReaders; +import org.apache.iceberg.data.avro.DecoderResolver; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap; +import org.apache.iceberg.types.Type; +import org.apache.iceberg.types.Types; +import org.apache.spark.sql.catalyst.InternalRow; + + +public class SparkAvroReader implements DatumReader, SupportsRowPosition { + + private final Schema readSchema; + private final ValueReader reader; + private Schema fileSchema = null; + + public SparkAvroReader(org.apache.iceberg.Schema expectedSchema, Schema readSchema) { + this(expectedSchema, readSchema, ImmutableMap.of()); + } + + @SuppressWarnings("unchecked") + public SparkAvroReader(org.apache.iceberg.Schema expectedSchema, Schema readSchema, Map constants) { + this.readSchema = readSchema; + this.reader = (ValueReader) AvroSchemaWithTypeVisitor + .visit(expectedSchema, readSchema, new ReadBuilder(constants)); + } + + @Override + public void setSchema(Schema newFileSchema) { + this.fileSchema = Schema.applyAliases(newFileSchema, readSchema); + } + + @Override + public InternalRow read(InternalRow reuse, Decoder decoder) throws IOException { + return DecoderResolver.resolveAndRead(decoder, readSchema, fileSchema, reader, reuse); + } + + @Override + public void setRowPositionSupplier(Supplier posSupplier) { + if (reader instanceof SupportsRowPosition) { + ((SupportsRowPosition) reader).setRowPositionSupplier(posSupplier); + } + } + + private static class ReadBuilder extends AvroSchemaWithTypeVisitor> { + private final Map idToConstant; + + private ReadBuilder(Map idToConstant) { + this.idToConstant = idToConstant; + } + + @Override + public ValueReader record(Types.StructType expected, Schema record, List names, + List> fields) { + return SparkValueReaders.struct(fields, expected, idToConstant); + } + + @Override + public ValueReader union(Type expected, Schema union, List> options) { + return ValueReaders.union(options); + } + + @Override + public ValueReader array(Types.ListType expected, Schema array, ValueReader elementReader) { + return SparkValueReaders.array(elementReader); + } + + @Override + public ValueReader map(Types.MapType expected, Schema map, + ValueReader keyReader, ValueReader valueReader) { + return SparkValueReaders.arrayMap(keyReader, valueReader); + } + + @Override + public ValueReader map(Types.MapType expected, Schema map, ValueReader valueReader) { + return SparkValueReaders.map(SparkValueReaders.strings(), valueReader); + } + + @Override + public ValueReader primitive(Type.PrimitiveType expected, Schema primitive) { + LogicalType logicalType = primitive.getLogicalType(); + if (logicalType != null) { + switch (logicalType.getName()) { + case "date": + // Spark uses the same representation + return ValueReaders.ints(); + + case "timestamp-millis": + // adjust to microseconds + ValueReader longs = ValueReaders.longs(); + return (ValueReader) (decoder, ignored) -> longs.read(decoder, null) * 1000L; + + case "timestamp-micros": + // Spark uses the same representation + return ValueReaders.longs(); + + case "decimal": + return SparkValueReaders.decimal( + ValueReaders.decimalBytesReader(primitive), + ((LogicalTypes.Decimal) logicalType).getScale()); + + case "uuid": + return SparkValueReaders.uuids(); + + default: + throw new IllegalArgumentException("Unknown logical type: " + logicalType); + } + } + + switch (primitive.getType()) { + case NULL: + return ValueReaders.nulls(); + case BOOLEAN: + return ValueReaders.booleans(); + case INT: + return ValueReaders.ints(); + case LONG: + return ValueReaders.longs(); + case FLOAT: + return ValueReaders.floats(); + case DOUBLE: + return ValueReaders.doubles(); + case STRING: + return SparkValueReaders.strings(); + case FIXED: + return ValueReaders.fixed(primitive.getFixedSize()); + case BYTES: + return ValueReaders.bytes(); + case ENUM: + return SparkValueReaders.enums(primitive.getEnumSymbols()); + default: + throw new IllegalArgumentException("Unsupported type: " + primitive); + } + } + } +} diff --git a/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/data/SparkAvroWriter.java b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/data/SparkAvroWriter.java new file mode 100644 index 000000000000..7582125128a7 --- /dev/null +++ b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/data/SparkAvroWriter.java @@ -0,0 +1,156 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iceberg.spark.data; + +import java.io.IOException; +import java.util.List; +import java.util.stream.Collectors; +import java.util.stream.IntStream; +import java.util.stream.Stream; +import org.apache.avro.LogicalType; +import org.apache.avro.LogicalTypes; +import org.apache.avro.Schema; +import org.apache.avro.io.Encoder; +import org.apache.iceberg.FieldMetrics; +import org.apache.iceberg.avro.MetricsAwareDatumWriter; +import org.apache.iceberg.avro.ValueWriter; +import org.apache.iceberg.avro.ValueWriters; +import org.apache.iceberg.relocated.com.google.common.base.Preconditions; +import org.apache.spark.sql.catalyst.InternalRow; +import org.apache.spark.sql.types.ByteType; +import org.apache.spark.sql.types.DataType; +import org.apache.spark.sql.types.ShortType; +import org.apache.spark.sql.types.StructType; + +public class SparkAvroWriter implements MetricsAwareDatumWriter { + private final StructType dsSchema; + private ValueWriter writer = null; + + public SparkAvroWriter(StructType dsSchema) { + this.dsSchema = dsSchema; + } + + @Override + @SuppressWarnings("unchecked") + public void setSchema(Schema schema) { + this.writer = (ValueWriter) AvroWithSparkSchemaVisitor + .visit(dsSchema, schema, new WriteBuilder()); + } + + @Override + public void write(InternalRow datum, Encoder out) throws IOException { + writer.write(datum, out); + } + + @Override + public Stream metrics() { + return writer.metrics(); + } + + private static class WriteBuilder extends AvroWithSparkSchemaVisitor> { + @Override + public ValueWriter record(DataType struct, Schema record, List names, List> fields) { + return SparkValueWriters.struct(fields, IntStream.range(0, names.size()) + .mapToObj(i -> fieldNameAndType(struct, i).second()).collect(Collectors.toList())); + } + + @Override + public ValueWriter union(DataType type, Schema union, List> options) { + Preconditions.checkArgument(options.contains(ValueWriters.nulls()), + "Cannot create writer for non-option union: %s", union); + Preconditions.checkArgument(options.size() == 2, + "Cannot create writer for non-option union: %s", union); + if (union.getTypes().get(0).getType() == Schema.Type.NULL) { + return ValueWriters.option(0, options.get(1)); + } else { + return ValueWriters.option(1, options.get(0)); + } + } + + @Override + public ValueWriter array(DataType sArray, Schema array, ValueWriter elementWriter) { + return SparkValueWriters.array(elementWriter, arrayElementType(sArray)); + } + + @Override + public ValueWriter map(DataType sMap, Schema map, ValueWriter valueReader) { + return SparkValueWriters.map(SparkValueWriters.strings(), mapKeyType(sMap), valueReader, mapValueType(sMap)); + } + + @Override + public ValueWriter map(DataType sMap, Schema map, ValueWriter keyWriter, ValueWriter valueWriter) { + return SparkValueWriters.arrayMap(keyWriter, mapKeyType(sMap), valueWriter, mapValueType(sMap)); + } + + @Override + public ValueWriter primitive(DataType type, Schema primitive) { + LogicalType logicalType = primitive.getLogicalType(); + if (logicalType != null) { + switch (logicalType.getName()) { + case "date": + // Spark uses the same representation + return ValueWriters.ints(); + + case "timestamp-micros": + // Spark uses the same representation + return ValueWriters.longs(); + + case "decimal": + LogicalTypes.Decimal decimal = (LogicalTypes.Decimal) logicalType; + return SparkValueWriters.decimal(decimal.getPrecision(), decimal.getScale()); + + case "uuid": + return ValueWriters.uuids(); + + default: + throw new IllegalArgumentException("Unsupported logical type: " + logicalType); + } + } + + switch (primitive.getType()) { + case NULL: + return ValueWriters.nulls(); + case BOOLEAN: + return ValueWriters.booleans(); + case INT: + if (type instanceof ByteType) { + return ValueWriters.tinyints(); + } else if (type instanceof ShortType) { + return ValueWriters.shorts(); + } + return ValueWriters.ints(); + case LONG: + return ValueWriters.longs(); + case FLOAT: + return ValueWriters.floats(); + case DOUBLE: + return ValueWriters.doubles(); + case STRING: + return SparkValueWriters.strings(); + case FIXED: + return ValueWriters.fixed(primitive.getFixedSize()); + case BYTES: + return ValueWriters.bytes(); + default: + throw new IllegalArgumentException("Unsupported type: " + primitive); + } + } + } +} diff --git a/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/data/SparkOrcReader.java b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/data/SparkOrcReader.java new file mode 100644 index 000000000000..4ed6420a9aa4 --- /dev/null +++ b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/data/SparkOrcReader.java @@ -0,0 +1,123 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iceberg.spark.data; + +import java.util.List; +import java.util.Map; +import org.apache.iceberg.orc.OrcRowReader; +import org.apache.iceberg.orc.OrcSchemaWithTypeVisitor; +import org.apache.iceberg.orc.OrcValueReader; +import org.apache.iceberg.orc.OrcValueReaders; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap; +import org.apache.iceberg.types.Type; +import org.apache.iceberg.types.Types; +import org.apache.orc.TypeDescription; +import org.apache.orc.storage.ql.exec.vector.StructColumnVector; +import org.apache.orc.storage.ql.exec.vector.VectorizedRowBatch; +import org.apache.spark.sql.catalyst.InternalRow; + +/** + * Converts the OrcIterator, which returns ORC's VectorizedRowBatch to a + * set of Spark's UnsafeRows. + * + * It minimizes allocations by reusing most of the objects in the implementation. + */ +public class SparkOrcReader implements OrcRowReader { + private final OrcValueReader reader; + + public SparkOrcReader(org.apache.iceberg.Schema expectedSchema, TypeDescription readSchema) { + this(expectedSchema, readSchema, ImmutableMap.of()); + } + + @SuppressWarnings("unchecked") + public SparkOrcReader( + org.apache.iceberg.Schema expectedSchema, TypeDescription readOrcSchema, Map idToConstant) { + this.reader = OrcSchemaWithTypeVisitor.visit(expectedSchema, readOrcSchema, new ReadBuilder(idToConstant)); + } + + @Override + public InternalRow read(VectorizedRowBatch batch, int row) { + return (InternalRow) reader.read(new StructColumnVector(batch.size, batch.cols), row); + } + + @Override + public void setBatchContext(long batchOffsetInFile) { + reader.setBatchContext(batchOffsetInFile); + } + + private static class ReadBuilder extends OrcSchemaWithTypeVisitor> { + private final Map idToConstant; + + private ReadBuilder(Map idToConstant) { + this.idToConstant = idToConstant; + } + + @Override + public OrcValueReader record( + Types.StructType expected, TypeDescription record, List names, List> fields) { + return SparkOrcValueReaders.struct(fields, expected, idToConstant); + } + + @Override + public OrcValueReader list(Types.ListType iList, TypeDescription array, OrcValueReader elementReader) { + return SparkOrcValueReaders.array(elementReader); + } + + @Override + public OrcValueReader map( + Types.MapType iMap, TypeDescription map, OrcValueReader keyReader, OrcValueReader valueReader) { + return SparkOrcValueReaders.map(keyReader, valueReader); + } + + @Override + public OrcValueReader primitive(Type.PrimitiveType iPrimitive, TypeDescription primitive) { + switch (primitive.getCategory()) { + case BOOLEAN: + return OrcValueReaders.booleans(); + case BYTE: + // Iceberg does not have a byte type. Use int + case SHORT: + // Iceberg does not have a short type. Use int + case DATE: + case INT: + return OrcValueReaders.ints(); + case LONG: + return OrcValueReaders.longs(); + case FLOAT: + return OrcValueReaders.floats(); + case DOUBLE: + return OrcValueReaders.doubles(); + case TIMESTAMP_INSTANT: + case TIMESTAMP: + return SparkOrcValueReaders.timestampTzs(); + case DECIMAL: + return SparkOrcValueReaders.decimals(primitive.getPrecision(), primitive.getScale()); + case CHAR: + case VARCHAR: + case STRING: + return SparkOrcValueReaders.utf8String(); + case BINARY: + return OrcValueReaders.bytes(); + default: + throw new IllegalArgumentException("Unhandled type " + primitive); + } + } + } +} diff --git a/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/data/SparkOrcValueReaders.java b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/data/SparkOrcValueReaders.java new file mode 100644 index 000000000000..f35ab7a17c63 --- /dev/null +++ b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/data/SparkOrcValueReaders.java @@ -0,0 +1,232 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iceberg.spark.data; + +import java.math.BigDecimal; +import java.util.List; +import java.util.Map; +import org.apache.iceberg.orc.OrcValueReader; +import org.apache.iceberg.orc.OrcValueReaders; +import org.apache.iceberg.relocated.com.google.common.base.Preconditions; +import org.apache.iceberg.relocated.com.google.common.collect.Lists; +import org.apache.iceberg.types.Types; +import org.apache.orc.storage.ql.exec.vector.BytesColumnVector; +import org.apache.orc.storage.ql.exec.vector.ColumnVector; +import org.apache.orc.storage.ql.exec.vector.DecimalColumnVector; +import org.apache.orc.storage.ql.exec.vector.ListColumnVector; +import org.apache.orc.storage.ql.exec.vector.MapColumnVector; +import org.apache.orc.storage.ql.exec.vector.TimestampColumnVector; +import org.apache.orc.storage.serde2.io.HiveDecimalWritable; +import org.apache.spark.sql.catalyst.InternalRow; +import org.apache.spark.sql.catalyst.expressions.GenericInternalRow; +import org.apache.spark.sql.catalyst.util.ArrayBasedMapData; +import org.apache.spark.sql.catalyst.util.ArrayData; +import org.apache.spark.sql.catalyst.util.GenericArrayData; +import org.apache.spark.sql.catalyst.util.MapData; +import org.apache.spark.sql.types.Decimal; +import org.apache.spark.unsafe.types.UTF8String; + +public class SparkOrcValueReaders { + private SparkOrcValueReaders() { + } + + public static OrcValueReader utf8String() { + return StringReader.INSTANCE; + } + + public static OrcValueReader timestampTzs() { + return TimestampTzReader.INSTANCE; + } + + public static OrcValueReader decimals(int precision, int scale) { + if (precision <= Decimal.MAX_LONG_DIGITS()) { + return new SparkOrcValueReaders.Decimal18Reader(precision, scale); + } else if (precision <= 38) { + return new SparkOrcValueReaders.Decimal38Reader(precision, scale); + } else { + throw new IllegalArgumentException("Invalid precision: " + precision); + } + } + + static OrcValueReader struct( + List> readers, Types.StructType struct, Map idToConstant) { + return new StructReader(readers, struct, idToConstant); + } + + static OrcValueReader array(OrcValueReader elementReader) { + return new ArrayReader(elementReader); + } + + static OrcValueReader map(OrcValueReader keyReader, OrcValueReader valueReader) { + return new MapReader(keyReader, valueReader); + } + + private static class ArrayReader implements OrcValueReader { + private final OrcValueReader elementReader; + + private ArrayReader(OrcValueReader elementReader) { + this.elementReader = elementReader; + } + + @Override + public ArrayData nonNullRead(ColumnVector vector, int row) { + ListColumnVector listVector = (ListColumnVector) vector; + int offset = (int) listVector.offsets[row]; + int length = (int) listVector.lengths[row]; + List elements = Lists.newArrayListWithExpectedSize(length); + for (int c = 0; c < length; ++c) { + elements.add(elementReader.read(listVector.child, offset + c)); + } + return new GenericArrayData(elements.toArray()); + } + + @Override + public void setBatchContext(long batchOffsetInFile) { + elementReader.setBatchContext(batchOffsetInFile); + } + } + + private static class MapReader implements OrcValueReader { + private final OrcValueReader keyReader; + private final OrcValueReader valueReader; + + private MapReader(OrcValueReader keyReader, OrcValueReader valueReader) { + this.keyReader = keyReader; + this.valueReader = valueReader; + } + + @Override + public MapData nonNullRead(ColumnVector vector, int row) { + MapColumnVector mapVector = (MapColumnVector) vector; + int offset = (int) mapVector.offsets[row]; + long length = mapVector.lengths[row]; + List keys = Lists.newArrayListWithExpectedSize((int) length); + List values = Lists.newArrayListWithExpectedSize((int) length); + for (int c = 0; c < length; c++) { + keys.add(keyReader.read(mapVector.keys, offset + c)); + values.add(valueReader.read(mapVector.values, offset + c)); + } + + return new ArrayBasedMapData( + new GenericArrayData(keys.toArray()), + new GenericArrayData(values.toArray())); + } + + @Override + public void setBatchContext(long batchOffsetInFile) { + keyReader.setBatchContext(batchOffsetInFile); + valueReader.setBatchContext(batchOffsetInFile); + } + } + + static class StructReader extends OrcValueReaders.StructReader { + private final int numFields; + + protected StructReader(List> readers, Types.StructType struct, Map idToConstant) { + super(readers, struct, idToConstant); + this.numFields = struct.fields().size(); + } + + @Override + protected InternalRow create() { + return new GenericInternalRow(numFields); + } + + @Override + protected void set(InternalRow struct, int pos, Object value) { + if (value != null) { + struct.update(pos, value); + } else { + struct.setNullAt(pos); + } + } + } + + private static class StringReader implements OrcValueReader { + private static final StringReader INSTANCE = new StringReader(); + + private StringReader() { + } + + @Override + public UTF8String nonNullRead(ColumnVector vector, int row) { + BytesColumnVector bytesVector = (BytesColumnVector) vector; + return UTF8String.fromBytes(bytesVector.vector[row], bytesVector.start[row], bytesVector.length[row]); + } + } + + private static class TimestampTzReader implements OrcValueReader { + private static final TimestampTzReader INSTANCE = new TimestampTzReader(); + + private TimestampTzReader() { + } + + @Override + public Long nonNullRead(ColumnVector vector, int row) { + TimestampColumnVector tcv = (TimestampColumnVector) vector; + return Math.floorDiv(tcv.time[row], 1_000) * 1_000_000 + Math.floorDiv(tcv.nanos[row], 1000); + } + } + + private static class Decimal18Reader implements OrcValueReader { + private final int precision; + private final int scale; + + Decimal18Reader(int precision, int scale) { + this.precision = precision; + this.scale = scale; + } + + @Override + public Decimal nonNullRead(ColumnVector vector, int row) { + HiveDecimalWritable value = ((DecimalColumnVector) vector).vector[row]; + + // The scale of decimal read from hive ORC file may be not equals to the expected scale. For data type + // decimal(10,3) and the value 10.100, the hive ORC writer will remove its trailing zero and store it + // as 101*10^(-1), its scale will adjust from 3 to 1. So here we could not assert that value.scale() == scale. + // we also need to convert the hive orc decimal to a decimal with expected precision and scale. + Preconditions.checkArgument(value.precision() <= precision, + "Cannot read value as decimal(%s,%s), too large: %s", precision, scale, value); + + return new Decimal().set(value.serialize64(scale), precision, scale); + } + } + + private static class Decimal38Reader implements OrcValueReader { + private final int precision; + private final int scale; + + Decimal38Reader(int precision, int scale) { + this.precision = precision; + this.scale = scale; + } + + @Override + public Decimal nonNullRead(ColumnVector vector, int row) { + BigDecimal value = ((DecimalColumnVector) vector).vector[row] + .getHiveDecimal().bigDecimalValue(); + + Preconditions.checkArgument(value.precision() <= precision, + "Cannot read value as decimal(%s,%s), too large: %s", precision, scale, value); + + return new Decimal().set(new scala.math.BigDecimal(value), precision, scale); + } + } +} diff --git a/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/data/SparkOrcValueWriters.java b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/data/SparkOrcValueWriters.java new file mode 100644 index 000000000000..abb12dffc050 --- /dev/null +++ b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/data/SparkOrcValueWriters.java @@ -0,0 +1,201 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iceberg.spark.data; + +import java.util.List; +import java.util.stream.Stream; +import org.apache.iceberg.FieldMetrics; +import org.apache.iceberg.orc.OrcValueWriter; +import org.apache.orc.TypeDescription; +import org.apache.orc.storage.common.type.HiveDecimal; +import org.apache.orc.storage.ql.exec.vector.BytesColumnVector; +import org.apache.orc.storage.ql.exec.vector.ColumnVector; +import org.apache.orc.storage.ql.exec.vector.DecimalColumnVector; +import org.apache.orc.storage.ql.exec.vector.ListColumnVector; +import org.apache.orc.storage.ql.exec.vector.MapColumnVector; +import org.apache.orc.storage.ql.exec.vector.TimestampColumnVector; +import org.apache.spark.sql.catalyst.util.ArrayData; +import org.apache.spark.sql.catalyst.util.MapData; +import org.apache.spark.sql.types.Decimal; +import org.apache.spark.unsafe.types.UTF8String; + +class SparkOrcValueWriters { + private SparkOrcValueWriters() { + } + + static OrcValueWriter strings() { + return StringWriter.INSTANCE; + } + + static OrcValueWriter timestampTz() { + return TimestampTzWriter.INSTANCE; + } + + static OrcValueWriter decimal(int precision, int scale) { + if (precision <= 18) { + return new Decimal18Writer(scale); + } else { + return new Decimal38Writer(); + } + } + + static OrcValueWriter list(OrcValueWriter element, List orcType) { + return new ListWriter<>(element, orcType); + } + + static OrcValueWriter map(OrcValueWriter keyWriter, OrcValueWriter valueWriter, + List orcTypes) { + return new MapWriter<>(keyWriter, valueWriter, orcTypes); + } + + private static class StringWriter implements OrcValueWriter { + private static final StringWriter INSTANCE = new StringWriter(); + + @Override + public void nonNullWrite(int rowId, UTF8String data, ColumnVector output) { + byte[] value = data.getBytes(); + ((BytesColumnVector) output).setRef(rowId, value, 0, value.length); + } + + } + + private static class TimestampTzWriter implements OrcValueWriter { + private static final TimestampTzWriter INSTANCE = new TimestampTzWriter(); + + @Override + public void nonNullWrite(int rowId, Long micros, ColumnVector output) { + TimestampColumnVector cv = (TimestampColumnVector) output; + cv.time[rowId] = Math.floorDiv(micros, 1_000); // millis + cv.nanos[rowId] = (int) Math.floorMod(micros, 1_000_000) * 1_000; // nanos + } + + } + + private static class Decimal18Writer implements OrcValueWriter { + private final int scale; + + Decimal18Writer(int scale) { + this.scale = scale; + } + + @Override + public void nonNullWrite(int rowId, Decimal decimal, ColumnVector output) { + ((DecimalColumnVector) output).vector[rowId].setFromLongAndScale( + decimal.toUnscaledLong(), scale); + } + + } + + private static class Decimal38Writer implements OrcValueWriter { + + @Override + public void nonNullWrite(int rowId, Decimal decimal, ColumnVector output) { + ((DecimalColumnVector) output).vector[rowId].set( + HiveDecimal.create(decimal.toJavaBigDecimal())); + } + + } + + private static class ListWriter implements OrcValueWriter { + private final OrcValueWriter writer; + private final SparkOrcWriter.FieldGetter fieldGetter; + + @SuppressWarnings("unchecked") + ListWriter(OrcValueWriter writer, List orcTypes) { + if (orcTypes.size() != 1) { + throw new IllegalArgumentException("Expected one (and same) ORC type for list elements, got: " + orcTypes); + } + this.writer = writer; + this.fieldGetter = (SparkOrcWriter.FieldGetter) SparkOrcWriter.createFieldGetter(orcTypes.get(0)); + } + + @Override + public void nonNullWrite(int rowId, ArrayData value, ColumnVector output) { + ListColumnVector cv = (ListColumnVector) output; + // record the length and start of the list elements + cv.lengths[rowId] = value.numElements(); + cv.offsets[rowId] = cv.childCount; + cv.childCount = (int) (cv.childCount + cv.lengths[rowId]); + // make sure the child is big enough + growColumnVector(cv.child, cv.childCount); + // Add each element + for (int e = 0; e < cv.lengths[rowId]; ++e) { + writer.write((int) (e + cv.offsets[rowId]), fieldGetter.getFieldOrNull(value, e), cv.child); + } + } + + @Override + public Stream> metrics() { + return writer.metrics(); + } + + } + + private static class MapWriter implements OrcValueWriter { + private final OrcValueWriter keyWriter; + private final OrcValueWriter valueWriter; + private final SparkOrcWriter.FieldGetter keyFieldGetter; + private final SparkOrcWriter.FieldGetter valueFieldGetter; + + @SuppressWarnings("unchecked") + MapWriter(OrcValueWriter keyWriter, OrcValueWriter valueWriter, List orcTypes) { + if (orcTypes.size() != 2) { + throw new IllegalArgumentException("Expected two ORC type descriptions for a map, got: " + orcTypes); + } + this.keyWriter = keyWriter; + this.valueWriter = valueWriter; + this.keyFieldGetter = (SparkOrcWriter.FieldGetter) SparkOrcWriter.createFieldGetter(orcTypes.get(0)); + this.valueFieldGetter = (SparkOrcWriter.FieldGetter) SparkOrcWriter.createFieldGetter(orcTypes.get(1)); + } + + @Override + public void nonNullWrite(int rowId, MapData map, ColumnVector output) { + ArrayData key = map.keyArray(); + ArrayData value = map.valueArray(); + MapColumnVector cv = (MapColumnVector) output; + // record the length and start of the list elements + cv.lengths[rowId] = value.numElements(); + cv.offsets[rowId] = cv.childCount; + cv.childCount = (int) (cv.childCount + cv.lengths[rowId]); + // make sure the child is big enough + growColumnVector(cv.keys, cv.childCount); + growColumnVector(cv.values, cv.childCount); + // Add each element + for (int e = 0; e < cv.lengths[rowId]; ++e) { + int pos = (int) (e + cv.offsets[rowId]); + keyWriter.write(pos, keyFieldGetter.getFieldOrNull(key, e), cv.keys); + valueWriter.write(pos, valueFieldGetter.getFieldOrNull(value, e), cv.values); + } + } + + @Override + public Stream> metrics() { + return Stream.concat(keyWriter.metrics(), valueWriter.metrics()); + } + + } + + private static void growColumnVector(ColumnVector cv, int requestedSize) { + if (cv.isNull.length < requestedSize) { + // Use growth factor of 3 to avoid frequent array allocations + cv.ensureSize(requestedSize * 3, true); + } + } +} diff --git a/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/data/SparkOrcWriter.java b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/data/SparkOrcWriter.java new file mode 100644 index 000000000000..9c7f3a6eb01d --- /dev/null +++ b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/data/SparkOrcWriter.java @@ -0,0 +1,222 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iceberg.spark.data; + +import java.io.Serializable; +import java.util.List; +import java.util.stream.Stream; +import javax.annotation.Nullable; +import org.apache.iceberg.FieldMetrics; +import org.apache.iceberg.Schema; +import org.apache.iceberg.data.orc.GenericOrcWriters; +import org.apache.iceberg.orc.ORCSchemaUtil; +import org.apache.iceberg.orc.OrcRowWriter; +import org.apache.iceberg.orc.OrcSchemaWithTypeVisitor; +import org.apache.iceberg.orc.OrcValueWriter; +import org.apache.iceberg.relocated.com.google.common.base.Preconditions; +import org.apache.iceberg.relocated.com.google.common.collect.Lists; +import org.apache.iceberg.types.Type; +import org.apache.iceberg.types.Types; +import org.apache.orc.TypeDescription; +import org.apache.orc.storage.ql.exec.vector.VectorizedRowBatch; +import org.apache.spark.sql.catalyst.InternalRow; +import org.apache.spark.sql.catalyst.expressions.SpecializedGetters; + +/** + * This class acts as an adaptor from an OrcFileAppender to a + * FileAppender<InternalRow>. + */ +public class SparkOrcWriter implements OrcRowWriter { + + private final InternalRowWriter writer; + + public SparkOrcWriter(Schema iSchema, TypeDescription orcSchema) { + Preconditions.checkArgument(orcSchema.getCategory() == TypeDescription.Category.STRUCT, + "Top level must be a struct " + orcSchema); + + writer = (InternalRowWriter) OrcSchemaWithTypeVisitor.visit(iSchema, orcSchema, new WriteBuilder()); + } + + @Override + public void write(InternalRow value, VectorizedRowBatch output) { + Preconditions.checkArgument(value != null, "value must not be null"); + writer.writeRow(value, output); + } + + @Override + public List> writers() { + return writer.writers(); + } + + @Override + public Stream> metrics() { + return writer.metrics(); + } + + private static class WriteBuilder extends OrcSchemaWithTypeVisitor> { + private WriteBuilder() { + } + + @Override + public OrcValueWriter record(Types.StructType iStruct, TypeDescription record, + List names, List> fields) { + return new InternalRowWriter(fields, record.getChildren()); + } + + @Override + public OrcValueWriter list(Types.ListType iList, TypeDescription array, + OrcValueWriter element) { + return SparkOrcValueWriters.list(element, array.getChildren()); + } + + @Override + public OrcValueWriter map(Types.MapType iMap, TypeDescription map, + OrcValueWriter key, OrcValueWriter value) { + return SparkOrcValueWriters.map(key, value, map.getChildren()); + } + + @Override + public OrcValueWriter primitive(Type.PrimitiveType iPrimitive, TypeDescription primitive) { + switch (primitive.getCategory()) { + case BOOLEAN: + return GenericOrcWriters.booleans(); + case BYTE: + return GenericOrcWriters.bytes(); + case SHORT: + return GenericOrcWriters.shorts(); + case DATE: + case INT: + return GenericOrcWriters.ints(); + case LONG: + return GenericOrcWriters.longs(); + case FLOAT: + return GenericOrcWriters.floats(ORCSchemaUtil.fieldId(primitive)); + case DOUBLE: + return GenericOrcWriters.doubles(ORCSchemaUtil.fieldId(primitive)); + case BINARY: + return GenericOrcWriters.byteArrays(); + case STRING: + case CHAR: + case VARCHAR: + return SparkOrcValueWriters.strings(); + case DECIMAL: + return SparkOrcValueWriters.decimal(primitive.getPrecision(), primitive.getScale()); + case TIMESTAMP_INSTANT: + case TIMESTAMP: + return SparkOrcValueWriters.timestampTz(); + default: + throw new IllegalArgumentException("Unhandled type " + primitive); + } + } + } + + private static class InternalRowWriter extends GenericOrcWriters.StructWriter { + private final List> fieldGetters; + + InternalRowWriter(List> writers, List orcTypes) { + super(writers); + this.fieldGetters = Lists.newArrayListWithExpectedSize(orcTypes.size()); + + for (int i = 0; i < orcTypes.size(); i++) { + fieldGetters.add(createFieldGetter(orcTypes.get(i))); + } + } + + @Override + protected Object get(InternalRow struct, int index) { + return fieldGetters.get(index).getFieldOrNull(struct, index); + } + } + + static FieldGetter createFieldGetter(TypeDescription fieldType) { + final FieldGetter fieldGetter; + switch (fieldType.getCategory()) { + case BOOLEAN: + fieldGetter = SpecializedGetters::getBoolean; + break; + case BYTE: + fieldGetter = SpecializedGetters::getByte; + break; + case SHORT: + fieldGetter = SpecializedGetters::getShort; + break; + case DATE: + case INT: + fieldGetter = SpecializedGetters::getInt; + break; + case LONG: + case TIMESTAMP: + case TIMESTAMP_INSTANT: + fieldGetter = SpecializedGetters::getLong; + break; + case FLOAT: + fieldGetter = SpecializedGetters::getFloat; + break; + case DOUBLE: + fieldGetter = SpecializedGetters::getDouble; + break; + case BINARY: + fieldGetter = SpecializedGetters::getBinary; + // getBinary always makes a copy, so we don't need to worry about it + // being changed behind our back. + break; + case DECIMAL: + fieldGetter = (row, ordinal) -> + row.getDecimal(ordinal, fieldType.getPrecision(), fieldType.getScale()); + break; + case STRING: + case CHAR: + case VARCHAR: + fieldGetter = SpecializedGetters::getUTF8String; + break; + case STRUCT: + fieldGetter = (row, ordinal) -> row.getStruct(ordinal, fieldType.getChildren().size()); + break; + case LIST: + fieldGetter = SpecializedGetters::getArray; + break; + case MAP: + fieldGetter = SpecializedGetters::getMap; + break; + default: + throw new IllegalArgumentException("Encountered an unsupported ORC type during a write from Spark."); + } + + return (row, ordinal) -> { + if (row.isNullAt(ordinal)) { + return null; + } + return fieldGetter.getFieldOrNull(row, ordinal); + }; + } + + interface FieldGetter extends Serializable { + + /** + * Returns a value from a complex Spark data holder such ArrayData, InternalRow, etc... + * Calls the appropriate getter for the expected data type. + * @param row Spark's data representation + * @param ordinal index in the data structure (e.g. column index for InterRow, list index in ArrayData, etc..) + * @return field value at ordinal + */ + @Nullable + T getFieldOrNull(SpecializedGetters row, int ordinal); + } +} diff --git a/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/data/SparkParquetReaders.java b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/data/SparkParquetReaders.java new file mode 100644 index 000000000000..35627ace23b0 --- /dev/null +++ b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/data/SparkParquetReaders.java @@ -0,0 +1,755 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iceberg.spark.data; + +import java.math.BigDecimal; +import java.math.BigInteger; +import java.nio.ByteBuffer; +import java.nio.ByteOrder; +import java.util.Arrays; +import java.util.List; +import java.util.Map; +import java.util.concurrent.TimeUnit; +import org.apache.iceberg.MetadataColumns; +import org.apache.iceberg.Schema; +import org.apache.iceberg.parquet.ParquetSchemaUtil; +import org.apache.iceberg.parquet.ParquetValueReader; +import org.apache.iceberg.parquet.ParquetValueReaders; +import org.apache.iceberg.parquet.ParquetValueReaders.FloatAsDoubleReader; +import org.apache.iceberg.parquet.ParquetValueReaders.IntAsLongReader; +import org.apache.iceberg.parquet.ParquetValueReaders.PrimitiveReader; +import org.apache.iceberg.parquet.ParquetValueReaders.RepeatedKeyValueReader; +import org.apache.iceberg.parquet.ParquetValueReaders.RepeatedReader; +import org.apache.iceberg.parquet.ParquetValueReaders.ReusableEntry; +import org.apache.iceberg.parquet.ParquetValueReaders.StructReader; +import org.apache.iceberg.parquet.ParquetValueReaders.UnboxedReader; +import org.apache.iceberg.parquet.TypeWithSchemaVisitor; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableList; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap; +import org.apache.iceberg.relocated.com.google.common.collect.Lists; +import org.apache.iceberg.relocated.com.google.common.collect.Maps; +import org.apache.iceberg.types.Type.TypeID; +import org.apache.iceberg.types.Types; +import org.apache.parquet.column.ColumnDescriptor; +import org.apache.parquet.io.api.Binary; +import org.apache.parquet.schema.DecimalMetadata; +import org.apache.parquet.schema.GroupType; +import org.apache.parquet.schema.MessageType; +import org.apache.parquet.schema.PrimitiveType; +import org.apache.parquet.schema.Type; +import org.apache.spark.sql.catalyst.InternalRow; +import org.apache.spark.sql.catalyst.expressions.GenericInternalRow; +import org.apache.spark.sql.catalyst.util.ArrayBasedMapData; +import org.apache.spark.sql.catalyst.util.ArrayData; +import org.apache.spark.sql.catalyst.util.GenericArrayData; +import org.apache.spark.sql.catalyst.util.MapData; +import org.apache.spark.sql.types.DataType; +import org.apache.spark.sql.types.Decimal; +import org.apache.spark.unsafe.types.CalendarInterval; +import org.apache.spark.unsafe.types.UTF8String; + +public class SparkParquetReaders { + private SparkParquetReaders() { + } + + public static ParquetValueReader buildReader(Schema expectedSchema, + MessageType fileSchema) { + return buildReader(expectedSchema, fileSchema, ImmutableMap.of()); + } + + @SuppressWarnings("unchecked") + public static ParquetValueReader buildReader(Schema expectedSchema, + MessageType fileSchema, + Map idToConstant) { + if (ParquetSchemaUtil.hasIds(fileSchema)) { + return (ParquetValueReader) + TypeWithSchemaVisitor.visit(expectedSchema.asStruct(), fileSchema, + new ReadBuilder(fileSchema, idToConstant)); + } else { + return (ParquetValueReader) + TypeWithSchemaVisitor.visit(expectedSchema.asStruct(), fileSchema, + new FallbackReadBuilder(fileSchema, idToConstant)); + } + } + + private static class FallbackReadBuilder extends ReadBuilder { + FallbackReadBuilder(MessageType type, Map idToConstant) { + super(type, idToConstant); + } + + @Override + public ParquetValueReader message(Types.StructType expected, MessageType message, + List> fieldReaders) { + // the top level matches by ID, but the remaining IDs are missing + return super.struct(expected, message, fieldReaders); + } + + @Override + public ParquetValueReader struct(Types.StructType ignored, GroupType struct, + List> fieldReaders) { + // the expected struct is ignored because nested fields are never found when the + List> newFields = Lists.newArrayListWithExpectedSize( + fieldReaders.size()); + List types = Lists.newArrayListWithExpectedSize(fieldReaders.size()); + List fields = struct.getFields(); + for (int i = 0; i < fields.size(); i += 1) { + Type fieldType = fields.get(i); + int fieldD = type().getMaxDefinitionLevel(path(fieldType.getName())) - 1; + newFields.add(ParquetValueReaders.option(fieldType, fieldD, fieldReaders.get(i))); + types.add(fieldType); + } + + return new InternalRowReader(types, newFields); + } + } + + private static class ReadBuilder extends TypeWithSchemaVisitor> { + private final MessageType type; + private final Map idToConstant; + + ReadBuilder(MessageType type, Map idToConstant) { + this.type = type; + this.idToConstant = idToConstant; + } + + @Override + public ParquetValueReader message(Types.StructType expected, MessageType message, + List> fieldReaders) { + return struct(expected, message.asGroupType(), fieldReaders); + } + + @Override + public ParquetValueReader struct(Types.StructType expected, GroupType struct, + List> fieldReaders) { + // match the expected struct's order + Map> readersById = Maps.newHashMap(); + Map typesById = Maps.newHashMap(); + List fields = struct.getFields(); + for (int i = 0; i < fields.size(); i += 1) { + Type fieldType = fields.get(i); + int fieldD = type.getMaxDefinitionLevel(path(fieldType.getName())) - 1; + if (fieldType.getId() != null) { + int id = fieldType.getId().intValue(); + readersById.put(id, ParquetValueReaders.option(fieldType, fieldD, fieldReaders.get(i))); + typesById.put(id, fieldType); + } + } + + List expectedFields = expected != null ? + expected.fields() : ImmutableList.of(); + List> reorderedFields = Lists.newArrayListWithExpectedSize( + expectedFields.size()); + List types = Lists.newArrayListWithExpectedSize(expectedFields.size()); + for (Types.NestedField field : expectedFields) { + int id = field.fieldId(); + if (idToConstant.containsKey(id)) { + // containsKey is used because the constant may be null + reorderedFields.add(ParquetValueReaders.constant(idToConstant.get(id))); + types.add(null); + } else if (id == MetadataColumns.ROW_POSITION.fieldId()) { + reorderedFields.add(ParquetValueReaders.position()); + types.add(null); + } else if (id == MetadataColumns.IS_DELETED.fieldId()) { + reorderedFields.add(ParquetValueReaders.constant(false)); + types.add(null); + } else { + ParquetValueReader reader = readersById.get(id); + if (reader != null) { + reorderedFields.add(reader); + types.add(typesById.get(id)); + } else { + reorderedFields.add(ParquetValueReaders.nulls()); + types.add(null); + } + } + } + + return new InternalRowReader(types, reorderedFields); + } + + @Override + public ParquetValueReader list(Types.ListType expectedList, GroupType array, + ParquetValueReader elementReader) { + String[] repeatedPath = currentPath(); + + int repeatedD = type.getMaxDefinitionLevel(repeatedPath) - 1; + int repeatedR = type.getMaxRepetitionLevel(repeatedPath) - 1; + + Type elementType = ParquetSchemaUtil.determineListElementType(array); + int elementD = type.getMaxDefinitionLevel(path(elementType.getName())) - 1; + + return new ArrayReader<>(repeatedD, repeatedR, ParquetValueReaders.option(elementType, elementD, elementReader)); + } + + @Override + public ParquetValueReader map(Types.MapType expectedMap, GroupType map, + ParquetValueReader keyReader, + ParquetValueReader valueReader) { + GroupType repeatedKeyValue = map.getFields().get(0).asGroupType(); + String[] repeatedPath = currentPath(); + + int repeatedD = type.getMaxDefinitionLevel(repeatedPath) - 1; + int repeatedR = type.getMaxRepetitionLevel(repeatedPath) - 1; + + Type keyType = repeatedKeyValue.getType(0); + int keyD = type.getMaxDefinitionLevel(path(keyType.getName())) - 1; + Type valueType = repeatedKeyValue.getType(1); + int valueD = type.getMaxDefinitionLevel(path(valueType.getName())) - 1; + + return new MapReader<>(repeatedD, repeatedR, + ParquetValueReaders.option(keyType, keyD, keyReader), + ParquetValueReaders.option(valueType, valueD, valueReader)); + } + + @Override + public ParquetValueReader primitive(org.apache.iceberg.types.Type.PrimitiveType expected, + PrimitiveType primitive) { + ColumnDescriptor desc = type.getColumnDescription(currentPath()); + + if (primitive.getOriginalType() != null) { + switch (primitive.getOriginalType()) { + case ENUM: + case JSON: + case UTF8: + return new StringReader(desc); + case INT_8: + case INT_16: + case INT_32: + if (expected != null && expected.typeId() == Types.LongType.get().typeId()) { + return new IntAsLongReader(desc); + } else { + return new UnboxedReader(desc); + } + case DATE: + case INT_64: + case TIMESTAMP_MICROS: + return new UnboxedReader<>(desc); + case TIMESTAMP_MILLIS: + return new TimestampMillisReader(desc); + case DECIMAL: + DecimalMetadata decimal = primitive.getDecimalMetadata(); + switch (primitive.getPrimitiveTypeName()) { + case BINARY: + case FIXED_LEN_BYTE_ARRAY: + return new BinaryDecimalReader(desc, decimal.getScale()); + case INT64: + return new LongDecimalReader(desc, decimal.getPrecision(), decimal.getScale()); + case INT32: + return new IntegerDecimalReader(desc, decimal.getPrecision(), decimal.getScale()); + default: + throw new UnsupportedOperationException( + "Unsupported base type for decimal: " + primitive.getPrimitiveTypeName()); + } + case BSON: + return new ParquetValueReaders.ByteArrayReader(desc); + default: + throw new UnsupportedOperationException( + "Unsupported logical type: " + primitive.getOriginalType()); + } + } + + switch (primitive.getPrimitiveTypeName()) { + case FIXED_LEN_BYTE_ARRAY: + case BINARY: + return new ParquetValueReaders.ByteArrayReader(desc); + case INT32: + if (expected != null && expected.typeId() == TypeID.LONG) { + return new IntAsLongReader(desc); + } else { + return new UnboxedReader<>(desc); + } + case FLOAT: + if (expected != null && expected.typeId() == TypeID.DOUBLE) { + return new FloatAsDoubleReader(desc); + } else { + return new UnboxedReader<>(desc); + } + case BOOLEAN: + case INT64: + case DOUBLE: + return new UnboxedReader<>(desc); + case INT96: + // Impala & Spark used to write timestamps as INT96 without a logical type. For backwards + // compatibility we try to read INT96 as timestamps. + return new TimestampInt96Reader(desc); + default: + throw new UnsupportedOperationException("Unsupported type: " + primitive); + } + } + + protected MessageType type() { + return type; + } + } + + private static class BinaryDecimalReader extends PrimitiveReader { + private final int scale; + + BinaryDecimalReader(ColumnDescriptor desc, int scale) { + super(desc); + this.scale = scale; + } + + @Override + public Decimal read(Decimal ignored) { + Binary binary = column.nextBinary(); + return Decimal.fromDecimal(new BigDecimal(new BigInteger(binary.getBytes()), scale)); + } + } + + private static class IntegerDecimalReader extends PrimitiveReader { + private final int precision; + private final int scale; + + IntegerDecimalReader(ColumnDescriptor desc, int precision, int scale) { + super(desc); + this.precision = precision; + this.scale = scale; + } + + @Override + public Decimal read(Decimal ignored) { + return Decimal.apply(column.nextInteger(), precision, scale); + } + } + + private static class LongDecimalReader extends PrimitiveReader { + private final int precision; + private final int scale; + + LongDecimalReader(ColumnDescriptor desc, int precision, int scale) { + super(desc); + this.precision = precision; + this.scale = scale; + } + + @Override + public Decimal read(Decimal ignored) { + return Decimal.apply(column.nextLong(), precision, scale); + } + } + + private static class TimestampMillisReader extends UnboxedReader { + TimestampMillisReader(ColumnDescriptor desc) { + super(desc); + } + + @Override + public Long read(Long ignored) { + return readLong(); + } + + @Override + public long readLong() { + return 1000 * column.nextLong(); + } + } + + private static class TimestampInt96Reader extends UnboxedReader { + private static final long UNIX_EPOCH_JULIAN = 2_440_588L; + + TimestampInt96Reader(ColumnDescriptor desc) { + super(desc); + } + + @Override + public Long read(Long ignored) { + return readLong(); + } + + @Override + public long readLong() { + final ByteBuffer byteBuffer = column.nextBinary().toByteBuffer().order(ByteOrder.LITTLE_ENDIAN); + final long timeOfDayNanos = byteBuffer.getLong(); + final int julianDay = byteBuffer.getInt(); + + return TimeUnit.DAYS.toMicros(julianDay - UNIX_EPOCH_JULIAN) + + TimeUnit.NANOSECONDS.toMicros(timeOfDayNanos); + } + } + + private static class StringReader extends PrimitiveReader { + StringReader(ColumnDescriptor desc) { + super(desc); + } + + @Override + public UTF8String read(UTF8String ignored) { + Binary binary = column.nextBinary(); + ByteBuffer buffer = binary.toByteBuffer(); + if (buffer.hasArray()) { + return UTF8String.fromBytes( + buffer.array(), buffer.arrayOffset() + buffer.position(), buffer.remaining()); + } else { + return UTF8String.fromBytes(binary.getBytes()); + } + } + } + + private static class ArrayReader extends RepeatedReader { + private int readPos = 0; + private int writePos = 0; + + ArrayReader(int definitionLevel, int repetitionLevel, ParquetValueReader reader) { + super(definitionLevel, repetitionLevel, reader); + } + + @Override + @SuppressWarnings("unchecked") + protected ReusableArrayData newListData(ArrayData reuse) { + this.readPos = 0; + this.writePos = 0; + + if (reuse instanceof ReusableArrayData) { + return (ReusableArrayData) reuse; + } else { + return new ReusableArrayData(); + } + } + + @Override + @SuppressWarnings("unchecked") + protected E getElement(ReusableArrayData list) { + E value = null; + if (readPos < list.capacity()) { + value = (E) list.values[readPos]; + } + + readPos += 1; + + return value; + } + + @Override + protected void addElement(ReusableArrayData reused, E element) { + if (writePos >= reused.capacity()) { + reused.grow(); + } + + reused.values[writePos] = element; + + writePos += 1; + } + + @Override + protected ArrayData buildList(ReusableArrayData list) { + list.setNumElements(writePos); + return list; + } + } + + private static class MapReader extends RepeatedKeyValueReader { + private int readPos = 0; + private int writePos = 0; + + private final ReusableEntry entry = new ReusableEntry<>(); + private final ReusableEntry nullEntry = new ReusableEntry<>(); + + MapReader(int definitionLevel, int repetitionLevel, + ParquetValueReader keyReader, ParquetValueReader valueReader) { + super(definitionLevel, repetitionLevel, keyReader, valueReader); + } + + @Override + @SuppressWarnings("unchecked") + protected ReusableMapData newMapData(MapData reuse) { + this.readPos = 0; + this.writePos = 0; + + if (reuse instanceof ReusableMapData) { + return (ReusableMapData) reuse; + } else { + return new ReusableMapData(); + } + } + + @Override + @SuppressWarnings("unchecked") + protected Map.Entry getPair(ReusableMapData map) { + Map.Entry kv = nullEntry; + if (readPos < map.capacity()) { + entry.set((K) map.keys.values[readPos], (V) map.values.values[readPos]); + kv = entry; + } + + readPos += 1; + + return kv; + } + + @Override + protected void addPair(ReusableMapData map, K key, V value) { + if (writePos >= map.capacity()) { + map.grow(); + } + + map.keys.values[writePos] = key; + map.values.values[writePos] = value; + + writePos += 1; + } + + @Override + protected MapData buildMap(ReusableMapData map) { + map.setNumElements(writePos); + return map; + } + } + + private static class InternalRowReader extends StructReader { + private final int numFields; + + InternalRowReader(List types, List> readers) { + super(types, readers); + this.numFields = readers.size(); + } + + @Override + protected GenericInternalRow newStructData(InternalRow reuse) { + if (reuse instanceof GenericInternalRow) { + return (GenericInternalRow) reuse; + } else { + return new GenericInternalRow(numFields); + } + } + + @Override + protected Object getField(GenericInternalRow intermediate, int pos) { + return intermediate.genericGet(pos); + } + + @Override + protected InternalRow buildStruct(GenericInternalRow struct) { + return struct; + } + + @Override + protected void set(GenericInternalRow row, int pos, Object value) { + row.update(pos, value); + } + + @Override + protected void setNull(GenericInternalRow row, int pos) { + row.setNullAt(pos); + } + + @Override + protected void setBoolean(GenericInternalRow row, int pos, boolean value) { + row.setBoolean(pos, value); + } + + @Override + protected void setInteger(GenericInternalRow row, int pos, int value) { + row.setInt(pos, value); + } + + @Override + protected void setLong(GenericInternalRow row, int pos, long value) { + row.setLong(pos, value); + } + + @Override + protected void setFloat(GenericInternalRow row, int pos, float value) { + row.setFloat(pos, value); + } + + @Override + protected void setDouble(GenericInternalRow row, int pos, double value) { + row.setDouble(pos, value); + } + } + + private static class ReusableMapData extends MapData { + private final ReusableArrayData keys; + private final ReusableArrayData values; + private int numElements; + + private ReusableMapData() { + this.keys = new ReusableArrayData(); + this.values = new ReusableArrayData(); + } + + private void grow() { + keys.grow(); + values.grow(); + } + + private int capacity() { + return keys.capacity(); + } + + public void setNumElements(int numElements) { + this.numElements = numElements; + keys.setNumElements(numElements); + values.setNumElements(numElements); + } + + @Override + public int numElements() { + return numElements; + } + + @Override + public MapData copy() { + return new ArrayBasedMapData(keyArray().copy(), valueArray().copy()); + } + + @Override + public ReusableArrayData keyArray() { + return keys; + } + + @Override + public ReusableArrayData valueArray() { + return values; + } + } + + private static class ReusableArrayData extends ArrayData { + private static final Object[] EMPTY = new Object[0]; + + private Object[] values = EMPTY; + private int numElements = 0; + + private void grow() { + if (values.length == 0) { + this.values = new Object[20]; + } else { + Object[] old = values; + this.values = new Object[old.length << 2]; + // copy the old array in case it has values that can be reused + System.arraycopy(old, 0, values, 0, old.length); + } + } + + private int capacity() { + return values.length; + } + + public void setNumElements(int numElements) { + this.numElements = numElements; + } + + @Override + public Object get(int ordinal, DataType dataType) { + return values[ordinal]; + } + + @Override + public int numElements() { + return numElements; + } + + @Override + public ArrayData copy() { + return new GenericArrayData(array()); + } + + @Override + public Object[] array() { + return Arrays.copyOfRange(values, 0, numElements); + } + + @Override + public void setNullAt(int i) { + values[i] = null; + } + + @Override + public void update(int ordinal, Object value) { + values[ordinal] = value; + } + + @Override + public boolean isNullAt(int ordinal) { + return null == values[ordinal]; + } + + @Override + public boolean getBoolean(int ordinal) { + return (boolean) values[ordinal]; + } + + @Override + public byte getByte(int ordinal) { + return (byte) values[ordinal]; + } + + @Override + public short getShort(int ordinal) { + return (short) values[ordinal]; + } + + @Override + public int getInt(int ordinal) { + return (int) values[ordinal]; + } + + @Override + public long getLong(int ordinal) { + return (long) values[ordinal]; + } + + @Override + public float getFloat(int ordinal) { + return (float) values[ordinal]; + } + + @Override + public double getDouble(int ordinal) { + return (double) values[ordinal]; + } + + @Override + public Decimal getDecimal(int ordinal, int precision, int scale) { + return (Decimal) values[ordinal]; + } + + @Override + public UTF8String getUTF8String(int ordinal) { + return (UTF8String) values[ordinal]; + } + + @Override + public byte[] getBinary(int ordinal) { + return (byte[]) values[ordinal]; + } + + @Override + public CalendarInterval getInterval(int ordinal) { + return (CalendarInterval) values[ordinal]; + } + + @Override + public InternalRow getStruct(int ordinal, int numFields) { + return (InternalRow) values[ordinal]; + } + + @Override + public ArrayData getArray(int ordinal) { + return (ArrayData) values[ordinal]; + } + + @Override + public MapData getMap(int ordinal) { + return (MapData) values[ordinal]; + } + } +} diff --git a/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/data/SparkParquetWriters.java b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/data/SparkParquetWriters.java new file mode 100644 index 000000000000..5e268d26ed9c --- /dev/null +++ b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/data/SparkParquetWriters.java @@ -0,0 +1,423 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iceberg.spark.data; + +import java.util.Iterator; +import java.util.List; +import java.util.Map; +import java.util.NoSuchElementException; +import org.apache.iceberg.parquet.ParquetValueReaders.ReusableEntry; +import org.apache.iceberg.parquet.ParquetValueWriter; +import org.apache.iceberg.parquet.ParquetValueWriters; +import org.apache.iceberg.parquet.ParquetValueWriters.PrimitiveWriter; +import org.apache.iceberg.parquet.ParquetValueWriters.RepeatedKeyValueWriter; +import org.apache.iceberg.parquet.ParquetValueWriters.RepeatedWriter; +import org.apache.iceberg.relocated.com.google.common.base.Preconditions; +import org.apache.iceberg.relocated.com.google.common.collect.Lists; +import org.apache.iceberg.types.TypeUtil; +import org.apache.iceberg.util.DecimalUtil; +import org.apache.parquet.column.ColumnDescriptor; +import org.apache.parquet.io.api.Binary; +import org.apache.parquet.schema.DecimalMetadata; +import org.apache.parquet.schema.GroupType; +import org.apache.parquet.schema.MessageType; +import org.apache.parquet.schema.PrimitiveType; +import org.apache.parquet.schema.Type; +import org.apache.spark.sql.catalyst.InternalRow; +import org.apache.spark.sql.catalyst.util.ArrayData; +import org.apache.spark.sql.catalyst.util.MapData; +import org.apache.spark.sql.types.ArrayType; +import org.apache.spark.sql.types.ByteType; +import org.apache.spark.sql.types.DataType; +import org.apache.spark.sql.types.Decimal; +import org.apache.spark.sql.types.MapType; +import org.apache.spark.sql.types.ShortType; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; +import org.apache.spark.unsafe.types.UTF8String; + +public class SparkParquetWriters { + private SparkParquetWriters() { + } + + @SuppressWarnings("unchecked") + public static ParquetValueWriter buildWriter(StructType dfSchema, MessageType type) { + return (ParquetValueWriter) ParquetWithSparkSchemaVisitor.visit(dfSchema, type, new WriteBuilder(type)); + } + + private static class WriteBuilder extends ParquetWithSparkSchemaVisitor> { + private final MessageType type; + + WriteBuilder(MessageType type) { + this.type = type; + } + + @Override + public ParquetValueWriter message(StructType sStruct, MessageType message, + List> fieldWriters) { + return struct(sStruct, message.asGroupType(), fieldWriters); + } + + @Override + public ParquetValueWriter struct(StructType sStruct, GroupType struct, + List> fieldWriters) { + List fields = struct.getFields(); + StructField[] sparkFields = sStruct.fields(); + List> writers = Lists.newArrayListWithExpectedSize(fieldWriters.size()); + List sparkTypes = Lists.newArrayList(); + for (int i = 0; i < fields.size(); i += 1) { + writers.add(newOption(struct.getType(i), fieldWriters.get(i))); + sparkTypes.add(sparkFields[i].dataType()); + } + + return new InternalRowWriter(writers, sparkTypes); + } + + @Override + public ParquetValueWriter list(ArrayType sArray, GroupType array, ParquetValueWriter elementWriter) { + GroupType repeated = array.getFields().get(0).asGroupType(); + String[] repeatedPath = currentPath(); + + int repeatedD = type.getMaxDefinitionLevel(repeatedPath); + int repeatedR = type.getMaxRepetitionLevel(repeatedPath); + + return new ArrayDataWriter<>(repeatedD, repeatedR, + newOption(repeated.getType(0), elementWriter), + sArray.elementType()); + } + + @Override + public ParquetValueWriter map(MapType sMap, GroupType map, + ParquetValueWriter keyWriter, ParquetValueWriter valueWriter) { + GroupType repeatedKeyValue = map.getFields().get(0).asGroupType(); + String[] repeatedPath = currentPath(); + + int repeatedD = type.getMaxDefinitionLevel(repeatedPath); + int repeatedR = type.getMaxRepetitionLevel(repeatedPath); + + return new MapDataWriter<>(repeatedD, repeatedR, + newOption(repeatedKeyValue.getType(0), keyWriter), + newOption(repeatedKeyValue.getType(1), valueWriter), + sMap.keyType(), sMap.valueType()); + } + + private ParquetValueWriter newOption(Type fieldType, ParquetValueWriter writer) { + int maxD = type.getMaxDefinitionLevel(path(fieldType.getName())); + return ParquetValueWriters.option(fieldType, maxD, writer); + } + + @Override + public ParquetValueWriter primitive(DataType sType, PrimitiveType primitive) { + ColumnDescriptor desc = type.getColumnDescription(currentPath()); + + if (primitive.getOriginalType() != null) { + switch (primitive.getOriginalType()) { + case ENUM: + case JSON: + case UTF8: + return utf8Strings(desc); + case DATE: + case INT_8: + case INT_16: + case INT_32: + return ints(sType, desc); + case INT_64: + case TIME_MICROS: + case TIMESTAMP_MICROS: + return ParquetValueWriters.longs(desc); + case DECIMAL: + DecimalMetadata decimal = primitive.getDecimalMetadata(); + switch (primitive.getPrimitiveTypeName()) { + case INT32: + return decimalAsInteger(desc, decimal.getPrecision(), decimal.getScale()); + case INT64: + return decimalAsLong(desc, decimal.getPrecision(), decimal.getScale()); + case BINARY: + case FIXED_LEN_BYTE_ARRAY: + return decimalAsFixed(desc, decimal.getPrecision(), decimal.getScale()); + default: + throw new UnsupportedOperationException( + "Unsupported base type for decimal: " + primitive.getPrimitiveTypeName()); + } + case BSON: + return byteArrays(desc); + default: + throw new UnsupportedOperationException( + "Unsupported logical type: " + primitive.getOriginalType()); + } + } + + switch (primitive.getPrimitiveTypeName()) { + case FIXED_LEN_BYTE_ARRAY: + case BINARY: + return byteArrays(desc); + case BOOLEAN: + return ParquetValueWriters.booleans(desc); + case INT32: + return ints(sType, desc); + case INT64: + return ParquetValueWriters.longs(desc); + case FLOAT: + return ParquetValueWriters.floats(desc); + case DOUBLE: + return ParquetValueWriters.doubles(desc); + default: + throw new UnsupportedOperationException("Unsupported type: " + primitive); + } + } + } + + private static PrimitiveWriter ints(DataType type, ColumnDescriptor desc) { + if (type instanceof ByteType) { + return ParquetValueWriters.tinyints(desc); + } else if (type instanceof ShortType) { + return ParquetValueWriters.shorts(desc); + } + return ParquetValueWriters.ints(desc); + } + + private static PrimitiveWriter utf8Strings(ColumnDescriptor desc) { + return new UTF8StringWriter(desc); + } + + private static PrimitiveWriter decimalAsInteger(ColumnDescriptor desc, + int precision, int scale) { + return new IntegerDecimalWriter(desc, precision, scale); + } + + private static PrimitiveWriter decimalAsLong(ColumnDescriptor desc, + int precision, int scale) { + return new LongDecimalWriter(desc, precision, scale); + } + + private static PrimitiveWriter decimalAsFixed(ColumnDescriptor desc, + int precision, int scale) { + return new FixedDecimalWriter(desc, precision, scale); + } + + private static PrimitiveWriter byteArrays(ColumnDescriptor desc) { + return new ByteArrayWriter(desc); + } + + private static class UTF8StringWriter extends PrimitiveWriter { + private UTF8StringWriter(ColumnDescriptor desc) { + super(desc); + } + + @Override + public void write(int repetitionLevel, UTF8String value) { + column.writeBinary(repetitionLevel, Binary.fromReusedByteArray(value.getBytes())); + } + } + + private static class IntegerDecimalWriter extends PrimitiveWriter { + private final int precision; + private final int scale; + + private IntegerDecimalWriter(ColumnDescriptor desc, int precision, int scale) { + super(desc); + this.precision = precision; + this.scale = scale; + } + + @Override + public void write(int repetitionLevel, Decimal decimal) { + Preconditions.checkArgument(decimal.scale() == scale, + "Cannot write value as decimal(%s,%s), wrong scale: %s", precision, scale, decimal); + Preconditions.checkArgument(decimal.precision() <= precision, + "Cannot write value as decimal(%s,%s), too large: %s", precision, scale, decimal); + + column.writeInteger(repetitionLevel, (int) decimal.toUnscaledLong()); + } + } + + private static class LongDecimalWriter extends PrimitiveWriter { + private final int precision; + private final int scale; + + private LongDecimalWriter(ColumnDescriptor desc, int precision, int scale) { + super(desc); + this.precision = precision; + this.scale = scale; + } + + @Override + public void write(int repetitionLevel, Decimal decimal) { + Preconditions.checkArgument(decimal.scale() == scale, + "Cannot write value as decimal(%s,%s), wrong scale: %s", precision, scale, decimal); + Preconditions.checkArgument(decimal.precision() <= precision, + "Cannot write value as decimal(%s,%s), too large: %s", precision, scale, decimal); + + column.writeLong(repetitionLevel, decimal.toUnscaledLong()); + } + } + + private static class FixedDecimalWriter extends PrimitiveWriter { + private final int precision; + private final int scale; + private final ThreadLocal bytes; + + private FixedDecimalWriter(ColumnDescriptor desc, int precision, int scale) { + super(desc); + this.precision = precision; + this.scale = scale; + this.bytes = ThreadLocal.withInitial(() -> new byte[TypeUtil.decimalRequiredBytes(precision)]); + } + + @Override + public void write(int repetitionLevel, Decimal decimal) { + byte[] binary = DecimalUtil.toReusedFixLengthBytes(precision, scale, decimal.toJavaBigDecimal(), bytes.get()); + column.writeBinary(repetitionLevel, Binary.fromReusedByteArray(binary)); + } + } + + private static class ByteArrayWriter extends PrimitiveWriter { + private ByteArrayWriter(ColumnDescriptor desc) { + super(desc); + } + + @Override + public void write(int repetitionLevel, byte[] bytes) { + column.writeBinary(repetitionLevel, Binary.fromReusedByteArray(bytes)); + } + } + + private static class ArrayDataWriter extends RepeatedWriter { + private final DataType elementType; + + private ArrayDataWriter(int definitionLevel, int repetitionLevel, + ParquetValueWriter writer, DataType elementType) { + super(definitionLevel, repetitionLevel, writer); + this.elementType = elementType; + } + + @Override + protected Iterator elements(ArrayData list) { + return new ElementIterator<>(list); + } + + private class ElementIterator implements Iterator { + private final int size; + private final ArrayData list; + private int index; + + private ElementIterator(ArrayData list) { + this.list = list; + size = list.numElements(); + index = 0; + } + + @Override + public boolean hasNext() { + return index != size; + } + + @Override + @SuppressWarnings("unchecked") + public E next() { + if (index >= size) { + throw new NoSuchElementException(); + } + + E element; + if (list.isNullAt(index)) { + element = null; + } else { + element = (E) list.get(index, elementType); + } + + index += 1; + + return element; + } + } + } + + private static class MapDataWriter extends RepeatedKeyValueWriter { + private final DataType keyType; + private final DataType valueType; + + private MapDataWriter(int definitionLevel, int repetitionLevel, + ParquetValueWriter keyWriter, ParquetValueWriter valueWriter, + DataType keyType, DataType valueType) { + super(definitionLevel, repetitionLevel, keyWriter, valueWriter); + this.keyType = keyType; + this.valueType = valueType; + } + + @Override + protected Iterator> pairs(MapData map) { + return new EntryIterator<>(map); + } + + private class EntryIterator implements Iterator> { + private final int size; + private final ArrayData keys; + private final ArrayData values; + private final ReusableEntry entry; + private int index; + + private EntryIterator(MapData map) { + size = map.numElements(); + keys = map.keyArray(); + values = map.valueArray(); + entry = new ReusableEntry<>(); + index = 0; + } + + @Override + public boolean hasNext() { + return index != size; + } + + @Override + @SuppressWarnings("unchecked") + public Map.Entry next() { + if (index >= size) { + throw new NoSuchElementException(); + } + + if (values.isNullAt(index)) { + entry.set((K) keys.get(index, keyType), null); + } else { + entry.set((K) keys.get(index, keyType), (V) values.get(index, valueType)); + } + + index += 1; + + return entry; + } + } + } + + private static class InternalRowWriter extends ParquetValueWriters.StructWriter { + private final DataType[] types; + + private InternalRowWriter(List> writers, List types) { + super(writers); + this.types = types.toArray(new DataType[types.size()]); + } + + @Override + protected Object get(InternalRow struct, int index) { + return struct.get(index, types[index]); + } + } +} diff --git a/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/data/SparkValueReaders.java b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/data/SparkValueReaders.java new file mode 100644 index 000000000000..0d3ce2b28d0b --- /dev/null +++ b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/data/SparkValueReaders.java @@ -0,0 +1,288 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iceberg.spark.data; + +import java.io.IOException; +import java.math.BigDecimal; +import java.math.BigInteger; +import java.nio.ByteBuffer; +import java.nio.ByteOrder; +import java.nio.charset.StandardCharsets; +import java.util.List; +import java.util.Map; +import org.apache.avro.io.Decoder; +import org.apache.avro.util.Utf8; +import org.apache.iceberg.avro.ValueReader; +import org.apache.iceberg.avro.ValueReaders; +import org.apache.iceberg.relocated.com.google.common.collect.Lists; +import org.apache.iceberg.types.Types; +import org.apache.iceberg.util.UUIDUtil; +import org.apache.spark.sql.catalyst.InternalRow; +import org.apache.spark.sql.catalyst.expressions.GenericInternalRow; +import org.apache.spark.sql.catalyst.util.ArrayBasedMapData; +import org.apache.spark.sql.catalyst.util.ArrayData; +import org.apache.spark.sql.catalyst.util.GenericArrayData; +import org.apache.spark.sql.types.Decimal; +import org.apache.spark.unsafe.types.UTF8String; + +public class SparkValueReaders { + + private SparkValueReaders() { + } + + static ValueReader strings() { + return StringReader.INSTANCE; + } + + static ValueReader enums(List symbols) { + return new EnumReader(symbols); + } + + static ValueReader uuids() { + return UUIDReader.INSTANCE; + } + + static ValueReader decimal(ValueReader unscaledReader, int scale) { + return new DecimalReader(unscaledReader, scale); + } + + static ValueReader array(ValueReader elementReader) { + return new ArrayReader(elementReader); + } + + static ValueReader arrayMap(ValueReader keyReader, + ValueReader valueReader) { + return new ArrayMapReader(keyReader, valueReader); + } + + static ValueReader map(ValueReader keyReader, ValueReader valueReader) { + return new MapReader(keyReader, valueReader); + } + + static ValueReader struct(List> readers, Types.StructType struct, + Map idToConstant) { + return new StructReader(readers, struct, idToConstant); + } + + private static class StringReader implements ValueReader { + private static final StringReader INSTANCE = new StringReader(); + + private StringReader() { + } + + @Override + public UTF8String read(Decoder decoder, Object reuse) throws IOException { + // use the decoder's readString(Utf8) method because it may be a resolving decoder + Utf8 utf8 = null; + if (reuse instanceof UTF8String) { + utf8 = new Utf8(((UTF8String) reuse).getBytes()); + } + + Utf8 string = decoder.readString(utf8); + return UTF8String.fromBytes(string.getBytes(), 0, string.getByteLength()); +// int length = decoder.readInt(); +// byte[] bytes = new byte[length]; +// decoder.readFixed(bytes, 0, length); +// return UTF8String.fromBytes(bytes); + } + } + + private static class EnumReader implements ValueReader { + private final UTF8String[] symbols; + + private EnumReader(List symbols) { + this.symbols = new UTF8String[symbols.size()]; + for (int i = 0; i < this.symbols.length; i += 1) { + this.symbols[i] = UTF8String.fromBytes(symbols.get(i).getBytes(StandardCharsets.UTF_8)); + } + } + + @Override + public UTF8String read(Decoder decoder, Object ignore) throws IOException { + int index = decoder.readEnum(); + return symbols[index]; + } + } + + private static class UUIDReader implements ValueReader { + private static final ThreadLocal BUFFER = ThreadLocal.withInitial(() -> { + ByteBuffer buffer = ByteBuffer.allocate(16); + buffer.order(ByteOrder.BIG_ENDIAN); + return buffer; + }); + + private static final UUIDReader INSTANCE = new UUIDReader(); + + private UUIDReader() { + } + + @Override + @SuppressWarnings("ByteBufferBackingArray") + public UTF8String read(Decoder decoder, Object reuse) throws IOException { + ByteBuffer buffer = BUFFER.get(); + buffer.rewind(); + + decoder.readFixed(buffer.array(), 0, 16); + + return UTF8String.fromString(UUIDUtil.convert(buffer).toString()); + } + } + + private static class DecimalReader implements ValueReader { + private final ValueReader bytesReader; + private final int scale; + + private DecimalReader(ValueReader bytesReader, int scale) { + this.bytesReader = bytesReader; + this.scale = scale; + } + + @Override + public Decimal read(Decoder decoder, Object reuse) throws IOException { + byte[] bytes = bytesReader.read(decoder, null); + return Decimal.apply(new BigDecimal(new BigInteger(bytes), scale)); + } + } + + private static class ArrayReader implements ValueReader { + private final ValueReader elementReader; + private final List reusedList = Lists.newArrayList(); + + private ArrayReader(ValueReader elementReader) { + this.elementReader = elementReader; + } + + @Override + public GenericArrayData read(Decoder decoder, Object reuse) throws IOException { + reusedList.clear(); + long chunkLength = decoder.readArrayStart(); + + while (chunkLength > 0) { + for (int i = 0; i < chunkLength; i += 1) { + reusedList.add(elementReader.read(decoder, null)); + } + + chunkLength = decoder.arrayNext(); + } + + // this will convert the list to an array so it is okay to reuse the list + return new GenericArrayData(reusedList.toArray()); + } + } + + private static class ArrayMapReader implements ValueReader { + private final ValueReader keyReader; + private final ValueReader valueReader; + + private final List reusedKeyList = Lists.newArrayList(); + private final List reusedValueList = Lists.newArrayList(); + + private ArrayMapReader(ValueReader keyReader, ValueReader valueReader) { + this.keyReader = keyReader; + this.valueReader = valueReader; + } + + @Override + public ArrayBasedMapData read(Decoder decoder, Object reuse) throws IOException { + reusedKeyList.clear(); + reusedValueList.clear(); + + long chunkLength = decoder.readArrayStart(); + + while (chunkLength > 0) { + for (int i = 0; i < chunkLength; i += 1) { + reusedKeyList.add(keyReader.read(decoder, null)); + reusedValueList.add(valueReader.read(decoder, null)); + } + + chunkLength = decoder.arrayNext(); + } + + return new ArrayBasedMapData( + new GenericArrayData(reusedKeyList.toArray()), + new GenericArrayData(reusedValueList.toArray())); + } + } + + private static class MapReader implements ValueReader { + private final ValueReader keyReader; + private final ValueReader valueReader; + + private final List reusedKeyList = Lists.newArrayList(); + private final List reusedValueList = Lists.newArrayList(); + + private MapReader(ValueReader keyReader, ValueReader valueReader) { + this.keyReader = keyReader; + this.valueReader = valueReader; + } + + @Override + public ArrayBasedMapData read(Decoder decoder, Object reuse) throws IOException { + reusedKeyList.clear(); + reusedValueList.clear(); + + long chunkLength = decoder.readMapStart(); + + while (chunkLength > 0) { + for (int i = 0; i < chunkLength; i += 1) { + reusedKeyList.add(keyReader.read(decoder, null)); + reusedValueList.add(valueReader.read(decoder, null)); + } + + chunkLength = decoder.mapNext(); + } + + return new ArrayBasedMapData( + new GenericArrayData(reusedKeyList.toArray()), + new GenericArrayData(reusedValueList.toArray())); + } + } + + static class StructReader extends ValueReaders.StructReader { + private final int numFields; + + protected StructReader(List> readers, Types.StructType struct, Map idToConstant) { + super(readers, struct, idToConstant); + this.numFields = readers.size(); + } + + @Override + protected InternalRow reuseOrCreate(Object reuse) { + if (reuse instanceof GenericInternalRow && ((GenericInternalRow) reuse).numFields() == numFields) { + return (InternalRow) reuse; + } + return new GenericInternalRow(numFields); + } + + @Override + protected Object get(InternalRow struct, int pos) { + return null; + } + + @Override + protected void set(InternalRow struct, int pos, Object value) { + if (value != null) { + struct.update(pos, value); + } else { + struct.setNullAt(pos); + } + } + } +} diff --git a/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/data/SparkValueWriters.java b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/data/SparkValueWriters.java new file mode 100644 index 000000000000..24a69c1d7f11 --- /dev/null +++ b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/data/SparkValueWriters.java @@ -0,0 +1,252 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iceberg.spark.data; + +import java.io.IOException; +import java.lang.reflect.Array; +import java.nio.ByteBuffer; +import java.nio.ByteOrder; +import java.util.List; +import java.util.UUID; +import org.apache.avro.io.Encoder; +import org.apache.avro.util.Utf8; +import org.apache.iceberg.avro.ValueWriter; +import org.apache.iceberg.types.TypeUtil; +import org.apache.iceberg.util.DecimalUtil; +import org.apache.spark.sql.catalyst.InternalRow; +import org.apache.spark.sql.catalyst.util.ArrayData; +import org.apache.spark.sql.catalyst.util.MapData; +import org.apache.spark.sql.types.DataType; +import org.apache.spark.sql.types.Decimal; +import org.apache.spark.unsafe.types.UTF8String; + +public class SparkValueWriters { + + private SparkValueWriters() { + } + + static ValueWriter strings() { + return StringWriter.INSTANCE; + } + + static ValueWriter uuids() { + return UUIDWriter.INSTANCE; + } + + static ValueWriter decimal(int precision, int scale) { + return new DecimalWriter(precision, scale); + } + + static ValueWriter array(ValueWriter elementWriter, DataType elementType) { + return new ArrayWriter<>(elementWriter, elementType); + } + + static ValueWriter arrayMap( + ValueWriter keyWriter, DataType keyType, ValueWriter valueWriter, DataType valueType) { + return new ArrayMapWriter<>(keyWriter, keyType, valueWriter, valueType); + } + + static ValueWriter map( + ValueWriter keyWriter, DataType keyType, ValueWriter valueWriter, DataType valueType) { + return new MapWriter<>(keyWriter, keyType, valueWriter, valueType); + } + + static ValueWriter struct(List> writers, List types) { + return new StructWriter(writers, types); + } + + private static class StringWriter implements ValueWriter { + private static final StringWriter INSTANCE = new StringWriter(); + + private StringWriter() { + } + + @Override + public void write(UTF8String s, Encoder encoder) throws IOException { + // use getBytes because it may return the backing byte array if available. + // otherwise, it copies to a new byte array, which is still cheaper than Avro + // calling toString, which incurs encoding costs + encoder.writeString(new Utf8(s.getBytes())); + } + } + + private static class UUIDWriter implements ValueWriter { + private static final ThreadLocal BUFFER = ThreadLocal.withInitial(() -> { + ByteBuffer buffer = ByteBuffer.allocate(16); + buffer.order(ByteOrder.BIG_ENDIAN); + return buffer; + }); + + private static final UUIDWriter INSTANCE = new UUIDWriter(); + + private UUIDWriter() { + } + + @Override + @SuppressWarnings("ByteBufferBackingArray") + public void write(UTF8String s, Encoder encoder) throws IOException { + // TODO: direct conversion from string to byte buffer + UUID uuid = UUID.fromString(s.toString()); + ByteBuffer buffer = BUFFER.get(); + buffer.rewind(); + buffer.putLong(uuid.getMostSignificantBits()); + buffer.putLong(uuid.getLeastSignificantBits()); + encoder.writeFixed(buffer.array()); + } + } + + private static class DecimalWriter implements ValueWriter { + private final int precision; + private final int scale; + private final ThreadLocal bytes; + + private DecimalWriter(int precision, int scale) { + this.precision = precision; + this.scale = scale; + this.bytes = ThreadLocal.withInitial(() -> new byte[TypeUtil.decimalRequiredBytes(precision)]); + } + + @Override + public void write(Decimal d, Encoder encoder) throws IOException { + encoder.writeFixed(DecimalUtil.toReusedFixLengthBytes(precision, scale, d.toJavaBigDecimal(), bytes.get())); + } + } + + private static class ArrayWriter implements ValueWriter { + private final ValueWriter elementWriter; + private final DataType elementType; + + private ArrayWriter(ValueWriter elementWriter, DataType elementType) { + this.elementWriter = elementWriter; + this.elementType = elementType; + } + + @Override + @SuppressWarnings("unchecked") + public void write(ArrayData array, Encoder encoder) throws IOException { + encoder.writeArrayStart(); + int numElements = array.numElements(); + encoder.setItemCount(numElements); + for (int i = 0; i < numElements; i += 1) { + encoder.startItem(); + elementWriter.write((T) array.get(i, elementType), encoder); + } + encoder.writeArrayEnd(); + } + } + + private static class ArrayMapWriter implements ValueWriter { + private final ValueWriter keyWriter; + private final ValueWriter valueWriter; + private final DataType keyType; + private final DataType valueType; + + private ArrayMapWriter(ValueWriter keyWriter, DataType keyType, + ValueWriter valueWriter, DataType valueType) { + this.keyWriter = keyWriter; + this.keyType = keyType; + this.valueWriter = valueWriter; + this.valueType = valueType; + } + + @Override + @SuppressWarnings("unchecked") + public void write(MapData map, Encoder encoder) throws IOException { + encoder.writeArrayStart(); + int numElements = map.numElements(); + encoder.setItemCount(numElements); + ArrayData keyArray = map.keyArray(); + ArrayData valueArray = map.valueArray(); + for (int i = 0; i < numElements; i += 1) { + encoder.startItem(); + keyWriter.write((K) keyArray.get(i, keyType), encoder); + valueWriter.write((V) valueArray.get(i, valueType), encoder); + } + encoder.writeArrayEnd(); + } + } + + private static class MapWriter implements ValueWriter { + private final ValueWriter keyWriter; + private final ValueWriter valueWriter; + private final DataType keyType; + private final DataType valueType; + + private MapWriter(ValueWriter keyWriter, DataType keyType, + ValueWriter valueWriter, DataType valueType) { + this.keyWriter = keyWriter; + this.keyType = keyType; + this.valueWriter = valueWriter; + this.valueType = valueType; + } + + @Override + @SuppressWarnings("unchecked") + public void write(MapData map, Encoder encoder) throws IOException { + encoder.writeMapStart(); + int numElements = map.numElements(); + encoder.setItemCount(numElements); + ArrayData keyArray = map.keyArray(); + ArrayData valueArray = map.valueArray(); + for (int i = 0; i < numElements; i += 1) { + encoder.startItem(); + keyWriter.write((K) keyArray.get(i, keyType), encoder); + valueWriter.write((V) valueArray.get(i, valueType), encoder); + } + encoder.writeMapEnd(); + } + } + + static class StructWriter implements ValueWriter { + private final ValueWriter[] writers; + private final DataType[] types; + + @SuppressWarnings("unchecked") + private StructWriter(List> writers, List types) { + this.writers = (ValueWriter[]) Array.newInstance(ValueWriter.class, writers.size()); + this.types = new DataType[writers.size()]; + for (int i = 0; i < writers.size(); i += 1) { + this.writers[i] = writers.get(i); + this.types[i] = types.get(i); + } + } + + ValueWriter[] writers() { + return writers; + } + + @Override + public void write(InternalRow row, Encoder encoder) throws IOException { + for (int i = 0; i < types.length; i += 1) { + if (row.isNullAt(i)) { + writers[i].write(null, encoder); + } else { + write(row, i, writers[i], encoder); + } + } + } + + @SuppressWarnings("unchecked") + private void write(InternalRow row, int pos, ValueWriter writer, Encoder encoder) + throws IOException { + writer.write((T) row.get(pos, types[pos]), encoder); + } + } +} diff --git a/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/data/vectorized/ArrowVectorAccessorFactory.java b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/data/vectorized/ArrowVectorAccessorFactory.java new file mode 100644 index 000000000000..505ace508352 --- /dev/null +++ b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/data/vectorized/ArrowVectorAccessorFactory.java @@ -0,0 +1,122 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iceberg.spark.data.vectorized; + +import java.math.BigDecimal; +import java.nio.ByteBuffer; +import org.apache.arrow.memory.ArrowBuf; +import org.apache.arrow.vector.ValueVector; +import org.apache.arrow.vector.VarCharVector; +import org.apache.arrow.vector.complex.ListVector; +import org.apache.iceberg.arrow.vectorized.GenericArrowVectorAccessorFactory; +import org.apache.spark.sql.types.Decimal; +import org.apache.spark.sql.vectorized.ArrowColumnVector; +import org.apache.spark.sql.vectorized.ColumnarArray; +import org.apache.spark.unsafe.types.UTF8String; + +final class ArrowVectorAccessorFactory + extends GenericArrowVectorAccessorFactory { + + ArrowVectorAccessorFactory() { + super(DecimalFactoryImpl::new, + StringFactoryImpl::new, + StructChildFactoryImpl::new, + ArrayFactoryImpl::new); + } + + private static final class DecimalFactoryImpl implements DecimalFactory { + @Override + public Class getGenericClass() { + return Decimal.class; + } + + @Override + public Decimal ofLong(long value, int precision, int scale) { + return Decimal.apply(value, precision, scale); + } + + @Override + public Decimal ofBigDecimal(BigDecimal value, int precision, int scale) { + return Decimal.apply(value, precision, scale); + } + } + + private static final class StringFactoryImpl implements StringFactory { + @Override + public Class getGenericClass() { + return UTF8String.class; + } + + @Override + public UTF8String ofRow(VarCharVector vector, int rowId) { + int start = vector.getStartOffset(rowId); + int end = vector.getEndOffset(rowId); + + return UTF8String.fromAddress( + null, + vector.getDataBuffer().memoryAddress() + start, + end - start); + } + + @Override + public UTF8String ofBytes(byte[] bytes) { + return UTF8String.fromBytes(bytes); + } + + @Override + public UTF8String ofByteBuffer(ByteBuffer byteBuffer) { + if (byteBuffer.hasArray()) { + return UTF8String.fromBytes( + byteBuffer.array(), byteBuffer.arrayOffset() + byteBuffer.position(), byteBuffer.remaining()); + } + byte[] bytes = new byte[byteBuffer.remaining()]; + byteBuffer.get(bytes); + return UTF8String.fromBytes(bytes); + } + } + + private static final class ArrayFactoryImpl implements ArrayFactory { + @Override + public ArrowColumnVector ofChild(ValueVector childVector) { + return new ArrowColumnVector(childVector); + } + + @Override + public ColumnarArray ofRow(ValueVector vector, ArrowColumnVector childData, int rowId) { + ArrowBuf offsets = vector.getOffsetBuffer(); + int index = rowId * ListVector.OFFSET_WIDTH; + int start = offsets.getInt(index); + int end = offsets.getInt(index + ListVector.OFFSET_WIDTH); + return new ColumnarArray(childData, start, end - start); + } + } + + private static final class StructChildFactoryImpl implements StructChildFactory { + @Override + public Class getGenericClass() { + return ArrowColumnVector.class; + } + + @Override + public ArrowColumnVector of(ValueVector childVector) { + return new ArrowColumnVector(childVector); + } + } +} diff --git a/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/data/vectorized/ArrowVectorAccessors.java b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/data/vectorized/ArrowVectorAccessors.java new file mode 100644 index 000000000000..f3b3377af2b4 --- /dev/null +++ b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/data/vectorized/ArrowVectorAccessors.java @@ -0,0 +1,40 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iceberg.spark.data.vectorized; + +import org.apache.iceberg.arrow.vectorized.ArrowVectorAccessor; +import org.apache.iceberg.arrow.vectorized.VectorHolder; +import org.apache.spark.sql.types.Decimal; +import org.apache.spark.sql.vectorized.ArrowColumnVector; +import org.apache.spark.sql.vectorized.ColumnarArray; +import org.apache.spark.unsafe.types.UTF8String; + +public class ArrowVectorAccessors { + + private static final ArrowVectorAccessorFactory factory = new ArrowVectorAccessorFactory(); + + static ArrowVectorAccessor + getVectorAccessor(VectorHolder holder) { + return factory.getVectorAccessor(holder); + } + + private ArrowVectorAccessors() { + } +} diff --git a/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/data/vectorized/ColumnVectorWithFilter.java b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/data/vectorized/ColumnVectorWithFilter.java new file mode 100644 index 000000000000..7804c02f9042 --- /dev/null +++ b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/data/vectorized/ColumnVectorWithFilter.java @@ -0,0 +1,105 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iceberg.spark.data.vectorized; + +import org.apache.iceberg.arrow.vectorized.VectorHolder; +import org.apache.iceberg.arrow.vectorized.VectorHolder.ConstantVectorHolder; +import org.apache.iceberg.types.Types; +import org.apache.spark.sql.types.Decimal; +import org.apache.spark.sql.vectorized.ColumnVector; +import org.apache.spark.sql.vectorized.ColumnarArray; +import org.apache.spark.unsafe.types.UTF8String; + +public class ColumnVectorWithFilter extends IcebergArrowColumnVector { + private final int[] rowIdMapping; + + public ColumnVectorWithFilter(VectorHolder holder, int[] rowIdMapping) { + super(holder); + this.rowIdMapping = rowIdMapping; + } + + @Override + public boolean isNullAt(int rowId) { + return nullabilityHolder().isNullAt(rowIdMapping[rowId]) == 1; + } + + @Override + public boolean getBoolean(int rowId) { + return accessor().getBoolean(rowIdMapping[rowId]); + } + + @Override + public int getInt(int rowId) { + return accessor().getInt(rowIdMapping[rowId]); + } + + @Override + public long getLong(int rowId) { + return accessor().getLong(rowIdMapping[rowId]); + } + + @Override + public float getFloat(int rowId) { + return accessor().getFloat(rowIdMapping[rowId]); + } + + @Override + public double getDouble(int rowId) { + return accessor().getDouble(rowIdMapping[rowId]); + } + + @Override + public ColumnarArray getArray(int rowId) { + if (isNullAt(rowId)) { + return null; + } + return accessor().getArray(rowIdMapping[rowId]); + } + + @Override + public Decimal getDecimal(int rowId, int precision, int scale) { + if (isNullAt(rowId)) { + return null; + } + return accessor().getDecimal(rowIdMapping[rowId], precision, scale); + } + + @Override + public UTF8String getUTF8String(int rowId) { + if (isNullAt(rowId)) { + return null; + } + return accessor().getUTF8String(rowIdMapping[rowId]); + } + + @Override + public byte[] getBinary(int rowId) { + if (isNullAt(rowId)) { + return null; + } + return accessor().getBinary(rowIdMapping[rowId]); + } + + public static ColumnVector forHolder(VectorHolder holder, int[] rowIdMapping, int numRows) { + return holder.isDummy() ? + new ConstantColumnVector(Types.IntegerType.get(), numRows, ((ConstantVectorHolder) holder).getConstant()) : + new ColumnVectorWithFilter(holder, rowIdMapping); + } +} diff --git a/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/data/vectorized/ColumnarBatchReader.java b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/data/vectorized/ColumnarBatchReader.java new file mode 100644 index 000000000000..e18058130eb2 --- /dev/null +++ b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/data/vectorized/ColumnarBatchReader.java @@ -0,0 +1,215 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iceberg.spark.data.vectorized; + +import java.util.Iterator; +import java.util.List; +import java.util.Map; +import org.apache.iceberg.arrow.vectorized.BaseBatchReader; +import org.apache.iceberg.arrow.vectorized.VectorizedArrowReader; +import org.apache.iceberg.data.DeleteFilter; +import org.apache.iceberg.deletes.PositionDeleteIndex; +import org.apache.iceberg.parquet.VectorizedReader; +import org.apache.iceberg.relocated.com.google.common.base.Preconditions; +import org.apache.iceberg.util.Pair; +import org.apache.parquet.column.page.PageReadStore; +import org.apache.parquet.hadoop.metadata.ColumnChunkMetaData; +import org.apache.parquet.hadoop.metadata.ColumnPath; +import org.apache.spark.sql.catalyst.InternalRow; +import org.apache.spark.sql.vectorized.ColumnVector; +import org.apache.spark.sql.vectorized.ColumnarBatch; + +/** + * {@link VectorizedReader} that returns Spark's {@link ColumnarBatch} to support Spark's vectorized read path. The + * {@link ColumnarBatch} returned is created by passing in the Arrow vectors populated via delegated read calls to + * {@linkplain VectorizedArrowReader VectorReader(s)}. + */ +public class ColumnarBatchReader extends BaseBatchReader { + private DeleteFilter deletes = null; + private long rowStartPosInBatch = 0; + + public ColumnarBatchReader(List> readers) { + super(readers); + } + + @Override + public void setRowGroupInfo(PageReadStore pageStore, Map metaData, + long rowPosition) { + super.setRowGroupInfo(pageStore, metaData, rowPosition); + this.rowStartPosInBatch = rowPosition; + } + + public void setDeleteFilter(DeleteFilter deleteFilter) { + this.deletes = deleteFilter; + } + + @Override + public final ColumnarBatch read(ColumnarBatch reuse, int numRowsToRead) { + if (reuse == null) { + closeVectors(); + } + + ColumnBatchLoader batchLoader = new ColumnBatchLoader(numRowsToRead); + rowStartPosInBatch += numRowsToRead; + return batchLoader.columnarBatch; + } + + private class ColumnBatchLoader { + private int[] rowIdMapping; // the rowId mapping to skip deleted rows for all column vectors inside a batch + private int numRows; + private ColumnarBatch columnarBatch; + + ColumnBatchLoader(int numRowsToRead) { + initRowIdMapping(numRowsToRead); + loadDataToColumnBatch(numRowsToRead); + } + + ColumnarBatch loadDataToColumnBatch(int numRowsToRead) { + Preconditions.checkArgument(numRowsToRead > 0, "Invalid number of rows to read: %s", numRowsToRead); + ColumnVector[] arrowColumnVectors = readDataToColumnVectors(numRowsToRead); + + columnarBatch = new ColumnarBatch(arrowColumnVectors); + columnarBatch.setNumRows(numRows); + + if (hasEqDeletes()) { + applyEqDelete(); + } + return columnarBatch; + } + + ColumnVector[] readDataToColumnVectors(int numRowsToRead) { + ColumnVector[] arrowColumnVectors = new ColumnVector[readers.length]; + + for (int i = 0; i < readers.length; i += 1) { + vectorHolders[i] = readers[i].read(vectorHolders[i], numRowsToRead); + int numRowsInVector = vectorHolders[i].numValues(); + Preconditions.checkState( + numRowsInVector == numRowsToRead, + "Number of rows in the vector %s didn't match expected %s ", numRowsInVector, + numRowsToRead); + + arrowColumnVectors[i] = hasDeletes() ? + ColumnVectorWithFilter.forHolder(vectorHolders[i], rowIdMapping, numRows) : + IcebergArrowColumnVector.forHolder(vectorHolders[i], numRowsInVector); + } + return arrowColumnVectors; + } + + boolean hasDeletes() { + return rowIdMapping != null; + } + + boolean hasEqDeletes() { + return deletes != null && deletes.hasEqDeletes(); + } + + void initRowIdMapping(int numRowsToRead) { + Pair posDeleteRowIdMapping = posDelRowIdMapping(numRowsToRead); + if (posDeleteRowIdMapping != null) { + rowIdMapping = posDeleteRowIdMapping.first(); + numRows = posDeleteRowIdMapping.second(); + } else { + numRows = numRowsToRead; + rowIdMapping = initEqDeleteRowIdMapping(numRowsToRead); + } + } + + Pair posDelRowIdMapping(int numRowsToRead) { + if (deletes != null && deletes.hasPosDeletes()) { + return buildPosDelRowIdMapping(deletes.deletedRowPositions(), numRowsToRead); + } else { + return null; + } + } + + /** + * Build a row id mapping inside a batch, which skips deleted rows. Here is an example of how we delete 2 rows in a + * batch with 8 rows in total. + * [0,1,2,3,4,5,6,7] -- Original status of the row id mapping array + * Position delete 2, 6 + * [0,1,3,4,5,7,-,-] -- After applying position deletes [Set Num records to 6] + * + * @param deletedRowPositions a set of deleted row positions + * @param numRowsToRead the num of rows + * @return the mapping array and the new num of rows in a batch, null if no row is deleted + */ + Pair buildPosDelRowIdMapping(PositionDeleteIndex deletedRowPositions, int numRowsToRead) { + if (deletedRowPositions == null) { + return null; + } + + int[] posDelRowIdMapping = new int[numRowsToRead]; + int originalRowId = 0; + int currentRowId = 0; + while (originalRowId < numRowsToRead) { + if (!deletedRowPositions.isDeleted(originalRowId + rowStartPosInBatch)) { + posDelRowIdMapping[currentRowId] = originalRowId; + currentRowId++; + } + originalRowId++; + } + + if (currentRowId == numRowsToRead) { + // there is no delete in this batch + return null; + } else { + return Pair.of(posDelRowIdMapping, currentRowId); + } + } + + int[] initEqDeleteRowIdMapping(int numRowsToRead) { + int[] eqDeleteRowIdMapping = null; + if (hasEqDeletes()) { + eqDeleteRowIdMapping = new int[numRowsToRead]; + for (int i = 0; i < numRowsToRead; i++) { + eqDeleteRowIdMapping[i] = i; + } + } + return eqDeleteRowIdMapping; + } + + /** + * Filter out the equality deleted rows. Here is an example, + * [0,1,2,3,4,5,6,7] -- Original status of the row id mapping array + * Position delete 2, 6 + * [0,1,3,4,5,7,-,-] -- After applying position deletes [Set Num records to 6] + * Equality delete 1 <= x <= 3 + * [0,4,5,7,-,-,-,-] -- After applying equality deletes [Set Num records to 4] + */ + void applyEqDelete() { + Iterator it = columnarBatch.rowIterator(); + int rowId = 0; + int currentRowId = 0; + while (it.hasNext()) { + InternalRow row = it.next(); + if (deletes.eqDeletedRowFilter().test(row)) { + // the row is NOT deleted + // skip deleted rows by pointing to the next undeleted row Id + rowIdMapping[currentRowId] = rowIdMapping[rowId]; + currentRowId++; + } + + rowId++; + } + + columnarBatch.setNumRows(currentRowId); + } + } +} diff --git a/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/data/vectorized/ConstantColumnVector.java b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/data/vectorized/ConstantColumnVector.java new file mode 100644 index 000000000000..3cdea65b2877 --- /dev/null +++ b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/data/vectorized/ConstantColumnVector.java @@ -0,0 +1,124 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iceberg.spark.data.vectorized; + +import org.apache.iceberg.spark.SparkSchemaUtil; +import org.apache.iceberg.types.Type; +import org.apache.spark.sql.types.Decimal; +import org.apache.spark.sql.vectorized.ColumnVector; +import org.apache.spark.sql.vectorized.ColumnarArray; +import org.apache.spark.sql.vectorized.ColumnarMap; +import org.apache.spark.unsafe.types.UTF8String; + +class ConstantColumnVector extends ColumnVector { + + private final Object constant; + private final int batchSize; + + ConstantColumnVector(Type type, int batchSize, Object constant) { + super(SparkSchemaUtil.convert(type)); + this.constant = constant; + this.batchSize = batchSize; + } + + @Override + public void close() { + } + + @Override + public boolean hasNull() { + return constant == null; + } + + @Override + public int numNulls() { + return constant == null ? batchSize : 0; + } + + @Override + public boolean isNullAt(int rowId) { + return constant == null; + } + + @Override + public boolean getBoolean(int rowId) { + return (boolean) constant; + } + + @Override + public byte getByte(int rowId) { + return (byte) constant; + } + + @Override + public short getShort(int rowId) { + return (short) constant; + } + + @Override + public int getInt(int rowId) { + return (int) constant; + } + + @Override + public long getLong(int rowId) { + return (long) constant; + } + + @Override + public float getFloat(int rowId) { + return (float) constant; + } + + @Override + public double getDouble(int rowId) { + return (double) constant; + } + + @Override + public ColumnarArray getArray(int rowId) { + throw new UnsupportedOperationException("ConstantColumnVector only supports primitives"); + } + + @Override + public ColumnarMap getMap(int ordinal) { + throw new UnsupportedOperationException("ConstantColumnVector only supports primitives"); + } + + @Override + public Decimal getDecimal(int rowId, int precision, int scale) { + return (Decimal) constant; + } + + @Override + public UTF8String getUTF8String(int rowId) { + return (UTF8String) constant; + } + + @Override + public byte[] getBinary(int rowId) { + return (byte[]) constant; + } + + @Override + public ColumnVector getChild(int ordinal) { + throw new UnsupportedOperationException("ConstantColumnVector only supports primitives"); + } +} diff --git a/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/data/vectorized/IcebergArrowColumnVector.java b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/data/vectorized/IcebergArrowColumnVector.java new file mode 100644 index 000000000000..c5b7853d5353 --- /dev/null +++ b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/data/vectorized/IcebergArrowColumnVector.java @@ -0,0 +1,165 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iceberg.spark.data.vectorized; + +import org.apache.iceberg.arrow.vectorized.ArrowVectorAccessor; +import org.apache.iceberg.arrow.vectorized.NullabilityHolder; +import org.apache.iceberg.arrow.vectorized.VectorHolder; +import org.apache.iceberg.arrow.vectorized.VectorHolder.ConstantVectorHolder; +import org.apache.iceberg.spark.SparkSchemaUtil; +import org.apache.iceberg.types.Types; +import org.apache.spark.sql.types.Decimal; +import org.apache.spark.sql.vectorized.ArrowColumnVector; +import org.apache.spark.sql.vectorized.ColumnVector; +import org.apache.spark.sql.vectorized.ColumnarArray; +import org.apache.spark.sql.vectorized.ColumnarMap; +import org.apache.spark.unsafe.types.UTF8String; + +/** + * Implementation of Spark's {@link ColumnVector} interface. The code for this class is heavily inspired from Spark's + * {@link ArrowColumnVector} The main difference is in how nullability checks are made in this class by relying on + * {@link NullabilityHolder} instead of the validity vector in the Arrow vector. + */ +public class IcebergArrowColumnVector extends ColumnVector { + + private final ArrowVectorAccessor accessor; + private final NullabilityHolder nullabilityHolder; + + public IcebergArrowColumnVector(VectorHolder holder) { + super(SparkSchemaUtil.convert(holder.icebergType())); + this.nullabilityHolder = holder.nullabilityHolder(); + this.accessor = ArrowVectorAccessors.getVectorAccessor(holder); + } + + protected ArrowVectorAccessor accessor() { + return accessor; + } + + protected NullabilityHolder nullabilityHolder() { + return nullabilityHolder; + } + + @Override + public void close() { + accessor.close(); + } + + @Override + public boolean hasNull() { + return nullabilityHolder.hasNulls(); + } + + @Override + public int numNulls() { + return nullabilityHolder.numNulls(); + } + + @Override + public boolean isNullAt(int rowId) { + return nullabilityHolder.isNullAt(rowId) == 1; + } + + @Override + public boolean getBoolean(int rowId) { + return accessor.getBoolean(rowId); + } + + @Override + public byte getByte(int rowId) { + throw new UnsupportedOperationException("Unsupported type - byte"); + } + + @Override + public short getShort(int rowId) { + throw new UnsupportedOperationException("Unsupported type - short"); + } + + @Override + public int getInt(int rowId) { + return accessor.getInt(rowId); + } + + @Override + public long getLong(int rowId) { + return accessor.getLong(rowId); + } + + @Override + public float getFloat(int rowId) { + return accessor.getFloat(rowId); + } + + @Override + public double getDouble(int rowId) { + return accessor.getDouble(rowId); + } + + @Override + public ColumnarArray getArray(int rowId) { + if (isNullAt(rowId)) { + return null; + } + return accessor.getArray(rowId); + } + + @Override + public ColumnarMap getMap(int rowId) { + throw new UnsupportedOperationException("Unsupported type - map"); + } + + @Override + public Decimal getDecimal(int rowId, int precision, int scale) { + if (isNullAt(rowId)) { + return null; + } + return accessor.getDecimal(rowId, precision, scale); + } + + @Override + public UTF8String getUTF8String(int rowId) { + if (isNullAt(rowId)) { + return null; + } + return accessor.getUTF8String(rowId); + } + + @Override + public byte[] getBinary(int rowId) { + if (isNullAt(rowId)) { + return null; + } + return accessor.getBinary(rowId); + } + + @Override + public ArrowColumnVector getChild(int ordinal) { + return accessor.childColumn(ordinal); + } + + static ColumnVector forHolder(VectorHolder holder, int numRows) { + return holder.isDummy() ? + new ConstantColumnVector(Types.IntegerType.get(), numRows, ((ConstantVectorHolder) holder).getConstant()) : + new IcebergArrowColumnVector(holder); + } + + public ArrowVectorAccessor vectorAccessor() { + return accessor; + } +} diff --git a/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/data/vectorized/RowPositionColumnVector.java b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/data/vectorized/RowPositionColumnVector.java new file mode 100644 index 000000000000..58db4eb55d04 --- /dev/null +++ b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/data/vectorized/RowPositionColumnVector.java @@ -0,0 +1,122 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iceberg.spark.data.vectorized; + +import org.apache.iceberg.spark.SparkSchemaUtil; +import org.apache.iceberg.types.Types; +import org.apache.spark.sql.types.Decimal; +import org.apache.spark.sql.vectorized.ColumnVector; +import org.apache.spark.sql.vectorized.ColumnarArray; +import org.apache.spark.sql.vectorized.ColumnarMap; +import org.apache.spark.unsafe.types.UTF8String; + +public class RowPositionColumnVector extends ColumnVector { + + private final long batchOffsetInFile; + + RowPositionColumnVector(long batchOffsetInFile) { + super(SparkSchemaUtil.convert(Types.LongType.get())); + this.batchOffsetInFile = batchOffsetInFile; + } + + @Override + public void close() { + } + + @Override + public boolean hasNull() { + return false; + } + + @Override + public int numNulls() { + return 0; + } + + @Override + public boolean isNullAt(int rowId) { + return false; + } + + @Override + public boolean getBoolean(int rowId) { + throw new UnsupportedOperationException(); + } + + @Override + public byte getByte(int rowId) { + throw new UnsupportedOperationException(); + } + + @Override + public short getShort(int rowId) { + throw new UnsupportedOperationException(); + } + + @Override + public int getInt(int rowId) { + throw new UnsupportedOperationException(); + } + + @Override + public long getLong(int rowId) { + return batchOffsetInFile + rowId; + } + + @Override + public float getFloat(int rowId) { + throw new UnsupportedOperationException(); + } + + @Override + public double getDouble(int rowId) { + throw new UnsupportedOperationException(); + } + + @Override + public ColumnarArray getArray(int rowId) { + throw new UnsupportedOperationException(); + } + + @Override + public ColumnarMap getMap(int ordinal) { + throw new UnsupportedOperationException(); + } + + @Override + public Decimal getDecimal(int rowId, int precision, int scale) { + throw new UnsupportedOperationException(); + } + + @Override + public UTF8String getUTF8String(int rowId) { + throw new UnsupportedOperationException(); + } + + @Override + public byte[] getBinary(int rowId) { + throw new UnsupportedOperationException(); + } + + @Override + public ColumnVector getChild(int ordinal) { + throw new UnsupportedOperationException(); + } +} diff --git a/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/data/vectorized/VectorizedSparkOrcReaders.java b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/data/vectorized/VectorizedSparkOrcReaders.java new file mode 100644 index 000000000000..418c25993a7e --- /dev/null +++ b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/data/vectorized/VectorizedSparkOrcReaders.java @@ -0,0 +1,427 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iceberg.spark.data.vectorized; + +import java.util.List; +import java.util.Map; +import java.util.stream.IntStream; +import org.apache.iceberg.MetadataColumns; +import org.apache.iceberg.Schema; +import org.apache.iceberg.orc.OrcBatchReader; +import org.apache.iceberg.orc.OrcSchemaWithTypeVisitor; +import org.apache.iceberg.orc.OrcValueReader; +import org.apache.iceberg.orc.OrcValueReaders; +import org.apache.iceberg.relocated.com.google.common.collect.Lists; +import org.apache.iceberg.spark.SparkSchemaUtil; +import org.apache.iceberg.spark.data.SparkOrcValueReaders; +import org.apache.iceberg.types.Type; +import org.apache.iceberg.types.Types; +import org.apache.orc.TypeDescription; +import org.apache.orc.storage.ql.exec.vector.ListColumnVector; +import org.apache.orc.storage.ql.exec.vector.MapColumnVector; +import org.apache.orc.storage.ql.exec.vector.StructColumnVector; +import org.apache.orc.storage.ql.exec.vector.VectorizedRowBatch; +import org.apache.spark.sql.types.Decimal; +import org.apache.spark.sql.vectorized.ColumnVector; +import org.apache.spark.sql.vectorized.ColumnarArray; +import org.apache.spark.sql.vectorized.ColumnarBatch; +import org.apache.spark.sql.vectorized.ColumnarMap; +import org.apache.spark.unsafe.types.UTF8String; + +public class VectorizedSparkOrcReaders { + + private VectorizedSparkOrcReaders() { + } + + public static OrcBatchReader buildReader(Schema expectedSchema, TypeDescription fileSchema, + Map idToConstant) { + Converter converter = OrcSchemaWithTypeVisitor.visit(expectedSchema, fileSchema, new ReadBuilder(idToConstant)); + + return new OrcBatchReader() { + private long batchOffsetInFile; + + @Override + public ColumnarBatch read(VectorizedRowBatch batch) { + BaseOrcColumnVector cv = (BaseOrcColumnVector) converter.convert(new StructColumnVector(batch.size, batch.cols), + batch.size, batchOffsetInFile); + ColumnarBatch columnarBatch = new ColumnarBatch(IntStream.range(0, expectedSchema.columns().size()) + .mapToObj(cv::getChild) + .toArray(ColumnVector[]::new)); + columnarBatch.setNumRows(batch.size); + return columnarBatch; + } + + @Override + public void setBatchContext(long batchOffsetInFile) { + this.batchOffsetInFile = batchOffsetInFile; + } + }; + } + + private interface Converter { + ColumnVector convert(org.apache.orc.storage.ql.exec.vector.ColumnVector columnVector, int batchSize, + long batchOffsetInFile); + } + + private static class ReadBuilder extends OrcSchemaWithTypeVisitor { + private final Map idToConstant; + + private ReadBuilder(Map idToConstant) { + this.idToConstant = idToConstant; + } + + @Override + public Converter record(Types.StructType iStruct, TypeDescription record, List names, + List fields) { + return new StructConverter(iStruct, fields, idToConstant); + } + + @Override + public Converter list(Types.ListType iList, TypeDescription array, Converter element) { + return new ArrayConverter(iList, element); + } + + @Override + public Converter map(Types.MapType iMap, TypeDescription map, Converter key, Converter value) { + return new MapConverter(iMap, key, value); + } + + @Override + public Converter primitive(Type.PrimitiveType iPrimitive, TypeDescription primitive) { + final OrcValueReader primitiveValueReader; + switch (primitive.getCategory()) { + case BOOLEAN: + primitiveValueReader = OrcValueReaders.booleans(); + break; + case BYTE: + // Iceberg does not have a byte type. Use int + case SHORT: + // Iceberg does not have a short type. Use int + case DATE: + case INT: + primitiveValueReader = OrcValueReaders.ints(); + break; + case LONG: + primitiveValueReader = OrcValueReaders.longs(); + break; + case FLOAT: + primitiveValueReader = OrcValueReaders.floats(); + break; + case DOUBLE: + primitiveValueReader = OrcValueReaders.doubles(); + break; + case TIMESTAMP_INSTANT: + case TIMESTAMP: + primitiveValueReader = SparkOrcValueReaders.timestampTzs(); + break; + case DECIMAL: + primitiveValueReader = SparkOrcValueReaders.decimals(primitive.getPrecision(), primitive.getScale()); + break; + case CHAR: + case VARCHAR: + case STRING: + primitiveValueReader = SparkOrcValueReaders.utf8String(); + break; + case BINARY: + primitiveValueReader = OrcValueReaders.bytes(); + break; + default: + throw new IllegalArgumentException("Unhandled type " + primitive); + } + return (columnVector, batchSize, batchOffsetInFile) -> + new PrimitiveOrcColumnVector(iPrimitive, batchSize, columnVector, primitiveValueReader, batchOffsetInFile); + } + } + + private abstract static class BaseOrcColumnVector extends ColumnVector { + private final org.apache.orc.storage.ql.exec.vector.ColumnVector vector; + private final int batchSize; + private Integer numNulls; + + BaseOrcColumnVector(Type type, int batchSize, org.apache.orc.storage.ql.exec.vector.ColumnVector vector) { + super(SparkSchemaUtil.convert(type)); + this.vector = vector; + this.batchSize = batchSize; + } + + @Override + public void close() { + } + + @Override + public boolean hasNull() { + return !vector.noNulls; + } + + @Override + public int numNulls() { + if (numNulls == null) { + numNulls = numNullsHelper(); + } + return numNulls; + } + + private int numNullsHelper() { + if (vector.isRepeating) { + if (vector.isNull[0]) { + return batchSize; + } else { + return 0; + } + } else if (vector.noNulls) { + return 0; + } else { + int count = 0; + for (int i = 0; i < batchSize; i++) { + if (vector.isNull[i]) { + count++; + } + } + return count; + } + } + + protected int getRowIndex(int rowId) { + return vector.isRepeating ? 0 : rowId; + } + + @Override + public boolean isNullAt(int rowId) { + return vector.isNull[getRowIndex(rowId)]; + } + + @Override + public boolean getBoolean(int rowId) { + throw new UnsupportedOperationException(); + } + + @Override + public byte getByte(int rowId) { + throw new UnsupportedOperationException(); + } + + @Override + public short getShort(int rowId) { + throw new UnsupportedOperationException(); + } + + @Override + public int getInt(int rowId) { + throw new UnsupportedOperationException(); + } + + @Override + public long getLong(int rowId) { + throw new UnsupportedOperationException(); + } + + @Override + public float getFloat(int rowId) { + throw new UnsupportedOperationException(); + } + + @Override + public double getDouble(int rowId) { + throw new UnsupportedOperationException(); + } + + @Override + public Decimal getDecimal(int rowId, int precision, int scale) { + throw new UnsupportedOperationException(); + } + + @Override + public UTF8String getUTF8String(int rowId) { + throw new UnsupportedOperationException(); + } + + @Override + public byte[] getBinary(int rowId) { + throw new UnsupportedOperationException(); + } + + @Override + public ColumnarArray getArray(int rowId) { + throw new UnsupportedOperationException(); + } + + @Override + public ColumnarMap getMap(int rowId) { + throw new UnsupportedOperationException(); + } + + @Override + public ColumnVector getChild(int ordinal) { + throw new UnsupportedOperationException(); + } + } + + private static class PrimitiveOrcColumnVector extends BaseOrcColumnVector { + private final org.apache.orc.storage.ql.exec.vector.ColumnVector vector; + private final OrcValueReader primitiveValueReader; + private final long batchOffsetInFile; + + PrimitiveOrcColumnVector(Type type, int batchSize, org.apache.orc.storage.ql.exec.vector.ColumnVector vector, + OrcValueReader primitiveValueReader, long batchOffsetInFile) { + super(type, batchSize, vector); + this.vector = vector; + this.primitiveValueReader = primitiveValueReader; + this.batchOffsetInFile = batchOffsetInFile; + } + + @Override + public boolean getBoolean(int rowId) { + return (Boolean) primitiveValueReader.read(vector, rowId); + } + + @Override + public int getInt(int rowId) { + return (Integer) primitiveValueReader.read(vector, rowId); + } + + @Override + public long getLong(int rowId) { + return (Long) primitiveValueReader.read(vector, rowId); + } + + @Override + public float getFloat(int rowId) { + return (Float) primitiveValueReader.read(vector, rowId); + } + + @Override + public double getDouble(int rowId) { + return (Double) primitiveValueReader.read(vector, rowId); + } + + @Override + public Decimal getDecimal(int rowId, int precision, int scale) { + // TODO: Is it okay to assume that (precision,scale) parameters == (precision,scale) of the decimal type + // and return a Decimal with (precision,scale) of the decimal type? + return (Decimal) primitiveValueReader.read(vector, rowId); + } + + @Override + public UTF8String getUTF8String(int rowId) { + return (UTF8String) primitiveValueReader.read(vector, rowId); + } + + @Override + public byte[] getBinary(int rowId) { + return (byte[]) primitiveValueReader.read(vector, rowId); + } + } + + private static class ArrayConverter implements Converter { + private final Types.ListType listType; + private final Converter elementConverter; + + private ArrayConverter(Types.ListType listType, Converter elementConverter) { + this.listType = listType; + this.elementConverter = elementConverter; + } + + @Override + public ColumnVector convert(org.apache.orc.storage.ql.exec.vector.ColumnVector vector, int batchSize, + long batchOffsetInFile) { + ListColumnVector listVector = (ListColumnVector) vector; + ColumnVector elementVector = elementConverter.convert(listVector.child, batchSize, batchOffsetInFile); + + return new BaseOrcColumnVector(listType, batchSize, vector) { + @Override + public ColumnarArray getArray(int rowId) { + int index = getRowIndex(rowId); + return new ColumnarArray(elementVector, (int) listVector.offsets[index], (int) listVector.lengths[index]); + } + }; + } + } + + private static class MapConverter implements Converter { + private final Types.MapType mapType; + private final Converter keyConverter; + private final Converter valueConverter; + + private MapConverter(Types.MapType mapType, Converter keyConverter, Converter valueConverter) { + this.mapType = mapType; + this.keyConverter = keyConverter; + this.valueConverter = valueConverter; + } + + @Override + public ColumnVector convert(org.apache.orc.storage.ql.exec.vector.ColumnVector vector, int batchSize, + long batchOffsetInFile) { + MapColumnVector mapVector = (MapColumnVector) vector; + ColumnVector keyVector = keyConverter.convert(mapVector.keys, batchSize, batchOffsetInFile); + ColumnVector valueVector = valueConverter.convert(mapVector.values, batchSize, batchOffsetInFile); + + return new BaseOrcColumnVector(mapType, batchSize, vector) { + @Override + public ColumnarMap getMap(int rowId) { + int index = getRowIndex(rowId); + return new ColumnarMap(keyVector, valueVector, (int) mapVector.offsets[index], + (int) mapVector.lengths[index]); + } + }; + } + } + + private static class StructConverter implements Converter { + private final Types.StructType structType; + private final List fieldConverters; + private final Map idToConstant; + + private StructConverter(Types.StructType structType, List fieldConverters, + Map idToConstant) { + this.structType = structType; + this.fieldConverters = fieldConverters; + this.idToConstant = idToConstant; + } + + @Override + public ColumnVector convert(org.apache.orc.storage.ql.exec.vector.ColumnVector vector, int batchSize, + long batchOffsetInFile) { + StructColumnVector structVector = (StructColumnVector) vector; + List fields = structType.fields(); + List fieldVectors = Lists.newArrayListWithExpectedSize(fields.size()); + for (int pos = 0, vectorIndex = 0; pos < fields.size(); pos += 1) { + Types.NestedField field = fields.get(pos); + if (idToConstant.containsKey(field.fieldId())) { + fieldVectors.add(new ConstantColumnVector(field.type(), batchSize, idToConstant.get(field.fieldId()))); + } else if (field.equals(MetadataColumns.ROW_POSITION)) { + fieldVectors.add(new RowPositionColumnVector(batchOffsetInFile)); + } else if (field.equals(MetadataColumns.IS_DELETED)) { + fieldVectors.add(new ConstantColumnVector(field.type(), batchSize, false)); + } else { + fieldVectors.add(fieldConverters.get(vectorIndex) + .convert(structVector.fields[vectorIndex], batchSize, batchOffsetInFile)); + vectorIndex++; + } + } + + return new BaseOrcColumnVector(structType, batchSize, vector) { + @Override + public ColumnVector getChild(int ordinal) { + return fieldVectors.get(ordinal); + } + }; + } + } +} diff --git a/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/data/vectorized/VectorizedSparkParquetReaders.java b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/data/vectorized/VectorizedSparkParquetReaders.java new file mode 100644 index 000000000000..020b35f52844 --- /dev/null +++ b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/data/vectorized/VectorizedSparkParquetReaders.java @@ -0,0 +1,92 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iceberg.spark.data.vectorized; + +import java.util.List; +import java.util.Map; +import java.util.function.Function; +import org.apache.iceberg.Schema; +import org.apache.iceberg.arrow.vectorized.VectorizedReaderBuilder; +import org.apache.iceberg.data.DeleteFilter; +import org.apache.iceberg.parquet.TypeWithSchemaVisitor; +import org.apache.iceberg.parquet.VectorizedReader; +import org.apache.iceberg.relocated.com.google.common.collect.Maps; +import org.apache.parquet.schema.MessageType; +import org.apache.spark.sql.catalyst.InternalRow; + +public class VectorizedSparkParquetReaders { + + private VectorizedSparkParquetReaders() { + } + + public static ColumnarBatchReader buildReader( + Schema expectedSchema, + MessageType fileSchema, + boolean setArrowValidityVector) { + return buildReader(expectedSchema, fileSchema, setArrowValidityVector, Maps.newHashMap()); + } + + public static ColumnarBatchReader buildReader( + Schema expectedSchema, + MessageType fileSchema, + boolean setArrowValidityVector, + Map idToConstant) { + return (ColumnarBatchReader) + TypeWithSchemaVisitor.visit(expectedSchema.asStruct(), fileSchema, + new VectorizedReaderBuilder( + expectedSchema, fileSchema, setArrowValidityVector, + idToConstant, ColumnarBatchReader::new)); + } + + public static ColumnarBatchReader buildReader(Schema expectedSchema, + MessageType fileSchema, + boolean setArrowValidityVector, + Map idToConstant, + DeleteFilter deleteFilter) { + return (ColumnarBatchReader) + TypeWithSchemaVisitor.visit(expectedSchema.asStruct(), fileSchema, + new ReaderBuilder( + expectedSchema, fileSchema, setArrowValidityVector, + idToConstant, ColumnarBatchReader::new, deleteFilter)); + } + + private static class ReaderBuilder extends VectorizedReaderBuilder { + private final DeleteFilter deleteFilter; + + ReaderBuilder(Schema expectedSchema, + MessageType parquetSchema, + boolean setArrowValidityVector, + Map idToConstant, + Function>, VectorizedReader> readerFactory, + DeleteFilter deleteFilter) { + super(expectedSchema, parquetSchema, setArrowValidityVector, idToConstant, readerFactory); + this.deleteFilter = deleteFilter; + } + + @Override + protected VectorizedReader vectorizedReader(List> reorderedFields) { + VectorizedReader reader = super.vectorizedReader(reorderedFields); + if (deleteFilter != null) { + ((ColumnarBatchReader) reader).setDeleteFilter(deleteFilter); + } + return reader; + } + } +} diff --git a/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/procedures/AddFilesProcedure.java b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/procedures/AddFilesProcedure.java new file mode 100644 index 000000000000..a902887f2965 --- /dev/null +++ b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/procedures/AddFilesProcedure.java @@ -0,0 +1,243 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iceberg.spark.procedures; + +import java.util.Collections; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.stream.Collectors; +import org.apache.hadoop.fs.Path; +import org.apache.iceberg.PartitionField; +import org.apache.iceberg.Snapshot; +import org.apache.iceberg.SnapshotSummary; +import org.apache.iceberg.Table; +import org.apache.iceberg.TableProperties; +import org.apache.iceberg.mapping.MappingUtil; +import org.apache.iceberg.mapping.NameMapping; +import org.apache.iceberg.mapping.NameMappingParser; +import org.apache.iceberg.relocated.com.google.common.base.Joiner; +import org.apache.iceberg.relocated.com.google.common.base.Preconditions; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableList; +import org.apache.iceberg.relocated.com.google.common.collect.Maps; +import org.apache.iceberg.spark.Spark3Util; +import org.apache.iceberg.spark.SparkTableUtil; +import org.apache.iceberg.spark.SparkTableUtil.SparkPartition; +import org.apache.spark.sql.catalyst.InternalRow; +import org.apache.spark.sql.catalyst.TableIdentifier; +import org.apache.spark.sql.connector.catalog.CatalogPlugin; +import org.apache.spark.sql.connector.catalog.Identifier; +import org.apache.spark.sql.connector.catalog.TableCatalog; +import org.apache.spark.sql.connector.iceberg.catalog.ProcedureParameter; +import org.apache.spark.sql.types.DataTypes; +import org.apache.spark.sql.types.Metadata; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; +import scala.runtime.BoxedUnit; + +class AddFilesProcedure extends BaseProcedure { + + private static final Joiner.MapJoiner MAP_JOINER = Joiner.on(",").withKeyValueSeparator("="); + + private static final ProcedureParameter[] PARAMETERS = new ProcedureParameter[]{ + ProcedureParameter.required("table", DataTypes.StringType), + ProcedureParameter.required("source_table", DataTypes.StringType), + ProcedureParameter.optional("partition_filter", STRING_MAP), + ProcedureParameter.optional("check_duplicate_files", DataTypes.BooleanType) + }; + + private static final StructType OUTPUT_TYPE = new StructType(new StructField[]{ + new StructField("added_files_count", DataTypes.LongType, false, Metadata.empty()) + }); + + private AddFilesProcedure(TableCatalog tableCatalog) { + super(tableCatalog); + } + + public static SparkProcedures.ProcedureBuilder builder() { + return new BaseProcedure.Builder() { + @Override + protected AddFilesProcedure doBuild() { + return new AddFilesProcedure(tableCatalog()); + } + }; + } + + @Override + public ProcedureParameter[] parameters() { + return PARAMETERS; + } + + @Override + public StructType outputType() { + return OUTPUT_TYPE; + } + + @Override + public InternalRow[] call(InternalRow args) { + Identifier tableIdent = toIdentifier(args.getString(0), PARAMETERS[0].name()); + + CatalogPlugin sessionCat = spark().sessionState().catalogManager().v2SessionCatalog(); + Identifier sourceIdent = toCatalogAndIdentifier(args.getString(1), PARAMETERS[1].name(), sessionCat).identifier(); + + Map partitionFilter = Maps.newHashMap(); + if (!args.isNullAt(2)) { + args.getMap(2).foreach(DataTypes.StringType, DataTypes.StringType, + (k, v) -> { + partitionFilter.put(k.toString(), v.toString()); + return BoxedUnit.UNIT; + }); + } + + boolean checkDuplicateFiles; + if (args.isNullAt(3)) { + checkDuplicateFiles = true; + } else { + checkDuplicateFiles = args.getBoolean(3); + } + + long addedFilesCount = importToIceberg(tableIdent, sourceIdent, partitionFilter, checkDuplicateFiles); + return new InternalRow[]{newInternalRow(addedFilesCount)}; + } + + private boolean isFileIdentifier(Identifier ident) { + String[] namespace = ident.namespace(); + return namespace.length == 1 && + (namespace[0].equalsIgnoreCase("orc") || + namespace[0].equalsIgnoreCase("parquet") || + namespace[0].equalsIgnoreCase("avro")); + } + + private long importToIceberg(Identifier destIdent, Identifier sourceIdent, Map partitionFilter, + boolean checkDuplicateFiles) { + return modifyIcebergTable(destIdent, table -> { + + validatePartitionSpec(table, partitionFilter); + ensureNameMappingPresent(table); + + if (isFileIdentifier(sourceIdent)) { + Path sourcePath = new Path(sourceIdent.name()); + String format = sourceIdent.namespace()[0]; + importFileTable(table, sourcePath, format, partitionFilter, checkDuplicateFiles); + } else { + importCatalogTable(table, sourceIdent, partitionFilter, checkDuplicateFiles); + } + + Snapshot snapshot = table.currentSnapshot(); + return Long.parseLong(snapshot.summary().getOrDefault(SnapshotSummary.ADDED_FILES_PROP, "0")); + }); + } + + private static void ensureNameMappingPresent(Table table) { + if (table.properties().get(TableProperties.DEFAULT_NAME_MAPPING) == null) { + // Forces Name based resolution instead of position based resolution + NameMapping mapping = MappingUtil.create(table.schema()); + String mappingJson = NameMappingParser.toJson(mapping); + table.updateProperties() + .set(TableProperties.DEFAULT_NAME_MAPPING, mappingJson) + .commit(); + } + } + + private void importFileTable(Table table, Path tableLocation, String format, Map partitionFilter, + boolean checkDuplicateFiles) { + // List Partitions via Spark InMemory file search interface + List partitions = + Spark3Util.getPartitions(spark(), tableLocation, format, partitionFilter); + + if (table.spec().isUnpartitioned()) { + Preconditions.checkArgument(partitions.isEmpty(), "Cannot add partitioned files to an unpartitioned table"); + Preconditions.checkArgument(partitionFilter.isEmpty(), "Cannot use a partition filter when importing" + + "to an unpartitioned table"); + + // Build a Global Partition for the source + SparkPartition partition = new SparkPartition(Collections.emptyMap(), tableLocation.toString(), format); + importPartitions(table, ImmutableList.of(partition), checkDuplicateFiles); + } else { + Preconditions.checkArgument(!partitions.isEmpty(), + "Cannot find any matching partitions in table %s", partitions); + importPartitions(table, partitions, checkDuplicateFiles); + } + } + + private void importCatalogTable(Table table, Identifier sourceIdent, Map partitionFilter, + boolean checkDuplicateFiles) { + String stagingLocation = getMetadataLocation(table); + TableIdentifier sourceTableIdentifier = Spark3Util.toV1TableIdentifier(sourceIdent); + SparkTableUtil.importSparkTable(spark(), sourceTableIdentifier, table, stagingLocation, partitionFilter, + checkDuplicateFiles); + } + + private void importPartitions(Table table, List partitions, + boolean checkDuplicateFiles) { + String stagingLocation = getMetadataLocation(table); + SparkTableUtil.importSparkPartitions(spark(), partitions, table, table.spec(), stagingLocation, + checkDuplicateFiles); + } + + private String getMetadataLocation(Table table) { + String defaultValue = table.location() + "/metadata"; + return table.properties().getOrDefault(TableProperties.WRITE_METADATA_LOCATION, defaultValue); + } + + @Override + public String description() { + return "AddFiles"; + } + + private void validatePartitionSpec(Table table, Map partitionFilter) { + List partitionFields = table.spec().fields(); + Set partitionNames = table.spec().fields().stream().map(PartitionField::name).collect(Collectors.toSet()); + + boolean tablePartitioned = !partitionFields.isEmpty(); + boolean partitionSpecPassed = !partitionFilter.isEmpty(); + + // Check for any non-identity partition columns + List nonIdentityFields = partitionFields.stream() + .filter(x -> !x.transform().isIdentity()) + .collect(Collectors.toList()); + Preconditions.checkArgument(nonIdentityFields.isEmpty(), + "Cannot add data files to target table %s because that table is partitioned and contains non-identity" + + "partition transforms which will not be compatible. Found non-identity fields %s", + table.name(), nonIdentityFields); + + if (tablePartitioned && partitionSpecPassed) { + // Check to see there are sufficient partition columns to satisfy the filter + Preconditions.checkArgument(partitionFields.size() >= partitionFilter.size(), + "Cannot add data files to target table %s because that table is partitioned, " + + "but the number of columns in the provided partition filter (%s) " + + "is greater than the number of partitioned columns in table (%s)", + table.name(), partitionFilter.size(), partitionFields.size()); + + // Check for any filters of non existent columns + List unMatchedFilters = partitionFilter.keySet().stream() + .filter(filterName -> !partitionNames.contains(filterName)) + .collect(Collectors.toList()); + Preconditions.checkArgument(unMatchedFilters.isEmpty(), + "Cannot add files to target table %s. %s is partitioned but the specified partition filter " + + "refers to columns that are not partitioned: '%s' . Valid partition columns %s", + table.name(), table.name(), unMatchedFilters, String.join(",", partitionNames)); + } else { + Preconditions.checkArgument(!partitionSpecPassed, + "Cannot use partition filter with an unpartitioned table %s", + table.name()); + } + } +} diff --git a/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/procedures/AncestorsOfProcedure.java b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/procedures/AncestorsOfProcedure.java new file mode 100644 index 000000000000..bfbca05a5744 --- /dev/null +++ b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/procedures/AncestorsOfProcedure.java @@ -0,0 +1,107 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iceberg.spark.procedures; + +import java.util.List; +import org.apache.iceberg.Table; +import org.apache.iceberg.relocated.com.google.common.collect.Lists; +import org.apache.iceberg.spark.source.SparkTable; +import org.apache.iceberg.util.SnapshotUtil; +import org.apache.spark.sql.catalyst.InternalRow; +import org.apache.spark.sql.connector.catalog.Identifier; +import org.apache.spark.sql.connector.catalog.TableCatalog; +import org.apache.spark.sql.connector.iceberg.catalog.ProcedureParameter; +import org.apache.spark.sql.types.DataTypes; +import org.apache.spark.sql.types.Metadata; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; + +public class AncestorsOfProcedure extends BaseProcedure { + + private static final ProcedureParameter[] PARAMETERS = new ProcedureParameter[] { + ProcedureParameter.required("table", DataTypes.StringType), + ProcedureParameter.optional("snapshot_id", DataTypes.LongType), + }; + + private static final StructType OUTPUT_TYPE = new StructType(new StructField[] { + new StructField("snapshot_id", DataTypes.LongType, true, Metadata.empty()), + new StructField("timestamp", DataTypes.LongType, true, Metadata.empty()) + }); + + private AncestorsOfProcedure(TableCatalog tableCatalog) { + super(tableCatalog); + } + + public static SparkProcedures.ProcedureBuilder builder() { + return new Builder() { + @Override + protected AncestorsOfProcedure doBuild() { + return new AncestorsOfProcedure(tableCatalog()); + } + }; + } + + @Override + public ProcedureParameter[] parameters() { + return PARAMETERS; + } + + @Override + public StructType outputType() { + return OUTPUT_TYPE; + } + + @Override + public InternalRow[] call(InternalRow args) { + Identifier tableIdent = toIdentifier(args.getString(0), PARAMETERS[0].name()); + Long toSnapshotId = args.isNullAt(1) ? null : args.getLong(1); + + SparkTable sparkTable = loadSparkTable(tableIdent); + Table icebergTable = sparkTable.table(); + + if (toSnapshotId == null) { + toSnapshotId = icebergTable.currentSnapshot() != null ? icebergTable.currentSnapshot().snapshotId() : -1; + } + + List snapshotIds = Lists.newArrayList( + SnapshotUtil.ancestorIdsBetween(toSnapshotId, null, icebergTable::snapshot)); + + return toOutputRow(icebergTable, snapshotIds); + } + + @Override + public String description() { + return "AncestorsOf"; + } + + private InternalRow[] toOutputRow(Table table, List snapshotIds) { + if (snapshotIds.isEmpty()) { + return new InternalRow[0]; + } + + InternalRow[] internalRows = new InternalRow[snapshotIds.size()]; + for (int i = 0; i < snapshotIds.size(); i++) { + Long snapshotId = snapshotIds.get(i); + internalRows[i] = newInternalRow(snapshotId, table.snapshot(snapshotId).timestampMillis()); + } + + return internalRows; + } +} diff --git a/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/procedures/BaseProcedure.java b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/procedures/BaseProcedure.java new file mode 100644 index 000000000000..3bb936e32dcb --- /dev/null +++ b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/procedures/BaseProcedure.java @@ -0,0 +1,201 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iceberg.spark.procedures; + +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.ThreadPoolExecutor; +import java.util.function.Function; +import org.apache.iceberg.exceptions.ValidationException; +import org.apache.iceberg.relocated.com.google.common.base.Preconditions; +import org.apache.iceberg.relocated.com.google.common.util.concurrent.MoreExecutors; +import org.apache.iceberg.relocated.com.google.common.util.concurrent.ThreadFactoryBuilder; +import org.apache.iceberg.spark.Spark3Util; +import org.apache.iceberg.spark.Spark3Util.CatalogAndIdentifier; +import org.apache.iceberg.spark.actions.SparkActions; +import org.apache.iceberg.spark.procedures.SparkProcedures.ProcedureBuilder; +import org.apache.iceberg.spark.source.SparkTable; +import org.apache.spark.sql.SparkSession; +import org.apache.spark.sql.catalyst.InternalRow; +import org.apache.spark.sql.catalyst.analysis.NoSuchTableException; +import org.apache.spark.sql.catalyst.expressions.GenericInternalRow; +import org.apache.spark.sql.connector.catalog.CatalogPlugin; +import org.apache.spark.sql.connector.catalog.Identifier; +import org.apache.spark.sql.connector.catalog.Table; +import org.apache.spark.sql.connector.catalog.TableCatalog; +import org.apache.spark.sql.connector.iceberg.catalog.Procedure; +import org.apache.spark.sql.execution.CacheManager; +import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation; +import org.apache.spark.sql.types.DataType; +import org.apache.spark.sql.types.DataTypes; +import scala.Option; + +abstract class BaseProcedure implements Procedure { + protected static final DataType STRING_MAP = DataTypes.createMapType(DataTypes.StringType, DataTypes.StringType); + + private final SparkSession spark; + private final TableCatalog tableCatalog; + + private SparkActions actions; + private ExecutorService executorService = null; + + protected BaseProcedure(TableCatalog tableCatalog) { + this.spark = SparkSession.active(); + this.tableCatalog = tableCatalog; + } + + protected SparkSession spark() { + return this.spark; + } + + protected SparkActions actions() { + if (actions == null) { + this.actions = SparkActions.get(spark); + } + return actions; + } + + protected TableCatalog tableCatalog() { + return this.tableCatalog; + } + + protected T modifyIcebergTable(Identifier ident, Function func) { + try { + return execute(ident, true, func); + } finally { + closeService(); + } + } + + protected T withIcebergTable(Identifier ident, Function func) { + try { + return execute(ident, false, func); + } finally { + closeService(); + } + } + + private T execute(Identifier ident, boolean refreshSparkCache, Function func) { + SparkTable sparkTable = loadSparkTable(ident); + org.apache.iceberg.Table icebergTable = sparkTable.table(); + + T result = func.apply(icebergTable); + + if (refreshSparkCache) { + refreshSparkCache(ident, sparkTable); + } + + return result; + } + + protected Identifier toIdentifier(String identifierAsString, String argName) { + CatalogAndIdentifier catalogAndIdentifier = toCatalogAndIdentifier(identifierAsString, argName, tableCatalog); + + Preconditions.checkArgument( + catalogAndIdentifier.catalog().equals(tableCatalog), + "Cannot run procedure in catalog '%s': '%s' is a table in catalog '%s'", + tableCatalog.name(), identifierAsString, catalogAndIdentifier.catalog().name()); + + return catalogAndIdentifier.identifier(); + } + + protected CatalogAndIdentifier toCatalogAndIdentifier(String identifierAsString, String argName, + CatalogPlugin catalog) { + Preconditions.checkArgument(identifierAsString != null && !identifierAsString.isEmpty(), + "Cannot handle an empty identifier for argument %s", argName); + + return Spark3Util.catalogAndIdentifier("identifier for arg " + argName, spark, identifierAsString, catalog); + } + + protected SparkTable loadSparkTable(Identifier ident) { + try { + Table table = tableCatalog.loadTable(ident); + ValidationException.check(table instanceof SparkTable, "%s is not %s", ident, SparkTable.class.getName()); + return (SparkTable) table; + } catch (NoSuchTableException e) { + String errMsg = String.format("Couldn't load table '%s' in catalog '%s'", ident, tableCatalog.name()); + throw new RuntimeException(errMsg, e); + } + } + + protected void refreshSparkCache(Identifier ident, Table table) { + CacheManager cacheManager = spark.sharedState().cacheManager(); + DataSourceV2Relation relation = DataSourceV2Relation.create(table, Option.apply(tableCatalog), Option.apply(ident)); + cacheManager.recacheByPlan(spark, relation); + } + + protected InternalRow newInternalRow(Object... values) { + return new GenericInternalRow(values); + } + + protected abstract static class Builder implements ProcedureBuilder { + private TableCatalog tableCatalog; + + @Override + public Builder withTableCatalog(TableCatalog newTableCatalog) { + this.tableCatalog = newTableCatalog; + return this; + } + + @Override + public T build() { + return doBuild(); + } + + protected abstract T doBuild(); + + TableCatalog tableCatalog() { + return tableCatalog; + } + } + + /** + * Closes this procedure's executor service if a new one was created with {@link #executorService(int, String)}. Does + * not block for any remaining tasks. + */ + protected void closeService() { + if (executorService != null) { + executorService.shutdown(); + } + } + + /** + * Starts a new executor service which can be used by this procedure in its work. The pool will be automatically + * shut down if {@link #withIcebergTable(Identifier, Function)} or {@link #modifyIcebergTable(Identifier, Function)} + * are called. If these methods are not used then the service can be shut down with {@link #closeService()} or left + * to be closed when this class is finalized. + * @param threadPoolSize number of threads in the service + * @param nameFormat name prefix for threads created in this service + * @return the new executor service owned by this procedure + */ + protected ExecutorService executorService(int threadPoolSize, String nameFormat) { + Preconditions.checkArgument(executorService == null, "Cannot create a new executor service, one already exists."); + Preconditions.checkArgument(nameFormat != null, "Cannot create a service with null nameFormat arg"); + this.executorService = MoreExecutors.getExitingExecutorService( + (ThreadPoolExecutor) Executors.newFixedThreadPool( + threadPoolSize, + new ThreadFactoryBuilder() + .setDaemon(true) + .setNameFormat(nameFormat + "-%d") + .build())); + + return executorService; + } +} diff --git a/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/procedures/CherrypickSnapshotProcedure.java b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/procedures/CherrypickSnapshotProcedure.java new file mode 100644 index 000000000000..6ce40f4ad5c5 --- /dev/null +++ b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/procedures/CherrypickSnapshotProcedure.java @@ -0,0 +1,97 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iceberg.spark.procedures; + +import org.apache.iceberg.Snapshot; +import org.apache.iceberg.spark.procedures.SparkProcedures.ProcedureBuilder; +import org.apache.spark.sql.catalyst.InternalRow; +import org.apache.spark.sql.connector.catalog.Identifier; +import org.apache.spark.sql.connector.catalog.TableCatalog; +import org.apache.spark.sql.connector.iceberg.catalog.ProcedureParameter; +import org.apache.spark.sql.types.DataTypes; +import org.apache.spark.sql.types.Metadata; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; + +/** + * A procedure that applies changes in a given snapshot and creates a new snapshot which will + * be set as the current snapshot in a table. + *

+ * Note: this procedure invalidates all cached Spark plans that reference the affected table. + * + * @see org.apache.iceberg.ManageSnapshots#cherrypick(long) + */ +class CherrypickSnapshotProcedure extends BaseProcedure { + + private static final ProcedureParameter[] PARAMETERS = new ProcedureParameter[]{ + ProcedureParameter.required("table", DataTypes.StringType), + ProcedureParameter.required("snapshot_id", DataTypes.LongType) + }; + + private static final StructType OUTPUT_TYPE = new StructType(new StructField[]{ + new StructField("source_snapshot_id", DataTypes.LongType, false, Metadata.empty()), + new StructField("current_snapshot_id", DataTypes.LongType, false, Metadata.empty()) + }); + + public static ProcedureBuilder builder() { + return new BaseProcedure.Builder() { + @Override + protected CherrypickSnapshotProcedure doBuild() { + return new CherrypickSnapshotProcedure(tableCatalog()); + } + }; + } + + private CherrypickSnapshotProcedure(TableCatalog catalog) { + super(catalog); + } + + @Override + public ProcedureParameter[] parameters() { + return PARAMETERS; + } + + @Override + public StructType outputType() { + return OUTPUT_TYPE; + } + + @Override + public InternalRow[] call(InternalRow args) { + Identifier tableIdent = toIdentifier(args.getString(0), PARAMETERS[0].name()); + long snapshotId = args.getLong(1); + + return modifyIcebergTable(tableIdent, table -> { + table.manageSnapshots() + .cherrypick(snapshotId) + .commit(); + + Snapshot currentSnapshot = table.currentSnapshot(); + + InternalRow outputRow = newInternalRow(snapshotId, currentSnapshot.snapshotId()); + return new InternalRow[]{outputRow}; + }); + } + + @Override + public String description() { + return "CherrypickSnapshotProcedure"; + } +} diff --git a/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/procedures/ExpireSnapshotsProcedure.java b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/procedures/ExpireSnapshotsProcedure.java new file mode 100644 index 000000000000..042fa6e4bb04 --- /dev/null +++ b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/procedures/ExpireSnapshotsProcedure.java @@ -0,0 +1,135 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iceberg.spark.procedures; + +import org.apache.iceberg.Table; +import org.apache.iceberg.actions.ExpireSnapshots; +import org.apache.iceberg.relocated.com.google.common.base.Preconditions; +import org.apache.iceberg.spark.actions.BaseExpireSnapshotsSparkAction; +import org.apache.iceberg.spark.actions.SparkActions; +import org.apache.iceberg.spark.procedures.SparkProcedures.ProcedureBuilder; +import org.apache.iceberg.util.DateTimeUtil; +import org.apache.spark.sql.catalyst.InternalRow; +import org.apache.spark.sql.connector.catalog.Identifier; +import org.apache.spark.sql.connector.catalog.TableCatalog; +import org.apache.spark.sql.connector.iceberg.catalog.ProcedureParameter; +import org.apache.spark.sql.types.DataTypes; +import org.apache.spark.sql.types.Metadata; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; + +/** + * A procedure that expires snapshots in a table. + * + * @see SparkActions#expireSnapshots(Table) + */ +public class ExpireSnapshotsProcedure extends BaseProcedure { + + private static final ProcedureParameter[] PARAMETERS = new ProcedureParameter[] { + ProcedureParameter.required("table", DataTypes.StringType), + ProcedureParameter.optional("older_than", DataTypes.TimestampType), + ProcedureParameter.optional("retain_last", DataTypes.IntegerType), + ProcedureParameter.optional("max_concurrent_deletes", DataTypes.IntegerType), + ProcedureParameter.optional("stream_results", DataTypes.BooleanType) + }; + + private static final StructType OUTPUT_TYPE = new StructType(new StructField[]{ + new StructField("deleted_data_files_count", DataTypes.LongType, true, Metadata.empty()), + new StructField("deleted_position_delete_files_count", DataTypes.LongType, true, Metadata.empty()), + new StructField("deleted_equality_delete_files_count", DataTypes.LongType, true, Metadata.empty()), + new StructField("deleted_manifest_files_count", DataTypes.LongType, true, Metadata.empty()), + new StructField("deleted_manifest_lists_count", DataTypes.LongType, true, Metadata.empty()) + }); + + public static ProcedureBuilder builder() { + return new BaseProcedure.Builder() { + @Override + protected ExpireSnapshotsProcedure doBuild() { + return new ExpireSnapshotsProcedure(tableCatalog()); + } + }; + } + + private ExpireSnapshotsProcedure(TableCatalog tableCatalog) { + super(tableCatalog); + } + + @Override + public ProcedureParameter[] parameters() { + return PARAMETERS; + } + + @Override + public StructType outputType() { + return OUTPUT_TYPE; + } + + @Override + public InternalRow[] call(InternalRow args) { + Identifier tableIdent = toIdentifier(args.getString(0), PARAMETERS[0].name()); + Long olderThanMillis = args.isNullAt(1) ? null : DateTimeUtil.microsToMillis(args.getLong(1)); + Integer retainLastNum = args.isNullAt(2) ? null : args.getInt(2); + Integer maxConcurrentDeletes = args.isNullAt(3) ? null : args.getInt(3); + Boolean streamResult = args.isNullAt(4) ? null : args.getBoolean(4); + + Preconditions.checkArgument(maxConcurrentDeletes == null || maxConcurrentDeletes > 0, + "max_concurrent_deletes should have value > 0, value: " + maxConcurrentDeletes); + + return modifyIcebergTable(tableIdent, table -> { + ExpireSnapshots action = actions().expireSnapshots(table); + + if (olderThanMillis != null) { + action.expireOlderThan(olderThanMillis); + } + + if (retainLastNum != null) { + action.retainLast(retainLastNum); + } + + if (maxConcurrentDeletes != null && maxConcurrentDeletes > 0) { + action.executeDeleteWith(executorService(maxConcurrentDeletes, "expire-snapshots")); + } + + if (streamResult != null) { + action.option(BaseExpireSnapshotsSparkAction.STREAM_RESULTS, Boolean.toString(streamResult)); + } + + ExpireSnapshots.Result result = action.execute(); + + return toOutputRows(result); + }); + } + + private InternalRow[] toOutputRows(ExpireSnapshots.Result result) { + InternalRow row = newInternalRow( + result.deletedDataFilesCount(), + result.deletedPositionDeleteFilesCount(), + result.deletedEqualityDeleteFilesCount(), + result.deletedManifestsCount(), + result.deletedManifestListsCount() + ); + return new InternalRow[]{row}; + } + + @Override + public String description() { + return "ExpireSnapshotProcedure"; + } +} diff --git a/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/procedures/MigrateTableProcedure.java b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/procedures/MigrateTableProcedure.java new file mode 100644 index 000000000000..2f6841924f8c --- /dev/null +++ b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/procedures/MigrateTableProcedure.java @@ -0,0 +1,93 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iceberg.spark.procedures; + +import java.util.Map; +import org.apache.iceberg.actions.MigrateTable; +import org.apache.iceberg.relocated.com.google.common.base.Preconditions; +import org.apache.iceberg.relocated.com.google.common.collect.Maps; +import org.apache.iceberg.spark.actions.SparkActions; +import org.apache.iceberg.spark.procedures.SparkProcedures.ProcedureBuilder; +import org.apache.spark.sql.catalyst.InternalRow; +import org.apache.spark.sql.connector.catalog.TableCatalog; +import org.apache.spark.sql.connector.iceberg.catalog.ProcedureParameter; +import org.apache.spark.sql.types.DataTypes; +import org.apache.spark.sql.types.Metadata; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; +import scala.runtime.BoxedUnit; + +class MigrateTableProcedure extends BaseProcedure { + private static final ProcedureParameter[] PARAMETERS = new ProcedureParameter[]{ + ProcedureParameter.required("table", DataTypes.StringType), + ProcedureParameter.optional("properties", STRING_MAP) + }; + + private static final StructType OUTPUT_TYPE = new StructType(new StructField[]{ + new StructField("migrated_files_count", DataTypes.LongType, false, Metadata.empty()) + }); + + private MigrateTableProcedure(TableCatalog tableCatalog) { + super(tableCatalog); + } + + public static ProcedureBuilder builder() { + return new BaseProcedure.Builder() { + @Override + protected MigrateTableProcedure doBuild() { + return new MigrateTableProcedure(tableCatalog()); + } + }; + } + + @Override + public ProcedureParameter[] parameters() { + return PARAMETERS; + } + + @Override + public StructType outputType() { + return OUTPUT_TYPE; + } + + @Override + public InternalRow[] call(InternalRow args) { + String tableName = args.getString(0); + Preconditions.checkArgument(tableName != null && !tableName.isEmpty(), + "Cannot handle an empty identifier for argument table"); + + Map properties = Maps.newHashMap(); + if (!args.isNullAt(1)) { + args.getMap(1).foreach(DataTypes.StringType, DataTypes.StringType, + (k, v) -> { + properties.put(k.toString(), v.toString()); + return BoxedUnit.UNIT; + }); + } + + MigrateTable.Result result = SparkActions.get().migrateTable(tableName).tableProperties(properties).execute(); + return new InternalRow[] {newInternalRow(result.migratedDataFilesCount())}; + } + + @Override + public String description() { + return "MigrateTableProcedure"; + } +} diff --git a/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/procedures/RegisterTableProcedure.java b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/procedures/RegisterTableProcedure.java new file mode 100644 index 000000000000..d7f1f56d9b57 --- /dev/null +++ b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/procedures/RegisterTableProcedure.java @@ -0,0 +1,104 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iceberg.spark.procedures; + +import org.apache.iceberg.Snapshot; +import org.apache.iceberg.SnapshotSummary; +import org.apache.iceberg.Table; +import org.apache.iceberg.catalog.Catalog; +import org.apache.iceberg.catalog.TableIdentifier; +import org.apache.iceberg.relocated.com.google.common.base.Preconditions; +import org.apache.iceberg.spark.Spark3Util; +import org.apache.iceberg.spark.procedures.SparkProcedures.ProcedureBuilder; +import org.apache.iceberg.spark.source.HasIcebergCatalog; +import org.apache.spark.sql.catalyst.InternalRow; +import org.apache.spark.sql.connector.catalog.TableCatalog; +import org.apache.spark.sql.connector.iceberg.catalog.ProcedureParameter; +import org.apache.spark.sql.types.DataTypes; +import org.apache.spark.sql.types.Metadata; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; + +class RegisterTableProcedure extends BaseProcedure { + private static final ProcedureParameter[] PARAMETERS = new ProcedureParameter[]{ + ProcedureParameter.required("table", DataTypes.StringType), + ProcedureParameter.required("metadata_file", DataTypes.StringType) + }; + + private static final StructType OUTPUT_TYPE = new StructType(new StructField[]{ + new StructField("Current SnapshotId", DataTypes.LongType, true, Metadata.empty()), + new StructField("Rows", DataTypes.LongType, true, Metadata.empty()), + new StructField("Datafiles", DataTypes.LongType, true, Metadata.empty()) + }); + + private RegisterTableProcedure(TableCatalog tableCatalog) { + super(tableCatalog); + } + + public static ProcedureBuilder builder() { + return new BaseProcedure.Builder() { + @Override + protected RegisterTableProcedure doBuild() { + return new RegisterTableProcedure(tableCatalog()); + } + }; + } + + @Override + public ProcedureParameter[] parameters() { + return PARAMETERS; + } + + @Override + public StructType outputType() { + return OUTPUT_TYPE; + } + + @Override + public InternalRow[] call(InternalRow args) { + TableIdentifier tableName = Spark3Util.identifierToTableIdentifier(toIdentifier(args.getString(0), "table")); + String metadataFile = args.getString(1); + Preconditions.checkArgument(tableCatalog() instanceof HasIcebergCatalog, + "Cannot use Register Table in a non-Iceberg catalog"); + Preconditions.checkArgument(metadataFile != null && !metadataFile.isEmpty(), + "Cannot handle an empty argument metadata_file"); + + Catalog icebergCatalog = ((HasIcebergCatalog) tableCatalog()).icebergCatalog(); + Table table = icebergCatalog.registerTable(tableName, metadataFile); + Long currentSnapshotId = null; + Long totalDataFiles = null; + Long totalRecords = null; + + Snapshot currentSnapshot = table.currentSnapshot(); + if (currentSnapshot != null) { + currentSnapshotId = currentSnapshot.snapshotId(); + totalDataFiles = Long.parseLong(currentSnapshot.summary().get(SnapshotSummary.TOTAL_DATA_FILES_PROP)); + totalRecords = Long.parseLong(currentSnapshot.summary().get(SnapshotSummary.TOTAL_RECORDS_PROP)); + } + + return new InternalRow[] {newInternalRow(currentSnapshotId, totalRecords, totalDataFiles)}; + } + + @Override + public String description() { + return "RegisterTableProcedure"; + } +} + diff --git a/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/procedures/RemoveOrphanFilesProcedure.java b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/procedures/RemoveOrphanFilesProcedure.java new file mode 100644 index 000000000000..23d64719081c --- /dev/null +++ b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/procedures/RemoveOrphanFilesProcedure.java @@ -0,0 +1,161 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iceberg.spark.procedures; + +import java.util.concurrent.TimeUnit; +import org.apache.iceberg.Table; +import org.apache.iceberg.actions.DeleteOrphanFiles; +import org.apache.iceberg.relocated.com.google.common.base.Preconditions; +import org.apache.iceberg.relocated.com.google.common.collect.Iterables; +import org.apache.iceberg.spark.actions.BaseDeleteOrphanFilesSparkAction; +import org.apache.iceberg.spark.actions.SparkActions; +import org.apache.iceberg.spark.procedures.SparkProcedures.ProcedureBuilder; +import org.apache.iceberg.util.DateTimeUtil; +import org.apache.spark.sql.catalyst.InternalRow; +import org.apache.spark.sql.connector.catalog.Identifier; +import org.apache.spark.sql.connector.catalog.TableCatalog; +import org.apache.spark.sql.connector.iceberg.catalog.ProcedureParameter; +import org.apache.spark.sql.types.DataTypes; +import org.apache.spark.sql.types.Metadata; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; +import org.apache.spark.unsafe.types.UTF8String; + +/** + * A procedure that removes orphan files in a table. + * + * @see SparkActions#deleteOrphanFiles(Table) + */ +public class RemoveOrphanFilesProcedure extends BaseProcedure { + + private static final ProcedureParameter[] PARAMETERS = new ProcedureParameter[]{ + ProcedureParameter.required("table", DataTypes.StringType), + ProcedureParameter.optional("older_than", DataTypes.TimestampType), + ProcedureParameter.optional("location", DataTypes.StringType), + ProcedureParameter.optional("dry_run", DataTypes.BooleanType), + ProcedureParameter.optional("max_concurrent_deletes", DataTypes.IntegerType), + ProcedureParameter.optional("file_list_view", DataTypes.StringType) + }; + + private static final StructType OUTPUT_TYPE = new StructType(new StructField[]{ + new StructField("orphan_file_location", DataTypes.StringType, false, Metadata.empty()) + }); + + public static ProcedureBuilder builder() { + return new BaseProcedure.Builder() { + @Override + protected RemoveOrphanFilesProcedure doBuild() { + return new RemoveOrphanFilesProcedure(tableCatalog()); + } + }; + } + + private RemoveOrphanFilesProcedure(TableCatalog catalog) { + super(catalog); + } + + @Override + public ProcedureParameter[] parameters() { + return PARAMETERS; + } + + @Override + public StructType outputType() { + return OUTPUT_TYPE; + } + + @Override + @SuppressWarnings("checkstyle:CyclomaticComplexity") + public InternalRow[] call(InternalRow args) { + Identifier tableIdent = toIdentifier(args.getString(0), PARAMETERS[0].name()); + Long olderThanMillis = args.isNullAt(1) ? null : DateTimeUtil.microsToMillis(args.getLong(1)); + String location = args.isNullAt(2) ? null : args.getString(2); + boolean dryRun = args.isNullAt(3) ? false : args.getBoolean(3); + Integer maxConcurrentDeletes = args.isNullAt(4) ? null : args.getInt(4); + String fileListView = args.isNullAt(5) ? null : args.getString(5); + + Preconditions.checkArgument(maxConcurrentDeletes == null || maxConcurrentDeletes > 0, + "max_concurrent_deletes should have value > 0, value: " + maxConcurrentDeletes); + + return withIcebergTable(tableIdent, table -> { + DeleteOrphanFiles action = actions().deleteOrphanFiles(table); + + if (olderThanMillis != null) { + boolean isTesting = Boolean.parseBoolean(spark().conf().get("spark.testing", "false")); + if (!isTesting) { + validateInterval(olderThanMillis); + } + action.olderThan(olderThanMillis); + } + + if (location != null) { + action.location(location); + } + + if (dryRun) { + action.deleteWith(file -> { }); + } + + if (maxConcurrentDeletes != null && maxConcurrentDeletes > 0) { + action.executeDeleteWith(executorService(maxConcurrentDeletes, "remove-orphans")); + } + + if (fileListView != null) { + ((BaseDeleteOrphanFilesSparkAction) action).compareToFileList(spark().table(fileListView)); + } + + DeleteOrphanFiles.Result result = action.execute(); + + return toOutputRows(result); + }); + } + + private InternalRow[] toOutputRows(DeleteOrphanFiles.Result result) { + Iterable orphanFileLocations = result.orphanFileLocations(); + + int orphanFileLocationsCount = Iterables.size(orphanFileLocations); + InternalRow[] rows = new InternalRow[orphanFileLocationsCount]; + + int index = 0; + for (String fileLocation : orphanFileLocations) { + rows[index] = newInternalRow(UTF8String.fromString(fileLocation)); + index++; + } + + return rows; + } + + private void validateInterval(long olderThanMillis) { + long intervalMillis = System.currentTimeMillis() - olderThanMillis; + if (intervalMillis < TimeUnit.DAYS.toMillis(1)) { + throw new IllegalArgumentException( + "Cannot remove orphan files with an interval less than 24 hours. Executing this " + + "procedure with a short interval may corrupt the table if other operations are happening " + + "at the same time. If you are absolutely confident that no concurrent operations will be " + + "affected by removing orphan files with such a short interval, you can use the Action API " + + "to remove orphan files with an arbitrary interval."); + } + } + + @Override + public String description() { + return "RemoveOrphanFilesProcedure"; + } +} diff --git a/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/procedures/RewriteDataFilesProcedure.java b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/procedures/RewriteDataFilesProcedure.java new file mode 100644 index 000000000000..1f6b34fd2918 --- /dev/null +++ b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/procedures/RewriteDataFilesProcedure.java @@ -0,0 +1,186 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iceberg.spark.procedures; + +import java.util.Map; +import org.apache.iceberg.SortOrder; +import org.apache.iceberg.Table; +import org.apache.iceberg.actions.RewriteDataFiles; +import org.apache.iceberg.relocated.com.google.common.collect.Maps; +import org.apache.iceberg.spark.Spark3Util; +import org.apache.iceberg.spark.procedures.SparkProcedures.ProcedureBuilder; +import org.apache.spark.sql.AnalysisException; +import org.apache.spark.sql.catalyst.InternalRow; +import org.apache.spark.sql.catalyst.expressions.Expression; +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan; +import org.apache.spark.sql.catalyst.plans.logical.SetWriteDistributionAndOrdering; +import org.apache.spark.sql.catalyst.plans.logical.SortOrderParserUtil; +import org.apache.spark.sql.connector.catalog.Identifier; +import org.apache.spark.sql.connector.catalog.TableCatalog; +import org.apache.spark.sql.connector.iceberg.catalog.ProcedureParameter; +import org.apache.spark.sql.execution.datasources.SparkExpressionConverter; +import org.apache.spark.sql.types.DataTypes; +import org.apache.spark.sql.types.Metadata; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; +import scala.runtime.BoxedUnit; + +/** + * A procedure that rewrites datafiles in a table. + * + * @see org.apache.iceberg.spark.actions.SparkActions#rewriteDataFiles(Table, String) + */ +class RewriteDataFilesProcedure extends BaseProcedure { + + private static final ProcedureParameter[] PARAMETERS = new ProcedureParameter[]{ + ProcedureParameter.required("table", DataTypes.StringType), + ProcedureParameter.optional("strategy", DataTypes.StringType), + ProcedureParameter.optional("sort_order", DataTypes.StringType), + ProcedureParameter.optional("options", STRING_MAP), + ProcedureParameter.optional("where", DataTypes.StringType) + }; + + // counts are not nullable since the action result is never null + private static final StructType OUTPUT_TYPE = new StructType(new StructField[]{ + new StructField("rewritten_data_files_count", DataTypes.IntegerType, false, Metadata.empty()), + new StructField("added_data_files_count", DataTypes.IntegerType, false, Metadata.empty()) + }); + + public static ProcedureBuilder builder() { + return new Builder() { + @Override + protected RewriteDataFilesProcedure doBuild() { + return new RewriteDataFilesProcedure(tableCatalog()); + } + }; + } + + private RewriteDataFilesProcedure(TableCatalog tableCatalog) { + super(tableCatalog); + } + + @Override + public ProcedureParameter[] parameters() { + return PARAMETERS; + } + + @Override + public StructType outputType() { + return OUTPUT_TYPE; + } + + @Override + public InternalRow[] call(InternalRow args) { + Identifier tableIdent = toIdentifier(args.getString(0), PARAMETERS[0].name()); + + return modifyIcebergTable(tableIdent, table -> { + String quotedFullIdentifier = Spark3Util.quotedFullIdentifier(tableCatalog().name(), tableIdent); + RewriteDataFiles action = actions().rewriteDataFiles(table, quotedFullIdentifier); + + String strategy = args.isNullAt(1) ? null : args.getString(1); + String sortOrderString = args.isNullAt(2) ? null : args.getString(2); + SortOrder sortOrder = null; + if (sortOrderString != null) { + sortOrder = collectSortOrders(table, sortOrderString); + } + if (strategy != null || sortOrder != null) { + action = checkAndApplyStrategy(action, strategy, sortOrder); + } + + if (!args.isNullAt(3)) { + action = checkAndApplyOptions(args, action); + } + + String where = args.isNullAt(4) ? null : args.getString(4); + + action = checkAndApplyFilter(action, where, quotedFullIdentifier); + + RewriteDataFiles.Result result = action.execute(); + + return toOutputRows(result); + }); + } + + private RewriteDataFiles checkAndApplyFilter(RewriteDataFiles action, String where, String tableName) { + if (where != null) { + try { + Expression expression = SparkExpressionConverter.collectResolvedSparkExpression(spark(), tableName, where); + return action.filter(SparkExpressionConverter.convertToIcebergExpression(expression)); + } catch (AnalysisException e) { + throw new IllegalArgumentException("Cannot parse predicates in where option: " + where); + } + } + return action; + } + + private RewriteDataFiles checkAndApplyOptions(InternalRow args, RewriteDataFiles action) { + Map options = Maps.newHashMap(); + args.getMap(3).foreach(DataTypes.StringType, DataTypes.StringType, + (k, v) -> { + options.put(k.toString(), v.toString()); + return BoxedUnit.UNIT; + }); + return action.options(options); + } + + private RewriteDataFiles checkAndApplyStrategy(RewriteDataFiles action, String strategy, SortOrder sortOrder) { + // caller of this function ensures that between strategy and sortOrder, at least one of them is not null. + if (strategy == null || strategy.equalsIgnoreCase("sort")) { + return action.sort(sortOrder); + } + if (strategy.equalsIgnoreCase("binpack")) { + RewriteDataFiles rewriteDataFiles = action.binPack(); + if (sortOrder != null) { + // calling below method to throw the error as user has set both binpack strategy and sort order + return rewriteDataFiles.sort(sortOrder); + } + return rewriteDataFiles; + } else { + throw new IllegalArgumentException("unsupported strategy: " + strategy + ". Only binpack,sort is supported"); + } + } + + private SortOrder collectSortOrders(Table table, String sortOrderStr) { + String prefix = "ALTER TABLE temp WRITE ORDERED BY "; + try { + // Note: Reusing the existing Iceberg sql parser to avoid implementing the custom parser for sort orders. + // To reuse the existing parser, adding a prefix of "ALTER TABLE temp WRITE ORDERED BY" + // along with input sort order and parsing it as a plan to collect the sortOrder. + LogicalPlan logicalPlan = spark().sessionState().sqlParser().parsePlan(prefix + sortOrderStr); + return (new SortOrderParserUtil()).collectSortOrder( + table.schema(), + ((SetWriteDistributionAndOrdering) logicalPlan).sortOrder()); + } catch (AnalysisException ex) { + throw new IllegalArgumentException("Unable to parse sortOrder: " + sortOrderStr); + } + } + + private InternalRow[] toOutputRows(RewriteDataFiles.Result result) { + int rewrittenDataFilesCount = result.rewrittenDataFilesCount(); + int addedDataFilesCount = result.addedDataFilesCount(); + InternalRow row = newInternalRow(rewrittenDataFilesCount, addedDataFilesCount); + return new InternalRow[]{row}; + } + + @Override + public String description() { + return "RewriteDataFilesProcedure"; + } +} diff --git a/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/procedures/RewriteManifestsProcedure.java b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/procedures/RewriteManifestsProcedure.java new file mode 100644 index 000000000000..eae7ce0fca97 --- /dev/null +++ b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/procedures/RewriteManifestsProcedure.java @@ -0,0 +1,108 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iceberg.spark.procedures; + +import org.apache.iceberg.Table; +import org.apache.iceberg.actions.RewriteManifests; +import org.apache.iceberg.relocated.com.google.common.collect.Iterables; +import org.apache.iceberg.spark.actions.SparkActions; +import org.apache.iceberg.spark.procedures.SparkProcedures.ProcedureBuilder; +import org.apache.spark.sql.catalyst.InternalRow; +import org.apache.spark.sql.connector.catalog.Identifier; +import org.apache.spark.sql.connector.catalog.TableCatalog; +import org.apache.spark.sql.connector.iceberg.catalog.ProcedureParameter; +import org.apache.spark.sql.types.DataTypes; +import org.apache.spark.sql.types.Metadata; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; + +/** + * A procedure that rewrites manifests in a table. + *

+ * Note: this procedure invalidates all cached Spark plans that reference the affected table. + * + * @see SparkActions#rewriteManifests(Table) () + */ +class RewriteManifestsProcedure extends BaseProcedure { + + private static final ProcedureParameter[] PARAMETERS = new ProcedureParameter[]{ + ProcedureParameter.required("table", DataTypes.StringType), + ProcedureParameter.optional("use_caching", DataTypes.BooleanType) + }; + + // counts are not nullable since the action result is never null + private static final StructType OUTPUT_TYPE = new StructType(new StructField[]{ + new StructField("rewritten_manifests_count", DataTypes.IntegerType, false, Metadata.empty()), + new StructField("added_manifests_count", DataTypes.IntegerType, false, Metadata.empty()) + }); + + public static ProcedureBuilder builder() { + return new BaseProcedure.Builder() { + @Override + protected RewriteManifestsProcedure doBuild() { + return new RewriteManifestsProcedure(tableCatalog()); + } + }; + } + + private RewriteManifestsProcedure(TableCatalog tableCatalog) { + super(tableCatalog); + } + + @Override + public ProcedureParameter[] parameters() { + return PARAMETERS; + } + + @Override + public StructType outputType() { + return OUTPUT_TYPE; + } + + @Override + public InternalRow[] call(InternalRow args) { + Identifier tableIdent = toIdentifier(args.getString(0), PARAMETERS[0].name()); + Boolean useCaching = args.isNullAt(1) ? null : args.getBoolean(1); + + return modifyIcebergTable(tableIdent, table -> { + RewriteManifests action = actions().rewriteManifests(table); + + if (useCaching != null) { + action.option("use-caching", useCaching.toString()); + } + + RewriteManifests.Result result = action.execute(); + + return toOutputRows(result); + }); + } + + private InternalRow[] toOutputRows(RewriteManifests.Result result) { + int rewrittenManifestsCount = Iterables.size(result.rewrittenManifests()); + int addedManifestsCount = Iterables.size(result.addedManifests()); + InternalRow row = newInternalRow(rewrittenManifestsCount, addedManifestsCount); + return new InternalRow[]{row}; + } + + @Override + public String description() { + return "RewriteManifestsProcedure"; + } +} diff --git a/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/procedures/RollbackToSnapshotProcedure.java b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/procedures/RollbackToSnapshotProcedure.java new file mode 100644 index 000000000000..7cf5b0c77bb2 --- /dev/null +++ b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/procedures/RollbackToSnapshotProcedure.java @@ -0,0 +1,96 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iceberg.spark.procedures; + +import org.apache.iceberg.Snapshot; +import org.apache.iceberg.spark.procedures.SparkProcedures.ProcedureBuilder; +import org.apache.spark.sql.catalyst.InternalRow; +import org.apache.spark.sql.connector.catalog.Identifier; +import org.apache.spark.sql.connector.catalog.TableCatalog; +import org.apache.spark.sql.connector.iceberg.catalog.ProcedureParameter; +import org.apache.spark.sql.types.DataTypes; +import org.apache.spark.sql.types.Metadata; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; + +/** + * A procedure that rollbacks a table to a specific snapshot id. + *

+ * Note: this procedure invalidates all cached Spark plans that reference the affected table. + * + * @see org.apache.iceberg.ManageSnapshots#rollbackTo(long) + */ +class RollbackToSnapshotProcedure extends BaseProcedure { + + private static final ProcedureParameter[] PARAMETERS = new ProcedureParameter[]{ + ProcedureParameter.required("table", DataTypes.StringType), + ProcedureParameter.required("snapshot_id", DataTypes.LongType) + }; + + private static final StructType OUTPUT_TYPE = new StructType(new StructField[]{ + new StructField("previous_snapshot_id", DataTypes.LongType, false, Metadata.empty()), + new StructField("current_snapshot_id", DataTypes.LongType, false, Metadata.empty()) + }); + + public static ProcedureBuilder builder() { + return new BaseProcedure.Builder() { + @Override + public RollbackToSnapshotProcedure doBuild() { + return new RollbackToSnapshotProcedure(tableCatalog()); + } + }; + } + + private RollbackToSnapshotProcedure(TableCatalog tableCatalog) { + super(tableCatalog); + } + + @Override + public ProcedureParameter[] parameters() { + return PARAMETERS; + } + + @Override + public StructType outputType() { + return OUTPUT_TYPE; + } + + @Override + public InternalRow[] call(InternalRow args) { + Identifier tableIdent = toIdentifier(args.getString(0), PARAMETERS[0].name()); + long snapshotId = args.getLong(1); + + return modifyIcebergTable(tableIdent, table -> { + Snapshot previousSnapshot = table.currentSnapshot(); + + table.manageSnapshots() + .rollbackTo(snapshotId) + .commit(); + + InternalRow outputRow = newInternalRow(previousSnapshot.snapshotId(), snapshotId); + return new InternalRow[]{outputRow}; + }); + } + + @Override + public String description() { + return "RollbackToSnapshotProcedure"; + } +} diff --git a/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/procedures/RollbackToTimestampProcedure.java b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/procedures/RollbackToTimestampProcedure.java new file mode 100644 index 000000000000..519a46c6dbb8 --- /dev/null +++ b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/procedures/RollbackToTimestampProcedure.java @@ -0,0 +1,100 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iceberg.spark.procedures; + +import org.apache.iceberg.Snapshot; +import org.apache.iceberg.spark.procedures.SparkProcedures.ProcedureBuilder; +import org.apache.iceberg.util.DateTimeUtil; +import org.apache.spark.sql.catalyst.InternalRow; +import org.apache.spark.sql.connector.catalog.Identifier; +import org.apache.spark.sql.connector.catalog.TableCatalog; +import org.apache.spark.sql.connector.iceberg.catalog.ProcedureParameter; +import org.apache.spark.sql.types.DataTypes; +import org.apache.spark.sql.types.Metadata; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; + +/** + * A procedure that rollbacks a table to a given point in time. + *

+ * Note: this procedure invalidates all cached Spark plans that reference the affected table. + * + * @see org.apache.iceberg.ManageSnapshots#rollbackToTime(long) + */ +class RollbackToTimestampProcedure extends BaseProcedure { + + private static final ProcedureParameter[] PARAMETERS = new ProcedureParameter[]{ + ProcedureParameter.required("table", DataTypes.StringType), + ProcedureParameter.required("timestamp", DataTypes.TimestampType) + }; + + private static final StructType OUTPUT_TYPE = new StructType(new StructField[]{ + new StructField("previous_snapshot_id", DataTypes.LongType, false, Metadata.empty()), + new StructField("current_snapshot_id", DataTypes.LongType, false, Metadata.empty()) + }); + + public static ProcedureBuilder builder() { + return new BaseProcedure.Builder() { + @Override + protected RollbackToTimestampProcedure doBuild() { + return new RollbackToTimestampProcedure(tableCatalog()); + } + }; + } + + private RollbackToTimestampProcedure(TableCatalog catalog) { + super(catalog); + } + + @Override + public ProcedureParameter[] parameters() { + return PARAMETERS; + } + + @Override + public StructType outputType() { + return OUTPUT_TYPE; + } + + @Override + public InternalRow[] call(InternalRow args) { + Identifier tableIdent = toIdentifier(args.getString(0), PARAMETERS[0].name()); + // timestamps in Spark have microsecond precision so this conversion is lossy + long timestampMillis = DateTimeUtil.microsToMillis(args.getLong(1)); + + return modifyIcebergTable(tableIdent, table -> { + Snapshot previousSnapshot = table.currentSnapshot(); + + table.manageSnapshots() + .rollbackToTime(timestampMillis) + .commit(); + + Snapshot currentSnapshot = table.currentSnapshot(); + + InternalRow outputRow = newInternalRow(previousSnapshot.snapshotId(), currentSnapshot.snapshotId()); + return new InternalRow[]{outputRow}; + }); + } + + @Override + public String description() { + return "RollbackToTimestampProcedure"; + } +} diff --git a/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/procedures/SetCurrentSnapshotProcedure.java b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/procedures/SetCurrentSnapshotProcedure.java new file mode 100644 index 000000000000..274ca19fc107 --- /dev/null +++ b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/procedures/SetCurrentSnapshotProcedure.java @@ -0,0 +1,97 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iceberg.spark.procedures; + +import org.apache.iceberg.Snapshot; +import org.apache.iceberg.spark.procedures.SparkProcedures.ProcedureBuilder; +import org.apache.spark.sql.catalyst.InternalRow; +import org.apache.spark.sql.connector.catalog.Identifier; +import org.apache.spark.sql.connector.catalog.TableCatalog; +import org.apache.spark.sql.connector.iceberg.catalog.ProcedureParameter; +import org.apache.spark.sql.types.DataTypes; +import org.apache.spark.sql.types.Metadata; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; + +/** + * A procedure that sets the current snapshot in a table. + *

+ * Note: this procedure invalidates all cached Spark plans that reference the affected table. + * + * @see org.apache.iceberg.ManageSnapshots#setCurrentSnapshot(long) + */ +class SetCurrentSnapshotProcedure extends BaseProcedure { + + private static final ProcedureParameter[] PARAMETERS = new ProcedureParameter[] { + ProcedureParameter.required("table", DataTypes.StringType), + ProcedureParameter.required("snapshot_id", DataTypes.LongType) + }; + + private static final StructType OUTPUT_TYPE = new StructType(new StructField[]{ + new StructField("previous_snapshot_id", DataTypes.LongType, true, Metadata.empty()), + new StructField("current_snapshot_id", DataTypes.LongType, false, Metadata.empty()) + }); + + public static ProcedureBuilder builder() { + return new BaseProcedure.Builder() { + @Override + protected SetCurrentSnapshotProcedure doBuild() { + return new SetCurrentSnapshotProcedure(tableCatalog()); + } + }; + } + + private SetCurrentSnapshotProcedure(TableCatalog catalog) { + super(catalog); + } + + @Override + public ProcedureParameter[] parameters() { + return PARAMETERS; + } + + @Override + public StructType outputType() { + return OUTPUT_TYPE; + } + + @Override + public InternalRow[] call(InternalRow args) { + Identifier tableIdent = toIdentifier(args.getString(0), PARAMETERS[0].name()); + long snapshotId = args.getLong(1); + + return modifyIcebergTable(tableIdent, table -> { + Snapshot previousSnapshot = table.currentSnapshot(); + Long previousSnapshotId = previousSnapshot != null ? previousSnapshot.snapshotId() : null; + + table.manageSnapshots() + .setCurrentSnapshot(snapshotId) + .commit(); + + InternalRow outputRow = newInternalRow(previousSnapshotId, snapshotId); + return new InternalRow[]{outputRow}; + }); + } + + @Override + public String description() { + return "SetCurrentSnapshotProcedure"; + } +} diff --git a/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/procedures/SnapshotTableProcedure.java b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/procedures/SnapshotTableProcedure.java new file mode 100644 index 000000000000..96e293d6b1da --- /dev/null +++ b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/procedures/SnapshotTableProcedure.java @@ -0,0 +1,107 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iceberg.spark.procedures; + +import java.util.Map; +import org.apache.iceberg.actions.SnapshotTable; +import org.apache.iceberg.relocated.com.google.common.base.Preconditions; +import org.apache.iceberg.relocated.com.google.common.collect.Maps; +import org.apache.iceberg.spark.actions.SparkActions; +import org.apache.spark.sql.catalyst.InternalRow; +import org.apache.spark.sql.connector.catalog.TableCatalog; +import org.apache.spark.sql.connector.iceberg.catalog.ProcedureParameter; +import org.apache.spark.sql.types.DataTypes; +import org.apache.spark.sql.types.Metadata; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; +import scala.runtime.BoxedUnit; + +class SnapshotTableProcedure extends BaseProcedure { + private static final ProcedureParameter[] PARAMETERS = new ProcedureParameter[]{ + ProcedureParameter.required("source_table", DataTypes.StringType), + ProcedureParameter.required("table", DataTypes.StringType), + ProcedureParameter.optional("location", DataTypes.StringType), + ProcedureParameter.optional("properties", STRING_MAP) + }; + + private static final StructType OUTPUT_TYPE = new StructType(new StructField[]{ + new StructField("imported_files_count", DataTypes.LongType, false, Metadata.empty()) + }); + + private SnapshotTableProcedure(TableCatalog tableCatalog) { + super(tableCatalog); + } + + public static SparkProcedures.ProcedureBuilder builder() { + return new BaseProcedure.Builder() { + @Override + protected SnapshotTableProcedure doBuild() { + return new SnapshotTableProcedure(tableCatalog()); + } + }; + } + + @Override + public ProcedureParameter[] parameters() { + return PARAMETERS; + } + + @Override + public StructType outputType() { + return OUTPUT_TYPE; + } + + @Override + public InternalRow[] call(InternalRow args) { + String source = args.getString(0); + Preconditions.checkArgument(source != null && !source.isEmpty(), + "Cannot handle an empty identifier for argument source_table"); + String dest = args.getString(1); + Preconditions.checkArgument(dest != null && !dest.isEmpty(), + "Cannot handle an empty identifier for argument table"); + String snapshotLocation = args.isNullAt(2) ? null : args.getString(2); + + Map properties = Maps.newHashMap(); + if (!args.isNullAt(3)) { + args.getMap(3).foreach(DataTypes.StringType, DataTypes.StringType, + (k, v) -> { + properties.put(k.toString(), v.toString()); + return BoxedUnit.UNIT; + }); + } + + Preconditions.checkArgument(!source.equals(dest), + "Cannot create a snapshot with the same name as the source of the snapshot."); + SnapshotTable action = SparkActions.get().snapshotTable(source).as(dest); + + if (snapshotLocation != null) { + action.tableLocation(snapshotLocation); + } + + SnapshotTable.Result result = action.tableProperties(properties).execute(); + return new InternalRow[] {newInternalRow(result.importedDataFilesCount())}; + } + + @Override + public String description() { + return "SnapshotTableProcedure"; + } + +} diff --git a/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/procedures/SparkProcedures.java b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/procedures/SparkProcedures.java new file mode 100644 index 000000000000..1e944fcf0651 --- /dev/null +++ b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/procedures/SparkProcedures.java @@ -0,0 +1,64 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iceberg.spark.procedures; + +import java.util.Locale; +import java.util.Map; +import java.util.function.Supplier; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap; +import org.apache.spark.sql.connector.catalog.TableCatalog; +import org.apache.spark.sql.connector.iceberg.catalog.Procedure; + +public class SparkProcedures { + + private static final Map> BUILDERS = initProcedureBuilders(); + + private SparkProcedures() { + } + + public static ProcedureBuilder newBuilder(String name) { + // procedure resolution is case insensitive to match the existing Spark behavior for functions + Supplier builderSupplier = BUILDERS.get(name.toLowerCase(Locale.ROOT)); + return builderSupplier != null ? builderSupplier.get() : null; + } + + private static Map> initProcedureBuilders() { + ImmutableMap.Builder> mapBuilder = ImmutableMap.builder(); + mapBuilder.put("rollback_to_snapshot", RollbackToSnapshotProcedure::builder); + mapBuilder.put("rollback_to_timestamp", RollbackToTimestampProcedure::builder); + mapBuilder.put("set_current_snapshot", SetCurrentSnapshotProcedure::builder); + mapBuilder.put("cherrypick_snapshot", CherrypickSnapshotProcedure::builder); + mapBuilder.put("rewrite_data_files", RewriteDataFilesProcedure::builder); + mapBuilder.put("rewrite_manifests", RewriteManifestsProcedure::builder); + mapBuilder.put("remove_orphan_files", RemoveOrphanFilesProcedure::builder); + mapBuilder.put("expire_snapshots", ExpireSnapshotsProcedure::builder); + mapBuilder.put("migrate", MigrateTableProcedure::builder); + mapBuilder.put("snapshot", SnapshotTableProcedure::builder); + mapBuilder.put("add_files", AddFilesProcedure::builder); + mapBuilder.put("ancestors_of", AncestorsOfProcedure::builder); + mapBuilder.put("register_table", RegisterTableProcedure::builder); + return mapBuilder.build(); + } + + public interface ProcedureBuilder { + ProcedureBuilder withTableCatalog(TableCatalog tableCatalog); + Procedure build(); + } +} diff --git a/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/source/BaseDataReader.java b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/source/BaseDataReader.java new file mode 100644 index 000000000000..f0664c7e8e29 --- /dev/null +++ b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/source/BaseDataReader.java @@ -0,0 +1,203 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iceberg.spark.source; + +import java.io.Closeable; +import java.io.IOException; +import java.math.BigDecimal; +import java.nio.ByteBuffer; +import java.util.Iterator; +import java.util.List; +import java.util.Map; +import java.util.stream.Stream; +import org.apache.avro.generic.GenericData; +import org.apache.avro.util.Utf8; +import org.apache.iceberg.CombinedScanTask; +import org.apache.iceberg.FileScanTask; +import org.apache.iceberg.MetadataColumns; +import org.apache.iceberg.Partitioning; +import org.apache.iceberg.Schema; +import org.apache.iceberg.StructLike; +import org.apache.iceberg.Table; +import org.apache.iceberg.encryption.EncryptedFiles; +import org.apache.iceberg.encryption.EncryptedInputFile; +import org.apache.iceberg.io.CloseableIterator; +import org.apache.iceberg.io.InputFile; +import org.apache.iceberg.relocated.com.google.common.base.Preconditions; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap; +import org.apache.iceberg.relocated.com.google.common.collect.Maps; +import org.apache.iceberg.types.Type; +import org.apache.iceberg.types.Types.NestedField; +import org.apache.iceberg.types.Types.StructType; +import org.apache.iceberg.util.ByteBuffers; +import org.apache.iceberg.util.PartitionUtil; +import org.apache.spark.rdd.InputFileBlockHolder; +import org.apache.spark.sql.catalyst.expressions.GenericInternalRow; +import org.apache.spark.sql.types.Decimal; +import org.apache.spark.unsafe.types.UTF8String; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * Base class of Spark readers. + * + * @param is the Java class returned by this reader whose objects contain one or more rows. + */ +abstract class BaseDataReader implements Closeable { + private static final Logger LOG = LoggerFactory.getLogger(BaseDataReader.class); + + private final Table table; + private final Iterator tasks; + private final Map inputFiles; + + private CloseableIterator currentIterator; + private T current = null; + private FileScanTask currentTask = null; + + BaseDataReader(Table table, CombinedScanTask task) { + this.table = table; + this.tasks = task.files().iterator(); + Map keyMetadata = Maps.newHashMap(); + task.files().stream() + .flatMap(fileScanTask -> Stream.concat(Stream.of(fileScanTask.file()), fileScanTask.deletes().stream())) + .forEach(file -> keyMetadata.put(file.path().toString(), file.keyMetadata())); + Stream encrypted = keyMetadata.entrySet().stream() + .map(entry -> EncryptedFiles.encryptedInput(table.io().newInputFile(entry.getKey()), entry.getValue())); + + // decrypt with the batch call to avoid multiple RPCs to a key server, if possible + Iterable decryptedFiles = table.encryption().decrypt(encrypted::iterator); + + Map files = Maps.newHashMapWithExpectedSize(task.files().size()); + decryptedFiles.forEach(decrypted -> files.putIfAbsent(decrypted.location(), decrypted)); + this.inputFiles = ImmutableMap.copyOf(files); + + this.currentIterator = CloseableIterator.empty(); + } + + protected Table table() { + return table; + } + + public boolean next() throws IOException { + try { + while (true) { + if (currentIterator.hasNext()) { + this.current = currentIterator.next(); + return true; + } else if (tasks.hasNext()) { + this.currentIterator.close(); + this.currentTask = tasks.next(); + this.currentIterator = open(currentTask); + } else { + this.currentIterator.close(); + return false; + } + } + } catch (IOException | RuntimeException e) { + if (currentTask != null && !currentTask.isDataTask()) { + LOG.error("Error reading file: {}", getInputFile(currentTask).location(), e); + } + throw e; + } + } + + public T get() { + return current; + } + + abstract CloseableIterator open(FileScanTask task); + + @Override + public void close() throws IOException { + InputFileBlockHolder.unset(); + + // close the current iterator + this.currentIterator.close(); + + // exhaust the task iterator + while (tasks.hasNext()) { + tasks.next(); + } + } + + protected InputFile getInputFile(FileScanTask task) { + Preconditions.checkArgument(!task.isDataTask(), "Invalid task type"); + return inputFiles.get(task.file().path().toString()); + } + + protected InputFile getInputFile(String location) { + return inputFiles.get(location); + } + + protected Map constantsMap(FileScanTask task, Schema readSchema) { + if (readSchema.findField(MetadataColumns.PARTITION_COLUMN_ID) != null) { + StructType partitionType = Partitioning.partitionType(table); + return PartitionUtil.constantsMap(task, partitionType, BaseDataReader::convertConstant); + } else { + return PartitionUtil.constantsMap(task, BaseDataReader::convertConstant); + } + } + + protected static Object convertConstant(Type type, Object value) { + if (value == null) { + return null; + } + + switch (type.typeId()) { + case DECIMAL: + return Decimal.apply((BigDecimal) value); + case STRING: + if (value instanceof Utf8) { + Utf8 utf8 = (Utf8) value; + return UTF8String.fromBytes(utf8.getBytes(), 0, utf8.getByteLength()); + } + return UTF8String.fromString(value.toString()); + case FIXED: + if (value instanceof byte[]) { + return value; + } else if (value instanceof GenericData.Fixed) { + return ((GenericData.Fixed) value).bytes(); + } + return ByteBuffers.toByteArray((ByteBuffer) value); + case BINARY: + return ByteBuffers.toByteArray((ByteBuffer) value); + case STRUCT: + StructType structType = (StructType) type; + + if (structType.fields().isEmpty()) { + return new GenericInternalRow(); + } + + List fields = structType.fields(); + Object[] values = new Object[fields.size()]; + StructLike struct = (StructLike) value; + + for (int index = 0; index < fields.size(); index++) { + NestedField field = fields.get(index); + Type fieldType = field.type(); + values[index] = convertConstant(fieldType, struct.get(index, fieldType.typeId().javaClass())); + } + + return new GenericInternalRow(values); + default: + } + return value; + } +} diff --git a/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/source/BatchDataReader.java b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/source/BatchDataReader.java new file mode 100644 index 000000000000..5f643fa37d5d --- /dev/null +++ b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/source/BatchDataReader.java @@ -0,0 +1,158 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iceberg.spark.source; + +import java.util.Map; +import java.util.Set; +import org.apache.arrow.vector.NullCheckingForGet; +import org.apache.iceberg.CombinedScanTask; +import org.apache.iceberg.DataFile; +import org.apache.iceberg.FileFormat; +import org.apache.iceberg.FileScanTask; +import org.apache.iceberg.MetadataColumns; +import org.apache.iceberg.Schema; +import org.apache.iceberg.StructLike; +import org.apache.iceberg.Table; +import org.apache.iceberg.TableProperties; +import org.apache.iceberg.data.DeleteFilter; +import org.apache.iceberg.io.CloseableIterable; +import org.apache.iceberg.io.CloseableIterator; +import org.apache.iceberg.io.InputFile; +import org.apache.iceberg.mapping.NameMappingParser; +import org.apache.iceberg.orc.ORC; +import org.apache.iceberg.parquet.Parquet; +import org.apache.iceberg.relocated.com.google.common.base.Preconditions; +import org.apache.iceberg.relocated.com.google.common.collect.Sets; +import org.apache.iceberg.spark.SparkSchemaUtil; +import org.apache.iceberg.spark.data.vectorized.VectorizedSparkOrcReaders; +import org.apache.iceberg.spark.data.vectorized.VectorizedSparkParquetReaders; +import org.apache.iceberg.types.TypeUtil; +import org.apache.spark.rdd.InputFileBlockHolder; +import org.apache.spark.sql.catalyst.InternalRow; +import org.apache.spark.sql.vectorized.ColumnarBatch; + +class BatchDataReader extends BaseDataReader { + private final Schema expectedSchema; + private final String nameMapping; + private final boolean caseSensitive; + private final int batchSize; + + BatchDataReader(CombinedScanTask task, Table table, Schema expectedSchema, boolean caseSensitive, int size) { + super(table, task); + this.expectedSchema = expectedSchema; + this.nameMapping = table.properties().get(TableProperties.DEFAULT_NAME_MAPPING); + this.caseSensitive = caseSensitive; + this.batchSize = size; + } + + @Override + CloseableIterator open(FileScanTask task) { + DataFile file = task.file(); + + // update the current file for Spark's filename() function + InputFileBlockHolder.set(file.path().toString(), task.start(), task.length()); + + Map idToConstant = constantsMap(task, expectedSchema); + + CloseableIterable iter; + InputFile location = getInputFile(task); + Preconditions.checkNotNull(location, "Could not find InputFile associated with FileScanTask"); + if (task.file().format() == FileFormat.PARQUET) { + SparkDeleteFilter deleteFilter = deleteFilter(task); + // get required schema for filtering out equality-delete rows in case equality-delete uses columns are + // not selected. + Schema requiredSchema = requiredSchema(deleteFilter); + + Parquet.ReadBuilder builder = Parquet.read(location) + .project(requiredSchema) + .split(task.start(), task.length()) + .createBatchedReaderFunc(fileSchema -> VectorizedSparkParquetReaders.buildReader(requiredSchema, + fileSchema, /* setArrowValidityVector */ NullCheckingForGet.NULL_CHECKING_ENABLED, idToConstant, + deleteFilter)) + .recordsPerBatch(batchSize) + .filter(task.residual()) + .caseSensitive(caseSensitive) + // Spark eagerly consumes the batches. So the underlying memory allocated could be reused + // without worrying about subsequent reads clobbering over each other. This improves + // read performance as every batch read doesn't have to pay the cost of allocating memory. + .reuseContainers(); + + if (nameMapping != null) { + builder.withNameMapping(NameMappingParser.fromJson(nameMapping)); + } + + iter = builder.build(); + } else if (task.file().format() == FileFormat.ORC) { + Set constantFieldIds = idToConstant.keySet(); + Set metadataFieldIds = MetadataColumns.metadataFieldIds(); + Sets.SetView constantAndMetadataFieldIds = Sets.union(constantFieldIds, metadataFieldIds); + Schema schemaWithoutConstantAndMetadataFields = TypeUtil.selectNot(expectedSchema, constantAndMetadataFieldIds); + ORC.ReadBuilder builder = ORC.read(location) + .project(schemaWithoutConstantAndMetadataFields) + .split(task.start(), task.length()) + .createBatchedReaderFunc(fileSchema -> VectorizedSparkOrcReaders.buildReader(expectedSchema, fileSchema, + idToConstant)) + .recordsPerBatch(batchSize) + .filter(task.residual()) + .caseSensitive(caseSensitive); + + if (nameMapping != null) { + builder.withNameMapping(NameMappingParser.fromJson(nameMapping)); + } + + iter = builder.build(); + } else { + throw new UnsupportedOperationException( + "Format: " + task.file().format() + " not supported for batched reads"); + } + return iter.iterator(); + } + + private SparkDeleteFilter deleteFilter(FileScanTask task) { + return task.deletes().isEmpty() ? null : new SparkDeleteFilter(task, table().schema(), expectedSchema); + } + + private Schema requiredSchema(DeleteFilter deleteFilter) { + if (deleteFilter != null && deleteFilter.hasEqDeletes()) { + return deleteFilter.requiredSchema(); + } else { + return expectedSchema; + } + } + + private class SparkDeleteFilter extends DeleteFilter { + private final InternalRowWrapper asStructLike; + + SparkDeleteFilter(FileScanTask task, Schema tableSchema, Schema requestedSchema) { + super(task.file().path().toString(), task.deletes(), tableSchema, requestedSchema); + this.asStructLike = new InternalRowWrapper(SparkSchemaUtil.convert(requiredSchema())); + } + + @Override + protected StructLike asStructLike(InternalRow row) { + return asStructLike.wrap(row); + } + + @Override + protected InputFile getInputFile(String location) { + return BatchDataReader.this.getInputFile(location); + } + } +} diff --git a/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/source/EqualityDeleteRowReader.java b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/source/EqualityDeleteRowReader.java new file mode 100644 index 000000000000..ce2226f4f75e --- /dev/null +++ b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/source/EqualityDeleteRowReader.java @@ -0,0 +1,54 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iceberg.spark.source; + +import java.util.Map; +import org.apache.iceberg.CombinedScanTask; +import org.apache.iceberg.DataFile; +import org.apache.iceberg.FileScanTask; +import org.apache.iceberg.Schema; +import org.apache.iceberg.Table; +import org.apache.iceberg.io.CloseableIterator; +import org.apache.spark.rdd.InputFileBlockHolder; +import org.apache.spark.sql.catalyst.InternalRow; + +public class EqualityDeleteRowReader extends RowDataReader { + private final Schema expectedSchema; + + public EqualityDeleteRowReader(CombinedScanTask task, Table table, Schema expectedSchema, boolean caseSensitive) { + super(task, table, table.schema(), caseSensitive); + this.expectedSchema = expectedSchema; + } + + @Override + CloseableIterator open(FileScanTask task) { + SparkDeleteFilter matches = new SparkDeleteFilter(task, tableSchema(), expectedSchema); + + // schema or rows returned by readers + Schema requiredSchema = matches.requiredSchema(); + Map idToConstant = constantsMap(task, expectedSchema); + DataFile file = task.file(); + + // update the current file for Spark's filename() function + InputFileBlockHolder.set(file.path().toString(), task.start(), task.length()); + + return matches.findEqualityDeleteRows(open(task, requiredSchema, idToConstant)).iterator(); + } +} diff --git a/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/source/HasIcebergCatalog.java b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/source/HasIcebergCatalog.java new file mode 100644 index 000000000000..e2579a0059b0 --- /dev/null +++ b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/source/HasIcebergCatalog.java @@ -0,0 +1,31 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iceberg.spark.source; + +import org.apache.iceberg.catalog.Catalog; +import org.apache.spark.sql.connector.catalog.TableCatalog; + +public interface HasIcebergCatalog extends TableCatalog { + + /** + * Returns the underlying {@link org.apache.iceberg.catalog.Catalog} backing this Spark Catalog + */ + Catalog icebergCatalog(); +} diff --git a/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/source/IcebergSource.java b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/source/IcebergSource.java new file mode 100644 index 000000000000..7aa66b2223fc --- /dev/null +++ b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/source/IcebergSource.java @@ -0,0 +1,189 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iceberg.spark.source; + +import java.util.Arrays; +import java.util.Map; +import org.apache.iceberg.relocated.com.google.common.base.Preconditions; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap; +import org.apache.iceberg.spark.PathIdentifier; +import org.apache.iceberg.spark.Spark3Util; +import org.apache.iceberg.spark.SparkReadOptions; +import org.apache.iceberg.spark.SparkSessionCatalog; +import org.apache.spark.sql.SparkSession; +import org.apache.spark.sql.catalyst.analysis.NoSuchTableException; +import org.apache.spark.sql.connector.catalog.CatalogManager; +import org.apache.spark.sql.connector.catalog.CatalogPlugin; +import org.apache.spark.sql.connector.catalog.Identifier; +import org.apache.spark.sql.connector.catalog.SupportsCatalogOptions; +import org.apache.spark.sql.connector.catalog.Table; +import org.apache.spark.sql.connector.catalog.TableCatalog; +import org.apache.spark.sql.connector.expressions.Transform; +import org.apache.spark.sql.sources.DataSourceRegister; +import org.apache.spark.sql.types.StructType; +import org.apache.spark.sql.util.CaseInsensitiveStringMap; + +/** + * The IcebergSource loads/writes tables with format "iceberg". It can load paths and tables. + * + * How paths/tables are loaded when using spark.read().format("iceberg").path(table) + * + * table = "file:/path/to/table" -> loads a HadoopTable at given path + * table = "tablename" -> loads currentCatalog.currentNamespace.tablename + * table = "catalog.tablename" -> load "tablename" from the specified catalog. + * table = "namespace.tablename" -> load "namespace.tablename" from current catalog + * table = "catalog.namespace.tablename" -> "namespace.tablename" from the specified catalog. + * table = "namespace1.namespace2.tablename" -> load "namespace1.namespace2.tablename" from current catalog + * + * The above list is in order of priority. For example: a matching catalog will take priority over any namespace + * resolution. + */ +public class IcebergSource implements DataSourceRegister, SupportsCatalogOptions { + private static final String DEFAULT_CATALOG_NAME = "default_iceberg"; + private static final String DEFAULT_CATALOG = "spark.sql.catalog." + DEFAULT_CATALOG_NAME; + private static final String AT_TIMESTAMP = "at_timestamp_"; + private static final String SNAPSHOT_ID = "snapshot_id_"; + + @Override + public String shortName() { + return "iceberg"; + } + + @Override + public StructType inferSchema(CaseInsensitiveStringMap options) { + return null; + } + + @Override + public Transform[] inferPartitioning(CaseInsensitiveStringMap options) { + return getTable(null, null, options).partitioning(); + } + + @Override + public boolean supportsExternalMetadata() { + return true; + } + + @Override + public Table getTable(StructType schema, Transform[] partitioning, Map options) { + Spark3Util.CatalogAndIdentifier catalogIdentifier = catalogAndIdentifier(new CaseInsensitiveStringMap(options)); + CatalogPlugin catalog = catalogIdentifier.catalog(); + Identifier ident = catalogIdentifier.identifier(); + + try { + if (catalog instanceof TableCatalog) { + return ((TableCatalog) catalog).loadTable(ident); + } + } catch (NoSuchTableException e) { + // throwing an iceberg NoSuchTableException because the Spark one is typed and cant be thrown from this interface + throw new org.apache.iceberg.exceptions.NoSuchTableException(e, "Cannot find table for %s.", ident); + } + + // throwing an iceberg NoSuchTableException because the Spark one is typed and cant be thrown from this interface + throw new org.apache.iceberg.exceptions.NoSuchTableException("Cannot find table for %s.", ident); + } + + private Spark3Util.CatalogAndIdentifier catalogAndIdentifier(CaseInsensitiveStringMap options) { + Preconditions.checkArgument(options.containsKey("path"), "Cannot open table: path is not set"); + SparkSession spark = SparkSession.active(); + setupDefaultSparkCatalog(spark); + String path = options.get("path"); + + Long snapshotId = propertyAsLong(options, SparkReadOptions.SNAPSHOT_ID); + Long asOfTimestamp = propertyAsLong(options, SparkReadOptions.AS_OF_TIMESTAMP); + Preconditions.checkArgument(asOfTimestamp == null || snapshotId == null, + "Cannot specify both snapshot-id (%s) and as-of-timestamp (%s)", snapshotId, asOfTimestamp); + + String selector = null; + + if (snapshotId != null) { + selector = SNAPSHOT_ID + snapshotId; + } + + if (asOfTimestamp != null) { + selector = AT_TIMESTAMP + asOfTimestamp; + } + + CatalogManager catalogManager = spark.sessionState().catalogManager(); + + if (path.contains("/")) { + // contains a path. Return iceberg default catalog and a PathIdentifier + String newPath = (selector == null) ? path : path + "#" + selector; + return new Spark3Util.CatalogAndIdentifier(catalogManager.catalog(DEFAULT_CATALOG_NAME), + new PathIdentifier(newPath)); + } + + final Spark3Util.CatalogAndIdentifier catalogAndIdentifier = Spark3Util.catalogAndIdentifier( + "path or identifier", spark, path); + + Identifier ident = identifierWithSelector(catalogAndIdentifier.identifier(), selector); + if (catalogAndIdentifier.catalog().name().equals("spark_catalog") && + !(catalogAndIdentifier.catalog() instanceof SparkSessionCatalog)) { + // catalog is a session catalog but does not support Iceberg. Use Iceberg instead. + return new Spark3Util.CatalogAndIdentifier(catalogManager.catalog(DEFAULT_CATALOG_NAME), ident); + } else { + return new Spark3Util.CatalogAndIdentifier(catalogAndIdentifier.catalog(), ident); + } + } + + private Identifier identifierWithSelector(Identifier ident, String selector) { + if (selector == null) { + return ident; + } else { + String[] namespace = ident.namespace(); + String[] ns = Arrays.copyOf(namespace, namespace.length + 1); + ns[namespace.length] = ident.name(); + return Identifier.of(ns, selector); + } + } + + @Override + public Identifier extractIdentifier(CaseInsensitiveStringMap options) { + return catalogAndIdentifier(options).identifier(); + } + + @Override + public String extractCatalog(CaseInsensitiveStringMap options) { + return catalogAndIdentifier(options).catalog().name(); + } + + private static Long propertyAsLong(CaseInsensitiveStringMap options, String property) { + String value = options.get(property); + if (value != null) { + return Long.parseLong(value); + } + + return null; + } + + private static void setupDefaultSparkCatalog(SparkSession spark) { + if (spark.conf().contains(DEFAULT_CATALOG)) { + return; + } + ImmutableMap config = ImmutableMap.of( + "type", "hive", + "default-namespace", "default", + "cache-enabled", "false" // the source should not use a cache + ); + String catalogName = "org.apache.iceberg.spark.SparkCatalog"; + spark.conf().set(DEFAULT_CATALOG, catalogName); + config.forEach((key, value) -> spark.conf().set(DEFAULT_CATALOG + "." + key, value)); + } +} diff --git a/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/source/InternalRowWrapper.java b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/source/InternalRowWrapper.java new file mode 100644 index 000000000000..ef1eb08d873c --- /dev/null +++ b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/source/InternalRowWrapper.java @@ -0,0 +1,96 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iceberg.spark.source; + +import java.nio.ByteBuffer; +import java.util.function.BiFunction; +import java.util.stream.Stream; +import org.apache.iceberg.StructLike; +import org.apache.spark.sql.catalyst.InternalRow; +import org.apache.spark.sql.types.BinaryType; +import org.apache.spark.sql.types.DataType; +import org.apache.spark.sql.types.DecimalType; +import org.apache.spark.sql.types.StringType; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; + +/** + * Class to adapt a Spark {@code InternalRow} to Iceberg {@link StructLike} for uses like + * {@link org.apache.iceberg.PartitionKey#partition(StructLike)} + */ +class InternalRowWrapper implements StructLike { + private final DataType[] types; + private final BiFunction[] getters; + private InternalRow row = null; + + @SuppressWarnings("unchecked") + InternalRowWrapper(StructType rowType) { + this.types = Stream.of(rowType.fields()) + .map(StructField::dataType) + .toArray(DataType[]::new); + this.getters = Stream.of(types) + .map(InternalRowWrapper::getter) + .toArray(BiFunction[]::new); + } + + InternalRowWrapper wrap(InternalRow internalRow) { + this.row = internalRow; + return this; + } + + @Override + public int size() { + return types.length; + } + + @Override + public T get(int pos, Class javaClass) { + if (row.isNullAt(pos)) { + return null; + } else if (getters[pos] != null) { + return javaClass.cast(getters[pos].apply(row, pos)); + } + + return javaClass.cast(row.get(pos, types[pos])); + } + + @Override + public void set(int pos, T value) { + row.update(pos, value); + } + + private static BiFunction getter(DataType type) { + if (type instanceof StringType) { + return (row, pos) -> row.getUTF8String(pos).toString(); + } else if (type instanceof DecimalType) { + DecimalType decimal = (DecimalType) type; + return (row, pos) -> + row.getDecimal(pos, decimal.precision(), decimal.scale()).toJavaBigDecimal(); + } else if (type instanceof BinaryType) { + return (row, pos) -> ByteBuffer.wrap(row.getBinary(pos)); + } else if (type instanceof StructType) { + StructType structType = (StructType) type; + InternalRowWrapper nestedWrapper = new InternalRowWrapper(structType); + return (row, pos) -> nestedWrapper.wrap(row.getStruct(pos, structType.size())); + } + + return null; + } +} diff --git a/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/source/RowDataReader.java b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/source/RowDataReader.java new file mode 100644 index 000000000000..6ebceee2f62a --- /dev/null +++ b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/source/RowDataReader.java @@ -0,0 +1,204 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iceberg.spark.source; + +import java.util.Map; +import org.apache.iceberg.CombinedScanTask; +import org.apache.iceberg.DataFile; +import org.apache.iceberg.DataTask; +import org.apache.iceberg.FileScanTask; +import org.apache.iceberg.MetadataColumns; +import org.apache.iceberg.Schema; +import org.apache.iceberg.StructLike; +import org.apache.iceberg.Table; +import org.apache.iceberg.TableProperties; +import org.apache.iceberg.avro.Avro; +import org.apache.iceberg.data.DeleteFilter; +import org.apache.iceberg.io.CloseableIterable; +import org.apache.iceberg.io.CloseableIterator; +import org.apache.iceberg.io.InputFile; +import org.apache.iceberg.mapping.NameMappingParser; +import org.apache.iceberg.orc.ORC; +import org.apache.iceberg.parquet.Parquet; +import org.apache.iceberg.relocated.com.google.common.base.Preconditions; +import org.apache.iceberg.relocated.com.google.common.collect.Sets; +import org.apache.iceberg.spark.SparkSchemaUtil; +import org.apache.iceberg.spark.data.SparkAvroReader; +import org.apache.iceberg.spark.data.SparkOrcReader; +import org.apache.iceberg.spark.data.SparkParquetReaders; +import org.apache.iceberg.types.TypeUtil; +import org.apache.spark.rdd.InputFileBlockHolder; +import org.apache.spark.sql.catalyst.InternalRow; + +class RowDataReader extends BaseDataReader { + + private final Schema tableSchema; + private final Schema expectedSchema; + private final String nameMapping; + private final boolean caseSensitive; + + RowDataReader(CombinedScanTask task, Table table, Schema expectedSchema, boolean caseSensitive) { + super(table, task); + this.tableSchema = table.schema(); + this.expectedSchema = expectedSchema; + this.nameMapping = table.properties().get(TableProperties.DEFAULT_NAME_MAPPING); + this.caseSensitive = caseSensitive; + } + + @Override + CloseableIterator open(FileScanTask task) { + SparkDeleteFilter deletes = new SparkDeleteFilter(task, tableSchema, expectedSchema); + + // schema or rows returned by readers + Schema requiredSchema = deletes.requiredSchema(); + Map idToConstant = constantsMap(task, expectedSchema); + DataFile file = task.file(); + + // update the current file for Spark's filename() function + InputFileBlockHolder.set(file.path().toString(), task.start(), task.length()); + + return deletes.filter(open(task, requiredSchema, idToConstant)).iterator(); + } + + protected Schema tableSchema() { + return tableSchema; + } + + protected CloseableIterable open(FileScanTask task, Schema readSchema, Map idToConstant) { + CloseableIterable iter; + if (task.isDataTask()) { + iter = newDataIterable(task.asDataTask(), readSchema); + } else { + InputFile location = getInputFile(task); + Preconditions.checkNotNull(location, "Could not find InputFile associated with FileScanTask"); + + switch (task.file().format()) { + case PARQUET: + iter = newParquetIterable(location, task, readSchema, idToConstant); + break; + + case AVRO: + iter = newAvroIterable(location, task, readSchema, idToConstant); + break; + + case ORC: + iter = newOrcIterable(location, task, readSchema, idToConstant); + break; + + default: + throw new UnsupportedOperationException( + "Cannot read unknown format: " + task.file().format()); + } + } + + return iter; + } + + private CloseableIterable newAvroIterable( + InputFile location, + FileScanTask task, + Schema projection, + Map idToConstant) { + Avro.ReadBuilder builder = Avro.read(location) + .reuseContainers() + .project(projection) + .split(task.start(), task.length()) + .createReaderFunc(readSchema -> new SparkAvroReader(projection, readSchema, idToConstant)); + + if (nameMapping != null) { + builder.withNameMapping(NameMappingParser.fromJson(nameMapping)); + } + + return builder.build(); + } + + private CloseableIterable newParquetIterable( + InputFile location, + FileScanTask task, + Schema readSchema, + Map idToConstant) { + Parquet.ReadBuilder builder = Parquet.read(location) + .reuseContainers() + .split(task.start(), task.length()) + .project(readSchema) + .createReaderFunc(fileSchema -> SparkParquetReaders.buildReader(readSchema, fileSchema, idToConstant)) + .filter(task.residual()) + .caseSensitive(caseSensitive); + + if (nameMapping != null) { + builder.withNameMapping(NameMappingParser.fromJson(nameMapping)); + } + + return builder.build(); + } + + private CloseableIterable newOrcIterable( + InputFile location, + FileScanTask task, + Schema readSchema, + Map idToConstant) { + Schema readSchemaWithoutConstantAndMetadataFields = TypeUtil.selectNot(readSchema, + Sets.union(idToConstant.keySet(), MetadataColumns.metadataFieldIds())); + + ORC.ReadBuilder builder = ORC.read(location) + .project(readSchemaWithoutConstantAndMetadataFields) + .split(task.start(), task.length()) + .createReaderFunc(readOrcSchema -> new SparkOrcReader(readSchema, readOrcSchema, idToConstant)) + .filter(task.residual()) + .caseSensitive(caseSensitive); + + if (nameMapping != null) { + builder.withNameMapping(NameMappingParser.fromJson(nameMapping)); + } + + return builder.build(); + } + + private CloseableIterable newDataIterable(DataTask task, Schema readSchema) { + StructInternalRow row = new StructInternalRow(readSchema.asStruct()); + CloseableIterable asSparkRows = CloseableIterable.transform( + task.asDataTask().rows(), row::setStruct); + return asSparkRows; + } + + protected class SparkDeleteFilter extends DeleteFilter { + private final InternalRowWrapper asStructLike; + + SparkDeleteFilter(FileScanTask task, Schema tableSchema, Schema requestedSchema) { + super(task.file().path().toString(), task.deletes(), tableSchema, requestedSchema); + this.asStructLike = new InternalRowWrapper(SparkSchemaUtil.convert(requiredSchema())); + } + + @Override + protected StructLike asStructLike(InternalRow row) { + return asStructLike.wrap(row); + } + + @Override + protected InputFile getInputFile(String location) { + return RowDataReader.this.getInputFile(location); + } + + @Override + protected void markRowDeleted(InternalRow row) { + row.setBoolean(columnIsDeletedPosition(), true); + } + } +} diff --git a/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/source/RowDataRewriter.java b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/source/RowDataRewriter.java new file mode 100644 index 000000000000..63cc3a466c1a --- /dev/null +++ b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/source/RowDataRewriter.java @@ -0,0 +1,153 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iceberg.spark.source; + +import java.io.Serializable; +import java.util.Collection; +import java.util.List; +import java.util.Locale; +import java.util.Map; +import java.util.stream.Collectors; +import org.apache.iceberg.CombinedScanTask; +import org.apache.iceberg.DataFile; +import org.apache.iceberg.FileFormat; +import org.apache.iceberg.PartitionSpec; +import org.apache.iceberg.Schema; +import org.apache.iceberg.Table; +import org.apache.iceberg.TableProperties; +import org.apache.iceberg.io.OutputFileFactory; +import org.apache.iceberg.io.TaskWriter; +import org.apache.iceberg.io.UnpartitionedWriter; +import org.apache.iceberg.relocated.com.google.common.collect.Lists; +import org.apache.iceberg.spark.SparkSchemaUtil; +import org.apache.iceberg.util.PropertyUtil; +import org.apache.spark.TaskContext; +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.broadcast.Broadcast; +import org.apache.spark.sql.catalyst.InternalRow; +import org.apache.spark.sql.types.StructType; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +public class RowDataRewriter implements Serializable { + + private static final Logger LOG = LoggerFactory.getLogger(RowDataRewriter.class); + + private final Broadcast

tableBroadcast; + private final PartitionSpec spec; + private final FileFormat format; + private final boolean caseSensitive; + + public RowDataRewriter(Broadcast
tableBroadcast, PartitionSpec spec, boolean caseSensitive) { + this.tableBroadcast = tableBroadcast; + this.spec = spec; + this.caseSensitive = caseSensitive; + + Table table = tableBroadcast.value(); + String formatString = table.properties().getOrDefault( + TableProperties.DEFAULT_FILE_FORMAT, TableProperties.DEFAULT_FILE_FORMAT_DEFAULT); + this.format = FileFormat.valueOf(formatString.toUpperCase(Locale.ENGLISH)); + } + + public List rewriteDataForTasks(JavaRDD taskRDD) { + JavaRDD> dataFilesRDD = taskRDD.map(this::rewriteDataForTask); + + return dataFilesRDD.collect().stream() + .flatMap(Collection::stream) + .collect(Collectors.toList()); + } + + private List rewriteDataForTask(CombinedScanTask task) throws Exception { + TaskContext context = TaskContext.get(); + int partitionId = context.partitionId(); + long taskId = context.taskAttemptId(); + + Table table = tableBroadcast.value(); + Schema schema = table.schema(); + Map properties = table.properties(); + + RowDataReader dataReader = new RowDataReader(task, table, schema, caseSensitive); + + StructType structType = SparkSchemaUtil.convert(schema); + SparkAppenderFactory appenderFactory = SparkAppenderFactory.builderFor(table, schema, structType) + .spec(spec) + .build(); + OutputFileFactory fileFactory = OutputFileFactory.builderFor(table, partitionId, taskId) + .defaultSpec(spec) + .format(format) + .build(); + + TaskWriter writer; + if (spec.isUnpartitioned()) { + writer = new UnpartitionedWriter<>(spec, format, appenderFactory, fileFactory, table.io(), + Long.MAX_VALUE); + } else if (PropertyUtil.propertyAsBoolean(properties, + TableProperties.SPARK_WRITE_PARTITIONED_FANOUT_ENABLED, + TableProperties.SPARK_WRITE_PARTITIONED_FANOUT_ENABLED_DEFAULT)) { + writer = new SparkPartitionedFanoutWriter( + spec, format, appenderFactory, fileFactory, table.io(), Long.MAX_VALUE, schema, + structType); + } else { + writer = new SparkPartitionedWriter( + spec, format, appenderFactory, fileFactory, table.io(), Long.MAX_VALUE, schema, + structType); + } + + try { + while (dataReader.next()) { + InternalRow row = dataReader.get(); + writer.write(row); + } + + dataReader.close(); + dataReader = null; + + writer.close(); + return Lists.newArrayList(writer.dataFiles()); + + } catch (Throwable originalThrowable) { + try { + LOG.error("Aborting task", originalThrowable); + context.markTaskFailed(originalThrowable); + + LOG.error("Aborting commit for partition {} (task {}, attempt {}, stage {}.{})", + partitionId, taskId, context.attemptNumber(), context.stageId(), context.stageAttemptNumber()); + if (dataReader != null) { + dataReader.close(); + } + writer.abort(); + LOG.error("Aborted commit for partition {} (task {}, attempt {}, stage {}.{})", + partitionId, taskId, context.taskAttemptId(), context.stageId(), context.stageAttemptNumber()); + + } catch (Throwable inner) { + if (originalThrowable != inner) { + originalThrowable.addSuppressed(inner); + LOG.warn("Suppressing exception in catch: {}", inner.getMessage(), inner); + } + } + + if (originalThrowable instanceof Exception) { + throw originalThrowable; + } else { + throw new RuntimeException(originalThrowable); + } + } + } +} diff --git a/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/source/SparkAppenderFactory.java b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/source/SparkAppenderFactory.java new file mode 100644 index 000000000000..4becf666ed3e --- /dev/null +++ b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/source/SparkAppenderFactory.java @@ -0,0 +1,295 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iceberg.spark.source; + +import java.io.IOException; +import java.io.UncheckedIOException; +import java.util.Map; +import org.apache.iceberg.FileFormat; +import org.apache.iceberg.MetricsConfig; +import org.apache.iceberg.PartitionSpec; +import org.apache.iceberg.Schema; +import org.apache.iceberg.StructLike; +import org.apache.iceberg.Table; +import org.apache.iceberg.avro.Avro; +import org.apache.iceberg.deletes.EqualityDeleteWriter; +import org.apache.iceberg.deletes.PositionDeleteWriter; +import org.apache.iceberg.encryption.EncryptedOutputFile; +import org.apache.iceberg.exceptions.RuntimeIOException; +import org.apache.iceberg.io.DataWriter; +import org.apache.iceberg.io.DeleteSchemaUtil; +import org.apache.iceberg.io.FileAppender; +import org.apache.iceberg.io.FileAppenderFactory; +import org.apache.iceberg.io.OutputFile; +import org.apache.iceberg.orc.ORC; +import org.apache.iceberg.parquet.Parquet; +import org.apache.iceberg.relocated.com.google.common.base.Preconditions; +import org.apache.iceberg.spark.SparkSchemaUtil; +import org.apache.iceberg.spark.data.SparkAvroWriter; +import org.apache.iceberg.spark.data.SparkOrcWriter; +import org.apache.iceberg.spark.data.SparkParquetWriters; +import org.apache.spark.sql.catalyst.InternalRow; +import org.apache.spark.sql.types.StructType; +import org.apache.spark.unsafe.types.UTF8String; + +class SparkAppenderFactory implements FileAppenderFactory { + private final Map properties; + private final Schema writeSchema; + private final StructType dsSchema; + private final PartitionSpec spec; + private final int[] equalityFieldIds; + private final Schema eqDeleteRowSchema; + private final Schema posDeleteRowSchema; + + private StructType eqDeleteSparkType = null; + private StructType posDeleteSparkType = null; + + SparkAppenderFactory(Map properties, Schema writeSchema, StructType dsSchema, PartitionSpec spec, + int[] equalityFieldIds, Schema eqDeleteRowSchema, Schema posDeleteRowSchema) { + this.properties = properties; + this.writeSchema = writeSchema; + this.dsSchema = dsSchema; + this.spec = spec; + this.equalityFieldIds = equalityFieldIds; + this.eqDeleteRowSchema = eqDeleteRowSchema; + this.posDeleteRowSchema = posDeleteRowSchema; + } + + static Builder builderFor(Table table, Schema writeSchema, StructType dsSchema) { + return new Builder(table, writeSchema, dsSchema); + } + + static class Builder { + private final Table table; + private final Schema writeSchema; + private final StructType dsSchema; + private PartitionSpec spec; + private int[] equalityFieldIds; + private Schema eqDeleteRowSchema; + private Schema posDeleteRowSchema; + + + Builder(Table table, Schema writeSchema, StructType dsSchema) { + this.table = table; + this.spec = table.spec(); + this.writeSchema = writeSchema; + this.dsSchema = dsSchema; + } + + Builder spec(PartitionSpec newSpec) { + this.spec = newSpec; + return this; + } + + Builder equalityFieldIds(int[] newEqualityFieldIds) { + this.equalityFieldIds = newEqualityFieldIds; + return this; + } + + Builder eqDeleteRowSchema(Schema newEqDeleteRowSchema) { + this.eqDeleteRowSchema = newEqDeleteRowSchema; + return this; + } + + Builder posDelRowSchema(Schema newPosDelRowSchema) { + this.posDeleteRowSchema = newPosDelRowSchema; + return this; + } + + SparkAppenderFactory build() { + Preconditions.checkNotNull(table, "Table must not be null"); + Preconditions.checkNotNull(writeSchema, "Write Schema must not be null"); + Preconditions.checkNotNull(dsSchema, "DS Schema must not be null"); + if (equalityFieldIds != null) { + Preconditions.checkNotNull(eqDeleteRowSchema, "Equality Field Ids and Equality Delete Row Schema" + + " must be set together"); + } + if (eqDeleteRowSchema != null) { + Preconditions.checkNotNull(equalityFieldIds, "Equality Field Ids and Equality Delete Row Schema" + + " must be set together"); + } + + return new SparkAppenderFactory(table.properties(), writeSchema, dsSchema, spec, equalityFieldIds, + eqDeleteRowSchema, posDeleteRowSchema); + } + } + + private StructType lazyEqDeleteSparkType() { + if (eqDeleteSparkType == null) { + Preconditions.checkNotNull(eqDeleteRowSchema, "Equality delete row schema shouldn't be null"); + this.eqDeleteSparkType = SparkSchemaUtil.convert(eqDeleteRowSchema); + } + return eqDeleteSparkType; + } + + private StructType lazyPosDeleteSparkType() { + if (posDeleteSparkType == null) { + Preconditions.checkNotNull(posDeleteRowSchema, "Position delete row schema shouldn't be null"); + this.posDeleteSparkType = SparkSchemaUtil.convert(posDeleteRowSchema); + } + return posDeleteSparkType; + } + + @Override + public FileAppender newAppender(OutputFile file, FileFormat fileFormat) { + MetricsConfig metricsConfig = MetricsConfig.fromProperties(properties); + try { + switch (fileFormat) { + case PARQUET: + return Parquet.write(file) + .createWriterFunc(msgType -> SparkParquetWriters.buildWriter(dsSchema, msgType)) + .setAll(properties) + .metricsConfig(metricsConfig) + .schema(writeSchema) + .overwrite() + .build(); + + case AVRO: + return Avro.write(file) + .createWriterFunc(ignored -> new SparkAvroWriter(dsSchema)) + .setAll(properties) + .schema(writeSchema) + .overwrite() + .build(); + + case ORC: + return ORC.write(file) + .createWriterFunc(SparkOrcWriter::new) + .setAll(properties) + .metricsConfig(metricsConfig) + .schema(writeSchema) + .overwrite() + .build(); + + default: + throw new UnsupportedOperationException("Cannot write unknown format: " + fileFormat); + } + } catch (IOException e) { + throw new RuntimeIOException(e); + } + } + + @Override + public DataWriter newDataWriter(EncryptedOutputFile file, FileFormat format, StructLike partition) { + return new DataWriter<>(newAppender(file.encryptingOutputFile(), format), format, + file.encryptingOutputFile().location(), spec, partition, file.keyMetadata()); + } + + @Override + public EqualityDeleteWriter newEqDeleteWriter(EncryptedOutputFile file, FileFormat format, + StructLike partition) { + Preconditions.checkState(equalityFieldIds != null && equalityFieldIds.length > 0, + "Equality field ids shouldn't be null or empty when creating equality-delete writer"); + Preconditions.checkNotNull(eqDeleteRowSchema, + "Equality delete row schema shouldn't be null when creating equality-delete writer"); + + try { + switch (format) { + case PARQUET: + return Parquet.writeDeletes(file.encryptingOutputFile()) + .createWriterFunc(msgType -> SparkParquetWriters.buildWriter(lazyEqDeleteSparkType(), msgType)) + .overwrite() + .rowSchema(eqDeleteRowSchema) + .withSpec(spec) + .withPartition(partition) + .equalityFieldIds(equalityFieldIds) + .withKeyMetadata(file.keyMetadata()) + .buildEqualityWriter(); + + case AVRO: + return Avro.writeDeletes(file.encryptingOutputFile()) + .createWriterFunc(ignored -> new SparkAvroWriter(lazyEqDeleteSparkType())) + .overwrite() + .rowSchema(eqDeleteRowSchema) + .withSpec(spec) + .withPartition(partition) + .equalityFieldIds(equalityFieldIds) + .withKeyMetadata(file.keyMetadata()) + .buildEqualityWriter(); + + case ORC: + return ORC.writeDeletes(file.encryptingOutputFile()) + .createWriterFunc(SparkOrcWriter::new) + .overwrite() + .rowSchema(eqDeleteRowSchema) + .withSpec(spec) + .withPartition(partition) + .equalityFieldIds(equalityFieldIds) + .withKeyMetadata(file.keyMetadata()) + .buildEqualityWriter(); + + default: + throw new UnsupportedOperationException( + "Cannot write equality-deletes for unsupported file format: " + format); + } + } catch (IOException e) { + throw new UncheckedIOException("Failed to create new equality delete writer", e); + } + } + + @Override + public PositionDeleteWriter newPosDeleteWriter(EncryptedOutputFile file, FileFormat format, + StructLike partition) { + try { + switch (format) { + case PARQUET: + StructType sparkPosDeleteSchema = + SparkSchemaUtil.convert(DeleteSchemaUtil.posDeleteSchema(posDeleteRowSchema)); + return Parquet.writeDeletes(file.encryptingOutputFile()) + .createWriterFunc(msgType -> SparkParquetWriters.buildWriter(sparkPosDeleteSchema, msgType)) + .overwrite() + .rowSchema(posDeleteRowSchema) + .withSpec(spec) + .withPartition(partition) + .withKeyMetadata(file.keyMetadata()) + .transformPaths(path -> UTF8String.fromString(path.toString())) + .buildPositionWriter(); + + case AVRO: + return Avro.writeDeletes(file.encryptingOutputFile()) + .createWriterFunc(ignored -> new SparkAvroWriter(lazyPosDeleteSparkType())) + .overwrite() + .rowSchema(posDeleteRowSchema) + .withSpec(spec) + .withPartition(partition) + .withKeyMetadata(file.keyMetadata()) + .buildPositionWriter(); + + case ORC: + return ORC.writeDeletes(file.encryptingOutputFile()) + .createWriterFunc(SparkOrcWriter::new) + .overwrite() + .rowSchema(posDeleteRowSchema) + .withSpec(spec) + .withPartition(partition) + .withKeyMetadata(file.keyMetadata()) + .transformPaths(path -> UTF8String.fromString(path.toString())) + .buildPositionWriter(); + + default: + throw new UnsupportedOperationException( + "Cannot write pos-deletes for unsupported file format: " + format); + } + + } catch (IOException e) { + throw new UncheckedIOException("Failed to create new equality delete writer", e); + } + } +} diff --git a/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/source/SparkBatch.java b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/source/SparkBatch.java new file mode 100644 index 000000000000..ee4674f27bb3 --- /dev/null +++ b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/source/SparkBatch.java @@ -0,0 +1,117 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iceberg.spark.source; + +import java.util.List; +import org.apache.iceberg.CombinedScanTask; +import org.apache.iceberg.FileFormat; +import org.apache.iceberg.Schema; +import org.apache.iceberg.SchemaParser; +import org.apache.iceberg.SerializableTable; +import org.apache.iceberg.Table; +import org.apache.iceberg.spark.SparkReadConf; +import org.apache.iceberg.spark.source.SparkScan.ReadTask; +import org.apache.iceberg.spark.source.SparkScan.ReaderFactory; +import org.apache.iceberg.util.TableScanUtil; +import org.apache.iceberg.util.Tasks; +import org.apache.iceberg.util.ThreadPools; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.broadcast.Broadcast; +import org.apache.spark.sql.connector.read.Batch; +import org.apache.spark.sql.connector.read.InputPartition; +import org.apache.spark.sql.connector.read.PartitionReaderFactory; + +class SparkBatch implements Batch { + + private final JavaSparkContext sparkContext; + private final Table table; + private final SparkReadConf readConf; + private final List tasks; + private final Schema expectedSchema; + private final boolean caseSensitive; + private final boolean localityEnabled; + + SparkBatch(JavaSparkContext sparkContext, Table table, SparkReadConf readConf, + List tasks, Schema expectedSchema) { + this.sparkContext = sparkContext; + this.table = table; + this.readConf = readConf; + this.tasks = tasks; + this.expectedSchema = expectedSchema; + this.caseSensitive = readConf.caseSensitive(); + this.localityEnabled = readConf.localityEnabled(); + } + + @Override + public InputPartition[] planInputPartitions() { + // broadcast the table metadata as input partitions will be sent to executors + Broadcast
tableBroadcast = sparkContext.broadcast(SerializableTable.copyOf(table)); + String expectedSchemaString = SchemaParser.toJson(expectedSchema); + + InputPartition[] readTasks = new InputPartition[tasks.size()]; + + Tasks.range(readTasks.length) + .stopOnFailure() + .executeWith(localityEnabled ? ThreadPools.getWorkerPool() : null) + .run(index -> readTasks[index] = new ReadTask( + tasks.get(index), tableBroadcast, expectedSchemaString, + caseSensitive, localityEnabled)); + + return readTasks; + } + + @Override + public PartitionReaderFactory createReaderFactory() { + return new ReaderFactory(batchSize()); + } + + private int batchSize() { + if (parquetOnly() && parquetBatchReadsEnabled()) { + return readConf.parquetBatchSize(); + } else if (orcOnly() && orcBatchReadsEnabled()) { + return readConf.orcBatchSize(); + } else { + return 0; + } + } + + private boolean parquetOnly() { + return tasks.stream().allMatch(task -> !task.isDataTask() && onlyFileFormat(task, FileFormat.PARQUET)); + } + + private boolean parquetBatchReadsEnabled() { + return readConf.parquetVectorizationEnabled() && // vectorization enabled + expectedSchema.columns().size() > 0 && // at least one column is projected + expectedSchema.columns().stream().allMatch(c -> c.type().isPrimitiveType()); // only primitives + } + + private boolean orcOnly() { + return tasks.stream().allMatch(task -> !task.isDataTask() && onlyFileFormat(task, FileFormat.ORC)); + } + + private boolean orcBatchReadsEnabled() { + return readConf.orcVectorizationEnabled() && // vectorization enabled + tasks.stream().noneMatch(TableScanUtil::hasDeletes); // no delete files + } + + private boolean onlyFileFormat(CombinedScanTask task, FileFormat fileFormat) { + return task.files().stream().allMatch(fileScanTask -> fileScanTask.file().format().equals(fileFormat)); + } +} diff --git a/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/source/SparkBatchQueryScan.java b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/source/SparkBatchQueryScan.java new file mode 100644 index 000000000000..651a411ebd7b --- /dev/null +++ b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/source/SparkBatchQueryScan.java @@ -0,0 +1,275 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iceberg.spark.source; + +import java.io.IOException; +import java.io.UncheckedIOException; +import java.util.Collections; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.Set; +import java.util.stream.Collectors; +import org.apache.iceberg.CombinedScanTask; +import org.apache.iceberg.FileScanTask; +import org.apache.iceberg.PartitionField; +import org.apache.iceberg.PartitionSpec; +import org.apache.iceberg.Schema; +import org.apache.iceberg.Snapshot; +import org.apache.iceberg.Table; +import org.apache.iceberg.TableScan; +import org.apache.iceberg.exceptions.ValidationException; +import org.apache.iceberg.expressions.Binder; +import org.apache.iceberg.expressions.Evaluator; +import org.apache.iceberg.expressions.Expression; +import org.apache.iceberg.expressions.Expressions; +import org.apache.iceberg.expressions.Projections; +import org.apache.iceberg.io.CloseableIterable; +import org.apache.iceberg.relocated.com.google.common.collect.Lists; +import org.apache.iceberg.relocated.com.google.common.collect.Maps; +import org.apache.iceberg.relocated.com.google.common.collect.Sets; +import org.apache.iceberg.spark.Spark3Util; +import org.apache.iceberg.spark.SparkFilters; +import org.apache.iceberg.spark.SparkReadConf; +import org.apache.iceberg.spark.SparkSchemaUtil; +import org.apache.iceberg.util.SnapshotUtil; +import org.apache.iceberg.util.TableScanUtil; +import org.apache.spark.sql.SparkSession; +import org.apache.spark.sql.connector.expressions.NamedReference; +import org.apache.spark.sql.connector.read.Statistics; +import org.apache.spark.sql.connector.read.SupportsRuntimeFiltering; +import org.apache.spark.sql.sources.Filter; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +class SparkBatchQueryScan extends SparkScan implements SupportsRuntimeFiltering { + + private static final Logger LOG = LoggerFactory.getLogger(SparkBatchQueryScan.class); + + private final TableScan scan; + private final Long snapshotId; + private final Long startSnapshotId; + private final Long endSnapshotId; + private final Long asOfTimestamp; + private final List runtimeFilterExpressions; + + private Set specIds = null; // lazy cache of scanned spec IDs + private List files = null; // lazy cache of files + private List tasks = null; // lazy cache of tasks + + SparkBatchQueryScan(SparkSession spark, Table table, TableScan scan, SparkReadConf readConf, + Schema expectedSchema, List filters) { + + super(spark, table, readConf, expectedSchema, filters); + + this.scan = scan; + this.snapshotId = readConf.snapshotId(); + this.startSnapshotId = readConf.startSnapshotId(); + this.endSnapshotId = readConf.endSnapshotId(); + this.asOfTimestamp = readConf.asOfTimestamp(); + this.runtimeFilterExpressions = Lists.newArrayList(); + + if (scan == null) { + this.specIds = Collections.emptySet(); + this.files = Collections.emptyList(); + this.tasks = Collections.emptyList(); + } + } + + Long snapshotId() { + return snapshotId; + } + + private Set specIds() { + if (specIds == null) { + Set specIdSet = Sets.newHashSet(); + for (FileScanTask file : files()) { + specIdSet.add(file.spec().specId()); + } + this.specIds = specIdSet; + } + + return specIds; + } + + private List files() { + if (files == null) { + try (CloseableIterable filesIterable = scan.planFiles()) { + this.files = Lists.newArrayList(filesIterable); + } catch (IOException e) { + throw new UncheckedIOException("Failed to close table scan: " + scan, e); + } + } + + return files; + } + + @Override + protected List tasks() { + if (tasks == null) { + CloseableIterable splitFiles = TableScanUtil.splitFiles( + CloseableIterable.withNoopClose(files()), + scan.targetSplitSize()); + CloseableIterable scanTasks = TableScanUtil.planTasks( + splitFiles, scan.targetSplitSize(), + scan.splitLookback(), scan.splitOpenFileCost()); + tasks = Lists.newArrayList(scanTasks); + } + + return tasks; + } + + @Override + public NamedReference[] filterAttributes() { + Set partitionFieldSourceIds = Sets.newHashSet(); + + for (Integer specId : specIds()) { + PartitionSpec spec = table().specs().get(specId); + for (PartitionField field : spec.fields()) { + partitionFieldSourceIds.add(field.sourceId()); + } + } + + Map quotedNameById = SparkSchemaUtil.indexQuotedNameById(expectedSchema()); + + // the optimizer will look for an equality condition with filter attributes in a join + // as the scan has been already planned, filtering can only be done on projected attributes + // that's why only partition source fields that are part of the read schema can be reported + + return partitionFieldSourceIds.stream() + .filter(fieldId -> expectedSchema().findField(fieldId) != null) + .map(fieldId -> Spark3Util.toNamedReference(quotedNameById.get(fieldId))) + .toArray(NamedReference[]::new); + } + + @Override + public void filter(Filter[] filters) { + Expression runtimeFilterExpr = convertRuntimeFilters(filters); + + if (runtimeFilterExpr != Expressions.alwaysTrue()) { + Map evaluatorsBySpecId = Maps.newHashMap(); + + for (Integer specId : specIds()) { + PartitionSpec spec = table().specs().get(specId); + Expression inclusiveExpr = Projections.inclusive(spec, caseSensitive()).project(runtimeFilterExpr); + Evaluator inclusive = new Evaluator(spec.partitionType(), inclusiveExpr); + evaluatorsBySpecId.put(specId, inclusive); + } + + LOG.info("Trying to filter {} files using runtime filter {}", files().size(), runtimeFilterExpr); + + List filteredFiles = files().stream() + .filter(file -> { + Evaluator evaluator = evaluatorsBySpecId.get(file.spec().specId()); + return evaluator.eval(file.file().partition()); + }) + .collect(Collectors.toList()); + + LOG.info("{}/{} files matched runtime filter {}", filteredFiles.size(), files().size(), runtimeFilterExpr); + + // don't invalidate tasks if the runtime filter had no effect to avoid planning splits again + if (filteredFiles.size() < files().size()) { + this.specIds = null; + this.files = filteredFiles; + this.tasks = null; + } + + // save the evaluated filter for equals/hashCode + runtimeFilterExpressions.add(runtimeFilterExpr); + } + } + + // at this moment, Spark can only pass IN filters for a single attribute + // if there are multiple filter attributes, Spark will pass two separate IN filters + private Expression convertRuntimeFilters(Filter[] filters) { + Expression runtimeFilterExpr = Expressions.alwaysTrue(); + + for (Filter filter : filters) { + Expression expr = SparkFilters.convert(filter); + if (expr != null) { + try { + Binder.bind(expectedSchema().asStruct(), expr, caseSensitive()); + runtimeFilterExpr = Expressions.and(runtimeFilterExpr, expr); + } catch (ValidationException e) { + LOG.warn("Failed to bind {} to expected schema, skipping runtime filter", expr, e); + } + } else { + LOG.warn("Unsupported runtime filter {}", filter); + } + } + + return runtimeFilterExpr; + } + + @Override + public Statistics estimateStatistics() { + if (scan == null) { + return estimateStatistics(null); + + } else if (snapshotId != null) { + Snapshot snapshot = table().snapshot(snapshotId); + return estimateStatistics(snapshot); + + } else if (asOfTimestamp != null) { + long snapshotIdAsOfTime = SnapshotUtil.snapshotIdAsOfTime(table(), asOfTimestamp); + Snapshot snapshot = table().snapshot(snapshotIdAsOfTime); + return estimateStatistics(snapshot); + + } else { + Snapshot snapshot = table().currentSnapshot(); + return estimateStatistics(snapshot); + } + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + + if (o == null || getClass() != o.getClass()) { + return false; + } + + SparkBatchQueryScan that = (SparkBatchQueryScan) o; + return table().name().equals(that.table().name()) && + readSchema().equals(that.readSchema()) && // compare Spark schemas to ignore field ids + filterExpressions().toString().equals(that.filterExpressions().toString()) && + runtimeFilterExpressions.toString().equals(that.runtimeFilterExpressions.toString()) && + Objects.equals(snapshotId, that.snapshotId) && + Objects.equals(startSnapshotId, that.startSnapshotId) && + Objects.equals(endSnapshotId, that.endSnapshotId) && + Objects.equals(asOfTimestamp, that.asOfTimestamp); + } + + @Override + public int hashCode() { + return Objects.hash( + table().name(), readSchema(), filterExpressions().toString(), runtimeFilterExpressions.toString(), + snapshotId, startSnapshotId, endSnapshotId, asOfTimestamp); + } + + @Override + public String toString() { + return String.format( + "IcebergScan(table=%s, type=%s, filters=%s, runtimeFilters=%s, caseSensitive=%s)", + table(), expectedSchema().asStruct(), filterExpressions(), runtimeFilterExpressions, caseSensitive()); + } +} diff --git a/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/source/SparkCopyOnWriteOperation.java b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/source/SparkCopyOnWriteOperation.java new file mode 100644 index 000000000000..19b4ab7a4087 --- /dev/null +++ b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/source/SparkCopyOnWriteOperation.java @@ -0,0 +1,103 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iceberg.spark.source; + +import org.apache.iceberg.IsolationLevel; +import org.apache.iceberg.MetadataColumns; +import org.apache.iceberg.Table; +import org.apache.iceberg.relocated.com.google.common.base.Preconditions; +import org.apache.spark.sql.SparkSession; +import org.apache.spark.sql.connector.expressions.Expressions; +import org.apache.spark.sql.connector.expressions.NamedReference; +import org.apache.spark.sql.connector.read.Scan; +import org.apache.spark.sql.connector.read.ScanBuilder; +import org.apache.spark.sql.connector.write.LogicalWriteInfo; +import org.apache.spark.sql.connector.write.RowLevelOperation; +import org.apache.spark.sql.connector.write.RowLevelOperationInfo; +import org.apache.spark.sql.connector.write.WriteBuilder; +import org.apache.spark.sql.util.CaseInsensitiveStringMap; + +import static org.apache.spark.sql.connector.write.RowLevelOperation.Command.DELETE; +import static org.apache.spark.sql.connector.write.RowLevelOperation.Command.UPDATE; + +class SparkCopyOnWriteOperation implements RowLevelOperation { + + private final SparkSession spark; + private final Table table; + private final Command command; + private final IsolationLevel isolationLevel; + + // lazy vars + private ScanBuilder lazyScanBuilder; + private Scan configuredScan; + private WriteBuilder lazyWriteBuilder; + + SparkCopyOnWriteOperation(SparkSession spark, Table table, RowLevelOperationInfo info, + IsolationLevel isolationLevel) { + this.spark = spark; + this.table = table; + this.command = info.command(); + this.isolationLevel = isolationLevel; + } + + @Override + public Command command() { + return command; + } + + @Override + public ScanBuilder newScanBuilder(CaseInsensitiveStringMap options) { + if (lazyScanBuilder == null) { + lazyScanBuilder = new SparkScanBuilder(spark, table, options) { + @Override + public Scan build() { + Scan scan = super.buildCopyOnWriteScan(); + SparkCopyOnWriteOperation.this.configuredScan = scan; + return scan; + } + }; + } + + return lazyScanBuilder; + } + + @Override + public WriteBuilder newWriteBuilder(LogicalWriteInfo info) { + if (lazyWriteBuilder == null) { + Preconditions.checkState(configuredScan != null, "Write must be configured after scan"); + SparkWriteBuilder writeBuilder = new SparkWriteBuilder(spark, table, info); + lazyWriteBuilder = writeBuilder.overwriteFiles(configuredScan, command, isolationLevel); + } + + return lazyWriteBuilder; + } + + @Override + public NamedReference[] requiredMetadataAttributes() { + NamedReference file = Expressions.column(MetadataColumns.FILE_PATH.name()); + NamedReference pos = Expressions.column(MetadataColumns.ROW_POSITION.name()); + + if (command == DELETE || command == UPDATE) { + return new NamedReference[]{file, pos}; + } else { + return new NamedReference[]{file}; + } + } +} diff --git a/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/source/SparkCopyOnWriteScan.java b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/source/SparkCopyOnWriteScan.java new file mode 100644 index 000000000000..49d90f5f2b13 --- /dev/null +++ b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/source/SparkCopyOnWriteScan.java @@ -0,0 +1,179 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iceberg.spark.source; + +import java.io.IOException; +import java.util.Collections; +import java.util.List; +import java.util.Objects; +import java.util.Set; +import java.util.stream.Collectors; +import org.apache.iceberg.CombinedScanTask; +import org.apache.iceberg.FileScanTask; +import org.apache.iceberg.MetadataColumns; +import org.apache.iceberg.Schema; +import org.apache.iceberg.Snapshot; +import org.apache.iceberg.Table; +import org.apache.iceberg.TableScan; +import org.apache.iceberg.exceptions.RuntimeIOException; +import org.apache.iceberg.expressions.Expression; +import org.apache.iceberg.io.CloseableIterable; +import org.apache.iceberg.relocated.com.google.common.collect.Lists; +import org.apache.iceberg.relocated.com.google.common.collect.Sets; +import org.apache.iceberg.spark.SparkReadConf; +import org.apache.iceberg.util.TableScanUtil; +import org.apache.spark.sql.SparkSession; +import org.apache.spark.sql.connector.expressions.Expressions; +import org.apache.spark.sql.connector.expressions.NamedReference; +import org.apache.spark.sql.connector.read.Statistics; +import org.apache.spark.sql.connector.read.SupportsRuntimeFiltering; +import org.apache.spark.sql.sources.Filter; +import org.apache.spark.sql.sources.In; + +class SparkCopyOnWriteScan extends SparkScan implements SupportsRuntimeFiltering { + + private final TableScan scan; + private final Snapshot snapshot; + + // lazy variables + private List files = null; // lazy cache of files + private List tasks = null; // lazy cache of tasks + private Set filteredLocations = null; + + SparkCopyOnWriteScan(SparkSession spark, Table table, SparkReadConf readConf, + Schema expectedSchema, List filters) { + this(spark, table, null, null, readConf, expectedSchema, filters); + } + + SparkCopyOnWriteScan(SparkSession spark, Table table, TableScan scan, Snapshot snapshot, + SparkReadConf readConf, Schema expectedSchema, List filters) { + + super(spark, table, readConf, expectedSchema, filters); + + this.scan = scan; + this.snapshot = snapshot; + + if (scan == null) { + this.files = Collections.emptyList(); + this.tasks = Collections.emptyList(); + this.filteredLocations = Collections.emptySet(); + } + } + + Long snapshotId() { + return snapshot != null ? snapshot.snapshotId() : null; + } + + @Override + public Statistics estimateStatistics() { + return estimateStatistics(snapshot); + } + + public NamedReference[] filterAttributes() { + NamedReference file = Expressions.column(MetadataColumns.FILE_PATH.name()); + return new NamedReference[]{file}; + } + + @Override + public void filter(Filter[] filters) { + for (Filter filter : filters) { + // Spark can only pass In filters at the moment + if (filter instanceof In && ((In) filter).attribute().equalsIgnoreCase(MetadataColumns.FILE_PATH.name())) { + In in = (In) filter; + + Set fileLocations = Sets.newHashSet(); + for (Object value : in.values()) { + fileLocations.add((String) value); + } + + // Spark may call this multiple times for UPDATEs with subqueries + // as such cases are rewritten using UNION and the same scan on both sides + // so filter files only if it is beneficial + if (filteredLocations == null || fileLocations.size() < filteredLocations.size()) { + this.tasks = null; + this.filteredLocations = fileLocations; + this.files = files().stream() + .filter(file -> fileLocations.contains(file.file().path().toString())) + .collect(Collectors.toList()); + } + } + } + } + + // should be accessible to the write + synchronized List files() { + if (files == null) { + try (CloseableIterable filesIterable = scan.planFiles()) { + this.files = Lists.newArrayList(filesIterable); + } catch (IOException e) { + throw new RuntimeIOException(e, "Failed to close table scan: %s", scan); + } + } + + return files; + } + + @Override + protected synchronized List tasks() { + if (tasks == null) { + CloseableIterable splitFiles = TableScanUtil.splitFiles( + CloseableIterable.withNoopClose(files()), + scan.targetSplitSize()); + CloseableIterable scanTasks = TableScanUtil.planTasks( + splitFiles, scan.targetSplitSize(), + scan.splitLookback(), scan.splitOpenFileCost()); + tasks = Lists.newArrayList(scanTasks); + } + + return tasks; + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + + if (o == null || getClass() != o.getClass()) { + return false; + } + + SparkCopyOnWriteScan that = (SparkCopyOnWriteScan) o; + return table().name().equals(that.table().name()) && + readSchema().equals(that.readSchema()) && // compare Spark schemas to ignore field ids + filterExpressions().toString().equals(that.filterExpressions().toString()) && + Objects.equals(snapshotId(), that.snapshotId()) && + Objects.equals(filteredLocations, that.filteredLocations); + } + + @Override + public int hashCode() { + return Objects.hash( + table().name(), readSchema(), filterExpressions().toString(), + snapshotId(), filteredLocations); + } + + @Override + public String toString() { + return String.format( + "IcebergCopyOnWriteScan(table=%s, type=%s, filters=%s, caseSensitive=%s)", + table(), expectedSchema().asStruct(), filterExpressions(), caseSensitive()); + } +} diff --git a/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/source/SparkFileWriterFactory.java b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/source/SparkFileWriterFactory.java new file mode 100644 index 000000000000..beaa7c295024 --- /dev/null +++ b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/source/SparkFileWriterFactory.java @@ -0,0 +1,245 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iceberg.spark.source; + +import java.util.Locale; +import java.util.Map; +import org.apache.iceberg.FileFormat; +import org.apache.iceberg.Schema; +import org.apache.iceberg.SortOrder; +import org.apache.iceberg.Table; +import org.apache.iceberg.avro.Avro; +import org.apache.iceberg.data.BaseFileWriterFactory; +import org.apache.iceberg.io.DeleteSchemaUtil; +import org.apache.iceberg.orc.ORC; +import org.apache.iceberg.parquet.Parquet; +import org.apache.iceberg.relocated.com.google.common.base.Preconditions; +import org.apache.iceberg.spark.SparkSchemaUtil; +import org.apache.iceberg.spark.data.SparkAvroWriter; +import org.apache.iceberg.spark.data.SparkOrcWriter; +import org.apache.iceberg.spark.data.SparkParquetWriters; +import org.apache.spark.sql.catalyst.InternalRow; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; +import org.apache.spark.unsafe.types.UTF8String; + +import static org.apache.iceberg.MetadataColumns.DELETE_FILE_ROW_FIELD_NAME; +import static org.apache.iceberg.TableProperties.DEFAULT_FILE_FORMAT; +import static org.apache.iceberg.TableProperties.DEFAULT_FILE_FORMAT_DEFAULT; +import static org.apache.iceberg.TableProperties.DELETE_DEFAULT_FILE_FORMAT; + +class SparkFileWriterFactory extends BaseFileWriterFactory { + private StructType dataSparkType; + private StructType equalityDeleteSparkType; + private StructType positionDeleteSparkType; + + SparkFileWriterFactory(Table table, FileFormat dataFileFormat, Schema dataSchema, StructType dataSparkType, + SortOrder dataSortOrder, FileFormat deleteFileFormat, + int[] equalityFieldIds, Schema equalityDeleteRowSchema, StructType equalityDeleteSparkType, + SortOrder equalityDeleteSortOrder, Schema positionDeleteRowSchema, + StructType positionDeleteSparkType) { + + super(table, dataFileFormat, dataSchema, dataSortOrder, deleteFileFormat, equalityFieldIds, + equalityDeleteRowSchema, equalityDeleteSortOrder, positionDeleteRowSchema); + + this.dataSparkType = dataSparkType; + this.equalityDeleteSparkType = equalityDeleteSparkType; + this.positionDeleteSparkType = positionDeleteSparkType; + } + + static Builder builderFor(Table table) { + return new Builder(table); + } + + @Override + protected void configureDataWrite(Avro.DataWriteBuilder builder) { + builder.createWriterFunc(ignored -> new SparkAvroWriter(dataSparkType())); + } + + @Override + protected void configureEqualityDelete(Avro.DeleteWriteBuilder builder) { + builder.createWriterFunc(ignored -> new SparkAvroWriter(equalityDeleteSparkType())); + } + + @Override + protected void configurePositionDelete(Avro.DeleteWriteBuilder builder) { + boolean withRow = positionDeleteSparkType().getFieldIndex(DELETE_FILE_ROW_FIELD_NAME).isDefined(); + if (withRow) { + // SparkAvroWriter accepts just the Spark type of the row ignoring the path and pos + StructField rowField = positionDeleteSparkType().apply(DELETE_FILE_ROW_FIELD_NAME); + StructType positionDeleteRowSparkType = (StructType) rowField.dataType(); + builder.createWriterFunc(ignored -> new SparkAvroWriter(positionDeleteRowSparkType)); + } + } + + @Override + protected void configureDataWrite(Parquet.DataWriteBuilder builder) { + builder.createWriterFunc(msgType -> SparkParquetWriters.buildWriter(dataSparkType(), msgType)); + } + + @Override + protected void configureEqualityDelete(Parquet.DeleteWriteBuilder builder) { + builder.createWriterFunc(msgType -> SparkParquetWriters.buildWriter(equalityDeleteSparkType(), msgType)); + } + + @Override + protected void configurePositionDelete(Parquet.DeleteWriteBuilder builder) { + builder.createWriterFunc(msgType -> SparkParquetWriters.buildWriter(positionDeleteSparkType(), msgType)); + builder.transformPaths(path -> UTF8String.fromString(path.toString())); + } + + @Override + protected void configureDataWrite(ORC.DataWriteBuilder builder) { + builder.createWriterFunc(SparkOrcWriter::new); + } + + @Override + protected void configureEqualityDelete(ORC.DeleteWriteBuilder builder) { + builder.createWriterFunc(SparkOrcWriter::new); + } + + @Override + protected void configurePositionDelete(ORC.DeleteWriteBuilder builder) { + builder.createWriterFunc(SparkOrcWriter::new); + builder.transformPaths(path -> UTF8String.fromString(path.toString())); + } + + private StructType dataSparkType() { + if (dataSparkType == null) { + Preconditions.checkNotNull(dataSchema(), "Data schema must not be null"); + this.dataSparkType = SparkSchemaUtil.convert(dataSchema()); + } + + return dataSparkType; + } + + private StructType equalityDeleteSparkType() { + if (equalityDeleteSparkType == null) { + Preconditions.checkNotNull(equalityDeleteRowSchema(), "Equality delete schema must not be null"); + this.equalityDeleteSparkType = SparkSchemaUtil.convert(equalityDeleteRowSchema()); + } + + return equalityDeleteSparkType; + } + + private StructType positionDeleteSparkType() { + if (positionDeleteSparkType == null) { + // wrap the optional row schema into the position delete schema that contains path and position + Schema positionDeleteSchema = DeleteSchemaUtil.posDeleteSchema(positionDeleteRowSchema()); + this.positionDeleteSparkType = SparkSchemaUtil.convert(positionDeleteSchema); + } + + return positionDeleteSparkType; + } + + static class Builder { + private final Table table; + private FileFormat dataFileFormat; + private Schema dataSchema; + private StructType dataSparkType; + private SortOrder dataSortOrder; + private FileFormat deleteFileFormat; + private int[] equalityFieldIds; + private Schema equalityDeleteRowSchema; + private StructType equalityDeleteSparkType; + private SortOrder equalityDeleteSortOrder; + private Schema positionDeleteRowSchema; + private StructType positionDeleteSparkType; + + Builder(Table table) { + this.table = table; + + Map properties = table.properties(); + + String dataFileFormatName = properties.getOrDefault(DEFAULT_FILE_FORMAT, DEFAULT_FILE_FORMAT_DEFAULT); + this.dataFileFormat = FileFormat.valueOf(dataFileFormatName.toUpperCase(Locale.ENGLISH)); + + String deleteFileFormatName = properties.getOrDefault(DELETE_DEFAULT_FILE_FORMAT, dataFileFormatName); + this.deleteFileFormat = FileFormat.valueOf(deleteFileFormatName.toUpperCase(Locale.ENGLISH)); + } + + Builder dataFileFormat(FileFormat newDataFileFormat) { + this.dataFileFormat = newDataFileFormat; + return this; + } + + Builder dataSchema(Schema newDataSchema) { + this.dataSchema = newDataSchema; + return this; + } + + Builder dataSparkType(StructType newDataSparkType) { + this.dataSparkType = newDataSparkType; + return this; + } + + Builder dataSortOrder(SortOrder newDataSortOrder) { + this.dataSortOrder = newDataSortOrder; + return this; + } + + Builder deleteFileFormat(FileFormat newDeleteFileFormat) { + this.deleteFileFormat = newDeleteFileFormat; + return this; + } + + Builder equalityFieldIds(int[] newEqualityFieldIds) { + this.equalityFieldIds = newEqualityFieldIds; + return this; + } + + Builder equalityDeleteRowSchema(Schema newEqualityDeleteRowSchema) { + this.equalityDeleteRowSchema = newEqualityDeleteRowSchema; + return this; + } + + Builder equalityDeleteSparkType(StructType newEqualityDeleteSparkType) { + this.equalityDeleteSparkType = newEqualityDeleteSparkType; + return this; + } + + Builder equalityDeleteSortOrder(SortOrder newEqualityDeleteSortOrder) { + this.equalityDeleteSortOrder = newEqualityDeleteSortOrder; + return this; + } + + Builder positionDeleteRowSchema(Schema newPositionDeleteRowSchema) { + this.positionDeleteRowSchema = newPositionDeleteRowSchema; + return this; + } + + Builder positionDeleteSparkType(StructType newPositionDeleteSparkType) { + this.positionDeleteSparkType = newPositionDeleteSparkType; + return this; + } + + SparkFileWriterFactory build() { + boolean noEqualityDeleteConf = equalityFieldIds == null && equalityDeleteRowSchema == null; + boolean fullEqualityDeleteConf = equalityFieldIds != null && equalityDeleteRowSchema != null; + Preconditions.checkArgument(noEqualityDeleteConf || fullEqualityDeleteConf, + "Equality field IDs and equality delete row schema must be set together"); + + return new SparkFileWriterFactory( + table, dataFileFormat, dataSchema, dataSparkType, dataSortOrder, deleteFileFormat, + equalityFieldIds, equalityDeleteRowSchema, equalityDeleteSparkType, equalityDeleteSortOrder, + positionDeleteRowSchema, positionDeleteSparkType); + } + } +} diff --git a/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/source/SparkFilesScan.java b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/source/SparkFilesScan.java new file mode 100644 index 000000000000..285bdaec851f --- /dev/null +++ b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/source/SparkFilesScan.java @@ -0,0 +1,103 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iceberg.spark.source; + +import java.util.List; +import java.util.Objects; +import org.apache.iceberg.CombinedScanTask; +import org.apache.iceberg.FileScanTask; +import org.apache.iceberg.Table; +import org.apache.iceberg.exceptions.ValidationException; +import org.apache.iceberg.io.CloseableIterable; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableList; +import org.apache.iceberg.relocated.com.google.common.collect.Lists; +import org.apache.iceberg.spark.FileScanTaskSetManager; +import org.apache.iceberg.spark.SparkReadConf; +import org.apache.iceberg.util.TableScanUtil; +import org.apache.spark.sql.SparkSession; + +class SparkFilesScan extends SparkScan { + private final String taskSetID; + private final long splitSize; + private final int splitLookback; + private final long splitOpenFileCost; + + private List tasks = null; // lazy cache of tasks + + SparkFilesScan(SparkSession spark, Table table, SparkReadConf readConf) { + super(spark, table, readConf, table.schema(), ImmutableList.of()); + + this.taskSetID = readConf.fileScanTaskSetId(); + this.splitSize = readConf.splitSize(); + this.splitLookback = readConf.splitLookback(); + this.splitOpenFileCost = readConf.splitOpenFileCost(); + } + + @Override + protected List tasks() { + if (tasks == null) { + FileScanTaskSetManager taskSetManager = FileScanTaskSetManager.get(); + List files = taskSetManager.fetchTasks(table(), taskSetID); + ValidationException.check(files != null, + "Task set manager has no tasks for table %s with id %s", + table(), taskSetID); + + CloseableIterable splitFiles = TableScanUtil.splitFiles( + CloseableIterable.withNoopClose(files), + splitSize); + CloseableIterable scanTasks = TableScanUtil.planTasks( + splitFiles, splitSize, + splitLookback, splitOpenFileCost); + this.tasks = Lists.newArrayList(scanTasks); + } + + return tasks; + } + + @Override + public boolean equals(Object other) { + if (this == other) { + return true; + } + + if (other == null || getClass() != other.getClass()) { + return false; + } + + SparkFilesScan that = (SparkFilesScan) other; + return table().name().equals(that.table().name()) && + Objects.equals(taskSetID, that.taskSetID) && + Objects.equals(splitSize, that.splitSize) && + Objects.equals(splitLookback, that.splitLookback) && + Objects.equals(splitOpenFileCost, that.splitOpenFileCost); + } + + @Override + public int hashCode() { + return Objects.hash(table().name(), taskSetID, splitSize, splitSize, splitOpenFileCost); + } + + @Override + public String toString() { + return String.format( + "IcebergFilesScan(table=%s, type=%s, taskSetID=%s, caseSensitive=%s)", + table(), expectedSchema().asStruct(), taskSetID, caseSensitive()); + } +} diff --git a/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/source/SparkFilesScanBuilder.java b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/source/SparkFilesScanBuilder.java new file mode 100644 index 000000000000..6ae63b37d3b4 --- /dev/null +++ b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/source/SparkFilesScanBuilder.java @@ -0,0 +1,45 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iceberg.spark.source; + +import org.apache.iceberg.Table; +import org.apache.iceberg.spark.SparkReadConf; +import org.apache.spark.sql.SparkSession; +import org.apache.spark.sql.connector.read.Scan; +import org.apache.spark.sql.connector.read.ScanBuilder; +import org.apache.spark.sql.util.CaseInsensitiveStringMap; + +class SparkFilesScanBuilder implements ScanBuilder { + + private final SparkSession spark; + private final Table table; + private final SparkReadConf readConf; + + SparkFilesScanBuilder(SparkSession spark, Table table, CaseInsensitiveStringMap options) { + this.spark = spark; + this.table = table; + this.readConf = new SparkReadConf(spark, table, options); + } + + @Override + public Scan build() { + return new SparkFilesScan(spark, table, readConf); + } +} diff --git a/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/source/SparkMetadataColumn.java b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/source/SparkMetadataColumn.java new file mode 100644 index 000000000000..638cb7275638 --- /dev/null +++ b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/source/SparkMetadataColumn.java @@ -0,0 +1,51 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iceberg.spark.source; + +import org.apache.spark.sql.connector.catalog.MetadataColumn; +import org.apache.spark.sql.types.DataType; + +public class SparkMetadataColumn implements MetadataColumn { + + private final String name; + private final DataType dataType; + private final boolean isNullable; + + public SparkMetadataColumn(String name, DataType dataType, boolean isNullable) { + this.name = name; + this.dataType = dataType; + this.isNullable = isNullable; + } + + @Override + public String name() { + return name; + } + + @Override + public DataType dataType() { + return dataType; + } + + @Override + public boolean isNullable() { + return isNullable; + } +} diff --git a/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/source/SparkMicroBatchStream.java b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/source/SparkMicroBatchStream.java new file mode 100644 index 000000000000..3f8b1bf0f84e --- /dev/null +++ b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/source/SparkMicroBatchStream.java @@ -0,0 +1,309 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iceberg.spark.source; + +import java.io.BufferedWriter; +import java.io.IOException; +import java.io.InputStream; +import java.io.OutputStream; +import java.io.OutputStreamWriter; +import java.io.UncheckedIOException; +import java.nio.charset.StandardCharsets; +import java.util.List; +import java.util.Locale; +import org.apache.iceberg.CombinedScanTask; +import org.apache.iceberg.DataOperations; +import org.apache.iceberg.FileScanTask; +import org.apache.iceberg.MicroBatches; +import org.apache.iceberg.MicroBatches.MicroBatch; +import org.apache.iceberg.Schema; +import org.apache.iceberg.SchemaParser; +import org.apache.iceberg.SerializableTable; +import org.apache.iceberg.Snapshot; +import org.apache.iceberg.SnapshotSummary; +import org.apache.iceberg.Table; +import org.apache.iceberg.io.CloseableIterable; +import org.apache.iceberg.io.FileIO; +import org.apache.iceberg.io.InputFile; +import org.apache.iceberg.io.OutputFile; +import org.apache.iceberg.relocated.com.google.common.base.Joiner; +import org.apache.iceberg.relocated.com.google.common.base.Preconditions; +import org.apache.iceberg.relocated.com.google.common.collect.Iterables; +import org.apache.iceberg.relocated.com.google.common.collect.Lists; +import org.apache.iceberg.spark.SparkReadConf; +import org.apache.iceberg.spark.SparkReadOptions; +import org.apache.iceberg.spark.source.SparkScan.ReadTask; +import org.apache.iceberg.spark.source.SparkScan.ReaderFactory; +import org.apache.iceberg.util.PropertyUtil; +import org.apache.iceberg.util.SnapshotUtil; +import org.apache.iceberg.util.TableScanUtil; +import org.apache.iceberg.util.Tasks; +import org.apache.iceberg.util.ThreadPools; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.broadcast.Broadcast; +import org.apache.spark.sql.connector.read.InputPartition; +import org.apache.spark.sql.connector.read.PartitionReaderFactory; +import org.apache.spark.sql.connector.read.streaming.MicroBatchStream; +import org.apache.spark.sql.connector.read.streaming.Offset; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +public class SparkMicroBatchStream implements MicroBatchStream { + private static final Joiner SLASH = Joiner.on("/"); + private static final Logger LOG = LoggerFactory.getLogger(SparkMicroBatchStream.class); + + private final Table table; + private final boolean caseSensitive; + private final String expectedSchema; + private final Broadcast
tableBroadcast; + private final Long splitSize; + private final Integer splitLookback; + private final Long splitOpenFileCost; + private final boolean localityPreferred; + private final StreamingOffset initialOffset; + private final boolean skipDelete; + private final boolean skipOverwrite; + private final Long fromTimestamp; + + SparkMicroBatchStream(JavaSparkContext sparkContext, Table table, SparkReadConf readConf, + Schema expectedSchema, String checkpointLocation) { + this.table = table; + this.caseSensitive = readConf.caseSensitive(); + this.expectedSchema = SchemaParser.toJson(expectedSchema); + this.localityPreferred = readConf.localityEnabled(); + this.tableBroadcast = sparkContext.broadcast(SerializableTable.copyOf(table)); + this.splitSize = readConf.splitSize(); + this.splitLookback = readConf.splitLookback(); + this.splitOpenFileCost = readConf.splitOpenFileCost(); + this.fromTimestamp = readConf.streamFromTimestamp(); + + InitialOffsetStore initialOffsetStore = new InitialOffsetStore(table, checkpointLocation, fromTimestamp); + this.initialOffset = initialOffsetStore.initialOffset(); + + this.skipDelete = readConf.streamingSkipDeleteSnapshots(); + this.skipOverwrite = readConf.streamingSkipOverwriteSnapshots(); + } + + @Override + public Offset latestOffset() { + table.refresh(); + if (table.currentSnapshot() == null) { + return StreamingOffset.START_OFFSET; + } + + if (table.currentSnapshot().timestampMillis() < fromTimestamp) { + return StreamingOffset.START_OFFSET; + } + + Snapshot latestSnapshot = table.currentSnapshot(); + long addedFilesCount = PropertyUtil.propertyAsLong(latestSnapshot.summary(), SnapshotSummary.ADDED_FILES_PROP, -1); + // If snapshotSummary doesn't have SnapshotSummary.ADDED_FILES_PROP, iterate through addedFiles iterator to find + // addedFilesCount. + addedFilesCount = addedFilesCount == -1 ? Iterables.size(latestSnapshot.addedFiles()) : addedFilesCount; + + return new StreamingOffset(latestSnapshot.snapshotId(), addedFilesCount, false); + } + + @Override + public InputPartition[] planInputPartitions(Offset start, Offset end) { + Preconditions.checkArgument(end instanceof StreamingOffset, "Invalid end offset: %s is not a StreamingOffset", end); + Preconditions.checkArgument( + start instanceof StreamingOffset, "Invalid start offset: %s is not a StreamingOffset", start); + + if (end.equals(StreamingOffset.START_OFFSET)) { + return new InputPartition[0]; + } + + StreamingOffset endOffset = (StreamingOffset) end; + StreamingOffset startOffset = (StreamingOffset) start; + + List fileScanTasks = planFiles(startOffset, endOffset); + + CloseableIterable splitTasks = TableScanUtil.splitFiles( + CloseableIterable.withNoopClose(fileScanTasks), + splitSize); + List combinedScanTasks = Lists.newArrayList( + TableScanUtil.planTasks(splitTasks, splitSize, splitLookback, splitOpenFileCost)); + InputPartition[] readTasks = new InputPartition[combinedScanTasks.size()]; + + Tasks.range(readTasks.length) + .stopOnFailure() + .executeWith(localityPreferred ? ThreadPools.getWorkerPool() : null) + .run(index -> readTasks[index] = new ReadTask( + combinedScanTasks.get(index), tableBroadcast, expectedSchema, + caseSensitive, localityPreferred)); + + return readTasks; + } + + @Override + public PartitionReaderFactory createReaderFactory() { + return new ReaderFactory(0); + } + + @Override + public Offset initialOffset() { + return initialOffset; + } + + @Override + public Offset deserializeOffset(String json) { + return StreamingOffset.fromJson(json); + } + + @Override + public void commit(Offset end) { + } + + @Override + public void stop() { + } + + private List planFiles(StreamingOffset startOffset, StreamingOffset endOffset) { + List fileScanTasks = Lists.newArrayList(); + StreamingOffset batchStartOffset = StreamingOffset.START_OFFSET.equals(startOffset) ? + determineStartingOffset(table, fromTimestamp) : startOffset; + + StreamingOffset currentOffset = null; + + do { + if (currentOffset == null) { + currentOffset = batchStartOffset; + } else { + Snapshot snapshotAfter = SnapshotUtil.snapshotAfter(table, currentOffset.snapshotId()); + currentOffset = new StreamingOffset(snapshotAfter.snapshotId(), 0L, false); + } + + if (!shouldProcess(table.snapshot(currentOffset.snapshotId()))) { + LOG.debug("Skipping snapshot: {} of table {}", currentOffset.snapshotId(), table.name()); + continue; + } + + MicroBatch latestMicroBatch = MicroBatches.from(table.snapshot(currentOffset.snapshotId()), table.io()) + .caseSensitive(caseSensitive) + .specsById(table.specs()) + .generate(currentOffset.position(), Long.MAX_VALUE, currentOffset.shouldScanAllFiles()); + + fileScanTasks.addAll(latestMicroBatch.tasks()); + } while (currentOffset.snapshotId() != endOffset.snapshotId()); + + return fileScanTasks; + } + + private boolean shouldProcess(Snapshot snapshot) { + String op = snapshot.operation(); + switch (op) { + case DataOperations.APPEND: + return true; + case DataOperations.REPLACE: + return false; + case DataOperations.DELETE: + Preconditions.checkState(skipDelete, + "Cannot process delete snapshot: %s, to ignore deletes, set %s=true", + snapshot.snapshotId(), SparkReadOptions.STREAMING_SKIP_DELETE_SNAPSHOTS); + return false; + case DataOperations.OVERWRITE: + Preconditions.checkState(skipOverwrite, + "Cannot process overwrite snapshot: %s, to ignore overwrites, set %s=true", + snapshot.snapshotId(), SparkReadOptions.STREAMING_SKIP_OVERWRITE_SNAPSHOTS); + return false; + default: + throw new IllegalStateException(String.format( + "Cannot process unknown snapshot operation: %s (snapshot id %s)", + op.toLowerCase(Locale.ROOT), snapshot.snapshotId())); + } + } + + private static StreamingOffset determineStartingOffset(Table table, Long fromTimestamp) { + if (table.currentSnapshot() == null) { + return StreamingOffset.START_OFFSET; + } + + if (fromTimestamp == null) { + // match existing behavior and start from the oldest snapshot + return new StreamingOffset(SnapshotUtil.oldestAncestor(table).snapshotId(), 0, false); + } + + if (table.currentSnapshot().timestampMillis() < fromTimestamp) { + return StreamingOffset.START_OFFSET; + } + + try { + Snapshot snapshot = SnapshotUtil.oldestAncestorAfter(table, fromTimestamp); + if (snapshot != null) { + return new StreamingOffset(snapshot.snapshotId(), 0, false); + } else { + return StreamingOffset.START_OFFSET; + } + } catch (IllegalStateException e) { + // could not determine the first snapshot after the timestamp. use the oldest ancestor instead + return new StreamingOffset(SnapshotUtil.oldestAncestor(table).snapshotId(), 0, false); + } + } + + private static class InitialOffsetStore { + private final Table table; + private final FileIO io; + private final String initialOffsetLocation; + private final Long fromTimestamp; + + InitialOffsetStore(Table table, String checkpointLocation, Long fromTimestamp) { + this.table = table; + this.io = table.io(); + this.initialOffsetLocation = SLASH.join(checkpointLocation, "offsets/0"); + this.fromTimestamp = fromTimestamp; + } + + public StreamingOffset initialOffset() { + InputFile inputFile = io.newInputFile(initialOffsetLocation); + if (inputFile.exists()) { + return readOffset(inputFile); + } + + table.refresh(); + StreamingOffset offset = determineStartingOffset(table, fromTimestamp); + + OutputFile outputFile = io.newOutputFile(initialOffsetLocation); + writeOffset(offset, outputFile); + + return offset; + } + + private void writeOffset(StreamingOffset offset, OutputFile file) { + try (OutputStream outputStream = file.create()) { + BufferedWriter writer = new BufferedWriter(new OutputStreamWriter(outputStream, StandardCharsets.UTF_8)); + writer.write(offset.json()); + writer.flush(); + } catch (IOException ioException) { + throw new UncheckedIOException( + String.format("Failed writing offset to: %s", initialOffsetLocation), ioException); + } + } + + private StreamingOffset readOffset(InputFile file) { + try (InputStream in = file.newStream()) { + return StreamingOffset.fromJson(in); + } catch (IOException ioException) { + throw new UncheckedIOException( + String.format("Failed reading offset from: %s", initialOffsetLocation), ioException); + } + } + } +} diff --git a/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/source/SparkPartitionedFanoutWriter.java b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/source/SparkPartitionedFanoutWriter.java new file mode 100644 index 000000000000..d38ae2f40316 --- /dev/null +++ b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/source/SparkPartitionedFanoutWriter.java @@ -0,0 +1,51 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iceberg.spark.source; + +import org.apache.iceberg.FileFormat; +import org.apache.iceberg.PartitionKey; +import org.apache.iceberg.PartitionSpec; +import org.apache.iceberg.Schema; +import org.apache.iceberg.io.FileAppenderFactory; +import org.apache.iceberg.io.FileIO; +import org.apache.iceberg.io.OutputFileFactory; +import org.apache.iceberg.io.PartitionedFanoutWriter; +import org.apache.spark.sql.catalyst.InternalRow; +import org.apache.spark.sql.types.StructType; + +public class SparkPartitionedFanoutWriter extends PartitionedFanoutWriter { + private final PartitionKey partitionKey; + private final InternalRowWrapper internalRowWrapper; + + public SparkPartitionedFanoutWriter(PartitionSpec spec, FileFormat format, + FileAppenderFactory appenderFactory, + OutputFileFactory fileFactory, FileIO io, long targetFileSize, + Schema schema, StructType sparkSchema) { + super(spec, format, appenderFactory, fileFactory, io, targetFileSize); + this.partitionKey = new PartitionKey(spec, schema); + this.internalRowWrapper = new InternalRowWrapper(sparkSchema); + } + + @Override + protected PartitionKey partition(InternalRow row) { + partitionKey.partition(internalRowWrapper.wrap(row)); + return partitionKey; + } +} diff --git a/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/source/SparkPartitionedWriter.java b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/source/SparkPartitionedWriter.java new file mode 100644 index 000000000000..f81a09926d85 --- /dev/null +++ b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/source/SparkPartitionedWriter.java @@ -0,0 +1,51 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iceberg.spark.source; + +import org.apache.iceberg.FileFormat; +import org.apache.iceberg.PartitionKey; +import org.apache.iceberg.PartitionSpec; +import org.apache.iceberg.Schema; +import org.apache.iceberg.io.FileAppenderFactory; +import org.apache.iceberg.io.FileIO; +import org.apache.iceberg.io.OutputFileFactory; +import org.apache.iceberg.io.PartitionedWriter; +import org.apache.spark.sql.catalyst.InternalRow; +import org.apache.spark.sql.types.StructType; + +public class SparkPartitionedWriter extends PartitionedWriter { + private final PartitionKey partitionKey; + private final InternalRowWrapper internalRowWrapper; + + public SparkPartitionedWriter(PartitionSpec spec, FileFormat format, + FileAppenderFactory appenderFactory, + OutputFileFactory fileFactory, FileIO io, long targetFileSize, + Schema schema, StructType sparkSchema) { + super(spec, format, appenderFactory, fileFactory, io, targetFileSize); + this.partitionKey = new PartitionKey(spec, schema); + this.internalRowWrapper = new InternalRowWrapper(sparkSchema); + } + + @Override + protected PartitionKey partition(InternalRow row) { + partitionKey.partition(internalRowWrapper.wrap(row)); + return partitionKey; + } +} diff --git a/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/source/SparkPositionDeltaOperation.java b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/source/SparkPositionDeltaOperation.java new file mode 100644 index 000000000000..7bdb23cd49ae --- /dev/null +++ b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/source/SparkPositionDeltaOperation.java @@ -0,0 +1,107 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iceberg.spark.source; + +import org.apache.iceberg.IsolationLevel; +import org.apache.iceberg.MetadataColumns; +import org.apache.iceberg.Table; +import org.apache.iceberg.relocated.com.google.common.base.Preconditions; +import org.apache.spark.sql.SparkSession; +import org.apache.spark.sql.connector.expressions.Expressions; +import org.apache.spark.sql.connector.expressions.NamedReference; +import org.apache.spark.sql.connector.iceberg.write.DeltaWriteBuilder; +import org.apache.spark.sql.connector.iceberg.write.ExtendedLogicalWriteInfo; +import org.apache.spark.sql.connector.iceberg.write.SupportsDelta; +import org.apache.spark.sql.connector.read.Scan; +import org.apache.spark.sql.connector.read.ScanBuilder; +import org.apache.spark.sql.connector.write.LogicalWriteInfo; +import org.apache.spark.sql.connector.write.RowLevelOperation; +import org.apache.spark.sql.connector.write.RowLevelOperationInfo; +import org.apache.spark.sql.util.CaseInsensitiveStringMap; + +class SparkPositionDeltaOperation implements RowLevelOperation, SupportsDelta { + + private final SparkSession spark; + private final Table table; + private final Command command; + private final IsolationLevel isolationLevel; + + // lazy vars + private ScanBuilder lazyScanBuilder; + private Scan configuredScan; + private DeltaWriteBuilder lazyWriteBuilder; + + SparkPositionDeltaOperation(SparkSession spark, Table table, RowLevelOperationInfo info, + IsolationLevel isolationLevel) { + this.spark = spark; + this.table = table; + this.command = info.command(); + this.isolationLevel = isolationLevel; + } + + @Override + public Command command() { + return command; + } + + @Override + public ScanBuilder newScanBuilder(CaseInsensitiveStringMap options) { + if (lazyScanBuilder == null) { + this.lazyScanBuilder = new SparkScanBuilder(spark, table, options) { + @Override + public Scan build() { + Scan scan = super.buildMergeOnReadScan(); + SparkPositionDeltaOperation.this.configuredScan = scan; + return scan; + } + }; + } + + return lazyScanBuilder; + } + + @Override + public DeltaWriteBuilder newWriteBuilder(LogicalWriteInfo info) { + if (lazyWriteBuilder == null) { + Preconditions.checkArgument(info instanceof ExtendedLogicalWriteInfo, "info must be ExtendedLogicalWriteInfo"); + // don't validate the scan is not null as if the condition evaluates to false, + // the optimizer replaces the original scan relation with a local relation + lazyWriteBuilder = new SparkPositionDeltaWriteBuilder( + spark, table, command, configuredScan, + isolationLevel, (ExtendedLogicalWriteInfo) info); + } + + return lazyWriteBuilder; + } + + @Override + public NamedReference[] requiredMetadataAttributes() { + NamedReference specId = Expressions.column(MetadataColumns.SPEC_ID.name()); + NamedReference partition = Expressions.column(MetadataColumns.PARTITION_COLUMN_NAME); + return new NamedReference[]{specId, partition}; + } + + @Override + public NamedReference[] rowId() { + NamedReference file = Expressions.column(MetadataColumns.FILE_PATH.name()); + NamedReference pos = Expressions.column(MetadataColumns.ROW_POSITION.name()); + return new NamedReference[]{file, pos}; + } +} diff --git a/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/source/SparkPositionDeltaWrite.java b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/source/SparkPositionDeltaWrite.java new file mode 100644 index 000000000000..65200a6c9139 --- /dev/null +++ b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/source/SparkPositionDeltaWrite.java @@ -0,0 +1,668 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iceberg.spark.source; + +import java.io.IOException; +import java.io.Serializable; +import java.util.Arrays; +import java.util.Map; +import org.apache.iceberg.ContentFile; +import org.apache.iceberg.DataFile; +import org.apache.iceberg.DeleteFile; +import org.apache.iceberg.FileFormat; +import org.apache.iceberg.IsolationLevel; +import org.apache.iceberg.MetadataColumns; +import org.apache.iceberg.PartitionKey; +import org.apache.iceberg.PartitionSpec; +import org.apache.iceberg.Partitioning; +import org.apache.iceberg.RowDelta; +import org.apache.iceberg.Schema; +import org.apache.iceberg.SerializableTable; +import org.apache.iceberg.SnapshotSummary; +import org.apache.iceberg.SnapshotUpdate; +import org.apache.iceberg.Table; +import org.apache.iceberg.deletes.PositionDelete; +import org.apache.iceberg.exceptions.CommitStateUnknownException; +import org.apache.iceberg.expressions.Expression; +import org.apache.iceberg.expressions.Expressions; +import org.apache.iceberg.io.BasePositionDeltaWriter; +import org.apache.iceberg.io.ClusteredDataWriter; +import org.apache.iceberg.io.ClusteredPositionDeleteWriter; +import org.apache.iceberg.io.DataWriteResult; +import org.apache.iceberg.io.DeleteWriteResult; +import org.apache.iceberg.io.FanoutDataWriter; +import org.apache.iceberg.io.FileIO; +import org.apache.iceberg.io.OutputFileFactory; +import org.apache.iceberg.io.PartitioningWriter; +import org.apache.iceberg.io.PositionDeltaWriter; +import org.apache.iceberg.io.WriteResult; +import org.apache.iceberg.relocated.com.google.common.collect.Maps; +import org.apache.iceberg.spark.CommitMetadata; +import org.apache.iceberg.spark.SparkSchemaUtil; +import org.apache.iceberg.spark.SparkWriteConf; +import org.apache.iceberg.types.Types; +import org.apache.iceberg.util.CharSequenceSet; +import org.apache.iceberg.util.StructProjection; +import org.apache.iceberg.util.Tasks; +import org.apache.iceberg.util.ThreadPools; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.broadcast.Broadcast; +import org.apache.spark.sql.SparkSession; +import org.apache.spark.sql.catalyst.InternalRow; +import org.apache.spark.sql.connector.distributions.Distribution; +import org.apache.spark.sql.connector.expressions.SortOrder; +import org.apache.spark.sql.connector.iceberg.write.DeltaBatchWrite; +import org.apache.spark.sql.connector.iceberg.write.DeltaWrite; +import org.apache.spark.sql.connector.iceberg.write.DeltaWriter; +import org.apache.spark.sql.connector.iceberg.write.DeltaWriterFactory; +import org.apache.spark.sql.connector.iceberg.write.ExtendedLogicalWriteInfo; +import org.apache.spark.sql.connector.write.PhysicalWriteInfo; +import org.apache.spark.sql.connector.write.RequiresDistributionAndOrdering; +import org.apache.spark.sql.connector.write.RowLevelOperation.Command; +import org.apache.spark.sql.connector.write.WriterCommitMessage; +import org.apache.spark.sql.types.StructType; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import static org.apache.iceberg.IsolationLevel.SERIALIZABLE; +import static org.apache.spark.sql.connector.write.RowLevelOperation.Command.DELETE; +import static org.apache.spark.sql.connector.write.RowLevelOperation.Command.MERGE; +import static org.apache.spark.sql.connector.write.RowLevelOperation.Command.UPDATE; + +class SparkPositionDeltaWrite implements DeltaWrite, RequiresDistributionAndOrdering { + + private static final Logger LOG = LoggerFactory.getLogger(SparkPositionDeltaWrite.class); + + private final JavaSparkContext sparkContext; + private final Table table; + private final Command command; + private final SparkBatchQueryScan scan; + private final IsolationLevel isolationLevel; + private final Context context; + private final String applicationId; + private final boolean wapEnabled; + private final String wapId; + private final Map extraSnapshotMetadata; + private final Distribution requiredDistribution; + private final SortOrder[] requiredOrdering; + + private boolean cleanupOnAbort = true; + + SparkPositionDeltaWrite(SparkSession spark, Table table, Command command, SparkBatchQueryScan scan, + IsolationLevel isolationLevel, SparkWriteConf writeConf, + ExtendedLogicalWriteInfo info, Schema dataSchema, + Distribution requiredDistribution, SortOrder[] requiredOrdering) { + this.sparkContext = JavaSparkContext.fromSparkContext(spark.sparkContext()); + this.table = table; + this.command = command; + this.scan = scan; + this.isolationLevel = isolationLevel; + this.context = new Context(dataSchema, writeConf, info); + this.applicationId = spark.sparkContext().applicationId(); + this.wapEnabled = writeConf.wapEnabled(); + this.wapId = writeConf.wapId(); + this.extraSnapshotMetadata = writeConf.extraSnapshotMetadata(); + this.requiredDistribution = requiredDistribution; + this.requiredOrdering = requiredOrdering; + } + + @Override + public Distribution requiredDistribution() { + return requiredDistribution; + } + + @Override + public SortOrder[] requiredOrdering() { + return requiredOrdering; + } + + @Override + public DeltaBatchWrite toBatch() { + return new PositionDeltaBatchWrite(); + } + + private static > void cleanFiles(FileIO io, Iterable files) { + Tasks.foreach(files) + .executeWith(ThreadPools.getWorkerPool()) + .throwFailureWhenFinished() + .noRetry() + .run(file -> io.deleteFile(file.path().toString())); + } + + private class PositionDeltaBatchWrite implements DeltaBatchWrite { + + @Override + public DeltaWriterFactory createBatchWriterFactory(PhysicalWriteInfo info) { + // broadcast the table metadata as the writer factory will be sent to executors + Broadcast
tableBroadcast = sparkContext.broadcast(SerializableTable.copyOf(table)); + return new PositionDeltaWriteFactory(tableBroadcast, command, context); + } + + @Override + public void commit(WriterCommitMessage[] messages) { + RowDelta rowDelta = table.newRowDelta(); + + CharSequenceSet referencedDataFiles = CharSequenceSet.empty(); + + int addedDataFilesCount = 0; + int addedDeleteFilesCount = 0; + + for (WriterCommitMessage message : messages) { + DeltaTaskCommit taskCommit = (DeltaTaskCommit) message; + + for (DataFile dataFile : taskCommit.dataFiles()) { + rowDelta.addRows(dataFile); + addedDataFilesCount += 1; + } + + for (DeleteFile deleteFile : taskCommit.deleteFiles()) { + rowDelta.addDeletes(deleteFile); + addedDeleteFilesCount += 1; + } + + referencedDataFiles.addAll(Arrays.asList(taskCommit.referencedDataFiles())); + } + + // the scan may be null if the optimizer replaces it with an empty relation (e.g. the cond is false) + // no validation is needed in this case as the command does not depend on the scanned table state + if (scan != null) { + Expression conflictDetectionFilter = conflictDetectionFilter(scan); + rowDelta.conflictDetectionFilter(conflictDetectionFilter); + + rowDelta.validateDataFilesExist(referencedDataFiles); + + if (scan.snapshotId() != null) { + // set the read snapshot ID to check only snapshots that happened after the table was read + // otherwise, the validation will go through all snapshots present in the table + rowDelta.validateFromSnapshot(scan.snapshotId()); + } + + if (command == UPDATE || command == MERGE) { + rowDelta.validateDeletedFiles(); + rowDelta.validateNoConflictingDeleteFiles(); + } + + if (isolationLevel == SERIALIZABLE) { + rowDelta.validateNoConflictingDataFiles(); + } + + String commitMsg = String.format( + "position delta with %d data files and %d delete files " + + "(scanSnapshotId: %d, conflictDetectionFilter: %s, isolationLevel: %s)", + addedDataFilesCount, addedDeleteFilesCount, scan.snapshotId(), conflictDetectionFilter, isolationLevel); + commitOperation(rowDelta, commitMsg); + + } else { + String commitMsg = String.format( + "position delta with %d data files and %d delete files (no validation required)", + addedDataFilesCount, addedDeleteFilesCount); + commitOperation(rowDelta, commitMsg); + } + } + + private Expression conflictDetectionFilter(SparkBatchQueryScan queryScan) { + Expression filter = Expressions.alwaysTrue(); + + for (Expression expr : queryScan.filterExpressions()) { + filter = Expressions.and(filter, expr); + } + + return filter; + } + + @Override + public void abort(WriterCommitMessage[] messages) { + if (!cleanupOnAbort) { + return; + } + + for (WriterCommitMessage message : messages) { + if (message != null) { + DeltaTaskCommit taskCommit = (DeltaTaskCommit) message; + cleanFiles(table.io(), Arrays.asList(taskCommit.dataFiles())); + cleanFiles(table.io(), Arrays.asList(taskCommit.deleteFiles())); + } + } + } + + private void commitOperation(SnapshotUpdate operation, String description) { + LOG.info("Committing {} to table {}", description, table); + if (applicationId != null) { + operation.set("spark.app.id", applicationId); + } + + extraSnapshotMetadata.forEach(operation::set); + + if (!CommitMetadata.commitProperties().isEmpty()) { + CommitMetadata.commitProperties().forEach(operation::set); + } + + if (wapEnabled && wapId != null) { + // write-audit-publish is enabled for this table and job + // stage the changes without changing the current snapshot + operation.set(SnapshotSummary.STAGED_WAP_ID_PROP, wapId); + operation.stageOnly(); + } + + try { + long start = System.currentTimeMillis(); + operation.commit(); // abort is automatically called if this fails + long duration = System.currentTimeMillis() - start; + LOG.info("Committed in {} ms", duration); + } catch (CommitStateUnknownException commitStateUnknownException) { + cleanupOnAbort = false; + throw commitStateUnknownException; + } + } + } + + public static class DeltaTaskCommit implements WriterCommitMessage { + private final DataFile[] dataFiles; + private final DeleteFile[] deleteFiles; + private final CharSequence[] referencedDataFiles; + + DeltaTaskCommit(WriteResult result) { + this.dataFiles = result.dataFiles(); + this.deleteFiles = result.deleteFiles(); + this.referencedDataFiles = result.referencedDataFiles(); + } + + DeltaTaskCommit(DeleteWriteResult result) { + this.dataFiles = new DataFile[0]; + this.deleteFiles = result.deleteFiles().toArray(new DeleteFile[0]); + this.referencedDataFiles = result.referencedDataFiles().toArray(new CharSequence[0]); + } + + DataFile[] dataFiles() { + return dataFiles; + } + + DeleteFile[] deleteFiles() { + return deleteFiles; + } + + CharSequence[] referencedDataFiles() { + return referencedDataFiles; + } + } + + private static class PositionDeltaWriteFactory implements DeltaWriterFactory { + private final Broadcast
tableBroadcast; + private final Command command; + private final Context context; + + PositionDeltaWriteFactory(Broadcast
tableBroadcast, Command command, Context context) { + this.tableBroadcast = tableBroadcast; + this.command = command; + this.context = context; + } + + @Override + public DeltaWriter createWriter(int partitionId, long taskId) { + Table table = tableBroadcast.value(); + + OutputFileFactory dataFileFactory = OutputFileFactory.builderFor(table, partitionId, taskId) + .format(context.dataFileFormat()) + .build(); + OutputFileFactory deleteFileFactory = OutputFileFactory.builderFor(table, partitionId, taskId) + .format(context.deleteFileFormat()) + .build(); + + SparkFileWriterFactory writerFactory = SparkFileWriterFactory.builderFor(table) + .dataFileFormat(context.dataFileFormat()) + .dataSchema(context.dataSchema()) + .dataSparkType(context.dataSparkType()) + .deleteFileFormat(context.deleteFileFormat()) + .positionDeleteSparkType(context.deleteSparkType()) + .build(); + + if (command == DELETE) { + return new DeleteOnlyDeltaWriter(table, writerFactory, deleteFileFactory, context); + + } else if (table.spec().isUnpartitioned()) { + return new UnpartitionedDeltaWriter(table, writerFactory, dataFileFactory, deleteFileFactory, context); + + } else { + return new PartitionedDeltaWriter(table, writerFactory, dataFileFactory, deleteFileFactory, context); + } + } + } + + private abstract static class BaseDeltaWriter implements DeltaWriter { + + protected InternalRowWrapper initPartitionRowWrapper(Types.StructType partitionType) { + StructType sparkPartitionType = (StructType) SparkSchemaUtil.convert(partitionType); + return new InternalRowWrapper(sparkPartitionType); + } + + protected Map buildPartitionProjections(Types.StructType partitionType, + Map specs) { + Map partitionProjections = Maps.newHashMap(); + specs.forEach((specID, spec) -> + partitionProjections.put(specID, StructProjection.create(partitionType, spec.partitionType())) + ); + return partitionProjections; + } + } + + private static class DeleteOnlyDeltaWriter extends BaseDeltaWriter { + private final ClusteredPositionDeleteWriter delegate; + private final PositionDelete positionDelete; + private final FileIO io; + private final Map specs; + private final InternalRowWrapper partitionRowWrapper; + private final Map partitionProjections; + private final int specIdOrdinal; + private final int partitionOrdinal; + private final int fileOrdinal; + private final int positionOrdinal; + + private boolean closed = false; + + DeleteOnlyDeltaWriter(Table table, SparkFileWriterFactory writerFactory, + OutputFileFactory deleteFileFactory, Context context) { + + this.delegate = new ClusteredPositionDeleteWriter<>( + writerFactory, deleteFileFactory, table.io(), context.targetDeleteFileSize()); + this.positionDelete = PositionDelete.create(); + this.io = table.io(); + this.specs = table.specs(); + + Types.StructType partitionType = Partitioning.partitionType(table); + this.partitionRowWrapper = initPartitionRowWrapper(partitionType); + this.partitionProjections = buildPartitionProjections(partitionType, specs); + + this.specIdOrdinal = context.metadataSparkType().fieldIndex(MetadataColumns.SPEC_ID.name()); + this.partitionOrdinal = context.metadataSparkType().fieldIndex(MetadataColumns.PARTITION_COLUMN_NAME); + this.fileOrdinal = context.deleteSparkType().fieldIndex(MetadataColumns.FILE_PATH.name()); + this.positionOrdinal = context.deleteSparkType().fieldIndex(MetadataColumns.ROW_POSITION.name()); + } + + @Override + public void delete(InternalRow metadata, InternalRow id) throws IOException { + int specId = metadata.getInt(specIdOrdinal); + PartitionSpec spec = specs.get(specId); + + InternalRow partition = metadata.getStruct(partitionOrdinal, partitionRowWrapper.size()); + StructProjection partitionProjection = partitionProjections.get(specId); + partitionProjection.wrap(partitionRowWrapper.wrap(partition)); + + String file = id.getString(fileOrdinal); + long position = id.getLong(positionOrdinal); + positionDelete.set(file, position, null); + delegate.write(positionDelete, spec, partitionProjection); + } + + @Override + public void update(InternalRow metadata, InternalRow id, InternalRow row) { + throw new UnsupportedOperationException(this.getClass().getName() + " does not implement update"); + } + + @Override + public void insert(InternalRow row) throws IOException { + throw new UnsupportedOperationException(this.getClass().getName() + " does not implement insert"); + } + + @Override + public WriterCommitMessage commit() throws IOException { + close(); + + DeleteWriteResult result = delegate.result(); + return new DeltaTaskCommit(result); + } + + @Override + public void abort() throws IOException { + close(); + + DeleteWriteResult result = delegate.result(); + cleanFiles(io, result.deleteFiles()); + } + + @Override + public void close() throws IOException { + if (!closed) { + delegate.close(); + this.closed = true; + } + } + } + + @SuppressWarnings("checkstyle:VisibilityModifier") + private abstract static class DeleteAndDataDeltaWriter extends BaseDeltaWriter { + protected final PositionDeltaWriter delegate; + private final FileIO io; + private final Map specs; + private final InternalRowWrapper deletePartitionRowWrapper; + private final Map deletePartitionProjections; + private final int specIdOrdinal; + private final int partitionOrdinal; + private final int fileOrdinal; + private final int positionOrdinal; + + private boolean closed = false; + + DeleteAndDataDeltaWriter(Table table, SparkFileWriterFactory writerFactory, + OutputFileFactory dataFileFactory, OutputFileFactory deleteFileFactory, + Context context) { + this.delegate = new BasePositionDeltaWriter<>( + newInsertWriter(table, writerFactory, dataFileFactory, context), + newUpdateWriter(table, writerFactory, dataFileFactory, context), + newDeleteWriter(table, writerFactory, deleteFileFactory, context)); + this.io = table.io(); + this.specs = table.specs(); + + Types.StructType partitionType = Partitioning.partitionType(table); + this.deletePartitionRowWrapper = initPartitionRowWrapper(partitionType); + this.deletePartitionProjections = buildPartitionProjections(partitionType, specs); + + this.specIdOrdinal = context.metadataSparkType().fieldIndex(MetadataColumns.SPEC_ID.name()); + this.partitionOrdinal = context.metadataSparkType().fieldIndex(MetadataColumns.PARTITION_COLUMN_NAME); + this.fileOrdinal = context.deleteSparkType().fieldIndex(MetadataColumns.FILE_PATH.name()); + this.positionOrdinal = context.deleteSparkType().fieldIndex(MetadataColumns.ROW_POSITION.name()); + } + + @Override + public void delete(InternalRow meta, InternalRow id) throws IOException { + int specId = meta.getInt(specIdOrdinal); + PartitionSpec spec = specs.get(specId); + + InternalRow partition = meta.getStruct(partitionOrdinal, deletePartitionRowWrapper.size()); + StructProjection partitionProjection = deletePartitionProjections.get(specId); + partitionProjection.wrap(deletePartitionRowWrapper.wrap(partition)); + + String file = id.getString(fileOrdinal); + long position = id.getLong(positionOrdinal); + delegate.delete(file, position, spec, partitionProjection); + } + + @Override + public WriterCommitMessage commit() throws IOException { + close(); + + WriteResult result = delegate.result(); + return new DeltaTaskCommit(result); + } + + @Override + public void abort() throws IOException { + close(); + + WriteResult result = delegate.result(); + cleanFiles(io, Arrays.asList(result.dataFiles())); + cleanFiles(io, Arrays.asList(result.deleteFiles())); + } + + @Override + public void close() throws IOException { + if (!closed) { + delegate.close(); + this.closed = true; + } + } + + private PartitioningWriter newInsertWriter(Table table, + SparkFileWriterFactory writerFactory, + OutputFileFactory fileFactory, + Context context) { + long targetFileSize = context.targetDataFileSize(); + + if (table.spec().isPartitioned() && context.fanoutWriterEnabled()) { + return new FanoutDataWriter<>(writerFactory, fileFactory, table.io(), targetFileSize); + } else { + return new ClusteredDataWriter<>(writerFactory, fileFactory, table.io(), targetFileSize); + } + } + + private PartitioningWriter newUpdateWriter(Table table, + SparkFileWriterFactory writerFactory, + OutputFileFactory fileFactory, + Context context) { + long targetFileSize = context.targetDataFileSize(); + + if (table.spec().isPartitioned()) { + // use a fanout writer for partitioned tables to write updates as they may be out of order + return new FanoutDataWriter<>(writerFactory, fileFactory, table.io(), targetFileSize); + } else { + return new ClusteredDataWriter<>(writerFactory, fileFactory, table.io(), targetFileSize); + } + } + + private ClusteredPositionDeleteWriter newDeleteWriter(Table table, + SparkFileWriterFactory writerFactory, + OutputFileFactory fileFactory, + Context context) { + long targetFileSize = context.targetDeleteFileSize(); + return new ClusteredPositionDeleteWriter<>(writerFactory, fileFactory, table.io(), targetFileSize); + } + } + + private static class UnpartitionedDeltaWriter extends DeleteAndDataDeltaWriter { + private final PartitionSpec dataSpec; + + UnpartitionedDeltaWriter(Table table, SparkFileWriterFactory writerFactory, + OutputFileFactory dataFileFactory, OutputFileFactory deleteFileFactory, + Context context) { + super(table, writerFactory, dataFileFactory, deleteFileFactory, context); + this.dataSpec = table.spec(); + } + + @Override + public void update(InternalRow meta, InternalRow id, InternalRow row) throws IOException { + delete(meta, id); + delegate.update(row, dataSpec, null); + } + + @Override + public void insert(InternalRow row) throws IOException { + delegate.insert(row, dataSpec, null); + } + } + + private static class PartitionedDeltaWriter extends DeleteAndDataDeltaWriter { + private final PartitionSpec dataSpec; + private final PartitionKey dataPartitionKey; + private final InternalRowWrapper internalRowDataWrapper; + + PartitionedDeltaWriter(Table table, SparkFileWriterFactory writerFactory, + OutputFileFactory dataFileFactory, OutputFileFactory deleteFileFactory, + Context context) { + super(table, writerFactory, dataFileFactory, deleteFileFactory, context); + + this.dataSpec = table.spec(); + this.dataPartitionKey = new PartitionKey(dataSpec, context.dataSchema()); + this.internalRowDataWrapper = new InternalRowWrapper(context.dataSparkType()); + } + + @Override + public void update(InternalRow meta, InternalRow id, InternalRow row) throws IOException { + delete(meta, id); + dataPartitionKey.partition(internalRowDataWrapper.wrap(row)); + delegate.update(row, dataSpec, dataPartitionKey); + } + + @Override + public void insert(InternalRow row) throws IOException { + dataPartitionKey.partition(internalRowDataWrapper.wrap(row)); + delegate.insert(row, dataSpec, dataPartitionKey); + } + } + + // a serializable helper class for common parameters required to configure writers + private static class Context implements Serializable { + private final Schema dataSchema; + private final StructType dataSparkType; + private final FileFormat dataFileFormat; + private final long targetDataFileSize; + private final StructType deleteSparkType; + private final StructType metadataSparkType; + private final FileFormat deleteFileFormat; + private final long targetDeleteFileSize; + private final boolean fanoutWriterEnabled; + + Context(Schema dataSchema, SparkWriteConf writeConf, ExtendedLogicalWriteInfo info) { + this.dataSchema = dataSchema; + this.dataSparkType = info.schema(); + this.dataFileFormat = writeConf.dataFileFormat(); + this.targetDataFileSize = writeConf.targetDataFileSize(); + this.deleteSparkType = info.rowIdSchema(); + this.deleteFileFormat = writeConf.deleteFileFormat(); + this.targetDeleteFileSize = writeConf.targetDeleteFileSize(); + this.metadataSparkType = info.metadataSchema(); + this.fanoutWriterEnabled = writeConf.fanoutWriterEnabled(); + } + + Schema dataSchema() { + return dataSchema; + } + + StructType dataSparkType() { + return dataSparkType; + } + + FileFormat dataFileFormat() { + return dataFileFormat; + } + + long targetDataFileSize() { + return targetDataFileSize; + } + + StructType deleteSparkType() { + return deleteSparkType; + } + + StructType metadataSparkType() { + return metadataSparkType; + } + + FileFormat deleteFileFormat() { + return deleteFileFormat; + } + + long targetDeleteFileSize() { + return targetDeleteFileSize; + } + + boolean fanoutWriterEnabled() { + return fanoutWriterEnabled; + } + } +} diff --git a/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/source/SparkPositionDeltaWriteBuilder.java b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/source/SparkPositionDeltaWriteBuilder.java new file mode 100644 index 000000000000..167e0853f055 --- /dev/null +++ b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/source/SparkPositionDeltaWriteBuilder.java @@ -0,0 +1,123 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iceberg.spark.source; + +import org.apache.iceberg.DistributionMode; +import org.apache.iceberg.IsolationLevel; +import org.apache.iceberg.MetadataColumns; +import org.apache.iceberg.Schema; +import org.apache.iceberg.Table; +import org.apache.iceberg.relocated.com.google.common.base.Preconditions; +import org.apache.iceberg.spark.SparkDistributionAndOrderingUtil; +import org.apache.iceberg.spark.SparkSchemaUtil; +import org.apache.iceberg.spark.SparkUtil; +import org.apache.iceberg.spark.SparkWriteConf; +import org.apache.iceberg.types.TypeUtil; +import org.apache.iceberg.types.Types.NestedField; +import org.apache.spark.sql.SparkSession; +import org.apache.spark.sql.connector.distributions.Distribution; +import org.apache.spark.sql.connector.expressions.SortOrder; +import org.apache.spark.sql.connector.iceberg.write.DeltaWrite; +import org.apache.spark.sql.connector.iceberg.write.DeltaWriteBuilder; +import org.apache.spark.sql.connector.iceberg.write.ExtendedLogicalWriteInfo; +import org.apache.spark.sql.connector.read.Scan; +import org.apache.spark.sql.connector.write.RowLevelOperation.Command; +import org.apache.spark.sql.types.StructType; + +class SparkPositionDeltaWriteBuilder implements DeltaWriteBuilder { + + private static final Schema EXPECTED_ROW_ID_SCHEMA = new Schema( + MetadataColumns.FILE_PATH, + MetadataColumns.ROW_POSITION + ); + + private final SparkSession spark; + private final Table table; + private final Command command; + private final SparkBatchQueryScan scan; + private final IsolationLevel isolationLevel; + private final SparkWriteConf writeConf; + private final ExtendedLogicalWriteInfo info; + private final boolean handleTimestampWithoutZone; + private final boolean checkNullability; + private final boolean checkOrdering; + + SparkPositionDeltaWriteBuilder(SparkSession spark, Table table, Command command, Scan scan, + IsolationLevel isolationLevel, ExtendedLogicalWriteInfo info) { + this.spark = spark; + this.table = table; + this.command = command; + this.scan = (SparkBatchQueryScan) scan; + this.isolationLevel = isolationLevel; + this.writeConf = new SparkWriteConf(spark, table, info.options()); + this.info = info; + this.handleTimestampWithoutZone = writeConf.handleTimestampWithoutZone(); + this.checkNullability = writeConf.checkNullability(); + this.checkOrdering = writeConf.checkOrdering(); + } + + @Override + public DeltaWrite build() { + Preconditions.checkArgument(handleTimestampWithoutZone || !SparkUtil.hasTimestampWithoutZone(table.schema()), + SparkUtil.TIMESTAMP_WITHOUT_TIMEZONE_ERROR); + + Schema dataSchema = dataSchema(); + if (dataSchema != null) { + TypeUtil.validateWriteSchema(table.schema(), dataSchema, checkNullability, checkOrdering); + } + + Schema rowIdSchema = SparkSchemaUtil.convert(EXPECTED_ROW_ID_SCHEMA, info.rowIdSchema()); + TypeUtil.validateSchema("row ID", EXPECTED_ROW_ID_SCHEMA, rowIdSchema, checkNullability, checkOrdering); + + NestedField partition = MetadataColumns.metadataColumn(table, MetadataColumns.PARTITION_COLUMN_NAME); + Schema expectedMetadataSchema = new Schema(MetadataColumns.SPEC_ID, partition); + Schema metadataSchema = SparkSchemaUtil.convert(expectedMetadataSchema, info.metadataSchema()); + TypeUtil.validateSchema("metadata", expectedMetadataSchema, metadataSchema, checkNullability, checkOrdering); + + SparkUtil.validatePartitionTransforms(table.spec()); + + Distribution distribution = SparkDistributionAndOrderingUtil.buildPositionDeltaDistribution( + table, command, distributionMode()); + SortOrder[] ordering = SparkDistributionAndOrderingUtil.buildPositionDeltaOrdering( + table, command); + + return new SparkPositionDeltaWrite( + spark, table, command, scan, isolationLevel, writeConf, + info, dataSchema, distribution, ordering); + } + + private Schema dataSchema() { + StructType dataSparkType = info.schema(); + return dataSparkType != null ? SparkSchemaUtil.convert(table.schema(), dataSparkType) : null; + } + + private DistributionMode distributionMode() { + switch (command) { + case DELETE: + return writeConf.deleteDistributionMode(); + case UPDATE: + return writeConf.updateDistributionMode(); + case MERGE: + return writeConf.positionDeltaMergeDistributionMode(); + default: + throw new IllegalArgumentException("Unexpected command: " + command); + } + } +} diff --git a/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/source/SparkRowLevelOperationBuilder.java b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/source/SparkRowLevelOperationBuilder.java new file mode 100644 index 000000000000..3cdb585f9745 --- /dev/null +++ b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/source/SparkRowLevelOperationBuilder.java @@ -0,0 +1,112 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iceberg.spark.source; + +import java.util.Map; +import org.apache.iceberg.IsolationLevel; +import org.apache.iceberg.RowLevelOperationMode; +import org.apache.iceberg.Table; +import org.apache.spark.sql.SparkSession; +import org.apache.spark.sql.connector.write.RowLevelOperation; +import org.apache.spark.sql.connector.write.RowLevelOperation.Command; +import org.apache.spark.sql.connector.write.RowLevelOperationBuilder; +import org.apache.spark.sql.connector.write.RowLevelOperationInfo; + +import static org.apache.iceberg.TableProperties.DELETE_ISOLATION_LEVEL; +import static org.apache.iceberg.TableProperties.DELETE_ISOLATION_LEVEL_DEFAULT; +import static org.apache.iceberg.TableProperties.DELETE_MODE; +import static org.apache.iceberg.TableProperties.DELETE_MODE_DEFAULT; +import static org.apache.iceberg.TableProperties.MERGE_ISOLATION_LEVEL; +import static org.apache.iceberg.TableProperties.MERGE_ISOLATION_LEVEL_DEFAULT; +import static org.apache.iceberg.TableProperties.MERGE_MODE; +import static org.apache.iceberg.TableProperties.MERGE_MODE_DEFAULT; +import static org.apache.iceberg.TableProperties.UPDATE_ISOLATION_LEVEL; +import static org.apache.iceberg.TableProperties.UPDATE_ISOLATION_LEVEL_DEFAULT; +import static org.apache.iceberg.TableProperties.UPDATE_MODE; +import static org.apache.iceberg.TableProperties.UPDATE_MODE_DEFAULT; + +class SparkRowLevelOperationBuilder implements RowLevelOperationBuilder { + + private final SparkSession spark; + private final Table table; + private final RowLevelOperationInfo info; + private final RowLevelOperationMode mode; + private final IsolationLevel isolationLevel; + + SparkRowLevelOperationBuilder(SparkSession spark, Table table, RowLevelOperationInfo info) { + this.spark = spark; + this.table = table; + this.info = info; + this.mode = mode(table.properties(), info.command()); + this.isolationLevel = isolationLevel(table.properties(), info.command()); + } + + @Override + public RowLevelOperation build() { + switch (mode) { + case COPY_ON_WRITE: + return new SparkCopyOnWriteOperation(spark, table, info, isolationLevel); + case MERGE_ON_READ: + return new SparkPositionDeltaOperation(spark, table, info, isolationLevel); + default: + throw new IllegalArgumentException("Unsupported operation mode: " + mode); + } + } + + private RowLevelOperationMode mode(Map properties, Command command) { + String modeName; + + switch (command) { + case DELETE: + modeName = properties.getOrDefault(DELETE_MODE, DELETE_MODE_DEFAULT); + break; + case UPDATE: + modeName = properties.getOrDefault(UPDATE_MODE, UPDATE_MODE_DEFAULT); + break; + case MERGE: + modeName = properties.getOrDefault(MERGE_MODE, MERGE_MODE_DEFAULT); + break; + default: + throw new IllegalArgumentException("Unsupported command: " + command); + } + + return RowLevelOperationMode.fromName(modeName); + } + + private IsolationLevel isolationLevel(Map properties, Command command) { + String levelName; + + switch (command) { + case DELETE: + levelName = properties.getOrDefault(DELETE_ISOLATION_LEVEL, DELETE_ISOLATION_LEVEL_DEFAULT); + break; + case UPDATE: + levelName = properties.getOrDefault(UPDATE_ISOLATION_LEVEL, UPDATE_ISOLATION_LEVEL_DEFAULT); + break; + case MERGE: + levelName = properties.getOrDefault(MERGE_ISOLATION_LEVEL, MERGE_ISOLATION_LEVEL_DEFAULT); + break; + default: + throw new IllegalArgumentException("Unsupported command: " + command); + } + + return IsolationLevel.fromName(levelName); + } +} diff --git a/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/source/SparkScan.java b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/source/SparkScan.java new file mode 100644 index 000000000000..7d93ad66e1e8 --- /dev/null +++ b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/source/SparkScan.java @@ -0,0 +1,257 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iceberg.spark.source; + +import java.io.Serializable; +import java.util.Collection; +import java.util.Collections; +import java.util.List; +import java.util.stream.Collectors; +import org.apache.iceberg.CombinedScanTask; +import org.apache.iceberg.FileScanTask; +import org.apache.iceberg.Schema; +import org.apache.iceberg.SchemaParser; +import org.apache.iceberg.Snapshot; +import org.apache.iceberg.SnapshotSummary; +import org.apache.iceberg.Table; +import org.apache.iceberg.expressions.Expression; +import org.apache.iceberg.hadoop.HadoopInputFile; +import org.apache.iceberg.hadoop.Util; +import org.apache.iceberg.relocated.com.google.common.base.Preconditions; +import org.apache.iceberg.spark.Spark3Util; +import org.apache.iceberg.spark.SparkReadConf; +import org.apache.iceberg.spark.SparkSchemaUtil; +import org.apache.iceberg.spark.SparkUtil; +import org.apache.iceberg.util.PropertyUtil; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.broadcast.Broadcast; +import org.apache.spark.sql.SparkSession; +import org.apache.spark.sql.catalyst.InternalRow; +import org.apache.spark.sql.connector.read.Batch; +import org.apache.spark.sql.connector.read.InputPartition; +import org.apache.spark.sql.connector.read.PartitionReader; +import org.apache.spark.sql.connector.read.PartitionReaderFactory; +import org.apache.spark.sql.connector.read.Scan; +import org.apache.spark.sql.connector.read.Statistics; +import org.apache.spark.sql.connector.read.SupportsReportStatistics; +import org.apache.spark.sql.connector.read.streaming.MicroBatchStream; +import org.apache.spark.sql.types.StructType; +import org.apache.spark.sql.vectorized.ColumnarBatch; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +abstract class SparkScan implements Scan, SupportsReportStatistics { + private static final Logger LOG = LoggerFactory.getLogger(SparkScan.class); + + private final JavaSparkContext sparkContext; + private final Table table; + private final SparkReadConf readConf; + private final boolean caseSensitive; + private final Schema expectedSchema; + private final List filterExpressions; + private final boolean readTimestampWithoutZone; + + // lazy variables + private StructType readSchema = null; + + SparkScan(SparkSession spark, Table table, SparkReadConf readConf, + Schema expectedSchema, List filters) { + + SparkSchemaUtil.validateMetadataColumnReferences(table.schema(), expectedSchema); + + this.sparkContext = JavaSparkContext.fromSparkContext(spark.sparkContext()); + this.table = table; + this.readConf = readConf; + this.caseSensitive = readConf.caseSensitive(); + this.expectedSchema = expectedSchema; + this.filterExpressions = filters != null ? filters : Collections.emptyList(); + this.readTimestampWithoutZone = readConf.handleTimestampWithoutZone(); + } + + protected Table table() { + return table; + } + + protected boolean caseSensitive() { + return caseSensitive; + } + + protected Schema expectedSchema() { + return expectedSchema; + } + + protected List filterExpressions() { + return filterExpressions; + } + + protected abstract List tasks(); + + @Override + public Batch toBatch() { + return new SparkBatch(sparkContext, table, readConf, tasks(), expectedSchema); + } + + @Override + public MicroBatchStream toMicroBatchStream(String checkpointLocation) { + return new SparkMicroBatchStream(sparkContext, table, readConf, expectedSchema, checkpointLocation); + } + + @Override + public StructType readSchema() { + if (readSchema == null) { + Preconditions.checkArgument(readTimestampWithoutZone || !SparkUtil.hasTimestampWithoutZone(expectedSchema), + SparkUtil.TIMESTAMP_WITHOUT_TIMEZONE_ERROR); + this.readSchema = SparkSchemaUtil.convert(expectedSchema); + } + return readSchema; + } + + @Override + public Statistics estimateStatistics() { + return estimateStatistics(table.currentSnapshot()); + } + + protected Statistics estimateStatistics(Snapshot snapshot) { + // its a fresh table, no data + if (snapshot == null) { + return new Stats(0L, 0L); + } + + // estimate stats using snapshot summary only for partitioned tables (metadata tables are unpartitioned) + if (!table.spec().isUnpartitioned() && filterExpressions.isEmpty()) { + LOG.debug("using table metadata to estimate table statistics"); + long totalRecords = PropertyUtil.propertyAsLong(snapshot.summary(), + SnapshotSummary.TOTAL_RECORDS_PROP, Long.MAX_VALUE); + return new Stats( + SparkSchemaUtil.estimateSize(readSchema(), totalRecords), + totalRecords); + } + + long numRows = 0L; + + for (CombinedScanTask task : tasks()) { + for (FileScanTask file : task.files()) { + // TODO: if possible, take deletes also into consideration. + double fractionOfFileScanned = ((double) file.length()) / file.file().fileSizeInBytes(); + numRows += (fractionOfFileScanned * file.file().recordCount()); + } + } + + long sizeInBytes = SparkSchemaUtil.estimateSize(readSchema(), numRows); + return new Stats(sizeInBytes, numRows); + } + + @Override + public String description() { + String filters = filterExpressions.stream().map(Spark3Util::describe).collect(Collectors.joining(", ")); + return String.format("%s [filters=%s]", table, filters); + } + + static class ReaderFactory implements PartitionReaderFactory { + private final int batchSize; + + ReaderFactory(int batchSize) { + this.batchSize = batchSize; + } + + @Override + public PartitionReader createReader(InputPartition partition) { + if (partition instanceof ReadTask) { + return new RowReader((ReadTask) partition); + } else { + throw new UnsupportedOperationException("Incorrect input partition type: " + partition); + } + } + + @Override + public PartitionReader createColumnarReader(InputPartition partition) { + if (partition instanceof ReadTask) { + return new BatchReader((ReadTask) partition, batchSize); + } else { + throw new UnsupportedOperationException("Incorrect input partition type: " + partition); + } + } + + @Override + public boolean supportColumnarReads(InputPartition partition) { + return batchSize > 1; + } + } + + private static class RowReader extends RowDataReader implements PartitionReader { + RowReader(ReadTask task) { + super(task.task, task.table(), task.expectedSchema(), task.isCaseSensitive()); + } + } + + private static class BatchReader extends BatchDataReader implements PartitionReader { + BatchReader(ReadTask task, int batchSize) { + super(task.task, task.table(), task.expectedSchema(), task.isCaseSensitive(), batchSize); + } + } + + static class ReadTask implements InputPartition, Serializable { + private final CombinedScanTask task; + private final Broadcast
tableBroadcast; + private final String expectedSchemaString; + private final boolean caseSensitive; + + private transient Schema expectedSchema = null; + private transient String[] preferredLocations = null; + + ReadTask(CombinedScanTask task, Broadcast
tableBroadcast, String expectedSchemaString, + boolean caseSensitive, boolean localityPreferred) { + this.task = task; + this.tableBroadcast = tableBroadcast; + this.expectedSchemaString = expectedSchemaString; + this.caseSensitive = caseSensitive; + if (localityPreferred) { + Table table = tableBroadcast.value(); + this.preferredLocations = Util.blockLocations(table.io(), task); + } else { + this.preferredLocations = HadoopInputFile.NO_LOCATION_PREFERENCE; + } + } + + @Override + public String[] preferredLocations() { + return preferredLocations; + } + + public Collection files() { + return task.files(); + } + + public Table table() { + return tableBroadcast.value(); + } + + public boolean isCaseSensitive() { + return caseSensitive; + } + + private Schema expectedSchema() { + if (expectedSchema == null) { + this.expectedSchema = SchemaParser.fromJson(expectedSchemaString); + } + return expectedSchema; + } + } +} diff --git a/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/source/SparkScanBuilder.java b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/source/SparkScanBuilder.java new file mode 100644 index 000000000000..aa2214bc7009 --- /dev/null +++ b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/source/SparkScanBuilder.java @@ -0,0 +1,295 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iceberg.spark.source; + +import java.util.List; +import java.util.stream.Collectors; +import java.util.stream.Stream; +import org.apache.iceberg.MetadataColumns; +import org.apache.iceberg.Schema; +import org.apache.iceberg.Snapshot; +import org.apache.iceberg.Table; +import org.apache.iceberg.TableProperties; +import org.apache.iceberg.TableScan; +import org.apache.iceberg.exceptions.ValidationException; +import org.apache.iceberg.expressions.Binder; +import org.apache.iceberg.expressions.Expression; +import org.apache.iceberg.expressions.Expressions; +import org.apache.iceberg.relocated.com.google.common.base.Preconditions; +import org.apache.iceberg.relocated.com.google.common.collect.Lists; +import org.apache.iceberg.spark.Spark3Util; +import org.apache.iceberg.spark.SparkFilters; +import org.apache.iceberg.spark.SparkReadConf; +import org.apache.iceberg.spark.SparkReadOptions; +import org.apache.iceberg.spark.SparkSchemaUtil; +import org.apache.iceberg.types.TypeUtil; +import org.apache.iceberg.types.Types; +import org.apache.spark.sql.SparkSession; +import org.apache.spark.sql.connector.read.Scan; +import org.apache.spark.sql.connector.read.ScanBuilder; +import org.apache.spark.sql.connector.read.Statistics; +import org.apache.spark.sql.connector.read.SupportsPushDownFilters; +import org.apache.spark.sql.connector.read.SupportsPushDownRequiredColumns; +import org.apache.spark.sql.connector.read.SupportsReportStatistics; +import org.apache.spark.sql.sources.Filter; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; +import org.apache.spark.sql.util.CaseInsensitiveStringMap; + +public class SparkScanBuilder implements ScanBuilder, SupportsPushDownFilters, SupportsPushDownRequiredColumns, + SupportsReportStatistics { + + private static final Filter[] NO_FILTERS = new Filter[0]; + + private final SparkSession spark; + private final Table table; + private final CaseInsensitiveStringMap options; + private final SparkReadConf readConf; + private final List metaColumns = Lists.newArrayList(); + + private Schema schema = null; + private boolean caseSensitive; + private List filterExpressions = null; + private Filter[] pushedFilters = NO_FILTERS; + + SparkScanBuilder(SparkSession spark, Table table, Schema schema, CaseInsensitiveStringMap options) { + this.spark = spark; + this.table = table; + this.schema = schema; + this.options = options; + this.readConf = new SparkReadConf(spark, table, options); + this.caseSensitive = readConf.caseSensitive(); + } + + SparkScanBuilder(SparkSession spark, Table table, CaseInsensitiveStringMap options) { + this(spark, table, table.schema(), options); + } + + private Expression filterExpression() { + if (filterExpressions != null) { + return filterExpressions.stream().reduce(Expressions.alwaysTrue(), Expressions::and); + } + return Expressions.alwaysTrue(); + } + + public SparkScanBuilder caseSensitive(boolean isCaseSensitive) { + this.caseSensitive = isCaseSensitive; + return this; + } + + @Override + public Filter[] pushFilters(Filter[] filters) { + List expressions = Lists.newArrayListWithExpectedSize(filters.length); + List pushed = Lists.newArrayListWithExpectedSize(filters.length); + + for (Filter filter : filters) { + Expression expr = SparkFilters.convert(filter); + if (expr != null) { + try { + Binder.bind(schema.asStruct(), expr, caseSensitive); + expressions.add(expr); + pushed.add(filter); + } catch (ValidationException e) { + // binding to the table schema failed, so this expression cannot be pushed down + } + } + } + + this.filterExpressions = expressions; + this.pushedFilters = pushed.toArray(new Filter[0]); + + // Spark doesn't support residuals per task, so return all filters + // to get Spark to handle record-level filtering + return filters; + } + + @Override + public Filter[] pushedFilters() { + return pushedFilters; + } + + @Override + public void pruneColumns(StructType requestedSchema) { + StructType requestedProjection = new StructType(Stream.of(requestedSchema.fields()) + .filter(field -> MetadataColumns.nonMetadataColumn(field.name())) + .toArray(StructField[]::new)); + + // the projection should include all columns that will be returned, including those only used in filters + this.schema = SparkSchemaUtil.prune(schema, requestedProjection, filterExpression(), caseSensitive); + + Stream.of(requestedSchema.fields()) + .map(StructField::name) + .filter(MetadataColumns::isMetadataColumn) + .distinct() + .forEach(metaColumns::add); + } + + private Schema schemaWithMetadataColumns() { + // metadata columns + List fields = metaColumns.stream() + .distinct() + .map(name -> MetadataColumns.metadataColumn(table, name)) + .collect(Collectors.toList()); + Schema meta = new Schema(fields); + + // schema or rows returned by readers + return TypeUtil.join(schema, meta); + } + + @Override + public Scan build() { + Long snapshotId = readConf.snapshotId(); + Long asOfTimestamp = readConf.asOfTimestamp(); + + Preconditions.checkArgument(snapshotId == null || asOfTimestamp == null, + "Cannot set both %s and %s to select which table snapshot to scan", + SparkReadOptions.SNAPSHOT_ID, SparkReadOptions.AS_OF_TIMESTAMP); + + Long startSnapshotId = readConf.startSnapshotId(); + Long endSnapshotId = readConf.endSnapshotId(); + + if (snapshotId != null || asOfTimestamp != null) { + Preconditions.checkArgument(startSnapshotId == null && endSnapshotId == null, + "Cannot set %s and %s for incremental scans when either %s or %s is set", + SparkReadOptions.START_SNAPSHOT_ID, SparkReadOptions.END_SNAPSHOT_ID, + SparkReadOptions.SNAPSHOT_ID, SparkReadOptions.AS_OF_TIMESTAMP); + } + + Preconditions.checkArgument(startSnapshotId != null || endSnapshotId == null, + "Cannot set only %s for incremental scans. Please, set %s too.", + SparkReadOptions.END_SNAPSHOT_ID, SparkReadOptions.START_SNAPSHOT_ID); + + Schema expectedSchema = schemaWithMetadataColumns(); + + TableScan scan = table.newScan() + .caseSensitive(caseSensitive) + .filter(filterExpression()) + .project(expectedSchema); + + if (snapshotId != null) { + scan = scan.useSnapshot(snapshotId); + } + + if (asOfTimestamp != null) { + scan = scan.asOfTime(asOfTimestamp); + } + + if (startSnapshotId != null) { + if (endSnapshotId != null) { + scan = scan.appendsBetween(startSnapshotId, endSnapshotId); + } else { + scan = scan.appendsAfter(startSnapshotId); + } + } + + scan = configureSplitPlanning(scan); + + return new SparkBatchQueryScan(spark, table, scan, readConf, expectedSchema, filterExpressions); + } + + public Scan buildMergeOnReadScan() { + Preconditions.checkArgument(readConf.snapshotId() == null && readConf.asOfTimestamp() == null, + "Cannot set time travel options %s and %s for row-level command scans", + SparkReadOptions.SNAPSHOT_ID, SparkReadOptions.AS_OF_TIMESTAMP); + + Preconditions.checkArgument(readConf.startSnapshotId() == null && readConf.endSnapshotId() == null, + "Cannot set incremental scan options %s and %s for row-level command scans", + SparkReadOptions.START_SNAPSHOT_ID, SparkReadOptions.END_SNAPSHOT_ID); + + Snapshot snapshot = table.currentSnapshot(); + + if (snapshot == null) { + return new SparkBatchQueryScan(spark, table, null, readConf, schemaWithMetadataColumns(), filterExpressions); + } + + // remember the current snapshot ID for commit validation + long snapshotId = snapshot.snapshotId(); + + CaseInsensitiveStringMap adjustedOptions = Spark3Util.setOption( + SparkReadOptions.SNAPSHOT_ID, + Long.toString(snapshotId), + options); + SparkReadConf adjustedReadConf = new SparkReadConf(spark, table, adjustedOptions); + + Schema expectedSchema = schemaWithMetadataColumns(); + + TableScan scan = table.newScan() + .useSnapshot(snapshotId) + .caseSensitive(caseSensitive) + .filter(filterExpression()) + .project(expectedSchema); + + scan = configureSplitPlanning(scan); + + return new SparkBatchQueryScan(spark, table, scan, adjustedReadConf, expectedSchema, filterExpressions); + } + + public Scan buildCopyOnWriteScan() { + Snapshot snapshot = table.currentSnapshot(); + + if (snapshot == null) { + return new SparkCopyOnWriteScan(spark, table, readConf, schemaWithMetadataColumns(), filterExpressions); + } + + Schema expectedSchema = schemaWithMetadataColumns(); + + TableScan scan = table.newScan() + .useSnapshot(snapshot.snapshotId()) + .ignoreResiduals() + .caseSensitive(caseSensitive) + .filter(filterExpression()) + .project(expectedSchema); + + scan = configureSplitPlanning(scan); + + return new SparkCopyOnWriteScan(spark, table, scan, snapshot, readConf, expectedSchema, filterExpressions); + } + + private TableScan configureSplitPlanning(TableScan scan) { + TableScan configuredScan = scan; + + Long splitSize = readConf.splitSizeOption(); + if (splitSize != null) { + configuredScan = configuredScan.option(TableProperties.SPLIT_SIZE, String.valueOf(splitSize)); + } + + Integer splitLookback = readConf.splitLookbackOption(); + if (splitLookback != null) { + configuredScan = configuredScan.option(TableProperties.SPLIT_LOOKBACK, String.valueOf(splitLookback)); + } + + Long splitOpenFileCost = readConf.splitOpenFileCostOption(); + if (splitOpenFileCost != null) { + configuredScan = configuredScan.option(TableProperties.SPLIT_OPEN_FILE_COST, String.valueOf(splitOpenFileCost)); + } + + return configuredScan; + } + + @Override + public Statistics estimateStatistics() { + return ((SparkScan) build()).estimateStatistics(); + } + + @Override + public StructType readSchema() { + return build().readSchema(); + } +} diff --git a/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/source/SparkTable.java b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/source/SparkTable.java new file mode 100644 index 000000000000..590f9336966b --- /dev/null +++ b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/source/SparkTable.java @@ -0,0 +1,339 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iceberg.spark.source; + +import java.io.IOException; +import java.util.Map; +import java.util.Set; +import org.apache.iceberg.BaseTable; +import org.apache.iceberg.DataFile; +import org.apache.iceberg.FileScanTask; +import org.apache.iceberg.MetadataColumns; +import org.apache.iceberg.PartitionSpec; +import org.apache.iceberg.Partitioning; +import org.apache.iceberg.Schema; +import org.apache.iceberg.Table; +import org.apache.iceberg.TableOperations; +import org.apache.iceberg.TableProperties; +import org.apache.iceberg.TableScan; +import org.apache.iceberg.expressions.Evaluator; +import org.apache.iceberg.expressions.Expression; +import org.apache.iceberg.expressions.Expressions; +import org.apache.iceberg.expressions.Projections; +import org.apache.iceberg.expressions.StrictMetricsEvaluator; +import org.apache.iceberg.io.CloseableIterable; +import org.apache.iceberg.relocated.com.google.common.base.Preconditions; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableSet; +import org.apache.iceberg.relocated.com.google.common.collect.Iterables; +import org.apache.iceberg.relocated.com.google.common.collect.Maps; +import org.apache.iceberg.spark.Spark3Util; +import org.apache.iceberg.spark.SparkFilters; +import org.apache.iceberg.spark.SparkReadOptions; +import org.apache.iceberg.spark.SparkSchemaUtil; +import org.apache.iceberg.util.PropertyUtil; +import org.apache.iceberg.util.SnapshotUtil; +import org.apache.spark.sql.SparkSession; +import org.apache.spark.sql.connector.catalog.MetadataColumn; +import org.apache.spark.sql.connector.catalog.SupportsDelete; +import org.apache.spark.sql.connector.catalog.SupportsMetadataColumns; +import org.apache.spark.sql.connector.catalog.SupportsRead; +import org.apache.spark.sql.connector.catalog.SupportsRowLevelOperations; +import org.apache.spark.sql.connector.catalog.SupportsWrite; +import org.apache.spark.sql.connector.catalog.TableCapability; +import org.apache.spark.sql.connector.expressions.Transform; +import org.apache.spark.sql.connector.read.ScanBuilder; +import org.apache.spark.sql.connector.write.LogicalWriteInfo; +import org.apache.spark.sql.connector.write.RowLevelOperationBuilder; +import org.apache.spark.sql.connector.write.RowLevelOperationInfo; +import org.apache.spark.sql.connector.write.WriteBuilder; +import org.apache.spark.sql.sources.Filter; +import org.apache.spark.sql.types.DataType; +import org.apache.spark.sql.types.DataTypes; +import org.apache.spark.sql.types.StructType; +import org.apache.spark.sql.util.CaseInsensitiveStringMap; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import static org.apache.iceberg.TableProperties.CURRENT_SNAPSHOT_ID; +import static org.apache.iceberg.TableProperties.FORMAT_VERSION; + +public class SparkTable implements org.apache.spark.sql.connector.catalog.Table, + SupportsRead, SupportsWrite, SupportsDelete, SupportsRowLevelOperations, SupportsMetadataColumns { + + private static final Logger LOG = LoggerFactory.getLogger(SparkTable.class); + + private static final Set RESERVED_PROPERTIES = + ImmutableSet.of("provider", "format", CURRENT_SNAPSHOT_ID, "location", FORMAT_VERSION, "sort-order", + "identifier-fields"); + private static final Set CAPABILITIES = ImmutableSet.of( + TableCapability.BATCH_READ, + TableCapability.BATCH_WRITE, + TableCapability.MICRO_BATCH_READ, + TableCapability.STREAMING_WRITE, + TableCapability.OVERWRITE_BY_FILTER, + TableCapability.OVERWRITE_DYNAMIC); + private static final Set CAPABILITIES_WITH_ACCEPT_ANY_SCHEMA = + ImmutableSet.builder() + .addAll(CAPABILITIES) + .add(TableCapability.ACCEPT_ANY_SCHEMA) + .build(); + + private final Table icebergTable; + private final Long snapshotId; + private final boolean refreshEagerly; + private final Set capabilities; + private StructType lazyTableSchema = null; + private SparkSession lazySpark = null; + + public SparkTable(Table icebergTable, boolean refreshEagerly) { + this(icebergTable, null, refreshEagerly); + } + + public SparkTable(Table icebergTable, Long snapshotId, boolean refreshEagerly) { + this.icebergTable = icebergTable; + this.snapshotId = snapshotId; + this.refreshEagerly = refreshEagerly; + + boolean acceptAnySchema = PropertyUtil.propertyAsBoolean(icebergTable.properties(), + TableProperties.SPARK_WRITE_ACCEPT_ANY_SCHEMA, TableProperties.SPARK_WRITE_ACCEPT_ANY_SCHEMA_DEFAULT); + this.capabilities = acceptAnySchema ? CAPABILITIES_WITH_ACCEPT_ANY_SCHEMA : CAPABILITIES; + } + + private SparkSession sparkSession() { + if (lazySpark == null) { + this.lazySpark = SparkSession.active(); + } + + return lazySpark; + } + + public Table table() { + return icebergTable; + } + + @Override + public String name() { + return icebergTable.toString(); + } + + private Schema snapshotSchema() { + return SnapshotUtil.schemaFor(icebergTable, snapshotId, null); + } + + @Override + public StructType schema() { + if (lazyTableSchema == null) { + this.lazyTableSchema = SparkSchemaUtil.convert(snapshotSchema()); + } + + return lazyTableSchema; + } + + @Override + public Transform[] partitioning() { + return Spark3Util.toTransforms(icebergTable.spec()); + } + + @Override + public Map properties() { + ImmutableMap.Builder propsBuilder = ImmutableMap.builder(); + + String fileFormat = icebergTable.properties() + .getOrDefault(TableProperties.DEFAULT_FILE_FORMAT, TableProperties.DEFAULT_FILE_FORMAT_DEFAULT); + propsBuilder.put("format", "iceberg/" + fileFormat); + propsBuilder.put("provider", "iceberg"); + String currentSnapshotId = icebergTable.currentSnapshot() != null ? + String.valueOf(icebergTable.currentSnapshot().snapshotId()) : "none"; + propsBuilder.put(CURRENT_SNAPSHOT_ID, currentSnapshotId); + propsBuilder.put("location", icebergTable.location()); + + if (icebergTable instanceof BaseTable) { + TableOperations ops = ((BaseTable) icebergTable).operations(); + propsBuilder.put(FORMAT_VERSION, String.valueOf(ops.current().formatVersion())); + } + + if (!icebergTable.sortOrder().isUnsorted()) { + propsBuilder.put("sort-order", Spark3Util.describe(icebergTable.sortOrder())); + } + + Set identifierFields = icebergTable.schema().identifierFieldNames(); + if (!identifierFields.isEmpty()) { + propsBuilder.put("identifier-fields", "[" + String.join(",", identifierFields) + "]"); + } + + icebergTable.properties().entrySet().stream() + .filter(entry -> !RESERVED_PROPERTIES.contains(entry.getKey())) + .forEach(propsBuilder::put); + + return propsBuilder.build(); + } + + @Override + public Set capabilities() { + return capabilities; + } + + @Override + public MetadataColumn[] metadataColumns() { + DataType sparkPartitionType = SparkSchemaUtil.convert(Partitioning.partitionType(table())); + return new MetadataColumn[] { + new SparkMetadataColumn(MetadataColumns.SPEC_ID.name(), DataTypes.IntegerType, false), + new SparkMetadataColumn(MetadataColumns.PARTITION_COLUMN_NAME, sparkPartitionType, true), + new SparkMetadataColumn(MetadataColumns.FILE_PATH.name(), DataTypes.StringType, false), + new SparkMetadataColumn(MetadataColumns.ROW_POSITION.name(), DataTypes.LongType, false), + new SparkMetadataColumn(MetadataColumns.IS_DELETED.name(), DataTypes.BooleanType, false) + }; + } + + @Override + public ScanBuilder newScanBuilder(CaseInsensitiveStringMap options) { + if (options.containsKey(SparkReadOptions.FILE_SCAN_TASK_SET_ID)) { + // skip planning the job and fetch already staged file scan tasks + return new SparkFilesScanBuilder(sparkSession(), icebergTable, options); + } + + if (refreshEagerly) { + icebergTable.refresh(); + } + + CaseInsensitiveStringMap scanOptions = addSnapshotId(options, snapshotId); + return new SparkScanBuilder(sparkSession(), icebergTable, snapshotSchema(), scanOptions); + } + + @Override + public WriteBuilder newWriteBuilder(LogicalWriteInfo info) { + Preconditions.checkArgument( + snapshotId == null, + "Cannot write to table at a specific snapshot: %s", snapshotId); + + return new SparkWriteBuilder(sparkSession(), icebergTable, info); + } + + @Override + public RowLevelOperationBuilder newRowLevelOperationBuilder(RowLevelOperationInfo info) { + return new SparkRowLevelOperationBuilder(sparkSession(), icebergTable, info); + } + + @Override + public boolean canDeleteWhere(Filter[] filters) { + Preconditions.checkArgument( + snapshotId == null, + "Cannot delete from table at a specific snapshot: %s", snapshotId); + + Expression deleteExpr = Expressions.alwaysTrue(); + + for (Filter filter : filters) { + Expression expr = SparkFilters.convert(filter); + if (expr != null) { + deleteExpr = Expressions.and(deleteExpr, expr); + } else { + return false; + } + } + + return deleteExpr == Expressions.alwaysTrue() || canDeleteUsingMetadata(deleteExpr); + } + + // a metadata delete is possible iff matching files can be deleted entirely + private boolean canDeleteUsingMetadata(Expression deleteExpr) { + boolean caseSensitive = Boolean.parseBoolean(sparkSession().conf().get("spark.sql.caseSensitive")); + TableScan scan = table().newScan() + .filter(deleteExpr) + .caseSensitive(caseSensitive) + .includeColumnStats() + .ignoreResiduals(); + + try (CloseableIterable tasks = scan.planFiles()) { + Map evaluators = Maps.newHashMap(); + StrictMetricsEvaluator metricsEvaluator = new StrictMetricsEvaluator(table().schema(), deleteExpr); + + return Iterables.all(tasks, task -> { + DataFile file = task.file(); + PartitionSpec spec = task.spec(); + Evaluator evaluator = evaluators.computeIfAbsent( + spec.specId(), + specId -> new Evaluator(spec.partitionType(), Projections.strict(spec).project(deleteExpr))); + return evaluator.eval(file.partition()) || metricsEvaluator.eval(file); + }); + + } catch (IOException ioe) { + LOG.warn("Failed to close task iterable", ioe); + return false; + } + } + + @Override + public void deleteWhere(Filter[] filters) { + Expression deleteExpr = SparkFilters.convert(filters); + + if (deleteExpr == Expressions.alwaysFalse()) { + LOG.info("Skipping the delete operation as the condition is always false"); + return; + } + + icebergTable.newDelete() + .set("spark.app.id", sparkSession().sparkContext().applicationId()) + .deleteFromRowFilter(deleteExpr) + .commit(); + } + + @Override + public String toString() { + return icebergTable.toString(); + } + + @Override + public boolean equals(Object other) { + if (this == other) { + return true; + } else if (other == null || getClass() != other.getClass()) { + return false; + } + + // use only name in order to correctly invalidate Spark cache + SparkTable that = (SparkTable) other; + return icebergTable.name().equals(that.icebergTable.name()); + } + + @Override + public int hashCode() { + // use only name in order to correctly invalidate Spark cache + return icebergTable.name().hashCode(); + } + + private static CaseInsensitiveStringMap addSnapshotId(CaseInsensitiveStringMap options, Long snapshotId) { + if (snapshotId != null) { + String snapshotIdFromOptions = options.get(SparkReadOptions.SNAPSHOT_ID); + String value = snapshotId.toString(); + Preconditions.checkArgument(snapshotIdFromOptions == null || snapshotIdFromOptions.equals(value), + "Cannot override snapshot ID more than once: %s", snapshotIdFromOptions); + + Map scanOptions = Maps.newHashMap(); + scanOptions.putAll(options.asCaseSensitiveMap()); + scanOptions.put(SparkReadOptions.SNAPSHOT_ID, value); + scanOptions.remove(SparkReadOptions.AS_OF_TIMESTAMP); + + return new CaseInsensitiveStringMap(scanOptions); + } + + return options; + } +} diff --git a/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/source/SparkWrite.java b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/source/SparkWrite.java new file mode 100644 index 000000000000..b714b805566e --- /dev/null +++ b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/source/SparkWrite.java @@ -0,0 +1,740 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iceberg.spark.source; + +import java.io.IOException; +import java.util.Arrays; +import java.util.Collections; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.stream.Collectors; +import org.apache.iceberg.AppendFiles; +import org.apache.iceberg.ContentFile; +import org.apache.iceberg.DataFile; +import org.apache.iceberg.FileFormat; +import org.apache.iceberg.FileScanTask; +import org.apache.iceberg.IsolationLevel; +import org.apache.iceberg.OverwriteFiles; +import org.apache.iceberg.PartitionKey; +import org.apache.iceberg.PartitionSpec; +import org.apache.iceberg.ReplacePartitions; +import org.apache.iceberg.Schema; +import org.apache.iceberg.SerializableTable; +import org.apache.iceberg.Snapshot; +import org.apache.iceberg.SnapshotSummary; +import org.apache.iceberg.SnapshotUpdate; +import org.apache.iceberg.Table; +import org.apache.iceberg.exceptions.CommitStateUnknownException; +import org.apache.iceberg.expressions.Expression; +import org.apache.iceberg.expressions.Expressions; +import org.apache.iceberg.io.ClusteredDataWriter; +import org.apache.iceberg.io.DataWriteResult; +import org.apache.iceberg.io.FanoutDataWriter; +import org.apache.iceberg.io.FileIO; +import org.apache.iceberg.io.FileWriter; +import org.apache.iceberg.io.OutputFileFactory; +import org.apache.iceberg.io.PartitioningWriter; +import org.apache.iceberg.io.RollingDataWriter; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableList; +import org.apache.iceberg.relocated.com.google.common.collect.Iterables; +import org.apache.iceberg.relocated.com.google.common.collect.Sets; +import org.apache.iceberg.spark.CommitMetadata; +import org.apache.iceberg.spark.FileRewriteCoordinator; +import org.apache.iceberg.spark.SparkWriteConf; +import org.apache.iceberg.util.PropertyUtil; +import org.apache.iceberg.util.Tasks; +import org.apache.iceberg.util.ThreadPools; +import org.apache.spark.TaskContext; +import org.apache.spark.TaskContext$; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.broadcast.Broadcast; +import org.apache.spark.executor.OutputMetrics; +import org.apache.spark.sql.SparkSession; +import org.apache.spark.sql.catalyst.InternalRow; +import org.apache.spark.sql.connector.distributions.Distribution; +import org.apache.spark.sql.connector.expressions.SortOrder; +import org.apache.spark.sql.connector.write.BatchWrite; +import org.apache.spark.sql.connector.write.DataWriter; +import org.apache.spark.sql.connector.write.DataWriterFactory; +import org.apache.spark.sql.connector.write.LogicalWriteInfo; +import org.apache.spark.sql.connector.write.PhysicalWriteInfo; +import org.apache.spark.sql.connector.write.RequiresDistributionAndOrdering; +import org.apache.spark.sql.connector.write.Write; +import org.apache.spark.sql.connector.write.WriterCommitMessage; +import org.apache.spark.sql.connector.write.streaming.StreamingDataWriterFactory; +import org.apache.spark.sql.connector.write.streaming.StreamingWrite; +import org.apache.spark.sql.types.StructType; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import static org.apache.iceberg.IsolationLevel.SERIALIZABLE; +import static org.apache.iceberg.IsolationLevel.SNAPSHOT; +import static org.apache.iceberg.TableProperties.COMMIT_MAX_RETRY_WAIT_MS; +import static org.apache.iceberg.TableProperties.COMMIT_MAX_RETRY_WAIT_MS_DEFAULT; +import static org.apache.iceberg.TableProperties.COMMIT_MIN_RETRY_WAIT_MS; +import static org.apache.iceberg.TableProperties.COMMIT_MIN_RETRY_WAIT_MS_DEFAULT; +import static org.apache.iceberg.TableProperties.COMMIT_NUM_RETRIES; +import static org.apache.iceberg.TableProperties.COMMIT_NUM_RETRIES_DEFAULT; +import static org.apache.iceberg.TableProperties.COMMIT_TOTAL_RETRY_TIME_MS; +import static org.apache.iceberg.TableProperties.COMMIT_TOTAL_RETRY_TIME_MS_DEFAULT; + +abstract class SparkWrite implements Write, RequiresDistributionAndOrdering { + private static final Logger LOG = LoggerFactory.getLogger(SparkWrite.class); + + private final JavaSparkContext sparkContext; + private final SparkWriteConf writeConf; + private final Table table; + private final String queryId; + private final FileFormat format; + private final String applicationId; + private final boolean wapEnabled; + private final String wapId; + private final long targetFileSize; + private final Schema writeSchema; + private final StructType dsSchema; + private final Map extraSnapshotMetadata; + private final boolean partitionedFanoutEnabled; + private final Distribution requiredDistribution; + private final SortOrder[] requiredOrdering; + + private boolean cleanupOnAbort = true; + + SparkWrite(SparkSession spark, Table table, SparkWriteConf writeConf, + LogicalWriteInfo writeInfo, String applicationId, + Schema writeSchema, StructType dsSchema, + Distribution requiredDistribution, SortOrder[] requiredOrdering) { + this.sparkContext = JavaSparkContext.fromSparkContext(spark.sparkContext()); + this.table = table; + this.writeConf = writeConf; + this.queryId = writeInfo.queryId(); + this.format = writeConf.dataFileFormat(); + this.applicationId = applicationId; + this.wapEnabled = writeConf.wapEnabled(); + this.wapId = writeConf.wapId(); + this.targetFileSize = writeConf.targetDataFileSize(); + this.writeSchema = writeSchema; + this.dsSchema = dsSchema; + this.extraSnapshotMetadata = writeConf.extraSnapshotMetadata(); + this.partitionedFanoutEnabled = writeConf.fanoutWriterEnabled(); + this.requiredDistribution = requiredDistribution; + this.requiredOrdering = requiredOrdering; + } + + @Override + public Distribution requiredDistribution() { + return requiredDistribution; + } + + @Override + public SortOrder[] requiredOrdering() { + return requiredOrdering; + } + + BatchWrite asBatchAppend() { + return new BatchAppend(); + } + + BatchWrite asDynamicOverwrite() { + return new DynamicOverwrite(); + } + + BatchWrite asOverwriteByFilter(Expression overwriteExpr) { + return new OverwriteByFilter(overwriteExpr); + } + + BatchWrite asCopyOnWriteOperation(SparkCopyOnWriteScan scan, IsolationLevel isolationLevel) { + return new CopyOnWriteOperation(scan, isolationLevel); + } + + BatchWrite asRewrite(String fileSetID) { + return new RewriteFiles(fileSetID); + } + + StreamingWrite asStreamingAppend() { + return new StreamingAppend(); + } + + StreamingWrite asStreamingOverwrite() { + return new StreamingOverwrite(); + } + + // the writer factory works for both batch and streaming + private WriterFactory createWriterFactory() { + // broadcast the table metadata as the writer factory will be sent to executors + Broadcast
tableBroadcast = sparkContext.broadcast(SerializableTable.copyOf(table)); + return new WriterFactory(tableBroadcast, format, targetFileSize, writeSchema, dsSchema, partitionedFanoutEnabled); + } + + private void commitOperation(SnapshotUpdate operation, String description) { + LOG.info("Committing {} to table {}", description, table); + if (applicationId != null) { + operation.set("spark.app.id", applicationId); + } + + if (!extraSnapshotMetadata.isEmpty()) { + extraSnapshotMetadata.forEach(operation::set); + } + + if (!CommitMetadata.commitProperties().isEmpty()) { + CommitMetadata.commitProperties().forEach(operation::set); + } + + if (wapEnabled && wapId != null) { + // write-audit-publish is enabled for this table and job + // stage the changes without changing the current snapshot + operation.set(SnapshotSummary.STAGED_WAP_ID_PROP, wapId); + operation.stageOnly(); + } + + try { + long start = System.currentTimeMillis(); + operation.commit(); // abort is automatically called if this fails + long duration = System.currentTimeMillis() - start; + LOG.info("Committed in {} ms", duration); + } catch (CommitStateUnknownException commitStateUnknownException) { + cleanupOnAbort = false; + throw commitStateUnknownException; + } + } + + private void abort(WriterCommitMessage[] messages) { + if (cleanupOnAbort) { + Map props = table.properties(); + Tasks.foreach(files(messages)) + .executeWith(ThreadPools.getWorkerPool()) + .retry(PropertyUtil.propertyAsInt(props, COMMIT_NUM_RETRIES, COMMIT_NUM_RETRIES_DEFAULT)) + .exponentialBackoff( + PropertyUtil.propertyAsInt(props, COMMIT_MIN_RETRY_WAIT_MS, COMMIT_MIN_RETRY_WAIT_MS_DEFAULT), + PropertyUtil.propertyAsInt(props, COMMIT_MAX_RETRY_WAIT_MS, COMMIT_MAX_RETRY_WAIT_MS_DEFAULT), + PropertyUtil.propertyAsInt(props, COMMIT_TOTAL_RETRY_TIME_MS, COMMIT_TOTAL_RETRY_TIME_MS_DEFAULT), + 2.0 /* exponential */) + .throwFailureWhenFinished() + .run(file -> { + table.io().deleteFile(file.path().toString()); + }); + } else { + LOG.warn("Skipping cleaning up of data files because Iceberg was unable to determine the final commit state"); + } + } + + private Iterable files(WriterCommitMessage[] messages) { + if (messages.length > 0) { + return Iterables.concat(Iterables.transform(Arrays.asList(messages), message -> message != null ? + ImmutableList.copyOf(((TaskCommit) message).files()) : + ImmutableList.of())); + } + return ImmutableList.of(); + } + + @Override + public String toString() { + return String.format("IcebergWrite(table=%s, format=%s)", table, format); + } + + private abstract class BaseBatchWrite implements BatchWrite { + @Override + public DataWriterFactory createBatchWriterFactory(PhysicalWriteInfo info) { + return createWriterFactory(); + } + + @Override + public void abort(WriterCommitMessage[] messages) { + SparkWrite.this.abort(messages); + } + + @Override + public String toString() { + return String.format("IcebergBatchWrite(table=%s, format=%s)", table, format); + } + } + + private class BatchAppend extends BaseBatchWrite { + @Override + public void commit(WriterCommitMessage[] messages) { + AppendFiles append = table.newAppend(); + + int numFiles = 0; + for (DataFile file : files(messages)) { + numFiles += 1; + append.appendFile(file); + } + + commitOperation(append, String.format("append with %d new data files", numFiles)); + } + } + + private class DynamicOverwrite extends BaseBatchWrite { + @Override + public void commit(WriterCommitMessage[] messages) { + Iterable files = files(messages); + + if (!files.iterator().hasNext()) { + LOG.info("Dynamic overwrite is empty, skipping commit"); + return; + } + + ReplacePartitions dynamicOverwrite = table.newReplacePartitions(); + IsolationLevel isolationLevel = writeConf.isolationLevel(); + Long validateFromSnapshotId = writeConf.validateFromSnapshotId(); + + if (isolationLevel != null && validateFromSnapshotId != null) { + dynamicOverwrite.validateFromSnapshot(validateFromSnapshotId); + } + + if (isolationLevel == SERIALIZABLE) { + dynamicOverwrite.validateNoConflictingData(); + dynamicOverwrite.validateNoConflictingDeletes(); + + } else if (isolationLevel == SNAPSHOT) { + dynamicOverwrite.validateNoConflictingDeletes(); + } + + int numFiles = 0; + for (DataFile file : files) { + numFiles += 1; + dynamicOverwrite.addFile(file); + } + + commitOperation(dynamicOverwrite, String.format("dynamic partition overwrite with %d new data files", numFiles)); + } + } + + private class OverwriteByFilter extends BaseBatchWrite { + private final Expression overwriteExpr; + + private OverwriteByFilter(Expression overwriteExpr) { + this.overwriteExpr = overwriteExpr; + } + + @Override + public void commit(WriterCommitMessage[] messages) { + OverwriteFiles overwriteFiles = table.newOverwrite(); + overwriteFiles.overwriteByRowFilter(overwriteExpr); + + int numFiles = 0; + for (DataFile file : files(messages)) { + numFiles += 1; + overwriteFiles.addFile(file); + } + + IsolationLevel isolationLevel = writeConf.isolationLevel(); + Long validateFromSnapshotId = writeConf.validateFromSnapshotId(); + + if (isolationLevel != null && validateFromSnapshotId != null) { + overwriteFiles.validateFromSnapshot(validateFromSnapshotId); + } + + if (isolationLevel == SERIALIZABLE) { + overwriteFiles.validateNoConflictingDeletes(); + overwriteFiles.validateNoConflictingData(); + + } else if (isolationLevel == SNAPSHOT) { + overwriteFiles.validateNoConflictingDeletes(); + } + + String commitMsg = String.format("overwrite by filter %s with %d new data files", overwriteExpr, numFiles); + commitOperation(overwriteFiles, commitMsg); + } + } + + private class CopyOnWriteOperation extends BaseBatchWrite { + private final SparkCopyOnWriteScan scan; + private final IsolationLevel isolationLevel; + + private CopyOnWriteOperation(SparkCopyOnWriteScan scan, IsolationLevel isolationLevel) { + this.scan = scan; + this.isolationLevel = isolationLevel; + } + + private List overwrittenFiles() { + return scan.files().stream().map(FileScanTask::file).collect(Collectors.toList()); + } + + private Expression conflictDetectionFilter() { + // the list of filter expressions may be empty but is never null + List scanFilterExpressions = scan.filterExpressions(); + + Expression filter = Expressions.alwaysTrue(); + + for (Expression expr : scanFilterExpressions) { + filter = Expressions.and(filter, expr); + } + + return filter; + } + + @Override + public void commit(WriterCommitMessage[] messages) { + OverwriteFiles overwriteFiles = table.newOverwrite(); + + List overwrittenFiles = overwrittenFiles(); + int numOverwrittenFiles = overwrittenFiles.size(); + for (DataFile overwrittenFile : overwrittenFiles) { + overwriteFiles.deleteFile(overwrittenFile); + } + + int numAddedFiles = 0; + for (DataFile file : files(messages)) { + numAddedFiles += 1; + overwriteFiles.addFile(file); + } + + if (isolationLevel == SERIALIZABLE) { + commitWithSerializableIsolation(overwriteFiles, numOverwrittenFiles, numAddedFiles); + } else if (isolationLevel == SNAPSHOT) { + commitWithSnapshotIsolation(overwriteFiles, numOverwrittenFiles, numAddedFiles); + } else { + throw new IllegalArgumentException("Unsupported isolation level: " + isolationLevel); + } + } + + private void commitWithSerializableIsolation(OverwriteFiles overwriteFiles, + int numOverwrittenFiles, + int numAddedFiles) { + Long scanSnapshotId = scan.snapshotId(); + if (scanSnapshotId != null) { + overwriteFiles.validateFromSnapshot(scanSnapshotId); + } + + Expression conflictDetectionFilter = conflictDetectionFilter(); + overwriteFiles.conflictDetectionFilter(conflictDetectionFilter); + overwriteFiles.validateNoConflictingData(); + overwriteFiles.validateNoConflictingDeletes(); + + String commitMsg = String.format( + "overwrite of %d data files with %d new data files, scanSnapshotId: %d, conflictDetectionFilter: %s", + numOverwrittenFiles, numAddedFiles, scanSnapshotId, conflictDetectionFilter); + commitOperation(overwriteFiles, commitMsg); + } + + private void commitWithSnapshotIsolation(OverwriteFiles overwriteFiles, + int numOverwrittenFiles, + int numAddedFiles) { + Long scanSnapshotId = scan.snapshotId(); + if (scanSnapshotId != null) { + overwriteFiles.validateFromSnapshot(scanSnapshotId); + } + + Expression conflictDetectionFilter = conflictDetectionFilter(); + overwriteFiles.conflictDetectionFilter(conflictDetectionFilter); + overwriteFiles.validateNoConflictingDeletes(); + + String commitMsg = String.format( + "overwrite of %d data files with %d new data files", + numOverwrittenFiles, numAddedFiles); + commitOperation(overwriteFiles, commitMsg); + } + } + + private class RewriteFiles extends BaseBatchWrite { + private final String fileSetID; + + private RewriteFiles(String fileSetID) { + this.fileSetID = fileSetID; + } + + @Override + public void commit(WriterCommitMessage[] messages) { + FileRewriteCoordinator coordinator = FileRewriteCoordinator.get(); + + Set newDataFiles = Sets.newHashSetWithExpectedSize(messages.length); + for (DataFile file : files(messages)) { + newDataFiles.add(file); + } + + coordinator.stageRewrite(table, fileSetID, Collections.unmodifiableSet(newDataFiles)); + } + } + + private abstract class BaseStreamingWrite implements StreamingWrite { + private static final String QUERY_ID_PROPERTY = "spark.sql.streaming.queryId"; + private static final String EPOCH_ID_PROPERTY = "spark.sql.streaming.epochId"; + + protected abstract String mode(); + + @Override + public StreamingDataWriterFactory createStreamingWriterFactory(PhysicalWriteInfo info) { + return createWriterFactory(); + } + + @Override + public final void commit(long epochId, WriterCommitMessage[] messages) { + LOG.info("Committing epoch {} for query {} in {} mode", epochId, queryId, mode()); + + table.refresh(); + + Long lastCommittedEpochId = findLastCommittedEpochId(); + if (lastCommittedEpochId != null && epochId <= lastCommittedEpochId) { + LOG.info("Skipping epoch {} for query {} as it was already committed", epochId, queryId); + return; + } + + doCommit(epochId, messages); + } + + protected abstract void doCommit(long epochId, WriterCommitMessage[] messages); + + protected void commit(SnapshotUpdate snapshotUpdate, long epochId, String description) { + snapshotUpdate.set(QUERY_ID_PROPERTY, queryId); + snapshotUpdate.set(EPOCH_ID_PROPERTY, Long.toString(epochId)); + commitOperation(snapshotUpdate, description); + } + + private Long findLastCommittedEpochId() { + Snapshot snapshot = table.currentSnapshot(); + Long lastCommittedEpochId = null; + while (snapshot != null) { + Map summary = snapshot.summary(); + String snapshotQueryId = summary.get(QUERY_ID_PROPERTY); + if (queryId.equals(snapshotQueryId)) { + lastCommittedEpochId = Long.valueOf(summary.get(EPOCH_ID_PROPERTY)); + break; + } + Long parentSnapshotId = snapshot.parentId(); + snapshot = parentSnapshotId != null ? table.snapshot(parentSnapshotId) : null; + } + return lastCommittedEpochId; + } + + @Override + public void abort(long epochId, WriterCommitMessage[] messages) { + SparkWrite.this.abort(messages); + } + + @Override + public String toString() { + return String.format("IcebergStreamingWrite(table=%s, format=%s)", table, format); + } + } + + private class StreamingAppend extends BaseStreamingWrite { + @Override + protected String mode() { + return "append"; + } + + @Override + protected void doCommit(long epochId, WriterCommitMessage[] messages) { + AppendFiles append = table.newFastAppend(); + int numFiles = 0; + for (DataFile file : files(messages)) { + append.appendFile(file); + numFiles++; + } + commit(append, epochId, String.format("streaming append with %d new data files", numFiles)); + } + } + + private class StreamingOverwrite extends BaseStreamingWrite { + @Override + protected String mode() { + return "complete"; + } + + @Override + public void doCommit(long epochId, WriterCommitMessage[] messages) { + OverwriteFiles overwriteFiles = table.newOverwrite(); + overwriteFiles.overwriteByRowFilter(Expressions.alwaysTrue()); + int numFiles = 0; + for (DataFile file : files(messages)) { + overwriteFiles.addFile(file); + numFiles++; + } + commit(overwriteFiles, epochId, String.format("streaming complete overwrite with %d new data files", numFiles)); + } + } + + public static class TaskCommit implements WriterCommitMessage { + private final DataFile[] taskFiles; + + TaskCommit(DataFile[] taskFiles) { + this.taskFiles = taskFiles; + } + + // Reports bytesWritten and recordsWritten to the Spark output metrics. + // Can only be called in executor. + void reportOutputMetrics() { + long bytesWritten = 0L; + long recordsWritten = 0L; + for (DataFile dataFile : taskFiles) { + bytesWritten += dataFile.fileSizeInBytes(); + recordsWritten += dataFile.recordCount(); + } + + TaskContext taskContext = TaskContext$.MODULE$.get(); + if (taskContext != null) { + OutputMetrics outputMetrics = taskContext.taskMetrics().outputMetrics(); + outputMetrics.setBytesWritten(bytesWritten); + outputMetrics.setRecordsWritten(recordsWritten); + } + } + + DataFile[] files() { + return taskFiles; + } + } + + private static class WriterFactory implements DataWriterFactory, StreamingDataWriterFactory { + private final Broadcast
tableBroadcast; + private final FileFormat format; + private final long targetFileSize; + private final Schema writeSchema; + private final StructType dsSchema; + private final boolean partitionedFanoutEnabled; + + protected WriterFactory(Broadcast
tableBroadcast, FileFormat format, long targetFileSize, + Schema writeSchema, StructType dsSchema, boolean partitionedFanoutEnabled) { + this.tableBroadcast = tableBroadcast; + this.format = format; + this.targetFileSize = targetFileSize; + this.writeSchema = writeSchema; + this.dsSchema = dsSchema; + this.partitionedFanoutEnabled = partitionedFanoutEnabled; + } + + @Override + public DataWriter createWriter(int partitionId, long taskId) { + return createWriter(partitionId, taskId, 0); + } + + @Override + public DataWriter createWriter(int partitionId, long taskId, long epochId) { + Table table = tableBroadcast.value(); + PartitionSpec spec = table.spec(); + FileIO io = table.io(); + + OutputFileFactory fileFactory = OutputFileFactory.builderFor(table, partitionId, taskId) + .format(format) + .build(); + SparkFileWriterFactory writerFactory = SparkFileWriterFactory.builderFor(table) + .dataFileFormat(format) + .dataSchema(writeSchema) + .dataSparkType(dsSchema) + .build(); + + if (spec.isUnpartitioned()) { + return new UnpartitionedDataWriter(writerFactory, fileFactory, io, spec, targetFileSize); + + } else { + return new PartitionedDataWriter( + writerFactory, fileFactory, io, spec, writeSchema, dsSchema, targetFileSize, partitionedFanoutEnabled); + } + } + } + + private static > void deleteFiles(FileIO io, List files) { + Tasks.foreach(files) + .executeWith(ThreadPools.getWorkerPool()) + .throwFailureWhenFinished() + .noRetry() + .run(file -> io.deleteFile(file.path().toString())); + } + + private static class UnpartitionedDataWriter implements DataWriter { + private final FileWriter delegate; + private final FileIO io; + + private UnpartitionedDataWriter(SparkFileWriterFactory writerFactory, OutputFileFactory fileFactory, + FileIO io, PartitionSpec spec, long targetFileSize) { + this.delegate = new RollingDataWriter<>(writerFactory, fileFactory, io, targetFileSize, spec, null); + this.io = io; + } + + @Override + public void write(InternalRow record) throws IOException { + delegate.write(record); + } + + @Override + public WriterCommitMessage commit() throws IOException { + close(); + + DataWriteResult result = delegate.result(); + TaskCommit taskCommit = new TaskCommit(result.dataFiles().toArray(new DataFile[0])); + taskCommit.reportOutputMetrics(); + return taskCommit; + } + + @Override + public void abort() throws IOException { + close(); + + DataWriteResult result = delegate.result(); + deleteFiles(io, result.dataFiles()); + } + + @Override + public void close() throws IOException { + delegate.close(); + } + } + + private static class PartitionedDataWriter implements DataWriter { + private final PartitioningWriter delegate; + private final FileIO io; + private final PartitionSpec spec; + private final PartitionKey partitionKey; + private final InternalRowWrapper internalRowWrapper; + + private PartitionedDataWriter(SparkFileWriterFactory writerFactory, OutputFileFactory fileFactory, + FileIO io, PartitionSpec spec, Schema dataSchema, + StructType dataSparkType, long targetFileSize, boolean fanoutEnabled) { + if (fanoutEnabled) { + this.delegate = new FanoutDataWriter<>(writerFactory, fileFactory, io, targetFileSize); + } else { + this.delegate = new ClusteredDataWriter<>(writerFactory, fileFactory, io, targetFileSize); + } + this.io = io; + this.spec = spec; + this.partitionKey = new PartitionKey(spec, dataSchema); + this.internalRowWrapper = new InternalRowWrapper(dataSparkType); + } + + @Override + public void write(InternalRow row) throws IOException { + partitionKey.partition(internalRowWrapper.wrap(row)); + delegate.write(row, spec, partitionKey); + } + + @Override + public WriterCommitMessage commit() throws IOException { + close(); + + DataWriteResult result = delegate.result(); + TaskCommit taskCommit = new TaskCommit(result.dataFiles().toArray(new DataFile[0])); + taskCommit.reportOutputMetrics(); + return taskCommit; + } + + @Override + public void abort() throws IOException { + close(); + + DataWriteResult result = delegate.result(); + deleteFiles(io, result.dataFiles()); + } + + @Override + public void close() throws IOException { + delegate.close(); + } + } +} diff --git a/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/source/SparkWriteBuilder.java b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/source/SparkWriteBuilder.java new file mode 100644 index 000000000000..dd0fc8e3ce37 --- /dev/null +++ b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/source/SparkWriteBuilder.java @@ -0,0 +1,254 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iceberg.spark.source; + +import org.apache.iceberg.DistributionMode; +import org.apache.iceberg.IsolationLevel; +import org.apache.iceberg.PartitionSpec; +import org.apache.iceberg.Schema; +import org.apache.iceberg.Table; +import org.apache.iceberg.UpdateSchema; +import org.apache.iceberg.expressions.Expression; +import org.apache.iceberg.expressions.Expressions; +import org.apache.iceberg.relocated.com.google.common.base.Preconditions; +import org.apache.iceberg.spark.Spark3Util; +import org.apache.iceberg.spark.SparkDistributionAndOrderingUtil; +import org.apache.iceberg.spark.SparkFilters; +import org.apache.iceberg.spark.SparkSchemaUtil; +import org.apache.iceberg.spark.SparkUtil; +import org.apache.iceberg.spark.SparkWriteConf; +import org.apache.iceberg.types.TypeUtil; +import org.apache.spark.sql.SparkSession; +import org.apache.spark.sql.connector.distributions.Distribution; +import org.apache.spark.sql.connector.distributions.Distributions; +import org.apache.spark.sql.connector.expressions.SortOrder; +import org.apache.spark.sql.connector.read.Scan; +import org.apache.spark.sql.connector.write.BatchWrite; +import org.apache.spark.sql.connector.write.LogicalWriteInfo; +import org.apache.spark.sql.connector.write.RowLevelOperation.Command; +import org.apache.spark.sql.connector.write.SupportsDynamicOverwrite; +import org.apache.spark.sql.connector.write.SupportsOverwrite; +import org.apache.spark.sql.connector.write.Write; +import org.apache.spark.sql.connector.write.WriteBuilder; +import org.apache.spark.sql.connector.write.streaming.StreamingWrite; +import org.apache.spark.sql.sources.Filter; +import org.apache.spark.sql.types.StructType; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +class SparkWriteBuilder implements WriteBuilder, SupportsDynamicOverwrite, SupportsOverwrite { + private static final Logger LOG = LoggerFactory.getLogger(SparkWriteBuilder.class); + private static final SortOrder[] NO_ORDERING = new SortOrder[0]; + + private final SparkSession spark; + private final Table table; + private final SparkWriteConf writeConf; + private final LogicalWriteInfo writeInfo; + private final StructType dsSchema; + private final String overwriteMode; + private final String rewrittenFileSetId; + private final boolean handleTimestampWithoutZone; + private final boolean useTableDistributionAndOrdering; + private boolean overwriteDynamic = false; + private boolean overwriteByFilter = false; + private Expression overwriteExpr = null; + private boolean overwriteFiles = false; + private SparkCopyOnWriteScan copyOnWriteScan = null; + private Command copyOnWriteCommand = null; + private IsolationLevel copyOnWriteIsolationLevel = null; + + SparkWriteBuilder(SparkSession spark, Table table, LogicalWriteInfo info) { + this.spark = spark; + this.table = table; + this.writeConf = new SparkWriteConf(spark, table, info.options()); + this.writeInfo = info; + this.dsSchema = info.schema(); + this.overwriteMode = writeConf.overwriteMode(); + this.rewrittenFileSetId = writeConf.rewrittenFileSetId(); + this.handleTimestampWithoutZone = writeConf.handleTimestampWithoutZone(); + this.useTableDistributionAndOrdering = writeConf.useTableDistributionAndOrdering(); + } + + public WriteBuilder overwriteFiles(Scan scan, Command command, IsolationLevel isolationLevel) { + Preconditions.checkArgument(scan instanceof SparkCopyOnWriteScan, "%s is not SparkCopyOnWriteScan", scan); + Preconditions.checkState(!overwriteByFilter, "Cannot overwrite individual files and by filter"); + Preconditions.checkState(!overwriteDynamic, "Cannot overwrite individual files and dynamically"); + Preconditions.checkState(rewrittenFileSetId == null, "Cannot overwrite individual files and rewrite"); + + this.overwriteFiles = true; + this.copyOnWriteScan = (SparkCopyOnWriteScan) scan; + this.copyOnWriteCommand = command; + this.copyOnWriteIsolationLevel = isolationLevel; + return this; + } + + @Override + public WriteBuilder overwriteDynamicPartitions() { + Preconditions.checkState(!overwriteByFilter, "Cannot overwrite dynamically and by filter: %s", overwriteExpr); + Preconditions.checkState(!overwriteFiles, "Cannot overwrite individual files and dynamically"); + Preconditions.checkState(rewrittenFileSetId == null, "Cannot overwrite dynamically and rewrite"); + + this.overwriteDynamic = true; + return this; + } + + @Override + public WriteBuilder overwrite(Filter[] filters) { + Preconditions.checkState(!overwriteFiles, "Cannot overwrite individual files and using filters"); + Preconditions.checkState(rewrittenFileSetId == null, "Cannot overwrite and rewrite"); + + this.overwriteExpr = SparkFilters.convert(filters); + if (overwriteExpr == Expressions.alwaysTrue() && "dynamic".equals(overwriteMode)) { + // use the write option to override truncating the table. use dynamic overwrite instead. + this.overwriteDynamic = true; + } else { + Preconditions.checkState(!overwriteDynamic, "Cannot overwrite dynamically and by filter: %s", overwriteExpr); + this.overwriteByFilter = true; + } + return this; + } + + @Override + public Write build() { + // Validate + Preconditions.checkArgument(handleTimestampWithoutZone || !SparkUtil.hasTimestampWithoutZone(table.schema()), + SparkUtil.TIMESTAMP_WITHOUT_TIMEZONE_ERROR); + + Schema writeSchema = validateOrMergeWriteSchema(table, dsSchema, writeConf); + SparkUtil.validatePartitionTransforms(table.spec()); + + // Get application id + String appId = spark.sparkContext().applicationId(); + + Distribution distribution; + SortOrder[] ordering; + + if (useTableDistributionAndOrdering) { + if (Spark3Util.extensionsEnabled(spark) || allIdentityTransforms(table.spec())) { + distribution = buildRequiredDistribution(); + ordering = buildRequiredOrdering(distribution); + } else { + LOG.warn("Skipping distribution/ordering: extensions are disabled and spec contains unsupported transforms"); + distribution = Distributions.unspecified(); + ordering = NO_ORDERING; + } + } else { + LOG.info("Skipping distribution/ordering: disabled per job configuration"); + distribution = Distributions.unspecified(); + ordering = NO_ORDERING; + } + + return new SparkWrite(spark, table, writeConf, writeInfo, appId, writeSchema, dsSchema, distribution, ordering) { + + @Override + public BatchWrite toBatch() { + if (rewrittenFileSetId != null) { + return asRewrite(rewrittenFileSetId); + } else if (overwriteByFilter) { + return asOverwriteByFilter(overwriteExpr); + } else if (overwriteDynamic) { + return asDynamicOverwrite(); + } else if (overwriteFiles) { + return asCopyOnWriteOperation(copyOnWriteScan, copyOnWriteIsolationLevel); + } else { + return asBatchAppend(); + } + } + + @Override + public StreamingWrite toStreaming() { + Preconditions.checkState(!overwriteDynamic, + "Unsupported streaming operation: dynamic partition overwrite"); + Preconditions.checkState(!overwriteByFilter || overwriteExpr == Expressions.alwaysTrue(), + "Unsupported streaming operation: overwrite by filter: %s", overwriteExpr); + Preconditions.checkState(rewrittenFileSetId == null, + "Unsupported streaming operation: rewrite"); + + if (overwriteByFilter) { + return asStreamingOverwrite(); + } else { + return asStreamingAppend(); + } + } + }; + } + + private Distribution buildRequiredDistribution() { + if (overwriteFiles) { + DistributionMode distributionMode = copyOnWriteDistributionMode(); + return SparkDistributionAndOrderingUtil.buildCopyOnWriteDistribution(table, copyOnWriteCommand, distributionMode); + } else { + DistributionMode distributionMode = writeConf.distributionMode(); + return SparkDistributionAndOrderingUtil.buildRequiredDistribution(table, distributionMode); + } + } + + private DistributionMode copyOnWriteDistributionMode() { + switch (copyOnWriteCommand) { + case DELETE: + return writeConf.deleteDistributionMode(); + case UPDATE: + return writeConf.updateDistributionMode(); + case MERGE: + return writeConf.copyOnWriteMergeDistributionMode(); + default: + throw new IllegalArgumentException("Unexpected command: " + copyOnWriteCommand); + } + } + + private SortOrder[] buildRequiredOrdering(Distribution requiredDistribution) { + if (overwriteFiles) { + return SparkDistributionAndOrderingUtil.buildCopyOnWriteOrdering(table, copyOnWriteCommand, requiredDistribution); + } else { + return SparkDistributionAndOrderingUtil.buildRequiredOrdering(table, requiredDistribution); + } + } + + private boolean allIdentityTransforms(PartitionSpec spec) { + return spec.fields().stream().allMatch(field -> field.transform().isIdentity()); + } + + private static Schema validateOrMergeWriteSchema(Table table, StructType dsSchema, SparkWriteConf writeConf) { + Schema writeSchema; + if (writeConf.mergeSchema()) { + // convert the dataset schema and assign fresh ids for new fields + Schema newSchema = SparkSchemaUtil.convertWithFreshIds(table.schema(), dsSchema); + + // update the table to get final id assignments and validate the changes + UpdateSchema update = table.updateSchema().unionByNameWith(newSchema); + Schema mergedSchema = update.apply(); + + // reconvert the dsSchema without assignment to use the ids assigned by UpdateSchema + writeSchema = SparkSchemaUtil.convert(mergedSchema, dsSchema); + + TypeUtil.validateWriteSchema( + mergedSchema, writeSchema, writeConf.checkNullability(), writeConf.checkOrdering()); + + // if the validation passed, update the table schema + update.commit(); + } else { + writeSchema = SparkSchemaUtil.convert(table.schema(), dsSchema); + TypeUtil.validateWriteSchema( + table.schema(), writeSchema, writeConf.checkNullability(), writeConf.checkOrdering()); + } + + return writeSchema; + } +} diff --git a/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/source/StagedSparkTable.java b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/source/StagedSparkTable.java new file mode 100644 index 000000000000..2e018cb09496 --- /dev/null +++ b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/source/StagedSparkTable.java @@ -0,0 +1,42 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iceberg.spark.source; + +import org.apache.iceberg.Transaction; +import org.apache.spark.sql.connector.catalog.StagedTable; + +public class StagedSparkTable extends SparkTable implements StagedTable { + private final Transaction transaction; + + public StagedSparkTable(Transaction transaction) { + super(transaction.table(), false); + this.transaction = transaction; + } + + @Override + public void commitStagedChanges() { + transaction.commitTransaction(); + } + + @Override + public void abortStagedChanges() { + // TODO: clean up + } +} diff --git a/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/source/Stats.java b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/source/Stats.java new file mode 100644 index 000000000000..939b07a0af61 --- /dev/null +++ b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/source/Stats.java @@ -0,0 +1,43 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iceberg.spark.source; + +import java.util.OptionalLong; +import org.apache.spark.sql.connector.read.Statistics; + +class Stats implements Statistics { + private final OptionalLong sizeInBytes; + private final OptionalLong numRows; + + Stats(long sizeInBytes, long numRows) { + this.sizeInBytes = OptionalLong.of(sizeInBytes); + this.numRows = OptionalLong.of(numRows); + } + + @Override + public OptionalLong sizeInBytes() { + return sizeInBytes; + } + + @Override + public OptionalLong numRows() { + return numRows; + } +} diff --git a/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/source/StreamingOffset.java b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/source/StreamingOffset.java new file mode 100644 index 000000000000..64277ecf3be5 --- /dev/null +++ b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/source/StreamingOffset.java @@ -0,0 +1,154 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iceberg.spark.source; + +import com.fasterxml.jackson.core.JsonGenerator; +import com.fasterxml.jackson.databind.JsonNode; +import java.io.IOException; +import java.io.InputStream; +import java.io.StringWriter; +import java.io.UncheckedIOException; +import org.apache.iceberg.relocated.com.google.common.base.Objects; +import org.apache.iceberg.relocated.com.google.common.base.Preconditions; +import org.apache.iceberg.util.JsonUtil; +import org.apache.spark.sql.connector.read.streaming.Offset; + +class StreamingOffset extends Offset { + static final StreamingOffset START_OFFSET = new StreamingOffset(-1L, -1, false); + + private static final int CURR_VERSION = 1; + private static final String VERSION = "version"; + private static final String SNAPSHOT_ID = "snapshot_id"; + private static final String POSITION = "position"; + private static final String SCAN_ALL_FILES = "scan_all_files"; + + private final long snapshotId; + private final long position; + private final boolean scanAllFiles; + + /** + * An implementation of Spark Structured Streaming Offset, to track the current processed files of + * Iceberg table. + * + * @param snapshotId The current processed snapshot id. + * @param position The position of last scanned file in snapshot. + * @param scanAllFiles whether to scan all files in a snapshot; for example, to read + * all data when starting a stream. + */ + StreamingOffset(long snapshotId, long position, boolean scanAllFiles) { + this.snapshotId = snapshotId; + this.position = position; + this.scanAllFiles = scanAllFiles; + } + + static StreamingOffset fromJson(String json) { + Preconditions.checkNotNull(json, "Cannot parse StreamingOffset JSON: null"); + + try { + JsonNode node = JsonUtil.mapper().readValue(json, JsonNode.class); + return fromJsonNode(node); + } catch (IOException e) { + throw new UncheckedIOException(String.format("Failed to parse StreamingOffset from JSON string %s", json), e); + } + } + + static StreamingOffset fromJson(InputStream inputStream) { + Preconditions.checkNotNull(inputStream, "Cannot parse StreamingOffset from inputStream: null"); + + JsonNode node; + try { + node = JsonUtil.mapper().readValue(inputStream, JsonNode.class); + } catch (IOException e) { + throw new UncheckedIOException("Failed to read StreamingOffset from json", e); + } + + return fromJsonNode(node); + } + + @Override + public String json() { + StringWriter writer = new StringWriter(); + try { + JsonGenerator generator = JsonUtil.factory().createGenerator(writer); + generator.writeStartObject(); + generator.writeNumberField(VERSION, CURR_VERSION); + generator.writeNumberField(SNAPSHOT_ID, snapshotId); + generator.writeNumberField(POSITION, position); + generator.writeBooleanField(SCAN_ALL_FILES, scanAllFiles); + generator.writeEndObject(); + generator.flush(); + + } catch (IOException e) { + throw new UncheckedIOException("Failed to write StreamingOffset to json", e); + } + + return writer.toString(); + } + + long snapshotId() { + return snapshotId; + } + + long position() { + return position; + } + + boolean shouldScanAllFiles() { + return scanAllFiles; + } + + @Override + public boolean equals(Object obj) { + if (obj instanceof StreamingOffset) { + StreamingOffset offset = (StreamingOffset) obj; + return offset.snapshotId == snapshotId && + offset.position == position && + offset.scanAllFiles == scanAllFiles; + } else { + return false; + } + } + + @Override + public int hashCode() { + return Objects.hashCode(snapshotId, position, scanAllFiles); + } + + @Override + public String toString() { + return String.format("Streaming Offset[%d: position (%d) scan_all_files (%b)]", + snapshotId, position, scanAllFiles); + } + + private static StreamingOffset fromJsonNode(JsonNode node) { + // The version of StreamingOffset. The offset was created with a version number + // used to validate when deserializing from json string. + int version = JsonUtil.getInt(VERSION, node); + Preconditions.checkArgument(version == CURR_VERSION, + "This version of Iceberg source only supports version %s. Version %s is not supported.", + CURR_VERSION, version); + + long snapshotId = JsonUtil.getLong(SNAPSHOT_ID, node); + int position = JsonUtil.getInt(POSITION, node); + boolean shouldScanAllFiles = JsonUtil.getBool(SCAN_ALL_FILES, node); + + return new StreamingOffset(snapshotId, position, shouldScanAllFiles); + } +} diff --git a/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/source/StructInternalRow.java b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/source/StructInternalRow.java new file mode 100644 index 000000000000..a2288ef3edd7 --- /dev/null +++ b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/source/StructInternalRow.java @@ -0,0 +1,336 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iceberg.spark.source; + +import java.math.BigDecimal; +import java.nio.ByteBuffer; +import java.time.Duration; +import java.time.Instant; +import java.time.LocalDate; +import java.time.OffsetDateTime; +import java.util.Collection; +import java.util.List; +import java.util.Map; +import java.util.function.BiConsumer; +import java.util.function.Function; +import org.apache.iceberg.StructLike; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableList; +import org.apache.iceberg.relocated.com.google.common.collect.Lists; +import org.apache.iceberg.types.Type; +import org.apache.iceberg.types.Types; +import org.apache.iceberg.util.ByteBuffers; +import org.apache.spark.sql.catalyst.InternalRow; +import org.apache.spark.sql.catalyst.util.ArrayBasedMapData; +import org.apache.spark.sql.catalyst.util.ArrayData; +import org.apache.spark.sql.catalyst.util.GenericArrayData; +import org.apache.spark.sql.catalyst.util.MapData; +import org.apache.spark.sql.types.ArrayType; +import org.apache.spark.sql.types.BinaryType; +import org.apache.spark.sql.types.BooleanType; +import org.apache.spark.sql.types.ByteType; +import org.apache.spark.sql.types.DataType; +import org.apache.spark.sql.types.DateType; +import org.apache.spark.sql.types.Decimal; +import org.apache.spark.sql.types.DecimalType; +import org.apache.spark.sql.types.DoubleType; +import org.apache.spark.sql.types.FloatType; +import org.apache.spark.sql.types.IntegerType; +import org.apache.spark.sql.types.LongType; +import org.apache.spark.sql.types.MapType; +import org.apache.spark.sql.types.ShortType; +import org.apache.spark.sql.types.StringType; +import org.apache.spark.sql.types.StructType; +import org.apache.spark.sql.types.TimestampType; +import org.apache.spark.unsafe.types.CalendarInterval; +import org.apache.spark.unsafe.types.UTF8String; + +class StructInternalRow extends InternalRow { + private final Types.StructType type; + private StructLike struct; + + StructInternalRow(Types.StructType type) { + this.type = type; + } + + private StructInternalRow(Types.StructType type, StructLike struct) { + this.type = type; + this.struct = struct; + } + + public StructInternalRow setStruct(StructLike newStruct) { + this.struct = newStruct; + return this; + } + + @Override + public int numFields() { + return struct.size(); + } + + @Override + public void setNullAt(int i) { + throw new UnsupportedOperationException("StructInternalRow is read-only"); + } + + @Override + public void update(int i, Object value) { + throw new UnsupportedOperationException("StructInternalRow is read-only"); + } + + @Override + public InternalRow copy() { + return this; + } + + @Override + public boolean isNullAt(int ordinal) { + return struct.get(ordinal, Object.class) == null; + } + + @Override + public boolean getBoolean(int ordinal) { + return struct.get(ordinal, Boolean.class); + } + + @Override + public byte getByte(int ordinal) { + return (byte) (int) struct.get(ordinal, Integer.class); + } + + @Override + public short getShort(int ordinal) { + return (short) (int) struct.get(ordinal, Integer.class); + } + + @Override + public int getInt(int ordinal) { + Object integer = struct.get(ordinal, Object.class); + + if (integer instanceof Integer) { + return (int) integer; + } else if (integer instanceof LocalDate) { + return (int) ((LocalDate) integer).toEpochDay(); + } else { + throw new IllegalStateException("Unknown type for int field. Type name: " + integer.getClass().getName()); + } + } + + @Override + public long getLong(int ordinal) { + Object longVal = struct.get(ordinal, Object.class); + + if (longVal instanceof Long) { + return (long) longVal; + } else if (longVal instanceof OffsetDateTime) { + return Duration.between(Instant.EPOCH, (OffsetDateTime) longVal).toNanos() / 1000; + } else if (longVal instanceof LocalDate) { + return ((LocalDate) longVal).toEpochDay(); + } else { + throw new IllegalStateException("Unknown type for long field. Type name: " + longVal.getClass().getName()); + } + } + + @Override + public float getFloat(int ordinal) { + return struct.get(ordinal, Float.class); + } + + @Override + public double getDouble(int ordinal) { + return struct.get(ordinal, Double.class); + } + + @Override + public Decimal getDecimal(int ordinal, int precision, int scale) { + return isNullAt(ordinal) ? null : getDecimalInternal(ordinal, precision, scale); + } + + private Decimal getDecimalInternal(int ordinal, int precision, int scale) { + return Decimal.apply(struct.get(ordinal, BigDecimal.class)); + } + + @Override + public UTF8String getUTF8String(int ordinal) { + return isNullAt(ordinal) ? null : getUTF8StringInternal(ordinal); + } + + private UTF8String getUTF8StringInternal(int ordinal) { + CharSequence seq = struct.get(ordinal, CharSequence.class); + return UTF8String.fromString(seq.toString()); + } + + @Override + public byte[] getBinary(int ordinal) { + return isNullAt(ordinal) ? null : getBinaryInternal(ordinal); + } + + private byte[] getBinaryInternal(int ordinal) { + Object bytes = struct.get(ordinal, Object.class); + + // should only be either ByteBuffer or byte[] + if (bytes instanceof ByteBuffer) { + return ByteBuffers.toByteArray((ByteBuffer) bytes); + } else if (bytes instanceof byte[]) { + return (byte[]) bytes; + } else { + throw new IllegalStateException("Unknown type for binary field. Type name: " + bytes.getClass().getName()); + } + } + + @Override + public CalendarInterval getInterval(int ordinal) { + throw new UnsupportedOperationException("Unsupported type: interval"); + } + + @Override + public InternalRow getStruct(int ordinal, int numFields) { + return isNullAt(ordinal) ? null : getStructInternal(ordinal, numFields); + } + + private InternalRow getStructInternal(int ordinal, int numFields) { + return new StructInternalRow( + type.fields().get(ordinal).type().asStructType(), + struct.get(ordinal, StructLike.class)); + } + + @Override + public ArrayData getArray(int ordinal) { + return isNullAt(ordinal) ? null : getArrayInternal(ordinal); + } + + private ArrayData getArrayInternal(int ordinal) { + return collectionToArrayData( + type.fields().get(ordinal).type().asListType().elementType(), + struct.get(ordinal, Collection.class)); + } + + @Override + public MapData getMap(int ordinal) { + return isNullAt(ordinal) ? null : getMapInternal(ordinal); + } + + private MapData getMapInternal(int ordinal) { + return mapToMapData(type.fields().get(ordinal).type().asMapType(), struct.get(ordinal, Map.class)); + } + + @Override + @SuppressWarnings("checkstyle:CyclomaticComplexity") + public Object get(int ordinal, DataType dataType) { + if (isNullAt(ordinal)) { + return null; + } + + if (dataType instanceof IntegerType) { + return getInt(ordinal); + } else if (dataType instanceof LongType) { + return getLong(ordinal); + } else if (dataType instanceof StringType) { + return getUTF8StringInternal(ordinal); + } else if (dataType instanceof FloatType) { + return getFloat(ordinal); + } else if (dataType instanceof DoubleType) { + return getDouble(ordinal); + } else if (dataType instanceof DecimalType) { + DecimalType decimalType = (DecimalType) dataType; + return getDecimalInternal(ordinal, decimalType.precision(), decimalType.scale()); + } else if (dataType instanceof BinaryType) { + return getBinaryInternal(ordinal); + } else if (dataType instanceof StructType) { + return getStructInternal(ordinal, ((StructType) dataType).size()); + } else if (dataType instanceof ArrayType) { + return getArrayInternal(ordinal); + } else if (dataType instanceof MapType) { + return getMapInternal(ordinal); + } else if (dataType instanceof BooleanType) { + return getBoolean(ordinal); + } else if (dataType instanceof ByteType) { + return getByte(ordinal); + } else if (dataType instanceof ShortType) { + return getShort(ordinal); + } else if (dataType instanceof DateType) { + return getInt(ordinal); + } else if (dataType instanceof TimestampType) { + return getLong(ordinal); + } + return null; + } + + private MapData mapToMapData(Types.MapType mapType, Map map) { + // make a defensive copy to ensure entries do not change + List> entries = ImmutableList.copyOf(map.entrySet()); + return new ArrayBasedMapData( + collectionToArrayData(mapType.keyType(), Lists.transform(entries, Map.Entry::getKey)), + collectionToArrayData(mapType.valueType(), Lists.transform(entries, Map.Entry::getValue))); + } + + private ArrayData collectionToArrayData(Type elementType, Collection values) { + switch (elementType.typeId()) { + case BOOLEAN: + case INTEGER: + case DATE: + case TIME: + case LONG: + case TIMESTAMP: + case FLOAT: + case DOUBLE: + return fillArray(values, array -> (pos, value) -> array[pos] = value); + case STRING: + return fillArray(values, array -> + (BiConsumer) (pos, seq) -> array[pos] = UTF8String.fromString(seq.toString())); + case FIXED: + case BINARY: + return fillArray(values, array -> + (BiConsumer) (pos, buf) -> array[pos] = ByteBuffers.toByteArray(buf)); + case DECIMAL: + return fillArray(values, array -> + (BiConsumer) (pos, dec) -> array[pos] = Decimal.apply(dec)); + case STRUCT: + return fillArray(values, array -> (BiConsumer) (pos, tuple) -> + array[pos] = new StructInternalRow(elementType.asStructType(), tuple)); + case LIST: + return fillArray(values, array -> (BiConsumer>) (pos, list) -> + array[pos] = collectionToArrayData(elementType.asListType().elementType(), list)); + case MAP: + return fillArray(values, array -> (BiConsumer>) (pos, map) -> + array[pos] = mapToMapData(elementType.asMapType(), map)); + default: + throw new UnsupportedOperationException("Unsupported array element type: " + elementType); + } + } + + @SuppressWarnings("unchecked") + private GenericArrayData fillArray(Collection values, Function> makeSetter) { + Object[] array = new Object[values.size()]; + BiConsumer setter = makeSetter.apply(array); + + int index = 0; + for (Object value : values) { + if (value == null) { + array[index] = null; + } else { + setter.accept(index, (T) value); + } + + index += 1; + } + + return new GenericArrayData(array); + } +} diff --git a/spark/v3.3/spark/src/main/java/org/apache/spark/sql/catalyst/analysis/NoSuchProcedureException.java b/spark/v3.3/spark/src/main/java/org/apache/spark/sql/catalyst/analysis/NoSuchProcedureException.java new file mode 100644 index 000000000000..25f6368cc8c3 --- /dev/null +++ b/spark/v3.3/spark/src/main/java/org/apache/spark/sql/catalyst/analysis/NoSuchProcedureException.java @@ -0,0 +1,31 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.spark.sql.catalyst.analysis; + +import org.apache.spark.sql.AnalysisException; +import org.apache.spark.sql.connector.catalog.Identifier; +import scala.Option; + +public class NoSuchProcedureException extends AnalysisException { + public NoSuchProcedureException(Identifier ident) { + super("Procedure " + ident + " not found", Option.empty(), Option.empty(), Option.empty(), + Option.empty(), Option.empty(), new String[0]); + } +} diff --git a/spark/v3.3/spark/src/main/java/org/apache/spark/sql/connector/iceberg/catalog/Procedure.java b/spark/v3.3/spark/src/main/java/org/apache/spark/sql/connector/iceberg/catalog/Procedure.java new file mode 100644 index 000000000000..8f7a70b9f9fc --- /dev/null +++ b/spark/v3.3/spark/src/main/java/org/apache/spark/sql/connector/iceberg/catalog/Procedure.java @@ -0,0 +1,60 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.spark.sql.connector.iceberg.catalog; + +import org.apache.spark.sql.catalyst.InternalRow; +import org.apache.spark.sql.types.StructType; + +/** + * An interface representing a stored procedure available for execution. + */ +public interface Procedure { + /** + * Returns the input parameters of this procedure. + */ + ProcedureParameter[] parameters(); + + /** + * Returns the type of rows produced by this procedure. + */ + StructType outputType(); + + /** + * Executes this procedure. + *

+ * Spark will align the provided arguments according to the input parameters + * defined in {@link #parameters()} either by position or by name before execution. + *

+ * Implementations may provide a summary of execution by returning one or many rows + * as a result. The schema of output rows must match the defined output type + * in {@link #outputType()}. + * + * @param args input arguments + * @return the result of executing this procedure with the given arguments + */ + InternalRow[] call(InternalRow args); + + /** + * Returns the description of this procedure. + */ + default String description() { + return this.getClass().toString(); + } +} diff --git a/spark/v3.3/spark/src/main/java/org/apache/spark/sql/connector/iceberg/catalog/ProcedureCatalog.java b/spark/v3.3/spark/src/main/java/org/apache/spark/sql/connector/iceberg/catalog/ProcedureCatalog.java new file mode 100644 index 000000000000..314bd659460e --- /dev/null +++ b/spark/v3.3/spark/src/main/java/org/apache/spark/sql/connector/iceberg/catalog/ProcedureCatalog.java @@ -0,0 +1,41 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.spark.sql.connector.iceberg.catalog; + +import org.apache.spark.sql.catalyst.analysis.NoSuchProcedureException; +import org.apache.spark.sql.connector.catalog.CatalogPlugin; +import org.apache.spark.sql.connector.catalog.Identifier; + +/** + * A catalog API for working with stored procedures. + *

+ * Implementations should implement this interface if they expose stored procedures that + * can be called via CALL statements. + */ +public interface ProcedureCatalog extends CatalogPlugin { + /** + * Load a {@link Procedure stored procedure} by {@link Identifier identifier}. + * + * @param ident a stored procedure identifier + * @return the stored procedure's metadata + * @throws NoSuchProcedureException if there is no matching stored procedure + */ + Procedure loadProcedure(Identifier ident) throws NoSuchProcedureException; +} diff --git a/spark/v3.3/spark/src/main/java/org/apache/spark/sql/connector/iceberg/catalog/ProcedureParameter.java b/spark/v3.3/spark/src/main/java/org/apache/spark/sql/connector/iceberg/catalog/ProcedureParameter.java new file mode 100644 index 000000000000..b341dc1e3282 --- /dev/null +++ b/spark/v3.3/spark/src/main/java/org/apache/spark/sql/connector/iceberg/catalog/ProcedureParameter.java @@ -0,0 +1,65 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.spark.sql.connector.iceberg.catalog; + +import org.apache.spark.sql.types.DataType; + +/** + * An input parameter of a {@link Procedure stored procedure}. + */ +public interface ProcedureParameter { + + /** + * Creates a required input parameter. + * + * @param name the name of the parameter + * @param dataType the type of the parameter + * @return the constructed stored procedure parameter + */ + static ProcedureParameter required(String name, DataType dataType) { + return new ProcedureParameterImpl(name, dataType, true); + } + + /** + * Creates an optional input parameter. + * + * @param name the name of the parameter. + * @param dataType the type of the parameter. + * @return the constructed optional stored procedure parameter + */ + static ProcedureParameter optional(String name, DataType dataType) { + return new ProcedureParameterImpl(name, dataType, false); + } + + /** + * Returns the name of this parameter. + */ + String name(); + + /** + * Returns the type of this parameter. + */ + DataType dataType(); + + /** + * Returns true if this parameter is required. + */ + boolean required(); +} diff --git a/spark/v3.3/spark/src/main/java/org/apache/spark/sql/connector/iceberg/catalog/ProcedureParameterImpl.java b/spark/v3.3/spark/src/main/java/org/apache/spark/sql/connector/iceberg/catalog/ProcedureParameterImpl.java new file mode 100644 index 000000000000..cea1e80f4051 --- /dev/null +++ b/spark/v3.3/spark/src/main/java/org/apache/spark/sql/connector/iceberg/catalog/ProcedureParameterImpl.java @@ -0,0 +1,77 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.spark.sql.connector.iceberg.catalog; + +import java.util.Objects; +import org.apache.spark.sql.types.DataType; + +/** + * A {@link ProcedureParameter} implementation. + */ +class ProcedureParameterImpl implements ProcedureParameter { + private final String name; + private final DataType dataType; + private final boolean required; + + ProcedureParameterImpl(String name, DataType dataType, boolean required) { + this.name = name; + this.dataType = dataType; + this.required = required; + } + + @Override + public String name() { + return name; + } + + @Override + public DataType dataType() { + return dataType; + } + + @Override + public boolean required() { + return required; + } + + @Override + public boolean equals(Object other) { + if (this == other) { + return true; + } else if (other == null || getClass() != other.getClass()) { + return false; + } + + ProcedureParameterImpl that = (ProcedureParameterImpl) other; + return required == that.required && + Objects.equals(name, that.name) && + Objects.equals(dataType, that.dataType); + } + + @Override + public int hashCode() { + return Objects.hash(name, dataType, required); + } + + @Override + public String toString() { + return String.format("ProcedureParameter(name='%s', type=%s, required=%b)", name, dataType, required); + } +} diff --git a/spark/v3.3/spark/src/main/java/org/apache/spark/sql/connector/iceberg/write/DeltaBatchWrite.java b/spark/v3.3/spark/src/main/java/org/apache/spark/sql/connector/iceberg/write/DeltaBatchWrite.java new file mode 100644 index 000000000000..a1fbf0fb8a24 --- /dev/null +++ b/spark/v3.3/spark/src/main/java/org/apache/spark/sql/connector/iceberg/write/DeltaBatchWrite.java @@ -0,0 +1,31 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.spark.sql.connector.iceberg.write; + +import org.apache.spark.sql.connector.write.BatchWrite; +import org.apache.spark.sql.connector.write.PhysicalWriteInfo; + +/** + * An interface that defines how to write a delta of rows during batch processing. + */ +public interface DeltaBatchWrite extends BatchWrite { + @Override + DeltaWriterFactory createBatchWriterFactory(PhysicalWriteInfo info); +} diff --git a/spark/v3.3/spark/src/main/java/org/apache/spark/sql/connector/iceberg/write/DeltaWrite.java b/spark/v3.3/spark/src/main/java/org/apache/spark/sql/connector/iceberg/write/DeltaWrite.java new file mode 100644 index 000000000000..7643e2d103ec --- /dev/null +++ b/spark/v3.3/spark/src/main/java/org/apache/spark/sql/connector/iceberg/write/DeltaWrite.java @@ -0,0 +1,33 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.spark.sql.connector.iceberg.write; + +import org.apache.spark.sql.connector.write.Write; + +/** + * A logical representation of a data source write that handles a delta of rows. + * A delta of rows is a set of instructions that indicate which records need to be deleted, + * updated, or inserted. Data sources that support deltas allow Spark to discard unchanged rows + * and pass only the information about what rows have changed during a row-level operation. + */ +public interface DeltaWrite extends Write { + @Override + DeltaBatchWrite toBatch(); +} diff --git a/spark/v3.3/spark/src/main/java/org/apache/spark/sql/connector/iceberg/write/DeltaWriteBuilder.java b/spark/v3.3/spark/src/main/java/org/apache/spark/sql/connector/iceberg/write/DeltaWriteBuilder.java new file mode 100644 index 000000000000..1af1c6680c89 --- /dev/null +++ b/spark/v3.3/spark/src/main/java/org/apache/spark/sql/connector/iceberg/write/DeltaWriteBuilder.java @@ -0,0 +1,35 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.spark.sql.connector.iceberg.write; + +import org.apache.spark.sql.connector.write.WriteBuilder; + +/** + * An interface for building delta writes. + */ +public interface DeltaWriteBuilder extends WriteBuilder { + /** + * Returns a logical delta write. + */ + @Override + default DeltaWrite build() { + throw new UnsupportedOperationException("Not implemented: build"); + } +} diff --git a/spark/v3.3/spark/src/main/java/org/apache/spark/sql/connector/iceberg/write/DeltaWriter.java b/spark/v3.3/spark/src/main/java/org/apache/spark/sql/connector/iceberg/write/DeltaWriter.java new file mode 100644 index 000000000000..a17ee33b13d7 --- /dev/null +++ b/spark/v3.3/spark/src/main/java/org/apache/spark/sql/connector/iceberg/write/DeltaWriter.java @@ -0,0 +1,60 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.spark.sql.connector.iceberg.write; + +import java.io.IOException; +import org.apache.spark.sql.connector.write.DataWriter; + +/** + * A data writer responsible for writing a delta of rows. + */ +public interface DeltaWriter extends DataWriter { + /** + * Passes information for a row that must be deleted. + * + * @param metadata values for metadata columns that were projected but are not part of the row ID + * @param id a row ID to delete + * @throws IOException if the write process encounters an error + */ + void delete(T metadata, T id) throws IOException; + + /** + * Passes information for a row that must be updated together with the updated row. + * + * @param metadata values for metadata columns that were projected but are not part of the row ID + * @param id a row ID to update + * @param row a row with updated values + * @throws IOException if the write process encounters an error + */ + void update(T metadata, T id, T row) throws IOException; + + /** + * Passes a row to insert. + * + * @param row a row to insert + * @throws IOException if the write process encounters an error + */ + void insert(T row) throws IOException; + + @Override + default void write(T row) throws IOException { + insert(row); + } +} diff --git a/spark/v3.3/spark/src/main/java/org/apache/spark/sql/connector/iceberg/write/DeltaWriterFactory.java b/spark/v3.3/spark/src/main/java/org/apache/spark/sql/connector/iceberg/write/DeltaWriterFactory.java new file mode 100644 index 000000000000..77af70958e41 --- /dev/null +++ b/spark/v3.3/spark/src/main/java/org/apache/spark/sql/connector/iceberg/write/DeltaWriterFactory.java @@ -0,0 +1,31 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.spark.sql.connector.iceberg.write; + +import org.apache.spark.sql.catalyst.InternalRow; +import org.apache.spark.sql.connector.write.DataWriterFactory; + +/** + * A factory for creating and initializing delta writers at the executor side. + */ +public interface DeltaWriterFactory extends DataWriterFactory { + @Override + DeltaWriter createWriter(int partitionId, long taskId); +} diff --git a/spark/v3.3/spark/src/main/java/org/apache/spark/sql/connector/iceberg/write/ExtendedLogicalWriteInfo.java b/spark/v3.3/spark/src/main/java/org/apache/spark/sql/connector/iceberg/write/ExtendedLogicalWriteInfo.java new file mode 100644 index 000000000000..3c39a4f0f1b5 --- /dev/null +++ b/spark/v3.3/spark/src/main/java/org/apache/spark/sql/connector/iceberg/write/ExtendedLogicalWriteInfo.java @@ -0,0 +1,38 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.spark.sql.connector.iceberg.write; + +import org.apache.spark.sql.connector.write.LogicalWriteInfo; +import org.apache.spark.sql.types.StructType; + +/** + * A class that holds logical write information not covered by LogicalWriteInfo in Spark. + */ +public interface ExtendedLogicalWriteInfo extends LogicalWriteInfo { + /** + * The schema of the input metadata from Spark to data source. + */ + StructType metadataSchema(); + + /** + * The schema of the ID columns from Spark to data source. + */ + StructType rowIdSchema(); +} diff --git a/spark/v3.3/spark/src/main/java/org/apache/spark/sql/connector/iceberg/write/SupportsDelta.java b/spark/v3.3/spark/src/main/java/org/apache/spark/sql/connector/iceberg/write/SupportsDelta.java new file mode 100644 index 000000000000..2b0926326470 --- /dev/null +++ b/spark/v3.3/spark/src/main/java/org/apache/spark/sql/connector/iceberg/write/SupportsDelta.java @@ -0,0 +1,38 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.spark.sql.connector.iceberg.write; + +import org.apache.spark.sql.connector.expressions.NamedReference; +import org.apache.spark.sql.connector.write.LogicalWriteInfo; +import org.apache.spark.sql.connector.write.RowLevelOperation; + +/** + * A mix-in interface for RowLevelOperation. Data sources can implement this interface + * to indicate they support handling deltas of rows. + */ +public interface SupportsDelta extends RowLevelOperation { + @Override + DeltaWriteBuilder newWriteBuilder(LogicalWriteInfo info); + + /** + * Returns the row ID column references that should be used for row equality. + */ + NamedReference[] rowId(); +} diff --git a/spark/v3.3/spark/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister b/spark/v3.3/spark/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister new file mode 100644 index 000000000000..01a6c4e0670d --- /dev/null +++ b/spark/v3.3/spark/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister @@ -0,0 +1,20 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +org.apache.iceberg.spark.source.IcebergSource diff --git a/spark/v3.3/spark/src/main/scala/org/apache/spark/sql/catalyst/expressions/TransformExpressions.scala b/spark/v3.3/spark/src/main/scala/org/apache/spark/sql/catalyst/expressions/TransformExpressions.scala new file mode 100644 index 000000000000..2d2c2fb9aae5 --- /dev/null +++ b/spark/v3.3/spark/src/main/scala/org/apache/spark/sql/catalyst/expressions/TransformExpressions.scala @@ -0,0 +1,159 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.spark.sql.catalyst.expressions + +import java.nio.ByteBuffer +import java.nio.CharBuffer +import java.nio.charset.StandardCharsets +import org.apache.iceberg.spark.SparkSchemaUtil +import org.apache.iceberg.transforms.Transform +import org.apache.iceberg.transforms.Transforms +import org.apache.iceberg.types.Type +import org.apache.iceberg.types.Types +import org.apache.iceberg.util.ByteBuffers +import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback +import org.apache.spark.sql.types.AbstractDataType +import org.apache.spark.sql.types.BinaryType +import org.apache.spark.sql.types.DataType +import org.apache.spark.sql.types.Decimal +import org.apache.spark.sql.types.DecimalType +import org.apache.spark.sql.types.IntegerType +import org.apache.spark.sql.types.StringType +import org.apache.spark.sql.types.TimestampType +import org.apache.spark.unsafe.types.UTF8String + +abstract class IcebergTransformExpression + extends UnaryExpression with CodegenFallback with NullIntolerant { + + @transient lazy val icebergInputType: Type = SparkSchemaUtil.convert(child.dataType) +} + +abstract class IcebergTimeTransform + extends IcebergTransformExpression with ImplicitCastInputTypes { + + def transform: Transform[Any, Integer] + + override protected def nullSafeEval(value: Any): Any = { + transform(value).toInt + } + + override def dataType: DataType = IntegerType + + override def inputTypes: Seq[AbstractDataType] = Seq(TimestampType) +} + +case class IcebergYearTransform(child: Expression) + extends IcebergTimeTransform { + + @transient lazy val transform: Transform[Any, Integer] = Transforms.year[Any](icebergInputType) + + override protected def withNewChildInternal(newChild: Expression): Expression = { + copy(child = newChild) + } +} + +case class IcebergMonthTransform(child: Expression) + extends IcebergTimeTransform { + + @transient lazy val transform: Transform[Any, Integer] = Transforms.month[Any](icebergInputType) + + override protected def withNewChildInternal(newChild: Expression): Expression = { + copy(child = newChild) + } +} + +case class IcebergDayTransform(child: Expression) + extends IcebergTimeTransform { + + @transient lazy val transform: Transform[Any, Integer] = Transforms.day[Any](icebergInputType) + + override protected def withNewChildInternal(newChild: Expression): Expression = { + copy(child = newChild) + } +} + +case class IcebergHourTransform(child: Expression) + extends IcebergTimeTransform { + + @transient lazy val transform: Transform[Any, Integer] = Transforms.hour[Any](icebergInputType) + + override protected def withNewChildInternal(newChild: Expression): Expression = { + copy(child = newChild) + } +} + +case class IcebergBucketTransform(numBuckets: Int, child: Expression) extends IcebergTransformExpression { + + @transient lazy val bucketFunc: Any => Int = child.dataType match { + case _: DecimalType => + val t = Transforms.bucket[Any](icebergInputType, numBuckets) + d: Any => t(d.asInstanceOf[Decimal].toJavaBigDecimal).toInt + case _: StringType => + // the spec requires that the hash of a string is equal to the hash of its UTF-8 encoded bytes + // TODO: pass bytes without the copy out of the InternalRow + val t = Transforms.bucket[ByteBuffer](Types.BinaryType.get(), numBuckets) + s: Any => t(ByteBuffer.wrap(s.asInstanceOf[UTF8String].getBytes)).toInt + case _ => + val t = Transforms.bucket[Any](icebergInputType, numBuckets) + a: Any => t(a).toInt + } + + override protected def nullSafeEval(value: Any): Any = { + bucketFunc(value) + } + + override def dataType: DataType = IntegerType + + override protected def withNewChildInternal(newChild: Expression): Expression = { + copy(child = newChild) + } +} + +case class IcebergTruncateTransform(child: Expression, width: Int) extends IcebergTransformExpression { + + @transient lazy val truncateFunc: Any => Any = child.dataType match { + case _: DecimalType => + val t = Transforms.truncate[java.math.BigDecimal](icebergInputType, width) + d: Any => Decimal.apply(t(d.asInstanceOf[Decimal].toJavaBigDecimal)) + case _: StringType => + val t = Transforms.truncate[CharSequence](icebergInputType, width) + s: Any => { + val charSequence = t(StandardCharsets.UTF_8.decode(ByteBuffer.wrap(s.asInstanceOf[UTF8String].getBytes))) + val bb = StandardCharsets.UTF_8.encode(CharBuffer.wrap(charSequence)); + UTF8String.fromBytes(ByteBuffers.toByteArray(bb)) + } + case _: BinaryType => + val t = Transforms.truncate[ByteBuffer](icebergInputType, width) + s: Any => ByteBuffers.toByteArray(t(ByteBuffer.wrap(s.asInstanceOf[Array[Byte]]))) + case _ => + val t = Transforms.truncate[Any](icebergInputType, width) + a: Any => t(a) + } + + override protected def nullSafeEval(value: Any): Any = { + truncateFunc(value) + } + + override def dataType: DataType = child.dataType + + override protected def withNewChildInternal(newChild: Expression): Expression = { + copy(child = newChild) + } +} diff --git a/spark/v3.3/spark/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/SetWriteDistributionAndOrdering.scala b/spark/v3.3/spark/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/SetWriteDistributionAndOrdering.scala new file mode 100644 index 000000000000..0a0234cdfe34 --- /dev/null +++ b/spark/v3.3/spark/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/SetWriteDistributionAndOrdering.scala @@ -0,0 +1,44 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.spark.sql.catalyst.plans.logical + +import org.apache.iceberg.DistributionMode +import org.apache.iceberg.NullOrder +import org.apache.iceberg.SortDirection +import org.apache.iceberg.expressions.Term +import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.connector.catalog.CatalogV2Implicits + +case class SetWriteDistributionAndOrdering( + table: Seq[String], + distributionMode: DistributionMode, + sortOrder: Seq[(Term, SortDirection, NullOrder)]) extends LeafCommand { + + import CatalogV2Implicits._ + + override lazy val output: Seq[Attribute] = Nil + + override def simpleString(maxFields: Int): String = { + val order = sortOrder.map { + case (term, direction, nullOrder) => s"$term $direction $nullOrder" + }.mkString(", ") + s"SetWriteDistributionAndOrdering ${table.quoted} $distributionMode $order" + } +} diff --git a/spark/v3.3/spark/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/SortOrderParserUtil.scala b/spark/v3.3/spark/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/SortOrderParserUtil.scala new file mode 100644 index 000000000000..bf19ef8a2167 --- /dev/null +++ b/spark/v3.3/spark/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/SortOrderParserUtil.scala @@ -0,0 +1,40 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.spark.sql.catalyst.plans.logical + +import org.apache.iceberg.NullOrder +import org.apache.iceberg.Schema +import org.apache.iceberg.SortDirection +import org.apache.iceberg.SortOrder +import org.apache.iceberg.expressions.Term + +class SortOrderParserUtil { + + def collectSortOrder(tableSchema:Schema, sortOrder: Seq[(Term, SortDirection, NullOrder)]): SortOrder = { + val orderBuilder = SortOrder.builderFor(tableSchema) + sortOrder.foreach { + case (term, SortDirection.ASC, nullOrder) => + orderBuilder.asc(term, nullOrder) + case (term, SortDirection.DESC, nullOrder) => + orderBuilder.desc(term, nullOrder) + } + orderBuilder.build(); + } +} diff --git a/spark/v3.3/spark/src/main/scala/org/apache/spark/sql/catalyst/utils/DistributionAndOrderingUtils.scala b/spark/v3.3/spark/src/main/scala/org/apache/spark/sql/catalyst/utils/DistributionAndOrderingUtils.scala new file mode 100644 index 000000000000..94b6f651a0df --- /dev/null +++ b/spark/v3.3/spark/src/main/scala/org/apache/spark/sql/catalyst/utils/DistributionAndOrderingUtils.scala @@ -0,0 +1,189 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.spark.sql.catalyst.utils + +import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.catalyst +import org.apache.spark.sql.catalyst.analysis.Resolver +import org.apache.spark.sql.catalyst.expressions.IcebergBucketTransform +import org.apache.spark.sql.catalyst.expressions.IcebergDayTransform +import org.apache.spark.sql.catalyst.expressions.IcebergHourTransform +import org.apache.spark.sql.catalyst.expressions.IcebergMonthTransform +import org.apache.spark.sql.catalyst.expressions.IcebergTruncateTransform +import org.apache.spark.sql.catalyst.expressions.IcebergYearTransform +import org.apache.spark.sql.catalyst.expressions.NamedExpression +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.catalyst.plans.logical.RepartitionByExpression +import org.apache.spark.sql.catalyst.plans.logical.Sort +import org.apache.spark.sql.connector.distributions.ClusteredDistribution +import org.apache.spark.sql.connector.distributions.Distribution +import org.apache.spark.sql.connector.distributions.OrderedDistribution +import org.apache.spark.sql.connector.distributions.UnspecifiedDistribution +import org.apache.spark.sql.connector.expressions.ApplyTransform +import org.apache.spark.sql.connector.expressions.BucketTransform +import org.apache.spark.sql.connector.expressions.DaysTransform +import org.apache.spark.sql.connector.expressions.Expression +import org.apache.spark.sql.connector.expressions.FieldReference +import org.apache.spark.sql.connector.expressions.HoursTransform +import org.apache.spark.sql.connector.expressions.IdentityTransform +import org.apache.spark.sql.connector.expressions.Literal +import org.apache.spark.sql.connector.expressions.MonthsTransform +import org.apache.spark.sql.connector.expressions.NamedReference +import org.apache.spark.sql.connector.expressions.NullOrdering +import org.apache.spark.sql.connector.expressions.SortDirection +import org.apache.spark.sql.connector.expressions.SortOrder +import org.apache.spark.sql.connector.expressions.Transform +import org.apache.spark.sql.connector.expressions.YearsTransform +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.types.DataType +import org.apache.spark.sql.types.IntegerType +import scala.collection.compat.immutable.ArraySeq + +object DistributionAndOrderingUtils { + + import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._ + + def prepareQuery( + requiredDistribution: Distribution, + requiredOrdering: Array[SortOrder], + query: LogicalPlan, + conf: SQLConf): LogicalPlan = { + + val resolver = conf.resolver + + val distribution = requiredDistribution match { + case d: OrderedDistribution => + d.ordering.map(e => toCatalyst(e, query, resolver)) + case d: ClusteredDistribution => + d.clustering.map(e => toCatalyst(e, query, resolver)) + case _: UnspecifiedDistribution => + Array.empty[catalyst.expressions.Expression] + } + + val queryWithDistribution = if (distribution.nonEmpty) { + // the conversion to catalyst expressions above produces SortOrder expressions + // for OrderedDistribution and generic expressions for ClusteredDistribution + // this allows RepartitionByExpression to pick either range or hash partitioning + RepartitionByExpression(distribution.toSeq, query, None) + } else { + query + } + + val ordering = requiredOrdering + .map(e => toCatalyst(e, query, resolver).asInstanceOf[catalyst.expressions.SortOrder]) + + val queryWithDistributionAndOrdering = if (ordering.nonEmpty) { + Sort(ArraySeq.unsafeWrapArray(ordering), global = false, queryWithDistribution) + } else { + queryWithDistribution + } + + queryWithDistributionAndOrdering + } + + private def toCatalyst( + expr: Expression, + query: LogicalPlan, + resolver: Resolver): catalyst.expressions.Expression = { + + // we cannot perform the resolution in the analyzer since we need to optimize expressions + // in nodes like OverwriteByExpression before constructing a logical write + def resolve(parts: Seq[String]): NamedExpression = { + query.resolve(parts, resolver) match { + case Some(attr) => + attr + case None => + val ref = parts.quoted + throw new AnalysisException(s"Cannot resolve '$ref' using ${query.output}") + } + } + + expr match { + case s: SortOrder => + val catalystChild = toCatalyst(s.expression(), query, resolver) + catalyst.expressions.SortOrder(catalystChild, toCatalyst(s.direction), toCatalyst(s.nullOrdering), Seq.empty) + case it: IdentityTransform => + resolve(ArraySeq.unsafeWrapArray(it.ref.fieldNames)) + case BucketTransform(numBuckets, ref) => + IcebergBucketTransform(numBuckets, resolve(ArraySeq.unsafeWrapArray(ref.fieldNames))) + case TruncateTransform(ref, width) => + IcebergTruncateTransform(resolve(ArraySeq.unsafeWrapArray(ref.fieldNames)), width) + case yt: YearsTransform => + IcebergYearTransform(resolve(ArraySeq.unsafeWrapArray(yt.ref.fieldNames))) + case mt: MonthsTransform => + IcebergMonthTransform(resolve(ArraySeq.unsafeWrapArray(mt.ref.fieldNames))) + case dt: DaysTransform => + IcebergDayTransform(resolve(ArraySeq.unsafeWrapArray(dt.ref.fieldNames))) + case ht: HoursTransform => + IcebergHourTransform(resolve(ArraySeq.unsafeWrapArray(ht.ref.fieldNames))) + case ref: FieldReference => + resolve(ArraySeq.unsafeWrapArray(ref.fieldNames)) + case _ => + throw new RuntimeException(s"$expr is not currently supported") + + } + } + + private def toCatalyst(direction: SortDirection): catalyst.expressions.SortDirection = { + direction match { + case SortDirection.ASCENDING => catalyst.expressions.Ascending + case SortDirection.DESCENDING => catalyst.expressions.Descending + } + } + + private def toCatalyst(nullOrdering: NullOrdering): catalyst.expressions.NullOrdering = { + nullOrdering match { + case NullOrdering.NULLS_FIRST => catalyst.expressions.NullsFirst + case NullOrdering.NULLS_LAST => catalyst.expressions.NullsLast + } + } + + private object BucketTransform { + def unapply(transform: Transform): Option[(Int, FieldReference)] = transform match { + case bt: BucketTransform => bt.columns match { + case Seq(nf: NamedReference) => + Some(bt.numBuckets.value(), FieldReference(ArraySeq.unsafeWrapArray(nf.fieldNames()))) + case _ => + None + } + case _ => None + } + } + + private object Lit { + def unapply[T](literal: Literal[T]): Some[(T, DataType)] = { + Some((literal.value, literal.dataType)) + } + } + + private object TruncateTransform { + def unapply(transform: Transform): Option[(FieldReference, Int)] = transform match { + case at @ ApplyTransform(name, _) if name.equalsIgnoreCase("truncate") => at.args match { + case Seq(nf: NamedReference, Lit(value: Int, IntegerType)) => + Some(FieldReference(ArraySeq.unsafeWrapArray(nf.fieldNames())), value) + case Seq(Lit(value: Int, IntegerType), nf: NamedReference) => + Some(FieldReference(ArraySeq.unsafeWrapArray(nf.fieldNames())), value) + case _ => + None + } + case _ => None + } + } +} diff --git a/spark/v3.3/spark/src/main/scala/org/apache/spark/sql/catalyst/utils/PlanUtils.scala b/spark/v3.3/spark/src/main/scala/org/apache/spark/sql/catalyst/utils/PlanUtils.scala new file mode 100644 index 000000000000..aa9e9c553346 --- /dev/null +++ b/spark/v3.3/spark/src/main/scala/org/apache/spark/sql/catalyst/utils/PlanUtils.scala @@ -0,0 +1,42 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.spark.sql.catalyst.utils + +import org.apache.iceberg.spark.source.SparkTable +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.catalyst.plans.logical.SubqueryAlias +import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation +import scala.annotation.tailrec + +object PlanUtils { + @tailrec + def isIcebergRelation(plan: LogicalPlan): Boolean = { + def isIcebergTable(relation: DataSourceV2Relation): Boolean = relation.table match { + case _: SparkTable => true + case _ => false + } + + plan match { + case s: SubqueryAlias => isIcebergRelation(s.child) + case r: DataSourceV2Relation => isIcebergTable(r) + case _ => false + } + } +} diff --git a/spark/v3.3/spark/src/main/scala/org/apache/spark/sql/execution/datasources/SparkExpressionConverter.scala b/spark/v3.3/spark/src/main/scala/org/apache/spark/sql/execution/datasources/SparkExpressionConverter.scala new file mode 100644 index 000000000000..c4b5a7c0ce14 --- /dev/null +++ b/spark/v3.3/spark/src/main/scala/org/apache/spark/sql/execution/datasources/SparkExpressionConverter.scala @@ -0,0 +1,51 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.spark.sql.execution.datasources + +import org.apache.iceberg.spark.SparkFilters +import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.catalyst.expressions.Expression +import org.apache.spark.sql.catalyst.plans.logical.Filter +import org.apache.spark.sql.execution.CommandExecutionMode + +object SparkExpressionConverter { + + def convertToIcebergExpression(sparkExpression: Expression): org.apache.iceberg.expressions.Expression = { + // Currently, it is a double conversion as we are converting Spark expression to Spark filter + // and then converting Spark filter to Iceberg expression. + // But these two conversions already exist and well tested. So, we are going with this approach. + SparkFilters.convert(DataSourceStrategy.translateFilter(sparkExpression, supportNestedPredicatePushdown = true).get) + } + + @throws[AnalysisException] + def collectResolvedSparkExpression(session: SparkSession, tableName: String, where: String): Expression = { + var expression: Expression = null + // Add a dummy prefix linking to the table to collect the resolved spark expression from optimized plan. + val prefix = String.format("SELECT 42 from %s where ", tableName) + val logicalPlan = session.sessionState.sqlParser.parsePlan(prefix + where) + val optimizedLogicalPlan = session.sessionState.executePlan(logicalPlan, CommandExecutionMode.ALL).optimizedPlan + optimizedLogicalPlan.collectFirst { + case filter: Filter => + expression = filter.expressions.head + } + expression + } +} diff --git a/spark/v3.3/spark/src/test/java/org/apache/iceberg/KryoHelpers.java b/spark/v3.3/spark/src/test/java/org/apache/iceberg/KryoHelpers.java new file mode 100644 index 000000000000..ee0f0a73959a --- /dev/null +++ b/spark/v3.3/spark/src/test/java/org/apache/iceberg/KryoHelpers.java @@ -0,0 +1,52 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iceberg; + +import com.esotericsoftware.kryo.Kryo; +import com.esotericsoftware.kryo.io.Input; +import com.esotericsoftware.kryo.io.Output; +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.ObjectInputStream; +import java.io.ObjectOutputStream; +import org.apache.spark.SparkConf; +import org.apache.spark.serializer.KryoSerializer; + +public class KryoHelpers { + + private KryoHelpers() { + } + + @SuppressWarnings("unchecked") + public static T roundTripSerialize(T obj) throws IOException { + Kryo kryo = new KryoSerializer(new SparkConf()).newKryo(); + + ByteArrayOutputStream bytes = new ByteArrayOutputStream(); + + try (Output out = new Output(new ObjectOutputStream(bytes))) { + kryo.writeClassAndObject(out, obj); + } + + try (Input in = new Input(new ObjectInputStream(new ByteArrayInputStream(bytes.toByteArray())))) { + return (T) kryo.readClassAndObject(in); + } + } +} diff --git a/spark/v3.3/spark/src/test/java/org/apache/iceberg/TaskCheckHelper.java b/spark/v3.3/spark/src/test/java/org/apache/iceberg/TaskCheckHelper.java new file mode 100644 index 000000000000..dd187c0b6bf0 --- /dev/null +++ b/spark/v3.3/spark/src/test/java/org/apache/iceberg/TaskCheckHelper.java @@ -0,0 +1,93 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iceberg; + +import java.util.Comparator; +import java.util.List; +import java.util.stream.Collectors; +import org.junit.Assert; + +public final class TaskCheckHelper { + private TaskCheckHelper() { + } + + public static void assertEquals(ScanTaskGroup expected, ScanTaskGroup actual) { + List expectedTasks = getFileScanTasksInFilePathOrder(expected); + List actualTasks = getFileScanTasksInFilePathOrder(actual); + + Assert.assertEquals("The number of file scan tasks should match", + expectedTasks.size(), actualTasks.size()); + + for (int i = 0; i < expectedTasks.size(); i++) { + FileScanTask expectedTask = expectedTasks.get(i); + FileScanTask actualTask = actualTasks.get(i); + assertEquals(expectedTask, actualTask); + } + } + + public static void assertEquals(FileScanTask expected, FileScanTask actual) { + assertEquals(expected.file(), actual.file()); + + // PartitionSpec implements its own equals method + Assert.assertEquals("PartitionSpec doesn't match", expected.spec(), actual.spec()); + + Assert.assertEquals("starting position doesn't match", expected.start(), actual.start()); + + Assert.assertEquals("the number of bytes to scan doesn't match", expected.start(), actual.start()); + + // simplify comparison on residual expression via comparing toString + Assert.assertEquals("Residual expression doesn't match", + expected.residual().toString(), actual.residual().toString()); + } + + public static void assertEquals(DataFile expected, DataFile actual) { + Assert.assertEquals("Should match the serialized record path", + expected.path(), actual.path()); + Assert.assertEquals("Should match the serialized record format", + expected.format(), actual.format()); + Assert.assertEquals("Should match the serialized record partition", + expected.partition().get(0, Object.class), actual.partition().get(0, Object.class)); + Assert.assertEquals("Should match the serialized record count", + expected.recordCount(), actual.recordCount()); + Assert.assertEquals("Should match the serialized record size", + expected.fileSizeInBytes(), actual.fileSizeInBytes()); + Assert.assertEquals("Should match the serialized record value counts", + expected.valueCounts(), actual.valueCounts()); + Assert.assertEquals("Should match the serialized record null value counts", + expected.nullValueCounts(), actual.nullValueCounts()); + Assert.assertEquals("Should match the serialized record lower bounds", + expected.lowerBounds(), actual.lowerBounds()); + Assert.assertEquals("Should match the serialized record upper bounds", + expected.upperBounds(), actual.upperBounds()); + Assert.assertEquals("Should match the serialized record key metadata", + expected.keyMetadata(), actual.keyMetadata()); + Assert.assertEquals("Should match the serialized record offsets", + expected.splitOffsets(), actual.splitOffsets()); + Assert.assertEquals("Should match the serialized record offsets", + expected.keyMetadata(), actual.keyMetadata()); + } + + private static List getFileScanTasksInFilePathOrder(ScanTaskGroup taskGroup) { + return taskGroup.tasks().stream() + // use file path + start position to differentiate the tasks + .sorted(Comparator.comparing(o -> o.file().path().toString() + "##" + o.start())) + .collect(Collectors.toList()); + } +} diff --git a/spark/v3.3/spark/src/test/java/org/apache/iceberg/TestDataFileSerialization.java b/spark/v3.3/spark/src/test/java/org/apache/iceberg/TestDataFileSerialization.java new file mode 100644 index 000000000000..12fa8b2fc539 --- /dev/null +++ b/spark/v3.3/spark/src/test/java/org/apache/iceberg/TestDataFileSerialization.java @@ -0,0 +1,170 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iceberg; + +import com.esotericsoftware.kryo.Kryo; +import com.esotericsoftware.kryo.io.Input; +import com.esotericsoftware.kryo.io.Output; +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.File; +import java.io.FileInputStream; +import java.io.FileOutputStream; +import java.io.IOException; +import java.io.ObjectInputStream; +import java.io.ObjectOutputStream; +import java.nio.ByteBuffer; +import java.nio.ByteOrder; +import java.util.Map; +import java.util.UUID; +import org.apache.iceberg.io.FileAppender; +import org.apache.iceberg.parquet.Parquet; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableList; +import org.apache.iceberg.relocated.com.google.common.collect.Maps; +import org.apache.iceberg.spark.SparkSchemaUtil; +import org.apache.iceberg.spark.data.RandomData; +import org.apache.iceberg.spark.data.SparkParquetWriters; +import org.apache.iceberg.types.Types; +import org.apache.spark.SparkConf; +import org.apache.spark.serializer.KryoSerializer; +import org.apache.spark.sql.catalyst.InternalRow; +import org.assertj.core.api.Assertions; +import org.junit.Assert; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.TemporaryFolder; + +import static org.apache.iceberg.TaskCheckHelper.assertEquals; +import static org.apache.iceberg.types.Types.NestedField.optional; +import static org.apache.iceberg.types.Types.NestedField.required; + +public class TestDataFileSerialization { + + private static final Schema DATE_SCHEMA = new Schema( + required(1, "id", Types.LongType.get()), + optional(2, "data", Types.StringType.get()), + required(3, "date", Types.StringType.get()), + optional(4, "double", Types.DoubleType.get())); + + private static final PartitionSpec PARTITION_SPEC = PartitionSpec + .builderFor(DATE_SCHEMA) + .identity("date") + .build(); + + private static final Map VALUE_COUNTS = Maps.newHashMap(); + private static final Map NULL_VALUE_COUNTS = Maps.newHashMap(); + private static final Map NAN_VALUE_COUNTS = Maps.newHashMap(); + private static final Map LOWER_BOUNDS = Maps.newHashMap(); + private static final Map UPPER_BOUNDS = Maps.newHashMap(); + + static { + VALUE_COUNTS.put(1, 5L); + VALUE_COUNTS.put(2, 3L); + VALUE_COUNTS.put(4, 2L); + NULL_VALUE_COUNTS.put(1, 0L); + NULL_VALUE_COUNTS.put(2, 2L); + NAN_VALUE_COUNTS.put(4, 1L); + LOWER_BOUNDS.put(1, longToBuffer(0L)); + UPPER_BOUNDS.put(1, longToBuffer(4L)); + } + + private static final DataFile DATA_FILE = DataFiles + .builder(PARTITION_SPEC) + .withPath("/path/to/data-1.parquet") + .withFileSizeInBytes(1234) + .withPartitionPath("date=2018-06-08") + .withMetrics(new Metrics( + 5L, null, VALUE_COUNTS, NULL_VALUE_COUNTS, NAN_VALUE_COUNTS, LOWER_BOUNDS, UPPER_BOUNDS)) + .withSplitOffsets(ImmutableList.of(4L)) + .withEncryptionKeyMetadata(ByteBuffer.allocate(4).putInt(34)) + .withSortOrder(SortOrder.unsorted()) + .build(); + + @Rule + public TemporaryFolder temp = new TemporaryFolder(); + + @Test + public void testDataFileKryoSerialization() throws Exception { + File data = temp.newFile(); + Assert.assertTrue(data.delete()); + Kryo kryo = new KryoSerializer(new SparkConf()).newKryo(); + + try (Output out = new Output(new FileOutputStream(data))) { + kryo.writeClassAndObject(out, DATA_FILE); + kryo.writeClassAndObject(out, DATA_FILE.copy()); + } + + try (Input in = new Input(new FileInputStream(data))) { + for (int i = 0; i < 2; i += 1) { + Object obj = kryo.readClassAndObject(in); + Assertions.assertThat(obj).as("Should be a DataFile").isInstanceOf(DataFile.class); + assertEquals(DATA_FILE, (DataFile) obj); + } + } + } + + @Test + public void testDataFileJavaSerialization() throws Exception { + ByteArrayOutputStream bytes = new ByteArrayOutputStream(); + try (ObjectOutputStream out = new ObjectOutputStream(bytes)) { + out.writeObject(DATA_FILE); + out.writeObject(DATA_FILE.copy()); + } + + try (ObjectInputStream in = new ObjectInputStream(new ByteArrayInputStream(bytes.toByteArray()))) { + for (int i = 0; i < 2; i += 1) { + Object obj = in.readObject(); + Assertions.assertThat(obj).as("Should be a DataFile").isInstanceOf(DataFile.class); + assertEquals(DATA_FILE, (DataFile) obj); + } + } + } + + @Test + public void testParquetWriterSplitOffsets() throws IOException { + Iterable records = RandomData.generateSpark(DATE_SCHEMA, 1, 33L); + File parquetFile = new File( + temp.getRoot(), + FileFormat.PARQUET.addExtension(UUID.randomUUID().toString())); + FileAppender writer = + Parquet.write(Files.localOutput(parquetFile)) + .schema(DATE_SCHEMA) + .createWriterFunc(msgType -> SparkParquetWriters.buildWriter(SparkSchemaUtil.convert(DATE_SCHEMA), msgType)) + .build(); + try { + writer.addAll(records); + } finally { + writer.close(); + } + + Kryo kryo = new KryoSerializer(new SparkConf()).newKryo(); + File dataFile = temp.newFile(); + try (Output out = new Output(new FileOutputStream(dataFile))) { + kryo.writeClassAndObject(out, writer.splitOffsets()); + } + try (Input in = new Input(new FileInputStream(dataFile))) { + kryo.readClassAndObject(in); + } + } + + private static ByteBuffer longToBuffer(long value) { + return ByteBuffer.allocate(8).order(ByteOrder.LITTLE_ENDIAN).putLong(0, value); + } +} diff --git a/spark/v3.3/spark/src/test/java/org/apache/iceberg/TestFileIOSerialization.java b/spark/v3.3/spark/src/test/java/org/apache/iceberg/TestFileIOSerialization.java new file mode 100644 index 000000000000..7c14485ff9a0 --- /dev/null +++ b/spark/v3.3/spark/src/test/java/org/apache/iceberg/TestFileIOSerialization.java @@ -0,0 +1,113 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iceberg; + +import java.io.File; +import java.io.IOException; +import java.util.Map; +import org.apache.hadoop.conf.Configuration; +import org.apache.iceberg.hadoop.HadoopFileIO; +import org.apache.iceberg.hadoop.HadoopTables; +import org.apache.iceberg.io.FileIO; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap; +import org.apache.iceberg.relocated.com.google.common.collect.Maps; +import org.apache.iceberg.types.Types; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.TemporaryFolder; + +import static org.apache.iceberg.types.Types.NestedField.optional; +import static org.apache.iceberg.types.Types.NestedField.required; + +public class TestFileIOSerialization { + + private static final Configuration CONF = new Configuration(); + private static final HadoopTables TABLES = new HadoopTables(CONF); + + private static final Schema SCHEMA = new Schema( + required(1, "id", Types.LongType.get()), + optional(2, "data", Types.StringType.get()), + required(3, "date", Types.StringType.get()), + optional(4, "double", Types.DoubleType.get())); + + private static final PartitionSpec SPEC = PartitionSpec + .builderFor(SCHEMA) + .identity("date") + .build(); + + private static final SortOrder SORT_ORDER = SortOrder.builderFor(SCHEMA) + .asc("id") + .build(); + + static { + CONF.set("k1", "v1"); + CONF.set("k2", "v2"); + } + + @Rule + public TemporaryFolder temp = new TemporaryFolder(); + private Table table; + + @Before + public void initTable() throws IOException { + Map props = ImmutableMap.of("k1", "v1", "k2", "v2"); + + File tableLocation = temp.newFolder(); + Assert.assertTrue(tableLocation.delete()); + + this.table = TABLES.create(SCHEMA, SPEC, SORT_ORDER, props, tableLocation.toString()); + } + + @Test + public void testHadoopFileIOKryoSerialization() throws IOException { + FileIO io = table.io(); + Configuration expectedConf = ((HadoopFileIO) io).conf(); + + Table serializableTable = SerializableTable.copyOf(table); + FileIO deserializedIO = KryoHelpers.roundTripSerialize(serializableTable.io()); + Configuration actualConf = ((HadoopFileIO) deserializedIO).conf(); + + Assert.assertEquals("Conf pairs must match", toMap(expectedConf), toMap(actualConf)); + Assert.assertEquals("Conf values must be present", "v1", actualConf.get("k1")); + Assert.assertEquals("Conf values must be present", "v2", actualConf.get("k2")); + } + + @Test + public void testHadoopFileIOJavaSerialization() throws IOException, ClassNotFoundException { + FileIO io = table.io(); + Configuration expectedConf = ((HadoopFileIO) io).conf(); + + Table serializableTable = SerializableTable.copyOf(table); + FileIO deserializedIO = TestHelpers.roundTripSerialize(serializableTable.io()); + Configuration actualConf = ((HadoopFileIO) deserializedIO).conf(); + + Assert.assertEquals("Conf pairs must match", toMap(expectedConf), toMap(actualConf)); + Assert.assertEquals("Conf values must be present", "v1", actualConf.get("k1")); + Assert.assertEquals("Conf values must be present", "v2", actualConf.get("k2")); + } + + private Map toMap(Configuration conf) { + Map map = Maps.newHashMapWithExpectedSize(conf.size()); + conf.forEach(entry -> map.put(entry.getKey(), entry.getValue())); + return map; + } +} diff --git a/spark/v3.3/spark/src/test/java/org/apache/iceberg/TestHadoopMetricsContextSerialization.java b/spark/v3.3/spark/src/test/java/org/apache/iceberg/TestHadoopMetricsContextSerialization.java new file mode 100644 index 000000000000..e1003a175f70 --- /dev/null +++ b/spark/v3.3/spark/src/test/java/org/apache/iceberg/TestHadoopMetricsContextSerialization.java @@ -0,0 +1,56 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iceberg; + +import java.io.IOException; +import org.apache.iceberg.hadoop.HadoopMetricsContext; +import org.apache.iceberg.io.FileIOMetricsContext; +import org.apache.iceberg.metrics.MetricsContext; +import org.apache.iceberg.relocated.com.google.common.collect.Maps; +import org.junit.Test; + +public class TestHadoopMetricsContextSerialization { + + @Test(expected = Test.None.class) + public void testHadoopMetricsContextKryoSerialization() throws IOException { + MetricsContext metricsContext = new HadoopMetricsContext("s3"); + + metricsContext.initialize(Maps.newHashMap()); + + MetricsContext deserializedMetricContext = KryoHelpers.roundTripSerialize(metricsContext); + // statistics are properly re-initialized post de-serialization + deserializedMetricContext + .counter(FileIOMetricsContext.WRITE_BYTES, Long.class, MetricsContext.Unit.BYTES) + .increment(); + } + + @Test(expected = Test.None.class) + public void testHadoopMetricsContextJavaSerialization() throws IOException, ClassNotFoundException { + MetricsContext metricsContext = new HadoopMetricsContext("s3"); + + metricsContext.initialize(Maps.newHashMap()); + + MetricsContext deserializedMetricContext = TestHelpers.roundTripSerialize(metricsContext); + // statistics are properly re-initialized post de-serialization + deserializedMetricContext + .counter(FileIOMetricsContext.WRITE_BYTES, Long.class, MetricsContext.Unit.BYTES) + .increment(); + } +} diff --git a/spark/v3.3/spark/src/test/java/org/apache/iceberg/TestManifestFileSerialization.java b/spark/v3.3/spark/src/test/java/org/apache/iceberg/TestManifestFileSerialization.java new file mode 100644 index 000000000000..25004aa110e4 --- /dev/null +++ b/spark/v3.3/spark/src/test/java/org/apache/iceberg/TestManifestFileSerialization.java @@ -0,0 +1,194 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iceberg; + +import com.esotericsoftware.kryo.Kryo; +import com.esotericsoftware.kryo.io.Input; +import com.esotericsoftware.kryo.io.Output; +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.File; +import java.io.FileInputStream; +import java.io.FileOutputStream; +import java.io.IOException; +import java.io.ObjectInputStream; +import java.io.ObjectOutputStream; +import java.nio.ByteBuffer; +import java.nio.ByteOrder; +import org.apache.hadoop.conf.Configuration; +import org.apache.iceberg.ManifestFile.PartitionFieldSummary; +import org.apache.iceberg.hadoop.HadoopFileIO; +import org.apache.iceberg.io.FileIO; +import org.apache.iceberg.io.OutputFile; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap; +import org.apache.iceberg.types.Types; +import org.apache.spark.SparkConf; +import org.apache.spark.serializer.KryoSerializer; +import org.assertj.core.api.Assertions; +import org.junit.Assert; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.TemporaryFolder; + +import static org.apache.iceberg.types.Types.NestedField.optional; +import static org.apache.iceberg.types.Types.NestedField.required; + +public class TestManifestFileSerialization { + + private static final Schema SCHEMA = new Schema( + required(1, "id", Types.LongType.get()), + optional(2, "data", Types.StringType.get()), + required(3, "date", Types.StringType.get()), + required(4, "double", Types.DoubleType.get())); + + private static final PartitionSpec SPEC = PartitionSpec + .builderFor(SCHEMA) + .identity("double") + .build(); + + private static final DataFile FILE_A = DataFiles.builder(SPEC) + .withPath("/path/to/data-1.parquet") + .withFileSizeInBytes(0) + .withPartition(TestHelpers.Row.of(1D)) + .withPartitionPath("double=1") + .withMetrics(new Metrics(5L, + null, // no column sizes + ImmutableMap.of(1, 5L, 2, 3L), // value count + ImmutableMap.of(1, 0L, 2, 2L), // null count + ImmutableMap.of(), // nan count + ImmutableMap.of(1, longToBuffer(0L)), // lower bounds + ImmutableMap.of(1, longToBuffer(4L)) // upper bounds + )) + .build(); + + private static final DataFile FILE_B = DataFiles.builder(SPEC) + .withPath("/path/to/data-2.parquet") + .withFileSizeInBytes(0) + .withPartition(TestHelpers.Row.of(Double.NaN)) + .withPartitionPath("double=NaN") + .withMetrics(new Metrics(1L, + null, // no column sizes + ImmutableMap.of(1, 1L, 4, 1L), // value count + ImmutableMap.of(1, 0L, 2, 0L), // null count + ImmutableMap.of(4, 1L), // nan count + ImmutableMap.of(1, longToBuffer(0L)), // lower bounds + ImmutableMap.of(1, longToBuffer(1L)) // upper bounds + )) + .build(); + + private static final FileIO FILE_IO = new HadoopFileIO(new Configuration()); + + @Rule + public TemporaryFolder temp = new TemporaryFolder(); + + @Test + public void testManifestFileKryoSerialization() throws IOException { + File data = temp.newFile(); + Assert.assertTrue(data.delete()); + + Kryo kryo = new KryoSerializer(new SparkConf()).newKryo(); + + ManifestFile manifest = writeManifest(FILE_A, FILE_B); + + try (Output out = new Output(new FileOutputStream(data))) { + kryo.writeClassAndObject(out, manifest); + kryo.writeClassAndObject(out, manifest.copy()); + kryo.writeClassAndObject(out, GenericManifestFile.copyOf(manifest).build()); + } + + try (Input in = new Input(new FileInputStream(data))) { + for (int i = 0; i < 3; i += 1) { + Object obj = kryo.readClassAndObject(in); + Assertions.assertThat(obj).as("Should be a ManifestFile").isInstanceOf(ManifestFile.class); + checkManifestFile(manifest, (ManifestFile) obj); + } + } + } + + @Test + public void testManifestFileJavaSerialization() throws Exception { + ByteArrayOutputStream bytes = new ByteArrayOutputStream(); + + ManifestFile manifest = writeManifest(FILE_A, FILE_B); + + try (ObjectOutputStream out = new ObjectOutputStream(bytes)) { + out.writeObject(manifest); + out.writeObject(manifest.copy()); + out.writeObject(GenericManifestFile.copyOf(manifest).build()); + } + + try (ObjectInputStream in = new ObjectInputStream(new ByteArrayInputStream(bytes.toByteArray()))) { + for (int i = 0; i < 3; i += 1) { + Object obj = in.readObject(); + Assertions.assertThat(obj).as("Should be a ManifestFile").isInstanceOf(ManifestFile.class); + checkManifestFile(manifest, (ManifestFile) obj); + } + } + } + + private void checkManifestFile(ManifestFile expected, ManifestFile actual) { + Assert.assertEquals("Path must match", expected.path(), actual.path()); + Assert.assertEquals("Length must match", expected.length(), actual.length()); + Assert.assertEquals("Spec id must match", expected.partitionSpecId(), actual.partitionSpecId()); + Assert.assertEquals("Snapshot id must match", expected.snapshotId(), actual.snapshotId()); + Assert.assertEquals("Added files flag must match", expected.hasAddedFiles(), actual.hasAddedFiles()); + Assert.assertEquals("Added files count must match", expected.addedFilesCount(), actual.addedFilesCount()); + Assert.assertEquals("Added rows count must match", expected.addedRowsCount(), actual.addedRowsCount()); + Assert.assertEquals("Existing files flag must match", expected.hasExistingFiles(), actual.hasExistingFiles()); + Assert.assertEquals("Existing files count must match", expected.existingFilesCount(), actual.existingFilesCount()); + Assert.assertEquals("Existing rows count must match", expected.existingRowsCount(), actual.existingRowsCount()); + Assert.assertEquals("Deleted files flag must match", expected.hasDeletedFiles(), actual.hasDeletedFiles()); + Assert.assertEquals("Deleted files count must match", expected.deletedFilesCount(), actual.deletedFilesCount()); + Assert.assertEquals("Deleted rows count must match", expected.deletedRowsCount(), actual.deletedRowsCount()); + + PartitionFieldSummary expectedPartition = expected.partitions().get(0); + PartitionFieldSummary actualPartition = actual.partitions().get(0); + + Assert.assertEquals("Null flag in partition must match", + expectedPartition.containsNull(), actualPartition.containsNull()); + Assert.assertEquals("NaN flag in partition must match", + expectedPartition.containsNaN(), actualPartition.containsNaN()); + Assert.assertEquals("Lower bounds in partition must match", + expectedPartition.lowerBound(), actualPartition.lowerBound()); + Assert.assertEquals("Upper bounds in partition must match", + expectedPartition.upperBound(), actualPartition.upperBound()); + } + + private ManifestFile writeManifest(DataFile... files) throws IOException { + File manifestFile = temp.newFile("input.m0.avro"); + Assert.assertTrue(manifestFile.delete()); + OutputFile outputFile = FILE_IO.newOutputFile(manifestFile.getCanonicalPath()); + + ManifestWriter writer = ManifestFiles.write(SPEC, outputFile); + try { + for (DataFile file : files) { + writer.add(file); + } + } finally { + writer.close(); + } + + return writer.toManifestFile(); + } + + private static ByteBuffer longToBuffer(long value) { + return ByteBuffer.allocate(8).order(ByteOrder.LITTLE_ENDIAN).putLong(0, value); + } +} diff --git a/spark/v3.3/spark/src/test/java/org/apache/iceberg/TestScanTaskSerialization.java b/spark/v3.3/spark/src/test/java/org/apache/iceberg/TestScanTaskSerialization.java new file mode 100644 index 000000000000..1fe736b51aad --- /dev/null +++ b/spark/v3.3/spark/src/test/java/org/apache/iceberg/TestScanTaskSerialization.java @@ -0,0 +1,199 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iceberg; + +import com.esotericsoftware.kryo.Kryo; +import com.esotericsoftware.kryo.io.Input; +import com.esotericsoftware.kryo.io.Output; +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.File; +import java.io.FileInputStream; +import java.io.FileOutputStream; +import java.io.ObjectInputStream; +import java.io.ObjectOutputStream; +import java.nio.file.Files; +import java.util.List; +import java.util.Map; +import org.apache.hadoop.conf.Configuration; +import org.apache.iceberg.hadoop.HadoopTables; +import org.apache.iceberg.io.CloseableIterable; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableList; +import org.apache.iceberg.relocated.com.google.common.collect.Lists; +import org.apache.iceberg.relocated.com.google.common.collect.Maps; +import org.apache.iceberg.spark.SparkTestBase; +import org.apache.iceberg.spark.source.ThreeColumnRecord; +import org.apache.iceberg.types.Types; +import org.apache.spark.SparkConf; +import org.apache.spark.serializer.KryoSerializer; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Row; +import org.assertj.core.api.Assertions; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.TemporaryFolder; + +import static org.apache.iceberg.types.Types.NestedField.optional; + +public class TestScanTaskSerialization extends SparkTestBase { + + private static final HadoopTables TABLES = new HadoopTables(new Configuration()); + private static final Schema SCHEMA = new Schema( + optional(1, "c1", Types.IntegerType.get()), + optional(2, "c2", Types.StringType.get()), + optional(3, "c3", Types.StringType.get()) + ); + + @Rule + public TemporaryFolder temp = new TemporaryFolder(); + + private String tableLocation = null; + + @Before + public void setupTableLocation() throws Exception { + File tableDir = temp.newFolder(); + this.tableLocation = tableDir.toURI().toString(); + } + + @Test + public void testBaseCombinedScanTaskKryoSerialization() throws Exception { + BaseCombinedScanTask scanTask = prepareBaseCombinedScanTaskForSerDeTest(); + + File data = temp.newFile(); + Assert.assertTrue(data.delete()); + Kryo kryo = new KryoSerializer(new SparkConf()).newKryo(); + + try (Output out = new Output(new FileOutputStream(data))) { + kryo.writeClassAndObject(out, scanTask); + } + + try (Input in = new Input(new FileInputStream(data))) { + Object obj = kryo.readClassAndObject(in); + Assertions.assertThat(obj).as("Should be a BaseCombinedScanTask").isInstanceOf(BaseCombinedScanTask.class); + TaskCheckHelper.assertEquals(scanTask, (BaseCombinedScanTask) obj); + } + } + + @Test + public void testBaseCombinedScanTaskJavaSerialization() throws Exception { + BaseCombinedScanTask scanTask = prepareBaseCombinedScanTaskForSerDeTest(); + + ByteArrayOutputStream bytes = new ByteArrayOutputStream(); + try (ObjectOutputStream out = new ObjectOutputStream(bytes)) { + out.writeObject(scanTask); + } + + try (ObjectInputStream in = new ObjectInputStream(new ByteArrayInputStream(bytes.toByteArray()))) { + Object obj = in.readObject(); + Assertions.assertThat(obj).as("Should be a BaseCombinedScanTask").isInstanceOf(BaseCombinedScanTask.class); + TaskCheckHelper.assertEquals(scanTask, (BaseCombinedScanTask) obj); + } + } + + @Test + @SuppressWarnings("unchecked") + public void testBaseScanTaskGroupKryoSerialization() throws Exception { + BaseScanTaskGroup taskGroup = prepareBaseScanTaskGroupForSerDeTest(); + + Assert.assertTrue("Task group can't be empty", taskGroup.tasks().size() > 0); + + File data = temp.newFile(); + Assert.assertTrue(data.delete()); + Kryo kryo = new KryoSerializer(new SparkConf()).newKryo(); + + try (Output out = new Output(Files.newOutputStream(data.toPath()))) { + kryo.writeClassAndObject(out, taskGroup); + } + + try (Input in = new Input(Files.newInputStream(data.toPath()))) { + Object obj = kryo.readClassAndObject(in); + Assertions.assertThat(obj).as("should be a BaseScanTaskGroup").isInstanceOf(BaseScanTaskGroup.class); + TaskCheckHelper.assertEquals(taskGroup, (BaseScanTaskGroup) obj); + } + } + + @Test + @SuppressWarnings("unchecked") + public void testBaseScanTaskGroupJavaSerialization() throws Exception { + BaseScanTaskGroup taskGroup = prepareBaseScanTaskGroupForSerDeTest(); + + Assert.assertTrue("Task group can't be empty", taskGroup.tasks().size() > 0); + + ByteArrayOutputStream bytes = new ByteArrayOutputStream(); + try (ObjectOutputStream out = new ObjectOutputStream(bytes)) { + out.writeObject(taskGroup); + } + + try (ObjectInputStream in = new ObjectInputStream(new ByteArrayInputStream(bytes.toByteArray()))) { + Object obj = in.readObject(); + Assertions.assertThat(obj).as("should be a BaseScanTaskGroup").isInstanceOf(BaseScanTaskGroup.class); + TaskCheckHelper.assertEquals(taskGroup, (BaseScanTaskGroup) obj); + } + } + + private BaseCombinedScanTask prepareBaseCombinedScanTaskForSerDeTest() { + Table table = initTable(); + CloseableIterable tasks = table.newScan().planFiles(); + return new BaseCombinedScanTask(Lists.newArrayList(tasks)); + } + + private BaseScanTaskGroup prepareBaseScanTaskGroupForSerDeTest() { + Table table = initTable(); + CloseableIterable tasks = table.newScan().planFiles(); + return new BaseScanTaskGroup<>(ImmutableList.copyOf(tasks)); + } + + private Table initTable() { + PartitionSpec spec = PartitionSpec.unpartitioned(); + Map options = Maps.newHashMap(); + Table table = TABLES.create(SCHEMA, spec, options, tableLocation); + + List records1 = Lists.newArrayList( + new ThreeColumnRecord(1, null, "AAAA"), + new ThreeColumnRecord(1, "BBBBBBBBBB", "BBBB") + ); + writeRecords(records1); + + List records2 = Lists.newArrayList( + new ThreeColumnRecord(2, "CCCCCCCCCC", "CCCC"), + new ThreeColumnRecord(2, "DDDDDDDDDD", "DDDD") + ); + writeRecords(records2); + + table.refresh(); + + return table; + } + + private void writeRecords(List records) { + Dataset df = spark.createDataFrame(records, ThreeColumnRecord.class); + writeDF(df); + } + + private void writeDF(Dataset df) { + df.select("c1", "c2", "c3") + .write() + .format("iceberg") + .mode("append") + .save(tableLocation); + } +} diff --git a/spark/v3.3/spark/src/test/java/org/apache/iceberg/TestTableSerialization.java b/spark/v3.3/spark/src/test/java/org/apache/iceberg/TestTableSerialization.java new file mode 100644 index 000000000000..88e30c0ece68 --- /dev/null +++ b/spark/v3.3/spark/src/test/java/org/apache/iceberg/TestTableSerialization.java @@ -0,0 +1,102 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iceberg; + +import java.io.File; +import java.io.IOException; +import java.util.Map; +import org.apache.iceberg.hadoop.HadoopTables; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap; +import org.apache.iceberg.types.Types; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.TemporaryFolder; + +import static org.apache.iceberg.types.Types.NestedField.optional; +import static org.apache.iceberg.types.Types.NestedField.required; + +public class TestTableSerialization { + + private static final HadoopTables TABLES = new HadoopTables(); + + private static final Schema SCHEMA = new Schema( + required(1, "id", Types.LongType.get()), + optional(2, "data", Types.StringType.get()), + required(3, "date", Types.StringType.get()), + optional(4, "double", Types.DoubleType.get())); + + private static final PartitionSpec SPEC = PartitionSpec + .builderFor(SCHEMA) + .identity("date") + .build(); + + private static final SortOrder SORT_ORDER = SortOrder.builderFor(SCHEMA) + .asc("id") + .build(); + + @Rule + public TemporaryFolder temp = new TemporaryFolder(); + private Table table; + + @Before + public void initTable() throws IOException { + Map props = ImmutableMap.of("k1", "v1", "k2", "v2"); + + File tableLocation = temp.newFolder(); + Assert.assertTrue(tableLocation.delete()); + + this.table = TABLES.create(SCHEMA, SPEC, SORT_ORDER, props, tableLocation.toString()); + } + + @Test + public void testSerializableTableKryoSerialization() throws IOException { + Table serializableTable = SerializableTable.copyOf(table); + TestHelpers.assertSerializedAndLoadedMetadata(table, KryoHelpers.roundTripSerialize(serializableTable)); + } + + @Test + public void testSerializableMetadataTableKryoSerialization() throws IOException { + for (MetadataTableType type : MetadataTableType.values()) { + TableOperations ops = ((HasTableOperations) table).operations(); + Table metadataTable = MetadataTableUtils.createMetadataTableInstance(ops, table.name(), "meta", type); + Table serializableMetadataTable = SerializableTable.copyOf(metadataTable); + + TestHelpers.assertSerializedAndLoadedMetadata( + metadataTable, + KryoHelpers.roundTripSerialize(serializableMetadataTable)); + } + } + + @Test + public void testSerializableTransactionTableKryoSerialization() throws IOException { + Transaction txn = table.newTransaction(); + + txn.updateProperties() + .set("k1", "v1") + .commit(); + + Table txnTable = txn.table(); + Table serializableTxnTable = SerializableTable.copyOf(txnTable); + + TestHelpers.assertSerializedMetadata(txnTable, KryoHelpers.roundTripSerialize(serializableTxnTable)); + } +} diff --git a/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/SparkCatalogConfig.java b/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/SparkCatalogConfig.java new file mode 100644 index 000000000000..5d5dfebf9532 --- /dev/null +++ b/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/SparkCatalogConfig.java @@ -0,0 +1,61 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iceberg.spark; + +import java.util.Map; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap; + +public enum SparkCatalogConfig { + HIVE("testhive", SparkCatalog.class.getName(), ImmutableMap.of( + "type", "hive", + "default-namespace", "default" + )), + HADOOP("testhadoop", SparkCatalog.class.getName(), ImmutableMap.of( + "type", "hadoop" + )), + SPARK("spark_catalog", SparkSessionCatalog.class.getName(), ImmutableMap.of( + "type", "hive", + "default-namespace", "default", + "parquet-enabled", "true", + "cache-enabled", "false" // Spark will delete tables using v1, leaving the cache out of sync + )); + + private final String catalogName; + private final String implementation; + private final Map properties; + + SparkCatalogConfig(String catalogName, String implementation, Map properties) { + this.catalogName = catalogName; + this.implementation = implementation; + this.properties = properties; + } + + public String catalogName() { + return catalogName; + } + + public String implementation() { + return implementation; + } + + public Map properties() { + return properties; + } +} diff --git a/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/SparkCatalogTestBase.java b/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/SparkCatalogTestBase.java new file mode 100644 index 000000000000..774e81328b2b --- /dev/null +++ b/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/SparkCatalogTestBase.java @@ -0,0 +1,59 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iceberg.spark; + +import java.util.Map; +import org.junit.Rule; +import org.junit.rules.TemporaryFolder; +import org.junit.runner.RunWith; +import org.junit.runners.Parameterized; + +@RunWith(Parameterized.class) +public abstract class SparkCatalogTestBase extends SparkTestBaseWithCatalog { + + // these parameters are broken out to avoid changes that need to modify lots of test suites + @Parameterized.Parameters(name = "catalogName = {0}, implementation = {1}, config = {2}") + public static Object[][] parameters() { + return new Object[][] {{ + SparkCatalogConfig.HIVE.catalogName(), + SparkCatalogConfig.HIVE.implementation(), + SparkCatalogConfig.HIVE.properties() + }, { + SparkCatalogConfig.HADOOP.catalogName(), + SparkCatalogConfig.HADOOP.implementation(), + SparkCatalogConfig.HADOOP.properties() + }, { + SparkCatalogConfig.SPARK.catalogName(), + SparkCatalogConfig.SPARK.implementation(), + SparkCatalogConfig.SPARK.properties() + }}; + } + + @Rule + public TemporaryFolder temp = new TemporaryFolder(); + + public SparkCatalogTestBase(SparkCatalogConfig config) { + super(config); + } + + public SparkCatalogTestBase(String catalogName, String implementation, Map config) { + super(catalogName, implementation, config); + } +} diff --git a/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/SparkTestBase.java b/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/SparkTestBase.java new file mode 100644 index 000000000000..8022e8696b63 --- /dev/null +++ b/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/SparkTestBase.java @@ -0,0 +1,254 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iceberg.spark; + +import java.io.IOException; +import java.io.UncheckedIOException; +import java.net.URI; +import java.nio.file.Files; +import java.nio.file.Path; +import java.nio.file.Paths; +import java.util.List; +import java.util.Map; +import java.util.stream.Collectors; +import java.util.stream.IntStream; +import org.apache.hadoop.hive.conf.HiveConf; +import org.apache.iceberg.CatalogUtil; +import org.apache.iceberg.ContentFile; +import org.apache.iceberg.catalog.Namespace; +import org.apache.iceberg.exceptions.AlreadyExistsException; +import org.apache.iceberg.hive.HiveCatalog; +import org.apache.iceberg.hive.TestHiveMetastore; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableList; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap; +import org.apache.iceberg.relocated.com.google.common.collect.Iterables; +import org.apache.iceberg.relocated.com.google.common.collect.Maps; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Encoders; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.SparkSession; +import org.apache.spark.sql.internal.SQLConf; +import org.junit.AfterClass; +import org.junit.Assert; +import org.junit.BeforeClass; + +import static org.apache.hadoop.hive.conf.HiveConf.ConfVars.METASTOREURIS; + +public abstract class SparkTestBase { + + protected static final Object ANY = new Object(); + + protected static TestHiveMetastore metastore = null; + protected static HiveConf hiveConf = null; + protected static SparkSession spark = null; + protected static HiveCatalog catalog = null; + + @BeforeClass + public static void startMetastoreAndSpark() { + SparkTestBase.metastore = new TestHiveMetastore(); + metastore.start(); + SparkTestBase.hiveConf = metastore.hiveConf(); + + SparkTestBase.spark = SparkSession.builder() + .master("local[2]") + .config(SQLConf.PARTITION_OVERWRITE_MODE().key(), "dynamic") + .config("spark.hadoop." + METASTOREURIS.varname, hiveConf.get(METASTOREURIS.varname)) + .enableHiveSupport() + .getOrCreate(); + + SparkTestBase.catalog = (HiveCatalog) + CatalogUtil.loadCatalog(HiveCatalog.class.getName(), "hive", ImmutableMap.of(), hiveConf); + + try { + catalog.createNamespace(Namespace.of("default")); + } catch (AlreadyExistsException ignored) { + // the default namespace already exists. ignore the create error + } + } + + @AfterClass + public static void stopMetastoreAndSpark() throws Exception { + SparkTestBase.catalog = null; + metastore.stop(); + SparkTestBase.metastore = null; + spark.stop(); + SparkTestBase.spark = null; + } + + protected long waitUntilAfter(long timestampMillis) { + long current = System.currentTimeMillis(); + while (current <= timestampMillis) { + current = System.currentTimeMillis(); + } + return current; + } + + protected List sql(String query, Object... args) { + List rows = spark.sql(String.format(query, args)).collectAsList(); + if (rows.size() < 1) { + return ImmutableList.of(); + } + + return rowsToJava(rows); + } + + protected List rowsToJava(List rows) { + return rows.stream().map(this::toJava).collect(Collectors.toList()); + } + + private Object[] toJava(Row row) { + return IntStream.range(0, row.size()) + .mapToObj(pos -> { + if (row.isNullAt(pos)) { + return null; + } + + Object value = row.get(pos); + if (value instanceof Row) { + return toJava((Row) value); + } else if (value instanceof scala.collection.Seq) { + return row.getList(pos); + } else if (value instanceof scala.collection.Map) { + return row.getJavaMap(pos); + } else { + return value; + } + }) + .toArray(Object[]::new); + } + + protected Object scalarSql(String query, Object... args) { + List rows = sql(query, args); + Assert.assertEquals("Scalar SQL should return one row", 1, rows.size()); + Object[] row = Iterables.getOnlyElement(rows); + Assert.assertEquals("Scalar SQL should return one value", 1, row.length); + return row[0]; + } + + protected Object[] row(Object... values) { + return values; + } + + protected void assertEquals(String context, List expectedRows, List actualRows) { + Assert.assertEquals(context + ": number of results should match", expectedRows.size(), actualRows.size()); + for (int row = 0; row < expectedRows.size(); row += 1) { + Object[] expected = expectedRows.get(row); + Object[] actual = actualRows.get(row); + Assert.assertEquals("Number of columns should match", expected.length, actual.length); + for (int col = 0; col < actualRows.get(row).length; col += 1) { + String newContext = String.format("%s: row %d col %d", context, row + 1, col + 1); + assertEquals(newContext, expected, actual); + } + } + } + + private void assertEquals(String context, Object[] expectedRow, Object[] actualRow) { + Assert.assertEquals("Number of columns should match", expectedRow.length, actualRow.length); + for (int col = 0; col < actualRow.length; col += 1) { + Object expectedValue = expectedRow[col]; + Object actualValue = actualRow[col]; + if (expectedValue != null && expectedValue.getClass().isArray()) { + String newContext = String.format("%s (nested col %d)", context, col + 1); + if (expectedValue instanceof byte[]) { + Assert.assertArrayEquals(newContext, (byte[]) expectedValue, (byte[]) actualValue); + } else { + assertEquals(newContext, (Object[]) expectedValue, (Object[]) actualValue); + } + } else if (expectedValue != ANY) { + Assert.assertEquals(context + " contents should match", expectedValue, actualValue); + } + } + } + + protected static String dbPath(String dbName) { + return metastore.getDatabasePath(dbName); + } + + protected void withUnavailableFiles(Iterable> files, Action action) { + Iterable fileLocations = Iterables.transform(files, file -> file.path().toString()); + withUnavailableLocations(fileLocations, action); + } + + private void move(String location, String newLocation) { + Path path = Paths.get(URI.create(location)); + Path tempPath = Paths.get(URI.create(newLocation)); + + try { + Files.move(path, tempPath); + } catch (IOException e) { + throw new UncheckedIOException("Failed to move: " + location, e); + } + } + + protected void withUnavailableLocations(Iterable locations, Action action) { + for (String location : locations) { + move(location, location + "_temp"); + } + + try { + action.invoke(); + } finally { + for (String location : locations) { + move(location + "_temp", location); + } + } + } + + protected void withSQLConf(Map conf, Action action) { + SQLConf sqlConf = SQLConf.get(); + + Map currentConfValues = Maps.newHashMap(); + conf.keySet().forEach(confKey -> { + if (sqlConf.contains(confKey)) { + String currentConfValue = sqlConf.getConfString(confKey); + currentConfValues.put(confKey, currentConfValue); + } + }); + + conf.forEach((confKey, confValue) -> { + if (SQLConf.isStaticConfigKey(confKey)) { + throw new RuntimeException("Cannot modify the value of a static config: " + confKey); + } + sqlConf.setConfString(confKey, confValue); + }); + + try { + action.invoke(); + } finally { + conf.forEach((confKey, confValue) -> { + if (currentConfValues.containsKey(confKey)) { + sqlConf.setConfString(confKey, currentConfValues.get(confKey)); + } else { + sqlConf.unsetConf(confKey); + } + }); + } + } + + protected Dataset jsonToDF(String schema, String... records) { + Dataset jsonDF = spark.createDataset(ImmutableList.copyOf(records), Encoders.STRING()); + return spark.read().schema(schema).json(jsonDF); + } + + @FunctionalInterface + protected interface Action { + void invoke(); + } +} diff --git a/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/SparkTestBaseWithCatalog.java b/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/SparkTestBaseWithCatalog.java new file mode 100644 index 000000000000..fc358f76ae9c --- /dev/null +++ b/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/SparkTestBaseWithCatalog.java @@ -0,0 +1,95 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iceberg.spark; + +import java.io.File; +import java.io.IOException; +import java.util.Map; +import org.apache.hadoop.fs.FileSystem; +import org.apache.hadoop.fs.Path; +import org.apache.iceberg.catalog.Catalog; +import org.apache.iceberg.catalog.Namespace; +import org.apache.iceberg.catalog.SupportsNamespaces; +import org.apache.iceberg.catalog.TableIdentifier; +import org.apache.iceberg.hadoop.HadoopCatalog; +import org.junit.AfterClass; +import org.junit.Assert; +import org.junit.BeforeClass; +import org.junit.Rule; +import org.junit.rules.TemporaryFolder; + +public abstract class SparkTestBaseWithCatalog extends SparkTestBase { + private static File warehouse = null; + + @BeforeClass + public static void createWarehouse() throws IOException { + SparkTestBaseWithCatalog.warehouse = File.createTempFile("warehouse", null); + Assert.assertTrue(warehouse.delete()); + } + + @AfterClass + public static void dropWarehouse() throws IOException { + if (warehouse != null && warehouse.exists()) { + Path warehousePath = new Path(warehouse.getAbsolutePath()); + FileSystem fs = warehousePath.getFileSystem(hiveConf); + Assert.assertTrue("Failed to delete " + warehousePath, fs.delete(warehousePath, true)); + } + } + + @Rule + public TemporaryFolder temp = new TemporaryFolder(); + + protected final String catalogName; + protected final Catalog validationCatalog; + protected final SupportsNamespaces validationNamespaceCatalog; + protected final TableIdentifier tableIdent = TableIdentifier.of(Namespace.of("default"), "table"); + protected final String tableName; + + public SparkTestBaseWithCatalog() { + this(SparkCatalogConfig.HADOOP); + } + + public SparkTestBaseWithCatalog(SparkCatalogConfig config) { + this(config.catalogName(), config.implementation(), config.properties()); + } + + public SparkTestBaseWithCatalog(String catalogName, String implementation, Map config) { + this.catalogName = catalogName; + this.validationCatalog = catalogName.equals("testhadoop") ? + new HadoopCatalog(spark.sessionState().newHadoopConf(), "file:" + warehouse) : + catalog; + this.validationNamespaceCatalog = (SupportsNamespaces) validationCatalog; + + spark.conf().set("spark.sql.catalog." + catalogName, implementation); + config.forEach((key, value) -> spark.conf().set("spark.sql.catalog." + catalogName + "." + key, value)); + + if (config.get("type").equalsIgnoreCase("hadoop")) { + spark.conf().set("spark.sql.catalog." + catalogName + ".warehouse", "file:" + warehouse); + } + + this.tableName = (catalogName.equals("spark_catalog") ? "" : catalogName + ".") + "default.table"; + + sql("CREATE NAMESPACE IF NOT EXISTS default"); + } + + protected String tableName(String name) { + return (catalogName.equals("spark_catalog") ? "" : catalogName + ".") + "default." + name; + } +} diff --git a/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/TestFileRewriteCoordinator.java b/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/TestFileRewriteCoordinator.java new file mode 100644 index 000000000000..8aa5cd6faec1 --- /dev/null +++ b/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/TestFileRewriteCoordinator.java @@ -0,0 +1,252 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iceberg.spark; + +import java.io.IOException; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.UUID; +import java.util.stream.Collectors; +import org.apache.iceberg.DataFile; +import org.apache.iceberg.FileScanTask; +import org.apache.iceberg.Table; +import org.apache.iceberg.io.CloseableIterable; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableSet; +import org.apache.iceberg.relocated.com.google.common.collect.Iterables; +import org.apache.iceberg.relocated.com.google.common.collect.Lists; +import org.apache.iceberg.spark.source.SimpleRecord; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Encoders; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.catalyst.analysis.NoSuchTableException; +import org.junit.After; +import org.junit.Assert; +import org.junit.Test; + +public class TestFileRewriteCoordinator extends SparkCatalogTestBase { + + public TestFileRewriteCoordinator(String catalogName, String implementation, Map config) { + super(catalogName, implementation, config); + } + + @After + public void removeTables() { + sql("DROP TABLE IF EXISTS %s", tableName); + } + + @Test + public void testBinPackRewrite() throws NoSuchTableException, IOException { + sql("CREATE TABLE %s (id INT, data STRING) USING iceberg", tableName); + + Dataset df = newDF(1000); + df.coalesce(1).writeTo(tableName).append(); + df.coalesce(1).writeTo(tableName).append(); + df.coalesce(1).writeTo(tableName).append(); + df.coalesce(1).writeTo(tableName).append(); + + Table table = validationCatalog.loadTable(tableIdent); + Assert.assertEquals("Should produce 4 snapshots", 4, Iterables.size(table.snapshots())); + + Dataset fileDF = spark.read().format("iceberg").load(tableName(tableIdent.name() + ".files")); + List fileSizes = fileDF.select("file_size_in_bytes").as(Encoders.LONG()).collectAsList(); + long avgFileSize = fileSizes.stream().mapToLong(i -> i).sum() / fileSizes.size(); + + try (CloseableIterable fileScanTasks = table.newScan().planFiles()) { + String fileSetID = UUID.randomUUID().toString(); + + FileScanTaskSetManager taskSetManager = FileScanTaskSetManager.get(); + taskSetManager.stageTasks(table, fileSetID, Lists.newArrayList(fileScanTasks)); + + // read and pack original 4 files into 2 splits + Dataset scanDF = spark.read().format("iceberg") + .option(SparkReadOptions.FILE_SCAN_TASK_SET_ID, fileSetID) + .option(SparkReadOptions.SPLIT_SIZE, Long.toString(avgFileSize * 2)) + .option(SparkReadOptions.FILE_OPEN_COST, "0") + .load(tableName); + + // write the packed data into new files where each split becomes a new file + scanDF.writeTo(tableName) + .option(SparkWriteOptions.REWRITTEN_FILE_SCAN_TASK_SET_ID, fileSetID) + .append(); + + // commit the rewrite + FileRewriteCoordinator rewriteCoordinator = FileRewriteCoordinator.get(); + Set rewrittenFiles = taskSetManager.fetchTasks(table, fileSetID).stream() + .map(FileScanTask::file) + .collect(Collectors.toSet()); + Set addedFiles = rewriteCoordinator.fetchNewDataFiles(table, fileSetID); + table.newRewrite().rewriteFiles(rewrittenFiles, addedFiles).commit(); + } + + table.refresh(); + + Map summary = table.currentSnapshot().summary(); + Assert.assertEquals("Deleted files count must match", "4", summary.get("deleted-data-files")); + Assert.assertEquals("Added files count must match", "2", summary.get("added-data-files")); + + Object rowCount = scalarSql("SELECT count(*) FROM %s", tableName); + Assert.assertEquals("Row count must match", 4000L, rowCount); + } + + @Test + public void testSortRewrite() throws NoSuchTableException, IOException { + sql("CREATE TABLE %s (id INT, data STRING) USING iceberg", tableName); + + Dataset df = newDF(1000); + df.coalesce(1).writeTo(tableName).append(); + df.coalesce(1).writeTo(tableName).append(); + df.coalesce(1).writeTo(tableName).append(); + df.coalesce(1).writeTo(tableName).append(); + + Table table = validationCatalog.loadTable(tableIdent); + Assert.assertEquals("Should produce 4 snapshots", 4, Iterables.size(table.snapshots())); + + try (CloseableIterable fileScanTasks = table.newScan().planFiles()) { + String fileSetID = UUID.randomUUID().toString(); + + FileScanTaskSetManager taskSetManager = FileScanTaskSetManager.get(); + taskSetManager.stageTasks(table, fileSetID, Lists.newArrayList(fileScanTasks)); + + // read original 4 files as 4 splits + Dataset scanDF = spark.read().format("iceberg") + .option(SparkReadOptions.FILE_SCAN_TASK_SET_ID, fileSetID) + .option(SparkReadOptions.SPLIT_SIZE, "134217728") + .option(SparkReadOptions.FILE_OPEN_COST, "134217728") + .load(tableName); + + // make sure we disable AQE and set the number of shuffle partitions as the target num files + ImmutableMap sqlConf = ImmutableMap.of( + "spark.sql.shuffle.partitions", "2", + "spark.sql.adaptive.enabled", "false" + ); + + withSQLConf(sqlConf, () -> { + try { + // write new files with sorted records + scanDF.sort("id").writeTo(tableName) + .option(SparkWriteOptions.REWRITTEN_FILE_SCAN_TASK_SET_ID, fileSetID) + .append(); + } catch (NoSuchTableException e) { + throw new RuntimeException("Could not replace files", e); + } + }); + + // commit the rewrite + FileRewriteCoordinator rewriteCoordinator = FileRewriteCoordinator.get(); + Set rewrittenFiles = taskSetManager.fetchTasks(table, fileSetID).stream() + .map(FileScanTask::file) + .collect(Collectors.toSet()); + Set addedFiles = rewriteCoordinator.fetchNewDataFiles(table, fileSetID); + table.newRewrite().rewriteFiles(rewrittenFiles, addedFiles).commit(); + } + + table.refresh(); + + Map summary = table.currentSnapshot().summary(); + Assert.assertEquals("Deleted files count must match", "4", summary.get("deleted-data-files")); + Assert.assertEquals("Added files count must match", "2", summary.get("added-data-files")); + + Object rowCount = scalarSql("SELECT count(*) FROM %s", tableName); + Assert.assertEquals("Row count must match", 4000L, rowCount); + } + + @Test + public void testCommitMultipleRewrites() throws NoSuchTableException, IOException { + sql("CREATE TABLE %s (id INT, data STRING) USING iceberg", tableName); + + Dataset df = newDF(1000); + + // add first two files + df.coalesce(1).writeTo(tableName).append(); + df.coalesce(1).writeTo(tableName).append(); + + Table table = validationCatalog.loadTable(tableIdent); + + String firstFileSetID = UUID.randomUUID().toString(); + long firstFileSetSnapshotId = table.currentSnapshot().snapshotId(); + + FileScanTaskSetManager taskSetManager = FileScanTaskSetManager.get(); + + try (CloseableIterable tasks = table.newScan().planFiles()) { + // stage first 2 files for compaction + taskSetManager.stageTasks(table, firstFileSetID, Lists.newArrayList(tasks)); + } + + // add two more files + df.coalesce(1).writeTo(tableName).append(); + df.coalesce(1).writeTo(tableName).append(); + + table.refresh(); + + String secondFileSetID = UUID.randomUUID().toString(); + + try (CloseableIterable tasks = table.newScan().appendsAfter(firstFileSetSnapshotId).planFiles()) { + // stage 2 more files for compaction + taskSetManager.stageTasks(table, secondFileSetID, Lists.newArrayList(tasks)); + } + + ImmutableSet fileSetIDs = ImmutableSet.of(firstFileSetID, secondFileSetID); + + for (String fileSetID : fileSetIDs) { + // read and pack 2 files into 1 split + Dataset scanDF = spark.read().format("iceberg") + .option(SparkReadOptions.FILE_SCAN_TASK_SET_ID, fileSetID) + .option(SparkReadOptions.SPLIT_SIZE, Long.MAX_VALUE) + .load(tableName); + + // write the combined data as one file + scanDF.writeTo(tableName) + .option(SparkWriteOptions.REWRITTEN_FILE_SCAN_TASK_SET_ID, fileSetID) + .append(); + } + + // commit both rewrites at the same time + FileRewriteCoordinator rewriteCoordinator = FileRewriteCoordinator.get(); + Set rewrittenFiles = fileSetIDs.stream().flatMap(fileSetID -> + taskSetManager.fetchTasks(table, fileSetID).stream()) + .map(FileScanTask::file) + .collect(Collectors.toSet()); + Set addedFiles = fileSetIDs.stream() + .flatMap(fileSetID -> rewriteCoordinator.fetchNewDataFiles(table, fileSetID).stream()) + .collect(Collectors.toSet()); + table.newRewrite().rewriteFiles(rewrittenFiles, addedFiles).commit(); + + table.refresh(); + + Assert.assertEquals("Should produce 5 snapshots", 5, Iterables.size(table.snapshots())); + + Map summary = table.currentSnapshot().summary(); + Assert.assertEquals("Deleted files count must match", "4", summary.get("deleted-data-files")); + Assert.assertEquals("Added files count must match", "2", summary.get("added-data-files")); + + Object rowCount = scalarSql("SELECT count(*) FROM %s", tableName); + Assert.assertEquals("Row count must match", 4000L, rowCount); + } + + private Dataset newDF(int numRecords) { + List data = Lists.newArrayListWithExpectedSize(numRecords); + for (int index = 0; index < numRecords; index++) { + data.add(new SimpleRecord(index, Integer.toString(index))); + } + return spark.createDataFrame(data, SimpleRecord.class); + } +} diff --git a/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/TestSpark3Util.java b/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/TestSpark3Util.java new file mode 100644 index 000000000000..be6130f63741 --- /dev/null +++ b/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/TestSpark3Util.java @@ -0,0 +1,118 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iceberg.spark; + +import org.apache.iceberg.CachingCatalog; +import org.apache.iceberg.Schema; +import org.apache.iceberg.SortOrder; +import org.apache.iceberg.SortOrderParser; +import org.apache.iceberg.Table; +import org.apache.iceberg.catalog.Catalog; +import org.apache.iceberg.types.Types; +import org.junit.Assert; +import org.junit.Test; + +import static org.apache.iceberg.NullOrder.NULLS_FIRST; +import static org.apache.iceberg.NullOrder.NULLS_LAST; +import static org.apache.iceberg.types.Types.NestedField.optional; +import static org.apache.iceberg.types.Types.NestedField.required; + +public class TestSpark3Util extends SparkTestBase { + @Test + public void testDescribeSortOrder() { + Schema schema = new Schema( + required(1, "data", Types.StringType.get()), + required(2, "time", Types.TimestampType.withoutZone()) + ); + + Assert.assertEquals("Sort order isn't correct.", "data DESC NULLS FIRST", + Spark3Util.describe(buildSortOrder("Identity", schema, 1))); + Assert.assertEquals("Sort order isn't correct.", "bucket(1, data) DESC NULLS FIRST", + Spark3Util.describe(buildSortOrder("bucket[1]", schema, 1))); + Assert.assertEquals("Sort order isn't correct.", "truncate(data, 3) DESC NULLS FIRST", + Spark3Util.describe(buildSortOrder("truncate[3]", schema, 1))); + Assert.assertEquals("Sort order isn't correct.", "years(time) DESC NULLS FIRST", + Spark3Util.describe(buildSortOrder("year", schema, 2))); + Assert.assertEquals("Sort order isn't correct.", "months(time) DESC NULLS FIRST", + Spark3Util.describe(buildSortOrder("month", schema, 2))); + Assert.assertEquals("Sort order isn't correct.", "days(time) DESC NULLS FIRST", + Spark3Util.describe(buildSortOrder("day", schema, 2))); + Assert.assertEquals("Sort order isn't correct.", "hours(time) DESC NULLS FIRST", + Spark3Util.describe(buildSortOrder("hour", schema, 2))); + Assert.assertEquals("Sort order isn't correct.", "unknown(data) DESC NULLS FIRST", + Spark3Util.describe(buildSortOrder("unknown", schema, 1))); + + // multiple sort orders + SortOrder multiOrder = SortOrder.builderFor(schema) + .asc("time", NULLS_FIRST) + .asc("data", NULLS_LAST) + .build(); + Assert.assertEquals("Sort order isn't correct.", "time ASC NULLS FIRST, data ASC NULLS LAST", + Spark3Util.describe(multiOrder)); + } + + @Test + public void testDescribeSchema() { + Schema schema = new Schema( + required(1, "data", Types.ListType.ofRequired(2, Types.StringType.get())), + optional(3, "pairs", Types.MapType.ofOptional(4, 5, Types.StringType.get(), Types.LongType.get())), + required(6, "time", Types.TimestampType.withoutZone()) + ); + + Assert.assertEquals("Schema description isn't correct.", + "struct not null,pairs: map,time: timestamp not null>", + Spark3Util.describe(schema)); + } + + @Test + public void testLoadIcebergTable() throws Exception { + spark.conf().set("spark.sql.catalog.hive", SparkCatalog.class.getName()); + spark.conf().set("spark.sql.catalog.hive.type", "hive"); + spark.conf().set("spark.sql.catalog.hive.default-namespace", "default"); + + String tableFullName = "hive.default.tbl"; + sql("CREATE TABLE %s (c1 bigint, c2 string, c3 string) USING iceberg", tableFullName); + + Table table = Spark3Util.loadIcebergTable(spark, tableFullName); + Assert.assertTrue(table.name().equals(tableFullName)); + } + + @Test + public void testLoadIcebergCatalog() throws Exception { + spark.conf().set("spark.sql.catalog.test_cat", SparkCatalog.class.getName()); + spark.conf().set("spark.sql.catalog.test_cat.type", "hive"); + Catalog catalog = Spark3Util.loadIcebergCatalog(spark, "test_cat"); + Assert.assertTrue("Should retrieve underlying catalog class", catalog instanceof CachingCatalog); + } + + private SortOrder buildSortOrder(String transform, Schema schema, int sourceId) { + String jsonString = "{\n" + + " \"order-id\" : 10,\n" + + " \"fields\" : [ {\n" + + " \"transform\" : \"" + transform + "\",\n" + + " \"source-id\" : " + sourceId + ",\n" + + " \"direction\" : \"desc\",\n" + + " \"null-order\" : \"nulls-first\"\n" + + " } ]\n" + + "}"; + + return SortOrderParser.fromJson(schema, jsonString); + } +} diff --git a/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/TestSparkCatalogOperations.java b/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/TestSparkCatalogOperations.java new file mode 100644 index 000000000000..5c9f3c4cb189 --- /dev/null +++ b/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/TestSparkCatalogOperations.java @@ -0,0 +1,98 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iceberg.spark; + +import java.util.Map; +import org.apache.iceberg.Schema; +import org.apache.iceberg.catalog.Catalog; +import org.apache.spark.sql.catalyst.analysis.NoSuchTableException; +import org.apache.spark.sql.connector.catalog.Identifier; +import org.apache.spark.sql.connector.catalog.Table; +import org.apache.spark.sql.connector.catalog.TableChange; +import org.apache.spark.sql.types.DataTypes; +import org.apache.spark.sql.types.StructField; +import org.junit.After; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; + +public class TestSparkCatalogOperations extends SparkCatalogTestBase { + public TestSparkCatalogOperations( + String catalogName, String implementation, Map config) { + super(catalogName, implementation, config); + } + + @Before + public void createTable() { + sql("CREATE TABLE %s (id bigint NOT NULL, data string) USING iceberg", tableName); + } + + @After + public void removeTable() { + sql("DROP TABLE IF EXISTS %s", tableName); + } + + @Test + public void testAlterTable() throws NoSuchTableException { + BaseCatalog catalog = (BaseCatalog) spark.sessionState().catalogManager().catalog(catalogName); + Identifier identifier = Identifier.of(tableIdent.namespace().levels(), tableIdent.name()); + + String fieldName = "location"; + String propsKey = "note"; + String propsValue = "jazz"; + Table table = + catalog.alterTable( + identifier, + TableChange.addColumn(new String[] {fieldName}, DataTypes.StringType, true), + TableChange.setProperty(propsKey, propsValue)); + + Assert.assertNotNull("Should return updated table", table); + + StructField expectedField = DataTypes.createStructField(fieldName, DataTypes.StringType, true); + Assert.assertEquals( + "Adding a column to a table should return the updated table with the new column", + table.schema().fields()[2], + expectedField); + + Assert.assertTrue( + "Adding a property to a table should return the updated table with the new property", + table.properties().containsKey(propsKey)); + Assert.assertEquals( + "Altering a table to add a new property should add the correct value", + propsValue, + table.properties().get(propsKey)); + } + + @Test + public void testInvalidateTable() { + // load table to CachingCatalog + sql("SELECT count(1) FROM %s", tableName); + + // recreate table from another catalog or program + Catalog anotherCatalog = validationCatalog; + Schema schema = anotherCatalog.loadTable(tableIdent).schema(); + anotherCatalog.dropTable(tableIdent); + anotherCatalog.createTable(tableIdent, schema); + + // invalidate and reload table + sql("REFRESH TABLE %s", tableName); + sql("SELECT count(1) FROM %s", tableName); + } +} diff --git a/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/TestSparkDistributionAndOrderingUtil.java b/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/TestSparkDistributionAndOrderingUtil.java new file mode 100644 index 000000000000..e1cf58ab2a1f --- /dev/null +++ b/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/TestSparkDistributionAndOrderingUtil.java @@ -0,0 +1,2022 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iceberg.spark; + +import org.apache.iceberg.DistributionMode; +import org.apache.iceberg.MetadataColumns; +import org.apache.iceberg.Table; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap; +import org.apache.spark.sql.connector.distributions.Distribution; +import org.apache.spark.sql.connector.distributions.Distributions; +import org.apache.spark.sql.connector.expressions.Expression; +import org.apache.spark.sql.connector.expressions.Expressions; +import org.apache.spark.sql.connector.expressions.SortDirection; +import org.apache.spark.sql.connector.expressions.SortOrder; +import org.apache.spark.sql.connector.write.RowLevelOperation.Command; +import org.junit.After; +import org.junit.Assert; +import org.junit.Test; + +import static org.apache.iceberg.TableProperties.DELETE_DISTRIBUTION_MODE; +import static org.apache.iceberg.TableProperties.MERGE_DISTRIBUTION_MODE; +import static org.apache.iceberg.TableProperties.UPDATE_DISTRIBUTION_MODE; +import static org.apache.iceberg.TableProperties.WRITE_DISTRIBUTION_MODE; +import static org.apache.iceberg.TableProperties.WRITE_DISTRIBUTION_MODE_HASH; +import static org.apache.iceberg.TableProperties.WRITE_DISTRIBUTION_MODE_NONE; +import static org.apache.iceberg.TableProperties.WRITE_DISTRIBUTION_MODE_RANGE; +import static org.apache.spark.sql.connector.write.RowLevelOperation.Command.DELETE; +import static org.apache.spark.sql.connector.write.RowLevelOperation.Command.MERGE; +import static org.apache.spark.sql.connector.write.RowLevelOperation.Command.UPDATE; + +public class TestSparkDistributionAndOrderingUtil extends SparkTestBaseWithCatalog { + + private static final Distribution UNSPECIFIED_DISTRIBUTION = Distributions.unspecified(); + private static final Distribution FILE_CLUSTERED_DISTRIBUTION = Distributions.clustered(new Expression[]{ + Expressions.column(MetadataColumns.FILE_PATH.name()) + }); + private static final Distribution SPEC_ID_PARTITION_CLUSTERED_DISTRIBUTION = Distributions.clustered(new Expression[]{ + Expressions.column(MetadataColumns.SPEC_ID.name()), + Expressions.column(MetadataColumns.PARTITION_COLUMN_NAME) + }); + + private static final SortOrder[] EMPTY_ORDERING = new SortOrder[]{}; + private static final SortOrder[] FILE_POSITION_ORDERING = new SortOrder[]{ + Expressions.sort(Expressions.column(MetadataColumns.FILE_PATH.name()), SortDirection.ASCENDING), + Expressions.sort(Expressions.column(MetadataColumns.ROW_POSITION.name()), SortDirection.ASCENDING) + }; + private static final SortOrder[] SPEC_ID_PARTITION_FILE_ORDERING = new SortOrder[]{ + Expressions.sort(Expressions.column(MetadataColumns.SPEC_ID.name()), SortDirection.ASCENDING), + Expressions.sort(Expressions.column(MetadataColumns.PARTITION_COLUMN_NAME), SortDirection.ASCENDING), + Expressions.sort(Expressions.column(MetadataColumns.FILE_PATH.name()), SortDirection.ASCENDING) + }; + private static final SortOrder[] SPEC_ID_PARTITION_FILE_POSITION_ORDERING = new SortOrder[]{ + Expressions.sort(Expressions.column(MetadataColumns.SPEC_ID.name()), SortDirection.ASCENDING), + Expressions.sort(Expressions.column(MetadataColumns.PARTITION_COLUMN_NAME), SortDirection.ASCENDING), + Expressions.sort(Expressions.column(MetadataColumns.FILE_PATH.name()), SortDirection.ASCENDING), + Expressions.sort(Expressions.column(MetadataColumns.ROW_POSITION.name()), SortDirection.ASCENDING) + }; + + @After + public void dropTable() { + sql("DROP TABLE IF EXISTS %s", tableName); + } + + @Test + public void testDefaultWriteUnpartitionedUnsortedTable() { + sql("CREATE TABLE %s (id bigint, data string) USING iceberg", tableName); + + Table table = validationCatalog.loadTable(tableIdent); + + checkWriteDistributionAndOrdering(table, UNSPECIFIED_DISTRIBUTION, EMPTY_ORDERING); + } + + @Test + public void testHashWriteUnpartitionedUnsortedTable() { + sql("CREATE TABLE %s (id bigint, data string) USING iceberg", tableName); + + Table table = validationCatalog.loadTable(tableIdent); + + table.updateProperties() + .set(WRITE_DISTRIBUTION_MODE, WRITE_DISTRIBUTION_MODE_HASH) + .commit(); + + checkWriteDistributionAndOrdering(table, UNSPECIFIED_DISTRIBUTION, EMPTY_ORDERING); + } + + @Test + public void testRangeWriteUnpartitionedUnsortedTable() { + sql("CREATE TABLE %s (id bigint, data string) USING iceberg", tableName); + + Table table = validationCatalog.loadTable(tableIdent); + + table.updateProperties() + .set(WRITE_DISTRIBUTION_MODE, WRITE_DISTRIBUTION_MODE_RANGE) + .commit(); + + checkWriteDistributionAndOrdering(table, UNSPECIFIED_DISTRIBUTION, EMPTY_ORDERING); + } + + @Test + public void testDefaultWriteUnpartitionedSortedTable() { + sql("CREATE TABLE %s (id bigint, data string) USING iceberg", tableName); + + Table table = validationCatalog.loadTable(tableIdent); + + table.replaceSortOrder() + .asc("id") + .asc("data") + .commit(); + + SortOrder[] expectedOrdering = new SortOrder[]{ + Expressions.sort(Expressions.column("id"), SortDirection.ASCENDING), + Expressions.sort(Expressions.column("data"), SortDirection.ASCENDING) + }; + + Distribution expectedDistribution = Distributions.ordered(expectedOrdering); + + checkWriteDistributionAndOrdering(table, expectedDistribution, expectedOrdering); + } + + @Test + public void testHashWriteUnpartitionedSortedTable() { + sql("CREATE TABLE %s (id bigint, data string) USING iceberg", tableName); + + Table table = validationCatalog.loadTable(tableIdent); + + table.updateProperties() + .set(WRITE_DISTRIBUTION_MODE, WRITE_DISTRIBUTION_MODE_HASH) + .commit(); + + table.replaceSortOrder() + .asc("id") + .asc("data") + .commit(); + + SortOrder[] expectedOrdering = new SortOrder[]{ + Expressions.sort(Expressions.column("id"), SortDirection.ASCENDING), + Expressions.sort(Expressions.column("data"), SortDirection.ASCENDING) + }; + + checkWriteDistributionAndOrdering(table, UNSPECIFIED_DISTRIBUTION, expectedOrdering); + } + + @Test + public void testRangeWriteUnpartitionedSortedTable() { + sql("CREATE TABLE %s (id bigint, data string) USING iceberg", tableName); + + Table table = validationCatalog.loadTable(tableIdent); + + table.updateProperties() + .set(WRITE_DISTRIBUTION_MODE, WRITE_DISTRIBUTION_MODE_RANGE) + .commit(); + + table.replaceSortOrder() + .asc("id") + .asc("data") + .commit(); + + SortOrder[] expectedOrdering = new SortOrder[]{ + Expressions.sort(Expressions.column("id"), SortDirection.ASCENDING), + Expressions.sort(Expressions.column("data"), SortDirection.ASCENDING) + }; + + Distribution expectedDistribution = Distributions.ordered(expectedOrdering); + + checkWriteDistributionAndOrdering(table, expectedDistribution, expectedOrdering); + } + + @Test + public void testDefaultWritePartitionedUnsortedTable() { + sql("CREATE TABLE %s (id BIGINT, data STRING, date DATE, ts TIMESTAMP) " + + "USING iceberg " + + "PARTITIONED BY (date, days(ts))", tableName); + + Table table = validationCatalog.loadTable(tableIdent); + + SortOrder[] expectedOrdering = new SortOrder[]{ + Expressions.sort(Expressions.column("date"), SortDirection.ASCENDING), + Expressions.sort(Expressions.days("ts"), SortDirection.ASCENDING) + }; + + checkWriteDistributionAndOrdering(table, UNSPECIFIED_DISTRIBUTION, expectedOrdering); + } + + @Test + public void testHashWritePartitionedUnsortedTable() { + sql("CREATE TABLE %s (id BIGINT, data STRING, date DATE, ts TIMESTAMP) " + + "USING iceberg " + + "PARTITIONED BY (date, days(ts))", tableName); + + Table table = validationCatalog.loadTable(tableIdent); + + table.updateProperties() + .set(WRITE_DISTRIBUTION_MODE, WRITE_DISTRIBUTION_MODE_HASH) + .commit(); + + Expression[] expectedClustering = new Expression[]{ + Expressions.identity("date"), + Expressions.days("ts") + }; + Distribution expectedDistribution = Distributions.clustered(expectedClustering); + + SortOrder[] expectedOrdering = new SortOrder[]{ + Expressions.sort(Expressions.column("date"), SortDirection.ASCENDING), + Expressions.sort(Expressions.days("ts"), SortDirection.ASCENDING) + }; + + checkWriteDistributionAndOrdering(table, expectedDistribution, expectedOrdering); + } + + @Test + public void testRangeWritePartitionedUnsortedTable() { + sql("CREATE TABLE %s (id BIGINT, data STRING, date DATE, ts TIMESTAMP) " + + "USING iceberg " + + "PARTITIONED BY (date, days(ts))", tableName); + + Table table = validationCatalog.loadTable(tableIdent); + + table.updateProperties() + .set(WRITE_DISTRIBUTION_MODE, WRITE_DISTRIBUTION_MODE_RANGE) + .commit(); + + SortOrder[] expectedOrdering = new SortOrder[]{ + Expressions.sort(Expressions.column("date"), SortDirection.ASCENDING), + Expressions.sort(Expressions.days("ts"), SortDirection.ASCENDING) + }; + + Distribution expectedDistribution = Distributions.ordered(expectedOrdering); + + checkWriteDistributionAndOrdering(table, expectedDistribution, expectedOrdering); + } + + @Test + public void testDefaultWritePartitionedSortedTable() { + sql("CREATE TABLE %s (id BIGINT, data STRING, date DATE, ts TIMESTAMP) " + + "USING iceberg " + + "PARTITIONED BY (date)", tableName); + + Table table = validationCatalog.loadTable(tableIdent); + + table.replaceSortOrder() + .desc("id") + .commit(); + + SortOrder[] expectedOrdering = new SortOrder[]{ + Expressions.sort(Expressions.column("date"), SortDirection.ASCENDING), + Expressions.sort(Expressions.column("id"), SortDirection.DESCENDING) + }; + + Distribution expectedDistribution = Distributions.ordered(expectedOrdering); + + checkWriteDistributionAndOrdering(table, expectedDistribution, expectedOrdering); + } + + @Test + public void testHashWritePartitionedSortedTable() { + sql("CREATE TABLE %s (id BIGINT, data STRING, date DATE, ts TIMESTAMP) " + + "USING iceberg " + + "PARTITIONED BY (date, bucket(8, data))", tableName); + + Table table = validationCatalog.loadTable(tableIdent); + + table.updateProperties() + .set(WRITE_DISTRIBUTION_MODE, WRITE_DISTRIBUTION_MODE_HASH) + .commit(); + + table.replaceSortOrder() + .asc("id") + .commit(); + + Expression[] expectedClustering = new Expression[]{ + Expressions.identity("date"), + Expressions.bucket(8, "data") + }; + Distribution expectedDistribution = Distributions.clustered(expectedClustering); + + SortOrder[] expectedOrdering = new SortOrder[]{ + Expressions.sort(Expressions.column("date"), SortDirection.ASCENDING), + Expressions.sort(Expressions.bucket(8, "data"), SortDirection.ASCENDING), + Expressions.sort(Expressions.column("id"), SortDirection.ASCENDING) + }; + + checkWriteDistributionAndOrdering(table, expectedDistribution, expectedOrdering); + } + + @Test + public void testRangeWritePartitionedSortedTable() { + sql("CREATE TABLE %s (id BIGINT, data STRING, date DATE, ts TIMESTAMP) " + + "USING iceberg " + + "PARTITIONED BY (date)", tableName); + + Table table = validationCatalog.loadTable(tableIdent); + + table.replaceSortOrder() + .asc("id") + .commit(); + + SortOrder[] expectedOrdering = new SortOrder[]{ + Expressions.sort(Expressions.column("date"), SortDirection.ASCENDING), + Expressions.sort(Expressions.column("id"), SortDirection.ASCENDING) + }; + + Distribution expectedDistribution = Distributions.ordered(expectedOrdering); + + checkWriteDistributionAndOrdering(table, expectedDistribution, expectedOrdering); + } + + // ============================================================= + // Distribution and ordering for copy-on-write DELETE operations + // ============================================================= + // + // UNPARTITIONED UNORDERED + // ------------------------------------------------------------------------- + // delete mode is NOT SET -> CLUSTER BY _file + LOCALLY ORDER BY _file, _pos + // delete mode is NONE -> unspecified distribution + empty ordering + // delete mode is HASH -> CLUSTER BY _file + LOCALLY ORDER BY _file, _pos + // delete mode is RANGE -> ORDER BY _file, _pos + // + // UNPARTITIONED ORDERED BY id, data + // ------------------------------------------------------------------------- + // delete mode is NOT SET -> CLUSTER BY _file + LOCALLY ORDER BY id, data + // delete mode is NONE -> unspecified distribution + LOCALLY ORDER BY id, data + // delete mode is HASH -> CLUSTER BY _file + LOCALLY ORDER BY id, data + // delete mode is RANGE -> ORDER BY id, data + // + // PARTITIONED BY date, days(ts) UNORDERED + // ------------------------------------------------------------------------- + // delete mode is NOT SET -> CLUSTER BY _file + LOCALLY ORDER BY date, days(ts), _file, _pos + // delete mode is NONE -> unspecified distribution + LOCALLY ORDERED BY date, days(ts) + // delete mode is HASH -> CLUSTER BY _file + LOCALLY ORDER BY date, days(ts), _file, _pos + // delete mode is RANGE -> ORDER BY date, days(ts), _file, _pos + // + // PARTITIONED BY date ORDERED BY id + // ------------------------------------------------------------------------- + // delete mode is NOT SET -> CLUSTER BY _file + LOCALLY ORDER BY date, id + // delete mode is NONE -> unspecified distribution + LOCALLY ORDERED BY date, id + // delete mode is HASH -> CLUSTER BY _file + LOCALLY ORDER BY date, id + // delete mode is RANGE -> ORDERED BY date, id + + @Test + public void testDefaultCopyOnWriteDeleteUnpartitionedUnsortedTable() { + sql("CREATE TABLE %s (id bigint, data string) USING iceberg", tableName); + + Table table = validationCatalog.loadTable(tableIdent); + + checkCopyOnWriteDistributionAndOrdering(table, DELETE, FILE_CLUSTERED_DISTRIBUTION, FILE_POSITION_ORDERING); + } + + @Test + public void testNoneCopyOnWriteDeleteUnpartitionedUnsortedTable() { + sql("CREATE TABLE %s (id bigint, data string) USING iceberg", tableName); + + Table table = validationCatalog.loadTable(tableIdent); + + table.updateProperties() + .set(DELETE_DISTRIBUTION_MODE, WRITE_DISTRIBUTION_MODE_NONE) + .commit(); + + checkCopyOnWriteDistributionAndOrdering(table, DELETE, UNSPECIFIED_DISTRIBUTION, EMPTY_ORDERING); + } + + @Test + public void testHashCopyOnWriteDeleteUnpartitionedUnsortedTable() { + sql("CREATE TABLE %s (id bigint, data string) USING iceberg", tableName); + + Table table = validationCatalog.loadTable(tableIdent); + + table.updateProperties() + .set(DELETE_DISTRIBUTION_MODE, WRITE_DISTRIBUTION_MODE_HASH) + .commit(); + + checkCopyOnWriteDistributionAndOrdering(table, DELETE, FILE_CLUSTERED_DISTRIBUTION, FILE_POSITION_ORDERING); + } + + @Test + public void testRangeCopyOnWriteDeleteUnpartitionedUnsortedTable() { + sql("CREATE TABLE %s (id bigint, data string) USING iceberg", tableName); + + Table table = validationCatalog.loadTable(tableIdent); + + table.updateProperties() + .set(DELETE_DISTRIBUTION_MODE, WRITE_DISTRIBUTION_MODE_RANGE) + .commit(); + + Distribution expectedDistribution = Distributions.ordered(FILE_POSITION_ORDERING); + + checkCopyOnWriteDistributionAndOrdering(table, DELETE, expectedDistribution, FILE_POSITION_ORDERING); + } + + @Test + public void testDefaultCopyOnWriteDeleteUnpartitionedSortedTable() { + sql("CREATE TABLE %s (id bigint, data string) USING iceberg", tableName); + + Table table = validationCatalog.loadTable(tableIdent); + + table.replaceSortOrder() + .asc("id") + .asc("data") + .commit(); + + SortOrder[] expectedOrdering = new SortOrder[]{ + Expressions.sort(Expressions.column("id"), SortDirection.ASCENDING), + Expressions.sort(Expressions.column("data"), SortDirection.ASCENDING) + }; + + checkCopyOnWriteDistributionAndOrdering(table, DELETE, FILE_CLUSTERED_DISTRIBUTION, expectedOrdering); + } + + @Test + public void testNoneCopyOnWriteDeleteUnpartitionedSortedTable() { + sql("CREATE TABLE %s (id bigint, data string) USING iceberg", tableName); + + Table table = validationCatalog.loadTable(tableIdent); + + table.updateProperties() + .set(DELETE_DISTRIBUTION_MODE, WRITE_DISTRIBUTION_MODE_NONE) + .commit(); + + table.replaceSortOrder() + .asc("id") + .asc("data") + .commit(); + + SortOrder[] expectedOrdering = new SortOrder[]{ + Expressions.sort(Expressions.column("id"), SortDirection.ASCENDING), + Expressions.sort(Expressions.column("data"), SortDirection.ASCENDING) + }; + + checkCopyOnWriteDistributionAndOrdering(table, DELETE, UNSPECIFIED_DISTRIBUTION, expectedOrdering); + } + + @Test + public void testHashCopyOnWriteDeleteUnpartitionedSortedTable() { + sql("CREATE TABLE %s (id bigint, data string) USING iceberg", tableName); + + Table table = validationCatalog.loadTable(tableIdent); + + table.updateProperties() + .set(DELETE_DISTRIBUTION_MODE, WRITE_DISTRIBUTION_MODE_HASH) + .commit(); + + table.replaceSortOrder() + .asc("id") + .asc("data") + .commit(); + + SortOrder[] expectedOrdering = new SortOrder[]{ + Expressions.sort(Expressions.column("id"), SortDirection.ASCENDING), + Expressions.sort(Expressions.column("data"), SortDirection.ASCENDING) + }; + + checkCopyOnWriteDistributionAndOrdering(table, DELETE, FILE_CLUSTERED_DISTRIBUTION, expectedOrdering); + } + + @Test + public void testRangeCopyOnWriteDeleteUnpartitionedSortedTable() { + sql("CREATE TABLE %s (id bigint, data string) USING iceberg", tableName); + + Table table = validationCatalog.loadTable(tableIdent); + + table.updateProperties() + .set(DELETE_DISTRIBUTION_MODE, WRITE_DISTRIBUTION_MODE_RANGE) + .commit(); + + table.replaceSortOrder() + .asc("id") + .asc("data") + .commit(); + + SortOrder[] expectedOrdering = new SortOrder[]{ + Expressions.sort(Expressions.column("id"), SortDirection.ASCENDING), + Expressions.sort(Expressions.column("data"), SortDirection.ASCENDING) + }; + + Distribution expectedDistribution = Distributions.ordered(expectedOrdering); + + checkCopyOnWriteDistributionAndOrdering(table, DELETE, expectedDistribution, expectedOrdering); + } + + @Test + public void testDefaultCopyOnWriteDeletePartitionedUnsortedTable() { + sql("CREATE TABLE %s (id BIGINT, data STRING, date DATE, ts TIMESTAMP) " + + "USING iceberg " + + "PARTITIONED BY (date, days(ts))", tableName); + + Table table = validationCatalog.loadTable(tableIdent); + + SortOrder[] expectedOrdering = new SortOrder[]{ + Expressions.sort(Expressions.column("date"), SortDirection.ASCENDING), + Expressions.sort(Expressions.days("ts"), SortDirection.ASCENDING), + Expressions.sort(Expressions.column(MetadataColumns.FILE_PATH.name()), SortDirection.ASCENDING), + Expressions.sort(Expressions.column(MetadataColumns.ROW_POSITION.name()), SortDirection.ASCENDING) + }; + + checkCopyOnWriteDistributionAndOrdering(table, DELETE, FILE_CLUSTERED_DISTRIBUTION, expectedOrdering); + } + + @Test + public void testNoneCopyOnWriteDeletePartitionedUnsortedTable() { + sql("CREATE TABLE %s (id BIGINT, data STRING, date DATE, ts TIMESTAMP) " + + "USING iceberg " + + "PARTITIONED BY (date, days(ts))", tableName); + + Table table = validationCatalog.loadTable(tableIdent); + + table.updateProperties() + .set(DELETE_DISTRIBUTION_MODE, WRITE_DISTRIBUTION_MODE_NONE) + .commit(); + + SortOrder[] expectedOrdering = new SortOrder[]{ + Expressions.sort(Expressions.column("date"), SortDirection.ASCENDING), + Expressions.sort(Expressions.days("ts"), SortDirection.ASCENDING) + }; + + checkCopyOnWriteDistributionAndOrdering(table, DELETE, UNSPECIFIED_DISTRIBUTION, expectedOrdering); + } + + @Test + public void testHashCopyOnWriteDeletePartitionedUnsortedTable() { + sql("CREATE TABLE %s (id BIGINT, data STRING, date DATE, ts TIMESTAMP) " + + "USING iceberg " + + "PARTITIONED BY (date, days(ts))", tableName); + + Table table = validationCatalog.loadTable(tableIdent); + + table.updateProperties() + .set(DELETE_DISTRIBUTION_MODE, WRITE_DISTRIBUTION_MODE_HASH) + .commit(); + + SortOrder[] expectedOrdering = new SortOrder[]{ + Expressions.sort(Expressions.column("date"), SortDirection.ASCENDING), + Expressions.sort(Expressions.days("ts"), SortDirection.ASCENDING), + Expressions.sort(Expressions.column(MetadataColumns.FILE_PATH.name()), SortDirection.ASCENDING), + Expressions.sort(Expressions.column(MetadataColumns.ROW_POSITION.name()), SortDirection.ASCENDING) + }; + + checkCopyOnWriteDistributionAndOrdering(table, DELETE, FILE_CLUSTERED_DISTRIBUTION, expectedOrdering); + } + + @Test + public void testRangeCopyOnWriteDeletePartitionedUnsortedTable() { + sql("CREATE TABLE %s (id BIGINT, data STRING, date DATE, ts TIMESTAMP) " + + "USING iceberg " + + "PARTITIONED BY (date, days(ts))", tableName); + + Table table = validationCatalog.loadTable(tableIdent); + + table.updateProperties() + .set(DELETE_DISTRIBUTION_MODE, WRITE_DISTRIBUTION_MODE_RANGE) + .commit(); + + SortOrder[] expectedOrdering = new SortOrder[]{ + Expressions.sort(Expressions.column("date"), SortDirection.ASCENDING), + Expressions.sort(Expressions.days("ts"), SortDirection.ASCENDING), + Expressions.sort(Expressions.column(MetadataColumns.FILE_PATH.name()), SortDirection.ASCENDING), + Expressions.sort(Expressions.column(MetadataColumns.ROW_POSITION.name()), SortDirection.ASCENDING) + }; + + Distribution expectedDistribution = Distributions.ordered(expectedOrdering); + + checkCopyOnWriteDistributionAndOrdering(table, DELETE, expectedDistribution, expectedOrdering); + } + + @Test + public void testDefaultCopyOnWriteDeletePartitionedSortedTable() { + sql("CREATE TABLE %s (id BIGINT, data STRING, date DATE, ts TIMESTAMP) " + + "USING iceberg " + + "PARTITIONED BY (date)", tableName); + + Table table = validationCatalog.loadTable(tableIdent); + + table.replaceSortOrder() + .desc("id") + .commit(); + + SortOrder[] expectedOrdering = new SortOrder[]{ + Expressions.sort(Expressions.column("date"), SortDirection.ASCENDING), + Expressions.sort(Expressions.column("id"), SortDirection.DESCENDING) + }; + + checkCopyOnWriteDistributionAndOrdering(table, DELETE, FILE_CLUSTERED_DISTRIBUTION, expectedOrdering); + } + + @Test + public void testNoneCopyOnWriteDeletePartitionedSortedTable() { + sql("CREATE TABLE %s (id BIGINT, data STRING, date DATE, ts TIMESTAMP) " + + "USING iceberg " + + "PARTITIONED BY (date)", tableName); + + Table table = validationCatalog.loadTable(tableIdent); + + table.updateProperties() + .set(DELETE_DISTRIBUTION_MODE, WRITE_DISTRIBUTION_MODE_NONE) + .commit(); + + table.replaceSortOrder() + .desc("id") + .commit(); + + SortOrder[] expectedOrdering = new SortOrder[]{ + Expressions.sort(Expressions.column("date"), SortDirection.ASCENDING), + Expressions.sort(Expressions.column("id"), SortDirection.DESCENDING) + }; + + checkCopyOnWriteDistributionAndOrdering(table, DELETE, UNSPECIFIED_DISTRIBUTION, expectedOrdering); + } + + @Test + public void testHashCopyOnWriteDeletePartitionedSortedTable() { + sql("CREATE TABLE %s (id BIGINT, data STRING, date DATE, ts TIMESTAMP) " + + "USING iceberg " + + "PARTITIONED BY (date, bucket(8, data))", tableName); + + Table table = validationCatalog.loadTable(tableIdent); + + table.updateProperties() + .set(DELETE_DISTRIBUTION_MODE, WRITE_DISTRIBUTION_MODE_HASH) + .commit(); + + table.replaceSortOrder() + .asc("id") + .commit(); + + SortOrder[] expectedOrdering = new SortOrder[]{ + Expressions.sort(Expressions.column("date"), SortDirection.ASCENDING), + Expressions.sort(Expressions.bucket(8, "data"), SortDirection.ASCENDING), + Expressions.sort(Expressions.column("id"), SortDirection.ASCENDING) + }; + + checkCopyOnWriteDistributionAndOrdering(table, DELETE, FILE_CLUSTERED_DISTRIBUTION, expectedOrdering); + } + + @Test + public void testRangeCopyOnWriteDeletePartitionedSortedTable() { + sql("CREATE TABLE %s (id BIGINT, data STRING, date DATE, ts TIMESTAMP) " + + "USING iceberg " + + "PARTITIONED BY (date)", tableName); + + Table table = validationCatalog.loadTable(tableIdent); + + table.updateProperties() + .set(DELETE_DISTRIBUTION_MODE, WRITE_DISTRIBUTION_MODE_RANGE) + .commit(); + + table.replaceSortOrder() + .asc("id") + .commit(); + + SortOrder[] expectedOrdering = new SortOrder[]{ + Expressions.sort(Expressions.column("date"), SortDirection.ASCENDING), + Expressions.sort(Expressions.column("id"), SortDirection.ASCENDING) + }; + + Distribution expectedDistribution = Distributions.ordered(expectedOrdering); + + checkCopyOnWriteDistributionAndOrdering(table, DELETE, expectedDistribution, expectedOrdering); + } + + // ============================================================= + // Distribution and ordering for copy-on-write UPDATE operations + // ============================================================= + // + // UNPARTITIONED UNORDERED + // ------------------------------------------------------------------------- + // update mode is NOT SET -> CLUSTER BY _file + LOCALLY ORDER BY _file, _pos + // update mode is NONE -> unspecified distribution + empty ordering + // update mode is HASH -> CLUSTER BY _file + LOCALLY ORDER BY _file, _pos + // update mode is RANGE -> ORDER BY _file, _pos + // + // UNPARTITIONED ORDERED BY id, data + // ------------------------------------------------------------------------- + // update mode is NOT SET -> CLUSTER BY _file + LOCALLY ORDER BY id, data + // update mode is NONE -> unspecified distribution + LOCALLY ORDER BY id, data + // update mode is HASH -> CLUSTER BY _file + LOCALLY ORDER BY id, data + // update mode is RANGE -> ORDER BY id, data + // + // PARTITIONED BY date, days(ts) UNORDERED + // ------------------------------------------------------------------------- + // update mode is NOT SET -> CLUSTER BY _file + LOCALLY ORDER BY date, days(ts), _file, _pos + // update mode is NONE -> unspecified distribution + LOCALLY ORDERED BY date, days(ts) + // update mode is HASH -> CLUSTER BY _file + LOCALLY ORDER BY date, days(ts), _file, _pos + // update mode is RANGE -> ORDER BY date, days(ts), _file, _pos + // + // PARTITIONED BY date ORDERED BY id + // ------------------------------------------------------------------------- + // update mode is NOT SET -> CLUSTER BY _file + LOCALLY ORDER BY date, id + // update mode is NONE -> unspecified distribution + LOCALLY ORDERED BY date, id + // update mode is HASH -> CLUSTER BY _file + LOCALLY ORDER BY date, id + // update mode is RANGE -> ORDERED BY date, id + + @Test + public void testDefaultCopyOnWriteUpdateUnpartitionedUnsortedTable() { + sql("CREATE TABLE %s (id bigint, data string) USING iceberg", tableName); + + Table table = validationCatalog.loadTable(tableIdent); + + checkCopyOnWriteDistributionAndOrdering(table, UPDATE, FILE_CLUSTERED_DISTRIBUTION, FILE_POSITION_ORDERING); + } + + @Test + public void testNoneCopyOnWriteUpdateUnpartitionedUnsortedTable() { + sql("CREATE TABLE %s (id bigint, data string) USING iceberg", tableName); + + Table table = validationCatalog.loadTable(tableIdent); + + table.updateProperties() + .set(UPDATE_DISTRIBUTION_MODE, WRITE_DISTRIBUTION_MODE_NONE) + .commit(); + + checkCopyOnWriteDistributionAndOrdering(table, UPDATE, UNSPECIFIED_DISTRIBUTION, EMPTY_ORDERING); + } + + @Test + public void testHashCopyOnWriteUpdateUnpartitionedUnsortedTable() { + sql("CREATE TABLE %s (id bigint, data string) USING iceberg", tableName); + + Table table = validationCatalog.loadTable(tableIdent); + + table.updateProperties() + .set(UPDATE_DISTRIBUTION_MODE, WRITE_DISTRIBUTION_MODE_HASH) + .commit(); + + checkCopyOnWriteDistributionAndOrdering(table, UPDATE, FILE_CLUSTERED_DISTRIBUTION, FILE_POSITION_ORDERING); + } + + @Test + public void testRangeCopyOnWriteUpdateUnpartitionedUnsortedTable() { + sql("CREATE TABLE %s (id bigint, data string) USING iceberg", tableName); + + Table table = validationCatalog.loadTable(tableIdent); + + table.updateProperties() + .set(UPDATE_DISTRIBUTION_MODE, WRITE_DISTRIBUTION_MODE_RANGE) + .commit(); + + Distribution expectedDistribution = Distributions.ordered(FILE_POSITION_ORDERING); + + checkCopyOnWriteDistributionAndOrdering(table, UPDATE, expectedDistribution, FILE_POSITION_ORDERING); + } + + @Test + public void testDefaultCopyOnWriteUpdateUnpartitionedSortedTable() { + sql("CREATE TABLE %s (id bigint, data string) USING iceberg", tableName); + + Table table = validationCatalog.loadTable(tableIdent); + + table.replaceSortOrder() + .asc("id") + .asc("data") + .commit(); + + SortOrder[] expectedOrdering = new SortOrder[]{ + Expressions.sort(Expressions.column("id"), SortDirection.ASCENDING), + Expressions.sort(Expressions.column("data"), SortDirection.ASCENDING) + }; + + checkCopyOnWriteDistributionAndOrdering(table, UPDATE, FILE_CLUSTERED_DISTRIBUTION, expectedOrdering); + } + + @Test + public void testNoneCopyOnWriteUpdateUnpartitionedSortedTable() { + sql("CREATE TABLE %s (id bigint, data string) USING iceberg", tableName); + + Table table = validationCatalog.loadTable(tableIdent); + + table.updateProperties() + .set(UPDATE_DISTRIBUTION_MODE, WRITE_DISTRIBUTION_MODE_NONE) + .commit(); + + table.replaceSortOrder() + .asc("id") + .asc("data") + .commit(); + + SortOrder[] expectedOrdering = new SortOrder[]{ + Expressions.sort(Expressions.column("id"), SortDirection.ASCENDING), + Expressions.sort(Expressions.column("data"), SortDirection.ASCENDING) + }; + + checkCopyOnWriteDistributionAndOrdering(table, UPDATE, UNSPECIFIED_DISTRIBUTION, expectedOrdering); + } + + @Test + public void testHashCopyOnWriteUpdateUnpartitionedSortedTable() { + sql("CREATE TABLE %s (id bigint, data string) USING iceberg", tableName); + + Table table = validationCatalog.loadTable(tableIdent); + + table.updateProperties() + .set(UPDATE_DISTRIBUTION_MODE, WRITE_DISTRIBUTION_MODE_HASH) + .commit(); + + table.replaceSortOrder() + .asc("id") + .asc("data") + .commit(); + + SortOrder[] expectedOrdering = new SortOrder[]{ + Expressions.sort(Expressions.column("id"), SortDirection.ASCENDING), + Expressions.sort(Expressions.column("data"), SortDirection.ASCENDING) + }; + + checkCopyOnWriteDistributionAndOrdering(table, UPDATE, FILE_CLUSTERED_DISTRIBUTION, expectedOrdering); + } + + @Test + public void testRangeCopyOnWriteUpdateUnpartitionedSortedTable() { + sql("CREATE TABLE %s (id bigint, data string) USING iceberg", tableName); + + Table table = validationCatalog.loadTable(tableIdent); + + table.updateProperties() + .set(UPDATE_DISTRIBUTION_MODE, WRITE_DISTRIBUTION_MODE_RANGE) + .commit(); + + table.replaceSortOrder() + .asc("id") + .asc("data") + .commit(); + + SortOrder[] expectedOrdering = new SortOrder[]{ + Expressions.sort(Expressions.column("id"), SortDirection.ASCENDING), + Expressions.sort(Expressions.column("data"), SortDirection.ASCENDING) + }; + + Distribution expectedDistribution = Distributions.ordered(expectedOrdering); + + checkCopyOnWriteDistributionAndOrdering(table, UPDATE, expectedDistribution, expectedOrdering); + } + + @Test + public void testDefaultCopyOnWriteUpdatePartitionedUnsortedTable() { + sql("CREATE TABLE %s (id BIGINT, data STRING, date DATE, ts TIMESTAMP) " + + "USING iceberg " + + "PARTITIONED BY (date, days(ts))", tableName); + + Table table = validationCatalog.loadTable(tableIdent); + + SortOrder[] expectedOrdering = new SortOrder[]{ + Expressions.sort(Expressions.column("date"), SortDirection.ASCENDING), + Expressions.sort(Expressions.days("ts"), SortDirection.ASCENDING), + Expressions.sort(Expressions.column(MetadataColumns.FILE_PATH.name()), SortDirection.ASCENDING), + Expressions.sort(Expressions.column(MetadataColumns.ROW_POSITION.name()), SortDirection.ASCENDING) + }; + + checkCopyOnWriteDistributionAndOrdering(table, UPDATE, FILE_CLUSTERED_DISTRIBUTION, expectedOrdering); + } + + @Test + public void testNoneCopyOnWriteUpdatePartitionedUnsortedTable() { + sql("CREATE TABLE %s (id BIGINT, data STRING, date DATE, ts TIMESTAMP) " + + "USING iceberg " + + "PARTITIONED BY (date, days(ts))", tableName); + + Table table = validationCatalog.loadTable(tableIdent); + + table.updateProperties() + .set(UPDATE_DISTRIBUTION_MODE, WRITE_DISTRIBUTION_MODE_NONE) + .commit(); + + SortOrder[] expectedOrdering = new SortOrder[]{ + Expressions.sort(Expressions.column("date"), SortDirection.ASCENDING), + Expressions.sort(Expressions.days("ts"), SortDirection.ASCENDING) + }; + + checkCopyOnWriteDistributionAndOrdering(table, UPDATE, UNSPECIFIED_DISTRIBUTION, expectedOrdering); + } + + @Test + public void testHashCopyOnWriteUpdatePartitionedUnsortedTable() { + sql("CREATE TABLE %s (id BIGINT, data STRING, date DATE, ts TIMESTAMP) " + + "USING iceberg " + + "PARTITIONED BY (date, days(ts))", tableName); + + Table table = validationCatalog.loadTable(tableIdent); + + table.updateProperties() + .set(UPDATE_DISTRIBUTION_MODE, WRITE_DISTRIBUTION_MODE_HASH) + .commit(); + + SortOrder[] expectedOrdering = new SortOrder[]{ + Expressions.sort(Expressions.column("date"), SortDirection.ASCENDING), + Expressions.sort(Expressions.days("ts"), SortDirection.ASCENDING), + Expressions.sort(Expressions.column(MetadataColumns.FILE_PATH.name()), SortDirection.ASCENDING), + Expressions.sort(Expressions.column(MetadataColumns.ROW_POSITION.name()), SortDirection.ASCENDING) + }; + + checkCopyOnWriteDistributionAndOrdering(table, UPDATE, FILE_CLUSTERED_DISTRIBUTION, expectedOrdering); + } + + @Test + public void testRangeCopyOnWriteUpdatePartitionedUnsortedTable() { + sql("CREATE TABLE %s (id BIGINT, data STRING, date DATE, ts TIMESTAMP) " + + "USING iceberg " + + "PARTITIONED BY (date, days(ts))", tableName); + + Table table = validationCatalog.loadTable(tableIdent); + + table.updateProperties() + .set(UPDATE_DISTRIBUTION_MODE, WRITE_DISTRIBUTION_MODE_RANGE) + .commit(); + + SortOrder[] expectedOrdering = new SortOrder[]{ + Expressions.sort(Expressions.column("date"), SortDirection.ASCENDING), + Expressions.sort(Expressions.days("ts"), SortDirection.ASCENDING), + Expressions.sort(Expressions.column(MetadataColumns.FILE_PATH.name()), SortDirection.ASCENDING), + Expressions.sort(Expressions.column(MetadataColumns.ROW_POSITION.name()), SortDirection.ASCENDING) + }; + + Distribution expectedDistribution = Distributions.ordered(expectedOrdering); + + checkCopyOnWriteDistributionAndOrdering(table, UPDATE, expectedDistribution, expectedOrdering); + } + + @Test + public void testDefaultCopyOnWriteUpdatePartitionedSortedTable() { + sql("CREATE TABLE %s (id BIGINT, data STRING, date DATE, ts TIMESTAMP) " + + "USING iceberg " + + "PARTITIONED BY (date)", tableName); + + Table table = validationCatalog.loadTable(tableIdent); + + table.replaceSortOrder() + .desc("id") + .commit(); + + SortOrder[] expectedOrdering = new SortOrder[]{ + Expressions.sort(Expressions.column("date"), SortDirection.ASCENDING), + Expressions.sort(Expressions.column("id"), SortDirection.DESCENDING) + }; + + checkCopyOnWriteDistributionAndOrdering(table, UPDATE, FILE_CLUSTERED_DISTRIBUTION, expectedOrdering); + } + + @Test + public void testNoneCopyOnWriteUpdatePartitionedSortedTable() { + sql("CREATE TABLE %s (id BIGINT, data STRING, date DATE, ts TIMESTAMP) " + + "USING iceberg " + + "PARTITIONED BY (date)", tableName); + + Table table = validationCatalog.loadTable(tableIdent); + + table.updateProperties() + .set(UPDATE_DISTRIBUTION_MODE, WRITE_DISTRIBUTION_MODE_NONE) + .commit(); + + table.replaceSortOrder() + .desc("id") + .commit(); + + SortOrder[] expectedOrdering = new SortOrder[]{ + Expressions.sort(Expressions.column("date"), SortDirection.ASCENDING), + Expressions.sort(Expressions.column("id"), SortDirection.DESCENDING) + }; + + checkCopyOnWriteDistributionAndOrdering(table, UPDATE, UNSPECIFIED_DISTRIBUTION, expectedOrdering); + } + + @Test + public void testHashCopyOnWriteUpdatePartitionedSortedTable() { + sql("CREATE TABLE %s (id BIGINT, data STRING, date DATE, ts TIMESTAMP) " + + "USING iceberg " + + "PARTITIONED BY (date, bucket(8, data))", tableName); + + Table table = validationCatalog.loadTable(tableIdent); + + table.updateProperties() + .set(UPDATE_DISTRIBUTION_MODE, WRITE_DISTRIBUTION_MODE_HASH) + .commit(); + + table.replaceSortOrder() + .asc("id") + .commit(); + + SortOrder[] expectedOrdering = new SortOrder[]{ + Expressions.sort(Expressions.column("date"), SortDirection.ASCENDING), + Expressions.sort(Expressions.bucket(8, "data"), SortDirection.ASCENDING), + Expressions.sort(Expressions.column("id"), SortDirection.ASCENDING) + }; + + checkCopyOnWriteDistributionAndOrdering(table, UPDATE, FILE_CLUSTERED_DISTRIBUTION, expectedOrdering); + } + + @Test + public void testRangeCopyOnWriteUpdatePartitionedSortedTable() { + sql("CREATE TABLE %s (id BIGINT, data STRING, date DATE, ts TIMESTAMP) " + + "USING iceberg " + + "PARTITIONED BY (date)", tableName); + + Table table = validationCatalog.loadTable(tableIdent); + + table.updateProperties() + .set(UPDATE_DISTRIBUTION_MODE, WRITE_DISTRIBUTION_MODE_RANGE) + .commit(); + + table.replaceSortOrder() + .asc("id") + .commit(); + + SortOrder[] expectedOrdering = new SortOrder[]{ + Expressions.sort(Expressions.column("date"), SortDirection.ASCENDING), + Expressions.sort(Expressions.column("id"), SortDirection.ASCENDING) + }; + + Distribution expectedDistribution = Distributions.ordered(expectedOrdering); + + checkCopyOnWriteDistributionAndOrdering(table, UPDATE, expectedDistribution, expectedOrdering); + } + + // ============================================================= + // Distribution and ordering for copy-on-write MERGE operations + // ============================================================= + // + // UNPARTITIONED UNORDERED + // ------------------------------------------------------------------------- + // merge mode is NOT SET -> use write distribution and ordering + // merge mode is NONE -> unspecified distribution + empty ordering + // merge mode is HASH -> unspecified distribution + empty ordering + // merge mode is RANGE -> unspecified distribution + empty ordering + // + // UNPARTITIONED ORDERED BY id, data + // ------------------------------------------------------------------------- + // merge mode is NOT SET -> use write distribution and ordering + // merge mode is NONE -> unspecified distribution + LOCALLY ORDER BY id, data + // merge mode is HASH -> unspecified distribution + LOCALLY ORDER BY id, data + // merge mode is RANGE -> ORDER BY id, data + // + // PARTITIONED BY date, days(ts) UNORDERED + // ------------------------------------------------------------------------- + // merge mode is NOT SET -> use write distribution and ordering + // merge mode is NONE -> unspecified distribution + LOCALLY ORDERED BY date, days(ts) + // merge mode is HASH -> CLUSTER BY date, days(ts) + LOCALLY ORDER BY date, days(ts) + // merge mode is RANGE -> ORDER BY date, days(ts) + // + // PARTITIONED BY date ORDERED BY id + // ------------------------------------------------------------------------- + // merge mode is NOT SET -> use write distribution and ordering + // merge mode is NONE -> unspecified distribution + LOCALLY ORDERED BY date, id + // merge mode is HASH -> CLUSTER BY date + LOCALLY ORDER BY date, id + // merge mode is RANGE -> ORDERED BY date, id + + @Test + public void testNoneCopyOnWriteMergeUnpartitionedUnsortedTable() { + sql("CREATE TABLE %s (id bigint, data string) USING iceberg", tableName); + + Table table = validationCatalog.loadTable(tableIdent); + + table.updateProperties() + .set(MERGE_DISTRIBUTION_MODE, WRITE_DISTRIBUTION_MODE_NONE) + .commit(); + + checkCopyOnWriteDistributionAndOrdering(table, MERGE, UNSPECIFIED_DISTRIBUTION, EMPTY_ORDERING); + } + + @Test + public void testHashCopyOnWriteMergeUnpartitionedUnsortedTable() { + sql("CREATE TABLE %s (id bigint, data string) USING iceberg", tableName); + + Table table = validationCatalog.loadTable(tableIdent); + + table.updateProperties() + .set(MERGE_DISTRIBUTION_MODE, WRITE_DISTRIBUTION_MODE_HASH) + .commit(); + + checkCopyOnWriteDistributionAndOrdering(table, MERGE, UNSPECIFIED_DISTRIBUTION, EMPTY_ORDERING); + } + + @Test + public void testRangeCopyOnWriteMergeUnpartitionedUnsortedTable() { + sql("CREATE TABLE %s (id bigint, data string) USING iceberg", tableName); + + Table table = validationCatalog.loadTable(tableIdent); + + table.updateProperties() + .set(MERGE_DISTRIBUTION_MODE, WRITE_DISTRIBUTION_MODE_RANGE) + .commit(); + + checkCopyOnWriteDistributionAndOrdering(table, MERGE, UNSPECIFIED_DISTRIBUTION, EMPTY_ORDERING); + } + + @Test + public void testNoneCopyOnWriteMergeUnpartitionedSortedTable() { + sql("CREATE TABLE %s (id bigint, data string) USING iceberg", tableName); + + Table table = validationCatalog.loadTable(tableIdent); + + table.updateProperties() + .set(MERGE_DISTRIBUTION_MODE, WRITE_DISTRIBUTION_MODE_NONE) + .commit(); + + table.replaceSortOrder() + .asc("id") + .asc("data") + .commit(); + + SortOrder[] expectedOrdering = new SortOrder[]{ + Expressions.sort(Expressions.column("id"), SortDirection.ASCENDING), + Expressions.sort(Expressions.column("data"), SortDirection.ASCENDING) + }; + + checkCopyOnWriteDistributionAndOrdering(table, MERGE, UNSPECIFIED_DISTRIBUTION, expectedOrdering); + } + + @Test + public void testHashCopyOnWriteMergeUnpartitionedSortedTable() { + sql("CREATE TABLE %s (id bigint, data string) USING iceberg", tableName); + + Table table = validationCatalog.loadTable(tableIdent); + + table.updateProperties() + .set(MERGE_DISTRIBUTION_MODE, WRITE_DISTRIBUTION_MODE_HASH) + .commit(); + + table.replaceSortOrder() + .asc("id") + .asc("data") + .commit(); + + SortOrder[] expectedOrdering = new SortOrder[]{ + Expressions.sort(Expressions.column("id"), SortDirection.ASCENDING), + Expressions.sort(Expressions.column("data"), SortDirection.ASCENDING) + }; + + checkCopyOnWriteDistributionAndOrdering(table, MERGE, UNSPECIFIED_DISTRIBUTION, expectedOrdering); + } + + @Test + public void testRangeCopyOnWriteMergeUnpartitionedSortedTable() { + sql("CREATE TABLE %s (id bigint, data string) USING iceberg", tableName); + + Table table = validationCatalog.loadTable(tableIdent); + + table.updateProperties() + .set(MERGE_DISTRIBUTION_MODE, WRITE_DISTRIBUTION_MODE_RANGE) + .commit(); + + table.replaceSortOrder() + .asc("id") + .asc("data") + .commit(); + + SortOrder[] expectedOrdering = new SortOrder[]{ + Expressions.sort(Expressions.column("id"), SortDirection.ASCENDING), + Expressions.sort(Expressions.column("data"), SortDirection.ASCENDING) + }; + + Distribution expectedDistribution = Distributions.ordered(expectedOrdering); + + checkCopyOnWriteDistributionAndOrdering(table, MERGE, expectedDistribution, expectedOrdering); + } + + @Test + public void testNoneCopyOnWriteMergePartitionedUnsortedTable() { + sql("CREATE TABLE %s (id BIGINT, data STRING, date DATE, ts TIMESTAMP) " + + "USING iceberg " + + "PARTITIONED BY (date, days(ts))", tableName); + + Table table = validationCatalog.loadTable(tableIdent); + + table.updateProperties() + .set(MERGE_DISTRIBUTION_MODE, WRITE_DISTRIBUTION_MODE_NONE) + .commit(); + + SortOrder[] expectedOrdering = new SortOrder[]{ + Expressions.sort(Expressions.column("date"), SortDirection.ASCENDING), + Expressions.sort(Expressions.days("ts"), SortDirection.ASCENDING) + }; + + checkCopyOnWriteDistributionAndOrdering(table, MERGE, UNSPECIFIED_DISTRIBUTION, expectedOrdering); + } + + @Test + public void testHashCopyOnWriteMergePartitionedUnsortedTable() { + sql("CREATE TABLE %s (id BIGINT, data STRING, date DATE, ts TIMESTAMP) " + + "USING iceberg " + + "PARTITIONED BY (date, days(ts))", tableName); + + Table table = validationCatalog.loadTable(tableIdent); + + table.updateProperties() + .set(MERGE_DISTRIBUTION_MODE, WRITE_DISTRIBUTION_MODE_HASH) + .commit(); + + Expression[] expectedClustering = new Expression[]{ + Expressions.identity("date"), + Expressions.days("ts") + }; + Distribution expectedDistribution = Distributions.clustered(expectedClustering); + + SortOrder[] expectedOrdering = new SortOrder[]{ + Expressions.sort(Expressions.column("date"), SortDirection.ASCENDING), + Expressions.sort(Expressions.days("ts"), SortDirection.ASCENDING) + }; + + checkCopyOnWriteDistributionAndOrdering(table, MERGE, expectedDistribution, expectedOrdering); + } + + @Test + public void testRangeCopyOnWriteMergePartitionedUnsortedTable() { + sql("CREATE TABLE %s (id BIGINT, data STRING, date DATE, ts TIMESTAMP) " + + "USING iceberg " + + "PARTITIONED BY (date, days(ts))", tableName); + + Table table = validationCatalog.loadTable(tableIdent); + + table.updateProperties() + .set(MERGE_DISTRIBUTION_MODE, WRITE_DISTRIBUTION_MODE_RANGE) + .commit(); + + SortOrder[] expectedOrdering = new SortOrder[]{ + Expressions.sort(Expressions.column("date"), SortDirection.ASCENDING), + Expressions.sort(Expressions.days("ts"), SortDirection.ASCENDING) + }; + + Distribution expectedDistribution = Distributions.ordered(expectedOrdering); + + checkCopyOnWriteDistributionAndOrdering(table, MERGE, expectedDistribution, expectedOrdering); + } + + @Test + public void testNoneCopyOnWriteMergePartitionedSortedTable() { + sql("CREATE TABLE %s (id BIGINT, data STRING, date DATE, ts TIMESTAMP) " + + "USING iceberg " + + "PARTITIONED BY (date)", tableName); + + Table table = validationCatalog.loadTable(tableIdent); + + table.updateProperties() + .set(MERGE_DISTRIBUTION_MODE, WRITE_DISTRIBUTION_MODE_NONE) + .commit(); + + table.replaceSortOrder() + .desc("id") + .commit(); + + SortOrder[] expectedOrdering = new SortOrder[]{ + Expressions.sort(Expressions.column("date"), SortDirection.ASCENDING), + Expressions.sort(Expressions.column("id"), SortDirection.DESCENDING) + }; + + checkCopyOnWriteDistributionAndOrdering(table, MERGE, UNSPECIFIED_DISTRIBUTION, expectedOrdering); + } + + @Test + public void testHashCopyOnWriteMergePartitionedSortedTable() { + sql("CREATE TABLE %s (id BIGINT, data STRING, date DATE, ts TIMESTAMP) " + + "USING iceberg " + + "PARTITIONED BY (date, bucket(8, data))", tableName); + + Table table = validationCatalog.loadTable(tableIdent); + + table.updateProperties() + .set(MERGE_DISTRIBUTION_MODE, WRITE_DISTRIBUTION_MODE_HASH) + .commit(); + + table.replaceSortOrder() + .asc("id") + .commit(); + + Expression[] expectedClustering = new Expression[]{ + Expressions.identity("date"), + Expressions.bucket(8, "data") + }; + Distribution expectedDistribution = Distributions.clustered(expectedClustering); + + SortOrder[] expectedOrdering = new SortOrder[]{ + Expressions.sort(Expressions.column("date"), SortDirection.ASCENDING), + Expressions.sort(Expressions.bucket(8, "data"), SortDirection.ASCENDING), + Expressions.sort(Expressions.column("id"), SortDirection.ASCENDING) + }; + + checkCopyOnWriteDistributionAndOrdering(table, MERGE, expectedDistribution, expectedOrdering); + } + + @Test + public void testRangeCopyOnWriteMergePartitionedSortedTable() { + sql("CREATE TABLE %s (id BIGINT, data STRING, date DATE, ts TIMESTAMP) " + + "USING iceberg " + + "PARTITIONED BY (date)", tableName); + + Table table = validationCatalog.loadTable(tableIdent); + + table.updateProperties() + .set(MERGE_DISTRIBUTION_MODE, WRITE_DISTRIBUTION_MODE_RANGE) + .commit(); + + table.replaceSortOrder() + .asc("id") + .commit(); + + SortOrder[] expectedOrdering = new SortOrder[]{ + Expressions.sort(Expressions.column("date"), SortDirection.ASCENDING), + Expressions.sort(Expressions.column("id"), SortDirection.ASCENDING) + }; + + Distribution expectedDistribution = Distributions.ordered(expectedOrdering); + + checkCopyOnWriteDistributionAndOrdering(table, MERGE, expectedDistribution, expectedOrdering); + } + + // =================================================================================== + // Distribution and ordering for merge-on-read DELETE operations with position deletes + // =================================================================================== + // + // delete mode is NOT SET -> CLUSTER BY _spec_id, _partition + LOCALLY ORDER BY _spec_id, _partition, _file, _pos + // delete mode is NONE -> unspecified distribution + LOCALLY ORDER BY _spec_id, _partition, _file, _pos + // delete mode is HASH -> CLUSTER BY _spec_id, _partition + LOCALLY ORDER BY _spec_id, _partition, _file, _pos + // delete mode is RANGE -> RANGE DISTRIBUTE BY _spec_id, _partition, _file + + // LOCALLY ORDER BY _spec_id, _partition, _file, _pos + + @Test + public void testDefaultPositionDeltaDeleteUnpartitionedTable() { + sql("CREATE TABLE %s (id bigint, data string) USING iceberg", tableName); + + Table table = validationCatalog.loadTable(tableIdent); + + checkPositionDeltaDistributionAndOrdering( + table, DELETE, SPEC_ID_PARTITION_CLUSTERED_DISTRIBUTION, SPEC_ID_PARTITION_FILE_POSITION_ORDERING); + } + + @Test + public void testNonePositionDeltaDeleteUnpartitionedTable() { + sql("CREATE TABLE %s (id bigint, data string) USING iceberg", tableName); + + Table table = validationCatalog.loadTable(tableIdent); + + table.updateProperties() + .set(DELETE_DISTRIBUTION_MODE, WRITE_DISTRIBUTION_MODE_NONE) + .commit(); + + checkPositionDeltaDistributionAndOrdering( + table, DELETE, UNSPECIFIED_DISTRIBUTION, SPEC_ID_PARTITION_FILE_POSITION_ORDERING); + } + + @Test + public void testHashPositionDeltaDeleteUnpartitionedTable() { + sql("CREATE TABLE %s (id bigint, data string) USING iceberg", tableName); + + Table table = validationCatalog.loadTable(tableIdent); + + table.updateProperties() + .set(DELETE_DISTRIBUTION_MODE, WRITE_DISTRIBUTION_MODE_HASH) + .commit(); + + checkPositionDeltaDistributionAndOrdering( + table, DELETE, SPEC_ID_PARTITION_CLUSTERED_DISTRIBUTION, SPEC_ID_PARTITION_FILE_POSITION_ORDERING); + } + + @Test + public void testRangePositionDeltaDeleteUnpartitionedTable() { + sql("CREATE TABLE %s (id bigint, data string) USING iceberg", tableName); + + Table table = validationCatalog.loadTable(tableIdent); + + table.updateProperties() + .set(DELETE_DISTRIBUTION_MODE, WRITE_DISTRIBUTION_MODE_RANGE) + .commit(); + + Distribution expectedDistribution = Distributions.ordered(SPEC_ID_PARTITION_FILE_ORDERING); + + checkPositionDeltaDistributionAndOrdering( + table, DELETE, expectedDistribution, SPEC_ID_PARTITION_FILE_POSITION_ORDERING); + } + + @Test + public void testDefaultPositionDeltaDeletePartitionedTable() { + sql("CREATE TABLE %s (id BIGINT, data STRING, date DATE, ts TIMESTAMP) " + + "USING iceberg " + + "PARTITIONED BY (date, bucket(8, data))", tableName); + + Table table = validationCatalog.loadTable(tableIdent); + + checkPositionDeltaDistributionAndOrdering( + table, DELETE, SPEC_ID_PARTITION_CLUSTERED_DISTRIBUTION, SPEC_ID_PARTITION_FILE_POSITION_ORDERING); + } + + @Test + public void testNonePositionDeltaDeletePartitionedTable() { + sql("CREATE TABLE %s (id BIGINT, data STRING, date DATE, ts TIMESTAMP) " + + "USING iceberg " + + "PARTITIONED BY (date, bucket(8, data))", tableName); + + Table table = validationCatalog.loadTable(tableIdent); + + table.updateProperties() + .set(DELETE_DISTRIBUTION_MODE, WRITE_DISTRIBUTION_MODE_NONE) + .commit(); + + checkPositionDeltaDistributionAndOrdering( + table, DELETE, UNSPECIFIED_DISTRIBUTION, SPEC_ID_PARTITION_FILE_POSITION_ORDERING); + } + + @Test + public void testHashPositionDeltaDeletePartitionedTable() { + sql("CREATE TABLE %s (id BIGINT, data STRING, date DATE, ts TIMESTAMP) " + + "USING iceberg " + + "PARTITIONED BY (date, bucket(8, data))", tableName); + + Table table = validationCatalog.loadTable(tableIdent); + + table.updateProperties() + .set(DELETE_DISTRIBUTION_MODE, WRITE_DISTRIBUTION_MODE_HASH) + .commit(); + + checkPositionDeltaDistributionAndOrdering( + table, DELETE, SPEC_ID_PARTITION_CLUSTERED_DISTRIBUTION, SPEC_ID_PARTITION_FILE_POSITION_ORDERING); + } + + @Test + public void testRangePositionDeltaDeletePartitionedTable() { + sql("CREATE TABLE %s (id BIGINT, data STRING, date DATE, ts TIMESTAMP) " + + "USING iceberg " + + "PARTITIONED BY (date, bucket(8, data))", tableName); + + Table table = validationCatalog.loadTable(tableIdent); + + table.updateProperties() + .set(DELETE_DISTRIBUTION_MODE, WRITE_DISTRIBUTION_MODE_RANGE) + .commit(); + + Distribution expectedDistribution = Distributions.ordered(SPEC_ID_PARTITION_FILE_ORDERING); + + checkPositionDeltaDistributionAndOrdering( + table, DELETE, expectedDistribution, SPEC_ID_PARTITION_FILE_POSITION_ORDERING); + } + + // =================================================================================== + // Distribution and ordering for merge-on-read UPDATE operations with position deletes + // =================================================================================== + // + // update mode is NOT SET -> CLUSTER BY _spec_id, _partition + LOCALLY ORDER BY _spec_id, _partition, _file, _pos + // update mode is NONE -> unspecified distribution + LOCALLY ORDER BY _spec_id, _partition, _file, _pos + // update mode is HASH -> CLUSTER BY _spec_id, _partition + LOCALLY ORDER BY _spec_id, _partition, _file, _pos + // update mode is RANGE -> RANGE DISTRIBUTE BY _spec_id, _partition, _file + + // LOCALLY ORDER BY _spec_id, _partition, _file, _pos + + @Test + public void testDefaultPositionDeltaUpdateUnpartitionedTable() { + sql("CREATE TABLE %s (id bigint, data string) USING iceberg", tableName); + + Table table = validationCatalog.loadTable(tableIdent); + + checkPositionDeltaDistributionAndOrdering( + table, UPDATE, SPEC_ID_PARTITION_CLUSTERED_DISTRIBUTION, SPEC_ID_PARTITION_FILE_POSITION_ORDERING); + } + + @Test + public void testNonePositionDeltaUpdateUnpartitionedTable() { + sql("CREATE TABLE %s (id bigint, data string) USING iceberg", tableName); + + Table table = validationCatalog.loadTable(tableIdent); + + table.updateProperties() + .set(UPDATE_DISTRIBUTION_MODE, WRITE_DISTRIBUTION_MODE_NONE) + .commit(); + + checkPositionDeltaDistributionAndOrdering( + table, UPDATE, UNSPECIFIED_DISTRIBUTION, SPEC_ID_PARTITION_FILE_POSITION_ORDERING); + } + + @Test + public void testHashPositionDeltaUpdateUnpartitionedTable() { + sql("CREATE TABLE %s (id bigint, data string) USING iceberg", tableName); + + Table table = validationCatalog.loadTable(tableIdent); + + table.updateProperties() + .set(UPDATE_DISTRIBUTION_MODE, WRITE_DISTRIBUTION_MODE_HASH) + .commit(); + + checkPositionDeltaDistributionAndOrdering( + table, UPDATE, SPEC_ID_PARTITION_CLUSTERED_DISTRIBUTION, SPEC_ID_PARTITION_FILE_POSITION_ORDERING); + } + + @Test + public void testRangePositionDeltaUpdateUnpartitionedTable() { + sql("CREATE TABLE %s (id bigint, data string) USING iceberg", tableName); + + Table table = validationCatalog.loadTable(tableIdent); + + table.updateProperties() + .set(UPDATE_DISTRIBUTION_MODE, WRITE_DISTRIBUTION_MODE_RANGE) + .commit(); + + Distribution expectedDistribution = Distributions.ordered(SPEC_ID_PARTITION_FILE_ORDERING); + + checkPositionDeltaDistributionAndOrdering( + table, UPDATE, expectedDistribution, SPEC_ID_PARTITION_FILE_POSITION_ORDERING); + } + + @Test + public void testDefaultPositionDeltaUpdatePartitionedTable() { + sql("CREATE TABLE %s (id BIGINT, data STRING, date DATE, ts TIMESTAMP) " + + "USING iceberg " + + "PARTITIONED BY (date, bucket(8, data))", tableName); + + Table table = validationCatalog.loadTable(tableIdent); + + checkPositionDeltaDistributionAndOrdering( + table, UPDATE, SPEC_ID_PARTITION_CLUSTERED_DISTRIBUTION, SPEC_ID_PARTITION_FILE_POSITION_ORDERING); + } + + @Test + public void testNonePositionDeltaUpdatePartitionedTable() { + sql("CREATE TABLE %s (id BIGINT, data STRING, date DATE, ts TIMESTAMP) " + + "USING iceberg " + + "PARTITIONED BY (date, bucket(8, data))", tableName); + + Table table = validationCatalog.loadTable(tableIdent); + + table.updateProperties() + .set(UPDATE_DISTRIBUTION_MODE, WRITE_DISTRIBUTION_MODE_NONE) + .commit(); + + checkPositionDeltaDistributionAndOrdering( + table, UPDATE, UNSPECIFIED_DISTRIBUTION, SPEC_ID_PARTITION_FILE_POSITION_ORDERING); + } + + @Test + public void testHashPositionDeltaUpdatePartitionedTable() { + sql("CREATE TABLE %s (id BIGINT, data STRING, date DATE, ts TIMESTAMP) " + + "USING iceberg " + + "PARTITIONED BY (date, bucket(8, data))", tableName); + + Table table = validationCatalog.loadTable(tableIdent); + + table.updateProperties() + .set(UPDATE_DISTRIBUTION_MODE, WRITE_DISTRIBUTION_MODE_HASH) + .commit(); + + checkPositionDeltaDistributionAndOrdering( + table, UPDATE, SPEC_ID_PARTITION_CLUSTERED_DISTRIBUTION, SPEC_ID_PARTITION_FILE_POSITION_ORDERING); + } + + @Test + public void testRangePositionDeltaUpdatePartitionedTable() { + sql("CREATE TABLE %s (id BIGINT, data STRING, date DATE, ts TIMESTAMP) " + + "USING iceberg " + + "PARTITIONED BY (date, bucket(8, data))", tableName); + + Table table = validationCatalog.loadTable(tableIdent); + + table.updateProperties() + .set(UPDATE_DISTRIBUTION_MODE, WRITE_DISTRIBUTION_MODE_RANGE) + .commit(); + + Distribution expectedDistribution = Distributions.ordered(SPEC_ID_PARTITION_FILE_ORDERING); + + checkPositionDeltaDistributionAndOrdering( + table, UPDATE, expectedDistribution, SPEC_ID_PARTITION_FILE_POSITION_ORDERING); + } + + // ================================================================================== + // Distribution and ordering for merge-on-read MERGE operations with position deletes + // ================================================================================== + // + // IMPORTANT: metadata columns like _spec_id and _partition are NULL for new rows + // + // UNPARTITIONED UNORDERED + // ------------------------------------------------------------------------- + // merge mode is NOT SET -> use write mode + // merge mode is NONE -> unspecified distribution + LOCALLY ORDER BY _spec_id, _partition, _file, _pos + // merge mode is HASH -> CLUSTER BY _spec_id, _partition, _file + + // LOCALLY ORDER BY _spec_id, _partition, _file, _pos + // merge mode is RANGE -> RANGE DISTRIBUTE BY _spec_id, _partition, _file + + // LOCALLY ORDER BY _spec_id, _partition, _file, _pos + // + // UNPARTITIONED ORDERED BY id, data + // ------------------------------------------------------------------------- + // merge mode is NOT SET -> use write mode + // merge mode is NONE -> unspecified distribution + + // LOCALLY ORDER BY _spec_id, _partition, _file, _pos, id, data + // merge mode is HASH -> CLUSTER BY _spec_id, _partition, _file + + // LOCALLY ORDER BY _spec_id, _partition, _file, _pos, id, data + // merge mode is RANGE -> RANGE DISTRIBUTE BY _spec_id, _partition, _file, id, data + + // LOCALLY ORDER BY _spec_id, _partition, _file, _pos, id, data + // + // PARTITIONED BY date, days(ts) UNORDERED + // ------------------------------------------------------------------------- + // merge mode is NOT SET -> use write mode + // merge mode is NONE -> unspecified distribution + + // LOCALLY ORDERED BY _spec_id, _partition, _file, _pos, date, days(ts) + // merge mode is HASH -> CLUSTER BY _spec_id, _partition, date, days(ts) + + // LOCALLY ORDER BY _spec_id, _partition, _file, _pos, date, days(ts) + // merge mode is RANGE -> RANGE DISTRIBUTE BY _spec_id, _partition, _file, date, days(ts) + + // LOCALLY ORDER BY _spec_id, _partition, _file, _pos, date, days(ts) + // + // PARTITIONED BY date ORDERED BY id + // ------------------------------------------------------------------------- + // merge mode is NOT SET -> use write mode + // merge mode is NONE -> unspecified distribution + + // LOCALLY ORDERED BY _spec_id, _partition, _file, _pos, date, id + // merge mode is HASH -> CLUSTER BY _spec_id, _partition, date + + // LOCALLY ORDER BY _spec_id, _partition, _file, _pos, date, id + // merge mode is RANGE -> RANGE DISTRIBUTE BY _spec_id, _partition, _file, date, id + // LOCALLY ORDERED BY _spec_id, _partition, _file, _pos, date, id + + @Test + public void testNonePositionDeltaMergeUnpartitionedUnsortedTable() { + sql("CREATE TABLE %s (id bigint, data string) USING iceberg", tableName); + + Table table = validationCatalog.loadTable(tableIdent); + + table.updateProperties() + .set(MERGE_DISTRIBUTION_MODE, WRITE_DISTRIBUTION_MODE_NONE) + .commit(); + + checkPositionDeltaDistributionAndOrdering( + table, MERGE, UNSPECIFIED_DISTRIBUTION, SPEC_ID_PARTITION_FILE_POSITION_ORDERING); + } + + @Test + public void testHashPositionDeltaMergeUnpartitionedUnsortedTable() { + sql("CREATE TABLE %s (id bigint, data string) USING iceberg", tableName); + + Table table = validationCatalog.loadTable(tableIdent); + + table.updateProperties() + .set(MERGE_DISTRIBUTION_MODE, WRITE_DISTRIBUTION_MODE_HASH) + .commit(); + + Expression[] expectedClustering = new Expression[]{ + Expressions.column(MetadataColumns.SPEC_ID.name()), + Expressions.column(MetadataColumns.PARTITION_COLUMN_NAME), + Expressions.column(MetadataColumns.FILE_PATH.name()) + }; + Distribution expectedDistribution = Distributions.clustered(expectedClustering); + + checkPositionDeltaDistributionAndOrdering( + table, MERGE, expectedDistribution, SPEC_ID_PARTITION_FILE_POSITION_ORDERING); + } + + @Test + public void testRangePositionDeltaMergeUnpartitionedUnsortedTable() { + sql("CREATE TABLE %s (id bigint, data string) USING iceberg", tableName); + + Table table = validationCatalog.loadTable(tableIdent); + + table.updateProperties() + .set(MERGE_DISTRIBUTION_MODE, WRITE_DISTRIBUTION_MODE_RANGE) + .commit(); + + SortOrder[] expectedDistributionOrdering = new SortOrder[]{ + Expressions.sort(Expressions.column(MetadataColumns.SPEC_ID.name()), SortDirection.ASCENDING), + Expressions.sort(Expressions.column(MetadataColumns.PARTITION_COLUMN_NAME), SortDirection.ASCENDING), + Expressions.sort(Expressions.column(MetadataColumns.FILE_PATH.name()), SortDirection.ASCENDING) + }; + Distribution expectedDistribution = Distributions.ordered(expectedDistributionOrdering); + + checkPositionDeltaDistributionAndOrdering( + table, MERGE, expectedDistribution, SPEC_ID_PARTITION_FILE_POSITION_ORDERING); + } + + @Test + public void testNonePositionDeltaMergeUnpartitionedSortedTable() { + sql("CREATE TABLE %s (id bigint, data string) USING iceberg", tableName); + + Table table = validationCatalog.loadTable(tableIdent); + + table.updateProperties() + .set(MERGE_DISTRIBUTION_MODE, WRITE_DISTRIBUTION_MODE_NONE) + .commit(); + + table.replaceSortOrder() + .asc("id") + .asc("data") + .commit(); + + SortOrder[] expectedOrdering = new SortOrder[]{ + Expressions.sort(Expressions.column(MetadataColumns.SPEC_ID.name()), SortDirection.ASCENDING), + Expressions.sort(Expressions.column(MetadataColumns.PARTITION_COLUMN_NAME), SortDirection.ASCENDING), + Expressions.sort(Expressions.column(MetadataColumns.FILE_PATH.name()), SortDirection.ASCENDING), + Expressions.sort(Expressions.column(MetadataColumns.ROW_POSITION.name()), SortDirection.ASCENDING), + Expressions.sort(Expressions.column("id"), SortDirection.ASCENDING), + Expressions.sort(Expressions.column("data"), SortDirection.ASCENDING) + }; + + checkPositionDeltaDistributionAndOrdering(table, MERGE, UNSPECIFIED_DISTRIBUTION, expectedOrdering); + } + + @Test + public void testHashPositionDeltaMergeUnpartitionedSortedTable() { + sql("CREATE TABLE %s (id bigint, data string) USING iceberg", tableName); + + Table table = validationCatalog.loadTable(tableIdent); + + table.updateProperties() + .set(MERGE_DISTRIBUTION_MODE, WRITE_DISTRIBUTION_MODE_HASH) + .commit(); + + table.replaceSortOrder() + .asc("id") + .asc("data") + .commit(); + + Expression[] expectedClustering = new Expression[]{ + Expressions.column(MetadataColumns.SPEC_ID.name()), + Expressions.column(MetadataColumns.PARTITION_COLUMN_NAME), + Expressions.column(MetadataColumns.FILE_PATH.name()) + }; + Distribution expectedDistribution = Distributions.clustered(expectedClustering); + + SortOrder[] expectedOrdering = new SortOrder[]{ + Expressions.sort(Expressions.column(MetadataColumns.SPEC_ID.name()), SortDirection.ASCENDING), + Expressions.sort(Expressions.column(MetadataColumns.PARTITION_COLUMN_NAME), SortDirection.ASCENDING), + Expressions.sort(Expressions.column(MetadataColumns.FILE_PATH.name()), SortDirection.ASCENDING), + Expressions.sort(Expressions.column(MetadataColumns.ROW_POSITION.name()), SortDirection.ASCENDING), + Expressions.sort(Expressions.column("id"), SortDirection.ASCENDING), + Expressions.sort(Expressions.column("data"), SortDirection.ASCENDING) + }; + + checkPositionDeltaDistributionAndOrdering(table, MERGE, expectedDistribution, expectedOrdering); + } + + @Test + public void testRangePositionDeltaMergeUnpartitionedSortedTable() { + sql("CREATE TABLE %s (id bigint, data string) USING iceberg", tableName); + + Table table = validationCatalog.loadTable(tableIdent); + + table.updateProperties() + .set(MERGE_DISTRIBUTION_MODE, WRITE_DISTRIBUTION_MODE_RANGE) + .commit(); + + table.replaceSortOrder() + .asc("id") + .asc("data") + .commit(); + + SortOrder[] expectedDistributionOrdering = new SortOrder[]{ + Expressions.sort(Expressions.column(MetadataColumns.SPEC_ID.name()), SortDirection.ASCENDING), + Expressions.sort(Expressions.column(MetadataColumns.PARTITION_COLUMN_NAME), SortDirection.ASCENDING), + Expressions.sort(Expressions.column(MetadataColumns.FILE_PATH.name()), SortDirection.ASCENDING), + Expressions.sort(Expressions.column("id"), SortDirection.ASCENDING), + Expressions.sort(Expressions.column("data"), SortDirection.ASCENDING) + }; + Distribution expectedDistribution = Distributions.ordered(expectedDistributionOrdering); + + SortOrder[] expectedOrdering = new SortOrder[]{ + Expressions.sort(Expressions.column(MetadataColumns.SPEC_ID.name()), SortDirection.ASCENDING), + Expressions.sort(Expressions.column(MetadataColumns.PARTITION_COLUMN_NAME), SortDirection.ASCENDING), + Expressions.sort(Expressions.column(MetadataColumns.FILE_PATH.name()), SortDirection.ASCENDING), + Expressions.sort(Expressions.column(MetadataColumns.ROW_POSITION.name()), SortDirection.ASCENDING), + Expressions.sort(Expressions.column("id"), SortDirection.ASCENDING), + Expressions.sort(Expressions.column("data"), SortDirection.ASCENDING) + }; + + checkPositionDeltaDistributionAndOrdering(table, MERGE, expectedDistribution, expectedOrdering); + } + + @Test + public void testNonePositionDeltaMergePartitionedUnsortedTable() { + sql("CREATE TABLE %s (id BIGINT, data STRING, date DATE, ts TIMESTAMP) " + + "USING iceberg " + + "PARTITIONED BY (date, days(ts))", tableName); + + Table table = validationCatalog.loadTable(tableIdent); + + table.updateProperties() + .set(MERGE_DISTRIBUTION_MODE, WRITE_DISTRIBUTION_MODE_NONE) + .commit(); + + SortOrder[] expectedOrdering = new SortOrder[]{ + Expressions.sort(Expressions.column(MetadataColumns.SPEC_ID.name()), SortDirection.ASCENDING), + Expressions.sort(Expressions.column(MetadataColumns.PARTITION_COLUMN_NAME), SortDirection.ASCENDING), + Expressions.sort(Expressions.column(MetadataColumns.FILE_PATH.name()), SortDirection.ASCENDING), + Expressions.sort(Expressions.column(MetadataColumns.ROW_POSITION.name()), SortDirection.ASCENDING), + Expressions.sort(Expressions.column("date"), SortDirection.ASCENDING), + Expressions.sort(Expressions.days("ts"), SortDirection.ASCENDING) + }; + + checkPositionDeltaDistributionAndOrdering(table, MERGE, UNSPECIFIED_DISTRIBUTION, expectedOrdering); + } + + @Test + public void testHashPositionDeltaMergePartitionedUnsortedTable() { + sql("CREATE TABLE %s (id BIGINT, data STRING, date DATE, ts TIMESTAMP) " + + "USING iceberg " + + "PARTITIONED BY (date, days(ts))", tableName); + + Table table = validationCatalog.loadTable(tableIdent); + + table.updateProperties() + .set(MERGE_DISTRIBUTION_MODE, WRITE_DISTRIBUTION_MODE_HASH) + .commit(); + + Expression[] expectedClustering = new Expression[]{ + Expressions.column(MetadataColumns.SPEC_ID.name()), + Expressions.column(MetadataColumns.PARTITION_COLUMN_NAME), + Expressions.identity("date"), + Expressions.days("ts") + }; + Distribution expectedDistribution = Distributions.clustered(expectedClustering); + + SortOrder[] expectedOrdering = new SortOrder[]{ + Expressions.sort(Expressions.column(MetadataColumns.SPEC_ID.name()), SortDirection.ASCENDING), + Expressions.sort(Expressions.column(MetadataColumns.PARTITION_COLUMN_NAME), SortDirection.ASCENDING), + Expressions.sort(Expressions.column(MetadataColumns.FILE_PATH.name()), SortDirection.ASCENDING), + Expressions.sort(Expressions.column(MetadataColumns.ROW_POSITION.name()), SortDirection.ASCENDING), + Expressions.sort(Expressions.column("date"), SortDirection.ASCENDING), + Expressions.sort(Expressions.days("ts"), SortDirection.ASCENDING) + }; + + checkPositionDeltaDistributionAndOrdering(table, MERGE, expectedDistribution, expectedOrdering); + } + + @Test + public void testRangePositionDeltaMergePartitionedUnsortedTable() { + sql("CREATE TABLE %s (id BIGINT, data STRING, date DATE, ts TIMESTAMP) " + + "USING iceberg " + + "PARTITIONED BY (date, days(ts))", tableName); + + Table table = validationCatalog.loadTable(tableIdent); + + table.updateProperties() + .set(MERGE_DISTRIBUTION_MODE, WRITE_DISTRIBUTION_MODE_RANGE) + .commit(); + + SortOrder[] expectedDistributionOrdering = new SortOrder[]{ + Expressions.sort(Expressions.column(MetadataColumns.SPEC_ID.name()), SortDirection.ASCENDING), + Expressions.sort(Expressions.column(MetadataColumns.PARTITION_COLUMN_NAME), SortDirection.ASCENDING), + Expressions.sort(Expressions.column(MetadataColumns.FILE_PATH.name()), SortDirection.ASCENDING), + Expressions.sort(Expressions.column("date"), SortDirection.ASCENDING), + Expressions.sort(Expressions.days("ts"), SortDirection.ASCENDING) + }; + Distribution expectedDistribution = Distributions.ordered(expectedDistributionOrdering); + + SortOrder[] expectedOrdering = new SortOrder[]{ + Expressions.sort(Expressions.column(MetadataColumns.SPEC_ID.name()), SortDirection.ASCENDING), + Expressions.sort(Expressions.column(MetadataColumns.PARTITION_COLUMN_NAME), SortDirection.ASCENDING), + Expressions.sort(Expressions.column(MetadataColumns.FILE_PATH.name()), SortDirection.ASCENDING), + Expressions.sort(Expressions.column(MetadataColumns.ROW_POSITION.name()), SortDirection.ASCENDING), + Expressions.sort(Expressions.column("date"), SortDirection.ASCENDING), + Expressions.sort(Expressions.days("ts"), SortDirection.ASCENDING) + }; + + checkPositionDeltaDistributionAndOrdering(table, MERGE, expectedDistribution, expectedOrdering); + } + + @Test + public void testNonePositionDeltaMergePartitionedSortedTable() { + sql("CREATE TABLE %s (id BIGINT, data STRING, date DATE, ts TIMESTAMP) " + + "USING iceberg " + + "PARTITIONED BY (date)", tableName); + + Table table = validationCatalog.loadTable(tableIdent); + + table.updateProperties() + .set(MERGE_DISTRIBUTION_MODE, WRITE_DISTRIBUTION_MODE_NONE) + .commit(); + + table.replaceSortOrder() + .desc("id") + .commit(); + + SortOrder[] expectedOrdering = new SortOrder[]{ + Expressions.sort(Expressions.column(MetadataColumns.SPEC_ID.name()), SortDirection.ASCENDING), + Expressions.sort(Expressions.column(MetadataColumns.PARTITION_COLUMN_NAME), SortDirection.ASCENDING), + Expressions.sort(Expressions.column(MetadataColumns.FILE_PATH.name()), SortDirection.ASCENDING), + Expressions.sort(Expressions.column(MetadataColumns.ROW_POSITION.name()), SortDirection.ASCENDING), + Expressions.sort(Expressions.column("date"), SortDirection.ASCENDING), + Expressions.sort(Expressions.column("id"), SortDirection.DESCENDING) + }; + + checkPositionDeltaDistributionAndOrdering(table, MERGE, UNSPECIFIED_DISTRIBUTION, expectedOrdering); + } + + @Test + public void testHashPositionDeltaMergePartitionedSortedTable() { + sql("CREATE TABLE %s (id BIGINT, data STRING, date DATE, ts TIMESTAMP) " + + "USING iceberg " + + "PARTITIONED BY (date, bucket(8, data))", tableName); + + Table table = validationCatalog.loadTable(tableIdent); + + table.updateProperties() + .set(MERGE_DISTRIBUTION_MODE, WRITE_DISTRIBUTION_MODE_HASH) + .commit(); + + table.replaceSortOrder() + .asc("id") + .commit(); + + Expression[] expectedClustering = new Expression[]{ + Expressions.column(MetadataColumns.SPEC_ID.name()), + Expressions.column(MetadataColumns.PARTITION_COLUMN_NAME), + Expressions.identity("date"), + Expressions.bucket(8, "data") + }; + Distribution expectedDistribution = Distributions.clustered(expectedClustering); + + SortOrder[] expectedOrdering = new SortOrder[]{ + Expressions.sort(Expressions.column(MetadataColumns.SPEC_ID.name()), SortDirection.ASCENDING), + Expressions.sort(Expressions.column(MetadataColumns.PARTITION_COLUMN_NAME), SortDirection.ASCENDING), + Expressions.sort(Expressions.column(MetadataColumns.FILE_PATH.name()), SortDirection.ASCENDING), + Expressions.sort(Expressions.column(MetadataColumns.ROW_POSITION.name()), SortDirection.ASCENDING), + Expressions.sort(Expressions.column("date"), SortDirection.ASCENDING), + Expressions.sort(Expressions.bucket(8, "data"), SortDirection.ASCENDING), + Expressions.sort(Expressions.column("id"), SortDirection.ASCENDING) + }; + + checkPositionDeltaDistributionAndOrdering(table, MERGE, expectedDistribution, expectedOrdering); + } + + @Test + public void testRangePositionDeltaMergePartitionedSortedTable() { + sql("CREATE TABLE %s (id BIGINT, data STRING, date DATE, ts TIMESTAMP) " + + "USING iceberg " + + "PARTITIONED BY (date)", tableName); + + Table table = validationCatalog.loadTable(tableIdent); + + table.updateProperties() + .set(MERGE_DISTRIBUTION_MODE, WRITE_DISTRIBUTION_MODE_RANGE) + .commit(); + + table.replaceSortOrder() + .asc("id") + .commit(); + + SortOrder[] expectedDistributionOrdering = new SortOrder[]{ + Expressions.sort(Expressions.column(MetadataColumns.SPEC_ID.name()), SortDirection.ASCENDING), + Expressions.sort(Expressions.column(MetadataColumns.PARTITION_COLUMN_NAME), SortDirection.ASCENDING), + Expressions.sort(Expressions.column(MetadataColumns.FILE_PATH.name()), SortDirection.ASCENDING), + Expressions.sort(Expressions.column("date"), SortDirection.ASCENDING), + Expressions.sort(Expressions.column("id"), SortDirection.ASCENDING) + }; + Distribution expectedDistribution = Distributions.ordered(expectedDistributionOrdering); + + SortOrder[] expectedOrdering = new SortOrder[]{ + Expressions.sort(Expressions.column(MetadataColumns.SPEC_ID.name()), SortDirection.ASCENDING), + Expressions.sort(Expressions.column(MetadataColumns.PARTITION_COLUMN_NAME), SortDirection.ASCENDING), + Expressions.sort(Expressions.column(MetadataColumns.FILE_PATH.name()), SortDirection.ASCENDING), + Expressions.sort(Expressions.column(MetadataColumns.ROW_POSITION.name()), SortDirection.ASCENDING), + Expressions.sort(Expressions.column("date"), SortDirection.ASCENDING), + Expressions.sort(Expressions.column("id"), SortDirection.ASCENDING) + }; + + checkPositionDeltaDistributionAndOrdering(table, MERGE, expectedDistribution, expectedOrdering); + } + + private void checkWriteDistributionAndOrdering(Table table, Distribution expectedDistribution, + SortOrder[] expectedOrdering) { + SparkWriteConf writeConf = new SparkWriteConf(spark, table, ImmutableMap.of()); + DistributionMode distributionMode = writeConf.distributionMode(); + Distribution distribution = SparkDistributionAndOrderingUtil.buildRequiredDistribution(table, distributionMode); + Assert.assertEquals("Distribution must match", expectedDistribution, distribution); + + SortOrder[] ordering = SparkDistributionAndOrderingUtil.buildRequiredOrdering(table, distribution); + Assert.assertArrayEquals("Ordering must match", expectedOrdering, ordering); + } + + private void checkCopyOnWriteDistributionAndOrdering(Table table, Command command, + Distribution expectedDistribution, + SortOrder[] expectedOrdering) { + SparkWriteConf writeConf = new SparkWriteConf(spark, table, ImmutableMap.of()); + + DistributionMode mode = copyOnWriteDistributionMode(command, writeConf); + + Distribution distribution = SparkDistributionAndOrderingUtil.buildCopyOnWriteDistribution(table, command, mode); + Assert.assertEquals("Distribution must match", expectedDistribution, distribution); + + SortOrder[] ordering = SparkDistributionAndOrderingUtil.buildCopyOnWriteOrdering(table, command, distribution); + Assert.assertArrayEquals("Ordering must match", expectedOrdering, ordering); + } + + private DistributionMode copyOnWriteDistributionMode(Command command, SparkWriteConf writeConf) { + switch (command) { + case DELETE: + return writeConf.deleteDistributionMode(); + case UPDATE: + return writeConf.updateDistributionMode(); + case MERGE: + return writeConf.copyOnWriteMergeDistributionMode(); + default: + throw new IllegalArgumentException("Unexpected command: " + command); + } + } + + private void checkPositionDeltaDistributionAndOrdering(Table table, Command command, + Distribution expectedDistribution, + SortOrder[] expectedOrdering) { + SparkWriteConf writeConf = new SparkWriteConf(spark, table, ImmutableMap.of()); + + DistributionMode mode = positionDeltaDistributionMode(command, writeConf); + + Distribution distribution = SparkDistributionAndOrderingUtil.buildPositionDeltaDistribution(table, command, mode); + Assert.assertEquals("Distribution must match", expectedDistribution, distribution); + + SortOrder[] ordering = SparkDistributionAndOrderingUtil.buildPositionDeltaOrdering(table, command); + Assert.assertArrayEquals("Ordering must match", expectedOrdering, ordering); + } + + private DistributionMode positionDeltaDistributionMode(Command command, SparkWriteConf writeConf) { + switch (command) { + case DELETE: + return writeConf.deleteDistributionMode(); + case UPDATE: + return writeConf.updateDistributionMode(); + case MERGE: + return writeConf.positionDeltaMergeDistributionMode(); + default: + throw new IllegalArgumentException("Unexpected command: " + command); + } + } +} diff --git a/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/TestSparkFilters.java b/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/TestSparkFilters.java new file mode 100644 index 000000000000..ccd526a54618 --- /dev/null +++ b/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/TestSparkFilters.java @@ -0,0 +1,151 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iceberg.spark; + +import java.sql.Date; +import java.sql.Timestamp; +import java.time.Instant; +import java.time.LocalDate; +import java.time.temporal.ChronoUnit; +import java.util.Map; +import org.apache.iceberg.expressions.Expression; +import org.apache.iceberg.expressions.Expressions; +import org.apache.iceberg.relocated.com.google.common.collect.Maps; +import org.apache.spark.sql.sources.And; +import org.apache.spark.sql.sources.EqualNullSafe; +import org.apache.spark.sql.sources.EqualTo; +import org.apache.spark.sql.sources.GreaterThan; +import org.apache.spark.sql.sources.GreaterThanOrEqual; +import org.apache.spark.sql.sources.In; +import org.apache.spark.sql.sources.IsNotNull; +import org.apache.spark.sql.sources.IsNull; +import org.apache.spark.sql.sources.LessThan; +import org.apache.spark.sql.sources.LessThanOrEqual; +import org.apache.spark.sql.sources.Not; +import org.junit.Assert; +import org.junit.Test; + +public class TestSparkFilters { + + @Test + public void testQuotedAttributes() { + Map attrMap = Maps.newHashMap(); + attrMap.put("id", "id"); + attrMap.put("`i.d`", "i.d"); + attrMap.put("`i``d`", "i`d"); + attrMap.put("`d`.b.`dd```", "d.b.dd`"); + attrMap.put("a.`aa```.c", "a.aa`.c"); + + attrMap.forEach((quoted, unquoted) -> { + IsNull isNull = IsNull.apply(quoted); + Expression expectedIsNull = Expressions.isNull(unquoted); + Expression actualIsNull = SparkFilters.convert(isNull); + Assert.assertEquals("IsNull must match", expectedIsNull.toString(), actualIsNull.toString()); + + IsNotNull isNotNull = IsNotNull.apply(quoted); + Expression expectedIsNotNull = Expressions.notNull(unquoted); + Expression actualIsNotNull = SparkFilters.convert(isNotNull); + Assert.assertEquals("IsNotNull must match", expectedIsNotNull.toString(), actualIsNotNull.toString()); + + LessThan lt = LessThan.apply(quoted, 1); + Expression expectedLt = Expressions.lessThan(unquoted, 1); + Expression actualLt = SparkFilters.convert(lt); + Assert.assertEquals("LessThan must match", expectedLt.toString(), actualLt.toString()); + + LessThanOrEqual ltEq = LessThanOrEqual.apply(quoted, 1); + Expression expectedLtEq = Expressions.lessThanOrEqual(unquoted, 1); + Expression actualLtEq = SparkFilters.convert(ltEq); + Assert.assertEquals("LessThanOrEqual must match", expectedLtEq.toString(), actualLtEq.toString()); + + GreaterThan gt = GreaterThan.apply(quoted, 1); + Expression expectedGt = Expressions.greaterThan(unquoted, 1); + Expression actualGt = SparkFilters.convert(gt); + Assert.assertEquals("GreaterThan must match", expectedGt.toString(), actualGt.toString()); + + GreaterThanOrEqual gtEq = GreaterThanOrEqual.apply(quoted, 1); + Expression expectedGtEq = Expressions.greaterThanOrEqual(unquoted, 1); + Expression actualGtEq = SparkFilters.convert(gtEq); + Assert.assertEquals("GreaterThanOrEqual must match", expectedGtEq.toString(), actualGtEq.toString()); + + EqualTo eq = EqualTo.apply(quoted, 1); + Expression expectedEq = Expressions.equal(unquoted, 1); + Expression actualEq = SparkFilters.convert(eq); + Assert.assertEquals("EqualTo must match", expectedEq.toString(), actualEq.toString()); + + EqualNullSafe eqNullSafe = EqualNullSafe.apply(quoted, 1); + Expression expectedEqNullSafe = Expressions.equal(unquoted, 1); + Expression actualEqNullSafe = SparkFilters.convert(eqNullSafe); + Assert.assertEquals("EqualNullSafe must match", expectedEqNullSafe.toString(), actualEqNullSafe.toString()); + + In in = In.apply(quoted, new Integer[]{1}); + Expression expectedIn = Expressions.in(unquoted, 1); + Expression actualIn = SparkFilters.convert(in); + Assert.assertEquals("In must match", expectedIn.toString(), actualIn.toString()); + }); + } + + @Test + public void testTimestampFilterConversion() { + Instant instant = Instant.parse("2018-10-18T00:00:57.907Z"); + Timestamp timestamp = Timestamp.from(instant); + long epochMicros = ChronoUnit.MICROS.between(Instant.EPOCH, instant); + + Expression instantExpression = SparkFilters.convert(GreaterThan.apply("x", instant)); + Expression timestampExpression = SparkFilters.convert(GreaterThan.apply("x", timestamp)); + Expression rawExpression = Expressions.greaterThan("x", epochMicros); + + Assert.assertEquals("Generated Timestamp expression should be correct", + rawExpression.toString(), timestampExpression.toString()); + Assert.assertEquals("Generated Instant expression should be correct", + rawExpression.toString(), instantExpression.toString()); + } + + @Test + public void testDateFilterConversion() { + LocalDate localDate = LocalDate.parse("2018-10-18"); + Date date = Date.valueOf(localDate); + long epochDay = localDate.toEpochDay(); + + Expression localDateExpression = SparkFilters.convert(GreaterThan.apply("x", localDate)); + Expression dateExpression = SparkFilters.convert(GreaterThan.apply("x", date)); + Expression rawExpression = Expressions.greaterThan("x", epochDay); + + Assert.assertEquals("Generated localdate expression should be correct", + rawExpression.toString(), localDateExpression.toString()); + + Assert.assertEquals("Generated date expression should be correct", + rawExpression.toString(), dateExpression.toString()); + } + + @Test + public void testNestedInInsideNot() { + Not filter = Not.apply(And.apply(EqualTo.apply("col1", 1), In.apply("col2", new Integer[]{1, 2}))); + Expression converted = SparkFilters.convert(filter); + Assert.assertNull("Expression should not be converted", converted); + } + + @Test + public void testNotIn() { + Not filter = Not.apply(In.apply("col", new Integer[]{1, 2})); + Expression actual = SparkFilters.convert(filter); + Expression expected = Expressions.and(Expressions.notNull("col"), Expressions.notIn("col", 1, 2)); + Assert.assertEquals("Expressions should match", expected.toString(), actual.toString()); + } +} diff --git a/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/TestSparkSchemaUtil.java b/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/TestSparkSchemaUtil.java new file mode 100644 index 000000000000..38fe5734fe5f --- /dev/null +++ b/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/TestSparkSchemaUtil.java @@ -0,0 +1,82 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iceberg.spark; + +import java.io.IOException; +import java.util.List; +import org.apache.iceberg.MetadataColumns; +import org.apache.iceberg.Schema; +import org.apache.iceberg.types.Types; +import org.apache.spark.sql.catalyst.expressions.AttributeReference; +import org.apache.spark.sql.catalyst.expressions.MetadataAttribute; +import org.apache.spark.sql.types.StructType; +import org.junit.Assert; +import org.junit.Test; + +import static org.apache.iceberg.types.Types.NestedField.optional; + +public class TestSparkSchemaUtil { + private static final Schema TEST_SCHEMA = new Schema( + optional(1, "id", Types.IntegerType.get()), + optional(2, "data", Types.StringType.get()) + ); + + private static final Schema TEST_SCHEMA_WITH_METADATA_COLS = new Schema( + optional(1, "id", Types.IntegerType.get()), + optional(2, "data", Types.StringType.get()), + MetadataColumns.FILE_PATH, + MetadataColumns.ROW_POSITION + ); + + @Test + public void testEstimateSizeMaxValue() throws IOException { + Assert.assertEquals("estimateSize returns Long max value", Long.MAX_VALUE, + SparkSchemaUtil.estimateSize( + null, + Long.MAX_VALUE)); + } + + @Test + public void testEstimateSizeWithOverflow() throws IOException { + long tableSize = SparkSchemaUtil.estimateSize(SparkSchemaUtil.convert(TEST_SCHEMA), Long.MAX_VALUE - 1); + Assert.assertEquals("estimateSize handles overflow", Long.MAX_VALUE, tableSize); + } + + @Test + public void testEstimateSize() throws IOException { + long tableSize = SparkSchemaUtil.estimateSize(SparkSchemaUtil.convert(TEST_SCHEMA), 1); + Assert.assertEquals("estimateSize matches with expected approximation", 24, tableSize); + } + + @Test + public void testSchemaConversionWithMetaDataColumnSchema() { + StructType structType = SparkSchemaUtil.convert(TEST_SCHEMA_WITH_METADATA_COLS); + List attrRefs = scala.collection.JavaConverters.seqAsJavaList(structType.toAttributes()); + for (AttributeReference attrRef : attrRefs) { + if (MetadataColumns.isMetadataColumn(attrRef.name())) { + Assert.assertTrue("metadata columns should have __metadata_col in attribute metadata", + MetadataAttribute.unapply(attrRef).isDefined()); + } else { + Assert.assertFalse("non metadata columns should not have __metadata_col in attribute metadata", + MetadataAttribute.unapply(attrRef).isDefined()); + } + } + } +} diff --git a/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/TestSparkSessionCatalog.java b/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/TestSparkSessionCatalog.java new file mode 100644 index 000000000000..3ab2d6b23a6f --- /dev/null +++ b/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/TestSparkSessionCatalog.java @@ -0,0 +1,66 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iceberg.spark; + +import org.junit.Assert; +import org.junit.Test; + +import static org.apache.hadoop.hive.conf.HiveConf.ConfVars.METASTOREURIS; + +public class TestSparkSessionCatalog extends SparkTestBase { + @Test + public void testValidateHmsUri() { + String envHmsUriKey = "spark.hadoop." + METASTOREURIS.varname; + String catalogHmsUriKey = "spark.sql.catalog.spark_catalog.uri"; + String hmsUri = hiveConf.get(METASTOREURIS.varname); + + spark.conf().set("spark.sql.catalog.spark_catalog", "org.apache.iceberg.spark.SparkSessionCatalog"); + spark.conf().set("spark.sql.catalog.spark_catalog.type", "hive"); + + // HMS uris match + spark.sessionState().catalogManager().reset(); + spark.conf().set(envHmsUriKey, hmsUri); + spark.conf().set(catalogHmsUriKey, hmsUri); + Assert.assertTrue(spark.sessionState().catalogManager().v2SessionCatalog().defaultNamespace()[0].equals("default")); + + // HMS uris doesn't match + spark.sessionState().catalogManager().reset(); + String catalogHmsUri = "RandomString"; + spark.conf().set(envHmsUriKey, hmsUri); + spark.conf().set(catalogHmsUriKey, catalogHmsUri); + IllegalArgumentException exception = Assert.assertThrows(IllegalArgumentException.class, + () -> spark.sessionState().catalogManager().v2SessionCatalog()); + String errorMessage = String.format("Inconsistent Hive metastore URIs: %s (Spark session) != %s (spark_catalog)", + hmsUri, catalogHmsUri); + Assert.assertEquals(errorMessage, exception.getMessage()); + + // no env HMS uri, only catalog HMS uri + spark.sessionState().catalogManager().reset(); + spark.conf().set(catalogHmsUriKey, hmsUri); + spark.conf().unset(envHmsUriKey); + Assert.assertTrue(spark.sessionState().catalogManager().v2SessionCatalog().defaultNamespace()[0].equals("default")); + + // no catalog HMS uri, only env HMS uri + spark.sessionState().catalogManager().reset(); + spark.conf().set(envHmsUriKey, hmsUri); + spark.conf().unset(catalogHmsUriKey); + Assert.assertTrue(spark.sessionState().catalogManager().v2SessionCatalog().defaultNamespace()[0].equals("default")); + } +} diff --git a/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/TestSparkTableUtil.java b/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/TestSparkTableUtil.java new file mode 100644 index 000000000000..61af0185d8bd --- /dev/null +++ b/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/TestSparkTableUtil.java @@ -0,0 +1,87 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iceberg.spark; + +import java.io.IOException; +import java.util.Map; +import org.apache.iceberg.KryoHelpers; +import org.apache.iceberg.MetricsConfig; +import org.apache.iceberg.MetricsModes; +import org.apache.iceberg.TableProperties; +import org.apache.iceberg.TestHelpers; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap; +import org.apache.iceberg.spark.SparkTableUtil.SparkPartition; +import org.assertj.core.api.Assertions; +import org.junit.Assert; +import org.junit.Test; + +public class TestSparkTableUtil { + @Test + public void testSparkPartitionOKryoSerialization() throws IOException { + Map values = ImmutableMap.of("id", "2"); + String uri = "s3://bucket/table/data/id=2"; + String format = "parquet"; + SparkPartition sparkPartition = new SparkPartition(values, uri, format); + + SparkPartition deserialized = KryoHelpers.roundTripSerialize(sparkPartition); + Assertions.assertThat(sparkPartition).isEqualTo(deserialized); + } + + @Test + public void testSparkPartitionJavaSerialization() throws IOException, ClassNotFoundException { + Map values = ImmutableMap.of("id", "2"); + String uri = "s3://bucket/table/data/id=2"; + String format = "parquet"; + SparkPartition sparkPartition = new SparkPartition(values, uri, format); + + SparkPartition deserialized = TestHelpers.roundTripSerialize(sparkPartition); + Assertions.assertThat(sparkPartition).isEqualTo(deserialized); + } + + @Test + public void testMetricsConfigKryoSerialization() throws Exception { + Map metricsConfig = ImmutableMap.of( + TableProperties.DEFAULT_WRITE_METRICS_MODE, "counts", + TableProperties.METRICS_MODE_COLUMN_CONF_PREFIX + "col1", "full", + TableProperties.METRICS_MODE_COLUMN_CONF_PREFIX + "col2", "truncate(16)"); + + MetricsConfig config = MetricsConfig.fromProperties(metricsConfig); + MetricsConfig deserialized = KryoHelpers.roundTripSerialize(config); + + Assert.assertEquals(MetricsModes.Full.get().toString(), deserialized.columnMode("col1").toString()); + Assert.assertEquals(MetricsModes.Truncate.withLength(16).toString(), deserialized.columnMode("col2").toString()); + Assert.assertEquals(MetricsModes.Counts.get().toString(), deserialized.columnMode("col3").toString()); + } + + @Test + public void testMetricsConfigJavaSerialization() throws Exception { + Map metricsConfig = ImmutableMap.of( + TableProperties.DEFAULT_WRITE_METRICS_MODE, "counts", + TableProperties.METRICS_MODE_COLUMN_CONF_PREFIX + "col1", "full", + TableProperties.METRICS_MODE_COLUMN_CONF_PREFIX + "col2", "truncate(16)"); + + MetricsConfig config = MetricsConfig.fromProperties(metricsConfig); + MetricsConfig deserialized = TestHelpers.roundTripSerialize(config); + + Assert.assertEquals(MetricsModes.Full.get().toString(), deserialized.columnMode("col1").toString()); + Assert.assertEquals(MetricsModes.Truncate.withLength(16).toString(), deserialized.columnMode("col2").toString()); + Assert.assertEquals(MetricsModes.Counts.get().toString(), deserialized.columnMode("col3").toString()); + } +} diff --git a/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/TestSparkValueConverter.java b/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/TestSparkValueConverter.java new file mode 100644 index 000000000000..57941b8c7940 --- /dev/null +++ b/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/TestSparkValueConverter.java @@ -0,0 +1,90 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iceberg.spark; + +import org.apache.iceberg.Schema; +import org.apache.iceberg.data.GenericRecord; +import org.apache.iceberg.data.Record; +import org.apache.iceberg.types.Types; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.RowFactory; +import org.junit.Assert; +import org.junit.Test; + +public class TestSparkValueConverter { + @Test + public void testSparkNullMapConvert() { + Schema schema = new Schema( + Types.NestedField.required(0, "id", Types.LongType.get()), + Types.NestedField.optional(5, "locations", Types.MapType.ofOptional(6, 7, + Types.StringType.get(), + Types.StructType.of( + Types.NestedField.required(1, "lat", Types.FloatType.get()), + Types.NestedField.required(2, "long", Types.FloatType.get()) + ) + )) + ); + + assertCorrectNullConversion(schema); + } + + @Test + public void testSparkNullListConvert() { + Schema schema = new Schema( + Types.NestedField.required(0, "id", Types.LongType.get()), + Types.NestedField.optional(5, "locations", + Types.ListType.ofOptional(6, Types.StringType.get()) + ) + ); + + assertCorrectNullConversion(schema); + } + + @Test + public void testSparkNullStructConvert() { + Schema schema = new Schema( + Types.NestedField.required(0, "id", Types.LongType.get()), + Types.NestedField.optional(5, "location", Types.StructType.of( + Types.NestedField.required(1, "lat", Types.FloatType.get()), + Types.NestedField.required(2, "long", Types.FloatType.get()) + )) + ); + + assertCorrectNullConversion(schema); + } + + @Test + public void testSparkNullPrimitiveConvert() { + Schema schema = new Schema( + Types.NestedField.required(0, "id", Types.LongType.get()), + Types.NestedField.optional(5, "location", Types.StringType.get()) + ); + assertCorrectNullConversion(schema); + } + + private void assertCorrectNullConversion(Schema schema) { + Row sparkRow = RowFactory.create(1, null); + Record record = GenericRecord.create(schema); + record.set(0, 1); + Assert.assertEquals("Round-trip conversion should produce original value", + record, + SparkValueConverter.convert(schema, sparkRow)); + } +} diff --git a/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/actions/TestCreateActions.java b/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/actions/TestCreateActions.java new file mode 100644 index 000000000000..77ff233e25b6 --- /dev/null +++ b/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/actions/TestCreateActions.java @@ -0,0 +1,899 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iceberg.spark.actions; + +import java.io.File; +import java.io.FilenameFilter; +import java.io.IOException; +import java.net.URI; +import java.nio.file.Paths; +import java.util.Arrays; +import java.util.Collections; +import java.util.List; +import java.util.Map; +import java.util.stream.Collectors; +import org.apache.commons.io.FileUtils; +import org.apache.commons.io.filefilter.TrueFileFilter; +import org.apache.hadoop.fs.Path; +import org.apache.iceberg.Schema; +import org.apache.iceberg.Table; +import org.apache.iceberg.actions.MigrateTable; +import org.apache.iceberg.actions.SnapshotTable; +import org.apache.iceberg.relocated.com.google.common.base.Preconditions; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap; +import org.apache.iceberg.relocated.com.google.common.collect.Lists; +import org.apache.iceberg.relocated.com.google.common.collect.Maps; +import org.apache.iceberg.spark.Spark3Util; +import org.apache.iceberg.spark.SparkCatalog; +import org.apache.iceberg.spark.SparkCatalogTestBase; +import org.apache.iceberg.spark.SparkSessionCatalog; +import org.apache.iceberg.spark.source.SimpleRecord; +import org.apache.iceberg.spark.source.SparkTable; +import org.apache.iceberg.types.Types; +import org.apache.parquet.hadoop.ParquetFileReader; +import org.apache.parquet.hadoop.util.HadoopInputFile; +import org.apache.parquet.schema.MessageType; +import org.apache.parquet.schema.MessageTypeParser; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.SaveMode; +import org.apache.spark.sql.catalyst.TableIdentifier; +import org.apache.spark.sql.catalyst.analysis.NoSuchDatabaseException; +import org.apache.spark.sql.catalyst.analysis.NoSuchTableException; +import org.apache.spark.sql.catalyst.catalog.CatalogTable; +import org.apache.spark.sql.catalyst.catalog.CatalogTablePartition; +import org.apache.spark.sql.catalyst.parser.ParseException; +import org.apache.spark.sql.connector.catalog.Identifier; +import org.apache.spark.sql.connector.catalog.TableCatalog; +import org.apache.spark.sql.types.ArrayType; +import org.apache.spark.sql.types.DataTypes; +import org.apache.spark.sql.types.Metadata; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; +import org.junit.After; +import org.junit.Assert; +import org.junit.Assume; +import org.junit.Before; +import org.junit.Ignore; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.TemporaryFolder; +import org.junit.runners.Parameterized; +import scala.Option; +import scala.Some; +import scala.collection.JavaConverters; +import scala.collection.Seq; + +public class TestCreateActions extends SparkCatalogTestBase { + private static final String CREATE_PARTITIONED_PARQUET = "CREATE TABLE %s (id INT, data STRING) " + + "using parquet PARTITIONED BY (id) LOCATION '%s'"; + private static final String CREATE_PARQUET = "CREATE TABLE %s (id INT, data STRING) " + + "using parquet LOCATION '%s'"; + private static final String CREATE_HIVE_EXTERNAL_PARQUET = "CREATE EXTERNAL TABLE %s (data STRING) " + + "PARTITIONED BY (id INT) STORED AS parquet LOCATION '%s'"; + private static final String CREATE_HIVE_PARQUET = "CREATE TABLE %s (data STRING) " + + "PARTITIONED BY (id INT) STORED AS parquet"; + + private static final String NAMESPACE = "default"; + + @Parameterized.Parameters(name = "Catalog Name {0} - Options {2}") + public static Object[][] parameters() { + return new Object[][] { + new Object[] {"spark_catalog", SparkSessionCatalog.class.getName(), ImmutableMap.of( + "type", "hive", + "default-namespace", "default", + "parquet-enabled", "true", + "cache-enabled", "false" // Spark will delete tables using v1, leaving the cache out of sync + )}, + new Object[] {"spark_catalog", SparkSessionCatalog.class.getName(), ImmutableMap.of( + "type", "hadoop", + "default-namespace", "default", + "parquet-enabled", "true", + "cache-enabled", "false" // Spark will delete tables using v1, leaving the cache out of sync + )}, + new Object[] { "testhive", SparkCatalog.class.getName(), ImmutableMap.of( + "type", "hive", + "default-namespace", "default" + )}, + new Object[] { "testhadoop", SparkCatalog.class.getName(), ImmutableMap.of( + "type", "hadoop", + "default-namespace", "default" + )} + }; + } + + @Rule + public TemporaryFolder temp = new TemporaryFolder(); + + private String baseTableName = "baseTable"; + private File tableDir; + private String tableLocation; + private final String type; + private final TableCatalog catalog; + + public TestCreateActions( + String catalogName, + String implementation, + Map config) { + super(catalogName, implementation, config); + this.catalog = (TableCatalog) spark.sessionState().catalogManager().catalog(catalogName); + this.type = config.get("type"); + } + + @Before + public void before() { + try { + this.tableDir = temp.newFolder(); + } catch (IOException e) { + throw new RuntimeException(e); + } + this.tableLocation = tableDir.toURI().toString(); + + spark.conf().set("hive.exec.dynamic.partition", "true"); + spark.conf().set("hive.exec.dynamic.partition.mode", "nonstrict"); + spark.conf().set("spark.sql.parquet.writeLegacyFormat", false); + spark.conf().set("spark.sql.parquet.writeLegacyFormat", false); + spark.sql(String.format("DROP TABLE IF EXISTS %s", baseTableName)); + + List expected = Lists.newArrayList( + new SimpleRecord(1, "a"), + new SimpleRecord(2, "b"), + new SimpleRecord(3, "c") + ); + + Dataset df = spark.createDataFrame(expected, SimpleRecord.class); + + df.select("id", "data").orderBy("data").write() + .mode("append") + .option("path", tableLocation) + .saveAsTable(baseTableName); + } + + @After + public void after() throws IOException { + // Drop the hive table. + spark.sql(String.format("DROP TABLE IF EXISTS %s", baseTableName)); + } + + @Test + public void testMigratePartitioned() throws Exception { + Assume.assumeTrue("Cannot migrate to a hadoop based catalog", !type.equals("hadoop")); + Assume.assumeTrue("Can only migrate from Spark Session Catalog", catalog.name().equals("spark_catalog")); + String source = sourceName("test_migrate_partitioned_table"); + String dest = source; + createSourceTable(CREATE_PARTITIONED_PARQUET, source); + assertMigratedFileCount(SparkActions.get().migrateTable(source), source, dest); + } + + @Test + public void testPartitionedTableWithUnRecoveredPartitions() throws Exception { + Assume.assumeTrue("Cannot migrate to a hadoop based catalog", !type.equals("hadoop")); + Assume.assumeTrue("Can only migrate from Spark Session Catalog", catalog.name().equals("spark_catalog")); + String source = sourceName("test_unrecovered_partitions"); + String dest = source; + File location = temp.newFolder(); + sql(CREATE_PARTITIONED_PARQUET, source, location); + + // Data generation and partition addition + spark.range(5) + .selectExpr("id", "cast(id as STRING) as data") + .write() + .partitionBy("id").mode(SaveMode.Overwrite) + .parquet(location.toURI().toString()); + sql("ALTER TABLE %s ADD PARTITION(id=0)", source); + + assertMigratedFileCount(SparkActions.get().migrateTable(source), source, dest); + } + + @Test + public void testPartitionedTableWithCustomPartitions() throws Exception { + Assume.assumeTrue("Cannot migrate to a hadoop based catalog", !type.equals("hadoop")); + Assume.assumeTrue("Can only migrate from Spark Session Catalog", catalog.name().equals("spark_catalog")); + String source = sourceName("test_custom_parts"); + String dest = source; + File tblLocation = temp.newFolder(); + File partitionDataLoc = temp.newFolder(); + + // Data generation and partition addition + spark.sql(String.format(CREATE_PARTITIONED_PARQUET, source, tblLocation)); + spark.range(10) + .selectExpr("cast(id as STRING) as data") + .write() + .mode(SaveMode.Overwrite).parquet(partitionDataLoc.toURI().toString()); + sql("ALTER TABLE %s ADD PARTITION(id=0) LOCATION '%s'", source, partitionDataLoc.toURI().toString()); + assertMigratedFileCount(SparkActions.get().migrateTable(source), source, dest); + } + + @Test + public void testAddColumnOnMigratedTableAtEnd() throws Exception { + Assume.assumeTrue("Cannot migrate to a hadoop based catalog", !type.equals("hadoop")); + Assume.assumeTrue("Can only migrate from Spark Session Catalog", catalog.name().equals("spark_catalog")); + String source = sourceName("test_add_column_migrated_table"); + String dest = source; + createSourceTable(CREATE_PARQUET, source); + List expected1 = sql("select *, null from %s order by id", source); + List expected2 = sql("select *, null, null from %s order by id", source); + + // migrate table + SparkActions.get().migrateTable(source).execute(); + SparkTable sparkTable = loadTable(dest); + Table table = sparkTable.table(); + + // test column addition on migrated table + Schema beforeSchema = table.schema(); + String newCol1 = "newCol1"; + sparkTable.table().updateSchema().addColumn(newCol1, Types.IntegerType.get()).commit(); + Schema afterSchema = table.schema(); + Assert.assertNull(beforeSchema.findField(newCol1)); + Assert.assertNotNull(afterSchema.findField(newCol1)); + + // reads should succeed without any exceptions + List results1 = sql("select * from %s order by id", dest); + Assert.assertTrue(results1.size() > 0); + assertEquals("Output must match", results1, expected1); + + String newCol2 = "newCol2"; + sql("ALTER TABLE %s ADD COLUMN %s INT", dest, newCol2); + StructType schema = spark.table(dest).schema(); + Assert.assertTrue(Arrays.asList(schema.fieldNames()).contains(newCol2)); + + // reads should succeed without any exceptions + List results2 = sql("select * from %s order by id", dest); + Assert.assertTrue(results2.size() > 0); + assertEquals("Output must match", results2, expected2); + } + + @Test + public void testAddColumnOnMigratedTableAtMiddle() throws Exception { + Assume.assumeTrue("Cannot migrate to a hadoop based catalog", !type.equals("hadoop")); + Assume.assumeTrue("Can only migrate from Spark Session Catalog", catalog.name().equals("spark_catalog")); + String source = sourceName("test_add_column_migrated_table_middle"); + String dest = source; + createSourceTable(CREATE_PARQUET, source); + + // migrate table + SparkActions.get().migrateTable(source).execute(); + SparkTable sparkTable = loadTable(dest); + Table table = sparkTable.table(); + List expected = sql("select id, null, data from %s order by id", source); + + // test column addition on migrated table + Schema beforeSchema = table.schema(); + String newCol1 = "newCol"; + sparkTable.table().updateSchema().addColumn("newCol", Types.IntegerType.get()) + .moveAfter(newCol1, "id") + .commit(); + Schema afterSchema = table.schema(); + Assert.assertNull(beforeSchema.findField(newCol1)); + Assert.assertNotNull(afterSchema.findField(newCol1)); + + // reads should succeed + List results = sql("select * from %s order by id", dest); + Assert.assertTrue(results.size() > 0); + assertEquals("Output must match", results, expected); + } + + @Test + public void removeColumnsAtEnd() throws Exception { + Assume.assumeTrue("Cannot migrate to a hadoop based catalog", !type.equals("hadoop")); + Assume.assumeTrue("Can only migrate from Spark Session Catalog", catalog.name().equals("spark_catalog")); + String source = sourceName("test_remove_column_migrated_table"); + String dest = source; + + String colName1 = "newCol1"; + String colName2 = "newCol2"; + File location = temp.newFolder(); + spark.range(10).selectExpr("cast(id as INT)", "CAST(id as INT) " + colName1, "CAST(id as INT) " + colName2) + .write() + .mode(SaveMode.Overwrite).saveAsTable(dest); + List expected1 = sql("select id, %s from %s order by id", colName1, source); + List expected2 = sql("select id from %s order by id", source); + + // migrate table + SparkActions.get().migrateTable(source).execute(); + SparkTable sparkTable = loadTable(dest); + Table table = sparkTable.table(); + + // test column removal on migrated table + Schema beforeSchema = table.schema(); + sparkTable.table().updateSchema().deleteColumn(colName1).commit(); + Schema afterSchema = table.schema(); + Assert.assertNotNull(beforeSchema.findField(colName1)); + Assert.assertNull(afterSchema.findField(colName1)); + + // reads should succeed without any exceptions + List results1 = sql("select * from %s order by id", dest); + Assert.assertTrue(results1.size() > 0); + assertEquals("Output must match", expected1, results1); + + sql("ALTER TABLE %s DROP COLUMN %s", dest, colName2); + StructType schema = spark.table(dest).schema(); + Assert.assertFalse(Arrays.asList(schema.fieldNames()).contains(colName2)); + + // reads should succeed without any exceptions + List results2 = sql("select * from %s order by id", dest); + Assert.assertTrue(results2.size() > 0); + assertEquals("Output must match", expected2, results2); + } + + @Test + public void removeColumnFromMiddle() throws Exception { + Assume.assumeTrue("Cannot migrate to a hadoop based catalog", !type.equals("hadoop")); + Assume.assumeTrue("Can only migrate from Spark Session Catalog", catalog.name().equals("spark_catalog")); + String source = sourceName("test_remove_column_migrated_table_from_middle"); + String dest = source; + String dropColumnName = "col1"; + + spark.range(10).selectExpr("cast(id as INT)", "CAST(id as INT) as " + + dropColumnName, "CAST(id as INT) as col2").write().mode(SaveMode.Overwrite).saveAsTable(dest); + List expected = sql("select id, col2 from %s order by id", source); + + // migrate table + SparkActions.get().migrateTable(source).execute(); + + // drop column + sql("ALTER TABLE %s DROP COLUMN %s", dest, "col1"); + StructType schema = spark.table(dest).schema(); + Assert.assertFalse(Arrays.asList(schema.fieldNames()).contains(dropColumnName)); + + // reads should return same output as that of non-iceberg table + List results = sql("select * from %s order by id", dest); + Assert.assertTrue(results.size() > 0); + assertEquals("Output must match", expected, results); + } + + @Test + public void testMigrateUnpartitioned() throws Exception { + Assume.assumeTrue("Cannot migrate to a hadoop based catalog", !type.equals("hadoop")); + Assume.assumeTrue("Can only migrate from Spark Session Catalog", catalog.name().equals("spark_catalog")); + String source = sourceName("test_migrate_unpartitioned_table"); + String dest = source; + createSourceTable(CREATE_PARQUET, source); + assertMigratedFileCount(SparkActions.get().migrateTable(source), source, dest); + } + + @Test + public void testSnapshotPartitioned() throws Exception { + Assume.assumeTrue("Cannot snapshot with arbitrary location in a hadoop based catalog", + !type.equals("hadoop")); + File location = temp.newFolder(); + String source = sourceName("test_snapshot_partitioned_table"); + String dest = destName("iceberg_snapshot_partitioned"); + createSourceTable(CREATE_PARTITIONED_PARQUET, source); + assertSnapshotFileCount( + SparkActions.get().snapshotTable(source).as(dest).tableLocation(location.toString()), source, dest); + assertIsolatedSnapshot(source, dest); + } + + @Test + public void testSnapshotUnpartitioned() throws Exception { + Assume.assumeTrue("Cannot snapshot with arbitrary location in a hadoop based catalog", + !type.equals("hadoop")); + File location = temp.newFolder(); + String source = sourceName("test_snapshot_unpartitioned_table"); + String dest = destName("iceberg_snapshot_unpartitioned"); + createSourceTable(CREATE_PARQUET, source); + assertSnapshotFileCount( + SparkActions.get().snapshotTable(source).as(dest).tableLocation(location.toString()), source, dest); + assertIsolatedSnapshot(source, dest); + } + + @Test + public void testSnapshotHiveTable() throws Exception { + Assume.assumeTrue("Cannot snapshot with arbitrary location in a hadoop based catalog", + !type.equals("hadoop")); + File location = temp.newFolder(); + String source = sourceName("snapshot_hive_table"); + String dest = destName("iceberg_snapshot_hive_table"); + createSourceTable(CREATE_HIVE_EXTERNAL_PARQUET, source); + assertSnapshotFileCount( + SparkActions.get().snapshotTable(source).as(dest).tableLocation(location.toString()), source, dest); + assertIsolatedSnapshot(source, dest); + } + + @Test + public void testMigrateHiveTable() throws Exception { + Assume.assumeTrue("Cannot migrate to a hadoop based catalog", !type.equals("hadoop")); + String source = sourceName("migrate_hive_table"); + String dest = source; + createSourceTable(CREATE_HIVE_EXTERNAL_PARQUET, source); + assertMigratedFileCount(SparkActions.get().migrateTable(source), source, dest); + } + + @Test + public void testSnapshotManagedHiveTable() throws Exception { + Assume.assumeTrue("Cannot migrate to a hadoop based catalog", !type.equals("hadoop")); + File location = temp.newFolder(); + String source = sourceName("snapshot_managed_hive_table"); + String dest = destName("iceberg_snapshot_managed_hive_table"); + createSourceTable(CREATE_HIVE_PARQUET, source); + assertSnapshotFileCount( + SparkActions.get().snapshotTable(source).as(dest).tableLocation(location.toString()), source, dest); + assertIsolatedSnapshot(source, dest); + } + + @Test + public void testMigrateManagedHiveTable() throws Exception { + Assume.assumeTrue("Cannot migrate to a hadoop based catalog", !type.equals("hadoop")); + File location = temp.newFolder(); + String source = sourceName("migrate_managed_hive_table"); + String dest = destName("iceberg_migrate_managed_hive_table"); + createSourceTable(CREATE_HIVE_PARQUET, source); + assertSnapshotFileCount( + SparkActions.get().snapshotTable(source).as(dest).tableLocation(location.toString()), source, dest); + } + + @Test + public void testProperties() throws Exception { + String source = sourceName("test_properties_table"); + String dest = destName("iceberg_properties"); + Map props = Maps.newHashMap(); + props.put("city", "New Orleans"); + props.put("note", "Jazz"); + createSourceTable(CREATE_PARQUET, source); + for (Map.Entry keyValue : props.entrySet()) { + spark.sql(String.format("ALTER TABLE %s SET TBLPROPERTIES (\"%s\" = \"%s\")", + source, keyValue.getKey(), keyValue.getValue())); + } + assertSnapshotFileCount( + SparkActions.get().snapshotTable(source).as(dest).tableProperty("dogs", "sundance"), source, dest); + SparkTable table = loadTable(dest); + + Map expectedProps = Maps.newHashMap(); + expectedProps.putAll(props); + expectedProps.put("dogs", "sundance"); + + for (Map.Entry entry : expectedProps.entrySet()) { + Assert.assertTrue( + "Created table missing property " + entry.getKey(), + table.properties().containsKey(entry.getKey())); + Assert.assertEquals("Property value is not the expected value", + entry.getValue(), table.properties().get(entry.getKey())); + } + } + + @Test + public void testSparkTableReservedProperties() throws Exception { + String destTableName = "iceberg_reserved_properties"; + String source = sourceName("test_reserved_properties_table"); + String dest = destName(destTableName); + createSourceTable(CREATE_PARQUET, source); + assertSnapshotFileCount(SparkActions.get().snapshotTable(source).as(dest), source, dest); + SparkTable table = loadTable(dest); + // set sort orders + table.table().replaceSortOrder().asc("id").desc("data").commit(); + + String[] keys = {"provider", "format", "current-snapshot-id", "location", "sort-order"}; + + for (String entry : keys) { + Assert.assertTrue("Created table missing reserved property " + entry, table.properties().containsKey(entry)); + } + + Assert.assertEquals("Unexpected provider", "iceberg", table.properties().get("provider")); + Assert.assertEquals("Unexpected format", "iceberg/parquet", table.properties().get("format")); + Assert.assertNotEquals("No current-snapshot-id found", "none", table.properties().get("current-snapshot-id")); + Assert.assertTrue("Location isn't correct", table.properties().get("location").endsWith(destTableName)); + + Assert.assertEquals("Unexpected format-version", "1", table.properties().get("format-version")); + table.table().updateProperties().set("format-version", "2").commit(); + Assert.assertEquals("Unexpected format-version", "2", table.properties().get("format-version")); + + Assert.assertEquals("Sort-order isn't correct", "id ASC NULLS FIRST, data DESC NULLS LAST", + table.properties().get("sort-order")); + Assert.assertNull("Identifier fields should be null", table.properties().get("identifier-fields")); + + table.table().updateSchema().allowIncompatibleChanges().requireColumn("id").setIdentifierFields("id").commit(); + Assert.assertEquals("Identifier fields aren't correct", "[id]", table.properties().get("identifier-fields")); + } + + @Test + public void testSnapshotDefaultLocation() throws Exception { + String source = sourceName("test_snapshot_default"); + String dest = destName("iceberg_snapshot_default"); + createSourceTable(CREATE_PARTITIONED_PARQUET, source); + assertSnapshotFileCount(SparkActions.get().snapshotTable(source).as(dest), source, dest); + assertIsolatedSnapshot(source, dest); + } + + @Test + public void schemaEvolutionTestWithSparkAPI() throws Exception { + Assume.assumeTrue("Cannot migrate to a hadoop based catalog", !type.equals("hadoop")); + Assume.assumeTrue("Can only migrate from Spark Session Catalog", catalog.name().equals("spark_catalog")); + + File location = temp.newFolder(); + String tblName = sourceName("schema_evolution_test"); + + // Data generation and partition addition + spark.range(0, 5) + .selectExpr("CAST(id as INT) as col0", "CAST(id AS FLOAT) col2", "CAST(id AS LONG) col3") + .write() + .mode(SaveMode.Append) + .parquet(location.toURI().toString()); + Dataset rowDataset = spark.range(6, 10) + .selectExpr("CAST(id as INT) as col0", "CAST(id AS STRING) col1", + "CAST(id AS FLOAT) col2", "CAST(id AS LONG) col3"); + rowDataset + .write() + .mode(SaveMode.Append) + .parquet(location.toURI().toString()); + spark.read() + .schema(rowDataset.schema()) + .parquet(location.toURI().toString()).write().saveAsTable(tblName); + List expectedBeforeAddColumn = sql("SELECT * FROM %s ORDER BY col0", tblName); + List expectedAfterAddColumn = sql("SELECT col0, null, col1, col2, col3 FROM %s ORDER BY col0", + tblName); + + // Migrate table + SparkActions.get().migrateTable(tblName).execute(); + + // check if iceberg and non-iceberg output + List afterMigarteBeforeAddResults = sql("SELECT * FROM %s ORDER BY col0", tblName); + assertEquals("Output must match", expectedBeforeAddColumn, afterMigarteBeforeAddResults); + + // Update schema and check output correctness + SparkTable sparkTable = loadTable(tblName); + sparkTable.table().updateSchema().addColumn("newCol", Types.IntegerType.get()) + .moveAfter("newCol", "col0") + .commit(); + List afterMigarteAfterAddResults = sql("SELECT * FROM %s ORDER BY col0", tblName); + assertEquals("Output must match", expectedAfterAddColumn, afterMigarteAfterAddResults); + } + + @Test + public void schemaEvolutionTestWithSparkSQL() throws Exception { + Assume.assumeTrue("Cannot migrate to a hadoop based catalog", !type.equals("hadoop")); + Assume.assumeTrue("Can only migrate from Spark Session Catalog", catalog.name().equals("spark_catalog")); + String tblName = sourceName("schema_evolution_test_sql"); + + // Data generation and partition addition + spark.range(0, 5) + .selectExpr("CAST(id as INT) col0", "CAST(id AS FLOAT) col1", "CAST(id AS STRING) col2") + .write() + .mode(SaveMode.Append) + .saveAsTable(tblName); + sql("ALTER TABLE %s ADD COLUMN col3 INT", tblName); + spark.range(6, 10) + .selectExpr("CAST(id AS INT) col0", "CAST(id AS FLOAT) col1", "CAST(id AS STRING) col2", "CAST(id AS INT) col3") + .registerTempTable("tempdata"); + sql("INSERT INTO TABLE %s SELECT * FROM tempdata", tblName); + List expectedBeforeAddColumn = sql("SELECT * FROM %s ORDER BY col0", tblName); + List expectedAfterAddColumn = sql("SELECT col0, null, col1, col2, col3 FROM %s ORDER BY col0", + tblName); + + // Migrate table + SparkActions.get().migrateTable(tblName).execute(); + + // check if iceberg and non-iceberg output + List afterMigarteBeforeAddResults = sql("SELECT * FROM %s ORDER BY col0", tblName); + assertEquals("Output must match", expectedBeforeAddColumn, afterMigarteBeforeAddResults); + + // Update schema and check output correctness + SparkTable sparkTable = loadTable(tblName); + sparkTable.table().updateSchema().addColumn("newCol", Types.IntegerType.get()) + .moveAfter("newCol", "col0") + .commit(); + List afterMigarteAfterAddResults = sql("SELECT * FROM %s ORDER BY col0", tblName); + assertEquals("Output must match", expectedAfterAddColumn, afterMigarteAfterAddResults); + } + + @Test + public void testHiveStyleThreeLevelList() throws Exception { + threeLevelList(true); + } + + @Test + public void testThreeLevelList() throws Exception { + threeLevelList(false); + } + + @Test + public void testHiveStyleThreeLevelListWithNestedStruct() throws Exception { + threeLevelListWithNestedStruct(true); + } + + @Test + public void testThreeLevelListWithNestedStruct() throws Exception { + threeLevelListWithNestedStruct(false); + } + + @Test + public void testHiveStyleThreeLevelLists() throws Exception { + threeLevelLists(true); + } + + @Test + public void testThreeLevelLists() throws Exception { + threeLevelLists(false); + } + + @Test + public void testHiveStyleStructOfThreeLevelLists() throws Exception { + structOfThreeLevelLists(true); + } + + @Test + public void testStructOfThreeLevelLists() throws Exception { + structOfThreeLevelLists(false); + } + + // TODO: revisit why Spark is no longer writing the legacy format. + @Ignore + @Test + public void testTwoLevelList() throws IOException { + spark.conf().set("spark.sql.parquet.writeLegacyFormat", true); + + String tableName = sourceName("testTwoLevelList"); + File location = temp.newFolder(); + + StructType sparkSchema = + new StructType( + new StructField[]{ + new StructField( + "col1", new ArrayType( + new StructType( + new StructField[]{ + new StructField( + "col2", + DataTypes.IntegerType, + false, + Metadata.empty()) + }), false), true, Metadata.empty())}); + + // even though this list looks like three level list, it is actually a 2-level list where the items are + // structs with 1 field. + String expectedParquetSchema = + "message spark_schema {\n" + + " optional group col1 (LIST) {\n" + + " repeated group array {\n" + + " required int32 col2;\n" + + " }\n" + + " }\n" + + "}\n"; + + // generate parquet file with required schema + List testData = Collections.singletonList("{\"col1\": [{\"col2\": 1}]}"); + spark.read().schema(sparkSchema).json( + JavaSparkContext.fromSparkContext(spark.sparkContext()).parallelize(testData)) + .coalesce(1).write().format("parquet").mode(SaveMode.Append).save(location.getPath()); + + File parquetFile = Arrays.stream(Preconditions.checkNotNull(location.listFiles(new FilenameFilter() { + @Override + public boolean accept(File dir, String name) { + return name.endsWith("parquet"); + } + }))).findAny().get(); + + // verify generated parquet file has expected schema + ParquetFileReader pqReader = ParquetFileReader.open(HadoopInputFile.fromPath(new Path(parquetFile.getPath()), + spark.sessionState().newHadoopConf())); + MessageType schema = pqReader.getFooter().getFileMetaData().getSchema(); + Assert.assertEquals(MessageTypeParser.parseMessageType(expectedParquetSchema), schema); + + // create sql table on top of it + sql("CREATE EXTERNAL TABLE %s (col1 ARRAY>)" + + " STORED AS parquet" + + " LOCATION '%s'", tableName, location); + List expected = sql("select array(struct(1))"); + + // migrate table + SparkActions.get().migrateTable(tableName).execute(); + + // check migrated table is returning expected result + List results = sql("SELECT * FROM %s", tableName); + Assert.assertTrue(results.size() > 0); + assertEquals("Output must match", expected, results); + } + + private void threeLevelList(boolean useLegacyMode) throws Exception { + spark.conf().set("spark.sql.parquet.writeLegacyFormat", useLegacyMode); + + String tableName = sourceName(String.format("threeLevelList_%s", useLegacyMode)); + File location = temp.newFolder(); + sql("CREATE TABLE %s (col1 ARRAY>)" + + " STORED AS parquet" + + " LOCATION '%s'", tableName, location); + + int testValue = 12345; + sql("INSERT INTO %s VALUES (ARRAY(STRUCT(%s)))", tableName, testValue); + List expected = sql(String.format("SELECT * FROM %s", tableName)); + + // migrate table + SparkActions.get().migrateTable(tableName).execute(); + + // check migrated table is returning expected result + List results = sql("SELECT * FROM %s", tableName); + Assert.assertTrue(results.size() > 0); + assertEquals("Output must match", expected, results); + } + + private void threeLevelListWithNestedStruct(boolean useLegacyMode) throws Exception { + spark.conf().set("spark.sql.parquet.writeLegacyFormat", useLegacyMode); + + String tableName = sourceName(String.format("threeLevelListWithNestedStruct_%s", useLegacyMode)); + File location = temp.newFolder(); + sql("CREATE TABLE %s (col1 ARRAY>>)" + + " STORED AS parquet" + + " LOCATION '%s'", tableName, location); + + int testValue = 12345; + sql("INSERT INTO %s VALUES (ARRAY(STRUCT(STRUCT(%s))))", tableName, testValue); + List expected = sql(String.format("SELECT * FROM %s", tableName)); + + // migrate table + SparkActions.get().migrateTable(tableName).execute(); + + // check migrated table is returning expected result + List results = sql("SELECT * FROM %s", tableName); + Assert.assertTrue(results.size() > 0); + assertEquals("Output must match", expected, results); + } + + private void threeLevelLists(boolean useLegacyMode) throws Exception { + spark.conf().set("spark.sql.parquet.writeLegacyFormat", useLegacyMode); + + String tableName = sourceName(String.format("threeLevelLists_%s", useLegacyMode)); + File location = temp.newFolder(); + sql("CREATE TABLE %s (col1 ARRAY>, col3 ARRAY>)" + + " STORED AS parquet" + + " LOCATION '%s'", tableName, location); + + int testValue1 = 12345; + int testValue2 = 987654; + sql("INSERT INTO %s VALUES (ARRAY(STRUCT(%s)), ARRAY(STRUCT(%s)))", + tableName, testValue1, testValue2); + List expected = sql(String.format("SELECT * FROM %s", tableName)); + + // migrate table + SparkActions.get().migrateTable(tableName).execute(); + + // check migrated table is returning expected result + List results = sql("SELECT * FROM %s", tableName); + Assert.assertTrue(results.size() > 0); + assertEquals("Output must match", expected, results); + } + + private void structOfThreeLevelLists(boolean useLegacyMode) throws Exception { + spark.conf().set("spark.sql.parquet.writeLegacyFormat", useLegacyMode); + + String tableName = sourceName(String.format("structOfThreeLevelLists_%s", useLegacyMode)); + File location = temp.newFolder(); + sql("CREATE TABLE %s (col1 STRUCT>>)" + + " STORED AS parquet" + + " LOCATION '%s'", tableName, location); + + int testValue1 = 12345; + sql("INSERT INTO %s VALUES (STRUCT(STRUCT(ARRAY(STRUCT(%s)))))", + tableName, testValue1); + List expected = sql(String.format("SELECT * FROM %s", tableName)); + + // migrate table + SparkActions.get().migrateTable(tableName).execute(); + + // check migrated table is returning expected result + List results = sql("SELECT * FROM %s", tableName); + Assert.assertTrue(results.size() > 0); + assertEquals("Output must match", expected, results); + } + + + private SparkTable loadTable(String name) throws NoSuchTableException, ParseException { + return (SparkTable) catalog.loadTable(Spark3Util.catalogAndIdentifier(spark, name).identifier()); + } + + private CatalogTable loadSessionTable(String name) + throws NoSuchTableException, NoSuchDatabaseException, ParseException { + Identifier identifier = Spark3Util.catalogAndIdentifier(spark, name).identifier(); + Some namespace = Some.apply(identifier.namespace()[0]); + return spark.sessionState().catalog().getTableMetadata(new TableIdentifier(identifier.name(), namespace)); + } + + private void createSourceTable(String createStatement, String tableName) + throws IOException, NoSuchTableException, NoSuchDatabaseException, ParseException { + File location = temp.newFolder(); + spark.sql(String.format(createStatement, tableName, location)); + CatalogTable table = loadSessionTable(tableName); + Seq partitionColumns = table.partitionColumnNames(); + String format = table.provider().get(); + spark.table(baseTableName).write().mode(SaveMode.Append).format(format).partitionBy(partitionColumns.toSeq()) + .saveAsTable(tableName); + } + + // Counts the number of files in the source table, makes sure the same files exist in the destination table + private void assertMigratedFileCount(MigrateTable migrateAction, String source, String dest) + throws NoSuchTableException, NoSuchDatabaseException, ParseException { + long expectedFiles = expectedFilesCount(source); + MigrateTable.Result migratedFiles = migrateAction.execute(); + validateTables(source, dest); + Assert.assertEquals("Expected number of migrated files", + expectedFiles, migratedFiles.migratedDataFilesCount()); + } + + // Counts the number of files in the source table, makes sure the same files exist in the destination table + private void assertSnapshotFileCount(SnapshotTable snapshotTable, String source, String dest) + throws NoSuchTableException, NoSuchDatabaseException, ParseException { + long expectedFiles = expectedFilesCount(source); + SnapshotTable.Result snapshotTableResult = snapshotTable.execute(); + validateTables(source, dest); + Assert.assertEquals("Expected number of imported snapshot files", expectedFiles, + snapshotTableResult.importedDataFilesCount()); + } + + private void validateTables(String source, String dest) throws NoSuchTableException, ParseException { + List expected = spark.table(source).collectAsList(); + SparkTable destTable = loadTable(dest); + Assert.assertEquals("Provider should be iceberg", "iceberg", + destTable.properties().get(TableCatalog.PROP_PROVIDER)); + List actual = spark.table(dest).collectAsList(); + Assert.assertTrue(String.format("Rows in migrated table did not match\nExpected :%s rows \nFound :%s", + expected, actual), expected.containsAll(actual) && actual.containsAll(expected)); + } + + private long expectedFilesCount(String source) throws NoSuchDatabaseException, NoSuchTableException, ParseException { + CatalogTable sourceTable = loadSessionTable(source); + List uris; + if (sourceTable.partitionColumnNames().size() == 0) { + uris = Lists.newArrayList(); + uris.add(sourceTable.location()); + } else { + Seq catalogTablePartitionSeq = + spark.sessionState().catalog().listPartitions(sourceTable.identifier(), Option.apply(null)); + uris = JavaConverters.seqAsJavaList(catalogTablePartitionSeq) + .stream() + .map(CatalogTablePartition::location) + .collect(Collectors.toList()); + } + return uris.stream() + .flatMap(uri -> + FileUtils.listFiles(Paths.get(uri).toFile(), + TrueFileFilter.INSTANCE, TrueFileFilter.INSTANCE).stream()) + .filter(file -> !file.toString().endsWith("crc") && !file.toString().contains("_SUCCESS")).count(); + } + + // Insert records into the destination, makes sure those records exist and source table is unchanged + private void assertIsolatedSnapshot(String source, String dest) { + List expected = spark.sql(String.format("SELECT * FROM %s", source)).collectAsList(); + + List extraData = Lists.newArrayList( + new SimpleRecord(4, "d") + ); + Dataset df = spark.createDataFrame(extraData, SimpleRecord.class); + df.write().format("iceberg").mode("append").saveAsTable(dest); + + List result = spark.sql(String.format("SELECT * FROM %s", source)).collectAsList(); + Assert.assertEquals("No additional rows should be added to the original table", expected.size(), + result.size()); + + List snapshot = spark.sql(String.format("SELECT * FROM %s WHERE id = 4 AND data = 'd'", dest)).collectAsList(); + Assert.assertEquals("Added row not found in snapshot", 1, snapshot.size()); + } + + private String sourceName(String source) { + return NAMESPACE + "." + catalog.name() + "_" + type + "_" + source; + } + + private String destName(String dest) { + if (catalog.name().equals("spark_catalog")) { + return NAMESPACE + "." + catalog.name() + "_" + type + "_" + dest; + } else { + return catalog.name() + "." + NAMESPACE + "." + catalog.name() + "_" + type + "_" + dest; + } + } +} diff --git a/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/actions/TestDeleteReachableFilesAction.java b/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/actions/TestDeleteReachableFilesAction.java new file mode 100644 index 000000000000..9792d772b12d --- /dev/null +++ b/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/actions/TestDeleteReachableFilesAction.java @@ -0,0 +1,353 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iceberg.spark.actions; + +import java.io.File; +import java.util.Set; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.Executors; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.stream.StreamSupport; +import org.apache.hadoop.conf.Configuration; +import org.apache.iceberg.AssertHelpers; +import org.apache.iceberg.DataFile; +import org.apache.iceberg.DataFiles; +import org.apache.iceberg.HasTableOperations; +import org.apache.iceberg.PartitionSpec; +import org.apache.iceberg.Schema; +import org.apache.iceberg.Table; +import org.apache.iceberg.TableProperties; +import org.apache.iceberg.TestHelpers; +import org.apache.iceberg.actions.ActionsProvider; +import org.apache.iceberg.actions.DeleteOrphanFiles; +import org.apache.iceberg.actions.DeleteReachableFiles; +import org.apache.iceberg.exceptions.ValidationException; +import org.apache.iceberg.hadoop.HadoopTables; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableSet; +import org.apache.iceberg.relocated.com.google.common.collect.Iterables; +import org.apache.iceberg.relocated.com.google.common.collect.Lists; +import org.apache.iceberg.relocated.com.google.common.collect.Maps; +import org.apache.iceberg.relocated.com.google.common.collect.Sets; +import org.apache.iceberg.spark.SparkTestBase; +import org.apache.iceberg.types.Types; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.TemporaryFolder; + +import static org.apache.iceberg.types.Types.NestedField.optional; + +public class TestDeleteReachableFilesAction extends SparkTestBase { + private static final HadoopTables TABLES = new HadoopTables(new Configuration()); + private static final Schema SCHEMA = new Schema( + optional(1, "c1", Types.IntegerType.get()), + optional(2, "c2", Types.StringType.get()), + optional(3, "c3", Types.StringType.get()) + ); + private static final int SHUFFLE_PARTITIONS = 2; + + private static final PartitionSpec SPEC = PartitionSpec.builderFor(SCHEMA).identity("c1").build(); + + static final DataFile FILE_A = DataFiles.builder(SPEC) + .withPath("/path/to/data-a.parquet") + .withFileSizeInBytes(10) + .withPartition(TestHelpers.Row.of(0)) + .withRecordCount(1) + .build(); + static final DataFile FILE_B = DataFiles.builder(SPEC) + .withPath("/path/to/data-b.parquet") + .withFileSizeInBytes(10) + .withPartition(TestHelpers.Row.of(1)) + .withRecordCount(1) + .build(); + static final DataFile FILE_C = DataFiles.builder(SPEC) + .withPath("/path/to/data-c.parquet") + .withFileSizeInBytes(10) + .withPartition(TestHelpers.Row.of(2)) + .withRecordCount(1) + .build(); + static final DataFile FILE_D = DataFiles.builder(SPEC) + .withPath("/path/to/data-d.parquet") + .withFileSizeInBytes(10) + .withPartition(TestHelpers.Row.of(3)) + .withRecordCount(1) + .build(); + + @Rule + public TemporaryFolder temp = new TemporaryFolder(); + + private Table table; + + @Before + public void setupTableLocation() throws Exception { + File tableDir = temp.newFolder(); + String tableLocation = tableDir.toURI().toString(); + this.table = TABLES.create(SCHEMA, SPEC, Maps.newHashMap(), tableLocation); + spark.conf().set("spark.sql.shuffle.partitions", SHUFFLE_PARTITIONS); + } + + private void checkRemoveFilesResults(long expectedDatafiles, long expectedManifestsDeleted, + long expectedManifestListsDeleted, long expectedOtherFilesDeleted, + DeleteReachableFiles.Result results) { + Assert.assertEquals("Incorrect number of manifest files deleted", + expectedManifestsDeleted, results.deletedManifestsCount()); + Assert.assertEquals("Incorrect number of datafiles deleted", + expectedDatafiles, results.deletedDataFilesCount()); + Assert.assertEquals("Incorrect number of manifest lists deleted", + expectedManifestListsDeleted, results.deletedManifestListsCount()); + Assert.assertEquals("Incorrect number of other lists deleted", + expectedOtherFilesDeleted, results.deletedOtherFilesCount()); + } + + @Test + public void dataFilesCleanupWithParallelTasks() { + table.newFastAppend() + .appendFile(FILE_A) + .commit(); + + table.newFastAppend() + .appendFile(FILE_B) + .commit(); + + table.newRewrite() + .rewriteFiles(ImmutableSet.of(FILE_B), ImmutableSet.of(FILE_D)) + .commit(); + + table.newRewrite() + .rewriteFiles(ImmutableSet.of(FILE_A), ImmutableSet.of(FILE_C)) + .commit(); + + Set deletedFiles = ConcurrentHashMap.newKeySet(); + Set deleteThreads = ConcurrentHashMap.newKeySet(); + AtomicInteger deleteThreadsIndex = new AtomicInteger(0); + + DeleteReachableFiles.Result result = sparkActions().deleteReachableFiles(metadataLocation(table)) + .io(table.io()) + .executeDeleteWith(Executors.newFixedThreadPool(4, runnable -> { + Thread thread = new Thread(runnable); + thread.setName("remove-files-" + deleteThreadsIndex.getAndIncrement()); + thread.setDaemon(true); // daemon threads will be terminated abruptly when the JVM exits + return thread; + })) + .deleteWith(s -> { + deleteThreads.add(Thread.currentThread().getName()); + deletedFiles.add(s); + }) + .execute(); + + // Verifies that the delete methods ran in the threads created by the provided ExecutorService ThreadFactory + Assert.assertEquals(deleteThreads, + Sets.newHashSet("remove-files-0", "remove-files-1", "remove-files-2", "remove-files-3")); + + Lists.newArrayList(FILE_A, FILE_B, FILE_C, FILE_D).forEach(file -> + Assert.assertTrue("FILE_A should be deleted", deletedFiles.contains(FILE_A.path().toString())) + ); + checkRemoveFilesResults(4L, 6L, 4L, 6, result); + } + + @Test + public void testWithExpiringDanglingStageCommit() { + table.location(); + // `A` commit + table.newAppend() + .appendFile(FILE_A) + .commit(); + + // `B` staged commit + table.newAppend() + .appendFile(FILE_B) + .stageOnly() + .commit(); + + // `C` commit + table.newAppend() + .appendFile(FILE_C) + .commit(); + + DeleteReachableFiles.Result result = sparkActions().deleteReachableFiles(metadataLocation(table)) + .io(table.io()) + .execute(); + + checkRemoveFilesResults(3L, 3L, 3L, 5, result); + } + + @Test + public void testRemoveFileActionOnEmptyTable() { + DeleteReachableFiles.Result result = sparkActions().deleteReachableFiles(metadataLocation(table)) + .io(table.io()) + .execute(); + + checkRemoveFilesResults(0, 0, 0, 2, result); + } + + @Test + public void testRemoveFilesActionWithReducedVersionsTable() { + table.updateProperties() + .set(TableProperties.METADATA_PREVIOUS_VERSIONS_MAX, "2").commit(); + table.newAppend() + .appendFile(FILE_A) + .commit(); + + table.newAppend() + .appendFile(FILE_B) + .commit(); + + table.newAppend() + .appendFile(FILE_B) + .commit(); + + table.newAppend() + .appendFile(FILE_C) + .commit(); + + table.newAppend() + .appendFile(FILE_D) + .commit(); + + DeleteReachableFiles baseRemoveFilesSparkAction = sparkActions() + .deleteReachableFiles(metadataLocation(table)) + .io(table.io()); + DeleteReachableFiles.Result result = baseRemoveFilesSparkAction.execute(); + + checkRemoveFilesResults(4, 5, 5, 8, result); + } + + @Test + public void testRemoveFilesAction() { + table.newAppend() + .appendFile(FILE_A) + .commit(); + + table.newAppend() + .appendFile(FILE_B) + .commit(); + + DeleteReachableFiles baseRemoveFilesSparkAction = sparkActions() + .deleteReachableFiles(metadataLocation(table)) + .io(table.io()); + checkRemoveFilesResults(2, 2, 2, 4, baseRemoveFilesSparkAction.execute()); + } + + @Test + public void testRemoveFilesActionWithDefaultIO() { + table.newAppend() + .appendFile(FILE_A) + .commit(); + + table.newAppend() + .appendFile(FILE_B) + .commit(); + + // IO not set explicitly on removeReachableFiles action + // IO defaults to HadoopFileIO + DeleteReachableFiles baseRemoveFilesSparkAction = sparkActions() + .deleteReachableFiles(metadataLocation(table)); + checkRemoveFilesResults(2, 2, 2, 4, baseRemoveFilesSparkAction.execute()); + } + + @Test + public void testUseLocalIterator() { + table.newFastAppend() + .appendFile(FILE_A) + .commit(); + + table.newOverwrite() + .deleteFile(FILE_A) + .addFile(FILE_B) + .commit(); + + table.newFastAppend() + .appendFile(FILE_C) + .commit(); + + int jobsBefore = spark.sparkContext().dagScheduler().nextJobId().get(); + + withSQLConf(ImmutableMap.of("spark.sql.adaptive.enabled", "false"), () -> { + DeleteReachableFiles.Result results = sparkActions().deleteReachableFiles(metadataLocation(table)) + .io(table.io()) + .option("stream-results", "true").execute(); + + int jobsAfter = spark.sparkContext().dagScheduler().nextJobId().get(); + int totalJobsRun = jobsAfter - jobsBefore; + + checkRemoveFilesResults(3L, 4L, 3L, 5, results); + + Assert.assertEquals( + "Expected total jobs to be equal to total number of shuffle partitions", + totalJobsRun, SHUFFLE_PARTITIONS); + }); + } + + @Test + public void testIgnoreMetadataFilesNotFound() { + table.updateProperties() + .set(TableProperties.METADATA_PREVIOUS_VERSIONS_MAX, "1").commit(); + + table.newAppend() + .appendFile(FILE_A) + .commit(); + // There are three metadata json files at this point + DeleteOrphanFiles.Result result = + sparkActions().deleteOrphanFiles(table).olderThan(System.currentTimeMillis()).execute(); + + Assert.assertEquals("Should delete 1 file", 1, Iterables.size(result.orphanFileLocations())); + Assert.assertTrue("Should remove v1 file", + StreamSupport.stream(result.orphanFileLocations().spliterator(), false) + .anyMatch(file -> file.contains("v1.metadata.json"))); + + DeleteReachableFiles baseRemoveFilesSparkAction = sparkActions() + .deleteReachableFiles(metadataLocation(table)) + .io(table.io()); + DeleteReachableFiles.Result res = baseRemoveFilesSparkAction.execute(); + + checkRemoveFilesResults(1, 1, 1, 4, res); + } + + @Test + public void testEmptyIOThrowsException() { + DeleteReachableFiles baseRemoveFilesSparkAction = sparkActions() + .deleteReachableFiles(metadataLocation(table)) + .io(null); + AssertHelpers.assertThrows("FileIO can't be null in DeleteReachableFiles action", + IllegalArgumentException.class, "File IO cannot be null", + baseRemoveFilesSparkAction::execute); + } + + @Test + public void testRemoveFilesActionWhenGarbageCollectionDisabled() { + table.updateProperties() + .set(TableProperties.GC_ENABLED, "false") + .commit(); + + AssertHelpers.assertThrows("Should complain about removing files when GC is disabled", + ValidationException.class, "Cannot delete files: GC is disabled (deleting files may corrupt other tables)", + () -> sparkActions().deleteReachableFiles(metadataLocation(table)).execute()); + } + + private String metadataLocation(Table tbl) { + return ((HasTableOperations) tbl).operations().current().metadataFileLocation(); + } + + private ActionsProvider sparkActions() { + return SparkActions.get(); + } +} diff --git a/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/actions/TestExpireSnapshotsAction.java b/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/actions/TestExpireSnapshotsAction.java new file mode 100644 index 000000000000..b278bd08a608 --- /dev/null +++ b/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/actions/TestExpireSnapshotsAction.java @@ -0,0 +1,1153 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iceberg.spark.actions; + +import java.io.File; +import java.io.IOException; +import java.util.List; +import java.util.Set; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.Executors; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.stream.Collectors; +import org.apache.hadoop.conf.Configuration; +import org.apache.iceberg.AssertHelpers; +import org.apache.iceberg.BaseTable; +import org.apache.iceberg.DataFile; +import org.apache.iceberg.DataFiles; +import org.apache.iceberg.DeleteFile; +import org.apache.iceberg.FileMetadata; +import org.apache.iceberg.ManifestFile; +import org.apache.iceberg.PartitionSpec; +import org.apache.iceberg.Schema; +import org.apache.iceberg.Snapshot; +import org.apache.iceberg.Table; +import org.apache.iceberg.TableMetadata; +import org.apache.iceberg.TableProperties; +import org.apache.iceberg.actions.ExpireSnapshots; +import org.apache.iceberg.exceptions.ValidationException; +import org.apache.iceberg.expressions.Expressions; +import org.apache.iceberg.hadoop.HadoopTables; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableSet; +import org.apache.iceberg.relocated.com.google.common.collect.Iterables; +import org.apache.iceberg.relocated.com.google.common.collect.Lists; +import org.apache.iceberg.relocated.com.google.common.collect.Maps; +import org.apache.iceberg.relocated.com.google.common.collect.Sets; +import org.apache.iceberg.spark.SparkTestBase; +import org.apache.iceberg.types.Types; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Row; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.TemporaryFolder; + +import static org.apache.iceberg.types.Types.NestedField.optional; + +public class TestExpireSnapshotsAction extends SparkTestBase { + private static final HadoopTables TABLES = new HadoopTables(new Configuration()); + private static final Schema SCHEMA = new Schema( + optional(1, "c1", Types.IntegerType.get()), + optional(2, "c2", Types.StringType.get()), + optional(3, "c3", Types.StringType.get()) + ); + private static final int SHUFFLE_PARTITIONS = 2; + + private static final PartitionSpec SPEC = PartitionSpec.builderFor(SCHEMA).identity("c1").build(); + + static final DataFile FILE_A = DataFiles.builder(SPEC) + .withPath("/path/to/data-a.parquet") + .withFileSizeInBytes(10) + .withPartitionPath("c1=0") // easy way to set partition data for now + .withRecordCount(1) + .build(); + static final DataFile FILE_B = DataFiles.builder(SPEC) + .withPath("/path/to/data-b.parquet") + .withFileSizeInBytes(10) + .withPartitionPath("c1=1") // easy way to set partition data for now + .withRecordCount(1) + .build(); + static final DataFile FILE_C = DataFiles.builder(SPEC) + .withPath("/path/to/data-c.parquet") + .withFileSizeInBytes(10) + .withPartitionPath("c1=2") // easy way to set partition data for now + .withRecordCount(1) + .build(); + static final DataFile FILE_D = DataFiles.builder(SPEC) + .withPath("/path/to/data-d.parquet") + .withFileSizeInBytes(10) + .withPartitionPath("c1=3") // easy way to set partition data for now + .withRecordCount(1) + .build(); + static final DeleteFile FILE_A_POS_DELETES = FileMetadata.deleteFileBuilder(SPEC) + .ofPositionDeletes() + .withPath("/path/to/data-a-pos-deletes.parquet") + .withFileSizeInBytes(10) + .withPartitionPath("c1=0") // easy way to set partition data for now + .withRecordCount(1) + .build(); + static final DeleteFile FILE_A_EQ_DELETES = FileMetadata.deleteFileBuilder(SPEC) + .ofEqualityDeletes() + .withPath("/path/to/data-a-eq-deletes.parquet") + .withFileSizeInBytes(10) + .withPartitionPath("c1=0") // easy way to set partition data for now + .withRecordCount(1) + .build(); + + @Rule + public TemporaryFolder temp = new TemporaryFolder(); + + private File tableDir; + private String tableLocation; + private Table table; + + @Before + public void setupTableLocation() throws Exception { + this.tableDir = temp.newFolder(); + this.tableLocation = tableDir.toURI().toString(); + this.table = TABLES.create(SCHEMA, SPEC, Maps.newHashMap(), tableLocation); + spark.conf().set("spark.sql.shuffle.partitions", SHUFFLE_PARTITIONS); + } + + private Long rightAfterSnapshot() { + return rightAfterSnapshot(table.currentSnapshot().snapshotId()); + } + + private Long rightAfterSnapshot(long snapshotId) { + Long end = System.currentTimeMillis(); + while (end <= table.snapshot(snapshotId).timestampMillis()) { + end = System.currentTimeMillis(); + } + return end; + } + + private void checkExpirationResults(long expectedDatafiles, long expectedPosDeleteFiles, long expectedEqDeleteFiles, + long expectedManifestsDeleted, long expectedManifestListsDeleted, + ExpireSnapshots.Result results) { + + Assert.assertEquals("Incorrect number of manifest files deleted", + expectedManifestsDeleted, results.deletedManifestsCount()); + Assert.assertEquals("Incorrect number of datafiles deleted", + expectedDatafiles, results.deletedDataFilesCount()); + Assert.assertEquals("Incorrect number of pos deletefiles deleted", + expectedPosDeleteFiles, results.deletedPositionDeleteFilesCount()); + Assert.assertEquals("Incorrect number of eq deletefiles deleted", + expectedEqDeleteFiles, results.deletedEqualityDeleteFilesCount()); + Assert.assertEquals("Incorrect number of manifest lists deleted", + expectedManifestListsDeleted, results.deletedManifestListsCount()); + } + + @Test + public void testFilesCleaned() throws Exception { + table.newFastAppend() + .appendFile(FILE_A) + .commit(); + + table.newOverwrite() + .deleteFile(FILE_A) + .addFile(FILE_B) + .commit(); + + table.newFastAppend() + .appendFile(FILE_C) + .commit(); + + long end = rightAfterSnapshot(); + + ExpireSnapshots.Result results = SparkActions.get().expireSnapshots(table).expireOlderThan(end).execute(); + + Assert.assertEquals("Table does not have 1 snapshot after expiration", 1, Iterables.size(table.snapshots())); + + checkExpirationResults(1L, 0L, 0L, 1L, 2L, results); + } + + @Test + public void dataFilesCleanupWithParallelTasks() throws IOException { + + table.newFastAppend() + .appendFile(FILE_A) + .commit(); + + table.newFastAppend() + .appendFile(FILE_B) + .commit(); + + table.newRewrite() + .rewriteFiles(ImmutableSet.of(FILE_B), ImmutableSet.of(FILE_D)) + .commit(); + + table.newRewrite() + .rewriteFiles(ImmutableSet.of(FILE_A), ImmutableSet.of(FILE_C)) + .commit(); + + long t4 = rightAfterSnapshot(); + + Set deletedFiles = Sets.newHashSet(); + Set deleteThreads = ConcurrentHashMap.newKeySet(); + AtomicInteger deleteThreadsIndex = new AtomicInteger(0); + + ExpireSnapshots.Result result = SparkActions.get().expireSnapshots(table) + .executeDeleteWith(Executors.newFixedThreadPool(4, runnable -> { + Thread thread = new Thread(runnable); + thread.setName("remove-snapshot-" + deleteThreadsIndex.getAndIncrement()); + thread.setDaemon(true); // daemon threads will be terminated abruptly when the JVM exits + return thread; + })) + .expireOlderThan(t4) + .deleteWith(s -> { + deleteThreads.add(Thread.currentThread().getName()); + deletedFiles.add(s); + }) + .execute(); + + // Verifies that the delete methods ran in the threads created by the provided ExecutorService ThreadFactory + Assert.assertEquals(deleteThreads, + Sets.newHashSet("remove-snapshot-0", "remove-snapshot-1", "remove-snapshot-2", "remove-snapshot-3")); + + Assert.assertTrue("FILE_A should be deleted", deletedFiles.contains(FILE_A.path().toString())); + Assert.assertTrue("FILE_B should be deleted", deletedFiles.contains(FILE_B.path().toString())); + + checkExpirationResults(2L, 0L, 0L, 3L, 3L, result); + } + + @Test + public void testNoFilesDeletedWhenNoSnapshotsExpired() throws Exception { + table.newFastAppend() + .appendFile(FILE_A) + .commit(); + + ExpireSnapshots.Result results = SparkActions.get().expireSnapshots(table).execute(); + checkExpirationResults(0L, 0L, 0L, 0L, 0L, results); + } + + @Test + public void testCleanupRepeatedOverwrites() throws Exception { + table.newFastAppend() + .appendFile(FILE_A) + .commit(); + + for (int i = 0; i < 10; i++) { + table.newOverwrite() + .deleteFile(FILE_A) + .addFile(FILE_B) + .commit(); + + table.newOverwrite() + .deleteFile(FILE_B) + .addFile(FILE_A) + .commit(); + } + + long end = rightAfterSnapshot(); + ExpireSnapshots.Result results = SparkActions.get().expireSnapshots(table).expireOlderThan(end).execute(); + checkExpirationResults(1L, 0L, 0L, 39L, 20L, results); + } + + @Test + public void testRetainLastWithExpireOlderThan() { + table.newAppend() + .appendFile(FILE_A) // data_bucket=0 + .commit(); + long firstSnapshotId = table.currentSnapshot().snapshotId(); + long t1 = System.currentTimeMillis(); + while (t1 <= table.currentSnapshot().timestampMillis()) { + t1 = System.currentTimeMillis(); + } + + table.newAppend() + .appendFile(FILE_B) // data_bucket=1 + .commit(); + + table.newAppend() + .appendFile(FILE_C) // data_bucket=2 + .commit(); + + long t3 = rightAfterSnapshot(); + + // Retain last 2 snapshots + SparkActions.get().expireSnapshots(table) + .expireOlderThan(t3) + .retainLast(2) + .execute(); + + Assert.assertEquals("Should have two snapshots.", 2, Lists.newArrayList(table.snapshots()).size()); + Assert.assertEquals("First snapshot should not present.", null, table.snapshot(firstSnapshotId)); + } + + @Test + public void testExpireTwoSnapshotsById() throws Exception { + table.newAppend() + .appendFile(FILE_A) // data_bucket=0 + .commit(); + long firstSnapshotId = table.currentSnapshot().snapshotId(); + + table.newAppend() + .appendFile(FILE_B) // data_bucket=1 + .commit(); + + long secondSnapshotID = table.currentSnapshot().snapshotId(); + + table.newAppend() + .appendFile(FILE_C) // data_bucket=2 + .commit(); + + // Retain last 2 snapshots + ExpireSnapshots.Result result = SparkActions.get().expireSnapshots(table) + .expireSnapshotId(firstSnapshotId) + .expireSnapshotId(secondSnapshotID) + .execute(); + + Assert.assertEquals("Should have one snapshots.", 1, Lists.newArrayList(table.snapshots()).size()); + Assert.assertEquals("First snapshot should not present.", null, table.snapshot(firstSnapshotId)); + Assert.assertEquals("Second snapshot should not be present.", null, table.snapshot(secondSnapshotID)); + + checkExpirationResults(0L, 0L, 0L, 0L, 2L, result); + } + + @Test + public void testRetainLastWithExpireById() { + table.newAppend() + .appendFile(FILE_A) // data_bucket=0 + .commit(); + long firstSnapshotId = table.currentSnapshot().snapshotId(); + + table.newAppend() + .appendFile(FILE_B) // data_bucket=1 + .commit(); + + table.newAppend() + .appendFile(FILE_C) // data_bucket=2 + .commit(); + + // Retain last 3 snapshots, but explicitly remove the first snapshot + ExpireSnapshots.Result result = SparkActions.get().expireSnapshots(table) + .expireSnapshotId(firstSnapshotId) + .retainLast(3) + .execute(); + + Assert.assertEquals("Should have two snapshots.", 2, Lists.newArrayList(table.snapshots()).size()); + Assert.assertEquals("First snapshot should not present.", null, table.snapshot(firstSnapshotId)); + checkExpirationResults(0L, 0L, 0L, 0L, 1L, result); + } + + @Test + public void testRetainLastWithTooFewSnapshots() { + table.newAppend() + .appendFile(FILE_A) // data_bucket=0 + .appendFile(FILE_B) // data_bucket=1 + .commit(); + long firstSnapshotId = table.currentSnapshot().snapshotId(); + + table.newAppend() + .appendFile(FILE_C) // data_bucket=2 + .commit(); + + long t2 = rightAfterSnapshot(); + + // Retain last 3 snapshots + ExpireSnapshots.Result result = SparkActions.get().expireSnapshots(table) + .expireOlderThan(t2) + .retainLast(3) + .execute(); + + Assert.assertEquals("Should have two snapshots", 2, Lists.newArrayList(table.snapshots()).size()); + Assert.assertEquals("First snapshot should still present", + firstSnapshotId, table.snapshot(firstSnapshotId).snapshotId()); + checkExpirationResults(0L, 0L, 0L, 0L, 0L, result); + } + + @Test + public void testRetainLastKeepsExpiringSnapshot() { + table.newAppend() + .appendFile(FILE_A) // data_bucket=0 + .commit(); + + table.newAppend() + .appendFile(FILE_B) // data_bucket=1 + .commit(); + + Snapshot secondSnapshot = table.currentSnapshot(); + + table.newAppend() + .appendFile(FILE_C) // data_bucket=2 + .commit(); + + table.newAppend() + .appendFile(FILE_D) // data_bucket=3 + .commit(); + + // Retain last 2 snapshots and expire older than t3 + ExpireSnapshots.Result result = SparkActions.get().expireSnapshots(table) + .expireOlderThan(secondSnapshot.timestampMillis()) + .retainLast(2) + .execute(); + + Assert.assertEquals("Should have three snapshots.", 3, Lists.newArrayList(table.snapshots()).size()); + Assert.assertNotNull("Second snapshot should present.", table.snapshot(secondSnapshot.snapshotId())); + checkExpirationResults(0L, 0L, 0L, 0L, 1L, result); + } + + @Test + public void testExpireSnapshotsWithDisabledGarbageCollection() { + table.updateProperties() + .set(TableProperties.GC_ENABLED, "false") + .commit(); + + table.newAppend() + .appendFile(FILE_A) + .commit(); + + AssertHelpers.assertThrows("Should complain about expiring snapshots", + ValidationException.class, "Cannot expire snapshots: GC is disabled", + () -> SparkActions.get().expireSnapshots(table)); + } + + @Test + public void testExpireOlderThanMultipleCalls() { + table.newAppend() + .appendFile(FILE_A) // data_bucket=0 + .commit(); + + table.newAppend() + .appendFile(FILE_B) // data_bucket=1 + .commit(); + + Snapshot secondSnapshot = table.currentSnapshot(); + + table.newAppend() + .appendFile(FILE_C) // data_bucket=2 + .commit(); + + Snapshot thirdSnapshot = table.currentSnapshot(); + + // Retain last 2 snapshots and expire older than t3 + ExpireSnapshots.Result result = SparkActions.get().expireSnapshots(table) + .expireOlderThan(secondSnapshot.timestampMillis()) + .expireOlderThan(thirdSnapshot.timestampMillis()) + .execute(); + + Assert.assertEquals("Should have one snapshots.", 1, Lists.newArrayList(table.snapshots()).size()); + Assert.assertNull("Second snapshot should not present.", table.snapshot(secondSnapshot.snapshotId())); + checkExpirationResults(0L, 0L, 0L, 0L, 2L, result); + } + + @Test + public void testRetainLastMultipleCalls() { + table.newAppend() + .appendFile(FILE_A) // data_bucket=0 + .commit(); + + table.newAppend() + .appendFile(FILE_B) // data_bucket=1 + .commit(); + + Snapshot secondSnapshot = table.currentSnapshot(); + + table.newAppend() + .appendFile(FILE_C) // data_bucket=2 + .commit(); + + long t3 = rightAfterSnapshot(); + + // Retain last 2 snapshots and expire older than t3 + ExpireSnapshots.Result result = SparkActions.get().expireSnapshots(table) + .expireOlderThan(t3) + .retainLast(2) + .retainLast(1) + .execute(); + + Assert.assertEquals("Should have one snapshots.", 1, Lists.newArrayList(table.snapshots()).size()); + Assert.assertNull("Second snapshot should not present.", table.snapshot(secondSnapshot.snapshotId())); + checkExpirationResults(0L, 0L, 0L, 0L, 2L, result); + } + + @Test + public void testRetainZeroSnapshots() { + AssertHelpers.assertThrows("Should fail retain 0 snapshots " + + "because number of snapshots to retain cannot be zero", + IllegalArgumentException.class, + "Number of snapshots to retain must be at least 1, cannot be: 0", + () -> SparkActions.get().expireSnapshots(table).retainLast(0).execute()); + } + + @Test + public void testScanExpiredManifestInValidSnapshotAppend() { + table.newAppend() + .appendFile(FILE_A) + .appendFile(FILE_B) + .commit(); + + table.newOverwrite() + .addFile(FILE_C) + .deleteFile(FILE_A) + .commit(); + + table.newAppend() + .appendFile(FILE_D) + .commit(); + + long t3 = rightAfterSnapshot(); + + Set deletedFiles = Sets.newHashSet(); + + ExpireSnapshots.Result result = SparkActions.get().expireSnapshots(table) + .expireOlderThan(t3) + .deleteWith(deletedFiles::add) + .execute(); + + Assert.assertTrue("FILE_A should be deleted", deletedFiles.contains(FILE_A.path().toString())); + checkExpirationResults(1L, 0L, 0L, 1L, 2L, result); + } + + @Test + public void testScanExpiredManifestInValidSnapshotFastAppend() { + table.updateProperties() + .set(TableProperties.MANIFEST_MERGE_ENABLED, "true") + .set(TableProperties.MANIFEST_MIN_MERGE_COUNT, "1") + .commit(); + + table.newAppend() + .appendFile(FILE_A) + .appendFile(FILE_B) + .commit(); + + table.newOverwrite() + .addFile(FILE_C) + .deleteFile(FILE_A) + .commit(); + + table.newFastAppend() + .appendFile(FILE_D) + .commit(); + + long t3 = rightAfterSnapshot(); + + Set deletedFiles = Sets.newHashSet(); + + ExpireSnapshots.Result result = SparkActions.get().expireSnapshots(table) + .expireOlderThan(t3) + .deleteWith(deletedFiles::add) + .execute(); + + Assert.assertTrue("FILE_A should be deleted", deletedFiles.contains(FILE_A.path().toString())); + checkExpirationResults(1L, 0L, 0L, 1L, 2L, result); + } + + /** + * Test on table below, and expiring the staged commit `B` using `expireOlderThan` API. + * Table: A - C + * ` B (staged) + */ + @Test + public void testWithExpiringDanglingStageCommit() { + // `A` commit + table.newAppend() + .appendFile(FILE_A) + .commit(); + + // `B` staged commit + table.newAppend() + .appendFile(FILE_B) + .stageOnly() + .commit(); + + TableMetadata base = ((BaseTable) table).operations().current(); + Snapshot snapshotA = base.snapshots().get(0); + Snapshot snapshotB = base.snapshots().get(1); + + // `C` commit + table.newAppend() + .appendFile(FILE_C) + .commit(); + + Set deletedFiles = Sets.newHashSet(); + + // Expire all commits including dangling staged snapshot. + ExpireSnapshots.Result result = SparkActions.get().expireSnapshots(table) + .deleteWith(deletedFiles::add) + .expireOlderThan(snapshotB.timestampMillis() + 1) + .execute(); + + checkExpirationResults(1L, 0L, 0L, 1L, 2L, result); + + Set expectedDeletes = Sets.newHashSet(); + expectedDeletes.add(snapshotA.manifestListLocation()); + + // Files should be deleted of dangling staged snapshot + snapshotB.addedFiles(table.io()).forEach(i -> { + expectedDeletes.add(i.path().toString()); + }); + + // ManifestList should be deleted too + expectedDeletes.add(snapshotB.manifestListLocation()); + snapshotB.dataManifests(table.io()).forEach(file -> { + // Only the manifest of B should be deleted. + if (file.snapshotId() == snapshotB.snapshotId()) { + expectedDeletes.add(file.path()); + } + }); + Assert.assertSame("Files deleted count should be expected", expectedDeletes.size(), deletedFiles.size()); + // Take the diff + expectedDeletes.removeAll(deletedFiles); + Assert.assertTrue("Exactly same files should be deleted", expectedDeletes.isEmpty()); + } + + /** + * Expire cherry-pick the commit as shown below, when `B` is in table's current state + * Table: + * A - B - C <--current snapshot + * `- D (source=B) + */ + @Test + public void testWithCherryPickTableSnapshot() { + // `A` commit + table.newAppend() + .appendFile(FILE_A) + .commit(); + Snapshot snapshotA = table.currentSnapshot(); + + // `B` commit + Set deletedAFiles = Sets.newHashSet(); + table.newOverwrite() + .addFile(FILE_B) + .deleteFile(FILE_A) + .deleteWith(deletedAFiles::add) + .commit(); + Assert.assertTrue("No files should be physically deleted", deletedAFiles.isEmpty()); + + // pick the snapshot 'B` + Snapshot snapshotB = table.currentSnapshot(); + + // `C` commit to let cherry-pick take effect, and avoid fast-forward of `B` with cherry-pick + table.newAppend() + .appendFile(FILE_C) + .commit(); + Snapshot snapshotC = table.currentSnapshot(); + + // Move the table back to `A` + table.manageSnapshots() + .setCurrentSnapshot(snapshotA.snapshotId()) + .commit(); + + // Generate A -> `D (B)` + table.manageSnapshots() + .cherrypick(snapshotB.snapshotId()) + .commit(); + Snapshot snapshotD = table.currentSnapshot(); + + // Move the table back to `C` + table.manageSnapshots() + .setCurrentSnapshot(snapshotC.snapshotId()) + .commit(); + List deletedFiles = Lists.newArrayList(); + + // Expire `C` + ExpireSnapshots.Result result = SparkActions.get().expireSnapshots(table) + .deleteWith(deletedFiles::add) + .expireOlderThan(snapshotC.timestampMillis() + 1) + .execute(); + + // Make sure no dataFiles are deleted for the B, C, D snapshot + Lists.newArrayList(snapshotB, snapshotC, snapshotD).forEach(i -> { + i.addedFiles(table.io()).forEach(item -> { + Assert.assertFalse(deletedFiles.contains(item.path().toString())); + }); + }); + + checkExpirationResults(1L, 0L, 0L, 2L, 2L, result); + } + + /** + * Test on table below, and expiring `B` which is not in current table state. + * 1) Expire `B` + * 2) All commit + * Table: A - C - D (B) + * ` B (staged) + */ + @Test + public void testWithExpiringStagedThenCherrypick() { + // `A` commit + table.newAppend() + .appendFile(FILE_A) + .commit(); + + // `B` commit + table.newAppend() + .appendFile(FILE_B) + .stageOnly() + .commit(); + + // pick the snapshot that's staged but not committed + TableMetadata base = ((BaseTable) table).operations().current(); + Snapshot snapshotB = base.snapshots().get(1); + + // `C` commit to let cherry-pick take effect, and avoid fast-forward of `B` with cherry-pick + table.newAppend() + .appendFile(FILE_C) + .commit(); + + // `D (B)` cherry-pick commit + table.manageSnapshots() + .cherrypick(snapshotB.snapshotId()) + .commit(); + + base = ((BaseTable) table).operations().current(); + Snapshot snapshotD = base.snapshots().get(3); + + List deletedFiles = Lists.newArrayList(); + + // Expire `B` commit. + ExpireSnapshots.Result firstResult = SparkActions.get().expireSnapshots(table) + .deleteWith(deletedFiles::add) + .expireSnapshotId(snapshotB.snapshotId()) + .execute(); + + // Make sure no dataFiles are deleted for the staged snapshot + Lists.newArrayList(snapshotB).forEach(i -> { + i.addedFiles(table.io()).forEach(item -> { + Assert.assertFalse(deletedFiles.contains(item.path().toString())); + }); + }); + checkExpirationResults(0L, 0L, 0L, 1L, 1L, firstResult); + + // Expire all snapshots including cherry-pick + ExpireSnapshots.Result secondResult = SparkActions.get().expireSnapshots(table) + .deleteWith(deletedFiles::add) + .expireOlderThan(table.currentSnapshot().timestampMillis() + 1) + .execute(); + + // Make sure no dataFiles are deleted for the staged and cherry-pick + Lists.newArrayList(snapshotB, snapshotD).forEach(i -> { + i.addedFiles(table.io()).forEach(item -> { + Assert.assertFalse(deletedFiles.contains(item.path().toString())); + }); + }); + checkExpirationResults(0L, 0L, 0L, 0L, 2L, secondResult); + } + + @Test + public void testExpireOlderThan() { + table.newAppend() + .appendFile(FILE_A) + .commit(); + + Snapshot firstSnapshot = table.currentSnapshot(); + + rightAfterSnapshot(); + + table.newAppend() + .appendFile(FILE_B) + .commit(); + + long snapshotId = table.currentSnapshot().snapshotId(); + + long tAfterCommits = rightAfterSnapshot(); + + Set deletedFiles = Sets.newHashSet(); + + ExpireSnapshots.Result result = SparkActions.get().expireSnapshots(table) + .expireOlderThan(tAfterCommits) + .deleteWith(deletedFiles::add) + .execute(); + + Assert.assertEquals("Expire should not change current snapshot", snapshotId, table.currentSnapshot().snapshotId()); + Assert.assertNull("Expire should remove the oldest snapshot", table.snapshot(firstSnapshot.snapshotId())); + Assert.assertEquals("Should remove only the expired manifest list location", + Sets.newHashSet(firstSnapshot.manifestListLocation()), deletedFiles); + + checkExpirationResults(0, 0, 0, 0, 1, result); + } + + @Test + public void testExpireOlderThanWithDelete() { + table.newAppend() + .appendFile(FILE_A) + .commit(); + + Snapshot firstSnapshot = table.currentSnapshot(); + Assert.assertEquals("Should create one manifest", + 1, firstSnapshot.allManifests(table.io()).size()); + + rightAfterSnapshot(); + + table.newDelete() + .deleteFile(FILE_A) + .commit(); + + Snapshot secondSnapshot = table.currentSnapshot(); + Assert.assertEquals("Should create replace manifest with a rewritten manifest", + 1, secondSnapshot.allManifests(table.io()).size()); + + table.newAppend() + .appendFile(FILE_B) + .commit(); + + rightAfterSnapshot(); + + long snapshotId = table.currentSnapshot().snapshotId(); + + long tAfterCommits = rightAfterSnapshot(); + + Set deletedFiles = Sets.newHashSet(); + + ExpireSnapshots.Result result = SparkActions.get().expireSnapshots(table) + .expireOlderThan(tAfterCommits) + .deleteWith(deletedFiles::add) + .execute(); + + Assert.assertEquals("Expire should not change current snapshot", snapshotId, table.currentSnapshot().snapshotId()); + Assert.assertNull("Expire should remove the oldest snapshot", table.snapshot(firstSnapshot.snapshotId())); + Assert.assertNull("Expire should remove the second oldest snapshot", table.snapshot(secondSnapshot.snapshotId())); + + Assert.assertEquals("Should remove expired manifest lists and deleted data file", + Sets.newHashSet( + firstSnapshot.manifestListLocation(), // snapshot expired + firstSnapshot.allManifests(table.io()).get(0).path(), // manifest was rewritten for delete + secondSnapshot.manifestListLocation(), // snapshot expired + secondSnapshot.allManifests(table.io()).get(0).path(), // manifest contained only deletes, was dropped + FILE_A.path()), // deleted + deletedFiles); + + checkExpirationResults(1, 0, 0, 2, 2, result); + } + + @Test + public void testExpireOlderThanWithDeleteInMergedManifests() { + // merge every commit + table.updateProperties() + .set(TableProperties.MANIFEST_MIN_MERGE_COUNT, "0") + .commit(); + + table.newAppend() + .appendFile(FILE_A) + .appendFile(FILE_B) + .commit(); + + Snapshot firstSnapshot = table.currentSnapshot(); + Assert.assertEquals("Should create one manifest", + 1, firstSnapshot.allManifests(table.io()).size()); + + rightAfterSnapshot(); + + table.newDelete() + .deleteFile(FILE_A) // FILE_B is still in the dataset + .commit(); + + Snapshot secondSnapshot = table.currentSnapshot(); + Assert.assertEquals("Should replace manifest with a rewritten manifest", + 1, secondSnapshot.allManifests(table.io()).size()); + + table.newFastAppend() // do not merge to keep the last snapshot's manifest valid + .appendFile(FILE_C) + .commit(); + + rightAfterSnapshot(); + + long snapshotId = table.currentSnapshot().snapshotId(); + + long tAfterCommits = rightAfterSnapshot(); + + Set deletedFiles = Sets.newHashSet(); + + ExpireSnapshots.Result result = SparkActions.get().expireSnapshots(table) + .expireOlderThan(tAfterCommits) + .deleteWith(deletedFiles::add) + .execute(); + + Assert.assertEquals("Expire should not change current snapshot", snapshotId, table.currentSnapshot().snapshotId()); + Assert.assertNull("Expire should remove the oldest snapshot", table.snapshot(firstSnapshot.snapshotId())); + Assert.assertNull("Expire should remove the second oldest snapshot", table.snapshot(secondSnapshot.snapshotId())); + + Assert.assertEquals("Should remove expired manifest lists and deleted data file", + Sets.newHashSet( + firstSnapshot.manifestListLocation(), // snapshot expired + firstSnapshot.allManifests(table.io()).get(0).path(), // manifest was rewritten for delete + secondSnapshot.manifestListLocation(), // snapshot expired + FILE_A.path()), // deleted + deletedFiles); + + checkExpirationResults(1, 0, 0, 1, 2, result); + } + + @Test + public void testExpireOlderThanWithRollback() { + // merge every commit + table.updateProperties() + .set(TableProperties.MANIFEST_MIN_MERGE_COUNT, "0") + .commit(); + + table.newAppend() + .appendFile(FILE_A) + .appendFile(FILE_B) + .commit(); + + Snapshot firstSnapshot = table.currentSnapshot(); + Assert.assertEquals("Should create one manifest", + 1, firstSnapshot.allManifests(table.io()).size()); + + rightAfterSnapshot(); + + table.newDelete() + .deleteFile(FILE_B) + .commit(); + + Snapshot secondSnapshot = table.currentSnapshot(); + Set secondSnapshotManifests = Sets.newHashSet(secondSnapshot.allManifests(table.io())); + secondSnapshotManifests.removeAll(firstSnapshot.allManifests(table.io())); + Assert.assertEquals("Should add one new manifest for append", 1, secondSnapshotManifests.size()); + + table.manageSnapshots() + .rollbackTo(firstSnapshot.snapshotId()) + .commit(); + + long tAfterCommits = rightAfterSnapshot(secondSnapshot.snapshotId()); + + long snapshotId = table.currentSnapshot().snapshotId(); + + Set deletedFiles = Sets.newHashSet(); + + ExpireSnapshots.Result result = SparkActions.get().expireSnapshots(table) + .expireOlderThan(tAfterCommits) + .deleteWith(deletedFiles::add) + .execute(); + + Assert.assertEquals("Expire should not change current snapshot", snapshotId, table.currentSnapshot().snapshotId()); + Assert.assertNotNull("Expire should keep the oldest snapshot, current", table.snapshot(firstSnapshot.snapshotId())); + Assert.assertNull("Expire should remove the orphaned snapshot", table.snapshot(secondSnapshot.snapshotId())); + + Assert.assertEquals("Should remove expired manifest lists and reverted appended data file", + Sets.newHashSet( + secondSnapshot.manifestListLocation(), // snapshot expired + Iterables.getOnlyElement(secondSnapshotManifests).path()), // manifest is no longer referenced + deletedFiles); + + checkExpirationResults(0, 0, 0, 1, 1, result); + } + + @Test + public void testExpireOlderThanWithRollbackAndMergedManifests() { + table.newAppend() + .appendFile(FILE_A) + .commit(); + + Snapshot firstSnapshot = table.currentSnapshot(); + Assert.assertEquals("Should create one manifest", + 1, firstSnapshot.allManifests(table.io()).size()); + + rightAfterSnapshot(); + + table.newAppend() + .appendFile(FILE_B) + .commit(); + + Snapshot secondSnapshot = table.currentSnapshot(); + Set secondSnapshotManifests = Sets.newHashSet(secondSnapshot.allManifests(table.io())); + secondSnapshotManifests.removeAll(firstSnapshot.allManifests(table.io())); + Assert.assertEquals("Should add one new manifest for append", 1, secondSnapshotManifests.size()); + + table.manageSnapshots() + .rollbackTo(firstSnapshot.snapshotId()) + .commit(); + + long tAfterCommits = rightAfterSnapshot(secondSnapshot.snapshotId()); + + long snapshotId = table.currentSnapshot().snapshotId(); + + Set deletedFiles = Sets.newHashSet(); + + ExpireSnapshots.Result result = SparkActions.get().expireSnapshots(table) + .expireOlderThan(tAfterCommits) + .deleteWith(deletedFiles::add) + .execute(); + + Assert.assertEquals("Expire should not change current snapshot", snapshotId, table.currentSnapshot().snapshotId()); + Assert.assertNotNull("Expire should keep the oldest snapshot, current", table.snapshot(firstSnapshot.snapshotId())); + Assert.assertNull("Expire should remove the orphaned snapshot", table.snapshot(secondSnapshot.snapshotId())); + + Assert.assertEquals("Should remove expired manifest lists and reverted appended data file", + Sets.newHashSet( + secondSnapshot.manifestListLocation(), // snapshot expired + Iterables.getOnlyElement(secondSnapshotManifests).path(), // manifest is no longer referenced + FILE_B.path()), // added, but rolled back + deletedFiles); + + checkExpirationResults(1, 0, 0, 1, 1, result); + } + + @Test + public void testExpireOlderThanWithDeleteFile() { + table.updateProperties() + .set(TableProperties.FORMAT_VERSION, "2") + .set(TableProperties.MANIFEST_MERGE_ENABLED, "false") + .commit(); + + // Add Data File + table.newAppend() + .appendFile(FILE_A) + .commit(); + Snapshot firstSnapshot = table.currentSnapshot(); + + // Add POS Delete + table.newRowDelta() + .addDeletes(FILE_A_POS_DELETES) + .commit(); + Snapshot secondSnapshot = table.currentSnapshot(); + + // Add EQ Delete + table.newRowDelta() + .addDeletes(FILE_A_EQ_DELETES) + .commit(); + Snapshot thirdSnapshot = table.currentSnapshot(); + + // Move files to DELETED + table.newDelete() + .deleteFromRowFilter(Expressions.alwaysTrue()) + .commit(); + Snapshot fourthSnapshot = table.currentSnapshot(); + + long afterAllDeleted = rightAfterSnapshot(); + + table.newAppend() + .appendFile(FILE_B) + .commit(); + + Set deletedFiles = Sets.newHashSet(); + + ExpireSnapshots.Result result = SparkActions.get().expireSnapshots(table) + .expireOlderThan(afterAllDeleted) + .deleteWith(deletedFiles::add) + .execute(); + + Set expectedDeletes = Sets.newHashSet( + firstSnapshot.manifestListLocation(), + secondSnapshot.manifestListLocation(), + thirdSnapshot.manifestListLocation(), + fourthSnapshot.manifestListLocation(), + FILE_A.path().toString(), + FILE_A_POS_DELETES.path().toString(), + FILE_A_EQ_DELETES.path().toString()); + + expectedDeletes.addAll( + thirdSnapshot.allManifests(table.io()).stream() + .map(ManifestFile::path) + .map(CharSequence::toString).collect(Collectors.toSet())); + // Delete operation (fourth snapshot) generates new manifest files + expectedDeletes.addAll( + fourthSnapshot.allManifests(table.io()).stream() + .map(ManifestFile::path) + .map(CharSequence::toString).collect(Collectors.toSet())); + + Assert.assertEquals("Should remove expired manifest lists and deleted data file", + expectedDeletes, + deletedFiles); + + checkExpirationResults(1, 1, 1, 6, 4, result); + } + + @Test + public void testExpireOnEmptyTable() { + Set deletedFiles = Sets.newHashSet(); + + // table has no data, testing ExpireSnapshots should not fail with no snapshot + ExpireSnapshots.Result result = SparkActions.get().expireSnapshots(table) + .expireOlderThan(System.currentTimeMillis()) + .deleteWith(deletedFiles::add) + .execute(); + + checkExpirationResults(0, 0, 0, 0, 0, result); + } + + @Test + public void testExpireAction() { + table.newAppend() + .appendFile(FILE_A) + .commit(); + + Snapshot firstSnapshot = table.currentSnapshot(); + + rightAfterSnapshot(); + + table.newAppend() + .appendFile(FILE_B) + .commit(); + + long snapshotId = table.currentSnapshot().snapshotId(); + + long tAfterCommits = rightAfterSnapshot(); + + Set deletedFiles = Sets.newHashSet(); + + BaseExpireSnapshotsSparkAction action = (BaseExpireSnapshotsSparkAction) SparkActions.get().expireSnapshots(table) + .expireOlderThan(tAfterCommits) + .deleteWith(deletedFiles::add); + Dataset pendingDeletes = action.expire(); + + List pending = pendingDeletes.collectAsList(); + + Assert.assertEquals("Should not change current snapshot", snapshotId, table.currentSnapshot().snapshotId()); + Assert.assertNull("Should remove the oldest snapshot", table.snapshot(firstSnapshot.snapshotId())); + + Assert.assertEquals("Pending deletes should contain one row", 1, pending.size()); + Assert.assertEquals("Pending delete should be the expired manifest list location", + firstSnapshot.manifestListLocation(), pending.get(0).getString(0)); + Assert.assertEquals("Pending delete should be a manifest list", + "Manifest List", pending.get(0).getString(1)); + + Assert.assertEquals("Should not delete any files", 0, deletedFiles.size()); + + Assert.assertSame("Multiple calls to expire should return the same deleted files", + pendingDeletes, action.expire()); + } + + @Test + public void testUseLocalIterator() { + table.newFastAppend() + .appendFile(FILE_A) + .commit(); + + table.newOverwrite() + .deleteFile(FILE_A) + .addFile(FILE_B) + .commit(); + + table.newFastAppend() + .appendFile(FILE_C) + .commit(); + + long end = rightAfterSnapshot(); + + int jobsBeforeStreamResults = spark.sparkContext().dagScheduler().nextJobId().get(); + + withSQLConf(ImmutableMap.of("spark.sql.adaptive.enabled", "false"), () -> { + ExpireSnapshots.Result results = SparkActions.get().expireSnapshots(table).expireOlderThan(end) + .option("stream-results", "true").execute(); + + int jobsAfterStreamResults = spark.sparkContext().dagScheduler().nextJobId().get(); + int jobsRunDuringStreamResults = jobsAfterStreamResults - jobsBeforeStreamResults; + + checkExpirationResults(1L, 0L, 0L, 1L, 2L, results); + + Assert.assertEquals("Expected total number of jobs with stream-results should match the expected number", + 5L, jobsRunDuringStreamResults); + }); + } +} diff --git a/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/actions/TestRemoveOrphanFilesAction.java b/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/actions/TestRemoveOrphanFilesAction.java new file mode 100644 index 000000000000..a56e888207b2 --- /dev/null +++ b/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/actions/TestRemoveOrphanFilesAction.java @@ -0,0 +1,958 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iceberg.spark.actions; + +import java.io.File; +import java.io.IOException; +import java.sql.Timestamp; +import java.util.Arrays; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.stream.Collectors; +import java.util.stream.StreamSupport; +import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.fs.FileStatus; +import org.apache.hadoop.fs.FileSystem; +import org.apache.hadoop.fs.Path; +import org.apache.iceberg.AssertHelpers; +import org.apache.iceberg.PartitionSpec; +import org.apache.iceberg.Schema; +import org.apache.iceberg.Snapshot; +import org.apache.iceberg.Table; +import org.apache.iceberg.TableProperties; +import org.apache.iceberg.actions.DeleteOrphanFiles; +import org.apache.iceberg.catalog.Namespace; +import org.apache.iceberg.catalog.TableIdentifier; +import org.apache.iceberg.exceptions.ValidationException; +import org.apache.iceberg.hadoop.HadoopCatalog; +import org.apache.iceberg.hadoop.HadoopTables; +import org.apache.iceberg.hadoop.HiddenPathFilter; +import org.apache.iceberg.relocated.com.google.common.collect.Iterables; +import org.apache.iceberg.relocated.com.google.common.collect.Lists; +import org.apache.iceberg.relocated.com.google.common.collect.Maps; +import org.apache.iceberg.relocated.com.google.common.collect.Sets; +import org.apache.iceberg.spark.SparkTestBase; +import org.apache.iceberg.spark.source.FilePathLastModifiedRecord; +import org.apache.iceberg.spark.source.ThreeColumnRecord; +import org.apache.iceberg.types.Types; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Encoders; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.RowFactory; +import org.apache.spark.sql.types.DataTypes; +import org.apache.spark.sql.types.StructType; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.TemporaryFolder; + +import static org.apache.iceberg.types.Types.NestedField.optional; + +public abstract class TestRemoveOrphanFilesAction extends SparkTestBase { + + private static final HadoopTables TABLES = new HadoopTables(new Configuration()); + protected static final Schema SCHEMA = new Schema( + optional(1, "c1", Types.IntegerType.get()), + optional(2, "c2", Types.StringType.get()), + optional(3, "c3", Types.StringType.get()) + ); + protected static final PartitionSpec SPEC = PartitionSpec.builderFor(SCHEMA) + .truncate("c2", 2) + .identity("c3") + .build(); + + @Rule + public TemporaryFolder temp = new TemporaryFolder(); + private File tableDir = null; + protected String tableLocation = null; + + @Before + public void setupTableLocation() throws Exception { + this.tableDir = temp.newFolder(); + this.tableLocation = tableDir.toURI().toString(); + } + + @Test + public void testDryRun() throws IOException, InterruptedException { + Table table = TABLES.create(SCHEMA, PartitionSpec.unpartitioned(), Maps.newHashMap(), tableLocation); + + List records = Lists.newArrayList( + new ThreeColumnRecord(1, "AAAAAAAAAA", "AAAA") + ); + + Dataset df = spark.createDataFrame(records, ThreeColumnRecord.class).coalesce(1); + + df.select("c1", "c2", "c3") + .write() + .format("iceberg") + .mode("append") + .save(tableLocation); + + df.select("c1", "c2", "c3") + .write() + .format("iceberg") + .mode("append") + .save(tableLocation); + + List validFiles = spark.read().format("iceberg") + .load(tableLocation + "#files") + .select("file_path") + .as(Encoders.STRING()) + .collectAsList(); + Assert.assertEquals("Should be 2 valid files", 2, validFiles.size()); + + df.write().mode("append").parquet(tableLocation + "/data"); + + Path dataPath = new Path(tableLocation + "/data"); + FileSystem fs = dataPath.getFileSystem(spark.sessionState().newHadoopConf()); + List allFiles = Arrays.stream(fs.listStatus(dataPath, HiddenPathFilter.get())) + .filter(FileStatus::isFile) + .map(file -> file.getPath().toString()) + .collect(Collectors.toList()); + Assert.assertEquals("Should be 3 files", 3, allFiles.size()); + + List invalidFiles = Lists.newArrayList(allFiles); + invalidFiles.removeAll(validFiles); + Assert.assertEquals("Should be 1 invalid file", 1, invalidFiles.size()); + + waitUntilAfter(System.currentTimeMillis()); + + SparkActions actions = SparkActions.get(); + + DeleteOrphanFiles.Result result1 = actions.deleteOrphanFiles(table) + .deleteWith(s -> { }) + .execute(); + Assert.assertTrue("Default olderThan interval should be safe", Iterables.isEmpty(result1.orphanFileLocations())); + + DeleteOrphanFiles.Result result2 = actions.deleteOrphanFiles(table) + .olderThan(System.currentTimeMillis()) + .deleteWith(s -> { }) + .execute(); + Assert.assertEquals("Action should find 1 file", invalidFiles, result2.orphanFileLocations()); + Assert.assertTrue("Invalid file should be present", fs.exists(new Path(invalidFiles.get(0)))); + + DeleteOrphanFiles.Result result3 = actions.deleteOrphanFiles(table) + .olderThan(System.currentTimeMillis()) + .execute(); + Assert.assertEquals("Action should delete 1 file", invalidFiles, result3.orphanFileLocations()); + Assert.assertFalse("Invalid file should not be present", fs.exists(new Path(invalidFiles.get(0)))); + + List expectedRecords = Lists.newArrayList(); + expectedRecords.addAll(records); + expectedRecords.addAll(records); + + Dataset resultDF = spark.read().format("iceberg").load(tableLocation); + List actualRecords = resultDF + .as(Encoders.bean(ThreeColumnRecord.class)) + .collectAsList(); + Assert.assertEquals("Rows must match", expectedRecords, actualRecords); + } + + @Test + public void testAllValidFilesAreKept() throws IOException, InterruptedException { + Table table = TABLES.create(SCHEMA, SPEC, Maps.newHashMap(), tableLocation); + + List records1 = Lists.newArrayList( + new ThreeColumnRecord(1, "AAAAAAAAAA", "AAAA") + ); + Dataset df1 = spark.createDataFrame(records1, ThreeColumnRecord.class).coalesce(1); + + // original append + df1.select("c1", "c2", "c3") + .write() + .format("iceberg") + .mode("append") + .save(tableLocation); + + List records2 = Lists.newArrayList( + new ThreeColumnRecord(2, "AAAAAAAAAA", "AAAA") + ); + Dataset df2 = spark.createDataFrame(records2, ThreeColumnRecord.class).coalesce(1); + + // dynamic partition overwrite + df2.select("c1", "c2", "c3") + .write() + .format("iceberg") + .mode("overwrite") + .save(tableLocation); + + // second append + df2.select("c1", "c2", "c3") + .write() + .format("iceberg") + .mode("append") + .save(tableLocation); + + List snapshots = Lists.newArrayList(table.snapshots()); + + List snapshotFiles1 = snapshotFiles(snapshots.get(0).snapshotId()); + Assert.assertEquals(1, snapshotFiles1.size()); + + List snapshotFiles2 = snapshotFiles(snapshots.get(1).snapshotId()); + Assert.assertEquals(1, snapshotFiles2.size()); + + List snapshotFiles3 = snapshotFiles(snapshots.get(2).snapshotId()); + Assert.assertEquals(2, snapshotFiles3.size()); + + df2.coalesce(1).write().mode("append").parquet(tableLocation + "/data"); + df2.coalesce(1).write().mode("append").parquet(tableLocation + "/data/c2_trunc=AA"); + df2.coalesce(1).write().mode("append").parquet(tableLocation + "/data/c2_trunc=AA/c3=AAAA"); + df2.coalesce(1).write().mode("append").parquet(tableLocation + "/data/invalid/invalid"); + + waitUntilAfter(System.currentTimeMillis()); + + SparkActions actions = SparkActions.get(); + + DeleteOrphanFiles.Result result = actions.deleteOrphanFiles(table) + .olderThan(System.currentTimeMillis()) + .execute(); + + Assert.assertEquals("Should delete 4 files", 4, Iterables.size(result.orphanFileLocations())); + + Path dataPath = new Path(tableLocation + "/data"); + FileSystem fs = dataPath.getFileSystem(spark.sessionState().newHadoopConf()); + + for (String fileLocation : snapshotFiles1) { + Assert.assertTrue("All snapshot files must remain", fs.exists(new Path(fileLocation))); + } + + for (String fileLocation : snapshotFiles2) { + Assert.assertTrue("All snapshot files must remain", fs.exists(new Path(fileLocation))); + } + + for (String fileLocation : snapshotFiles3) { + Assert.assertTrue("All snapshot files must remain", fs.exists(new Path(fileLocation))); + } + } + + @Test + public void orphanedFileRemovedWithParallelTasks() throws InterruptedException, IOException { + Table table = TABLES.create(SCHEMA, SPEC, Maps.newHashMap(), tableLocation); + + List records1 = Lists.newArrayList( + new ThreeColumnRecord(1, "AAAAAAAAAA", "AAAA") + ); + Dataset df1 = spark.createDataFrame(records1, ThreeColumnRecord.class).coalesce(1); + + // original append + df1.select("c1", "c2", "c3") + .write() + .format("iceberg") + .mode("append") + .save(tableLocation); + + List records2 = Lists.newArrayList( + new ThreeColumnRecord(2, "AAAAAAAAAA", "AAAA") + ); + Dataset df2 = spark.createDataFrame(records2, ThreeColumnRecord.class).coalesce(1); + + // dynamic partition overwrite + df2.select("c1", "c2", "c3") + .write() + .format("iceberg") + .mode("overwrite") + .save(tableLocation); + + // second append + df2.select("c1", "c2", "c3") + .write() + .format("iceberg") + .mode("append") + .save(tableLocation); + + df2.coalesce(1).write().mode("append").parquet(tableLocation + "/data"); + df2.coalesce(1).write().mode("append").parquet(tableLocation + "/data/c2_trunc=AA"); + df2.coalesce(1).write().mode("append").parquet(tableLocation + "/data/c2_trunc=AA/c3=AAAA"); + df2.coalesce(1).write().mode("append").parquet(tableLocation + "/data/invalid/invalid"); + + waitUntilAfter(System.currentTimeMillis()); + + Set deletedFiles = Sets.newHashSet(); + Set deleteThreads = ConcurrentHashMap.newKeySet(); + AtomicInteger deleteThreadsIndex = new AtomicInteger(0); + + ExecutorService executorService = Executors.newFixedThreadPool(4, runnable -> { + Thread thread = new Thread(runnable); + thread.setName("remove-orphan-" + deleteThreadsIndex.getAndIncrement()); + thread.setDaemon(true); + return thread; + }); + + DeleteOrphanFiles.Result result = SparkActions.get().deleteOrphanFiles(table) + .executeDeleteWith(executorService) + .olderThan(System.currentTimeMillis() + 5000) // Ensure all orphan files are selected + .deleteWith(file -> { + deleteThreads.add(Thread.currentThread().getName()); + deletedFiles.add(file); + }) + .execute(); + + // Verifies that the delete methods ran in the threads created by the provided ExecutorService ThreadFactory + Assert.assertEquals(deleteThreads, + Sets.newHashSet("remove-orphan-0", "remove-orphan-1", "remove-orphan-2", "remove-orphan-3")); + + Assert.assertEquals("Should delete 4 files", 4, deletedFiles.size()); + } + + @Test + public void testWapFilesAreKept() throws InterruptedException { + Map props = Maps.newHashMap(); + props.put(TableProperties.WRITE_AUDIT_PUBLISH_ENABLED, "true"); + Table table = TABLES.create(SCHEMA, SPEC, props, tableLocation); + + List records = Lists.newArrayList( + new ThreeColumnRecord(1, "AAAAAAAAAA", "AAAA") + ); + Dataset df = spark.createDataFrame(records, ThreeColumnRecord.class); + + // normal write + df.select("c1", "c2", "c3") + .write() + .format("iceberg") + .mode("append") + .save(tableLocation); + + spark.conf().set("spark.wap.id", "1"); + + // wap write + df.select("c1", "c2", "c3") + .write() + .format("iceberg") + .mode("append") + .save(tableLocation); + + Dataset resultDF = spark.read().format("iceberg").load(tableLocation); + List actualRecords = resultDF + .as(Encoders.bean(ThreeColumnRecord.class)) + .collectAsList(); + Assert.assertEquals("Should not return data from the staged snapshot", records, actualRecords); + + waitUntilAfter(System.currentTimeMillis()); + + SparkActions actions = SparkActions.get(); + + DeleteOrphanFiles.Result result = actions.deleteOrphanFiles(table) + .olderThan(System.currentTimeMillis()) + .execute(); + + Assert.assertTrue("Should not delete any files", Iterables.isEmpty(result.orphanFileLocations())); + } + + @Test + public void testMetadataFolderIsIntact() throws InterruptedException { + // write data directly to the table location + Map props = Maps.newHashMap(); + props.put(TableProperties.WRITE_DATA_LOCATION, tableLocation); + Table table = TABLES.create(SCHEMA, SPEC, props, tableLocation); + + List records = Lists.newArrayList( + new ThreeColumnRecord(1, "AAAAAAAAAA", "AAAA") + ); + Dataset df = spark.createDataFrame(records, ThreeColumnRecord.class).coalesce(1); + + df.select("c1", "c2", "c3") + .write() + .format("iceberg") + .mode("append") + .save(tableLocation); + + df.write().mode("append").parquet(tableLocation + "/c2_trunc=AA/c3=AAAA"); + + waitUntilAfter(System.currentTimeMillis()); + + SparkActions actions = SparkActions.get(); + + DeleteOrphanFiles.Result result = actions.deleteOrphanFiles(table) + .olderThan(System.currentTimeMillis()) + .execute(); + + Assert.assertEquals("Should delete 1 file", 1, Iterables.size(result.orphanFileLocations())); + + Dataset resultDF = spark.read().format("iceberg").load(tableLocation); + List actualRecords = resultDF + .as(Encoders.bean(ThreeColumnRecord.class)) + .collectAsList(); + Assert.assertEquals("Rows must match", records, actualRecords); + } + + @Test + public void testOlderThanTimestamp() throws InterruptedException { + Table table = TABLES.create(SCHEMA, SPEC, Maps.newHashMap(), tableLocation); + + List records = Lists.newArrayList( + new ThreeColumnRecord(1, "AAAAAAAAAA", "AAAA") + ); + Dataset df = spark.createDataFrame(records, ThreeColumnRecord.class).coalesce(1); + + df.select("c1", "c2", "c3") + .write() + .format("iceberg") + .mode("append") + .save(tableLocation); + + df.write().mode("append").parquet(tableLocation + "/data/c2_trunc=AA/c3=AAAA"); + df.write().mode("append").parquet(tableLocation + "/data/c2_trunc=AA/c3=AAAA"); + + waitUntilAfter(System.currentTimeMillis()); + + long timestamp = System.currentTimeMillis(); + + waitUntilAfter(System.currentTimeMillis() + 1000L); + + df.write().mode("append").parquet(tableLocation + "/data/c2_trunc=AA/c3=AAAA"); + + SparkActions actions = SparkActions.get(); + + DeleteOrphanFiles.Result result = actions.deleteOrphanFiles(table) + .olderThan(timestamp) + .execute(); + + Assert.assertEquals("Should delete only 2 files", 2, Iterables.size(result.orphanFileLocations())); + } + + @Test + public void testRemoveUnreachableMetadataVersionFiles() throws InterruptedException { + Map props = Maps.newHashMap(); + props.put(TableProperties.WRITE_DATA_LOCATION, tableLocation); + props.put(TableProperties.METADATA_PREVIOUS_VERSIONS_MAX, "1"); + Table table = TABLES.create(SCHEMA, SPEC, props, tableLocation); + + List records = Lists.newArrayList( + new ThreeColumnRecord(1, "AAAAAAAAAA", "AAAA") + ); + Dataset df = spark.createDataFrame(records, ThreeColumnRecord.class); + + df.select("c1", "c2", "c3") + .write() + .format("iceberg") + .mode("append") + .save(tableLocation); + + df.select("c1", "c2", "c3") + .write() + .format("iceberg") + .mode("append") + .save(tableLocation); + + waitUntilAfter(System.currentTimeMillis()); + + SparkActions actions = SparkActions.get(); + + DeleteOrphanFiles.Result result = actions.deleteOrphanFiles(table) + .olderThan(System.currentTimeMillis()) + .execute(); + + Assert.assertEquals("Should delete 1 file", 1, Iterables.size(result.orphanFileLocations())); + Assert.assertTrue("Should remove v1 file", StreamSupport.stream(result.orphanFileLocations().spliterator(), false) + .anyMatch(file -> file.contains("v1.metadata.json"))); + + List expectedRecords = Lists.newArrayList(); + expectedRecords.addAll(records); + expectedRecords.addAll(records); + + Dataset resultDF = spark.read().format("iceberg").load(tableLocation); + List actualRecords = resultDF + .as(Encoders.bean(ThreeColumnRecord.class)) + .collectAsList(); + Assert.assertEquals("Rows must match", expectedRecords, actualRecords); + } + + @Test + public void testManyTopLevelPartitions() throws InterruptedException { + Table table = TABLES.create(SCHEMA, SPEC, Maps.newHashMap(), tableLocation); + + List records = Lists.newArrayList(); + for (int i = 0; i < 100; i++) { + records.add(new ThreeColumnRecord(i, String.valueOf(i), String.valueOf(i))); + } + + Dataset df = spark.createDataFrame(records, ThreeColumnRecord.class); + + df.select("c1", "c2", "c3") + .write() + .format("iceberg") + .mode("append") + .save(tableLocation); + + waitUntilAfter(System.currentTimeMillis()); + + SparkActions actions = SparkActions.get(); + + DeleteOrphanFiles.Result result = actions.deleteOrphanFiles(table) + .olderThan(System.currentTimeMillis()) + .execute(); + + Assert.assertTrue("Should not delete any files", Iterables.isEmpty(result.orphanFileLocations())); + + Dataset resultDF = spark.read().format("iceberg").load(tableLocation); + List actualRecords = resultDF + .as(Encoders.bean(ThreeColumnRecord.class)) + .collectAsList(); + Assert.assertEquals("Rows must match", records, actualRecords); + } + + @Test + public void testManyLeafPartitions() throws InterruptedException { + Table table = TABLES.create(SCHEMA, SPEC, Maps.newHashMap(), tableLocation); + + List records = Lists.newArrayList(); + for (int i = 0; i < 100; i++) { + records.add(new ThreeColumnRecord(i, String.valueOf(i % 3), String.valueOf(i))); + } + + Dataset df = spark.createDataFrame(records, ThreeColumnRecord.class); + + df.select("c1", "c2", "c3") + .write() + .format("iceberg") + .mode("append") + .save(tableLocation); + + waitUntilAfter(System.currentTimeMillis()); + + SparkActions actions = SparkActions.get(); + + DeleteOrphanFiles.Result result = actions.deleteOrphanFiles(table) + .olderThan(System.currentTimeMillis()) + .execute(); + + Assert.assertTrue("Should not delete any files", Iterables.isEmpty(result.orphanFileLocations())); + + Dataset resultDF = spark.read().format("iceberg").load(tableLocation); + List actualRecords = resultDF + .as(Encoders.bean(ThreeColumnRecord.class)) + .collectAsList(); + Assert.assertEquals("Rows must match", records, actualRecords); + } + + @Test + public void testHiddenPartitionPaths() throws InterruptedException { + Schema schema = new Schema( + optional(1, "c1", Types.IntegerType.get()), + optional(2, "_c2", Types.StringType.get()), + optional(3, "c3", Types.StringType.get()) + ); + PartitionSpec spec = PartitionSpec.builderFor(schema) + .truncate("_c2", 2) + .identity("c3") + .build(); + Table table = TABLES.create(schema, spec, Maps.newHashMap(), tableLocation); + + StructType structType = new StructType() + .add("c1", DataTypes.IntegerType) + .add("_c2", DataTypes.StringType) + .add("c3", DataTypes.StringType); + List records = Lists.newArrayList( + RowFactory.create(1, "AAAAAAAAAA", "AAAA") + ); + Dataset df = spark.createDataFrame(records, structType).coalesce(1); + + df.select("c1", "_c2", "c3") + .write() + .format("iceberg") + .mode("append") + .save(tableLocation); + + df.write().mode("append").parquet(tableLocation + "/data/_c2_trunc=AA/c3=AAAA"); + df.write().mode("append").parquet(tableLocation + "/data/_c2_trunc=AA/c3=AAAA"); + + waitUntilAfter(System.currentTimeMillis()); + + SparkActions actions = SparkActions.get(); + + DeleteOrphanFiles.Result result = actions.deleteOrphanFiles(table) + .olderThan(System.currentTimeMillis()) + .execute(); + + Assert.assertEquals("Should delete 2 files", 2, Iterables.size(result.orphanFileLocations())); + } + + @Test + public void testHiddenPartitionPathsWithPartitionEvolution() throws InterruptedException { + Schema schema = new Schema( + optional(1, "_c1", Types.IntegerType.get()), + optional(2, "_c2", Types.StringType.get()), + optional(3, "c3", Types.StringType.get()) + ); + PartitionSpec spec = PartitionSpec.builderFor(schema) + .truncate("_c2", 2) + .build(); + Table table = TABLES.create(schema, spec, Maps.newHashMap(), tableLocation); + + StructType structType = new StructType() + .add("_c1", DataTypes.IntegerType) + .add("_c2", DataTypes.StringType) + .add("c3", DataTypes.StringType); + List records = Lists.newArrayList( + RowFactory.create(1, "AAAAAAAAAA", "AAAA") + ); + Dataset df = spark.createDataFrame(records, structType).coalesce(1); + + df.select("_c1", "_c2", "c3") + .write() + .format("iceberg") + .mode("append") + .save(tableLocation); + + df.write().mode("append").parquet(tableLocation + "/data/_c2_trunc=AA"); + + table.updateSpec() + .addField("_c1") + .commit(); + + df.write().mode("append").parquet(tableLocation + "/data/_c2_trunc=AA/_c1=1"); + + waitUntilAfter(System.currentTimeMillis()); + + SparkActions actions = SparkActions.get(); + + DeleteOrphanFiles.Result result = actions.deleteOrphanFiles(table) + .olderThan(System.currentTimeMillis()) + .execute(); + + Assert.assertEquals("Should delete 2 files", 2, Iterables.size(result.orphanFileLocations())); + } + + @Test + public void testHiddenPathsStartingWithPartitionNamesAreIgnored() throws InterruptedException, IOException { + Schema schema = new Schema( + optional(1, "c1", Types.IntegerType.get()), + optional(2, "_c2", Types.StringType.get()), + optional(3, "c3", Types.StringType.get()) + ); + PartitionSpec spec = PartitionSpec.builderFor(schema) + .truncate("_c2", 2) + .identity("c3") + .build(); + Table table = TABLES.create(schema, spec, Maps.newHashMap(), tableLocation); + + StructType structType = new StructType() + .add("c1", DataTypes.IntegerType) + .add("_c2", DataTypes.StringType) + .add("c3", DataTypes.StringType); + List records = Lists.newArrayList( + RowFactory.create(1, "AAAAAAAAAA", "AAAA") + ); + Dataset df = spark.createDataFrame(records, structType).coalesce(1); + + df.select("c1", "_c2", "c3") + .write() + .format("iceberg") + .mode("append") + .save(tableLocation); + + Path dataPath = new Path(tableLocation + "/data"); + FileSystem fs = dataPath.getFileSystem(spark.sessionState().newHadoopConf()); + Path pathToFileInHiddenFolder = new Path(dataPath, "_c2_trunc/file.txt"); + fs.createNewFile(pathToFileInHiddenFolder); + + waitUntilAfter(System.currentTimeMillis()); + + SparkActions actions = SparkActions.get(); + + DeleteOrphanFiles.Result result = actions.deleteOrphanFiles(table) + .olderThan(System.currentTimeMillis()) + .execute(); + + Assert.assertEquals("Should delete 0 files", 0, Iterables.size(result.orphanFileLocations())); + Assert.assertTrue(fs.exists(pathToFileInHiddenFolder)); + } + + private List snapshotFiles(long snapshotId) { + return spark.read().format("iceberg") + .option("snapshot-id", snapshotId) + .load(tableLocation + "#files") + .select("file_path") + .as(Encoders.STRING()) + .collectAsList(); + } + + @Test + public void testRemoveOrphanFilesWithRelativeFilePath() throws IOException, InterruptedException { + Table table = TABLES.create(SCHEMA, PartitionSpec.unpartitioned(), Maps.newHashMap(), tableDir.getAbsolutePath()); + + List records = Lists.newArrayList( + new ThreeColumnRecord(1, "AAAAAAAAAA", "AAAA") + ); + + Dataset df = spark.createDataFrame(records, ThreeColumnRecord.class).coalesce(1); + + df.select("c1", "c2", "c3") + .write() + .format("iceberg") + .mode("append") + .save(tableDir.getAbsolutePath()); + + List validFiles = spark.read().format("iceberg") + .load(tableLocation + "#files") + .select("file_path") + .as(Encoders.STRING()) + .collectAsList(); + Assert.assertEquals("Should be 1 valid files", 1, validFiles.size()); + String validFile = validFiles.get(0); + + df.write().mode("append").parquet(tableLocation + "/data"); + + Path dataPath = new Path(tableLocation + "/data"); + FileSystem fs = dataPath.getFileSystem(spark.sessionState().newHadoopConf()); + List allFiles = Arrays.stream(fs.listStatus(dataPath, HiddenPathFilter.get())) + .filter(FileStatus::isFile) + .map(file -> file.getPath().toString()) + .collect(Collectors.toList()); + Assert.assertEquals("Should be 2 files", 2, allFiles.size()); + + List invalidFiles = Lists.newArrayList(allFiles); + invalidFiles.removeIf(file -> file.contains(validFile)); + Assert.assertEquals("Should be 1 invalid file", 1, invalidFiles.size()); + + waitUntilAfter(System.currentTimeMillis()); + + SparkActions actions = SparkActions.get(); + DeleteOrphanFiles.Result result = actions.deleteOrphanFiles(table) + .olderThan(System.currentTimeMillis()) + .deleteWith(s -> { }) + .execute(); + Assert.assertEquals("Action should find 1 file", invalidFiles, result.orphanFileLocations()); + Assert.assertTrue("Invalid file should be present", fs.exists(new Path(invalidFiles.get(0)))); + } + + @Test + public void testRemoveOrphanFilesWithHadoopCatalog() throws InterruptedException { + HadoopCatalog catalog = new HadoopCatalog(new Configuration(), tableLocation); + String namespaceName = "testDb"; + String tableName = "testTb"; + + Namespace namespace = Namespace.of(namespaceName); + TableIdentifier tableIdentifier = TableIdentifier.of(namespace, tableName); + Table table = catalog.createTable(tableIdentifier, SCHEMA, PartitionSpec.unpartitioned(), Maps.newHashMap()); + + List records = Lists.newArrayList( + new ThreeColumnRecord(1, "AAAAAAAAAA", "AAAA") + ); + Dataset df = spark.createDataFrame(records, ThreeColumnRecord.class).coalesce(1); + + df.select("c1", "c2", "c3") + .write() + .format("iceberg") + .mode("append") + .save(table.location()); + + df.write().mode("append").parquet(table.location() + "/data"); + + waitUntilAfter(System.currentTimeMillis()); + + table.refresh(); + + DeleteOrphanFiles.Result result = SparkActions.get() + .deleteOrphanFiles(table) + .olderThan(System.currentTimeMillis()) + .execute(); + + Assert.assertEquals("Should delete only 1 files", 1, Iterables.size(result.orphanFileLocations())); + + Dataset resultDF = spark.read().format("iceberg").load(table.location()); + List actualRecords = resultDF + .as(Encoders.bean(ThreeColumnRecord.class)) + .collectAsList(); + Assert.assertEquals("Rows must match", records, actualRecords); + } + + @Test + public void testHiveCatalogTable() throws IOException { + Table table = catalog.createTable(TableIdentifier.of("default", "hivetestorphan"), SCHEMA, SPEC, tableLocation, + Maps.newHashMap()); + + List records = Lists.newArrayList( + new ThreeColumnRecord(1, "AAAAAAAAAA", "AAAA") + ); + + Dataset df = spark.createDataFrame(records, ThreeColumnRecord.class).coalesce(1); + + df.select("c1", "c2", "c3") + .write() + .format("iceberg") + .mode("append") + .save("default.hivetestorphan"); + + String location = table.location().replaceFirst("file:", ""); + new File(location + "/data/trashfile").createNewFile(); + + DeleteOrphanFiles.Result result = SparkActions.get().deleteOrphanFiles(table) + .olderThan(System.currentTimeMillis() + 1000).execute(); + Assert.assertTrue("trash file should be removed", + StreamSupport.stream(result.orphanFileLocations().spliterator(), false) + .anyMatch(file -> file.contains("file:" + location + "data/trashfile"))); + } + + @Test + public void testGarbageCollectionDisabled() { + Table table = TABLES.create(SCHEMA, PartitionSpec.unpartitioned(), Maps.newHashMap(), tableLocation); + + List records = Lists.newArrayList( + new ThreeColumnRecord(1, "AAAAAAAAAA", "AAAA") + ); + + Dataset df = spark.createDataFrame(records, ThreeColumnRecord.class).coalesce(1); + + df.select("c1", "c2", "c3") + .write() + .format("iceberg") + .mode("append") + .save(tableLocation); + + table.updateProperties() + .set(TableProperties.GC_ENABLED, "false") + .commit(); + + AssertHelpers.assertThrows("Should complain about removing orphan files", + ValidationException.class, "Cannot delete orphan files: GC is disabled", + () -> SparkActions.get().deleteOrphanFiles(table).execute()); + } + + @Test + public void testCompareToFileList() throws IOException, InterruptedException { + Table table = + TABLES.create(SCHEMA, PartitionSpec.unpartitioned(), Maps.newHashMap(), tableLocation); + + List records = + Lists.newArrayList(new ThreeColumnRecord(1, "AAAAAAAAAA", "AAAA")); + + Dataset df = spark.createDataFrame(records, ThreeColumnRecord.class).coalesce(1); + + df.select("c1", "c2", "c3").write().format("iceberg").mode("append").save(tableLocation); + + df.select("c1", "c2", "c3").write().format("iceberg").mode("append").save(tableLocation); + + Path dataPath = new Path(tableLocation + "/data"); + FileSystem fs = dataPath.getFileSystem(spark.sessionState().newHadoopConf()); + List validFiles = + Arrays.stream(fs.listStatus(dataPath, HiddenPathFilter.get())) + .filter(FileStatus::isFile) + .map( + file -> + new FilePathLastModifiedRecord( + file.getPath().toString(), new Timestamp(file.getModificationTime()))) + .collect(Collectors.toList()); + + Assert.assertEquals("Should be 2 valid files", 2, validFiles.size()); + + df.write().mode("append").parquet(tableLocation + "/data"); + + List allFiles = + Arrays.stream(fs.listStatus(dataPath, HiddenPathFilter.get())) + .filter(FileStatus::isFile) + .map( + file -> + new FilePathLastModifiedRecord( + file.getPath().toString(), new Timestamp(file.getModificationTime()))) + .collect(Collectors.toList()); + + Assert.assertEquals("Should be 3 files", 3, allFiles.size()); + + List invalidFiles = Lists.newArrayList(allFiles); + invalidFiles.removeAll(validFiles); + List invalidFilePaths = + invalidFiles.stream() + .map(FilePathLastModifiedRecord::getFilePath) + .collect(Collectors.toList()); + Assert.assertEquals("Should be 1 invalid file", 1, invalidFiles.size()); + + // sleep for 1 second to ensure files will be old enough + waitUntilAfter(System.currentTimeMillis()); + + SparkActions actions = SparkActions.get(); + + Dataset compareToFileList = + spark + .createDataFrame(allFiles, FilePathLastModifiedRecord.class) + .withColumnRenamed("filePath", "file_path") + .withColumnRenamed("lastModified", "last_modified"); + + DeleteOrphanFiles.Result result1 = + ((BaseDeleteOrphanFilesSparkAction) actions.deleteOrphanFiles(table)) + .compareToFileList(compareToFileList) + .deleteWith(s -> { }) + .execute(); + Assert.assertTrue( + "Default olderThan interval should be safe", + Iterables.isEmpty(result1.orphanFileLocations())); + + DeleteOrphanFiles.Result result2 = + ((BaseDeleteOrphanFilesSparkAction) actions.deleteOrphanFiles(table)) + .compareToFileList(compareToFileList) + .olderThan(System.currentTimeMillis()) + .deleteWith(s -> { }) + .execute(); + Assert.assertEquals( + "Action should find 1 file", invalidFilePaths, result2.orphanFileLocations()); + Assert.assertTrue( + "Invalid file should be present", fs.exists(new Path(invalidFilePaths.get(0)))); + + DeleteOrphanFiles.Result result3 = + ((BaseDeleteOrphanFilesSparkAction) actions.deleteOrphanFiles(table)) + .compareToFileList(compareToFileList) + .olderThan(System.currentTimeMillis()) + .execute(); + Assert.assertEquals( + "Action should delete 1 file", invalidFilePaths, result3.orphanFileLocations()); + Assert.assertFalse( + "Invalid file should not be present", fs.exists(new Path(invalidFilePaths.get(0)))); + + List expectedRecords = Lists.newArrayList(); + expectedRecords.addAll(records); + expectedRecords.addAll(records); + + Dataset resultDF = spark.read().format("iceberg").load(tableLocation); + List actualRecords = + resultDF.as(Encoders.bean(ThreeColumnRecord.class)).collectAsList(); + Assert.assertEquals("Rows must match", expectedRecords, actualRecords); + + List outsideLocationMockFiles = + Lists.newArrayList(new FilePathLastModifiedRecord("/tmp/mock1", new Timestamp(0L))); + + Dataset compareToFileListWithOutsideLocation = + spark + .createDataFrame(outsideLocationMockFiles, FilePathLastModifiedRecord.class) + .withColumnRenamed("filePath", "file_path") + .withColumnRenamed("lastModified", "last_modified"); + + DeleteOrphanFiles.Result result4 = + ((BaseDeleteOrphanFilesSparkAction) actions.deleteOrphanFiles(table)) + .compareToFileList(compareToFileListWithOutsideLocation) + .deleteWith(s -> { }) + .execute(); + Assert.assertEquals( + "Action should find nothing", Lists.newArrayList(), result4.orphanFileLocations()); + } + + protected long waitUntilAfter(long timestampMillis) { + long current = System.currentTimeMillis(); + while (current <= timestampMillis) { + current = System.currentTimeMillis(); + } + return current; + } +} diff --git a/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/actions/TestRemoveOrphanFilesAction3.java b/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/actions/TestRemoveOrphanFilesAction3.java new file mode 100644 index 000000000000..77eb23a6dffc --- /dev/null +++ b/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/actions/TestRemoveOrphanFilesAction3.java @@ -0,0 +1,175 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iceberg.spark.actions; + +import java.io.File; +import java.util.Map; +import java.util.stream.StreamSupport; +import org.apache.iceberg.actions.DeleteOrphanFiles; +import org.apache.iceberg.relocated.com.google.common.collect.Maps; +import org.apache.iceberg.spark.SparkCatalog; +import org.apache.iceberg.spark.SparkSchemaUtil; +import org.apache.iceberg.spark.SparkSessionCatalog; +import org.apache.iceberg.spark.source.SparkTable; +import org.apache.spark.sql.connector.catalog.Identifier; +import org.apache.spark.sql.connector.expressions.Transform; +import org.junit.After; +import org.junit.Assert; +import org.junit.Test; + +public class TestRemoveOrphanFilesAction3 extends TestRemoveOrphanFilesAction { + @Test + public void testSparkCatalogTable() throws Exception { + spark.conf().set("spark.sql.catalog.mycat", "org.apache.iceberg.spark.SparkCatalog"); + spark.conf().set("spark.sql.catalog.mycat.type", "hadoop"); + spark.conf().set("spark.sql.catalog.mycat.warehouse", tableLocation); + SparkCatalog cat = (SparkCatalog) spark.sessionState().catalogManager().catalog("mycat"); + + String[] database = {"default"}; + Identifier id = Identifier.of(database, "table"); + Map options = Maps.newHashMap(); + Transform[] transforms = {}; + cat.createTable(id, SparkSchemaUtil.convert(SCHEMA), transforms, options); + SparkTable table = cat.loadTable(id); + + spark.sql("INSERT INTO mycat.default.table VALUES (1,1,1)"); + + String location = table.table().location().replaceFirst("file:", ""); + new File(location + "/data/trashfile").createNewFile(); + + DeleteOrphanFiles.Result results = SparkActions.get().deleteOrphanFiles(table.table()) + .olderThan(System.currentTimeMillis() + 1000).execute(); + Assert.assertTrue("trash file should be removed", + StreamSupport.stream(results.orphanFileLocations().spliterator(), false) + .anyMatch(file -> file.contains("file:" + location + "/data/trashfile"))); + } + + @Test + public void testSparkCatalogNamedHadoopTable() throws Exception { + spark.conf().set("spark.sql.catalog.hadoop", "org.apache.iceberg.spark.SparkCatalog"); + spark.conf().set("spark.sql.catalog.hadoop.type", "hadoop"); + spark.conf().set("spark.sql.catalog.hadoop.warehouse", tableLocation); + SparkCatalog cat = (SparkCatalog) spark.sessionState().catalogManager().catalog("hadoop"); + + String[] database = {"default"}; + Identifier id = Identifier.of(database, "table"); + Map options = Maps.newHashMap(); + Transform[] transforms = {}; + cat.createTable(id, SparkSchemaUtil.convert(SCHEMA), transforms, options); + SparkTable table = cat.loadTable(id); + + spark.sql("INSERT INTO hadoop.default.table VALUES (1,1,1)"); + + String location = table.table().location().replaceFirst("file:", ""); + new File(location + "/data/trashfile").createNewFile(); + + DeleteOrphanFiles.Result results = SparkActions.get().deleteOrphanFiles(table.table()) + .olderThan(System.currentTimeMillis() + 1000).execute(); + Assert.assertTrue("trash file should be removed", + StreamSupport.stream(results.orphanFileLocations().spliterator(), false) + .anyMatch(file -> file.contains("file:" + location + "/data/trashfile"))); + } + + @Test + public void testSparkCatalogNamedHiveTable() throws Exception { + spark.conf().set("spark.sql.catalog.hive", "org.apache.iceberg.spark.SparkCatalog"); + spark.conf().set("spark.sql.catalog.hive.type", "hadoop"); + spark.conf().set("spark.sql.catalog.hive.warehouse", tableLocation); + SparkCatalog cat = (SparkCatalog) spark.sessionState().catalogManager().catalog("hive"); + + String[] database = {"default"}; + Identifier id = Identifier.of(database, "table"); + Map options = Maps.newHashMap(); + Transform[] transforms = {}; + cat.createTable(id, SparkSchemaUtil.convert(SCHEMA), transforms, options); + SparkTable table = cat.loadTable(id); + + spark.sql("INSERT INTO hive.default.table VALUES (1,1,1)"); + + String location = table.table().location().replaceFirst("file:", ""); + new File(location + "/data/trashfile").createNewFile(); + + DeleteOrphanFiles.Result results = SparkActions.get().deleteOrphanFiles(table.table()) + .olderThan(System.currentTimeMillis() + 1000).execute(); + Assert.assertTrue("trash file should be removed", + StreamSupport.stream(results.orphanFileLocations().spliterator(), false) + .anyMatch(file -> file.contains("file:" + location + "/data/trashfile"))); + } + + @Test + public void testSparkSessionCatalogHadoopTable() throws Exception { + spark.conf().set("spark.sql.catalog.spark_catalog", "org.apache.iceberg.spark.SparkSessionCatalog"); + spark.conf().set("spark.sql.catalog.spark_catalog.type", "hadoop"); + spark.conf().set("spark.sql.catalog.spark_catalog.warehouse", tableLocation); + SparkSessionCatalog cat = (SparkSessionCatalog) spark.sessionState().catalogManager().v2SessionCatalog(); + + String[] database = {"default"}; + Identifier id = Identifier.of(database, "table"); + Map options = Maps.newHashMap(); + Transform[] transforms = {}; + cat.createTable(id, SparkSchemaUtil.convert(SCHEMA), transforms, options); + SparkTable table = (SparkTable) cat.loadTable(id); + + spark.sql("INSERT INTO default.table VALUES (1,1,1)"); + + String location = table.table().location().replaceFirst("file:", ""); + new File(location + "/data/trashfile").createNewFile(); + + DeleteOrphanFiles.Result results = SparkActions.get().deleteOrphanFiles(table.table()) + .olderThan(System.currentTimeMillis() + 1000).execute(); + Assert.assertTrue("trash file should be removed", + StreamSupport.stream(results.orphanFileLocations().spliterator(), false) + .anyMatch(file -> file.contains("file:" + location + "/data/trashfile"))); + } + + @Test + public void testSparkSessionCatalogHiveTable() throws Exception { + spark.conf().set("spark.sql.catalog.spark_catalog", "org.apache.iceberg.spark.SparkSessionCatalog"); + spark.conf().set("spark.sql.catalog.spark_catalog.type", "hive"); + SparkSessionCatalog cat = (SparkSessionCatalog) spark.sessionState().catalogManager().v2SessionCatalog(); + + String[] database = {"default"}; + Identifier id = Identifier.of(database, "sessioncattest"); + Map options = Maps.newHashMap(); + Transform[] transforms = {}; + cat.dropTable(id); + cat.createTable(id, SparkSchemaUtil.convert(SCHEMA), transforms, options); + SparkTable table = (SparkTable) cat.loadTable(id); + + spark.sql("INSERT INTO default.sessioncattest VALUES (1,1,1)"); + + String location = table.table().location().replaceFirst("file:", ""); + new File(location + "/data/trashfile").createNewFile(); + + DeleteOrphanFiles.Result results = SparkActions.get().deleteOrphanFiles(table.table()) + .olderThan(System.currentTimeMillis() + 1000).execute(); + Assert.assertTrue("trash file should be removed", + StreamSupport.stream(results.orphanFileLocations().spliterator(), false) + .anyMatch(file -> file.contains("file:" + location + "/data/trashfile"))); + } + + @After + public void resetSparkSessionCatalog() throws Exception { + spark.conf().unset("spark.sql.catalog.spark_catalog"); + spark.conf().unset("spark.sql.catalog.spark_catalog.type"); + spark.conf().unset("spark.sql.catalog.spark_catalog.warehouse"); + } + +} diff --git a/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/actions/TestRewriteDataFilesAction.java b/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/actions/TestRewriteDataFilesAction.java new file mode 100644 index 000000000000..f7394f645016 --- /dev/null +++ b/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/actions/TestRewriteDataFilesAction.java @@ -0,0 +1,1564 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iceberg.spark.actions; + +import java.io.File; +import java.io.IOException; +import java.io.UncheckedIOException; +import java.util.Collections; +import java.util.Comparator; +import java.util.List; +import java.util.Map; +import java.util.Random; +import java.util.Set; +import java.util.UUID; +import java.util.stream.Collectors; +import java.util.stream.IntStream; +import java.util.stream.Stream; +import org.apache.hadoop.conf.Configuration; +import org.apache.iceberg.AssertHelpers; +import org.apache.iceberg.ContentFile; +import org.apache.iceberg.DataFile; +import org.apache.iceberg.DeleteFile; +import org.apache.iceberg.FileFormat; +import org.apache.iceberg.FileScanTask; +import org.apache.iceberg.MetadataTableType; +import org.apache.iceberg.PartitionSpec; +import org.apache.iceberg.RewriteJobOrder; +import org.apache.iceberg.RowDelta; +import org.apache.iceberg.Schema; +import org.apache.iceberg.SortOrder; +import org.apache.iceberg.StructLike; +import org.apache.iceberg.Table; +import org.apache.iceberg.TableProperties; +import org.apache.iceberg.actions.ActionsProvider; +import org.apache.iceberg.actions.BinPackStrategy; +import org.apache.iceberg.actions.RewriteDataFiles; +import org.apache.iceberg.actions.RewriteDataFiles.Result; +import org.apache.iceberg.actions.RewriteDataFilesCommitManager; +import org.apache.iceberg.actions.RewriteFileGroup; +import org.apache.iceberg.actions.SortStrategy; +import org.apache.iceberg.data.GenericAppenderFactory; +import org.apache.iceberg.data.Record; +import org.apache.iceberg.deletes.PositionDeleteWriter; +import org.apache.iceberg.encryption.EncryptedFiles; +import org.apache.iceberg.encryption.EncryptedOutputFile; +import org.apache.iceberg.encryption.EncryptionKeyMetadata; +import org.apache.iceberg.exceptions.CommitStateUnknownException; +import org.apache.iceberg.expressions.Expression; +import org.apache.iceberg.expressions.Expressions; +import org.apache.iceberg.hadoop.HadoopTables; +import org.apache.iceberg.io.CloseableIterable; +import org.apache.iceberg.io.OutputFile; +import org.apache.iceberg.relocated.com.google.common.base.Preconditions; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableList; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableSet; +import org.apache.iceberg.relocated.com.google.common.collect.Iterables; +import org.apache.iceberg.relocated.com.google.common.collect.Lists; +import org.apache.iceberg.relocated.com.google.common.collect.Maps; +import org.apache.iceberg.relocated.com.google.common.collect.Streams; +import org.apache.iceberg.spark.FileRewriteCoordinator; +import org.apache.iceberg.spark.FileScanTaskSetManager; +import org.apache.iceberg.spark.SparkTableUtil; +import org.apache.iceberg.spark.SparkTestBase; +import org.apache.iceberg.spark.actions.BaseRewriteDataFilesSparkAction.RewriteExecutionContext; +import org.apache.iceberg.spark.source.ThreeColumnRecord; +import org.apache.iceberg.types.Comparators; +import org.apache.iceberg.types.Conversions; +import org.apache.iceberg.types.Types; +import org.apache.iceberg.types.Types.NestedField; +import org.apache.iceberg.util.Pair; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Row; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.TemporaryFolder; +import org.mockito.ArgumentMatcher; +import org.mockito.Mockito; + +import static org.apache.iceberg.types.Types.NestedField.optional; +import static org.apache.iceberg.types.Types.NestedField.required; +import static org.apache.spark.sql.functions.current_date; +import static org.apache.spark.sql.functions.date_add; +import static org.apache.spark.sql.functions.expr; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.argThat; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.doCallRealMethod; +import static org.mockito.Mockito.doReturn; +import static org.mockito.Mockito.doThrow; +import static org.mockito.Mockito.spy; + +public class TestRewriteDataFilesAction extends SparkTestBase { + + private static final int SCALE = 400000; + + private static final HadoopTables TABLES = new HadoopTables(new Configuration()); + private static final Schema SCHEMA = new Schema( + optional(1, "c1", Types.IntegerType.get()), + optional(2, "c2", Types.StringType.get()), + optional(3, "c3", Types.StringType.get()) + ); + + @Rule + public TemporaryFolder temp = new TemporaryFolder(); + + private final FileRewriteCoordinator coordinator = FileRewriteCoordinator.get(); + private final FileScanTaskSetManager manager = FileScanTaskSetManager.get(); + private String tableLocation = null; + + @Before + public void setupTableLocation() throws Exception { + File tableDir = temp.newFolder(); + this.tableLocation = tableDir.toURI().toString(); + } + + private RewriteDataFiles basicRewrite(Table table) { + // Always compact regardless of input files + table.refresh(); + return actions().rewriteDataFiles(table, tableLocation).option(BinPackStrategy.MIN_INPUT_FILES, "1"); + } + + @Test + public void testEmptyTable() { + PartitionSpec spec = PartitionSpec.unpartitioned(); + Map options = Maps.newHashMap(); + Table table = TABLES.create(SCHEMA, spec, options, tableLocation); + + Assert.assertNull("Table must be empty", table.currentSnapshot()); + + basicRewrite(table).execute(); + + Assert.assertNull("Table must stay empty", table.currentSnapshot()); + } + + @Test + public void testBinPackUnpartitionedTable() { + Table table = createTable(4); + shouldHaveFiles(table, 4); + List expectedRecords = currentData(); + + Result result = basicRewrite(table).execute(); + Assert.assertEquals("Action should rewrite 4 data files", 4, result.rewrittenDataFilesCount()); + Assert.assertEquals("Action should add 1 data file", 1, result.addedDataFilesCount()); + + shouldHaveFiles(table, 1); + List actual = currentData(); + + assertEquals("Rows must match", expectedRecords, actual); + } + + @Test + public void testBinPackPartitionedTable() { + Table table = createTablePartitioned(4, 2); + shouldHaveFiles(table, 8); + List expectedRecords = currentData(); + + Result result = basicRewrite(table).execute(); + Assert.assertEquals("Action should rewrite 8 data files", 8, result.rewrittenDataFilesCount()); + Assert.assertEquals("Action should add 4 data file", 4, result.addedDataFilesCount()); + + shouldHaveFiles(table, 4); + List actualRecords = currentData(); + + assertEquals("Rows must match", expectedRecords, actualRecords); + } + + @Test + public void testBinPackWithFilter() { + Table table = createTablePartitioned(4, 2); + shouldHaveFiles(table, 8); + List expectedRecords = currentData(); + + Result result = basicRewrite(table) + .filter(Expressions.equal("c1", 1)) + .filter(Expressions.startsWith("c2", "foo")) + .execute(); + + Assert.assertEquals("Action should rewrite 2 data files", 2, result.rewrittenDataFilesCount()); + Assert.assertEquals("Action should add 1 data file", 1, result.addedDataFilesCount()); + + shouldHaveFiles(table, 7); + + List actualRecords = currentData(); + assertEquals("Rows must match", expectedRecords, actualRecords); + } + + @Test + public void testBinPackAfterPartitionChange() { + Table table = createTable(); + + writeRecords(20, SCALE, 20); + shouldHaveFiles(table, 20); + table.updateSpec().addField(Expressions.ref("c1")).commit(); + + List originalData = currentData(); + + RewriteDataFiles.Result result = + basicRewrite(table) + .option(SortStrategy.MIN_INPUT_FILES, "1") + .option(SortStrategy.MIN_FILE_SIZE_BYTES, Integer.toString(averageFileSize(table) + 1000)) + .option(RewriteDataFiles.TARGET_FILE_SIZE_BYTES, Integer.toString(averageFileSize(table) + 1001)) + .execute(); + + Assert.assertEquals("Should have 1 fileGroup because all files were not correctly partitioned", + 1, result.rewriteResults().size()); + + List postRewriteData = currentData(); + assertEquals("We shouldn't have changed the data", originalData, postRewriteData); + + shouldHaveSnapshots(table, 2); + shouldHaveACleanCache(table); + shouldHaveFiles(table, 20); + } + + @Test + public void testBinPackWithDeletes() throws Exception { + Table table = createTablePartitioned(4, 2); + table.updateProperties().set(TableProperties.FORMAT_VERSION, "2").commit(); + shouldHaveFiles(table, 8); + table.refresh(); + + CloseableIterable tasks = table.newScan().planFiles(); + List dataFiles = Lists.newArrayList(CloseableIterable.transform(tasks, FileScanTask::file)); + int total = (int) dataFiles.stream().mapToLong(ContentFile::recordCount).sum(); + + RowDelta rowDelta = table.newRowDelta(); + // add 1 delete file for data files 0, 1, 2 + for (int i = 0; i < 3; i++) { + writePosDeletesToFile(table, dataFiles.get(i), 1) + .forEach(rowDelta::addDeletes); + } + + // add 2 delete files for data files 3, 4 + for (int i = 3; i < 5; i++) { + writePosDeletesToFile(table, dataFiles.get(i), 2) + .forEach(rowDelta::addDeletes); + } + + rowDelta.commit(); + table.refresh(); + List expectedRecords = currentData(); + Result result = actions().rewriteDataFiles(table, tableLocation) + // do not include any file based on bin pack file size configs + .option(BinPackStrategy.MIN_FILE_SIZE_BYTES, "0") + .option(RewriteDataFiles.TARGET_FILE_SIZE_BYTES, Long.toString(Long.MAX_VALUE - 1)) + .option(BinPackStrategy.MAX_FILE_SIZE_BYTES, Long.toString(Long.MAX_VALUE)) + .option(BinPackStrategy.DELETE_FILE_THRESHOLD, "2") + .execute(); + Assert.assertEquals("Action should rewrite 2 data files", 2, result.rewrittenDataFilesCount()); + + List actualRecords = currentData(); + assertEquals("Rows must match", expectedRecords, actualRecords); + Assert.assertEquals("7 rows are removed", total - 7, actualRecords.size()); + } + + @Test + public void testBinPackWithDeleteAllData() { + Map options = Maps.newHashMap(); + options.put(TableProperties.FORMAT_VERSION, "2"); + Table table = createTablePartitioned(1, 1, 1, options); + shouldHaveFiles(table, 1); + table.refresh(); + + CloseableIterable tasks = table.newScan().planFiles(); + List dataFiles = Lists.newArrayList(CloseableIterable.transform(tasks, FileScanTask::file)); + int total = (int) dataFiles.stream().mapToLong(ContentFile::recordCount).sum(); + + RowDelta rowDelta = table.newRowDelta(); + // remove all data + writePosDeletesToFile(table, dataFiles.get(0), total) + .forEach(rowDelta::addDeletes); + + rowDelta.commit(); + table.refresh(); + List expectedRecords = currentData(); + Result result = actions().rewriteDataFiles(table, tableLocation) + .option(BinPackStrategy.DELETE_FILE_THRESHOLD, "1") + .execute(); + Assert.assertEquals("Action should rewrite 1 data files", 1, result.rewrittenDataFilesCount()); + + List actualRecords = currentData(); + assertEquals("Rows must match", expectedRecords, actualRecords); + Assert.assertEquals( + "Data manifest should not have existing data file", + 0, + (long) table.currentSnapshot().dataManifests(table.io()).get(0).existingFilesCount()); + Assert.assertEquals("Data manifest should have 1 delete data file", + 1L, + (long) table.currentSnapshot().dataManifests(table.io()).get(0).deletedFilesCount()); + Assert.assertEquals( + "Delete manifest added row count should equal total count", + total, + (long) table.currentSnapshot().deleteManifests(table.io()).get(0).addedRowsCount()); + } + + @Test + public void testBinPackWithStartingSequenceNumber() { + Table table = createTablePartitioned(4, 2); + shouldHaveFiles(table, 8); + List expectedRecords = currentData(); + table.updateProperties().set(TableProperties.FORMAT_VERSION, "2").commit(); + table.refresh(); + long oldSequenceNumber = table.currentSnapshot().sequenceNumber(); + + Result result = basicRewrite(table) + .option(RewriteDataFiles.USE_STARTING_SEQUENCE_NUMBER, "true") + .execute(); + Assert.assertEquals("Action should rewrite 8 data files", 8, result.rewrittenDataFilesCount()); + Assert.assertEquals("Action should add 4 data file", 4, result.addedDataFilesCount()); + + shouldHaveFiles(table, 4); + List actualRecords = currentData(); + assertEquals("Rows must match", expectedRecords, actualRecords); + + table.refresh(); + Assert.assertTrue("Table sequence number should be incremented", + oldSequenceNumber < table.currentSnapshot().sequenceNumber()); + + Dataset rows = SparkTableUtil.loadMetadataTable(spark, table, MetadataTableType.ENTRIES); + for (Row row : rows.collectAsList()) { + if (row.getInt(0) == 1) { + Assert.assertEquals("Expect old sequence number for added entries", oldSequenceNumber, row.getLong(2)); + } + } + } + + @Test + public void testBinPackWithStartingSequenceNumberV1Compatibility() { + Table table = createTablePartitioned(4, 2); + shouldHaveFiles(table, 8); + List expectedRecords = currentData(); + table.refresh(); + long oldSequenceNumber = table.currentSnapshot().sequenceNumber(); + Assert.assertEquals("Table sequence number should be 0", 0, oldSequenceNumber); + + Result result = basicRewrite(table) + .option(RewriteDataFiles.USE_STARTING_SEQUENCE_NUMBER, "true") + .execute(); + Assert.assertEquals("Action should rewrite 8 data files", 8, result.rewrittenDataFilesCount()); + Assert.assertEquals("Action should add 4 data file", 4, result.addedDataFilesCount()); + + shouldHaveFiles(table, 4); + List actualRecords = currentData(); + assertEquals("Rows must match", expectedRecords, actualRecords); + + table.refresh(); + Assert.assertEquals("Table sequence number should still be 0", + oldSequenceNumber, table.currentSnapshot().sequenceNumber()); + + Dataset rows = SparkTableUtil.loadMetadataTable(spark, table, MetadataTableType.ENTRIES); + for (Row row : rows.collectAsList()) { + Assert.assertEquals("Expect sequence number 0 for all entries", + oldSequenceNumber, row.getLong(2)); + } + } + + @Test + public void testRewriteLargeTableHasResiduals() { + PartitionSpec spec = PartitionSpec.builderFor(SCHEMA).build(); + Map options = Maps.newHashMap(); + options.put(TableProperties.PARQUET_ROW_GROUP_SIZE_BYTES, "100"); + Table table = TABLES.create(SCHEMA, spec, options, tableLocation); + + // all records belong to the same partition + List records = Lists.newArrayList(); + for (int i = 0; i < 100; i++) { + records.add(new ThreeColumnRecord(i, String.valueOf(i), String.valueOf(i % 4))); + } + Dataset df = spark.createDataFrame(records, ThreeColumnRecord.class); + writeDF(df); + + List expectedRecords = currentData(); + + table.refresh(); + + CloseableIterable tasks = table.newScan() + .ignoreResiduals() + .filter(Expressions.equal("c3", "0")) + .planFiles(); + for (FileScanTask task : tasks) { + Assert.assertEquals("Residuals must be ignored", Expressions.alwaysTrue(), task.residual()); + } + + shouldHaveFiles(table, 2); + + Result result = basicRewrite(table) + .filter(Expressions.equal("c3", "0")) + .execute(); + Assert.assertEquals("Action should rewrite 2 data files", 2, result.rewrittenDataFilesCount()); + Assert.assertEquals("Action should add 1 data file", 1, result.addedDataFilesCount()); + + List actualRecords = currentData(); + + assertEquals("Rows must match", expectedRecords, actualRecords); + } + + @Test + public void testBinPackSplitLargeFile() { + Table table = createTable(1); + shouldHaveFiles(table, 1); + + List expectedRecords = currentData(); + long targetSize = testDataSize(table) / 2; + + Result result = basicRewrite(table) + .option(RewriteDataFiles.TARGET_FILE_SIZE_BYTES, Long.toString(targetSize)) + .option(BinPackStrategy.MAX_FILE_SIZE_BYTES, Long.toString(targetSize * 2 - 2000)) + .execute(); + + Assert.assertEquals("Action should delete 1 data files", 1, result.rewrittenDataFilesCount()); + Assert.assertEquals("Action should add 2 data files", 2, result.addedDataFilesCount()); + + shouldHaveFiles(table, 2); + + List actualRecords = currentData(); + assertEquals("Rows must match", expectedRecords, actualRecords); + } + + @Test + public void testBinPackCombineMixedFiles() { + Table table = createTable(1); // 400000 + shouldHaveFiles(table, 1); + + // Add one more small file, and one large file + writeRecords(1, SCALE); + writeRecords(1, SCALE * 3); + shouldHaveFiles(table, 3); + + List expectedRecords = currentData(); + + int targetSize = averageFileSize(table); + + Result result = basicRewrite(table) + .option(RewriteDataFiles.TARGET_FILE_SIZE_BYTES, Integer.toString(targetSize + 1000)) + .option(BinPackStrategy.MAX_FILE_SIZE_BYTES, Integer.toString(targetSize + 80000)) + .option(BinPackStrategy.MIN_FILE_SIZE_BYTES, Integer.toString(targetSize - 1000)) + .execute(); + + Assert.assertEquals("Action should delete 3 data files", 3, result.rewrittenDataFilesCount()); + // Should Split the big files into 3 pieces, one of which should be combined with the two smaller files + Assert.assertEquals("Action should add 3 data files", 3, result.addedDataFilesCount()); + + shouldHaveFiles(table, 3); + + List actualRecords = currentData(); + assertEquals("Rows must match", expectedRecords, actualRecords); + } + + @Test + public void testBinPackCombineMediumFiles() { + Table table = createTable(4); + shouldHaveFiles(table, 4); + + List expectedRecords = currentData(); + int targetSize = ((int) testDataSize(table) / 3); + // The test is to see if we can combine parts of files to make files of the correct size + + Result result = basicRewrite(table) + .option(RewriteDataFiles.TARGET_FILE_SIZE_BYTES, Integer.toString(targetSize)) + .option(BinPackStrategy.MAX_FILE_SIZE_BYTES, Integer.toString((int) (targetSize * 1.8))) + .option(BinPackStrategy.MIN_FILE_SIZE_BYTES, Integer.toString(targetSize - 100)) // All files too small + .execute(); + + Assert.assertEquals("Action should delete 4 data files", 4, result.rewrittenDataFilesCount()); + Assert.assertEquals("Action should add 3 data files", 3, result.addedDataFilesCount()); + + shouldHaveFiles(table, 3); + + List actualRecords = currentData(); + assertEquals("Rows must match", expectedRecords, actualRecords); + } + + @Test + public void testPartialProgressEnabled() { + Table table = createTable(20); + int fileSize = averageFileSize(table); + + List originalData = currentData(); + + // Perform a rewrite but only allow 2 files to be compacted at a time + RewriteDataFiles.Result result = + basicRewrite(table) + .option(RewriteDataFiles.PARTIAL_PROGRESS_ENABLED, "true") + .option(RewriteDataFiles.MAX_FILE_GROUP_SIZE_BYTES, Integer.toString(fileSize * 2 + 1000)) + .option(RewriteDataFiles.PARTIAL_PROGRESS_MAX_COMMITS, "10") + .execute(); + + Assert.assertEquals("Should have 10 fileGroups", result.rewriteResults().size(), 10); + + table.refresh(); + + shouldHaveSnapshots(table, 11); + shouldHaveACleanCache(table); + + List postRewriteData = currentData(); + assertEquals("We shouldn't have changed the data", originalData, postRewriteData); + } + + @Test + public void testMultipleGroups() { + Table table = createTable(20); + int fileSize = averageFileSize(table); + + List originalData = currentData(); + + // Perform a rewrite but only allow 2 files to be compacted at a time + RewriteDataFiles.Result result = + basicRewrite(table) + .option(RewriteDataFiles.MAX_FILE_GROUP_SIZE_BYTES, Integer.toString(fileSize * 2 + 1000)) + .option(BinPackStrategy.MIN_INPUT_FILES, "1") + .execute(); + + Assert.assertEquals("Should have 10 fileGroups", result.rewriteResults().size(), 10); + + table.refresh(); + + List postRewriteData = currentData(); + assertEquals("We shouldn't have changed the data", originalData, postRewriteData); + + shouldHaveSnapshots(table, 2); + shouldHaveACleanCache(table); + } + + @Test + public void testPartialProgressMaxCommits() { + Table table = createTable(20); + int fileSize = averageFileSize(table); + + List originalData = currentData(); + + // Perform a rewrite but only allow 2 files to be compacted at a time + RewriteDataFiles.Result result = + basicRewrite(table) + .option(RewriteDataFiles.MAX_FILE_GROUP_SIZE_BYTES, Integer.toString(fileSize * 2 + 1000)) + .option(RewriteDataFiles.PARTIAL_PROGRESS_ENABLED, "true") + .option(RewriteDataFiles.PARTIAL_PROGRESS_MAX_COMMITS, "3") + .execute(); + + Assert.assertEquals("Should have 10 fileGroups", result.rewriteResults().size(), 10); + + table.refresh(); + + List postRewriteData = currentData(); + assertEquals("We shouldn't have changed the data", originalData, postRewriteData); + + shouldHaveSnapshots(table, 4); + shouldHaveACleanCache(table); + } + + @Test + public void testSingleCommitWithRewriteFailure() { + Table table = createTable(20); + int fileSize = averageFileSize(table); + + List originalData = currentData(); + + BaseRewriteDataFilesSparkAction realRewrite = + (org.apache.iceberg.spark.actions.BaseRewriteDataFilesSparkAction) + basicRewrite(table) + .option(RewriteDataFiles.MAX_FILE_GROUP_SIZE_BYTES, Integer.toString(fileSize * 2 + 1000)); + + BaseRewriteDataFilesSparkAction spyRewrite = Mockito.spy(realRewrite); + + // Fail groups 1, 3, and 7 during rewrite + GroupInfoMatcher failGroup = new GroupInfoMatcher(1, 3, 7); + doThrow(new RuntimeException("Rewrite Failed")) + .when(spyRewrite) + .rewriteFiles(any(), argThat(failGroup)); + + AssertHelpers.assertThrows("Should fail entire rewrite if part fails", RuntimeException.class, + () -> spyRewrite.execute()); + + table.refresh(); + + List postRewriteData = currentData(); + assertEquals("We shouldn't have changed the data", originalData, postRewriteData); + + shouldHaveSnapshots(table, 1); + shouldHaveNoOrphans(table); + shouldHaveACleanCache(table); + } + + @Test + public void testSingleCommitWithCommitFailure() { + Table table = createTable(20); + int fileSize = averageFileSize(table); + + List originalData = currentData(); + + BaseRewriteDataFilesSparkAction realRewrite = + (org.apache.iceberg.spark.actions.BaseRewriteDataFilesSparkAction) + basicRewrite(table) + .option(RewriteDataFiles.MAX_FILE_GROUP_SIZE_BYTES, Integer.toString(fileSize * 2 + 1000)); + + BaseRewriteDataFilesSparkAction spyRewrite = spy(realRewrite); + RewriteDataFilesCommitManager util = spy(new RewriteDataFilesCommitManager(table)); + + // Fail to commit + doThrow(new RuntimeException("Commit Failure")) + .when(util) + .commitFileGroups(any()); + + doReturn(util) + .when(spyRewrite) + .commitManager(table.currentSnapshot().snapshotId()); + + AssertHelpers.assertThrows("Should fail entire rewrite if commit fails", RuntimeException.class, + () -> spyRewrite.execute()); + + table.refresh(); + + List postRewriteData = currentData(); + assertEquals("We shouldn't have changed the data", originalData, postRewriteData); + + shouldHaveSnapshots(table, 1); + shouldHaveNoOrphans(table); + shouldHaveACleanCache(table); + } + + @Test + public void testParallelSingleCommitWithRewriteFailure() { + Table table = createTable(20); + int fileSize = averageFileSize(table); + + List originalData = currentData(); + + BaseRewriteDataFilesSparkAction realRewrite = + (org.apache.iceberg.spark.actions.BaseRewriteDataFilesSparkAction) + basicRewrite(table) + .option(RewriteDataFiles.MAX_FILE_GROUP_SIZE_BYTES, Integer.toString(fileSize * 2 + 1000)) + .option(RewriteDataFiles.MAX_CONCURRENT_FILE_GROUP_REWRITES, "3"); + + BaseRewriteDataFilesSparkAction spyRewrite = Mockito.spy(realRewrite); + + // Fail groups 1, 3, and 7 during rewrite + GroupInfoMatcher failGroup = new GroupInfoMatcher(1, 3, 7); + doThrow(new RuntimeException("Rewrite Failed")) + .when(spyRewrite) + .rewriteFiles(any(), argThat(failGroup)); + + AssertHelpers.assertThrows("Should fail entire rewrite if part fails", RuntimeException.class, + () -> spyRewrite.execute()); + + table.refresh(); + + List postRewriteData = currentData(); + assertEquals("We shouldn't have changed the data", originalData, postRewriteData); + + shouldHaveSnapshots(table, 1); + shouldHaveNoOrphans(table); + shouldHaveACleanCache(table); + } + + @Test + public void testPartialProgressWithRewriteFailure() { + Table table = createTable(20); + int fileSize = averageFileSize(table); + + List originalData = currentData(); + + BaseRewriteDataFilesSparkAction realRewrite = + (org.apache.iceberg.spark.actions.BaseRewriteDataFilesSparkAction) + basicRewrite(table) + .option(RewriteDataFiles.MAX_FILE_GROUP_SIZE_BYTES, Integer.toString(fileSize * 2 + 1000)) + .option(RewriteDataFiles.PARTIAL_PROGRESS_ENABLED, "true") + .option(RewriteDataFiles.PARTIAL_PROGRESS_MAX_COMMITS, "3"); + + BaseRewriteDataFilesSparkAction spyRewrite = Mockito.spy(realRewrite); + + // Fail groups 1, 3, and 7 during rewrite + GroupInfoMatcher failGroup = new GroupInfoMatcher(1, 3, 7); + doThrow(new RuntimeException("Rewrite Failed")) + .when(spyRewrite) + .rewriteFiles(any(), argThat(failGroup)); + + RewriteDataFiles.Result result = spyRewrite.execute(); + + Assert.assertEquals("Should have 7 fileGroups", result.rewriteResults().size(), 7); + + table.refresh(); + + List postRewriteData = currentData(); + assertEquals("We shouldn't have changed the data", originalData, postRewriteData); + + // With 10 original groups and Max Commits of 3, we should have commits with 4, 4, and 2. + // removing 3 groups leaves us with only 2 new commits, 4 and 3 + shouldHaveSnapshots(table, 3); + shouldHaveNoOrphans(table); + shouldHaveACleanCache(table); + } + + @Test + public void testParallelPartialProgressWithRewriteFailure() { + Table table = createTable(20); + int fileSize = averageFileSize(table); + + List originalData = currentData(); + + BaseRewriteDataFilesSparkAction realRewrite = + (org.apache.iceberg.spark.actions.BaseRewriteDataFilesSparkAction) + basicRewrite(table) + .option(RewriteDataFiles.MAX_FILE_GROUP_SIZE_BYTES, Integer.toString(fileSize * 2 + 1000)) + .option(RewriteDataFiles.MAX_CONCURRENT_FILE_GROUP_REWRITES, "3") + .option(RewriteDataFiles.PARTIAL_PROGRESS_ENABLED, "true") + .option(RewriteDataFiles.PARTIAL_PROGRESS_MAX_COMMITS, "3"); + + BaseRewriteDataFilesSparkAction spyRewrite = Mockito.spy(realRewrite); + + // Fail groups 1, 3, and 7 during rewrite + GroupInfoMatcher failGroup = new GroupInfoMatcher(1, 3, 7); + doThrow(new RuntimeException("Rewrite Failed")) + .when(spyRewrite) + .rewriteFiles(any(), argThat(failGroup)); + + RewriteDataFiles.Result result = spyRewrite.execute(); + + Assert.assertEquals("Should have 7 fileGroups", result.rewriteResults().size(), 7); + + table.refresh(); + + List postRewriteData = currentData(); + assertEquals("We shouldn't have changed the data", originalData, postRewriteData); + + // With 10 original groups and Max Commits of 3, we should have commits with 4, 4, and 2. + // removing 3 groups leaves us with only 2 new commits, 4 and 3 + shouldHaveSnapshots(table, 3); + shouldHaveNoOrphans(table); + shouldHaveACleanCache(table); + } + + @Test + public void testParallelPartialProgressWithCommitFailure() { + Table table = createTable(20); + int fileSize = averageFileSize(table); + + List originalData = currentData(); + + BaseRewriteDataFilesSparkAction realRewrite = + (org.apache.iceberg.spark.actions.BaseRewriteDataFilesSparkAction) + basicRewrite(table) + .option(RewriteDataFiles.MAX_FILE_GROUP_SIZE_BYTES, Integer.toString(fileSize * 2 + 1000)) + .option(RewriteDataFiles.MAX_CONCURRENT_FILE_GROUP_REWRITES, "3") + .option(RewriteDataFiles.PARTIAL_PROGRESS_ENABLED, "true") + .option(RewriteDataFiles.PARTIAL_PROGRESS_MAX_COMMITS, "3"); + + BaseRewriteDataFilesSparkAction spyRewrite = spy(realRewrite); + RewriteDataFilesCommitManager util = spy(new RewriteDataFilesCommitManager(table)); + + // First and Third commits work, second does not + doCallRealMethod() + .doThrow(new RuntimeException("Commit Failed")) + .doCallRealMethod() + .when(util) + .commitFileGroups(any()); + + doReturn(util) + .when(spyRewrite) + .commitManager(table.currentSnapshot().snapshotId()); + + RewriteDataFiles.Result result = spyRewrite.execute(); + + // Commit 1: 4/4 + Commit 2 failed 0/4 + Commit 3: 2/2 == 6 out of 10 total groups comitted + Assert.assertEquals("Should have 6 fileGroups", 6, result.rewriteResults().size()); + + table.refresh(); + + List postRewriteData = currentData(); + assertEquals("We shouldn't have changed the data", originalData, postRewriteData); + + // Only 2 new commits because we broke one + shouldHaveSnapshots(table, 3); + shouldHaveNoOrphans(table); + shouldHaveACleanCache(table); + } + + @Test + public void testInvalidOptions() { + Table table = createTable(20); + + AssertHelpers.assertThrows("No negative values for partial progress max commits", + IllegalArgumentException.class, + () -> basicRewrite(table) + .option(RewriteDataFiles.PARTIAL_PROGRESS_ENABLED, "true") + .option(RewriteDataFiles.PARTIAL_PROGRESS_MAX_COMMITS, "-5") + .execute()); + + AssertHelpers.assertThrows("No negative values for max concurrent groups", + IllegalArgumentException.class, + () -> basicRewrite(table) + .option(RewriteDataFiles.MAX_CONCURRENT_FILE_GROUP_REWRITES, "-5") + .execute()); + + AssertHelpers.assertThrows("No unknown options allowed", + IllegalArgumentException.class, + () -> basicRewrite(table) + .option("foobarity", "-5") + .execute()); + + AssertHelpers.assertThrows("Cannot set rewrite-job-order to foo", + IllegalArgumentException.class, + () -> basicRewrite(table) + .option(RewriteDataFiles.REWRITE_JOB_ORDER, "foo") + .execute()); + } + + @Test + public void testSortMultipleGroups() { + Table table = createTable(20); + shouldHaveFiles(table, 20); + table.replaceSortOrder().asc("c2").commit(); + shouldHaveLastCommitUnsorted(table, "c2"); + int fileSize = averageFileSize(table); + + List originalData = currentData(); + + // Perform a rewrite but only allow 2 files to be compacted at a time + RewriteDataFiles.Result result = + basicRewrite(table) + .sort() + .option(SortStrategy.REWRITE_ALL, "true") + .option(RewriteDataFiles.MAX_FILE_GROUP_SIZE_BYTES, Integer.toString(fileSize * 2 + 1000)) + .execute(); + + Assert.assertEquals("Should have 10 fileGroups", result.rewriteResults().size(), 10); + + table.refresh(); + + List postRewriteData = currentData(); + assertEquals("We shouldn't have changed the data", originalData, postRewriteData); + + shouldHaveSnapshots(table, 2); + shouldHaveACleanCache(table); + } + + @Test + public void testSimpleSort() { + Table table = createTable(20); + shouldHaveFiles(table, 20); + table.replaceSortOrder().asc("c2").commit(); + shouldHaveLastCommitUnsorted(table, "c2"); + + List originalData = currentData(); + + RewriteDataFiles.Result result = + basicRewrite(table) + .sort() + .option(SortStrategy.MIN_INPUT_FILES, "1") + .option(SortStrategy.REWRITE_ALL, "true") + .option(RewriteDataFiles.TARGET_FILE_SIZE_BYTES, Integer.toString(averageFileSize(table))) + .execute(); + + Assert.assertEquals("Should have 1 fileGroups", result.rewriteResults().size(), 1); + + table.refresh(); + + List postRewriteData = currentData(); + assertEquals("We shouldn't have changed the data", originalData, postRewriteData); + + shouldHaveSnapshots(table, 2); + shouldHaveACleanCache(table); + shouldHaveMultipleFiles(table); + shouldHaveLastCommitSorted(table, "c2"); + } + + @Test + public void testSortAfterPartitionChange() { + Table table = createTable(20); + shouldHaveFiles(table, 20); + table.updateSpec().addField(Expressions.bucket("c1", 4)).commit(); + table.replaceSortOrder().asc("c2").commit(); + shouldHaveLastCommitUnsorted(table, "c2"); + + List originalData = currentData(); + + RewriteDataFiles.Result result = + basicRewrite(table) + .sort() + .option(SortStrategy.MIN_INPUT_FILES, "1") + .option(SortStrategy.REWRITE_ALL, "true") + .option(RewriteDataFiles.TARGET_FILE_SIZE_BYTES, Integer.toString(averageFileSize(table))) + .execute(); + + Assert.assertEquals("Should have 1 fileGroup because all files were not correctly partitioned", + result.rewriteResults().size(), 1); + + table.refresh(); + + List postRewriteData = currentData(); + assertEquals("We shouldn't have changed the data", originalData, postRewriteData); + + shouldHaveSnapshots(table, 2); + shouldHaveACleanCache(table); + shouldHaveMultipleFiles(table); + shouldHaveLastCommitSorted(table, "c2"); + } + + @Test + public void testSortCustomSortOrder() { + Table table = createTable(20); + shouldHaveLastCommitUnsorted(table, "c2"); + shouldHaveFiles(table, 20); + + List originalData = currentData(); + + RewriteDataFiles.Result result = + basicRewrite(table) + .sort(SortOrder.builderFor(table.schema()).asc("c2").build()) + .option(SortStrategy.REWRITE_ALL, "true") + .option(RewriteDataFiles.TARGET_FILE_SIZE_BYTES, Integer.toString(averageFileSize(table))) + .execute(); + + Assert.assertEquals("Should have 1 fileGroups", result.rewriteResults().size(), 1); + + table.refresh(); + + List postRewriteData = currentData(); + assertEquals("We shouldn't have changed the data", originalData, postRewriteData); + + shouldHaveSnapshots(table, 2); + shouldHaveACleanCache(table); + shouldHaveMultipleFiles(table); + shouldHaveLastCommitSorted(table, "c2"); + } + + @Test + public void testSortCustomSortOrderRequiresRepartition() { + int partitions = 4; + Table table = createTable(); + writeRecords(20, SCALE, partitions); + shouldHaveLastCommitUnsorted(table, "c3"); + + // Add a partition column so this requires repartitioning + table.updateSpec().addField("c1").commit(); + // Add a sort order which our repartitioning needs to ignore + table.replaceSortOrder().asc("c2").apply(); + shouldHaveFiles(table, 20); + + List originalData = currentData(); + + RewriteDataFiles.Result result = + basicRewrite(table) + .sort(SortOrder.builderFor(table.schema()).asc("c3").build()) + .option(SortStrategy.REWRITE_ALL, "true") + .option(RewriteDataFiles.TARGET_FILE_SIZE_BYTES, Integer.toString(averageFileSize(table) / partitions)) + .execute(); + + Assert.assertEquals("Should have 1 fileGroups", result.rewriteResults().size(), 1); + + table.refresh(); + + List postRewriteData = currentData(); + assertEquals("We shouldn't have changed the data", originalData, postRewriteData); + + shouldHaveSnapshots(table, 2); + shouldHaveACleanCache(table); + shouldHaveMultipleFiles(table); + shouldHaveLastCommitUnsorted(table, "c2"); + shouldHaveLastCommitSorted(table, "c3"); + } + + @Test + public void testAutoSortShuffleOutput() { + Table table = createTable(20); + shouldHaveLastCommitUnsorted(table, "c2"); + shouldHaveFiles(table, 20); + + List originalData = currentData(); + + RewriteDataFiles.Result result = + basicRewrite(table) + .sort(SortOrder.builderFor(table.schema()).asc("c2").build()) + .option(SortStrategy.MAX_FILE_SIZE_BYTES, Integer.toString((averageFileSize(table) / 2) + 2)) + // Divide files in 2 + .option(RewriteDataFiles.TARGET_FILE_SIZE_BYTES, Integer.toString(averageFileSize(table) / 2)) + .option(SortStrategy.MIN_INPUT_FILES, "1") + .execute(); + + Assert.assertEquals("Should have 1 fileGroups", result.rewriteResults().size(), 1); + Assert.assertTrue("Should have written 40+ files", + Iterables.size(table.currentSnapshot().addedFiles(table.io())) >= 40); + + table.refresh(); + + List postRewriteData = currentData(); + assertEquals("We shouldn't have changed the data", originalData, postRewriteData); + + shouldHaveSnapshots(table, 2); + shouldHaveACleanCache(table); + shouldHaveMultipleFiles(table); + shouldHaveLastCommitSorted(table, "c2"); + } + + @Test + public void testCommitStateUnknownException() { + Table table = createTable(20); + shouldHaveFiles(table, 20); + + List originalData = currentData(); + + BaseRewriteDataFilesSparkAction action = (BaseRewriteDataFilesSparkAction) basicRewrite(table); + BaseRewriteDataFilesSparkAction spyAction = spy(action); + RewriteDataFilesCommitManager util = spy(new RewriteDataFilesCommitManager(table)); + + doAnswer(invocationOnMock -> { + invocationOnMock.callRealMethod(); + throw new CommitStateUnknownException(new RuntimeException("Unknown State")); + }).when(util).commitFileGroups(any()); + + doReturn(util) + .when(spyAction) + .commitManager(table.currentSnapshot().snapshotId()); + + AssertHelpers.assertThrows("Should propagate CommitStateUnknown Exception", + CommitStateUnknownException.class, () -> spyAction.execute()); + + List postRewriteData = currentData(); + assertEquals("We shouldn't have changed the data", originalData, postRewriteData); + + shouldHaveSnapshots(table, 2); // Commit actually Succeeded + } + + @Test + public void testZOrderSort() { + int originalFiles = 20; + Table table = createTable(originalFiles); + shouldHaveLastCommitUnsorted(table, "c2"); + shouldHaveFiles(table, originalFiles); + + List originalData = currentData(); + double originalFilesC2 = percentFilesRequired(table, "c2", "foo23"); + double originalFilesC3 = percentFilesRequired(table, "c3", "bar21"); + double originalFilesC2C3 = percentFilesRequired(table, new String[]{"c2", "c3"}, new String[]{"foo23", "bar23"}); + + Assert.assertTrue("Should require all files to scan c2", originalFilesC2 > 0.99); + Assert.assertTrue("Should require all files to scan c3", originalFilesC3 > 0.99); + + RewriteDataFiles.Result result = + basicRewrite(table) + .zOrder("c2", "c3") + .option(SortStrategy.MAX_FILE_SIZE_BYTES, Integer.toString((averageFileSize(table) / 2) + 2)) + // Divide files in 2 + .option(RewriteDataFiles.TARGET_FILE_SIZE_BYTES, Integer.toString(averageFileSize(table) / 2)) + .option(SortStrategy.MIN_INPUT_FILES, "1") + .execute(); + + Assert.assertEquals("Should have 1 fileGroups", 1, result.rewriteResults().size()); + int zOrderedFilesTotal = Iterables.size(table.currentSnapshot().addedFiles(table.io())); + Assert.assertTrue("Should have written 40+ files", zOrderedFilesTotal >= 40); + + table.refresh(); + + List postRewriteData = currentData(); + assertEquals("We shouldn't have changed the data", originalData, postRewriteData); + + shouldHaveSnapshots(table, 2); + shouldHaveACleanCache(table); + + double filesScannedC2 = percentFilesRequired(table, "c2", "foo23"); + double filesScannedC3 = percentFilesRequired(table, "c3", "bar21"); + double filesScannedC2C3 = percentFilesRequired(table, new String[]{"c2", "c3"}, new String[]{"foo23", "bar23"}); + + Assert.assertTrue("Should have reduced the number of files required for c2", + filesScannedC2 < originalFilesC2); + Assert.assertTrue("Should have reduced the number of files required for c3", + filesScannedC3 < originalFilesC3); + Assert.assertTrue("Should have reduced the number of files required for a c2,c3 predicate", + filesScannedC2C3 < originalFilesC2C3); + } + + @Test + public void testZOrderAllTypesSort() { + Table table = createTypeTestTable(); + shouldHaveFiles(table, 10); + + List originalRaw = spark.read().format("iceberg").load(tableLocation).sort("longCol").collectAsList(); + List originalData = rowsToJava(originalRaw); + + // TODO add in UUID when it is supported in Spark + RewriteDataFiles.Result result = + basicRewrite(table) + .zOrder("longCol", "intCol", "floatCol", "doubleCol", "dateCol", "timestampCol", "stringCol", "binaryCol", + "booleanCol") + .option(SortStrategy.MIN_INPUT_FILES, "1") + .option(SortStrategy.REWRITE_ALL, "true") + .execute(); + + Assert.assertEquals("Should have 1 fileGroups", 1, result.rewriteResults().size()); + int zOrderedFilesTotal = Iterables.size(table.currentSnapshot().addedFiles(table.io())); + Assert.assertEquals("Should have written 1 file", 1, zOrderedFilesTotal); + + table.refresh(); + + List postRaw = spark.read().format("iceberg").load(tableLocation).sort("longCol").collectAsList(); + List postRewriteData = rowsToJava(postRaw); + assertEquals("We shouldn't have changed the data", originalData, postRewriteData); + + shouldHaveSnapshots(table, 2); + shouldHaveACleanCache(table); + } + + @Test + public void testInvalidAPIUsage() { + Table table = createTable(1); + + AssertHelpers.assertThrows("Should be unable to set Strategy more than once", IllegalArgumentException.class, + "Cannot set strategy", () -> actions().rewriteDataFiles(table, tableLocation).binPack().sort()); + + AssertHelpers.assertThrows("Should be unable to set Strategy more than once", IllegalArgumentException.class, + "Cannot set strategy", () -> actions().rewriteDataFiles(table, tableLocation).sort().binPack()); + + AssertHelpers.assertThrows( + "Should be unable to set Strategy more than once", + IllegalArgumentException.class, + "Cannot set strategy", + () -> + actions().rewriteDataFiles(table, tableLocation).sort(SortOrder.unsorted()).binPack()); + } + + @Test + public void testRewriteJobOrderBytesAsc() { + Table table = createTablePartitioned(4, 2); + writeRecords(1, SCALE, 1); + writeRecords(2, SCALE, 2); + writeRecords(3, SCALE, 3); + writeRecords(4, SCALE, 4); + table.updateProperties().set(TableProperties.FORMAT_VERSION, "2").commit(); + + BaseRewriteDataFilesSparkAction basicRewrite = + (BaseRewriteDataFilesSparkAction) basicRewrite(table) + .binPack(); + List expected = toGroupStream(table, basicRewrite) + .mapToLong(RewriteFileGroup::sizeInBytes) + .boxed() + .collect(Collectors.toList()); + + BaseRewriteDataFilesSparkAction jobOrderRewrite = + (BaseRewriteDataFilesSparkAction) basicRewrite(table) + .option(RewriteDataFiles.REWRITE_JOB_ORDER, RewriteJobOrder.BYTES_ASC.orderName()) + .binPack(); + List actual = toGroupStream(table, jobOrderRewrite) + .mapToLong(RewriteFileGroup::sizeInBytes) + .boxed() + .collect(Collectors.toList()); + + expected.sort(Comparator.naturalOrder()); + Assert.assertEquals("Size in bytes order should be ascending", actual, expected); + Collections.reverse(expected); + Assert.assertNotEquals("Size in bytes order should not be descending", actual, expected); + } + + @Test + public void testRewriteJobOrderBytesDesc() { + Table table = createTablePartitioned(4, 2); + writeRecords(1, SCALE, 1); + writeRecords(2, SCALE, 2); + writeRecords(3, SCALE, 3); + writeRecords(4, SCALE, 4); + table.updateProperties().set(TableProperties.FORMAT_VERSION, "2").commit(); + + BaseRewriteDataFilesSparkAction basicRewrite = + (BaseRewriteDataFilesSparkAction) basicRewrite(table) + .binPack(); + List expected = toGroupStream(table, basicRewrite) + .mapToLong(RewriteFileGroup::sizeInBytes) + .boxed() + .collect(Collectors.toList()); + + BaseRewriteDataFilesSparkAction jobOrderRewrite = + (BaseRewriteDataFilesSparkAction) basicRewrite(table) + .option(RewriteDataFiles.REWRITE_JOB_ORDER, RewriteJobOrder.BYTES_DESC.orderName()) + .binPack(); + List actual = toGroupStream(table, jobOrderRewrite) + .mapToLong(RewriteFileGroup::sizeInBytes) + .boxed() + .collect(Collectors.toList()); + + expected.sort(Comparator.reverseOrder()); + Assert.assertEquals("Size in bytes order should be descending", actual, expected); + Collections.reverse(expected); + Assert.assertNotEquals("Size in bytes order should not be ascending", actual, expected); + } + + @Test + public void testRewriteJobOrderFilesAsc() { + Table table = createTablePartitioned(4, 2); + writeRecords(1, SCALE, 1); + writeRecords(2, SCALE, 2); + writeRecords(3, SCALE, 3); + writeRecords(4, SCALE, 4); + table.updateProperties().set(TableProperties.FORMAT_VERSION, "2").commit(); + + BaseRewriteDataFilesSparkAction basicRewrite = + (BaseRewriteDataFilesSparkAction) basicRewrite(table) + .binPack(); + List expected = toGroupStream(table, basicRewrite) + .mapToLong(RewriteFileGroup::numFiles) + .boxed() + .collect(Collectors.toList()); + + BaseRewriteDataFilesSparkAction jobOrderRewrite = + (BaseRewriteDataFilesSparkAction) basicRewrite(table) + .option(RewriteDataFiles.REWRITE_JOB_ORDER, RewriteJobOrder.FILES_ASC.orderName()) + .binPack(); + List actual = toGroupStream(table, jobOrderRewrite) + .mapToLong(RewriteFileGroup::numFiles) + .boxed() + .collect(Collectors.toList()); + + expected.sort(Comparator.naturalOrder()); + Assert.assertEquals("Number of files order should be ascending", actual, expected); + Collections.reverse(expected); + Assert.assertNotEquals("Number of files order should not be descending", actual, expected); + } + + @Test + public void testRewriteJobOrderFilesDesc() { + Table table = createTablePartitioned(4, 2); + writeRecords(1, SCALE, 1); + writeRecords(2, SCALE, 2); + writeRecords(3, SCALE, 3); + writeRecords(4, SCALE, 4); + table.updateProperties().set(TableProperties.FORMAT_VERSION, "2").commit(); + + BaseRewriteDataFilesSparkAction basicRewrite = + (BaseRewriteDataFilesSparkAction) basicRewrite(table) + .binPack(); + List expected = toGroupStream(table, basicRewrite) + .mapToLong(RewriteFileGroup::numFiles) + .boxed() + .collect(Collectors.toList()); + + BaseRewriteDataFilesSparkAction jobOrderRewrite = + (BaseRewriteDataFilesSparkAction) basicRewrite(table) + .option(RewriteDataFiles.REWRITE_JOB_ORDER, RewriteJobOrder.FILES_DESC.orderName()) + .binPack(); + List actual = toGroupStream(table, jobOrderRewrite) + .mapToLong(RewriteFileGroup::numFiles) + .boxed() + .collect(Collectors.toList()); + + expected.sort(Comparator.reverseOrder()); + Assert.assertEquals("Number of files order should be descending", actual, expected); + Collections.reverse(expected); + Assert.assertNotEquals("Number of files order should not be ascending", actual, expected); + } + + private Stream toGroupStream(Table table, + BaseRewriteDataFilesSparkAction rewrite) { + rewrite.validateAndInitOptions(); + Map>> fileGroupsByPartition = + rewrite.planFileGroups(table.currentSnapshot().snapshotId()); + + return rewrite.toGroupStream( + new RewriteExecutionContext(fileGroupsByPartition), fileGroupsByPartition); + } + + protected List currentData() { + return rowsToJava(spark.read().format("iceberg").load(tableLocation) + .sort("c1", "c2", "c3") + .collectAsList()); + } + + protected long testDataSize(Table table) { + return Streams.stream(table.newScan().planFiles()).mapToLong(FileScanTask::length).sum(); + } + + protected void shouldHaveMultipleFiles(Table table) { + table.refresh(); + int numFiles = Iterables.size(table.newScan().planFiles()); + Assert.assertTrue(String.format("Should have multiple files, had %d", numFiles), numFiles > 1); + } + + protected void shouldHaveFiles(Table table, int numExpected) { + table.refresh(); + int numFiles = Iterables.size(table.newScan().planFiles()); + Assert.assertEquals("Did not have the expected number of files", numExpected, numFiles); + } + + protected void shouldHaveSnapshots(Table table, int expectedSnapshots) { + table.refresh(); + int actualSnapshots = Iterables.size(table.snapshots()); + Assert.assertEquals("Table did not have the expected number of snapshots", + expectedSnapshots, actualSnapshots); + } + + protected void shouldHaveNoOrphans(Table table) { + Assert.assertEquals("Should not have found any orphan files", ImmutableList.of(), + actions().deleteOrphanFiles(table) + .olderThan(System.currentTimeMillis()) + .execute() + .orphanFileLocations()); + } + + protected void shouldHaveACleanCache(Table table) { + Assert.assertEquals("Should not have any entries in cache", ImmutableSet.of(), + cacheContents(table)); + } + + protected void shouldHaveLastCommitSorted(Table table, String column) { + List, Pair>> + overlappingFiles = checkForOverlappingFiles(table, column); + + Assert.assertEquals("Found overlapping files", Collections.emptyList(), overlappingFiles); + } + + protected void shouldHaveLastCommitUnsorted(Table table, String column) { + List, Pair>> + overlappingFiles = checkForOverlappingFiles(table, column); + + Assert.assertNotEquals("Found no overlapping files", Collections.emptyList(), overlappingFiles); + } + + private Pair boundsOf(DataFile file, NestedField field, Class javaClass) { + int columnId = field.fieldId(); + return Pair.of(javaClass.cast(Conversions.fromByteBuffer(field.type(), file.lowerBounds().get(columnId))), + javaClass.cast(Conversions.fromByteBuffer(field.type(), file.upperBounds().get(columnId)))); + } + + + private List, Pair>> checkForOverlappingFiles(Table table, String column) { + table.refresh(); + NestedField field = table.schema().caseInsensitiveFindField(column); + Class javaClass = (Class) field.type().typeId().javaClass(); + + Map> filesByPartition = Streams.stream(table.currentSnapshot().addedFiles(table.io())) + .collect(Collectors.groupingBy(DataFile::partition)); + + Stream, Pair>> overlaps = + filesByPartition.entrySet().stream().flatMap(entry -> { + List datafiles = entry.getValue(); + Preconditions.checkArgument(datafiles.size() > 1, + "This test is checking for overlaps in a situation where no overlaps can actually occur because the " + + "partition %s does not contain multiple datafiles", entry.getKey()); + + List, Pair>> boundComparisons = Lists.cartesianProduct(datafiles, datafiles).stream() + .filter(tuple -> tuple.get(0) != tuple.get(1)) + .map(tuple -> Pair.of(boundsOf(tuple.get(0), field, javaClass), boundsOf(tuple.get(1), field, javaClass))) + .collect(Collectors.toList()); + + Comparator comparator = Comparators.forType(field.type().asPrimitiveType()); + + List, Pair>> overlappingFiles = boundComparisons.stream() + .filter(filePair -> { + Pair left = filePair.first(); + T lMin = left.first(); + T lMax = left.second(); + Pair right = filePair.second(); + T rMin = right.first(); + T rMax = right.second(); + boolean boundsDoNotOverlap = + // Min and Max of a range are greater than or equal to the max value of the other range + (comparator.compare(rMax, lMax) >= 0 && comparator.compare(rMin, lMax) >= 0) || + (comparator.compare(lMax, rMax) >= 0 && comparator.compare(lMin, rMax) >= 0); + + return !boundsDoNotOverlap; + }).collect(Collectors.toList()); + return overlappingFiles.stream(); + }); + + return overlaps.collect(Collectors.toList()); + } + + protected Table createTable() { + PartitionSpec spec = PartitionSpec.unpartitioned(); + Map options = Maps.newHashMap(); + Table table = TABLES.create(SCHEMA, spec, options, tableLocation); + table.updateProperties().set(TableProperties.PARQUET_ROW_GROUP_SIZE_BYTES, Integer.toString(20 * 1024)).commit(); + Assert.assertNull("Table must be empty", table.currentSnapshot()); + return table; + } + + /** + * Create a table with a certain number of files, returns the size of a file + * @param files number of files to create + * @return the created table + */ + protected Table createTable(int files) { + Table table = createTable(); + writeRecords(files, SCALE); + return table; + } + + protected Table createTablePartitioned(int partitions, int files, + int numRecords, Map options) { + PartitionSpec spec = PartitionSpec.builderFor(SCHEMA) + .identity("c1") + .truncate("c2", 2) + .build(); + Table table = TABLES.create(SCHEMA, spec, options, tableLocation); + Assert.assertNull("Table must be empty", table.currentSnapshot()); + + writeRecords(files, numRecords, partitions); + return table; + } + + protected Table createTablePartitioned(int partitions, int files) { + return createTablePartitioned(partitions, files, SCALE, Maps.newHashMap()); + } + + private Table createTypeTestTable() { + Schema schema = new Schema( + required(1, "longCol", Types.LongType.get()), + required(2, "intCol", Types.IntegerType.get()), + required(3, "floatCol", Types.FloatType.get()), + optional(4, "doubleCol", Types.DoubleType.get()), + optional(5, "dateCol", Types.DateType.get()), + optional(6, "timestampCol", Types.TimestampType.withZone()), + optional(7, "stringCol", Types.StringType.get()), + optional(8, "booleanCol", Types.BooleanType.get()), + optional(9, "binaryCol", Types.BinaryType.get())); + + Map options = Maps.newHashMap(); + Table table = TABLES.create(schema, PartitionSpec.unpartitioned(), options, tableLocation); + + spark.range(0, 10, 1, 10) + .withColumnRenamed("id", "longCol") + .withColumn("intCol", expr("CAST(longCol AS INT)")) + .withColumn("floatCol", expr("CAST(longCol AS FLOAT)")) + .withColumn("doubleCol", expr("CAST(longCol AS DOUBLE)")) + .withColumn("dateCol", date_add(current_date(), 1)) + .withColumn("timestampCol", expr("TO_TIMESTAMP(dateCol)")) + .withColumn("stringCol", expr("CAST(dateCol AS STRING)")) + .withColumn("booleanCol", expr("longCol > 5")) + .withColumn("binaryCol", expr("CAST(longCol AS BINARY)")) + .write() + .format("iceberg") + .mode("append") + .save(tableLocation); + + return table; + } + + protected int averageFileSize(Table table) { + table.refresh(); + return (int) Streams.stream(table.newScan().planFiles()).mapToLong(FileScanTask::length).average().getAsDouble(); + } + + private void writeRecords(int files, int numRecords) { + writeRecords(files, numRecords, 0); + } + + private void writeRecords(int files, int numRecords, int partitions) { + List records = Lists.newArrayList(); + int rowDimension = (int) Math.ceil(Math.sqrt(numRecords)); + List> data = + IntStream.range(0, rowDimension).boxed().flatMap(x -> + IntStream.range(0, rowDimension).boxed().map(y -> Pair.of(x, y))) + .collect(Collectors.toList()); + Collections.shuffle(data, new Random(42)); + if (partitions > 0) { + data.forEach(i -> records.add(new ThreeColumnRecord( + i.first() % partitions, + "foo" + i.first(), + "bar" + i.second()))); + } else { + data.forEach(i -> records.add(new ThreeColumnRecord( + i.first(), + "foo" + i.first(), + "bar" + i.second()))); + } + Dataset df = spark.createDataFrame(records, ThreeColumnRecord.class).repartition(files); + writeDF(df); + } + + private void writeDF(Dataset df) { + df.select("c1", "c2", "c3") + .sortWithinPartitions("c1", "c2") + .write() + .format("iceberg") + .mode("append") + .save(tableLocation); + } + + private List writePosDeletesToFile(Table table, DataFile dataFile, int outputDeleteFiles) { + return writePosDeletes(table, dataFile.partition(), dataFile.path().toString(), outputDeleteFiles); + } + + private List writePosDeletes(Table table, StructLike partition, String path, int outputDeleteFiles) { + List results = Lists.newArrayList(); + int rowPosition = 0; + for (int file = 0; file < outputDeleteFiles; file++) { + OutputFile outputFile = table.io().newOutputFile( + table.locationProvider().newDataLocation(UUID.randomUUID().toString())); + EncryptedOutputFile encryptedOutputFile = EncryptedFiles.encryptedOutput( + outputFile, EncryptionKeyMetadata.EMPTY); + + GenericAppenderFactory appenderFactory = new GenericAppenderFactory( + table.schema(), table.spec(), null, null, null); + PositionDeleteWriter posDeleteWriter = appenderFactory + .set(TableProperties.DEFAULT_WRITE_METRICS_MODE, "full") + .newPosDeleteWriter(encryptedOutputFile, FileFormat.PARQUET, partition); + + posDeleteWriter.delete(path, rowPosition); + try { + posDeleteWriter.close(); + } catch (IOException e) { + throw new UncheckedIOException(e); + } + + results.add(posDeleteWriter.toDeleteFile()); + rowPosition++; + } + + return results; + } + + private ActionsProvider actions() { + return SparkActions.get(); + } + + private Set cacheContents(Table table) { + return ImmutableSet.builder() + .addAll(manager.fetchSetIDs(table)) + .addAll(coordinator.fetchSetIDs(table)) + .build(); + } + + private double percentFilesRequired(Table table, String col, String value) { + return percentFilesRequired(table, new String[]{col}, new String[]{value}); + } + + private double percentFilesRequired(Table table, String[] cols, String[] values) { + Preconditions.checkArgument(cols.length == values.length); + Expression restriction = Expressions.alwaysTrue(); + for (int i = 0; i < cols.length; i++) { + restriction = Expressions.and(restriction, Expressions.equal(cols[i], values[i])); + } + int totalFiles = Iterables.size(table.newScan().planFiles()); + int filteredFiles = Iterables.size(table.newScan().filter(restriction).planFiles()); + return (double) filteredFiles / (double) totalFiles; + } + + class GroupInfoMatcher implements ArgumentMatcher { + private final Set groupIDs; + + GroupInfoMatcher(Integer... globalIndex) { + this.groupIDs = ImmutableSet.copyOf(globalIndex); + } + + @Override + public boolean matches(RewriteFileGroup argument) { + return groupIDs.contains(argument.info().globalIndex()); + } + } +} diff --git a/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/actions/TestRewriteManifestsAction.java b/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/actions/TestRewriteManifestsAction.java new file mode 100644 index 000000000000..f30251e74001 --- /dev/null +++ b/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/actions/TestRewriteManifestsAction.java @@ -0,0 +1,493 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iceberg.spark.actions; + +import java.io.File; +import java.io.IOException; +import java.util.List; +import java.util.Map; +import org.apache.hadoop.conf.Configuration; +import org.apache.iceberg.AssertHelpers; +import org.apache.iceberg.ManifestFile; +import org.apache.iceberg.PartitionSpec; +import org.apache.iceberg.Schema; +import org.apache.iceberg.Snapshot; +import org.apache.iceberg.Table; +import org.apache.iceberg.TableProperties; +import org.apache.iceberg.actions.RewriteManifests; +import org.apache.iceberg.exceptions.CommitStateUnknownException; +import org.apache.iceberg.hadoop.HadoopTables; +import org.apache.iceberg.relocated.com.google.common.collect.Iterables; +import org.apache.iceberg.relocated.com.google.common.collect.Lists; +import org.apache.iceberg.relocated.com.google.common.collect.Maps; +import org.apache.iceberg.spark.SparkTableUtil; +import org.apache.iceberg.spark.SparkTestBase; +import org.apache.iceberg.spark.source.ThreeColumnRecord; +import org.apache.iceberg.types.Types; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Encoders; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.catalyst.TableIdentifier; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.TemporaryFolder; +import org.junit.runner.RunWith; +import org.junit.runners.Parameterized; + +import static org.apache.iceberg.types.Types.NestedField.optional; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.spy; +import static org.mockito.Mockito.when; + +@RunWith(Parameterized.class) +public class TestRewriteManifestsAction extends SparkTestBase { + + private static final HadoopTables TABLES = new HadoopTables(new Configuration()); + private static final Schema SCHEMA = new Schema( + optional(1, "c1", Types.IntegerType.get()), + optional(2, "c2", Types.StringType.get()), + optional(3, "c3", Types.StringType.get()) + ); + + @Parameterized.Parameters(name = "snapshotIdInheritanceEnabled = {0}") + public static Object[] parameters() { + return new Object[] { "true", "false" }; + } + + @Rule + public TemporaryFolder temp = new TemporaryFolder(); + + private final String snapshotIdInheritanceEnabled; + private String tableLocation = null; + + public TestRewriteManifestsAction(String snapshotIdInheritanceEnabled) { + this.snapshotIdInheritanceEnabled = snapshotIdInheritanceEnabled; + } + + @Before + public void setupTableLocation() throws Exception { + File tableDir = temp.newFolder(); + this.tableLocation = tableDir.toURI().toString(); + } + + @Test + public void testRewriteManifestsEmptyTable() throws IOException { + PartitionSpec spec = PartitionSpec.unpartitioned(); + Map options = Maps.newHashMap(); + options.put(TableProperties.SNAPSHOT_ID_INHERITANCE_ENABLED, snapshotIdInheritanceEnabled); + Table table = TABLES.create(SCHEMA, spec, options, tableLocation); + + Assert.assertNull("Table must be empty", table.currentSnapshot()); + + SparkActions actions = SparkActions.get(); + + actions.rewriteManifests(table) + .rewriteIf(manifest -> true) + .stagingLocation(temp.newFolder().toString()) + .execute(); + + Assert.assertNull("Table must stay empty", table.currentSnapshot()); + } + + @Test + public void testRewriteSmallManifestsNonPartitionedTable() { + PartitionSpec spec = PartitionSpec.unpartitioned(); + Map options = Maps.newHashMap(); + options.put(TableProperties.SNAPSHOT_ID_INHERITANCE_ENABLED, snapshotIdInheritanceEnabled); + Table table = TABLES.create(SCHEMA, spec, options, tableLocation); + + List records1 = Lists.newArrayList( + new ThreeColumnRecord(1, null, "AAAA"), + new ThreeColumnRecord(1, "BBBBBBBBBB", "BBBB") + ); + writeRecords(records1); + + List records2 = Lists.newArrayList( + new ThreeColumnRecord(2, "CCCCCCCCCC", "CCCC"), + new ThreeColumnRecord(2, "DDDDDDDDDD", "DDDD") + ); + writeRecords(records2); + + table.refresh(); + + List manifests = table.currentSnapshot().allManifests(table.io()); + Assert.assertEquals("Should have 2 manifests before rewrite", 2, manifests.size()); + + SparkActions actions = SparkActions.get(); + + RewriteManifests.Result result = actions.rewriteManifests(table) + .rewriteIf(manifest -> true) + .execute(); + + Assert.assertEquals("Action should rewrite 2 manifests", 2, Iterables.size(result.rewrittenManifests())); + Assert.assertEquals("Action should add 1 manifests", 1, Iterables.size(result.addedManifests())); + + table.refresh(); + + List newManifests = table.currentSnapshot().allManifests(table.io()); + Assert.assertEquals("Should have 1 manifests after rewrite", 1, newManifests.size()); + + Assert.assertEquals(4, (long) newManifests.get(0).existingFilesCount()); + Assert.assertFalse(newManifests.get(0).hasAddedFiles()); + Assert.assertFalse(newManifests.get(0).hasDeletedFiles()); + + List expectedRecords = Lists.newArrayList(); + expectedRecords.addAll(records1); + expectedRecords.addAll(records2); + + Dataset resultDF = spark.read().format("iceberg").load(tableLocation); + List actualRecords = resultDF.sort("c1", "c2") + .as(Encoders.bean(ThreeColumnRecord.class)) + .collectAsList(); + + Assert.assertEquals("Rows must match", expectedRecords, actualRecords); + } + + @Test + public void testRewriteManifestsWithCommitStateUnknownException() { + PartitionSpec spec = PartitionSpec.unpartitioned(); + Map options = Maps.newHashMap(); + options.put(TableProperties.SNAPSHOT_ID_INHERITANCE_ENABLED, snapshotIdInheritanceEnabled); + Table table = TABLES.create(SCHEMA, spec, options, tableLocation); + + List records1 = Lists.newArrayList( + new ThreeColumnRecord(1, null, "AAAA"), + new ThreeColumnRecord(1, "BBBBBBBBBB", "BBBB") + ); + writeRecords(records1); + + List records2 = Lists.newArrayList( + new ThreeColumnRecord(2, "CCCCCCCCCC", "CCCC"), + new ThreeColumnRecord(2, "DDDDDDDDDD", "DDDD") + ); + writeRecords(records2); + + table.refresh(); + + List manifests = table.currentSnapshot().allManifests(table.io()); + Assert.assertEquals("Should have 2 manifests before rewrite", 2, manifests.size()); + + SparkActions actions = SparkActions.get(); + + // create a spy which would throw a CommitStateUnknownException after successful commit. + org.apache.iceberg.RewriteManifests newRewriteManifests = table.rewriteManifests(); + org.apache.iceberg.RewriteManifests spyNewRewriteManifests = spy(newRewriteManifests); + doAnswer(invocation -> { + newRewriteManifests.commit(); + throw new CommitStateUnknownException(new RuntimeException("Datacenter on Fire")); + }).when(spyNewRewriteManifests).commit(); + + Table spyTable = spy(table); + when(spyTable.rewriteManifests()).thenReturn(spyNewRewriteManifests); + + AssertHelpers.assertThrowsCause("Should throw a Commit State Unknown Exception", + RuntimeException.class, + "Datacenter on Fire", + () -> actions.rewriteManifests(spyTable).rewriteIf(manifest -> true).execute()); + + table.refresh(); + + // table should reflect the changes, since the commit was successful + List newManifests = table.currentSnapshot().allManifests(table.io()); + Assert.assertEquals("Should have 1 manifests after rewrite", 1, newManifests.size()); + + Assert.assertEquals(4, (long) newManifests.get(0).existingFilesCount()); + Assert.assertFalse(newManifests.get(0).hasAddedFiles()); + Assert.assertFalse(newManifests.get(0).hasDeletedFiles()); + + List expectedRecords = Lists.newArrayList(); + expectedRecords.addAll(records1); + expectedRecords.addAll(records2); + + Dataset resultDF = spark.read().format("iceberg").load(tableLocation); + List actualRecords = resultDF.sort("c1", "c2") + .as(Encoders.bean(ThreeColumnRecord.class)) + .collectAsList(); + + Assert.assertEquals("Rows must match", expectedRecords, actualRecords); + } + + @Test + public void testRewriteSmallManifestsPartitionedTable() { + PartitionSpec spec = PartitionSpec.builderFor(SCHEMA) + .identity("c1") + .truncate("c2", 2) + .build(); + Map options = Maps.newHashMap(); + options.put(TableProperties.SNAPSHOT_ID_INHERITANCE_ENABLED, snapshotIdInheritanceEnabled); + Table table = TABLES.create(SCHEMA, spec, options, tableLocation); + + List records1 = Lists.newArrayList( + new ThreeColumnRecord(1, null, "AAAA"), + new ThreeColumnRecord(1, "BBBBBBBBBB", "BBBB") + ); + writeRecords(records1); + + List records2 = Lists.newArrayList( + new ThreeColumnRecord(2, "CCCCCCCCCC", "CCCC"), + new ThreeColumnRecord(2, "DDDDDDDDDD", "DDDD") + ); + writeRecords(records2); + + List records3 = Lists.newArrayList( + new ThreeColumnRecord(3, "EEEEEEEEEE", "EEEE"), + new ThreeColumnRecord(3, "FFFFFFFFFF", "FFFF") + ); + writeRecords(records3); + + List records4 = Lists.newArrayList( + new ThreeColumnRecord(4, "GGGGGGGGGG", "GGGG"), + new ThreeColumnRecord(4, "HHHHHHHHHG", "HHHH") + ); + writeRecords(records4); + + table.refresh(); + + List manifests = table.currentSnapshot().allManifests(table.io()); + Assert.assertEquals("Should have 4 manifests before rewrite", 4, manifests.size()); + + SparkActions actions = SparkActions.get(); + + // we will expect to have 2 manifests with 4 entries in each after rewrite + long manifestEntrySizeBytes = computeManifestEntrySizeBytes(manifests); + long targetManifestSizeBytes = (long) (1.05 * 4 * manifestEntrySizeBytes); + + table.updateProperties() + .set(TableProperties.MANIFEST_TARGET_SIZE_BYTES, String.valueOf(targetManifestSizeBytes)) + .commit(); + + RewriteManifests.Result result = actions.rewriteManifests(table) + .rewriteIf(manifest -> true) + .execute(); + + Assert.assertEquals("Action should rewrite 4 manifests", 4, Iterables.size(result.rewrittenManifests())); + Assert.assertEquals("Action should add 2 manifests", 2, Iterables.size(result.addedManifests())); + + table.refresh(); + + List newManifests = table.currentSnapshot().allManifests(table.io()); + Assert.assertEquals("Should have 2 manifests after rewrite", 2, newManifests.size()); + + Assert.assertEquals(4, (long) newManifests.get(0).existingFilesCount()); + Assert.assertFalse(newManifests.get(0).hasAddedFiles()); + Assert.assertFalse(newManifests.get(0).hasDeletedFiles()); + + Assert.assertEquals(4, (long) newManifests.get(1).existingFilesCount()); + Assert.assertFalse(newManifests.get(1).hasAddedFiles()); + Assert.assertFalse(newManifests.get(1).hasDeletedFiles()); + + List expectedRecords = Lists.newArrayList(); + expectedRecords.addAll(records1); + expectedRecords.addAll(records2); + expectedRecords.addAll(records3); + expectedRecords.addAll(records4); + + Dataset resultDF = spark.read().format("iceberg").load(tableLocation); + List actualRecords = resultDF.sort("c1", "c2") + .as(Encoders.bean(ThreeColumnRecord.class)) + .collectAsList(); + + Assert.assertEquals("Rows must match", expectedRecords, actualRecords); + } + + @Test + public void testRewriteImportedManifests() throws IOException { + PartitionSpec spec = PartitionSpec.builderFor(SCHEMA) + .identity("c3") + .build(); + Map options = Maps.newHashMap(); + options.put(TableProperties.SNAPSHOT_ID_INHERITANCE_ENABLED, snapshotIdInheritanceEnabled); + Table table = TABLES.create(SCHEMA, spec, options, tableLocation); + + List records = Lists.newArrayList( + new ThreeColumnRecord(1, null, "AAAA"), + new ThreeColumnRecord(1, "BBBBBBBBBB", "BBBB") + ); + File parquetTableDir = temp.newFolder("parquet_table"); + String parquetTableLocation = parquetTableDir.toURI().toString(); + + try { + Dataset inputDF = spark.createDataFrame(records, ThreeColumnRecord.class); + inputDF.select("c1", "c2", "c3") + .write() + .format("parquet") + .mode("overwrite") + .option("path", parquetTableLocation) + .partitionBy("c3") + .saveAsTable("parquet_table"); + + File stagingDir = temp.newFolder("staging-dir"); + SparkTableUtil.importSparkTable(spark, new TableIdentifier("parquet_table"), table, stagingDir.toString()); + + Snapshot snapshot = table.currentSnapshot(); + + SparkActions actions = SparkActions.get(); + + RewriteManifests.Result result = actions.rewriteManifests(table) + .rewriteIf(manifest -> true) + .stagingLocation(temp.newFolder().toString()) + .execute(); + + Assert.assertEquals("Action should rewrite all manifests", + snapshot.allManifests(table.io()), result.rewrittenManifests()); + Assert.assertEquals("Action should add 1 manifest", 1, Iterables.size(result.addedManifests())); + + } finally { + spark.sql("DROP TABLE parquet_table"); + } + } + + @Test + public void testRewriteLargeManifestsPartitionedTable() throws IOException { + PartitionSpec spec = PartitionSpec.builderFor(SCHEMA) + .identity("c3") + .build(); + Map options = Maps.newHashMap(); + options.put(TableProperties.SNAPSHOT_ID_INHERITANCE_ENABLED, snapshotIdInheritanceEnabled); + Table table = TABLES.create(SCHEMA, spec, options, tableLocation); + + // all records belong to the same partition + List records = Lists.newArrayList(); + for (int i = 0; i < 50; i++) { + records.add(new ThreeColumnRecord(i, String.valueOf(i), "0")); + } + Dataset df = spark.createDataFrame(records, ThreeColumnRecord.class); + // repartition to create separate files + writeDF(df.repartition(50, df.col("c1"))); + + table.refresh(); + + List manifests = table.currentSnapshot().allManifests(table.io()); + Assert.assertEquals("Should have 1 manifests before rewrite", 1, manifests.size()); + + // set the target manifest size to a small value to force splitting records into multiple files + table.updateProperties() + .set(TableProperties.MANIFEST_TARGET_SIZE_BYTES, String.valueOf(manifests.get(0).length() / 2)) + .commit(); + + SparkActions actions = SparkActions.get(); + + RewriteManifests.Result result = actions.rewriteManifests(table) + .rewriteIf(manifest -> true) + .stagingLocation(temp.newFolder().toString()) + .execute(); + + Assert.assertEquals("Action should rewrite 1 manifest", 1, Iterables.size(result.rewrittenManifests())); + Assert.assertEquals("Action should add 2 manifests", 2, Iterables.size(result.addedManifests())); + + table.refresh(); + + List newManifests = table.currentSnapshot().allManifests(table.io()); + Assert.assertEquals("Should have 2 manifests after rewrite", 2, newManifests.size()); + + Dataset resultDF = spark.read().format("iceberg").load(tableLocation); + List actualRecords = resultDF.sort("c1", "c2") + .as(Encoders.bean(ThreeColumnRecord.class)) + .collectAsList(); + + Assert.assertEquals("Rows must match", records, actualRecords); + } + + @Test + public void testRewriteManifestsWithPredicate() throws IOException { + PartitionSpec spec = PartitionSpec.builderFor(SCHEMA) + .identity("c1") + .truncate("c2", 2) + .build(); + Map options = Maps.newHashMap(); + options.put(TableProperties.SNAPSHOT_ID_INHERITANCE_ENABLED, snapshotIdInheritanceEnabled); + Table table = TABLES.create(SCHEMA, spec, options, tableLocation); + + List records1 = Lists.newArrayList( + new ThreeColumnRecord(1, null, "AAAA"), + new ThreeColumnRecord(1, "BBBBBBBBBB", "BBBB") + ); + writeRecords(records1); + + List records2 = Lists.newArrayList( + new ThreeColumnRecord(2, "CCCCCCCCCC", "CCCC"), + new ThreeColumnRecord(2, "DDDDDDDDDD", "DDDD") + ); + writeRecords(records2); + + table.refresh(); + + List manifests = table.currentSnapshot().allManifests(table.io()); + Assert.assertEquals("Should have 2 manifests before rewrite", 2, manifests.size()); + + SparkActions actions = SparkActions.get(); + + // rewrite only the first manifest without caching + RewriteManifests.Result result = actions.rewriteManifests(table) + .rewriteIf(manifest -> manifest.path().equals(manifests.get(0).path())) + .stagingLocation(temp.newFolder().toString()) + .option("use-caching", "false") + .execute(); + + Assert.assertEquals("Action should rewrite 1 manifest", 1, Iterables.size(result.rewrittenManifests())); + Assert.assertEquals("Action should add 1 manifests", 1, Iterables.size(result.addedManifests())); + + table.refresh(); + + List newManifests = table.currentSnapshot().allManifests(table.io()); + Assert.assertEquals("Should have 2 manifests after rewrite", 2, newManifests.size()); + + Assert.assertFalse("First manifest must be rewritten", newManifests.contains(manifests.get(0))); + Assert.assertTrue("Second manifest must not be rewritten", newManifests.contains(manifests.get(1))); + + List expectedRecords = Lists.newArrayList(); + expectedRecords.addAll(records1); + expectedRecords.addAll(records2); + + Dataset resultDF = spark.read().format("iceberg").load(tableLocation); + List actualRecords = resultDF.sort("c1", "c2") + .as(Encoders.bean(ThreeColumnRecord.class)) + .collectAsList(); + + Assert.assertEquals("Rows must match", expectedRecords, actualRecords); + } + + private void writeRecords(List records) { + Dataset df = spark.createDataFrame(records, ThreeColumnRecord.class); + writeDF(df); + } + + private void writeDF(Dataset df) { + df.select("c1", "c2", "c3") + .write() + .format("iceberg") + .mode("append") + .save(tableLocation); + } + + private long computeManifestEntrySizeBytes(List manifests) { + long totalSize = 0L; + int numEntries = 0; + + for (ManifestFile manifest : manifests) { + totalSize += manifest.length(); + numEntries += manifest.addedFilesCount() + manifest.existingFilesCount() + manifest.deletedFilesCount(); + } + + return totalSize / numEntries; + } +} diff --git a/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/data/AvroDataTest.java b/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/data/AvroDataTest.java new file mode 100644 index 000000000000..5f902f6e6828 --- /dev/null +++ b/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/data/AvroDataTest.java @@ -0,0 +1,240 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iceberg.spark.data; + +import java.io.IOException; +import java.util.Map; +import java.util.concurrent.atomic.AtomicInteger; +import org.apache.iceberg.Schema; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap; +import org.apache.iceberg.relocated.com.google.common.collect.Lists; +import org.apache.iceberg.relocated.com.google.common.collect.Maps; +import org.apache.iceberg.spark.SparkSQLProperties; +import org.apache.iceberg.types.TypeUtil; +import org.apache.iceberg.types.Types; +import org.apache.iceberg.types.Types.ListType; +import org.apache.iceberg.types.Types.LongType; +import org.apache.iceberg.types.Types.MapType; +import org.apache.iceberg.types.Types.StructType; +import org.apache.spark.sql.internal.SQLConf; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.TemporaryFolder; + +import static org.apache.iceberg.types.Types.NestedField.optional; +import static org.apache.iceberg.types.Types.NestedField.required; + +public abstract class AvroDataTest { + + protected abstract void writeAndValidate(Schema schema) throws IOException; + + protected static final StructType SUPPORTED_PRIMITIVES = StructType.of( + required(100, "id", LongType.get()), + optional(101, "data", Types.StringType.get()), + required(102, "b", Types.BooleanType.get()), + optional(103, "i", Types.IntegerType.get()), + required(104, "l", LongType.get()), + optional(105, "f", Types.FloatType.get()), + required(106, "d", Types.DoubleType.get()), + optional(107, "date", Types.DateType.get()), + required(108, "ts", Types.TimestampType.withZone()), + required(110, "s", Types.StringType.get()), + // required(111, "uuid", Types.UUIDType.get()), + required(112, "fixed", Types.FixedType.ofLength(7)), + optional(113, "bytes", Types.BinaryType.get()), + required(114, "dec_9_0", Types.DecimalType.of(9, 0)), + required(115, "dec_11_2", Types.DecimalType.of(11, 2)), + required(116, "dec_38_10", Types.DecimalType.of(38, 10)) // spark's maximum precision + ); + + @Rule + public TemporaryFolder temp = new TemporaryFolder(); + + @Test + public void testSimpleStruct() throws IOException { + writeAndValidate(TypeUtil.assignIncreasingFreshIds(new Schema(SUPPORTED_PRIMITIVES.fields()))); + } + + @Test + public void testStructWithRequiredFields() throws IOException { + writeAndValidate(TypeUtil.assignIncreasingFreshIds(new Schema( + Lists.transform(SUPPORTED_PRIMITIVES.fields(), Types.NestedField::asRequired)))); + } + + @Test + public void testStructWithOptionalFields() throws IOException { + writeAndValidate(TypeUtil.assignIncreasingFreshIds(new Schema( + Lists.transform(SUPPORTED_PRIMITIVES.fields(), Types.NestedField::asOptional)))); + } + + @Test + public void testNestedStruct() throws IOException { + writeAndValidate(TypeUtil.assignIncreasingFreshIds(new Schema(required(1, "struct", SUPPORTED_PRIMITIVES)))); + } + + @Test + public void testArray() throws IOException { + Schema schema = new Schema( + required(0, "id", LongType.get()), + optional(1, "data", ListType.ofOptional(2, Types.StringType.get()))); + + writeAndValidate(schema); + } + + @Test + public void testArrayOfStructs() throws IOException { + Schema schema = TypeUtil.assignIncreasingFreshIds(new Schema( + required(0, "id", LongType.get()), + optional(1, "data", ListType.ofOptional(2, SUPPORTED_PRIMITIVES)))); + + writeAndValidate(schema); + } + + @Test + public void testMap() throws IOException { + Schema schema = new Schema( + required(0, "id", LongType.get()), + optional(1, "data", MapType.ofOptional(2, 3, + Types.StringType.get(), + Types.StringType.get()))); + + writeAndValidate(schema); + } + + @Test + public void testNumericMapKey() throws IOException { + Schema schema = new Schema( + required(0, "id", LongType.get()), + optional(1, "data", MapType.ofOptional(2, 3, + Types.LongType.get(), + Types.StringType.get()))); + + writeAndValidate(schema); + } + + @Test + public void testComplexMapKey() throws IOException { + Schema schema = new Schema( + required(0, "id", LongType.get()), + optional(1, "data", MapType.ofOptional(2, 3, + Types.StructType.of( + required(4, "i", Types.IntegerType.get()), + optional(5, "s", Types.StringType.get())), + Types.StringType.get()))); + + writeAndValidate(schema); + } + + @Test + public void testMapOfStructs() throws IOException { + Schema schema = TypeUtil.assignIncreasingFreshIds(new Schema( + required(0, "id", LongType.get()), + optional(1, "data", MapType.ofOptional(2, 3, + Types.StringType.get(), + SUPPORTED_PRIMITIVES)))); + + writeAndValidate(schema); + } + + @Test + public void testMixedTypes() throws IOException { + StructType structType = StructType.of( + required(0, "id", LongType.get()), + optional(1, "list_of_maps", + ListType.ofOptional(2, MapType.ofOptional(3, 4, + Types.StringType.get(), + SUPPORTED_PRIMITIVES))), + optional(5, "map_of_lists", + MapType.ofOptional(6, 7, + Types.StringType.get(), + ListType.ofOptional(8, SUPPORTED_PRIMITIVES))), + required(9, "list_of_lists", + ListType.ofOptional(10, ListType.ofOptional(11, SUPPORTED_PRIMITIVES))), + required(12, "map_of_maps", + MapType.ofOptional(13, 14, + Types.StringType.get(), + MapType.ofOptional(15, 16, + Types.StringType.get(), + SUPPORTED_PRIMITIVES))), + required(17, "list_of_struct_of_nested_types", ListType.ofOptional(19, StructType.of( + Types.NestedField.required(20, "m1", MapType.ofOptional(21, 22, + Types.StringType.get(), + SUPPORTED_PRIMITIVES)), + Types.NestedField.optional(23, "l1", ListType.ofRequired(24, SUPPORTED_PRIMITIVES)), + Types.NestedField.required(25, "l2", ListType.ofRequired(26, SUPPORTED_PRIMITIVES)), + Types.NestedField.optional(27, "m2", MapType.ofOptional(28, 29, + Types.StringType.get(), + SUPPORTED_PRIMITIVES)) + ))) + ); + + Schema schema = new Schema(TypeUtil.assignFreshIds(structType, new AtomicInteger(0)::incrementAndGet) + .asStructType().fields()); + + writeAndValidate(schema); + } + + @Test + public void testTimestampWithoutZone() throws IOException { + withSQLConf(ImmutableMap.of(SparkSQLProperties.HANDLE_TIMESTAMP_WITHOUT_TIMEZONE, "true"), () -> { + Schema schema = TypeUtil.assignIncreasingFreshIds(new Schema( + required(0, "id", LongType.get()), + optional(1, "ts_without_zone", Types.TimestampType.withoutZone()))); + + writeAndValidate(schema); + }); + } + + protected void withSQLConf(Map conf, Action action) throws IOException { + SQLConf sqlConf = SQLConf.get(); + + Map currentConfValues = Maps.newHashMap(); + conf.keySet().forEach(confKey -> { + if (sqlConf.contains(confKey)) { + String currentConfValue = sqlConf.getConfString(confKey); + currentConfValues.put(confKey, currentConfValue); + } + }); + + conf.forEach((confKey, confValue) -> { + if (SQLConf.isStaticConfigKey(confKey)) { + throw new RuntimeException("Cannot modify the value of a static config: " + confKey); + } + sqlConf.setConfString(confKey, confValue); + }); + + try { + action.invoke(); + } finally { + conf.forEach((confKey, confValue) -> { + if (currentConfValues.containsKey(confKey)) { + sqlConf.setConfString(confKey, currentConfValues.get(confKey)); + } else { + sqlConf.unsetConf(confKey); + } + }); + } + } + + @FunctionalInterface + protected interface Action { + void invoke() throws IOException; + } +} diff --git a/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/data/GenericsHelpers.java b/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/data/GenericsHelpers.java new file mode 100644 index 000000000000..46c95cef112d --- /dev/null +++ b/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/data/GenericsHelpers.java @@ -0,0 +1,311 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iceberg.spark.data; + +import java.math.BigDecimal; +import java.nio.ByteBuffer; +import java.sql.Timestamp; +import java.time.Instant; +import java.time.LocalDate; +import java.time.LocalDateTime; +import java.time.OffsetDateTime; +import java.time.ZoneId; +import java.time.ZoneOffset; +import java.time.temporal.ChronoUnit; +import java.util.Collection; +import java.util.Date; +import java.util.List; +import java.util.Map; +import java.util.UUID; +import org.apache.iceberg.data.Record; +import org.apache.iceberg.relocated.com.google.common.collect.Lists; +import org.apache.iceberg.types.Type; +import org.apache.iceberg.types.Types; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.catalyst.InternalRow; +import org.apache.spark.sql.catalyst.util.ArrayData; +import org.apache.spark.sql.catalyst.util.MapData; +import org.apache.spark.sql.types.Decimal; +import org.apache.spark.unsafe.types.UTF8String; +import org.assertj.core.api.Assertions; +import org.junit.Assert; +import scala.collection.Seq; + +import static org.apache.iceberg.spark.SparkSchemaUtil.convert; +import static scala.collection.JavaConverters.mapAsJavaMapConverter; +import static scala.collection.JavaConverters.seqAsJavaListConverter; + +public class GenericsHelpers { + private GenericsHelpers() { + } + + private static final OffsetDateTime EPOCH = Instant.ofEpochMilli(0L).atOffset(ZoneOffset.UTC); + private static final LocalDate EPOCH_DAY = EPOCH.toLocalDate(); + + public static void assertEqualsSafe(Types.StructType struct, Record expected, Row actual) { + List fields = struct.fields(); + for (int i = 0; i < fields.size(); i += 1) { + Type fieldType = fields.get(i).type(); + + Object expectedValue = expected.get(i); + Object actualValue = actual.get(i); + + assertEqualsSafe(fieldType, expectedValue, actualValue); + } + } + + private static void assertEqualsSafe(Types.ListType list, Collection expected, List actual) { + Type elementType = list.elementType(); + List expectedElements = Lists.newArrayList(expected); + for (int i = 0; i < expectedElements.size(); i += 1) { + Object expectedValue = expectedElements.get(i); + Object actualValue = actual.get(i); + + assertEqualsSafe(elementType, expectedValue, actualValue); + } + } + + private static void assertEqualsSafe(Types.MapType map, + Map expected, Map actual) { + Type keyType = map.keyType(); + Type valueType = map.valueType(); + Assert.assertEquals("Should have the same number of keys", expected.keySet().size(), actual.keySet().size()); + + for (Object expectedKey : expected.keySet()) { + Object matchingKey = null; + for (Object actualKey : actual.keySet()) { + try { + assertEqualsSafe(keyType, expectedKey, actualKey); + matchingKey = actualKey; + break; + } catch (AssertionError e) { + // failed + } + } + + Assert.assertNotNull("Should have a matching key", matchingKey); + assertEqualsSafe(valueType, expected.get(expectedKey), actual.get(matchingKey)); + } + } + + @SuppressWarnings("unchecked") + private static void assertEqualsSafe(Type type, Object expected, Object actual) { + if (expected == null && actual == null) { + return; + } + + switch (type.typeId()) { + case BOOLEAN: + case INTEGER: + case LONG: + case FLOAT: + case DOUBLE: + Assert.assertEquals("Primitive value should be equal to expected", expected, actual); + break; + case DATE: + Assertions.assertThat(expected).as("Should expect a LocalDate").isInstanceOf(LocalDate.class); + Assertions.assertThat(actual).as("Should be a Date").isInstanceOf(Date.class); + Assert.assertEquals("ISO-8601 date should be equal", expected.toString(), actual.toString()); + break; + case TIMESTAMP: + Assertions.assertThat(actual).as("Should be a Timestamp").isInstanceOf(Timestamp.class); + Timestamp ts = (Timestamp) actual; + // milliseconds from nanos has already been added by getTime + OffsetDateTime actualTs = EPOCH.plusNanos( + (ts.getTime() * 1_000_000) + (ts.getNanos() % 1_000_000)); + Types.TimestampType timestampType = (Types.TimestampType) type; + if (timestampType.shouldAdjustToUTC()) { + Assertions.assertThat(expected).as("Should expect an OffsetDateTime").isInstanceOf(OffsetDateTime.class); + Assert.assertEquals("Timestamp should be equal", expected, actualTs); + } else { + Assertions.assertThat(expected).as("Should expect an LocalDateTime").isInstanceOf(LocalDateTime.class); + Assert.assertEquals("Timestamp should be equal", expected, actualTs.toLocalDateTime()); + } + break; + case STRING: + Assertions.assertThat(actual).as("Should be a String").isInstanceOf(String.class); + Assert.assertEquals("Strings should be equal", String.valueOf(expected), actual); + break; + case UUID: + Assertions.assertThat(expected).as("Should expect a UUID").isInstanceOf(UUID.class); + Assertions.assertThat(actual).as("Should be a String").isInstanceOf(String.class); + Assert.assertEquals("UUID string representation should match", + expected.toString(), actual); + break; + case FIXED: + Assertions.assertThat(expected).as("Should expect a byte[]").isInstanceOf(byte[].class); + Assertions.assertThat(actual).as("Should be a byte[]").isInstanceOf(byte[].class); + Assert.assertArrayEquals("Bytes should match", + (byte[]) expected, (byte[]) actual); + break; + case BINARY: + Assertions.assertThat(expected).as("Should expect a ByteBuffer").isInstanceOf(ByteBuffer.class); + Assertions.assertThat(actual).as("Should be a byte[]").isInstanceOf(byte[].class); + Assert.assertArrayEquals("Bytes should match", + ((ByteBuffer) expected).array(), (byte[]) actual); + break; + case DECIMAL: + Assertions.assertThat(expected).as("Should expect a BigDecimal").isInstanceOf(BigDecimal.class); + Assertions.assertThat(actual).as("Should be a BigDecimal").isInstanceOf(BigDecimal.class); + Assert.assertEquals("BigDecimals should be equal", expected, actual); + break; + case STRUCT: + Assertions.assertThat(expected).as("Should expect a Record").isInstanceOf(Record.class); + Assertions.assertThat(actual).as("Should be a Row").isInstanceOf(Row.class); + assertEqualsSafe(type.asNestedType().asStructType(), (Record) expected, (Row) actual); + break; + case LIST: + Assertions.assertThat(expected).as("Should expect a Collection").isInstanceOf(Collection.class); + Assertions.assertThat(actual).as("Should be a Seq").isInstanceOf(Seq.class); + List asList = seqAsJavaListConverter((Seq) actual).asJava(); + assertEqualsSafe(type.asNestedType().asListType(), (Collection) expected, asList); + break; + case MAP: + Assertions.assertThat(expected).as("Should expect a Collection").isInstanceOf(Map.class); + Assertions.assertThat(actual).as("Should be a Map").isInstanceOf(scala.collection.Map.class); + Map asMap = mapAsJavaMapConverter( + (scala.collection.Map) actual).asJava(); + assertEqualsSafe(type.asNestedType().asMapType(), (Map) expected, asMap); + break; + case TIME: + default: + throw new IllegalArgumentException("Not a supported type: " + type); + } + } + + public static void assertEqualsUnsafe(Types.StructType struct, Record expected, InternalRow actual) { + List fields = struct.fields(); + for (int i = 0; i < fields.size(); i += 1) { + Type fieldType = fields.get(i).type(); + + Object expectedValue = expected.get(i); + Object actualValue = actual.get(i, convert(fieldType)); + + assertEqualsUnsafe(fieldType, expectedValue, actualValue); + } + } + + private static void assertEqualsUnsafe(Types.ListType list, Collection expected, ArrayData actual) { + Type elementType = list.elementType(); + List expectedElements = Lists.newArrayList(expected); + for (int i = 0; i < expectedElements.size(); i += 1) { + Object expectedValue = expectedElements.get(i); + Object actualValue = actual.get(i, convert(elementType)); + + assertEqualsUnsafe(elementType, expectedValue, actualValue); + } + } + + private static void assertEqualsUnsafe(Types.MapType map, Map expected, MapData actual) { + Type keyType = map.keyType(); + Type valueType = map.valueType(); + + List> expectedElements = Lists.newArrayList(expected.entrySet()); + ArrayData actualKeys = actual.keyArray(); + ArrayData actualValues = actual.valueArray(); + + for (int i = 0; i < expectedElements.size(); i += 1) { + Map.Entry expectedPair = expectedElements.get(i); + Object actualKey = actualKeys.get(i, convert(keyType)); + Object actualValue = actualValues.get(i, convert(keyType)); + + assertEqualsUnsafe(keyType, expectedPair.getKey(), actualKey); + assertEqualsUnsafe(valueType, expectedPair.getValue(), actualValue); + } + } + + private static void assertEqualsUnsafe(Type type, Object expected, Object actual) { + if (expected == null && actual == null) { + return; + } + + switch (type.typeId()) { + case BOOLEAN: + case INTEGER: + case LONG: + case FLOAT: + case DOUBLE: + Assert.assertEquals("Primitive value should be equal to expected", expected, actual); + break; + case DATE: + Assertions.assertThat(expected).as("Should expect a LocalDate").isInstanceOf(LocalDate.class); + int expectedDays = (int) ChronoUnit.DAYS.between(EPOCH_DAY, (LocalDate) expected); + Assert.assertEquals("Primitive value should be equal to expected", expectedDays, actual); + break; + case TIMESTAMP: + Types.TimestampType timestampType = (Types.TimestampType) type; + if (timestampType.shouldAdjustToUTC()) { + Assertions.assertThat(expected).as("Should expect an OffsetDateTime").isInstanceOf(OffsetDateTime.class); + long expectedMicros = ChronoUnit.MICROS.between(EPOCH, (OffsetDateTime) expected); + Assert.assertEquals("Primitive value should be equal to expected", expectedMicros, actual); + } else { + Assertions.assertThat(expected).as("Should expect an LocalDateTime").isInstanceOf(LocalDateTime.class); + long expectedMicros = ChronoUnit.MICROS.between(EPOCH, ((LocalDateTime) expected).atZone(ZoneId.of("UTC"))); + Assert.assertEquals("Primitive value should be equal to expected", expectedMicros, actual); + } + break; + case STRING: + Assertions.assertThat(actual).as("Should be a UTF8String").isInstanceOf(UTF8String.class); + Assert.assertEquals("Strings should be equal", expected, actual.toString()); + break; + case UUID: + Assertions.assertThat(expected).as("Should expect a UUID").isInstanceOf(UUID.class); + Assertions.assertThat(actual).as("Should be a UTF8String").isInstanceOf(UTF8String.class); + Assert.assertEquals("UUID string representation should match", + expected.toString(), actual.toString()); + break; + case FIXED: + Assertions.assertThat(expected).as("Should expect a byte[]").isInstanceOf(byte[].class); + Assertions.assertThat(actual).as("Should be a byte[]").isInstanceOf(byte[].class); + Assert.assertArrayEquals("Bytes should match", (byte[]) expected, (byte[]) actual); + break; + case BINARY: + Assertions.assertThat(expected).as("Should expect a ByteBuffer").isInstanceOf(ByteBuffer.class); + Assertions.assertThat(actual).as("Should be a byte[]").isInstanceOf(byte[].class); + Assert.assertArrayEquals("Bytes should match", + ((ByteBuffer) expected).array(), (byte[]) actual); + break; + case DECIMAL: + Assertions.assertThat(expected).as("Should expect a BigDecimal").isInstanceOf(BigDecimal.class); + Assertions.assertThat(actual).as("Should be a Decimal").isInstanceOf(Decimal.class); + Assert.assertEquals("BigDecimals should be equal", + expected, ((Decimal) actual).toJavaBigDecimal()); + break; + case STRUCT: + Assertions.assertThat(expected).as("Should expect a Record").isInstanceOf(Record.class); + Assertions.assertThat(actual).as("Should be an InternalRow").isInstanceOf(InternalRow.class); + assertEqualsUnsafe(type.asNestedType().asStructType(), (Record) expected, (InternalRow) actual); + break; + case LIST: + Assertions.assertThat(expected).as("Should expect a Collection").isInstanceOf(Collection.class); + Assertions.assertThat(actual).as("Should be an ArrayData").isInstanceOf(ArrayData.class); + assertEqualsUnsafe(type.asNestedType().asListType(), (Collection) expected, (ArrayData) actual); + break; + case MAP: + Assertions.assertThat(expected).as("Should expect a Map").isInstanceOf(Map.class); + Assertions.assertThat(actual).as("Should be an ArrayBasedMapData").isInstanceOf(MapData.class); + assertEqualsUnsafe(type.asNestedType().asMapType(), (Map) expected, (MapData) actual); + break; + case TIME: + default: + throw new IllegalArgumentException("Not a supported type: " + type); + } + } +} diff --git a/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/data/RandomData.java b/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/data/RandomData.java new file mode 100644 index 000000000000..d3bffb75eb5c --- /dev/null +++ b/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/data/RandomData.java @@ -0,0 +1,363 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iceberg.spark.data; + +import java.math.BigDecimal; +import java.nio.ByteBuffer; +import java.util.Iterator; +import java.util.List; +import java.util.Map; +import java.util.NoSuchElementException; +import java.util.Random; +import java.util.Set; +import java.util.UUID; +import java.util.function.Supplier; +import org.apache.avro.generic.GenericData; +import org.apache.avro.generic.GenericData.Record; +import org.apache.iceberg.Schema; +import org.apache.iceberg.avro.AvroSchemaUtil; +import org.apache.iceberg.relocated.com.google.common.base.Preconditions; +import org.apache.iceberg.relocated.com.google.common.collect.Lists; +import org.apache.iceberg.relocated.com.google.common.collect.Maps; +import org.apache.iceberg.relocated.com.google.common.collect.Sets; +import org.apache.iceberg.types.Type; +import org.apache.iceberg.types.TypeUtil; +import org.apache.iceberg.types.Types; +import org.apache.iceberg.util.RandomUtil; +import org.apache.spark.sql.catalyst.InternalRow; +import org.apache.spark.sql.catalyst.expressions.GenericInternalRow; +import org.apache.spark.sql.catalyst.util.ArrayBasedMapData; +import org.apache.spark.sql.catalyst.util.GenericArrayData; +import org.apache.spark.sql.types.Decimal; +import org.apache.spark.unsafe.types.UTF8String; + +public class RandomData { + + // Default percentage of number of values that are null for optional fields + public static final float DEFAULT_NULL_PERCENTAGE = 0.05f; + + private RandomData() { + } + + public static List generateList(Schema schema, int numRecords, long seed) { + RandomDataGenerator generator = new RandomDataGenerator(schema, seed, DEFAULT_NULL_PERCENTAGE); + List records = Lists.newArrayListWithExpectedSize(numRecords); + for (int i = 0; i < numRecords; i += 1) { + records.add((Record) TypeUtil.visit(schema, generator)); + } + + return records; + } + + public static Iterable generateSpark(Schema schema, int numRecords, long seed) { + return () -> new Iterator() { + private SparkRandomDataGenerator generator = new SparkRandomDataGenerator(seed); + private int count = 0; + + @Override + public boolean hasNext() { + return count < numRecords; + } + + @Override + public InternalRow next() { + if (count >= numRecords) { + throw new NoSuchElementException(); + } + count += 1; + return (InternalRow) TypeUtil.visit(schema, generator); + } + }; + } + + public static Iterable generate(Schema schema, int numRecords, long seed) { + return newIterable(() -> new RandomDataGenerator(schema, seed, DEFAULT_NULL_PERCENTAGE), schema, numRecords); + } + + public static Iterable generate(Schema schema, int numRecords, long seed, float nullPercentage) { + return newIterable(() -> new RandomDataGenerator(schema, seed, nullPercentage), schema, numRecords); + } + + public static Iterable generateFallbackData(Schema schema, int numRecords, long seed, long numDictRecords) { + return newIterable(() -> new FallbackDataGenerator(schema, seed, numDictRecords), schema, numRecords); + } + + public static Iterable generateDictionaryEncodableData( + Schema schema, int numRecords, long seed, float nullPercentage) { + return newIterable(() -> new DictionaryEncodedDataGenerator(schema, seed, nullPercentage), schema, numRecords); + } + + private static Iterable newIterable(Supplier newGenerator, + Schema schema, int numRecords) { + return () -> new Iterator() { + private int count = 0; + private RandomDataGenerator generator = newGenerator.get(); + + @Override + public boolean hasNext() { + return count < numRecords; + } + + @Override + public Record next() { + if (count >= numRecords) { + throw new NoSuchElementException(); + } + count += 1; + return (Record) TypeUtil.visit(schema, generator); + } + }; + } + + private static class RandomDataGenerator extends TypeUtil.CustomOrderSchemaVisitor { + private final Map typeToSchema; + private final Random random; + // Percentage of number of values that are null for optional fields + private final float nullPercentage; + + private RandomDataGenerator(Schema schema, long seed, float nullPercentage) { + Preconditions.checkArgument( + 0.0f <= nullPercentage && nullPercentage <= 1.0f, + "Percentage needs to be in the range (0.0, 1.0)"); + this.nullPercentage = nullPercentage; + this.typeToSchema = AvroSchemaUtil.convertTypes(schema.asStruct(), "test"); + this.random = new Random(seed); + } + + @Override + public Record schema(Schema schema, Supplier structResult) { + return (Record) structResult.get(); + } + + @Override + public Record struct(Types.StructType struct, Iterable fieldResults) { + Record rec = new Record(typeToSchema.get(struct)); + + List values = Lists.newArrayList(fieldResults); + for (int i = 0; i < values.size(); i += 1) { + rec.put(i, values.get(i)); + } + + return rec; + } + + @Override + public Object field(Types.NestedField field, Supplier fieldResult) { + if (field.isOptional() && isNull()) { + return null; + } + return fieldResult.get(); + } + + private boolean isNull() { + return random.nextFloat() < nullPercentage; + } + + @Override + public Object list(Types.ListType list, Supplier elementResult) { + int numElements = random.nextInt(20); + + List result = Lists.newArrayListWithExpectedSize(numElements); + for (int i = 0; i < numElements; i += 1) { + if (list.isElementOptional() && isNull()) { + result.add(null); + } else { + result.add(elementResult.get()); + } + } + + return result; + } + + @Override + public Object map(Types.MapType map, Supplier keyResult, Supplier valueResult) { + int numEntries = random.nextInt(20); + + Map result = Maps.newLinkedHashMap(); + Set keySet = Sets.newHashSet(); + for (int i = 0; i < numEntries; i += 1) { + Object key = keyResult.get(); + // ensure no collisions + while (keySet.contains(key)) { + key = keyResult.get(); + } + + keySet.add(key); + + if (map.isValueOptional() && isNull()) { + result.put(key, null); + } else { + result.put(key, valueResult.get()); + } + } + + return result; + } + + @Override + public Object primitive(Type.PrimitiveType primitive) { + Object result = randomValue(primitive, random); + // For the primitives that Avro needs a different type than Spark, fix + // them here. + switch (primitive.typeId()) { + case FIXED: + return new GenericData.Fixed(typeToSchema.get(primitive), + (byte[]) result); + case BINARY: + return ByteBuffer.wrap((byte[]) result); + case UUID: + return UUID.nameUUIDFromBytes((byte[]) result); + default: + return result; + } + } + + protected Object randomValue(Type.PrimitiveType primitive, Random rand) { + return RandomUtil.generatePrimitive(primitive, random); + } + } + + private static class SparkRandomDataGenerator extends TypeUtil.CustomOrderSchemaVisitor { + private final Random random; + + private SparkRandomDataGenerator(long seed) { + this.random = new Random(seed); + } + + @Override + public InternalRow schema(Schema schema, Supplier structResult) { + return (InternalRow) structResult.get(); + } + + @Override + public InternalRow struct(Types.StructType struct, Iterable fieldResults) { + List values = Lists.newArrayList(fieldResults); + GenericInternalRow row = new GenericInternalRow(values.size()); + for (int i = 0; i < values.size(); i += 1) { + row.update(i, values.get(i)); + } + + return row; + } + + @Override + public Object field(Types.NestedField field, Supplier fieldResult) { + // return null 5% of the time when the value is optional + if (field.isOptional() && random.nextInt(20) == 1) { + return null; + } + return fieldResult.get(); + } + + @Override + public GenericArrayData list(Types.ListType list, Supplier elementResult) { + int numElements = random.nextInt(20); + Object[] arr = new Object[numElements]; + GenericArrayData result = new GenericArrayData(arr); + + for (int i = 0; i < numElements; i += 1) { + // return null 5% of the time when the value is optional + if (list.isElementOptional() && random.nextInt(20) == 1) { + arr[i] = null; + } else { + arr[i] = elementResult.get(); + } + } + + return result; + } + + @Override + public Object map(Types.MapType map, Supplier keyResult, Supplier valueResult) { + int numEntries = random.nextInt(20); + + Object[] keysArr = new Object[numEntries]; + Object[] valuesArr = new Object[numEntries]; + GenericArrayData keys = new GenericArrayData(keysArr); + GenericArrayData values = new GenericArrayData(valuesArr); + ArrayBasedMapData result = new ArrayBasedMapData(keys, values); + + Set keySet = Sets.newHashSet(); + for (int i = 0; i < numEntries; i += 1) { + Object key = keyResult.get(); + // ensure no collisions + while (keySet.contains(key)) { + key = keyResult.get(); + } + + keySet.add(key); + + keysArr[i] = key; + // return null 5% of the time when the value is optional + if (map.isValueOptional() && random.nextInt(20) == 1) { + valuesArr[i] = null; + } else { + valuesArr[i] = valueResult.get(); + } + } + + return result; + } + + @Override + public Object primitive(Type.PrimitiveType primitive) { + Object obj = RandomUtil.generatePrimitive(primitive, random); + switch (primitive.typeId()) { + case STRING: + return UTF8String.fromString((String) obj); + case DECIMAL: + return Decimal.apply((BigDecimal) obj); + default: + return obj; + } + } + } + + private static class DictionaryEncodedDataGenerator extends RandomDataGenerator { + private DictionaryEncodedDataGenerator(Schema schema, long seed, float nullPercentage) { + super(schema, seed, nullPercentage); + } + + @Override + protected Object randomValue(Type.PrimitiveType primitive, Random random) { + return RandomUtil.generateDictionaryEncodablePrimitive(primitive, random); + } + } + + private static class FallbackDataGenerator extends RandomDataGenerator { + private final long dictionaryEncodedRows; + private long rowCount = 0; + + private FallbackDataGenerator(Schema schema, long seed, long numDictionaryEncoded) { + super(schema, seed, DEFAULT_NULL_PERCENTAGE); + this.dictionaryEncodedRows = numDictionaryEncoded; + } + + @Override + protected Object randomValue(Type.PrimitiveType primitive, Random rand) { + this.rowCount += 1; + if (rowCount > dictionaryEncodedRows) { + return RandomUtil.generatePrimitive(primitive, rand); + } else { + return RandomUtil.generateDictionaryEncodablePrimitive(primitive, rand); + } + } + } +} diff --git a/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/data/TestHelpers.java b/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/data/TestHelpers.java new file mode 100644 index 000000000000..4aed78d1e155 --- /dev/null +++ b/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/data/TestHelpers.java @@ -0,0 +1,736 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iceberg.spark.data; + +import java.math.BigDecimal; +import java.nio.ByteBuffer; +import java.sql.Timestamp; +import java.time.Instant; +import java.time.LocalDate; +import java.time.OffsetDateTime; +import java.time.ZoneOffset; +import java.time.temporal.ChronoUnit; +import java.util.Collection; +import java.util.Date; +import java.util.Iterator; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.UUID; +import org.apache.arrow.vector.ValueVector; +import org.apache.avro.generic.GenericData; +import org.apache.avro.generic.GenericData.Record; +import org.apache.iceberg.DataFile; +import org.apache.iceberg.DeleteFile; +import org.apache.iceberg.FileScanTask; +import org.apache.iceberg.ManifestFile; +import org.apache.iceberg.Schema; +import org.apache.iceberg.Table; +import org.apache.iceberg.relocated.com.google.common.collect.Lists; +import org.apache.iceberg.relocated.com.google.common.collect.Sets; +import org.apache.iceberg.relocated.com.google.common.collect.Streams; +import org.apache.iceberg.spark.data.vectorized.IcebergArrowColumnVector; +import org.apache.iceberg.types.Type; +import org.apache.iceberg.types.Types; +import org.apache.orc.storage.serde2.io.DateWritable; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.catalyst.InternalRow; +import org.apache.spark.sql.catalyst.expressions.GenericRow; +import org.apache.spark.sql.catalyst.expressions.SpecializedGetters; +import org.apache.spark.sql.catalyst.util.ArrayData; +import org.apache.spark.sql.catalyst.util.DateTimeUtils; +import org.apache.spark.sql.catalyst.util.MapData; +import org.apache.spark.sql.types.ArrayType; +import org.apache.spark.sql.types.BinaryType; +import org.apache.spark.sql.types.DataType; +import org.apache.spark.sql.types.Decimal; +import org.apache.spark.sql.types.MapType; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; +import org.apache.spark.sql.vectorized.ColumnVector; +import org.apache.spark.sql.vectorized.ColumnarBatch; +import org.apache.spark.unsafe.types.UTF8String; +import org.assertj.core.api.Assertions; +import org.junit.Assert; +import scala.collection.Seq; + +import static org.apache.iceberg.spark.SparkSchemaUtil.convert; +import static scala.collection.JavaConverters.mapAsJavaMapConverter; +import static scala.collection.JavaConverters.seqAsJavaListConverter; + +public class TestHelpers { + + private TestHelpers() { + } + + public static void assertEqualsSafe(Types.StructType struct, List recs, List rows) { + Streams.forEachPair(recs.stream(), rows.stream(), (rec, row) -> assertEqualsSafe(struct, rec, row)); + } + + public static void assertEqualsSafe(Types.StructType struct, Record rec, Row row) { + List fields = struct.fields(); + for (int i = 0; i < fields.size(); i += 1) { + Type fieldType = fields.get(i).type(); + + Object expectedValue = rec.get(i); + Object actualValue = row.get(i); + + assertEqualsSafe(fieldType, expectedValue, actualValue); + } + } + + public static void assertEqualsBatch(Types.StructType struct, Iterator expected, ColumnarBatch batch, + boolean checkArrowValidityVector) { + for (int rowId = 0; rowId < batch.numRows(); rowId++) { + List fields = struct.fields(); + InternalRow row = batch.getRow(rowId); + Record rec = expected.next(); + for (int i = 0; i < fields.size(); i += 1) { + Type fieldType = fields.get(i).type(); + Object expectedValue = rec.get(i); + Object actualValue = row.isNullAt(i) ? null : row.get(i, convert(fieldType)); + assertEqualsUnsafe(fieldType, expectedValue, actualValue); + + if (checkArrowValidityVector) { + ColumnVector columnVector = batch.column(i); + ValueVector arrowVector = ((IcebergArrowColumnVector) columnVector).vectorAccessor().getVector(); + Assert.assertFalse("Nullability doesn't match of " + columnVector.dataType(), + expectedValue == null ^ arrowVector.isNull(rowId)); + } + } + } + } + + + private static void assertEqualsSafe(Types.ListType list, Collection expected, List actual) { + Type elementType = list.elementType(); + List expectedElements = Lists.newArrayList(expected); + for (int i = 0; i < expectedElements.size(); i += 1) { + Object expectedValue = expectedElements.get(i); + Object actualValue = actual.get(i); + + assertEqualsSafe(elementType, expectedValue, actualValue); + } + } + + private static void assertEqualsSafe(Types.MapType map, + Map expected, Map actual) { + Type keyType = map.keyType(); + Type valueType = map.valueType(); + + for (Object expectedKey : expected.keySet()) { + Object matchingKey = null; + for (Object actualKey : actual.keySet()) { + try { + assertEqualsSafe(keyType, expectedKey, actualKey); + matchingKey = actualKey; + } catch (AssertionError e) { + // failed + } + } + + Assert.assertNotNull("Should have a matching key", matchingKey); + assertEqualsSafe(valueType, expected.get(expectedKey), actual.get(matchingKey)); + } + } + + private static final OffsetDateTime EPOCH = Instant.ofEpochMilli(0L).atOffset(ZoneOffset.UTC); + private static final LocalDate EPOCH_DAY = EPOCH.toLocalDate(); + + @SuppressWarnings("unchecked") + private static void assertEqualsSafe(Type type, Object expected, Object actual) { + if (expected == null && actual == null) { + return; + } + + switch (type.typeId()) { + case BOOLEAN: + case INTEGER: + case LONG: + case FLOAT: + case DOUBLE: + Assert.assertEquals("Primitive value should be equal to expected", expected, actual); + break; + case DATE: + Assertions.assertThat(expected).as("Should be an int").isInstanceOf(Integer.class); + Assertions.assertThat(actual).as("Should be a Date").isInstanceOf(Date.class); + int daysFromEpoch = (Integer) expected; + LocalDate date = ChronoUnit.DAYS.addTo(EPOCH_DAY, daysFromEpoch); + Assert.assertEquals("ISO-8601 date should be equal", date.toString(), actual.toString()); + break; + case TIMESTAMP: + Assertions.assertThat(expected).as("Should be a long").isInstanceOf(Long.class); + Assertions.assertThat(actual).as("Should be a Timestamp").isInstanceOf(Timestamp.class); + Timestamp ts = (Timestamp) actual; + // milliseconds from nanos has already been added by getTime + long tsMicros = (ts.getTime() * 1000) + ((ts.getNanos() / 1000) % 1000); + Assert.assertEquals("Timestamp micros should be equal", expected, tsMicros); + break; + case STRING: + Assertions.assertThat(actual).as("Should be a String").isInstanceOf(String.class); + Assert.assertEquals("Strings should be equal", String.valueOf(expected), actual); + break; + case UUID: + Assertions.assertThat(expected).as("Should expect a UUID").isInstanceOf(UUID.class); + Assertions.assertThat(actual).as("Should be a String").isInstanceOf(String.class); + Assert.assertEquals("UUID string representation should match", + expected.toString(), actual); + break; + case FIXED: + Assertions.assertThat(expected).as("Should expect a Fixed").isInstanceOf(GenericData.Fixed.class); + Assertions.assertThat(actual).as("Should be a byte[]").isInstanceOf(byte[].class); + Assert.assertArrayEquals("Bytes should match", + ((GenericData.Fixed) expected).bytes(), (byte[]) actual); + break; + case BINARY: + Assertions.assertThat(expected).as("Should expect a ByteBuffer").isInstanceOf(ByteBuffer.class); + Assertions.assertThat(actual).as("Should be a byte[]").isInstanceOf(byte[].class); + Assert.assertArrayEquals("Bytes should match", + ((ByteBuffer) expected).array(), (byte[]) actual); + break; + case DECIMAL: + Assertions.assertThat(expected).as("Should expect a BigDecimal").isInstanceOf(BigDecimal.class); + Assertions.assertThat(actual).as("Should be a BigDecimal").isInstanceOf(BigDecimal.class); + Assert.assertEquals("BigDecimals should be equal", expected, actual); + break; + case STRUCT: + Assertions.assertThat(expected).as("Should expect a Record").isInstanceOf(Record.class); + Assertions.assertThat(actual).as("Should be a Row").isInstanceOf(Row.class); + assertEqualsSafe(type.asNestedType().asStructType(), (Record) expected, (Row) actual); + break; + case LIST: + Assertions.assertThat(expected).as("Should expect a Collection").isInstanceOf(Collection.class); + Assertions.assertThat(actual).as("Should be a Seq").isInstanceOf(Seq.class); + List asList = seqAsJavaListConverter((Seq) actual).asJava(); + assertEqualsSafe(type.asNestedType().asListType(), (Collection) expected, asList); + break; + case MAP: + Assertions.assertThat(expected).as("Should expect a Collection").isInstanceOf(Map.class); + Assertions.assertThat(actual).as("Should be a Map").isInstanceOf(scala.collection.Map.class); + Map asMap = mapAsJavaMapConverter( + (scala.collection.Map) actual).asJava(); + assertEqualsSafe(type.asNestedType().asMapType(), (Map) expected, asMap); + break; + case TIME: + default: + throw new IllegalArgumentException("Not a supported type: " + type); + } + } + + public static void assertEqualsUnsafe(Types.StructType struct, Record rec, InternalRow row) { + List fields = struct.fields(); + for (int i = 0; i < fields.size(); i += 1) { + Type fieldType = fields.get(i).type(); + + Object expectedValue = rec.get(i); + Object actualValue = row.isNullAt(i) ? null : row.get(i, convert(fieldType)); + + assertEqualsUnsafe(fieldType, expectedValue, actualValue); + } + } + + private static void assertEqualsUnsafe(Types.ListType list, Collection expected, ArrayData actual) { + Type elementType = list.elementType(); + List expectedElements = Lists.newArrayList(expected); + for (int i = 0; i < expectedElements.size(); i += 1) { + Object expectedValue = expectedElements.get(i); + Object actualValue = actual.get(i, convert(elementType)); + + assertEqualsUnsafe(elementType, expectedValue, actualValue); + } + } + + private static void assertEqualsUnsafe(Types.MapType map, Map expected, MapData actual) { + Type keyType = map.keyType(); + Type valueType = map.valueType(); + + List> expectedElements = Lists.newArrayList(expected.entrySet()); + ArrayData actualKeys = actual.keyArray(); + ArrayData actualValues = actual.valueArray(); + + for (int i = 0; i < expectedElements.size(); i += 1) { + Map.Entry expectedPair = expectedElements.get(i); + Object actualKey = actualKeys.get(i, convert(keyType)); + Object actualValue = actualValues.get(i, convert(keyType)); + + assertEqualsUnsafe(keyType, expectedPair.getKey(), actualKey); + assertEqualsUnsafe(valueType, expectedPair.getValue(), actualValue); + } + } + + private static void assertEqualsUnsafe(Type type, Object expected, Object actual) { + if (expected == null && actual == null) { + return; + } + + switch (type.typeId()) { + case LONG: + Assertions.assertThat(actual).as("Should be a long").isInstanceOf(Long.class); + if (expected instanceof Integer) { + Assert.assertEquals("Values didn't match", ((Number) expected).longValue(), actual); + } else { + Assert.assertEquals("Primitive value should be equal to expected", expected, actual); + } + break; + case DOUBLE: + Assertions.assertThat(actual).as("Should be a double").isInstanceOf(Double.class); + if (expected instanceof Float) { + Assert.assertEquals("Values didn't match", Double.doubleToLongBits(((Number) expected).doubleValue()), + Double.doubleToLongBits((double) actual)); + } else { + Assert.assertEquals("Primitive value should be equal to expected", expected, actual); + } + break; + case INTEGER: + case FLOAT: + case BOOLEAN: + case DATE: + case TIMESTAMP: + Assert.assertEquals("Primitive value should be equal to expected", expected, actual); + break; + case STRING: + Assertions.assertThat(actual).as("Should be a UTF8String").isInstanceOf(UTF8String.class); + Assert.assertEquals("Strings should be equal", expected, actual.toString()); + break; + case UUID: + Assertions.assertThat(expected).as("Should expect a UUID").isInstanceOf(UUID.class); + Assertions.assertThat(actual).as("Should be a UTF8String").isInstanceOf(UTF8String.class); + Assert.assertEquals("UUID string representation should match", + expected.toString(), actual.toString()); + break; + case FIXED: + Assertions.assertThat(expected).as("Should expect a Fixed").isInstanceOf(GenericData.Fixed.class); + Assertions.assertThat(actual).as("Should be a byte[]").isInstanceOf(byte[].class); + Assert.assertArrayEquals("Bytes should match", + ((GenericData.Fixed) expected).bytes(), (byte[]) actual); + break; + case BINARY: + Assertions.assertThat(expected).as("Should expect a ByteBuffer").isInstanceOf(ByteBuffer.class); + Assertions.assertThat(actual).as("Should be a byte[]").isInstanceOf(byte[].class); + Assert.assertArrayEquals("Bytes should match", + ((ByteBuffer) expected).array(), (byte[]) actual); + break; + case DECIMAL: + Assertions.assertThat(expected).as("Should expect a BigDecimal").isInstanceOf(BigDecimal.class); + Assertions.assertThat(actual).as("Should be a Decimal").isInstanceOf(Decimal.class); + Assert.assertEquals("BigDecimals should be equal", + expected, ((Decimal) actual).toJavaBigDecimal()); + break; + case STRUCT: + Assertions.assertThat(expected).as("Should expect a Record").isInstanceOf(Record.class); + Assertions.assertThat(actual).as("Should be an InternalRow").isInstanceOf(InternalRow.class); + assertEqualsUnsafe(type.asNestedType().asStructType(), (Record) expected, (InternalRow) actual); + break; + case LIST: + Assertions.assertThat(expected).as("Should expect a Collection").isInstanceOf(Collection.class); + Assertions.assertThat(actual).as("Should be an ArrayData").isInstanceOf(ArrayData.class); + assertEqualsUnsafe(type.asNestedType().asListType(), (Collection) expected, (ArrayData) actual); + break; + case MAP: + Assertions.assertThat(expected).as("Should expect a Map").isInstanceOf(Map.class); + Assertions.assertThat(actual).as("Should be an ArrayBasedMapData").isInstanceOf(MapData.class); + assertEqualsUnsafe(type.asNestedType().asMapType(), (Map) expected, (MapData) actual); + break; + case TIME: + default: + throw new IllegalArgumentException("Not a supported type: " + type); + } + } + + /** + * Check that the given InternalRow is equivalent to the Row. + * @param prefix context for error messages + * @param type the type of the row + * @param expected the expected value of the row + * @param actual the actual value of the row + */ + public static void assertEquals(String prefix, Types.StructType type, + InternalRow expected, Row actual) { + if (expected == null || actual == null) { + Assert.assertEquals(prefix, expected, actual); + } else { + List fields = type.fields(); + for (int c = 0; c < fields.size(); ++c) { + String fieldName = fields.get(c).name(); + Type childType = fields.get(c).type(); + switch (childType.typeId()) { + case BOOLEAN: + case INTEGER: + case LONG: + case FLOAT: + case DOUBLE: + case STRING: + case DECIMAL: + case DATE: + case TIMESTAMP: + Assert.assertEquals(prefix + "." + fieldName + " - " + childType, + getValue(expected, c, childType), + getPrimitiveValue(actual, c, childType)); + break; + case UUID: + case FIXED: + case BINARY: + assertEqualBytes(prefix + "." + fieldName, + (byte[]) getValue(expected, c, childType), + (byte[]) actual.get(c)); + break; + case STRUCT: { + Types.StructType st = (Types.StructType) childType; + assertEquals(prefix + "." + fieldName, st, + expected.getStruct(c, st.fields().size()), actual.getStruct(c)); + break; + } + case LIST: + assertEqualsLists(prefix + "." + fieldName, childType.asListType(), + expected.getArray(c), + toList((Seq) actual.get(c))); + break; + case MAP: + assertEqualsMaps(prefix + "." + fieldName, childType.asMapType(), expected.getMap(c), + toJavaMap((scala.collection.Map) actual.getMap(c))); + break; + default: + throw new IllegalArgumentException("Unhandled type " + childType); + } + } + } + } + + private static void assertEqualsLists(String prefix, Types.ListType type, + ArrayData expected, List actual) { + if (expected == null || actual == null) { + Assert.assertEquals(prefix, expected, actual); + } else { + Assert.assertEquals(prefix + " length", expected.numElements(), actual.size()); + Type childType = type.elementType(); + for (int e = 0; e < expected.numElements(); ++e) { + switch (childType.typeId()) { + case BOOLEAN: + case INTEGER: + case LONG: + case FLOAT: + case DOUBLE: + case STRING: + case DECIMAL: + case DATE: + case TIMESTAMP: + Assert.assertEquals(prefix + ".elem " + e + " - " + childType, + getValue(expected, e, childType), + actual.get(e)); + break; + case UUID: + case FIXED: + case BINARY: + assertEqualBytes(prefix + ".elem " + e, + (byte[]) getValue(expected, e, childType), + (byte[]) actual.get(e)); + break; + case STRUCT: { + Types.StructType st = (Types.StructType) childType; + assertEquals(prefix + ".elem " + e, st, + expected.getStruct(e, st.fields().size()), (Row) actual.get(e)); + break; + } + case LIST: + assertEqualsLists(prefix + ".elem " + e, childType.asListType(), + expected.getArray(e), + toList((Seq) actual.get(e))); + break; + case MAP: + assertEqualsMaps(prefix + ".elem " + e, childType.asMapType(), + expected.getMap(e), toJavaMap((scala.collection.Map) actual.get(e))); + break; + default: + throw new IllegalArgumentException("Unhandled type " + childType); + } + } + } + } + + private static void assertEqualsMaps(String prefix, Types.MapType type, + MapData expected, Map actual) { + if (expected == null || actual == null) { + Assert.assertEquals(prefix, expected, actual); + } else { + Type keyType = type.keyType(); + Type valueType = type.valueType(); + ArrayData expectedKeyArray = expected.keyArray(); + ArrayData expectedValueArray = expected.valueArray(); + Assert.assertEquals(prefix + " length", expected.numElements(), actual.size()); + for (int e = 0; e < expected.numElements(); ++e) { + Object expectedKey = getValue(expectedKeyArray, e, keyType); + Object actualValue = actual.get(expectedKey); + if (actualValue == null) { + Assert.assertEquals(prefix + ".key=" + expectedKey + " has null", true, + expected.valueArray().isNullAt(e)); + } else { + switch (valueType.typeId()) { + case BOOLEAN: + case INTEGER: + case LONG: + case FLOAT: + case DOUBLE: + case STRING: + case DECIMAL: + case DATE: + case TIMESTAMP: + Assert.assertEquals(prefix + ".key=" + expectedKey + " - " + valueType, + getValue(expectedValueArray, e, valueType), + actual.get(expectedKey)); + break; + case UUID: + case FIXED: + case BINARY: + assertEqualBytes(prefix + ".key=" + expectedKey, + (byte[]) getValue(expectedValueArray, e, valueType), + (byte[]) actual.get(expectedKey)); + break; + case STRUCT: { + Types.StructType st = (Types.StructType) valueType; + assertEquals(prefix + ".key=" + expectedKey, st, + expectedValueArray.getStruct(e, st.fields().size()), + (Row) actual.get(expectedKey)); + break; + } + case LIST: + assertEqualsLists(prefix + ".key=" + expectedKey, + valueType.asListType(), + expectedValueArray.getArray(e), + toList((Seq) actual.get(expectedKey))); + break; + case MAP: + assertEqualsMaps(prefix + ".key=" + expectedKey, valueType.asMapType(), + expectedValueArray.getMap(e), + toJavaMap((scala.collection.Map) actual.get(expectedKey))); + break; + default: + throw new IllegalArgumentException("Unhandled type " + valueType); + } + } + } + } + } + + private static Object getValue(SpecializedGetters container, int ord, + Type type) { + if (container.isNullAt(ord)) { + return null; + } + switch (type.typeId()) { + case BOOLEAN: + return container.getBoolean(ord); + case INTEGER: + return container.getInt(ord); + case LONG: + return container.getLong(ord); + case FLOAT: + return container.getFloat(ord); + case DOUBLE: + return container.getDouble(ord); + case STRING: + return container.getUTF8String(ord).toString(); + case BINARY: + case FIXED: + case UUID: + return container.getBinary(ord); + case DATE: + return new DateWritable(container.getInt(ord)).get(); + case TIMESTAMP: + return DateTimeUtils.toJavaTimestamp(container.getLong(ord)); + case DECIMAL: { + Types.DecimalType dt = (Types.DecimalType) type; + return container.getDecimal(ord, dt.precision(), dt.scale()).toJavaBigDecimal(); + } + case STRUCT: + Types.StructType struct = type.asStructType(); + InternalRow internalRow = container.getStruct(ord, struct.fields().size()); + Object[] data = new Object[struct.fields().size()]; + for (int i = 0; i < data.length; i += 1) { + if (internalRow.isNullAt(i)) { + data[i] = null; + } else { + data[i] = getValue(internalRow, i, struct.fields().get(i).type()); + } + } + return new GenericRow(data); + default: + throw new IllegalArgumentException("Unhandled type " + type); + } + } + + private static Object getPrimitiveValue(Row row, int ord, Type type) { + if (row.isNullAt(ord)) { + return null; + } + switch (type.typeId()) { + case BOOLEAN: + return row.getBoolean(ord); + case INTEGER: + return row.getInt(ord); + case LONG: + return row.getLong(ord); + case FLOAT: + return row.getFloat(ord); + case DOUBLE: + return row.getDouble(ord); + case STRING: + return row.getString(ord); + case BINARY: + case FIXED: + case UUID: + return row.get(ord); + case DATE: + return row.getDate(ord); + case TIMESTAMP: + return row.getTimestamp(ord); + case DECIMAL: + return row.getDecimal(ord); + default: + throw new IllegalArgumentException("Unhandled type " + type); + } + } + + private static Map toJavaMap(scala.collection.Map map) { + return map == null ? null : mapAsJavaMapConverter(map).asJava(); + } + + private static List toList(Seq val) { + return val == null ? null : seqAsJavaListConverter(val).asJava(); + } + + private static void assertEqualBytes(String context, byte[] expected, + byte[] actual) { + if (expected == null || actual == null) { + Assert.assertEquals(context, expected, actual); + } else { + Assert.assertArrayEquals(context, expected, actual); + } + } + + static void assertEquals(Schema schema, Object expected, Object actual) { + assertEquals("schema", convert(schema), expected, actual); + } + + private static void assertEquals(String context, DataType type, Object expected, Object actual) { + if (expected == null && actual == null) { + return; + } + + if (type instanceof StructType) { + Assertions.assertThat(expected).as("Expected should be an InternalRow: " + context) + .isInstanceOf(InternalRow.class); + Assertions.assertThat(actual).as("Actual should be an InternalRow: " + context) + .isInstanceOf(InternalRow.class); + assertEquals(context, (StructType) type, (InternalRow) expected, (InternalRow) actual); + + } else if (type instanceof ArrayType) { + Assertions.assertThat(expected).as("Expected should be an ArrayData: " + context) + .isInstanceOf(ArrayData.class); + Assertions.assertThat(actual).as("Actual should be an ArrayData: " + context) + .isInstanceOf(ArrayData.class); + assertEquals(context, (ArrayType) type, (ArrayData) expected, (ArrayData) actual); + + } else if (type instanceof MapType) { + Assertions.assertThat(expected).as("Expected should be a MapData: " + context) + .isInstanceOf(MapData.class); + Assertions.assertThat(actual).as("Actual should be a MapData: " + context) + .isInstanceOf(MapData.class); + assertEquals(context, (MapType) type, (MapData) expected, (MapData) actual); + + } else if (type instanceof BinaryType) { + assertEqualBytes(context, (byte[]) expected, (byte[]) actual); + } else { + Assert.assertEquals("Value should match expected: " + context, expected, actual); + } + } + + private static void assertEquals(String context, StructType struct, + InternalRow expected, InternalRow actual) { + Assert.assertEquals("Should have correct number of fields", struct.size(), actual.numFields()); + for (int i = 0; i < actual.numFields(); i += 1) { + StructField field = struct.fields()[i]; + DataType type = field.dataType(); + assertEquals(context + "." + field.name(), type, + expected.isNullAt(i) ? null : expected.get(i, type), + actual.isNullAt(i) ? null : actual.get(i, type)); + } + } + + private static void assertEquals(String context, ArrayType array, ArrayData expected, ArrayData actual) { + Assert.assertEquals("Should have the same number of elements", + expected.numElements(), actual.numElements()); + DataType type = array.elementType(); + for (int i = 0; i < actual.numElements(); i += 1) { + assertEquals(context + ".element", type, + expected.isNullAt(i) ? null : expected.get(i, type), + actual.isNullAt(i) ? null : actual.get(i, type)); + } + } + + private static void assertEquals(String context, MapType map, MapData expected, MapData actual) { + Assert.assertEquals("Should have the same number of elements", + expected.numElements(), actual.numElements()); + + DataType keyType = map.keyType(); + ArrayData expectedKeys = expected.keyArray(); + ArrayData expectedValues = expected.valueArray(); + + DataType valueType = map.valueType(); + ArrayData actualKeys = actual.keyArray(); + ArrayData actualValues = actual.valueArray(); + + for (int i = 0; i < actual.numElements(); i += 1) { + assertEquals(context + ".key", keyType, + expectedKeys.isNullAt(i) ? null : expectedKeys.get(i, keyType), + actualKeys.isNullAt(i) ? null : actualKeys.get(i, keyType)); + assertEquals(context + ".value", valueType, + expectedValues.isNullAt(i) ? null : expectedValues.get(i, valueType), + actualValues.isNullAt(i) ? null : actualValues.get(i, valueType)); + } + } + + public static List dataManifests(Table table) { + return table.currentSnapshot().dataManifests(table.io()); + } + + public static List deleteManifests(Table table) { + return table.currentSnapshot().deleteManifests(table.io()); + } + + public static Set dataFiles(Table table) { + Set dataFiles = Sets.newHashSet(); + + for (FileScanTask task : table.newScan().planFiles()) { + dataFiles.add(task.file()); + } + + return dataFiles; + } + + public static Set deleteFiles(Table table) { + Set deleteFiles = Sets.newHashSet(); + + for (FileScanTask task : table.newScan().planFiles()) { + deleteFiles.addAll(task.deletes()); + } + + return deleteFiles; + } +} diff --git a/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/data/TestOrcWrite.java b/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/data/TestOrcWrite.java new file mode 100644 index 000000000000..7cf9b9c736c6 --- /dev/null +++ b/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/data/TestOrcWrite.java @@ -0,0 +1,61 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iceberg.spark.data; + +import java.io.File; +import java.io.IOException; +import org.apache.iceberg.Files; +import org.apache.iceberg.Schema; +import org.apache.iceberg.io.FileAppender; +import org.apache.iceberg.orc.ORC; +import org.apache.iceberg.types.Types; +import org.apache.spark.sql.catalyst.InternalRow; +import org.junit.Assert; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.TemporaryFolder; + +import static org.apache.iceberg.types.Types.NestedField.optional; + +public class TestOrcWrite { + @Rule + public TemporaryFolder temp = new TemporaryFolder(); + + private static final Schema SCHEMA = new Schema( + optional(1, "id", Types.IntegerType.get()), + optional(2, "data", Types.StringType.get()) + ); + + @Test + public void splitOffsets() throws IOException { + File testFile = temp.newFile(); + Assert.assertTrue("Delete should succeed", testFile.delete()); + + Iterable rows = RandomData.generateSpark(SCHEMA, 1, 0L); + FileAppender writer = ORC.write(Files.localOutput(testFile)) + .createWriterFunc(SparkOrcWriter::new) + .schema(SCHEMA) + .build(); + + writer.addAll(rows); + writer.close(); + Assert.assertNotNull("Split offsets not present", writer.splitOffsets()); + } +} diff --git a/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/data/TestParquetAvroReader.java b/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/data/TestParquetAvroReader.java new file mode 100644 index 000000000000..464e3165583c --- /dev/null +++ b/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/data/TestParquetAvroReader.java @@ -0,0 +1,220 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iceberg.spark.data; + +import java.io.File; +import java.io.IOException; +import java.util.Iterator; +import org.apache.avro.generic.GenericData.Record; +import org.apache.iceberg.Files; +import org.apache.iceberg.Schema; +import org.apache.iceberg.io.CloseableIterable; +import org.apache.iceberg.io.FileAppender; +import org.apache.iceberg.parquet.Parquet; +import org.apache.iceberg.parquet.ParquetAvroValueReaders; +import org.apache.iceberg.parquet.ParquetSchemaUtil; +import org.apache.iceberg.types.Types; +import org.apache.parquet.schema.MessageType; +import org.junit.Assert; +import org.junit.Ignore; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.TemporaryFolder; + +import static org.apache.iceberg.types.Types.NestedField.optional; +import static org.apache.iceberg.types.Types.NestedField.required; + +public class TestParquetAvroReader { + @Rule + public TemporaryFolder temp = new TemporaryFolder(); + + private static final Schema COMPLEX_SCHEMA = new Schema( + required(1, "roots", Types.LongType.get()), + optional(3, "lime", Types.ListType.ofRequired(4, Types.DoubleType.get())), + required(5, "strict", Types.StructType.of( + required(9, "tangerine", Types.StringType.get()), + optional(6, "hopeful", Types.StructType.of( + required(7, "steel", Types.FloatType.get()), + required(8, "lantern", Types.DateType.get()) + )), + optional(10, "vehement", Types.LongType.get()) + )), + optional(11, "metamorphosis", Types.MapType.ofRequired(12, 13, + Types.StringType.get(), Types.TimestampType.withoutZone())), + required(14, "winter", Types.ListType.ofOptional(15, Types.StructType.of( + optional(16, "beet", Types.DoubleType.get()), + required(17, "stamp", Types.TimeType.get()), + optional(18, "wheeze", Types.StringType.get()) + ))), + optional(19, "renovate", Types.MapType.ofRequired(20, 21, + Types.StringType.get(), Types.StructType.of( + optional(22, "jumpy", Types.DoubleType.get()), + required(23, "koala", Types.TimeType.get()), + required(24, "couch rope", Types.IntegerType.get()) + ))), + optional(2, "slide", Types.StringType.get()) + ); + + @Ignore + public void testStructSchema() throws IOException { + Schema structSchema = new Schema( + required(1, "circumvent", Types.LongType.get()), + optional(2, "antarctica", Types.StringType.get()), + optional(3, "fluent", Types.DoubleType.get()), + required(4, "quell", Types.StructType.of( + required(5, "operator", Types.BooleanType.get()), + optional(6, "fanta", Types.IntegerType.get()), + optional(7, "cable", Types.FloatType.get()) + )), + required(8, "chimney", Types.TimestampType.withZone()), + required(9, "wool", Types.DateType.get()) + ); + + File testFile = writeTestData(structSchema, 5_000_000, 1059); + // RandomData uses the root record name "test", which must match for records to be equal + MessageType readSchema = ParquetSchemaUtil.convert(structSchema, "test"); + + long sum = 0; + long sumSq = 0; + int warmups = 2; + int trials = 10; + + for (int i = 0; i < warmups + trials; i += 1) { + // clean up as much memory as possible to avoid a large GC during the timed run + System.gc(); + + try (CloseableIterable reader = Parquet.read(Files.localInput(testFile)) + .project(structSchema) + .createReaderFunc( + fileSchema -> ParquetAvroValueReaders.buildReader(structSchema, readSchema)) + .build()) { + long start = System.currentTimeMillis(); + long val = 0; + long count = 0; + for (Record record : reader) { + // access something to ensure the compiler doesn't optimize this away + val ^= (Long) record.get(0); + count += 1; + } + long end = System.currentTimeMillis(); + long duration = end - start; + + if (i >= warmups) { + sum += duration; + sumSq += duration * duration; + } + } + } + + double mean = ((double) sum) / trials; + double stddev = Math.sqrt((((double) sumSq) / trials) - (mean * mean)); + } + + @Ignore + public void testWithOldReadPath() throws IOException { + File testFile = writeTestData(COMPLEX_SCHEMA, 500_000, 1985); + // RandomData uses the root record name "test", which must match for records to be equal + MessageType readSchema = ParquetSchemaUtil.convert(COMPLEX_SCHEMA, "test"); + + for (int i = 0; i < 5; i += 1) { + // clean up as much memory as possible to avoid a large GC during the timed run + System.gc(); + + try (CloseableIterable reader = Parquet.read(Files.localInput(testFile)) + .project(COMPLEX_SCHEMA) + .build()) { + long start = System.currentTimeMillis(); + long val = 0; + long count = 0; + for (Record record : reader) { + // access something to ensure the compiler doesn't optimize this away + val ^= (Long) record.get(0); + count += 1; + } + long end = System.currentTimeMillis(); + } + + // clean up as much memory as possible to avoid a large GC during the timed run + System.gc(); + + try (CloseableIterable reader = Parquet.read(Files.localInput(testFile)) + .project(COMPLEX_SCHEMA) + .createReaderFunc( + fileSchema -> ParquetAvroValueReaders.buildReader(COMPLEX_SCHEMA, readSchema)) + .build()) { + long start = System.currentTimeMillis(); + long val = 0; + long count = 0; + for (Record record : reader) { + // access something to ensure the compiler doesn't optimize this away + val ^= (Long) record.get(0); + count += 1; + } + long end = System.currentTimeMillis(); + } + } + } + + @Test + public void testCorrectness() throws IOException { + Iterable records = RandomData.generate(COMPLEX_SCHEMA, 50_000, 34139); + + File testFile = temp.newFile(); + Assert.assertTrue("Delete should succeed", testFile.delete()); + + try (FileAppender writer = Parquet.write(Files.localOutput(testFile)) + .schema(COMPLEX_SCHEMA) + .build()) { + writer.addAll(records); + } + + // RandomData uses the root record name "test", which must match for records to be equal + MessageType readSchema = ParquetSchemaUtil.convert(COMPLEX_SCHEMA, "test"); + + // verify that the new read path is correct + try (CloseableIterable reader = Parquet.read(Files.localInput(testFile)) + .project(COMPLEX_SCHEMA) + .createReaderFunc( + fileSchema -> ParquetAvroValueReaders.buildReader(COMPLEX_SCHEMA, readSchema)) + .reuseContainers() + .build()) { + int recordNum = 0; + Iterator iter = records.iterator(); + for (Record actual : reader) { + Record expected = iter.next(); + Assert.assertEquals("Record " + recordNum + " should match expected", expected, actual); + recordNum += 1; + } + } + } + + private File writeTestData(Schema schema, int numRecords, int seed) throws IOException { + File testFile = temp.newFile(); + Assert.assertTrue("Delete should succeed", testFile.delete()); + + try (FileAppender writer = Parquet.write(Files.localOutput(testFile)) + .schema(schema) + .build()) { + writer.addAll(RandomData.generate(schema, numRecords, seed)); + } + + return testFile; + } +} diff --git a/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/data/TestParquetAvroWriter.java b/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/data/TestParquetAvroWriter.java new file mode 100644 index 000000000000..dcfc873a5a67 --- /dev/null +++ b/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/data/TestParquetAvroWriter.java @@ -0,0 +1,107 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iceberg.spark.data; + +import java.io.File; +import java.io.IOException; +import java.util.Iterator; +import org.apache.avro.generic.GenericData.Record; +import org.apache.iceberg.Files; +import org.apache.iceberg.Schema; +import org.apache.iceberg.io.CloseableIterable; +import org.apache.iceberg.io.FileAppender; +import org.apache.iceberg.parquet.Parquet; +import org.apache.iceberg.parquet.ParquetAvroValueReaders; +import org.apache.iceberg.parquet.ParquetAvroWriter; +import org.apache.iceberg.parquet.ParquetSchemaUtil; +import org.apache.iceberg.types.Types; +import org.apache.parquet.schema.MessageType; +import org.junit.Assert; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.TemporaryFolder; + +import static org.apache.iceberg.types.Types.NestedField.optional; +import static org.apache.iceberg.types.Types.NestedField.required; + +public class TestParquetAvroWriter { + @Rule + public TemporaryFolder temp = new TemporaryFolder(); + + private static final Schema COMPLEX_SCHEMA = new Schema( + required(1, "roots", Types.LongType.get()), + optional(3, "lime", Types.ListType.ofRequired(4, Types.DoubleType.get())), + required(5, "strict", Types.StructType.of( + required(9, "tangerine", Types.StringType.get()), + optional(6, "hopeful", Types.StructType.of( + required(7, "steel", Types.FloatType.get()), + required(8, "lantern", Types.DateType.get()) + )), + optional(10, "vehement", Types.LongType.get()) + )), + optional(11, "metamorphosis", Types.MapType.ofRequired(12, 13, + Types.StringType.get(), Types.TimestampType.withoutZone())), + required(14, "winter", Types.ListType.ofOptional(15, Types.StructType.of( + optional(16, "beet", Types.DoubleType.get()), + required(17, "stamp", Types.TimeType.get()), + optional(18, "wheeze", Types.StringType.get()) + ))), + optional(19, "renovate", Types.MapType.ofRequired(20, 21, + Types.StringType.get(), Types.StructType.of( + optional(22, "jumpy", Types.DoubleType.get()), + required(23, "koala", Types.TimeType.get()), + required(24, "couch rope", Types.IntegerType.get()) + ))), + optional(2, "slide", Types.StringType.get()) + ); + + @Test + public void testCorrectness() throws IOException { + Iterable records = RandomData.generate(COMPLEX_SCHEMA, 50_000, 34139); + + File testFile = temp.newFile(); + Assert.assertTrue("Delete should succeed", testFile.delete()); + + try (FileAppender writer = Parquet.write(Files.localOutput(testFile)) + .schema(COMPLEX_SCHEMA) + .createWriterFunc(ParquetAvroWriter::buildWriter) + .build()) { + writer.addAll(records); + } + + // RandomData uses the root record name "test", which must match for records to be equal + MessageType readSchema = ParquetSchemaUtil.convert(COMPLEX_SCHEMA, "test"); + + // verify that the new read path is correct + try (CloseableIterable reader = Parquet.read(Files.localInput(testFile)) + .project(COMPLEX_SCHEMA) + .createReaderFunc( + fileSchema -> ParquetAvroValueReaders.buildReader(COMPLEX_SCHEMA, readSchema)) + .build()) { + int recordNum = 0; + Iterator iter = records.iterator(); + for (Record actual : reader) { + Record expected = iter.next(); + Assert.assertEquals("Record " + recordNum + " should match expected", expected, actual); + recordNum += 1; + } + } + } +} diff --git a/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/data/TestSparkAvroEnums.java b/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/data/TestSparkAvroEnums.java new file mode 100644 index 000000000000..3517c32ffebb --- /dev/null +++ b/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/data/TestSparkAvroEnums.java @@ -0,0 +1,95 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iceberg.spark.data; + +import java.io.File; +import java.io.IOException; +import java.util.List; +import org.apache.avro.SchemaBuilder; +import org.apache.avro.file.DataFileWriter; +import org.apache.avro.generic.GenericData; +import org.apache.avro.generic.GenericData.Record; +import org.apache.avro.generic.GenericDatumWriter; +import org.apache.iceberg.Files; +import org.apache.iceberg.Schema; +import org.apache.iceberg.avro.Avro; +import org.apache.iceberg.avro.AvroIterable; +import org.apache.iceberg.avro.AvroSchemaUtil; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableList; +import org.apache.iceberg.relocated.com.google.common.collect.Lists; +import org.apache.spark.sql.catalyst.InternalRow; +import org.junit.Assert; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.TemporaryFolder; + +public class TestSparkAvroEnums { + + @Rule + public TemporaryFolder temp = new TemporaryFolder(); + + @Test + public void writeAndValidateEnums() throws IOException { + org.apache.avro.Schema avroSchema = SchemaBuilder.record("root") + .fields() + .name("enumCol") + .type() + .nullable() + .enumeration("testEnum") + .symbols("SYMB1", "SYMB2") + .enumDefault("SYMB2") + .endRecord(); + + org.apache.avro.Schema enumSchema = avroSchema.getField("enumCol").schema().getTypes().get(0); + Record enumRecord1 = new GenericData.Record(avroSchema); + enumRecord1.put("enumCol", new GenericData.EnumSymbol(enumSchema, "SYMB1")); + Record enumRecord2 = new GenericData.Record(avroSchema); + enumRecord2.put("enumCol", new GenericData.EnumSymbol(enumSchema, "SYMB2")); + Record enumRecord3 = new GenericData.Record(avroSchema); // null enum + List expected = ImmutableList.of(enumRecord1, enumRecord2, enumRecord3); + + File testFile = temp.newFile(); + Assert.assertTrue("Delete should succeed", testFile.delete()); + + try (DataFileWriter writer = new DataFileWriter<>(new GenericDatumWriter<>())) { + writer.create(avroSchema, testFile); + writer.append(enumRecord1); + writer.append(enumRecord2); + writer.append(enumRecord3); + } + + Schema schema = new Schema(AvroSchemaUtil.convert(avroSchema).asStructType().fields()); + List rows; + try (AvroIterable reader = Avro.read(Files.localInput(testFile)) + .createReaderFunc(SparkAvroReader::new) + .project(schema) + .build()) { + rows = Lists.newArrayList(reader); + } + + // Iceberg will return enums as strings, so we compare string values for the enum field + for (int i = 0; i < expected.size(); i += 1) { + String expectedEnumString = + expected.get(i).get("enumCol") == null ? null : expected.get(i).get("enumCol").toString(); + String sparkString = rows.get(i).getUTF8String(0) == null ? null : rows.get(i).getUTF8String(0).toString(); + Assert.assertEquals(expectedEnumString, sparkString); + } + } +} diff --git a/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/data/TestSparkAvroReader.java b/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/data/TestSparkAvroReader.java new file mode 100644 index 000000000000..e4398df39cc8 --- /dev/null +++ b/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/data/TestSparkAvroReader.java @@ -0,0 +1,66 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iceberg.spark.data; + +import java.io.File; +import java.io.IOException; +import java.util.List; +import org.apache.avro.generic.GenericData.Record; +import org.apache.iceberg.Files; +import org.apache.iceberg.Schema; +import org.apache.iceberg.avro.Avro; +import org.apache.iceberg.avro.AvroIterable; +import org.apache.iceberg.io.FileAppender; +import org.apache.iceberg.relocated.com.google.common.collect.Lists; +import org.apache.spark.sql.catalyst.InternalRow; +import org.junit.Assert; + +import static org.apache.iceberg.spark.data.TestHelpers.assertEqualsUnsafe; + +public class TestSparkAvroReader extends AvroDataTest { + @Override + protected void writeAndValidate(Schema schema) throws IOException { + List expected = RandomData.generateList(schema, 100, 0L); + + File testFile = temp.newFile(); + Assert.assertTrue("Delete should succeed", testFile.delete()); + + try (FileAppender writer = Avro.write(Files.localOutput(testFile)) + .schema(schema) + .named("test") + .build()) { + for (Record rec : expected) { + writer.add(rec); + } + } + + List rows; + try (AvroIterable reader = Avro.read(Files.localInput(testFile)) + .createReaderFunc(SparkAvroReader::new) + .project(schema) + .build()) { + rows = Lists.newArrayList(reader); + } + + for (int i = 0; i < expected.size(); i += 1) { + assertEqualsUnsafe(schema.asStruct(), expected.get(i), rows.get(i)); + } + } +} diff --git a/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/data/TestSparkDateTimes.java b/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/data/TestSparkDateTimes.java new file mode 100644 index 000000000000..b67e57310b4c --- /dev/null +++ b/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/data/TestSparkDateTimes.java @@ -0,0 +1,75 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iceberg.spark.data; + +import java.time.ZoneId; +import java.util.TimeZone; +import org.apache.iceberg.expressions.Literal; +import org.apache.iceberg.types.Types; +import org.apache.spark.sql.catalyst.util.DateTimeUtils; +import org.apache.spark.sql.catalyst.util.TimestampFormatter; +import org.junit.Assert; +import org.junit.Test; + +public class TestSparkDateTimes { + @Test + public void testSparkDate() { + // checkSparkDate("1582-10-14"); // -141428 + checkSparkDate("1582-10-15"); // first day of the gregorian calendar + checkSparkDate("1601-08-12"); + checkSparkDate("1801-07-04"); + checkSparkDate("1901-08-12"); + checkSparkDate("1969-12-31"); + checkSparkDate("1970-01-01"); + checkSparkDate("2017-12-25"); + checkSparkDate("2043-08-11"); + checkSparkDate("2111-05-03"); + checkSparkDate("2224-02-29"); + checkSparkDate("3224-10-05"); + } + + public void checkSparkDate(String dateString) { + Literal date = Literal.of(dateString).to(Types.DateType.get()); + String sparkDate = DateTimeUtils.toJavaDate(date.value()).toString(); + Assert.assertEquals("Should be the same date (" + date.value() + ")", dateString, sparkDate); + } + + @Test + public void testSparkTimestamp() { + TimeZone currentTz = TimeZone.getDefault(); + try { + TimeZone.setDefault(TimeZone.getTimeZone("UTC")); + checkSparkTimestamp("1582-10-15T15:51:08.440219+00:00", "1582-10-15 15:51:08.440219"); + checkSparkTimestamp("1970-01-01T00:00:00.000000+00:00", "1970-01-01 00:00:00"); + checkSparkTimestamp("2043-08-11T12:30:01.000001+00:00", "2043-08-11 12:30:01.000001"); + } finally { + TimeZone.setDefault(currentTz); + } + } + + public void checkSparkTimestamp(String timestampString, String sparkRepr) { + Literal ts = Literal.of(timestampString).to(Types.TimestampType.withZone()); + ZoneId zoneId = DateTimeUtils.getZoneId("UTC"); + TimestampFormatter formatter = TimestampFormatter.getFractionFormatter(zoneId); + String sparkTimestamp = formatter.format(ts.value()); + Assert.assertEquals("Should be the same timestamp (" + ts.value() + ")", + sparkRepr, sparkTimestamp); + } +} diff --git a/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/data/TestSparkOrcReadMetadataColumns.java b/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/data/TestSparkOrcReadMetadataColumns.java new file mode 100644 index 000000000000..b8ee56370edf --- /dev/null +++ b/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/data/TestSparkOrcReadMetadataColumns.java @@ -0,0 +1,209 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iceberg.spark.data; + +import java.io.File; +import java.io.IOException; +import java.util.Iterator; +import java.util.List; +import java.util.stream.Collectors; +import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.fs.Path; +import org.apache.iceberg.Files; +import org.apache.iceberg.MetadataColumns; +import org.apache.iceberg.Schema; +import org.apache.iceberg.exceptions.RuntimeIOException; +import org.apache.iceberg.expressions.Expression; +import org.apache.iceberg.expressions.Expressions; +import org.apache.iceberg.io.CloseableIterable; +import org.apache.iceberg.io.FileAppender; +import org.apache.iceberg.orc.ORC; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap; +import org.apache.iceberg.relocated.com.google.common.collect.Iterables; +import org.apache.iceberg.relocated.com.google.common.collect.Lists; +import org.apache.iceberg.spark.data.vectorized.VectorizedSparkOrcReaders; +import org.apache.iceberg.types.TypeUtil; +import org.apache.iceberg.types.Types; +import org.apache.orc.OrcConf; +import org.apache.orc.OrcFile; +import org.apache.orc.Reader; +import org.apache.orc.StripeInformation; +import org.apache.spark.sql.catalyst.InternalRow; +import org.apache.spark.sql.catalyst.expressions.GenericInternalRow; +import org.apache.spark.sql.vectorized.ColumnarBatch; +import org.apache.spark.unsafe.types.UTF8String; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.TemporaryFolder; +import org.junit.runner.RunWith; +import org.junit.runners.Parameterized; + +import static org.apache.iceberg.types.Types.NestedField.required; + +@RunWith(Parameterized.class) +public class TestSparkOrcReadMetadataColumns { + private static final Schema DATA_SCHEMA = new Schema( + required(100, "id", Types.LongType.get()), + required(101, "data", Types.StringType.get()) + ); + + private static final Schema PROJECTION_SCHEMA = new Schema( + required(100, "id", Types.LongType.get()), + required(101, "data", Types.StringType.get()), + MetadataColumns.ROW_POSITION, + MetadataColumns.IS_DELETED + ); + + private static final int NUM_ROWS = 1000; + private static final List DATA_ROWS; + private static final List EXPECTED_ROWS; + + static { + DATA_ROWS = Lists.newArrayListWithCapacity(NUM_ROWS); + for (long i = 0; i < NUM_ROWS; i++) { + InternalRow row = new GenericInternalRow(DATA_SCHEMA.columns().size()); + row.update(0, i); + row.update(1, UTF8String.fromString("str" + i)); + DATA_ROWS.add(row); + } + + EXPECTED_ROWS = Lists.newArrayListWithCapacity(NUM_ROWS); + for (long i = 0; i < NUM_ROWS; i++) { + InternalRow row = new GenericInternalRow(PROJECTION_SCHEMA.columns().size()); + row.update(0, i); + row.update(1, UTF8String.fromString("str" + i)); + row.update(2, i); + row.update(3, false); + EXPECTED_ROWS.add(row); + } + } + + @Parameterized.Parameters(name = "vectorized = {0}") + public static Object[] parameters() { + return new Object[] { false, true }; + } + + @Rule + public TemporaryFolder temp = new TemporaryFolder(); + + private boolean vectorized; + private File testFile; + + public TestSparkOrcReadMetadataColumns(boolean vectorized) { + this.vectorized = vectorized; + } + + @Before + public void writeFile() throws IOException { + testFile = temp.newFile(); + Assert.assertTrue("Delete should succeed", testFile.delete()); + + try (FileAppender writer = ORC.write(Files.localOutput(testFile)) + .createWriterFunc(SparkOrcWriter::new) + .schema(DATA_SCHEMA) + // write in such a way that the file contains 10 stripes each with 100 rows + .set("iceberg.orc.vectorbatch.size", "100") + .set(OrcConf.ROWS_BETWEEN_CHECKS.getAttribute(), "100") + .set(OrcConf.STRIPE_SIZE.getAttribute(), "1") + .build()) { + writer.addAll(DATA_ROWS); + } + } + + @Test + public void testReadRowNumbers() throws IOException { + readAndValidate(null, null, null, EXPECTED_ROWS); + } + + @Test + public void testReadRowNumbersWithFilter() throws IOException { + readAndValidate(Expressions.greaterThanOrEqual("id", 500), null, null, EXPECTED_ROWS.subList(500, 1000)); + } + + @Test + public void testReadRowNumbersWithSplits() throws IOException { + Reader reader; + try { + OrcFile.ReaderOptions readerOptions = OrcFile.readerOptions(new Configuration()).useUTCTimestamp(true); + reader = OrcFile.createReader(new Path(testFile.toString()), readerOptions); + } catch (IOException ioe) { + throw new RuntimeIOException(ioe, "Failed to open file: %s", testFile); + } + List splitOffsets = reader.getStripes().stream().map(StripeInformation::getOffset) + .collect(Collectors.toList()); + List splitLengths = reader.getStripes().stream().map(StripeInformation::getLength) + .collect(Collectors.toList()); + + for (int i = 0; i < 10; i++) { + readAndValidate(null, splitOffsets.get(i), splitLengths.get(i), EXPECTED_ROWS.subList(i * 100, (i + 1) * 100)); + } + } + + private void readAndValidate(Expression filter, Long splitStart, Long splitLength, List expected) + throws IOException { + Schema projectionWithoutMetadataFields = TypeUtil.selectNot(PROJECTION_SCHEMA, MetadataColumns.metadataFieldIds()); + CloseableIterable reader = null; + try { + ORC.ReadBuilder builder = ORC.read(Files.localInput(testFile)) + .project(projectionWithoutMetadataFields); + + if (vectorized) { + builder = builder.createBatchedReaderFunc(readOrcSchema -> + VectorizedSparkOrcReaders.buildReader(PROJECTION_SCHEMA, readOrcSchema, ImmutableMap.of())); + } else { + builder = builder.createReaderFunc(readOrcSchema -> new SparkOrcReader(PROJECTION_SCHEMA, readOrcSchema)); + } + + if (filter != null) { + builder = builder.filter(filter); + } + + if (splitStart != null && splitLength != null) { + builder = builder.split(splitStart, splitLength); + } + + if (vectorized) { + reader = batchesToRows(builder.build()); + } else { + reader = builder.build(); + } + + final Iterator actualRows = reader.iterator(); + final Iterator expectedRows = expected.iterator(); + while (expectedRows.hasNext()) { + Assert.assertTrue("Should have expected number of rows", actualRows.hasNext()); + TestHelpers.assertEquals(PROJECTION_SCHEMA, expectedRows.next(), actualRows.next()); + } + Assert.assertFalse("Should not have extra rows", actualRows.hasNext()); + } finally { + if (reader != null) { + reader.close(); + } + } + } + + private CloseableIterable batchesToRows(CloseableIterable batches) { + return CloseableIterable.combine( + Iterables.concat(Iterables.transform(batches, b -> (Iterable) b::rowIterator)), + batches); + } +} diff --git a/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/data/TestSparkOrcReader.java b/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/data/TestSparkOrcReader.java new file mode 100644 index 000000000000..5042d1cc1338 --- /dev/null +++ b/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/data/TestSparkOrcReader.java @@ -0,0 +1,107 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iceberg.spark.data; + +import java.io.File; +import java.io.IOException; +import java.util.Collections; +import java.util.Iterator; +import java.util.List; +import org.apache.iceberg.Files; +import org.apache.iceberg.Schema; +import org.apache.iceberg.io.CloseableIterable; +import org.apache.iceberg.io.FileAppender; +import org.apache.iceberg.orc.ORC; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap; +import org.apache.iceberg.relocated.com.google.common.collect.Iterators; +import org.apache.iceberg.spark.data.vectorized.VectorizedSparkOrcReaders; +import org.apache.iceberg.types.Types; +import org.apache.spark.sql.catalyst.InternalRow; +import org.apache.spark.sql.vectorized.ColumnarBatch; +import org.junit.Assert; +import org.junit.Test; + +import static org.apache.iceberg.spark.data.TestHelpers.assertEquals; +import static org.apache.iceberg.types.Types.NestedField.required; + +public class TestSparkOrcReader extends AvroDataTest { + @Override + protected void writeAndValidate(Schema schema) throws IOException { + final Iterable expected = RandomData + .generateSpark(schema, 100, 0L); + + writeAndValidateRecords(schema, expected); + } + + @Test + public void writeAndValidateRepeatingRecords() throws IOException { + Schema structSchema = new Schema( + required(100, "id", Types.LongType.get()), + required(101, "data", Types.StringType.get()) + ); + List expectedRepeating = Collections.nCopies(100, + RandomData.generateSpark(structSchema, 1, 0L).iterator().next()); + + writeAndValidateRecords(structSchema, expectedRepeating); + } + + private void writeAndValidateRecords(Schema schema, Iterable expected) throws IOException { + final File testFile = temp.newFile(); + Assert.assertTrue("Delete should succeed", testFile.delete()); + + try (FileAppender writer = ORC.write(Files.localOutput(testFile)) + .createWriterFunc(SparkOrcWriter::new) + .schema(schema) + .build()) { + writer.addAll(expected); + } + + try (CloseableIterable reader = ORC.read(Files.localInput(testFile)) + .project(schema) + .createReaderFunc(readOrcSchema -> new SparkOrcReader(schema, readOrcSchema)) + .build()) { + final Iterator actualRows = reader.iterator(); + final Iterator expectedRows = expected.iterator(); + while (expectedRows.hasNext()) { + Assert.assertTrue("Should have expected number of rows", actualRows.hasNext()); + assertEquals(schema, expectedRows.next(), actualRows.next()); + } + Assert.assertFalse("Should not have extra rows", actualRows.hasNext()); + } + + try (CloseableIterable reader = ORC.read(Files.localInput(testFile)) + .project(schema) + .createBatchedReaderFunc(readOrcSchema -> + VectorizedSparkOrcReaders.buildReader(schema, readOrcSchema, ImmutableMap.of())) + .build()) { + final Iterator actualRows = batchesToRows(reader.iterator()); + final Iterator expectedRows = expected.iterator(); + while (expectedRows.hasNext()) { + Assert.assertTrue("Should have expected number of rows", actualRows.hasNext()); + assertEquals(schema, expectedRows.next(), actualRows.next()); + } + Assert.assertFalse("Should not have extra rows", actualRows.hasNext()); + } + } + + private Iterator batchesToRows(Iterator batches) { + return Iterators.concat(Iterators.transform(batches, ColumnarBatch::rowIterator)); + } +} diff --git a/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/data/TestSparkParquetReadMetadataColumns.java b/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/data/TestSparkParquetReadMetadataColumns.java new file mode 100644 index 000000000000..256753f4ccc6 --- /dev/null +++ b/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/data/TestSparkParquetReadMetadataColumns.java @@ -0,0 +1,297 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iceberg.spark.data; + +import java.io.File; +import java.io.IOException; +import java.util.Iterator; +import java.util.List; +import java.util.Set; +import org.apache.arrow.vector.NullCheckingForGet; +import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.fs.Path; +import org.apache.iceberg.Files; +import org.apache.iceberg.MetadataColumns; +import org.apache.iceberg.Schema; +import org.apache.iceberg.data.DeleteFilter; +import org.apache.iceberg.deletes.PositionDeleteIndex; +import org.apache.iceberg.expressions.Expression; +import org.apache.iceberg.expressions.Expressions; +import org.apache.iceberg.io.CloseableIterable; +import org.apache.iceberg.io.FileAppender; +import org.apache.iceberg.parquet.Parquet; +import org.apache.iceberg.parquet.ParquetSchemaUtil; +import org.apache.iceberg.relocated.com.google.common.collect.Iterables; +import org.apache.iceberg.relocated.com.google.common.collect.Lists; +import org.apache.iceberg.relocated.com.google.common.collect.Maps; +import org.apache.iceberg.relocated.com.google.common.collect.Sets; +import org.apache.iceberg.spark.SparkSchemaUtil; +import org.apache.iceberg.spark.data.vectorized.VectorizedSparkParquetReaders; +import org.apache.iceberg.types.Types; +import org.apache.parquet.ParquetReadOptions; +import org.apache.parquet.hadoop.ParquetFileReader; +import org.apache.parquet.hadoop.ParquetFileWriter; +import org.apache.parquet.hadoop.metadata.BlockMetaData; +import org.apache.parquet.hadoop.util.HadoopInputFile; +import org.apache.spark.sql.catalyst.InternalRow; +import org.apache.spark.sql.catalyst.expressions.GenericInternalRow; +import org.apache.spark.sql.types.StructType; +import org.apache.spark.sql.vectorized.ColumnarBatch; +import org.apache.spark.unsafe.types.UTF8String; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.TemporaryFolder; +import org.junit.runner.RunWith; +import org.junit.runners.Parameterized; + +import static org.apache.iceberg.types.Types.NestedField.required; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +@RunWith(Parameterized.class) +public class TestSparkParquetReadMetadataColumns { + private static final Schema DATA_SCHEMA = new Schema( + required(100, "id", Types.LongType.get()), + required(101, "data", Types.StringType.get()) + ); + + private static final Schema PROJECTION_SCHEMA = new Schema( + required(100, "id", Types.LongType.get()), + required(101, "data", Types.StringType.get()), + MetadataColumns.ROW_POSITION, + MetadataColumns.IS_DELETED + ); + + private static final int NUM_ROWS = 1000; + private static final List DATA_ROWS; + private static final List EXPECTED_ROWS; + private static final int NUM_ROW_GROUPS = 10; + private static final int ROWS_PER_SPLIT = NUM_ROWS / NUM_ROW_GROUPS; + private static final int RECORDS_PER_BATCH = ROWS_PER_SPLIT / 10; + + static { + DATA_ROWS = Lists.newArrayListWithCapacity(NUM_ROWS); + for (long i = 0; i < NUM_ROWS; i += 1) { + InternalRow row = new GenericInternalRow(DATA_SCHEMA.columns().size()); + if (i >= NUM_ROWS / 2) { + row.update(0, 2 * i); + } else { + row.update(0, i); + } + row.update(1, UTF8String.fromString("str" + i)); + DATA_ROWS.add(row); + } + + EXPECTED_ROWS = Lists.newArrayListWithCapacity(NUM_ROWS); + for (long i = 0; i < NUM_ROWS; i += 1) { + InternalRow row = new GenericInternalRow(PROJECTION_SCHEMA.columns().size()); + if (i >= NUM_ROWS / 2) { + row.update(0, 2 * i); + } else { + row.update(0, i); + } + row.update(1, UTF8String.fromString("str" + i)); + row.update(2, i); + row.update(3, false); + EXPECTED_ROWS.add(row); + } + } + + @Parameterized.Parameters(name = "vectorized = {0}") + public static Object[][] parameters() { + return new Object[][] { + new Object[] { false }, + new Object[] { true } + }; + } + + @Rule + public TemporaryFolder temp = new TemporaryFolder(); + + private final boolean vectorized; + private File testFile; + + public TestSparkParquetReadMetadataColumns(boolean vectorized) { + this.vectorized = vectorized; + } + + @Before + public void writeFile() throws IOException { + List fileSplits = Lists.newArrayList(); + StructType struct = SparkSchemaUtil.convert(DATA_SCHEMA); + Configuration conf = new Configuration(); + + testFile = temp.newFile(); + Assert.assertTrue("Delete should succeed", testFile.delete()); + ParquetFileWriter parquetFileWriter = new ParquetFileWriter( + conf, + ParquetSchemaUtil.convert(DATA_SCHEMA, "testSchema"), + new Path(testFile.getAbsolutePath()) + ); + + parquetFileWriter.start(); + for (int i = 0; i < NUM_ROW_GROUPS; i += 1) { + File split = temp.newFile(); + Assert.assertTrue("Delete should succeed", split.delete()); + fileSplits.add(new Path(split.getAbsolutePath())); + try (FileAppender writer = Parquet.write(Files.localOutput(split)) + .createWriterFunc(msgType -> SparkParquetWriters.buildWriter(struct, msgType)) + .schema(DATA_SCHEMA) + .overwrite() + .build()) { + writer.addAll(DATA_ROWS.subList(i * ROWS_PER_SPLIT, (i + 1) * ROWS_PER_SPLIT)); + } + parquetFileWriter.appendFile(HadoopInputFile.fromPath(new Path(split.getAbsolutePath()), conf)); + } + parquetFileWriter + .end(ParquetFileWriter.mergeMetadataFiles(fileSplits, conf).getFileMetaData().getKeyValueMetaData()); + } + + @Test + public void testReadRowNumbers() throws IOException { + readAndValidate(null, null, null, EXPECTED_ROWS); + } + + @Test + public void testReadRowNumbersWithDelete() throws IOException { + if (vectorized) { + List expectedRowsAfterDelete = Lists.newArrayList(EXPECTED_ROWS); + // remove row at position 98, 99, 100, 101, 102, this crosses two row groups [0, 100) and [100, 200) + for (int i = 1; i <= 5; i++) { + expectedRowsAfterDelete.remove(98); + } + + Parquet.ReadBuilder builder = Parquet.read(Files.localInput(testFile)).project(PROJECTION_SCHEMA); + + DeleteFilter deleteFilter = mock(DeleteFilter.class); + when(deleteFilter.hasPosDeletes()).thenReturn(true); + PositionDeleteIndex deletedRowPos = new CustomizedPositionDeleteIndex(); + deletedRowPos.delete(98, 103); + when(deleteFilter.deletedRowPositions()).thenReturn(deletedRowPos); + + builder.createBatchedReaderFunc(fileSchema -> VectorizedSparkParquetReaders.buildReader(PROJECTION_SCHEMA, + fileSchema, NullCheckingForGet.NULL_CHECKING_ENABLED, Maps.newHashMap(), deleteFilter)); + builder.recordsPerBatch(RECORDS_PER_BATCH); + + validate(expectedRowsAfterDelete, builder); + } + } + + private class CustomizedPositionDeleteIndex implements PositionDeleteIndex { + private final Set deleteIndex; + + private CustomizedPositionDeleteIndex() { + deleteIndex = Sets.newHashSet(); + } + + @Override + public void delete(long position) { + deleteIndex.add(position); + } + + @Override + public void delete(long posStart, long posEnd) { + for (long l = posStart; l < posEnd; l++) { + delete(l); + } + } + + @Override + public boolean isDeleted(long position) { + return deleteIndex.contains(position); + } + + @Override + public boolean isEmpty() { + return deleteIndex.isEmpty(); + } + } + + @Test + public void testReadRowNumbersWithFilter() throws IOException { + // current iceberg supports row group filter. + for (int i = 1; i < 5; i += 1) { + readAndValidate( + Expressions.and(Expressions.lessThan("id", NUM_ROWS / 2), + Expressions.greaterThanOrEqual("id", i * ROWS_PER_SPLIT)), + null, + null, + EXPECTED_ROWS.subList(i * ROWS_PER_SPLIT, NUM_ROWS / 2)); + } + } + + @Test + public void testReadRowNumbersWithSplits() throws IOException { + ParquetFileReader fileReader = new ParquetFileReader( + HadoopInputFile.fromPath(new Path(testFile.getAbsolutePath()), new Configuration()), + ParquetReadOptions.builder().build()); + List rowGroups = fileReader.getRowGroups(); + for (int i = 0; i < NUM_ROW_GROUPS; i += 1) { + readAndValidate(null, + rowGroups.get(i).getColumns().get(0).getStartingPos(), + rowGroups.get(i).getCompressedSize(), + EXPECTED_ROWS.subList(i * ROWS_PER_SPLIT, (i + 1) * ROWS_PER_SPLIT)); + } + } + + private void readAndValidate(Expression filter, Long splitStart, Long splitLength, List expected) + throws IOException { + Parquet.ReadBuilder builder = Parquet.read(Files.localInput(testFile)).project(PROJECTION_SCHEMA); + + if (vectorized) { + builder.createBatchedReaderFunc(fileSchema -> VectorizedSparkParquetReaders.buildReader(PROJECTION_SCHEMA, + fileSchema, NullCheckingForGet.NULL_CHECKING_ENABLED)); + builder.recordsPerBatch(RECORDS_PER_BATCH); + } else { + builder = builder.createReaderFunc(msgType -> SparkParquetReaders.buildReader(PROJECTION_SCHEMA, msgType)); + } + + if (filter != null) { + builder = builder.filter(filter); + } + + if (splitStart != null && splitLength != null) { + builder = builder.split(splitStart, splitLength); + } + + validate(expected, builder); + } + + private void validate(List expected, Parquet.ReadBuilder builder) throws IOException { + try (CloseableIterable reader = vectorized ? batchesToRows(builder.build()) : builder.build()) { + final Iterator actualRows = reader.iterator(); + + for (InternalRow internalRow : expected) { + Assert.assertTrue("Should have expected number of rows", actualRows.hasNext()); + TestHelpers.assertEquals(PROJECTION_SCHEMA, internalRow, actualRows.next()); + } + + Assert.assertFalse("Should not have extra rows", actualRows.hasNext()); + } + } + + private CloseableIterable batchesToRows(CloseableIterable batches) { + return CloseableIterable.combine( + Iterables.concat(Iterables.transform(batches, b -> (Iterable) b::rowIterator)), + batches); + } +} diff --git a/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/data/TestSparkParquetReader.java b/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/data/TestSparkParquetReader.java new file mode 100644 index 000000000000..6895523711b9 --- /dev/null +++ b/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/data/TestSparkParquetReader.java @@ -0,0 +1,204 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iceberg.spark.data; + +import java.io.File; +import java.io.IOException; +import java.util.Iterator; +import java.util.List; +import java.util.Map; +import org.apache.avro.generic.GenericData; +import org.apache.hadoop.conf.Configuration; +import org.apache.iceberg.DataFiles; +import org.apache.iceberg.FileFormat; +import org.apache.iceberg.Files; +import org.apache.iceberg.MetricsConfig; +import org.apache.iceberg.PartitionSpec; +import org.apache.iceberg.Schema; +import org.apache.iceberg.Table; +import org.apache.iceberg.data.IcebergGenerics; +import org.apache.iceberg.data.Record; +import org.apache.iceberg.hadoop.HadoopTables; +import org.apache.iceberg.io.CloseableIterable; +import org.apache.iceberg.io.FileAppender; +import org.apache.iceberg.io.InputFile; +import org.apache.iceberg.parquet.Parquet; +import org.apache.iceberg.parquet.ParquetUtil; +import org.apache.iceberg.parquet.ParquetWriteAdapter; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap; +import org.apache.iceberg.relocated.com.google.common.collect.Lists; +import org.apache.iceberg.relocated.com.google.common.collect.Maps; +import org.apache.iceberg.types.TypeUtil; +import org.apache.iceberg.types.Types; +import org.apache.parquet.hadoop.ParquetWriter; +import org.apache.parquet.hadoop.api.WriteSupport; +import org.apache.parquet.hadoop.util.HadoopOutputFile; +import org.apache.spark.sql.catalyst.InternalRow; +import org.apache.spark.sql.types.DataTypes; +import org.apache.spark.sql.types.Metadata; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; +import org.hamcrest.CoreMatchers; +import org.junit.Assert; +import org.junit.Assume; +import org.junit.Test; + +import static org.apache.iceberg.spark.data.TestHelpers.assertEqualsUnsafe; +import static org.apache.iceberg.types.Types.NestedField.required; + +public class TestSparkParquetReader extends AvroDataTest { + @Override + protected void writeAndValidate(Schema schema) throws IOException { + Assume.assumeTrue("Parquet Avro cannot write non-string map keys", null == TypeUtil.find(schema, + type -> type.isMapType() && type.asMapType().keyType() != Types.StringType.get())); + + List expected = RandomData.generateList(schema, 100, 0L); + + File testFile = temp.newFile(); + Assert.assertTrue("Delete should succeed", testFile.delete()); + + try (FileAppender writer = Parquet.write(Files.localOutput(testFile)) + .schema(schema) + .named("test") + .build()) { + writer.addAll(expected); + } + + try (CloseableIterable reader = Parquet.read(Files.localInput(testFile)) + .project(schema) + .createReaderFunc(type -> SparkParquetReaders.buildReader(schema, type)) + .build()) { + Iterator rows = reader.iterator(); + for (int i = 0; i < expected.size(); i += 1) { + Assert.assertTrue("Should have expected number of rows", rows.hasNext()); + assertEqualsUnsafe(schema.asStruct(), expected.get(i), rows.next()); + } + Assert.assertFalse("Should not have extra rows", rows.hasNext()); + } + } + + protected List rowsFromFile(InputFile inputFile, Schema schema) throws IOException { + try (CloseableIterable reader = + Parquet.read(inputFile) + .project(schema) + .createReaderFunc(type -> SparkParquetReaders.buildReader(schema, type)) + .build()) { + return Lists.newArrayList(reader); + } + } + + protected Table tableFromInputFile(InputFile inputFile, Schema schema) throws IOException { + HadoopTables tables = new HadoopTables(); + Table table = + tables.create( + schema, + PartitionSpec.unpartitioned(), + ImmutableMap.of(), + temp.newFolder().getCanonicalPath()); + + table + .newAppend() + .appendFile( + DataFiles.builder(PartitionSpec.unpartitioned()) + .withFormat(FileFormat.PARQUET) + .withInputFile(inputFile) + .withMetrics(ParquetUtil.fileMetrics(inputFile, MetricsConfig.getDefault())) + .withFileSizeInBytes(inputFile.getLength()) + .build()) + .commit(); + + return table; + } + + @Test + public void testInt96TimestampProducedBySparkIsReadCorrectly() throws IOException { + String outputFilePath = String.format("%s/%s", temp.getRoot().getAbsolutePath(), "parquet_int96.parquet"); + HadoopOutputFile outputFile = + HadoopOutputFile.fromPath( + new org.apache.hadoop.fs.Path(outputFilePath), new Configuration()); + Schema schema = new Schema(required(1, "ts", Types.TimestampType.withZone())); + StructType sparkSchema = + new StructType( + new StructField[] { + new StructField("ts", DataTypes.TimestampType, true, Metadata.empty()) + }); + List rows = Lists.newArrayList(RandomData.generateSpark(schema, 10, 0L)); + + try (FileAppender writer = + new ParquetWriteAdapter<>( + new NativeSparkWriterBuilder(outputFile) + .set("org.apache.spark.sql.parquet.row.attributes", sparkSchema.json()) + .set("spark.sql.parquet.writeLegacyFormat", "false") + .set("spark.sql.parquet.outputTimestampType", "INT96") + .set("spark.sql.parquet.fieldId.write.enabled", "true") + .build(), + MetricsConfig.getDefault())) { + writer.addAll(rows); + } + + InputFile parquetInputFile = Files.localInput(outputFilePath); + List readRows = rowsFromFile(parquetInputFile, schema); + Assert.assertEquals(rows.size(), readRows.size()); + Assert.assertThat(readRows, CoreMatchers.is(rows)); + + // Now we try to import that file as an Iceberg table to make sure Iceberg can read + // Int96 end to end. + Table int96Table = tableFromInputFile(parquetInputFile, schema); + List tableRecords = Lists.newArrayList(IcebergGenerics.read(int96Table).build()); + + Assert.assertEquals(rows.size(), tableRecords.size()); + + for (int i = 0; i < tableRecords.size(); i++) { + GenericsHelpers.assertEqualsUnsafe(schema.asStruct(), tableRecords.get(i), rows.get(i)); + } + } + + /** + * Native Spark ParquetWriter.Builder implementation so that we can write timestamps using Spark's native + * ParquetWriteSupport. + */ + private static class NativeSparkWriterBuilder + extends ParquetWriter.Builder { + private final Map config = Maps.newHashMap(); + + NativeSparkWriterBuilder(org.apache.parquet.io.OutputFile path) { + super(path); + } + + public NativeSparkWriterBuilder set(String property, String value) { + this.config.put(property, value); + return self(); + } + + @Override + protected NativeSparkWriterBuilder self() { + return this; + } + + @Override + protected WriteSupport getWriteSupport(Configuration configuration) { + for (Map.Entry entry : config.entrySet()) { + configuration.set(entry.getKey(), entry.getValue()); + } + + return new org.apache.spark.sql.execution.datasources.parquet.ParquetWriteSupport(); + } + } +} diff --git a/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/data/TestSparkParquetWriter.java b/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/data/TestSparkParquetWriter.java new file mode 100644 index 000000000000..c75a87abc45c --- /dev/null +++ b/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/data/TestSparkParquetWriter.java @@ -0,0 +1,100 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iceberg.spark.data; + +import java.io.File; +import java.io.IOException; +import java.util.Iterator; +import org.apache.iceberg.Files; +import org.apache.iceberg.Schema; +import org.apache.iceberg.io.CloseableIterable; +import org.apache.iceberg.io.FileAppender; +import org.apache.iceberg.parquet.Parquet; +import org.apache.iceberg.spark.SparkSchemaUtil; +import org.apache.iceberg.types.Types; +import org.apache.spark.sql.catalyst.InternalRow; +import org.junit.Assert; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.TemporaryFolder; + +import static org.apache.iceberg.types.Types.NestedField.optional; +import static org.apache.iceberg.types.Types.NestedField.required; + +public class TestSparkParquetWriter { + @Rule + public TemporaryFolder temp = new TemporaryFolder(); + + private static final Schema COMPLEX_SCHEMA = new Schema( + required(1, "roots", Types.LongType.get()), + optional(3, "lime", Types.ListType.ofRequired(4, Types.DoubleType.get())), + required(5, "strict", Types.StructType.of( + required(9, "tangerine", Types.StringType.get()), + optional(6, "hopeful", Types.StructType.of( + required(7, "steel", Types.FloatType.get()), + required(8, "lantern", Types.DateType.get()) + )), + optional(10, "vehement", Types.LongType.get()) + )), + optional(11, "metamorphosis", Types.MapType.ofRequired(12, 13, + Types.StringType.get(), Types.TimestampType.withZone())), + required(14, "winter", Types.ListType.ofOptional(15, Types.StructType.of( + optional(16, "beet", Types.DoubleType.get()), + required(17, "stamp", Types.FloatType.get()), + optional(18, "wheeze", Types.StringType.get()) + ))), + optional(19, "renovate", Types.MapType.ofRequired(20, 21, + Types.StringType.get(), Types.StructType.of( + optional(22, "jumpy", Types.DoubleType.get()), + required(23, "koala", Types.IntegerType.get()), + required(24, "couch rope", Types.IntegerType.get()) + ))), + optional(2, "slide", Types.StringType.get()) + ); + + @Test + public void testCorrectness() throws IOException { + int numRows = 50_000; + Iterable records = RandomData.generateSpark(COMPLEX_SCHEMA, numRows, 19981); + + File testFile = temp.newFile(); + Assert.assertTrue("Delete should succeed", testFile.delete()); + + try (FileAppender writer = Parquet.write(Files.localOutput(testFile)) + .schema(COMPLEX_SCHEMA) + .createWriterFunc(msgType -> SparkParquetWriters.buildWriter(SparkSchemaUtil.convert(COMPLEX_SCHEMA), msgType)) + .build()) { + writer.addAll(records); + } + + try (CloseableIterable reader = Parquet.read(Files.localInput(testFile)) + .project(COMPLEX_SCHEMA) + .createReaderFunc(type -> SparkParquetReaders.buildReader(COMPLEX_SCHEMA, type)) + .build()) { + Iterator expected = records.iterator(); + Iterator rows = reader.iterator(); + for (int i = 0; i < numRows; i += 1) { + Assert.assertTrue("Should have expected number of rows", rows.hasNext()); + TestHelpers.assertEquals(COMPLEX_SCHEMA, expected.next(), rows.next()); + } + Assert.assertFalse("Should not have extra rows", rows.hasNext()); + } + } +} diff --git a/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/data/TestSparkRecordOrcReaderWriter.java b/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/data/TestSparkRecordOrcReaderWriter.java new file mode 100644 index 000000000000..1e7430d16df7 --- /dev/null +++ b/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/data/TestSparkRecordOrcReaderWriter.java @@ -0,0 +1,148 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iceberg.spark.data; + +import java.io.File; +import java.io.IOException; +import java.math.BigDecimal; +import java.util.Iterator; +import java.util.List; +import org.apache.iceberg.Files; +import org.apache.iceberg.Schema; +import org.apache.iceberg.data.GenericRecord; +import org.apache.iceberg.data.RandomGenericData; +import org.apache.iceberg.data.Record; +import org.apache.iceberg.data.orc.GenericOrcReader; +import org.apache.iceberg.data.orc.GenericOrcWriter; +import org.apache.iceberg.io.CloseableIterable; +import org.apache.iceberg.io.FileAppender; +import org.apache.iceberg.orc.ORC; +import org.apache.iceberg.relocated.com.google.common.collect.Lists; +import org.apache.iceberg.types.Types; +import org.apache.spark.sql.catalyst.InternalRow; +import org.junit.Assert; +import org.junit.Test; + +import static org.apache.iceberg.types.Types.NestedField.required; + +public class TestSparkRecordOrcReaderWriter extends AvroDataTest { + private static final int NUM_RECORDS = 200; + + private void writeAndValidate(Schema schema, List expectedRecords) throws IOException { + final File originalFile = temp.newFile(); + Assert.assertTrue("Delete should succeed", originalFile.delete()); + + // Write few generic records into the original test file. + try (FileAppender writer = ORC.write(Files.localOutput(originalFile)) + .createWriterFunc(GenericOrcWriter::buildWriter) + .schema(schema) + .build()) { + writer.addAll(expectedRecords); + } + + // Read into spark InternalRow from the original test file. + List internalRows = Lists.newArrayList(); + try (CloseableIterable reader = ORC.read(Files.localInput(originalFile)) + .project(schema) + .createReaderFunc(readOrcSchema -> new SparkOrcReader(schema, readOrcSchema)) + .build()) { + reader.forEach(internalRows::add); + assertEqualsUnsafe(schema.asStruct(), expectedRecords, reader, expectedRecords.size()); + } + + final File anotherFile = temp.newFile(); + Assert.assertTrue("Delete should succeed", anotherFile.delete()); + + // Write those spark InternalRows into a new file again. + try (FileAppender writer = ORC.write(Files.localOutput(anotherFile)) + .createWriterFunc(SparkOrcWriter::new) + .schema(schema) + .build()) { + writer.addAll(internalRows); + } + + // Check whether the InternalRows are expected records. + try (CloseableIterable reader = ORC.read(Files.localInput(anotherFile)) + .project(schema) + .createReaderFunc(readOrcSchema -> new SparkOrcReader(schema, readOrcSchema)) + .build()) { + assertEqualsUnsafe(schema.asStruct(), expectedRecords, reader, expectedRecords.size()); + } + + // Read into iceberg GenericRecord and check again. + try (CloseableIterable reader = ORC.read(Files.localInput(anotherFile)) + .createReaderFunc(typeDesc -> GenericOrcReader.buildReader(schema, typeDesc)) + .project(schema) + .build()) { + assertRecordEquals(expectedRecords, reader, expectedRecords.size()); + } + } + + @Override + protected void writeAndValidate(Schema schema) throws IOException { + List expectedRecords = RandomGenericData.generate(schema, NUM_RECORDS, 1992L); + writeAndValidate(schema, expectedRecords); + } + + @Test + public void testDecimalWithTrailingZero() throws IOException { + Schema schema = new Schema( + required(1, "d1", Types.DecimalType.of(10, 2)), + required(2, "d2", Types.DecimalType.of(20, 5)), + required(3, "d3", Types.DecimalType.of(38, 20)) + ); + + List expected = Lists.newArrayList(); + + GenericRecord record = GenericRecord.create(schema); + record.set(0, new BigDecimal("101.00")); + record.set(1, new BigDecimal("10.00E-3")); + record.set(2, new BigDecimal("1001.0000E-16")); + + expected.add(record.copy()); + + writeAndValidate(schema, expected); + } + + private static void assertRecordEquals(Iterable expected, Iterable actual, int size) { + Iterator expectedIter = expected.iterator(); + Iterator actualIter = actual.iterator(); + for (int i = 0; i < size; i += 1) { + Assert.assertTrue("Expected iterator should have more rows", expectedIter.hasNext()); + Assert.assertTrue("Actual iterator should have more rows", actualIter.hasNext()); + Assert.assertEquals("Should have same rows.", expectedIter.next(), actualIter.next()); + } + Assert.assertFalse("Expected iterator should not have any extra rows.", expectedIter.hasNext()); + Assert.assertFalse("Actual iterator should not have any extra rows.", actualIter.hasNext()); + } + + private static void assertEqualsUnsafe(Types.StructType struct, Iterable expected, + Iterable actual, int size) { + Iterator expectedIter = expected.iterator(); + Iterator actualIter = actual.iterator(); + for (int i = 0; i < size; i += 1) { + Assert.assertTrue("Expected iterator should have more rows", expectedIter.hasNext()); + Assert.assertTrue("Actual iterator should have more rows", actualIter.hasNext()); + GenericsHelpers.assertEqualsUnsafe(struct, expectedIter.next(), actualIter.next()); + } + Assert.assertFalse("Expected iterator should not have any extra rows.", expectedIter.hasNext()); + Assert.assertFalse("Actual iterator should not have any extra rows.", actualIter.hasNext()); + } +} diff --git a/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/data/parquet/vectorized/TestParquetDictionaryEncodedVectorizedReads.java b/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/data/parquet/vectorized/TestParquetDictionaryEncodedVectorizedReads.java new file mode 100644 index 000000000000..f292df0c3bf8 --- /dev/null +++ b/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/data/parquet/vectorized/TestParquetDictionaryEncodedVectorizedReads.java @@ -0,0 +1,92 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iceberg.spark.data.parquet.vectorized; + +import java.io.File; +import java.io.IOException; +import org.apache.avro.generic.GenericData; +import org.apache.iceberg.Schema; +import org.apache.iceberg.io.FileAppender; +import org.apache.iceberg.parquet.Parquet; +import org.apache.iceberg.relocated.com.google.common.base.Function; +import org.apache.iceberg.relocated.com.google.common.collect.FluentIterable; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableList; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap; +import org.apache.iceberg.relocated.com.google.common.collect.Iterables; +import org.apache.iceberg.spark.data.RandomData; +import org.junit.Assert; +import org.junit.Ignore; +import org.junit.Test; + +import static org.apache.iceberg.TableProperties.PARQUET_ROW_GROUP_SIZE_BYTES_DEFAULT; + +public class TestParquetDictionaryEncodedVectorizedReads extends TestParquetVectorizedReads { + + @Override + Iterable generateData(Schema schema, int numRecords, long seed, float nullPercentage, + Function transform) { + Iterable data = RandomData.generateDictionaryEncodableData(schema, numRecords, seed, nullPercentage); + return transform == IDENTITY ? data : Iterables.transform(data, transform); + } + + @Test + @Override + @Ignore // Ignored since this code path is already tested in TestParquetVectorizedReads + public void testVectorizedReadsWithNewContainers() throws IOException { + + } + + @Test + public void testMixedDictionaryNonDictionaryReads() throws IOException { + Schema schema = new Schema(SUPPORTED_PRIMITIVES.fields()); + File dictionaryEncodedFile = temp.newFile(); + Assert.assertTrue("Delete should succeed", dictionaryEncodedFile.delete()); + Iterable dictionaryEncodableData = RandomData.generateDictionaryEncodableData( + schema, + 10000, + 0L, + RandomData.DEFAULT_NULL_PERCENTAGE); + try (FileAppender writer = getParquetWriter(schema, dictionaryEncodedFile)) { + writer.addAll(dictionaryEncodableData); + } + + File plainEncodingFile = temp.newFile(); + Assert.assertTrue("Delete should succeed", plainEncodingFile.delete()); + Iterable nonDictionaryData = RandomData.generate(schema, 10000, 0L, + RandomData.DEFAULT_NULL_PERCENTAGE); + try (FileAppender writer = getParquetWriter(schema, plainEncodingFile)) { + writer.addAll(nonDictionaryData); + } + + int rowGroupSize = PARQUET_ROW_GROUP_SIZE_BYTES_DEFAULT; + File mixedFile = temp.newFile(); + Assert.assertTrue("Delete should succeed", mixedFile.delete()); + Parquet.concat(ImmutableList.of(dictionaryEncodedFile, plainEncodingFile, dictionaryEncodedFile), + mixedFile, rowGroupSize, schema, ImmutableMap.of()); + assertRecordsMatch( + schema, + 30000, + FluentIterable.concat(dictionaryEncodableData, nonDictionaryData, dictionaryEncodableData), + mixedFile, + false, + true, + BATCH_SIZE); + } +} diff --git a/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/data/parquet/vectorized/TestParquetDictionaryFallbackToPlainEncodingVectorizedReads.java b/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/data/parquet/vectorized/TestParquetDictionaryFallbackToPlainEncodingVectorizedReads.java new file mode 100644 index 000000000000..5ceac3fdb76e --- /dev/null +++ b/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/data/parquet/vectorized/TestParquetDictionaryFallbackToPlainEncodingVectorizedReads.java @@ -0,0 +1,74 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iceberg.spark.data.parquet.vectorized; + +import java.io.File; +import java.io.IOException; +import org.apache.avro.generic.GenericData; +import org.apache.iceberg.Files; +import org.apache.iceberg.Schema; +import org.apache.iceberg.TableProperties; +import org.apache.iceberg.io.FileAppender; +import org.apache.iceberg.parquet.Parquet; +import org.apache.iceberg.relocated.com.google.common.base.Function; +import org.apache.iceberg.relocated.com.google.common.collect.Iterables; +import org.apache.iceberg.spark.data.RandomData; +import org.junit.Ignore; +import org.junit.Test; + +public class TestParquetDictionaryFallbackToPlainEncodingVectorizedReads extends TestParquetVectorizedReads { + private static final int NUM_ROWS = 1_000_000; + + @Override + protected int getNumRows() { + return NUM_ROWS; + } + + @Override + Iterable generateData(Schema schema, int numRecords, long seed, float nullPercentage, + Function transform) { + // TODO: take into account nullPercentage when generating fallback encoding data + Iterable data = RandomData.generateFallbackData(schema, numRecords, seed, numRecords / 20); + return transform == IDENTITY ? data : Iterables.transform(data, transform); + } + + @Override + FileAppender getParquetWriter(Schema schema, File testFile) throws IOException { + return Parquet.write(Files.localOutput(testFile)) + .schema(schema) + .named("test") + .set(TableProperties.PARQUET_DICT_SIZE_BYTES, "512000") + .build(); + } + + @Test + @Override + @Ignore // Fallback encoding not triggered when data is mostly null + public void testMostlyNullsForOptionalFields() { + + } + + @Test + @Override + @Ignore // Ignored since this code path is already tested in TestParquetVectorizedReads + public void testVectorizedReadsWithNewContainers() throws IOException { + + } +} diff --git a/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/data/parquet/vectorized/TestParquetVectorizedReads.java b/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/data/parquet/vectorized/TestParquetVectorizedReads.java new file mode 100644 index 000000000000..19ffd023a061 --- /dev/null +++ b/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/data/parquet/vectorized/TestParquetVectorizedReads.java @@ -0,0 +1,316 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iceberg.spark.data.parquet.vectorized; + +import java.io.File; +import java.io.IOException; +import java.util.Iterator; +import org.apache.avro.generic.GenericData; +import org.apache.iceberg.AssertHelpers; +import org.apache.iceberg.Files; +import org.apache.iceberg.Schema; +import org.apache.iceberg.io.CloseableIterable; +import org.apache.iceberg.io.FileAppender; +import org.apache.iceberg.parquet.Parquet; +import org.apache.iceberg.relocated.com.google.common.base.Function; +import org.apache.iceberg.relocated.com.google.common.base.Strings; +import org.apache.iceberg.relocated.com.google.common.collect.Iterables; +import org.apache.iceberg.relocated.com.google.common.collect.Lists; +import org.apache.iceberg.spark.data.AvroDataTest; +import org.apache.iceberg.spark.data.RandomData; +import org.apache.iceberg.spark.data.TestHelpers; +import org.apache.iceberg.spark.data.vectorized.VectorizedSparkParquetReaders; +import org.apache.iceberg.types.TypeUtil; +import org.apache.iceberg.types.Types; +import org.apache.parquet.column.ParquetProperties; +import org.apache.parquet.schema.GroupType; +import org.apache.parquet.schema.MessageType; +import org.apache.parquet.schema.Type; +import org.apache.spark.sql.vectorized.ColumnarBatch; +import org.junit.Assert; +import org.junit.Assume; +import org.junit.Ignore; +import org.junit.Test; + +import static org.apache.iceberg.types.Types.NestedField.optional; +import static org.apache.iceberg.types.Types.NestedField.required; + +public class TestParquetVectorizedReads extends AvroDataTest { + private static final int NUM_ROWS = 200_000; + static final int BATCH_SIZE = 10_000; + + static final Function IDENTITY = record -> record; + + @Override + protected void writeAndValidate(Schema schema) throws IOException { + writeAndValidate(schema, getNumRows(), 0L, RandomData.DEFAULT_NULL_PERCENTAGE, false, true); + } + + private void writeAndValidate( + Schema schema, int numRecords, long seed, float nullPercentage, + boolean setAndCheckArrowValidityVector, boolean reuseContainers) + throws IOException { + writeAndValidate(schema, numRecords, seed, nullPercentage, + setAndCheckArrowValidityVector, reuseContainers, BATCH_SIZE, IDENTITY); + } + + private void writeAndValidate( + Schema schema, int numRecords, long seed, float nullPercentage, + boolean setAndCheckArrowValidityVector, boolean reuseContainers, int batchSize, + Function transform) + throws IOException { + // Write test data + Assume.assumeTrue("Parquet Avro cannot write non-string map keys", null == TypeUtil.find( + schema, + type -> type.isMapType() && type.asMapType().keyType() != Types.StringType.get())); + + Iterable expected = generateData(schema, numRecords, seed, nullPercentage, transform); + + // write a test parquet file using iceberg writer + File testFile = temp.newFile(); + Assert.assertTrue("Delete should succeed", testFile.delete()); + + try (FileAppender writer = getParquetWriter(schema, testFile)) { + writer.addAll(expected); + } + assertRecordsMatch(schema, numRecords, expected, testFile, setAndCheckArrowValidityVector, + reuseContainers, batchSize); + } + + protected int getNumRows() { + return NUM_ROWS; + } + + Iterable generateData(Schema schema, int numRecords, long seed, float nullPercentage, + Function transform) { + Iterable data = RandomData.generate(schema, numRecords, seed, nullPercentage); + return transform == IDENTITY ? data : Iterables.transform(data, transform); + } + + FileAppender getParquetWriter(Schema schema, File testFile) throws IOException { + return Parquet.write(Files.localOutput(testFile)) + .schema(schema) + .named("test") + .build(); + } + + FileAppender getParquetV2Writer(Schema schema, File testFile) throws IOException { + return Parquet.write(Files.localOutput(testFile)) + .schema(schema) + .named("test") + .writerVersion(ParquetProperties.WriterVersion.PARQUET_2_0) + .build(); + } + + void assertRecordsMatch( + Schema schema, int expectedSize, Iterable expected, File testFile, + boolean setAndCheckArrowValidityBuffer, boolean reuseContainers, int batchSize) + throws IOException { + Parquet.ReadBuilder readBuilder = Parquet.read(Files.localInput(testFile)) + .project(schema) + .recordsPerBatch(batchSize) + .createBatchedReaderFunc(type -> VectorizedSparkParquetReaders.buildReader( + schema, + type, + setAndCheckArrowValidityBuffer)); + if (reuseContainers) { + readBuilder.reuseContainers(); + } + try (CloseableIterable batchReader = + readBuilder.build()) { + Iterator expectedIter = expected.iterator(); + Iterator batches = batchReader.iterator(); + int numRowsRead = 0; + while (batches.hasNext()) { + ColumnarBatch batch = batches.next(); + numRowsRead += batch.numRows(); + TestHelpers.assertEqualsBatch(schema.asStruct(), expectedIter, batch, setAndCheckArrowValidityBuffer); + } + Assert.assertEquals(expectedSize, numRowsRead); + } + } + + @Override + @Test + @Ignore + public void testArray() { + } + + @Override + @Test + @Ignore + public void testArrayOfStructs() { + } + + @Override + @Test + @Ignore + public void testMap() { + } + + @Override + @Test + @Ignore + public void testNumericMapKey() { + } + + @Override + @Test + @Ignore + public void testComplexMapKey() { + } + + @Override + @Test + @Ignore + public void testMapOfStructs() { + } + + @Override + @Test + @Ignore + public void testMixedTypes() { + } + + @Test + @Override + public void testNestedStruct() { + AssertHelpers.assertThrows( + "Vectorized reads are not supported yet for struct fields", + UnsupportedOperationException.class, + "Vectorized reads are not supported yet for struct fields", + () -> VectorizedSparkParquetReaders.buildReader( + TypeUtil.assignIncreasingFreshIds(new Schema(required( + 1, + "struct", + SUPPORTED_PRIMITIVES))), + new MessageType("struct", new GroupType(Type.Repetition.OPTIONAL, "struct").withId(1)), + false)); + } + + @Test + public void testMostlyNullsForOptionalFields() throws IOException { + writeAndValidate( + TypeUtil.assignIncreasingFreshIds(new Schema(SUPPORTED_PRIMITIVES.fields())), + getNumRows(), + 0L, + 0.99f, + false, + true); + } + + @Test + public void testSettingArrowValidityVector() throws IOException { + writeAndValidate(new Schema( + Lists.transform(SUPPORTED_PRIMITIVES.fields(), Types.NestedField::asOptional)), + getNumRows(), 0L, RandomData.DEFAULT_NULL_PERCENTAGE, true, true); + } + + @Test + public void testVectorizedReadsWithNewContainers() throws IOException { + writeAndValidate(TypeUtil.assignIncreasingFreshIds(new Schema(SUPPORTED_PRIMITIVES.fields())), + getNumRows(), 0L, RandomData.DEFAULT_NULL_PERCENTAGE, true, false); + } + + @Test + public void testVectorizedReadsWithReallocatedArrowBuffers() throws IOException { + // With a batch size of 2, 256 bytes are allocated in the VarCharVector. By adding strings of + // length 512, the vector will need to be reallocated for storing the batch. + writeAndValidate(new Schema( + Lists.newArrayList( + SUPPORTED_PRIMITIVES.field("id"), + SUPPORTED_PRIMITIVES.field("data"))), + 10, 0L, RandomData.DEFAULT_NULL_PERCENTAGE, + true, true, 2, + record -> { + if (record.get("data") != null) { + record.put("data", Strings.padEnd((String) record.get("data"), 512, 'a')); + } else { + record.put("data", Strings.padEnd("", 512, 'a')); + } + return record; + }); + } + + @Test + public void testReadsForTypePromotedColumns() throws Exception { + Schema writeSchema = new Schema( + required(100, "id", Types.LongType.get()), + optional(101, "int_data", Types.IntegerType.get()), + optional(102, "float_data", Types.FloatType.get()), + optional(103, "decimal_data", Types.DecimalType.of(10, 5)) + ); + + File dataFile = temp.newFile(); + Assert.assertTrue("Delete should succeed", dataFile.delete()); + Iterable data = generateData(writeSchema, 30000, 0L, + RandomData.DEFAULT_NULL_PERCENTAGE, IDENTITY); + try (FileAppender writer = getParquetWriter(writeSchema, dataFile)) { + writer.addAll(data); + } + + Schema readSchema = new Schema( + required(100, "id", Types.LongType.get()), + optional(101, "int_data", Types.LongType.get()), + optional(102, "float_data", Types.DoubleType.get()), + optional(103, "decimal_data", Types.DecimalType.of(25, 5)) + ); + + assertRecordsMatch(readSchema, 30000, data, dataFile, false, + true, BATCH_SIZE); + } + + @Test + public void testSupportedReadsForParquetV2() throws Exception { + // Only float and double column types are written using plain encoding with Parquet V2 + Schema schema = new Schema( + optional(102, "float_data", Types.FloatType.get()), + optional(103, "double_data", Types.DoubleType.get())); + + File dataFile = temp.newFile(); + Assert.assertTrue("Delete should succeed", dataFile.delete()); + Iterable data = generateData(schema, 30000, 0L, + RandomData.DEFAULT_NULL_PERCENTAGE, IDENTITY); + try (FileAppender writer = getParquetV2Writer(schema, dataFile)) { + writer.addAll(data); + } + assertRecordsMatch(schema, 30000, data, dataFile, false, + true, BATCH_SIZE); + } + + @Test + public void testUnsupportedReadsForParquetV2() throws Exception { + // Longs, ints, string types etc use delta encoding and which are not supported for vectorized reads + Schema schema = new Schema(SUPPORTED_PRIMITIVES.fields()); + File dataFile = temp.newFile(); + Assert.assertTrue("Delete should succeed", dataFile.delete()); + Iterable data = generateData(schema, 30000, 0L, + RandomData.DEFAULT_NULL_PERCENTAGE, IDENTITY); + try (FileAppender writer = getParquetV2Writer(schema, dataFile)) { + writer.addAll(data); + } + AssertHelpers.assertThrows("Vectorized reads not supported", + UnsupportedOperationException.class, "Cannot support vectorized reads for column", () -> { + assertRecordsMatch(schema, 30000, data, dataFile, false, + true, BATCH_SIZE); + return null; + }); + } +} diff --git a/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/source/FilePathLastModifiedRecord.java b/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/source/FilePathLastModifiedRecord.java new file mode 100644 index 000000000000..275e3a520db5 --- /dev/null +++ b/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/source/FilePathLastModifiedRecord.java @@ -0,0 +1,78 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iceberg.spark.source; + +import java.sql.Timestamp; +import java.util.Objects; + +public class FilePathLastModifiedRecord { + private String filePath; + private Timestamp lastModified; + + public FilePathLastModifiedRecord() { + } + + public FilePathLastModifiedRecord(String filePath, Timestamp lastModified) { + this.filePath = filePath; + this.lastModified = lastModified; + } + + public String getFilePath() { + return filePath; + } + + public void setFilePath(String filePath) { + this.filePath = filePath; + } + + public Timestamp getLastModified() { + return lastModified; + } + + public void setLastModified(Timestamp lastModified) { + this.lastModified = lastModified; + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + FilePathLastModifiedRecord that = (FilePathLastModifiedRecord) o; + return Objects.equals(filePath, that.filePath) && + Objects.equals(lastModified, that.lastModified); + } + + @Override + public int hashCode() { + return Objects.hash(filePath, lastModified); + } + + @Override + public String toString() { + return "FilePathLastModifiedRecord{" + + "filePath='" + filePath + '\'' + + ", lastModified='" + lastModified + '\'' + + '}'; + } +} diff --git a/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/source/LogMessage.java b/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/source/LogMessage.java new file mode 100644 index 000000000000..5e22daeb0841 --- /dev/null +++ b/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/source/LogMessage.java @@ -0,0 +1,120 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iceberg.spark.source; + +import java.time.Instant; +import java.util.concurrent.atomic.AtomicInteger; + +public class LogMessage { + private static AtomicInteger idCounter = new AtomicInteger(0); + + static LogMessage debug(String date, String message) { + return new LogMessage(idCounter.getAndIncrement(), date, "DEBUG", message); + } + + static LogMessage debug(String date, String message, Instant timestamp) { + return new LogMessage(idCounter.getAndIncrement(), date, "DEBUG", message, timestamp); + } + + static LogMessage info(String date, String message) { + return new LogMessage(idCounter.getAndIncrement(), date, "INFO", message); + } + + static LogMessage info(String date, String message, Instant timestamp) { + return new LogMessage(idCounter.getAndIncrement(), date, "INFO", message, timestamp); + } + + static LogMessage error(String date, String message) { + return new LogMessage(idCounter.getAndIncrement(), date, "ERROR", message); + } + + static LogMessage error(String date, String message, Instant timestamp) { + return new LogMessage(idCounter.getAndIncrement(), date, "ERROR", message, timestamp); + } + + static LogMessage warn(String date, String message) { + return new LogMessage(idCounter.getAndIncrement(), date, "WARN", message); + } + + static LogMessage warn(String date, String message, Instant timestamp) { + return new LogMessage(idCounter.getAndIncrement(), date, "WARN", message, timestamp); + } + + private int id; + private String date; + private String level; + private String message; + private Instant timestamp; + + private LogMessage(int id, String date, String level, String message) { + this.id = id; + this.date = date; + this.level = level; + this.message = message; + } + + private LogMessage(int id, String date, String level, String message, Instant timestamp) { + this.id = id; + this.date = date; + this.level = level; + this.message = message; + this.timestamp = timestamp; + } + + public int getId() { + return id; + } + + public void setId(int id) { + this.id = id; + } + + public String getDate() { + return date; + } + + public void setDate(String date) { + this.date = date; + } + + public String getLevel() { + return level; + } + + public void setLevel(String level) { + this.level = level; + } + + public String getMessage() { + return message; + } + + public void setMessage(String message) { + this.message = message; + } + + public Instant getTimestamp() { + return timestamp; + } + + public void setTimestamp(Instant timestamp) { + this.timestamp = timestamp; + } +} diff --git a/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/source/ManualSource.java b/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/source/ManualSource.java new file mode 100644 index 000000000000..0f8c8b3b65c6 --- /dev/null +++ b/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/source/ManualSource.java @@ -0,0 +1,74 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iceberg.spark.source; + +import java.util.Map; +import org.apache.iceberg.relocated.com.google.common.base.Preconditions; +import org.apache.iceberg.relocated.com.google.common.collect.Maps; +import org.apache.spark.sql.connector.catalog.Table; +import org.apache.spark.sql.connector.catalog.TableProvider; +import org.apache.spark.sql.connector.expressions.Transform; +import org.apache.spark.sql.sources.DataSourceRegister; +import org.apache.spark.sql.types.StructType; +import org.apache.spark.sql.util.CaseInsensitiveStringMap; + +public class ManualSource implements TableProvider, DataSourceRegister { + public static final String SHORT_NAME = "manual_source"; + public static final String TABLE_NAME = "TABLE_NAME"; + private static final Map tableMap = Maps.newHashMap(); + + public static void setTable(String name, Table table) { + Preconditions.checkArgument(!tableMap.containsKey(name), "Cannot set " + name + ". It is already set"); + tableMap.put(name, table); + } + + public static void clearTables() { + tableMap.clear(); + } + + @Override + public String shortName() { + return SHORT_NAME; + } + + @Override + public StructType inferSchema(CaseInsensitiveStringMap options) { + return getTable(null, null, options).schema(); + } + + @Override + public Transform[] inferPartitioning(CaseInsensitiveStringMap options) { + return getTable(null, null, options).partitioning(); + } + + @Override + public org.apache.spark.sql.connector.catalog.Table getTable( + StructType schema, Transform[] partitioning, Map properties) { + Preconditions.checkArgument(properties.containsKey(TABLE_NAME), "Missing property " + TABLE_NAME); + String tableName = properties.get(TABLE_NAME); + Preconditions.checkArgument(tableMap.containsKey(tableName), "Table missing " + tableName); + return tableMap.get(tableName); + } + + @Override + public boolean supportsExternalMetadata() { + return false; + } +} diff --git a/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/source/SimpleRecord.java b/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/source/SimpleRecord.java new file mode 100644 index 000000000000..c8b7a31b3ba0 --- /dev/null +++ b/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/source/SimpleRecord.java @@ -0,0 +1,80 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iceberg.spark.source; + +import org.apache.iceberg.relocated.com.google.common.base.Objects; + +public class SimpleRecord { + private Integer id; + private String data; + + public SimpleRecord() { + } + + public SimpleRecord(Integer id, String data) { + this.id = id; + this.data = data; + } + + public Integer getId() { + return id; + } + + public void setId(Integer id) { + this.id = id; + } + + public String getData() { + return data; + } + + public void setData(String data) { + this.data = data; + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + + SimpleRecord record = (SimpleRecord) o; + return Objects.equal(id, record.id) && Objects.equal(data, record.data); + } + + @Override + public int hashCode() { + return Objects.hashCode(id, data); + } + + @Override + public String toString() { + StringBuilder buffer = new StringBuilder(); + buffer.append("{\"id\"="); + buffer.append(id); + buffer.append(",\"data\"=\""); + buffer.append(data); + buffer.append("\"}"); + return buffer.toString(); + } +} diff --git a/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/source/TestAvroScan.java b/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/source/TestAvroScan.java new file mode 100644 index 000000000000..4d2e12229813 --- /dev/null +++ b/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/source/TestAvroScan.java @@ -0,0 +1,115 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iceberg.spark.source; + +import java.io.File; +import java.io.IOException; +import java.util.List; +import java.util.UUID; +import org.apache.avro.generic.GenericData.Record; +import org.apache.hadoop.conf.Configuration; +import org.apache.iceberg.DataFile; +import org.apache.iceberg.DataFiles; +import org.apache.iceberg.FileFormat; +import org.apache.iceberg.PartitionSpec; +import org.apache.iceberg.Schema; +import org.apache.iceberg.Table; +import org.apache.iceberg.avro.Avro; +import org.apache.iceberg.hadoop.HadoopTables; +import org.apache.iceberg.io.FileAppender; +import org.apache.iceberg.spark.data.AvroDataTest; +import org.apache.iceberg.spark.data.RandomData; +import org.apache.iceberg.spark.data.TestHelpers; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.SparkSession; +import org.junit.AfterClass; +import org.junit.Assert; +import org.junit.BeforeClass; +import org.junit.Rule; +import org.junit.rules.TemporaryFolder; + +import static org.apache.iceberg.Files.localOutput; + +public class TestAvroScan extends AvroDataTest { + private static final Configuration CONF = new Configuration(); + + @Rule + public TemporaryFolder temp = new TemporaryFolder(); + + private static SparkSession spark = null; + + @BeforeClass + public static void startSpark() { + TestAvroScan.spark = SparkSession.builder().master("local[2]").getOrCreate(); + } + + @AfterClass + public static void stopSpark() { + SparkSession currentSpark = TestAvroScan.spark; + TestAvroScan.spark = null; + currentSpark.stop(); + } + + @Override + protected void writeAndValidate(Schema schema) throws IOException { + File parent = temp.newFolder("avro"); + File location = new File(parent, "test"); + File dataFolder = new File(location, "data"); + dataFolder.mkdirs(); + + File avroFile = new File(dataFolder, + FileFormat.AVRO.addExtension(UUID.randomUUID().toString())); + + HadoopTables tables = new HadoopTables(CONF); + Table table = tables.create(schema, PartitionSpec.unpartitioned(), location.toString()); + + // Important: use the table's schema for the rest of the test + // When tables are created, the column ids are reassigned. + Schema tableSchema = table.schema(); + + List expected = RandomData.generateList(tableSchema, 100, 1L); + + try (FileAppender writer = Avro.write(localOutput(avroFile)) + .schema(tableSchema) + .build()) { + writer.addAll(expected); + } + + DataFile file = DataFiles.builder(PartitionSpec.unpartitioned()) + .withRecordCount(100) + .withFileSizeInBytes(avroFile.length()) + .withPath(avroFile.toString()) + .build(); + + table.newAppend().appendFile(file).commit(); + + Dataset df = spark.read() + .format("iceberg") + .load(location.toString()); + + List rows = df.collectAsList(); + Assert.assertEquals("Should contain 100 rows", 100, rows.size()); + + for (int i = 0; i < expected.size(); i += 1) { + TestHelpers.assertEqualsSafe(tableSchema.asStruct(), expected.get(i), rows.get(i)); + } + } +} diff --git a/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/source/TestDataFrameWriterV2.java b/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/source/TestDataFrameWriterV2.java new file mode 100644 index 000000000000..fdee5911994e --- /dev/null +++ b/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/source/TestDataFrameWriterV2.java @@ -0,0 +1,161 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iceberg.spark.source; + +import org.apache.iceberg.AssertHelpers; +import org.apache.iceberg.TableProperties; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableList; +import org.apache.iceberg.spark.SparkTestBaseWithCatalog; +import org.apache.spark.sql.AnalysisException; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.catalyst.analysis.NoSuchTableException; +import org.junit.After; +import org.junit.Before; +import org.junit.Test; + +public class TestDataFrameWriterV2 extends SparkTestBaseWithCatalog { + @Before + public void createTable() { + sql("CREATE TABLE %s (id bigint, data string) USING iceberg", tableName); + } + + @After + public void removeTables() { + sql("DROP TABLE IF EXISTS %s", tableName); + } + + @Test + public void testMergeSchemaFailsWithoutWriterOption() throws Exception { + sql("ALTER TABLE %s SET TBLPROPERTIES ('%s'='true')", tableName, TableProperties.SPARK_WRITE_ACCEPT_ANY_SCHEMA); + + Dataset twoColDF = jsonToDF( + "id bigint, data string", + "{ \"id\": 1, \"data\": \"a\" }", + "{ \"id\": 2, \"data\": \"b\" }"); + + twoColDF.writeTo(tableName).append(); + + assertEquals("Should have initial 2-column rows", + ImmutableList.of(row(1L, "a"), row(2L, "b")), + sql("select * from %s order by id", tableName)); + + Dataset threeColDF = jsonToDF( + "id bigint, data string, new_col float", + "{ \"id\": 3, \"data\": \"c\", \"new_col\": 12.06 }", + "{ \"id\": 4, \"data\": \"d\", \"new_col\": 14.41 }"); + + // this has a different error message than the case without accept-any-schema because it uses Iceberg checks + AssertHelpers.assertThrows("Should fail when merge-schema is not enabled on the writer", + IllegalArgumentException.class, "Field new_col not found in source schema", + () -> { + try { + threeColDF.writeTo(tableName).append(); + } catch (NoSuchTableException e) { + // needed because append has checked exceptions + throw new RuntimeException(e); + } + }); + } + + @Test + public void testMergeSchemaWithoutAcceptAnySchema() throws Exception { + Dataset twoColDF = jsonToDF( + "id bigint, data string", + "{ \"id\": 1, \"data\": \"a\" }", + "{ \"id\": 2, \"data\": \"b\" }"); + + twoColDF.writeTo(tableName).append(); + + assertEquals("Should have initial 2-column rows", + ImmutableList.of(row(1L, "a"), row(2L, "b")), + sql("select * from %s order by id", tableName)); + + Dataset threeColDF = jsonToDF( + "id bigint, data string, new_col float", + "{ \"id\": 3, \"data\": \"c\", \"new_col\": 12.06 }", + "{ \"id\": 4, \"data\": \"d\", \"new_col\": 14.41 }"); + + AssertHelpers.assertThrows("Should fail when accept-any-schema is not enabled on the table", + AnalysisException.class, "too many data columns", + () -> { + try { + threeColDF.writeTo(tableName).option("merge-schema", "true").append(); + } catch (NoSuchTableException e) { + // needed because append has checked exceptions + throw new RuntimeException(e); + } + }); + } + + @Test + public void testMergeSchemaSparkProperty() throws Exception { + sql("ALTER TABLE %s SET TBLPROPERTIES ('%s'='true')", tableName, TableProperties.SPARK_WRITE_ACCEPT_ANY_SCHEMA); + + Dataset twoColDF = jsonToDF( + "id bigint, data string", + "{ \"id\": 1, \"data\": \"a\" }", + "{ \"id\": 2, \"data\": \"b\" }"); + + twoColDF.writeTo(tableName).append(); + + assertEquals("Should have initial 2-column rows", + ImmutableList.of(row(1L, "a"), row(2L, "b")), + sql("select * from %s order by id", tableName)); + + Dataset threeColDF = jsonToDF( + "id bigint, data string, new_col float", + "{ \"id\": 3, \"data\": \"c\", \"new_col\": 12.06 }", + "{ \"id\": 4, \"data\": \"d\", \"new_col\": 14.41 }"); + + threeColDF.writeTo(tableName).option("mergeSchema", "true").append(); + + assertEquals("Should have 3-column rows", + ImmutableList.of(row(1L, "a", null), row(2L, "b", null), row(3L, "c", 12.06F), row(4L, "d", 14.41F)), + sql("select * from %s order by id", tableName)); + } + + @Test + public void testMergeSchemaIcebergProperty() throws Exception { + sql("ALTER TABLE %s SET TBLPROPERTIES ('%s'='true')", tableName, TableProperties.SPARK_WRITE_ACCEPT_ANY_SCHEMA); + + Dataset twoColDF = jsonToDF( + "id bigint, data string", + "{ \"id\": 1, \"data\": \"a\" }", + "{ \"id\": 2, \"data\": \"b\" }"); + + twoColDF.writeTo(tableName).append(); + + assertEquals("Should have initial 2-column rows", + ImmutableList.of(row(1L, "a"), row(2L, "b")), + sql("select * from %s order by id", tableName)); + + Dataset threeColDF = jsonToDF( + "id bigint, data string, new_col float", + "{ \"id\": 3, \"data\": \"c\", \"new_col\": 12.06 }", + "{ \"id\": 4, \"data\": \"d\", \"new_col\": 14.41 }"); + + threeColDF.writeTo(tableName).option("merge-schema", "true").append(); + + assertEquals("Should have 3-column rows", + ImmutableList.of(row(1L, "a", null), row(2L, "b", null), row(3L, "c", 12.06F), row(4L, "d", 14.41F)), + sql("select * from %s order by id", tableName)); + } +} diff --git a/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/source/TestDataFrameWrites.java b/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/source/TestDataFrameWrites.java new file mode 100644 index 000000000000..f6d292f89f8b --- /dev/null +++ b/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/source/TestDataFrameWrites.java @@ -0,0 +1,385 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iceberg.spark.source; + +import java.io.File; +import java.io.IOException; +import java.net.URI; +import java.util.Arrays; +import java.util.Iterator; +import java.util.List; +import java.util.Map; +import java.util.Random; +import org.apache.avro.generic.GenericData.Record; +import org.apache.hadoop.conf.Configuration; +import org.apache.iceberg.Files; +import org.apache.iceberg.PartitionSpec; +import org.apache.iceberg.Schema; +import org.apache.iceberg.Snapshot; +import org.apache.iceberg.Table; +import org.apache.iceberg.TableProperties; +import org.apache.iceberg.avro.Avro; +import org.apache.iceberg.avro.AvroIterable; +import org.apache.iceberg.hadoop.HadoopTables; +import org.apache.iceberg.io.FileAppender; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap; +import org.apache.iceberg.relocated.com.google.common.collect.Lists; +import org.apache.iceberg.spark.SparkSQLProperties; +import org.apache.iceberg.spark.SparkSchemaUtil; +import org.apache.iceberg.spark.SparkWriteOptions; +import org.apache.iceberg.spark.data.AvroDataTest; +import org.apache.iceberg.spark.data.RandomData; +import org.apache.iceberg.spark.data.SparkAvroReader; +import org.apache.iceberg.types.Types; +import org.apache.spark.SparkException; +import org.apache.spark.TaskContext; +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.api.java.function.MapPartitionsFunction; +import org.apache.spark.sql.DataFrameWriter; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.SaveMode; +import org.apache.spark.sql.SparkSession; +import org.apache.spark.sql.catalyst.InternalRow; +import org.apache.spark.sql.catalyst.encoders.RowEncoder; +import org.junit.AfterClass; +import org.junit.Assert; +import org.junit.Assume; +import org.junit.BeforeClass; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.Parameterized; + +import static org.apache.iceberg.spark.SparkSchemaUtil.convert; +import static org.apache.iceberg.spark.data.TestHelpers.assertEqualsSafe; +import static org.apache.iceberg.spark.data.TestHelpers.assertEqualsUnsafe; + +@RunWith(Parameterized.class) +public class TestDataFrameWrites extends AvroDataTest { + private static final Configuration CONF = new Configuration(); + + private final String format; + + @Parameterized.Parameters(name = "format = {0}") + public static Object[] parameters() { + return new Object[] { "parquet", "avro", "orc" }; + } + + public TestDataFrameWrites(String format) { + this.format = format; + } + + private static SparkSession spark = null; + private static JavaSparkContext sc = null; + + private Map tableProperties; + + private org.apache.spark.sql.types.StructType sparkSchema = new org.apache.spark.sql.types.StructType( + new org.apache.spark.sql.types.StructField[] { + new org.apache.spark.sql.types.StructField( + "optionalField", + org.apache.spark.sql.types.DataTypes.StringType, + true, + org.apache.spark.sql.types.Metadata.empty()), + new org.apache.spark.sql.types.StructField( + "requiredField", + org.apache.spark.sql.types.DataTypes.StringType, + false, + org.apache.spark.sql.types.Metadata.empty()) + }); + + private Schema icebergSchema = new Schema( + Types.NestedField.optional(1, "optionalField", Types.StringType.get()), + Types.NestedField.required(2, "requiredField", Types.StringType.get())); + + private List data0 = Arrays.asList( + "{\"optionalField\": \"a1\", \"requiredField\": \"bid_001\"}", + "{\"optionalField\": \"a2\", \"requiredField\": \"bid_002\"}"); + private List data1 = Arrays.asList( + "{\"optionalField\": \"d1\", \"requiredField\": \"bid_101\"}", + "{\"optionalField\": \"d2\", \"requiredField\": \"bid_102\"}", + "{\"optionalField\": \"d3\", \"requiredField\": \"bid_103\"}", + "{\"optionalField\": \"d4\", \"requiredField\": \"bid_104\"}"); + + @BeforeClass + public static void startSpark() { + TestDataFrameWrites.spark = SparkSession.builder().master("local[2]").getOrCreate(); + TestDataFrameWrites.sc = JavaSparkContext.fromSparkContext(spark.sparkContext()); + } + + @AfterClass + public static void stopSpark() { + SparkSession currentSpark = TestDataFrameWrites.spark; + TestDataFrameWrites.spark = null; + TestDataFrameWrites.sc = null; + currentSpark.stop(); + } + + @Override + protected void writeAndValidate(Schema schema) throws IOException { + File location = createTableFolder(); + Table table = createTable(schema, location); + writeAndValidateWithLocations(table, location, new File(location, "data")); + } + + @Test + public void testWriteWithCustomDataLocation() throws IOException { + File location = createTableFolder(); + File tablePropertyDataLocation = temp.newFolder("test-table-property-data-dir"); + Table table = createTable(new Schema(SUPPORTED_PRIMITIVES.fields()), location); + table.updateProperties().set( + TableProperties.WRITE_DATA_LOCATION, tablePropertyDataLocation.getAbsolutePath()).commit(); + writeAndValidateWithLocations(table, location, tablePropertyDataLocation); + } + + private File createTableFolder() throws IOException { + File parent = temp.newFolder("parquet"); + File location = new File(parent, "test"); + Assert.assertTrue("Mkdir should succeed", location.mkdirs()); + return location; + } + + private Table createTable(Schema schema, File location) { + HadoopTables tables = new HadoopTables(CONF); + return tables.create(schema, PartitionSpec.unpartitioned(), location.toString()); + } + + private void writeAndValidateWithLocations(Table table, File location, File expectedDataDir) throws IOException { + Schema tableSchema = table.schema(); // use the table schema because ids are reassigned + + table.updateProperties().set(TableProperties.DEFAULT_FILE_FORMAT, format).commit(); + + Iterable expected = RandomData.generate(tableSchema, 100, 0L); + writeData(expected, tableSchema, location.toString()); + + table.refresh(); + + List actual = readTable(location.toString()); + + Iterator expectedIter = expected.iterator(); + Iterator actualIter = actual.iterator(); + while (expectedIter.hasNext() && actualIter.hasNext()) { + assertEqualsSafe(tableSchema.asStruct(), expectedIter.next(), actualIter.next()); + } + Assert.assertEquals("Both iterators should be exhausted", expectedIter.hasNext(), actualIter.hasNext()); + + table.currentSnapshot().addedFiles(table.io()).forEach(dataFile -> + Assert.assertTrue( + String.format( + "File should have the parent directory %s, but has: %s.", + expectedDataDir.getAbsolutePath(), + dataFile.path()), + URI.create(dataFile.path().toString()).getPath().startsWith(expectedDataDir.getAbsolutePath()))); + } + + private List readTable(String location) { + Dataset result = spark.read() + .format("iceberg") + .load(location); + + return result.collectAsList(); + } + + private void writeData(Iterable records, Schema schema, String location) throws IOException { + Dataset df = createDataset(records, schema); + DataFrameWriter writer = df.write().format("iceberg").mode("append"); + writer.save(location); + } + + private void writeDataWithFailOnPartition(Iterable records, Schema schema, String location) + throws IOException, SparkException { + final int numPartitions = 10; + final int partitionToFail = new Random().nextInt(numPartitions); + MapPartitionsFunction failOnFirstPartitionFunc = (MapPartitionsFunction) input -> { + int partitionId = TaskContext.getPartitionId(); + + if (partitionId == partitionToFail) { + throw new SparkException(String.format("Intended exception in partition %d !", partitionId)); + } + return input; + }; + + Dataset df = createDataset(records, schema) + .repartition(numPartitions) + .mapPartitions(failOnFirstPartitionFunc, RowEncoder.apply(convert(schema))); + // This trick is needed because Spark 3 handles decimal overflow in RowEncoder which "changes" + // nullability of the column to "true" regardless of original nullability. + // Setting "check-nullability" option to "false" doesn't help as it fails at Spark analyzer. + Dataset convertedDf = df.sqlContext().createDataFrame(df.rdd(), convert(schema)); + DataFrameWriter writer = convertedDf.write().format("iceberg").mode("append"); + writer.save(location); + } + + private Dataset createDataset(Iterable records, Schema schema) throws IOException { + // this uses the SparkAvroReader to create a DataFrame from the list of records + // it assumes that SparkAvroReader is correct + File testFile = temp.newFile(); + Assert.assertTrue("Delete should succeed", testFile.delete()); + + try (FileAppender writer = Avro.write(Files.localOutput(testFile)) + .schema(schema) + .named("test") + .build()) { + for (Record rec : records) { + writer.add(rec); + } + } + + // make sure the dataframe matches the records before moving on + List rows = Lists.newArrayList(); + try (AvroIterable reader = Avro.read(Files.localInput(testFile)) + .createReaderFunc(SparkAvroReader::new) + .project(schema) + .build()) { + + Iterator recordIter = records.iterator(); + Iterator readIter = reader.iterator(); + while (recordIter.hasNext() && readIter.hasNext()) { + InternalRow row = readIter.next(); + assertEqualsUnsafe(schema.asStruct(), recordIter.next(), row); + rows.add(row); + } + Assert.assertEquals("Both iterators should be exhausted", recordIter.hasNext(), readIter.hasNext()); + } + + JavaRDD rdd = sc.parallelize(rows); + return spark.internalCreateDataFrame(JavaRDD.toRDD(rdd), convert(schema), false); + } + + @Test + public void testNullableWithWriteOption() throws IOException { + Assume.assumeTrue("Spark 3.0 rejects writing nulls to a required column", spark.version().startsWith("2")); + + File location = new File(temp.newFolder("parquet"), "test"); + String sourcePath = String.format("%s/nullable_poc/sourceFolder/", location.toString()); + String targetPath = String.format("%s/nullable_poc/targetFolder/", location.toString()); + + tableProperties = ImmutableMap.of(TableProperties.WRITE_DATA_LOCATION, targetPath); + + // read this and append to iceberg dataset + spark + .read().schema(sparkSchema).json( + JavaSparkContext.fromSparkContext(spark.sparkContext()).parallelize(data1)) + .write().parquet(sourcePath); + + // this is our iceberg dataset to which we will append data + new HadoopTables(spark.sessionState().newHadoopConf()) + .create( + icebergSchema, + PartitionSpec.builderFor(icebergSchema).identity("requiredField").build(), + tableProperties, + targetPath); + + // this is the initial data inside the iceberg dataset + spark + .read().schema(sparkSchema).json( + JavaSparkContext.fromSparkContext(spark.sparkContext()).parallelize(data0)) + .write().format("iceberg").mode(SaveMode.Append).save(targetPath); + + // read from parquet and append to iceberg w/ nullability check disabled + spark + .read().schema(SparkSchemaUtil.convert(icebergSchema)).parquet(sourcePath) + .write().format("iceberg").option(SparkWriteOptions.CHECK_NULLABILITY, false) + .mode(SaveMode.Append).save(targetPath); + + // read all data + List rows = spark.read().format("iceberg").load(targetPath).collectAsList(); + Assert.assertEquals("Should contain 6 rows", 6, rows.size()); + } + + @Test + public void testNullableWithSparkSqlOption() throws IOException { + Assume.assumeTrue("Spark 3.0 rejects writing nulls to a required column", spark.version().startsWith("2")); + + File location = new File(temp.newFolder("parquet"), "test"); + String sourcePath = String.format("%s/nullable_poc/sourceFolder/", location.toString()); + String targetPath = String.format("%s/nullable_poc/targetFolder/", location.toString()); + + tableProperties = ImmutableMap.of(TableProperties.WRITE_DATA_LOCATION, targetPath); + + // read this and append to iceberg dataset + spark + .read().schema(sparkSchema).json( + JavaSparkContext.fromSparkContext(spark.sparkContext()).parallelize(data1)) + .write().parquet(sourcePath); + + SparkSession newSparkSession = SparkSession.builder() + .master("local[2]") + .appName("NullableTest") + .config(SparkSQLProperties.CHECK_NULLABILITY, false) + .getOrCreate(); + + // this is our iceberg dataset to which we will append data + new HadoopTables(newSparkSession.sessionState().newHadoopConf()) + .create( + icebergSchema, + PartitionSpec.builderFor(icebergSchema).identity("requiredField").build(), + tableProperties, + targetPath); + + // this is the initial data inside the iceberg dataset + newSparkSession + .read().schema(sparkSchema).json( + JavaSparkContext.fromSparkContext(spark.sparkContext()).parallelize(data0)) + .write().format("iceberg").mode(SaveMode.Append).save(targetPath); + + // read from parquet and append to iceberg + newSparkSession + .read().schema(SparkSchemaUtil.convert(icebergSchema)).parquet(sourcePath) + .write().format("iceberg").mode(SaveMode.Append).save(targetPath); + + // read all data + List rows = newSparkSession.read().format("iceberg").load(targetPath).collectAsList(); + Assert.assertEquals("Should contain 6 rows", 6, rows.size()); + + } + + @Test + public void testFaultToleranceOnWrite() throws IOException { + File location = createTableFolder(); + Schema schema = new Schema(SUPPORTED_PRIMITIVES.fields()); + Table table = createTable(schema, location); + + Iterable records = RandomData.generate(schema, 100, 0L); + writeData(records, schema, location.toString()); + + table.refresh(); + + Snapshot snapshotBeforeFailingWrite = table.currentSnapshot(); + List resultBeforeFailingWrite = readTable(location.toString()); + + try { + Iterable records2 = RandomData.generate(schema, 100, 0L); + writeDataWithFailOnPartition(records2, schema, location.toString()); + Assert.fail("The query should fail"); + } catch (SparkException e) { + // no-op + } + + table.refresh(); + + Snapshot snapshotAfterFailingWrite = table.currentSnapshot(); + List resultAfterFailingWrite = readTable(location.toString()); + + Assert.assertEquals(snapshotAfterFailingWrite, snapshotBeforeFailingWrite); + Assert.assertEquals(resultAfterFailingWrite, resultBeforeFailingWrite); + } +} diff --git a/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/source/TestDataSourceOptions.java b/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/source/TestDataSourceOptions.java new file mode 100644 index 000000000000..00a11c55a9d1 --- /dev/null +++ b/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/source/TestDataSourceOptions.java @@ -0,0 +1,448 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iceberg.spark.source; + +import java.io.IOException; +import java.math.RoundingMode; +import java.util.List; +import java.util.Map; +import java.util.Set; +import org.apache.hadoop.conf.Configuration; +import org.apache.iceberg.AssertHelpers; +import org.apache.iceberg.DataFile; +import org.apache.iceberg.FileFormat; +import org.apache.iceberg.FileScanTask; +import org.apache.iceberg.ManifestFile; +import org.apache.iceberg.PartitionSpec; +import org.apache.iceberg.Schema; +import org.apache.iceberg.Snapshot; +import org.apache.iceberg.Table; +import org.apache.iceberg.TableProperties; +import org.apache.iceberg.hadoop.HadoopTables; +import org.apache.iceberg.io.CloseableIterable; +import org.apache.iceberg.relocated.com.google.common.collect.Lists; +import org.apache.iceberg.relocated.com.google.common.collect.Maps; +import org.apache.iceberg.relocated.com.google.common.collect.Sets; +import org.apache.iceberg.relocated.com.google.common.math.LongMath; +import org.apache.iceberg.spark.CommitMetadata; +import org.apache.iceberg.spark.SparkReadOptions; +import org.apache.iceberg.spark.SparkWriteOptions; +import org.apache.iceberg.types.Types; +import org.apache.iceberg.util.SnapshotUtil; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Encoders; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.SaveMode; +import org.apache.spark.sql.SparkSession; +import org.junit.AfterClass; +import org.junit.Assert; +import org.junit.BeforeClass; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.TemporaryFolder; + +import static org.apache.iceberg.types.Types.NestedField.optional; + +public class TestDataSourceOptions { + + private static final Configuration CONF = new Configuration(); + private static final Schema SCHEMA = new Schema( + optional(1, "id", Types.IntegerType.get()), + optional(2, "data", Types.StringType.get()) + ); + private static SparkSession spark = null; + + @Rule + public TemporaryFolder temp = new TemporaryFolder(); + + @BeforeClass + public static void startSpark() { + TestDataSourceOptions.spark = SparkSession.builder().master("local[2]").getOrCreate(); + } + + @AfterClass + public static void stopSpark() { + SparkSession currentSpark = TestDataSourceOptions.spark; + TestDataSourceOptions.spark = null; + currentSpark.stop(); + } + + @Test + public void testWriteFormatOptionOverridesTableProperties() throws IOException { + String tableLocation = temp.newFolder("iceberg-table").toString(); + + HadoopTables tables = new HadoopTables(CONF); + PartitionSpec spec = PartitionSpec.unpartitioned(); + Map options = Maps.newHashMap(); + options.put(TableProperties.DEFAULT_FILE_FORMAT, "avro"); + Table table = tables.create(SCHEMA, spec, options, tableLocation); + + List expectedRecords = Lists.newArrayList( + new SimpleRecord(1, "a"), + new SimpleRecord(2, "b"), + new SimpleRecord(3, "c") + ); + Dataset df = spark.createDataFrame(expectedRecords, SimpleRecord.class); + df.select("id", "data").write() + .format("iceberg") + .option(SparkWriteOptions.WRITE_FORMAT, "parquet") + .mode(SaveMode.Append) + .save(tableLocation); + + try (CloseableIterable tasks = table.newScan().planFiles()) { + tasks.forEach(task -> { + FileFormat fileFormat = FileFormat.fromFileName(task.file().path()); + Assert.assertEquals(FileFormat.PARQUET, fileFormat); + }); + } + } + + @Test + public void testNoWriteFormatOption() throws IOException { + String tableLocation = temp.newFolder("iceberg-table").toString(); + + HadoopTables tables = new HadoopTables(CONF); + PartitionSpec spec = PartitionSpec.unpartitioned(); + Map options = Maps.newHashMap(); + options.put(TableProperties.DEFAULT_FILE_FORMAT, "avro"); + Table table = tables.create(SCHEMA, spec, options, tableLocation); + + List expectedRecords = Lists.newArrayList( + new SimpleRecord(1, "a"), + new SimpleRecord(2, "b"), + new SimpleRecord(3, "c") + ); + Dataset df = spark.createDataFrame(expectedRecords, SimpleRecord.class); + df.select("id", "data").write() + .format("iceberg") + .mode("append") + .save(tableLocation); + + try (CloseableIterable tasks = table.newScan().planFiles()) { + tasks.forEach(task -> { + FileFormat fileFormat = FileFormat.fromFileName(task.file().path()); + Assert.assertEquals(FileFormat.AVRO, fileFormat); + }); + } + } + + @Test + public void testHadoopOptions() throws IOException { + String tableLocation = temp.newFolder("iceberg-table").toString(); + Configuration sparkHadoopConf = spark.sessionState().newHadoopConf(); + String originalDefaultFS = sparkHadoopConf.get("fs.default.name"); + + try { + HadoopTables tables = new HadoopTables(CONF); + PartitionSpec spec = PartitionSpec.unpartitioned(); + Map options = Maps.newHashMap(); + tables.create(SCHEMA, spec, options, tableLocation); + + // set an invalid value for 'fs.default.name' in Spark Hadoop config + // to verify that 'hadoop.' data source options are propagated correctly + sparkHadoopConf.set("fs.default.name", "hdfs://localhost:9000"); + + List expectedRecords = Lists.newArrayList( + new SimpleRecord(1, "a"), + new SimpleRecord(2, "b") + ); + Dataset originalDf = spark.createDataFrame(expectedRecords, SimpleRecord.class); + originalDf.select("id", "data").write() + .format("iceberg") + .mode("append") + .option("hadoop.fs.default.name", "file:///") + .save(tableLocation); + + Dataset resultDf = spark.read() + .format("iceberg") + .option("hadoop.fs.default.name", "file:///") + .load(tableLocation); + List resultRecords = resultDf.orderBy("id") + .as(Encoders.bean(SimpleRecord.class)) + .collectAsList(); + + Assert.assertEquals("Records should match", expectedRecords, resultRecords); + } finally { + sparkHadoopConf.set("fs.default.name", originalDefaultFS); + } + } + + @Test + public void testSplitOptionsOverridesTableProperties() throws IOException { + String tableLocation = temp.newFolder("iceberg-table").toString(); + + HadoopTables tables = new HadoopTables(CONF); + PartitionSpec spec = PartitionSpec.unpartitioned(); + Map options = Maps.newHashMap(); + options.put(TableProperties.SPLIT_SIZE, String.valueOf(128L * 1024 * 1024)); // 128Mb + options.put(TableProperties.DEFAULT_FILE_FORMAT, String.valueOf(FileFormat.AVRO)); // Arbitrarily splittable + Table icebergTable = tables.create(SCHEMA, spec, options, tableLocation); + + List expectedRecords = Lists.newArrayList( + new SimpleRecord(1, "a"), + new SimpleRecord(2, "b") + ); + Dataset originalDf = spark.createDataFrame(expectedRecords, SimpleRecord.class); + originalDf.select("id", "data") + .repartition(1) + .write() + .format("iceberg") + .mode("append") + .save(tableLocation); + + List files = Lists.newArrayList(icebergTable.currentSnapshot().addedFiles(icebergTable.io())); + Assert.assertEquals("Should have written 1 file", 1, files.size()); + + long fileSize = files.get(0).fileSizeInBytes(); + long splitSize = LongMath.divide(fileSize, 2, RoundingMode.CEILING); + + Dataset resultDf = spark.read() + .format("iceberg") + .option(SparkReadOptions.SPLIT_SIZE, String.valueOf(splitSize)) + .load(tableLocation); + + Assert.assertEquals("Spark partitions should match", 2, resultDf.javaRDD().getNumPartitions()); + } + + @Test + public void testIncrementalScanOptions() throws IOException { + String tableLocation = temp.newFolder("iceberg-table").toString(); + + HadoopTables tables = new HadoopTables(CONF); + PartitionSpec spec = PartitionSpec.unpartitioned(); + Map options = Maps.newHashMap(); + Table table = tables.create(SCHEMA, spec, options, tableLocation); + + List expectedRecords = Lists.newArrayList( + new SimpleRecord(1, "a"), + new SimpleRecord(2, "b"), + new SimpleRecord(3, "c"), + new SimpleRecord(4, "d") + ); + for (SimpleRecord record : expectedRecords) { + Dataset originalDf = spark.createDataFrame(Lists.newArrayList(record), SimpleRecord.class); + originalDf.select("id", "data").write() + .format("iceberg") + .mode("append") + .save(tableLocation); + } + List snapshotIds = SnapshotUtil.currentAncestorIds(table); + + // start-snapshot-id and snapshot-id are both configured. + AssertHelpers.assertThrows( + "Check both start-snapshot-id and snapshot-id are configured", + IllegalArgumentException.class, + "Cannot set start-snapshot-id and end-snapshot-id for incremental scans", + () -> { + spark.read() + .format("iceberg") + .option("snapshot-id", snapshotIds.get(3).toString()) + .option("start-snapshot-id", snapshotIds.get(3).toString()) + .load(tableLocation).explain(); + }); + + // end-snapshot-id and as-of-timestamp are both configured. + AssertHelpers.assertThrows( + "Check both start-snapshot-id and snapshot-id are configured", + IllegalArgumentException.class, + "Cannot set start-snapshot-id and end-snapshot-id for incremental scans", + () -> { + spark.read() + .format("iceberg") + .option(SparkReadOptions.AS_OF_TIMESTAMP, + Long.toString(table.snapshot(snapshotIds.get(3)).timestampMillis())) + .option("end-snapshot-id", snapshotIds.get(2).toString()) + .load(tableLocation).explain(); + }); + + // only end-snapshot-id is configured. + AssertHelpers.assertThrows( + "Check both start-snapshot-id and snapshot-id are configured", + IllegalArgumentException.class, + "Cannot set only end-snapshot-id for incremental scans", + () -> { + spark.read() + .format("iceberg") + .option("end-snapshot-id", snapshotIds.get(2).toString()) + .load(tableLocation).explain(); + }); + + // test (1st snapshot, current snapshot] incremental scan. + List result = spark.read() + .format("iceberg") + .option("start-snapshot-id", snapshotIds.get(3).toString()) + .load(tableLocation) + .orderBy("id") + .as(Encoders.bean(SimpleRecord.class)) + .collectAsList(); + Assert.assertEquals("Records should match", expectedRecords.subList(1, 4), result); + + // test (2nd snapshot, 3rd snapshot] incremental scan. + List result1 = spark.read() + .format("iceberg") + .option("start-snapshot-id", snapshotIds.get(2).toString()) + .option("end-snapshot-id", snapshotIds.get(1).toString()) + .load(tableLocation) + .orderBy("id") + .as(Encoders.bean(SimpleRecord.class)) + .collectAsList(); + Assert.assertEquals("Records should match", expectedRecords.subList(2, 3), result1); + } + + @Test + public void testMetadataSplitSizeOptionOverrideTableProperties() throws IOException { + String tableLocation = temp.newFolder("iceberg-table").toString(); + + HadoopTables tables = new HadoopTables(CONF); + PartitionSpec spec = PartitionSpec.unpartitioned(); + Map options = Maps.newHashMap(); + Table table = tables.create(SCHEMA, spec, options, tableLocation); + + List expectedRecords = Lists.newArrayList( + new SimpleRecord(1, "a"), + new SimpleRecord(2, "b") + ); + Dataset originalDf = spark.createDataFrame(expectedRecords, SimpleRecord.class); + // produce 1st manifest + originalDf.select("id", "data").write() + .format("iceberg") + .mode("append") + .save(tableLocation); + // produce 2nd manifest + originalDf.select("id", "data").write() + .format("iceberg") + .mode("append") + .save(tableLocation); + + List manifests = table.currentSnapshot().allManifests(table.io()); + + Assert.assertEquals("Must be 2 manifests", 2, manifests.size()); + + // set the target metadata split size so each manifest ends up in a separate split + table.updateProperties() + .set(TableProperties.METADATA_SPLIT_SIZE, String.valueOf(manifests.get(0).length())) + .commit(); + + Dataset entriesDf = spark.read() + .format("iceberg") + .load(tableLocation + "#entries"); + Assert.assertEquals("Num partitions must match", 2, entriesDf.javaRDD().getNumPartitions()); + + // override the table property using options + entriesDf = spark.read() + .format("iceberg") + .option(SparkReadOptions.SPLIT_SIZE, String.valueOf(128 * 1024 * 1024)) + .load(tableLocation + "#entries"); + Assert.assertEquals("Num partitions must match", 1, entriesDf.javaRDD().getNumPartitions()); + } + + @Test + public void testDefaultMetadataSplitSize() throws IOException { + String tableLocation = temp.newFolder("iceberg-table").toString(); + + HadoopTables tables = new HadoopTables(CONF); + PartitionSpec spec = PartitionSpec.unpartitioned(); + Map options = Maps.newHashMap(); + Table icebergTable = tables.create(SCHEMA, spec, options, tableLocation); + + List expectedRecords = Lists.newArrayList( + new SimpleRecord(1, "a"), + new SimpleRecord(2, "b") + ); + Dataset originalDf = spark.createDataFrame(expectedRecords, SimpleRecord.class); + originalDf.select("id", "data").write() + .format("iceberg") + .mode("append") + .save(tableLocation); + + int splitSize = (int) TableProperties.METADATA_SPLIT_SIZE_DEFAULT; // 32MB split size + + int expectedSplits = ((int) tables.load(tableLocation + "#entries") + .currentSnapshot().allManifests(icebergTable.io()).get(0).length() + splitSize - 1) / splitSize; + + Dataset metadataDf = spark.read() + .format("iceberg") + .load(tableLocation + "#entries"); + + int partitionNum = metadataDf.javaRDD().getNumPartitions(); + Assert.assertEquals("Spark partitions should match", expectedSplits, partitionNum); + } + + @Test + public void testExtraSnapshotMetadata() throws IOException { + String tableLocation = temp.newFolder("iceberg-table").toString(); + HadoopTables tables = new HadoopTables(CONF); + tables.create(SCHEMA, PartitionSpec.unpartitioned(), Maps.newHashMap(), tableLocation); + + List expectedRecords = Lists.newArrayList( + new SimpleRecord(1, "a"), + new SimpleRecord(2, "b") + ); + Dataset originalDf = spark.createDataFrame(expectedRecords, SimpleRecord.class); + originalDf.select("id", "data").write() + .format("iceberg") + .mode("append") + .option(SparkWriteOptions.SNAPSHOT_PROPERTY_PREFIX + ".extra-key", "someValue") + .option(SparkWriteOptions.SNAPSHOT_PROPERTY_PREFIX + ".another-key", "anotherValue") + .save(tableLocation); + + Table table = tables.load(tableLocation); + + Assert.assertTrue(table.currentSnapshot().summary().get("extra-key").equals("someValue")); + Assert.assertTrue(table.currentSnapshot().summary().get("another-key").equals("anotherValue")); + } + + @Test + public void testExtraSnapshotMetadataWithSQL() throws InterruptedException, IOException { + String tableLocation = temp.newFolder("iceberg-table").toString(); + HadoopTables tables = new HadoopTables(CONF); + + Table table = tables.create(SCHEMA, PartitionSpec.unpartitioned(), Maps.newHashMap(), tableLocation); + + List expectedRecords = Lists.newArrayList( + new SimpleRecord(1, "a"), + new SimpleRecord(2, "b") + ); + Dataset originalDf = spark.createDataFrame(expectedRecords, SimpleRecord.class); + originalDf.select("id", "data").write() + .format("iceberg") + .mode("append") + .save(tableLocation); + spark.read().format("iceberg").load(tableLocation).createOrReplaceTempView("target"); + Thread writerThread = new Thread(() -> { + Map properties = Maps.newHashMap(); + properties.put("writer-thread", String.valueOf(Thread.currentThread().getName())); + CommitMetadata.withCommitProperties(properties, () -> { + spark.sql("INSERT INTO target VALUES (3, 'c'), (4, 'd')"); + return 0; + }, RuntimeException.class); + }); + writerThread.setName("test-extra-commit-message-writer-thread"); + writerThread.start(); + writerThread.join(); + Set threadNames = Sets.newHashSet(); + for (Snapshot snapshot : table.snapshots()) { + threadNames.add(snapshot.summary().get("writer-thread")); + } + Assert.assertEquals(2, threadNames.size()); + Assert.assertTrue(threadNames.contains(null)); + Assert.assertTrue(threadNames.contains("test-extra-commit-message-writer-thread")); + } +} diff --git a/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/source/TestFilteredScan.java b/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/source/TestFilteredScan.java new file mode 100644 index 000000000000..d51fd3c4e8eb --- /dev/null +++ b/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/source/TestFilteredScan.java @@ -0,0 +1,620 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iceberg.spark.source; + +import java.io.File; +import java.io.IOException; +import java.sql.Timestamp; +import java.time.OffsetDateTime; +import java.util.List; +import java.util.Locale; +import java.util.UUID; +import java.util.stream.Collectors; +import org.apache.hadoop.conf.Configuration; +import org.apache.iceberg.DataFile; +import org.apache.iceberg.DataFiles; +import org.apache.iceberg.FileFormat; +import org.apache.iceberg.PartitionSpec; +import org.apache.iceberg.Schema; +import org.apache.iceberg.Table; +import org.apache.iceberg.data.GenericAppenderFactory; +import org.apache.iceberg.data.GenericRecord; +import org.apache.iceberg.data.Record; +import org.apache.iceberg.hadoop.HadoopTables; +import org.apache.iceberg.io.FileAppender; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap; +import org.apache.iceberg.relocated.com.google.common.collect.Lists; +import org.apache.iceberg.relocated.com.google.common.collect.Sets; +import org.apache.iceberg.spark.SparkReadOptions; +import org.apache.iceberg.spark.data.GenericsHelpers; +import org.apache.iceberg.transforms.Transform; +import org.apache.iceberg.transforms.Transforms; +import org.apache.iceberg.types.Types; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Encoders; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.SparkSession; +import org.apache.spark.sql.api.java.UDF1; +import org.apache.spark.sql.catalyst.expressions.UnsafeRow; +import org.apache.spark.sql.connector.read.Batch; +import org.apache.spark.sql.connector.read.InputPartition; +import org.apache.spark.sql.connector.read.ScanBuilder; +import org.apache.spark.sql.connector.read.SupportsPushDownFilters; +import org.apache.spark.sql.sources.And; +import org.apache.spark.sql.sources.EqualTo; +import org.apache.spark.sql.sources.Filter; +import org.apache.spark.sql.sources.GreaterThan; +import org.apache.spark.sql.sources.LessThan; +import org.apache.spark.sql.sources.Not; +import org.apache.spark.sql.sources.StringStartsWith; +import org.apache.spark.sql.types.IntegerType$; +import org.apache.spark.sql.types.LongType$; +import org.apache.spark.sql.types.StringType$; +import org.apache.spark.sql.util.CaseInsensitiveStringMap; +import org.assertj.core.api.Assertions; +import org.junit.AfterClass; +import org.junit.Assert; +import org.junit.Before; +import org.junit.BeforeClass; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.TemporaryFolder; +import org.junit.runner.RunWith; +import org.junit.runners.Parameterized; + +import static org.apache.iceberg.Files.localOutput; +import static org.apache.spark.sql.catalyst.util.DateTimeUtils.fromJavaTimestamp; +import static org.apache.spark.sql.functions.callUDF; +import static org.apache.spark.sql.functions.column; + +@RunWith(Parameterized.class) +public class TestFilteredScan { + private static final Configuration CONF = new Configuration(); + private static final HadoopTables TABLES = new HadoopTables(CONF); + + private static final Schema SCHEMA = new Schema( + Types.NestedField.required(1, "id", Types.LongType.get()), + Types.NestedField.optional(2, "ts", Types.TimestampType.withZone()), + Types.NestedField.optional(3, "data", Types.StringType.get()) + ); + + private static final PartitionSpec BUCKET_BY_ID = PartitionSpec.builderFor(SCHEMA) + .bucket("id", 4) + .build(); + + private static final PartitionSpec PARTITION_BY_DAY = PartitionSpec.builderFor(SCHEMA) + .day("ts") + .build(); + + private static final PartitionSpec PARTITION_BY_HOUR = PartitionSpec.builderFor(SCHEMA) + .hour("ts") + .build(); + + private static final PartitionSpec PARTITION_BY_DATA = PartitionSpec.builderFor(SCHEMA) + .identity("data") + .build(); + + private static final PartitionSpec PARTITION_BY_ID = PartitionSpec.builderFor(SCHEMA) + .identity("id") + .build(); + + private static SparkSession spark = null; + + @BeforeClass + public static void startSpark() { + TestFilteredScan.spark = SparkSession.builder().master("local[2]").getOrCreate(); + + // define UDFs used by partition tests + Transform bucket4 = Transforms.bucket(Types.LongType.get(), 4); + spark.udf().register("bucket4", (UDF1) bucket4::apply, IntegerType$.MODULE$); + + Transform day = Transforms.day(Types.TimestampType.withZone()); + spark.udf().register("ts_day", + (UDF1) timestamp -> day.apply((Long) fromJavaTimestamp(timestamp)), + IntegerType$.MODULE$); + + Transform hour = Transforms.hour(Types.TimestampType.withZone()); + spark.udf().register("ts_hour", + (UDF1) timestamp -> hour.apply((Long) fromJavaTimestamp(timestamp)), + IntegerType$.MODULE$); + + spark.udf().register("data_ident", (UDF1) data -> data, StringType$.MODULE$); + spark.udf().register("id_ident", (UDF1) id -> id, LongType$.MODULE$); + } + + @AfterClass + public static void stopSpark() { + SparkSession currentSpark = TestFilteredScan.spark; + TestFilteredScan.spark = null; + currentSpark.stop(); + } + + @Rule + public TemporaryFolder temp = new TemporaryFolder(); + + private final String format; + private final boolean vectorized; + + @Parameterized.Parameters(name = "format = {0}, vectorized = {1}") + public static Object[][] parameters() { + return new Object[][] { + { "parquet", false }, + { "parquet", true }, + { "avro", false }, + { "orc", false }, + { "orc", true } + }; + } + + public TestFilteredScan(String format, boolean vectorized) { + this.format = format; + this.vectorized = vectorized; + } + + private File parent = null; + private File unpartitioned = null; + private List records = null; + + @Before + public void writeUnpartitionedTable() throws IOException { + this.parent = temp.newFolder("TestFilteredScan"); + this.unpartitioned = new File(parent, "unpartitioned"); + File dataFolder = new File(unpartitioned, "data"); + Assert.assertTrue("Mkdir should succeed", dataFolder.mkdirs()); + + Table table = TABLES.create(SCHEMA, PartitionSpec.unpartitioned(), unpartitioned.toString()); + Schema tableSchema = table.schema(); // use the table schema because ids are reassigned + + FileFormat fileFormat = FileFormat.valueOf(format.toUpperCase(Locale.ENGLISH)); + + File testFile = new File(dataFolder, fileFormat.addExtension(UUID.randomUUID().toString())); + + this.records = testRecords(tableSchema); + + try (FileAppender writer = new GenericAppenderFactory(tableSchema).newAppender( + localOutput(testFile), fileFormat)) { + writer.addAll(records); + } + + DataFile file = DataFiles.builder(PartitionSpec.unpartitioned()) + .withRecordCount(records.size()) + .withFileSizeInBytes(testFile.length()) + .withPath(testFile.toString()) + .build(); + + table.newAppend().appendFile(file).commit(); + } + + @Test + public void testUnpartitionedIDFilters() { + CaseInsensitiveStringMap options = new CaseInsensitiveStringMap(ImmutableMap.of( + "path", unpartitioned.toString()) + ); + SparkScanBuilder builder = new SparkScanBuilder(spark, TABLES.load(options.get("path")), options); + + for (int i = 0; i < 10; i += 1) { + pushFilters(builder, EqualTo.apply("id", i)); + Batch scan = builder.build().toBatch(); + + InputPartition[] partitions = scan.planInputPartitions(); + Assert.assertEquals("Should only create one task for a small file", 1, partitions.length); + + // validate row filtering + assertEqualsSafe(SCHEMA.asStruct(), expected(i), + read(unpartitioned.toString(), vectorized, "id = " + i)); + } + } + + @Test + public void testUnpartitionedCaseInsensitiveIDFilters() { + CaseInsensitiveStringMap options = new CaseInsensitiveStringMap(ImmutableMap.of( + "path", unpartitioned.toString()) + ); + + // set spark.sql.caseSensitive to false + String caseSensitivityBeforeTest = TestFilteredScan.spark.conf().get("spark.sql.caseSensitive"); + TestFilteredScan.spark.conf().set("spark.sql.caseSensitive", "false"); + + try { + + for (int i = 0; i < 10; i += 1) { + SparkScanBuilder builder = new SparkScanBuilder(spark, TABLES.load(options.get("path")), options) + .caseSensitive(false); + + pushFilters(builder, EqualTo.apply("ID", i)); // note lower(ID) == lower(id), so there must be a match + Batch scan = builder.build().toBatch(); + + InputPartition[] tasks = scan.planInputPartitions(); + Assert.assertEquals("Should only create one task for a small file", 1, tasks.length); + + // validate row filtering + assertEqualsSafe(SCHEMA.asStruct(), expected(i), + read(unpartitioned.toString(), vectorized, "id = " + i)); + } + } finally { + // return global conf to previous state + TestFilteredScan.spark.conf().set("spark.sql.caseSensitive", caseSensitivityBeforeTest); + } + } + + @Test + public void testUnpartitionedTimestampFilter() { + CaseInsensitiveStringMap options = new CaseInsensitiveStringMap(ImmutableMap.of( + "path", unpartitioned.toString()) + ); + + SparkScanBuilder builder = new SparkScanBuilder(spark, TABLES.load(options.get("path")), options); + + pushFilters(builder, LessThan.apply("ts", "2017-12-22T00:00:00+00:00")); + Batch scan = builder.build().toBatch(); + + InputPartition[] tasks = scan.planInputPartitions(); + Assert.assertEquals("Should only create one task for a small file", 1, tasks.length); + + assertEqualsSafe(SCHEMA.asStruct(), expected(5, 6, 7, 8, 9), + read(unpartitioned.toString(), vectorized, "ts < cast('2017-12-22 00:00:00+00:00' as timestamp)")); + } + + @Test + public void testBucketPartitionedIDFilters() { + Table table = buildPartitionedTable("bucketed_by_id", BUCKET_BY_ID, "bucket4", "id"); + CaseInsensitiveStringMap options = new CaseInsensitiveStringMap(ImmutableMap.of("path", table.location())); + + Batch unfiltered = new SparkScanBuilder(spark, TABLES.load(options.get("path")), options).build().toBatch(); + Assert.assertEquals("Unfiltered table should created 4 read tasks", + 4, unfiltered.planInputPartitions().length); + + for (int i = 0; i < 10; i += 1) { + SparkScanBuilder builder = new SparkScanBuilder(spark, TABLES.load(options.get("path")), options); + + pushFilters(builder, EqualTo.apply("id", i)); + Batch scan = builder.build().toBatch(); + + InputPartition[] tasks = scan.planInputPartitions(); + + // validate predicate push-down + Assert.assertEquals("Should create one task for a single bucket", 1, tasks.length); + + // validate row filtering + assertEqualsSafe(SCHEMA.asStruct(), expected(i), read(table.location(), vectorized, "id = " + i)); + } + } + + @SuppressWarnings("checkstyle:AvoidNestedBlocks") + @Test + public void testDayPartitionedTimestampFilters() { + Table table = buildPartitionedTable("partitioned_by_day", PARTITION_BY_DAY, "ts_day", "ts"); + CaseInsensitiveStringMap options = new CaseInsensitiveStringMap(ImmutableMap.of("path", table.location())); + Batch unfiltered = new SparkScanBuilder(spark, TABLES.load(options.get("path")), options).build().toBatch(); + + Assert.assertEquals("Unfiltered table should created 2 read tasks", + 2, unfiltered.planInputPartitions().length); + + { + SparkScanBuilder builder = new SparkScanBuilder(spark, TABLES.load(options.get("path")), options); + + pushFilters(builder, LessThan.apply("ts", "2017-12-22T00:00:00+00:00")); + Batch scan = builder.build().toBatch(); + + InputPartition[] tasks = scan.planInputPartitions(); + Assert.assertEquals("Should create one task for 2017-12-21", 1, tasks.length); + + assertEqualsSafe(SCHEMA.asStruct(), expected(5, 6, 7, 8, 9), + read(table.location(), vectorized, "ts < cast('2017-12-22 00:00:00+00:00' as timestamp)")); + } + + { + SparkScanBuilder builder = new SparkScanBuilder(spark, TABLES.load(options.get("path")), options); + + pushFilters(builder, And.apply( + GreaterThan.apply("ts", "2017-12-22T06:00:00+00:00"), + LessThan.apply("ts", "2017-12-22T08:00:00+00:00"))); + Batch scan = builder.build().toBatch(); + + InputPartition[] tasks = scan.planInputPartitions(); + Assert.assertEquals("Should create one task for 2017-12-22", 1, tasks.length); + + assertEqualsSafe(SCHEMA.asStruct(), expected(1, 2), read(table.location(), vectorized, + "ts > cast('2017-12-22 06:00:00+00:00' as timestamp) and " + + "ts < cast('2017-12-22 08:00:00+00:00' as timestamp)")); + } + } + + @SuppressWarnings("checkstyle:AvoidNestedBlocks") + @Test + public void testHourPartitionedTimestampFilters() { + Table table = buildPartitionedTable("partitioned_by_hour", PARTITION_BY_HOUR, "ts_hour", "ts"); + + CaseInsensitiveStringMap options = new CaseInsensitiveStringMap(ImmutableMap.of("path", table.location())); + Batch unfiltered = new SparkScanBuilder(spark, TABLES.load(options.get("path")), options).build().toBatch(); + + Assert.assertEquals("Unfiltered table should created 9 read tasks", + 9, unfiltered.planInputPartitions().length); + + { + SparkScanBuilder builder = new SparkScanBuilder(spark, TABLES.load(options.get("path")), options); + + pushFilters(builder, LessThan.apply("ts", "2017-12-22T00:00:00+00:00")); + Batch scan = builder.build().toBatch(); + + InputPartition[] tasks = scan.planInputPartitions(); + Assert.assertEquals("Should create 4 tasks for 2017-12-21: 15, 17, 21, 22", 4, tasks.length); + + assertEqualsSafe(SCHEMA.asStruct(), expected(8, 9, 7, 6, 5), + read(table.location(), vectorized, "ts < cast('2017-12-22 00:00:00+00:00' as timestamp)")); + } + + { + SparkScanBuilder builder = new SparkScanBuilder(spark, TABLES.load(options.get("path")), options); + + pushFilters(builder, And.apply( + GreaterThan.apply("ts", "2017-12-22T06:00:00+00:00"), + LessThan.apply("ts", "2017-12-22T08:00:00+00:00"))); + Batch scan = builder.build().toBatch(); + + InputPartition[] tasks = scan.planInputPartitions(); + Assert.assertEquals("Should create 2 tasks for 2017-12-22: 6, 7", 2, tasks.length); + + assertEqualsSafe(SCHEMA.asStruct(), expected(2, 1), read(table.location(), vectorized, + "ts > cast('2017-12-22 06:00:00+00:00' as timestamp) and " + + "ts < cast('2017-12-22 08:00:00+00:00' as timestamp)")); + } + } + + @SuppressWarnings("checkstyle:AvoidNestedBlocks") + @Test + public void testFilterByNonProjectedColumn() { + { + Schema actualProjection = SCHEMA.select("id", "data"); + List expected = Lists.newArrayList(); + for (Record rec : expected(5, 6, 7, 8, 9)) { + expected.add(projectFlat(actualProjection, rec)); + } + + assertEqualsSafe(actualProjection.asStruct(), expected, read( + unpartitioned.toString(), vectorized, + "ts < cast('2017-12-22 00:00:00+00:00' as timestamp)", + "id", "data")); + } + + { + // only project id: ts will be projected because of the filter, but data will not be included + + Schema actualProjection = SCHEMA.select("id"); + List expected = Lists.newArrayList(); + for (Record rec : expected(1, 2)) { + expected.add(projectFlat(actualProjection, rec)); + } + + assertEqualsSafe(actualProjection.asStruct(), expected, read( + unpartitioned.toString(), vectorized, + "ts > cast('2017-12-22 06:00:00+00:00' as timestamp) and " + + "ts < cast('2017-12-22 08:00:00+00:00' as timestamp)", + "id")); + } + } + + @Test + public void testPartitionedByDataStartsWithFilter() { + Table table = buildPartitionedTable("partitioned_by_data", PARTITION_BY_DATA, "data_ident", "data"); + CaseInsensitiveStringMap options = new CaseInsensitiveStringMap(ImmutableMap.of("path", table.location())); + + SparkScanBuilder builder = new SparkScanBuilder(spark, TABLES.load(options.get("path")), options); + + pushFilters(builder, new StringStartsWith("data", "junc")); + Batch scan = builder.build().toBatch(); + + Assert.assertEquals(1, scan.planInputPartitions().length); + } + + @Test + public void testPartitionedByDataNotStartsWithFilter() { + Table table = buildPartitionedTable("partitioned_by_data", PARTITION_BY_DATA, "data_ident", "data"); + CaseInsensitiveStringMap options = new CaseInsensitiveStringMap(ImmutableMap.of("path", table.location())); + + SparkScanBuilder builder = new SparkScanBuilder(spark, TABLES.load(options.get("path")), options); + + pushFilters(builder, new Not(new StringStartsWith("data", "junc"))); + Batch scan = builder.build().toBatch(); + + Assert.assertEquals(9, scan.planInputPartitions().length); + } + + @Test + public void testPartitionedByIdStartsWith() { + Table table = buildPartitionedTable("partitioned_by_id", PARTITION_BY_ID, "id_ident", "id"); + + CaseInsensitiveStringMap options = new CaseInsensitiveStringMap(ImmutableMap.of( + "path", table.location()) + ); + + SparkScanBuilder builder = new SparkScanBuilder(spark, TABLES.load(options.get("path")), options); + + pushFilters(builder, new StringStartsWith("data", "junc")); + Batch scan = builder.build().toBatch(); + + Assert.assertEquals(1, scan.planInputPartitions().length); + } + + @Test + public void testPartitionedByIdNotStartsWith() { + Table table = buildPartitionedTable("partitioned_by_id", PARTITION_BY_ID, "id_ident", "id"); + + CaseInsensitiveStringMap options = new CaseInsensitiveStringMap(ImmutableMap.of( + "path", table.location()) + ); + + SparkScanBuilder builder = new SparkScanBuilder(spark, TABLES.load(options.get("path")), options); + + pushFilters(builder, new Not(new StringStartsWith("data", "junc"))); + Batch scan = builder.build().toBatch(); + + Assert.assertEquals(9, scan.planInputPartitions().length); + } + + @Test + public void testUnpartitionedStartsWith() { + Dataset df = spark.read() + .format("iceberg") + .option(SparkReadOptions.VECTORIZATION_ENABLED, String.valueOf(vectorized)) + .load(unpartitioned.toString()); + + List matchedData = df.select("data") + .where("data LIKE 'jun%'") + .as(Encoders.STRING()) + .collectAsList(); + + Assert.assertEquals(1, matchedData.size()); + Assert.assertEquals("junction", matchedData.get(0)); + } + + @Test + public void testUnpartitionedNotStartsWith() { + Dataset df = spark.read() + .format("iceberg") + .option(SparkReadOptions.VECTORIZATION_ENABLED, String.valueOf(vectorized)) + .load(unpartitioned.toString()); + + List matchedData = df.select("data") + .where("data NOT LIKE 'jun%'") + .as(Encoders.STRING()) + .collectAsList(); + + List expected = testRecords(SCHEMA).stream() + .map(r -> r.getField("data").toString()) + .filter(d -> !d.startsWith("jun")) + .collect(Collectors.toList()); + + Assert.assertEquals(9, matchedData.size()); + Assert.assertEquals(Sets.newHashSet(expected), Sets.newHashSet(matchedData)); + } + + private static Record projectFlat(Schema projection, Record record) { + Record result = GenericRecord.create(projection); + List fields = projection.asStruct().fields(); + for (int i = 0; i < fields.size(); i += 1) { + Types.NestedField field = fields.get(i); + result.set(i, record.getField(field.name())); + } + return result; + } + + public static void assertEqualsUnsafe(Types.StructType struct, + List expected, List actual) { + // TODO: match records by ID + int numRecords = Math.min(expected.size(), actual.size()); + for (int i = 0; i < numRecords; i += 1) { + GenericsHelpers.assertEqualsUnsafe(struct, expected.get(i), actual.get(i)); + } + Assert.assertEquals("Number of results should match expected", expected.size(), actual.size()); + } + + public static void assertEqualsSafe(Types.StructType struct, + List expected, List actual) { + // TODO: match records by ID + int numRecords = Math.min(expected.size(), actual.size()); + for (int i = 0; i < numRecords; i += 1) { + GenericsHelpers.assertEqualsSafe(struct, expected.get(i), actual.get(i)); + } + Assert.assertEquals("Number of results should match expected", expected.size(), actual.size()); + } + + private List expected(int... ordinals) { + List expected = Lists.newArrayListWithExpectedSize(ordinals.length); + for (int ord : ordinals) { + expected.add(records.get(ord)); + } + return expected; + } + + private void pushFilters(ScanBuilder scan, Filter... filters) { + Assertions.assertThat(scan).isInstanceOf(SupportsPushDownFilters.class); + SupportsPushDownFilters filterable = (SupportsPushDownFilters) scan; + filterable.pushFilters(filters); + } + + private Table buildPartitionedTable(String desc, PartitionSpec spec, String udf, String partitionColumn) { + File location = new File(parent, desc); + Table table = TABLES.create(SCHEMA, spec, location.toString()); + + // Do not combine or split files because the tests expect a split per partition. + // A target split size of 2048 helps us achieve that. + table.updateProperties().set("read.split.target-size", "2048").commit(); + + // copy the unpartitioned table into the partitioned table to produce the partitioned data + Dataset allRows = spark.read() + .format("iceberg") + .option(SparkReadOptions.VECTORIZATION_ENABLED, String.valueOf(vectorized)) + .load(unpartitioned.toString()); + + allRows + .coalesce(1) // ensure only 1 file per partition is written + .withColumn("part", callUDF(udf, column(partitionColumn))) + .sortWithinPartitions("part") + .drop("part") + .write() + .format("iceberg") + .mode("append") + .save(table.location()); + + table.refresh(); + + return table; + } + + private List testRecords(Schema schema) { + return Lists.newArrayList( + record(schema, 0L, parse("2017-12-22T09:20:44.294658+00:00"), "junction"), + record(schema, 1L, parse("2017-12-22T07:15:34.582910+00:00"), "alligator"), + record(schema, 2L, parse("2017-12-22T06:02:09.243857+00:00"), ""), + record(schema, 3L, parse("2017-12-22T03:10:11.134509+00:00"), "clapping"), + record(schema, 4L, parse("2017-12-22T00:34:00.184671+00:00"), "brush"), + record(schema, 5L, parse("2017-12-21T22:20:08.935889+00:00"), "trap"), + record(schema, 6L, parse("2017-12-21T21:55:30.589712+00:00"), "element"), + record(schema, 7L, parse("2017-12-21T17:31:14.532797+00:00"), "limited"), + record(schema, 8L, parse("2017-12-21T15:21:51.237521+00:00"), "global"), + record(schema, 9L, parse("2017-12-21T15:02:15.230570+00:00"), "goldfish") + ); + } + + private static List read(String table, boolean vectorized, String expr) { + return read(table, vectorized, expr, "*"); + } + + private static List read(String table, boolean vectorized, String expr, String select0, String... selectN) { + Dataset dataset = spark.read().format("iceberg") + .option(SparkReadOptions.VECTORIZATION_ENABLED, String.valueOf(vectorized)) + .load(table).filter(expr) + .select(select0, selectN); + return dataset.collectAsList(); + } + + private static OffsetDateTime parse(String timestamp) { + return OffsetDateTime.parse(timestamp); + } + + private static Record record(Schema schema, Object... values) { + Record rec = GenericRecord.create(schema); + for (int i = 0; i < values.length; i += 1) { + rec.set(i, values[i]); + } + return rec; + } +} diff --git a/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/source/TestForwardCompatibility.java b/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/source/TestForwardCompatibility.java new file mode 100644 index 000000000000..74203fd20f9c --- /dev/null +++ b/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/source/TestForwardCompatibility.java @@ -0,0 +1,217 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iceberg.spark.source; + +import java.io.File; +import java.io.IOException; +import java.util.List; +import java.util.UUID; +import java.util.concurrent.TimeoutException; +import org.apache.avro.generic.GenericData; +import org.apache.hadoop.conf.Configuration; +import org.apache.iceberg.AssertHelpers; +import org.apache.iceberg.DataFile; +import org.apache.iceberg.DataFiles; +import org.apache.iceberg.FileFormat; +import org.apache.iceberg.ManifestFiles; +import org.apache.iceberg.ManifestWriter; +import org.apache.iceberg.PartitionSpec; +import org.apache.iceberg.PartitionSpecParser; +import org.apache.iceberg.Schema; +import org.apache.iceberg.Table; +import org.apache.iceberg.TableProperties; +import org.apache.iceberg.hadoop.HadoopTables; +import org.apache.iceberg.io.FileAppender; +import org.apache.iceberg.io.OutputFile; +import org.apache.iceberg.parquet.Parquet; +import org.apache.iceberg.relocated.com.google.common.collect.Lists; +import org.apache.iceberg.spark.data.RandomData; +import org.apache.iceberg.spark.data.TestHelpers; +import org.apache.iceberg.types.Types; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Encoder; +import org.apache.spark.sql.Encoders; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.SQLContext; +import org.apache.spark.sql.SparkSession; +import org.apache.spark.sql.execution.streaming.MemoryStream; +import org.apache.spark.sql.streaming.StreamingQuery; +import org.apache.spark.sql.streaming.StreamingQueryException; +import org.junit.AfterClass; +import org.junit.Assert; +import org.junit.BeforeClass; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.TemporaryFolder; +import scala.Option; +import scala.collection.JavaConverters; + +import static org.apache.iceberg.Files.localInput; +import static org.apache.iceberg.Files.localOutput; + +public class TestForwardCompatibility { + private static final Configuration CONF = new Configuration(); + + private static final Schema SCHEMA = new Schema( + Types.NestedField.optional(1, "id", Types.LongType.get()), + Types.NestedField.optional(2, "data", Types.StringType.get())); + + // create a spec for the schema that uses a "zero" transform that produces all 0s + private static final PartitionSpec UNKNOWN_SPEC = PartitionSpecParser.fromJson(SCHEMA, + "{ \"spec-id\": 0, \"fields\": [ { \"name\": \"id_zero\", \"transform\": \"zero\", \"source-id\": 1 } ] }"); + // create a fake spec to use to write table metadata + private static final PartitionSpec FAKE_SPEC = PartitionSpecParser.fromJson(SCHEMA, + "{ \"spec-id\": 0, \"fields\": [ { \"name\": \"id_zero\", \"transform\": \"identity\", \"source-id\": 1 } ] }"); + + @Rule + public TemporaryFolder temp = new TemporaryFolder(); + + private static SparkSession spark = null; + + @BeforeClass + public static void startSpark() { + TestForwardCompatibility.spark = SparkSession.builder().master("local[2]").getOrCreate(); + } + + @AfterClass + public static void stopSpark() { + SparkSession currentSpark = TestForwardCompatibility.spark; + TestForwardCompatibility.spark = null; + currentSpark.stop(); + } + + @Test + public void testSparkWriteFailsUnknownTransform() throws IOException { + File parent = temp.newFolder("avro"); + File location = new File(parent, "test"); + File dataFolder = new File(location, "data"); + dataFolder.mkdirs(); + + HadoopTables tables = new HadoopTables(CONF); + tables.create(SCHEMA, UNKNOWN_SPEC, location.toString()); + + List expected = Lists.newArrayList( + new SimpleRecord(1, "a"), + new SimpleRecord(2, "b"), + new SimpleRecord(3, "c") + ); + + Dataset df = spark.createDataFrame(expected, SimpleRecord.class); + + AssertHelpers.assertThrows("Should reject write with unsupported transform", + UnsupportedOperationException.class, "Cannot write using unsupported transforms: zero", + () -> df.select("id", "data").write() + .format("iceberg") + .mode("append") + .save(location.toString())); + } + + @Test + public void testSparkStreamingWriteFailsUnknownTransform() throws IOException, TimeoutException { + File parent = temp.newFolder("avro"); + File location = new File(parent, "test"); + File dataFolder = new File(location, "data"); + dataFolder.mkdirs(); + File checkpoint = new File(parent, "checkpoint"); + checkpoint.mkdirs(); + + HadoopTables tables = new HadoopTables(CONF); + tables.create(SCHEMA, UNKNOWN_SPEC, location.toString()); + + MemoryStream inputStream = newMemoryStream(1, spark.sqlContext(), Encoders.INT()); + StreamingQuery query = inputStream.toDF() + .selectExpr("value AS id", "CAST (value AS STRING) AS data") + .writeStream() + .outputMode("append") + .format("iceberg") + .option("checkpointLocation", checkpoint.toString()) + .option("path", location.toString()) + .start(); + + List batch1 = Lists.newArrayList(1, 2); + send(batch1, inputStream); + + AssertHelpers.assertThrows("Should reject streaming write with unsupported transform", + StreamingQueryException.class, "Cannot write using unsupported transforms: zero", + query::processAllAvailable); + } + + @Test + public void testSparkCanReadUnknownTransform() throws IOException { + File parent = temp.newFolder("avro"); + File location = new File(parent, "test"); + File dataFolder = new File(location, "data"); + dataFolder.mkdirs(); + + HadoopTables tables = new HadoopTables(CONF); + Table table = tables.create(SCHEMA, UNKNOWN_SPEC, location.toString()); + + // enable snapshot inheritance to avoid rewriting the manifest with an unknown transform + table.updateProperties().set(TableProperties.SNAPSHOT_ID_INHERITANCE_ENABLED, "true").commit(); + + List expected = RandomData.generateList(table.schema(), 100, 1L); + + File parquetFile = new File(dataFolder, + FileFormat.PARQUET.addExtension(UUID.randomUUID().toString())); + FileAppender writer = Parquet.write(localOutput(parquetFile)) + .schema(table.schema()) + .build(); + try { + writer.addAll(expected); + } finally { + writer.close(); + } + + DataFile file = DataFiles.builder(FAKE_SPEC) + .withInputFile(localInput(parquetFile)) + .withMetrics(writer.metrics()) + .withPartitionPath("id_zero=0") + .build(); + + OutputFile manifestFile = localOutput(FileFormat.AVRO.addExtension(temp.newFile().toString())); + ManifestWriter manifestWriter = ManifestFiles.write(FAKE_SPEC, manifestFile); + try { + manifestWriter.add(file); + } finally { + manifestWriter.close(); + } + + table.newFastAppend().appendManifest(manifestWriter.toManifestFile()).commit(); + + Dataset df = spark.read() + .format("iceberg") + .load(location.toString()); + + List rows = df.collectAsList(); + Assert.assertEquals("Should contain 100 rows", 100, rows.size()); + + for (int i = 0; i < expected.size(); i += 1) { + TestHelpers.assertEqualsSafe(table.schema().asStruct(), expected.get(i), rows.get(i)); + } + } + + private MemoryStream newMemoryStream(int id, SQLContext sqlContext, Encoder encoder) { + return new MemoryStream<>(id, sqlContext, Option.empty(), encoder); + } + + private void send(List records, MemoryStream stream) { + stream.addData(JavaConverters.asScalaBuffer(records)); + } +} diff --git a/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/source/TestIcebergSource.java b/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/source/TestIcebergSource.java new file mode 100644 index 000000000000..42f53d585601 --- /dev/null +++ b/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/source/TestIcebergSource.java @@ -0,0 +1,44 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iceberg.spark.source; + +import org.apache.iceberg.catalog.TableIdentifier; +import org.apache.spark.sql.SparkSession; +import org.apache.spark.sql.connector.catalog.Identifier; +import org.apache.spark.sql.util.CaseInsensitiveStringMap; + +public class TestIcebergSource extends IcebergSource { + @Override + public String shortName() { + return "iceberg-test"; + } + + @Override + public Identifier extractIdentifier(CaseInsensitiveStringMap options) { + TableIdentifier ti = TableIdentifier.parse(options.get("iceberg.table.name")); + return Identifier.of(ti.namespace().levels(), ti.name()); + } + + @Override + public String extractCatalog(CaseInsensitiveStringMap options) { + return SparkSession.active().sessionState().catalogManager().currentCatalog().name(); + } + +} diff --git a/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/source/TestIcebergSourceHadoopTables.java b/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/source/TestIcebergSourceHadoopTables.java new file mode 100644 index 000000000000..f1cfc7a72e17 --- /dev/null +++ b/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/source/TestIcebergSourceHadoopTables.java @@ -0,0 +1,68 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iceberg.spark.source; + +import java.io.File; +import org.apache.hadoop.conf.Configuration; +import org.apache.iceberg.PartitionSpec; +import org.apache.iceberg.Schema; +import org.apache.iceberg.Table; +import org.apache.iceberg.catalog.TableIdentifier; +import org.apache.iceberg.hadoop.HadoopTables; +import org.junit.Before; + +public class TestIcebergSourceHadoopTables extends TestIcebergSourceTablesBase { + + private static final HadoopTables TABLES = new HadoopTables(new Configuration()); + + File tableDir = null; + String tableLocation = null; + + @Before + public void setupTable() throws Exception { + this.tableDir = temp.newFolder(); + tableDir.delete(); // created by table create + + this.tableLocation = tableDir.toURI().toString(); + } + + @Override + public Table createTable(TableIdentifier ident, Schema schema, PartitionSpec spec) { + if (spec.equals(PartitionSpec.unpartitioned())) { + return TABLES.create(schema, tableLocation); + } + return TABLES.create(schema, spec, tableLocation); + } + + @Override + public Table loadTable(TableIdentifier ident, String entriesSuffix) { + return TABLES.load(loadLocation(ident, entriesSuffix)); + } + + @Override + public String loadLocation(TableIdentifier ident, String entriesSuffix) { + return String.format("%s#%s", loadLocation(ident), entriesSuffix); + } + + @Override + public String loadLocation(TableIdentifier ident) { + return tableLocation; + } +} diff --git a/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/source/TestIcebergSourceHiveTables.java b/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/source/TestIcebergSourceHiveTables.java new file mode 100644 index 000000000000..a51f9ee85e2f --- /dev/null +++ b/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/source/TestIcebergSourceHiveTables.java @@ -0,0 +1,75 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iceberg.spark.source; + +import java.io.IOException; +import org.apache.hadoop.fs.FileSystem; +import org.apache.hadoop.fs.Path; +import org.apache.iceberg.PartitionSpec; +import org.apache.iceberg.Schema; +import org.apache.iceberg.Table; +import org.apache.iceberg.catalog.Namespace; +import org.apache.iceberg.catalog.TableIdentifier; +import org.junit.After; +import org.junit.BeforeClass; + +public class TestIcebergSourceHiveTables extends TestIcebergSourceTablesBase { + + private static TableIdentifier currentIdentifier; + + @BeforeClass + public static void start() { + Namespace db = Namespace.of("db"); + if (!catalog.namespaceExists(db)) { + catalog.createNamespace(db); + } + } + + @After + public void dropTable() throws IOException { + Table table = catalog.loadTable(currentIdentifier); + Path tablePath = new Path(table.location()); + FileSystem fs = tablePath.getFileSystem(spark.sessionState().newHadoopConf()); + fs.delete(tablePath, true); + catalog.dropTable(currentIdentifier, false); + } + + @Override + public Table createTable(TableIdentifier ident, Schema schema, PartitionSpec spec) { + TestIcebergSourceHiveTables.currentIdentifier = ident; + return TestIcebergSourceHiveTables.catalog.createTable(ident, schema, spec); + } + + @Override + public Table loadTable(TableIdentifier ident, String entriesSuffix) { + TableIdentifier identifier = TableIdentifier.of(ident.namespace().level(0), ident.name(), entriesSuffix); + return TestIcebergSourceHiveTables.catalog.loadTable(identifier); + } + + @Override + public String loadLocation(TableIdentifier ident, String entriesSuffix) { + return String.format("%s.%s", loadLocation(ident), entriesSuffix); + } + + @Override + public String loadLocation(TableIdentifier ident) { + return ident.toString(); + } +} diff --git a/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/source/TestIcebergSourceTablesBase.java b/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/source/TestIcebergSourceTablesBase.java new file mode 100644 index 000000000000..d055c70cea1f --- /dev/null +++ b/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/source/TestIcebergSourceTablesBase.java @@ -0,0 +1,1544 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iceberg.spark.source; + +import java.io.IOException; +import java.io.UncheckedIOException; +import java.util.Comparator; +import java.util.List; +import java.util.stream.Collectors; +import org.apache.avro.generic.GenericData; +import org.apache.avro.generic.GenericRecordBuilder; +import org.apache.iceberg.AssertHelpers; +import org.apache.iceberg.DataFile; +import org.apache.iceberg.DeleteFile; +import org.apache.iceberg.FileContent; +import org.apache.iceberg.ManifestFile; +import org.apache.iceberg.PartitionSpec; +import org.apache.iceberg.Schema; +import org.apache.iceberg.Snapshot; +import org.apache.iceberg.StructLike; +import org.apache.iceberg.Table; +import org.apache.iceberg.TableProperties; +import org.apache.iceberg.actions.DeleteOrphanFiles; +import org.apache.iceberg.avro.Avro; +import org.apache.iceberg.avro.AvroSchemaUtil; +import org.apache.iceberg.catalog.TableIdentifier; +import org.apache.iceberg.deletes.PositionDelete; +import org.apache.iceberg.deletes.PositionDeleteWriter; +import org.apache.iceberg.encryption.EncryptedOutputFile; +import org.apache.iceberg.expressions.Expressions; +import org.apache.iceberg.io.CloseableIterable; +import org.apache.iceberg.io.InputFile; +import org.apache.iceberg.io.OutputFileFactory; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableList; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap; +import org.apache.iceberg.relocated.com.google.common.collect.Iterables; +import org.apache.iceberg.relocated.com.google.common.collect.Lists; +import org.apache.iceberg.spark.SparkReadOptions; +import org.apache.iceberg.spark.SparkSchemaUtil; +import org.apache.iceberg.spark.SparkTableUtil; +import org.apache.iceberg.spark.SparkTestBase; +import org.apache.iceberg.spark.actions.SparkActions; +import org.apache.iceberg.spark.data.TestHelpers; +import org.apache.iceberg.types.Types; +import org.apache.spark.SparkException; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Encoders; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.RowFactory; +import org.apache.spark.sql.SaveMode; +import org.apache.spark.sql.catalyst.InternalRow; +import org.apache.spark.sql.types.StructType; +import org.junit.Assert; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.TemporaryFolder; + +import static org.apache.iceberg.ManifestContent.DATA; +import static org.apache.iceberg.ManifestContent.DELETES; +import static org.apache.iceberg.types.Types.NestedField.optional; +import static org.apache.iceberg.types.Types.NestedField.required; + +public abstract class TestIcebergSourceTablesBase extends SparkTestBase { + + private static final Schema SCHEMA = new Schema( + optional(1, "id", Types.IntegerType.get()), + optional(2, "data", Types.StringType.get()) + ); + + private static final Schema SCHEMA2 = new Schema( + optional(1, "id", Types.IntegerType.get()), + optional(2, "data", Types.StringType.get()), + optional(3, "category", Types.StringType.get()) + ); + + private static final Schema SCHEMA3 = new Schema( + optional(1, "id", Types.IntegerType.get()), + optional(3, "category", Types.StringType.get()) + ); + + private static final PartitionSpec SPEC = PartitionSpec.builderFor(SCHEMA).identity("id").build(); + + @Rule + public TemporaryFolder temp = new TemporaryFolder(); + + public abstract Table createTable(TableIdentifier ident, Schema schema, PartitionSpec spec); + + public abstract Table loadTable(TableIdentifier ident, String entriesSuffix); + + public abstract String loadLocation(TableIdentifier ident, String entriesSuffix); + + public abstract String loadLocation(TableIdentifier ident); + + @Test + public synchronized void testTablesSupport() { + TableIdentifier tableIdentifier = TableIdentifier.of("db", "table"); + createTable(tableIdentifier, SCHEMA, PartitionSpec.unpartitioned()); + + List expectedRecords = Lists.newArrayList( + new SimpleRecord(1, "1"), + new SimpleRecord(2, "2"), + new SimpleRecord(3, "3")); + + Dataset inputDf = spark.createDataFrame(expectedRecords, SimpleRecord.class); + inputDf.select("id", "data").write() + .format("iceberg") + .mode(SaveMode.Append) + .save(loadLocation(tableIdentifier)); + + Dataset resultDf = spark.read() + .format("iceberg") + .load(loadLocation(tableIdentifier)); + List actualRecords = resultDf.orderBy("id") + .as(Encoders.bean(SimpleRecord.class)) + .collectAsList(); + + Assert.assertEquals("Records should match", expectedRecords, actualRecords); + } + + @Test + public void testEntriesTable() throws Exception { + TableIdentifier tableIdentifier = TableIdentifier.of("db", "entries_test"); + Table table = createTable(tableIdentifier, SCHEMA, PartitionSpec.unpartitioned()); + Table entriesTable = loadTable(tableIdentifier, "entries"); + + List records = Lists.newArrayList(new SimpleRecord(1, "1")); + + Dataset inputDf = spark.createDataFrame(records, SimpleRecord.class); + inputDf.select("id", "data").write() + .format("iceberg") + .mode("append") + .save(loadLocation(tableIdentifier)); + + table.refresh(); + + List actual = spark.read() + .format("iceberg") + .load(loadLocation(tableIdentifier, "entries")) + .collectAsList(); + + Snapshot snapshot = table.currentSnapshot(); + + Assert.assertEquals("Should only contain one manifest", 1, snapshot.allManifests(table.io()).size()); + + InputFile manifest = table.io().newInputFile(snapshot.allManifests(table.io()).get(0).path()); + List expected = Lists.newArrayList(); + try (CloseableIterable rows = Avro.read(manifest).project(entriesTable.schema()).build()) { + // each row must inherit snapshot_id and sequence_number + rows.forEach(row -> { + row.put(2, 0L); + GenericData.Record file = (GenericData.Record) row.get("data_file"); + asMetadataRecord(file); + expected.add(row); + }); + } + + Assert.assertEquals("Entries table should have one row", 1, expected.size()); + Assert.assertEquals("Actual results should have one row", 1, actual.size()); + TestHelpers.assertEqualsSafe(entriesTable.schema().asStruct(), expected.get(0), actual.get(0)); + } + + @Test + public void testEntriesTablePartitionedPrune() throws Exception { + TableIdentifier tableIdentifier = TableIdentifier.of("db", "entries_test"); + Table table = createTable(tableIdentifier, SCHEMA, SPEC); + + List records = Lists.newArrayList(new SimpleRecord(1, "1")); + + Dataset inputDf = spark.createDataFrame(records, SimpleRecord.class); + inputDf.select("id", "data").write() + .format("iceberg") + .mode("append") + .save(loadLocation(tableIdentifier)); + + table.refresh(); + + List actual = spark.read() + .format("iceberg") + .load(loadLocation(tableIdentifier, "entries")) + .select("status") + .collectAsList(); + + Assert.assertEquals("Results should contain only one status", 1, actual.size()); + Assert.assertEquals("That status should be Added (1)", 1, actual.get(0).getInt(0)); + } + + @Test + public void testEntriesTableDataFilePrune() throws Exception { + TableIdentifier tableIdentifier = TableIdentifier.of("db", "entries_test"); + Table table = createTable(tableIdentifier, SCHEMA, PartitionSpec.unpartitioned()); + + List records = Lists.newArrayList(new SimpleRecord(1, "1")); + + Dataset inputDf = spark.createDataFrame(records, SimpleRecord.class); + inputDf.select("id", "data").write() + .format("iceberg") + .mode("append") + .save(loadLocation(tableIdentifier)); + + table.refresh(); + DataFile file = table.currentSnapshot().addedFiles(table.io()).iterator().next(); + + List singleActual = rowsToJava(spark.read() + .format("iceberg") + .load(loadLocation(tableIdentifier, "entries")) + .select("data_file.file_path") + .collectAsList()); + + List singleExpected = ImmutableList.of(row(file.path())); + + assertEquals("Should prune a single element from a nested struct", singleExpected, singleActual); + } + + @Test + public void testEntriesTableDataFilePruneMulti() throws Exception { + TableIdentifier tableIdentifier = TableIdentifier.of("db", "entries_test"); + Table table = createTable(tableIdentifier, SCHEMA, PartitionSpec.unpartitioned()); + + List records = Lists.newArrayList(new SimpleRecord(1, "1")); + + Dataset inputDf = spark.createDataFrame(records, SimpleRecord.class); + inputDf.select("id", "data").write() + .format("iceberg") + .mode("append") + .save(loadLocation(tableIdentifier)); + + table.refresh(); + DataFile file = table.currentSnapshot().addedFiles(table.io()).iterator().next(); + + List multiActual = rowsToJava(spark.read() + .format("iceberg") + .load(loadLocation(tableIdentifier, "entries")) + .select("data_file.file_path", "data_file.value_counts", "data_file.record_count", "data_file.column_sizes") + .collectAsList()); + + List multiExpected = ImmutableList.of( + row(file.path(), file.valueCounts(), file.recordCount(), file.columnSizes())); + + assertEquals("Should prune a single element from a nested struct", multiExpected, multiActual); + } + + @Test + public void testFilesSelectMap() throws Exception { + TableIdentifier tableIdentifier = TableIdentifier.of("db", "entries_test"); + Table table = createTable(tableIdentifier, SCHEMA, PartitionSpec.unpartitioned()); + + List records = Lists.newArrayList(new SimpleRecord(1, "1")); + + Dataset inputDf = spark.createDataFrame(records, SimpleRecord.class); + inputDf.select("id", "data").write() + .format("iceberg") + .mode("append") + .save(loadLocation(tableIdentifier)); + + table.refresh(); + DataFile file = table.currentSnapshot().addedFiles(table.io()).iterator().next(); + + List multiActual = rowsToJava(spark.read() + .format("iceberg") + .load(loadLocation(tableIdentifier, "files")) + .select("file_path", "value_counts", "record_count", "column_sizes") + .collectAsList()); + + List multiExpected = ImmutableList.of( + row(file.path(), file.valueCounts(), file.recordCount(), file.columnSizes())); + + assertEquals("Should prune a single element from a row", multiExpected, multiActual); + } + + @Test + public void testAllEntriesTable() throws Exception { + TableIdentifier tableIdentifier = TableIdentifier.of("db", "entries_test"); + Table table = createTable(tableIdentifier, SCHEMA, PartitionSpec.unpartitioned()); + Table entriesTable = loadTable(tableIdentifier, "all_entries"); + + Dataset df1 = spark.createDataFrame(Lists.newArrayList(new SimpleRecord(1, "a")), SimpleRecord.class); + Dataset df2 = spark.createDataFrame(Lists.newArrayList(new SimpleRecord(1, "b")), SimpleRecord.class); + + df1.select("id", "data").write() + .format("iceberg") + .mode("append") + .save(loadLocation(tableIdentifier)); + + // delete the first file to test that not only live files are listed + table.newDelete().deleteFromRowFilter(Expressions.equal("id", 1)).commit(); + + // add a second file + df2.select("id", "data").write() + .format("iceberg") + .mode("append") + .save(loadLocation(tableIdentifier)); + + // ensure table data isn't stale + table.refresh(); + + List actual = spark.read() + .format("iceberg") + .load(loadLocation(tableIdentifier, "all_entries")) + .orderBy("snapshot_id") + .collectAsList(); + + List expected = Lists.newArrayList(); + for (ManifestFile manifest : Iterables.concat( + Iterables.transform(table.snapshots(), s -> s.allManifests(table.io())))) { + InputFile in = table.io().newInputFile(manifest.path()); + try (CloseableIterable rows = Avro.read(in).project(entriesTable.schema()).build()) { + // each row must inherit snapshot_id and sequence_number + rows.forEach(row -> { + row.put(2, 0L); + GenericData.Record file = (GenericData.Record) row.get("data_file"); + asMetadataRecord(file); + expected.add(row); + }); + } + } + + expected.sort(Comparator.comparing(o -> (Long) o.get("snapshot_id"))); + + Assert.assertEquals("Entries table should have 3 rows", 3, expected.size()); + Assert.assertEquals("Actual results should have 3 rows", 3, actual.size()); + for (int i = 0; i < expected.size(); i += 1) { + TestHelpers.assertEqualsSafe(entriesTable.schema().asStruct(), expected.get(i), actual.get(i)); + } + } + + @Test + public void testCountEntriesTable() { + TableIdentifier tableIdentifier = TableIdentifier.of("db", "count_entries_test"); + createTable(tableIdentifier, SCHEMA, PartitionSpec.unpartitioned()); + + // init load + List records = Lists.newArrayList(new SimpleRecord(1, "1")); + Dataset inputDf = spark.createDataFrame(records, SimpleRecord.class); + inputDf.select("id", "data").write() + .format("iceberg") + .mode("append") + .save(loadLocation(tableIdentifier)); + + final int expectedEntryCount = 1; + + // count entries + Assert.assertEquals("Count should return " + expectedEntryCount, + expectedEntryCount, spark.read().format("iceberg").load(loadLocation(tableIdentifier, "entries")).count()); + + // count all_entries + Assert.assertEquals("Count should return " + expectedEntryCount, + expectedEntryCount, spark.read().format("iceberg").load(loadLocation(tableIdentifier, "all_entries")).count()); + } + + @Test + public void testFilesTable() throws Exception { + TableIdentifier tableIdentifier = TableIdentifier.of("db", "files_test"); + Table table = createTable(tableIdentifier, SCHEMA, SPEC); + Table entriesTable = loadTable(tableIdentifier, "entries"); + Table filesTable = loadTable(tableIdentifier, "files"); + + Dataset df1 = spark.createDataFrame(Lists.newArrayList(new SimpleRecord(1, "a")), SimpleRecord.class); + Dataset df2 = spark.createDataFrame(Lists.newArrayList(new SimpleRecord(2, "b")), SimpleRecord.class); + + df1.select("id", "data").write() + .format("iceberg") + .mode("append") + .save(loadLocation(tableIdentifier)); + + // add a second file + df2.select("id", "data").write() + .format("iceberg") + .mode("append") + .save(loadLocation(tableIdentifier)); + + // delete the first file to test that only live files are listed + table.newDelete().deleteFromRowFilter(Expressions.equal("id", 1)).commit(); + + List actual = spark.read() + .format("iceberg") + .load(loadLocation(tableIdentifier, "files")) + .collectAsList(); + + List expected = Lists.newArrayList(); + for (ManifestFile manifest : table.currentSnapshot().dataManifests(table.io())) { + InputFile in = table.io().newInputFile(manifest.path()); + try (CloseableIterable rows = Avro.read(in).project(entriesTable.schema()).build()) { + for (GenericData.Record record : rows) { + if ((Integer) record.get("status") < 2 /* added or existing */) { + GenericData.Record file = (GenericData.Record) record.get("data_file"); + asMetadataRecord(file); + expected.add(file); + } + } + } + } + + Assert.assertEquals("Files table should have one row", 1, expected.size()); + Assert.assertEquals("Actual results should have one row", 1, actual.size()); + TestHelpers.assertEqualsSafe(filesTable.schema().asStruct(), expected.get(0), actual.get(0)); + } + + @Test + public void testFilesTableWithSnapshotIdInheritance() throws Exception { + spark.sql("DROP TABLE IF EXISTS parquet_table"); + + TableIdentifier tableIdentifier = TableIdentifier.of("db", "files_inheritance_test"); + Table table = createTable(tableIdentifier, SCHEMA, SPEC); + table.updateProperties() + .set(TableProperties.SNAPSHOT_ID_INHERITANCE_ENABLED, "true") + .commit(); + Table entriesTable = loadTable(tableIdentifier, "entries"); + Table filesTable = loadTable(tableIdentifier, "files"); + + spark.sql(String.format( + "CREATE TABLE parquet_table (data string, id int) " + + "USING parquet PARTITIONED BY (id) LOCATION '%s'", + temp.newFolder())); + + List records = Lists.newArrayList( + new SimpleRecord(1, "a"), + new SimpleRecord(2, "b") + ); + + Dataset inputDF = spark.createDataFrame(records, SimpleRecord.class); + inputDF.select("data", "id").write() + .mode("overwrite") + .insertInto("parquet_table"); + + try { + String stagingLocation = table.location() + "/metadata"; + SparkTableUtil.importSparkTable(spark, + new org.apache.spark.sql.catalyst.TableIdentifier("parquet_table"), + table, stagingLocation); + + List actual = spark.read() + .format("iceberg") + .load(loadLocation(tableIdentifier, "files")) + .collectAsList(); + + List expected = Lists.newArrayList(); + for (ManifestFile manifest : table.currentSnapshot().dataManifests(table.io())) { + InputFile in = table.io().newInputFile(manifest.path()); + try (CloseableIterable rows = Avro.read(in).project(entriesTable.schema()).build()) { + for (GenericData.Record record : rows) { + GenericData.Record file = (GenericData.Record) record.get("data_file"); + asMetadataRecord(file); + expected.add(file); + } + } + } + + Assert.assertEquals("Files table should have one row", 2, expected.size()); + Assert.assertEquals("Actual results should have one row", 2, actual.size()); + TestHelpers.assertEqualsSafe(filesTable.schema().asStruct(), expected.get(0), actual.get(0)); + TestHelpers.assertEqualsSafe(filesTable.schema().asStruct(), expected.get(1), actual.get(1)); + } finally { + spark.sql("DROP TABLE parquet_table"); + } + + } + + @Test + public void testEntriesTableWithSnapshotIdInheritance() throws Exception { + spark.sql("DROP TABLE IF EXISTS parquet_table"); + + TableIdentifier tableIdentifier = TableIdentifier.of("db", "entries_inheritance_test"); + PartitionSpec spec = SPEC; + Table table = createTable(tableIdentifier, SCHEMA, spec); + + table.updateProperties() + .set(TableProperties.SNAPSHOT_ID_INHERITANCE_ENABLED, "true") + .commit(); + + spark.sql(String.format( + "CREATE TABLE parquet_table (data string, id int) " + + "USING parquet PARTITIONED BY (id) LOCATION '%s'", + temp.newFolder())); + + List records = Lists.newArrayList( + new SimpleRecord(1, "a"), + new SimpleRecord(2, "b") + ); + + Dataset inputDF = spark.createDataFrame(records, SimpleRecord.class); + inputDF.select("data", "id").write() + .mode("overwrite") + .insertInto("parquet_table"); + + try { + String stagingLocation = table.location() + "/metadata"; + SparkTableUtil.importSparkTable( + spark, new org.apache.spark.sql.catalyst.TableIdentifier("parquet_table"), table, stagingLocation); + + List actual = spark.read() + .format("iceberg") + .load(loadLocation(tableIdentifier, "entries")) + .select("sequence_number", "snapshot_id", "data_file") + .collectAsList(); + + table.refresh(); + + long snapshotId = table.currentSnapshot().snapshotId(); + + Assert.assertEquals("Entries table should have 2 rows", 2, actual.size()); + Assert.assertEquals("Sequence number must match", 0, actual.get(0).getLong(0)); + Assert.assertEquals("Snapshot id must match", snapshotId, actual.get(0).getLong(1)); + Assert.assertEquals("Sequence number must match", 0, actual.get(1).getLong(0)); + Assert.assertEquals("Snapshot id must match", snapshotId, actual.get(1).getLong(1)); + } finally { + spark.sql("DROP TABLE parquet_table"); + } + } + + @Test + public void testFilesUnpartitionedTable() throws Exception { + TableIdentifier tableIdentifier = TableIdentifier.of("db", "unpartitioned_files_test"); + Table table = createTable(tableIdentifier, SCHEMA, PartitionSpec.unpartitioned()); + Table entriesTable = loadTable(tableIdentifier, "entries"); + Table filesTable = loadTable(tableIdentifier, "files"); + + Dataset df1 = spark.createDataFrame(Lists.newArrayList(new SimpleRecord(1, "a")), SimpleRecord.class); + Dataset df2 = spark.createDataFrame(Lists.newArrayList(new SimpleRecord(2, "b")), SimpleRecord.class); + + df1.select("id", "data").write() + .format("iceberg") + .mode("append") + .save(loadLocation(tableIdentifier)); + + table.refresh(); + DataFile toDelete = Iterables.getOnlyElement(table.currentSnapshot().addedFiles(table.io())); + + // add a second file + df2.select("id", "data").write() + .format("iceberg") + .mode("append") + .save(loadLocation(tableIdentifier)); + + // delete the first file to test that only live files are listed + table.newDelete().deleteFile(toDelete).commit(); + + List actual = spark.read() + .format("iceberg") + .load(loadLocation(tableIdentifier, "files")) + .collectAsList(); + + List expected = Lists.newArrayList(); + for (ManifestFile manifest : table.currentSnapshot().dataManifests(table.io())) { + InputFile in = table.io().newInputFile(manifest.path()); + try (CloseableIterable rows = Avro.read(in).project(entriesTable.schema()).build()) { + for (GenericData.Record record : rows) { + if ((Integer) record.get("status") < 2 /* added or existing */) { + GenericData.Record file = (GenericData.Record) record.get("data_file"); + asMetadataRecord(file); + expected.add(file); + } + } + } + } + + Assert.assertEquals("Files table should have one row", 1, expected.size()); + Assert.assertEquals("Actual results should have one row", 1, actual.size()); + TestHelpers.assertEqualsSafe(filesTable.schema().asStruct(), expected.get(0), actual.get(0)); + } + + @Test + public void testAllMetadataTablesWithStagedCommits() throws Exception { + TableIdentifier tableIdentifier = TableIdentifier.of("db", "stage_aggregate_table_test"); + Table table = createTable(tableIdentifier, SCHEMA, SPEC); + + table.updateProperties().set(TableProperties.WRITE_AUDIT_PUBLISH_ENABLED, "true").commit(); + spark.conf().set("spark.wap.id", "1234567"); + Dataset df1 = spark.createDataFrame(Lists.newArrayList(new SimpleRecord(1, "a")), SimpleRecord.class); + Dataset df2 = spark.createDataFrame(Lists.newArrayList(new SimpleRecord(2, "b")), SimpleRecord.class); + + df1.select("id", "data").write() + .format("iceberg") + .mode("append") + .save(loadLocation(tableIdentifier)); + + // add a second file + df2.select("id", "data").write() + .format("iceberg") + .mode("append") + .save(loadLocation(tableIdentifier)); + + List actualAllData = spark.read() + .format("iceberg") + .load(loadLocation(tableIdentifier, "all_data_files")) + .collectAsList(); + + List actualAllManifests = spark.read() + .format("iceberg") + .load(loadLocation(tableIdentifier, "all_manifests")) + .collectAsList(); + + List actualAllEntries = spark.read() + .format("iceberg") + .load(loadLocation(tableIdentifier, "all_entries")) + .collectAsList(); + + Assert.assertTrue("Stage table should have some snapshots", table.snapshots().iterator().hasNext()); + Assert.assertEquals("Stage table should have null currentSnapshot", + null, table.currentSnapshot()); + Assert.assertEquals("Actual results should have two rows", 2, actualAllData.size()); + Assert.assertEquals("Actual results should have two rows", 2, actualAllManifests.size()); + Assert.assertEquals("Actual results should have two rows", 2, actualAllEntries.size()); + } + + @Test + public void testAllDataFilesTable() throws Exception { + TableIdentifier tableIdentifier = TableIdentifier.of("db", "files_test"); + Table table = createTable(tableIdentifier, SCHEMA, SPEC); + Table entriesTable = loadTable(tableIdentifier, "entries"); + Table filesTable = loadTable(tableIdentifier, "all_data_files"); + + Dataset df1 = spark.createDataFrame(Lists.newArrayList(new SimpleRecord(1, "a")), SimpleRecord.class); + Dataset df2 = spark.createDataFrame(Lists.newArrayList(new SimpleRecord(2, "b")), SimpleRecord.class); + + df1.select("id", "data").write() + .format("iceberg") + .mode("append") + .save(loadLocation(tableIdentifier)); + + // delete the first file to test that not only live files are listed + table.newDelete().deleteFromRowFilter(Expressions.equal("id", 1)).commit(); + + // add a second file + df2.select("id", "data").write() + .format("iceberg") + .mode("append") + .save(loadLocation(tableIdentifier)); + + // ensure table data isn't stale + table.refresh(); + + List actual = spark.read() + .format("iceberg") + .load(loadLocation(tableIdentifier, "all_data_files")) + .orderBy("file_path") + .collectAsList(); + actual.sort(Comparator.comparing(o -> o.getString(1))); + + List expected = Lists.newArrayList(); + Iterable dataManifests = Iterables.concat(Iterables.transform(table.snapshots(), + snapshot -> snapshot.dataManifests(table.io()))); + for (ManifestFile manifest : dataManifests) { + InputFile in = table.io().newInputFile(manifest.path()); + try (CloseableIterable rows = Avro.read(in).project(entriesTable.schema()).build()) { + for (GenericData.Record record : rows) { + if ((Integer) record.get("status") < 2 /* added or existing */) { + GenericData.Record file = (GenericData.Record) record.get("data_file"); + asMetadataRecord(file); + expected.add(file); + } + } + } + } + + expected.sort(Comparator.comparing(o -> o.get("file_path").toString())); + + Assert.assertEquals("Files table should have two rows", 2, expected.size()); + Assert.assertEquals("Actual results should have two rows", 2, actual.size()); + for (int i = 0; i < expected.size(); i += 1) { + TestHelpers.assertEqualsSafe(filesTable.schema().asStruct(), expected.get(i), actual.get(i)); + } + } + + @Test + public void testHistoryTable() { + TableIdentifier tableIdentifier = TableIdentifier.of("db", "history_test"); + Table table = createTable(tableIdentifier, SCHEMA, PartitionSpec.unpartitioned()); + Table historyTable = loadTable(tableIdentifier, "history"); + + List records = Lists.newArrayList(new SimpleRecord(1, "1")); + Dataset inputDf = spark.createDataFrame(records, SimpleRecord.class); + + inputDf.select("id", "data").write() + .format("iceberg") + .mode("append") + .save(loadLocation(tableIdentifier)); + + table.refresh(); + long firstSnapshotTimestamp = table.currentSnapshot().timestampMillis(); + long firstSnapshotId = table.currentSnapshot().snapshotId(); + + inputDf.select("id", "data").write() + .format("iceberg") + .mode("append") + .save(loadLocation(tableIdentifier)); + + table.refresh(); + long secondSnapshotTimestamp = table.currentSnapshot().timestampMillis(); + long secondSnapshotId = table.currentSnapshot().snapshotId(); + + // rollback the table state to the first snapshot + table.rollback().toSnapshotId(firstSnapshotId).commit(); + long rollbackTimestamp = Iterables.getLast(table.history()).timestampMillis(); + + inputDf.select("id", "data").write() + .format("iceberg") + .mode("append") + .save(loadLocation(tableIdentifier)); + + table.refresh(); + long thirdSnapshotTimestamp = table.currentSnapshot().timestampMillis(); + long thirdSnapshotId = table.currentSnapshot().snapshotId(); + + List actual = spark.read() + .format("iceberg") + .load(loadLocation(tableIdentifier, "history")) + .collectAsList(); + + GenericRecordBuilder builder = new GenericRecordBuilder(AvroSchemaUtil.convert(historyTable.schema(), "history")); + List expected = Lists.newArrayList( + builder.set("made_current_at", firstSnapshotTimestamp * 1000) + .set("snapshot_id", firstSnapshotId) + .set("parent_id", null) + .set("is_current_ancestor", true) + .build(), + builder.set("made_current_at", secondSnapshotTimestamp * 1000) + .set("snapshot_id", secondSnapshotId) + .set("parent_id", firstSnapshotId) + .set("is_current_ancestor", false) // commit rolled back, not an ancestor of the current table state + .build(), + builder.set("made_current_at", rollbackTimestamp * 1000) + .set("snapshot_id", firstSnapshotId) + .set("parent_id", null) + .set("is_current_ancestor", true) + .build(), + builder.set("made_current_at", thirdSnapshotTimestamp * 1000) + .set("snapshot_id", thirdSnapshotId) + .set("parent_id", firstSnapshotId) + .set("is_current_ancestor", true) + .build() + ); + + Assert.assertEquals("History table should have a row for each commit", 4, actual.size()); + TestHelpers.assertEqualsSafe(historyTable.schema().asStruct(), expected.get(0), actual.get(0)); + TestHelpers.assertEqualsSafe(historyTable.schema().asStruct(), expected.get(1), actual.get(1)); + TestHelpers.assertEqualsSafe(historyTable.schema().asStruct(), expected.get(2), actual.get(2)); + TestHelpers.assertEqualsSafe(historyTable.schema().asStruct(), expected.get(3), actual.get(3)); + } + + @Test + public void testSnapshotsTable() { + TableIdentifier tableIdentifier = TableIdentifier.of("db", "snapshots_test"); + Table table = createTable(tableIdentifier, SCHEMA, PartitionSpec.unpartitioned()); + Table snapTable = loadTable(tableIdentifier, "snapshots"); + + List records = Lists.newArrayList(new SimpleRecord(1, "1")); + Dataset inputDf = spark.createDataFrame(records, SimpleRecord.class); + + inputDf.select("id", "data").write() + .format("iceberg") + .mode("append") + .save(loadLocation(tableIdentifier)); + + table.refresh(); + long firstSnapshotTimestamp = table.currentSnapshot().timestampMillis(); + long firstSnapshotId = table.currentSnapshot().snapshotId(); + String firstManifestList = table.currentSnapshot().manifestListLocation(); + + table.newDelete().deleteFromRowFilter(Expressions.alwaysTrue()).commit(); + + long secondSnapshotTimestamp = table.currentSnapshot().timestampMillis(); + long secondSnapshotId = table.currentSnapshot().snapshotId(); + String secondManifestList = table.currentSnapshot().manifestListLocation(); + + // rollback the table state to the first snapshot + table.rollback().toSnapshotId(firstSnapshotId).commit(); + + List actual = spark.read() + .format("iceberg") + .load(loadLocation(tableIdentifier, "snapshots")) + .collectAsList(); + + GenericRecordBuilder builder = new GenericRecordBuilder(AvroSchemaUtil.convert(snapTable.schema(), "snapshots")); + List expected = Lists.newArrayList( + builder.set("committed_at", firstSnapshotTimestamp * 1000) + .set("snapshot_id", firstSnapshotId) + .set("parent_id", null) + .set("operation", "append") + .set("manifest_list", firstManifestList) + .set("summary", ImmutableMap.of( + "added-records", "1", + "added-data-files", "1", + "changed-partition-count", "1", + "total-data-files", "1", + "total-records", "1" + )) + .build(), + builder.set("committed_at", secondSnapshotTimestamp * 1000) + .set("snapshot_id", secondSnapshotId) + .set("parent_id", firstSnapshotId) + .set("operation", "delete") + .set("manifest_list", secondManifestList) + .set("summary", ImmutableMap.of( + "deleted-records", "1", + "deleted-data-files", "1", + "changed-partition-count", "1", + "total-records", "0", + "total-data-files", "0" + )) + .build() + ); + + Assert.assertEquals("Snapshots table should have a row for each snapshot", 2, actual.size()); + TestHelpers.assertEqualsSafe(snapTable.schema().asStruct(), expected.get(0), actual.get(0)); + TestHelpers.assertEqualsSafe(snapTable.schema().asStruct(), expected.get(1), actual.get(1)); + } + + @Test + public void testPrunedSnapshotsTable() { + TableIdentifier tableIdentifier = TableIdentifier.of("db", "snapshots_test"); + Table table = createTable(tableIdentifier, SCHEMA, PartitionSpec.unpartitioned()); + + List records = Lists.newArrayList(new SimpleRecord(1, "1")); + Dataset inputDf = spark.createDataFrame(records, SimpleRecord.class); + + inputDf.select("id", "data").write() + .format("iceberg") + .mode("append") + .save(loadLocation(tableIdentifier)); + + table.refresh(); + long firstSnapshotTimestamp = table.currentSnapshot().timestampMillis(); + long firstSnapshotId = table.currentSnapshot().snapshotId(); + + table.newDelete().deleteFromRowFilter(Expressions.alwaysTrue()).commit(); + + long secondSnapshotTimestamp = table.currentSnapshot().timestampMillis(); + + // rollback the table state to the first snapshot + table.rollback().toSnapshotId(firstSnapshotId).commit(); + + Dataset actualDf = spark.read() + .format("iceberg") + .load(loadLocation(tableIdentifier, "snapshots")) + .select("operation", "committed_at", "summary", "parent_id"); + + Schema projectedSchema = SparkSchemaUtil.convert(actualDf.schema()); + + List actual = actualDf.collectAsList(); + + GenericRecordBuilder builder = new GenericRecordBuilder(AvroSchemaUtil.convert(projectedSchema, "snapshots")); + List expected = Lists.newArrayList( + builder.set("committed_at", firstSnapshotTimestamp * 1000) + .set("parent_id", null) + .set("operation", "append") + .set("summary", ImmutableMap.of( + "added-records", "1", + "added-data-files", "1", + "changed-partition-count", "1", + "total-data-files", "1", + "total-records", "1" + )) + .build(), + builder.set("committed_at", secondSnapshotTimestamp * 1000) + .set("parent_id", firstSnapshotId) + .set("operation", "delete") + .set("summary", ImmutableMap.of( + "deleted-records", "1", + "deleted-data-files", "1", + "changed-partition-count", "1", + "total-records", "0", + "total-data-files", "0" + )) + .build() + ); + + Assert.assertEquals("Snapshots table should have a row for each snapshot", 2, actual.size()); + TestHelpers.assertEqualsSafe(projectedSchema.asStruct(), expected.get(0), actual.get(0)); + TestHelpers.assertEqualsSafe(projectedSchema.asStruct(), expected.get(1), actual.get(1)); + } + + @Test + public void testManifestsTable() { + TableIdentifier tableIdentifier = TableIdentifier.of("db", "manifests_test"); + Table table = createTable(tableIdentifier, SCHEMA, SPEC); + Table manifestTable = loadTable(tableIdentifier, "manifests"); + Dataset df1 = spark.createDataFrame( + Lists.newArrayList(new SimpleRecord(1, "a"), new SimpleRecord(null, "b")), SimpleRecord.class); + + df1.select("id", "data").write() + .format("iceberg") + .mode("append") + .save(loadLocation(tableIdentifier)); + + table.updateProperties() + .set(TableProperties.FORMAT_VERSION, "2") + .commit(); + + DataFile dataFile = Iterables.getFirst(table.currentSnapshot().addedFiles(table.io()), null); + PartitionSpec dataFileSpec = table.specs().get(dataFile.specId()); + StructLike dataFilePartition = dataFile.partition(); + + PositionDelete delete = PositionDelete.create(); + delete.set(dataFile.path(), 0L, null); + + DeleteFile deleteFile = writePositionDeletes(table, dataFileSpec, dataFilePartition, ImmutableList.of(delete)); + + table.newRowDelta() + .addDeletes(deleteFile) + .commit(); + + List actual = spark.read() + .format("iceberg") + .load(loadLocation(tableIdentifier, "manifests")) + .collectAsList(); + + table.refresh(); + + GenericRecordBuilder builder = new GenericRecordBuilder(AvroSchemaUtil.convert( + manifestTable.schema(), "manifests")); + GenericRecordBuilder summaryBuilder = new GenericRecordBuilder(AvroSchemaUtil.convert( + manifestTable.schema().findType("partition_summaries.element").asStructType(), "partition_summary")); + List expected = Lists.transform(table.currentSnapshot().allManifests(table.io()), manifest -> + builder + .set("content", manifest.content().id()) + .set("path", manifest.path()) + .set("length", manifest.length()) + .set("partition_spec_id", manifest.partitionSpecId()) + .set("added_snapshot_id", manifest.snapshotId()) + .set("added_data_files_count", manifest.content() == DATA ? manifest.addedFilesCount() : 0) + .set("existing_data_files_count", manifest.content() == DATA ? manifest.existingFilesCount() : 0) + .set("deleted_data_files_count", manifest.content() == DATA ? manifest.deletedFilesCount() : 0) + .set("added_delete_files_count", manifest.content() == DELETES ? manifest.addedFilesCount() : 0) + .set("existing_delete_files_count", manifest.content() == DELETES ? manifest.existingFilesCount() : 0) + .set("deleted_delete_files_count", manifest.content() == DELETES ? manifest.deletedFilesCount() : 0) + .set("partition_summaries", Lists.transform(manifest.partitions(), partition -> + summaryBuilder + .set("contains_null", manifest.content() == DATA) + .set("contains_nan", false) + .set("lower_bound", "1") + .set("upper_bound", "1") + .build() + )) + .build() + ); + + Assert.assertEquals("Manifests table should have two manifest rows", 2, actual.size()); + TestHelpers.assertEqualsSafe(manifestTable.schema().asStruct(), expected.get(0), actual.get(0)); + TestHelpers.assertEqualsSafe(manifestTable.schema().asStruct(), expected.get(1), actual.get(1)); + } + + @Test + public void testPruneManifestsTable() { + TableIdentifier tableIdentifier = TableIdentifier.of("db", "manifests_test"); + Table table = createTable(tableIdentifier, SCHEMA, SPEC); + Table manifestTable = loadTable(tableIdentifier, "manifests"); + Dataset df1 = spark.createDataFrame( + Lists.newArrayList(new SimpleRecord(1, "a"), new SimpleRecord(null, "b")), SimpleRecord.class); + + df1.select("id", "data").write() + .format("iceberg") + .mode("append") + .save(loadLocation(tableIdentifier)); + + if (!spark.version().startsWith("2")) { + // Spark 2 isn't able to actually push down nested struct projections so this will not break + AssertHelpers.assertThrows("Can't prune struct inside list", SparkException.class, + "Cannot project a partial list element struct", + () -> spark.read() + .format("iceberg") + .load(loadLocation(tableIdentifier, "manifests")) + .select("partition_spec_id", "path", "partition_summaries.contains_null") + .collectAsList()); + } + + Dataset actualDf = spark.read() + .format("iceberg") + .load(loadLocation(tableIdentifier, "manifests")) + .select("partition_spec_id", "path", "partition_summaries"); + + Schema projectedSchema = SparkSchemaUtil.convert(actualDf.schema()); + + List actual = spark.read() + .format("iceberg") + .load(loadLocation(tableIdentifier, "manifests")) + .select("partition_spec_id", "path", "partition_summaries") + .collectAsList(); + + table.refresh(); + + GenericRecordBuilder builder = new GenericRecordBuilder(AvroSchemaUtil.convert(projectedSchema.asStruct())); + GenericRecordBuilder summaryBuilder = new GenericRecordBuilder(AvroSchemaUtil.convert( + projectedSchema.findType("partition_summaries.element").asStructType(), "partition_summary")); + List expected = Lists.transform(table.currentSnapshot().allManifests(table.io()), manifest -> + builder.set("partition_spec_id", manifest.partitionSpecId()) + .set("path", manifest.path()) + .set("partition_summaries", Lists.transform(manifest.partitions(), partition -> + summaryBuilder + .set("contains_null", true) + .set("contains_nan", false) + .set("lower_bound", "1") + .set("upper_bound", "1") + .build() + )) + .build() + ); + + Assert.assertEquals("Manifests table should have one manifest row", 1, actual.size()); + TestHelpers.assertEqualsSafe(projectedSchema.asStruct(), expected.get(0), actual.get(0)); + } + + @Test + public void testAllManifestsTable() { + TableIdentifier tableIdentifier = TableIdentifier.of("db", "manifests_test"); + Table table = createTable(tableIdentifier, SCHEMA, SPEC); + Table manifestTable = loadTable(tableIdentifier, "all_manifests"); + Dataset df1 = spark.createDataFrame(Lists.newArrayList(new SimpleRecord(1, "a")), SimpleRecord.class); + + List manifests = Lists.newArrayList(); + + df1.select("id", "data").write() + .format("iceberg") + .mode("append") + .save(loadLocation(tableIdentifier)); + + table.updateProperties() + .set(TableProperties.FORMAT_VERSION, "2") + .commit(); + + DataFile dataFile = Iterables.getFirst(table.currentSnapshot().addedFiles(table.io()), null); + PartitionSpec dataFileSpec = table.specs().get(dataFile.specId()); + StructLike dataFilePartition = dataFile.partition(); + + PositionDelete delete = PositionDelete.create(); + delete.set(dataFile.path(), 0L, null); + + DeleteFile deleteFile = writePositionDeletes(table, dataFileSpec, dataFilePartition, ImmutableList.of(delete)); + + table.newRowDelta() + .addDeletes(deleteFile) + .commit(); + + manifests.addAll(table.currentSnapshot().allManifests(table.io())); + + table.newDelete().deleteFromRowFilter(Expressions.alwaysTrue()).commit(); + + manifests.addAll(table.currentSnapshot().allManifests(table.io())); + + List actual = spark.read() + .format("iceberg") + .load(loadLocation(tableIdentifier, "all_manifests")) + .dropDuplicates("path") + .orderBy("path") + .collectAsList(); + + table.refresh(); + + GenericRecordBuilder builder = new GenericRecordBuilder(AvroSchemaUtil.convert( + manifestTable.schema(), "manifests")); + GenericRecordBuilder summaryBuilder = new GenericRecordBuilder(AvroSchemaUtil.convert( + manifestTable.schema().findType("partition_summaries.element").asStructType(), "partition_summary")); + List expected = Lists.newArrayList(Iterables.transform(manifests, manifest -> + builder + .set("content", manifest.content().id()) + .set("path", manifest.path()) + .set("length", manifest.length()) + .set("partition_spec_id", manifest.partitionSpecId()) + .set("added_snapshot_id", manifest.snapshotId()) + .set("added_data_files_count", manifest.content() == DATA ? manifest.addedFilesCount() : 0) + .set("existing_data_files_count", manifest.content() == DATA ? manifest.existingFilesCount() : 0) + .set("deleted_data_files_count", manifest.content() == DATA ? manifest.deletedFilesCount() : 0) + .set("added_delete_files_count", manifest.content() == DELETES ? manifest.addedFilesCount() : 0) + .set("existing_delete_files_count", manifest.content() == DELETES ? manifest.existingFilesCount() : 0) + .set("deleted_delete_files_count", manifest.content() == DELETES ? manifest.deletedFilesCount() : 0) + .set("partition_summaries", Lists.transform(manifest.partitions(), partition -> + summaryBuilder + .set("contains_null", false) + .set("contains_nan", false) + .set("lower_bound", "1") + .set("upper_bound", "1") + .build() + )) + .build() + )); + + expected.sort(Comparator.comparing(o -> o.get("path").toString())); + + Assert.assertEquals("Manifests table should have 4 manifest rows", 4, actual.size()); + for (int i = 0; i < expected.size(); i += 1) { + TestHelpers.assertEqualsSafe(manifestTable.schema().asStruct(), expected.get(i), actual.get(i)); + } + } + + @Test + public void testUnpartitionedPartitionsTable() { + TableIdentifier tableIdentifier = TableIdentifier.of("db", "unpartitioned_partitions_test"); + createTable(tableIdentifier, SCHEMA, PartitionSpec.unpartitioned()); + + Dataset df = spark.createDataFrame(Lists.newArrayList(new SimpleRecord(1, "a")), SimpleRecord.class); + + df.select("id", "data").write() + .format("iceberg") + .mode("append") + .save(loadLocation(tableIdentifier)); + + Types.StructType expectedSchema = Types.StructType.of( + required(2, "record_count", Types.LongType.get()), + required(3, "file_count", Types.IntegerType.get())); + + Table partitionsTable = loadTable(tableIdentifier, "partitions"); + + Assert.assertEquals("Schema should not have partition field", + expectedSchema, partitionsTable.schema().asStruct()); + + GenericRecordBuilder builder = new GenericRecordBuilder(AvroSchemaUtil.convert( + partitionsTable.schema(), "partitions")); + GenericData.Record expectedRow = builder + .set("record_count", 1L) + .set("file_count", 1) + .build(); + + List actual = spark.read() + .format("iceberg") + .load(loadLocation(tableIdentifier, "partitions")) + .collectAsList(); + + Assert.assertEquals("Unpartitioned partitions table should have one row", 1, actual.size()); + TestHelpers.assertEqualsSafe(expectedSchema, expectedRow, actual.get(0)); + } + + @Test + public void testPartitionsTable() { + TableIdentifier tableIdentifier = TableIdentifier.of("db", "partitions_test"); + Table table = createTable(tableIdentifier, SCHEMA, SPEC); + Table partitionsTable = loadTable(tableIdentifier, "partitions"); + Dataset df1 = spark.createDataFrame(Lists.newArrayList(new SimpleRecord(1, "a")), SimpleRecord.class); + Dataset df2 = spark.createDataFrame(Lists.newArrayList(new SimpleRecord(2, "b")), SimpleRecord.class); + + df1.select("id", "data").write() + .format("iceberg") + .mode("append") + .save(loadLocation(tableIdentifier)); + + table.refresh(); + long firstCommitId = table.currentSnapshot().snapshotId(); + + // add a second file + df2.select("id", "data").write() + .format("iceberg") + .mode("append") + .save(loadLocation(tableIdentifier)); + + List actual = spark.read() + .format("iceberg") + .load(loadLocation(tableIdentifier, "partitions")) + .orderBy("partition.id") + .collectAsList(); + + GenericRecordBuilder builder = new GenericRecordBuilder(AvroSchemaUtil.convert( + partitionsTable.schema(), "partitions")); + GenericRecordBuilder partitionBuilder = new GenericRecordBuilder(AvroSchemaUtil.convert( + partitionsTable.schema().findType("partition").asStructType(), "partition")); + List expected = Lists.newArrayList(); + expected.add(builder + .set("partition", partitionBuilder.set("id", 1).build()) + .set("record_count", 1L) + .set("file_count", 1) + .set("spec_id", 0) + .build()); + expected.add(builder + .set("partition", partitionBuilder.set("id", 2).build()) + .set("record_count", 1L) + .set("file_count", 1) + .set("spec_id", 0) + .build()); + + Assert.assertEquals("Partitions table should have two rows", 2, expected.size()); + Assert.assertEquals("Actual results should have two rows", 2, actual.size()); + for (int i = 0; i < 2; i += 1) { + TestHelpers.assertEqualsSafe(partitionsTable.schema().asStruct(), expected.get(i), actual.get(i)); + } + + // check time travel + List actualAfterFirstCommit = spark.read() + .format("iceberg") + .option(SparkReadOptions.SNAPSHOT_ID, String.valueOf(firstCommitId)) + .load(loadLocation(tableIdentifier, "partitions")) + .orderBy("partition.id") + .collectAsList(); + + Assert.assertEquals("Actual results should have one row", 1, actualAfterFirstCommit.size()); + TestHelpers.assertEqualsSafe(partitionsTable.schema().asStruct(), expected.get(0), actualAfterFirstCommit.get(0)); + + // check predicate push down + List filtered = spark.read() + .format("iceberg") + .load(loadLocation(tableIdentifier, "partitions")) + .filter("partition.id < 2") + .collectAsList(); + Assert.assertEquals("Actual results should have one row", 1, filtered.size()); + TestHelpers.assertEqualsSafe(partitionsTable.schema().asStruct(), expected.get(0), filtered.get(0)); + + List nonFiltered = spark.read() + .format("iceberg") + .load(loadLocation(tableIdentifier, "partitions")) + .filter("partition.id < 2 or record_count=1") + .collectAsList(); + Assert.assertEquals("Actual results should have one row", 2, nonFiltered.size()); + for (int i = 0; i < 2; i += 1) { + TestHelpers.assertEqualsSafe(partitionsTable.schema().asStruct(), expected.get(i), actual.get(i)); + } + } + + @Test + public synchronized void testSnapshotReadAfterAddColumn() { + TableIdentifier tableIdentifier = TableIdentifier.of("db", "table"); + Table table = createTable(tableIdentifier, SCHEMA, PartitionSpec.unpartitioned()); + + List originalRecords = Lists.newArrayList( + RowFactory.create(1, "x"), + RowFactory.create(2, "y"), + RowFactory.create(3, "z")); + + StructType originalSparkSchema = SparkSchemaUtil.convert(SCHEMA); + Dataset inputDf = spark.createDataFrame(originalRecords, originalSparkSchema); + inputDf.select("id", "data").write() + .format("iceberg") + .mode(SaveMode.Append) + .save(loadLocation(tableIdentifier)); + + table.refresh(); + + Dataset resultDf = spark.read() + .format("iceberg") + .load(loadLocation(tableIdentifier)); + Assert.assertEquals("Records should match", originalRecords, + resultDf.orderBy("id").collectAsList()); + + Snapshot snapshotBeforeAddColumn = table.currentSnapshot(); + + table.updateSchema().addColumn("category", Types.StringType.get()).commit(); + + List newRecords = Lists.newArrayList( + RowFactory.create(4, "xy", "B"), + RowFactory.create(5, "xyz", "C")); + + StructType newSparkSchema = SparkSchemaUtil.convert(SCHEMA2); + Dataset inputDf2 = spark.createDataFrame(newRecords, newSparkSchema); + inputDf2.select("id", "data", "category").write() + .format("iceberg") + .mode(SaveMode.Append) + .save(loadLocation(tableIdentifier)); + + table.refresh(); + + List updatedRecords = Lists.newArrayList( + RowFactory.create(1, "x", null), + RowFactory.create(2, "y", null), + RowFactory.create(3, "z", null), + RowFactory.create(4, "xy", "B"), + RowFactory.create(5, "xyz", "C")); + + Dataset resultDf2 = spark.read() + .format("iceberg") + .load(loadLocation(tableIdentifier)); + Assert.assertEquals("Records should match", updatedRecords, + resultDf2.orderBy("id").collectAsList()); + + Dataset resultDf3 = spark.read() + .format("iceberg") + .option(SparkReadOptions.SNAPSHOT_ID, snapshotBeforeAddColumn.snapshotId()) + .load(loadLocation(tableIdentifier)); + Assert.assertEquals("Records should match", originalRecords, + resultDf3.orderBy("id").collectAsList()); + Assert.assertEquals("Schemas should match", originalSparkSchema, resultDf3.schema()); + } + + @Test + public synchronized void testSnapshotReadAfterDropColumn() { + TableIdentifier tableIdentifier = TableIdentifier.of("db", "table"); + Table table = createTable(tableIdentifier, SCHEMA2, PartitionSpec.unpartitioned()); + + List originalRecords = Lists.newArrayList( + RowFactory.create(1, "x", "A"), + RowFactory.create(2, "y", "A"), + RowFactory.create(3, "z", "B")); + + StructType originalSparkSchema = SparkSchemaUtil.convert(SCHEMA2); + Dataset inputDf = spark.createDataFrame(originalRecords, originalSparkSchema); + inputDf.select("id", "data", "category").write() + .format("iceberg") + .mode(SaveMode.Append) + .save(loadLocation(tableIdentifier)); + + table.refresh(); + + Dataset resultDf = spark.read() + .format("iceberg") + .load(loadLocation(tableIdentifier)); + Assert.assertEquals("Records should match", originalRecords, + resultDf.orderBy("id").collectAsList()); + + long tsBeforeDropColumn = waitUntilAfter(System.currentTimeMillis()); + table.updateSchema().deleteColumn("data").commit(); + long tsAfterDropColumn = waitUntilAfter(System.currentTimeMillis()); + + List newRecords = Lists.newArrayList( + RowFactory.create(4, "B"), + RowFactory.create(5, "C")); + + StructType newSparkSchema = SparkSchemaUtil.convert(SCHEMA3); + Dataset inputDf2 = spark.createDataFrame(newRecords, newSparkSchema); + inputDf2.select("id", "category").write() + .format("iceberg") + .mode(SaveMode.Append) + .save(loadLocation(tableIdentifier)); + + table.refresh(); + + List updatedRecords = Lists.newArrayList( + RowFactory.create(1, "A"), + RowFactory.create(2, "A"), + RowFactory.create(3, "B"), + RowFactory.create(4, "B"), + RowFactory.create(5, "C")); + + Dataset resultDf2 = spark.read() + .format("iceberg") + .load(loadLocation(tableIdentifier)); + Assert.assertEquals("Records should match", updatedRecords, + resultDf2.orderBy("id").collectAsList()); + + Dataset resultDf3 = spark.read() + .format("iceberg") + .option(SparkReadOptions.AS_OF_TIMESTAMP, tsBeforeDropColumn) + .load(loadLocation(tableIdentifier)); + Assert.assertEquals("Records should match", originalRecords, + resultDf3.orderBy("id").collectAsList()); + Assert.assertEquals("Schemas should match", originalSparkSchema, resultDf3.schema()); + + // At tsAfterDropColumn, there has been a schema change, but no new snapshot, + // so the snapshot as of tsAfterDropColumn is the same as that as of tsBeforeDropColumn. + Dataset resultDf4 = spark.read() + .format("iceberg") + .option(SparkReadOptions.AS_OF_TIMESTAMP, tsAfterDropColumn) + .load(loadLocation(tableIdentifier)); + Assert.assertEquals("Records should match", originalRecords, + resultDf4.orderBy("id").collectAsList()); + Assert.assertEquals("Schemas should match", originalSparkSchema, resultDf4.schema()); + } + + @Test + public synchronized void testSnapshotReadAfterAddAndDropColumn() { + TableIdentifier tableIdentifier = TableIdentifier.of("db", "table"); + Table table = createTable(tableIdentifier, SCHEMA, PartitionSpec.unpartitioned()); + + List originalRecords = Lists.newArrayList( + RowFactory.create(1, "x"), + RowFactory.create(2, "y"), + RowFactory.create(3, "z")); + + StructType originalSparkSchema = SparkSchemaUtil.convert(SCHEMA); + Dataset inputDf = spark.createDataFrame(originalRecords, originalSparkSchema); + inputDf.select("id", "data").write() + .format("iceberg") + .mode(SaveMode.Append) + .save(loadLocation(tableIdentifier)); + + table.refresh(); + + Dataset resultDf = spark.read() + .format("iceberg") + .load(loadLocation(tableIdentifier)); + Assert.assertEquals("Records should match", originalRecords, + resultDf.orderBy("id").collectAsList()); + + Snapshot snapshotBeforeAddColumn = table.currentSnapshot(); + + table.updateSchema().addColumn("category", Types.StringType.get()).commit(); + + List newRecords = Lists.newArrayList( + RowFactory.create(4, "xy", "B"), + RowFactory.create(5, "xyz", "C")); + + StructType sparkSchemaAfterAddColumn = SparkSchemaUtil.convert(SCHEMA2); + Dataset inputDf2 = spark.createDataFrame(newRecords, sparkSchemaAfterAddColumn); + inputDf2.select("id", "data", "category").write() + .format("iceberg") + .mode(SaveMode.Append) + .save(loadLocation(tableIdentifier)); + + table.refresh(); + + List updatedRecords = Lists.newArrayList( + RowFactory.create(1, "x", null), + RowFactory.create(2, "y", null), + RowFactory.create(3, "z", null), + RowFactory.create(4, "xy", "B"), + RowFactory.create(5, "xyz", "C")); + + Dataset resultDf2 = spark.read() + .format("iceberg") + .load(loadLocation(tableIdentifier)); + Assert.assertEquals("Records should match", updatedRecords, + resultDf2.orderBy("id").collectAsList()); + + table.updateSchema().deleteColumn("data").commit(); + + List recordsAfterDropColumn = Lists.newArrayList( + RowFactory.create(1, null), + RowFactory.create(2, null), + RowFactory.create(3, null), + RowFactory.create(4, "B"), + RowFactory.create(5, "C")); + + Dataset resultDf3 = spark.read() + .format("iceberg") + .load(loadLocation(tableIdentifier)); + Assert.assertEquals("Records should match", recordsAfterDropColumn, + resultDf3.orderBy("id").collectAsList()); + + Dataset resultDf4 = spark.read() + .format("iceberg") + .option(SparkReadOptions.SNAPSHOT_ID, snapshotBeforeAddColumn.snapshotId()) + .load(loadLocation(tableIdentifier)); + Assert.assertEquals("Records should match", originalRecords, + resultDf4.orderBy("id").collectAsList()); + Assert.assertEquals("Schemas should match", originalSparkSchema, resultDf4.schema()); + } + + @Test + public void testRemoveOrphanFilesActionSupport() throws InterruptedException { + TableIdentifier tableIdentifier = TableIdentifier.of("db", "table"); + Table table = createTable(tableIdentifier, SCHEMA, PartitionSpec.unpartitioned()); + + List records = Lists.newArrayList( + new SimpleRecord(1, "1") + ); + + Dataset df = spark.createDataFrame(records, SimpleRecord.class); + + df.select("id", "data").write() + .format("iceberg") + .mode("append") + .save(loadLocation(tableIdentifier)); + + df.write().mode("append").parquet(table.location() + "/data"); + + // sleep for 1 second to ensure files will be old enough + Thread.sleep(1000); + + SparkActions actions = SparkActions.get(); + + DeleteOrphanFiles.Result result1 = actions.deleteOrphanFiles(table) + .location(table.location() + "/metadata") + .olderThan(System.currentTimeMillis()) + .execute(); + Assert.assertTrue("Should not delete any metadata files", Iterables.isEmpty(result1.orphanFileLocations())); + + DeleteOrphanFiles.Result result2 = actions.deleteOrphanFiles(table) + .olderThan(System.currentTimeMillis()) + .execute(); + Assert.assertEquals("Should delete 1 data file", 1, Iterables.size(result2.orphanFileLocations())); + + Dataset resultDF = spark.read().format("iceberg").load(loadLocation(tableIdentifier)); + List actualRecords = resultDF + .as(Encoders.bean(SimpleRecord.class)) + .collectAsList(); + + Assert.assertEquals("Rows must match", records, actualRecords); + } + + + @Test + public void testFilesTablePartitionId() throws Exception { + TableIdentifier tableIdentifier = TableIdentifier.of("db", "files_test"); + Table table = createTable(tableIdentifier, SCHEMA, PartitionSpec.builderFor(SCHEMA).identity("id").build()); + int spec0 = table.spec().specId(); + + Dataset df1 = spark.createDataFrame(Lists.newArrayList(new SimpleRecord(1, "a")), SimpleRecord.class); + Dataset df2 = spark.createDataFrame(Lists.newArrayList(new SimpleRecord(2, "b")), SimpleRecord.class); + + df1.select("id", "data").write() + .format("iceberg") + .mode("append") + .save(loadLocation(tableIdentifier)); + + // change partition spec + table.refresh(); + table.updateSpec().removeField("id").commit(); + int spec1 = table.spec().specId(); + + // add a second file + df2.select("id", "data").write() + .format("iceberg") + .mode("append") + .save(loadLocation(tableIdentifier)); + + List actual = spark.read() + .format("iceberg") + .load(loadLocation(tableIdentifier, "files")) + .sort(DataFile.SPEC_ID.name()) + .collectAsList() + .stream().map(r -> (Integer) r.getAs(DataFile.SPEC_ID.name())).collect(Collectors.toList()); + + Assert.assertEquals("Should have two partition specs", ImmutableList.of(spec0, spec1), actual); + } + + private void asMetadataRecord(GenericData.Record file) { + file.put(0, FileContent.DATA.id()); + file.put(3, 0); // specId + } + + private PositionDeleteWriter newPositionDeleteWriter(Table table, PartitionSpec spec, + StructLike partition) { + OutputFileFactory fileFactory = OutputFileFactory.builderFor(table, 0, 0).build(); + EncryptedOutputFile outputFile = fileFactory.newOutputFile(spec, partition); + + SparkFileWriterFactory fileWriterFactory = SparkFileWriterFactory.builderFor(table).build(); + return fileWriterFactory.newPositionDeleteWriter(outputFile, spec, partition); + } + + private DeleteFile writePositionDeletes(Table table, PartitionSpec spec, StructLike partition, + Iterable> deletes) { + PositionDeleteWriter positionDeleteWriter = newPositionDeleteWriter(table, spec, partition); + + try (PositionDeleteWriter writer = positionDeleteWriter) { + for (PositionDelete delete : deletes) { + writer.write(delete); + } + } catch (IOException e) { + throw new UncheckedIOException(e); + } + + return positionDeleteWriter.toDeleteFile(); + } +} diff --git a/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/source/TestIcebergSpark.java b/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/source/TestIcebergSpark.java new file mode 100644 index 000000000000..1a99697a09f9 --- /dev/null +++ b/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/source/TestIcebergSpark.java @@ -0,0 +1,225 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iceberg.spark.source; + +import java.math.BigDecimal; +import java.nio.ByteBuffer; +import java.sql.Date; +import java.sql.Timestamp; +import java.util.List; +import org.apache.iceberg.spark.IcebergSpark; +import org.apache.iceberg.transforms.Transforms; +import org.apache.iceberg.types.Types; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.SparkSession; +import org.apache.spark.sql.catalyst.util.DateTimeUtils; +import org.apache.spark.sql.types.CharType; +import org.apache.spark.sql.types.DataTypes; +import org.apache.spark.sql.types.DecimalType; +import org.apache.spark.sql.types.VarcharType; +import org.assertj.core.api.Assertions; +import org.junit.AfterClass; +import org.junit.Assert; +import org.junit.BeforeClass; +import org.junit.Test; + +public class TestIcebergSpark { + + private static SparkSession spark = null; + + @BeforeClass + public static void startSpark() { + TestIcebergSpark.spark = SparkSession.builder().master("local[2]").getOrCreate(); + } + + @AfterClass + public static void stopSpark() { + SparkSession currentSpark = TestIcebergSpark.spark; + TestIcebergSpark.spark = null; + currentSpark.stop(); + } + + @Test + public void testRegisterIntegerBucketUDF() { + IcebergSpark.registerBucketUDF(spark, "iceberg_bucket_int_16", DataTypes.IntegerType, 16); + List results = spark.sql("SELECT iceberg_bucket_int_16(1)").collectAsList(); + Assert.assertEquals(1, results.size()); + Assert.assertEquals((int) Transforms.bucket(Types.IntegerType.get(), 16).apply(1), + results.get(0).getInt(0)); + } + + @Test + public void testRegisterShortBucketUDF() { + IcebergSpark.registerBucketUDF(spark, "iceberg_bucket_short_16", DataTypes.ShortType, 16); + List results = spark.sql("SELECT iceberg_bucket_short_16(1S)").collectAsList(); + Assert.assertEquals(1, results.size()); + Assert.assertEquals((int) Transforms.bucket(Types.IntegerType.get(), 16).apply(1), + results.get(0).getInt(0)); + } + + @Test + public void testRegisterByteBucketUDF() { + IcebergSpark.registerBucketUDF(spark, "iceberg_bucket_byte_16", DataTypes.ByteType, 16); + List results = spark.sql("SELECT iceberg_bucket_byte_16(1Y)").collectAsList(); + Assert.assertEquals(1, results.size()); + Assert.assertEquals((int) Transforms.bucket(Types.IntegerType.get(), 16).apply(1), + results.get(0).getInt(0)); + } + + @Test + public void testRegisterLongBucketUDF() { + IcebergSpark.registerBucketUDF(spark, "iceberg_bucket_long_16", DataTypes.LongType, 16); + List results = spark.sql("SELECT iceberg_bucket_long_16(1L)").collectAsList(); + Assert.assertEquals(1, results.size()); + Assert.assertEquals((int) Transforms.bucket(Types.LongType.get(), 16).apply(1L), + results.get(0).getInt(0)); + } + + @Test + public void testRegisterStringBucketUDF() { + IcebergSpark.registerBucketUDF(spark, "iceberg_bucket_string_16", DataTypes.StringType, 16); + List results = spark.sql("SELECT iceberg_bucket_string_16('hello')").collectAsList(); + Assert.assertEquals(1, results.size()); + Assert.assertEquals((int) Transforms.bucket(Types.StringType.get(), 16).apply("hello"), + results.get(0).getInt(0)); + } + + @Test + public void testRegisterCharBucketUDF() { + IcebergSpark.registerBucketUDF(spark, "iceberg_bucket_char_16", new CharType(5), 16); + List results = spark.sql("SELECT iceberg_bucket_char_16('hello')").collectAsList(); + Assert.assertEquals(1, results.size()); + Assert.assertEquals((int) Transforms.bucket(Types.StringType.get(), 16).apply("hello"), + results.get(0).getInt(0)); + } + + @Test + public void testRegisterVarCharBucketUDF() { + IcebergSpark.registerBucketUDF(spark, "iceberg_bucket_varchar_16", new VarcharType(5), 16); + List results = spark.sql("SELECT iceberg_bucket_varchar_16('hello')").collectAsList(); + Assert.assertEquals(1, results.size()); + Assert.assertEquals((int) Transforms.bucket(Types.StringType.get(), 16).apply("hello"), + results.get(0).getInt(0)); + } + + @Test + public void testRegisterDateBucketUDF() { + IcebergSpark.registerBucketUDF(spark, "iceberg_bucket_date_16", DataTypes.DateType, 16); + List results = spark.sql("SELECT iceberg_bucket_date_16(DATE '2021-06-30')").collectAsList(); + Assert.assertEquals(1, results.size()); + Assert.assertEquals((int) Transforms.bucket(Types.DateType.get(), 16) + .apply(DateTimeUtils.fromJavaDate(Date.valueOf("2021-06-30"))), + results.get(0).getInt(0)); + } + + @Test + public void testRegisterTimestampBucketUDF() { + IcebergSpark.registerBucketUDF(spark, "iceberg_bucket_timestamp_16", DataTypes.TimestampType, 16); + List results = + spark.sql("SELECT iceberg_bucket_timestamp_16(TIMESTAMP '2021-06-30 00:00:00.000')").collectAsList(); + Assert.assertEquals(1, results.size()); + Assert.assertEquals((int) Transforms.bucket(Types.TimestampType.withZone(), 16) + .apply(DateTimeUtils.fromJavaTimestamp(Timestamp.valueOf("2021-06-30 00:00:00.000"))), + results.get(0).getInt(0)); + } + + @Test + public void testRegisterBinaryBucketUDF() { + IcebergSpark.registerBucketUDF(spark, "iceberg_bucket_binary_16", DataTypes.BinaryType, 16); + List results = + spark.sql("SELECT iceberg_bucket_binary_16(X'0020001F')").collectAsList(); + Assert.assertEquals(1, results.size()); + Assert.assertEquals((int) Transforms.bucket(Types.BinaryType.get(), 16) + .apply(ByteBuffer.wrap(new byte[]{0x00, 0x20, 0x00, 0x1F})), + results.get(0).getInt(0)); + } + + @Test + public void testRegisterDecimalBucketUDF() { + IcebergSpark.registerBucketUDF(spark, "iceberg_bucket_decimal_16", new DecimalType(4, 2), 16); + List results = + spark.sql("SELECT iceberg_bucket_decimal_16(11.11)").collectAsList(); + Assert.assertEquals(1, results.size()); + Assert.assertEquals((int) Transforms.bucket(Types.DecimalType.of(4, 2), 16) + .apply(new BigDecimal("11.11")), + results.get(0).getInt(0)); + } + + @Test + public void testRegisterBooleanBucketUDF() { + Assertions.assertThatThrownBy(() -> + IcebergSpark.registerBucketUDF(spark, "iceberg_bucket_boolean_16", DataTypes.BooleanType, 16)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("Cannot bucket by type: boolean"); + } + + @Test + public void testRegisterDoubleBucketUDF() { + Assertions.assertThatThrownBy(() -> + IcebergSpark.registerBucketUDF(spark, "iceberg_bucket_double_16", DataTypes.DoubleType, 16)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("Cannot bucket by type: double"); + } + + @Test + public void testRegisterFloatBucketUDF() { + Assertions.assertThatThrownBy(() -> + IcebergSpark.registerBucketUDF(spark, "iceberg_bucket_float_16", DataTypes.FloatType, 16)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("Cannot bucket by type: float"); + } + + @Test + public void testRegisterIntegerTruncateUDF() { + IcebergSpark.registerTruncateUDF(spark, "iceberg_truncate_int_4", DataTypes.IntegerType, 4); + List results = spark.sql("SELECT iceberg_truncate_int_4(1)").collectAsList(); + Assert.assertEquals(1, results.size()); + Assert.assertEquals(Transforms.truncate(Types.IntegerType.get(), 4).apply(1), + results.get(0).getInt(0)); + } + + @Test + public void testRegisterLongTruncateUDF() { + IcebergSpark.registerTruncateUDF(spark, "iceberg_truncate_long_4", DataTypes.LongType, 4); + List results = spark.sql("SELECT iceberg_truncate_long_4(1L)").collectAsList(); + Assert.assertEquals(1, results.size()); + Assert.assertEquals(Transforms.truncate(Types.LongType.get(), 4).apply(1L), + results.get(0).getLong(0)); + } + + @Test + public void testRegisterDecimalTruncateUDF() { + IcebergSpark.registerTruncateUDF(spark, "iceberg_truncate_decimal_4", new DecimalType(4, 2), 4); + List results = + spark.sql("SELECT iceberg_truncate_decimal_4(11.11)").collectAsList(); + Assert.assertEquals(1, results.size()); + Assert.assertEquals(Transforms.truncate(Types.DecimalType.of(4, 2), 4) + .apply(new BigDecimal("11.11")), results.get(0).getDecimal(0)); + } + + @Test + public void testRegisterStringTruncateUDF() { + IcebergSpark.registerTruncateUDF(spark, "iceberg_truncate_string_4", DataTypes.StringType, 4); + List results = spark.sql("SELECT iceberg_truncate_string_4('hello')").collectAsList(); + Assert.assertEquals(1, results.size()); + Assert.assertEquals(Transforms.truncate(Types.StringType.get(), 4).apply("hello"), + results.get(0).getString(0)); + } +} diff --git a/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/source/TestIdentityPartitionData.java b/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/source/TestIdentityPartitionData.java new file mode 100644 index 000000000000..e07798301db8 --- /dev/null +++ b/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/source/TestIdentityPartitionData.java @@ -0,0 +1,185 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iceberg.spark.source; + +import java.io.File; +import java.util.Arrays; +import java.util.List; +import java.util.Map; +import org.apache.hadoop.conf.Configuration; +import org.apache.iceberg.PartitionSpec; +import org.apache.iceberg.Schema; +import org.apache.iceberg.Table; +import org.apache.iceberg.TableProperties; +import org.apache.iceberg.hadoop.HadoopTables; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableList; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap; +import org.apache.iceberg.spark.SparkReadOptions; +import org.apache.iceberg.spark.SparkSchemaUtil; +import org.apache.iceberg.spark.SparkTableUtil; +import org.apache.iceberg.spark.SparkTestBase; +import org.apache.iceberg.types.Types; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.catalyst.TableIdentifier; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.TemporaryFolder; +import org.junit.runner.RunWith; +import org.junit.runners.Parameterized; + +@RunWith(Parameterized.class) +public class TestIdentityPartitionData extends SparkTestBase { + private static final Configuration CONF = new Configuration(); + private static final HadoopTables TABLES = new HadoopTables(CONF); + + @Parameterized.Parameters(name = "format = {0}, vectorized = {1}") + public static Object[][] parameters() { + return new Object[][] { + { "parquet", false }, + { "parquet", true }, + { "avro", false }, + { "orc", false }, + { "orc", true }, + }; + } + + private final String format; + private final boolean vectorized; + + public TestIdentityPartitionData(String format, boolean vectorized) { + this.format = format; + this.vectorized = vectorized; + } + + private static final Schema LOG_SCHEMA = new Schema( + Types.NestedField.optional(1, "id", Types.IntegerType.get()), + Types.NestedField.optional(2, "date", Types.StringType.get()), + Types.NestedField.optional(3, "level", Types.StringType.get()), + Types.NestedField.optional(4, "message", Types.StringType.get()) + ); + + private static final List LOGS = ImmutableList.of( + LogMessage.debug("2020-02-02", "debug event 1"), + LogMessage.info("2020-02-02", "info event 1"), + LogMessage.debug("2020-02-02", "debug event 2"), + LogMessage.info("2020-02-03", "info event 2"), + LogMessage.debug("2020-02-03", "debug event 3"), + LogMessage.info("2020-02-03", "info event 3"), + LogMessage.error("2020-02-03", "error event 1"), + LogMessage.debug("2020-02-04", "debug event 4"), + LogMessage.warn("2020-02-04", "warn event 1"), + LogMessage.debug("2020-02-04", "debug event 5") + ); + + @Rule + public TemporaryFolder temp = new TemporaryFolder(); + + private PartitionSpec spec = PartitionSpec.builderFor(LOG_SCHEMA).identity("date").identity("level").build(); + private Table table = null; + private Dataset logs = null; + + /** + * Use the Hive Based table to make Identity Partition Columns with no duplication of the data in the underlying + * parquet files. This makes sure that if the identity mapping fails, the test will also fail. + */ + private void setupParquet() throws Exception { + File location = temp.newFolder("logs"); + File hiveLocation = temp.newFolder("hive"); + String hiveTable = "hivetable"; + Assert.assertTrue("Temp folder should exist", location.exists()); + + Map properties = ImmutableMap.of(TableProperties.DEFAULT_FILE_FORMAT, format); + this.logs = spark.createDataFrame(LOGS, LogMessage.class).select("id", "date", "level", "message"); + spark.sql(String.format("DROP TABLE IF EXISTS %s", hiveTable)); + logs.orderBy("date", "level", "id").write().partitionBy("date", "level").format("parquet") + .option("path", hiveLocation.toString()).saveAsTable(hiveTable); + + this.table = TABLES.create(SparkSchemaUtil.schemaForTable(spark, hiveTable), + SparkSchemaUtil.specForTable(spark, hiveTable), properties, location.toString()); + + SparkTableUtil.importSparkTable(spark, new TableIdentifier(hiveTable), table, location.toString()); + } + + @Before + public void setupTable() throws Exception { + if (format.equals("parquet")) { + setupParquet(); + } else { + File location = temp.newFolder("logs"); + Assert.assertTrue("Temp folder should exist", location.exists()); + + Map properties = ImmutableMap.of(TableProperties.DEFAULT_FILE_FORMAT, format); + this.table = TABLES.create(LOG_SCHEMA, spec, properties, location.toString()); + this.logs = spark.createDataFrame(LOGS, LogMessage.class).select("id", "date", "level", "message"); + + logs.orderBy("date", "level", "id").write().format("iceberg").mode("append").save(location.toString()); + } + } + + @Test + public void testFullProjection() { + List expected = logs.orderBy("id").collectAsList(); + List actual = spark.read().format("iceberg") + .option(SparkReadOptions.VECTORIZATION_ENABLED, String.valueOf(vectorized)) + .load(table.location()).orderBy("id") + .select("id", "date", "level", "message") + .collectAsList(); + Assert.assertEquals("Rows should match", expected, actual); + } + + @Test + public void testProjections() { + String[][] cases = new String[][] { + // individual fields + new String[] { "date" }, + new String[] { "level" }, + new String[] { "message" }, + // field pairs + new String[] { "date", "message" }, + new String[] { "level", "message" }, + new String[] { "date", "level" }, + // out-of-order pairs + new String[] { "message", "date" }, + new String[] { "message", "level" }, + new String[] { "level", "date" }, + // full projection, different orderings + new String[] { "date", "level", "message" }, + new String[] { "level", "date", "message" }, + new String[] { "date", "message", "level" }, + new String[] { "level", "message", "date" }, + new String[] { "message", "date", "level" }, + new String[] { "message", "level", "date" } + }; + + for (String[] ordering : cases) { + List expected = logs.select("id", ordering).orderBy("id").collectAsList(); + List actual = spark.read() + .format("iceberg") + .option(SparkReadOptions.VECTORIZATION_ENABLED, String.valueOf(vectorized)) + .load(table.location()) + .select("id", ordering).orderBy("id") + .collectAsList(); + Assert.assertEquals("Rows should match for ordering: " + Arrays.toString(ordering), expected, actual); + } + } +} diff --git a/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/source/TestInternalRowWrapper.java b/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/source/TestInternalRowWrapper.java new file mode 100644 index 000000000000..4ab01044046f --- /dev/null +++ b/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/source/TestInternalRowWrapper.java @@ -0,0 +1,78 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iceberg.spark.source; + +import java.util.Iterator; +import org.apache.iceberg.RecordWrapperTest; +import org.apache.iceberg.Schema; +import org.apache.iceberg.StructLike; +import org.apache.iceberg.data.InternalRecordWrapper; +import org.apache.iceberg.data.RandomGenericData; +import org.apache.iceberg.data.Record; +import org.apache.iceberg.spark.SparkSchemaUtil; +import org.apache.iceberg.spark.data.RandomData; +import org.apache.iceberg.util.StructLikeWrapper; +import org.apache.spark.sql.catalyst.InternalRow; +import org.junit.Assert; +import org.junit.Ignore; + +public class TestInternalRowWrapper extends RecordWrapperTest { + + @Ignore + @Override + public void testTimestampWithoutZone() { + // Spark does not support timestamp without zone. + } + + @Ignore + @Override + public void testTime() { + // Spark does not support time fields. + } + + @Override + protected void generateAndValidate(Schema schema, AssertMethod assertMethod) { + int numRecords = 100; + Iterable recordList = RandomGenericData.generate(schema, numRecords, 101L); + Iterable rowList = RandomData.generateSpark(schema, numRecords, 101L); + + InternalRecordWrapper recordWrapper = new InternalRecordWrapper(schema.asStruct()); + InternalRowWrapper rowWrapper = new InternalRowWrapper(SparkSchemaUtil.convert(schema)); + + Iterator actual = recordList.iterator(); + Iterator expected = rowList.iterator(); + + StructLikeWrapper actualWrapper = StructLikeWrapper.forType(schema.asStruct()); + StructLikeWrapper expectedWrapper = StructLikeWrapper.forType(schema.asStruct()); + for (int i = 0; i < numRecords; i++) { + Assert.assertTrue("Should have more records", actual.hasNext()); + Assert.assertTrue("Should have more InternalRow", expected.hasNext()); + + StructLike recordStructLike = recordWrapper.wrap(actual.next()); + StructLike rowStructLike = rowWrapper.wrap(expected.next()); + + assertMethod.assertEquals("Should have expected StructLike values", + actualWrapper.set(recordStructLike), expectedWrapper.set(rowStructLike)); + } + + Assert.assertFalse("Shouldn't have more record", actual.hasNext()); + Assert.assertFalse("Shouldn't have more InternalRow", expected.hasNext()); + } +} diff --git a/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/source/TestMetadataTablesWithPartitionEvolution.java b/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/source/TestMetadataTablesWithPartitionEvolution.java new file mode 100644 index 000000000000..691e9f6f5481 --- /dev/null +++ b/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/source/TestMetadataTablesWithPartitionEvolution.java @@ -0,0 +1,806 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iceberg.spark.source; + +import java.util.Arrays; +import java.util.List; +import java.util.Map; +import org.apache.iceberg.AssertHelpers; +import org.apache.iceberg.FileFormat; +import org.apache.iceberg.HasTableOperations; +import org.apache.iceberg.MetadataTableType; +import org.apache.iceberg.PartitionSpec; +import org.apache.iceberg.PartitionSpecParser; +import org.apache.iceberg.Table; +import org.apache.iceberg.TableMetadata; +import org.apache.iceberg.TableOperations; +import org.apache.iceberg.exceptions.ValidationException; +import org.apache.iceberg.expressions.Expressions; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableList; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap; +import org.apache.iceberg.spark.SparkCatalog; +import org.apache.iceberg.spark.SparkCatalogTestBase; +import org.apache.iceberg.spark.SparkSessionCatalog; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.catalyst.parser.ParseException; +import org.apache.spark.sql.types.DataType; +import org.apache.spark.sql.types.StructType; +import org.junit.After; +import org.junit.Assert; +import org.junit.Assume; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.Parameterized; +import org.junit.runners.Parameterized.Parameters; + +import static org.apache.iceberg.FileFormat.AVRO; +import static org.apache.iceberg.FileFormat.ORC; +import static org.apache.iceberg.FileFormat.PARQUET; +import static org.apache.iceberg.MetadataTableType.ALL_DATA_FILES; +import static org.apache.iceberg.MetadataTableType.ALL_ENTRIES; +import static org.apache.iceberg.MetadataTableType.ENTRIES; +import static org.apache.iceberg.MetadataTableType.FILES; +import static org.apache.iceberg.MetadataTableType.PARTITIONS; +import static org.apache.iceberg.TableProperties.DEFAULT_FILE_FORMAT; +import static org.apache.iceberg.TableProperties.FORMAT_VERSION; + +@RunWith(Parameterized.class) +public class TestMetadataTablesWithPartitionEvolution extends SparkCatalogTestBase { + + @Parameters(name = "catalog = {0}, impl = {1}, conf = {2}, fileFormat = {3}, formatVersion = {4}") + public static Object[][] parameters() { + return new Object[][] { + { "testhive", SparkCatalog.class.getName(), + ImmutableMap.of( + "type", "hive", + "default-namespace", "default" + ), + ORC, + 1 + }, + { "testhive", SparkCatalog.class.getName(), + ImmutableMap.of( + "type", "hive", + "default-namespace", "default" + ), + ORC, + 2 + }, + { "testhadoop", SparkCatalog.class.getName(), + ImmutableMap.of( + "type", "hadoop" + ), + PARQUET, + 1 + }, + { "testhadoop", SparkCatalog.class.getName(), + ImmutableMap.of( + "type", "hadoop" + ), + PARQUET, + 2 + }, + { "spark_catalog", SparkSessionCatalog.class.getName(), + ImmutableMap.of( + "type", "hive", + "default-namespace", "default", + "clients", "1", + "parquet-enabled", "false", + "cache-enabled", "false" // Spark will delete tables using v1, leaving the cache out of sync + ), + AVRO, + 1 + }, + { "spark_catalog", SparkSessionCatalog.class.getName(), + ImmutableMap.of( + "type", "hive", + "default-namespace", "default", + "clients", "1", + "parquet-enabled", "false", + "cache-enabled", "false" // Spark will delete tables using v1, leaving the cache out of sync + ), + AVRO, + 2 + } + }; + } + + private final FileFormat fileFormat; + private final int formatVersion; + + public TestMetadataTablesWithPartitionEvolution(String catalogName, String implementation, Map config, + FileFormat fileFormat, int formatVersion) { + super(catalogName, implementation, config); + this.fileFormat = fileFormat; + this.formatVersion = formatVersion; + } + + @After + public void removeTable() { + sql("DROP TABLE IF EXISTS %s", tableName); + } + + @Test + public void testFilesMetadataTable() throws ParseException { + sql("CREATE TABLE %s (id bigint NOT NULL, category string, data string) USING iceberg", tableName); + initTable(); + + sql("INSERT INTO TABLE %s VALUES (1, 'a1', 'b1')", tableName); + + // verify the metadata tables while the current spec is still unpartitioned + for (MetadataTableType tableType : Arrays.asList(FILES, ALL_DATA_FILES)) { + Dataset df = loadMetadataTable(tableType); + Assert.assertTrue("Partition must be skipped", df.schema().getFieldIndex("partition").isEmpty()); + } + + Table table = validationCatalog.loadTable(tableIdent); + + table.updateSpec() + .addField("data") + .commit(); + sql("REFRESH TABLE %s", tableName); + sql("INSERT INTO TABLE %s VALUES (1, 'a1', 'b1')", tableName); + + // verify the metadata tables after adding the first partition column + for (MetadataTableType tableType : Arrays.asList(FILES, ALL_DATA_FILES)) { + assertPartitions( + ImmutableList.of(row(new Object[]{null}), row("b1")), + "STRUCT", + tableType); + } + + table.updateSpec() + .addField(Expressions.bucket("category", 8)) + .commit(); + sql("REFRESH TABLE %s", tableName); + sql("INSERT INTO TABLE %s VALUES (1, 'a1', 'b1')", tableName); + + // verify the metadata tables after adding the second partition column + for (MetadataTableType tableType : Arrays.asList(FILES, ALL_DATA_FILES)) { + assertPartitions( + ImmutableList.of(row(null, null), row("b1", null), row("b1", 2)), + "STRUCT", + tableType); + } + + table.updateSpec() + .removeField("data") + .commit(); + sql("REFRESH TABLE %s", tableName); + sql("INSERT INTO TABLE %s VALUES (1, 'a1', 'b1')", tableName); + + // verify the metadata tables after dropping the first partition column + for (MetadataTableType tableType : Arrays.asList(FILES, ALL_DATA_FILES)) { + assertPartitions( + ImmutableList.of(row(null, null), row(null, 2), row("b1", null), row("b1", 2)), + "STRUCT", + tableType); + } + + table.updateSpec() + .renameField("category_bucket_8", "category_bucket_8_another_name") + .commit(); + sql("REFRESH TABLE %s", tableName); + + // verify the metadata tables after renaming the second partition column + for (MetadataTableType tableType : Arrays.asList(FILES, ALL_DATA_FILES)) { + assertPartitions( + ImmutableList.of(row(null, null), row(null, 2), row("b1", null), row("b1", 2)), + "STRUCT", + tableType); + } + } + + @Test + public void testFilesMetadataTableFilter() throws ParseException { + sql("CREATE TABLE %s (id bigint NOT NULL, category string, data string) USING iceberg " + + "TBLPROPERTIES ('commit.manifest-merge.enabled' 'false')", tableName); + initTable(); + + sql("INSERT INTO TABLE %s VALUES (1, 'c1', 'd1')", tableName); + sql("INSERT INTO TABLE %s VALUES (2, 'c2', 'd2')", tableName); + + // verify the metadata tables while the current spec is still unpartitioned + for (MetadataTableType tableType : Arrays.asList(FILES, ALL_DATA_FILES)) { + Dataset df = loadMetadataTable(tableType); + Assert.assertTrue("Partition must be skipped", df.schema().getFieldIndex("partition").isEmpty()); + } + + Table table = validationCatalog.loadTable(tableIdent); + + table.updateSpec() + .addField("data") + .commit(); + sql("REFRESH TABLE %s", tableName); + sql("INSERT INTO TABLE %s VALUES (1, 'c1', 'd1')", tableName); + sql("INSERT INTO TABLE %s VALUES (2, 'c2', 'd2')", tableName); + + // verify the metadata tables after adding the first partition column + for (MetadataTableType tableType : Arrays.asList(FILES, ALL_DATA_FILES)) { + assertPartitions( + ImmutableList.of(row("d2")), + "STRUCT", + tableType, + "partition.data = 'd2'"); + } + + table.updateSpec() + .addField("category") + .commit(); + sql("REFRESH TABLE %s", tableName); + sql("INSERT INTO TABLE %s VALUES (1, 'c1', 'd1')", tableName); + sql("INSERT INTO TABLE %s VALUES (2, 'c2', 'd2')", tableName); + + // verify the metadata tables after adding the second partition column + for (MetadataTableType tableType : Arrays.asList(FILES, ALL_DATA_FILES)) { + assertPartitions(ImmutableList.of(row("d2", null), row("d2", "c2")), + "STRUCT", + tableType, + "partition.data = 'd2'"); + } + for (MetadataTableType tableType : Arrays.asList(FILES, ALL_DATA_FILES)) { + assertPartitions( + ImmutableList.of(row("d2", "c2")), + "STRUCT", + tableType, + "partition.category = 'c2'"); + } + + table.updateSpec() + .removeField("data") + .commit(); + sql("REFRESH TABLE %s", tableName); + + // Verify new partitions do not show up for removed 'partition.data=d2' query + sql("INSERT INTO TABLE %s VALUES (3, 'c3', 'd2')", tableName); + sql("INSERT INTO TABLE %s VALUES (4, 'c4', 'd2')", tableName); + + // Verify new partitions do show up for 'partition.category=c2' query + sql("INSERT INTO TABLE %s VALUES (5, 'c2', 'd5')", tableName); + + // no new partition should show up for 'data' partition query as partition field has been removed + for (MetadataTableType tableType : Arrays.asList(FILES, ALL_DATA_FILES)) { + assertPartitions( + ImmutableList.of(row("d2", null), row("d2", "c2")), + "STRUCT", + tableType, + "partition.data = 'd2'"); + } + // new partition shows up from 'category' partition field query + for (MetadataTableType tableType : Arrays.asList(FILES, ALL_DATA_FILES)) { + assertPartitions( + ImmutableList.of(row(null, "c2"), row("d2", "c2")), + "STRUCT", + tableType, + "partition.category = 'c2'"); + } + + table.updateSpec() + .renameField("category", "category_another_name") + .commit(); + sql("REFRESH TABLE %s", tableName); + + // Verify new partitions do show up for 'category=c2' query + sql("INSERT INTO TABLE %s VALUES (6, 'c2', 'd6')", tableName); + for (MetadataTableType tableType : Arrays.asList(FILES, ALL_DATA_FILES)) { + assertPartitions( + ImmutableList.of(row(null, "c2"), row(null, "c2"), row("d2", "c2")), + "STRUCT", + tableType, + "partition.category_another_name = 'c2'"); + } + } + + @Test + public void testEntriesMetadataTable() throws ParseException { + sql("CREATE TABLE %s (id bigint NOT NULL, category string, data string) USING iceberg", tableName); + initTable(); + + sql("INSERT INTO TABLE %s VALUES (1, 'a1', 'b1')", tableName); + + // verify the metadata tables while the current spec is still unpartitioned + for (MetadataTableType tableType : Arrays.asList(ENTRIES, ALL_ENTRIES)) { + Dataset df = loadMetadataTable(tableType); + StructType dataFileType = (StructType) df.schema().apply("data_file").dataType(); + Assert.assertTrue("Partition must be skipped", dataFileType.getFieldIndex("").isEmpty()); + } + + Table table = validationCatalog.loadTable(tableIdent); + + table.updateSpec() + .addField("data") + .commit(); + sql("REFRESH TABLE %s", tableName); + sql("INSERT INTO TABLE %s VALUES (1, 'a1', 'b1')", tableName); + + // verify the metadata tables after adding the first partition column + for (MetadataTableType tableType : Arrays.asList(ENTRIES, ALL_ENTRIES)) { + assertPartitions( + ImmutableList.of(row(new Object[]{null}), row("b1")), + "STRUCT", + tableType); + } + + table.updateSpec() + .addField(Expressions.bucket("category", 8)) + .commit(); + sql("REFRESH TABLE %s", tableName); + sql("INSERT INTO TABLE %s VALUES (1, 'a1', 'b1')", tableName); + + // verify the metadata tables after adding the second partition column + for (MetadataTableType tableType : Arrays.asList(ENTRIES, ALL_ENTRIES)) { + assertPartitions( + ImmutableList.of(row(null, null), row("b1", null), row("b1", 2)), + "STRUCT", + tableType); + } + + table.updateSpec() + .removeField("data") + .commit(); + sql("REFRESH TABLE %s", tableName); + sql("INSERT INTO TABLE %s VALUES (1, 'a1', 'b1')", tableName); + + // verify the metadata tables after dropping the first partition column + for (MetadataTableType tableType : Arrays.asList(ENTRIES, ALL_ENTRIES)) { + assertPartitions( + ImmutableList.of(row(null, null), row(null, 2), row("b1", null), row("b1", 2)), + "STRUCT", + tableType); + } + + table.updateSpec() + .renameField("category_bucket_8", "category_bucket_8_another_name") + .commit(); + sql("REFRESH TABLE %s", tableName); + + // verify the metadata tables after renaming the second partition column + for (MetadataTableType tableType : Arrays.asList(ENTRIES, ALL_ENTRIES)) { + assertPartitions( + ImmutableList.of(row(null, null), row(null, 2), row("b1", null), row("b1", 2)), + "STRUCT", + tableType); + } + } + @Test + public void testPartitionsTableAddRemoveFields() throws ParseException { + sql("CREATE TABLE %s (id bigint NOT NULL, category string, data string) USING iceberg ", tableName); + initTable(); + sql("INSERT INTO TABLE %s VALUES (1, 'c1', 'd1')", tableName); + sql("INSERT INTO TABLE %s VALUES (2, 'c2', 'd2')", tableName); + + // verify the metadata tables while the current spec is still unpartitioned + Dataset df = loadMetadataTable(PARTITIONS); + Assert.assertTrue("Partition must be skipped", df.schema().getFieldIndex("partition").isEmpty()); + + Table table = validationCatalog.loadTable(tableIdent); + + table.updateSpec() + .addField("data") + .commit(); + sql("REFRESH TABLE %s", tableName); + sql("INSERT INTO TABLE %s VALUES (1, 'c1', 'd1')", tableName); + sql("INSERT INTO TABLE %s VALUES (2, 'c2', 'd2')", tableName); + + // verify the metadata tables after adding the first partition column + assertPartitions( + ImmutableList.of(row(new Object[]{null}), row("d1"), row("d2")), + "STRUCT", + PARTITIONS); + + table.updateSpec() + .addField("category") + .commit(); + sql("REFRESH TABLE %s", tableName); + sql("INSERT INTO TABLE %s VALUES (1, 'c1', 'd1')", tableName); + sql("INSERT INTO TABLE %s VALUES (2, 'c2', 'd2')", tableName); + + // verify the metadata tables after adding the second partition column + assertPartitions(ImmutableList.of( + row(null, null), + row("d1", null), + row("d1", "c1"), + row("d2", null), + row("d2", "c2")), + "STRUCT", + PARTITIONS); + + // verify the metadata tables after removing the first partition column + table.updateSpec() + .removeField("data") + .commit(); + sql("REFRESH TABLE %s", tableName); + sql("INSERT INTO TABLE %s VALUES (1, 'c1', 'd1')", tableName); + sql("INSERT INTO TABLE %s VALUES (2, 'c2', 'd2')", tableName); + + assertPartitions( + ImmutableList.of( + row(null, null), + row(null, "c1"), + row(null, "c2"), + row("d1", null), + row("d1", "c1"), + row("d2", null), + row("d2", "c2")), + "STRUCT", + PARTITIONS); + } + + @Test + public void testPartitionsTableRenameFields() throws ParseException { + sql("CREATE TABLE %s (id bigint NOT NULL, category string, data string) USING iceberg", tableName); + initTable(); + + Table table = validationCatalog.loadTable(tableIdent); + + table.updateSpec() + .addField("data") + .addField("category") + .commit(); + sql("REFRESH TABLE %s", tableName); + sql("INSERT INTO TABLE %s VALUES (1, 'c1', 'd1')", tableName); + sql("INSERT INTO TABLE %s VALUES (2, 'c2', 'd2')", tableName); + + assertPartitions(ImmutableList.of( + row("d1", "c1"), + row("d2", "c2")), + "STRUCT", + PARTITIONS); + + table.updateSpec() + .renameField("category", "category_another_name") + .commit(); + sql("REFRESH TABLE %s", tableName); + sql("INSERT INTO TABLE %s VALUES (1, 'c1', 'd1')", tableName); + sql("INSERT INTO TABLE %s VALUES (2, 'c2', 'd2')", tableName); + + assertPartitions( + ImmutableList.of( + row("d1", "c1"), + row("d2", "c2")), + "STRUCT", + PARTITIONS); + } + + @Test + public void testPartitionsTableSwitchFields() throws Exception { + sql("CREATE TABLE %s (id bigint NOT NULL, category string, data string) USING iceberg", tableName); + initTable(); + Table table = validationCatalog.loadTable(tableIdent); + + // verify the metadata tables after re-adding the first dropped column in the second location + table.updateSpec() + .addField("data") + .addField("category") + .commit(); + sql("REFRESH TABLE %s", tableName); + + sql("INSERT INTO TABLE %s VALUES (1, 'c1', 'd1')", tableName); + sql("INSERT INTO TABLE %s VALUES (2, 'c2', 'd2')", tableName); + + assertPartitions(ImmutableList.of( + row("d1", "c1"), + row("d2", "c2")), + "STRUCT", + PARTITIONS); + + table.updateSpec() + .removeField("data") + .commit(); + sql("REFRESH TABLE %s", tableName); + + sql("INSERT INTO TABLE %s VALUES (1, 'c1', 'd1')", tableName); + sql("INSERT INTO TABLE %s VALUES (2, 'c2', 'd2')", tableName); + + assertPartitions( + ImmutableList.of( + row(null, "c1"), + row(null, "c2"), + row("d1", "c1"), + row("d2", "c2")), + "STRUCT", + PARTITIONS); + + table.updateSpec() + .addField("data") + .commit(); + sql("REFRESH TABLE %s", tableName); + + sql("INSERT INTO TABLE %s VALUES (1, 'c1', 'd1')", tableName); + sql("INSERT INTO TABLE %s VALUES (2, 'c2', 'd2')", tableName); + sql("INSERT INTO TABLE %s VALUES (3, 'c3', 'd3')", tableName); + + if (formatVersion == 1) { + assertPartitions( + ImmutableList.of( + row(null, "c1", null), + row(null, "c1", "d1"), + row(null, "c2", null), + row(null, "c2", "d2"), + row(null, "c3", "d3"), + row("d1", "c1", null), + row("d2", "c2", null)), + "STRUCT", + PARTITIONS); + } else { + // In V2 re-adding a former partition field that was part of an older spec will not change its name or its + // field ID either, thus values will be collapsed into a single common column (as opposed to V1 where any new + // partition field addition will result in a new column in this metadata table) + assertPartitions( + ImmutableList.of( + row(null, "c1"), + row(null, "c2"), + row("d1", "c1"), + row("d2", "c2"), + row("d3", "c3")), + "STRUCT", + PARTITIONS); + } + } + + @Test + public void testPartitionTableFilterAddRemoveFields() throws ParseException { + // Create un-partitioned table + sql("CREATE TABLE %s (id bigint NOT NULL, category string, data string) USING iceberg", tableName); + initTable(); + + sql("INSERT INTO TABLE %s VALUES (1, 'c1', 'd1')", tableName); + sql("INSERT INTO TABLE %s VALUES (2, 'c2', 'd2')", tableName); + + // Partition Table with one partition column + Table table = validationCatalog.loadTable(tableIdent); + table.updateSpec() + .addField("data") + .commit(); + sql("REFRESH TABLE %s", tableName); + sql("INSERT INTO TABLE %s VALUES (1, 'c1', 'd1')", tableName); + sql("INSERT INTO TABLE %s VALUES (2, 'c2', 'd2')", tableName); + + assertPartitions( + ImmutableList.of(row("d2")), + "STRUCT", + PARTITIONS, + "partition.data = 'd2'"); + + // Partition Table with two partition column + table.updateSpec() + .addField("category") + .commit(); + sql("REFRESH TABLE %s", tableName); + sql("INSERT INTO TABLE %s VALUES (1, 'c1', 'd1')", tableName); + sql("INSERT INTO TABLE %s VALUES (2, 'c2', 'd2')", tableName); + + assertPartitions(ImmutableList.of(row("d2", null), row("d2", "c2")), + "STRUCT", + PARTITIONS, + "partition.data = 'd2'"); + assertPartitions( + ImmutableList.of(row("d2", "c2")), + "STRUCT", + PARTITIONS, + "partition.category = 'c2'"); + + // Partition Table with first partition column removed + table.updateSpec() + .removeField("data") + .commit(); + sql("REFRESH TABLE %s", tableName); + + sql("INSERT INTO TABLE %s VALUES (3, 'c3', 'd2')", tableName); + sql("INSERT INTO TABLE %s VALUES (4, 'c4', 'd2')", tableName); + sql("INSERT INTO TABLE %s VALUES (5, 'c2', 'd5')", tableName); + assertPartitions( + ImmutableList.of(row("d2", null), row("d2", "c2")), + "STRUCT", + PARTITIONS, + "partition.data = 'd2'"); + assertPartitions( + ImmutableList.of(row(null, "c2"), row("d2", "c2")), + "STRUCT", + PARTITIONS, + "partition.category = 'c2'"); + } + + @Test + public void testPartitionTableFilterSwitchFields() throws Exception { + // Re-added partition fields currently not re-associated: https://github.com/apache/iceberg/issues/4292 + // In V1, dropped partition fields show separately when field is re-added + // In V2, re-added field currently conflicts with its deleted form + Assume.assumeTrue(formatVersion == 1); + + sql("CREATE TABLE %s (id bigint NOT NULL, category string, data string) USING iceberg", tableName); + initTable(); + sql("INSERT INTO TABLE %s VALUES (1, 'c1', 'd1')", tableName); + sql("INSERT INTO TABLE %s VALUES (2, 'c2', 'd2')", tableName); + Table table = validationCatalog.loadTable(tableIdent); + + // Two partition columns + table.updateSpec() + .addField("data") + .addField("category") + .commit(); + sql("REFRESH TABLE %s", tableName); + + sql("INSERT INTO TABLE %s VALUES (1, 'c1', 'd1')", tableName); + sql("INSERT INTO TABLE %s VALUES (2, 'c2', 'd2')", tableName); + + // Drop first partition column + table.updateSpec() + .removeField("data") + .commit(); + sql("REFRESH TABLE %s", tableName); + sql("INSERT INTO TABLE %s VALUES (1, 'c1', 'd1')", tableName); + sql("INSERT INTO TABLE %s VALUES (2, 'c2', 'd2')", tableName); + + // Re-add first partition column at the end + table.updateSpec() + .addField("data") + .commit(); + sql("REFRESH TABLE %s", tableName); + + sql("INSERT INTO TABLE %s VALUES (1, 'c1', 'd1')", tableName); + sql("INSERT INTO TABLE %s VALUES (2, 'c2', 'd2')", tableName); + + assertPartitions( + ImmutableList.of( + row(null, "c2", null), + row(null, "c2", "d2"), + row("d2", "c2", null)), + "STRUCT", + PARTITIONS, + "partition.category = 'c2'"); + + assertPartitions( + ImmutableList.of(row(null, "c1", "d1")), + "STRUCT", + PARTITIONS, + "partition.data = 'd1'"); + } + + @Test + public void testPartitionsTableFilterRenameFields() throws ParseException { + sql("CREATE TABLE %s (id bigint NOT NULL, category string, data string) USING iceberg", tableName); + initTable(); + + Table table = validationCatalog.loadTable(tableIdent); + + table.updateSpec() + .addField("data") + .addField("category") + .commit(); + sql("REFRESH TABLE %s", tableName); + sql("INSERT INTO TABLE %s VALUES (1, 'c1', 'd1')", tableName); + sql("INSERT INTO TABLE %s VALUES (2, 'c2', 'd2')", tableName); + + table.updateSpec() + .renameField("category", "category_another_name") + .commit(); + sql("REFRESH TABLE %s", tableName); + sql("INSERT INTO TABLE %s VALUES (1, 'c1', 'd1')", tableName); + sql("INSERT INTO TABLE %s VALUES (2, 'c2', 'd2')", tableName); + + assertPartitions( + ImmutableList.of(row("d1", "c1")), + "STRUCT", + PARTITIONS, + "partition.category_another_name = 'c1'"); + } + + @Test + public void testMetadataTablesWithUnknownTransforms() { + sql("CREATE TABLE %s (id bigint NOT NULL, category string, data string) USING iceberg", tableName); + initTable(); + + sql("INSERT INTO TABLE %s VALUES (1, 'a1', 'b1')", tableName); + + Table table = validationCatalog.loadTable(tableIdent); + + PartitionSpec unknownSpec = PartitionSpecParser.fromJson(table.schema(), + "{ \"spec-id\": 1, \"fields\": [ { \"name\": \"id_zero\", \"transform\": \"zero\", \"source-id\": 1 } ] }"); + + // replace the table spec to include an unknown transform + TableOperations ops = ((HasTableOperations) table).operations(); + TableMetadata base = ops.current(); + ops.commit(base, base.updatePartitionSpec(unknownSpec)); + + sql("REFRESH TABLE %s", tableName); + + for (MetadataTableType tableType : Arrays.asList(FILES, ALL_DATA_FILES, ENTRIES, ALL_ENTRIES)) { + AssertHelpers.assertThrows("Should complain about the partition type", + ValidationException.class, "Cannot build table partition type, unknown transforms", + () -> loadMetadataTable(tableType)); + } + } + + @Test + public void testPartitionColumnNamedPartition() { + sql("CREATE TABLE %s (id int, partition int) USING iceberg PARTITIONED BY (partition)", tableName); + sql("INSERT INTO %s VALUES (1, 1), (2, 1), (3, 2), (2, 2)", tableName); + List expected = ImmutableList.of( + row(1, 1), row(2, 1), row(3, 2), row(2, 2)); + assertEquals("Should return all expected rows", expected, sql("SELECT * FROM %s", tableName)); + Assert.assertEquals(2, sql("SELECT * FROM %s.files", tableName).size()); + } + + private void assertPartitions(List expectedPartitions, String expectedTypeAsString, + MetadataTableType tableType) throws ParseException { + assertPartitions(expectedPartitions, expectedTypeAsString, tableType, null); + } + + private void assertPartitions(List expectedPartitions, String expectedTypeAsString, + MetadataTableType tableType, String filter) throws ParseException { + Dataset df = loadMetadataTable(tableType); + if (filter != null) { + df = df.filter(filter); + } + + DataType expectedType = spark.sessionState().sqlParser().parseDataType(expectedTypeAsString); + switch (tableType) { + case PARTITIONS: + case FILES: + case ALL_DATA_FILES: + DataType actualFilesType = df.schema().apply("partition").dataType(); + Assert.assertEquals("Partition type must match", expectedType, actualFilesType); + break; + + case ENTRIES: + case ALL_ENTRIES: + StructType dataFileType = (StructType) df.schema().apply("data_file").dataType(); + DataType actualEntriesType = dataFileType.apply("partition").dataType(); + Assert.assertEquals("Partition type must match", expectedType, actualEntriesType); + break; + + default: + throw new UnsupportedOperationException("Unsupported metadata table type: " + tableType); + } + + switch (tableType) { + case PARTITIONS: + case FILES: + case ALL_DATA_FILES: + List actualFilesPartitions = df.orderBy("partition") + .select("partition.*") + .collectAsList(); + assertEquals("Partitions must match", expectedPartitions, rowsToJava(actualFilesPartitions)); + break; + + case ENTRIES: + case ALL_ENTRIES: + List actualEntriesPartitions = df.orderBy("data_file.partition") + .select("data_file.partition.*") + .collectAsList(); + assertEquals("Partitions must match", expectedPartitions, rowsToJava(actualEntriesPartitions)); + break; + + default: + throw new UnsupportedOperationException("Unsupported metadata table type: " + tableType); + } + } + + private Dataset loadMetadataTable(MetadataTableType tableType) { + return spark.read().format("iceberg").load(tableName + "." + tableType.name()); + } + + private void initTable() { + sql("ALTER TABLE %s SET TBLPROPERTIES('%s' '%s')", tableName, DEFAULT_FILE_FORMAT, fileFormat.name()); + sql("ALTER TABLE %s SET TBLPROPERTIES('%s' '%d')", tableName, FORMAT_VERSION, formatVersion); + } +} diff --git a/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/source/TestParquetScan.java b/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/source/TestParquetScan.java new file mode 100644 index 000000000000..adfe8c7d3649 --- /dev/null +++ b/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/source/TestParquetScan.java @@ -0,0 +1,139 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iceberg.spark.source; + +import java.io.File; +import java.io.IOException; +import java.util.List; +import java.util.UUID; +import org.apache.avro.generic.GenericData; +import org.apache.hadoop.conf.Configuration; +import org.apache.iceberg.DataFile; +import org.apache.iceberg.DataFiles; +import org.apache.iceberg.FileFormat; +import org.apache.iceberg.PartitionSpec; +import org.apache.iceberg.Schema; +import org.apache.iceberg.Table; +import org.apache.iceberg.TableProperties; +import org.apache.iceberg.hadoop.HadoopTables; +import org.apache.iceberg.io.FileAppender; +import org.apache.iceberg.parquet.Parquet; +import org.apache.iceberg.spark.data.AvroDataTest; +import org.apache.iceberg.spark.data.RandomData; +import org.apache.iceberg.spark.data.TestHelpers; +import org.apache.iceberg.types.TypeUtil; +import org.apache.iceberg.types.Types; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.SparkSession; +import org.junit.AfterClass; +import org.junit.Assert; +import org.junit.Assume; +import org.junit.BeforeClass; +import org.junit.Rule; +import org.junit.rules.TemporaryFolder; +import org.junit.runner.RunWith; +import org.junit.runners.Parameterized; + +import static org.apache.iceberg.Files.localOutput; + +@RunWith(Parameterized.class) +public class TestParquetScan extends AvroDataTest { + private static final Configuration CONF = new Configuration(); + + private static SparkSession spark = null; + + @BeforeClass + public static void startSpark() { + TestParquetScan.spark = SparkSession.builder().master("local[2]").getOrCreate(); + } + + @AfterClass + public static void stopSpark() { + SparkSession currentSpark = TestParquetScan.spark; + TestParquetScan.spark = null; + currentSpark.stop(); + } + + @Rule + public TemporaryFolder temp = new TemporaryFolder(); + + @Parameterized.Parameters(name = "vectorized = {0}") + public static Object[] parameters() { + return new Object[] { false, true }; + } + + private final boolean vectorized; + + public TestParquetScan(boolean vectorized) { + this.vectorized = vectorized; + } + + @Override + protected void writeAndValidate(Schema schema) throws IOException { + Assume.assumeTrue("Cannot handle non-string map keys in parquet-avro", + null == TypeUtil.find( + schema, + type -> type.isMapType() && type.asMapType().keyType() != Types.StringType.get())); + + File parent = temp.newFolder("parquet"); + File location = new File(parent, "test"); + File dataFolder = new File(location, "data"); + dataFolder.mkdirs(); + + File parquetFile = new File(dataFolder, + FileFormat.PARQUET.addExtension(UUID.randomUUID().toString())); + + HadoopTables tables = new HadoopTables(CONF); + Table table = tables.create(schema, PartitionSpec.unpartitioned(), location.toString()); + + // Important: use the table's schema for the rest of the test + // When tables are created, the column ids are reassigned. + Schema tableSchema = table.schema(); + + List expected = RandomData.generateList(tableSchema, 100, 1L); + + try (FileAppender writer = Parquet.write(localOutput(parquetFile)) + .schema(tableSchema) + .build()) { + writer.addAll(expected); + } + + DataFile file = DataFiles.builder(PartitionSpec.unpartitioned()) + .withFileSizeInBytes(parquetFile.length()) + .withPath(parquetFile.toString()) + .withRecordCount(100) + .build(); + + table.newAppend().appendFile(file).commit(); + table.updateProperties().set(TableProperties.PARQUET_VECTORIZATION_ENABLED, String.valueOf(vectorized)).commit(); + + Dataset df = spark.read() + .format("iceberg") + .load(location.toString()); + + List rows = df.collectAsList(); + Assert.assertEquals("Should contain 100 rows", 100, rows.size()); + + for (int i = 0; i < expected.size(); i += 1) { + TestHelpers.assertEqualsSafe(tableSchema.asStruct(), expected.get(i), rows.get(i)); + } + } +} diff --git a/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/source/TestPartitionPruning.java b/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/source/TestPartitionPruning.java new file mode 100644 index 000000000000..24f7b69e1dc5 --- /dev/null +++ b/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/source/TestPartitionPruning.java @@ -0,0 +1,414 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iceberg.spark.source; + +import java.io.File; +import java.io.IOException; +import java.net.URI; +import java.sql.Timestamp; +import java.time.Instant; +import java.util.Arrays; +import java.util.List; +import java.util.Map; +import java.util.Random; +import java.util.Set; +import java.util.concurrent.TimeUnit; +import java.util.function.Predicate; +import java.util.stream.Collectors; +import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.fs.FSDataInputStream; +import org.apache.hadoop.fs.Path; +import org.apache.hadoop.fs.RawLocalFileSystem; +import org.apache.iceberg.PartitionSpec; +import org.apache.iceberg.Schema; +import org.apache.iceberg.Table; +import org.apache.iceberg.TableProperties; +import org.apache.iceberg.expressions.Literal; +import org.apache.iceberg.hadoop.HadoopTables; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableList; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap; +import org.apache.iceberg.relocated.com.google.common.collect.Maps; +import org.apache.iceberg.relocated.com.google.common.collect.Sets; +import org.apache.iceberg.spark.SparkReadOptions; +import org.apache.iceberg.spark.SparkSchemaUtil; +import org.apache.iceberg.transforms.Transform; +import org.apache.iceberg.transforms.Transforms; +import org.apache.iceberg.types.Types; +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.SparkSession; +import org.apache.spark.sql.catalyst.InternalRow; +import org.apache.spark.sql.catalyst.expressions.GenericInternalRow; +import org.apache.spark.sql.types.DataTypes; +import org.apache.spark.unsafe.types.UTF8String; +import org.junit.AfterClass; +import org.junit.Assert; +import org.junit.BeforeClass; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.TemporaryFolder; +import org.junit.runner.RunWith; +import org.junit.runners.Parameterized; + +@RunWith(Parameterized.class) +public class TestPartitionPruning { + + private static final Configuration CONF = new Configuration(); + private static final HadoopTables TABLES = new HadoopTables(CONF); + + @Parameterized.Parameters(name = "format = {0}, vectorized = {1}") + public static Object[][] parameters() { + return new Object[][] { + { "parquet", false }, + { "parquet", true }, + { "avro", false }, + { "orc", false }, + { "orc", true } + }; + } + + private final String format; + private final boolean vectorized; + + public TestPartitionPruning(String format, boolean vectorized) { + this.format = format; + this.vectorized = vectorized; + } + + private static SparkSession spark = null; + private static JavaSparkContext sparkContext = null; + + private static Transform bucketTransform = Transforms.bucket(Types.IntegerType.get(), 3); + private static Transform truncateTransform = Transforms.truncate(Types.StringType.get(), 5); + private static Transform hourTransform = Transforms.hour(Types.TimestampType.withoutZone()); + + @BeforeClass + public static void startSpark() { + TestPartitionPruning.spark = SparkSession.builder().master("local[2]").getOrCreate(); + TestPartitionPruning.sparkContext = JavaSparkContext.fromSparkContext(spark.sparkContext()); + + String optionKey = String.format("fs.%s.impl", CountOpenLocalFileSystem.scheme); + CONF.set(optionKey, CountOpenLocalFileSystem.class.getName()); + spark.conf().set(optionKey, CountOpenLocalFileSystem.class.getName()); + spark.conf().set("spark.sql.session.timeZone", "UTC"); + spark.udf().register("bucket3", (Integer num) -> bucketTransform.apply(num), DataTypes.IntegerType); + spark.udf().register("truncate5", (String str) -> truncateTransform.apply(str), DataTypes.StringType); + // NOTE: date transforms take the type long, not Timestamp + spark.udf().register("hour", (Timestamp ts) -> hourTransform.apply( + org.apache.spark.sql.catalyst.util.DateTimeUtils.fromJavaTimestamp(ts)), + DataTypes.IntegerType); + } + + @AfterClass + public static void stopSpark() { + SparkSession currentSpark = TestPartitionPruning.spark; + TestPartitionPruning.spark = null; + currentSpark.stop(); + } + + private static final Schema LOG_SCHEMA = new Schema( + Types.NestedField.optional(1, "id", Types.IntegerType.get()), + Types.NestedField.optional(2, "date", Types.StringType.get()), + Types.NestedField.optional(3, "level", Types.StringType.get()), + Types.NestedField.optional(4, "message", Types.StringType.get()), + Types.NestedField.optional(5, "timestamp", Types.TimestampType.withZone()) + ); + + private static final List LOGS = ImmutableList.of( + LogMessage.debug("2020-02-02", "debug event 1", getInstant("2020-02-02T00:00:00")), + LogMessage.info("2020-02-02", "info event 1", getInstant("2020-02-02T01:00:00")), + LogMessage.debug("2020-02-02", "debug event 2", getInstant("2020-02-02T02:00:00")), + LogMessage.info("2020-02-03", "info event 2", getInstant("2020-02-03T00:00:00")), + LogMessage.debug("2020-02-03", "debug event 3", getInstant("2020-02-03T01:00:00")), + LogMessage.info("2020-02-03", "info event 3", getInstant("2020-02-03T02:00:00")), + LogMessage.error("2020-02-03", "error event 1", getInstant("2020-02-03T03:00:00")), + LogMessage.debug("2020-02-04", "debug event 4", getInstant("2020-02-04T01:00:00")), + LogMessage.warn("2020-02-04", "warn event 1", getInstant("2020-02-04T02:00:00")), + LogMessage.debug("2020-02-04", "debug event 5", getInstant("2020-02-04T03:00:00")) + ); + + private static Instant getInstant(String timestampWithoutZone) { + Long epochMicros = (Long) Literal.of(timestampWithoutZone).to(Types.TimestampType.withoutZone()).value(); + return Instant.ofEpochMilli(TimeUnit.MICROSECONDS.toMillis(epochMicros)); + } + + @Rule + public TemporaryFolder temp = new TemporaryFolder(); + + private PartitionSpec spec = PartitionSpec.builderFor(LOG_SCHEMA) + .identity("date") + .identity("level") + .bucket("id", 3) + .truncate("message", 5) + .hour("timestamp") + .build(); + + @Test + public void testPartitionPruningIdentityString() { + String filterCond = "date >= '2020-02-03' AND level = 'DEBUG'"; + Predicate partCondition = (Row r) -> { + String date = r.getString(0); + String level = r.getString(1); + return date.compareTo("2020-02-03") >= 0 && level.equals("DEBUG"); + }; + + runTest(filterCond, partCondition); + } + + @Test + public void testPartitionPruningBucketingInteger() { + final int[] ids = new int[]{ + LOGS.get(3).getId(), + LOGS.get(7).getId() + }; + String condForIds = Arrays.stream(ids).mapToObj(String::valueOf) + .collect(Collectors.joining(",", "(", ")")); + String filterCond = "id in " + condForIds; + Predicate partCondition = (Row r) -> { + int bucketId = r.getInt(2); + Set buckets = Arrays.stream(ids).map(bucketTransform::apply) + .boxed().collect(Collectors.toSet()); + return buckets.contains(bucketId); + }; + + runTest(filterCond, partCondition); + } + + @Test + public void testPartitionPruningTruncatedString() { + String filterCond = "message like 'info event%'"; + Predicate partCondition = (Row r) -> { + String truncatedMessage = r.getString(3); + return truncatedMessage.equals("info "); + }; + + runTest(filterCond, partCondition); + } + + @Test + public void testPartitionPruningTruncatedStringComparingValueShorterThanPartitionValue() { + String filterCond = "message like 'inf%'"; + Predicate partCondition = (Row r) -> { + String truncatedMessage = r.getString(3); + return truncatedMessage.startsWith("inf"); + }; + + runTest(filterCond, partCondition); + } + + @Test + public void testPartitionPruningHourlyPartition() { + String filterCond; + if (spark.version().startsWith("2")) { + // Looks like from Spark 2 we need to compare timestamp with timestamp to push down the filter. + filterCond = "timestamp >= to_timestamp('2020-02-03T01:00:00')"; + } else { + filterCond = "timestamp >= '2020-02-03T01:00:00'"; + } + Predicate partCondition = (Row r) -> { + int hourValue = r.getInt(4); + Instant instant = getInstant("2020-02-03T01:00:00"); + Integer hourValueToFilter = hourTransform.apply(TimeUnit.MILLISECONDS.toMicros(instant.toEpochMilli())); + return hourValue >= hourValueToFilter; + }; + + runTest(filterCond, partCondition); + } + + private void runTest(String filterCond, Predicate partCondition) { + File originTableLocation = createTempDir(); + Assert.assertTrue("Temp folder should exist", originTableLocation.exists()); + + Table table = createTable(originTableLocation); + Dataset logs = createTestDataset(); + saveTestDatasetToTable(logs, table); + + List expected = logs + .select("id", "date", "level", "message", "timestamp") + .filter(filterCond) + .orderBy("id") + .collectAsList(); + Assert.assertFalse("Expected rows should be not empty", expected.isEmpty()); + + // remove records which may be recorded during storing to table + CountOpenLocalFileSystem.resetRecordsInPathPrefix(originTableLocation.getAbsolutePath()); + + List actual = spark.read() + .format("iceberg") + .option(SparkReadOptions.VECTORIZATION_ENABLED, String.valueOf(vectorized)) + .load(table.location()) + .select("id", "date", "level", "message", "timestamp") + .filter(filterCond) + .orderBy("id") + .collectAsList(); + Assert.assertFalse("Actual rows should not be empty", actual.isEmpty()); + + Assert.assertEquals("Rows should match", expected, actual); + + assertAccessOnDataFiles(originTableLocation, table, partCondition); + } + + private File createTempDir() { + try { + return temp.newFolder(); + } catch (Exception e) { + throw new RuntimeException(e); + } + } + + private Table createTable(File originTableLocation) { + String trackedTableLocation = CountOpenLocalFileSystem.convertPath(originTableLocation); + Map properties = ImmutableMap.of(TableProperties.DEFAULT_FILE_FORMAT, format); + return TABLES.create(LOG_SCHEMA, spec, properties, trackedTableLocation); + } + + private Dataset createTestDataset() { + List rows = LOGS.stream().map(logMessage -> { + Object[] underlying = new Object[] { + logMessage.getId(), + UTF8String.fromString(logMessage.getDate()), + UTF8String.fromString(logMessage.getLevel()), + UTF8String.fromString(logMessage.getMessage()), + // discard the nanoseconds part to simplify + TimeUnit.MILLISECONDS.toMicros(logMessage.getTimestamp().toEpochMilli()) + }; + return new GenericInternalRow(underlying); + }).collect(Collectors.toList()); + + JavaRDD rdd = sparkContext.parallelize(rows); + Dataset df = spark.internalCreateDataFrame(JavaRDD.toRDD(rdd), SparkSchemaUtil.convert(LOG_SCHEMA), false); + + return df + .selectExpr("id", "date", "level", "message", "timestamp") + .selectExpr("id", "date", "level", "message", "timestamp", "bucket3(id) AS bucket_id", + "truncate5(message) AS truncated_message", "hour(timestamp) AS ts_hour"); + } + + private void saveTestDatasetToTable(Dataset logs, Table table) { + logs.orderBy("date", "level", "bucket_id", "truncated_message", "ts_hour") + .select("id", "date", "level", "message", "timestamp") + .write().format("iceberg").mode("append").save(table.location()); + } + + private void assertAccessOnDataFiles(File originTableLocation, Table table, Predicate partCondition) { + // only use files in current table location to avoid side-effects on concurrent test runs + Set readFilesInQuery = CountOpenLocalFileSystem.pathToNumOpenCalled.keySet() + .stream().filter(path -> path.startsWith(originTableLocation.getAbsolutePath())) + .collect(Collectors.toSet()); + + List files = spark.read().format("iceberg").load(table.location() + "#files").collectAsList(); + + Set filesToRead = extractFilePathsMatchingConditionOnPartition(files, partCondition); + Set filesToNotRead = extractFilePathsNotIn(files, filesToRead); + + // Just to be sure, they should be mutually exclusive. + Assert.assertTrue(Sets.intersection(filesToRead, filesToNotRead).isEmpty()); + + Assert.assertFalse("The query should prune some data files.", filesToNotRead.isEmpty()); + + // We don't check "all" data files bound to the condition are being read, as data files can be pruned on + // other conditions like lower/upper bound of columns. + Assert.assertFalse("Some of data files in partition range should be read. " + + "Read files in query: " + readFilesInQuery + " / data files in partition range: " + filesToRead, + Sets.intersection(filesToRead, readFilesInQuery).isEmpty()); + + // Data files which aren't bound to the condition shouldn't be read. + Assert.assertTrue("Data files outside of partition range should not be read. " + + "Read files in query: " + readFilesInQuery + " / data files outside of partition range: " + filesToNotRead, + Sets.intersection(filesToNotRead, readFilesInQuery).isEmpty()); + } + + private Set extractFilePathsMatchingConditionOnPartition(List files, Predicate condition) { + // idx 1: file_path, idx 3: partition + return files.stream() + .filter(r -> { + Row partition = r.getStruct(4); + return condition.test(partition); + }).map(r -> CountOpenLocalFileSystem.stripScheme(r.getString(1))) + .collect(Collectors.toSet()); + } + + private Set extractFilePathsNotIn(List files, Set filePaths) { + Set allFilePaths = files.stream().map(r -> CountOpenLocalFileSystem.stripScheme(r.getString(1))) + .collect(Collectors.toSet()); + return Sets.newHashSet(Sets.symmetricDifference(allFilePaths, filePaths)); + } + + public static class CountOpenLocalFileSystem extends RawLocalFileSystem { + public static String scheme = String.format("TestIdentityPartitionData%dfs", + new Random().nextInt()); + public static Map pathToNumOpenCalled = Maps.newConcurrentMap(); + + public static String convertPath(String absPath) { + return scheme + "://" + absPath; + } + + public static String convertPath(File file) { + return convertPath(file.getAbsolutePath()); + } + + public static String stripScheme(String pathWithScheme) { + if (!pathWithScheme.startsWith(scheme + ":")) { + throw new IllegalArgumentException("Received unexpected path: " + pathWithScheme); + } + + int idxToCut = scheme.length() + 1; + while (pathWithScheme.charAt(idxToCut) == '/') { + idxToCut++; + } + + // leave the last '/' + idxToCut--; + + return pathWithScheme.substring(idxToCut); + } + + public static void resetRecordsInPathPrefix(String pathPrefix) { + pathToNumOpenCalled.keySet().stream() + .filter(p -> p.startsWith(pathPrefix)) + .forEach(key -> pathToNumOpenCalled.remove(key)); + } + + @Override + public URI getUri() { + return URI.create(scheme + ":///"); + } + + @Override + public String getScheme() { + return scheme; + } + + @Override + public FSDataInputStream open(Path f, int bufferSize) throws IOException { + String path = f.toUri().getPath(); + pathToNumOpenCalled.compute(path, (ignored, v) -> { + if (v == null) { + return 1L; + } else { + return v + 1; + } + }); + return super.open(f, bufferSize); + } + } +} diff --git a/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/source/TestPartitionValues.java b/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/source/TestPartitionValues.java new file mode 100644 index 000000000000..9d1b49f1aa8e --- /dev/null +++ b/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/source/TestPartitionValues.java @@ -0,0 +1,435 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iceberg.spark.source; + +import java.io.File; +import java.util.List; +import org.apache.avro.generic.GenericData; +import org.apache.iceberg.DataFiles; +import org.apache.iceberg.Files; +import org.apache.iceberg.PartitionSpec; +import org.apache.iceberg.Schema; +import org.apache.iceberg.Table; +import org.apache.iceberg.TableProperties; +import org.apache.iceberg.avro.Avro; +import org.apache.iceberg.hadoop.HadoopTables; +import org.apache.iceberg.io.FileAppender; +import org.apache.iceberg.relocated.com.google.common.collect.Lists; +import org.apache.iceberg.spark.SparkReadOptions; +import org.apache.iceberg.spark.SparkWriteOptions; +import org.apache.iceberg.spark.data.RandomData; +import org.apache.iceberg.spark.data.TestHelpers; +import org.apache.iceberg.types.Types; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Encoders; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.RowFactory; +import org.apache.spark.sql.SaveMode; +import org.apache.spark.sql.SparkSession; +import org.apache.spark.sql.types.DataTypes; +import org.apache.spark.sql.types.Metadata; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; +import org.junit.AfterClass; +import org.junit.Assert; +import org.junit.BeforeClass; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.TemporaryFolder; +import org.junit.runner.RunWith; +import org.junit.runners.Parameterized; + +import static org.apache.iceberg.types.Types.NestedField.optional; +import static org.apache.iceberg.types.Types.NestedField.required; + +@RunWith(Parameterized.class) +public class TestPartitionValues { + @Parameterized.Parameters(name = "format = {0}, vectorized = {1}") + public static Object[][] parameters() { + return new Object[][] { + { "parquet", false }, + { "parquet", true }, + { "avro", false }, + { "orc", false }, + { "orc", true } + }; + } + + private static final Schema SUPPORTED_PRIMITIVES = new Schema( + required(100, "id", Types.LongType.get()), + required(101, "data", Types.StringType.get()), + required(102, "b", Types.BooleanType.get()), + required(103, "i", Types.IntegerType.get()), + required(104, "l", Types.LongType.get()), + required(105, "f", Types.FloatType.get()), + required(106, "d", Types.DoubleType.get()), + required(107, "date", Types.DateType.get()), + required(108, "ts", Types.TimestampType.withZone()), + required(110, "s", Types.StringType.get()), + required(113, "bytes", Types.BinaryType.get()), + required(114, "dec_9_0", Types.DecimalType.of(9, 0)), + required(115, "dec_11_2", Types.DecimalType.of(11, 2)), + required(116, "dec_38_10", Types.DecimalType.of(38, 10)) // spark's maximum precision + ); + + private static final Schema SIMPLE_SCHEMA = new Schema( + optional(1, "id", Types.IntegerType.get()), + optional(2, "data", Types.StringType.get())); + + private static final PartitionSpec SPEC = PartitionSpec.builderFor(SIMPLE_SCHEMA) + .identity("data") + .build(); + + private static SparkSession spark = null; + + @BeforeClass + public static void startSpark() { + TestPartitionValues.spark = SparkSession.builder().master("local[2]").getOrCreate(); + } + + @AfterClass + public static void stopSpark() { + SparkSession currentSpark = TestPartitionValues.spark; + TestPartitionValues.spark = null; + currentSpark.stop(); + } + + @Rule + public TemporaryFolder temp = new TemporaryFolder(); + + private final String format; + private final boolean vectorized; + + public TestPartitionValues(String format, boolean vectorized) { + this.format = format; + this.vectorized = vectorized; + } + + @Test + public void testNullPartitionValue() throws Exception { + String desc = "null_part"; + File parent = temp.newFolder(desc); + File location = new File(parent, "test"); + File dataFolder = new File(location, "data"); + Assert.assertTrue("mkdirs should succeed", dataFolder.mkdirs()); + + HadoopTables tables = new HadoopTables(spark.sessionState().newHadoopConf()); + Table table = tables.create(SIMPLE_SCHEMA, SPEC, location.toString()); + table.updateProperties().set(TableProperties.DEFAULT_FILE_FORMAT, format).commit(); + + List expected = Lists.newArrayList( + new SimpleRecord(1, "a"), + new SimpleRecord(2, "b"), + new SimpleRecord(3, "c"), + new SimpleRecord(4, null) + ); + + Dataset df = spark.createDataFrame(expected, SimpleRecord.class); + + df.select("id", "data").write() + .format("iceberg") + .mode(SaveMode.Append) + .save(location.toString()); + + Dataset result = spark.read() + .format("iceberg") + .option(SparkReadOptions.VECTORIZATION_ENABLED, String.valueOf(vectorized)) + .load(location.toString()); + + List actual = result + .orderBy("id") + .as(Encoders.bean(SimpleRecord.class)) + .collectAsList(); + + Assert.assertEquals("Number of rows should match", expected.size(), actual.size()); + Assert.assertEquals("Result rows should match", expected, actual); + } + + @Test + public void testReorderedColumns() throws Exception { + String desc = "reorder_columns"; + File parent = temp.newFolder(desc); + File location = new File(parent, "test"); + File dataFolder = new File(location, "data"); + Assert.assertTrue("mkdirs should succeed", dataFolder.mkdirs()); + + HadoopTables tables = new HadoopTables(spark.sessionState().newHadoopConf()); + Table table = tables.create(SIMPLE_SCHEMA, SPEC, location.toString()); + table.updateProperties().set(TableProperties.DEFAULT_FILE_FORMAT, format).commit(); + + List expected = Lists.newArrayList( + new SimpleRecord(1, "a"), + new SimpleRecord(2, "b"), + new SimpleRecord(3, "c") + ); + + Dataset df = spark.createDataFrame(expected, SimpleRecord.class); + + df.select("data", "id").write() + .format("iceberg") + .mode(SaveMode.Append) + .option(SparkWriteOptions.CHECK_ORDERING, "false") + .save(location.toString()); + + Dataset result = spark.read() + .format("iceberg") + .option(SparkReadOptions.VECTORIZATION_ENABLED, String.valueOf(vectorized)) + .load(location.toString()); + + List actual = result + .orderBy("id") + .as(Encoders.bean(SimpleRecord.class)) + .collectAsList(); + + Assert.assertEquals("Number of rows should match", expected.size(), actual.size()); + Assert.assertEquals("Result rows should match", expected, actual); + } + + @Test + public void testReorderedColumnsNoNullability() throws Exception { + String desc = "reorder_columns_no_nullability"; + File parent = temp.newFolder(desc); + File location = new File(parent, "test"); + File dataFolder = new File(location, "data"); + Assert.assertTrue("mkdirs should succeed", dataFolder.mkdirs()); + + HadoopTables tables = new HadoopTables(spark.sessionState().newHadoopConf()); + Table table = tables.create(SIMPLE_SCHEMA, SPEC, location.toString()); + table.updateProperties().set(TableProperties.DEFAULT_FILE_FORMAT, format).commit(); + + List expected = Lists.newArrayList( + new SimpleRecord(1, "a"), + new SimpleRecord(2, "b"), + new SimpleRecord(3, "c") + ); + + Dataset df = spark.createDataFrame(expected, SimpleRecord.class); + + df.select("data", "id").write() + .format("iceberg") + .mode(SaveMode.Append) + .option(SparkWriteOptions.CHECK_ORDERING, "false") + .option(SparkWriteOptions.CHECK_NULLABILITY, "false") + .save(location.toString()); + + Dataset result = spark.read() + .format("iceberg") + .option(SparkReadOptions.VECTORIZATION_ENABLED, String.valueOf(vectorized)) + .load(location.toString()); + + List actual = result + .orderBy("id") + .as(Encoders.bean(SimpleRecord.class)) + .collectAsList(); + + Assert.assertEquals("Number of rows should match", expected.size(), actual.size()); + Assert.assertEquals("Result rows should match", expected, actual); + } + + @Test + public void testPartitionValueTypes() throws Exception { + String[] columnNames = new String[] { + "b", "i", "l", "f", "d", "date", "ts", "s", "bytes", "dec_9_0", "dec_11_2", "dec_38_10" + }; + + HadoopTables tables = new HadoopTables(spark.sessionState().newHadoopConf()); + + // create a table around the source data + String sourceLocation = temp.newFolder("source_table").toString(); + Table source = tables.create(SUPPORTED_PRIMITIVES, sourceLocation); + + // write out an Avro data file with all of the data types for source data + List expected = RandomData.generateList(source.schema(), 2, 128735L); + File avroData = temp.newFile("data.avro"); + Assert.assertTrue(avroData.delete()); + try (FileAppender appender = Avro.write(Files.localOutput(avroData)) + .schema(source.schema()) + .build()) { + appender.addAll(expected); + } + + // add the Avro data file to the source table + source.newAppend() + .appendFile(DataFiles.builder(PartitionSpec.unpartitioned()) + .withRecordCount(10) + .withInputFile(Files.localInput(avroData)) + .build()) + .commit(); + + Dataset sourceDF = spark.read().format("iceberg") + .option(SparkReadOptions.VECTORIZATION_ENABLED, String.valueOf(vectorized)) + .load(sourceLocation); + + for (String column : columnNames) { + String desc = "partition_by_" + SUPPORTED_PRIMITIVES.findType(column).toString(); + + File parent = temp.newFolder(desc); + File location = new File(parent, "test"); + File dataFolder = new File(location, "data"); + Assert.assertTrue("mkdirs should succeed", dataFolder.mkdirs()); + + PartitionSpec spec = PartitionSpec.builderFor(SUPPORTED_PRIMITIVES).identity(column).build(); + + Table table = tables.create(SUPPORTED_PRIMITIVES, spec, location.toString()); + table.updateProperties().set(TableProperties.DEFAULT_FILE_FORMAT, format).commit(); + + sourceDF.write() + .format("iceberg") + .mode(SaveMode.Append) + .option(SparkWriteOptions.USE_TABLE_DISTRIBUTION_AND_ORDERING, "false") + .save(location.toString()); + + List actual = spark.read() + .format("iceberg") + .option(SparkReadOptions.VECTORIZATION_ENABLED, String.valueOf(vectorized)) + .load(location.toString()) + .collectAsList(); + + Assert.assertEquals("Number of rows should match", expected.size(), actual.size()); + + for (int i = 0; i < expected.size(); i += 1) { + TestHelpers.assertEqualsSafe( + SUPPORTED_PRIMITIVES.asStruct(), expected.get(i), actual.get(i)); + } + } + } + + @Test + public void testNestedPartitionValues() throws Exception { + String[] columnNames = new String[] { + "b", "i", "l", "f", "d", "date", "ts", "s", "bytes", "dec_9_0", "dec_11_2", "dec_38_10" + }; + + HadoopTables tables = new HadoopTables(spark.sessionState().newHadoopConf()); + Schema nestedSchema = new Schema(optional(1, "nested", SUPPORTED_PRIMITIVES.asStruct())); + + // create a table around the source data + String sourceLocation = temp.newFolder("source_table").toString(); + Table source = tables.create(nestedSchema, sourceLocation); + + // write out an Avro data file with all of the data types for source data + List expected = RandomData.generateList(source.schema(), 2, 128735L); + File avroData = temp.newFile("data.avro"); + Assert.assertTrue(avroData.delete()); + try (FileAppender appender = Avro.write(Files.localOutput(avroData)) + .schema(source.schema()) + .build()) { + appender.addAll(expected); + } + + // add the Avro data file to the source table + source.newAppend() + .appendFile(DataFiles.builder(PartitionSpec.unpartitioned()) + .withRecordCount(10) + .withInputFile(Files.localInput(avroData)) + .build()) + .commit(); + + Dataset sourceDF = spark.read().format("iceberg") + .option(SparkReadOptions.VECTORIZATION_ENABLED, String.valueOf(vectorized)) + .load(sourceLocation); + + for (String column : columnNames) { + String desc = "partition_by_" + SUPPORTED_PRIMITIVES.findType(column).toString(); + + File parent = temp.newFolder(desc); + File location = new File(parent, "test"); + File dataFolder = new File(location, "data"); + Assert.assertTrue("mkdirs should succeed", dataFolder.mkdirs()); + + PartitionSpec spec = PartitionSpec.builderFor(nestedSchema).identity("nested." + column).build(); + + Table table = tables.create(nestedSchema, spec, location.toString()); + table.updateProperties().set(TableProperties.DEFAULT_FILE_FORMAT, format).commit(); + + sourceDF.write() + .format("iceberg") + .mode(SaveMode.Append) + .option(SparkWriteOptions.USE_TABLE_DISTRIBUTION_AND_ORDERING, "false") + .save(location.toString()); + + List actual = spark.read() + .format("iceberg") + .option(SparkReadOptions.VECTORIZATION_ENABLED, String.valueOf(vectorized)) + .load(location.toString()) + .collectAsList(); + + Assert.assertEquals("Number of rows should match", expected.size(), actual.size()); + + for (int i = 0; i < expected.size(); i += 1) { + TestHelpers.assertEqualsSafe( + nestedSchema.asStruct(), expected.get(i), actual.get(i)); + } + } + } + + /** + * To verify if WrappedPositionAccessor is generated against a string field within a nested field, + * rather than a Position2Accessor. + * Or when building the partition path, a ClassCastException is thrown with the message like: + * Cannot cast org.apache.spark.unsafe.types.UTF8String to java.lang.CharSequence + */ + @Test + public void testPartitionedByNestedString() throws Exception { + // schema and partition spec + Schema nestedSchema = new Schema( + Types.NestedField.required(1, "struct", + Types.StructType.of(Types.NestedField.required(2, "string", Types.StringType.get())) + ) + ); + PartitionSpec spec = PartitionSpec.builderFor(nestedSchema).identity("struct.string").build(); + + // create table + HadoopTables tables = new HadoopTables(spark.sessionState().newHadoopConf()); + String baseLocation = temp.newFolder("partition_by_nested_string").toString(); + tables.create(nestedSchema, spec, baseLocation); + + // input data frame + StructField[] structFields = { + new StructField("struct", + DataTypes.createStructType( + new StructField[] { + new StructField("string", DataTypes.StringType, false, Metadata.empty()) + } + ), + false, Metadata.empty() + ) + }; + + List rows = Lists.newArrayList(); + rows.add(RowFactory.create(RowFactory.create("nested_string_value"))); + Dataset sourceDF = spark.createDataFrame(rows, new StructType(structFields)); + + // write into iceberg + sourceDF.write() + .format("iceberg") + .mode(SaveMode.Append) + .save(baseLocation); + + // verify + List actual = spark.read() + .format("iceberg") + .option(SparkReadOptions.VECTORIZATION_ENABLED, String.valueOf(vectorized)) + .load(baseLocation) + .collectAsList(); + + Assert.assertEquals("Number of rows should match", rows.size(), actual.size()); + } +} diff --git a/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/source/TestPathIdentifier.java b/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/source/TestPathIdentifier.java new file mode 100644 index 000000000000..ff4fe22a7a8a --- /dev/null +++ b/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/source/TestPathIdentifier.java @@ -0,0 +1,88 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iceberg.spark.source; + +import java.io.File; +import java.io.IOException; +import org.apache.iceberg.BaseTable; +import org.apache.iceberg.Schema; +import org.apache.iceberg.hadoop.HadoopTableOperations; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap; +import org.apache.iceberg.spark.PathIdentifier; +import org.apache.iceberg.spark.SparkCatalog; +import org.apache.iceberg.spark.SparkSchemaUtil; +import org.apache.iceberg.spark.SparkTestBase; +import org.apache.iceberg.types.Types; +import org.apache.spark.sql.catalyst.analysis.NoSuchTableException; +import org.apache.spark.sql.catalyst.analysis.TableAlreadyExistsException; +import org.apache.spark.sql.connector.expressions.Transform; +import org.apache.spark.sql.util.CaseInsensitiveStringMap; +import org.assertj.core.api.Assertions; +import org.junit.After; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.TemporaryFolder; + +import static org.apache.iceberg.types.Types.NestedField.required; + +public class TestPathIdentifier extends SparkTestBase { + + private static final Schema SCHEMA = new Schema( + required(1, "id", Types.LongType.get()), + required(2, "data", Types.StringType.get())); + + @Rule + public TemporaryFolder temp = new TemporaryFolder(); + private File tableLocation; + private PathIdentifier identifier; + private SparkCatalog sparkCatalog; + + @Before + public void before() throws IOException { + tableLocation = temp.newFolder(); + identifier = new PathIdentifier(tableLocation.getAbsolutePath()); + sparkCatalog = new SparkCatalog(); + sparkCatalog.initialize("test", new CaseInsensitiveStringMap(ImmutableMap.of())); + } + + @After + public void after() { + tableLocation.delete(); + sparkCatalog = null; + } + + @Test + public void testPathIdentifier() throws TableAlreadyExistsException, NoSuchTableException { + SparkTable table = sparkCatalog.createTable(identifier, + SparkSchemaUtil.convert(SCHEMA), + new Transform[0], + ImmutableMap.of()); + + Assert.assertEquals(table.table().location(), tableLocation.getAbsolutePath()); + Assertions.assertThat(table.table()).isInstanceOf(BaseTable.class); + Assertions.assertThat(((BaseTable) table.table()).operations()).isInstanceOf(HadoopTableOperations.class); + + Assert.assertEquals(sparkCatalog.loadTable(identifier), table); + Assert.assertTrue(sparkCatalog.dropTable(identifier)); + } +} + diff --git a/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/source/TestReadProjection.java b/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/source/TestReadProjection.java new file mode 100644 index 000000000000..8d65b64cab6d --- /dev/null +++ b/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/source/TestReadProjection.java @@ -0,0 +1,578 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iceberg.spark.source; + +import java.io.IOException; +import java.util.List; +import java.util.Map; +import org.apache.iceberg.Schema; +import org.apache.iceberg.data.GenericRecord; +import org.apache.iceberg.data.Record; +import org.apache.iceberg.relocated.com.google.common.base.Preconditions; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableList; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap; +import org.apache.iceberg.relocated.com.google.common.collect.Maps; +import org.apache.iceberg.relocated.com.google.common.collect.Sets; +import org.apache.iceberg.types.Comparators; +import org.apache.iceberg.types.Types; +import org.junit.Assert; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.TemporaryFolder; + +import static org.apache.avro.Schema.Type.UNION; + +public abstract class TestReadProjection { + final String format; + + TestReadProjection(String format) { + this.format = format; + } + + protected abstract Record writeAndRead(String desc, + Schema writeSchema, + Schema readSchema, + Record record) throws IOException; + + @Rule + public TemporaryFolder temp = new TemporaryFolder(); + + @Test + public void testFullProjection() throws Exception { + Schema schema = new Schema( + Types.NestedField.required(0, "id", Types.LongType.get()), + Types.NestedField.optional(1, "data", Types.StringType.get()) + ); + + Record record = GenericRecord.create(schema); + record.setField("id", 34L); + record.setField("data", "test"); + + Record projected = writeAndRead("full_projection", schema, schema, record); + + Assert.assertEquals("Should contain the correct id value", 34L, (long) projected.getField("id")); + + int cmp = Comparators.charSequences() + .compare("test", (CharSequence) projected.getField("data")); + Assert.assertEquals("Should contain the correct data value", 0, cmp); + } + + @Test + public void testReorderedFullProjection() throws Exception { +// Assume.assumeTrue( +// "Spark's Parquet read support does not support reordered columns", +// !format.equalsIgnoreCase("parquet")); + + Schema schema = new Schema( + Types.NestedField.required(0, "id", Types.LongType.get()), + Types.NestedField.optional(1, "data", Types.StringType.get()) + ); + + Record record = GenericRecord.create(schema); + record.setField("id", 34L); + record.setField("data", "test"); + + Schema reordered = new Schema( + Types.NestedField.optional(1, "data", Types.StringType.get()), + Types.NestedField.required(0, "id", Types.LongType.get()) + ); + + Record projected = writeAndRead("reordered_full_projection", schema, reordered, record); + + Assert.assertEquals("Should contain the correct 0 value", "test", projected.get(0).toString()); + Assert.assertEquals("Should contain the correct 1 value", 34L, projected.get(1)); + } + + @Test + public void testReorderedProjection() throws Exception { +// Assume.assumeTrue( +// "Spark's Parquet read support does not support reordered columns", +// !format.equalsIgnoreCase("parquet")); + + Schema schema = new Schema( + Types.NestedField.required(0, "id", Types.LongType.get()), + Types.NestedField.optional(1, "data", Types.StringType.get()) + ); + + Record record = GenericRecord.create(schema); + record.setField("id", 34L); + record.setField("data", "test"); + + Schema reordered = new Schema( + Types.NestedField.optional(2, "missing_1", Types.StringType.get()), + Types.NestedField.optional(1, "data", Types.StringType.get()), + Types.NestedField.optional(3, "missing_2", Types.LongType.get()) + ); + + Record projected = writeAndRead("reordered_projection", schema, reordered, record); + + Assert.assertNull("Should contain the correct 0 value", projected.get(0)); + Assert.assertEquals("Should contain the correct 1 value", "test", projected.get(1).toString()); + Assert.assertNull("Should contain the correct 2 value", projected.get(2)); + } + + @Test + public void testEmptyProjection() throws Exception { + Schema schema = new Schema( + Types.NestedField.required(0, "id", Types.LongType.get()), + Types.NestedField.optional(1, "data", Types.StringType.get()) + ); + + Record record = GenericRecord.create(schema); + record.setField("id", 34L); + record.setField("data", "test"); + + Record projected = writeAndRead("empty_projection", schema, schema.select(), record); + + Assert.assertNotNull("Should read a non-null record", projected); + try { + projected.get(0); + Assert.fail("Should not retrieve value with ordinal 0"); + } catch (ArrayIndexOutOfBoundsException e) { + // this is expected because there are no values + } + } + + @Test + public void testBasicProjection() throws Exception { + Schema writeSchema = new Schema( + Types.NestedField.required(0, "id", Types.LongType.get()), + Types.NestedField.optional(1, "data", Types.StringType.get()) + ); + + Record record = GenericRecord.create(writeSchema); + record.setField("id", 34L); + record.setField("data", "test"); + + Schema idOnly = new Schema( + Types.NestedField.required(0, "id", Types.LongType.get()) + ); + + Record projected = writeAndRead("basic_projection_id", writeSchema, idOnly, record); + Assert.assertNull("Should not project data", projected.getField("data")); + Assert.assertEquals("Should contain the correct id value", 34L, (long) projected.getField("id")); + + Schema dataOnly = new Schema( + Types.NestedField.optional(1, "data", Types.StringType.get()) + ); + + projected = writeAndRead("basic_projection_data", writeSchema, dataOnly, record); + + Assert.assertNull("Should not project id", projected.getField("id")); + int cmp = Comparators.charSequences() + .compare("test", (CharSequence) projected.getField("data")); + Assert.assertEquals("Should contain the correct data value", 0, cmp); + } + + @Test + public void testRename() throws Exception { + Schema writeSchema = new Schema( + Types.NestedField.required(0, "id", Types.LongType.get()), + Types.NestedField.optional(1, "data", Types.StringType.get()) + ); + + Record record = GenericRecord.create(writeSchema); + record.setField("id", 34L); + record.setField("data", "test"); + + Schema readSchema = new Schema( + Types.NestedField.required(0, "id", Types.LongType.get()), + Types.NestedField.optional(1, "renamed", Types.StringType.get()) + ); + + Record projected = writeAndRead("project_and_rename", writeSchema, readSchema, record); + + Assert.assertEquals("Should contain the correct id value", 34L, (long) projected.getField("id")); + int cmp = Comparators.charSequences() + .compare("test", (CharSequence) projected.getField("renamed")); + Assert.assertEquals("Should contain the correct data/renamed value", 0, cmp); + } + + @Test + public void testNestedStructProjection() throws Exception { + Schema writeSchema = new Schema( + Types.NestedField.required(0, "id", Types.LongType.get()), + Types.NestedField.optional(3, "location", Types.StructType.of( + Types.NestedField.required(1, "lat", Types.FloatType.get()), + Types.NestedField.required(2, "long", Types.FloatType.get()) + )) + ); + + Record record = GenericRecord.create(writeSchema); + record.setField("id", 34L); + Record location = GenericRecord.create(writeSchema.findType("location").asStructType()); + location.setField("lat", 52.995143f); + location.setField("long", -1.539054f); + record.setField("location", location); + + Schema idOnly = new Schema( + Types.NestedField.required(0, "id", Types.LongType.get()) + ); + + Record projected = writeAndRead("id_only", writeSchema, idOnly, record); + Record projectedLocation = (Record) projected.getField("location"); + Assert.assertEquals("Should contain the correct id value", 34L, (long) projected.getField("id")); + Assert.assertNull("Should not project location", projectedLocation); + + Schema latOnly = new Schema( + Types.NestedField.optional(3, "location", Types.StructType.of( + Types.NestedField.required(1, "lat", Types.FloatType.get()) + )) + ); + + projected = writeAndRead("latitude_only", writeSchema, latOnly, record); + projectedLocation = (Record) projected.getField("location"); + Assert.assertNull("Should not project id", projected.getField("id")); + Assert.assertNotNull("Should project location", projected.getField("location")); + Assert.assertNull("Should not project longitude", projectedLocation.getField("long")); + Assert.assertEquals("Should project latitude", + 52.995143f, (float) projectedLocation.getField("lat"), 0.000001f); + + Schema longOnly = new Schema( + Types.NestedField.optional(3, "location", Types.StructType.of( + Types.NestedField.required(2, "long", Types.FloatType.get()) + )) + ); + + projected = writeAndRead("longitude_only", writeSchema, longOnly, record); + projectedLocation = (Record) projected.getField("location"); + Assert.assertNull("Should not project id", projected.getField("id")); + Assert.assertNotNull("Should project location", projected.getField("location")); + Assert.assertNull("Should not project latitutde", projectedLocation.getField("lat")); + Assert.assertEquals("Should project longitude", + -1.539054f, (float) projectedLocation.getField("long"), 0.000001f); + + Schema locationOnly = writeSchema.select("location"); + projected = writeAndRead("location_only", writeSchema, locationOnly, record); + projectedLocation = (Record) projected.getField("location"); + Assert.assertNull("Should not project id", projected.getField("id")); + Assert.assertNotNull("Should project location", projected.getField("location")); + Assert.assertEquals("Should project latitude", + 52.995143f, (float) projectedLocation.getField("lat"), 0.000001f); + Assert.assertEquals("Should project longitude", + -1.539054f, (float) projectedLocation.getField("long"), 0.000001f); + } + + @Test + public void testMapProjection() throws IOException { + Schema writeSchema = new Schema( + Types.NestedField.required(0, "id", Types.LongType.get()), + Types.NestedField.optional(5, "properties", + Types.MapType.ofOptional(6, 7, Types.StringType.get(), Types.StringType.get())) + ); + + Map properties = ImmutableMap.of("a", "A", "b", "B"); + + Record record = GenericRecord.create(writeSchema); + record.setField("id", 34L); + record.setField("properties", properties); + + Schema idOnly = new Schema( + Types.NestedField.required(0, "id", Types.LongType.get()) + ); + + Record projected = writeAndRead("id_only", writeSchema, idOnly, record); + Assert.assertEquals("Should contain the correct id value", 34L, (long) projected.getField("id")); + Assert.assertNull("Should not project properties map", projected.getField("properties")); + + Schema keyOnly = writeSchema.select("properties.key"); + projected = writeAndRead("key_only", writeSchema, keyOnly, record); + Assert.assertNull("Should not project id", projected.getField("id")); + Assert.assertEquals("Should project entire map", + properties, toStringMap((Map) projected.getField("properties"))); + + Schema valueOnly = writeSchema.select("properties.value"); + projected = writeAndRead("value_only", writeSchema, valueOnly, record); + Assert.assertNull("Should not project id", projected.getField("id")); + Assert.assertEquals("Should project entire map", + properties, toStringMap((Map) projected.getField("properties"))); + + Schema mapOnly = writeSchema.select("properties"); + projected = writeAndRead("map_only", writeSchema, mapOnly, record); + Assert.assertNull("Should not project id", projected.getField("id")); + Assert.assertEquals("Should project entire map", + properties, toStringMap((Map) projected.getField("properties"))); + } + + private Map toStringMap(Map map) { + Map stringMap = Maps.newHashMap(); + for (Map.Entry entry : map.entrySet()) { + if (entry.getValue() instanceof CharSequence) { + stringMap.put(entry.getKey().toString(), entry.getValue().toString()); + } else { + stringMap.put(entry.getKey().toString(), entry.getValue()); + } + } + return stringMap; + } + + @Test + public void testMapOfStructsProjection() throws IOException { + Schema writeSchema = new Schema( + Types.NestedField.required(0, "id", Types.LongType.get()), + Types.NestedField.optional(5, "locations", Types.MapType.ofOptional(6, 7, + Types.StringType.get(), + Types.StructType.of( + Types.NestedField.required(1, "lat", Types.FloatType.get()), + Types.NestedField.required(2, "long", Types.FloatType.get()) + ) + )) + ); + + Record record = GenericRecord.create(writeSchema); + record.setField("id", 34L); + Record l1 = GenericRecord.create(writeSchema.findType("locations.value").asStructType()); + l1.setField("lat", 53.992811f); + l1.setField("long", -1.542616f); + Record l2 = GenericRecord.create(l1.struct()); + l2.setField("lat", 52.995143f); + l2.setField("long", -1.539054f); + record.setField("locations", ImmutableMap.of("L1", l1, "L2", l2)); + + Schema idOnly = new Schema( + Types.NestedField.required(0, "id", Types.LongType.get()) + ); + + Record projected = writeAndRead("id_only", writeSchema, idOnly, record); + Assert.assertEquals("Should contain the correct id value", 34L, (long) projected.getField("id")); + Assert.assertNull("Should not project locations map", projected.getField("locations")); + + projected = writeAndRead("all_locations", writeSchema, writeSchema.select("locations"), record); + Assert.assertNull("Should not project id", projected.getField("id")); + Assert.assertEquals("Should project locations map", + record.getField("locations"), toStringMap((Map) projected.getField("locations"))); + + projected = writeAndRead("lat_only", + writeSchema, writeSchema.select("locations.lat"), record); + Assert.assertNull("Should not project id", projected.getField("id")); + Map locations = toStringMap((Map) projected.getField("locations")); + Assert.assertNotNull("Should project locations map", locations); + Assert.assertEquals("Should contain L1 and L2", + Sets.newHashSet("L1", "L2"), locations.keySet()); + Record projectedL1 = (Record) locations.get("L1"); + Assert.assertNotNull("L1 should not be null", projectedL1); + Assert.assertEquals("L1 should contain lat", + 53.992811f, (float) projectedL1.getField("lat"), 0.000001); + Assert.assertNull("L1 should not contain long", projectedL1.getField("long")); + Record projectedL2 = (Record) locations.get("L2"); + Assert.assertNotNull("L2 should not be null", projectedL2); + Assert.assertEquals("L2 should contain lat", + 52.995143f, (float) projectedL2.getField("lat"), 0.000001); + Assert.assertNull("L2 should not contain long", projectedL2.getField("long")); + + projected = writeAndRead("long_only", + writeSchema, writeSchema.select("locations.long"), record); + Assert.assertNull("Should not project id", projected.getField("id")); + locations = toStringMap((Map) projected.getField("locations")); + Assert.assertNotNull("Should project locations map", locations); + Assert.assertEquals("Should contain L1 and L2", + Sets.newHashSet("L1", "L2"), locations.keySet()); + projectedL1 = (Record) locations.get("L1"); + Assert.assertNotNull("L1 should not be null", projectedL1); + Assert.assertNull("L1 should not contain lat", projectedL1.getField("lat")); + Assert.assertEquals("L1 should contain long", + -1.542616f, (float) projectedL1.getField("long"), 0.000001); + projectedL2 = (Record) locations.get("L2"); + Assert.assertNotNull("L2 should not be null", projectedL2); + Assert.assertNull("L2 should not contain lat", projectedL2.getField("lat")); + Assert.assertEquals("L2 should contain long", + -1.539054f, (float) projectedL2.getField("long"), 0.000001); + + Schema latitiudeRenamed = new Schema( + Types.NestedField.optional(5, "locations", Types.MapType.ofOptional(6, 7, + Types.StringType.get(), + Types.StructType.of( + Types.NestedField.required(1, "latitude", Types.FloatType.get()) + ) + )) + ); + + projected = writeAndRead("latitude_renamed", writeSchema, latitiudeRenamed, record); + Assert.assertNull("Should not project id", projected.getField("id")); + locations = toStringMap((Map) projected.getField("locations")); + Assert.assertNotNull("Should project locations map", locations); + Assert.assertEquals("Should contain L1 and L2", + Sets.newHashSet("L1", "L2"), locations.keySet()); + projectedL1 = (Record) locations.get("L1"); + Assert.assertNotNull("L1 should not be null", projectedL1); + Assert.assertEquals("L1 should contain latitude", + 53.992811f, (float) projectedL1.getField("latitude"), 0.000001); + Assert.assertNull("L1 should not contain lat", projectedL1.getField("lat")); + Assert.assertNull("L1 should not contain long", projectedL1.getField("long")); + projectedL2 = (Record) locations.get("L2"); + Assert.assertNotNull("L2 should not be null", projectedL2); + Assert.assertEquals("L2 should contain latitude", + 52.995143f, (float) projectedL2.getField("latitude"), 0.000001); + Assert.assertNull("L2 should not contain lat", projectedL2.getField("lat")); + Assert.assertNull("L2 should not contain long", projectedL2.getField("long")); + } + + @Test + public void testListProjection() throws IOException { + Schema writeSchema = new Schema( + Types.NestedField.required(0, "id", Types.LongType.get()), + Types.NestedField.optional(10, "values", + Types.ListType.ofOptional(11, Types.LongType.get())) + ); + + List values = ImmutableList.of(56L, 57L, 58L); + + Record record = GenericRecord.create(writeSchema); + record.setField("id", 34L); + record.setField("values", values); + + Schema idOnly = new Schema( + Types.NestedField.required(0, "id", Types.LongType.get()) + ); + + Record projected = writeAndRead("id_only", writeSchema, idOnly, record); + Assert.assertEquals("Should contain the correct id value", 34L, (long) projected.getField("id")); + Assert.assertNull("Should not project values list", projected.getField("values")); + + Schema elementOnly = writeSchema.select("values.element"); + projected = writeAndRead("element_only", writeSchema, elementOnly, record); + Assert.assertNull("Should not project id", projected.getField("id")); + Assert.assertEquals("Should project entire list", values, projected.getField("values")); + + Schema listOnly = writeSchema.select("values"); + projected = writeAndRead("list_only", writeSchema, listOnly, record); + Assert.assertNull("Should not project id", projected.getField("id")); + Assert.assertEquals("Should project entire list", values, projected.getField("values")); + } + + @Test + @SuppressWarnings("unchecked") + public void testListOfStructsProjection() throws IOException { + Schema writeSchema = new Schema( + Types.NestedField.required(0, "id", Types.LongType.get()), + Types.NestedField.optional(22, "points", + Types.ListType.ofOptional(21, Types.StructType.of( + Types.NestedField.required(19, "x", Types.IntegerType.get()), + Types.NestedField.optional(18, "y", Types.IntegerType.get()) + )) + ) + ); + + Record record = GenericRecord.create(writeSchema); + record.setField("id", 34L); + Record p1 = GenericRecord.create(writeSchema.findType("points.element").asStructType()); + p1.setField("x", 1); + p1.setField("y", 2); + Record p2 = GenericRecord.create(p1.struct()); + p2.setField("x", 3); + p2.setField("y", null); + record.setField("points", ImmutableList.of(p1, p2)); + + Schema idOnly = new Schema( + Types.NestedField.required(0, "id", Types.LongType.get()) + ); + + Record projected = writeAndRead("id_only", writeSchema, idOnly, record); + Assert.assertEquals("Should contain the correct id value", 34L, (long) projected.getField("id")); + Assert.assertNull("Should not project points list", projected.getField("points")); + + projected = writeAndRead("all_points", writeSchema, writeSchema.select("points"), record); + Assert.assertNull("Should not project id", projected.getField("id")); + Assert.assertEquals("Should project points list", + record.getField("points"), projected.getField("points")); + + projected = writeAndRead("x_only", writeSchema, writeSchema.select("points.x"), record); + Assert.assertNull("Should not project id", projected.getField("id")); + Assert.assertNotNull("Should project points list", projected.getField("points")); + List points = (List) projected.getField("points"); + Assert.assertEquals("Should read 2 points", 2, points.size()); + Record projectedP1 = points.get(0); + Assert.assertEquals("Should project x", 1, (int) projectedP1.getField("x")); + Assert.assertNull("Should not project y", projectedP1.getField("y")); + Record projectedP2 = points.get(1); + Assert.assertEquals("Should project x", 3, (int) projectedP2.getField("x")); + Assert.assertNull("Should not project y", projectedP2.getField("y")); + + projected = writeAndRead("y_only", writeSchema, writeSchema.select("points.y"), record); + Assert.assertNull("Should not project id", projected.getField("id")); + Assert.assertNotNull("Should project points list", projected.getField("points")); + points = (List) projected.getField("points"); + Assert.assertEquals("Should read 2 points", 2, points.size()); + projectedP1 = points.get(0); + Assert.assertNull("Should not project x", projectedP1.getField("x")); + Assert.assertEquals("Should project y", 2, (int) projectedP1.getField("y")); + projectedP2 = points.get(1); + Assert.assertNull("Should not project x", projectedP2.getField("x")); + Assert.assertNull("Should project null y", projectedP2.getField("y")); + + Schema yRenamed = new Schema( + Types.NestedField.optional(22, "points", + Types.ListType.ofOptional(21, Types.StructType.of( + Types.NestedField.optional(18, "z", Types.IntegerType.get()) + )) + ) + ); + + projected = writeAndRead("y_renamed", writeSchema, yRenamed, record); + Assert.assertNull("Should not project id", projected.getField("id")); + Assert.assertNotNull("Should project points list", projected.getField("points")); + points = (List) projected.getField("points"); + Assert.assertEquals("Should read 2 points", 2, points.size()); + projectedP1 = points.get(0); + Assert.assertNull("Should not project x", projectedP1.getField("x")); + Assert.assertNull("Should not project y", projectedP1.getField("y")); + Assert.assertEquals("Should project z", 2, (int) projectedP1.getField("z")); + projectedP2 = points.get(1); + Assert.assertNull("Should not project x", projectedP2.getField("x")); + Assert.assertNull("Should not project y", projectedP2.getField("y")); + Assert.assertNull("Should project null z", projectedP2.getField("z")); + + Schema zAdded = new Schema( + Types.NestedField.optional(22, "points", + Types.ListType.ofOptional(21, Types.StructType.of( + Types.NestedField.required(19, "x", Types.IntegerType.get()), + Types.NestedField.optional(18, "y", Types.IntegerType.get()), + Types.NestedField.optional(20, "z", Types.IntegerType.get()) + )) + ) + ); + + projected = writeAndRead("z_added", writeSchema, zAdded, record); + Assert.assertNull("Should not project id", projected.getField("id")); + Assert.assertNotNull("Should project points list", projected.getField("points")); + points = (List) projected.getField("points"); + Assert.assertEquals("Should read 2 points", 2, points.size()); + projectedP1 = points.get(0); + Assert.assertEquals("Should project x", 1, (int) projectedP1.getField("x")); + Assert.assertEquals("Should project y", 2, (int) projectedP1.getField("y")); + Assert.assertNull("Should contain null z", projectedP1.getField("z")); + projectedP2 = points.get(1); + Assert.assertEquals("Should project x", 3, (int) projectedP2.getField("x")); + Assert.assertNull("Should project null y", projectedP2.getField("y")); + Assert.assertNull("Should contain null z", projectedP2.getField("z")); + } + + private static org.apache.avro.Schema fromOption(org.apache.avro.Schema schema) { + Preconditions.checkArgument(schema.getType() == UNION, + "Expected union schema but was passed: %s", schema); + Preconditions.checkArgument(schema.getTypes().size() == 2, + "Expected optional schema, but was passed: %s", schema); + if (schema.getTypes().get(0).getType() == org.apache.avro.Schema.Type.NULL) { + return schema.getTypes().get(1); + } else { + return schema.getTypes().get(0); + } + } +} diff --git a/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/source/TestRequiredDistributionAndOrdering.java b/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/source/TestRequiredDistributionAndOrdering.java new file mode 100644 index 000000000000..2c64bc18c35c --- /dev/null +++ b/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/source/TestRequiredDistributionAndOrdering.java @@ -0,0 +1,316 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iceberg.spark.source; + +import java.util.List; +import java.util.Map; +import org.apache.iceberg.AssertHelpers; +import org.apache.iceberg.Table; +import org.apache.iceberg.TableProperties; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableList; +import org.apache.iceberg.spark.SparkCatalogTestBase; +import org.apache.iceberg.spark.SparkWriteOptions; +import org.apache.spark.SparkException; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.catalyst.analysis.NoSuchTableException; +import org.junit.After; +import org.junit.Test; + +public class TestRequiredDistributionAndOrdering extends SparkCatalogTestBase { + + public TestRequiredDistributionAndOrdering(String catalogName, String implementation, Map config) { + super(catalogName, implementation, config); + } + + @After + public void dropTestTable() { + sql("DROP TABLE IF EXISTS %s", tableName); + } + + @Test + public void testDefaultLocalSort() throws NoSuchTableException { + sql("CREATE TABLE %s (c1 INT, c2 STRING, c3 STRING) " + + "USING iceberg " + + "PARTITIONED BY (c3)", tableName); + + List data = ImmutableList.of( + new ThreeColumnRecord(1, null, "A"), + new ThreeColumnRecord(2, "BBBBBBBBBB", "B"), + new ThreeColumnRecord(3, "BBBBBBBBBB", "A"), + new ThreeColumnRecord(4, "BBBBBBBBBB", "B"), + new ThreeColumnRecord(5, "BBBBBBBBBB", "A"), + new ThreeColumnRecord(6, "BBBBBBBBBB", "B"), + new ThreeColumnRecord(7, "BBBBBBBBBB", "A") + ); + Dataset ds = spark.createDataFrame(data, ThreeColumnRecord.class); + Dataset inputDF = ds.coalesce(1).sortWithinPartitions("c1"); + + // should insert a local sort by partition columns by default + inputDF.writeTo(tableName).append(); + + assertEquals("Row count must match", + ImmutableList.of(row(7L)), + sql("SELECT count(*) FROM %s", tableName)); + } + + @Test + public void testPartitionColumnsArePrependedForRangeDistribution() throws NoSuchTableException { + sql("CREATE TABLE %s (c1 INT, c2 STRING, c3 STRING) " + + "USING iceberg " + + "PARTITIONED BY (c3)", tableName); + + List data = ImmutableList.of( + new ThreeColumnRecord(1, null, "A"), + new ThreeColumnRecord(2, "BBBBBBBBBB", "B"), + new ThreeColumnRecord(3, "BBBBBBBBBB", "A"), + new ThreeColumnRecord(4, "BBBBBBBBBB", "B"), + new ThreeColumnRecord(5, "BBBBBBBBBB", "A"), + new ThreeColumnRecord(6, "BBBBBBBBBB", "B"), + new ThreeColumnRecord(7, "BBBBBBBBBB", "A") + ); + Dataset ds = spark.createDataFrame(data, ThreeColumnRecord.class); + Dataset inputDF = ds.coalesce(1).sortWithinPartitions("c1"); + + Table table = validationCatalog.loadTable(tableIdent); + + // should automatically prepend partition columns to the ordering + table.updateProperties() + .set(TableProperties.WRITE_DISTRIBUTION_MODE, TableProperties.WRITE_DISTRIBUTION_MODE_RANGE) + .commit(); + table.replaceSortOrder() + .asc("c1") + .asc("c2") + .commit(); + inputDF.writeTo(tableName).append(); + + assertEquals("Row count must match", + ImmutableList.of(row(7L)), + sql("SELECT count(*) FROM %s", tableName)); + } + + @Test + public void testSortOrderIncludesPartitionColumns() throws NoSuchTableException { + sql("CREATE TABLE %s (c1 INT, c2 STRING, c3 STRING) " + + "USING iceberg " + + "PARTITIONED BY (c3)", tableName); + + List data = ImmutableList.of( + new ThreeColumnRecord(1, null, "A"), + new ThreeColumnRecord(2, "BBBBBBBBBB", "B"), + new ThreeColumnRecord(3, "BBBBBBBBBB", "A"), + new ThreeColumnRecord(4, "BBBBBBBBBB", "B"), + new ThreeColumnRecord(5, "BBBBBBBBBB", "A"), + new ThreeColumnRecord(6, "BBBBBBBBBB", "B"), + new ThreeColumnRecord(7, "BBBBBBBBBB", "A") + ); + Dataset ds = spark.createDataFrame(data, ThreeColumnRecord.class); + Dataset inputDF = ds.coalesce(1).sortWithinPartitions("c1"); + + Table table = validationCatalog.loadTable(tableIdent); + + // should succeed with a correct sort order + table.replaceSortOrder() + .asc("c3") + .asc("c1") + .asc("c2") + .commit(); + inputDF.writeTo(tableName).append(); + + assertEquals("Row count must match", + ImmutableList.of(row(7L)), + sql("SELECT count(*) FROM %s", tableName)); + } + + @Test + public void testDisabledDistributionAndOrdering() { + sql("CREATE TABLE %s (c1 INT, c2 STRING, c3 STRING) " + + "USING iceberg " + + "PARTITIONED BY (c3)", tableName); + + List data = ImmutableList.of( + new ThreeColumnRecord(1, null, "A"), + new ThreeColumnRecord(2, "BBBBBBBBBB", "B"), + new ThreeColumnRecord(3, "BBBBBBBBBB", "A"), + new ThreeColumnRecord(4, "BBBBBBBBBB", "B"), + new ThreeColumnRecord(5, "BBBBBBBBBB", "A"), + new ThreeColumnRecord(6, "BBBBBBBBBB", "B"), + new ThreeColumnRecord(7, "BBBBBBBBBB", "A") + ); + Dataset ds = spark.createDataFrame(data, ThreeColumnRecord.class); + Dataset inputDF = ds.coalesce(1).sortWithinPartitions("c1"); + + // should fail if ordering is disabled + AssertHelpers.assertThrows("Should reject writes without ordering", + SparkException.class, "Writing job aborted", + () -> { + try { + inputDF.writeTo(tableName) + .option(SparkWriteOptions.USE_TABLE_DISTRIBUTION_AND_ORDERING, "false") + .append(); + } catch (NoSuchTableException e) { + throw new RuntimeException(e); + } + }); + } + + @Test + public void testHashDistribution() throws NoSuchTableException { + sql("CREATE TABLE %s (c1 INT, c2 STRING, c3 STRING) " + + "USING iceberg " + + "PARTITIONED BY (c3)", tableName); + + List data = ImmutableList.of( + new ThreeColumnRecord(1, null, "A"), + new ThreeColumnRecord(2, "BBBBBBBBBB", "B"), + new ThreeColumnRecord(3, "BBBBBBBBBB", "A"), + new ThreeColumnRecord(4, "BBBBBBBBBB", "B"), + new ThreeColumnRecord(5, "BBBBBBBBBB", "A"), + new ThreeColumnRecord(6, "BBBBBBBBBB", "B"), + new ThreeColumnRecord(7, "BBBBBBBBBB", "A") + ); + Dataset ds = spark.createDataFrame(data, ThreeColumnRecord.class); + Dataset inputDF = ds.coalesce(1).sortWithinPartitions("c1"); + + Table table = validationCatalog.loadTable(tableIdent); + + // should automatically prepend partition columns to the local ordering after hash distribution + table.updateProperties() + .set(TableProperties.WRITE_DISTRIBUTION_MODE, TableProperties.WRITE_DISTRIBUTION_MODE_HASH) + .commit(); + table.replaceSortOrder() + .asc("c1") + .asc("c2") + .commit(); + inputDF.writeTo(tableName).append(); + + assertEquals("Row count must match", + ImmutableList.of(row(7L)), + sql("SELECT count(*) FROM %s", tableName)); + } + + @Test + public void testNoSortBucketTransformsWithoutExtensions() throws NoSuchTableException { + sql("CREATE TABLE %s (c1 INT, c2 STRING, c3 STRING) " + + "USING iceberg " + + "PARTITIONED BY (bucket(2, c1))", tableName); + + List data = ImmutableList.of( + new ThreeColumnRecord(1, null, "A"), + new ThreeColumnRecord(2, "BBBB", "B"), + new ThreeColumnRecord(3, "BBBB", "B"), + new ThreeColumnRecord(4, "BBBB", "B") + ); + Dataset ds = spark.createDataFrame(data, ThreeColumnRecord.class); + Dataset inputDF = ds.coalesce(1).sortWithinPartitions("c1"); + + // should fail by default as extensions are disabled + AssertHelpers.assertThrows("Should reject writes without ordering", + SparkException.class, "Writing job aborted", + () -> { + try { + inputDF.writeTo(tableName).append(); + } catch (NoSuchTableException e) { + throw new RuntimeException(e); + } + }); + + inputDF.writeTo(tableName) + .option(SparkWriteOptions.FANOUT_ENABLED, "true") + .append(); + + List expected = ImmutableList.of( + row(1, null, "A"), + row(2, "BBBB", "B"), + row(3, "BBBB", "B"), + row(4, "BBBB", "B") + ); + + assertEquals("Rows must match", expected, sql("SELECT * FROM %s ORDER BY c1", tableName)); + } + + @Test + public void testRangeDistributionWithQuotedColumnsNames() throws NoSuchTableException { + sql("CREATE TABLE %s (c1 INT, c2 STRING, `c.3` STRING) " + + "USING iceberg " + + "PARTITIONED BY (`c.3`)", tableName); + + List data = ImmutableList.of( + new ThreeColumnRecord(1, null, "A"), + new ThreeColumnRecord(2, "BBBBBBBBBB", "B"), + new ThreeColumnRecord(3, "BBBBBBBBBB", "A"), + new ThreeColumnRecord(4, "BBBBBBBBBB", "B"), + new ThreeColumnRecord(5, "BBBBBBBBBB", "A"), + new ThreeColumnRecord(6, "BBBBBBBBBB", "B"), + new ThreeColumnRecord(7, "BBBBBBBBBB", "A") + ); + Dataset ds = spark.createDataFrame(data, ThreeColumnRecord.class); + Dataset inputDF = ds.selectExpr("c1", "c2", "c3 as `c.3`").coalesce(1).sortWithinPartitions("c1"); + + Table table = validationCatalog.loadTable(tableIdent); + + table.updateProperties() + .set(TableProperties.WRITE_DISTRIBUTION_MODE, TableProperties.WRITE_DISTRIBUTION_MODE_RANGE) + .commit(); + table.replaceSortOrder() + .asc("c1") + .asc("c2") + .commit(); + inputDF.writeTo(tableName).append(); + + assertEquals("Row count must match", + ImmutableList.of(row(7L)), + sql("SELECT count(*) FROM %s", tableName)); + } + + @Test + public void testHashDistributionWithQuotedColumnsNames() throws NoSuchTableException { + sql("CREATE TABLE %s (c1 INT, c2 STRING, `c``3` STRING) " + + "USING iceberg " + + "PARTITIONED BY (`c``3`)", tableName); + + List data = ImmutableList.of( + new ThreeColumnRecord(1, null, "A"), + new ThreeColumnRecord(2, "BBBBBBBBBB", "B"), + new ThreeColumnRecord(3, "BBBBBBBBBB", "A"), + new ThreeColumnRecord(4, "BBBBBBBBBB", "B"), + new ThreeColumnRecord(5, "BBBBBBBBBB", "A"), + new ThreeColumnRecord(6, "BBBBBBBBBB", "B"), + new ThreeColumnRecord(7, "BBBBBBBBBB", "A") + ); + Dataset ds = spark.createDataFrame(data, ThreeColumnRecord.class); + Dataset inputDF = ds.selectExpr("c1", "c2", "c3 as `c``3`").coalesce(1).sortWithinPartitions("c1"); + + Table table = validationCatalog.loadTable(tableIdent); + + table.updateProperties() + .set(TableProperties.WRITE_DISTRIBUTION_MODE, TableProperties.WRITE_DISTRIBUTION_MODE_HASH) + .commit(); + table.replaceSortOrder() + .asc("c1") + .asc("c2") + .commit(); + inputDF.writeTo(tableName).append(); + + assertEquals("Row count must match", + ImmutableList.of(row(7L)), + sql("SELECT count(*) FROM %s", tableName)); + } +} diff --git a/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/source/TestRuntimeFiltering.java b/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/source/TestRuntimeFiltering.java new file mode 100644 index 000000000000..c5127d0c636d --- /dev/null +++ b/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/source/TestRuntimeFiltering.java @@ -0,0 +1,415 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iceberg.spark.source; + +import java.io.IOException; +import java.io.UncheckedIOException; +import java.util.List; +import java.util.Set; +import org.apache.commons.lang3.StringUtils; +import org.apache.iceberg.FileScanTask; +import org.apache.iceberg.Table; +import org.apache.iceberg.expressions.Expression; +import org.apache.iceberg.expressions.Expressions; +import org.apache.iceberg.io.CloseableIterable; +import org.apache.iceberg.io.FileIO; +import org.apache.iceberg.relocated.com.google.common.collect.Sets; +import org.apache.iceberg.spark.SparkTestBaseWithCatalog; +import org.apache.iceberg.spark.SparkWriteOptions; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.catalyst.analysis.NoSuchTableException; +import org.junit.After; +import org.junit.Assert; +import org.junit.Test; + +import static org.apache.spark.sql.functions.date_add; +import static org.apache.spark.sql.functions.expr; + +public class TestRuntimeFiltering extends SparkTestBaseWithCatalog { + + @After + public void removeTables() { + sql("DROP TABLE IF EXISTS %s", tableName); + sql("DROP TABLE IF EXISTS dim"); + } + + @Test + public void testIdentityPartitionedTable() throws NoSuchTableException { + sql("CREATE TABLE %s (id BIGINT, data STRING, date DATE, ts TIMESTAMP) " + + "USING iceberg " + + "PARTITIONED BY (date)", tableName); + + Dataset df = spark.range(1, 100) + .withColumn("date", date_add(expr("DATE '1970-01-01'"), expr("CAST(id % 4 AS INT)"))) + .withColumn("ts", expr("TO_TIMESTAMP(date)")) + .withColumn("data", expr("CAST(date AS STRING)")) + .select("id", "data", "date", "ts"); + + df.coalesce(1).writeTo(tableName).option(SparkWriteOptions.FANOUT_ENABLED, "true").append(); + + sql("CREATE TABLE dim (id BIGINT, date DATE) USING parquet"); + Dataset dimDF = spark.range(1, 10) + .withColumn("date", expr("DATE '1970-01-02'")) + .select("id", "date"); + dimDF.coalesce(1).write().mode("append").insertInto("dim"); + + String query = String.format( + "SELECT f.* FROM %s f JOIN dim d ON f.date = d.date AND d.id = 1 ORDER BY id", + tableName); + + assertQueryContainsRuntimeFilter(query); + + deleteNotMatchingFiles(Expressions.equal("date", 1), 3); + + assertEquals("Should have expected rows", + sql("SELECT * FROM %s WHERE date = DATE '1970-01-02' ORDER BY id", tableName), + sql(query)); + } + + @Test + public void testBucketedTable() throws NoSuchTableException { + sql("CREATE TABLE %s (id BIGINT, data STRING, date DATE, ts TIMESTAMP) " + + "USING iceberg " + + "PARTITIONED BY (bucket(8, id))", tableName); + + Dataset df = spark.range(1, 100) + .withColumn("date", date_add(expr("DATE '1970-01-01'"), expr("CAST(id % 4 AS INT)"))) + .withColumn("ts", expr("TO_TIMESTAMP(date)")) + .withColumn("data", expr("CAST(date AS STRING)")) + .select("id", "data", "date", "ts"); + + df.coalesce(1).writeTo(tableName).option(SparkWriteOptions.FANOUT_ENABLED, "true").append(); + + sql("CREATE TABLE dim (id BIGINT, date DATE) USING parquet"); + Dataset dimDF = spark.range(1, 2) + .withColumn("date", expr("DATE '1970-01-02'")) + .select("id", "date"); + dimDF.coalesce(1).write().mode("append").insertInto("dim"); + + String query = String.format( + "SELECT f.* FROM %s f JOIN dim d ON f.id = d.id AND d.date = DATE '1970-01-02' ORDER BY date", + tableName); + + assertQueryContainsRuntimeFilter(query); + + deleteNotMatchingFiles(Expressions.equal("id", 1), 7); + + assertEquals("Should have expected rows", + sql("SELECT * FROM %s WHERE id = 1 ORDER BY date", tableName), + sql(query)); + } + + @Test + public void testRenamedSourceColumnTable() throws NoSuchTableException { + sql("CREATE TABLE %s (id BIGINT, data STRING, date DATE, ts TIMESTAMP) " + + "USING iceberg " + + "PARTITIONED BY (bucket(8, id))", tableName); + + Dataset df = spark.range(1, 100) + .withColumn("date", date_add(expr("DATE '1970-01-01'"), expr("CAST(id % 4 AS INT)"))) + .withColumn("ts", expr("TO_TIMESTAMP(date)")) + .withColumn("data", expr("CAST(date AS STRING)")) + .select("id", "data", "date", "ts"); + + df.coalesce(1).writeTo(tableName).option(SparkWriteOptions.FANOUT_ENABLED, "true").append(); + + sql("CREATE TABLE dim (id BIGINT, date DATE) USING parquet"); + Dataset dimDF = spark.range(1, 2) + .withColumn("date", expr("DATE '1970-01-02'")) + .select("id", "date"); + dimDF.coalesce(1).write().mode("append").insertInto("dim"); + + sql("ALTER TABLE %s RENAME COLUMN id TO row_id", tableName); + + String query = String.format( + "SELECT f.* FROM %s f JOIN dim d ON f.row_id = d.id AND d.date = DATE '1970-01-02' ORDER BY date", + tableName); + + assertQueryContainsRuntimeFilter(query); + + deleteNotMatchingFiles(Expressions.equal("row_id", 1), 7); + + assertEquals("Should have expected rows", + sql("SELECT * FROM %s WHERE row_id = 1 ORDER BY date", tableName), + sql(query)); + } + + @Test + public void testMultipleRuntimeFilters() throws NoSuchTableException { + sql("CREATE TABLE %s (id BIGINT, data STRING, date DATE, ts TIMESTAMP) " + + "USING iceberg " + + "PARTITIONED BY (data, bucket(8, id))", tableName); + + Dataset df = spark.range(1, 100) + .withColumn("date", date_add(expr("DATE '1970-01-01'"), expr("CAST(id % 4 AS INT)"))) + .withColumn("ts", expr("TO_TIMESTAMP(date)")) + .withColumn("data", expr("CAST(date AS STRING)")) + .select("id", "data", "date", "ts"); + + df.coalesce(1).writeTo(tableName).option(SparkWriteOptions.FANOUT_ENABLED, "true").append(); + + sql("CREATE TABLE dim (id BIGINT, date DATE, data STRING) USING parquet"); + Dataset dimDF = spark.range(1, 2) + .withColumn("date", expr("DATE '1970-01-02'")) + .withColumn("data", expr("'1970-01-02'")) + .select("id", "date", "data"); + dimDF.coalesce(1).write().mode("append").insertInto("dim"); + + String query = String.format( + "SELECT f.* FROM %s f JOIN dim d ON f.id = d.id AND f.data = d.data AND d.date = DATE '1970-01-02'", + tableName); + + assertQueryContainsRuntimeFilters(query, 2, "Query should have 2 runtime filters"); + + deleteNotMatchingFiles(Expressions.equal("id", 1), 31); + + assertEquals("Should have expected rows", + sql("SELECT * FROM %s WHERE id = 1 AND data = '1970-01-02'", tableName), + sql(query)); + } + + @Test + public void testCaseSensitivityOfRuntimeFilters() throws NoSuchTableException { + sql("CREATE TABLE %s (id BIGINT, data STRING, date DATE, ts TIMESTAMP) " + + "USING iceberg " + + "PARTITIONED BY (data, bucket(8, id))", tableName); + + Dataset df = spark.range(1, 100) + .withColumn("date", date_add(expr("DATE '1970-01-01'"), expr("CAST(id % 4 AS INT)"))) + .withColumn("ts", expr("TO_TIMESTAMP(date)")) + .withColumn("data", expr("CAST(date AS STRING)")) + .select("id", "data", "date", "ts"); + + df.coalesce(1).writeTo(tableName).option(SparkWriteOptions.FANOUT_ENABLED, "true").append(); + + sql("CREATE TABLE dim (id BIGINT, date DATE, data STRING) USING parquet"); + Dataset dimDF = spark.range(1, 2) + .withColumn("date", expr("DATE '1970-01-02'")) + .withColumn("data", expr("'1970-01-02'")) + .select("id", "date", "data"); + dimDF.coalesce(1).write().mode("append").insertInto("dim"); + + String caseInsensitiveQuery = String.format( + "select f.* from %s F join dim d ON f.Id = d.iD and f.DaTa = d.dAtA and d.dAtE = date '1970-01-02'", + tableName); + + assertQueryContainsRuntimeFilters(caseInsensitiveQuery, 2, "Query should have 2 runtime filters"); + + deleteNotMatchingFiles(Expressions.equal("id", 1), 31); + + assertEquals("Should have expected rows", + sql("SELECT * FROM %s WHERE id = 1 AND data = '1970-01-02'", tableName), + sql(caseInsensitiveQuery)); + } + + @Test + public void testBucketedTableWithMultipleSpecs() throws NoSuchTableException { + sql("CREATE TABLE %s (id BIGINT, data STRING, date DATE, ts TIMESTAMP) USING iceberg", tableName); + + Dataset df1 = spark.range(1, 100) + .withColumn("date", date_add(expr("DATE '1970-01-01'"), expr("CAST(id % 2 AS INT)"))) + .withColumn("ts", expr("TO_TIMESTAMP(date)")) + .withColumn("data", expr("CAST(date AS STRING)")) + .select("id", "data", "date", "ts"); + + df1.coalesce(1).writeTo(tableName).append(); + + Table table = validationCatalog.loadTable(tableIdent); + table.updateSpec() + .addField(Expressions.bucket("id", 8)) + .commit(); + + sql("REFRESH TABLE %s", tableName); + + Dataset df2 = spark.range(1, 100) + .withColumn("date", date_add(expr("DATE '1970-01-01'"), expr("CAST(id % 4 AS INT)"))) + .withColumn("ts", expr("TO_TIMESTAMP(date)")) + .withColumn("data", expr("CAST(date AS STRING)")) + .select("id", "data", "date", "ts"); + + df2.coalesce(1).writeTo(tableName).option(SparkWriteOptions.FANOUT_ENABLED, "true").append(); + + sql("CREATE TABLE dim (id BIGINT, date DATE) USING parquet"); + Dataset dimDF = spark.range(1, 2) + .withColumn("date", expr("DATE '1970-01-02'")) + .select("id", "date"); + dimDF.coalesce(1).write().mode("append").insertInto("dim"); + + String query = String.format( + "SELECT f.* FROM %s f JOIN dim d ON f.id = d.id AND d.date = DATE '1970-01-02' ORDER BY date", + tableName); + + assertQueryContainsRuntimeFilter(query); + + deleteNotMatchingFiles(Expressions.equal("id", 1), 7); + + assertEquals("Should have expected rows", + sql("SELECT * FROM %s WHERE id = 1 ORDER BY date", tableName), + sql(query)); + } + + @Test + public void testSourceColumnWithDots() throws NoSuchTableException { + sql("CREATE TABLE %s (`i.d` BIGINT, data STRING, date DATE, ts TIMESTAMP) " + + "USING iceberg " + + "PARTITIONED BY (bucket(8, `i.d`))", tableName); + + Dataset df = spark.range(1, 100) + .withColumnRenamed("id", "i.d") + .withColumn("date", date_add(expr("DATE '1970-01-01'"), expr("CAST(`i.d` % 4 AS INT)"))) + .withColumn("ts", expr("TO_TIMESTAMP(date)")) + .withColumn("data", expr("CAST(date AS STRING)")) + .select("`i.d`", "data", "date", "ts"); + + df.coalesce(1).writeTo(tableName).option(SparkWriteOptions.FANOUT_ENABLED, "true").append(); + + sql("SELECT * FROM %s WHERE `i.d` = 1", tableName); + + sql("CREATE TABLE dim (id BIGINT, date DATE) USING parquet"); + Dataset dimDF = spark.range(1, 2) + .withColumn("date", expr("DATE '1970-01-02'")) + .select("id", "date"); + dimDF.coalesce(1).write().mode("append").insertInto("dim"); + + String query = String.format( + "SELECT f.* FROM %s f JOIN dim d ON f.`i.d` = d.id AND d.date = DATE '1970-01-02' ORDER BY date", + tableName); + + assertQueryContainsRuntimeFilter(query); + + deleteNotMatchingFiles(Expressions.equal("i.d", 1), 7); + + sql(query); + + assertEquals("Should have expected rows", + sql("SELECT * FROM %s WHERE `i.d` = 1 ORDER BY date", tableName), + sql(query)); + } + + @Test + public void testSourceColumnWithBackticks() throws NoSuchTableException { + sql("CREATE TABLE %s (`i``d` BIGINT, data STRING, date DATE, ts TIMESTAMP) " + + "USING iceberg " + + "PARTITIONED BY (bucket(8, `i``d`))", tableName); + + Dataset df = spark.range(1, 100) + .withColumnRenamed("id", "i`d") + .withColumn("date", date_add(expr("DATE '1970-01-01'"), expr("CAST(`i``d` % 4 AS INT)"))) + .withColumn("ts", expr("TO_TIMESTAMP(date)")) + .withColumn("data", expr("CAST(date AS STRING)")) + .select("`i``d`", "data", "date", "ts"); + + df.coalesce(1).writeTo(tableName).option(SparkWriteOptions.FANOUT_ENABLED, "true").append(); + + sql("CREATE TABLE dim (id BIGINT, date DATE) USING parquet"); + Dataset dimDF = spark.range(1, 2) + .withColumn("date", expr("DATE '1970-01-02'")) + .select("id", "date"); + dimDF.coalesce(1).write().mode("append").insertInto("dim"); + + String query = String.format( + "SELECT f.* FROM %s f JOIN dim d ON f.`i``d` = d.id AND d.date = DATE '1970-01-02' ORDER BY date", + tableName); + + assertQueryContainsRuntimeFilter(query); + + deleteNotMatchingFiles(Expressions.equal("i`d", 1), 7); + + assertEquals("Should have expected rows", + sql("SELECT * FROM %s WHERE `i``d` = 1 ORDER BY date", tableName), + sql(query)); + } + + @Test + public void testUnpartitionedTable() throws NoSuchTableException { + sql("CREATE TABLE %s (id BIGINT, data STRING, date DATE, ts TIMESTAMP) USING iceberg", tableName); + + Dataset df = spark.range(1, 100) + .withColumn("date", date_add(expr("DATE '1970-01-01'"), expr("CAST(id % 4 AS INT)"))) + .withColumn("ts", expr("TO_TIMESTAMP(date)")) + .withColumn("data", expr("CAST(date AS STRING)")) + .select("id", "data", "date", "ts"); + + df.coalesce(1).writeTo(tableName).append(); + + sql("CREATE TABLE dim (id BIGINT, date DATE) USING parquet"); + Dataset dimDF = spark.range(1, 2) + .withColumn("date", expr("DATE '1970-01-02'")) + .select("id", "date"); + dimDF.coalesce(1).write().mode("append").insertInto("dim"); + + String query = String.format( + "SELECT f.* FROM %s f JOIN dim d ON f.id = d.id AND d.date = DATE '1970-01-02' ORDER BY date", + tableName); + + assertQueryContainsNoRuntimeFilter(query); + + assertEquals("Should have expected rows", + sql("SELECT * FROM %s WHERE id = 1 ORDER BY date", tableName), + sql(query)); + } + + private void assertQueryContainsRuntimeFilter(String query) { + assertQueryContainsRuntimeFilters(query, 1, "Query should have 1 runtime filter"); + } + + private void assertQueryContainsNoRuntimeFilter(String query) { + assertQueryContainsRuntimeFilters(query, 0, "Query should have no runtime filters"); + } + + private void assertQueryContainsRuntimeFilters(String query, int expectedFilterCount, String errorMessage) { + List output = spark.sql("EXPLAIN EXTENDED " + query).collectAsList(); + String plan = output.get(0).getString(0); + int actualFilterCount = StringUtils.countMatches(plan, "dynamicpruningexpression"); + Assert.assertEquals(errorMessage, expectedFilterCount, actualFilterCount); + } + + // delete files that don't match the filter to ensure dynamic filtering works and only required files are read + private void deleteNotMatchingFiles(Expression filter, int expectedDeletedFileCount) { + Table table = validationCatalog.loadTable(tableIdent); + FileIO io = table.io(); + + Set matchingFileLocations = Sets.newHashSet(); + try (CloseableIterable files = table.newScan().filter(filter).planFiles()) { + for (FileScanTask file : files) { + String path = file.file().path().toString(); + matchingFileLocations.add(path); + } + } catch (IOException e) { + throw new UncheckedIOException(e); + } + + Set deletedFileLocations = Sets.newHashSet(); + try (CloseableIterable files = table.newScan().planFiles()) { + for (FileScanTask file : files) { + String path = file.file().path().toString(); + if (!matchingFileLocations.contains(path)) { + io.deleteFile(path); + deletedFileLocations.add(path); + } + } + } catch (IOException e) { + throw new UncheckedIOException(e); + } + + Assert.assertEquals("Deleted unexpected number of files", expectedDeletedFileCount, deletedFileLocations.size()); + } +} diff --git a/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/source/TestSnapshotSelection.java b/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/source/TestSnapshotSelection.java new file mode 100644 index 000000000000..6da76b0b4211 --- /dev/null +++ b/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/source/TestSnapshotSelection.java @@ -0,0 +1,238 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iceberg.spark.source; + +import java.io.IOException; +import java.util.List; +import org.apache.hadoop.conf.Configuration; +import org.apache.iceberg.PartitionSpec; +import org.apache.iceberg.Schema; +import org.apache.iceberg.Snapshot; +import org.apache.iceberg.Table; +import org.apache.iceberg.hadoop.HadoopTables; +import org.apache.iceberg.relocated.com.google.common.collect.Iterables; +import org.apache.iceberg.relocated.com.google.common.collect.Lists; +import org.apache.iceberg.spark.SparkReadOptions; +import org.apache.iceberg.types.Types; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Encoders; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.SparkSession; +import org.junit.AfterClass; +import org.junit.Assert; +import org.junit.BeforeClass; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.TemporaryFolder; + +import static org.apache.iceberg.types.Types.NestedField.optional; + +public class TestSnapshotSelection { + + private static final Configuration CONF = new Configuration(); + private static final Schema SCHEMA = new Schema( + optional(1, "id", Types.IntegerType.get()), + optional(2, "data", Types.StringType.get()) + ); + + @Rule + public TemporaryFolder temp = new TemporaryFolder(); + + private static SparkSession spark = null; + + @BeforeClass + public static void startSpark() { + TestSnapshotSelection.spark = SparkSession.builder().master("local[2]").getOrCreate(); + } + + @AfterClass + public static void stopSpark() { + SparkSession currentSpark = TestSnapshotSelection.spark; + TestSnapshotSelection.spark = null; + currentSpark.stop(); + } + + @Test + public void testSnapshotSelectionById() throws IOException { + String tableLocation = temp.newFolder("iceberg-table").toString(); + + HadoopTables tables = new HadoopTables(CONF); + PartitionSpec spec = PartitionSpec.unpartitioned(); + Table table = tables.create(SCHEMA, spec, tableLocation); + + // produce the first snapshot + List firstBatchRecords = Lists.newArrayList( + new SimpleRecord(1, "a"), + new SimpleRecord(2, "b"), + new SimpleRecord(3, "c") + ); + Dataset firstDf = spark.createDataFrame(firstBatchRecords, SimpleRecord.class); + firstDf.select("id", "data").write().format("iceberg").mode("append").save(tableLocation); + + // produce the second snapshot + List secondBatchRecords = Lists.newArrayList( + new SimpleRecord(4, "d"), + new SimpleRecord(5, "e"), + new SimpleRecord(6, "f") + ); + Dataset secondDf = spark.createDataFrame(secondBatchRecords, SimpleRecord.class); + secondDf.select("id", "data").write().format("iceberg").mode("append").save(tableLocation); + + Assert.assertEquals("Expected 2 snapshots", 2, Iterables.size(table.snapshots())); + + // verify records in the current snapshot + Dataset currentSnapshotResult = spark.read() + .format("iceberg") + .load(tableLocation); + List currentSnapshotRecords = currentSnapshotResult.orderBy("id") + .as(Encoders.bean(SimpleRecord.class)) + .collectAsList(); + List expectedRecords = Lists.newArrayList(); + expectedRecords.addAll(firstBatchRecords); + expectedRecords.addAll(secondBatchRecords); + Assert.assertEquals("Current snapshot rows should match", expectedRecords, currentSnapshotRecords); + + // verify records in the previous snapshot + Snapshot currentSnapshot = table.currentSnapshot(); + Long parentSnapshotId = currentSnapshot.parentId(); + Dataset previousSnapshotResult = spark.read() + .format("iceberg") + .option("snapshot-id", parentSnapshotId) + .load(tableLocation); + List previousSnapshotRecords = previousSnapshotResult.orderBy("id") + .as(Encoders.bean(SimpleRecord.class)) + .collectAsList(); + Assert.assertEquals("Previous snapshot rows should match", firstBatchRecords, previousSnapshotRecords); + } + + @Test + public void testSnapshotSelectionByTimestamp() throws IOException { + String tableLocation = temp.newFolder("iceberg-table").toString(); + + HadoopTables tables = new HadoopTables(CONF); + PartitionSpec spec = PartitionSpec.unpartitioned(); + Table table = tables.create(SCHEMA, spec, tableLocation); + + // produce the first snapshot + List firstBatchRecords = Lists.newArrayList( + new SimpleRecord(1, "a"), + new SimpleRecord(2, "b"), + new SimpleRecord(3, "c") + ); + Dataset firstDf = spark.createDataFrame(firstBatchRecords, SimpleRecord.class); + firstDf.select("id", "data").write().format("iceberg").mode("append").save(tableLocation); + + // remember the time when the first snapshot was valid + long firstSnapshotTimestamp = System.currentTimeMillis(); + + // produce the second snapshot + List secondBatchRecords = Lists.newArrayList( + new SimpleRecord(4, "d"), + new SimpleRecord(5, "e"), + new SimpleRecord(6, "f") + ); + Dataset secondDf = spark.createDataFrame(secondBatchRecords, SimpleRecord.class); + secondDf.select("id", "data").write().format("iceberg").mode("append").save(tableLocation); + + Assert.assertEquals("Expected 2 snapshots", 2, Iterables.size(table.snapshots())); + + // verify records in the current snapshot + Dataset currentSnapshotResult = spark.read() + .format("iceberg") + .load(tableLocation); + List currentSnapshotRecords = currentSnapshotResult.orderBy("id") + .as(Encoders.bean(SimpleRecord.class)) + .collectAsList(); + List expectedRecords = Lists.newArrayList(); + expectedRecords.addAll(firstBatchRecords); + expectedRecords.addAll(secondBatchRecords); + Assert.assertEquals("Current snapshot rows should match", expectedRecords, currentSnapshotRecords); + + // verify records in the previous snapshot + Dataset previousSnapshotResult = spark.read() + .format("iceberg") + .option(SparkReadOptions.AS_OF_TIMESTAMP, firstSnapshotTimestamp) + .load(tableLocation); + List previousSnapshotRecords = previousSnapshotResult.orderBy("id") + .as(Encoders.bean(SimpleRecord.class)) + .collectAsList(); + Assert.assertEquals("Previous snapshot rows should match", firstBatchRecords, previousSnapshotRecords); + } + + @Test(expected = IllegalArgumentException.class) + public void testSnapshotSelectionByInvalidSnapshotId() throws IOException { + String tableLocation = temp.newFolder("iceberg-table").toString(); + + HadoopTables tables = new HadoopTables(CONF); + PartitionSpec spec = PartitionSpec.unpartitioned(); + tables.create(SCHEMA, spec, tableLocation); + + Dataset df = spark.read() + .format("iceberg") + .option("snapshot-id", -10) + .load(tableLocation); + + df.collectAsList(); + } + + @Test(expected = IllegalArgumentException.class) + public void testSnapshotSelectionByInvalidTimestamp() throws IOException { + long timestamp = System.currentTimeMillis(); + + String tableLocation = temp.newFolder("iceberg-table").toString(); + HadoopTables tables = new HadoopTables(CONF); + PartitionSpec spec = PartitionSpec.unpartitioned(); + tables.create(SCHEMA, spec, tableLocation); + + Dataset df = spark.read() + .format("iceberg") + .option(SparkReadOptions.AS_OF_TIMESTAMP, timestamp) + .load(tableLocation); + + df.collectAsList(); + } + + @Test(expected = IllegalArgumentException.class) + public void testSnapshotSelectionBySnapshotIdAndTimestamp() throws IOException { + String tableLocation = temp.newFolder("iceberg-table").toString(); + + HadoopTables tables = new HadoopTables(CONF); + PartitionSpec spec = PartitionSpec.unpartitioned(); + Table table = tables.create(SCHEMA, spec, tableLocation); + + List firstBatchRecords = Lists.newArrayList( + new SimpleRecord(1, "a"), + new SimpleRecord(2, "b"), + new SimpleRecord(3, "c") + ); + Dataset firstDf = spark.createDataFrame(firstBatchRecords, SimpleRecord.class); + firstDf.select("id", "data").write().format("iceberg").mode("append").save(tableLocation); + + long timestamp = System.currentTimeMillis(); + long snapshotId = table.currentSnapshot().snapshotId(); + Dataset df = spark.read() + .format("iceberg") + .option(SparkReadOptions.SNAPSHOT_ID, snapshotId) + .option(SparkReadOptions.AS_OF_TIMESTAMP, timestamp) + .load(tableLocation); + + df.collectAsList(); + } +} diff --git a/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/source/TestSparkAppenderFactory.java b/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/source/TestSparkAppenderFactory.java new file mode 100644 index 000000000000..bda525780d8b --- /dev/null +++ b/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/source/TestSparkAppenderFactory.java @@ -0,0 +1,70 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iceberg.spark.source; + +import java.util.List; +import org.apache.iceberg.Schema; +import org.apache.iceberg.io.FileAppenderFactory; +import org.apache.iceberg.io.TestAppenderFactory; +import org.apache.iceberg.spark.SparkSchemaUtil; +import org.apache.iceberg.util.ArrayUtil; +import org.apache.iceberg.util.StructLikeSet; +import org.apache.spark.sql.catalyst.InternalRow; +import org.apache.spark.sql.catalyst.expressions.GenericInternalRow; +import org.apache.spark.sql.types.StructType; +import org.apache.spark.unsafe.types.UTF8String; + +public class TestSparkAppenderFactory extends TestAppenderFactory { + + private final StructType sparkType; + + public TestSparkAppenderFactory(String fileFormat, boolean partitioned) { + super(fileFormat, partitioned); + this.sparkType = SparkSchemaUtil.convert(SCHEMA); + } + + @Override + protected FileAppenderFactory createAppenderFactory(List equalityFieldIds, + Schema eqDeleteSchema, + Schema posDeleteRowSchema) { + return SparkAppenderFactory.builderFor(table, table.schema(), sparkType) + .equalityFieldIds(ArrayUtil.toIntArray(equalityFieldIds)) + .eqDeleteRowSchema(eqDeleteSchema) + .posDelRowSchema(posDeleteRowSchema).build(); + } + + @Override + protected InternalRow createRow(Integer id, String data) { + InternalRow row = new GenericInternalRow(2); + row.update(0, id); + row.update(1, UTF8String.fromString(data)); + return row; + } + + @Override + protected StructLikeSet expectedRowSet(Iterable rows) { + StructLikeSet set = StructLikeSet.create(table.schema().asStruct()); + for (InternalRow row : rows) { + InternalRowWrapper wrapper = new InternalRowWrapper(sparkType); + set.add(wrapper.wrap(row)); + } + return set; + } +} diff --git a/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/source/TestSparkBaseDataReader.java b/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/source/TestSparkBaseDataReader.java new file mode 100644 index 000000000000..870be890da90 --- /dev/null +++ b/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/source/TestSparkBaseDataReader.java @@ -0,0 +1,278 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iceberg.spark.source; + +import java.io.File; +import java.io.IOException; +import java.util.Iterator; +import java.util.List; +import java.util.Map; +import java.util.UUID; +import java.util.stream.Collectors; +import java.util.stream.IntStream; +import java.util.stream.StreamSupport; +import org.apache.avro.generic.GenericData; +import org.apache.iceberg.AppendFiles; +import org.apache.iceberg.BaseCombinedScanTask; +import org.apache.iceberg.DataFile; +import org.apache.iceberg.DataFiles; +import org.apache.iceberg.FileScanTask; +import org.apache.iceberg.PartitionSpec; +import org.apache.iceberg.Schema; +import org.apache.iceberg.Table; +import org.apache.iceberg.io.CloseableIterator; +import org.apache.iceberg.io.FileAppender; +import org.apache.iceberg.parquet.Parquet; +import org.apache.iceberg.relocated.com.google.common.collect.Maps; +import org.apache.iceberg.spark.data.RandomData; +import org.apache.iceberg.types.Types; +import org.junit.Assert; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.TemporaryFolder; + +import static org.apache.iceberg.FileFormat.PARQUET; +import static org.apache.iceberg.Files.localOutput; + +public class TestSparkBaseDataReader { + + @Rule + public TemporaryFolder temp = new TemporaryFolder(); + + private Table table; + + // Simulates the closeable iterator of data to be read + private static class CloseableIntegerRange implements CloseableIterator { + boolean closed; + Iterator iter; + + CloseableIntegerRange(long range) { + this.closed = false; + this.iter = IntStream.range(0, (int) range).iterator(); + } + + @Override + public void close() { + this.closed = true; + } + + @Override + public boolean hasNext() { + return iter.hasNext(); + } + + @Override + public Integer next() { + return iter.next(); + } + } + + // Main reader class to test base class iteration logic. + // Keeps track of iterator closure. + private static class ClosureTrackingReader extends BaseDataReader { + private Map tracker = Maps.newHashMap(); + + ClosureTrackingReader(Table table, List tasks) { + super(table, new BaseCombinedScanTask(tasks)); + } + + @Override + CloseableIterator open(FileScanTask task) { + CloseableIntegerRange intRange = new CloseableIntegerRange(task.file().recordCount()); + tracker.put(getKey(task), intRange); + return intRange; + } + + public Boolean isIteratorClosed(FileScanTask task) { + return tracker.get(getKey(task)).closed; + } + + public Boolean hasIterator(FileScanTask task) { + return tracker.containsKey(getKey(task)); + } + + private String getKey(FileScanTask task) { + return task.file().path().toString(); + } + } + + @Test + public void testClosureOnDataExhaustion() throws IOException { + Integer totalTasks = 10; + Integer recordPerTask = 10; + List tasks = createFileScanTasks(totalTasks, recordPerTask); + + ClosureTrackingReader reader = new ClosureTrackingReader(table, tasks); + + int countRecords = 0; + while (reader.next()) { + countRecords += 1; + Assert.assertNotNull("Reader should return non-null value", reader.get()); + } + + Assert.assertEquals("Reader returned incorrect number of records", + totalTasks * recordPerTask, + countRecords + ); + tasks.forEach(t -> + Assert.assertTrue("All iterators should be closed after read exhausion", + reader.isIteratorClosed(t)) + ); + } + + @Test + public void testClosureDuringIteration() throws IOException { + Integer totalTasks = 2; + Integer recordPerTask = 1; + List tasks = createFileScanTasks(totalTasks, recordPerTask); + Assert.assertEquals(2, tasks.size()); + FileScanTask firstTask = tasks.get(0); + FileScanTask secondTask = tasks.get(1); + + ClosureTrackingReader reader = new ClosureTrackingReader(table, tasks); + + // Total of 2 elements + Assert.assertTrue(reader.next()); + Assert.assertFalse("First iter should not be closed on its last element", + reader.isIteratorClosed(firstTask)); + + Assert.assertTrue(reader.next()); + Assert.assertTrue("First iter should be closed after moving to second iter", + reader.isIteratorClosed(firstTask)); + Assert.assertFalse("Second iter should not be closed on its last element", + reader.isIteratorClosed(secondTask)); + + Assert.assertFalse(reader.next()); + Assert.assertTrue(reader.isIteratorClosed(firstTask)); + Assert.assertTrue(reader.isIteratorClosed(secondTask)); + } + + @Test + public void testClosureWithoutAnyRead() throws IOException { + Integer totalTasks = 10; + Integer recordPerTask = 10; + List tasks = createFileScanTasks(totalTasks, recordPerTask); + + ClosureTrackingReader reader = new ClosureTrackingReader(table, tasks); + + reader.close(); + + tasks.forEach(t -> + Assert.assertFalse("Iterator should not be created eagerly for tasks", + reader.hasIterator(t)) + ); + } + + @Test + public void testExplicitClosure() throws IOException { + Integer totalTasks = 10; + Integer recordPerTask = 10; + List tasks = createFileScanTasks(totalTasks, recordPerTask); + + ClosureTrackingReader reader = new ClosureTrackingReader(table, tasks); + + Integer halfDataSize = (totalTasks * recordPerTask) / 2; + for (int i = 0; i < halfDataSize; i++) { + Assert.assertTrue("Reader should have some element", reader.next()); + Assert.assertNotNull("Reader should return non-null value", reader.get()); + } + + reader.close(); + + // Some tasks might have not been opened yet, so we don't have corresponding tracker for it. + // But all that have been created must be closed. + tasks.forEach(t -> { + if (reader.hasIterator(t)) { + Assert.assertTrue("Iterator should be closed after read exhausion", + reader.isIteratorClosed(t)); + } + }); + } + + @Test + public void testIdempotentExplicitClosure() throws IOException { + Integer totalTasks = 10; + Integer recordPerTask = 10; + List tasks = createFileScanTasks(totalTasks, recordPerTask); + + ClosureTrackingReader reader = new ClosureTrackingReader(table, tasks); + + // Total 100 elements, only 5 iterators have been created + for (int i = 0; i < 45; i++) { + Assert.assertTrue("eader should have some element", reader.next()); + Assert.assertNotNull("Reader should return non-null value", reader.get()); + } + + for (int closeAttempt = 0; closeAttempt < 5; closeAttempt++) { + reader.close(); + for (int i = 0; i < 5; i++) { + Assert.assertTrue("Iterator should be closed after read exhausion", + reader.isIteratorClosed(tasks.get(i))); + } + for (int i = 5; i < 10; i++) { + Assert.assertFalse("Iterator should not be created eagerly for tasks", + reader.hasIterator(tasks.get(i))); + } + } + } + + private List createFileScanTasks(Integer totalTasks, Integer recordPerTask) throws IOException { + String desc = "make_scan_tasks"; + File parent = temp.newFolder(desc); + File location = new File(parent, "test"); + File dataFolder = new File(location, "data"); + Assert.assertTrue("mkdirs should succeed", dataFolder.mkdirs()); + + Schema schema = new Schema( + Types.NestedField.required(0, "id", Types.LongType.get()) + ); + + try { + this.table = TestTables.create(location, desc, schema, PartitionSpec.unpartitioned()); + // Important: use the table's schema for the rest of the test + // When tables are created, the column ids are reassigned. + Schema tableSchema = table.schema(); + List expected = RandomData.generateList(tableSchema, recordPerTask, 1L); + + AppendFiles appendFiles = table.newAppend(); + for (int i = 0; i < totalTasks; i++) { + File parquetFile = new File(dataFolder, PARQUET.addExtension(UUID.randomUUID().toString())); + try (FileAppender writer = Parquet.write(localOutput(parquetFile)) + .schema(tableSchema) + .build()) { + writer.addAll(expected); + } + DataFile file = DataFiles.builder(PartitionSpec.unpartitioned()) + .withFileSizeInBytes(parquetFile.length()) + .withPath(parquetFile.toString()) + .withRecordCount(recordPerTask) + .build(); + appendFiles.appendFile(file); + } + appendFiles.commit(); + + return StreamSupport + .stream(table.newScan().planFiles().spliterator(), false) + .collect(Collectors.toList()); + } finally { + TestTables.clearTables(); + } + } +} diff --git a/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/source/TestSparkCatalog.java b/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/source/TestSparkCatalog.java new file mode 100644 index 000000000000..ef7b80a4ba6f --- /dev/null +++ b/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/source/TestSparkCatalog.java @@ -0,0 +1,64 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iceberg.spark.source; + +import java.util.Map; +import org.apache.iceberg.catalog.Namespace; +import org.apache.iceberg.catalog.TableIdentifier; +import org.apache.iceberg.relocated.com.google.common.base.Preconditions; +import org.apache.iceberg.relocated.com.google.common.collect.Maps; +import org.apache.iceberg.spark.Spark3Util; +import org.apache.iceberg.spark.SparkSessionCatalog; +import org.apache.spark.sql.catalyst.analysis.NoSuchTableException; +import org.apache.spark.sql.connector.catalog.Identifier; +import org.apache.spark.sql.connector.catalog.SupportsNamespaces; +import org.apache.spark.sql.connector.catalog.Table; +import org.apache.spark.sql.connector.catalog.TableCatalog; + +public class TestSparkCatalog extends SparkSessionCatalog { + + private static final Map tableMap = Maps.newHashMap(); + + public static void setTable(Identifier ident, Table table) { + Preconditions.checkArgument(!tableMap.containsKey(ident), "Cannot set " + ident + ". It is already set"); + tableMap.put(ident, table); + } + + @Override + public Table loadTable(Identifier ident) throws NoSuchTableException { + if (tableMap.containsKey(ident)) { + return tableMap.get(ident); + } + + TableIdentifier tableIdentifier = Spark3Util.identifierToTableIdentifier(ident); + Namespace namespace = tableIdentifier.namespace(); + + TestTables.TestTable table = TestTables.load(tableIdentifier.toString()); + if (table == null && namespace.equals(Namespace.of("default"))) { + table = TestTables.load(tableIdentifier.name()); + } + + return new SparkTable(table, false); + } + + public static void clearTables() { + tableMap.clear(); + } +} diff --git a/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/source/TestSparkCatalogCacheExpiration.java b/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/source/TestSparkCatalogCacheExpiration.java new file mode 100644 index 000000000000..96aeed65bfa7 --- /dev/null +++ b/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/source/TestSparkCatalogCacheExpiration.java @@ -0,0 +1,140 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iceberg.spark.source; + +import java.util.Map; +import org.apache.iceberg.CachingCatalog; +import org.apache.iceberg.CatalogProperties; +import org.apache.iceberg.catalog.Catalog; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap; +import org.apache.iceberg.spark.SparkCatalog; +import org.apache.iceberg.spark.SparkSessionCatalog; +import org.apache.iceberg.spark.SparkTestBaseWithCatalog; +import org.apache.spark.sql.connector.catalog.TableCatalog; +import org.assertj.core.api.Assertions; +import org.junit.BeforeClass; +import org.junit.Test; + +public class TestSparkCatalogCacheExpiration extends SparkTestBaseWithCatalog { + + private static final String sessionCatalogName = "spark_catalog"; + private static final String sessionCatalogImpl = SparkSessionCatalog.class.getName(); + private static final Map sessionCatalogConfig = ImmutableMap.of( + "type", "hadoop", + "default-namespace", "default", + CatalogProperties.CACHE_ENABLED, "true", + CatalogProperties.CACHE_EXPIRATION_INTERVAL_MS, "3000" + ); + + private static String asSqlConfCatalogKeyFor(String catalog, String configKey) { + // configKey is empty when the catalog's class is being defined + if (configKey.isEmpty()) { + return String.format("spark.sql.catalog.%s", catalog); + } else { + return String.format("spark.sql.catalog.%s.%s", catalog, configKey); + } + } + + // Add more catalogs to the spark session, so we only need to start spark one time for multiple + // different catalog configuration tests. + @BeforeClass + public static void beforeClass() { + // Catalog - expiration_disabled: Catalog with caching on and expiration disabled. + ImmutableMap.of( + "", "org.apache.iceberg.spark.SparkCatalog", + "type", "hive", + CatalogProperties.CACHE_ENABLED, "true", + CatalogProperties.CACHE_EXPIRATION_INTERVAL_MS, "-1" + ).forEach((k, v) -> spark.conf().set(asSqlConfCatalogKeyFor("expiration_disabled", k), v)); + + // Catalog - cache_disabled_implicitly: Catalog that does not cache, as the cache expiration interval is 0. + ImmutableMap.of( + "", "org.apache.iceberg.spark.SparkCatalog", + "type", "hive", + CatalogProperties.CACHE_ENABLED, "true", + CatalogProperties.CACHE_EXPIRATION_INTERVAL_MS, "0" + ).forEach((k, v) -> spark.conf().set(asSqlConfCatalogKeyFor("cache_disabled_implicitly", k), v)); + } + + public TestSparkCatalogCacheExpiration() { + super(sessionCatalogName, sessionCatalogImpl, sessionCatalogConfig); + } + + @Test + public void testSparkSessionCatalogWithExpirationEnabled() { + SparkSessionCatalog sparkCatalog = sparkSessionCatalog(); + Assertions.assertThat(sparkCatalog) + .extracting("icebergCatalog") + .extracting("cacheEnabled") + .isEqualTo(true); + + Assertions + .assertThat(sparkCatalog) + .extracting("icebergCatalog") + .extracting("icebergCatalog") + .isInstanceOfSatisfying(Catalog.class, icebergCatalog -> { + Assertions.assertThat(icebergCatalog) + .isExactlyInstanceOf(CachingCatalog.class) + .extracting("expirationIntervalMillis") + .isEqualTo(3000L); + }); + } + + @Test + public void testCacheEnabledAndExpirationDisabled() { + SparkCatalog sparkCatalog = getSparkCatalog("expiration_disabled"); + Assertions.assertThat(sparkCatalog) + .extracting("cacheEnabled") + .isEqualTo(true); + + Assertions + .assertThat(sparkCatalog) + .extracting("icebergCatalog") + .isInstanceOfSatisfying(CachingCatalog.class, icebergCatalog -> { + Assertions.assertThat(icebergCatalog) + .extracting("expirationIntervalMillis") + .isEqualTo(-1L); + }); + } + + @Test + public void testCacheDisabledImplicitly() { + SparkCatalog sparkCatalog = getSparkCatalog("cache_disabled_implicitly"); + Assertions.assertThat(sparkCatalog) + .extracting("cacheEnabled") + .isEqualTo(false); + + Assertions + .assertThat(sparkCatalog) + .extracting("icebergCatalog") + .isInstanceOfSatisfying( + Catalog.class, + icebergCatalog -> Assertions.assertThat(icebergCatalog).isNotInstanceOf(CachingCatalog.class)); + } + + private SparkSessionCatalog sparkSessionCatalog() { + TableCatalog catalog = (TableCatalog) spark.sessionState().catalogManager().catalog("spark_catalog"); + return (SparkSessionCatalog) catalog; + } + + private SparkCatalog getSparkCatalog(String catalog) { + return (SparkCatalog) spark.sessionState().catalogManager().catalog(catalog); + } +} diff --git a/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/source/TestSparkCatalogHadoopOverrides.java b/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/source/TestSparkCatalogHadoopOverrides.java new file mode 100644 index 000000000000..fba6c2686096 --- /dev/null +++ b/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/source/TestSparkCatalogHadoopOverrides.java @@ -0,0 +1,133 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iceberg.spark.source; + +import java.util.Map; +import org.apache.hadoop.conf.Configurable; +import org.apache.hadoop.conf.Configuration; +import org.apache.iceberg.KryoHelpers; +import org.apache.iceberg.SerializableTable; +import org.apache.iceberg.Table; +import org.apache.iceberg.TestHelpers; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap; +import org.apache.iceberg.spark.SparkCatalog; +import org.apache.iceberg.spark.SparkCatalogTestBase; +import org.apache.iceberg.spark.SparkSessionCatalog; +import org.apache.spark.sql.connector.catalog.Identifier; +import org.apache.spark.sql.connector.catalog.TableCatalog; +import org.junit.After; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; +import org.junit.runners.Parameterized; + + +public class TestSparkCatalogHadoopOverrides extends SparkCatalogTestBase { + + private static final String configToOverride = "fs.s3a.buffer.dir"; + // prepend "hadoop." so that the test base formats SQLConf correctly + // as `spark.sql.catalogs..hadoop. + private static final String hadoopPrefixedConfigToOverride = "hadoop." + configToOverride; + private static final String configOverrideValue = "/tmp-overridden"; + + @Parameterized.Parameters(name = "catalogName = {0}, implementation = {1}, config = {2}") + public static Object[][] parameters() { + return new Object[][] { + { "testhive", SparkCatalog.class.getName(), + ImmutableMap.of( + "type", "hive", + "default-namespace", "default", + hadoopPrefixedConfigToOverride, configOverrideValue + ) }, + { "testhadoop", SparkCatalog.class.getName(), + ImmutableMap.of( + "type", "hadoop", + hadoopPrefixedConfigToOverride, configOverrideValue + ) }, + { "spark_catalog", SparkSessionCatalog.class.getName(), + ImmutableMap.of( + "type", "hive", + "default-namespace", "default", + hadoopPrefixedConfigToOverride, configOverrideValue + ) } + }; + } + + public TestSparkCatalogHadoopOverrides(String catalogName, + String implementation, + Map config) { + super(catalogName, implementation, config); + } + + @Before + public void createTable() { + sql("CREATE TABLE IF NOT EXISTS %s (id bigint) USING iceberg", tableName(tableIdent.name())); + } + + @After + public void dropTable() { + sql("DROP TABLE IF EXISTS %s", tableName(tableIdent.name())); + } + + @Test + public void testTableFromCatalogHasOverrides() throws Exception { + Table table = getIcebergTableFromSparkCatalog(); + Configuration conf = ((Configurable) table.io()).getConf(); + String actualCatalogOverride = conf.get(configToOverride, "/whammies"); + Assert.assertEquals( + "Iceberg tables from spark should have the overridden hadoop configurations from the spark config", + configOverrideValue, actualCatalogOverride); + } + + @Test + public void ensureRoundTripSerializedTableRetainsHadoopConfig() throws Exception { + Table table = getIcebergTableFromSparkCatalog(); + Configuration originalConf = ((Configurable) table.io()).getConf(); + String actualCatalogOverride = originalConf.get(configToOverride, "/whammies"); + Assert.assertEquals( + "Iceberg tables from spark should have the overridden hadoop configurations from the spark config", + configOverrideValue, actualCatalogOverride); + + // Now convert to SerializableTable and ensure overridden property is still present. + Table serializableTable = SerializableTable.copyOf(table); + Table kryoSerializedTable = KryoHelpers.roundTripSerialize(SerializableTable.copyOf(table)); + Configuration configFromKryoSerde = ((Configurable) kryoSerializedTable.io()).getConf(); + String kryoSerializedCatalogOverride = configFromKryoSerde.get(configToOverride, "/whammies"); + Assert.assertEquals( + "Tables serialized with Kryo serialization should retain overridden hadoop configuration properties", + configOverrideValue, kryoSerializedCatalogOverride); + + // Do the same for Java based serde + Table javaSerializedTable = TestHelpers.roundTripSerialize(serializableTable); + Configuration configFromJavaSerde = ((Configurable) javaSerializedTable.io()).getConf(); + String javaSerializedCatalogOverride = configFromJavaSerde.get(configToOverride, "/whammies"); + Assert.assertEquals( + "Tables serialized with Java serialization should retain overridden hadoop configuration properties", + configOverrideValue, javaSerializedCatalogOverride); + } + + @SuppressWarnings("ThrowSpecificity") + private Table getIcebergTableFromSparkCatalog() throws Exception { + Identifier identifier = Identifier.of(tableIdent.namespace().levels(), tableIdent.name()); + TableCatalog catalog = (TableCatalog) spark.sessionState().catalogManager().catalog(catalogName); + SparkTable sparkTable = (SparkTable) catalog.loadTable(identifier); + return sparkTable.table(); + } +} diff --git a/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/source/TestSparkDataFile.java b/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/source/TestSparkDataFile.java new file mode 100644 index 000000000000..cd1404766d46 --- /dev/null +++ b/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/source/TestSparkDataFile.java @@ -0,0 +1,216 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iceberg.spark.source; + +import java.io.File; +import java.io.IOException; +import java.util.Arrays; +import java.util.Collections; +import java.util.List; +import java.util.Map; +import java.util.stream.Collectors; +import org.apache.hadoop.conf.Configuration; +import org.apache.iceberg.DataFile; +import org.apache.iceberg.DataFiles; +import org.apache.iceberg.ManifestFile; +import org.apache.iceberg.ManifestFiles; +import org.apache.iceberg.ManifestReader; +import org.apache.iceberg.PartitionSpec; +import org.apache.iceberg.Schema; +import org.apache.iceberg.StructLike; +import org.apache.iceberg.Table; +import org.apache.iceberg.TableProperties; +import org.apache.iceberg.hadoop.HadoopTables; +import org.apache.iceberg.relocated.com.google.common.collect.Iterables; +import org.apache.iceberg.relocated.com.google.common.collect.Lists; +import org.apache.iceberg.relocated.com.google.common.collect.Maps; +import org.apache.iceberg.spark.SparkDataFile; +import org.apache.iceberg.spark.SparkSchemaUtil; +import org.apache.iceberg.spark.data.RandomData; +import org.apache.iceberg.types.Types; +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.sql.Column; +import org.apache.spark.sql.ColumnName; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.SparkSession; +import org.apache.spark.sql.catalyst.InternalRow; +import org.apache.spark.sql.types.StructType; +import org.junit.AfterClass; +import org.junit.Assert; +import org.junit.Before; +import org.junit.BeforeClass; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.TemporaryFolder; + +import static org.apache.iceberg.types.Types.NestedField.optional; +import static org.apache.iceberg.types.Types.NestedField.required; + +public class TestSparkDataFile { + + private static final HadoopTables TABLES = new HadoopTables(new Configuration()); + private static final Schema SCHEMA = new Schema( + required(100, "id", Types.LongType.get()), + optional(101, "data", Types.StringType.get()), + required(102, "b", Types.BooleanType.get()), + optional(103, "i", Types.IntegerType.get()), + required(104, "l", Types.LongType.get()), + optional(105, "f", Types.FloatType.get()), + required(106, "d", Types.DoubleType.get()), + optional(107, "date", Types.DateType.get()), + required(108, "ts", Types.TimestampType.withZone()), + required(110, "s", Types.StringType.get()), + optional(113, "bytes", Types.BinaryType.get()), + required(114, "dec_9_0", Types.DecimalType.of(9, 0)), + required(115, "dec_11_2", Types.DecimalType.of(11, 2)), + required(116, "dec_38_10", Types.DecimalType.of(38, 10)) // maximum precision + ); + private static final PartitionSpec SPEC = PartitionSpec.builderFor(SCHEMA) + .identity("b") + .bucket("i", 2) + .identity("l") + .identity("f") + .identity("d") + .identity("date") + .hour("ts") + .identity("ts") + .truncate("s", 2) + .identity("bytes") + .bucket("dec_9_0", 2) + .bucket("dec_11_2", 2) + .bucket("dec_38_10", 2) + .build(); + + private static SparkSession spark; + private static JavaSparkContext sparkContext = null; + + @BeforeClass + public static void startSpark() { + TestSparkDataFile.spark = SparkSession.builder().master("local[2]").getOrCreate(); + TestSparkDataFile.sparkContext = JavaSparkContext.fromSparkContext(spark.sparkContext()); + } + + @AfterClass + public static void stopSpark() { + SparkSession currentSpark = TestSparkDataFile.spark; + TestSparkDataFile.spark = null; + TestSparkDataFile.sparkContext = null; + currentSpark.stop(); + } + + @Rule + public TemporaryFolder temp = new TemporaryFolder(); + private String tableLocation = null; + + @Before + public void setupTableLocation() throws Exception { + File tableDir = temp.newFolder(); + this.tableLocation = tableDir.toURI().toString(); + } + + @Test + public void testValueConversion() throws IOException { + Table table = TABLES.create(SCHEMA, PartitionSpec.unpartitioned(), Maps.newHashMap(), tableLocation); + checkSparkDataFile(table); + } + + @Test + public void testValueConversionPartitionedTable() throws IOException { + Table table = TABLES.create(SCHEMA, SPEC, Maps.newHashMap(), tableLocation); + checkSparkDataFile(table); + } + + @Test + public void testValueConversionWithEmptyStats() throws IOException { + Map props = Maps.newHashMap(); + props.put(TableProperties.DEFAULT_WRITE_METRICS_MODE, "none"); + Table table = TABLES.create(SCHEMA, SPEC, props, tableLocation); + checkSparkDataFile(table); + } + + private void checkSparkDataFile(Table table) throws IOException { + Iterable rows = RandomData.generateSpark(table.schema(), 200, 0); + JavaRDD rdd = sparkContext.parallelize(Lists.newArrayList(rows)); + Dataset df = spark.internalCreateDataFrame(JavaRDD.toRDD(rdd), SparkSchemaUtil.convert(table.schema()), false); + + df.write().format("iceberg").mode("append").save(tableLocation); + + table.refresh(); + + List manifests = table.currentSnapshot().allManifests(table.io()); + Assert.assertEquals("Should have 1 manifest", 1, manifests.size()); + + List dataFiles = Lists.newArrayList(); + try (ManifestReader reader = ManifestFiles.read(manifests.get(0), table.io())) { + for (DataFile dataFile : reader) { + checkDataFile(dataFile.copy(), DataFiles.builder(table.spec()).copy(dataFile).build()); + dataFiles.add(dataFile.copy()); + } + } + + Dataset dataFileDF = spark.read().format("iceberg").load(tableLocation + "#files"); + + // reorder columns to test arbitrary projections + List columns = Arrays.stream(dataFileDF.columns()) + .map(ColumnName::new) + .collect(Collectors.toList()); + Collections.shuffle(columns); + + List sparkDataFiles = dataFileDF + .select(Iterables.toArray(columns, Column.class)) + .collectAsList(); + + Assert.assertEquals("The number of files should match", dataFiles.size(), sparkDataFiles.size()); + + Types.StructType dataFileType = DataFile.getType(table.spec().partitionType()); + StructType sparkDataFileType = sparkDataFiles.get(0).schema(); + SparkDataFile wrapper = new SparkDataFile(dataFileType, sparkDataFileType); + + for (int i = 0; i < dataFiles.size(); i++) { + checkDataFile(dataFiles.get(i), wrapper.wrap(sparkDataFiles.get(i))); + } + } + + private void checkDataFile(DataFile expected, DataFile actual) { + Assert.assertEquals("Path must match", expected.path(), actual.path()); + Assert.assertEquals("Format must match", expected.format(), actual.format()); + Assert.assertEquals("Record count must match", expected.recordCount(), actual.recordCount()); + Assert.assertEquals("Size must match", expected.fileSizeInBytes(), actual.fileSizeInBytes()); + Assert.assertEquals("Record value counts must match", expected.valueCounts(), actual.valueCounts()); + Assert.assertEquals("Record null value counts must match", expected.nullValueCounts(), actual.nullValueCounts()); + Assert.assertEquals("Record nan value counts must match", expected.nanValueCounts(), actual.nanValueCounts()); + Assert.assertEquals("Lower bounds must match", expected.lowerBounds(), actual.lowerBounds()); + Assert.assertEquals("Upper bounds must match", expected.upperBounds(), actual.upperBounds()); + Assert.assertEquals("Key metadata must match", expected.keyMetadata(), actual.keyMetadata()); + Assert.assertEquals("Split offsets must match", expected.splitOffsets(), actual.splitOffsets()); + Assert.assertEquals("Sort order id must match", expected.sortOrderId(), actual.sortOrderId()); + + checkStructLike(expected.partition(), actual.partition()); + } + + private void checkStructLike(StructLike expected, StructLike actual) { + Assert.assertEquals("Struct size should match", expected.size(), actual.size()); + for (int i = 0; i < expected.size(); i++) { + Assert.assertEquals("Struct values must match", expected.get(i, Object.class), actual.get(i, Object.class)); + } + } +} diff --git a/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/source/TestSparkDataWrite.java b/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/source/TestSparkDataWrite.java new file mode 100644 index 000000000000..5b158c518ae4 --- /dev/null +++ b/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/source/TestSparkDataWrite.java @@ -0,0 +1,658 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iceberg.spark.source; + +import java.io.File; +import java.io.IOException; +import java.util.List; +import java.util.Locale; +import org.apache.hadoop.conf.Configuration; +import org.apache.iceberg.AppendFiles; +import org.apache.iceberg.AssertHelpers; +import org.apache.iceberg.DataFile; +import org.apache.iceberg.FileFormat; +import org.apache.iceberg.ManifestFile; +import org.apache.iceberg.ManifestFiles; +import org.apache.iceberg.PartitionSpec; +import org.apache.iceberg.Schema; +import org.apache.iceberg.Table; +import org.apache.iceberg.TableProperties; +import org.apache.iceberg.exceptions.CommitStateUnknownException; +import org.apache.iceberg.hadoop.HadoopTables; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableList; +import org.apache.iceberg.relocated.com.google.common.collect.Lists; +import org.apache.iceberg.spark.SparkWriteOptions; +import org.apache.iceberg.types.Types; +import org.apache.spark.SparkException; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Encoders; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.SaveMode; +import org.apache.spark.sql.SparkSession; +import org.junit.AfterClass; +import org.junit.Assert; +import org.junit.Assume; +import org.junit.BeforeClass; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.TemporaryFolder; +import org.junit.runner.RunWith; +import org.junit.runners.Parameterized; + +import static org.apache.iceberg.TableProperties.SPARK_WRITE_PARTITIONED_FANOUT_ENABLED; +import static org.apache.iceberg.types.Types.NestedField.optional; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.spy; +import static org.mockito.Mockito.when; + +@RunWith(Parameterized.class) +public class TestSparkDataWrite { + private static final Configuration CONF = new Configuration(); + private final FileFormat format; + private static SparkSession spark = null; + private static final Schema SCHEMA = new Schema( + optional(1, "id", Types.IntegerType.get()), + optional(2, "data", Types.StringType.get()) + ); + + @Rule + public TemporaryFolder temp = new TemporaryFolder(); + + @Parameterized.Parameters(name = "format = {0}") + public static Object[] parameters() { + return new Object[] { "parquet", "avro", "orc" }; + } + + @BeforeClass + public static void startSpark() { + TestSparkDataWrite.spark = SparkSession.builder().master("local[2]").getOrCreate(); + } + + @Parameterized.AfterParam + public static void clearSourceCache() { + ManualSource.clearTables(); + } + + @AfterClass + public static void stopSpark() { + SparkSession currentSpark = TestSparkDataWrite.spark; + TestSparkDataWrite.spark = null; + currentSpark.stop(); + } + + public TestSparkDataWrite(String format) { + this.format = FileFormat.valueOf(format.toUpperCase(Locale.ENGLISH)); + } + + @Test + public void testBasicWrite() throws IOException { + File parent = temp.newFolder(format.toString()); + File location = new File(parent, "test"); + + HadoopTables tables = new HadoopTables(CONF); + PartitionSpec spec = PartitionSpec.builderFor(SCHEMA).identity("data").build(); + Table table = tables.create(SCHEMA, spec, location.toString()); + + List expected = Lists.newArrayList( + new SimpleRecord(1, "a"), + new SimpleRecord(2, "b"), + new SimpleRecord(3, "c") + ); + + Dataset df = spark.createDataFrame(expected, SimpleRecord.class); + // TODO: incoming columns must be ordered according to the table's schema + df.select("id", "data").write() + .format("iceberg") + .option(SparkWriteOptions.WRITE_FORMAT, format.toString()) + .mode(SaveMode.Append) + .save(location.toString()); + + table.refresh(); + + Dataset result = spark.read() + .format("iceberg") + .load(location.toString()); + + List actual = result.orderBy("id").as(Encoders.bean(SimpleRecord.class)).collectAsList(); + Assert.assertEquals("Number of rows should match", expected.size(), actual.size()); + Assert.assertEquals("Result rows should match", expected, actual); + for (ManifestFile manifest : table.currentSnapshot().allManifests(table.io())) { + for (DataFile file : ManifestFiles.read(manifest, table.io())) { + // TODO: avro not support split + if (!format.equals(FileFormat.AVRO)) { + Assert.assertNotNull("Split offsets not present", file.splitOffsets()); + } + Assert.assertEquals("Should have reported record count as 1", 1, file.recordCount()); + // TODO: append more metric info + if (format.equals(FileFormat.PARQUET)) { + Assert.assertNotNull("Column sizes metric not present", file.columnSizes()); + Assert.assertNotNull("Counts metric not present", file.valueCounts()); + Assert.assertNotNull("Null value counts metric not present", file.nullValueCounts()); + Assert.assertNotNull("Lower bounds metric not present", file.lowerBounds()); + Assert.assertNotNull("Upper bounds metric not present", file.upperBounds()); + } + } + } + } + + @Test + public void testAppend() throws IOException { + File parent = temp.newFolder(format.toString()); + File location = new File(parent, "test"); + + HadoopTables tables = new HadoopTables(CONF); + PartitionSpec spec = PartitionSpec.builderFor(SCHEMA).identity("data").build(); + Table table = tables.create(SCHEMA, spec, location.toString()); + + List records = Lists.newArrayList( + new SimpleRecord(1, "a"), + new SimpleRecord(2, "b"), + new SimpleRecord(3, "c") + ); + + List expected = Lists.newArrayList( + new SimpleRecord(1, "a"), + new SimpleRecord(2, "b"), + new SimpleRecord(3, "c"), + new SimpleRecord(4, "a"), + new SimpleRecord(5, "b"), + new SimpleRecord(6, "c") + ); + + Dataset df = spark.createDataFrame(records, SimpleRecord.class); + + df.select("id", "data").write() + .format("iceberg") + .option(SparkWriteOptions.WRITE_FORMAT, format.toString()) + .mode(SaveMode.Append) + .save(location.toString()); + + df.withColumn("id", df.col("id").plus(3)).select("id", "data").write() + .format("iceberg") + .option(SparkWriteOptions.WRITE_FORMAT, format.toString()) + .mode(SaveMode.Append) + .save(location.toString()); + + table.refresh(); + + Dataset result = spark.read() + .format("iceberg") + .load(location.toString()); + + List actual = result.orderBy("id").as(Encoders.bean(SimpleRecord.class)).collectAsList(); + Assert.assertEquals("Number of rows should match", expected.size(), actual.size()); + Assert.assertEquals("Result rows should match", expected, actual); + } + + @Test + public void testEmptyOverwrite() throws IOException { + File parent = temp.newFolder(format.toString()); + File location = new File(parent, "test"); + + HadoopTables tables = new HadoopTables(CONF); + PartitionSpec spec = PartitionSpec.builderFor(SCHEMA).identity("id").build(); + Table table = tables.create(SCHEMA, spec, location.toString()); + + List records = Lists.newArrayList( + new SimpleRecord(1, "a"), + new SimpleRecord(2, "b"), + new SimpleRecord(3, "c") + ); + + List expected = records; + Dataset df = spark.createDataFrame(records, SimpleRecord.class); + + df.select("id", "data").write() + .format("iceberg") + .option(SparkWriteOptions.WRITE_FORMAT, format.toString()) + .mode(SaveMode.Append) + .save(location.toString()); + + Dataset empty = spark.createDataFrame(ImmutableList.of(), SimpleRecord.class); + empty.select("id", "data").write() + .format("iceberg") + .option(SparkWriteOptions.WRITE_FORMAT, format.toString()) + .mode(SaveMode.Overwrite) + .option("overwrite-mode", "dynamic") + .save(location.toString()); + + table.refresh(); + + Dataset result = spark.read() + .format("iceberg") + .load(location.toString()); + + List actual = result.orderBy("id").as(Encoders.bean(SimpleRecord.class)).collectAsList(); + Assert.assertEquals("Number of rows should match", expected.size(), actual.size()); + Assert.assertEquals("Result rows should match", expected, actual); + } + + @Test + public void testOverwrite() throws IOException { + File parent = temp.newFolder(format.toString()); + File location = new File(parent, "test"); + + HadoopTables tables = new HadoopTables(CONF); + PartitionSpec spec = PartitionSpec.builderFor(SCHEMA).identity("id").build(); + Table table = tables.create(SCHEMA, spec, location.toString()); + + List records = Lists.newArrayList( + new SimpleRecord(1, "a"), + new SimpleRecord(2, "b"), + new SimpleRecord(3, "c") + ); + + List expected = Lists.newArrayList( + new SimpleRecord(1, "a"), + new SimpleRecord(2, "a"), + new SimpleRecord(3, "c"), + new SimpleRecord(4, "b"), + new SimpleRecord(6, "c") + ); + + Dataset df = spark.createDataFrame(records, SimpleRecord.class); + + df.select("id", "data").write() + .format("iceberg") + .option(SparkWriteOptions.WRITE_FORMAT, format.toString()) + .mode(SaveMode.Append) + .save(location.toString()); + + // overwrite with 2*id to replace record 2, append 4 and 6 + df.withColumn("id", df.col("id").multiply(2)).select("id", "data").write() + .format("iceberg") + .option(SparkWriteOptions.WRITE_FORMAT, format.toString()) + .mode(SaveMode.Overwrite) + .option("overwrite-mode", "dynamic") + .save(location.toString()); + + table.refresh(); + + Dataset result = spark.read() + .format("iceberg") + .load(location.toString()); + + List actual = result.orderBy("id").as(Encoders.bean(SimpleRecord.class)).collectAsList(); + Assert.assertEquals("Number of rows should match", expected.size(), actual.size()); + Assert.assertEquals("Result rows should match", expected, actual); + } + + @Test + public void testUnpartitionedOverwrite() throws IOException { + File parent = temp.newFolder(format.toString()); + File location = new File(parent, "test"); + + HadoopTables tables = new HadoopTables(CONF); + PartitionSpec spec = PartitionSpec.unpartitioned(); + Table table = tables.create(SCHEMA, spec, location.toString()); + + List expected = Lists.newArrayList( + new SimpleRecord(1, "a"), + new SimpleRecord(2, "b"), + new SimpleRecord(3, "c") + ); + + Dataset df = spark.createDataFrame(expected, SimpleRecord.class); + + df.select("id", "data").write() + .format("iceberg") + .option(SparkWriteOptions.WRITE_FORMAT, format.toString()) + .mode(SaveMode.Append) + .save(location.toString()); + + // overwrite with the same data; should not produce two copies + df.select("id", "data").write() + .format("iceberg") + .option(SparkWriteOptions.WRITE_FORMAT, format.toString()) + .mode(SaveMode.Overwrite) + .save(location.toString()); + + table.refresh(); + + Dataset result = spark.read() + .format("iceberg") + .load(location.toString()); + + List actual = result.orderBy("id").as(Encoders.bean(SimpleRecord.class)).collectAsList(); + Assert.assertEquals("Number of rows should match", expected.size(), actual.size()); + Assert.assertEquals("Result rows should match", expected, actual); + } + + @Test + public void testUnpartitionedCreateWithTargetFileSizeViaTableProperties() throws IOException { + File parent = temp.newFolder(format.toString()); + File location = new File(parent, "test"); + + HadoopTables tables = new HadoopTables(CONF); + PartitionSpec spec = PartitionSpec.unpartitioned(); + Table table = tables.create(SCHEMA, spec, location.toString()); + + table.updateProperties() + .set(TableProperties.WRITE_TARGET_FILE_SIZE_BYTES, "4") // ~4 bytes; low enough to trigger + .commit(); + + List expected = Lists.newArrayListWithCapacity(4000); + for (int i = 0; i < 4000; i++) { + expected.add(new SimpleRecord(i, "a")); + } + + Dataset df = spark.createDataFrame(expected, SimpleRecord.class); + + df.select("id", "data").write() + .format("iceberg") + .option(SparkWriteOptions.WRITE_FORMAT, format.toString()) + .mode(SaveMode.Append) + .save(location.toString()); + + table.refresh(); + + Dataset result = spark.read() + .format("iceberg") + .load(location.toString()); + + List actual = result.orderBy("id").as(Encoders.bean(SimpleRecord.class)).collectAsList(); + Assert.assertEquals("Number of rows should match", expected.size(), actual.size()); + Assert.assertEquals("Result rows should match", expected, actual); + + List files = Lists.newArrayList(); + for (ManifestFile manifest : table.currentSnapshot().allManifests(table.io())) { + for (DataFile file : ManifestFiles.read(manifest, table.io())) { + files.add(file); + } + } + + Assert.assertEquals("Should have 4 DataFiles", 4, files.size()); + Assert.assertTrue("All DataFiles contain 1000 rows", files.stream().allMatch(d -> d.recordCount() == 1000)); + } + + @Test + public void testPartitionedCreateWithTargetFileSizeViaOption() throws IOException { + partitionedCreateWithTargetFileSizeViaOption(IcebergOptionsType.NONE); + } + + @Test + public void testPartitionedFanoutCreateWithTargetFileSizeViaOption() throws IOException { + partitionedCreateWithTargetFileSizeViaOption(IcebergOptionsType.TABLE); + } + + @Test + public void testPartitionedFanoutCreateWithTargetFileSizeViaOption2() throws IOException { + partitionedCreateWithTargetFileSizeViaOption(IcebergOptionsType.JOB); + } + + @Test + public void testWriteProjection() throws IOException { + Assume.assumeTrue( + "Not supported in Spark 3.0; analysis requires all columns are present", + spark.version().startsWith("2")); + + File parent = temp.newFolder(format.toString()); + File location = new File(parent, "test"); + + HadoopTables tables = new HadoopTables(CONF); + PartitionSpec spec = PartitionSpec.unpartitioned(); + Table table = tables.create(SCHEMA, spec, location.toString()); + + List expected = Lists.newArrayList( + new SimpleRecord(1, null), + new SimpleRecord(2, null), + new SimpleRecord(3, null) + ); + + Dataset df = spark.createDataFrame(expected, SimpleRecord.class); + + df.select("id").write() // select only id column + .format("iceberg") + .option(SparkWriteOptions.WRITE_FORMAT, format.toString()) + .mode(SaveMode.Append) + .save(location.toString()); + + table.refresh(); + + Dataset result = spark.read() + .format("iceberg") + .load(location.toString()); + + List actual = result.orderBy("id").as(Encoders.bean(SimpleRecord.class)).collectAsList(); + Assert.assertEquals("Number of rows should match", expected.size(), actual.size()); + Assert.assertEquals("Result rows should match", expected, actual); + } + + @Test + public void testWriteProjectionWithMiddle() throws IOException { + Assume.assumeTrue( + "Not supported in Spark 3.0; analysis requires all columns are present", + spark.version().startsWith("2")); + + File parent = temp.newFolder(format.toString()); + File location = new File(parent, "test"); + + HadoopTables tables = new HadoopTables(CONF); + PartitionSpec spec = PartitionSpec.unpartitioned(); + Schema schema = new Schema( + optional(1, "c1", Types.IntegerType.get()), + optional(2, "c2", Types.StringType.get()), + optional(3, "c3", Types.StringType.get()) + ); + Table table = tables.create(schema, spec, location.toString()); + + List expected = Lists.newArrayList( + new ThreeColumnRecord(1, null, "hello"), + new ThreeColumnRecord(2, null, "world"), + new ThreeColumnRecord(3, null, null) + ); + + Dataset df = spark.createDataFrame(expected, ThreeColumnRecord.class); + + df.select("c1", "c3").write() + .format("iceberg") + .option(SparkWriteOptions.WRITE_FORMAT, format.toString()) + .mode(SaveMode.Append) + .save(location.toString()); + + table.refresh(); + + Dataset result = spark.read() + .format("iceberg") + .load(location.toString()); + + List actual = result.orderBy("c1").as(Encoders.bean(ThreeColumnRecord.class)).collectAsList(); + Assert.assertEquals("Number of rows should match", expected.size(), actual.size()); + Assert.assertEquals("Result rows should match", expected, actual); + } + + @Test + public void testViewsReturnRecentResults() throws IOException { + File parent = temp.newFolder(format.toString()); + File location = new File(parent, "test"); + + HadoopTables tables = new HadoopTables(CONF); + PartitionSpec spec = PartitionSpec.builderFor(SCHEMA).identity("data").build(); + tables.create(SCHEMA, spec, location.toString()); + + List records = Lists.newArrayList( + new SimpleRecord(1, "a"), + new SimpleRecord(2, "b"), + new SimpleRecord(3, "c") + ); + + Dataset df = spark.createDataFrame(records, SimpleRecord.class); + + df.select("id", "data").write() + .format("iceberg") + .option(SparkWriteOptions.WRITE_FORMAT, format.toString()) + .mode(SaveMode.Append) + .save(location.toString()); + + Dataset query = spark.read() + .format("iceberg") + .load(location.toString()) + .where("id = 1"); + query.createOrReplaceTempView("tmp"); + + List actual1 = spark.table("tmp").as(Encoders.bean(SimpleRecord.class)).collectAsList(); + List expected1 = Lists.newArrayList( + new SimpleRecord(1, "a") + ); + Assert.assertEquals("Number of rows should match", expected1.size(), actual1.size()); + Assert.assertEquals("Result rows should match", expected1, actual1); + + df.select("id", "data").write() + .format("iceberg") + .option(SparkWriteOptions.WRITE_FORMAT, format.toString()) + .mode(SaveMode.Append) + .save(location.toString()); + + List actual2 = spark.table("tmp").as(Encoders.bean(SimpleRecord.class)).collectAsList(); + List expected2 = Lists.newArrayList( + new SimpleRecord(1, "a"), + new SimpleRecord(1, "a") + ); + Assert.assertEquals("Number of rows should match", expected2.size(), actual2.size()); + Assert.assertEquals("Result rows should match", expected2, actual2); + } + + public void partitionedCreateWithTargetFileSizeViaOption(IcebergOptionsType option) + throws IOException { + File parent = temp.newFolder(format.toString()); + File location = new File(parent, "test"); + + HadoopTables tables = new HadoopTables(CONF); + PartitionSpec spec = PartitionSpec.builderFor(SCHEMA).identity("data").build(); + Table table = tables.create(SCHEMA, spec, location.toString()); + + List expected = Lists.newArrayListWithCapacity(8000); + for (int i = 0; i < 2000; i++) { + expected.add(new SimpleRecord(i, "a")); + expected.add(new SimpleRecord(i, "b")); + expected.add(new SimpleRecord(i, "c")); + expected.add(new SimpleRecord(i, "d")); + } + + Dataset df = spark.createDataFrame(expected, SimpleRecord.class); + + switch (option) { + case NONE: + df.select("id", "data").sort("data").write() + .format("iceberg") + .option(SparkWriteOptions.WRITE_FORMAT, format.toString()) + .mode(SaveMode.Append) + .option(SparkWriteOptions.TARGET_FILE_SIZE_BYTES, 4) // ~4 bytes; low enough to trigger + .save(location.toString()); + break; + case TABLE: + table.updateProperties().set(SPARK_WRITE_PARTITIONED_FANOUT_ENABLED, "true").commit(); + df.select("id", "data").write() + .format("iceberg") + .option(SparkWriteOptions.WRITE_FORMAT, format.toString()) + .mode(SaveMode.Append) + .option(SparkWriteOptions.TARGET_FILE_SIZE_BYTES, 4) // ~4 bytes; low enough to trigger + .save(location.toString()); + break; + case JOB: + df.select("id", "data").write() + .format("iceberg") + .option(SparkWriteOptions.WRITE_FORMAT, format.toString()) + .mode(SaveMode.Append) + .option(SparkWriteOptions.TARGET_FILE_SIZE_BYTES, 4) // ~4 bytes; low enough to trigger + .option(SparkWriteOptions.FANOUT_ENABLED, true) + .save(location.toString()); + break; + default: + break; + } + + table.refresh(); + + Dataset result = spark.read() + .format("iceberg") + .load(location.toString()); + + List actual = result.orderBy("id").as(Encoders.bean(SimpleRecord.class)).collectAsList(); + Assert.assertEquals("Number of rows should match", expected.size(), actual.size()); + Assert.assertEquals("Result rows should match", expected, actual); + + List files = Lists.newArrayList(); + for (ManifestFile manifest : table.currentSnapshot().allManifests(table.io())) { + for (DataFile file : ManifestFiles.read(manifest, table.io())) { + files.add(file); + } + } + + Assert.assertEquals("Should have 8 DataFiles", 8, files.size()); + Assert.assertTrue("All DataFiles contain 1000 rows", files.stream().allMatch(d -> d.recordCount() == 1000)); + } + + @Test + public void testCommitUnknownException() throws IOException { + File parent = temp.newFolder(format.toString()); + File location = new File(parent, "commitunknown"); + + HadoopTables tables = new HadoopTables(CONF); + PartitionSpec spec = PartitionSpec.builderFor(SCHEMA).identity("data").build(); + Table table = tables.create(SCHEMA, spec, location.toString()); + + List records = Lists.newArrayList( + new SimpleRecord(1, "a"), + new SimpleRecord(2, "b"), + new SimpleRecord(3, "c") + ); + + Dataset df = spark.createDataFrame(records, SimpleRecord.class); + + AppendFiles append = table.newFastAppend(); + AppendFiles spyAppend = spy(append); + doAnswer(invocation -> { + append.commit(); + throw new CommitStateUnknownException(new RuntimeException("Datacenter on Fire")); + }).when(spyAppend).commit(); + + Table spyTable = spy(table); + when(spyTable.newAppend()).thenReturn(spyAppend); + SparkTable sparkTable = new SparkTable(spyTable, false); + + String manualTableName = "unknown_exception"; + ManualSource.setTable(manualTableName, sparkTable); + + // Although an exception is thrown here, write and commit have succeeded + AssertHelpers.assertThrowsWithCause("Should throw a Commit State Unknown Exception", + SparkException.class, + "Writing job aborted", + CommitStateUnknownException.class, + "Datacenter on Fire", + () -> df.select("id", "data").sort("data").write() + .format("org.apache.iceberg.spark.source.ManualSource") + .option(ManualSource.TABLE_NAME, manualTableName) + .mode(SaveMode.Append) + .save(location.toString())); + + // Since write and commit succeeded, the rows should be readable + Dataset result = spark.read().format("iceberg").load(location.toString()); + List actual = result.orderBy("id").as(Encoders.bean(SimpleRecord.class)).collectAsList(); + Assert.assertEquals("Number of rows should match", records.size(), actual.size()); + Assert.assertEquals("Result rows should match", records, actual); + } + + public enum IcebergOptionsType { + NONE, + TABLE, + JOB + } +} diff --git a/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/source/TestSparkFileWriterFactory.java b/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/source/TestSparkFileWriterFactory.java new file mode 100644 index 000000000000..702e8ab98990 --- /dev/null +++ b/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/source/TestSparkFileWriterFactory.java @@ -0,0 +1,73 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iceberg.spark.source; + +import java.util.List; +import org.apache.iceberg.FileFormat; +import org.apache.iceberg.Schema; +import org.apache.iceberg.io.FileWriterFactory; +import org.apache.iceberg.io.TestFileWriterFactory; +import org.apache.iceberg.spark.SparkSchemaUtil; +import org.apache.iceberg.util.ArrayUtil; +import org.apache.iceberg.util.StructLikeSet; +import org.apache.spark.sql.catalyst.InternalRow; +import org.apache.spark.sql.catalyst.expressions.GenericInternalRow; +import org.apache.spark.sql.types.StructType; +import org.apache.spark.unsafe.types.UTF8String; + +public class TestSparkFileWriterFactory extends TestFileWriterFactory { + + public TestSparkFileWriterFactory(FileFormat fileFormat, boolean partitioned) { + super(fileFormat, partitioned); + } + + @Override + protected FileWriterFactory newWriterFactory(Schema dataSchema, List equalityFieldIds, + Schema equalityDeleteRowSchema, + Schema positionDeleteRowSchema) { + return SparkFileWriterFactory.builderFor(table) + .dataSchema(table.schema()) + .dataFileFormat(format()) + .deleteFileFormat(format()) + .equalityFieldIds(ArrayUtil.toIntArray(equalityFieldIds)) + .equalityDeleteRowSchema(equalityDeleteRowSchema) + .positionDeleteRowSchema(positionDeleteRowSchema) + .build(); + } + + @Override + protected InternalRow toRow(Integer id, String data) { + InternalRow row = new GenericInternalRow(2); + row.update(0, id); + row.update(1, UTF8String.fromString(data)); + return row; + } + + @Override + protected StructLikeSet toSet(Iterable rows) { + StructLikeSet set = StructLikeSet.create(table.schema().asStruct()); + StructType sparkType = SparkSchemaUtil.convert(table.schema()); + for (InternalRow row : rows) { + InternalRowWrapper wrapper = new InternalRowWrapper(sparkType); + set.add(wrapper.wrap(row)); + } + return set; + } +} diff --git a/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/source/TestSparkFilesScan.java b/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/source/TestSparkFilesScan.java new file mode 100644 index 000000000000..63195cfd3967 --- /dev/null +++ b/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/source/TestSparkFilesScan.java @@ -0,0 +1,121 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iceberg.spark.source; + +import java.io.IOException; +import java.util.List; +import java.util.Map; +import java.util.UUID; +import org.apache.iceberg.FileScanTask; +import org.apache.iceberg.Table; +import org.apache.iceberg.io.CloseableIterable; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableList; +import org.apache.iceberg.relocated.com.google.common.collect.Iterables; +import org.apache.iceberg.spark.FileScanTaskSetManager; +import org.apache.iceberg.spark.SparkCatalogTestBase; +import org.apache.iceberg.spark.SparkReadOptions; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.catalyst.analysis.NoSuchTableException; +import org.junit.After; +import org.junit.Assert; +import org.junit.Test; + +public class TestSparkFilesScan extends SparkCatalogTestBase { + + public TestSparkFilesScan(String catalogName, String implementation, Map config) { + super(catalogName, implementation, config); + } + + @After + public void removeTables() { + sql("DROP TABLE IF EXISTS %s", tableName); + } + + @Test + public void testTaskSetLoading() throws NoSuchTableException, IOException { + sql("CREATE TABLE %s (id INT, data STRING) USING iceberg", tableName); + + List records = ImmutableList.of( + new SimpleRecord(1, "a"), + new SimpleRecord(2, "b") + ); + Dataset df = spark.createDataFrame(records, SimpleRecord.class); + df.writeTo(tableName).append(); + + Table table = validationCatalog.loadTable(tableIdent); + Assert.assertEquals("Should produce 1 snapshot", 1, Iterables.size(table.snapshots())); + + try (CloseableIterable fileScanTasks = table.newScan().planFiles()) { + FileScanTaskSetManager taskSetManager = FileScanTaskSetManager.get(); + String setID = UUID.randomUUID().toString(); + taskSetManager.stageTasks(table, setID, ImmutableList.copyOf(fileScanTasks)); + + // load the staged file set + Dataset scanDF = spark.read().format("iceberg") + .option(SparkReadOptions.FILE_SCAN_TASK_SET_ID, setID) + .load(tableName); + + // write the records back essentially duplicating data + scanDF.writeTo(tableName).append(); + } + + assertEquals("Should have expected rows", + ImmutableList.of(row(1, "a"), row(1, "a"), row(2, "b"), row(2, "b")), + sql("SELECT * FROM %s ORDER BY id", tableName)); + } + + @Test + public void testTaskSetPlanning() throws NoSuchTableException, IOException { + sql("CREATE TABLE %s (id INT, data STRING) USING iceberg", tableName); + + List records = ImmutableList.of( + new SimpleRecord(1, "a"), + new SimpleRecord(2, "b") + ); + Dataset df = spark.createDataFrame(records, SimpleRecord.class); + df.coalesce(1).writeTo(tableName).append(); + df.coalesce(1).writeTo(tableName).append(); + + Table table = validationCatalog.loadTable(tableIdent); + Assert.assertEquals("Should produce 2 snapshots", 2, Iterables.size(table.snapshots())); + + try (CloseableIterable fileScanTasks = table.newScan().planFiles()) { + FileScanTaskSetManager taskSetManager = FileScanTaskSetManager.get(); + String setID = UUID.randomUUID().toString(); + List tasks = ImmutableList.copyOf(fileScanTasks); + taskSetManager.stageTasks(table, setID, tasks); + + // load the staged file set and make sure each file is in a separate split + Dataset scanDF = spark.read().format("iceberg") + .option(SparkReadOptions.FILE_SCAN_TASK_SET_ID, setID) + .option(SparkReadOptions.SPLIT_SIZE, tasks.get(0).file().fileSizeInBytes()) + .load(tableName); + Assert.assertEquals("Num partitions should match", 2, scanDF.javaRDD().getNumPartitions()); + + // load the staged file set and make sure we combine both files into a single split + scanDF = spark.read().format("iceberg") + .option(SparkReadOptions.FILE_SCAN_TASK_SET_ID, setID) + .option(SparkReadOptions.SPLIT_SIZE, Long.MAX_VALUE) + .load(tableName); + Assert.assertEquals("Num partitions should match", 1, scanDF.javaRDD().getNumPartitions()); + } + } +} diff --git a/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/source/TestSparkMergingMetrics.java b/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/source/TestSparkMergingMetrics.java new file mode 100644 index 000000000000..be74d1c5a33b --- /dev/null +++ b/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/source/TestSparkMergingMetrics.java @@ -0,0 +1,68 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iceberg.spark.source; + +import java.io.IOException; +import java.util.Collections; +import java.util.List; +import java.util.Map; +import org.apache.iceberg.BaseTable; +import org.apache.iceberg.FileFormat; +import org.apache.iceberg.PartitionSpec; +import org.apache.iceberg.SortOrder; +import org.apache.iceberg.Table; +import org.apache.iceberg.TestMergingMetrics; +import org.apache.iceberg.data.Record; +import org.apache.iceberg.io.FileAppender; +import org.apache.iceberg.spark.SparkSchemaUtil; +import org.apache.spark.sql.catalyst.InternalRow; + +public class TestSparkMergingMetrics extends TestMergingMetrics { + + public TestSparkMergingMetrics(FileFormat fileFormat) { + super(fileFormat); + } + + @Override + protected FileAppender writeAndGetAppender(List records) throws IOException { + Table testTable = new BaseTable(null, "dummy") { + @Override + public Map properties() { + return Collections.emptyMap(); + } + @Override + public SortOrder sortOrder() { + return SortOrder.unsorted(); + } + @Override + public PartitionSpec spec() { + return PartitionSpec.unpartitioned(); + } + }; + + FileAppender appender = + SparkAppenderFactory.builderFor(testTable, SCHEMA, SparkSchemaUtil.convert(SCHEMA)).build() + .newAppender(org.apache.iceberg.Files.localOutput(temp.newFile()), fileFormat); + try (FileAppender fileAppender = appender) { + records.stream().map(r -> new StructInternalRow(SCHEMA.asStruct()).setStruct(r)).forEach(fileAppender::add); + } + return appender; + } +} diff --git a/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/source/TestSparkMetadataColumns.java b/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/source/TestSparkMetadataColumns.java new file mode 100644 index 000000000000..9477bc985295 --- /dev/null +++ b/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/source/TestSparkMetadataColumns.java @@ -0,0 +1,288 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iceberg.spark.source; + +import java.io.IOException; +import java.util.List; +import java.util.stream.Collectors; +import org.apache.iceberg.AssertHelpers; +import org.apache.iceberg.FileFormat; +import org.apache.iceberg.HasTableOperations; +import org.apache.iceberg.MetadataColumns; +import org.apache.iceberg.PartitionSpec; +import org.apache.iceberg.PartitionSpecParser; +import org.apache.iceberg.Schema; +import org.apache.iceberg.Table; +import org.apache.iceberg.TableMetadata; +import org.apache.iceberg.TableOperations; +import org.apache.iceberg.UpdateProperties; +import org.apache.iceberg.exceptions.ValidationException; +import org.apache.iceberg.expressions.Expressions; +import org.apache.iceberg.relocated.com.google.common.base.Preconditions; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableList; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap; +import org.apache.iceberg.relocated.com.google.common.collect.Lists; +import org.apache.iceberg.spark.SparkTestBase; +import org.apache.iceberg.types.Types; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Encoders; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.catalyst.analysis.NoSuchTableException; +import org.junit.After; +import org.junit.Assert; +import org.junit.Assume; +import org.junit.Before; +import org.junit.BeforeClass; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.TemporaryFolder; +import org.junit.runner.RunWith; +import org.junit.runners.Parameterized; + +import static org.apache.iceberg.TableProperties.DEFAULT_FILE_FORMAT; +import static org.apache.iceberg.TableProperties.FORMAT_VERSION; +import static org.apache.iceberg.TableProperties.ORC_VECTORIZATION_ENABLED; +import static org.apache.iceberg.TableProperties.PARQUET_BATCH_SIZE; +import static org.apache.iceberg.TableProperties.PARQUET_ROW_GROUP_SIZE_BYTES; +import static org.apache.iceberg.TableProperties.PARQUET_VECTORIZATION_ENABLED; +import static org.apache.spark.sql.functions.lit; + +@RunWith(Parameterized.class) +public class TestSparkMetadataColumns extends SparkTestBase { + + private static final String TABLE_NAME = "test_table"; + private static final Schema SCHEMA = new Schema( + Types.NestedField.optional(1, "id", Types.LongType.get()), + Types.NestedField.optional(2, "category", Types.StringType.get()), + Types.NestedField.optional(3, "data", Types.StringType.get()) + ); + private static final PartitionSpec UNKNOWN_SPEC = PartitionSpecParser.fromJson(SCHEMA, + "{ \"spec-id\": 1, \"fields\": [ { \"name\": \"id_zero\", \"transform\": \"zero\", \"source-id\": 1 } ] }"); + + @Parameterized.Parameters(name = "fileFormat = {0}, vectorized = {1}, formatVersion = {2}") + public static Object[][] parameters() { + return new Object[][] { + { FileFormat.PARQUET, false, 1}, + { FileFormat.PARQUET, true, 1}, + { FileFormat.PARQUET, false, 2}, + { FileFormat.PARQUET, true, 2}, + { FileFormat.AVRO, false, 1}, + { FileFormat.AVRO, false, 2}, + { FileFormat.ORC, false, 1}, + { FileFormat.ORC, true, 1}, + { FileFormat.ORC, false, 2}, + { FileFormat.ORC, true, 2}, + }; + } + + @Rule + public TemporaryFolder temp = new TemporaryFolder(); + + private final FileFormat fileFormat; + private final boolean vectorized; + private final int formatVersion; + + private Table table = null; + + public TestSparkMetadataColumns(FileFormat fileFormat, boolean vectorized, int formatVersion) { + this.fileFormat = fileFormat; + this.vectorized = vectorized; + this.formatVersion = formatVersion; + } + + @BeforeClass + public static void setupSpark() { + ImmutableMap config = ImmutableMap.of( + "type", "hive", + "default-namespace", "default", + "cache-enabled", "true" + ); + spark.conf().set("spark.sql.catalog.spark_catalog", "org.apache.iceberg.spark.source.TestSparkCatalog"); + config.forEach((key, value) -> spark.conf().set("spark.sql.catalog.spark_catalog." + key, value)); + } + + @Before + public void setupTable() throws IOException { + createAndInitTable(); + } + + @After + public void dropTable() { + TestTables.clearTables(); + } + + @Test + public void testSpecAndPartitionMetadataColumns() { + // TODO: support metadata structs in vectorized ORC reads + Assume.assumeFalse(fileFormat == FileFormat.ORC && vectorized); + + sql("INSERT INTO TABLE %s VALUES (1, 'a1', 'b1')", TABLE_NAME); + + table.refresh(); + table.updateSpec() + .addField("data") + .commit(); + sql("INSERT INTO TABLE %s VALUES (1, 'a1', 'b1')", TABLE_NAME); + + table.refresh(); + table.updateSpec() + .addField(Expressions.bucket("category", 8)) + .commit(); + sql("INSERT INTO TABLE %s VALUES (1, 'a1', 'b1')", TABLE_NAME); + + table.refresh(); + table.updateSpec() + .removeField("data") + .commit(); + sql("INSERT INTO TABLE %s VALUES (1, 'a1', 'b1')", TABLE_NAME); + + table.refresh(); + table.updateSpec() + .renameField("category_bucket_8", "category_bucket_8_another_name") + .commit(); + + List expected = ImmutableList.of( + row(0, row(null, null)), + row(1, row("b1", null)), + row(2, row("b1", 2)), + row(3, row(null, 2)) + ); + assertEquals("Rows must match", expected, + sql("SELECT _spec_id, _partition FROM %s ORDER BY _spec_id", TABLE_NAME)); + } + + @Test + public void testPositionMetadataColumnWithMultipleRowGroups() throws NoSuchTableException { + Assume.assumeTrue(fileFormat == FileFormat.PARQUET); + + table.updateProperties() + .set(PARQUET_ROW_GROUP_SIZE_BYTES, "100") + .commit(); + + List ids = Lists.newArrayList(); + for (long id = 0L; id < 200L; id++) { + ids.add(id); + } + Dataset df = spark.createDataset(ids, Encoders.LONG()) + .withColumnRenamed("value", "id") + .withColumn("category", lit("hr")) + .withColumn("data", lit("ABCDEF")); + df.coalesce(1).writeTo(TABLE_NAME).append(); + + Assert.assertEquals(200, spark.table(TABLE_NAME).count()); + + List expectedRows = ids.stream() + .map(this::row) + .collect(Collectors.toList()); + assertEquals("Rows must match", + expectedRows, + sql("SELECT _pos FROM %s", TABLE_NAME)); + } + + @Test + public void testPositionMetadataColumnWithMultipleBatches() throws NoSuchTableException { + Assume.assumeTrue(fileFormat == FileFormat.PARQUET); + + table.updateProperties() + .set(PARQUET_BATCH_SIZE, "1000") + .commit(); + + List ids = Lists.newArrayList(); + for (long id = 0L; id < 7500L; id++) { + ids.add(id); + } + Dataset df = spark.createDataset(ids, Encoders.LONG()) + .withColumnRenamed("value", "id") + .withColumn("category", lit("hr")) + .withColumn("data", lit("ABCDEF")); + df.coalesce(1).writeTo(TABLE_NAME).append(); + + Assert.assertEquals(7500, spark.table(TABLE_NAME).count()); + + List expectedRows = ids.stream() + .map(this::row) + .collect(Collectors.toList()); + assertEquals("Rows must match", + expectedRows, + sql("SELECT _pos FROM %s", TABLE_NAME)); + } + + @Test + public void testPartitionMetadataColumnWithUnknownTransforms() { + // replace the table spec to include an unknown transform + TableOperations ops = ((HasTableOperations) table).operations(); + TableMetadata base = ops.current(); + ops.commit(base, base.updatePartitionSpec(UNKNOWN_SPEC)); + + AssertHelpers.assertThrows("Should fail to query the partition metadata column", + ValidationException.class, "Cannot build table partition type, unknown transforms", + () -> sql("SELECT _partition FROM %s", TABLE_NAME)); + } + + @Test + public void testConflictingColumns() { + table.updateSchema() + .addColumn(MetadataColumns.SPEC_ID.name(), Types.IntegerType.get()) + .addColumn(MetadataColumns.FILE_PATH.name(), Types.StringType.get()) + .commit(); + + sql("INSERT INTO TABLE %s VALUES (1, 'a1', 'b1', -1, 'path/to/file')", TABLE_NAME); + + assertEquals("Rows must match", + ImmutableList.of(row(1L, "a1")), + sql("SELECT id, category FROM %s", TABLE_NAME)); + + AssertHelpers.assertThrows("Should fail to query conflicting columns", + ValidationException.class, "column names conflict", + () -> sql("SELECT * FROM %s", TABLE_NAME)); + + table.refresh(); + + table.updateSchema() + .renameColumn(MetadataColumns.SPEC_ID.name(), "_renamed" + MetadataColumns.SPEC_ID.name()) + .renameColumn(MetadataColumns.FILE_PATH.name(), "_renamed" + MetadataColumns.FILE_PATH.name()) + .commit(); + + assertEquals("Rows must match", + ImmutableList.of(row(0, null, -1)), + sql("SELECT _spec_id, _partition, _renamed_spec_id FROM %s", TABLE_NAME)); + } + + private void createAndInitTable() throws IOException { + this.table = TestTables.create(temp.newFolder(), TABLE_NAME, SCHEMA, PartitionSpec.unpartitioned()); + + UpdateProperties updateProperties = table.updateProperties(); + updateProperties.set(FORMAT_VERSION, String.valueOf(formatVersion)); + updateProperties.set(DEFAULT_FILE_FORMAT, fileFormat.name()); + + switch (fileFormat) { + case PARQUET: + updateProperties.set(PARQUET_VECTORIZATION_ENABLED, String.valueOf(vectorized)); + break; + case ORC: + updateProperties.set(ORC_VECTORIZATION_ENABLED, String.valueOf(vectorized)); + break; + default: + Preconditions.checkState(!vectorized, "File format %s does not support vectorized reads", fileFormat); + } + + updateProperties.commit(); + } +} diff --git a/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/source/TestSparkPartitioningWriters.java b/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/source/TestSparkPartitioningWriters.java new file mode 100644 index 000000000000..4d07cfbe86ea --- /dev/null +++ b/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/source/TestSparkPartitioningWriters.java @@ -0,0 +1,73 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iceberg.spark.source; + +import java.util.List; +import org.apache.iceberg.FileFormat; +import org.apache.iceberg.Schema; +import org.apache.iceberg.io.FileWriterFactory; +import org.apache.iceberg.io.TestPartitioningWriters; +import org.apache.iceberg.spark.SparkSchemaUtil; +import org.apache.iceberg.util.ArrayUtil; +import org.apache.iceberg.util.StructLikeSet; +import org.apache.spark.sql.catalyst.InternalRow; +import org.apache.spark.sql.catalyst.expressions.GenericInternalRow; +import org.apache.spark.sql.types.StructType; +import org.apache.spark.unsafe.types.UTF8String; + +public class TestSparkPartitioningWriters extends TestPartitioningWriters { + + public TestSparkPartitioningWriters(FileFormat fileFormat) { + super(fileFormat); + } + + @Override + protected FileWriterFactory newWriterFactory(Schema dataSchema, List equalityFieldIds, + Schema equalityDeleteRowSchema, + Schema positionDeleteRowSchema) { + return SparkFileWriterFactory.builderFor(table) + .dataSchema(table.schema()) + .dataFileFormat(format()) + .deleteFileFormat(format()) + .equalityFieldIds(ArrayUtil.toIntArray(equalityFieldIds)) + .equalityDeleteRowSchema(equalityDeleteRowSchema) + .positionDeleteRowSchema(positionDeleteRowSchema) + .build(); + } + + @Override + protected InternalRow toRow(Integer id, String data) { + InternalRow row = new GenericInternalRow(2); + row.update(0, id); + row.update(1, UTF8String.fromString(data)); + return row; + } + + @Override + protected StructLikeSet toSet(Iterable rows) { + StructLikeSet set = StructLikeSet.create(table.schema().asStruct()); + StructType sparkType = SparkSchemaUtil.convert(table.schema()); + for (InternalRow row : rows) { + InternalRowWrapper wrapper = new InternalRowWrapper(sparkType); + set.add(wrapper.wrap(row)); + } + return set; + } +} diff --git a/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/source/TestSparkPositionDeltaWriters.java b/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/source/TestSparkPositionDeltaWriters.java new file mode 100644 index 000000000000..480448e13a8f --- /dev/null +++ b/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/source/TestSparkPositionDeltaWriters.java @@ -0,0 +1,73 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iceberg.spark.source; + +import java.util.List; +import org.apache.iceberg.FileFormat; +import org.apache.iceberg.Schema; +import org.apache.iceberg.io.FileWriterFactory; +import org.apache.iceberg.io.TestPositionDeltaWriters; +import org.apache.iceberg.spark.SparkSchemaUtil; +import org.apache.iceberg.util.ArrayUtil; +import org.apache.iceberg.util.StructLikeSet; +import org.apache.spark.sql.catalyst.InternalRow; +import org.apache.spark.sql.catalyst.expressions.GenericInternalRow; +import org.apache.spark.sql.types.StructType; +import org.apache.spark.unsafe.types.UTF8String; + +public class TestSparkPositionDeltaWriters extends TestPositionDeltaWriters { + + public TestSparkPositionDeltaWriters(FileFormat fileFormat) { + super(fileFormat); + } + + @Override + protected FileWriterFactory newWriterFactory(Schema dataSchema, List equalityFieldIds, + Schema equalityDeleteRowSchema, + Schema positionDeleteRowSchema) { + return SparkFileWriterFactory.builderFor(table) + .dataSchema(table.schema()) + .dataFileFormat(format()) + .deleteFileFormat(format()) + .equalityFieldIds(ArrayUtil.toIntArray(equalityFieldIds)) + .equalityDeleteRowSchema(equalityDeleteRowSchema) + .positionDeleteRowSchema(positionDeleteRowSchema) + .build(); + } + + @Override + protected InternalRow toRow(Integer id, String data) { + InternalRow row = new GenericInternalRow(2); + row.update(0, id); + row.update(1, UTF8String.fromString(data)); + return row; + } + + @Override + protected StructLikeSet toSet(Iterable rows) { + StructLikeSet set = StructLikeSet.create(table.schema().asStruct()); + StructType sparkType = SparkSchemaUtil.convert(table.schema()); + for (InternalRow row : rows) { + InternalRowWrapper wrapper = new InternalRowWrapper(sparkType); + set.add(wrapper.wrap(row)); + } + return set; + } +} diff --git a/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/source/TestSparkReadProjection.java b/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/source/TestSparkReadProjection.java new file mode 100644 index 000000000000..f42b48d0e30d --- /dev/null +++ b/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/source/TestSparkReadProjection.java @@ -0,0 +1,243 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iceberg.spark.source; + +import java.io.File; +import java.io.IOException; +import java.util.List; +import java.util.Locale; +import java.util.Map; +import java.util.UUID; +import org.apache.iceberg.DataFile; +import org.apache.iceberg.DataFiles; +import org.apache.iceberg.FileFormat; +import org.apache.iceberg.PartitionSpec; +import org.apache.iceberg.Schema; +import org.apache.iceberg.Table; +import org.apache.iceberg.data.GenericAppenderFactory; +import org.apache.iceberg.data.Record; +import org.apache.iceberg.io.FileAppender; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap; +import org.apache.iceberg.relocated.com.google.common.collect.Lists; +import org.apache.iceberg.relocated.com.google.common.collect.Maps; +import org.apache.iceberg.spark.SparkReadOptions; +import org.apache.iceberg.spark.SparkValueConverter; +import org.apache.iceberg.types.Type; +import org.apache.iceberg.types.TypeUtil; +import org.apache.iceberg.types.Types; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.SparkSession; +import org.junit.AfterClass; +import org.junit.Assert; +import org.junit.BeforeClass; +import org.junit.runner.RunWith; +import org.junit.runners.Parameterized; + +import static org.apache.iceberg.Files.localOutput; +import static org.apache.iceberg.types.Types.NestedField.optional; +import static org.apache.iceberg.types.Types.NestedField.required; + +@RunWith(Parameterized.class) +public class TestSparkReadProjection extends TestReadProjection { + + private static SparkSession spark = null; + + @Parameterized.Parameters(name = "format = {0}, vectorized = {1}") + public static Object[][] parameters() { + return new Object[][] { + { "parquet", false }, + { "parquet", true }, + { "avro", false }, + { "orc", false }, + { "orc", true } + }; + } + + private final FileFormat format; + private final boolean vectorized; + + public TestSparkReadProjection(String format, boolean vectorized) { + super(format); + this.format = FileFormat.valueOf(format.toUpperCase(Locale.ROOT)); + this.vectorized = vectorized; + } + + @BeforeClass + public static void startSpark() { + TestSparkReadProjection.spark = SparkSession.builder().master("local[2]").getOrCreate(); + ImmutableMap config = ImmutableMap.of( + "type", "hive", + "default-namespace", "default", + "parquet-enabled", "true", + "cache-enabled", "false" + ); + spark.conf().set("spark.sql.catalog.spark_catalog", "org.apache.iceberg.spark.source.TestSparkCatalog"); + config.forEach((key, value) -> spark.conf().set("spark.sql.catalog.spark_catalog." + key, value)); + } + + @AfterClass + public static void stopSpark() { + SparkSession currentSpark = TestSparkReadProjection.spark; + TestSparkReadProjection.spark = null; + currentSpark.stop(); + } + + @Override + protected Record writeAndRead(String desc, Schema writeSchema, Schema readSchema, + Record record) throws IOException { + File parent = temp.newFolder(desc); + File location = new File(parent, "test"); + File dataFolder = new File(location, "data"); + Assert.assertTrue("mkdirs should succeed", dataFolder.mkdirs()); + + File testFile = new File(dataFolder, format.addExtension(UUID.randomUUID().toString())); + + Table table = TestTables.create(location, desc, writeSchema, PartitionSpec.unpartitioned()); + try { + // Important: use the table's schema for the rest of the test + // When tables are created, the column ids are reassigned. + Schema tableSchema = table.schema(); + + try (FileAppender writer = new GenericAppenderFactory(tableSchema).newAppender( + localOutput(testFile), format)) { + writer.add(record); + } + + DataFile file = DataFiles.builder(PartitionSpec.unpartitioned()) + .withRecordCount(100) + .withFileSizeInBytes(testFile.length()) + .withPath(testFile.toString()) + .build(); + + table.newAppend().appendFile(file).commit(); + + // rewrite the read schema for the table's reassigned ids + Map idMapping = Maps.newHashMap(); + for (int id : allIds(writeSchema)) { + // translate each id to the original schema's column name, then to the new schema's id + String originalName = writeSchema.findColumnName(id); + idMapping.put(id, tableSchema.findField(originalName).fieldId()); + } + Schema expectedSchema = reassignIds(readSchema, idMapping); + + // Set the schema to the expected schema directly to simulate the table schema evolving + TestTables.replaceMetadata(desc, + TestTables.readMetadata(desc).updateSchema(expectedSchema, 100)); + + Dataset df = spark.read() + .format("org.apache.iceberg.spark.source.TestIcebergSource") + .option("iceberg.table.name", desc) + .option(SparkReadOptions.VECTORIZATION_ENABLED, String.valueOf(vectorized)) + .load(); + + return SparkValueConverter.convert(readSchema, df.collectAsList().get(0)); + + } finally { + TestTables.clearTables(); + } + } + + private List allIds(Schema schema) { + List ids = Lists.newArrayList(); + TypeUtil.visit(schema, new TypeUtil.SchemaVisitor() { + @Override + public Void field(Types.NestedField field, Void fieldResult) { + ids.add(field.fieldId()); + return null; + } + + @Override + public Void list(Types.ListType list, Void elementResult) { + ids.add(list.elementId()); + return null; + } + + @Override + public Void map(Types.MapType map, Void keyResult, Void valueResult) { + ids.add(map.keyId()); + ids.add(map.valueId()); + return null; + } + }); + return ids; + } + + private Schema reassignIds(Schema schema, Map idMapping) { + return new Schema(TypeUtil.visit(schema, new TypeUtil.SchemaVisitor() { + private int mapId(int id) { + if (idMapping.containsKey(id)) { + return idMapping.get(id); + } + return 1000 + id; // make sure the new IDs don't conflict with reassignment + } + + @Override + public Type schema(Schema schema, Type structResult) { + return structResult; + } + + @Override + public Type struct(Types.StructType struct, List fieldResults) { + List newFields = Lists.newArrayListWithExpectedSize(fieldResults.size()); + List fields = struct.fields(); + for (int i = 0; i < fields.size(); i += 1) { + Types.NestedField field = fields.get(i); + if (field.isOptional()) { + newFields.add(optional(mapId(field.fieldId()), field.name(), fieldResults.get(i))); + } else { + newFields.add(required(mapId(field.fieldId()), field.name(), fieldResults.get(i))); + } + } + return Types.StructType.of(newFields); + } + + @Override + public Type field(Types.NestedField field, Type fieldResult) { + return fieldResult; + } + + @Override + public Type list(Types.ListType list, Type elementResult) { + if (list.isElementOptional()) { + return Types.ListType.ofOptional(mapId(list.elementId()), elementResult); + } else { + return Types.ListType.ofRequired(mapId(list.elementId()), elementResult); + } + } + + @Override + public Type map(Types.MapType map, Type keyResult, Type valueResult) { + if (map.isValueOptional()) { + return Types.MapType.ofOptional( + mapId(map.keyId()), mapId(map.valueId()), keyResult, valueResult); + } else { + return Types.MapType.ofRequired( + mapId(map.keyId()), mapId(map.valueId()), keyResult, valueResult); + } + } + + @Override + public Type primitive(Type.PrimitiveType primitive) { + return primitive; + } + }).asNestedType().asStructType().fields()); + } +} diff --git a/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/source/TestSparkReaderDeletes.java b/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/source/TestSparkReaderDeletes.java new file mode 100644 index 000000000000..aa3e193ac3af --- /dev/null +++ b/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/source/TestSparkReaderDeletes.java @@ -0,0 +1,481 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iceberg.spark.source; + +import java.io.IOException; +import java.util.List; +import java.util.Set; +import org.apache.hadoop.hive.conf.HiveConf; +import org.apache.iceberg.BaseTable; +import org.apache.iceberg.CatalogUtil; +import org.apache.iceberg.CombinedScanTask; +import org.apache.iceberg.DeleteFile; +import org.apache.iceberg.Files; +import org.apache.iceberg.MetadataColumns; +import org.apache.iceberg.PartitionSpec; +import org.apache.iceberg.Schema; +import org.apache.iceberg.Table; +import org.apache.iceberg.TableMetadata; +import org.apache.iceberg.TableOperations; +import org.apache.iceberg.TableProperties; +import org.apache.iceberg.TestHelpers; +import org.apache.iceberg.catalog.Namespace; +import org.apache.iceberg.catalog.TableIdentifier; +import org.apache.iceberg.data.DeleteReadTests; +import org.apache.iceberg.data.FileHelpers; +import org.apache.iceberg.data.GenericRecord; +import org.apache.iceberg.data.InternalRecordWrapper; +import org.apache.iceberg.data.Record; +import org.apache.iceberg.exceptions.AlreadyExistsException; +import org.apache.iceberg.hive.HiveCatalog; +import org.apache.iceberg.hive.TestHiveMetastore; +import org.apache.iceberg.io.CloseableIterable; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap; +import org.apache.iceberg.relocated.com.google.common.collect.Lists; +import org.apache.iceberg.relocated.com.google.common.collect.Sets; +import org.apache.iceberg.spark.SparkSchemaUtil; +import org.apache.iceberg.spark.SparkStructLike; +import org.apache.iceberg.types.Types; +import org.apache.iceberg.util.ArrayUtil; +import org.apache.iceberg.util.CharSequenceSet; +import org.apache.iceberg.util.Pair; +import org.apache.iceberg.util.StructLikeSet; +import org.apache.iceberg.util.TableScanUtil; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.SparkSession; +import org.apache.spark.sql.internal.SQLConf; +import org.jetbrains.annotations.NotNull; +import org.junit.AfterClass; +import org.junit.Assert; +import org.junit.Assume; +import org.junit.BeforeClass; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.Parameterized; + +import static org.apache.hadoop.hive.conf.HiveConf.ConfVars.METASTOREURIS; +import static org.apache.iceberg.types.Types.NestedField.required; + +@RunWith(Parameterized.class) +public class TestSparkReaderDeletes extends DeleteReadTests { + + private static TestHiveMetastore metastore = null; + protected static SparkSession spark = null; + protected static HiveCatalog catalog = null; + private final boolean vectorized; + + public TestSparkReaderDeletes(boolean vectorized) { + this.vectorized = vectorized; + } + + @Parameterized.Parameters(name = "vectorized = {0}") + public static Object[] parameters() { + return new Object[] {false, true}; + } + + @BeforeClass + public static void startMetastoreAndSpark() { + metastore = new TestHiveMetastore(); + metastore.start(); + HiveConf hiveConf = metastore.hiveConf(); + + spark = SparkSession.builder() + .master("local[2]") + .config(SQLConf.PARTITION_OVERWRITE_MODE().key(), "dynamic") + .config("spark.hadoop." + METASTOREURIS.varname, hiveConf.get(METASTOREURIS.varname)) + .enableHiveSupport() + .getOrCreate(); + + catalog = (HiveCatalog) + CatalogUtil.loadCatalog(HiveCatalog.class.getName(), "hive", ImmutableMap.of(), hiveConf); + + try { + catalog.createNamespace(Namespace.of("default")); + } catch (AlreadyExistsException ignored) { + // the default namespace already exists. ignore the create error + } + } + + @AfterClass + public static void stopMetastoreAndSpark() throws Exception { + catalog = null; + metastore.stop(); + metastore = null; + spark.stop(); + spark = null; + } + + @Override + protected Table createTable(String name, Schema schema, PartitionSpec spec) { + Table table = catalog.createTable(TableIdentifier.of("default", name), schema); + TableOperations ops = ((BaseTable) table).operations(); + TableMetadata meta = ops.current(); + ops.commit(meta, meta.upgradeToFormatVersion(2)); + if (vectorized) { + table.updateProperties() + .set(TableProperties.PARQUET_VECTORIZATION_ENABLED, "true") + .set(TableProperties.PARQUET_BATCH_SIZE, "4") // split 7 records to two batches to cover more code paths + .commit(); + } else { + table.updateProperties() + .set(TableProperties.PARQUET_VECTORIZATION_ENABLED, "false") + .commit(); + } + return table; + } + + @Override + protected void dropTable(String name) { + catalog.dropTable(TableIdentifier.of("default", name)); + } + + @Override + public StructLikeSet rowSet(String name, Table table, String... columns) { + return rowSet(name, table.schema().select(columns).asStruct(), columns); + } + + public StructLikeSet rowSet(String name, Types.StructType projection, String... columns) { + Dataset df = spark.read() + .format("iceberg") + .load(TableIdentifier.of("default", name).toString()) + .selectExpr(columns); + + StructLikeSet set = StructLikeSet.create(projection); + df.collectAsList().forEach(row -> { + SparkStructLike rowWrapper = new SparkStructLike(projection); + set.add(rowWrapper.wrap(row)); + }); + + return set; + } + + @Test + public void testEqualityDeleteWithFilter() throws IOException { + String tableName = table.name().substring(table.name().lastIndexOf(".") + 1); + Schema deleteRowSchema = table.schema().select("data"); + Record dataDelete = GenericRecord.create(deleteRowSchema); + List dataDeletes = Lists.newArrayList( + dataDelete.copy("data", "a"), // id = 29 + dataDelete.copy("data", "d"), // id = 89 + dataDelete.copy("data", "g") // id = 122 + ); + + DeleteFile eqDeletes = FileHelpers.writeDeleteFile( + table, Files.localOutput(temp.newFile()), TestHelpers.Row.of(0), dataDeletes, deleteRowSchema); + + table.newRowDelta() + .addDeletes(eqDeletes) + .commit(); + + Types.StructType projection = table.schema().select("*").asStruct(); + Dataset df = spark.read() + .format("iceberg") + .load(TableIdentifier.of("default", tableName).toString()) + .filter("data = 'a'") // select a deleted row + .selectExpr("*"); + + StructLikeSet actual = StructLikeSet.create(projection); + df.collectAsList().forEach(row -> { + SparkStructLike rowWrapper = new SparkStructLike(projection); + actual.add(rowWrapper.wrap(row)); + }); + + Assert.assertEquals("Table should contain no rows", 0, actual.size()); + } + + @Test + public void testReadEqualityDeleteRows() throws IOException { + Schema deleteSchema1 = table.schema().select("data"); + Record dataDelete = GenericRecord.create(deleteSchema1); + List dataDeletes = Lists.newArrayList( + dataDelete.copy("data", "a"), // id = 29 + dataDelete.copy("data", "d") // id = 89 + ); + + Schema deleteSchema2 = table.schema().select("id"); + Record idDelete = GenericRecord.create(deleteSchema2); + List idDeletes = Lists.newArrayList( + idDelete.copy("id", 121), // id = 121 + idDelete.copy("id", 122) // id = 122 + ); + + DeleteFile eqDelete1 = FileHelpers.writeDeleteFile( + table, Files.localOutput(temp.newFile()), TestHelpers.Row.of(0), dataDeletes, deleteSchema1); + + DeleteFile eqDelete2 = FileHelpers.writeDeleteFile( + table, Files.localOutput(temp.newFile()), TestHelpers.Row.of(0), idDeletes, deleteSchema2); + + table.newRowDelta() + .addDeletes(eqDelete1) + .addDeletes(eqDelete2) + .commit(); + + StructLikeSet expectedRowSet = rowSetWithIds(29, 89, 121, 122); + + Types.StructType type = table.schema().asStruct(); + StructLikeSet actualRowSet = StructLikeSet.create(type); + + CloseableIterable tasks = TableScanUtil.planTasks( + table.newScan().planFiles(), + TableProperties.METADATA_SPLIT_SIZE_DEFAULT, + TableProperties.SPLIT_LOOKBACK_DEFAULT, + TableProperties.SPLIT_OPEN_FILE_COST_DEFAULT); + + for (CombinedScanTask task : tasks) { + try (EqualityDeleteRowReader reader = new EqualityDeleteRowReader(task, table, table.schema(), false)) { + while (reader.next()) { + actualRowSet.add(new InternalRowWrapper(SparkSchemaUtil.convert(table.schema())).wrap(reader.get().copy())); + } + } + } + + Assert.assertEquals("should include 4 deleted row", 4, actualRowSet.size()); + Assert.assertEquals("deleted row should be matched", expectedRowSet, actualRowSet); + } + + @Test + public void testPosDeletesAllRowsInBatch() throws IOException { + // read.parquet.vectorization.batch-size is set to 4, so the 4 rows in the first batch are all deleted. + List> deletes = Lists.newArrayList( + Pair.of(dataFile.path(), 0L), // id = 29 + Pair.of(dataFile.path(), 1L), // id = 43 + Pair.of(dataFile.path(), 2L), // id = 61 + Pair.of(dataFile.path(), 3L) // id = 89 + ); + + Pair posDeletes = FileHelpers.writeDeleteFile( + table, Files.localOutput(temp.newFile()), TestHelpers.Row.of(0), deletes); + + table.newRowDelta() + .addDeletes(posDeletes.first()) + .validateDataFilesExist(posDeletes.second()) + .commit(); + + StructLikeSet expected = rowSetWithoutIds(table, records, 29, 43, 61, 89); + StructLikeSet actual = rowSet(tableName, table, "*"); + + Assert.assertEquals("Table should contain expected rows", expected, actual); + } + + @Test + public void testPosDeletesWithDeletedColumn() throws IOException { + Assume.assumeFalse(vectorized); + + // read.parquet.vectorization.batch-size is set to 4, so the 4 rows in the first batch are all deleted. + List> deletes = Lists.newArrayList( + Pair.of(dataFile.path(), 0L), // id = 29 + Pair.of(dataFile.path(), 1L), // id = 43 + Pair.of(dataFile.path(), 2L), // id = 61 + Pair.of(dataFile.path(), 3L) // id = 89 + ); + + Pair posDeletes = FileHelpers.writeDeleteFile( + table, Files.localOutput(temp.newFile()), TestHelpers.Row.of(0), deletes); + + table.newRowDelta() + .addDeletes(posDeletes.first()) + .validateDataFilesExist(posDeletes.second()) + .commit(); + + StructLikeSet expected = expectedRowSet(29, 43, 61, 89); + StructLikeSet actual = rowSet(tableName, PROJECTION_SCHEMA.asStruct(), "id", "data", "_deleted"); + + Assert.assertEquals("Table should contain expected row", expected, actual); + } + + @Test + public void testEqualityDeleteWithDeletedColumn() throws IOException { + Assume.assumeFalse(vectorized); + + String tableName = table.name().substring(table.name().lastIndexOf(".") + 1); + Schema deleteRowSchema = table.schema().select("data"); + Record dataDelete = GenericRecord.create(deleteRowSchema); + List dataDeletes = Lists.newArrayList( + dataDelete.copy("data", "a"), // id = 29 + dataDelete.copy("data", "d"), // id = 89 + dataDelete.copy("data", "g") // id = 122 + ); + + DeleteFile eqDeletes = FileHelpers.writeDeleteFile( + table, Files.localOutput(temp.newFile()), TestHelpers.Row.of(0), dataDeletes, deleteRowSchema); + + table.newRowDelta() + .addDeletes(eqDeletes) + .commit(); + + StructLikeSet expected = expectedRowSet(29, 89, 122); + StructLikeSet actual = rowSet(tableName, PROJECTION_SCHEMA.asStruct(), "id", "data", "_deleted"); + + Assert.assertEquals("Table should contain expected row", expected, actual); + } + + @Test + public void testMixedPosAndEqDeletesWithDeletedColumn() throws IOException { + Assume.assumeFalse(vectorized); + + Schema dataSchema = table.schema().select("data"); + Record dataDelete = GenericRecord.create(dataSchema); + List dataDeletes = Lists.newArrayList( + dataDelete.copy("data", "a"), // id = 29 + dataDelete.copy("data", "d"), // id = 89 + dataDelete.copy("data", "g") // id = 122 + ); + + DeleteFile eqDeletes = FileHelpers.writeDeleteFile( + table, Files.localOutput(temp.newFile()), TestHelpers.Row.of(0), dataDeletes, dataSchema); + + List> deletes = Lists.newArrayList( + Pair.of(dataFile.path(), 3L), // id = 89 + Pair.of(dataFile.path(), 5L) // id = 121 + ); + + Pair posDeletes = FileHelpers.writeDeleteFile( + table, Files.localOutput(temp.newFile()), TestHelpers.Row.of(0), deletes); + + table.newRowDelta() + .addDeletes(eqDeletes) + .addDeletes(posDeletes.first()) + .validateDataFilesExist(posDeletes.second()) + .commit(); + + StructLikeSet expected = expectedRowSet(29, 89, 121, 122); + StructLikeSet actual = rowSet(tableName, PROJECTION_SCHEMA.asStruct(), "id", "data", "_deleted"); + + Assert.assertEquals("Table should contain expected row", expected, actual); + } + + @Test + public void testFilterOnDeletedMetadataColumn() throws IOException { + Assume.assumeFalse(vectorized); + + List> deletes = Lists.newArrayList( + Pair.of(dataFile.path(), 0L), // id = 29 + Pair.of(dataFile.path(), 1L), // id = 43 + Pair.of(dataFile.path(), 2L), // id = 61 + Pair.of(dataFile.path(), 3L) // id = 89 + ); + + Pair posDeletes = FileHelpers.writeDeleteFile( + table, Files.localOutput(temp.newFile()), TestHelpers.Row.of(0), deletes); + + table.newRowDelta() + .addDeletes(posDeletes.first()) + .validateDataFilesExist(posDeletes.second()) + .commit(); + + StructLikeSet expected = expectedRowSetWithNonDeletesOnly(29, 43, 61, 89); + + // get non-deleted rows + Dataset df = spark.read() + .format("iceberg") + .load(TableIdentifier.of("default", tableName).toString()) + .select("id", "data", "_deleted") + .filter("_deleted = false"); + + Types.StructType projection = PROJECTION_SCHEMA.asStruct(); + StructLikeSet actual = StructLikeSet.create(projection); + df.collectAsList().forEach(row -> { + SparkStructLike rowWrapper = new SparkStructLike(projection); + actual.add(rowWrapper.wrap(row)); + }); + + Assert.assertEquals("Table should contain expected row", expected, actual); + + StructLikeSet expectedDeleted = expectedRowSetWithDeletesOnly(29, 43, 61, 89); + + // get deleted rows + df = spark.read() + .format("iceberg") + .load(TableIdentifier.of("default", tableName).toString()) + .select("id", "data", "_deleted") + .filter("_deleted = true"); + + StructLikeSet actualDeleted = StructLikeSet.create(projection); + df.collectAsList().forEach(row -> { + SparkStructLike rowWrapper = new SparkStructLike(projection); + actualDeleted.add(rowWrapper.wrap(row)); + }); + + Assert.assertEquals("Table should contain expected row", expectedDeleted, actualDeleted); + } + + @Test + public void testIsDeletedColumnWithoutDeleteFile() { + Assume.assumeFalse(vectorized); + + StructLikeSet expected = expectedRowSet(); + StructLikeSet actual = rowSet(tableName, PROJECTION_SCHEMA.asStruct(), "id", "data", "_deleted"); + Assert.assertEquals("Table should contain expected row", expected, actual); + } + + private static final Schema PROJECTION_SCHEMA = new Schema( + required(1, "id", Types.IntegerType.get()), + required(2, "data", Types.StringType.get()), + MetadataColumns.IS_DELETED + ); + + private static StructLikeSet expectedRowSet(int... idsToRemove) { + return expectedRowSet(false, false, idsToRemove); + } + + private static StructLikeSet expectedRowSetWithDeletesOnly(int... idsToRemove) { + return expectedRowSet(false, true, idsToRemove); + } + + private static StructLikeSet expectedRowSetWithNonDeletesOnly(int... idsToRemove) { + return expectedRowSet(true, false, idsToRemove); + } + + private static StructLikeSet expectedRowSet(boolean removeDeleted, boolean removeNonDeleted, int... idsToRemove) { + Set deletedIds = Sets.newHashSet(ArrayUtil.toIntList(idsToRemove)); + List records = recordsWithDeletedColumn(); + // mark rows deleted + records.forEach(record -> { + if (deletedIds.contains(record.getField("id"))) { + record.setField(MetadataColumns.IS_DELETED.name(), true); + } + }); + + records.removeIf(record -> deletedIds.contains(record.getField("id")) && removeDeleted); + records.removeIf(record -> !deletedIds.contains(record.getField("id")) && removeNonDeleted); + + StructLikeSet set = StructLikeSet.create(PROJECTION_SCHEMA.asStruct()); + records.forEach(record -> set.add(new InternalRecordWrapper(PROJECTION_SCHEMA.asStruct()).wrap(record))); + + return set; + } + + @NotNull + private static List recordsWithDeletedColumn() { + List records = Lists.newArrayList(); + + // records all use IDs that are in bucket id_bucket=0 + GenericRecord record = GenericRecord.create(PROJECTION_SCHEMA); + records.add(record.copy("id", 29, "data", "a", "_deleted", false)); + records.add(record.copy("id", 43, "data", "b", "_deleted", false)); + records.add(record.copy("id", 61, "data", "c", "_deleted", false)); + records.add(record.copy("id", 89, "data", "d", "_deleted", false)); + records.add(record.copy("id", 100, "data", "e", "_deleted", false)); + records.add(record.copy("id", 121, "data", "f", "_deleted", false)); + records.add(record.copy("id", 122, "data", "g", "_deleted", false)); + return records; + } +} diff --git a/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/source/TestSparkRollingFileWriters.java b/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/source/TestSparkRollingFileWriters.java new file mode 100644 index 000000000000..9023195dcc6a --- /dev/null +++ b/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/source/TestSparkRollingFileWriters.java @@ -0,0 +1,59 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iceberg.spark.source; + +import java.util.List; +import org.apache.iceberg.FileFormat; +import org.apache.iceberg.Schema; +import org.apache.iceberg.io.FileWriterFactory; +import org.apache.iceberg.io.TestRollingFileWriters; +import org.apache.iceberg.util.ArrayUtil; +import org.apache.spark.sql.catalyst.InternalRow; +import org.apache.spark.sql.catalyst.expressions.GenericInternalRow; +import org.apache.spark.unsafe.types.UTF8String; + +public class TestSparkRollingFileWriters extends TestRollingFileWriters { + + public TestSparkRollingFileWriters(FileFormat fileFormat, boolean partitioned) { + super(fileFormat, partitioned); + } + + @Override + protected FileWriterFactory newWriterFactory(Schema dataSchema, List equalityFieldIds, + Schema equalityDeleteRowSchema, + Schema positionDeleteRowSchema) { + return SparkFileWriterFactory.builderFor(table) + .dataSchema(table.schema()) + .dataFileFormat(format()) + .deleteFileFormat(format()) + .equalityFieldIds(ArrayUtil.toIntArray(equalityFieldIds)) + .equalityDeleteRowSchema(equalityDeleteRowSchema) + .positionDeleteRowSchema(positionDeleteRowSchema) + .build(); + } + + @Override + protected InternalRow toRow(Integer id, String data) { + InternalRow row = new GenericInternalRow(2); + row.update(0, id); + row.update(1, UTF8String.fromString(data)); + return row; + } +} diff --git a/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/source/TestSparkTable.java b/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/source/TestSparkTable.java new file mode 100644 index 000000000000..1b4fb5f8ce58 --- /dev/null +++ b/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/source/TestSparkTable.java @@ -0,0 +1,61 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iceberg.spark.source; + +import java.util.Map; +import org.apache.iceberg.spark.SparkCatalogTestBase; +import org.apache.spark.sql.catalyst.analysis.NoSuchTableException; +import org.apache.spark.sql.connector.catalog.CatalogManager; +import org.apache.spark.sql.connector.catalog.Identifier; +import org.apache.spark.sql.connector.catalog.TableCatalog; +import org.junit.After; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; + +public class TestSparkTable extends SparkCatalogTestBase { + + public TestSparkTable(String catalogName, String implementation, Map config) { + super(catalogName, implementation, config); + } + + @Before + public void createTable() { + sql("CREATE TABLE %s (id bigint NOT NULL, data string) USING iceberg", tableName); + } + + @After + public void removeTable() { + sql("DROP TABLE IF EXISTS %s", tableName); + } + + @Test + public void testTableEquality() throws NoSuchTableException { + CatalogManager catalogManager = spark.sessionState().catalogManager(); + TableCatalog catalog = (TableCatalog) catalogManager.catalog(catalogName); + Identifier identifier = Identifier.of(tableIdent.namespace().levels(), tableIdent.name()); + SparkTable table1 = (SparkTable) catalog.loadTable(identifier); + SparkTable table2 = (SparkTable) catalog.loadTable(identifier); + + // different instances pointing to the same table must be equivalent + Assert.assertNotSame("References must be different", table1, table2); + Assert.assertEquals("Tables must be equivalent", table1, table2); + } +} diff --git a/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/source/TestSparkWriterMetrics.java b/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/source/TestSparkWriterMetrics.java new file mode 100644 index 000000000000..967f394faa74 --- /dev/null +++ b/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/source/TestSparkWriterMetrics.java @@ -0,0 +1,68 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iceberg.spark.source; + +import org.apache.iceberg.FileFormat; +import org.apache.iceberg.Table; +import org.apache.iceberg.io.FileWriterFactory; +import org.apache.iceberg.io.TestWriterMetrics; +import org.apache.spark.sql.catalyst.InternalRow; +import org.apache.spark.sql.catalyst.expressions.GenericInternalRow; +import org.apache.spark.unsafe.types.UTF8String; + +public class TestSparkWriterMetrics extends TestWriterMetrics { + + public TestSparkWriterMetrics(FileFormat fileFormat) { + super(fileFormat); + } + + @Override + protected FileWriterFactory newWriterFactory(Table sourceTable) { + return SparkFileWriterFactory.builderFor(sourceTable) + .dataSchema(sourceTable.schema()) + .dataFileFormat(fileFormat) + .deleteFileFormat(fileFormat) + .positionDeleteRowSchema(sourceTable.schema()) + .build(); + } + + @Override + protected InternalRow toRow(Integer id, String data, boolean boolValue, Long longValue) { + InternalRow row = new GenericInternalRow(3); + row.update(0, id); + row.update(1, UTF8String.fromString(data)); + + InternalRow nested = new GenericInternalRow(2); + nested.update(0, boolValue); + nested.update(1, longValue); + + row.update(2, nested); + return row; + } + + @Override + protected InternalRow toGenericRow(int value, int repeated) { + InternalRow row = new GenericInternalRow(repeated); + for (int i = 0; i < repeated; i++) { + row.update(i, value); + } + return row; + } +} diff --git a/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/source/TestStreamingOffset.java b/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/source/TestStreamingOffset.java new file mode 100644 index 000000000000..69302e9d24d7 --- /dev/null +++ b/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/source/TestStreamingOffset.java @@ -0,0 +1,54 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iceberg.spark.source; + +import com.fasterxml.jackson.databind.node.ObjectNode; +import java.util.Arrays; +import org.apache.iceberg.util.JsonUtil; +import org.junit.Assert; +import org.junit.Test; + +public class TestStreamingOffset { + + @Test + public void testJsonConversion() { + StreamingOffset[] expected = new StreamingOffset[]{ + new StreamingOffset(System.currentTimeMillis(), 1L, false), + new StreamingOffset(System.currentTimeMillis(), 2L, false), + new StreamingOffset(System.currentTimeMillis(), 3L, false), + new StreamingOffset(System.currentTimeMillis(), 4L, true) + }; + Assert.assertArrayEquals("StreamingOffsets should match", expected, + Arrays.stream(expected).map(elem -> StreamingOffset.fromJson(elem.json())).toArray()); + } + + @Test + public void testToJson() throws Exception { + StreamingOffset expected = new StreamingOffset(System.currentTimeMillis(), 1L, false); + ObjectNode actual = JsonUtil.mapper().createObjectNode(); + actual.put("version", 1); + actual.put("snapshot_id", expected.snapshotId()); + actual.put("position", 1L); + actual.put("scan_all_files", false); + String expectedJson = expected.json(); + String actualJson = JsonUtil.mapper().writeValueAsString(actual); + Assert.assertEquals("Json should match", expectedJson, actualJson); + } +} diff --git a/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/source/TestStructuredStreaming.java b/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/source/TestStructuredStreaming.java new file mode 100644 index 000000000000..e931f554cf5a --- /dev/null +++ b/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/source/TestStructuredStreaming.java @@ -0,0 +1,305 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iceberg.spark.source; + +import java.io.File; +import java.util.List; +import org.apache.hadoop.conf.Configuration; +import org.apache.iceberg.PartitionSpec; +import org.apache.iceberg.Schema; +import org.apache.iceberg.Table; +import org.apache.iceberg.hadoop.HadoopTables; +import org.apache.iceberg.relocated.com.google.common.collect.Iterables; +import org.apache.iceberg.relocated.com.google.common.collect.Lists; +import org.apache.iceberg.types.Types; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Encoder; +import org.apache.spark.sql.Encoders; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.SQLContext; +import org.apache.spark.sql.SparkSession; +import org.apache.spark.sql.execution.streaming.MemoryStream; +import org.apache.spark.sql.streaming.DataStreamWriter; +import org.apache.spark.sql.streaming.StreamingQuery; +import org.apache.spark.sql.streaming.StreamingQueryException; +import org.junit.AfterClass; +import org.junit.Assert; +import org.junit.BeforeClass; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.ExpectedException; +import org.junit.rules.TemporaryFolder; +import scala.Option; +import scala.collection.JavaConverters; + +import static org.apache.iceberg.types.Types.NestedField.optional; + +public class TestStructuredStreaming { + + private static final Configuration CONF = new Configuration(); + private static final Schema SCHEMA = new Schema( + optional(1, "id", Types.IntegerType.get()), + optional(2, "data", Types.StringType.get()) + ); + private static SparkSession spark = null; + + @Rule + public TemporaryFolder temp = new TemporaryFolder(); + @Rule + public ExpectedException exceptionRule = ExpectedException.none(); + + @BeforeClass + public static void startSpark() { + TestStructuredStreaming.spark = SparkSession.builder() + .master("local[2]") + .config("spark.sql.shuffle.partitions", 4) + .getOrCreate(); + } + + @AfterClass + public static void stopSpark() { + SparkSession currentSpark = TestStructuredStreaming.spark; + TestStructuredStreaming.spark = null; + currentSpark.stop(); + } + + @Test + public void testStreamingWriteAppendMode() throws Exception { + File parent = temp.newFolder("parquet"); + File location = new File(parent, "test-table"); + File checkpoint = new File(parent, "checkpoint"); + + HadoopTables tables = new HadoopTables(CONF); + PartitionSpec spec = PartitionSpec.builderFor(SCHEMA).identity("data").build(); + Table table = tables.create(SCHEMA, spec, location.toString()); + + List expected = Lists.newArrayList( + new SimpleRecord(1, "1"), + new SimpleRecord(2, "2"), + new SimpleRecord(3, "3"), + new SimpleRecord(4, "4") + ); + + MemoryStream inputStream = newMemoryStream(1, spark.sqlContext(), Encoders.INT()); + DataStreamWriter streamWriter = inputStream.toDF() + .selectExpr("value AS id", "CAST (value AS STRING) AS data") + .writeStream() + .outputMode("append") + .format("iceberg") + .option("checkpointLocation", checkpoint.toString()) + .option("path", location.toString()); + + try { + // start the original query with checkpointing + StreamingQuery query = streamWriter.start(); + List batch1 = Lists.newArrayList(1, 2); + send(batch1, inputStream); + query.processAllAvailable(); + List batch2 = Lists.newArrayList(3, 4); + send(batch2, inputStream); + query.processAllAvailable(); + query.stop(); + + // remove the last commit to force Spark to reprocess batch #1 + File lastCommitFile = new File(checkpoint.toString() + "/commits/1"); + Assert.assertTrue("The commit file must be deleted", lastCommitFile.delete()); + + // restart the query from the checkpoint + StreamingQuery restartedQuery = streamWriter.start(); + restartedQuery.processAllAvailable(); + + // ensure the write was idempotent + Dataset result = spark.read() + .format("iceberg") + .load(location.toString()); + List actual = result.orderBy("id").as(Encoders.bean(SimpleRecord.class)).collectAsList(); + Assert.assertEquals("Number of rows should match", expected.size(), actual.size()); + Assert.assertEquals("Result rows should match", expected, actual); + Assert.assertEquals("Number of snapshots should match", 2, Iterables.size(table.snapshots())); + } finally { + for (StreamingQuery query : spark.streams().active()) { + query.stop(); + } + } + } + + @Test + public void testStreamingWriteCompleteMode() throws Exception { + File parent = temp.newFolder("parquet"); + File location = new File(parent, "test-table"); + File checkpoint = new File(parent, "checkpoint"); + + HadoopTables tables = new HadoopTables(CONF); + PartitionSpec spec = PartitionSpec.builderFor(SCHEMA).identity("data").build(); + Table table = tables.create(SCHEMA, spec, location.toString()); + + List expected = Lists.newArrayList( + new SimpleRecord(2, "1"), + new SimpleRecord(3, "2"), + new SimpleRecord(1, "3") + ); + + MemoryStream inputStream = newMemoryStream(1, spark.sqlContext(), Encoders.INT()); + DataStreamWriter streamWriter = inputStream.toDF() + .groupBy("value") + .count() + .selectExpr("CAST(count AS INT) AS id", "CAST (value AS STRING) AS data") + .writeStream() + .outputMode("complete") + .format("iceberg") + .option("checkpointLocation", checkpoint.toString()) + .option("path", location.toString()); + + try { + // start the original query with checkpointing + StreamingQuery query = streamWriter.start(); + List batch1 = Lists.newArrayList(1, 2); + send(batch1, inputStream); + query.processAllAvailable(); + List batch2 = Lists.newArrayList(1, 2, 2, 3); + send(batch2, inputStream); + query.processAllAvailable(); + query.stop(); + + // remove the last commit to force Spark to reprocess batch #1 + File lastCommitFile = new File(checkpoint.toString() + "/commits/1"); + Assert.assertTrue("The commit file must be deleted", lastCommitFile.delete()); + + // restart the query from the checkpoint + StreamingQuery restartedQuery = streamWriter.start(); + restartedQuery.processAllAvailable(); + + // ensure the write was idempotent + Dataset result = spark.read() + .format("iceberg") + .load(location.toString()); + List actual = result.orderBy("data").as(Encoders.bean(SimpleRecord.class)).collectAsList(); + Assert.assertEquals("Number of rows should match", expected.size(), actual.size()); + Assert.assertEquals("Result rows should match", expected, actual); + Assert.assertEquals("Number of snapshots should match", 2, Iterables.size(table.snapshots())); + } finally { + for (StreamingQuery query : spark.streams().active()) { + query.stop(); + } + } + } + + @Test + public void testStreamingWriteCompleteModeWithProjection() throws Exception { + File parent = temp.newFolder("parquet"); + File location = new File(parent, "test-table"); + File checkpoint = new File(parent, "checkpoint"); + + HadoopTables tables = new HadoopTables(CONF); + PartitionSpec spec = PartitionSpec.unpartitioned(); + Table table = tables.create(SCHEMA, spec, location.toString()); + + List expected = Lists.newArrayList( + new SimpleRecord(1, null), + new SimpleRecord(2, null), + new SimpleRecord(3, null) + ); + + MemoryStream inputStream = newMemoryStream(1, spark.sqlContext(), Encoders.INT()); + DataStreamWriter streamWriter = inputStream.toDF() + .groupBy("value") + .count() + .selectExpr("CAST(count AS INT) AS id") // select only id column + .writeStream() + .outputMode("complete") + .format("iceberg") + .option("checkpointLocation", checkpoint.toString()) + .option("path", location.toString()); + + try { + // start the original query with checkpointing + StreamingQuery query = streamWriter.start(); + List batch1 = Lists.newArrayList(1, 2); + send(batch1, inputStream); + query.processAllAvailable(); + List batch2 = Lists.newArrayList(1, 2, 2, 3); + send(batch2, inputStream); + query.processAllAvailable(); + query.stop(); + + // remove the last commit to force Spark to reprocess batch #1 + File lastCommitFile = new File(checkpoint.toString() + "/commits/1"); + Assert.assertTrue("The commit file must be deleted", lastCommitFile.delete()); + + // restart the query from the checkpoint + StreamingQuery restartedQuery = streamWriter.start(); + restartedQuery.processAllAvailable(); + + // ensure the write was idempotent + Dataset result = spark.read() + .format("iceberg") + .load(location.toString()); + List actual = result.orderBy("id").as(Encoders.bean(SimpleRecord.class)).collectAsList(); + Assert.assertEquals("Number of rows should match", expected.size(), actual.size()); + Assert.assertEquals("Result rows should match", expected, actual); + Assert.assertEquals("Number of snapshots should match", 2, Iterables.size(table.snapshots())); + } finally { + for (StreamingQuery query : spark.streams().active()) { + query.stop(); + } + } + } + + @Test + public void testStreamingWriteUpdateMode() throws Exception { + exceptionRule.expect(StreamingQueryException.class); + + File parent = temp.newFolder("parquet"); + File location = new File(parent, "test-table"); + File checkpoint = new File(parent, "checkpoint"); + + HadoopTables tables = new HadoopTables(CONF); + PartitionSpec spec = PartitionSpec.builderFor(SCHEMA).identity("data").build(); + tables.create(SCHEMA, spec, location.toString()); + + MemoryStream inputStream = newMemoryStream(1, spark.sqlContext(), Encoders.INT()); + DataStreamWriter streamWriter = inputStream.toDF() + .selectExpr("value AS id", "CAST (value AS STRING) AS data") + .writeStream() + .outputMode("update") + .format("iceberg") + .option("checkpointLocation", checkpoint.toString()) + .option("path", location.toString()); + + try { + StreamingQuery query = streamWriter.start(); + List batch1 = Lists.newArrayList(1, 2); + send(batch1, inputStream); + query.processAllAvailable(); + } finally { + for (StreamingQuery query : spark.streams().active()) { + query.stop(); + } + } + } + + private MemoryStream newMemoryStream(int id, SQLContext sqlContext, Encoder encoder) { + return new MemoryStream<>(id, sqlContext, Option.empty(), encoder); + } + + private void send(List records, MemoryStream stream) { + stream.addData(JavaConverters.asScalaBuffer(records)); + } +} diff --git a/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/source/TestStructuredStreamingRead3.java b/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/source/TestStructuredStreamingRead3.java new file mode 100644 index 000000000000..e609412f8be0 --- /dev/null +++ b/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/source/TestStructuredStreamingRead3.java @@ -0,0 +1,547 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iceberg.spark.source; + +import java.io.File; +import java.util.Collections; +import java.util.List; +import java.util.Map; +import java.util.concurrent.TimeoutException; +import java.util.stream.IntStream; +import org.apache.iceberg.AssertHelpers; +import org.apache.iceberg.BaseTable; +import org.apache.iceberg.DataOperations; +import org.apache.iceberg.DeleteFile; +import org.apache.iceberg.Files; +import org.apache.iceberg.Schema; +import org.apache.iceberg.Table; +import org.apache.iceberg.TableMetadata; +import org.apache.iceberg.TableOperations; +import org.apache.iceberg.TestHelpers; +import org.apache.iceberg.data.FileHelpers; +import org.apache.iceberg.data.GenericRecord; +import org.apache.iceberg.data.Record; +import org.apache.iceberg.expressions.Expressions; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap; +import org.apache.iceberg.relocated.com.google.common.collect.Iterables; +import org.apache.iceberg.relocated.com.google.common.collect.Lists; +import org.apache.iceberg.spark.SparkCatalogTestBase; +import org.apache.iceberg.spark.SparkReadOptions; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Encoders; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.streaming.DataStreamWriter; +import org.apache.spark.sql.streaming.OutputMode; +import org.apache.spark.sql.streaming.StreamingQuery; +import org.assertj.core.api.Assertions; +import org.junit.After; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.Parameterized; + +import static org.apache.iceberg.expressions.Expressions.ref; + +@RunWith(Parameterized.class) +public final class TestStructuredStreamingRead3 extends SparkCatalogTestBase { + public TestStructuredStreamingRead3( + String catalogName, String implementation, Map config) { + super(catalogName, implementation, config); + } + + private Table table; + + /** + * test data to be used by multiple writes + * each write creates a snapshot and writes a list of records + */ + private static final List> TEST_DATA_MULTIPLE_SNAPSHOTS = Lists.newArrayList( + Lists.newArrayList( + new SimpleRecord(1, "one"), + new SimpleRecord(2, "two"), + new SimpleRecord(3, "three")), + Lists.newArrayList( + new SimpleRecord(4, "four"), + new SimpleRecord(5, "five")), + Lists.newArrayList( + new SimpleRecord(6, "six"), + new SimpleRecord(7, "seven"))); + + /** + * test data - to be used for multiple write batches + * each batch inturn will have multiple snapshots + */ + private static final List>> TEST_DATA_MULTIPLE_WRITES_MULTIPLE_SNAPSHOTS = Lists.newArrayList( + Lists.newArrayList( + Lists.newArrayList( + new SimpleRecord(1, "one"), + new SimpleRecord(2, "two"), + new SimpleRecord(3, "three")), + Lists.newArrayList( + new SimpleRecord(4, "four"), + new SimpleRecord(5, "five"))), + Lists.newArrayList( + Lists.newArrayList( + new SimpleRecord(6, "six"), + new SimpleRecord(7, "seven")), + Lists.newArrayList( + new SimpleRecord(8, "eight"), + new SimpleRecord(9, "nine"))), + Lists.newArrayList( + Lists.newArrayList( + new SimpleRecord(10, "ten"), + new SimpleRecord(11, "eleven"), + new SimpleRecord(12, "twelve")), + Lists.newArrayList( + new SimpleRecord(13, "thirteen"), + new SimpleRecord(14, "fourteen")), + Lists.newArrayList( + new SimpleRecord(15, "fifteen"), + new SimpleRecord(16, "sixteen")))); + + @Before + public void setupTable() { + sql("CREATE TABLE %s " + + "(id INT, data STRING) " + + "USING iceberg " + + "PARTITIONED BY (bucket(3, id))", tableName); + this.table = validationCatalog.loadTable(tableIdent); + } + + @After + public void stopStreams() throws TimeoutException { + for (StreamingQuery query : spark.streams().active()) { + query.stop(); + } + } + + @After + public void removeTables() { + sql("DROP TABLE IF EXISTS %s", tableName); + } + + @Test + public void testReadStreamOnIcebergTableWithMultipleSnapshots() throws Exception { + List> expected = TEST_DATA_MULTIPLE_SNAPSHOTS; + appendDataAsMultipleSnapshots(expected); + + StreamingQuery query = startStream(); + + List actual = rowsAvailable(query); + Assertions.assertThat(actual).containsExactlyInAnyOrderElementsOf(Iterables.concat(expected)); + } + + @Test + public void testReadStreamOnIcebergThenAddData() throws Exception { + List> expected = TEST_DATA_MULTIPLE_SNAPSHOTS; + + StreamingQuery query = startStream(); + + appendDataAsMultipleSnapshots(expected); + + List actual = rowsAvailable(query); + Assertions.assertThat(actual).containsExactlyInAnyOrderElementsOf(Iterables.concat(expected)); + } + + @Test + public void testReadingStreamFromTimestamp() throws Exception { + List dataBeforeTimestamp = Lists.newArrayList( + new SimpleRecord(-2, "minustwo"), + new SimpleRecord(-1, "minusone"), + new SimpleRecord(0, "zero")); + + appendData(dataBeforeTimestamp); + + table.refresh(); + long streamStartTimestamp = table.currentSnapshot().timestampMillis() + 1; + + StreamingQuery query = startStream(SparkReadOptions.STREAM_FROM_TIMESTAMP, Long.toString(streamStartTimestamp)); + + List empty = rowsAvailable(query); + Assertions.assertThat(empty.isEmpty()).isTrue(); + + List> expected = TEST_DATA_MULTIPLE_SNAPSHOTS; + appendDataAsMultipleSnapshots(expected); + + List actual = rowsAvailable(query); + + Assertions.assertThat(actual).containsExactlyInAnyOrderElementsOf(Iterables.concat(expected)); + } + + @Test + public void testReadingStreamFromFutureTimetsamp() throws Exception { + long futureTimestamp = System.currentTimeMillis() + 10000; + + StreamingQuery query = startStream(SparkReadOptions.STREAM_FROM_TIMESTAMP, Long.toString(futureTimestamp)); + + List actual = rowsAvailable(query); + Assertions.assertThat(actual.isEmpty()).isTrue(); + + List data = Lists.newArrayList( + new SimpleRecord(-2, "minustwo"), + new SimpleRecord(-1, "minusone"), + new SimpleRecord(0, "zero")); + + // Perform several inserts that should not show up because the fromTimestamp has not elapsed + IntStream.range(0, 3).forEach(x -> { + appendData(data); + Assertions.assertThat(rowsAvailable(query).isEmpty()).isTrue(); + }); + + waitUntilAfter(futureTimestamp); + + // Data appended after the timestamp should appear + appendData(data); + actual = rowsAvailable(query); + Assertions.assertThat(actual).containsExactlyInAnyOrderElementsOf(data); + } + + @Test + public void testReadingStreamFromTimestampFutureWithExistingSnapshots() throws Exception { + List dataBeforeTimestamp = Lists.newArrayList( + new SimpleRecord(1, "one"), + new SimpleRecord(2, "two"), + new SimpleRecord(3, "three")); + appendData(dataBeforeTimestamp); + + long streamStartTimestamp = System.currentTimeMillis() + 2000; + + // Start the stream with a future timestamp after the current snapshot + StreamingQuery query = startStream(SparkReadOptions.STREAM_FROM_TIMESTAMP, Long.toString(streamStartTimestamp)); + List actual = rowsAvailable(query); + Assert.assertEquals(Collections.emptyList(), actual); + + // Stream should contain data added after the timestamp elapses + waitUntilAfter(streamStartTimestamp); + List> expected = TEST_DATA_MULTIPLE_SNAPSHOTS; + appendDataAsMultipleSnapshots(expected); + Assertions.assertThat(rowsAvailable(query)).containsExactlyInAnyOrderElementsOf(Iterables.concat(expected)); + } + + @Test + public void testReadingStreamFromTimestampOfExistingSnapshot() throws Exception { + List> expected = TEST_DATA_MULTIPLE_SNAPSHOTS; + + // Create an existing snapshot with some data + appendData(expected.get(0)); + table.refresh(); + long firstSnapshotTime = table.currentSnapshot().timestampMillis(); + + // Start stream giving the first Snapshot's time as the start point + StreamingQuery stream = startStream(SparkReadOptions.STREAM_FROM_TIMESTAMP, Long.toString(firstSnapshotTime)); + + // Append rest of expected data + for (int i = 1; i < expected.size(); i++) { + appendData(expected.get(i)); + } + + List actual = rowsAvailable(stream); + Assertions.assertThat(actual).containsExactlyInAnyOrderElementsOf(Iterables.concat(expected)); + } + + @Test + public void testReadingStreamWithExpiredSnapshotFromTimestamp() throws TimeoutException { + List firstSnapshotRecordList = Lists.newArrayList( + new SimpleRecord(1, "one")); + + List secondSnapshotRecordList = Lists.newArrayList( + new SimpleRecord(2, "two")); + + List thirdSnapshotRecordList = Lists.newArrayList( + new SimpleRecord(3, "three")); + + List expectedRecordList = Lists.newArrayList(); + expectedRecordList.addAll(secondSnapshotRecordList); + expectedRecordList.addAll(thirdSnapshotRecordList); + + appendData(firstSnapshotRecordList); + table.refresh(); + long firstSnapshotid = table.currentSnapshot().snapshotId(); + long firstSnapshotCommitTime = table.currentSnapshot().timestampMillis(); + + + appendData(secondSnapshotRecordList); + appendData(thirdSnapshotRecordList); + + table.expireSnapshots().expireSnapshotId(firstSnapshotid).commit(); + + StreamingQuery query = startStream(SparkReadOptions.STREAM_FROM_TIMESTAMP, String.valueOf(firstSnapshotCommitTime)); + List actual = rowsAvailable(query); + Assertions.assertThat(actual).containsExactlyInAnyOrderElementsOf(expectedRecordList); + } + + @Test + public void testResumingStreamReadFromCheckpoint() throws Exception { + File writerCheckpointFolder = temp.newFolder("writer-checkpoint-folder"); + File writerCheckpoint = new File(writerCheckpointFolder, "writer-checkpoint"); + File output = temp.newFolder(); + + DataStreamWriter querySource = spark.readStream() + .format("iceberg") + .load(tableName) + .writeStream() + .option("checkpointLocation", writerCheckpoint.toString()) + .format("parquet") + .queryName("checkpoint_test") + .option("path", output.getPath()); + + StreamingQuery startQuery = querySource.start(); + startQuery.processAllAvailable(); + startQuery.stop(); + + List expected = Lists.newArrayList(); + for (List> expectedCheckpoint : TEST_DATA_MULTIPLE_WRITES_MULTIPLE_SNAPSHOTS) { + // New data was added while the stream was down + appendDataAsMultipleSnapshots(expectedCheckpoint); + expected.addAll(Lists.newArrayList(Iterables.concat(Iterables.concat(expectedCheckpoint)))); + + // Stream starts up again from checkpoint read the newly added data and shut down + StreamingQuery restartedQuery = querySource.start(); + restartedQuery.processAllAvailable(); + restartedQuery.stop(); + + // Read data added by the stream + List actual = spark.read() + .load(output.getPath()) + .as(Encoders.bean(SimpleRecord.class)) + .collectAsList(); + Assertions.assertThat(actual).containsExactlyInAnyOrderElementsOf(Iterables.concat(expected)); + } + } + + @Test + public void testParquetOrcAvroDataInOneTable() throws Exception { + List parquetFileRecords = Lists.newArrayList( + new SimpleRecord(1, "one"), + new SimpleRecord(2, "two"), + new SimpleRecord(3, "three")); + + List orcFileRecords = Lists.newArrayList( + new SimpleRecord(4, "four"), + new SimpleRecord(5, "five")); + + List avroFileRecords = Lists.newArrayList( + new SimpleRecord(6, "six"), + new SimpleRecord(7, "seven")); + + appendData(parquetFileRecords); + appendData(orcFileRecords, "orc"); + appendData(avroFileRecords, "avro"); + + StreamingQuery query = startStream(); + Assertions.assertThat(rowsAvailable(query)) + .containsExactlyInAnyOrderElementsOf(Iterables.concat(parquetFileRecords, orcFileRecords, avroFileRecords)); + } + + @Test + public void testReadStreamFromEmptyTable() throws Exception { + StreamingQuery stream = startStream(); + List actual = rowsAvailable(stream); + Assert.assertEquals(Collections.emptyList(), actual); + } + + @Test + public void testReadStreamWithSnapshotTypeOverwriteErrorsOut() throws Exception { + // upgrade table to version 2 - to facilitate creation of Snapshot of type OVERWRITE. + TableOperations ops = ((BaseTable) table).operations(); + TableMetadata meta = ops.current(); + ops.commit(meta, meta.upgradeToFormatVersion(2)); + + // fill table with some initial data + List> dataAcrossSnapshots = TEST_DATA_MULTIPLE_SNAPSHOTS; + appendDataAsMultipleSnapshots(dataAcrossSnapshots); + + Schema deleteRowSchema = table.schema().select("data"); + Record dataDelete = GenericRecord.create(deleteRowSchema); + List dataDeletes = Lists.newArrayList( + dataDelete.copy("data", "one") // id = 1 + ); + + DeleteFile eqDeletes = FileHelpers.writeDeleteFile( + table, Files.localOutput(temp.newFile()), TestHelpers.Row.of(0), dataDeletes, deleteRowSchema); + + table.newRowDelta() + .addDeletes(eqDeletes) + .commit(); + + // check pre-condition - that the above Delete file write - actually resulted in snapshot of type OVERWRITE + Assert.assertEquals(DataOperations.OVERWRITE, table.currentSnapshot().operation()); + + StreamingQuery query = startStream(); + + AssertHelpers.assertThrowsCause( + "Streaming should fail with IllegalStateException, as the snapshot is not of type APPEND", + IllegalStateException.class, + "Cannot process overwrite snapshot", + () -> query.processAllAvailable() + ); + } + + @Test + public void testReadStreamWithSnapshotTypeReplaceIgnoresReplace() throws Exception { + // fill table with some data + List> expected = TEST_DATA_MULTIPLE_SNAPSHOTS; + appendDataAsMultipleSnapshots(expected); + + // this should create a snapshot with type Replace. + table.rewriteManifests() + .clusterBy(f -> 1) + .commit(); + + // check pre-condition + Assert.assertEquals(DataOperations.REPLACE, table.currentSnapshot().operation()); + + StreamingQuery query = startStream(); + List actual = rowsAvailable(query); + Assertions.assertThat(actual).containsExactlyInAnyOrderElementsOf(Iterables.concat(expected)); + } + + @Test + public void testReadStreamWithSnapshotTypeDeleteErrorsOut() throws Exception { + table.updateSpec() + .removeField("id_bucket") + .addField(ref("id")) + .commit(); + + // fill table with some data + List> dataAcrossSnapshots = TEST_DATA_MULTIPLE_SNAPSHOTS; + appendDataAsMultipleSnapshots(dataAcrossSnapshots); + + // this should create a snapshot with type delete. + table.newDelete() + .deleteFromRowFilter(Expressions.equal("id", 4)) + .commit(); + + // check pre-condition - that the above delete operation on table resulted in Snapshot of Type DELETE. + Assert.assertEquals(DataOperations.DELETE, table.currentSnapshot().operation()); + + StreamingQuery query = startStream(); + + AssertHelpers.assertThrowsCause( + "Streaming should fail with IllegalStateException, as the snapshot is not of type APPEND", + IllegalStateException.class, + "Cannot process delete snapshot", + () -> query.processAllAvailable() + ); + } + + @Test + public void testReadStreamWithSnapshotTypeDeleteAndSkipDeleteOption() throws Exception { + table.updateSpec() + .removeField("id_bucket") + .addField(ref("id")) + .commit(); + + // fill table with some data + List> dataAcrossSnapshots = TEST_DATA_MULTIPLE_SNAPSHOTS; + appendDataAsMultipleSnapshots(dataAcrossSnapshots); + + // this should create a snapshot with type delete. + table.newDelete() + .deleteFromRowFilter(Expressions.equal("id", 4)) + .commit(); + + // check pre-condition - that the above delete operation on table resulted in Snapshot of Type DELETE. + Assert.assertEquals(DataOperations.DELETE, table.currentSnapshot().operation()); + + StreamingQuery query = startStream(SparkReadOptions.STREAMING_SKIP_DELETE_SNAPSHOTS, "true"); + Assertions.assertThat(rowsAvailable(query)) + .containsExactlyInAnyOrderElementsOf(Iterables.concat(dataAcrossSnapshots)); + } + + @Test + public void testReadStreamWithSnapshotTypeDeleteAndSkipOverwriteOption() throws Exception { + table.updateSpec() + .removeField("id_bucket") + .addField(ref("id")) + .commit(); + + // fill table with some data + List> dataAcrossSnapshots = TEST_DATA_MULTIPLE_SNAPSHOTS; + appendDataAsMultipleSnapshots(dataAcrossSnapshots); + + // this should create a snapshot with type overwrite. + table.newOverwrite() + .overwriteByRowFilter(Expressions.greaterThan("id", 4)) + .commit(); + + // check pre-condition - that the above delete operation on table resulted in Snapshot of Type OVERWRITE. + Assert.assertEquals(DataOperations.OVERWRITE, table.currentSnapshot().operation()); + + StreamingQuery query = startStream(SparkReadOptions.STREAMING_SKIP_OVERWRITE_SNAPSHOTS, "true"); + Assertions.assertThat(rowsAvailable(query)) + .containsExactlyInAnyOrderElementsOf(Iterables.concat(dataAcrossSnapshots)); + } + + /** + * appends each list as a Snapshot on the iceberg table at the given location. + * accepts a list of lists - each list representing data per snapshot. + */ + private void appendDataAsMultipleSnapshots(List> data) { + for (List l : data) { + appendData(l); + } + } + + private void appendData(List data) { + appendData(data, "parquet"); + } + + private void appendData(List data, String format) { + Dataset df = spark.createDataFrame(data, SimpleRecord.class); + df.select("id", "data").write() + .format("iceberg") + .option("write-format", format) + .mode("append") + .save(tableName); + } + + private static final String MEMORY_TABLE = "_stream_view_mem"; + + private StreamingQuery startStream(Map options) throws TimeoutException { + return spark.readStream() + .options(options) + .format("iceberg") + .load(tableName) + .writeStream() + .options(options) + .format("memory") + .queryName(MEMORY_TABLE) + .outputMode(OutputMode.Append()) + .start(); + } + + private StreamingQuery startStream() throws TimeoutException { + return startStream(Collections.emptyMap()); + } + + private StreamingQuery startStream(String key, String value) throws TimeoutException { + return startStream(ImmutableMap.of(key, value)); + } + + private List rowsAvailable(StreamingQuery query) { + query.processAllAvailable(); + return spark.sql("select * from " + MEMORY_TABLE) + .as(Encoders.bean(SimpleRecord.class)) + .collectAsList(); + } + +} diff --git a/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/source/TestTables.java b/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/source/TestTables.java new file mode 100644 index 000000000000..ae3b7aa7a785 --- /dev/null +++ b/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/source/TestTables.java @@ -0,0 +1,206 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iceberg.spark.source; + +import java.io.File; +import java.util.Map; +import org.apache.iceberg.BaseTable; +import org.apache.iceberg.Files; +import org.apache.iceberg.LocationProviders; +import org.apache.iceberg.PartitionSpec; +import org.apache.iceberg.Schema; +import org.apache.iceberg.Snapshot; +import org.apache.iceberg.TableMetadata; +import org.apache.iceberg.TableOperations; +import org.apache.iceberg.exceptions.AlreadyExistsException; +import org.apache.iceberg.exceptions.CommitFailedException; +import org.apache.iceberg.exceptions.RuntimeIOException; +import org.apache.iceberg.io.FileIO; +import org.apache.iceberg.io.InputFile; +import org.apache.iceberg.io.LocationProvider; +import org.apache.iceberg.io.OutputFile; +import org.apache.iceberg.relocated.com.google.common.base.Preconditions; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap; +import org.apache.iceberg.relocated.com.google.common.collect.Maps; + +// TODO: Use the copy of this from core. +class TestTables { + private TestTables() { + } + + static TestTable create(File temp, String name, Schema schema, PartitionSpec spec) { + TestTableOperations ops = new TestTableOperations(name); + if (ops.current() != null) { + throw new AlreadyExistsException("Table %s already exists at location: %s", name, temp); + } + ops.commit(null, TableMetadata.newTableMetadata(schema, spec, temp.toString(), ImmutableMap.of())); + return new TestTable(ops, name); + } + + static TestTable load(String name) { + TestTableOperations ops = new TestTableOperations(name); + if (ops.current() == null) { + return null; + } + return new TestTable(ops, name); + } + + static boolean drop(String name) { + synchronized (METADATA) { + return METADATA.remove(name) != null; + } + } + + static class TestTable extends BaseTable { + private final TestTableOperations ops; + + private TestTable(TestTableOperations ops, String name) { + super(ops, name); + this.ops = ops; + } + + @Override + public TestTableOperations operations() { + return ops; + } + } + + private static final Map METADATA = Maps.newHashMap(); + + static void clearTables() { + synchronized (METADATA) { + METADATA.clear(); + } + } + + static TableMetadata readMetadata(String tableName) { + synchronized (METADATA) { + return METADATA.get(tableName); + } + } + + static void replaceMetadata(String tableName, TableMetadata metadata) { + synchronized (METADATA) { + METADATA.put(tableName, metadata); + } + } + + static class TestTableOperations implements TableOperations { + + private final String tableName; + private TableMetadata current = null; + private long lastSnapshotId = 0; + private int failCommits = 0; + + TestTableOperations(String tableName) { + this.tableName = tableName; + refresh(); + if (current != null) { + for (Snapshot snap : current.snapshots()) { + this.lastSnapshotId = Math.max(lastSnapshotId, snap.snapshotId()); + } + } else { + this.lastSnapshotId = 0; + } + } + + void failCommits(int numFailures) { + this.failCommits = numFailures; + } + + @Override + public TableMetadata current() { + return current; + } + + @Override + public TableMetadata refresh() { + synchronized (METADATA) { + this.current = METADATA.get(tableName); + } + return current; + } + + @Override + public void commit(TableMetadata base, TableMetadata metadata) { + if (base != current) { + throw new CommitFailedException("Cannot commit changes based on stale metadata"); + } + synchronized (METADATA) { + refresh(); + if (base == current) { + if (failCommits > 0) { + this.failCommits -= 1; + throw new CommitFailedException("Injected failure"); + } + METADATA.put(tableName, metadata); + this.current = metadata; + } else { + throw new CommitFailedException( + "Commit failed: table was updated at %d", base.lastUpdatedMillis()); + } + } + } + + @Override + public FileIO io() { + return new LocalFileIO(); + } + + @Override + public LocationProvider locationProvider() { + Preconditions.checkNotNull(current, + "Current metadata should not be null when locationProvider is called"); + return LocationProviders.locationsFor(current.location(), current.properties()); + } + + @Override + public String metadataFileLocation(String fileName) { + return new File(new File(current.location(), "metadata"), fileName).getAbsolutePath(); + } + + @Override + public long newSnapshotId() { + long nextSnapshotId = lastSnapshotId + 1; + this.lastSnapshotId = nextSnapshotId; + return nextSnapshotId; + } + } + + static class LocalFileIO implements FileIO { + + @Override + public InputFile newInputFile(String path) { + return Files.localInput(path); + } + + @Override + public OutputFile newOutputFile(String path) { + return Files.localOutput(new File(path)); + } + + @Override + public void deleteFile(String path) { + if (!new File(path).delete()) { + throw new RuntimeIOException("Failed to delete file: " + path); + } + } + } +} diff --git a/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/source/TestTimestampWithoutZone.java b/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/source/TestTimestampWithoutZone.java new file mode 100644 index 000000000000..20509eef7471 --- /dev/null +++ b/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/source/TestTimestampWithoutZone.java @@ -0,0 +1,290 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iceberg.spark.source; + +import java.io.File; +import java.io.IOException; +import java.time.LocalDateTime; +import java.util.List; +import java.util.Locale; +import java.util.UUID; +import java.util.stream.Collectors; +import java.util.stream.Stream; +import org.apache.hadoop.conf.Configuration; +import org.apache.iceberg.AssertHelpers; +import org.apache.iceberg.DataFile; +import org.apache.iceberg.DataFiles; +import org.apache.iceberg.FileFormat; +import org.apache.iceberg.PartitionSpec; +import org.apache.iceberg.Schema; +import org.apache.iceberg.Table; +import org.apache.iceberg.data.GenericAppenderFactory; +import org.apache.iceberg.data.GenericRecord; +import org.apache.iceberg.data.Record; +import org.apache.iceberg.hadoop.HadoopTables; +import org.apache.iceberg.io.FileAppender; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap; +import org.apache.iceberg.relocated.com.google.common.collect.Lists; +import org.apache.iceberg.spark.SparkReadOptions; +import org.apache.iceberg.spark.SparkSQLProperties; +import org.apache.iceberg.spark.SparkTestBase; +import org.apache.iceberg.spark.SparkUtil; +import org.apache.iceberg.spark.SparkWriteOptions; +import org.apache.iceberg.spark.data.GenericsHelpers; +import org.apache.iceberg.types.Types; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.SaveMode; +import org.apache.spark.sql.SparkSession; +import org.junit.AfterClass; +import org.junit.Assert; +import org.junit.Before; +import org.junit.BeforeClass; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.ExpectedException; +import org.junit.rules.TemporaryFolder; +import org.junit.runner.RunWith; +import org.junit.runners.Parameterized; + +import static org.apache.iceberg.Files.localOutput; + +@RunWith(Parameterized.class) +public class TestTimestampWithoutZone extends SparkTestBase { + private static final Configuration CONF = new Configuration(); + private static final HadoopTables TABLES = new HadoopTables(CONF); + + private static final Schema SCHEMA = new Schema( + Types.NestedField.required(1, "id", Types.LongType.get()), + Types.NestedField.optional(2, "ts", Types.TimestampType.withoutZone()), + Types.NestedField.optional(3, "data", Types.StringType.get()) + ); + + private static SparkSession spark = null; + + @BeforeClass + public static void startSpark() { + TestTimestampWithoutZone.spark = SparkSession.builder().master("local[2]").getOrCreate(); + } + + @AfterClass + public static void stopSpark() { + SparkSession currentSpark = TestTimestampWithoutZone.spark; + TestTimestampWithoutZone.spark = null; + currentSpark.stop(); + } + + @Rule + public TemporaryFolder temp = new TemporaryFolder(); + + private final String format; + private final boolean vectorized; + + @Parameterized.Parameters(name = "format = {0}, vectorized = {1}") + public static Object[][] parameters() { + return new Object[][] { + { "parquet", false }, + { "parquet", true }, + { "avro", false } + }; + } + + public TestTimestampWithoutZone(String format, boolean vectorized) { + this.format = format; + this.vectorized = vectorized; + } + + private File parent = null; + private File unpartitioned = null; + private List records = null; + + @Before + public void writeUnpartitionedTable() throws IOException { + this.parent = temp.newFolder("TestTimestampWithoutZone"); + this.unpartitioned = new File(parent, "unpartitioned"); + File dataFolder = new File(unpartitioned, "data"); + Assert.assertTrue("Mkdir should succeed", dataFolder.mkdirs()); + + Table table = TABLES.create(SCHEMA, PartitionSpec.unpartitioned(), unpartitioned.toString()); + Schema tableSchema = table.schema(); // use the table schema because ids are reassigned + + FileFormat fileFormat = FileFormat.valueOf(format.toUpperCase(Locale.ENGLISH)); + + File testFile = new File(dataFolder, fileFormat.addExtension(UUID.randomUUID().toString())); + + // create records using the table's schema + this.records = testRecords(tableSchema); + + try (FileAppender writer = new GenericAppenderFactory(tableSchema).newAppender( + localOutput(testFile), fileFormat)) { + writer.addAll(records); + } + + DataFile file = DataFiles.builder(PartitionSpec.unpartitioned()) + .withRecordCount(records.size()) + .withFileSizeInBytes(testFile.length()) + .withPath(testFile.toString()) + .build(); + + table.newAppend().appendFile(file).commit(); + } + + @Test + public void testUnpartitionedTimestampWithoutZone() { + assertEqualsSafe(SCHEMA.asStruct(), records, read(unpartitioned.toString(), vectorized)); + } + + @Test + public void testUnpartitionedTimestampWithoutZoneProjection() { + Schema projection = SCHEMA.select("id", "ts"); + assertEqualsSafe(projection.asStruct(), + records.stream().map(r -> projectFlat(projection, r)).collect(Collectors.toList()), + read(unpartitioned.toString(), vectorized, "id", "ts")); + } + + @Rule + public ExpectedException exception = ExpectedException.none(); + + @Test + public void testUnpartitionedTimestampWithoutZoneError() { + AssertHelpers.assertThrows(String.format("Read operation performed on a timestamp without timezone field while " + + "'%s' set to false should throw exception", + SparkReadOptions.HANDLE_TIMESTAMP_WITHOUT_TIMEZONE), + IllegalArgumentException.class, + SparkUtil.TIMESTAMP_WITHOUT_TIMEZONE_ERROR, + () -> spark.read().format("iceberg") + .option(SparkReadOptions.VECTORIZATION_ENABLED, String.valueOf(vectorized)) + .option(SparkReadOptions.HANDLE_TIMESTAMP_WITHOUT_TIMEZONE, "false") + .load(unpartitioned.toString()) + .collectAsList()); + } + + @Test + public void testUnpartitionedTimestampWithoutZoneAppend() { + spark.read().format("iceberg") + .option(SparkReadOptions.HANDLE_TIMESTAMP_WITHOUT_TIMEZONE, "true") + .option(SparkReadOptions.VECTORIZATION_ENABLED, String.valueOf(vectorized)) + .load(unpartitioned.toString()) + .write() + .format("iceberg") + .option(SparkWriteOptions.HANDLE_TIMESTAMP_WITHOUT_TIMEZONE, "true") + .mode(SaveMode.Append) + .save(unpartitioned.toString()); + + assertEqualsSafe(SCHEMA.asStruct(), + Stream.concat(records.stream(), records.stream()).collect(Collectors.toList()), + read(unpartitioned.toString(), vectorized)); + } + + @Test + public void testUnpartitionedTimestampWithoutZoneWriteError() { + String errorMessage = String.format("Write operation performed on a timestamp without timezone field while " + + "'%s' set to false should throw exception", + SparkSQLProperties.HANDLE_TIMESTAMP_WITHOUT_TIMEZONE); + Runnable writeOperation = () -> spark.read().format("iceberg") + .option(SparkReadOptions.HANDLE_TIMESTAMP_WITHOUT_TIMEZONE, "true") + .option(SparkReadOptions.VECTORIZATION_ENABLED, String.valueOf(vectorized)) + .load(unpartitioned.toString()) + .write() + .format("iceberg") + .option(SparkWriteOptions.HANDLE_TIMESTAMP_WITHOUT_TIMEZONE, "false") + .mode(SaveMode.Append) + .save(unpartitioned.toString()); + + AssertHelpers.assertThrows(errorMessage, IllegalArgumentException.class, + SparkUtil.TIMESTAMP_WITHOUT_TIMEZONE_ERROR, writeOperation); + + } + + @Test + public void testUnpartitionedTimestampWithoutZoneSessionProperties() { + withSQLConf(ImmutableMap.of(SparkSQLProperties.HANDLE_TIMESTAMP_WITHOUT_TIMEZONE, "true"), () -> { + spark.read().format("iceberg") + .option(SparkReadOptions.VECTORIZATION_ENABLED, String.valueOf(vectorized)) + .load(unpartitioned.toString()) + .write() + .format("iceberg") + .mode(SaveMode.Append) + .save(unpartitioned.toString()); + + assertEqualsSafe(SCHEMA.asStruct(), + Stream.concat(records.stream(), records.stream()).collect(Collectors.toList()), + read(unpartitioned.toString(), vectorized)); + }); + } + + private static Record projectFlat(Schema projection, Record record) { + Record result = GenericRecord.create(projection); + List fields = projection.asStruct().fields(); + for (int i = 0; i < fields.size(); i += 1) { + Types.NestedField field = fields.get(i); + result.set(i, record.getField(field.name())); + } + return result; + } + + public static void assertEqualsSafe(Types.StructType struct, + List expected, List actual) { + Assert.assertEquals("Number of results should match expected", expected.size(), actual.size()); + for (int i = 0; i < expected.size(); i += 1) { + GenericsHelpers.assertEqualsSafe(struct, expected.get(i), actual.get(i)); + } + } + + private List testRecords(Schema schema) { + return Lists.newArrayList( + record(schema, 0L, parseToLocal("2017-12-22T09:20:44.294658"), "junction"), + record(schema, 1L, parseToLocal("2017-12-22T07:15:34.582910"), "alligator"), + record(schema, 2L, parseToLocal("2017-12-22T06:02:09.243857"), "forrest"), + record(schema, 3L, parseToLocal("2017-12-22T03:10:11.134509"), "clapping"), + record(schema, 4L, parseToLocal("2017-12-22T00:34:00.184671"), "brush"), + record(schema, 5L, parseToLocal("2017-12-21T22:20:08.935889"), "trap"), + record(schema, 6L, parseToLocal("2017-12-21T21:55:30.589712"), "element"), + record(schema, 7L, parseToLocal("2017-12-21T17:31:14.532797"), "limited"), + record(schema, 8L, parseToLocal("2017-12-21T15:21:51.237521"), "global"), + record(schema, 9L, parseToLocal("2017-12-21T15:02:15.230570"), "goldfish") + ); + } + + private static List read(String table, boolean vectorized) { + return read(table, vectorized, "*"); + } + + private static List read(String table, boolean vectorized, String select0, String... selectN) { + Dataset dataset = spark.read().format("iceberg") + .option(SparkReadOptions.VECTORIZATION_ENABLED, String.valueOf(vectorized)) + .option(SparkReadOptions.HANDLE_TIMESTAMP_WITHOUT_TIMEZONE, "true") + .load(table) + .select(select0, selectN); + return dataset.collectAsList(); + } + + private static LocalDateTime parseToLocal(String timestamp) { + return LocalDateTime.parse(timestamp); + } + + private static Record record(Schema schema, Object... values) { + Record rec = GenericRecord.create(schema); + for (int i = 0; i < values.length; i += 1) { + rec.set(i, values[i]); + } + return rec; + } +} diff --git a/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/source/TestWriteMetricsConfig.java b/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/source/TestWriteMetricsConfig.java new file mode 100644 index 000000000000..7ed71031f3f2 --- /dev/null +++ b/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/source/TestWriteMetricsConfig.java @@ -0,0 +1,306 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iceberg.spark.source; + +import java.io.IOException; +import java.nio.ByteBuffer; +import java.util.List; +import java.util.Map; +import org.apache.hadoop.conf.Configuration; +import org.apache.iceberg.AssertHelpers; +import org.apache.iceberg.DataFile; +import org.apache.iceberg.FileScanTask; +import org.apache.iceberg.PartitionSpec; +import org.apache.iceberg.Schema; +import org.apache.iceberg.Table; +import org.apache.iceberg.TableProperties; +import org.apache.iceberg.exceptions.ValidationException; +import org.apache.iceberg.hadoop.HadoopTables; +import org.apache.iceberg.relocated.com.google.common.collect.Lists; +import org.apache.iceberg.relocated.com.google.common.collect.Maps; +import org.apache.iceberg.spark.SparkWriteOptions; +import org.apache.iceberg.spark.data.RandomData; +import org.apache.iceberg.types.Types; +import org.apache.iceberg.util.ByteBuffers; +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.SaveMode; +import org.apache.spark.sql.SparkSession; +import org.apache.spark.sql.catalyst.InternalRow; +import org.junit.AfterClass; +import org.junit.Assert; +import org.junit.BeforeClass; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.TemporaryFolder; + +import static org.apache.iceberg.spark.SparkSchemaUtil.convert; +import static org.apache.iceberg.types.Types.NestedField.optional; +import static org.apache.iceberg.types.Types.NestedField.required; + +public class TestWriteMetricsConfig { + + private static final Configuration CONF = new Configuration(); + private static final Schema SIMPLE_SCHEMA = new Schema( + optional(1, "id", Types.IntegerType.get()), + optional(2, "data", Types.StringType.get()) + ); + private static final Schema COMPLEX_SCHEMA = new Schema( + required(1, "longCol", Types.IntegerType.get()), + optional(2, "strCol", Types.StringType.get()), + required(3, "record", Types.StructType.of( + required(4, "id", Types.IntegerType.get()), + required(5, "data", Types.StringType.get()) + )) + ); + + @Rule + public TemporaryFolder temp = new TemporaryFolder(); + + private static SparkSession spark = null; + private static JavaSparkContext sc = null; + + @BeforeClass + public static void startSpark() { + TestWriteMetricsConfig.spark = SparkSession.builder().master("local[2]").getOrCreate(); + TestWriteMetricsConfig.sc = JavaSparkContext.fromSparkContext(spark.sparkContext()); + } + + @AfterClass + public static void stopSpark() { + SparkSession currentSpark = TestWriteMetricsConfig.spark; + TestWriteMetricsConfig.spark = null; + TestWriteMetricsConfig.sc = null; + currentSpark.stop(); + } + + @Test + public void testFullMetricsCollectionForParquet() throws IOException { + String tableLocation = temp.newFolder("iceberg-table").toString(); + + HadoopTables tables = new HadoopTables(CONF); + PartitionSpec spec = PartitionSpec.unpartitioned(); + Map properties = Maps.newHashMap(); + properties.put(TableProperties.DEFAULT_WRITE_METRICS_MODE, "full"); + Table table = tables.create(SIMPLE_SCHEMA, spec, properties, tableLocation); + + List expectedRecords = Lists.newArrayList( + new SimpleRecord(1, "a"), + new SimpleRecord(2, "b"), + new SimpleRecord(3, "c") + ); + Dataset df = spark.createDataFrame(expectedRecords, SimpleRecord.class); + df.select("id", "data") + .coalesce(1) + .write() + .format("iceberg") + .option(SparkWriteOptions.WRITE_FORMAT, "parquet") + .mode(SaveMode.Append) + .save(tableLocation); + + for (FileScanTask task : table.newScan().includeColumnStats().planFiles()) { + DataFile file = task.file(); + Assert.assertEquals(2, file.nullValueCounts().size()); + Assert.assertEquals(2, file.valueCounts().size()); + Assert.assertEquals(2, file.lowerBounds().size()); + Assert.assertEquals(2, file.upperBounds().size()); + } + } + + @Test + public void testCountMetricsCollectionForParquet() throws IOException { + String tableLocation = temp.newFolder("iceberg-table").toString(); + + HadoopTables tables = new HadoopTables(CONF); + PartitionSpec spec = PartitionSpec.unpartitioned(); + Map properties = Maps.newHashMap(); + properties.put(TableProperties.DEFAULT_WRITE_METRICS_MODE, "counts"); + Table table = tables.create(SIMPLE_SCHEMA, spec, properties, tableLocation); + + List expectedRecords = Lists.newArrayList( + new SimpleRecord(1, "a"), + new SimpleRecord(2, "b"), + new SimpleRecord(3, "c") + ); + Dataset df = spark.createDataFrame(expectedRecords, SimpleRecord.class); + df.select("id", "data") + .coalesce(1) + .write() + .format("iceberg") + .option(SparkWriteOptions.WRITE_FORMAT, "parquet") + .mode(SaveMode.Append) + .save(tableLocation); + + for (FileScanTask task : table.newScan().includeColumnStats().planFiles()) { + DataFile file = task.file(); + Assert.assertEquals(2, file.nullValueCounts().size()); + Assert.assertEquals(2, file.valueCounts().size()); + Assert.assertTrue(file.lowerBounds().isEmpty()); + Assert.assertTrue(file.upperBounds().isEmpty()); + } + } + + @Test + public void testNoMetricsCollectionForParquet() throws IOException { + String tableLocation = temp.newFolder("iceberg-table").toString(); + + HadoopTables tables = new HadoopTables(CONF); + PartitionSpec spec = PartitionSpec.unpartitioned(); + Map properties = Maps.newHashMap(); + properties.put(TableProperties.DEFAULT_WRITE_METRICS_MODE, "none"); + Table table = tables.create(SIMPLE_SCHEMA, spec, properties, tableLocation); + + List expectedRecords = Lists.newArrayList( + new SimpleRecord(1, "a"), + new SimpleRecord(2, "b"), + new SimpleRecord(3, "c") + ); + Dataset df = spark.createDataFrame(expectedRecords, SimpleRecord.class); + df.select("id", "data") + .coalesce(1) + .write() + .format("iceberg") + .option(SparkWriteOptions.WRITE_FORMAT, "parquet") + .mode(SaveMode.Append) + .save(tableLocation); + + for (FileScanTask task : table.newScan().includeColumnStats().planFiles()) { + DataFile file = task.file(); + Assert.assertTrue(file.nullValueCounts().isEmpty()); + Assert.assertTrue(file.valueCounts().isEmpty()); + Assert.assertTrue(file.lowerBounds().isEmpty()); + Assert.assertTrue(file.upperBounds().isEmpty()); + } + } + + @Test + public void testCustomMetricCollectionForParquet() throws IOException { + String tableLocation = temp.newFolder("iceberg-table").toString(); + + HadoopTables tables = new HadoopTables(CONF); + PartitionSpec spec = PartitionSpec.unpartitioned(); + Map properties = Maps.newHashMap(); + properties.put(TableProperties.DEFAULT_WRITE_METRICS_MODE, "counts"); + properties.put("write.metadata.metrics.column.id", "full"); + Table table = tables.create(SIMPLE_SCHEMA, spec, properties, tableLocation); + + List expectedRecords = Lists.newArrayList( + new SimpleRecord(1, "a"), + new SimpleRecord(2, "b"), + new SimpleRecord(3, "c") + ); + Dataset df = spark.createDataFrame(expectedRecords, SimpleRecord.class); + df.select("id", "data") + .coalesce(1) + .write() + .format("iceberg") + .option(SparkWriteOptions.WRITE_FORMAT, "parquet") + .mode(SaveMode.Append) + .save(tableLocation); + + Schema schema = table.schema(); + Types.NestedField id = schema.findField("id"); + for (FileScanTask task : table.newScan().includeColumnStats().planFiles()) { + DataFile file = task.file(); + Assert.assertEquals(2, file.nullValueCounts().size()); + Assert.assertEquals(2, file.valueCounts().size()); + Assert.assertEquals(1, file.lowerBounds().size()); + Assert.assertTrue(file.lowerBounds().containsKey(id.fieldId())); + Assert.assertEquals(1, file.upperBounds().size()); + Assert.assertTrue(file.upperBounds().containsKey(id.fieldId())); + } + } + + @Test + public void testBadCustomMetricCollectionForParquet() throws IOException { + String tableLocation = temp.newFolder("iceberg-table").toString(); + + HadoopTables tables = new HadoopTables(CONF); + PartitionSpec spec = PartitionSpec.unpartitioned(); + Map properties = Maps.newHashMap(); + properties.put(TableProperties.DEFAULT_WRITE_METRICS_MODE, "counts"); + properties.put("write.metadata.metrics.column.ids", "full"); + + AssertHelpers.assertThrows("Creating a table with invalid metrics should fail", + ValidationException.class, + null, + () -> tables.create(SIMPLE_SCHEMA, spec, properties, tableLocation)); + } + + @Test + public void testCustomMetricCollectionForNestedParquet() throws IOException { + String tableLocation = temp.newFolder("iceberg-table").toString(); + + HadoopTables tables = new HadoopTables(CONF); + PartitionSpec spec = PartitionSpec.builderFor(COMPLEX_SCHEMA) + .identity("strCol") + .build(); + Map properties = Maps.newHashMap(); + properties.put(TableProperties.DEFAULT_WRITE_METRICS_MODE, "none"); + properties.put("write.metadata.metrics.column.longCol", "counts"); + properties.put("write.metadata.metrics.column.record.id", "full"); + properties.put("write.metadata.metrics.column.record.data", "truncate(2)"); + Table table = tables.create(COMPLEX_SCHEMA, spec, properties, tableLocation); + + Iterable rows = RandomData.generateSpark(COMPLEX_SCHEMA, 10, 0); + JavaRDD rdd = sc.parallelize(Lists.newArrayList(rows)); + Dataset df = spark.internalCreateDataFrame(JavaRDD.toRDD(rdd), convert(COMPLEX_SCHEMA), false); + + df.coalesce(1).write() + .format("iceberg") + .option(SparkWriteOptions.WRITE_FORMAT, "parquet") + .mode(SaveMode.Append) + .save(tableLocation); + + Schema schema = table.schema(); + Types.NestedField longCol = schema.findField("longCol"); + Types.NestedField recordId = schema.findField("record.id"); + Types.NestedField recordData = schema.findField("record.data"); + for (FileScanTask task : table.newScan().includeColumnStats().planFiles()) { + DataFile file = task.file(); + + Map nullValueCounts = file.nullValueCounts(); + Assert.assertEquals(3, nullValueCounts.size()); + Assert.assertTrue(nullValueCounts.containsKey(longCol.fieldId())); + Assert.assertTrue(nullValueCounts.containsKey(recordId.fieldId())); + Assert.assertTrue(nullValueCounts.containsKey(recordData.fieldId())); + + Map valueCounts = file.valueCounts(); + Assert.assertEquals(3, valueCounts.size()); + Assert.assertTrue(valueCounts.containsKey(longCol.fieldId())); + Assert.assertTrue(valueCounts.containsKey(recordId.fieldId())); + Assert.assertTrue(valueCounts.containsKey(recordData.fieldId())); + + Map lowerBounds = file.lowerBounds(); + Assert.assertEquals(2, lowerBounds.size()); + Assert.assertTrue(lowerBounds.containsKey(recordId.fieldId())); + ByteBuffer recordDataLowerBound = lowerBounds.get(recordData.fieldId()); + Assert.assertEquals(2, ByteBuffers.toByteArray(recordDataLowerBound).length); + + Map upperBounds = file.upperBounds(); + Assert.assertEquals(2, upperBounds.size()); + Assert.assertTrue(upperBounds.containsKey(recordId.fieldId())); + ByteBuffer recordDataUpperBound = upperBounds.get(recordData.fieldId()); + Assert.assertEquals(2, ByteBuffers.toByteArray(recordDataUpperBound).length); + } + } +} diff --git a/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/source/ThreeColumnRecord.java b/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/source/ThreeColumnRecord.java new file mode 100644 index 000000000000..684dfbb255c7 --- /dev/null +++ b/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/source/ThreeColumnRecord.java @@ -0,0 +1,89 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iceberg.spark.source; + +import java.util.Objects; + +public class ThreeColumnRecord { + private Integer c1; + private String c2; + private String c3; + + public ThreeColumnRecord() { + } + + public ThreeColumnRecord(Integer c1, String c2, String c3) { + this.c1 = c1; + this.c2 = c2; + this.c3 = c3; + } + + public Integer getC1() { + return c1; + } + + public void setC1(Integer c1) { + this.c1 = c1; + } + + public String getC2() { + return c2; + } + + public void setC2(String c2) { + this.c2 = c2; + } + + public String getC3() { + return c3; + } + + public void setC3(String c3) { + this.c3 = c3; + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + ThreeColumnRecord that = (ThreeColumnRecord) o; + return Objects.equals(c1, that.c1) && + Objects.equals(c2, that.c2) && + Objects.equals(c3, that.c3); + } + + @Override + public int hashCode() { + return Objects.hash(c1, c2, c3); + } + + @Override + public String toString() { + return "ThreeColumnRecord{" + + "c1=" + c1 + + ", c2='" + c2 + '\'' + + ", c3='" + c3 + '\'' + + '}'; + } +} diff --git a/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/sql/TestAlterTable.java b/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/sql/TestAlterTable.java new file mode 100644 index 000000000000..6172bd1fd0fe --- /dev/null +++ b/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/sql/TestAlterTable.java @@ -0,0 +1,276 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iceberg.spark.sql; + +import java.util.Map; +import org.apache.iceberg.AssertHelpers; +import org.apache.iceberg.catalog.Namespace; +import org.apache.iceberg.catalog.TableIdentifier; +import org.apache.iceberg.hadoop.HadoopCatalog; +import org.apache.iceberg.spark.SparkCatalogTestBase; +import org.apache.iceberg.types.Types; +import org.apache.iceberg.types.Types.NestedField; +import org.apache.spark.SparkException; +import org.apache.spark.sql.AnalysisException; +import org.junit.After; +import org.junit.Assert; +import org.junit.Assume; +import org.junit.Before; +import org.junit.Test; + +public class TestAlterTable extends SparkCatalogTestBase { + private final TableIdentifier renamedIdent = TableIdentifier.of(Namespace.of("default"), "table2"); + + public TestAlterTable(String catalogName, String implementation, Map config) { + super(catalogName, implementation, config); + } + + @Before + public void createTable() { + sql("CREATE TABLE %s (id bigint NOT NULL, data string) USING iceberg", tableName); + } + + @After + public void removeTable() { + sql("DROP TABLE IF EXISTS %s", tableName); + sql("DROP TABLE IF EXISTS %s2", tableName); + } + + @Test + public void testAddColumnNotNull() { + AssertHelpers.assertThrows("Should reject adding NOT NULL column", + SparkException.class, "Incompatible change: cannot add required column", + () -> sql("ALTER TABLE %s ADD COLUMN c3 INT NOT NULL", tableName)); + } + + @Test + public void testAddColumn() { + sql("ALTER TABLE %s ADD COLUMN point struct AFTER id", tableName); + + Types.StructType expectedSchema = Types.StructType.of( + NestedField.required(1, "id", Types.LongType.get()), + NestedField.optional(3, "point", Types.StructType.of( + NestedField.required(4, "x", Types.DoubleType.get()), + NestedField.required(5, "y", Types.DoubleType.get()) + )), + NestedField.optional(2, "data", Types.StringType.get())); + + Assert.assertEquals("Schema should match expected", + expectedSchema, validationCatalog.loadTable(tableIdent).schema().asStruct()); + + sql("ALTER TABLE %s ADD COLUMN point.z double COMMENT 'May be null' FIRST", tableName); + + Types.StructType expectedSchema2 = Types.StructType.of( + NestedField.required(1, "id", Types.LongType.get()), + NestedField.optional(3, "point", Types.StructType.of( + NestedField.optional(6, "z", Types.DoubleType.get(), "May be null"), + NestedField.required(4, "x", Types.DoubleType.get()), + NestedField.required(5, "y", Types.DoubleType.get()) + )), + NestedField.optional(2, "data", Types.StringType.get())); + + Assert.assertEquals("Schema should match expected", + expectedSchema2, validationCatalog.loadTable(tableIdent).schema().asStruct()); + } + + @Test + public void testAddColumnWithArray() { + sql("ALTER TABLE %s ADD COLUMN data2 array>", tableName); + // use the implicit column name 'element' to access member of array and add column d to struct. + sql("ALTER TABLE %s ADD COLUMN data2.element.d int", tableName); + Types.StructType expectedSchema = Types.StructType.of( + NestedField.required(1, "id", Types.LongType.get()), + NestedField.optional(2, "data", Types.StringType.get()), + NestedField.optional(3, "data2", Types.ListType.ofOptional( + 4, + Types.StructType.of( + NestedField.optional(5, "a", Types.IntegerType.get()), + NestedField.optional(6, "b", Types.IntegerType.get()), + NestedField.optional(7, "c", Types.IntegerType.get()), + NestedField.optional(8, "d", Types.IntegerType.get())) + ))); + Assert.assertEquals("Schema should match expected", + expectedSchema, validationCatalog.loadTable(tableIdent).schema().asStruct()); + } + + @Test + public void testAddColumnWithMap() { + sql("ALTER TABLE %s ADD COLUMN data2 map, struct>", tableName); + // use the implicit column name 'key' and 'value' to access member of map. + // add column to value struct column + sql("ALTER TABLE %s ADD COLUMN data2.value.c int", tableName); + Types.StructType expectedSchema = Types.StructType.of( + NestedField.required(1, "id", Types.LongType.get()), + NestedField.optional(2, "data", Types.StringType.get()), + NestedField.optional(3, "data2", Types.MapType.ofOptional( + 4, + 5, + Types.StructType.of( + NestedField.optional(6, "x", Types.IntegerType.get())), + Types.StructType.of( + NestedField.optional(7, "a", Types.IntegerType.get()), + NestedField.optional(8, "b", Types.IntegerType.get()), + NestedField.optional(9, "c", Types.IntegerType.get())) + ))); + Assert.assertEquals("Schema should match expected", + expectedSchema, validationCatalog.loadTable(tableIdent).schema().asStruct()); + + // should not allow changing map key column + AssertHelpers.assertThrows("Should reject changing key of the map column", + SparkException.class, "Unsupported table change: Cannot add fields to map keys:", + () -> sql("ALTER TABLE %s ADD COLUMN data2.key.y int", tableName)); + } + + @Test + public void testDropColumn() { + sql("ALTER TABLE %s DROP COLUMN data", tableName); + + Types.StructType expectedSchema = Types.StructType.of( + NestedField.required(1, "id", Types.LongType.get())); + + Assert.assertEquals("Schema should match expected", + expectedSchema, validationCatalog.loadTable(tableIdent).schema().asStruct()); + } + + @Test + public void testRenameColumn() { + sql("ALTER TABLE %s RENAME COLUMN id TO row_id", tableName); + + Types.StructType expectedSchema = Types.StructType.of( + NestedField.required(1, "row_id", Types.LongType.get()), + NestedField.optional(2, "data", Types.StringType.get())); + + Assert.assertEquals("Schema should match expected", + expectedSchema, validationCatalog.loadTable(tableIdent).schema().asStruct()); + } + + @Test + public void testAlterColumnComment() { + sql("ALTER TABLE %s ALTER COLUMN id COMMENT 'Record id'", tableName); + + Types.StructType expectedSchema = Types.StructType.of( + NestedField.required(1, "id", Types.LongType.get(), "Record id"), + NestedField.optional(2, "data", Types.StringType.get())); + + Assert.assertEquals("Schema should match expected", + expectedSchema, validationCatalog.loadTable(tableIdent).schema().asStruct()); + } + + @Test + public void testAlterColumnType() { + sql("ALTER TABLE %s ADD COLUMN count int", tableName); + sql("ALTER TABLE %s ALTER COLUMN count TYPE bigint", tableName); + + Types.StructType expectedSchema = Types.StructType.of( + NestedField.required(1, "id", Types.LongType.get()), + NestedField.optional(2, "data", Types.StringType.get()), + NestedField.optional(3, "count", Types.LongType.get())); + + Assert.assertEquals("Schema should match expected", + expectedSchema, validationCatalog.loadTable(tableIdent).schema().asStruct()); + } + + @Test + public void testAlterColumnDropNotNull() { + sql("ALTER TABLE %s ALTER COLUMN id DROP NOT NULL", tableName); + + Types.StructType expectedSchema = Types.StructType.of( + NestedField.optional(1, "id", Types.LongType.get()), + NestedField.optional(2, "data", Types.StringType.get())); + + Assert.assertEquals("Schema should match expected", + expectedSchema, validationCatalog.loadTable(tableIdent).schema().asStruct()); + } + + @Test + public void testAlterColumnSetNotNull() { + // no-op changes are allowed + sql("ALTER TABLE %s ALTER COLUMN id SET NOT NULL", tableName); + + Types.StructType expectedSchema = Types.StructType.of( + NestedField.required(1, "id", Types.LongType.get()), + NestedField.optional(2, "data", Types.StringType.get())); + + Assert.assertEquals("Schema should match expected", + expectedSchema, validationCatalog.loadTable(tableIdent).schema().asStruct()); + + AssertHelpers.assertThrows("Should reject adding NOT NULL constraint to an optional column", + AnalysisException.class, "Cannot change nullable column to non-nullable: data", + () -> sql("ALTER TABLE %s ALTER COLUMN data SET NOT NULL", tableName)); + } + + @Test + public void testAlterColumnPositionAfter() { + sql("ALTER TABLE %s ADD COLUMN count int", tableName); + sql("ALTER TABLE %s ALTER COLUMN count AFTER id", tableName); + + Types.StructType expectedSchema = Types.StructType.of( + NestedField.required(1, "id", Types.LongType.get()), + NestedField.optional(3, "count", Types.IntegerType.get()), + NestedField.optional(2, "data", Types.StringType.get())); + + Assert.assertEquals("Schema should match expected", + expectedSchema, validationCatalog.loadTable(tableIdent).schema().asStruct()); + } + + @Test + public void testAlterColumnPositionFirst() { + sql("ALTER TABLE %s ADD COLUMN count int", tableName); + sql("ALTER TABLE %s ALTER COLUMN count FIRST", tableName); + + Types.StructType expectedSchema = Types.StructType.of( + NestedField.optional(3, "count", Types.IntegerType.get()), + NestedField.required(1, "id", Types.LongType.get()), + NestedField.optional(2, "data", Types.StringType.get())); + + Assert.assertEquals("Schema should match expected", + expectedSchema, validationCatalog.loadTable(tableIdent).schema().asStruct()); + } + + @Test + public void testTableRename() { + Assume.assumeFalse("Hadoop catalog does not support rename", validationCatalog instanceof HadoopCatalog); + + Assert.assertTrue("Initial name should exist", validationCatalog.tableExists(tableIdent)); + Assert.assertFalse("New name should not exist", validationCatalog.tableExists(renamedIdent)); + + sql("ALTER TABLE %s RENAME TO %s2", tableName, tableName); + + Assert.assertFalse("Initial name should not exist", validationCatalog.tableExists(tableIdent)); + Assert.assertTrue("New name should exist", validationCatalog.tableExists(renamedIdent)); + } + + @Test + public void testSetTableProperties() { + sql("ALTER TABLE %s SET TBLPROPERTIES ('prop'='value')", tableName); + + Assert.assertEquals("Should have the new table property", + "value", validationCatalog.loadTable(tableIdent).properties().get("prop")); + + sql("ALTER TABLE %s UNSET TBLPROPERTIES ('prop')", tableName); + + Assert.assertNull("Should not have the removed table property", + validationCatalog.loadTable(tableIdent).properties().get("prop")); + + AssertHelpers.assertThrows("Cannot specify the 'sort-order' because it's a reserved table property", + UnsupportedOperationException.class, + () -> sql("ALTER TABLE %s SET TBLPROPERTIES ('sort-order'='value')", tableName)); + } +} diff --git a/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/sql/TestCreateTable.java b/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/sql/TestCreateTable.java new file mode 100644 index 000000000000..0a4c9368cb96 --- /dev/null +++ b/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/sql/TestCreateTable.java @@ -0,0 +1,293 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iceberg.spark.sql; + +import java.io.File; +import java.util.Map; +import org.apache.iceberg.AssertHelpers; +import org.apache.iceberg.BaseTable; +import org.apache.iceberg.PartitionSpec; +import org.apache.iceberg.Schema; +import org.apache.iceberg.Table; +import org.apache.iceberg.TableOperations; +import org.apache.iceberg.TableProperties; +import org.apache.iceberg.hadoop.HadoopCatalog; +import org.apache.iceberg.spark.SparkCatalogTestBase; +import org.apache.iceberg.types.Types; +import org.apache.iceberg.types.Types.NestedField; +import org.apache.iceberg.types.Types.StructType; +import org.apache.spark.sql.connector.catalog.TableCatalog; +import org.junit.After; +import org.junit.Assert; +import org.junit.Assume; +import org.junit.Test; + +public class TestCreateTable extends SparkCatalogTestBase { + public TestCreateTable(String catalogName, String implementation, Map config) { + super(catalogName, implementation, config); + } + + @After + public void dropTestTable() { + sql("DROP TABLE IF EXISTS %s", tableName); + } + + @Test + public void testCreateTable() { + Assert.assertFalse("Table should not already exist", validationCatalog.tableExists(tableIdent)); + + sql("CREATE TABLE %s (id BIGINT NOT NULL, data STRING) USING iceberg", tableName); + + Table table = validationCatalog.loadTable(tableIdent); + Assert.assertNotNull("Should load the new table", table); + + StructType expectedSchema = StructType.of( + NestedField.required(1, "id", Types.LongType.get()), + NestedField.optional(2, "data", Types.StringType.get())); + Assert.assertEquals("Should have the expected schema", expectedSchema, table.schema().asStruct()); + Assert.assertEquals("Should not be partitioned", 0, table.spec().fields().size()); + Assert.assertNull("Should not have the default format set", + table.properties().get(TableProperties.DEFAULT_FILE_FORMAT)); + } + + @Test + public void testCreateTableInRootNamespace() { + Assume.assumeTrue("Hadoop has no default namespace configured", "testhadoop".equals(catalogName)); + + try { + sql("CREATE TABLE %s.table (id bigint) USING iceberg", catalogName); + } finally { + sql("DROP TABLE IF EXISTS %s.table", catalogName); + } + } + + @Test + public void testCreateTableUsingParquet() { + Assume.assumeTrue( + "Not working with session catalog because Spark will not use v2 for a Parquet table", + !"spark_catalog".equals(catalogName)); + + Assert.assertFalse("Table should not already exist", validationCatalog.tableExists(tableIdent)); + + sql("CREATE TABLE %s (id BIGINT NOT NULL, data STRING) USING parquet", tableName); + + Table table = validationCatalog.loadTable(tableIdent); + Assert.assertNotNull("Should load the new table", table); + + StructType expectedSchema = StructType.of( + NestedField.required(1, "id", Types.LongType.get()), + NestedField.optional(2, "data", Types.StringType.get())); + Assert.assertEquals("Should have the expected schema", expectedSchema, table.schema().asStruct()); + Assert.assertEquals("Should not be partitioned", 0, table.spec().fields().size()); + Assert.assertEquals("Should not have default format parquet", + "parquet", + table.properties().get(TableProperties.DEFAULT_FILE_FORMAT)); + + AssertHelpers.assertThrows("Should reject unsupported format names", + IllegalArgumentException.class, "Unsupported format in USING: crocodile", + () -> sql("CREATE TABLE %s.default.fail (id BIGINT NOT NULL, data STRING) USING crocodile", catalogName)); + } + + @Test + public void testCreateTablePartitionedBy() { + Assert.assertFalse("Table should not already exist", validationCatalog.tableExists(tableIdent)); + + sql("CREATE TABLE %s " + + "(id BIGINT NOT NULL, created_at TIMESTAMP, category STRING, data STRING) " + + "USING iceberg " + + "PARTITIONED BY (category, bucket(8, id), days(created_at))", tableName); + + Table table = validationCatalog.loadTable(tableIdent); + Assert.assertNotNull("Should load the new table", table); + + StructType expectedSchema = StructType.of( + NestedField.required(1, "id", Types.LongType.get()), + NestedField.optional(2, "created_at", Types.TimestampType.withZone()), + NestedField.optional(3, "category", Types.StringType.get()), + NestedField.optional(4, "data", Types.StringType.get())); + Assert.assertEquals("Should have the expected schema", expectedSchema, table.schema().asStruct()); + + PartitionSpec expectedSpec = PartitionSpec.builderFor(new Schema(expectedSchema.fields())) + .identity("category") + .bucket("id", 8) + .day("created_at") + .build(); + Assert.assertEquals("Should be partitioned correctly", expectedSpec, table.spec()); + + Assert.assertNull("Should not have the default format set", + table.properties().get(TableProperties.DEFAULT_FILE_FORMAT)); + } + + @Test + public void testCreateTableColumnComments() { + Assert.assertFalse("Table should not already exist", validationCatalog.tableExists(tableIdent)); + + sql("CREATE TABLE %s " + + "(id BIGINT NOT NULL COMMENT 'Unique identifier', data STRING COMMENT 'Data value') " + + "USING iceberg", + tableName); + + Table table = validationCatalog.loadTable(tableIdent); + Assert.assertNotNull("Should load the new table", table); + + StructType expectedSchema = StructType.of( + NestedField.required(1, "id", Types.LongType.get(), "Unique identifier"), + NestedField.optional(2, "data", Types.StringType.get(), "Data value")); + Assert.assertEquals("Should have the expected schema", expectedSchema, table.schema().asStruct()); + Assert.assertEquals("Should not be partitioned", 0, table.spec().fields().size()); + Assert.assertNull("Should not have the default format set", + table.properties().get(TableProperties.DEFAULT_FILE_FORMAT)); + } + + @Test + public void testCreateTableComment() { + Assert.assertFalse("Table should not already exist", validationCatalog.tableExists(tableIdent)); + + sql("CREATE TABLE %s " + + "(id BIGINT NOT NULL, data STRING) " + + "USING iceberg " + + "COMMENT 'Table doc'", + tableName); + + Table table = validationCatalog.loadTable(tableIdent); + Assert.assertNotNull("Should load the new table", table); + + StructType expectedSchema = StructType.of( + NestedField.required(1, "id", Types.LongType.get()), + NestedField.optional(2, "data", Types.StringType.get())); + Assert.assertEquals("Should have the expected schema", expectedSchema, table.schema().asStruct()); + Assert.assertEquals("Should not be partitioned", 0, table.spec().fields().size()); + Assert.assertNull("Should not have the default format set", + table.properties().get(TableProperties.DEFAULT_FILE_FORMAT)); + Assert.assertEquals("Should have the table comment set in properties", + "Table doc", table.properties().get(TableCatalog.PROP_COMMENT)); + } + + @Test + public void testCreateTableLocation() throws Exception { + Assume.assumeTrue( + "Cannot set custom locations for Hadoop catalog tables", + !(validationCatalog instanceof HadoopCatalog)); + + Assert.assertFalse("Table should not already exist", validationCatalog.tableExists(tableIdent)); + + File tableLocation = temp.newFolder(); + Assert.assertTrue(tableLocation.delete()); + + String location = "file:" + tableLocation.toString(); + + sql("CREATE TABLE %s " + + "(id BIGINT NOT NULL, data STRING) " + + "USING iceberg " + + "LOCATION '%s'", + tableName, location); + + Table table = validationCatalog.loadTable(tableIdent); + Assert.assertNotNull("Should load the new table", table); + + StructType expectedSchema = StructType.of( + NestedField.required(1, "id", Types.LongType.get()), + NestedField.optional(2, "data", Types.StringType.get())); + Assert.assertEquals("Should have the expected schema", expectedSchema, table.schema().asStruct()); + Assert.assertEquals("Should not be partitioned", 0, table.spec().fields().size()); + Assert.assertNull("Should not have the default format set", + table.properties().get(TableProperties.DEFAULT_FILE_FORMAT)); + Assert.assertEquals("Should have a custom table location", + location, table.location()); + } + + @Test + public void testCreateTableProperties() { + Assert.assertFalse("Table should not already exist", validationCatalog.tableExists(tableIdent)); + + sql("CREATE TABLE %s " + + "(id BIGINT NOT NULL, data STRING) " + + "USING iceberg " + + "TBLPROPERTIES (p1=2, p2='x')", + tableName); + + Table table = validationCatalog.loadTable(tableIdent); + Assert.assertNotNull("Should load the new table", table); + + StructType expectedSchema = StructType.of( + NestedField.required(1, "id", Types.LongType.get()), + NestedField.optional(2, "data", Types.StringType.get())); + Assert.assertEquals("Should have the expected schema", expectedSchema, table.schema().asStruct()); + Assert.assertEquals("Should not be partitioned", 0, table.spec().fields().size()); + Assert.assertEquals("Should have property p1", "2", table.properties().get("p1")); + Assert.assertEquals("Should have property p2", "x", table.properties().get("p2")); + } + + @Test + public void testCreateTableWithFormatV2ThroughTableProperty() { + Assert.assertFalse("Table should not already exist", validationCatalog.tableExists(tableIdent)); + + sql("CREATE TABLE %s " + + "(id BIGINT NOT NULL, data STRING) " + + "USING iceberg " + + "TBLPROPERTIES ('format-version'='2')", + tableName); + + Table table = validationCatalog.loadTable(tableIdent); + Assert.assertEquals("should create table using format v2", + 2, ((BaseTable) table).operations().current().formatVersion()); + } + + @Test + public void testUpgradeTableWithFormatV2ThroughTableProperty() { + Assert.assertFalse("Table should not already exist", validationCatalog.tableExists(tableIdent)); + + sql("CREATE TABLE %s " + + "(id BIGINT NOT NULL, data STRING) " + + "USING iceberg " + + "TBLPROPERTIES ('format-version'='1')", + tableName); + + Table table = validationCatalog.loadTable(tableIdent); + TableOperations ops = ((BaseTable) table).operations(); + Assert.assertEquals("should create table using format v1", + 1, ops.refresh().formatVersion()); + + sql("ALTER TABLE %s SET TBLPROPERTIES ('format-version'='2')", tableName); + Assert.assertEquals("should update table to use format v2", + 2, ops.refresh().formatVersion()); + } + + @Test + public void testDowngradeTableToFormatV1ThroughTablePropertyFails() { + Assert.assertFalse("Table should not already exist", validationCatalog.tableExists(tableIdent)); + + sql("CREATE TABLE %s " + + "(id BIGINT NOT NULL, data STRING) " + + "USING iceberg " + + "TBLPROPERTIES ('format-version'='2')", + tableName); + + Table table = validationCatalog.loadTable(tableIdent); + TableOperations ops = ((BaseTable) table).operations(); + Assert.assertEquals("should create table using format v2", + 2, ops.refresh().formatVersion()); + + AssertHelpers.assertThrowsCause("should fail to downgrade to v1", + IllegalArgumentException.class, + "Cannot downgrade v2 table to v1", + () -> sql("ALTER TABLE %s SET TBLPROPERTIES ('format-version'='1')", tableName)); + } +} diff --git a/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/sql/TestCreateTableAsSelect.java b/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/sql/TestCreateTableAsSelect.java new file mode 100644 index 000000000000..4a70327e21a1 --- /dev/null +++ b/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/sql/TestCreateTableAsSelect.java @@ -0,0 +1,363 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iceberg.spark.sql; + +import java.util.Map; +import org.apache.iceberg.PartitionSpec; +import org.apache.iceberg.Schema; +import org.apache.iceberg.Table; +import org.apache.iceberg.relocated.com.google.common.collect.Iterables; +import org.apache.iceberg.spark.SparkCatalogTestBase; +import org.apache.iceberg.types.Types; +import org.junit.After; +import org.junit.Assert; +import org.junit.Test; + +import static org.apache.spark.sql.functions.col; +import static org.apache.spark.sql.functions.lit; +import static org.apache.spark.sql.functions.when; + +public class TestCreateTableAsSelect extends SparkCatalogTestBase { + + private final String sourceName; + + public TestCreateTableAsSelect(String catalogName, String implementation, Map config) { + super(catalogName, implementation, config); + this.sourceName = tableName("source"); + + sql("CREATE TABLE IF NOT EXISTS %s (id bigint NOT NULL, data string) " + + "USING iceberg PARTITIONED BY (truncate(id, 3))", sourceName); + sql("INSERT INTO %s VALUES (1, 'a'), (2, 'b'), (3, 'c')", sourceName); + } + + @After + public void removeTables() { + sql("DROP TABLE IF EXISTS %s", tableName); + } + + @Test + public void testUnpartitionedCTAS() { + sql("CREATE TABLE %s USING iceberg AS SELECT * FROM %s", tableName, sourceName); + + Schema expectedSchema = new Schema( + Types.NestedField.optional(1, "id", Types.LongType.get()), + Types.NestedField.optional(2, "data", Types.StringType.get()) + ); + + Table ctasTable = validationCatalog.loadTable(tableIdent); + + Assert.assertEquals("Should have expected nullable schema", + expectedSchema.asStruct(), ctasTable.schema().asStruct()); + Assert.assertEquals("Should be an unpartitioned table", + 0, ctasTable.spec().fields().size()); + assertEquals("Should have rows matching the source table", + sql("SELECT * FROM %s ORDER BY id", sourceName), + sql("SELECT * FROM %s ORDER BY id", tableName)); + } + + @Test + public void testPartitionedCTAS() { + sql("CREATE TABLE %s USING iceberg PARTITIONED BY (id) AS SELECT * FROM %s ORDER BY id", tableName, sourceName); + + Schema expectedSchema = new Schema( + Types.NestedField.optional(1, "id", Types.LongType.get()), + Types.NestedField.optional(2, "data", Types.StringType.get()) + ); + + PartitionSpec expectedSpec = PartitionSpec.builderFor(expectedSchema) + .identity("id") + .build(); + + Table ctasTable = validationCatalog.loadTable(tableIdent); + + Assert.assertEquals("Should have expected nullable schema", + expectedSchema.asStruct(), ctasTable.schema().asStruct()); + Assert.assertEquals("Should be partitioned by id", + expectedSpec, ctasTable.spec()); + assertEquals("Should have rows matching the source table", + sql("SELECT * FROM %s ORDER BY id", sourceName), + sql("SELECT * FROM %s ORDER BY id", tableName)); + } + + @Test + public void testRTAS() { + sql("CREATE TABLE %s USING iceberg TBLPROPERTIES ('prop1'='val1', 'prop2'='val2')" + + "AS SELECT * FROM %s", tableName, sourceName); + + assertEquals("Should have rows matching the source table", + sql("SELECT * FROM %s ORDER BY id", sourceName), + sql("SELECT * FROM %s ORDER BY id", tableName)); + + sql("REPLACE TABLE %s USING iceberg PARTITIONED BY (part) TBLPROPERTIES ('prop1'='newval1', 'prop3'='val3') AS " + + "SELECT id, data, CASE WHEN (id %% 2) = 0 THEN 'even' ELSE 'odd' END AS part " + + "FROM %s ORDER BY 3, 1", tableName, sourceName); + + Schema expectedSchema = new Schema( + Types.NestedField.optional(1, "id", Types.LongType.get()), + Types.NestedField.optional(2, "data", Types.StringType.get()), + Types.NestedField.optional(3, "part", Types.StringType.get()) + ); + + PartitionSpec expectedSpec = PartitionSpec.builderFor(expectedSchema) + .identity("part") + .withSpecId(1) + .build(); + + Table rtasTable = validationCatalog.loadTable(tableIdent); + + // the replacement table has a different schema and partition spec than the original + Assert.assertEquals("Should have expected nullable schema", + expectedSchema.asStruct(), rtasTable.schema().asStruct()); + Assert.assertEquals("Should be partitioned by part", + expectedSpec, rtasTable.spec()); + + assertEquals("Should have rows matching the source table", + sql("SELECT id, data, CASE WHEN (id %% 2) = 0 THEN 'even' ELSE 'odd' END AS part " + + "FROM %s ORDER BY id", sourceName), + sql("SELECT * FROM %s ORDER BY id", tableName)); + + Assert.assertEquals("Table should have expected snapshots", + 2, Iterables.size(rtasTable.snapshots())); + + Assert.assertEquals("Should have updated table property", + "newval1", rtasTable.properties().get("prop1")); + Assert.assertEquals("Should have preserved table property", + "val2", rtasTable.properties().get("prop2")); + Assert.assertEquals("Should have new table property", + "val3", rtasTable.properties().get("prop3")); + } + + @Test + public void testCreateRTAS() { + sql("CREATE OR REPLACE TABLE %s USING iceberg PARTITIONED BY (part) AS " + + "SELECT id, data, CASE WHEN (id %% 2) = 0 THEN 'even' ELSE 'odd' END AS part " + + "FROM %s ORDER BY 3, 1", tableName, sourceName); + + assertEquals("Should have rows matching the source table", + sql("SELECT id, data, CASE WHEN (id %% 2) = 0 THEN 'even' ELSE 'odd' END AS part " + + "FROM %s ORDER BY id", sourceName), + sql("SELECT * FROM %s ORDER BY id", tableName)); + + sql("CREATE OR REPLACE TABLE %s USING iceberg PARTITIONED BY (part) AS " + + "SELECT 2 * id as id, data, CASE WHEN ((2 * id) %% 2) = 0 THEN 'even' ELSE 'odd' END AS part " + + "FROM %s ORDER BY 3, 1", tableName, sourceName); + + Schema expectedSchema = new Schema( + Types.NestedField.optional(1, "id", Types.LongType.get()), + Types.NestedField.optional(2, "data", Types.StringType.get()), + Types.NestedField.optional(3, "part", Types.StringType.get()) + ); + + PartitionSpec expectedSpec = PartitionSpec.builderFor(expectedSchema) + .identity("part") + .withSpecId(0) // the spec is identical and should be reused + .build(); + + Table rtasTable = validationCatalog.loadTable(tableIdent); + + // the replacement table has a different schema and partition spec than the original + Assert.assertEquals("Should have expected nullable schema", + expectedSchema.asStruct(), rtasTable.schema().asStruct()); + Assert.assertEquals("Should be partitioned by part", + expectedSpec, rtasTable.spec()); + + assertEquals("Should have rows matching the source table", + sql("SELECT 2 * id, data, CASE WHEN ((2 * id) %% 2) = 0 THEN 'even' ELSE 'odd' END AS part " + + "FROM %s ORDER BY id", sourceName), + sql("SELECT * FROM %s ORDER BY id", tableName)); + + Assert.assertEquals("Table should have expected snapshots", + 2, Iterables.size(rtasTable.snapshots())); + } + + @Test + public void testDataFrameV2Create() throws Exception { + spark.table(sourceName).writeTo(tableName).using("iceberg").create(); + + Schema expectedSchema = new Schema( + Types.NestedField.optional(1, "id", Types.LongType.get()), + Types.NestedField.optional(2, "data", Types.StringType.get()) + ); + + Table ctasTable = validationCatalog.loadTable(tableIdent); + + Assert.assertEquals("Should have expected nullable schema", + expectedSchema.asStruct(), ctasTable.schema().asStruct()); + Assert.assertEquals("Should be an unpartitioned table", + 0, ctasTable.spec().fields().size()); + assertEquals("Should have rows matching the source table", + sql("SELECT * FROM %s ORDER BY id", sourceName), + sql("SELECT * FROM %s ORDER BY id", tableName)); + } + + @Test + public void testDataFrameV2Replace() throws Exception { + spark.table(sourceName).writeTo(tableName).using("iceberg").create(); + + assertEquals("Should have rows matching the source table", + sql("SELECT * FROM %s ORDER BY id", sourceName), + sql("SELECT * FROM %s ORDER BY id", tableName)); + + spark.table(sourceName) + .select( + col("id"), + col("data"), + when(col("id").mod(lit(2)).equalTo(lit(0)), lit("even")).otherwise("odd").as("part")) + .orderBy("part", "id") + .writeTo(tableName) + .partitionedBy(col("part")) + .using("iceberg") + .replace(); + + Schema expectedSchema = new Schema( + Types.NestedField.optional(1, "id", Types.LongType.get()), + Types.NestedField.optional(2, "data", Types.StringType.get()), + Types.NestedField.optional(3, "part", Types.StringType.get()) + ); + + PartitionSpec expectedSpec = PartitionSpec.builderFor(expectedSchema) + .identity("part") + .withSpecId(1) + .build(); + + Table rtasTable = validationCatalog.loadTable(tableIdent); + + // the replacement table has a different schema and partition spec than the original + Assert.assertEquals("Should have expected nullable schema", + expectedSchema.asStruct(), rtasTable.schema().asStruct()); + Assert.assertEquals("Should be partitioned by part", + expectedSpec, rtasTable.spec()); + + assertEquals("Should have rows matching the source table", + sql("SELECT id, data, CASE WHEN (id %% 2) = 0 THEN 'even' ELSE 'odd' END AS part " + + "FROM %s ORDER BY id", sourceName), + sql("SELECT * FROM %s ORDER BY id", tableName)); + + Assert.assertEquals("Table should have expected snapshots", + 2, Iterables.size(rtasTable.snapshots())); + } + + @Test + public void testDataFrameV2CreateOrReplace() { + spark.table(sourceName) + .select( + col("id"), + col("data"), + when(col("id").mod(lit(2)).equalTo(lit(0)), lit("even")).otherwise("odd").as("part")) + .orderBy("part", "id") + .writeTo(tableName) + .partitionedBy(col("part")) + .using("iceberg") + .createOrReplace(); + + assertEquals("Should have rows matching the source table", + sql("SELECT id, data, CASE WHEN (id %% 2) = 0 THEN 'even' ELSE 'odd' END AS part " + + "FROM %s ORDER BY id", sourceName), + sql("SELECT * FROM %s ORDER BY id", tableName)); + + spark.table(sourceName) + .select(col("id").multiply(lit(2)).as("id"), col("data")) + .select( + col("id"), + col("data"), + when(col("id").mod(lit(2)).equalTo(lit(0)), lit("even")).otherwise("odd").as("part")) + .orderBy("part", "id") + .writeTo(tableName) + .partitionedBy(col("part")) + .using("iceberg") + .createOrReplace(); + + Schema expectedSchema = new Schema( + Types.NestedField.optional(1, "id", Types.LongType.get()), + Types.NestedField.optional(2, "data", Types.StringType.get()), + Types.NestedField.optional(3, "part", Types.StringType.get()) + ); + + PartitionSpec expectedSpec = PartitionSpec.builderFor(expectedSchema) + .identity("part") + .withSpecId(0) // the spec is identical and should be reused + .build(); + + Table rtasTable = validationCatalog.loadTable(tableIdent); + + // the replacement table has a different schema and partition spec than the original + Assert.assertEquals("Should have expected nullable schema", + expectedSchema.asStruct(), rtasTable.schema().asStruct()); + Assert.assertEquals("Should be partitioned by part", + expectedSpec, rtasTable.spec()); + + assertEquals("Should have rows matching the source table", + sql("SELECT 2 * id, data, CASE WHEN ((2 * id) %% 2) = 0 THEN 'even' ELSE 'odd' END AS part " + + "FROM %s ORDER BY id", sourceName), + sql("SELECT * FROM %s ORDER BY id", tableName)); + + Assert.assertEquals("Table should have expected snapshots", + 2, Iterables.size(rtasTable.snapshots())); + } + + @Test + public void testCreateRTASWithPartitionSpecChanging() { + sql("CREATE OR REPLACE TABLE %s USING iceberg PARTITIONED BY (part) AS " + + "SELECT id, data, CASE WHEN (id %% 2) = 0 THEN 'even' ELSE 'odd' END AS part " + + "FROM %s ORDER BY 3, 1", tableName, sourceName); + + Table rtasTable = validationCatalog.loadTable(tableIdent); + + assertEquals("Should have rows matching the source table", + sql("SELECT id, data, CASE WHEN (id %% 2) = 0 THEN 'even' ELSE 'odd' END AS part " + + "FROM %s ORDER BY id", sourceName), + sql("SELECT * FROM %s ORDER BY id", tableName)); + + // Change the partitioning of the table + rtasTable.updateSpec().removeField("part").commit(); // Spec 1 + + sql("CREATE OR REPLACE TABLE %s USING iceberg PARTITIONED BY (part, id) AS " + + "SELECT 2 * id as id, data, CASE WHEN ((2 * id) %% 2) = 0 THEN 'even' ELSE 'odd' END AS part " + + "FROM %s ORDER BY 3, 1", tableName, sourceName); + + Schema expectedSchema = new Schema( + Types.NestedField.optional(1, "id", Types.LongType.get()), + Types.NestedField.optional(2, "data", Types.StringType.get()), + Types.NestedField.optional(3, "part", Types.StringType.get()) + ); + + PartitionSpec expectedSpec = PartitionSpec.builderFor(expectedSchema) + .alwaysNull("part", "part_1000") + .identity("part") + .identity("id") + .withSpecId(2) // The Spec is new + .build(); + + Assert.assertEquals("Should be partitioned by part and id", + expectedSpec, rtasTable.spec()); + + // the replacement table has a different schema and partition spec than the original + Assert.assertEquals("Should have expected nullable schema", + expectedSchema.asStruct(), rtasTable.schema().asStruct()); + + assertEquals("Should have rows matching the source table", + sql("SELECT 2 * id, data, CASE WHEN ((2 * id) %% 2) = 0 THEN 'even' ELSE 'odd' END AS part " + + "FROM %s ORDER BY id", sourceName), + sql("SELECT * FROM %s ORDER BY id", tableName)); + + Assert.assertEquals("Table should have expected snapshots", + 2, Iterables.size(rtasTable.snapshots())); + } +} diff --git a/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/sql/TestDeleteFrom.java b/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/sql/TestDeleteFrom.java new file mode 100644 index 000000000000..1ba001c9a030 --- /dev/null +++ b/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/sql/TestDeleteFrom.java @@ -0,0 +1,163 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iceberg.spark.sql; + +import java.util.List; +import java.util.Map; +import org.apache.iceberg.AssertHelpers; +import org.apache.iceberg.Table; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableList; +import org.apache.iceberg.relocated.com.google.common.collect.Iterables; +import org.apache.iceberg.relocated.com.google.common.collect.Lists; +import org.apache.iceberg.spark.SparkCatalogTestBase; +import org.apache.iceberg.spark.source.SimpleRecord; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.catalyst.analysis.NoSuchTableException; +import org.junit.After; +import org.junit.Assert; +import org.junit.Test; + +public class TestDeleteFrom extends SparkCatalogTestBase { + public TestDeleteFrom(String catalogName, String implementation, Map config) { + super(catalogName, implementation, config); + } + + @After + public void removeTables() { + sql("DROP TABLE IF EXISTS %s", tableName); + } + + @Test + public void testDeleteFromUnpartitionedTable() throws NoSuchTableException { + sql("CREATE TABLE %s (id bigint, data string) USING iceberg", tableName); + + List records = Lists.newArrayList( + new SimpleRecord(1, "a"), + new SimpleRecord(2, "b"), + new SimpleRecord(3, "c") + ); + Dataset df = spark.createDataFrame(records, SimpleRecord.class); + df.coalesce(1).writeTo(tableName).append(); + + assertEquals("Should have expected rows", + ImmutableList.of(row(1L, "a"), row(2L, "b"), row(3L, "c")), + sql("SELECT * FROM %s ORDER BY id", tableName)); + + sql("DELETE FROM %s WHERE id < 2", tableName); + + assertEquals("Should have no rows after successful delete", + ImmutableList.of(row(2L, "b"), row(3L, "c")), + sql("SELECT * FROM %s ORDER BY id", tableName)); + + sql("DELETE FROM %s WHERE id < 4", tableName); + + assertEquals("Should have no rows after successful delete", + ImmutableList.of(), + sql("SELECT * FROM %s ORDER BY id", tableName)); + } + + @Test + public void testDeleteFromTableAtSnapshot() throws NoSuchTableException { + sql("CREATE TABLE %s (id bigint, data string) USING iceberg", tableName); + + List records = Lists.newArrayList( + new SimpleRecord(1, "a"), + new SimpleRecord(2, "b"), + new SimpleRecord(3, "c") + ); + Dataset df = spark.createDataFrame(records, SimpleRecord.class); + df.coalesce(1).writeTo(tableName).append(); + + long snapshotId = validationCatalog.loadTable(tableIdent).currentSnapshot().snapshotId(); + String prefix = "snapshot_id_"; + AssertHelpers.assertThrows("Should not be able to delete from a table at a specific snapshot", + IllegalArgumentException.class, "Cannot delete from table at a specific snapshot", + () -> sql("DELETE FROM %s.%s WHERE id < 4", tableName, prefix + snapshotId)); + } + + @Test + public void testDeleteFromPartitionedTable() throws NoSuchTableException { + sql("CREATE TABLE %s (id bigint, data string) " + + "USING iceberg " + + "PARTITIONED BY (truncate(id, 2))", tableName); + + List records = Lists.newArrayList( + new SimpleRecord(1, "a"), + new SimpleRecord(2, "b"), + new SimpleRecord(3, "c") + ); + Dataset df = spark.createDataFrame(records, SimpleRecord.class); + df.coalesce(1).writeTo(tableName).append(); + + assertEquals("Should have 3 rows in 2 partitions", + ImmutableList.of(row(1L, "a"), row(2L, "b"), row(3L, "c")), + sql("SELECT * FROM %s ORDER BY id", tableName)); + + sql("DELETE FROM %s WHERE id > 2", tableName); + assertEquals("Should have two rows in the second partition", + ImmutableList.of(row(1L, "a"), row(2L, "b")), + sql("SELECT * FROM %s ORDER BY id", tableName)); + + sql("DELETE FROM %s WHERE id < 2", tableName); + + assertEquals("Should have two rows in the second partition", + ImmutableList.of(row(2L, "b")), + sql("SELECT * FROM %s ORDER BY id", tableName)); + } + + @Test + public void testDeleteFromWhereFalse() { + sql("CREATE TABLE %s (id bigint NOT NULL, data string) USING iceberg", tableName); + sql("INSERT INTO TABLE %s VALUES (1, 'a'), (2, 'b'), (3, 'c')", tableName); + + assertEquals("Should have expected rows", + ImmutableList.of(row(1L, "a"), row(2L, "b"), row(3L, "c")), + sql("SELECT * FROM %s ORDER BY id", tableName)); + + Table table = validationCatalog.loadTable(tableIdent); + Assert.assertEquals("Should have 1 snapshot", 1, Iterables.size(table.snapshots())); + + sql("DELETE FROM %s WHERE false", tableName); + + table.refresh(); + + Assert.assertEquals("Delete should not produce a new snapshot", 1, Iterables.size(table.snapshots())); + } + + @Test + public void testTruncate() { + sql("CREATE TABLE %s (id bigint NOT NULL, data string) USING iceberg", tableName); + sql("INSERT INTO TABLE %s VALUES (1, 'a'), (2, 'b'), (3, 'c')", tableName); + + assertEquals("Should have expected rows", + ImmutableList.of(row(1L, "a"), row(2L, "b"), row(3L, "c")), + sql("SELECT * FROM %s ORDER BY id", tableName)); + + Table table = validationCatalog.loadTable(tableIdent); + Assert.assertEquals("Should have 1 snapshot", 1, Iterables.size(table.snapshots())); + + sql("TRUNCATE TABLE %s", tableName); + + assertEquals("Should have expected rows", + ImmutableList.of(), + sql("SELECT * FROM %s ORDER BY id", tableName)); + } +} diff --git a/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/sql/TestDropTable.java b/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/sql/TestDropTable.java new file mode 100644 index 000000000000..535cd3926a1a --- /dev/null +++ b/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/sql/TestDropTable.java @@ -0,0 +1,140 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iceberg.spark.sql; + +import java.io.IOException; +import java.util.List; +import java.util.Map; +import java.util.stream.Collectors; +import org.apache.hadoop.fs.FileSystem; +import org.apache.hadoop.fs.Path; +import org.apache.iceberg.AssertHelpers; +import org.apache.iceberg.MetadataTableType; +import org.apache.iceberg.exceptions.ValidationException; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableList; +import org.apache.iceberg.relocated.com.google.common.collect.Streams; +import org.apache.iceberg.spark.SparkCatalogTestBase; +import org.junit.After; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; + +public class TestDropTable extends SparkCatalogTestBase { + + public TestDropTable(String catalogName, String implementation, Map config) { + super(catalogName, implementation, config); + } + + @Before + public void createTable() { + sql("CREATE TABLE %s (id INT, name STRING) USING iceberg", tableName); + sql("INSERT INTO %s VALUES (1, 'test')", tableName); + } + + @After + public void removeTable() throws IOException { + sql("DROP TABLE IF EXISTS %s", tableName); + } + + @Test + public void testDropTable() throws IOException { + dropTableInternal(); + } + + @Test + public void testDropTableGCDisabled() throws IOException { + sql("ALTER TABLE %s SET TBLPROPERTIES (gc.enabled = false)", tableName); + dropTableInternal(); + } + + private void dropTableInternal() throws IOException { + assertEquals("Should have expected rows", + ImmutableList.of(row(1, "test")), sql("SELECT * FROM %s", tableName)); + + List manifestAndFiles = manifestsAndFiles(); + Assert.assertEquals("There should be 2 files for manifests and files", 2, manifestAndFiles.size()); + Assert.assertTrue("All files should be existed", checkFilesExist(manifestAndFiles, true)); + + sql("DROP TABLE %s", tableName); + Assert.assertFalse("Table should not exist", validationCatalog.tableExists(tableIdent)); + + if (catalogName.equals("testhadoop")) { + // HadoopCatalog drop table without purge will delete the base table location. + Assert.assertTrue("All files should be deleted", checkFilesExist(manifestAndFiles, false)); + } else { + Assert.assertTrue("All files should not be deleted", checkFilesExist(manifestAndFiles, true)); + } + } + + @Test + public void testPurgeTable() throws IOException { + assertEquals("Should have expected rows", + ImmutableList.of(row(1, "test")), sql("SELECT * FROM %s", tableName)); + + List manifestAndFiles = manifestsAndFiles(); + Assert.assertEquals("There should be 2 files for manifests and files", 2, manifestAndFiles.size()); + Assert.assertTrue("All files should exist", checkFilesExist(manifestAndFiles, true)); + + sql("DROP TABLE %s PURGE", tableName); + Assert.assertFalse("Table should not exist", validationCatalog.tableExists(tableIdent)); + Assert.assertTrue("All files should be deleted", checkFilesExist(manifestAndFiles, false)); + } + + @Test + public void testPurgeTableGCDisabled() throws IOException { + sql("ALTER TABLE %s SET TBLPROPERTIES (gc.enabled = false)", tableName); + + assertEquals("Should have expected rows", + ImmutableList.of(row(1, "test")), sql("SELECT * FROM %s", tableName)); + + List manifestAndFiles = manifestsAndFiles(); + Assert.assertEquals("There totally should have 2 files for manifests and files", 2, manifestAndFiles.size()); + Assert.assertTrue("All files should be existed", checkFilesExist(manifestAndFiles, true)); + + AssertHelpers.assertThrows("Purge table is not allowed when GC is disabled", ValidationException.class, + "Cannot purge table: GC is disabled (deleting files may corrupt other tables", + () -> sql("DROP TABLE %s PURGE", tableName)); + + Assert.assertTrue("Table should not been dropped", validationCatalog.tableExists(tableIdent)); + Assert.assertTrue("All files should not be deleted", checkFilesExist(manifestAndFiles, true)); + } + + private List manifestsAndFiles() { + List files = sql("SELECT file_path FROM %s.%s", tableName, MetadataTableType.FILES); + List manifests = sql("SELECT path FROM %s.%s", tableName, MetadataTableType.MANIFESTS); + return Streams.concat(files.stream(), manifests.stream()).map(row -> (String) row[0]).collect(Collectors.toList()); + } + + private boolean checkFilesExist(List files, boolean shouldExist) throws IOException { + boolean mask = !shouldExist; + if (files.isEmpty()) { + return mask; + } + + FileSystem fs = new Path(files.get(0)).getFileSystem(hiveConf); + return files.stream().allMatch(file -> { + try { + return fs.exists(new Path(file)) ^ mask; + } catch (IOException e) { + throw new RuntimeException(e); + } + }); + } +} diff --git a/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/sql/TestNamespaceSQL.java b/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/sql/TestNamespaceSQL.java new file mode 100644 index 000000000000..a76e4d624ba1 --- /dev/null +++ b/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/sql/TestNamespaceSQL.java @@ -0,0 +1,221 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iceberg.spark.sql; + +import java.io.File; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.stream.Collectors; +import org.apache.iceberg.AssertHelpers; +import org.apache.iceberg.catalog.Namespace; +import org.apache.iceberg.catalog.TableIdentifier; +import org.apache.iceberg.exceptions.NamespaceNotEmptyException; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableSet; +import org.apache.iceberg.relocated.com.google.common.collect.Iterables; +import org.apache.iceberg.spark.SparkCatalogTestBase; +import org.junit.After; +import org.junit.Assert; +import org.junit.Assume; +import org.junit.Test; + +public class TestNamespaceSQL extends SparkCatalogTestBase { + private static final Namespace NS = Namespace.of("db"); + + private final String fullNamespace; + private final boolean isHadoopCatalog; + + public TestNamespaceSQL(String catalogName, String implementation, Map config) { + super(catalogName, implementation, config); + this.fullNamespace = ("spark_catalog".equals(catalogName) ? "" : catalogName + ".") + NS; + this.isHadoopCatalog = "testhadoop".equals(catalogName); + } + + @After + public void cleanNamespaces() { + sql("DROP TABLE IF EXISTS %s.table", fullNamespace); + sql("DROP NAMESPACE IF EXISTS %s", fullNamespace); + } + + @Test + public void testCreateNamespace() { + Assert.assertFalse("Namespace should not already exist", validationNamespaceCatalog.namespaceExists(NS)); + + sql("CREATE NAMESPACE %s", fullNamespace); + + Assert.assertTrue("Namespace should exist", validationNamespaceCatalog.namespaceExists(NS)); + } + + @Test + public void testDefaultNamespace() { + Assume.assumeFalse("Hadoop has no default namespace configured", isHadoopCatalog); + + sql("USE %s", catalogName); + + Object[] current = Iterables.getOnlyElement(sql("SHOW CURRENT NAMESPACE")); + Assert.assertEquals("Should use the current catalog", current[0], catalogName); + Assert.assertEquals("Should use the configured default namespace", current[1], "default"); + } + + @Test + public void testDropEmptyNamespace() { + Assert.assertFalse("Namespace should not already exist", validationNamespaceCatalog.namespaceExists(NS)); + + sql("CREATE NAMESPACE %s", fullNamespace); + + Assert.assertTrue("Namespace should exist", validationNamespaceCatalog.namespaceExists(NS)); + + sql("DROP NAMESPACE %s", fullNamespace); + + Assert.assertFalse("Namespace should have been dropped", validationNamespaceCatalog.namespaceExists(NS)); + } + + @Test + public void testDropNonEmptyNamespace() { + Assume.assumeFalse("Session catalog has flaky behavior", "spark_catalog".equals(catalogName)); + + Assert.assertFalse("Namespace should not already exist", validationNamespaceCatalog.namespaceExists(NS)); + + sql("CREATE NAMESPACE %s", fullNamespace); + sql("CREATE TABLE %s.table (id bigint) USING iceberg", fullNamespace); + + Assert.assertTrue("Namespace should exist", validationNamespaceCatalog.namespaceExists(NS)); + Assert.assertTrue("Table should exist", validationCatalog.tableExists(TableIdentifier.of(NS, "table"))); + + AssertHelpers.assertThrows("Should fail if trying to delete a non-empty namespace", + NamespaceNotEmptyException.class, "Namespace db is not empty.", + () -> sql("DROP NAMESPACE %s", fullNamespace)); + + sql("DROP TABLE %s.table", fullNamespace); + } + + @Test + public void testListTables() { + Assert.assertFalse("Namespace should not already exist", validationNamespaceCatalog.namespaceExists(NS)); + + sql("CREATE NAMESPACE %s", fullNamespace); + + Assert.assertTrue("Namespace should exist", validationNamespaceCatalog.namespaceExists(NS)); + + List rows = sql("SHOW TABLES IN %s", fullNamespace); + Assert.assertEquals("Should not list any tables", 0, rows.size()); + + sql("CREATE TABLE %s.table (id bigint) USING iceberg", fullNamespace); + + Object[] row = Iterables.getOnlyElement(sql("SHOW TABLES IN %s", fullNamespace)); + Assert.assertEquals("Namespace should match", "db", row[0]); + Assert.assertEquals("Table name should match", "table", row[1]); + } + + @Test + public void testListNamespace() { + Assert.assertFalse("Namespace should not already exist", validationNamespaceCatalog.namespaceExists(NS)); + + sql("CREATE NAMESPACE %s", fullNamespace); + + Assert.assertTrue("Namespace should exist", validationNamespaceCatalog.namespaceExists(NS)); + + List namespaces = sql("SHOW NAMESPACES IN %s", catalogName); + + if (isHadoopCatalog) { + Assert.assertEquals("Should have 1 namespace", 1, namespaces.size()); + Set namespaceNames = namespaces.stream().map(arr -> arr[0].toString()).collect(Collectors.toSet()); + Assert.assertEquals("Should have only db namespace", ImmutableSet.of("db"), namespaceNames); + } else { + Assert.assertEquals("Should have 2 namespaces", 2, namespaces.size()); + Set namespaceNames = namespaces.stream().map(arr -> arr[0].toString()).collect(Collectors.toSet()); + Assert.assertEquals("Should have default and db namespaces", ImmutableSet.of("default", "db"), namespaceNames); + } + + List nestedNamespaces = sql("SHOW NAMESPACES IN %s", fullNamespace); + + Set nestedNames = nestedNamespaces.stream().map(arr -> arr[0].toString()).collect(Collectors.toSet()); + Assert.assertEquals("Should not have nested namespaces", ImmutableSet.of(), nestedNames); + } + + @Test + public void testCreateNamespaceWithMetadata() { + Assume.assumeFalse("HadoopCatalog does not support namespace metadata", isHadoopCatalog); + + Assert.assertFalse("Namespace should not already exist", validationNamespaceCatalog.namespaceExists(NS)); + + sql("CREATE NAMESPACE %s WITH PROPERTIES ('prop'='value')", fullNamespace); + + Assert.assertTrue("Namespace should exist", validationNamespaceCatalog.namespaceExists(NS)); + + Map nsMetadata = validationNamespaceCatalog.loadNamespaceMetadata(NS); + + Assert.assertEquals("Namespace should have expected prop value", "value", nsMetadata.get("prop")); + } + + @Test + public void testCreateNamespaceWithComment() { + Assume.assumeFalse("HadoopCatalog does not support namespace metadata", isHadoopCatalog); + + Assert.assertFalse("Namespace should not already exist", validationNamespaceCatalog.namespaceExists(NS)); + + sql("CREATE NAMESPACE %s COMMENT 'namespace doc'", fullNamespace); + + Assert.assertTrue("Namespace should exist", validationNamespaceCatalog.namespaceExists(NS)); + + Map nsMetadata = validationNamespaceCatalog.loadNamespaceMetadata(NS); + + Assert.assertEquals("Namespace should have expected comment", "namespace doc", nsMetadata.get("comment")); + } + + @Test + public void testCreateNamespaceWithLocation() throws Exception { + Assume.assumeFalse("HadoopCatalog does not support namespace locations", isHadoopCatalog); + + Assert.assertFalse("Namespace should not already exist", validationNamespaceCatalog.namespaceExists(NS)); + + File location = temp.newFile(); + Assert.assertTrue(location.delete()); + + sql("CREATE NAMESPACE %s LOCATION '%s'", fullNamespace, location); + + Assert.assertTrue("Namespace should exist", validationNamespaceCatalog.namespaceExists(NS)); + + Map nsMetadata = validationNamespaceCatalog.loadNamespaceMetadata(NS); + + Assert.assertEquals("Namespace should have expected location", + "file:" + location.getPath(), nsMetadata.get("location")); + } + + @Test + public void testSetProperties() { + Assume.assumeFalse("HadoopCatalog does not support namespace metadata", isHadoopCatalog); + + Assert.assertFalse("Namespace should not already exist", validationNamespaceCatalog.namespaceExists(NS)); + + sql("CREATE NAMESPACE %s", fullNamespace); + + Assert.assertTrue("Namespace should exist", validationNamespaceCatalog.namespaceExists(NS)); + + Map defaultMetadata = validationNamespaceCatalog.loadNamespaceMetadata(NS); + Assert.assertFalse("Default metadata should not have custom property", defaultMetadata.containsKey("prop")); + + sql("ALTER NAMESPACE %s SET PROPERTIES ('prop'='value')", fullNamespace); + + Map nsMetadata = validationNamespaceCatalog.loadNamespaceMetadata(NS); + + Assert.assertEquals("Namespace should have expected prop value", "value", nsMetadata.get("prop")); + } +} diff --git a/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/sql/TestPartitionedWrites.java b/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/sql/TestPartitionedWrites.java new file mode 100644 index 000000000000..9223797ada32 --- /dev/null +++ b/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/sql/TestPartitionedWrites.java @@ -0,0 +1,179 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iceberg.spark.sql; + +import java.util.List; +import java.util.Map; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableList; +import org.apache.iceberg.spark.SparkCatalogTestBase; +import org.apache.iceberg.spark.source.SimpleRecord; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.catalyst.analysis.NoSuchTableException; +import org.apache.spark.sql.functions; +import org.junit.After; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; + +public class TestPartitionedWrites extends SparkCatalogTestBase { + public TestPartitionedWrites(String catalogName, String implementation, Map config) { + super(catalogName, implementation, config); + } + + @Before + public void createTables() { + sql("CREATE TABLE %s (id bigint, data string) USING iceberg PARTITIONED BY (truncate(id, 3))", tableName); + sql("INSERT INTO %s VALUES (1, 'a'), (2, 'b'), (3, 'c')", tableName); + } + + @After + public void removeTables() { + sql("DROP TABLE IF EXISTS %s", tableName); + } + + @Test + public void testInsertAppend() { + Assert.assertEquals("Should have 3 rows", 3L, scalarSql("SELECT count(*) FROM %s", tableName)); + + sql("INSERT INTO %s VALUES (4, 'd'), (5, 'e')", tableName); + + Assert.assertEquals("Should have 5 rows after insert", 5L, scalarSql("SELECT count(*) FROM %s", tableName)); + + List expected = ImmutableList.of( + row(1L, "a"), + row(2L, "b"), + row(3L, "c"), + row(4L, "d"), + row(5L, "e") + ); + + assertEquals("Row data should match expected", expected, sql("SELECT * FROM %s ORDER BY id", tableName)); + } + + @Test + public void testInsertOverwrite() { + Assert.assertEquals("Should have 3 rows", 3L, scalarSql("SELECT count(*) FROM %s", tableName)); + + // 4 and 5 replace 3 in the partition (id - (id % 3)) = 3 + sql("INSERT OVERWRITE %s VALUES (4, 'd'), (5, 'e')", tableName); + + Assert.assertEquals("Should have 4 rows after overwrite", 4L, scalarSql("SELECT count(*) FROM %s", tableName)); + + List expected = ImmutableList.of( + row(1L, "a"), + row(2L, "b"), + row(4L, "d"), + row(5L, "e") + ); + + assertEquals("Row data should match expected", expected, sql("SELECT * FROM %s ORDER BY id", tableName)); + } + + @Test + public void testDataFrameV2Append() throws NoSuchTableException { + Assert.assertEquals("Should have 3 rows", 3L, scalarSql("SELECT count(*) FROM %s", tableName)); + + List data = ImmutableList.of( + new SimpleRecord(4, "d"), + new SimpleRecord(5, "e") + ); + Dataset ds = spark.createDataFrame(data, SimpleRecord.class); + + ds.writeTo(tableName).append(); + + Assert.assertEquals("Should have 5 rows after insert", 5L, scalarSql("SELECT count(*) FROM %s", tableName)); + + List expected = ImmutableList.of( + row(1L, "a"), + row(2L, "b"), + row(3L, "c"), + row(4L, "d"), + row(5L, "e") + ); + + assertEquals("Row data should match expected", expected, sql("SELECT * FROM %s ORDER BY id", tableName)); + } + + @Test + public void testDataFrameV2DynamicOverwrite() throws NoSuchTableException { + Assert.assertEquals("Should have 3 rows", 3L, scalarSql("SELECT count(*) FROM %s", tableName)); + + List data = ImmutableList.of( + new SimpleRecord(4, "d"), + new SimpleRecord(5, "e") + ); + Dataset ds = spark.createDataFrame(data, SimpleRecord.class); + + ds.writeTo(tableName).overwritePartitions(); + + Assert.assertEquals("Should have 4 rows after overwrite", 4L, scalarSql("SELECT count(*) FROM %s", tableName)); + + List expected = ImmutableList.of( + row(1L, "a"), + row(2L, "b"), + row(4L, "d"), + row(5L, "e") + ); + + assertEquals("Row data should match expected", expected, sql("SELECT * FROM %s ORDER BY id", tableName)); + } + + @Test + public void testDataFrameV2Overwrite() throws NoSuchTableException { + Assert.assertEquals("Should have 3 rows", 3L, scalarSql("SELECT count(*) FROM %s", tableName)); + + List data = ImmutableList.of( + new SimpleRecord(4, "d"), + new SimpleRecord(5, "e") + ); + Dataset ds = spark.createDataFrame(data, SimpleRecord.class); + + ds.writeTo(tableName).overwrite(functions.col("id").$less(3)); + + Assert.assertEquals("Should have 3 rows after overwrite", 3L, scalarSql("SELECT count(*) FROM %s", tableName)); + + List expected = ImmutableList.of( + row(3L, "c"), + row(4L, "d"), + row(5L, "e") + ); + + assertEquals("Row data should match expected", expected, sql("SELECT * FROM %s ORDER BY id", tableName)); + } + + @Test + public void testViewsReturnRecentResults() { + Assert.assertEquals("Should have 3 rows", 3L, scalarSql("SELECT count(*) FROM %s", tableName)); + + Dataset query = spark.sql("SELECT * FROM " + tableName + " WHERE id = 1"); + query.createOrReplaceTempView("tmp"); + + assertEquals("View should have expected rows", + ImmutableList.of(row(1L, "a")), + sql("SELECT * FROM tmp")); + + sql("INSERT INTO TABLE %s VALUES (1, 'a')", tableName); + + assertEquals("View should have expected rows", + ImmutableList.of(row(1L, "a"), row(1L, "a")), + sql("SELECT * FROM tmp")); + } +} diff --git a/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/sql/TestPartitionedWritesAsSelect.java b/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/sql/TestPartitionedWritesAsSelect.java new file mode 100644 index 000000000000..e38d545d8df9 --- /dev/null +++ b/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/sql/TestPartitionedWritesAsSelect.java @@ -0,0 +1,112 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iceberg.spark.sql; + +import java.util.List; +import java.util.stream.IntStream; +import org.apache.iceberg.spark.IcebergSpark; +import org.apache.iceberg.spark.SparkTestBaseWithCatalog; +import org.apache.spark.sql.types.DataTypes; +import org.junit.After; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; + +public class TestPartitionedWritesAsSelect extends SparkTestBaseWithCatalog { + + private final String targetTable = tableName("target_table"); + + @Before + public void createTables() { + sql("CREATE TABLE %s (id bigint, data string, category string, ts timestamp) USING iceberg", tableName); + } + + @After + public void removeTables() { + sql("DROP TABLE IF EXISTS %s", tableName); + sql("DROP TABLE IF EXISTS %s", targetTable); + } + + @Test + public void testInsertAsSelectAppend() { + insertData(3); + List expected = currentData(); + + sql("CREATE TABLE %s (id bigint, data string, category string, ts timestamp)" + + "USING iceberg PARTITIONED BY (days(ts), category)", targetTable); + + sql("INSERT INTO %s SELECT id, data, category, ts FROM %s ORDER BY ts,category", targetTable, tableName); + Assert.assertEquals("Should have 15 rows after insert", + 3 * 5L, scalarSql("SELECT count(*) FROM %s", targetTable)); + + assertEquals("Row data should match expected", + expected, sql("SELECT * FROM %s ORDER BY id", targetTable)); + } + + @Test + public void testInsertAsSelectWithBucket() { + insertData(3); + List expected = currentData(); + + sql("CREATE TABLE %s (id bigint, data string, category string, ts timestamp)" + + "USING iceberg PARTITIONED BY (bucket(8, data))", targetTable); + + IcebergSpark.registerBucketUDF(spark, "iceberg_bucket8", DataTypes.StringType, 8); + sql("INSERT INTO %s SELECT id, data, category, ts FROM %s ORDER BY iceberg_bucket8(data)", targetTable, tableName); + Assert.assertEquals("Should have 15 rows after insert", + 3 * 5L, scalarSql("SELECT count(*) FROM %s", targetTable)); + + assertEquals("Row data should match expected", + expected, sql("SELECT * FROM %s ORDER BY id", targetTable)); + } + + @Test + public void testInsertAsSelectWithTruncate() { + insertData(3); + List expected = currentData(); + + sql("CREATE TABLE %s (id bigint, data string, category string, ts timestamp)" + + "USING iceberg PARTITIONED BY (truncate(data, 4), truncate(id, 4))", targetTable); + + IcebergSpark.registerTruncateUDF(spark, "iceberg_truncate_string4", DataTypes.StringType, 4); + IcebergSpark.registerTruncateUDF(spark, "iceberg_truncate_long4", DataTypes.LongType, 4); + sql("INSERT INTO %s SELECT id, data, category, ts FROM %s " + + "ORDER BY iceberg_truncate_string4(data),iceberg_truncate_long4(id)", targetTable, tableName); + Assert.assertEquals("Should have 15 rows after insert", + 3 * 5L, scalarSql("SELECT count(*) FROM %s", targetTable)); + + assertEquals("Row data should match expected", + expected, sql("SELECT * FROM %s ORDER BY id", targetTable)); + } + + private void insertData(int repeatCounter) { + IntStream.range(0, repeatCounter).forEach(i -> { + sql("INSERT INTO %s VALUES (13, '1', 'bgd16', timestamp('2021-11-10 11:20:10'))," + + "(21, '2', 'bgd13', timestamp('2021-11-10 11:20:10')), " + + "(12, '3', 'bgd14', timestamp('2021-11-10 11:20:10'))," + + "(222, '3', 'bgd15', timestamp('2021-11-10 11:20:10'))," + + "(45, '4', 'bgd16', timestamp('2021-11-10 11:20:10'))", tableName); + }); + } + + private List currentData() { + return rowsToJava(spark.sql("SELECT * FROM " + tableName + " order by id").collectAsList()); + } +} diff --git a/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/sql/TestRefreshTable.java b/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/sql/TestRefreshTable.java new file mode 100644 index 000000000000..187dd4470a05 --- /dev/null +++ b/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/sql/TestRefreshTable.java @@ -0,0 +1,76 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + + +package org.apache.iceberg.spark.sql; + +import java.util.List; +import java.util.Map; +import org.apache.iceberg.DataFile; +import org.apache.iceberg.Table; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableList; +import org.apache.iceberg.spark.SparkCatalogTestBase; +import org.junit.After; +import org.junit.Before; +import org.junit.Test; + +public class TestRefreshTable extends SparkCatalogTestBase { + + public TestRefreshTable(String catalogName, String implementation, Map config) { + super(catalogName, implementation, config); + } + + @Before + public void createTables() { + sql("CREATE TABLE %s (key int, value int) USING iceberg", tableName); + sql("INSERT INTO %s VALUES (1,1)", tableName); + } + + @After + public void removeTables() { + sql("DROP TABLE IF EXISTS %s", tableName); + } + + @Test + public void testRefreshCommand() { + // We are not allowed to change the session catalog after it has been initialized, so build a new one + if (catalogName.equals("spark_catalog")) { + spark.conf().set("spark.sql.catalog." + catalogName + ".cache-enabled", true); + spark = spark.cloneSession(); + } + + List originalExpected = ImmutableList.of(row(1, 1)); + List originalActual = sql("SELECT * FROM %s", tableName); + assertEquals("Table should start as expected", originalExpected, originalActual); + + // Modify table outside of spark, it should be cached so Spark should see the same value after mutation + Table table = validationCatalog.loadTable(tableIdent); + DataFile file = table.currentSnapshot().addedFiles(table.io()).iterator().next(); + table.newDelete().deleteFile(file).commit(); + + List cachedActual = sql("SELECT * FROM %s", tableName); + assertEquals("Cached table should be unchanged", originalExpected, cachedActual); + + // Refresh the Spark catalog, should be empty + sql("REFRESH TABLE %s", tableName); + List refreshedExpected = ImmutableList.of(); + List refreshedActual = sql("SELECT * FROM %s", tableName); + assertEquals("Refreshed table should be empty", refreshedExpected, refreshedActual); + } +} diff --git a/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/sql/TestSelect.java b/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/sql/TestSelect.java new file mode 100644 index 000000000000..e9ac21d2df39 --- /dev/null +++ b/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/sql/TestSelect.java @@ -0,0 +1,218 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iceberg.spark.sql; + +import java.util.List; +import java.util.Map; +import org.apache.iceberg.AssertHelpers; +import org.apache.iceberg.events.Listeners; +import org.apache.iceberg.events.ScanEvent; +import org.apache.iceberg.expressions.Expressions; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableList; +import org.apache.iceberg.spark.Spark3Util; +import org.apache.iceberg.spark.SparkCatalogTestBase; +import org.apache.iceberg.spark.SparkReadOptions; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Row; +import org.junit.After; +import org.junit.Assert; +import org.junit.Assume; +import org.junit.Before; +import org.junit.Test; + +public class TestSelect extends SparkCatalogTestBase { + private int scanEventCount = 0; + private ScanEvent lastScanEvent = null; + private String binaryTableName = tableName("binary_table"); + + public TestSelect(String catalogName, String implementation, Map config) { + super(catalogName, implementation, config); + + // register a scan event listener to validate pushdown + Listeners.register(event -> { + scanEventCount += 1; + lastScanEvent = event; + }, ScanEvent.class); + } + + @Before + public void createTables() { + sql("CREATE TABLE %s (id bigint, data string, float float) USING iceberg", tableName); + sql("INSERT INTO %s VALUES (1, 'a', 1.0), (2, 'b', 2.0), (3, 'c', float('NaN'))", tableName); + + this.scanEventCount = 0; + this.lastScanEvent = null; + } + + @After + public void removeTables() { + sql("DROP TABLE IF EXISTS %s", tableName); + sql("DROP TABLE IF EXISTS %s", binaryTableName); + } + + @Test + public void testSelect() { + List expected = ImmutableList.of( + row(1L, "a", 1.0F), row(2L, "b", 2.0F), row(3L, "c", Float.NaN)); + + assertEquals("Should return all expected rows", expected, sql("SELECT * FROM %s", tableName)); + } + + @Test + public void testSelectRewrite() { + List expected = ImmutableList.of(row(3L, "c", Float.NaN)); + + assertEquals("Should return all expected rows", expected, + sql("SELECT * FROM %s where float = float('NaN')", tableName)); + + Assert.assertEquals("Should create only one scan", 1, scanEventCount); + Assert.assertEquals("Should push down expected filter", + "(float IS NOT NULL AND is_nan(float))", + Spark3Util.describe(lastScanEvent.filter())); + } + + @Test + public void testProjection() { + List expected = ImmutableList.of(row(1L), row(2L), row(3L)); + + assertEquals("Should return all expected rows", expected, sql("SELECT id FROM %s", tableName)); + + Assert.assertEquals("Should create only one scan", 1, scanEventCount); + Assert.assertEquals("Should not push down a filter", Expressions.alwaysTrue(), lastScanEvent.filter()); + Assert.assertEquals("Should project only the id column", + validationCatalog.loadTable(tableIdent).schema().select("id").asStruct(), + lastScanEvent.projection().asStruct()); + } + + @Test + public void testExpressionPushdown() { + List expected = ImmutableList.of(row("b")); + + assertEquals("Should return all expected rows", expected, sql("SELECT data FROM %s WHERE id = 2", tableName)); + + Assert.assertEquals("Should create only one scan", 1, scanEventCount); + Assert.assertEquals("Should push down expected filter", + "(id IS NOT NULL AND id = 2)", + Spark3Util.describe(lastScanEvent.filter())); + Assert.assertEquals("Should project only id and data columns", + validationCatalog.loadTable(tableIdent).schema().select("id", "data").asStruct(), + lastScanEvent.projection().asStruct()); + } + + @Test + public void testMetadataTables() { + Assume.assumeFalse( + "Spark session catalog does not support metadata tables", + "spark_catalog".equals(catalogName)); + + assertEquals("Snapshot metadata table", + ImmutableList.of(row(ANY, ANY, null, "append", ANY, ANY)), + sql("SELECT * FROM %s.snapshots", tableName)); + } + + @Test + public void testSnapshotInTableName() { + Assume.assumeFalse( + "Spark session catalog does not support extended table names", + "spark_catalog".equals(catalogName)); + + // get the snapshot ID of the last write and get the current row set as expected + long snapshotId = validationCatalog.loadTable(tableIdent).currentSnapshot().snapshotId(); + List expected = sql("SELECT * FROM %s", tableName); + + // create a second snapshot + sql("INSERT INTO %s VALUES (4, 'd', 4.0), (5, 'e', 5.0)", tableName); + + String prefix = "snapshot_id_"; + // read the table at the snapshot + List actual = sql("SELECT * FROM %s.%s", tableName, prefix + snapshotId); + assertEquals("Snapshot at specific ID, prefix " + prefix, expected, actual); + + // read the table using DataFrameReader option + Dataset df = spark.read() + .format("iceberg") + .option(SparkReadOptions.SNAPSHOT_ID, snapshotId) + .load(tableName); + List fromDF = rowsToJava(df.collectAsList()); + assertEquals("Snapshot at specific ID " + snapshotId, expected, fromDF); + } + + @Test + public void testTimestampInTableName() { + Assume.assumeFalse( + "Spark session catalog does not support extended table names", + "spark_catalog".equals(catalogName)); + + // get a timestamp just after the last write and get the current row set as expected + long snapshotTs = validationCatalog.loadTable(tableIdent).currentSnapshot().timestampMillis(); + long timestamp = waitUntilAfter(snapshotTs + 2); + List expected = sql("SELECT * FROM %s", tableName); + + // create a second snapshot + sql("INSERT INTO %s VALUES (4, 'd', 4.0), (5, 'e', 5.0)", tableName); + + String prefix = "at_timestamp_"; + // read the table at the snapshot + List actual = sql("SELECT * FROM %s.%s", tableName, prefix + timestamp); + assertEquals("Snapshot at timestamp, prefix " + prefix, expected, actual); + + // read the table using DataFrameReader option + Dataset df = spark.read() + .format("iceberg") + .option(SparkReadOptions.AS_OF_TIMESTAMP, timestamp) + .load(tableName); + List fromDF = rowsToJava(df.collectAsList()); + assertEquals("Snapshot at timestamp " + timestamp, expected, fromDF); + } + + @Test + public void testSpecifySnapshotAndTimestamp() { + // get the snapshot ID of the last write + long snapshotId = validationCatalog.loadTable(tableIdent).currentSnapshot().snapshotId(); + // get a timestamp just after the last write + long timestamp = validationCatalog.loadTable(tableIdent).currentSnapshot().timestampMillis() + 2; + + // create a second snapshot + sql("INSERT INTO %s VALUES (4, 'd', 4.0), (5, 'e', 5.0)", tableName); + + AssertHelpers.assertThrows("Should not be able to specify both snapshot id and timestamp", + IllegalArgumentException.class, + String.format("Cannot specify both snapshot-id (%s) and as-of-timestamp (%s)", + snapshotId, timestamp), + () -> { + spark.read() + .format("iceberg") + .option(SparkReadOptions.SNAPSHOT_ID, snapshotId) + .option(SparkReadOptions.AS_OF_TIMESTAMP, timestamp) + .load(tableName) + .collectAsList(); + }); + } + + @Test + public void testBinaryInFilter() { + sql("CREATE TABLE %s (id bigint, binary binary) USING iceberg", binaryTableName); + sql("INSERT INTO %s VALUES (1, X''), (2, X'1111'), (3, X'11')", binaryTableName); + List expected = ImmutableList.of(row(2L, new byte[]{0x11, 0x11})); + + assertEquals("Should return all expected rows", expected, + sql("SELECT id, binary FROM %s where binary > X'11'", binaryTableName)); + } +} diff --git a/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/sql/TestTimestampWithoutZone.java b/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/sql/TestTimestampWithoutZone.java new file mode 100644 index 000000000000..ddaac5256e10 --- /dev/null +++ b/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/sql/TestTimestampWithoutZone.java @@ -0,0 +1,196 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iceberg.spark.sql; + +import java.sql.Timestamp; +import java.util.Arrays; +import java.util.List; +import java.util.Map; +import java.util.stream.Collectors; +import org.apache.iceberg.AssertHelpers; +import org.apache.iceberg.Schema; +import org.apache.iceberg.Table; +import org.apache.iceberg.catalog.TableIdentifier; +import org.apache.iceberg.relocated.com.google.common.base.Joiner; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableList; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap; +import org.apache.iceberg.spark.SparkCatalogTestBase; +import org.apache.iceberg.spark.SparkSQLProperties; +import org.apache.iceberg.spark.SparkSessionCatalog; +import org.apache.iceberg.spark.SparkUtil; +import org.apache.iceberg.types.Type; +import org.apache.iceberg.types.Types; +import org.apache.spark.sql.util.CaseInsensitiveStringMap; +import org.joda.time.DateTime; +import org.junit.After; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; +import org.junit.runners.Parameterized; + +public class TestTimestampWithoutZone extends SparkCatalogTestBase { + + private static final String newTableName = "created_table"; + private final Map config; + + private static final Schema schema = new Schema( + Types.NestedField.required(1, "id", Types.LongType.get()), + Types.NestedField.required(2, "ts", Types.TimestampType.withoutZone()), + Types.NestedField.required(3, "tsz", Types.TimestampType.withZone()) + ); + + private final List values = ImmutableList.of( + row(1L, toTimestamp("2021-01-01T00:00:00.0"), toTimestamp("2021-02-01T00:00:00.0")), + row(2L, toTimestamp("2021-01-01T00:00:00.0"), toTimestamp("2021-02-01T00:00:00.0")), + row(3L, toTimestamp("2021-01-01T00:00:00.0"), toTimestamp("2021-02-01T00:00:00.0")) + ); + + @Parameterized.Parameters(name = "catalogName = {0}, implementation = {1}, config = {2}") + public static Object[][] parameters() { + return new Object[][]{{"spark_catalog", + SparkSessionCatalog.class.getName(), + ImmutableMap.of( + "type", "hive", + "default-namespace", "default", + "parquet-enabled", "true", + "cache-enabled", "false" + )} + }; + } + + public TestTimestampWithoutZone(String catalogName, String implementation, Map config) { + super(catalogName, implementation, config); + this.config = config; + } + + @Before + public void createTables() { + validationCatalog.createTable(tableIdent, schema); + } + + @After + public void removeTables() { + validationCatalog.dropTable(tableIdent, true); + sql("DROP TABLE IF EXISTS %s", newTableName); + } + + @Test + public void testWriteTimestampWithoutZoneError() { + AssertHelpers.assertThrows( + String.format("Write operation performed on a timestamp without timezone field while " + + "'%s' set to false should throw exception", SparkSQLProperties.HANDLE_TIMESTAMP_WITHOUT_TIMEZONE), + IllegalArgumentException.class, + SparkUtil.TIMESTAMP_WITHOUT_TIMEZONE_ERROR, + () -> sql("INSERT INTO %s VALUES %s", tableName, rowToSqlValues(values))); + } + + @Test + public void testAppendTimestampWithoutZone() { + withSQLConf(ImmutableMap.of(SparkSQLProperties.HANDLE_TIMESTAMP_WITHOUT_TIMEZONE, "true"), () -> { + sql("INSERT INTO %s VALUES %s", tableName, rowToSqlValues(values)); + + Assert.assertEquals("Should have " + values.size() + " row", + (long) values.size(), scalarSql("SELECT count(*) FROM %s", tableName)); + + assertEquals("Row data should match expected", + values, sql("SELECT * FROM %s ORDER BY id", tableName)); + }); + } + + @Test + public void testCreateAsSelectWithTimestampWithoutZone() { + withSQLConf(ImmutableMap.of(SparkSQLProperties.HANDLE_TIMESTAMP_WITHOUT_TIMEZONE, "true"), () -> { + sql("INSERT INTO %s VALUES %s", tableName, rowToSqlValues(values)); + + sql("CREATE TABLE %s USING iceberg AS SELECT * FROM %s", newTableName, tableName); + + Assert.assertEquals("Should have " + values.size() + " row", (long) values.size(), + scalarSql("SELECT count(*) FROM %s", newTableName)); + + assertEquals("Row data should match expected", + sql("SELECT * FROM %s ORDER BY id", tableName), + sql("SELECT * FROM %s ORDER BY id", newTableName)); + }); + } + + @Test + public void testCreateNewTableShouldHaveTimestampWithZoneIcebergType() { + withSQLConf(ImmutableMap.of(SparkSQLProperties.HANDLE_TIMESTAMP_WITHOUT_TIMEZONE, "true"), () -> { + sql("INSERT INTO %s VALUES %s", tableName, rowToSqlValues(values)); + + sql("CREATE TABLE %s USING iceberg AS SELECT * FROM %s", newTableName, tableName); + + Assert.assertEquals("Should have " + values.size() + " row", (long) values.size(), + scalarSql("SELECT count(*) FROM %s", newTableName)); + + assertEquals("Data from created table should match data from base table", + sql("SELECT * FROM %s ORDER BY id", tableName), + sql("SELECT * FROM %s ORDER BY id", newTableName)); + + Table createdTable = validationCatalog.loadTable(TableIdentifier.of("default", newTableName)); + assertFieldsType(createdTable.schema(), Types.TimestampType.withZone(), "ts", "tsz"); + }); + } + + @Test + public void testCreateNewTableShouldHaveTimestampWithoutZoneIcebergType() { + withSQLConf(ImmutableMap.of( + SparkSQLProperties.HANDLE_TIMESTAMP_WITHOUT_TIMEZONE, "true", + SparkSQLProperties.USE_TIMESTAMP_WITHOUT_TIME_ZONE_IN_NEW_TABLES, "true"), () -> { + spark.sessionState().catalogManager().currentCatalog() + .initialize(catalog.name(), new CaseInsensitiveStringMap(config)); + sql("INSERT INTO %s VALUES %s", tableName, rowToSqlValues(values)); + + sql("CREATE TABLE %s USING iceberg AS SELECT * FROM %s", newTableName, tableName); + + Assert.assertEquals("Should have " + values.size() + " row", (long) values.size(), + scalarSql("SELECT count(*) FROM %s", newTableName)); + + assertEquals("Row data should match expected", + sql("SELECT * FROM %s ORDER BY id", tableName), + sql("SELECT * FROM %s ORDER BY id", newTableName)); + Table createdTable = validationCatalog.loadTable(TableIdentifier.of("default", newTableName)); + assertFieldsType(createdTable.schema(), Types.TimestampType.withoutZone(), "ts", "tsz"); + }); + } + + private Timestamp toTimestamp(String value) { + return new Timestamp(DateTime.parse(value).getMillis()); + } + + private String rowToSqlValues(List rows) { + List rowValues = rows.stream().map(row -> { + List columns = Arrays.stream(row).map(value -> { + if (value instanceof Long) { + return value.toString(); + } else if (value instanceof Timestamp) { + return String.format("timestamp '%s'", value); + } + throw new RuntimeException("Type is not supported"); + }).collect(Collectors.toList()); + return "(" + Joiner.on(",").join(columns) + ")"; + }).collect(Collectors.toList()); + return Joiner.on(",").join(rowValues); + } + + private void assertFieldsType(Schema actual, Type.PrimitiveType expected, String... fields) { + actual.select(fields).asStruct().fields().forEach(field -> Assert.assertEquals(expected, field.type())); + } +} diff --git a/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/sql/TestUnpartitionedWrites.java b/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/sql/TestUnpartitionedWrites.java new file mode 100644 index 000000000000..bd6fb5abf2c6 --- /dev/null +++ b/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/sql/TestUnpartitionedWrites.java @@ -0,0 +1,174 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iceberg.spark.sql; + +import java.util.List; +import java.util.Map; +import org.apache.iceberg.AssertHelpers; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableList; +import org.apache.iceberg.spark.SparkCatalogTestBase; +import org.apache.iceberg.spark.source.SimpleRecord; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.catalyst.analysis.NoSuchTableException; +import org.apache.spark.sql.functions; +import org.junit.After; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; + +public class TestUnpartitionedWrites extends SparkCatalogTestBase { + public TestUnpartitionedWrites(String catalogName, String implementation, Map config) { + super(catalogName, implementation, config); + } + + @Before + public void createTables() { + sql("CREATE TABLE %s (id bigint, data string) USING iceberg", tableName); + sql("INSERT INTO %s VALUES (1, 'a'), (2, 'b'), (3, 'c')", tableName); + } + + @After + public void removeTables() { + sql("DROP TABLE IF EXISTS %s", tableName); + } + + @Test + public void testInsertAppend() { + Assert.assertEquals("Should have 3 rows", 3L, scalarSql("SELECT count(*) FROM %s", tableName)); + + sql("INSERT INTO %s VALUES (4, 'd'), (5, 'e')", tableName); + + Assert.assertEquals("Should have 5 rows after insert", 5L, scalarSql("SELECT count(*) FROM %s", tableName)); + + List expected = ImmutableList.of( + row(1L, "a"), + row(2L, "b"), + row(3L, "c"), + row(4L, "d"), + row(5L, "e") + ); + + assertEquals("Row data should match expected", expected, sql("SELECT * FROM %s ORDER BY id", tableName)); + } + + @Test + public void testInsertOverwrite() { + Assert.assertEquals("Should have 3 rows", 3L, scalarSql("SELECT count(*) FROM %s", tableName)); + + sql("INSERT OVERWRITE %s VALUES (4, 'd'), (5, 'e')", tableName); + + Assert.assertEquals("Should have 2 rows after overwrite", 2L, scalarSql("SELECT count(*) FROM %s", tableName)); + + List expected = ImmutableList.of( + row(4L, "d"), + row(5L, "e") + ); + + assertEquals("Row data should match expected", expected, sql("SELECT * FROM %s ORDER BY id", tableName)); + } + + @Test + public void testInsertAppendAtSnapshot() { + long snapshotId = validationCatalog.loadTable(tableIdent).currentSnapshot().snapshotId(); + String prefix = "snapshot_id_"; + AssertHelpers.assertThrows("Should not be able to insert into a table at a specific snapshot", + IllegalArgumentException.class, "Cannot write to table at a specific snapshot", + () -> sql("INSERT INTO %s.%s VALUES (4, 'd'), (5, 'e')", tableName, prefix + snapshotId)); + } + + @Test + public void testInsertOverwriteAtSnapshot() { + long snapshotId = validationCatalog.loadTable(tableIdent).currentSnapshot().snapshotId(); + String prefix = "snapshot_id_"; + AssertHelpers.assertThrows("Should not be able to insert into a table at a specific snapshot", + IllegalArgumentException.class, "Cannot write to table at a specific snapshot", + () -> sql("INSERT OVERWRITE %s.%s VALUES (4, 'd'), (5, 'e')", tableName, prefix + snapshotId)); + } + + @Test + public void testDataFrameV2Append() throws NoSuchTableException { + Assert.assertEquals("Should have 3 rows", 3L, scalarSql("SELECT count(*) FROM %s", tableName)); + + List data = ImmutableList.of( + new SimpleRecord(4, "d"), + new SimpleRecord(5, "e") + ); + Dataset ds = spark.createDataFrame(data, SimpleRecord.class); + + ds.writeTo(tableName).append(); + + Assert.assertEquals("Should have 5 rows after insert", 5L, scalarSql("SELECT count(*) FROM %s", tableName)); + + List expected = ImmutableList.of( + row(1L, "a"), + row(2L, "b"), + row(3L, "c"), + row(4L, "d"), + row(5L, "e") + ); + + assertEquals("Row data should match expected", expected, sql("SELECT * FROM %s ORDER BY id", tableName)); + } + + @Test + public void testDataFrameV2DynamicOverwrite() throws NoSuchTableException { + Assert.assertEquals("Should have 3 rows", 3L, scalarSql("SELECT count(*) FROM %s", tableName)); + + List data = ImmutableList.of( + new SimpleRecord(4, "d"), + new SimpleRecord(5, "e") + ); + Dataset ds = spark.createDataFrame(data, SimpleRecord.class); + + ds.writeTo(tableName).overwritePartitions(); + + Assert.assertEquals("Should have 2 rows after overwrite", 2L, scalarSql("SELECT count(*) FROM %s", tableName)); + + List expected = ImmutableList.of( + row(4L, "d"), + row(5L, "e") + ); + + assertEquals("Row data should match expected", expected, sql("SELECT * FROM %s ORDER BY id", tableName)); + } + + @Test + public void testDataFrameV2Overwrite() throws NoSuchTableException { + Assert.assertEquals("Should have 3 rows", 3L, scalarSql("SELECT count(*) FROM %s", tableName)); + + List data = ImmutableList.of( + new SimpleRecord(4, "d"), + new SimpleRecord(5, "e") + ); + Dataset ds = spark.createDataFrame(data, SimpleRecord.class); + + ds.writeTo(tableName).overwrite(functions.col("id").$less$eq(3)); + + Assert.assertEquals("Should have 2 rows after overwrite", 2L, scalarSql("SELECT count(*) FROM %s", tableName)); + + List expected = ImmutableList.of( + row(4L, "d"), + row(5L, "e") + ); + + assertEquals("Row data should match expected", expected, sql("SELECT * FROM %s ORDER BY id", tableName)); + } +}