diff --git a/.github/workflows/jmh-benchmarks.yml b/.github/workflows/jmh-benchmarks.yml index eaa3d03e386e..2e95baaeb917 100644 --- a/.github/workflows/jmh-benchmarks.yml +++ b/.github/workflows/jmh-benchmarks.yml @@ -28,8 +28,8 @@ on: description: 'The branch name' required: true spark_version: - description: 'The spark project version to use, such as iceberg-spark-3.3' - default: 'iceberg-spark-3.3' + description: 'The spark project version to use, such as iceberg-spark-3.4' + default: 'iceberg-spark-3.4' required: true benchmarks: description: 'A list of comma-separated double-quoted Benchmark names, such as "IcebergSourceFlatParquetDataReadBenchmark", "IcebergSourceFlatParquetDataFilterBenchmark"' diff --git a/.github/workflows/publish-snapshot.yml b/.github/workflows/publish-snapshot.yml index e849ebf1bee5..eb8b79f1a8f6 100644 --- a/.github/workflows/publish-snapshot.yml +++ b/.github/workflows/publish-snapshot.yml @@ -41,4 +41,4 @@ jobs: - run: | ./gradlew printVersion ./gradlew -DallVersions publishApachePublicationToMavenRepository -PmavenUser=${{ secrets.NEXUS_USER }} -PmavenPassword=${{ secrets.NEXUS_PW }} - ./gradlew -DflinkVersions= -DsparkVersions=3.2,3.3 -DscalaVersion=2.13 -DhiveVersions= publishApachePublicationToMavenRepository -PmavenUser=${{ secrets.NEXUS_USER }} -PmavenPassword=${{ secrets.NEXUS_PW }} + ./gradlew -DflinkVersions= -DsparkVersions=3.2,3.3,3.4 -DscalaVersion=2.13 -DhiveVersions= publishApachePublicationToMavenRepository -PmavenUser=${{ secrets.NEXUS_USER }} -PmavenPassword=${{ secrets.NEXUS_PW }} diff --git a/.github/workflows/recurring-jmh-benchmarks.yml b/.github/workflows/recurring-jmh-benchmarks.yml index 92cd924c4561..c6a1e80fe7ea 100644 --- a/.github/workflows/recurring-jmh-benchmarks.yml +++ b/.github/workflows/recurring-jmh-benchmarks.yml @@ -40,7 +40,7 @@ jobs: "IcebergSourceNestedParquetDataReadBenchmark", "IcebergSourceNestedParquetDataWriteBenchmark", "IcebergSourceParquetEqDeleteBenchmark", "IcebergSourceParquetMultiDeleteFileBenchmark", "IcebergSourceParquetPosDeleteBenchmark", "IcebergSourceParquetWithUnrelatedDeleteBenchmark"] - spark_version: ['iceberg-spark-3.3'] + spark_version: ['iceberg-spark-3.4'] env: SPARK_LOCAL_IP: localhost steps: diff --git a/.github/workflows/spark-ci.yml b/.github/workflows/spark-ci.yml index 2a22a2df8ccf..794d845ad635 100644 --- a/.github/workflows/spark-ci.yml +++ b/.github/workflows/spark-ci.yml @@ -86,7 +86,7 @@ jobs: strategy: matrix: jvm: [8, 11] - spark: ['3.1', '3.2', '3.3'] + spark: ['3.1', '3.2', '3.3', '3.4'] env: SPARK_LOCAL_IP: localhost steps: @@ -116,7 +116,7 @@ jobs: strategy: matrix: jvm: [8, 11] - spark: ['3.2','3.3'] + spark: ['3.2','3.3', '3.4'] env: SPARK_LOCAL_IP: localhost steps: diff --git a/.gitignore b/.gitignore index a259f712379c..54e4275accdb 100644 --- a/.gitignore +++ b/.gitignore @@ -32,6 +32,7 @@ spark/v2.4/spark/benchmark/* spark/v3.1/spark/benchmark/* spark/v3.2/spark/benchmark/* spark/v3.3/spark/benchmark/* +spark/v3.4/spark/benchmark/* data/benchmark/* __pycache__/ diff --git a/dev/stage-binaries.sh b/dev/stage-binaries.sh index 7421993ade66..4fc0c514c695 100755 --- a/dev/stage-binaries.sh +++ b/dev/stage-binaries.sh @@ -20,7 +20,7 @@ SCALA_VERSION=2.12 FLINK_VERSIONS=1.15,1.16,1.17 -SPARK_VERSIONS=2.4,3.1,3.2,3.3 +SPARK_VERSIONS=2.4,3.1,3.2,3.3,3.4 HIVE_VERSIONS=2,3 ./gradlew -Prelease -DscalaVersion=$SCALA_VERSION -DflinkVersions=$FLINK_VERSIONS -DsparkVersions=$SPARK_VERSIONS -DhiveVersions=$HIVE_VERSIONS publishApachePublicationToMavenRepository @@ -29,4 +29,5 @@ HIVE_VERSIONS=2,3 # Flink does not yet support 2.13 (and is largely dropping a user-facing dependency on Scala). Hive doesn't need a Scala specification. ./gradlew -Prelease -DscalaVersion=2.13 -DsparkVersions=3.2 :iceberg-spark:iceberg-spark-3.2_2.13:publishApachePublicationToMavenRepository :iceberg-spark:iceberg-spark-extensions-3.2_2.13:publishApachePublicationToMavenRepository :iceberg-spark:iceberg-spark-runtime-3.2_2.13:publishApachePublicationToMavenRepository ./gradlew -Prelease -DscalaVersion=2.13 -DsparkVersions=3.3 :iceberg-spark:iceberg-spark-3.3_2.13:publishApachePublicationToMavenRepository :iceberg-spark:iceberg-spark-extensions-3.3_2.13:publishApachePublicationToMavenRepository :iceberg-spark:iceberg-spark-runtime-3.3_2.13:publishApachePublicationToMavenRepository +./gradlew -Prelease -DscalaVersion=2.13 -DsparkVersions=3.4 :iceberg-spark:iceberg-spark-3.4_2.13:publishApachePublicationToMavenRepository :iceberg-spark:iceberg-spark-extensions-3.4_2.13:publishApachePublicationToMavenRepository :iceberg-spark:iceberg-spark-runtime-3.4_2.13:publishApachePublicationToMavenRepository diff --git a/gradle.properties b/gradle.properties index 71f4e4118274..8d8a9f0dc021 100644 --- a/gradle.properties +++ b/gradle.properties @@ -20,8 +20,8 @@ systemProp.defaultFlinkVersions=1.17 systemProp.knownFlinkVersions=1.15,1.16,1.17 systemProp.defaultHiveVersions=2 systemProp.knownHiveVersions=2,3 -systemProp.defaultSparkVersions=3.3 -systemProp.knownSparkVersions=2.4,3.1,3.2,3.3 +systemProp.defaultSparkVersions=3.4 +systemProp.knownSparkVersions=2.4,3.1,3.2,3.3,3.4 systemProp.defaultScalaVersion=2.12 systemProp.knownScalaVersions=2.12,2.13 org.gradle.parallel=true diff --git a/jmh.gradle b/jmh.gradle index 41529d18926b..e560365931b8 100644 --- a/jmh.gradle +++ b/jmh.gradle @@ -41,6 +41,10 @@ if (sparkVersions.contains("3.3")) { jmhProjects.add(project(":iceberg-spark:iceberg-spark-3.3_${scalaVersion}")) } +if (sparkVersions.contains("3.4")) { + jmhProjects.add(project(":iceberg-spark:iceberg-spark-3.4_${scalaVersion}")) +} + jmhProjects.add(project(":iceberg-data")) configure(jmhProjects) { diff --git a/settings.gradle b/settings.gradle index 37fdbd98ee4e..40db31aa1282 100644 --- a/settings.gradle +++ b/settings.gradle @@ -161,6 +161,18 @@ if (sparkVersions.contains("3.3")) { project(":iceberg-spark:spark-runtime-3.3_${scalaVersion}").name = "iceberg-spark-runtime-3.3_${scalaVersion}" } +if (sparkVersions.contains("3.4")) { + include ":iceberg-spark:spark-3.4_${scalaVersion}" + include ":iceberg-spark:spark-extensions-3.4_${scalaVersion}" + include ":iceberg-spark:spark-runtime-3.4_${scalaVersion}" + project(":iceberg-spark:spark-3.4_${scalaVersion}").projectDir = file('spark/v3.4/spark') + project(":iceberg-spark:spark-3.4_${scalaVersion}").name = "iceberg-spark-3.4_${scalaVersion}" + project(":iceberg-spark:spark-extensions-3.4_${scalaVersion}").projectDir = file('spark/v3.4/spark-extensions') + project(":iceberg-spark:spark-extensions-3.4_${scalaVersion}").name = "iceberg-spark-extensions-3.4_${scalaVersion}" + project(":iceberg-spark:spark-runtime-3.4_${scalaVersion}").projectDir = file('spark/v3.4/spark-runtime') + project(":iceberg-spark:spark-runtime-3.4_${scalaVersion}").name = "iceberg-spark-runtime-3.4_${scalaVersion}" +} + // hive 3 depends on hive 2, so always add hive 2 if hive3 is enabled if (hiveVersions.contains("2") || hiveVersions.contains("3")) { include 'mr' diff --git a/spark/build.gradle b/spark/build.gradle index d1a8549df04d..f9947e34a034 100644 --- a/spark/build.gradle +++ b/spark/build.gradle @@ -34,4 +34,8 @@ if (sparkVersions.contains("3.2")) { if (sparkVersions.contains("3.3")) { apply from: file("$projectDir/v3.3/build.gradle") -} \ No newline at end of file +} + +if (sparkVersions.contains("3.4")) { + apply from: file("$projectDir/v3.4/build.gradle") +} diff --git a/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/source/SparkMicroBatchStream.java b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/source/SparkMicroBatchStream.java index 6e03dd69a850..317a7863eff7 100644 --- a/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/source/SparkMicroBatchStream.java +++ b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/source/SparkMicroBatchStream.java @@ -212,16 +212,7 @@ private List planFiles(StreamingOffset startOffset, StreamingOffse currentOffset = new StreamingOffset(snapshotAfter.snapshotId(), 0L, false); } - Snapshot snapshot = table.snapshot(currentOffset.snapshotId()); - - if (snapshot == null) { - throw new IllegalStateException( - String.format( - "Cannot load current offset at snapshot %d, the snapshot was expired or removed", - currentOffset.snapshotId())); - } - - if (!shouldProcess(snapshot)) { + if (!shouldProcess(table.snapshot(currentOffset.snapshotId()))) { LOG.debug("Skipping snapshot: {} of table {}", currentOffset.snapshotId(), table.name()); continue; } diff --git a/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/source/TestStructuredStreamingRead3.java b/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/source/TestStructuredStreamingRead3.java index dd456f22371e..23fdfb09cb83 100644 --- a/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/source/TestStructuredStreamingRead3.java +++ b/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/source/TestStructuredStreamingRead3.java @@ -19,10 +19,8 @@ package org.apache.iceberg.spark.source; import static org.apache.iceberg.expressions.Expressions.ref; -import static org.assertj.core.api.AssertionsForClassTypes.assertThatThrownBy; import java.io.File; -import java.io.IOException; import java.util.Collections; import java.util.List; import java.util.Map; @@ -325,46 +323,6 @@ public void testResumingStreamReadFromCheckpoint() throws Exception { } } - @Test - public void testFailReadingCheckpointInvalidSnapshot() throws IOException, TimeoutException { - File writerCheckpointFolder = temp.newFolder("writer-checkpoint-folder"); - File writerCheckpoint = new File(writerCheckpointFolder, "writer-checkpoint"); - File output = temp.newFolder(); - - DataStreamWriter querySource = - spark - .readStream() - .format("iceberg") - .load(tableName) - .writeStream() - .option("checkpointLocation", writerCheckpoint.toString()) - .format("parquet") - .queryName("checkpoint_test") - .option("path", output.getPath()); - - List firstSnapshotRecordList = Lists.newArrayList(new SimpleRecord(1, "one")); - List secondSnapshotRecordList = Lists.newArrayList(new SimpleRecord(2, "two")); - StreamingQuery startQuery = querySource.start(); - - appendData(firstSnapshotRecordList); - table.refresh(); - long firstSnapshotid = table.currentSnapshot().snapshotId(); - startQuery.processAllAvailable(); - startQuery.stop(); - - appendData(secondSnapshotRecordList); - - table.expireSnapshots().expireSnapshotId(firstSnapshotid).commit(); - - StreamingQuery restartedQuery = querySource.start(); - assertThatThrownBy(restartedQuery::processAllAvailable) - .hasCauseInstanceOf(IllegalStateException.class) - .hasMessageContaining( - String.format( - "Cannot load current offset at snapshot %d, the snapshot was expired or removed", - firstSnapshotid)); - } - @Test public void testParquetOrcAvroDataInOneTable() throws Exception { List parquetFileRecords = diff --git a/spark/v3.4/build.gradle b/spark/v3.4/build.gradle new file mode 100644 index 000000000000..5cf131098742 --- /dev/null +++ b/spark/v3.4/build.gradle @@ -0,0 +1,289 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +String sparkVersion = '3.4.0' +String sparkMajorVersion = '3.4' +String scalaVersion = System.getProperty("scalaVersion") != null ? System.getProperty("scalaVersion") : System.getProperty("defaultScalaVersion") + +def sparkProjects = [ + project(":iceberg-spark:iceberg-spark-${sparkMajorVersion}_${scalaVersion}"), + project(":iceberg-spark:iceberg-spark-extensions-${sparkMajorVersion}_${scalaVersion}"), + project(":iceberg-spark:iceberg-spark-runtime-${sparkMajorVersion}_${scalaVersion}"), +] + +configure(sparkProjects) { + configurations { + all { + resolutionStrategy { + force "com.fasterxml.jackson.module:jackson-module-scala_${scalaVersion}:2.13.4" + force 'com.fasterxml.jackson.core:jackson-databind:2.13.4.2' + force 'com.fasterxml.jackson.core:jackson-core:2.13.4' + } + } + } +} + +project(":iceberg-spark:iceberg-spark-${sparkMajorVersion}_${scalaVersion}") { + apply plugin: 'scala' + apply plugin: 'com.github.alisiikh.scalastyle' + + sourceSets { + main { + scala.srcDirs = ['src/main/scala', 'src/main/java'] + java.srcDirs = [] + } + } + + dependencies { + implementation project(path: ':iceberg-bundled-guava', configuration: 'shadow') + api project(':iceberg-api') + implementation project(':iceberg-common') + implementation project(':iceberg-core') + implementation project(':iceberg-data') + implementation project(':iceberg-orc') + implementation project(':iceberg-parquet') + implementation project(':iceberg-arrow') + implementation("org.scala-lang.modules:scala-collection-compat_${scalaVersion}") + + compileOnly "com.google.errorprone:error_prone_annotations" + compileOnly "org.apache.avro:avro" + compileOnly("org.apache.spark:spark-hive_${scalaVersion}:${sparkVersion}") { + exclude group: 'org.apache.avro', module: 'avro' + exclude group: 'org.apache.arrow' + exclude group: 'org.apache.parquet' + // to make sure netty libs only come from project(':iceberg-arrow') + exclude group: 'io.netty', module: 'netty-buffer' + exclude group: 'io.netty', module: 'netty-common' + exclude group: 'org.roaringbitmap' + } + + implementation("org.apache.parquet:parquet-column") + implementation("org.apache.parquet:parquet-hadoop") + + implementation("org.apache.orc:orc-core::nohive") { + exclude group: 'org.apache.hadoop' + exclude group: 'commons-lang' + // These artifacts are shaded and included in the orc-core fat jar + exclude group: 'com.google.protobuf', module: 'protobuf-java' + exclude group: 'org.apache.hive', module: 'hive-storage-api' + } + + implementation("org.apache.arrow:arrow-vector") { + exclude group: 'io.netty', module: 'netty-buffer' + exclude group: 'io.netty', module: 'netty-common' + exclude group: 'com.google.code.findbugs', module: 'jsr305' + } + + testImplementation("org.apache.hadoop:hadoop-minicluster") { + exclude group: 'org.apache.avro', module: 'avro' + // to make sure netty libs only come from project(':iceberg-arrow') + exclude group: 'io.netty', module: 'netty-buffer' + exclude group: 'io.netty', module: 'netty-common' + } + testImplementation project(path: ':iceberg-hive-metastore') + testImplementation project(path: ':iceberg-hive-metastore', configuration: 'testArtifacts') + testImplementation project(path: ':iceberg-api', configuration: 'testArtifacts') + testImplementation project(path: ':iceberg-core', configuration: 'testArtifacts') + testImplementation project(path: ':iceberg-data', configuration: 'testArtifacts') + testImplementation "org.xerial:sqlite-jdbc" + } + + test { + useJUnitPlatform() + } + + tasks.withType(Test) { + // Vectorized reads need more memory + maxHeapSize '2560m' + } +} + +project(":iceberg-spark:iceberg-spark-extensions-${sparkMajorVersion}_${scalaVersion}") { + apply plugin: 'java-library' + apply plugin: 'scala' + apply plugin: 'com.github.alisiikh.scalastyle' + apply plugin: 'antlr' + + configurations { + /* + The Gradle Antlr plugin erroneously adds both antlr-build and runtime dependencies to the runtime path. This + bug https://github.com/gradle/gradle/issues/820 exists because older versions of Antlr do not have separate + runtime and implementation dependencies and they do not want to break backwards compatibility. So to only end up with + the runtime dependency on the runtime classpath we remove the dependencies added by the plugin here. Then add + the runtime dependency back to only the runtime configuration manually. + */ + implementation { + extendsFrom = extendsFrom.findAll { it != configurations.antlr } + } + } + + dependencies { + implementation("org.scala-lang.modules:scala-collection-compat_${scalaVersion}") + + compileOnly "org.scala-lang:scala-library:${scalaVersion}" + compileOnly project(path: ':iceberg-bundled-guava', configuration: 'shadow') + compileOnly project(':iceberg-api') + compileOnly project(':iceberg-core') + compileOnly project(':iceberg-common') + compileOnly project(":iceberg-spark:iceberg-spark-${sparkMajorVersion}_${scalaVersion}") + compileOnly("org.apache.spark:spark-hive_${scalaVersion}:${sparkVersion}") { + exclude group: 'org.apache.avro', module: 'avro' + exclude group: 'org.apache.arrow' + exclude group: 'org.apache.parquet' + // to make sure netty libs only come from project(':iceberg-arrow') + exclude group: 'io.netty', module: 'netty-buffer' + exclude group: 'io.netty', module: 'netty-common' + exclude group: 'org.roaringbitmap' + } + + testImplementation project(path: ':iceberg-data') + testImplementation project(path: ':iceberg-parquet') + testImplementation project(path: ':iceberg-hive-metastore') + testImplementation project(path: ':iceberg-api', configuration: 'testArtifacts') + testImplementation project(path: ':iceberg-hive-metastore', configuration: 'testArtifacts') + testImplementation project(path: ":iceberg-spark:iceberg-spark-${sparkMajorVersion}_${scalaVersion}", configuration: 'testArtifacts') + + testImplementation "org.apache.avro:avro" + testImplementation "org.apache.parquet:parquet-hadoop" + + // Required because we remove antlr plugin dependencies from the compile configuration, see note above + runtimeOnly "org.antlr:antlr4-runtime:4.8" + antlr "org.antlr:antlr4:4.8" + } + + generateGrammarSource { + maxHeapSize = "64m" + arguments += ['-visitor', '-package', 'org.apache.spark.sql.catalyst.parser.extensions'] + } +} + +project(":iceberg-spark:iceberg-spark-runtime-${sparkMajorVersion}_${scalaVersion}") { + apply plugin: 'com.github.johnrengelman.shadow' + + tasks.jar.dependsOn tasks.shadowJar + + sourceSets { + integration { + java.srcDir "$projectDir/src/integration/java" + resources.srcDir "$projectDir/src/integration/resources" + } + } + + configurations { + implementation { + exclude group: 'org.apache.spark' + // included in Spark + exclude group: 'org.slf4j' + exclude group: 'org.apache.commons' + exclude group: 'commons-pool' + exclude group: 'commons-codec' + exclude group: 'org.xerial.snappy' + exclude group: 'javax.xml.bind' + exclude group: 'javax.annotation' + exclude group: 'com.github.luben' + exclude group: 'com.ibm.icu' + exclude group: 'org.glassfish' + exclude group: 'org.abego.treelayout' + exclude group: 'org.antlr' + exclude group: 'org.scala-lang' + exclude group: 'org.scala-lang.modules' + } + } + + dependencies { + api project(':iceberg-api') + implementation project(":iceberg-spark:iceberg-spark-${sparkMajorVersion}_${scalaVersion}") + implementation project(":iceberg-spark:iceberg-spark-extensions-${sparkMajorVersion}_${scalaVersion}") + implementation project(':iceberg-aws') + implementation(project(':iceberg-aliyun')) { + exclude group: 'edu.umd.cs.findbugs', module: 'findbugs' + exclude group: 'org.apache.httpcomponents', module: 'httpclient' + exclude group: 'commons-logging', module: 'commons-logging' + } + implementation project(':iceberg-hive-metastore') + implementation(project(':iceberg-nessie')) { + exclude group: 'com.google.code.findbugs', module: 'jsr305' + } + implementation (project(':iceberg-snowflake')) { + exclude group: 'net.snowflake' , module: 'snowflake-jdbc' + } + + integrationImplementation "org.scala-lang.modules:scala-collection-compat_${scalaVersion}" + integrationImplementation "org.apache.spark:spark-hive_${scalaVersion}:${sparkVersion}" + integrationImplementation 'org.junit.vintage:junit-vintage-engine' + integrationImplementation 'org.slf4j:slf4j-simple' + integrationImplementation project(path: ':iceberg-api', configuration: 'testArtifacts') + integrationImplementation project(path: ':iceberg-hive-metastore', configuration: 'testArtifacts') + integrationImplementation project(path: ":iceberg-spark:iceberg-spark-${sparkMajorVersion}_${scalaVersion}", configuration: 'testArtifacts') + integrationImplementation project(path: ":iceberg-spark:iceberg-spark-extensions-${sparkMajorVersion}_${scalaVersion}", configuration: 'testArtifacts') + // Not allowed on our classpath, only the runtime jar is allowed + integrationCompileOnly project(":iceberg-spark:iceberg-spark-extensions-${sparkMajorVersion}_${scalaVersion}") + integrationCompileOnly project(":iceberg-spark:iceberg-spark-${sparkMajorVersion}_${scalaVersion}") + integrationCompileOnly project(':iceberg-api') + } + + shadowJar { + configurations = [project.configurations.runtimeClasspath] + + zip64 true + + // include the LICENSE and NOTICE files for the shaded Jar + from(projectDir) { + include 'LICENSE' + include 'NOTICE' + } + + // Relocate dependencies to avoid conflicts + relocate 'com.google', 'org.apache.iceberg.shaded.com.google' + relocate 'com.fasterxml', 'org.apache.iceberg.shaded.com.fasterxml' + relocate 'com.github.benmanes', 'org.apache.iceberg.shaded.com.github.benmanes' + relocate 'org.checkerframework', 'org.apache.iceberg.shaded.org.checkerframework' + relocate 'org.apache.avro', 'org.apache.iceberg.shaded.org.apache.avro' + relocate 'avro.shaded', 'org.apache.iceberg.shaded.org.apache.avro.shaded' + relocate 'com.thoughtworks.paranamer', 'org.apache.iceberg.shaded.com.thoughtworks.paranamer' + relocate 'org.apache.parquet', 'org.apache.iceberg.shaded.org.apache.parquet' + relocate 'shaded.parquet', 'org.apache.iceberg.shaded.org.apache.parquet.shaded' + relocate 'org.apache.orc', 'org.apache.iceberg.shaded.org.apache.orc' + relocate 'io.airlift', 'org.apache.iceberg.shaded.io.airlift' + relocate 'org.apache.httpcomponents.client5', 'org.apache.iceberg.shaded.org.apache.httpcomponents.client5' + // relocate Arrow and related deps to shade Iceberg specific version + relocate 'io.netty', 'org.apache.iceberg.shaded.io.netty' + relocate 'org.apache.arrow', 'org.apache.iceberg.shaded.org.apache.arrow' + relocate 'com.carrotsearch', 'org.apache.iceberg.shaded.com.carrotsearch' + relocate 'org.threeten.extra', 'org.apache.iceberg.shaded.org.threeten.extra' + relocate 'org.roaringbitmap', 'org.apache.iceberg.shaded.org.roaringbitmap' + + archiveClassifier.set(null) + } + + task integrationTest(type: Test) { + description = "Test Spark3 Runtime Jar against Spark ${sparkMajorVersion}" + group = "verification" + testClassesDirs = sourceSets.integration.output.classesDirs + classpath = sourceSets.integration.runtimeClasspath + files(shadowJar.archiveFile.get().asFile.path) + inputs.file(shadowJar.archiveFile.get().asFile.path) + } + integrationTest.dependsOn shadowJar + check.dependsOn integrationTest + + jar { + enabled = false + } +} + diff --git a/spark/v3.4/spark-extensions/src/main/antlr/org.apache.spark.sql.catalyst.parser.extensions/IcebergSqlExtensions.g4 b/spark/v3.4/spark-extensions/src/main/antlr/org.apache.spark.sql.catalyst.parser.extensions/IcebergSqlExtensions.g4 new file mode 100644 index 000000000000..b962699d9b47 --- /dev/null +++ b/spark/v3.4/spark-extensions/src/main/antlr/org.apache.spark.sql.catalyst.parser.extensions/IcebergSqlExtensions.g4 @@ -0,0 +1,374 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + * + * This file is an adaptation of Presto's and Spark's grammar files. + */ + +grammar IcebergSqlExtensions; + +@lexer::members { + /** + * Verify whether current token is a valid decimal token (which contains dot). + * Returns true if the character that follows the token is not a digit or letter or underscore. + * + * For example: + * For char stream "2.3", "2." is not a valid decimal token, because it is followed by digit '3'. + * For char stream "2.3_", "2.3" is not a valid decimal token, because it is followed by '_'. + * For char stream "2.3W", "2.3" is not a valid decimal token, because it is followed by 'W'. + * For char stream "12.0D 34.E2+0.12 " 12.0D is a valid decimal token because it is followed + * by a space. 34.E2 is a valid decimal token because it is followed by symbol '+' + * which is not a digit or letter or underscore. + */ + public boolean isValidDecimal() { + int nextChar = _input.LA(1); + if (nextChar >= 'A' && nextChar <= 'Z' || nextChar >= '0' && nextChar <= '9' || + nextChar == '_') { + return false; + } else { + return true; + } + } + + /** + * This method will be called when we see '/*' and try to match it as a bracketed comment. + * If the next character is '+', it should be parsed as hint later, and we cannot match + * it as a bracketed comment. + * + * Returns true if the next character is '+'. + */ + public boolean isHint() { + int nextChar = _input.LA(1); + if (nextChar == '+') { + return true; + } else { + return false; + } + } +} + +singleStatement + : statement EOF + ; + +statement + : CALL multipartIdentifier '(' (callArgument (',' callArgument)*)? ')' #call + | ALTER TABLE multipartIdentifier ADD PARTITION FIELD transform (AS name=identifier)? #addPartitionField + | ALTER TABLE multipartIdentifier DROP PARTITION FIELD transform #dropPartitionField + | ALTER TABLE multipartIdentifier REPLACE PARTITION FIELD transform WITH transform (AS name=identifier)? #replacePartitionField + | ALTER TABLE multipartIdentifier WRITE writeSpec #setWriteDistributionAndOrdering + | ALTER TABLE multipartIdentifier SET IDENTIFIER_KW FIELDS fieldList #setIdentifierFields + | ALTER TABLE multipartIdentifier DROP IDENTIFIER_KW FIELDS fieldList #dropIdentifierFields + | ALTER TABLE multipartIdentifier createReplaceBranchClause #createOrReplaceBranch + | ALTER TABLE multipartIdentifier createReplaceTagClause #createOrReplaceTag + | ALTER TABLE multipartIdentifier DROP BRANCH (IF EXISTS)? identifier #dropBranch + | ALTER TABLE multipartIdentifier DROP TAG (IF EXISTS)? identifier #dropTag + ; + +createReplaceTagClause + : (CREATE OR)? REPLACE TAG identifier tagOptions + | CREATE TAG (IF NOT EXISTS)? identifier tagOptions + ; + +createReplaceBranchClause + : (CREATE OR)? REPLACE BRANCH identifier branchOptions + | CREATE BRANCH (IF NOT EXISTS)? identifier branchOptions + ; + +tagOptions + : (AS OF VERSION snapshotId)? (refRetain)? + ; + +branchOptions + : (AS OF VERSION snapshotId)? (refRetain)? (snapshotRetention)? + ; + +snapshotRetention + : WITH SNAPSHOT RETENTION minSnapshotsToKeep + | WITH SNAPSHOT RETENTION maxSnapshotAge + | WITH SNAPSHOT RETENTION minSnapshotsToKeep maxSnapshotAge + ; + +refRetain + : RETAIN number timeUnit + ; + +maxSnapshotAge + : number timeUnit + ; + +minSnapshotsToKeep + : number SNAPSHOTS + ; + +writeSpec + : (writeDistributionSpec | writeOrderingSpec)* + ; + +writeDistributionSpec + : DISTRIBUTED BY PARTITION + ; + +writeOrderingSpec + : LOCALLY? ORDERED BY order + | UNORDERED + ; + +callArgument + : expression #positionalArgument + | identifier '=>' expression #namedArgument + ; + +singleOrder + : order EOF + ; + +order + : fields+=orderField (',' fields+=orderField)* + | '(' fields+=orderField (',' fields+=orderField)* ')' + ; + +orderField + : transform direction=(ASC | DESC)? (NULLS nullOrder=(FIRST | LAST))? + ; + +transform + : multipartIdentifier #identityTransform + | transformName=identifier + '(' arguments+=transformArgument (',' arguments+=transformArgument)* ')' #applyTransform + ; + +transformArgument + : multipartIdentifier + | constant + ; + +expression + : constant + | stringMap + | stringArray + ; + +constant + : number #numericLiteral + | booleanValue #booleanLiteral + | STRING+ #stringLiteral + | identifier STRING #typeConstructor + ; + +stringMap + : MAP '(' constant (',' constant)* ')' + ; + +stringArray + : ARRAY '(' constant (',' constant)* ')' + ; + +booleanValue + : TRUE | FALSE + ; + +number + : MINUS? EXPONENT_VALUE #exponentLiteral + | MINUS? DECIMAL_VALUE #decimalLiteral + | MINUS? INTEGER_VALUE #integerLiteral + | MINUS? BIGINT_LITERAL #bigIntLiteral + | MINUS? SMALLINT_LITERAL #smallIntLiteral + | MINUS? TINYINT_LITERAL #tinyIntLiteral + | MINUS? DOUBLE_LITERAL #doubleLiteral + | MINUS? FLOAT_LITERAL #floatLiteral + | MINUS? BIGDECIMAL_LITERAL #bigDecimalLiteral + ; + +multipartIdentifier + : parts+=identifier ('.' parts+=identifier)* + ; + +identifier + : IDENTIFIER #unquotedIdentifier + | quotedIdentifier #quotedIdentifierAlternative + | nonReserved #unquotedIdentifier + ; + +quotedIdentifier + : BACKQUOTED_IDENTIFIER + ; + +fieldList + : fields+=multipartIdentifier (',' fields+=multipartIdentifier)* + ; + +nonReserved + : ADD | ALTER | AS | ASC | BRANCH | BY | CALL | CREATE | DAYS | DESC | DROP | EXISTS | FIELD | FIRST | HOURS | IF | LAST | NOT | NULLS | OF | OR | ORDERED | PARTITION | TABLE | WRITE + | DISTRIBUTED | LOCALLY | MINUTES | MONTHS | UNORDERED | REPLACE | RETAIN | VERSION | WITH | IDENTIFIER_KW | FIELDS | SET | SNAPSHOT | SNAPSHOTS + | TAG | TRUE | FALSE + | MAP + ; + +snapshotId + : number + ; + +numSnapshots + : number + ; + +timeUnit + : DAYS + | HOURS + | MINUTES + ; + +ADD: 'ADD'; +ALTER: 'ALTER'; +AS: 'AS'; +ASC: 'ASC'; +BRANCH: 'BRANCH'; +BY: 'BY'; +CALL: 'CALL'; +DAYS: 'DAYS'; +DESC: 'DESC'; +DISTRIBUTED: 'DISTRIBUTED'; +DROP: 'DROP'; +EXISTS: 'EXISTS'; +FIELD: 'FIELD'; +FIELDS: 'FIELDS'; +FIRST: 'FIRST'; +HOURS: 'HOURS'; +IF : 'IF'; +LAST: 'LAST'; +LOCALLY: 'LOCALLY'; +MINUTES: 'MINUTES'; +MONTHS: 'MONTHS'; +CREATE: 'CREATE'; +NOT: 'NOT'; +NULLS: 'NULLS'; +OF: 'OF'; +OR: 'OR'; +ORDERED: 'ORDERED'; +PARTITION: 'PARTITION'; +REPLACE: 'REPLACE'; +RETAIN: 'RETAIN'; +RETENTION: 'RETENTION'; +IDENTIFIER_KW: 'IDENTIFIER'; +SET: 'SET'; +SNAPSHOT: 'SNAPSHOT'; +SNAPSHOTS: 'SNAPSHOTS'; +TABLE: 'TABLE'; +TAG: 'TAG'; +UNORDERED: 'UNORDERED'; +VERSION: 'VERSION'; +WITH: 'WITH'; +WRITE: 'WRITE'; + +TRUE: 'TRUE'; +FALSE: 'FALSE'; + +MAP: 'MAP'; +ARRAY: 'ARRAY'; + +PLUS: '+'; +MINUS: '-'; + +STRING + : '\'' ( ~('\''|'\\') | ('\\' .) )* '\'' + | '"' ( ~('"'|'\\') | ('\\' .) )* '"' + ; + +BIGINT_LITERAL + : DIGIT+ 'L' + ; + +SMALLINT_LITERAL + : DIGIT+ 'S' + ; + +TINYINT_LITERAL + : DIGIT+ 'Y' + ; + +INTEGER_VALUE + : DIGIT+ + ; + +EXPONENT_VALUE + : DIGIT+ EXPONENT + | DECIMAL_DIGITS EXPONENT {isValidDecimal()}? + ; + +DECIMAL_VALUE + : DECIMAL_DIGITS {isValidDecimal()}? + ; + +FLOAT_LITERAL + : DIGIT+ EXPONENT? 'F' + | DECIMAL_DIGITS EXPONENT? 'F' {isValidDecimal()}? + ; + +DOUBLE_LITERAL + : DIGIT+ EXPONENT? 'D' + | DECIMAL_DIGITS EXPONENT? 'D' {isValidDecimal()}? + ; + +BIGDECIMAL_LITERAL + : DIGIT+ EXPONENT? 'BD' + | DECIMAL_DIGITS EXPONENT? 'BD' {isValidDecimal()}? + ; + +IDENTIFIER + : (LETTER | DIGIT | '_')+ + ; + +BACKQUOTED_IDENTIFIER + : '`' ( ~'`' | '``' )* '`' + ; + +fragment DECIMAL_DIGITS + : DIGIT+ '.' DIGIT* + | '.' DIGIT+ + ; + +fragment EXPONENT + : 'E' [+-]? DIGIT+ + ; + +fragment DIGIT + : [0-9] + ; + +fragment LETTER + : [A-Z] + ; + +SIMPLE_COMMENT + : '--' ('\\\n' | ~[\r\n])* '\r'? '\n'? -> channel(HIDDEN) + ; + +BRACKETED_COMMENT + : '/*' {!isHint()}? (BRACKETED_COMMENT|.)*? '*/' -> channel(HIDDEN) + ; + +WS + : [ \r\n\t]+ -> channel(HIDDEN) + ; + +// Catch-all for anything we can't recognize. +// We use this to be able to ignore and recover all the text +// when splitting statements with DelimiterLexer +UNRECOGNIZED + : . + ; diff --git a/spark/v3.4/spark-extensions/src/main/scala/org/apache/iceberg/spark/extensions/IcebergSparkSessionExtensions.scala b/spark/v3.4/spark-extensions/src/main/scala/org/apache/iceberg/spark/extensions/IcebergSparkSessionExtensions.scala new file mode 100644 index 000000000000..4fb9a48a3e00 --- /dev/null +++ b/spark/v3.4/spark-extensions/src/main/scala/org/apache/iceberg/spark/extensions/IcebergSparkSessionExtensions.scala @@ -0,0 +1,77 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iceberg.spark.extensions + +import org.apache.spark.sql.SparkSessionExtensions +import org.apache.spark.sql.catalyst.analysis.AlignedRowLevelIcebergCommandCheck +import org.apache.spark.sql.catalyst.analysis.AlignRowLevelCommandAssignments +import org.apache.spark.sql.catalyst.analysis.CheckMergeIntoTableConditions +import org.apache.spark.sql.catalyst.analysis.MergeIntoIcebergTableResolutionCheck +import org.apache.spark.sql.catalyst.analysis.ProcedureArgumentCoercion +import org.apache.spark.sql.catalyst.analysis.ResolveMergeIntoTableReferences +import org.apache.spark.sql.catalyst.analysis.ResolveProcedures +import org.apache.spark.sql.catalyst.analysis.RewriteDeleteFromIcebergTable +import org.apache.spark.sql.catalyst.analysis.RewriteMergeIntoTable +import org.apache.spark.sql.catalyst.analysis.RewriteUpdateTable +import org.apache.spark.sql.catalyst.optimizer.ExtendedReplaceNullWithFalseInPredicate +import org.apache.spark.sql.catalyst.optimizer.ExtendedSimplifyConditionalsInPredicate +import org.apache.spark.sql.catalyst.parser.extensions.IcebergSparkSqlExtensionsParser +import org.apache.spark.sql.execution.datasources.v2.ExtendedDataSourceV2Strategy +import org.apache.spark.sql.execution.datasources.v2.ExtendedV2Writes +import org.apache.spark.sql.execution.datasources.v2.OptimizeMetadataOnlyDeleteFromIcebergTable +import org.apache.spark.sql.execution.datasources.v2.ReplaceRewrittenRowLevelCommand +import org.apache.spark.sql.execution.datasources.v2.RowLevelCommandScanRelationPushDown +import org.apache.spark.sql.execution.dynamicpruning.RowLevelCommandDynamicPruning + +class IcebergSparkSessionExtensions extends (SparkSessionExtensions => Unit) { + + override def apply(extensions: SparkSessionExtensions): Unit = { + // parser extensions + extensions.injectParser { case (_, parser) => new IcebergSparkSqlExtensionsParser(parser) } + + // analyzer extensions + extensions.injectResolutionRule { spark => ResolveProcedures(spark) } + extensions.injectResolutionRule { spark => ResolveMergeIntoTableReferences(spark) } + extensions.injectResolutionRule { _ => CheckMergeIntoTableConditions } + extensions.injectResolutionRule { _ => ProcedureArgumentCoercion } + extensions.injectResolutionRule { _ => AlignRowLevelCommandAssignments } + extensions.injectResolutionRule { _ => RewriteDeleteFromIcebergTable } + extensions.injectResolutionRule { _ => RewriteUpdateTable } + extensions.injectResolutionRule { _ => RewriteMergeIntoTable } + extensions.injectCheckRule { _ => MergeIntoIcebergTableResolutionCheck } + extensions.injectCheckRule { _ => AlignedRowLevelIcebergCommandCheck } + + // optimizer extensions + extensions.injectOptimizerRule { _ => ExtendedSimplifyConditionalsInPredicate } + extensions.injectOptimizerRule { _ => ExtendedReplaceNullWithFalseInPredicate } + // pre-CBO rules run only once and the order of the rules is important + // - metadata deletes have to be attempted immediately after the operator optimization + // - dynamic filters should be added before replacing commands with rewrite plans + // - scans must be planned before building writes + extensions.injectPreCBORule { _ => OptimizeMetadataOnlyDeleteFromIcebergTable } + extensions.injectPreCBORule { _ => RowLevelCommandScanRelationPushDown } + extensions.injectPreCBORule { _ => ExtendedV2Writes } + extensions.injectPreCBORule { spark => RowLevelCommandDynamicPruning(spark) } + extensions.injectPreCBORule { _ => ReplaceRewrittenRowLevelCommand } + + // planner extensions + extensions.injectPlannerStrategy { spark => ExtendedDataSourceV2Strategy(spark) } + } +} diff --git a/spark/v3.4/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/analysis/AlignRowLevelCommandAssignments.scala b/spark/v3.4/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/analysis/AlignRowLevelCommandAssignments.scala new file mode 100644 index 000000000000..23ba50bdfd06 --- /dev/null +++ b/spark/v3.4/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/analysis/AlignRowLevelCommandAssignments.scala @@ -0,0 +1,113 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.spark.sql.catalyst.analysis + +import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.catalyst.expressions.AssignmentUtils +import org.apache.spark.sql.catalyst.plans.logical.Assignment +import org.apache.spark.sql.catalyst.plans.logical.DeleteAction +import org.apache.spark.sql.catalyst.plans.logical.InsertAction +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.catalyst.plans.logical.MergeIntoIcebergTable +import org.apache.spark.sql.catalyst.plans.logical.UpdateAction +import org.apache.spark.sql.catalyst.plans.logical.UpdateIcebergTable +import org.apache.spark.sql.catalyst.rules.Rule + +/** + * A rule that aligns assignments in UPDATE and MERGE operations. + * + * Note that this rule must be run before rewriting row-level commands. + */ +object AlignRowLevelCommandAssignments + extends Rule[LogicalPlan] with AssignmentAlignmentSupport { + + override def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { + case u: UpdateIcebergTable if u.resolved && !u.aligned => + u.copy(assignments = alignAssignments(u.table, u.assignments)) + + case m: MergeIntoIcebergTable if m.resolved && !m.aligned => + val alignedMatchedActions = m.matchedActions.map { + case u @ UpdateAction(_, assignments) => + u.copy(assignments = alignAssignments(m.targetTable, assignments)) + case d: DeleteAction => + d + case _ => + throw new AnalysisException("Matched actions can only contain UPDATE or DELETE") + } + + val alignedNotMatchedActions = m.notMatchedActions.map { + case i @ InsertAction(_, assignments) => + // check no nested columns are present + val refs = assignments.map(_.key).map(AssignmentUtils.toAssignmentRef) + refs.foreach { ref => + if (ref.size > 1) { + throw new AnalysisException( + "Nested fields are not supported inside INSERT clauses of MERGE operations: " + + s"${ref.mkString("`", "`.`", "`")}") + } + } + + val colNames = refs.map(_.head) + + // check there are no duplicates + val duplicateColNames = colNames.groupBy(identity).collect { + case (name, matchingNames) if matchingNames.size > 1 => name + } + + if (duplicateColNames.nonEmpty) { + throw new AnalysisException( + s"Duplicate column names inside INSERT clause: ${duplicateColNames.mkString(", ")}") + } + + // reorder assignments by the target table column order + val assignmentMap = colNames.zip(assignments).toMap + i.copy(assignments = alignInsertActionAssignments(m.targetTable, assignmentMap)) + + case _ => + throw new AnalysisException("Not matched actions can only contain INSERT") + } + + m.copy(matchedActions = alignedMatchedActions, notMatchedActions = alignedNotMatchedActions) + } + + private def alignInsertActionAssignments( + targetTable: LogicalPlan, + assignmentMap: Map[String, Assignment]): Seq[Assignment] = { + + val resolver = conf.resolver + + targetTable.output.map { targetAttr => + val assignment = assignmentMap + .find { case (name, _) => resolver(name, targetAttr.name) } + .map { case (_, assignment) => assignment } + + if (assignment.isEmpty) { + throw new AnalysisException( + s"Cannot find column '${targetAttr.name}' of the target table among " + + s"the INSERT columns: ${assignmentMap.keys.mkString(", ")}. " + + "INSERT clauses must provide values for all columns of the target table.") + } + + val key = assignment.get.key + val value = castIfNeeded(targetAttr, assignment.get.value, resolver, Seq(targetAttr.name)) + AssignmentUtils.handleCharVarcharLimits(Assignment(key, value)) + } + } +} diff --git a/spark/v3.4/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/analysis/AlignedRowLevelIcebergCommandCheck.scala b/spark/v3.4/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/analysis/AlignedRowLevelIcebergCommandCheck.scala new file mode 100644 index 000000000000..d915e4f10949 --- /dev/null +++ b/spark/v3.4/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/analysis/AlignedRowLevelIcebergCommandCheck.scala @@ -0,0 +1,38 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.spark.sql.catalyst.analysis + +import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.catalyst.plans.logical.MergeIntoIcebergTable +import org.apache.spark.sql.catalyst.plans.logical.UpdateIcebergTable + +object AlignedRowLevelIcebergCommandCheck extends (LogicalPlan => Unit) { + + override def apply(plan: LogicalPlan): Unit = { + plan foreach { + case m: MergeIntoIcebergTable if !m.aligned => + throw new AnalysisException(s"Could not align Iceberg MERGE INTO: $m") + case u: UpdateIcebergTable if !u.aligned => + throw new AnalysisException(s"Could not align Iceberg UPDATE: $u") + case _ => // OK + } + } +} diff --git a/spark/v3.4/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/analysis/AssignmentAlignmentSupport.scala b/spark/v3.4/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/analysis/AssignmentAlignmentSupport.scala new file mode 100644 index 000000000000..76aec46a23b5 --- /dev/null +++ b/spark/v3.4/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/analysis/AssignmentAlignmentSupport.scala @@ -0,0 +1,205 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.spark.sql.catalyst.analysis + +import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.catalyst.SQLConfHelper +import org.apache.spark.sql.catalyst.expressions.Alias +import org.apache.spark.sql.catalyst.expressions.AssignmentUtils._ +import org.apache.spark.sql.catalyst.expressions.Cast +import org.apache.spark.sql.catalyst.expressions.CreateNamedStruct +import org.apache.spark.sql.catalyst.expressions.Expression +import org.apache.spark.sql.catalyst.expressions.GetStructField +import org.apache.spark.sql.catalyst.expressions.Literal +import org.apache.spark.sql.catalyst.expressions.NamedExpression +import org.apache.spark.sql.catalyst.plans.logical.Assignment +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._ +import org.apache.spark.sql.internal.SQLConf.StoreAssignmentPolicy +import org.apache.spark.sql.types.DataType +import org.apache.spark.sql.types.StructField +import org.apache.spark.sql.types.StructType +import scala.collection.compat.immutable.ArraySeq +import scala.collection.mutable + +trait AssignmentAlignmentSupport extends CastSupport { + + self: SQLConfHelper => + + private case class ColumnUpdate(ref: Seq[String], expr: Expression) + + /** + * Aligns assignments to match table columns. + *

+ * This method processes and reorders given assignments so that each target column gets + * an expression it should be set to. If a column does not have a matching assignment, + * it will be set to its current value. For example, if one passes a table with columns c1, c2 + * and an assignment c2 = 1, this method will return c1 = c1, c2 = 1. + *

+ * This method also handles updates to nested columns. If there is an assignment to a particular + * nested field, this method will construct a new struct with one field updated + * preserving other fields that have not been modified. For example, if one passes a table with + * columns c1, c2 where c2 is a struct with fields n1 and n2 and an assignment c2.n2 = 1, + * this method will return c1 = c1, c2 = struct(c2.n1, 1). + * + * @param table a target table + * @param assignments assignments to align + * @return aligned assignments that match table columns + */ + protected def alignAssignments( + table: LogicalPlan, + assignments: Seq[Assignment]): Seq[Assignment] = { + + val columnUpdates = assignments.map(a => ColumnUpdate(toAssignmentRef(a.key), a.value)) + val outputExprs = applyUpdates(table.output, columnUpdates) + outputExprs.zip(table.output).map { + case (expr, attr) => handleCharVarcharLimits(Assignment(attr, expr)) + } + } + + private def applyUpdates( + cols: Seq[NamedExpression], + updates: Seq[ColumnUpdate], + resolver: Resolver = conf.resolver, + namePrefix: Seq[String] = Nil): Seq[Expression] = { + + // iterate through columns at the current level and find which column updates match + cols.map { col => + // find matches for this column or any of its children + val prefixMatchedUpdates = updates.filter(a => resolver(a.ref.head, col.name)) + prefixMatchedUpdates match { + // if there is no exact match and no match for children, return the column as is + case updates if updates.isEmpty => + col + + // if there is an exact match, return the assigned expression + case Seq(update) if isExactMatch(update, col, resolver) => + castIfNeeded(col, update.expr, resolver, namePrefix :+ col.name) + + // if there are matches only for children + case updates if !hasExactMatch(updates, col, resolver) => + col.dataType match { + case StructType(fields) => + // build field expressions + val fieldExprs = fields.zipWithIndex.map { case (field, ordinal) => + Alias(GetStructField(col, ordinal, Some(field.name)), field.name)() + } + + // recursively apply this method on nested fields + val newUpdates = updates.map(u => u.copy(ref = u.ref.tail)) + val updatedFieldExprs = applyUpdates( + ArraySeq.unsafeWrapArray(fieldExprs), + newUpdates, + resolver, + namePrefix :+ col.name) + + // construct a new struct with updated field expressions + toNamedStruct(ArraySeq.unsafeWrapArray(fields), updatedFieldExprs) + + case otherType => + val colName = (namePrefix :+ col.name).mkString(".") + throw new AnalysisException( + "Updating nested fields is only supported for StructType " + + s"but $colName is of type $otherType" + ) + } + + // if there are conflicting updates, throw an exception + // there are two illegal scenarios: + // - multiple updates to the same column + // - updates to a top-level struct and its nested fields (e.g., a.b and a.b.c) + case updates if hasExactMatch(updates, col, resolver) => + val conflictingCols = updates.map(u => (namePrefix ++ u.ref).mkString(".")) + throw new AnalysisException( + "Updates are in conflict for these columns: " + + conflictingCols.distinct.mkString(", ")) + } + } + } + + private def toNamedStruct(fields: Seq[StructField], fieldExprs: Seq[Expression]): Expression = { + val namedStructExprs = fields.zip(fieldExprs).flatMap { case (field, expr) => + Seq(Literal(field.name), expr) + } + CreateNamedStruct(namedStructExprs) + } + + private def hasExactMatch( + updates: Seq[ColumnUpdate], + col: NamedExpression, + resolver: Resolver): Boolean = { + + updates.exists(assignment => isExactMatch(assignment, col, resolver)) + } + + private def isExactMatch( + update: ColumnUpdate, + col: NamedExpression, + resolver: Resolver): Boolean = { + + update.ref match { + case Seq(namePart) if resolver(namePart, col.name) => true + case _ => false + } + } + + protected def castIfNeeded( + tableAttr: NamedExpression, + expr: Expression, + resolver: Resolver, + colPath: Seq[String]): Expression = { + + val storeAssignmentPolicy = conf.storeAssignmentPolicy + + // run the type check and catch type errors + storeAssignmentPolicy match { + case StoreAssignmentPolicy.STRICT | StoreAssignmentPolicy.ANSI => + if (expr.nullable && !tableAttr.nullable) { + throw new AnalysisException( + s"Cannot write nullable values to non-null column '${tableAttr.name}'") + } + + // use byName = true to catch cases when struct field names don't match + // e.g. a struct with fields (a, b) is assigned as a struct with fields (a, c) or (b, a) + val errors = new mutable.ArrayBuffer[String]() + val canWrite = DataType.canWrite( + expr.dataType, tableAttr.dataType, byName = true, resolver, tableAttr.name, + storeAssignmentPolicy, err => errors += err) + + if (!canWrite) { + throw new AnalysisException( + s"Cannot write incompatible data:\n- ${errors.mkString("\n- ")}") + } + + case _ => // OK + } + + storeAssignmentPolicy match { + case _ if tableAttr.dataType.sameType(expr.dataType) => + expr + case StoreAssignmentPolicy.ANSI => + val cast = Cast(expr, tableAttr.dataType, Option(conf.sessionLocalTimeZone), ansiEnabled = true) + cast.setTagValue(Cast.BY_TABLE_INSERTION, ()) + TableOutputResolver.checkCastOverflowInTableInsert(cast, colPath.quoted) + case _ => + Cast(expr, tableAttr.dataType, Option(conf.sessionLocalTimeZone)) + } + } +} diff --git a/spark/v3.4/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckMergeIntoTableConditions.scala b/spark/v3.4/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckMergeIntoTableConditions.scala new file mode 100644 index 000000000000..70f6694af60b --- /dev/null +++ b/spark/v3.4/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckMergeIntoTableConditions.scala @@ -0,0 +1,74 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.spark.sql.catalyst.analysis + +import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.catalyst.expressions.Expression +import org.apache.spark.sql.catalyst.expressions.SubqueryExpression +import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression +import org.apache.spark.sql.catalyst.plans.logical.DeleteAction +import org.apache.spark.sql.catalyst.plans.logical.InsertAction +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.catalyst.plans.logical.MergeIntoIcebergTable +import org.apache.spark.sql.catalyst.plans.logical.UpdateAction +import org.apache.spark.sql.catalyst.rules.Rule + +/** + * A rule that checks MERGE operations contain only supported conditions. + * + * Note that this rule must be run in the resolution batch before Spark executes CheckAnalysis. + * Otherwise, CheckAnalysis will throw a less descriptive error. + */ +object CheckMergeIntoTableConditions extends Rule[LogicalPlan] { + + override def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { + case m: MergeIntoIcebergTable if m.resolved => + checkMergeIntoCondition("SEARCH", m.mergeCondition) + + val actions = m.matchedActions ++ m.notMatchedActions + actions.foreach { + case DeleteAction(Some(cond)) => checkMergeIntoCondition("DELETE", cond) + case UpdateAction(Some(cond), _) => checkMergeIntoCondition("UPDATE", cond) + case InsertAction(Some(cond), _) => checkMergeIntoCondition("INSERT", cond) + case _ => // OK + } + + m + } + + private def checkMergeIntoCondition(condName: String, cond: Expression): Unit = { + if (!cond.deterministic) { + throw new AnalysisException( + s"Non-deterministic functions are not supported in $condName conditions of " + + s"MERGE operations: ${cond.sql}") + } + + if (SubqueryExpression.hasSubquery(cond)) { + throw new AnalysisException( + s"Subqueries are not supported in conditions of MERGE operations. " + + s"Found a subquery in the $condName condition: ${cond.sql}") + } + + if (cond.find(_.isInstanceOf[AggregateExpression]).isDefined) { + throw new AnalysisException( + s"Agg functions are not supported in $condName conditions of MERGE operations: " + {cond.sql}) + } + } +} diff --git a/spark/v3.4/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/analysis/MergeIntoIcebergTableResolutionCheck.scala b/spark/v3.4/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/analysis/MergeIntoIcebergTableResolutionCheck.scala new file mode 100644 index 000000000000..b3a9bda280d2 --- /dev/null +++ b/spark/v3.4/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/analysis/MergeIntoIcebergTableResolutionCheck.scala @@ -0,0 +1,35 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.spark.sql.catalyst.analysis + +import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.catalyst.plans.logical.UnresolvedMergeIntoIcebergTable + +object MergeIntoIcebergTableResolutionCheck extends (LogicalPlan => Unit) { + + override def apply(plan: LogicalPlan): Unit = { + plan foreach { + case m: UnresolvedMergeIntoIcebergTable => + throw new AnalysisException(s"Could not resolve Iceberg MERGE INTO statement: $m") + case _ => // OK + } + } +} diff --git a/spark/v3.4/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/analysis/ProcedureArgumentCoercion.scala b/spark/v3.4/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/analysis/ProcedureArgumentCoercion.scala new file mode 100644 index 000000000000..7f0ca8fadded --- /dev/null +++ b/spark/v3.4/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/analysis/ProcedureArgumentCoercion.scala @@ -0,0 +1,56 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.spark.sql.catalyst.analysis + +import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.catalyst.expressions.Cast +import org.apache.spark.sql.catalyst.plans.logical.Call +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.catalyst.rules.Rule + +object ProcedureArgumentCoercion extends Rule[LogicalPlan] { + override def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { + case c @ Call(procedure, args) if c.resolved => + val params = procedure.parameters + + val newArgs = args.zipWithIndex.map { case (arg, index) => + val param = params(index) + val paramType = param.dataType + val argType = arg.dataType + + if (paramType != argType && !Cast.canUpCast(argType, paramType)) { + throw new AnalysisException( + s"Wrong arg type for ${param.name}: cannot cast $argType to $paramType") + } + + if (paramType != argType) { + Cast(arg, paramType) + } else { + arg + } + } + + if (newArgs != args) { + c.copy(args = newArgs) + } else { + c + } + } +} diff --git a/spark/v3.4/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveMergeIntoTableReferences.scala b/spark/v3.4/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveMergeIntoTableReferences.scala new file mode 100644 index 000000000000..bb270391a170 --- /dev/null +++ b/spark/v3.4/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveMergeIntoTableReferences.scala @@ -0,0 +1,153 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.spark.sql.catalyst.analysis + +import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.catalyst.expressions.Expression +import org.apache.spark.sql.catalyst.plans.logical.Assignment +import org.apache.spark.sql.catalyst.plans.logical.DeleteAction +import org.apache.spark.sql.catalyst.plans.logical.InsertAction +import org.apache.spark.sql.catalyst.plans.logical.InsertStarAction +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.catalyst.plans.logical.MergeIntoIcebergTable +import org.apache.spark.sql.catalyst.plans.logical.Project +import org.apache.spark.sql.catalyst.plans.logical.UnresolvedMergeIntoIcebergTable +import org.apache.spark.sql.catalyst.plans.logical.UpdateAction +import org.apache.spark.sql.catalyst.plans.logical.UpdateStarAction +import org.apache.spark.sql.catalyst.rules.Rule + +/** + * A resolution rule similar to ResolveReferences in Spark but handles Iceberg MERGE operations. + */ +case class ResolveMergeIntoTableReferences(spark: SparkSession) extends Rule[LogicalPlan] { + + private lazy val analyzer: Analyzer = spark.sessionState.analyzer + + override def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperatorsUp { + case m @ UnresolvedMergeIntoIcebergTable(targetTable, sourceTable, context) + if targetTable.resolved && sourceTable.resolved && m.duplicateResolved => + + val resolvedMatchedActions = context.matchedActions.map { + case DeleteAction(cond) => + val resolvedCond = cond.map(resolveCond("DELETE", _, m)) + DeleteAction(resolvedCond) + + case UpdateAction(cond, assignments) => + val resolvedCond = cond.map(resolveCond("UPDATE", _, m)) + // the update action can access columns from both target and source tables + val resolvedAssignments = resolveAssignments(assignments, m, resolveValuesWithSourceOnly = false) + UpdateAction(resolvedCond, resolvedAssignments) + + case UpdateStarAction(updateCondition) => + val resolvedUpdateCondition = updateCondition.map(resolveCond("UPDATE", _, m)) + val assignments = targetTable.output.map { attr => + Assignment(attr, UnresolvedAttribute(Seq(attr.name))) + } + // for UPDATE *, the value must be from the source table + val resolvedAssignments = resolveAssignments(assignments, m, resolveValuesWithSourceOnly = true) + UpdateAction(resolvedUpdateCondition, resolvedAssignments) + + case _ => + throw new AnalysisException("Matched actions can only contain UPDATE or DELETE") + } + + val resolvedNotMatchedActions = context.notMatchedActions.map { + case InsertAction(cond, assignments) => + // the insert action is used when not matched, so its condition and value can only + // access columns from the source table + val resolvedCond = cond.map(resolveCond("INSERT", _, Project(Nil, m.sourceTable))) + val resolvedAssignments = resolveAssignments(assignments, m, resolveValuesWithSourceOnly = true) + InsertAction(resolvedCond, resolvedAssignments) + + case InsertStarAction(cond) => + // the insert action is used when not matched, so its condition and value can only + // access columns from the source table + val resolvedCond = cond.map(resolveCond("INSERT", _, Project(Nil, m.sourceTable))) + val assignments = targetTable.output.map { attr => + Assignment(attr, UnresolvedAttribute(Seq(attr.name))) + } + val resolvedAssignments = resolveAssignments(assignments, m, resolveValuesWithSourceOnly = true) + InsertAction(resolvedCond, resolvedAssignments) + + case _ => + throw new AnalysisException("Not matched actions can only contain INSERT") + } + + val resolvedMergeCondition = resolveCond("SEARCH", context.mergeCondition, m) + + MergeIntoIcebergTable( + targetTable, + sourceTable, + mergeCondition = resolvedMergeCondition, + matchedActions = resolvedMatchedActions, + notMatchedActions = resolvedNotMatchedActions) + } + + private def resolveCond(condName: String, cond: Expression, plan: LogicalPlan): Expression = { + val resolvedCond = analyzer.resolveExpressionByPlanChildren(cond, plan) + + val unresolvedAttrs = resolvedCond.references.filter(!_.resolved) + if (unresolvedAttrs.nonEmpty) { + throw new AnalysisException( + s"Cannot resolve ${unresolvedAttrs.map(_.sql).mkString("[", ",", "]")} in $condName condition " + + s"of MERGE operation given input columns: ${plan.inputSet.toSeq.map(_.sql).mkString("[", ",", "]")}") + } + + resolvedCond + } + + // copied from ResolveReferences in Spark + private def resolveAssignments( + assignments: Seq[Assignment], + mergeInto: UnresolvedMergeIntoIcebergTable, + resolveValuesWithSourceOnly: Boolean): Seq[Assignment] = { + assignments.map { assign => + val resolvedKey = assign.key match { + case c if !c.resolved => + resolveMergeExprOrFail(c, Project(Nil, mergeInto.targetTable)) + case o => o + } + val resolvedValue = assign.value match { + // The update values may contain target and/or source references. + case c if !c.resolved => + if (resolveValuesWithSourceOnly) { + resolveMergeExprOrFail(c, Project(Nil, mergeInto.sourceTable)) + } else { + resolveMergeExprOrFail(c, mergeInto) + } + case o => o + } + Assignment(resolvedKey, resolvedValue) + } + } + + // copied from ResolveReferences in Spark + private def resolveMergeExprOrFail(e: Expression, p: LogicalPlan): Expression = { + val resolved = analyzer.resolveExpressionByPlanChildren(e, p) + resolved.references.filter(!_.resolved).foreach { a => + // Note: This will throw error only on unresolved attribute issues, + // not other resolution errors like mismatched data types. + val cols = p.inputSet.toSeq.map(_.sql).mkString(", ") + throw new AnalysisException(s"cannot resolve ${a.sql} in MERGE command given columns [$cols]") + } + resolved + } +} diff --git a/spark/v3.4/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveProcedures.scala b/spark/v3.4/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveProcedures.scala new file mode 100644 index 000000000000..ee69b5e344f0 --- /dev/null +++ b/spark/v3.4/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveProcedures.scala @@ -0,0 +1,190 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.spark.sql.catalyst.analysis + +import java.util.Locale +import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.catalyst.expressions.Expression +import org.apache.spark.sql.catalyst.expressions.Literal +import org.apache.spark.sql.catalyst.plans.logical.Call +import org.apache.spark.sql.catalyst.plans.logical.CallArgument +import org.apache.spark.sql.catalyst.plans.logical.CallStatement +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.catalyst.plans.logical.NamedArgument +import org.apache.spark.sql.catalyst.plans.logical.PositionalArgument +import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.connector.catalog.CatalogManager +import org.apache.spark.sql.connector.catalog.CatalogPlugin +import org.apache.spark.sql.connector.catalog.LookupCatalog +import org.apache.spark.sql.connector.iceberg.catalog.ProcedureCatalog +import org.apache.spark.sql.connector.iceberg.catalog.ProcedureParameter +import scala.collection.Seq + +case class ResolveProcedures(spark: SparkSession) extends Rule[LogicalPlan] with LookupCatalog { + + protected lazy val catalogManager: CatalogManager = spark.sessionState.catalogManager + + override def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { + case CallStatement(CatalogAndIdentifier(catalog, ident), args) => + val procedure = catalog.asProcedureCatalog.loadProcedure(ident) + + val params = procedure.parameters + val normalizedParams = normalizeParams(params) + validateParams(normalizedParams) + + val normalizedArgs = normalizeArgs(args) + Call(procedure, args = buildArgExprs(normalizedParams, normalizedArgs).toSeq) + } + + private def validateParams(params: Seq[ProcedureParameter]): Unit = { + // should not be any duplicate param names + val duplicateParamNames = params.groupBy(_.name).collect { + case (name, matchingParams) if matchingParams.length > 1 => name + } + + if (duplicateParamNames.nonEmpty) { + throw new AnalysisException(s"Duplicate parameter names: ${duplicateParamNames.mkString("[", ",", "]")}") + } + + // optional params should be at the end + params.sliding(2).foreach { + case Seq(previousParam, currentParam) if !previousParam.required && currentParam.required => + throw new AnalysisException( + s"Optional parameters must be after required ones but $currentParam is after $previousParam") + case _ => + } + } + + private def buildArgExprs( + params: Seq[ProcedureParameter], + args: Seq[CallArgument]): Seq[Expression] = { + + // build a map of declared parameter names to their positions + val nameToPositionMap = params.map(_.name).zipWithIndex.toMap + + // build a map of parameter names to args + val nameToArgMap = buildNameToArgMap(params, args, nameToPositionMap) + + // verify all required parameters are provided + val missingParamNames = params.filter(_.required).collect { + case param if !nameToArgMap.contains(param.name) => param.name + } + + if (missingParamNames.nonEmpty) { + throw new AnalysisException(s"Missing required parameters: ${missingParamNames.mkString("[", ",", "]")}") + } + + val argExprs = new Array[Expression](params.size) + + nameToArgMap.foreach { case (name, arg) => + val position = nameToPositionMap(name) + argExprs(position) = arg.expr + } + + // assign nulls to optional params that were not set + params.foreach { + case p if !p.required && !nameToArgMap.contains(p.name) => + val position = nameToPositionMap(p.name) + argExprs(position) = Literal.create(null, p.dataType) + case _ => + } + + argExprs + } + + private def buildNameToArgMap( + params: Seq[ProcedureParameter], + args: Seq[CallArgument], + nameToPositionMap: Map[String, Int]): Map[String, CallArgument] = { + + val containsNamedArg = args.exists(_.isInstanceOf[NamedArgument]) + val containsPositionalArg = args.exists(_.isInstanceOf[PositionalArgument]) + + if (containsNamedArg && containsPositionalArg) { + throw new AnalysisException("Named and positional arguments cannot be mixed") + } + + if (containsNamedArg) { + buildNameToArgMapUsingNames(args, nameToPositionMap) + } else { + buildNameToArgMapUsingPositions(args, params) + } + } + + private def buildNameToArgMapUsingNames( + args: Seq[CallArgument], + nameToPositionMap: Map[String, Int]): Map[String, CallArgument] = { + + val namedArgs = args.asInstanceOf[Seq[NamedArgument]] + + val validationErrors = namedArgs.groupBy(_.name).collect { + case (name, matchingArgs) if matchingArgs.size > 1 => s"Duplicate procedure argument: $name" + case (name, _) if !nameToPositionMap.contains(name) => s"Unknown argument: $name" + } + + if (validationErrors.nonEmpty) { + throw new AnalysisException(s"Could not build name to arg map: ${validationErrors.mkString(", ")}") + } + + namedArgs.map(arg => arg.name -> arg).toMap + } + + private def buildNameToArgMapUsingPositions( + args: Seq[CallArgument], + params: Seq[ProcedureParameter]): Map[String, CallArgument] = { + + if (args.size > params.size) { + throw new AnalysisException("Too many arguments for procedure") + } + + args.zipWithIndex.map { case (arg, position) => + val param = params(position) + param.name -> arg + }.toMap + } + + private def normalizeParams(params: Seq[ProcedureParameter]): Seq[ProcedureParameter] = { + params.map { + case param if param.required => + val normalizedName = param.name.toLowerCase(Locale.ROOT) + ProcedureParameter.required(normalizedName, param.dataType) + case param => + val normalizedName = param.name.toLowerCase(Locale.ROOT) + ProcedureParameter.optional(normalizedName, param.dataType) + } + } + + private def normalizeArgs(args: Seq[CallArgument]): Seq[CallArgument] = { + args.map { + case a @ NamedArgument(name, _) => a.copy(name = name.toLowerCase(Locale.ROOT)) + case other => other + } + } + + implicit class CatalogHelper(plugin: CatalogPlugin) { + def asProcedureCatalog: ProcedureCatalog = plugin match { + case procedureCatalog: ProcedureCatalog => + procedureCatalog + case _ => + throw new AnalysisException(s"Cannot use catalog ${plugin.name}: not a ProcedureCatalog") + } + } +} diff --git a/spark/v3.4/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/analysis/RewriteDeleteFromIcebergTable.scala b/spark/v3.4/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/analysis/RewriteDeleteFromIcebergTable.scala new file mode 100644 index 000000000000..d97a921250b1 --- /dev/null +++ b/spark/v3.4/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/analysis/RewriteDeleteFromIcebergTable.scala @@ -0,0 +1,118 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.spark.sql.catalyst.analysis + +import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.catalyst.expressions.Alias +import org.apache.spark.sql.catalyst.expressions.EqualNullSafe +import org.apache.spark.sql.catalyst.expressions.Expression +import org.apache.spark.sql.catalyst.expressions.Literal +import org.apache.spark.sql.catalyst.expressions.Not +import org.apache.spark.sql.catalyst.plans.logical.DeleteFromIcebergTable +import org.apache.spark.sql.catalyst.plans.logical.Filter +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.catalyst.plans.logical.Project +import org.apache.spark.sql.catalyst.plans.logical.ReplaceIcebergData +import org.apache.spark.sql.catalyst.plans.logical.WriteIcebergDelta +import org.apache.spark.sql.catalyst.util.RowDeltaUtils._ +import org.apache.spark.sql.connector.catalog.SupportsRowLevelOperations +import org.apache.spark.sql.connector.write.RowLevelOperation.Command.DELETE +import org.apache.spark.sql.connector.write.RowLevelOperationTable +import org.apache.spark.sql.connector.write.SupportsDelta +import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation +import org.apache.spark.sql.util.CaseInsensitiveStringMap + +/** + * Assigns a rewrite plan for v2 tables that support rewriting data to handle DELETE statements. + * + * If a table implements SupportsDelete and SupportsRowLevelOperations, this rule assigns a rewrite + * plan but the optimizer will check whether this particular DELETE statement can be handled + * by simply passing delete filters to the connector. If yes, the optimizer will then discard + * the rewrite plan. + */ +object RewriteDeleteFromIcebergTable extends RewriteRowLevelIcebergCommand { + + override def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { + case d @ DeleteFromIcebergTable(aliasedTable, Some(cond), None) if d.resolved => + EliminateSubqueryAliases(aliasedTable) match { + case r @ DataSourceV2Relation(tbl: SupportsRowLevelOperations, _, _, _, _) => + val table = buildOperationTable(tbl, DELETE, CaseInsensitiveStringMap.empty()) + val rewritePlan = table.operation match { + case _: SupportsDelta => + buildWriteDeltaPlan(r, table, cond) + case _ => + buildReplaceDataPlan(r, table, cond) + } + // keep the original relation in DELETE to try deleting using filters + DeleteFromIcebergTable(r, Some(cond), Some(rewritePlan)) + + case p => + throw new AnalysisException(s"$p is not an Iceberg table") + } + } + + // build a rewrite plan for sources that support replacing groups of data (e.g. files, partitions) + private def buildReplaceDataPlan( + relation: DataSourceV2Relation, + operationTable: RowLevelOperationTable, + cond: Expression): ReplaceIcebergData = { + + // resolve all needed attrs (e.g. metadata attrs for grouping data on write) + val metadataAttrs = resolveRequiredMetadataAttrs(relation, operationTable.operation) + + // construct a read relation and include all required metadata columns + val readRelation = buildRelationWithAttrs(relation, operationTable, metadataAttrs) + + // construct a plan that contains unmatched rows in matched groups that must be carried over + // such rows do not match the condition but have to be copied over as the source can replace + // only groups of rows + val remainingRowsFilter = Not(EqualNullSafe(cond, Literal.TrueLiteral)) + val remainingRowsPlan = Filter(remainingRowsFilter, readRelation) + + // build a plan to replace read groups in the table + val writeRelation = relation.copy(table = operationTable) + ReplaceIcebergData(writeRelation, remainingRowsPlan, relation) + } + + // build a rewrite plan for sources that support row deltas + private def buildWriteDeltaPlan( + relation: DataSourceV2Relation, + operationTable: RowLevelOperationTable, + cond: Expression): WriteIcebergDelta = { + + // resolve all needed attrs (e.g. row ID and any required metadata attrs) + val rowIdAttrs = resolveRowIdAttrs(relation, operationTable.operation) + val metadataAttrs = resolveRequiredMetadataAttrs(relation, operationTable.operation) + + // construct a read relation and include all required metadata columns + val readRelation = buildRelationWithAttrs(relation, operationTable, rowIdAttrs ++ metadataAttrs) + + // construct a plan that only contains records to delete + val deletedRowsPlan = Filter(cond, readRelation) + val operationType = Alias(Literal(DELETE_OPERATION), OPERATION_COLUMN)() + val requiredWriteAttrs = dedupAttrs(rowIdAttrs ++ metadataAttrs) + val project = Project(operationType +: requiredWriteAttrs, deletedRowsPlan) + + // build a plan to write deletes to the table + val writeRelation = relation.copy(table = operationTable) + val projections = buildWriteDeltaProjections(project, Nil, rowIdAttrs, metadataAttrs) + WriteIcebergDelta(writeRelation, project, relation, projections) + } +} diff --git a/spark/v3.4/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/analysis/RewriteMergeIntoTable.scala b/spark/v3.4/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/analysis/RewriteMergeIntoTable.scala new file mode 100644 index 000000000000..c01306ccf5b9 --- /dev/null +++ b/spark/v3.4/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/analysis/RewriteMergeIntoTable.scala @@ -0,0 +1,459 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.spark.sql.catalyst.analysis + +import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.catalyst.ProjectingInternalRow +import org.apache.spark.sql.catalyst.expressions.Alias +import org.apache.spark.sql.catalyst.expressions.And +import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.catalyst.expressions.AttributeReference +import org.apache.spark.sql.catalyst.expressions.AttributeSet +import org.apache.spark.sql.catalyst.expressions.Expression +import org.apache.spark.sql.catalyst.expressions.ExtendedV2ExpressionUtils +import org.apache.spark.sql.catalyst.expressions.IsNotNull +import org.apache.spark.sql.catalyst.expressions.Literal +import org.apache.spark.sql.catalyst.expressions.Literal.FalseLiteral +import org.apache.spark.sql.catalyst.expressions.Literal.TrueLiteral +import org.apache.spark.sql.catalyst.expressions.MonotonicallyIncreasingID +import org.apache.spark.sql.catalyst.expressions.PredicateHelper +import org.apache.spark.sql.catalyst.plans.FullOuter +import org.apache.spark.sql.catalyst.plans.Inner +import org.apache.spark.sql.catalyst.plans.LeftAnti +import org.apache.spark.sql.catalyst.plans.LeftOuter +import org.apache.spark.sql.catalyst.plans.RightOuter +import org.apache.spark.sql.catalyst.plans.logical.AppendData +import org.apache.spark.sql.catalyst.plans.logical.DeleteAction +import org.apache.spark.sql.catalyst.plans.logical.Filter +import org.apache.spark.sql.catalyst.plans.logical.HintInfo +import org.apache.spark.sql.catalyst.plans.logical.InsertAction +import org.apache.spark.sql.catalyst.plans.logical.Join +import org.apache.spark.sql.catalyst.plans.logical.JoinHint +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.catalyst.plans.logical.MergeAction +import org.apache.spark.sql.catalyst.plans.logical.MergeIntoIcebergTable +import org.apache.spark.sql.catalyst.plans.logical.MergeRows +import org.apache.spark.sql.catalyst.plans.logical.NO_BROADCAST_HASH +import org.apache.spark.sql.catalyst.plans.logical.NoStatsUnaryNode +import org.apache.spark.sql.catalyst.plans.logical.Project +import org.apache.spark.sql.catalyst.plans.logical.ReplaceIcebergData +import org.apache.spark.sql.catalyst.plans.logical.UpdateAction +import org.apache.spark.sql.catalyst.plans.logical.WriteIcebergDelta +import org.apache.spark.sql.catalyst.util.RowDeltaUtils._ +import org.apache.spark.sql.catalyst.util.WriteDeltaProjections +import org.apache.spark.sql.connector.catalog.SupportsRowLevelOperations +import org.apache.spark.sql.connector.expressions.FieldReference +import org.apache.spark.sql.connector.expressions.NamedReference +import org.apache.spark.sql.connector.write.RowLevelOperation.Command.MERGE +import org.apache.spark.sql.connector.write.RowLevelOperationTable +import org.apache.spark.sql.connector.write.SupportsDelta +import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation +import org.apache.spark.sql.types.IntegerType +import org.apache.spark.sql.types.StructField +import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.util.CaseInsensitiveStringMap + +/** + * Assigns a rewrite plan for v2 tables that support rewriting data to handle MERGE statements. + * + * This rule assumes the commands have been fully resolved and all assignments have been aligned. + * That's why it must be run after AlignRowLevelCommandAssignments. + */ +object RewriteMergeIntoTable extends RewriteRowLevelIcebergCommand with PredicateHelper { + + private final val ROW_FROM_SOURCE = "__row_from_source" + private final val ROW_FROM_TARGET = "__row_from_target" + private final val ROW_ID = "__row_id" + + private final val ROW_FROM_SOURCE_REF = FieldReference(ROW_FROM_SOURCE) + private final val ROW_FROM_TARGET_REF = FieldReference(ROW_FROM_TARGET) + private final val ROW_ID_REF = FieldReference(ROW_ID) + + override def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { + case m @ MergeIntoIcebergTable(aliasedTable, source, cond, matchedActions, notMatchedActions, None) + if m.resolved && m.aligned && matchedActions.isEmpty && notMatchedActions.size == 1 => + + EliminateSubqueryAliases(aliasedTable) match { + case r: DataSourceV2Relation => + // NOT MATCHED conditions may only refer to columns in source so they can be pushed down + val insertAction = notMatchedActions.head.asInstanceOf[InsertAction] + val filteredSource = insertAction.condition match { + case Some(insertCond) => Filter(insertCond, source) + case None => source + } + + // when there are no MATCHED actions, use a left anti join to remove any matching rows + // and switch to using a regular append instead of a row-level merge + // only unmatched source rows that match the condition are appended to the table + val joinPlan = Join(filteredSource, r, LeftAnti, Some(cond), JoinHint.NONE) + + val outputExprs = insertAction.assignments.map(_.value) + val outputColNames = r.output.map(_.name) + val outputCols = outputExprs.zip(outputColNames).map { case (expr, name) => + Alias(expr, name)() + } + val project = Project(outputCols, joinPlan) + + AppendData.byPosition(r, project) + + case p => + throw new AnalysisException(s"$p is not an Iceberg table") + } + + case m @ MergeIntoIcebergTable(aliasedTable, source, cond, matchedActions, notMatchedActions, None) + if m.resolved && m.aligned && matchedActions.isEmpty => + + EliminateSubqueryAliases(aliasedTable) match { + case r: DataSourceV2Relation => + // when there are no MATCHED actions, use a left anti join to remove any matching rows + // and switch to using a regular append instead of a row-level merge + // only unmatched source rows that match action conditions are appended to the table + val joinPlan = Join(source, r, LeftAnti, Some(cond), JoinHint.NONE) + + val notMatchedConditions = notMatchedActions.map(actionCondition) + val notMatchedOutputs = notMatchedActions.map(actionOutput(_, Nil)) + + // merge rows as there are multiple not matched actions + val mergeRows = MergeRows( + isSourceRowPresent = TrueLiteral, + isTargetRowPresent = FalseLiteral, + matchedConditions = Nil, + matchedOutputs = Nil, + notMatchedConditions = notMatchedConditions, + notMatchedOutputs = notMatchedOutputs, + targetOutput = Nil, + rowIdAttrs = Nil, + performCardinalityCheck = false, + emitNotMatchedTargetRows = false, + output = buildMergeRowsOutput(Nil, notMatchedOutputs, r.output), + joinPlan) + + AppendData.byPosition(r, mergeRows) + + case p => + throw new AnalysisException(s"$p is not an Iceberg table") + } + + case m @ MergeIntoIcebergTable(aliasedTable, source, cond, matchedActions, notMatchedActions, None) + if m.resolved && m.aligned => + + EliminateSubqueryAliases(aliasedTable) match { + case r @ DataSourceV2Relation(tbl: SupportsRowLevelOperations, _, _, _, _) => + val table = buildOperationTable(tbl, MERGE, CaseInsensitiveStringMap.empty()) + val rewritePlan = table.operation match { + case _: SupportsDelta => + buildWriteDeltaPlan(r, table, source, cond, matchedActions, notMatchedActions) + case _ => + buildReplaceDataPlan(r, table, source, cond, matchedActions, notMatchedActions) + } + + m.copy(rewritePlan = Some(rewritePlan)) + + case p => + throw new AnalysisException(s"$p is not an Iceberg table") + } + } + + // build a rewrite plan for sources that support replacing groups of data (e.g. files, partitions) + private def buildReplaceDataPlan( + relation: DataSourceV2Relation, + operationTable: RowLevelOperationTable, + source: LogicalPlan, + cond: Expression, + matchedActions: Seq[MergeAction], + notMatchedActions: Seq[MergeAction]): ReplaceIcebergData = { + + // resolve all needed attrs (e.g. metadata attrs for grouping data on write) + val metadataAttrs = resolveRequiredMetadataAttrs(relation, operationTable.operation) + + // construct a scan relation and include all required metadata columns + val readRelation = buildRelationWithAttrs(relation, operationTable, metadataAttrs) + val readAttrs = readRelation.output + + // project an extra column to check if a target row exists after the join + // project a synthetic row ID to perform the cardinality check + val rowFromTarget = Alias(TrueLiteral, ROW_FROM_TARGET)() + val rowId = Alias(MonotonicallyIncreasingID(), ROW_ID)() + val targetTableProjExprs = readAttrs ++ Seq(rowFromTarget, rowId) + val targetTableProj = Project(targetTableProjExprs, readRelation) + + // project an extra column to check if a source row exists after the join + val rowFromSource = Alias(TrueLiteral, ROW_FROM_SOURCE)() + val sourceTableProjExprs = source.output :+ rowFromSource + val sourceTableProj = Project(sourceTableProjExprs, source) + + // use left outer join if there is no NOT MATCHED action, unmatched source rows can be discarded + // use full outer join in all other cases, unmatched source rows may be needed + // disable broadcasts for the target table to perform the cardinality check + val joinType = if (notMatchedActions.isEmpty) LeftOuter else FullOuter + val joinHint = JoinHint(leftHint = Some(HintInfo(Some(NO_BROADCAST_HASH))), rightHint = None) + val joinPlan = Join(NoStatsUnaryNode(targetTableProj), sourceTableProj, joinType, Some(cond), joinHint) + + // add an extra matched action to output the original row if none of the actual actions matched + // this is needed to keep target rows that should be copied over + val matchedConditions = matchedActions.map(actionCondition) :+ TrueLiteral + val matchedOutputs = matchedActions.map(actionOutput(_, metadataAttrs)) :+ readAttrs + + val notMatchedConditions = notMatchedActions.map(actionCondition) + val notMatchedOutputs = notMatchedActions.map(actionOutput(_, metadataAttrs)) + + val rowIdAttr = resolveAttrRef(ROW_ID_REF, joinPlan) + val rowFromSourceAttr = resolveAttrRef(ROW_FROM_SOURCE_REF, joinPlan) + val rowFromTargetAttr = resolveAttrRef(ROW_FROM_TARGET_REF, joinPlan) + + val mergeRows = MergeRows( + isSourceRowPresent = IsNotNull(rowFromSourceAttr), + isTargetRowPresent = if (notMatchedActions.isEmpty) TrueLiteral else IsNotNull(rowFromTargetAttr), + matchedConditions = matchedConditions, + matchedOutputs = matchedOutputs, + notMatchedConditions = notMatchedConditions, + notMatchedOutputs = notMatchedOutputs, + targetOutput = readAttrs, + rowIdAttrs = Seq(rowIdAttr), + performCardinalityCheck = isCardinalityCheckNeeded(matchedActions), + emitNotMatchedTargetRows = true, + output = buildMergeRowsOutput(matchedOutputs, notMatchedOutputs, readAttrs), + joinPlan) + + // build a plan to replace read groups in the table + val writeRelation = relation.copy(table = operationTable) + ReplaceIcebergData(writeRelation, mergeRows, relation) + } + + // build a rewrite plan for sources that support row deltas + private def buildWriteDeltaPlan( + relation: DataSourceV2Relation, + operationTable: RowLevelOperationTable, + source: LogicalPlan, + cond: Expression, + matchedActions: Seq[MergeAction], + notMatchedActions: Seq[MergeAction]): WriteIcebergDelta = { + + // resolve all needed attrs (e.g. row ID and any required metadata attrs) + val rowAttrs = relation.output + val rowIdAttrs = resolveRowIdAttrs(relation, operationTable.operation) + val metadataAttrs = resolveRequiredMetadataAttrs(relation, operationTable.operation) + + // construct a scan relation and include all required metadata columns + val readRelation = buildRelationWithAttrs(relation, operationTable, rowIdAttrs ++ metadataAttrs) + val readAttrs = readRelation.output + + val (targetCond, joinCond) = splitMergeCond(cond, readRelation) + + // project an extra column to check if a target row exists after the join + val targetTableProjExprs = readAttrs :+ Alias(TrueLiteral, ROW_FROM_TARGET)() + val targetTableProj = Project(targetTableProjExprs, Filter(targetCond, readRelation)) + + // project an extra column to check if a source row exists after the join + val sourceTableProjExprs = source.output :+ Alias(TrueLiteral, ROW_FROM_SOURCE)() + val sourceTableProj = Project(sourceTableProjExprs, source) + + // use inner join if there is no NOT MATCHED action, unmatched source rows can be discarded + // use right outer join in all other cases, unmatched source rows may be needed + // also disable broadcasts for the target table to perform the cardinality check + val joinType = if (notMatchedActions.isEmpty) Inner else RightOuter + val joinHint = JoinHint(leftHint = Some(HintInfo(Some(NO_BROADCAST_HASH))), rightHint = None) + val joinPlan = Join(NoStatsUnaryNode(targetTableProj), sourceTableProj, joinType, Some(joinCond), joinHint) + + val deleteRowValues = buildDeltaDeleteRowValues(rowAttrs, rowIdAttrs) + val metadataReadAttrs = readAttrs.filterNot(relation.outputSet.contains) + + val matchedConditions = matchedActions.map(actionCondition) + val matchedOutputs = matchedActions.map(deltaActionOutput(_, deleteRowValues, metadataReadAttrs)) + + val notMatchedConditions = notMatchedActions.map(actionCondition) + val notMatchedOutputs = notMatchedActions.map(deltaActionOutput(_, deleteRowValues, metadataReadAttrs)) + + val operationTypeAttr = AttributeReference(OPERATION_COLUMN, IntegerType, nullable = false)() + val rowFromSourceAttr = resolveAttrRef(ROW_FROM_SOURCE_REF, joinPlan) + val rowFromTargetAttr = resolveAttrRef(ROW_FROM_TARGET_REF, joinPlan) + + // merged rows must contain values for the operation type and all read attrs + val mergeRowsOutput = buildMergeRowsOutput(matchedOutputs, notMatchedOutputs, operationTypeAttr +: readAttrs) + + val mergeRows = MergeRows( + isSourceRowPresent = IsNotNull(rowFromSourceAttr), + isTargetRowPresent = if (notMatchedActions.isEmpty) TrueLiteral else IsNotNull(rowFromTargetAttr), + matchedConditions = matchedConditions, + matchedOutputs = matchedOutputs, + notMatchedConditions = notMatchedConditions, + notMatchedOutputs = notMatchedOutputs, + // only needed if emitting unmatched target rows + targetOutput = Nil, + rowIdAttrs = rowIdAttrs, + performCardinalityCheck = isCardinalityCheckNeeded(matchedActions), + emitNotMatchedTargetRows = false, + output = mergeRowsOutput, + joinPlan) + + // build a plan to write the row delta to the table + val writeRelation = relation.copy(table = operationTable) + val projections = buildMergeDeltaProjections(mergeRows, rowAttrs, rowIdAttrs, metadataAttrs) + WriteIcebergDelta(writeRelation, mergeRows, relation, projections) + } + + private def actionCondition(action: MergeAction): Expression = { + action.condition.getOrElse(TrueLiteral) + } + + private def actionOutput( + clause: MergeAction, + metadataAttrs: Seq[Attribute]): Seq[Expression] = { + + clause match { + case u: UpdateAction => + u.assignments.map(_.value) ++ metadataAttrs + + case _: DeleteAction => + Nil + + case i: InsertAction => + i.assignments.map(_.value) ++ metadataAttrs.map(attr => Literal(null, attr.dataType)) + + case other => + throw new AnalysisException(s"Unexpected action: $other") + } + } + + private def deltaActionOutput( + action: MergeAction, + deleteRowValues: Seq[Expression], + metadataAttrs: Seq[Attribute]): Seq[Expression] = { + + action match { + case u: UpdateAction => + Seq(Literal(UPDATE_OPERATION)) ++ u.assignments.map(_.value) ++ metadataAttrs + + case _: DeleteAction => + Seq(Literal(DELETE_OPERATION)) ++ deleteRowValues ++ metadataAttrs + + case i: InsertAction => + val metadataAttrValues = metadataAttrs.map(attr => Literal(null, attr.dataType)) + Seq(Literal(INSERT_OPERATION)) ++ i.assignments.map(_.value) ++ metadataAttrValues + + case other => + throw new AnalysisException(s"Unexpected action: $other") + } + } + + private def buildMergeRowsOutput( + matchedOutputs: Seq[Seq[Expression]], + notMatchedOutputs: Seq[Seq[Expression]], + attrs: Seq[Attribute]): Seq[Attribute] = { + + // collect all outputs from matched and not matched actions (ignoring DELETEs) + val outputs = matchedOutputs.filter(_.nonEmpty) ++ notMatchedOutputs.filter(_.nonEmpty) + + // build a correct nullability map for output attributes + // an attribute is nullable if at least one matched or not matched action may produce null + val nullabilityMap = attrs.indices.map { index => + index -> outputs.exists(output => output(index).nullable) + }.toMap + + attrs.zipWithIndex.map { case (attr, index) => + AttributeReference(attr.name, attr.dataType, nullabilityMap(index), attr.metadata)() + } + } + + private def isCardinalityCheckNeeded(actions: Seq[MergeAction]): Boolean = actions match { + case Seq(DeleteAction(None)) => false + case _ => true + } + + private def buildDeltaDeleteRowValues( + rowAttrs: Seq[Attribute], + rowIdAttrs: Seq[Attribute]): Seq[Expression] = { + + // nullify all row attrs that are not part of the row ID + val rowIdAttSet = AttributeSet(rowIdAttrs) + rowAttrs.map { + case attr if rowIdAttSet.contains(attr) => attr + case attr => Literal(null, attr.dataType) + } + } + + private def resolveAttrRef(ref: NamedReference, plan: LogicalPlan): AttributeReference = { + ExtendedV2ExpressionUtils.resolveRef[AttributeReference](ref, plan) + } + + private def buildMergeDeltaProjections( + mergeRows: MergeRows, + rowAttrs: Seq[Attribute], + rowIdAttrs: Seq[Attribute], + metadataAttrs: Seq[Attribute]): WriteDeltaProjections = { + + val outputAttrs = mergeRows.output + + val outputs = mergeRows.matchedOutputs ++ mergeRows.notMatchedOutputs + val insertAndUpdateOutputs = outputs.filterNot(_.head == Literal(DELETE_OPERATION)) + val updateAndDeleteOutputs = outputs.filterNot(_.head == Literal(INSERT_OPERATION)) + + val rowProjection = if (rowAttrs.nonEmpty) { + Some(newLazyProjection(insertAndUpdateOutputs, outputAttrs, rowAttrs)) + } else { + None + } + + val rowIdProjection = newLazyProjection(updateAndDeleteOutputs, outputAttrs, rowIdAttrs) + + val metadataProjection = if (metadataAttrs.nonEmpty) { + Some(newLazyProjection(updateAndDeleteOutputs, outputAttrs, metadataAttrs)) + } else { + None + } + + WriteDeltaProjections(rowProjection, rowIdProjection, metadataProjection) + } + + // the projection is done by name, ignoring expr IDs + private def newLazyProjection( + outputs: Seq[Seq[Expression]], + outputAttrs: Seq[Attribute], + projectedAttrs: Seq[Attribute]): ProjectingInternalRow = { + + val projectedOrdinals = projectedAttrs.map(attr => outputAttrs.indexWhere(_.name == attr.name)) + + val structFields = projectedAttrs.zip(projectedOrdinals).map { case (attr, ordinal) => + // output attr is nullable if at least one action may produce null for that attr + // but row ID and metadata attrs are projected only in update/delete actions and + // row attrs are projected only in insert/update actions + // that's why the projection schema must rely only on relevant action outputs + // instead of blindly inheriting the output attr nullability + val nullable = outputs.exists(output => output(ordinal).nullable) + StructField(attr.name, attr.dataType, nullable, attr.metadata) + } + val schema = StructType(structFields) + + ProjectingInternalRow(schema, projectedOrdinals) + } + + // splits the MERGE condition into a predicate that references columns only from the target table, + // which can be pushed down, and a predicate used as a join condition to find matches + private def splitMergeCond( + cond: Expression, + targetTable: LogicalPlan): (Expression, Expression) = { + + val (targetPredicates, joinPredicates) = splitConjunctivePredicates(cond) + .partition(_.references.subsetOf(targetTable.outputSet)) + val targetCond = targetPredicates.reduceOption(And).getOrElse(TrueLiteral) + val joinCond = joinPredicates.reduceOption(And).getOrElse(TrueLiteral) + (targetCond, joinCond) + } +} diff --git a/spark/v3.4/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/analysis/RewriteRowLevelIcebergCommand.scala b/spark/v3.4/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/analysis/RewriteRowLevelIcebergCommand.scala new file mode 100644 index 000000000000..b460f648d28b --- /dev/null +++ b/spark/v3.4/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/analysis/RewriteRowLevelIcebergCommand.scala @@ -0,0 +1,91 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.spark.sql.catalyst.analysis + +import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.catalyst.ProjectingInternalRow +import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.catalyst.expressions.AttributeReference +import org.apache.spark.sql.catalyst.expressions.ExtendedV2ExpressionUtils +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.catalyst.util.WriteDeltaProjections +import org.apache.spark.sql.connector.write.RowLevelOperation +import org.apache.spark.sql.connector.write.SupportsDelta +import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation +import org.apache.spark.sql.types.StructType + +trait RewriteRowLevelIcebergCommand extends RewriteRowLevelCommand { + + // override as the existing Spark method does not work for UPDATE and MERGE + protected override def buildWriteDeltaProjections( + plan: LogicalPlan, + rowAttrs: Seq[Attribute], + rowIdAttrs: Seq[Attribute], + metadataAttrs: Seq[Attribute]): WriteDeltaProjections = { + + val rowProjection = if (rowAttrs.nonEmpty) { + Some(newLazyProjection(plan, rowAttrs)) + } else { + None + } + + val rowIdProjection = newLazyProjection(plan, rowIdAttrs) + + val metadataProjection = if (metadataAttrs.nonEmpty) { + Some(newLazyProjection(plan, metadataAttrs)) + } else { + None + } + + WriteDeltaProjections(rowProjection, rowIdProjection, metadataProjection) + } + + // the projection is done by name, ignoring expr IDs + private def newLazyProjection( + plan: LogicalPlan, + projectedAttrs: Seq[Attribute]): ProjectingInternalRow = { + + val projectedOrdinals = projectedAttrs.map(attr => plan.output.indexWhere(_.name == attr.name)) + val schema = StructType.fromAttributes(projectedOrdinals.map(plan.output(_))) + ProjectingInternalRow(schema, projectedOrdinals) + } + + protected def resolveRowIdAttrs( + relation: DataSourceV2Relation, + operation: RowLevelOperation): Seq[AttributeReference] = { + + operation match { + case supportsDelta: SupportsDelta => + val rowIdAttrs = ExtendedV2ExpressionUtils.resolveRefs[AttributeReference]( + supportsDelta.rowId.toSeq, + relation) + + val nullableRowIdAttrs = rowIdAttrs.filter(_.nullable) + if (nullableRowIdAttrs.nonEmpty) { + throw new AnalysisException(s"Row ID attrs cannot be nullable: $nullableRowIdAttrs") + } + + rowIdAttrs + + case other => + throw new AnalysisException(s"Operation $other does not support deltas") + } + } +} diff --git a/spark/v3.4/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/analysis/RewriteUpdateTable.scala b/spark/v3.4/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/analysis/RewriteUpdateTable.scala new file mode 100644 index 000000000000..006040081b8f --- /dev/null +++ b/spark/v3.4/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/analysis/RewriteUpdateTable.scala @@ -0,0 +1,183 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.spark.sql.catalyst.analysis + +import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.catalyst.expressions.Alias +import org.apache.spark.sql.catalyst.expressions.EqualNullSafe +import org.apache.spark.sql.catalyst.expressions.Expression +import org.apache.spark.sql.catalyst.expressions.If +import org.apache.spark.sql.catalyst.expressions.Literal +import org.apache.spark.sql.catalyst.expressions.Not +import org.apache.spark.sql.catalyst.expressions.SubqueryExpression +import org.apache.spark.sql.catalyst.plans.logical.Assignment +import org.apache.spark.sql.catalyst.plans.logical.Filter +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.catalyst.plans.logical.Project +import org.apache.spark.sql.catalyst.plans.logical.ReplaceIcebergData +import org.apache.spark.sql.catalyst.plans.logical.Union +import org.apache.spark.sql.catalyst.plans.logical.UpdateIcebergTable +import org.apache.spark.sql.catalyst.plans.logical.WriteIcebergDelta +import org.apache.spark.sql.catalyst.util.RowDeltaUtils._ +import org.apache.spark.sql.connector.catalog.SupportsRowLevelOperations +import org.apache.spark.sql.connector.write.RowLevelOperation.Command.UPDATE +import org.apache.spark.sql.connector.write.RowLevelOperationTable +import org.apache.spark.sql.connector.write.SupportsDelta +import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation +import org.apache.spark.sql.util.CaseInsensitiveStringMap + +/** + * Assigns a rewrite plan for v2 tables that support rewriting data to handle UPDATE statements. + * + * This rule assumes the commands have been fully resolved and all assignments have been aligned. + * That's why it must be run after AlignRowLevelCommandAssignments. + * + * This rule also must be run in the same batch with DeduplicateRelations in Spark. + */ +object RewriteUpdateTable extends RewriteRowLevelIcebergCommand { + + override def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { + case u @ UpdateIcebergTable(aliasedTable, assignments, cond, None) if u.resolved && u.aligned => + EliminateSubqueryAliases(aliasedTable) match { + case r @ DataSourceV2Relation(tbl: SupportsRowLevelOperations, _, _, _, _) => + val table = buildOperationTable(tbl, UPDATE, CaseInsensitiveStringMap.empty()) + val updateCond = cond.getOrElse(Literal.TrueLiteral) + val rewritePlan = table.operation match { + case _: SupportsDelta => + buildWriteDeltaPlan(r, table, assignments, updateCond) + case _ if SubqueryExpression.hasSubquery(updateCond) => + buildReplaceDataWithUnionPlan(r, table, assignments, updateCond) + case _ => + buildReplaceDataPlan(r, table, assignments, updateCond) + } + UpdateIcebergTable(r, assignments, cond, Some(rewritePlan)) + + case p => + throw new AnalysisException(s"$p is not an Iceberg table") + } + } + + // build a rewrite plan for sources that support replacing groups of data (e.g. files, partitions) + // if the condition does NOT contain a subquery + private def buildReplaceDataPlan( + relation: DataSourceV2Relation, + operationTable: RowLevelOperationTable, + assignments: Seq[Assignment], + cond: Expression): ReplaceIcebergData = { + + // resolve all needed attrs (e.g. metadata attrs for grouping data on write) + val metadataAttrs = resolveRequiredMetadataAttrs(relation, operationTable.operation) + + // construct a read relation and include all required metadata columns + val readRelation = buildRelationWithAttrs(relation, operationTable, metadataAttrs) + + // build a plan with updated and copied over records + val updatedAndRemainingRowsPlan = buildUpdateProjection(readRelation, assignments, cond) + + // build a plan to replace read groups in the table + val writeRelation = relation.copy(table = operationTable) + ReplaceIcebergData(writeRelation, updatedAndRemainingRowsPlan, relation) + } + + // build a rewrite plan for sources that support replacing groups of data (e.g. files, partitions) + // if the condition contains a subquery + private def buildReplaceDataWithUnionPlan( + relation: DataSourceV2Relation, + operationTable: RowLevelOperationTable, + assignments: Seq[Assignment], + cond: Expression): ReplaceIcebergData = { + + // resolve all needed attrs (e.g. metadata attrs for grouping data on write) + val metadataAttrs = resolveRequiredMetadataAttrs(relation, operationTable.operation) + + // construct a read relation and include all required metadata columns + // the same read relation will be used to read records that must be updated and be copied over + // DeduplicateRelations will take care of duplicated attr IDs + val readRelation = buildRelationWithAttrs(relation, operationTable, metadataAttrs) + + // build a plan for records that match the cond and should be updated + val matchedRowsPlan = Filter(cond, readRelation) + val updatedRowsPlan = buildUpdateProjection(matchedRowsPlan, assignments) + + // build a plan for records that did not match the cond but had to be copied over + val remainingRowFilter = Not(EqualNullSafe(cond, Literal.TrueLiteral)) + val remainingRowsPlan = Filter(remainingRowFilter, readRelation) + + // new state is a union of updated and copied over records + val updatedAndRemainingRowsPlan = Union(updatedRowsPlan, remainingRowsPlan) + + // build a plan to replace read groups in the table + val writeRelation = relation.copy(table = operationTable) + ReplaceIcebergData(writeRelation, updatedAndRemainingRowsPlan, relation) + } + + // build a rewrite plan for sources that support row deltas + private def buildWriteDeltaPlan( + relation: DataSourceV2Relation, + operationTable: RowLevelOperationTable, + assignments: Seq[Assignment], + cond: Expression): WriteIcebergDelta = { + + // resolve all needed attrs (e.g. row ID and any required metadata attrs) + val rowAttrs = relation.output + val rowIdAttrs = resolveRowIdAttrs(relation, operationTable.operation) + val metadataAttrs = resolveRequiredMetadataAttrs(relation, operationTable.operation) + + // construct a scan relation and include all required metadata columns + val readRelation = buildRelationWithAttrs(relation, operationTable, rowIdAttrs ++ metadataAttrs) + + // build a plan for updated records that match the cond + val matchedRowsPlan = Filter(cond, readRelation) + val updatedRowsPlan = buildUpdateProjection(matchedRowsPlan, assignments) + val operationType = Alias(Literal(UPDATE_OPERATION), OPERATION_COLUMN)() + val project = Project(operationType +: updatedRowsPlan.output, updatedRowsPlan) + + // build a plan to write the row delta to the table + val writeRelation = relation.copy(table = operationTable) + val projections = buildWriteDeltaProjections(project, rowAttrs, rowIdAttrs, metadataAttrs) + WriteIcebergDelta(writeRelation, project, relation, projections) + } + + // this method assumes the assignments have been already aligned before + // the condition passed to this method may be different from the UPDATE condition + private def buildUpdateProjection( + plan: LogicalPlan, + assignments: Seq[Assignment], + cond: Expression = Literal.TrueLiteral): LogicalPlan = { + + // TODO: avoid executing the condition for each column + + // the plan output may include metadata columns that are not modified + // that's why the number of assignments may not match the number of plan output columns + + val assignedValues = assignments.map(_.value) + val updatedValues = plan.output.zipWithIndex.map { case (attr, index) => + if (index < assignments.size) { + val assignedExpr = assignedValues(index) + val updatedValue = If(cond, assignedExpr, attr) + Alias(updatedValue, attr.name)() + } else { + attr + } + } + + Project(updatedValues, plan) + } +} diff --git a/spark/v3.4/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/expressions/AssignmentUtils.scala b/spark/v3.4/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/expressions/AssignmentUtils.scala new file mode 100644 index 000000000000..ce3818922c78 --- /dev/null +++ b/spark/v3.4/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/expressions/AssignmentUtils.scala @@ -0,0 +1,86 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.spark.sql.catalyst.expressions + +import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.catalyst.SQLConfHelper +import org.apache.spark.sql.catalyst.plans.logical.Assignment +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.catalyst.util.CharVarcharUtils +import org.apache.spark.sql.types.DataType + +object AssignmentUtils extends SQLConfHelper { + + /** + * Checks whether assignments are aligned and match table columns. + * + * @param table a target table + * @param assignments assignments to check + * @return true if the assignments are aligned + */ + def aligned(table: LogicalPlan, assignments: Seq[Assignment]): Boolean = { + val sameSize = table.output.size == assignments.size + sameSize && table.output.zip(assignments).forall { case (attr, assignment) => + val key = assignment.key + val value = assignment.value + val refsEqual = toAssignmentRef(attr).zip(toAssignmentRef(key)) + .forall{ case (attrRef, keyRef) => conf.resolver(attrRef, keyRef)} + + refsEqual && + DataType.equalsIgnoreCompatibleNullability(value.dataType, attr.dataType) && + (attr.nullable || !value.nullable) + } + } + + def toAssignmentRef(expr: Expression): Seq[String] = expr match { + case attr: AttributeReference => + Seq(attr.name) + case Alias(child, _) => + toAssignmentRef(child) + case GetStructField(child, _, Some(name)) => + toAssignmentRef(child) :+ name + case other: ExtractValue => + throw new AnalysisException(s"Updating nested fields is only supported for structs: $other") + case other => + throw new AnalysisException(s"Cannot convert to a reference, unsupported expression: $other") + } + + def handleCharVarcharLimits(assignment: Assignment): Assignment = { + val key = assignment.key + val value = assignment.value + + val rawKeyType = key.transform { + case attr: AttributeReference => + CharVarcharUtils.getRawType(attr.metadata) + .map(attr.withDataType) + .getOrElse(attr) + }.dataType + + if (CharVarcharUtils.hasCharVarchar(rawKeyType)) { + val newKey = key.transform { + case attr: AttributeReference => CharVarcharUtils.cleanAttrMetadata(attr) + } + val newValue = CharVarcharUtils.stringLengthCheck(value, rawKeyType) + Assignment(newKey, newValue) + } else { + assignment + } + } +} diff --git a/spark/v3.4/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExtendedV2ExpressionUtils.scala b/spark/v3.4/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExtendedV2ExpressionUtils.scala new file mode 100644 index 000000000000..16ff67a70522 --- /dev/null +++ b/spark/v3.4/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExtendedV2ExpressionUtils.scala @@ -0,0 +1,103 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.spark.sql.catalyst.expressions + +import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.catalyst.SQLConfHelper +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.connector.expressions.{Expression => V2Expression} +import org.apache.spark.sql.connector.expressions.{SortDirection => V2SortDirection} +import org.apache.spark.sql.connector.expressions.{NullOrdering => V2NullOrdering} +import org.apache.spark.sql.connector.expressions.BucketTransform +import org.apache.spark.sql.connector.expressions.DaysTransform +import org.apache.spark.sql.connector.expressions.FieldReference +import org.apache.spark.sql.connector.expressions.HoursTransform +import org.apache.spark.sql.connector.expressions.IdentityTransform +import org.apache.spark.sql.connector.expressions.MonthsTransform +import org.apache.spark.sql.connector.expressions.NamedReference +import org.apache.spark.sql.connector.expressions.SortValue +import org.apache.spark.sql.connector.expressions.Transform +import org.apache.spark.sql.connector.expressions.TruncateTransform +import org.apache.spark.sql.connector.expressions.YearsTransform +import org.apache.spark.sql.errors.QueryCompilationErrors + +/** + * A class that is inspired by V2ExpressionUtils in Spark but supports Iceberg transforms. + */ +object ExtendedV2ExpressionUtils extends SQLConfHelper { + import org.apache.spark.sql.connector.catalog.CatalogV2Implicits.MultipartIdentifierHelper + + def resolveRef[T <: NamedExpression](ref: NamedReference, plan: LogicalPlan): T = { + plan.resolve(ref.fieldNames.toSeq, conf.resolver) match { + case Some(namedExpr) => + namedExpr.asInstanceOf[T] + case None => + val name = ref.fieldNames.toSeq.quoted + val outputString = plan.output.map(_.name).mkString(",") + throw QueryCompilationErrors.cannotResolveAttributeError(name, outputString) + } + } + + def resolveRefs[T <: NamedExpression](refs: Seq[NamedReference], plan: LogicalPlan): Seq[T] = { + refs.map(ref => resolveRef[T](ref, plan)) + } + + def toCatalyst(expr: V2Expression, query: LogicalPlan): Expression = { + expr match { + case SortValue(child, direction, nullOrdering) => + val catalystChild = toCatalyst(child, query) + SortOrder(catalystChild, toCatalyst(direction), toCatalyst(nullOrdering), Seq.empty) + case IdentityTransform(ref) => + resolveRef[NamedExpression](ref, query) + case t: Transform if BucketTransform.unapply(t).isDefined => + t match { + // sort columns will be empty for bucket. + case BucketTransform(numBuckets, cols, _) => + IcebergBucketTransform(numBuckets, resolveRef[NamedExpression](cols.head, query)) + case _ => t.asInstanceOf[Expression] + // do nothing + } + case TruncateTransform(length, ref) => + IcebergTruncateTransform(resolveRef[NamedExpression](ref, query), length) + case YearsTransform(ref) => + IcebergYearTransform(resolveRef[NamedExpression](ref, query)) + case MonthsTransform(ref) => + IcebergMonthTransform(resolveRef[NamedExpression](ref, query)) + case DaysTransform(ref) => + IcebergDayTransform(resolveRef[NamedExpression](ref, query)) + case HoursTransform(ref) => + IcebergHourTransform(resolveRef[NamedExpression](ref, query)) + case ref: FieldReference => + resolveRef[NamedExpression](ref, query) + case _ => + throw new AnalysisException(s"$expr is not currently supported") + } + } + + private def toCatalyst(direction: V2SortDirection): SortDirection = direction match { + case V2SortDirection.ASCENDING => Ascending + case V2SortDirection.DESCENDING => Descending + } + + private def toCatalyst(nullOrdering: V2NullOrdering): NullOrdering = nullOrdering match { + case V2NullOrdering.NULLS_FIRST => NullsFirst + case V2NullOrdering.NULLS_LAST => NullsLast + } +} diff --git a/spark/v3.4/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ExtendedReplaceNullWithFalseInPredicate.scala b/spark/v3.4/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ExtendedReplaceNullWithFalseInPredicate.scala new file mode 100644 index 000000000000..d62cf6d83969 --- /dev/null +++ b/spark/v3.4/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ExtendedReplaceNullWithFalseInPredicate.scala @@ -0,0 +1,132 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.spark.sql.catalyst.optimizer + +import org.apache.spark.sql.catalyst.expressions.And +import org.apache.spark.sql.catalyst.expressions.CaseWhen +import org.apache.spark.sql.catalyst.expressions.Expression +import org.apache.spark.sql.catalyst.expressions.If +import org.apache.spark.sql.catalyst.expressions.In +import org.apache.spark.sql.catalyst.expressions.InSet +import org.apache.spark.sql.catalyst.expressions.Literal +import org.apache.spark.sql.catalyst.expressions.Literal.FalseLiteral +import org.apache.spark.sql.catalyst.expressions.Not +import org.apache.spark.sql.catalyst.expressions.Or +import org.apache.spark.sql.catalyst.plans.logical.DeleteAction +import org.apache.spark.sql.catalyst.plans.logical.DeleteFromIcebergTable +import org.apache.spark.sql.catalyst.plans.logical.InsertAction +import org.apache.spark.sql.catalyst.plans.logical.InsertStarAction +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.catalyst.plans.logical.MergeAction +import org.apache.spark.sql.catalyst.plans.logical.MergeIntoIcebergTable +import org.apache.spark.sql.catalyst.plans.logical.UpdateAction +import org.apache.spark.sql.catalyst.plans.logical.UpdateIcebergTable +import org.apache.spark.sql.catalyst.plans.logical.UpdateStarAction +import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.catalyst.trees.TreePattern.INSET +import org.apache.spark.sql.catalyst.trees.TreePattern.NULL_LITERAL +import org.apache.spark.sql.catalyst.trees.TreePattern.TRUE_OR_FALSE_LITERAL +import org.apache.spark.sql.types.BooleanType +import org.apache.spark.util.Utils + +/** + * A rule similar to ReplaceNullWithFalseInPredicate in Spark but applies to Iceberg row-level commands. + */ +object ExtendedReplaceNullWithFalseInPredicate extends Rule[LogicalPlan] { + + override def apply(plan: LogicalPlan): LogicalPlan = plan.transformWithPruning( + _.containsAnyPattern(NULL_LITERAL, TRUE_OR_FALSE_LITERAL, INSET)) { + + case d @ DeleteFromIcebergTable(_, Some(cond), _) => + d.copy(condition = Some(replaceNullWithFalse(cond))) + + case u @ UpdateIcebergTable(_, _, Some(cond), _) => + u.copy(condition = Some(replaceNullWithFalse(cond))) + + case m @ MergeIntoIcebergTable(_, _, mergeCond, matchedActions, notMatchedActions, _) => + m.copy( + mergeCondition = replaceNullWithFalse(mergeCond), + matchedActions = replaceNullWithFalse(matchedActions), + notMatchedActions = replaceNullWithFalse(notMatchedActions)) + } + + /** + * Recursively traverse the Boolean-type expression to replace + * `Literal(null, BooleanType)` with `FalseLiteral`, if possible. + * + * Note that `transformExpressionsDown` can not be used here as we must stop as soon as we hit + * an expression that is not [[CaseWhen]], [[If]], [[And]], [[Or]] or + * `Literal(null, BooleanType)`. + */ + private def replaceNullWithFalse(e: Expression): Expression = e match { + case Literal(null, BooleanType) => + FalseLiteral + // In SQL, the `Not(IN)` expression evaluates as follows: + // `NULL not in (1)` -> NULL + // `NULL not in (1, NULL)` -> NULL + // `1 not in (1, NULL)` -> false + // `1 not in (2, NULL)` -> NULL + // In predicate, NULL is equal to false, so we can simplify them to false directly. + case Not(In(value, list)) if (value +: list).exists(isNullLiteral) => + FalseLiteral + case Not(InSet(value, list)) if isNullLiteral(value) || list.contains(null) => + FalseLiteral + + case And(left, right) => + And(replaceNullWithFalse(left), replaceNullWithFalse(right)) + case Or(left, right) => + Or(replaceNullWithFalse(left), replaceNullWithFalse(right)) + case cw: CaseWhen if cw.dataType == BooleanType => + val newBranches = cw.branches.map { case (cond, value) => + replaceNullWithFalse(cond) -> replaceNullWithFalse(value) + } + val newElseValue = cw.elseValue.map(replaceNullWithFalse).getOrElse(FalseLiteral) + CaseWhen(newBranches, newElseValue) + case i @ If(pred, trueVal, falseVal) if i.dataType == BooleanType => + If(replaceNullWithFalse(pred), replaceNullWithFalse(trueVal), replaceNullWithFalse(falseVal)) + case e if e.dataType == BooleanType => + e + case e => + val message = "Expected a Boolean type expression in replaceNullWithFalse, " + + s"but got the type `${e.dataType.catalogString}` in `${e.sql}`." + if (Utils.isTesting) { + throw new IllegalArgumentException(message) + } else { + logWarning(message) + e + } + } + + private def isNullLiteral(e: Expression): Boolean = e match { + case Literal(null, _) => true + case _ => false + } + + private def replaceNullWithFalse(mergeActions: Seq[MergeAction]): Seq[MergeAction] = { + mergeActions.map { + case u @ UpdateAction(Some(cond), _) => u.copy(condition = Some(replaceNullWithFalse(cond))) + case u @ UpdateStarAction(Some(cond)) => u.copy(condition = Some(replaceNullWithFalse(cond))) + case d @ DeleteAction(Some(cond)) => d.copy(condition = Some(replaceNullWithFalse(cond))) + case i @ InsertAction(Some(cond), _) => i.copy(condition = Some(replaceNullWithFalse(cond))) + case i @ InsertStarAction(Some(cond)) => i.copy(condition = Some(replaceNullWithFalse(cond))) + case other => other + } + } +} diff --git a/spark/v3.4/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ExtendedSimplifyConditionalsInPredicate.scala b/spark/v3.4/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ExtendedSimplifyConditionalsInPredicate.scala new file mode 100644 index 000000000000..f4df565c44df --- /dev/null +++ b/spark/v3.4/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ExtendedSimplifyConditionalsInPredicate.scala @@ -0,0 +1,94 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.spark.sql.catalyst.optimizer + +import org.apache.spark.sql.catalyst.expressions.And +import org.apache.spark.sql.catalyst.expressions.CaseWhen +import org.apache.spark.sql.catalyst.expressions.Coalesce +import org.apache.spark.sql.catalyst.expressions.Expression +import org.apache.spark.sql.catalyst.expressions.If +import org.apache.spark.sql.catalyst.expressions.Literal +import org.apache.spark.sql.catalyst.expressions.Literal.FalseLiteral +import org.apache.spark.sql.catalyst.expressions.Literal.TrueLiteral +import org.apache.spark.sql.catalyst.expressions.Not +import org.apache.spark.sql.catalyst.expressions.Or +import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.catalyst.trees.TreePattern.CASE_WHEN +import org.apache.spark.sql.catalyst.trees.TreePattern.IF +import org.apache.spark.sql.types.BooleanType + +/** + * A rule similar to SimplifyConditionalsInPredicate in Spark but applies to Iceberg row-level commands. + */ +object ExtendedSimplifyConditionalsInPredicate extends Rule[LogicalPlan] { + + override def apply(plan: LogicalPlan): LogicalPlan = plan.transformWithPruning( + _.containsAnyPattern(CASE_WHEN, IF)) { + + case d @ DeleteFromIcebergTable(_, Some(cond), _) => + d.copy(condition = Some(simplifyConditional(cond))) + + case u @ UpdateIcebergTable(_, _, Some(cond), _) => + u.copy(condition = Some(simplifyConditional(cond))) + + case m @ MergeIntoIcebergTable(_, _, mergeCond, matchedActions, notMatchedActions, _) => + m.copy( + mergeCondition = simplifyConditional(mergeCond), + matchedActions = simplifyConditional(matchedActions), + notMatchedActions = simplifyConditional(notMatchedActions)) + } + + private def simplifyConditional(e: Expression): Expression = e match { + case And(left, right) => And(simplifyConditional(left), simplifyConditional(right)) + case Or(left, right) => Or(simplifyConditional(left), simplifyConditional(right)) + case If(cond, trueValue, FalseLiteral) => And(cond, trueValue) + case If(cond, trueValue, TrueLiteral) => Or(Not(Coalesce(Seq(cond, FalseLiteral))), trueValue) + case If(cond, FalseLiteral, falseValue) => + And(Not(Coalesce(Seq(cond, FalseLiteral))), falseValue) + case If(cond, TrueLiteral, falseValue) => Or(cond, falseValue) + case CaseWhen(Seq((cond, trueValue)), + Some(FalseLiteral) | Some(Literal(null, BooleanType)) | None) => + And(cond, trueValue) + case CaseWhen(Seq((cond, trueValue)), Some(TrueLiteral)) => + Or(Not(Coalesce(Seq(cond, FalseLiteral))), trueValue) + case CaseWhen(Seq((cond, FalseLiteral)), Some(elseValue)) => + And(Not(Coalesce(Seq(cond, FalseLiteral))), elseValue) + case CaseWhen(Seq((cond, TrueLiteral)), Some(elseValue)) => + Or(cond, elseValue) + case e if e.dataType == BooleanType => e + case e => + assert(e.dataType != BooleanType, + "Expected a Boolean type expression in ExtendedSimplifyConditionalsInPredicate, " + + s"but got the type `${e.dataType.catalogString}` in `${e.sql}`.") + e + } + + private def simplifyConditional(mergeActions: Seq[MergeAction]): Seq[MergeAction] = { + mergeActions.map { + case u @ UpdateAction(Some(cond), _) => u.copy(condition = Some(simplifyConditional(cond))) + case u @ UpdateStarAction(Some(cond)) => u.copy(condition = Some(simplifyConditional(cond))) + case d @ DeleteAction(Some(cond)) => d.copy(condition = Some(simplifyConditional(cond))) + case i @ InsertAction(Some(cond), _) => i.copy(condition = Some(simplifyConditional(cond))) + case i @ InsertStarAction(Some(cond)) => i.copy(condition = Some(simplifyConditional(cond))) + case other => other + } + } +} diff --git a/spark/v3.4/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/parser/extensions/IcebergSparkSqlExtensionsParser.scala b/spark/v3.4/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/parser/extensions/IcebergSparkSqlExtensionsParser.scala new file mode 100644 index 000000000000..5d5c9bba2714 --- /dev/null +++ b/spark/v3.4/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/parser/extensions/IcebergSparkSqlExtensionsParser.scala @@ -0,0 +1,400 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.spark.sql.catalyst.parser.extensions + +import java.util.Locale +import org.antlr.v4.runtime._ +import org.antlr.v4.runtime.atn.PredictionMode +import org.antlr.v4.runtime.misc.Interval +import org.antlr.v4.runtime.misc.ParseCancellationException +import org.antlr.v4.runtime.tree.TerminalNodeImpl +import org.apache.iceberg.common.DynConstructors +import org.apache.iceberg.spark.ExtendedParser +import org.apache.iceberg.spark.ExtendedParser.RawOrderField +import org.apache.iceberg.spark.Spark3Util +import org.apache.iceberg.spark.source.SparkTable +import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.catalyst.FunctionIdentifier +import org.apache.spark.sql.catalyst.TableIdentifier +import org.apache.spark.sql.catalyst.analysis.EliminateSubqueryAliases +import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation +import org.apache.spark.sql.catalyst.expressions.Expression +import org.apache.spark.sql.catalyst.parser.ParserInterface +import org.apache.spark.sql.catalyst.parser.extensions.IcebergSqlExtensionsParser.NonReservedContext +import org.apache.spark.sql.catalyst.parser.extensions.IcebergSqlExtensionsParser.QuotedIdentifierContext +import org.apache.spark.sql.catalyst.plans.logical.DeleteFromIcebergTable +import org.apache.spark.sql.catalyst.plans.logical.DeleteFromTable +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.catalyst.plans.logical.MergeIntoContext +import org.apache.spark.sql.catalyst.plans.logical.MergeIntoTable +import org.apache.spark.sql.catalyst.plans.logical.UnresolvedMergeIntoIcebergTable +import org.apache.spark.sql.catalyst.plans.logical.UpdateIcebergTable +import org.apache.spark.sql.catalyst.plans.logical.UpdateTable +import org.apache.spark.sql.catalyst.trees.Origin +import org.apache.spark.sql.connector.catalog.Table +import org.apache.spark.sql.connector.catalog.TableCatalog +import org.apache.spark.sql.execution.command.ExplainCommand +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.internal.VariableSubstitution +import org.apache.spark.sql.types.DataType +import org.apache.spark.sql.types.StructType +import scala.jdk.CollectionConverters._ +import scala.util.Try + +class IcebergSparkSqlExtensionsParser(delegate: ParserInterface) extends ParserInterface with ExtendedParser { + + import IcebergSparkSqlExtensionsParser._ + + private lazy val substitutor = substitutorCtor.newInstance(SQLConf.get) + private lazy val astBuilder = new IcebergSqlExtensionsAstBuilder(delegate) + + /** + * Parse a string to a DataType. + */ + override def parseDataType(sqlText: String): DataType = { + delegate.parseDataType(sqlText) + } + + /** + * Parse a string to a raw DataType without CHAR/VARCHAR replacement. + */ + def parseRawDataType(sqlText: String): DataType = throw new UnsupportedOperationException() + + /** + * Parse a string to an Expression. + */ + override def parseExpression(sqlText: String): Expression = { + delegate.parseExpression(sqlText) + } + + /** + * Parse a string to a TableIdentifier. + */ + override def parseTableIdentifier(sqlText: String): TableIdentifier = { + delegate.parseTableIdentifier(sqlText) + } + + /** + * Parse a string to a FunctionIdentifier. + */ + override def parseFunctionIdentifier(sqlText: String): FunctionIdentifier = { + delegate.parseFunctionIdentifier(sqlText) + } + + /** + * Parse a string to a multi-part identifier. + */ + override def parseMultipartIdentifier(sqlText: String): Seq[String] = { + delegate.parseMultipartIdentifier(sqlText) + } + + /** + * Creates StructType for a given SQL string, which is a comma separated list of field + * definitions which will preserve the correct Hive metadata. + */ + override def parseTableSchema(sqlText: String): StructType = { + delegate.parseTableSchema(sqlText) + } + + override def parseSortOrder(sqlText: String): java.util.List[RawOrderField] = { + val fields = parse(sqlText) { parser => astBuilder.visitSingleOrder(parser.singleOrder()) } + fields.map { field => + val (term, direction, order) = field + new RawOrderField(term, direction, order) + }.asJava + } + + /** + * Parse a string to a LogicalPlan. + */ + override def parsePlan(sqlText: String): LogicalPlan = { + val sqlTextAfterSubstitution = substitutor.substitute(sqlText) + if (isIcebergCommand(sqlTextAfterSubstitution)) { + parse(sqlTextAfterSubstitution) { parser => astBuilder.visit(parser.singleStatement()) }.asInstanceOf[LogicalPlan] + } else { + val parsedPlan = delegate.parsePlan(sqlText) + parsedPlan match { + case e: ExplainCommand => + e.copy(logicalPlan = replaceRowLevelCommands(e.logicalPlan)) + case p => + replaceRowLevelCommands(p) + } + } + } + + private def replaceRowLevelCommands(plan: LogicalPlan): LogicalPlan = plan resolveOperatorsDown { + case DeleteFromTable(UnresolvedIcebergTable(aliasedTable), condition) => + DeleteFromIcebergTable(aliasedTable, Some(condition)) + + case UpdateTable(UnresolvedIcebergTable(aliasedTable), assignments, condition) => + UpdateIcebergTable(aliasedTable, assignments, condition) + + case MergeIntoTable(UnresolvedIcebergTable(aliasedTable), source, cond, matchedActions, notMatchedActions, Nil) => + // cannot construct MergeIntoIcebergTable right away as MERGE operations require special resolution + // that's why the condition and actions must be hidden from the regular resolution rules in Spark + // see ResolveMergeIntoTableReferences for details + val context = MergeIntoContext(cond, matchedActions, notMatchedActions) + UnresolvedMergeIntoIcebergTable(aliasedTable, source, context) + + case MergeIntoTable(UnresolvedIcebergTable(_), _, _, _, _, notMatchedBySourceActions) + if notMatchedBySourceActions.nonEmpty => + throw new AnalysisException("Iceberg does not support WHEN NOT MATCHED BY SOURCE clause") + } + + object UnresolvedIcebergTable { + + def unapply(plan: LogicalPlan): Option[LogicalPlan] = { + EliminateSubqueryAliases(plan) match { + case UnresolvedRelation(multipartIdentifier, _, _) if isIcebergTable(multipartIdentifier) => + Some(plan) + case _ => + None + } + } + + private def isIcebergTable(multipartIdent: Seq[String]): Boolean = { + val catalogAndIdentifier = Spark3Util.catalogAndIdentifier(SparkSession.active, multipartIdent.asJava) + catalogAndIdentifier.catalog match { + case tableCatalog: TableCatalog => + Try(tableCatalog.loadTable(catalogAndIdentifier.identifier)) + .map(isIcebergTable) + .getOrElse(false) + + case _ => + false + } + } + + private def isIcebergTable(table: Table): Boolean = table match { + case _: SparkTable => true + case _ => false + } + } + + private def isIcebergCommand(sqlText: String): Boolean = { + val normalized = sqlText.toLowerCase(Locale.ROOT).trim() + // Strip simple SQL comments that terminate a line, e.g. comments starting with `--` . + .replaceAll("--.*?\\n", " ") + // Strip newlines. + .replaceAll("\\s+", " ") + // Strip comments of the form /* ... */. This must come after stripping newlines so that + // comments that span multiple lines are caught. + .replaceAll("/\\*.*?\\*/", " ") + .trim() + normalized.startsWith("call") || ( + normalized.startsWith("alter table") && ( + normalized.contains("add partition field") || + normalized.contains("drop partition field") || + normalized.contains("replace partition field") || + normalized.contains("write ordered by") || + normalized.contains("write locally ordered by") || + normalized.contains("write distributed by") || + normalized.contains("write unordered") || + normalized.contains("set identifier fields") || + normalized.contains("drop identifier fields") || + isSnapshotRefDdl(normalized))) + } + + private def isSnapshotRefDdl(normalized: String): Boolean = { + normalized.contains("create branch") || + normalized.contains("replace branch") || + normalized.contains("create tag") || + normalized.contains("replace tag") || + normalized.contains("drop branch") || + normalized.contains("drop tag") + } + + protected def parse[T](command: String)(toResult: IcebergSqlExtensionsParser => T): T = { + val lexer = new IcebergSqlExtensionsLexer(new UpperCaseCharStream(CharStreams.fromString(command))) + lexer.removeErrorListeners() + lexer.addErrorListener(IcebergParseErrorListener) + + val tokenStream = new CommonTokenStream(lexer) + val parser = new IcebergSqlExtensionsParser(tokenStream) + parser.addParseListener(IcebergSqlExtensionsPostProcessor) + parser.removeErrorListeners() + parser.addErrorListener(IcebergParseErrorListener) + + try { + try { + // first, try parsing with potentially faster SLL mode + parser.getInterpreter.setPredictionMode(PredictionMode.SLL) + toResult(parser) + } + catch { + case _: ParseCancellationException => + // if we fail, parse with LL mode + tokenStream.seek(0) // rewind input stream + parser.reset() + + // Try Again. + parser.getInterpreter.setPredictionMode(PredictionMode.LL) + toResult(parser) + } + } + catch { + case e: IcebergParseException if e.command.isDefined => + throw e + case e: IcebergParseException => + throw e.withCommand(command) + case e: AnalysisException => + val position = Origin(e.line, e.startPosition) + throw new IcebergParseException(Option(command), e.message, position, position) + } + } + + override def parseQuery(sqlText: String): LogicalPlan = { + parsePlan(sqlText) + } +} + +object IcebergSparkSqlExtensionsParser { + private val substitutorCtor: DynConstructors.Ctor[VariableSubstitution] = + DynConstructors.builder() + .impl(classOf[VariableSubstitution]) + .impl(classOf[VariableSubstitution], classOf[SQLConf]) + .build() +} + +/* Copied from Apache Spark's to avoid dependency on Spark Internals */ +class UpperCaseCharStream(wrapped: CodePointCharStream) extends CharStream { + override def consume(): Unit = wrapped.consume + override def getSourceName(): String = wrapped.getSourceName + override def index(): Int = wrapped.index + override def mark(): Int = wrapped.mark + override def release(marker: Int): Unit = wrapped.release(marker) + override def seek(where: Int): Unit = wrapped.seek(where) + override def size(): Int = wrapped.size + + override def getText(interval: Interval): String = wrapped.getText(interval) + + // scalastyle:off + override def LA(i: Int): Int = { + val la = wrapped.LA(i) + if (la == 0 || la == IntStream.EOF) la + else Character.toUpperCase(la) + } + // scalastyle:on +} + +/** + * The post-processor validates & cleans-up the parse tree during the parse process. + */ +case object IcebergSqlExtensionsPostProcessor extends IcebergSqlExtensionsBaseListener { + + /** Remove the back ticks from an Identifier. */ + override def exitQuotedIdentifier(ctx: QuotedIdentifierContext): Unit = { + replaceTokenByIdentifier(ctx, 1) { token => + // Remove the double back ticks in the string. + token.setText(token.getText.replace("``", "`")) + token + } + } + + /** Treat non-reserved keywords as Identifiers. */ + override def exitNonReserved(ctx: NonReservedContext): Unit = { + replaceTokenByIdentifier(ctx, 0)(identity) + } + + private def replaceTokenByIdentifier( + ctx: ParserRuleContext, + stripMargins: Int)( + f: CommonToken => CommonToken = identity): Unit = { + val parent = ctx.getParent + parent.removeLastChild() + val token = ctx.getChild(0).getPayload.asInstanceOf[Token] + val newToken = new CommonToken( + new org.antlr.v4.runtime.misc.Pair(token.getTokenSource, token.getInputStream), + IcebergSqlExtensionsParser.IDENTIFIER, + token.getChannel, + token.getStartIndex + stripMargins, + token.getStopIndex - stripMargins) + parent.addChild(new TerminalNodeImpl(f(newToken))) + } +} + +/* Partially copied from Apache Spark's Parser to avoid dependency on Spark Internals */ +case object IcebergParseErrorListener extends BaseErrorListener { + override def syntaxError( + recognizer: Recognizer[_, _], + offendingSymbol: scala.Any, + line: Int, + charPositionInLine: Int, + msg: String, + e: RecognitionException): Unit = { + val (start, stop) = offendingSymbol match { + case token: CommonToken => + val start = Origin(Some(line), Some(token.getCharPositionInLine)) + val length = token.getStopIndex - token.getStartIndex + 1 + val stop = Origin(Some(line), Some(token.getCharPositionInLine + length)) + (start, stop) + case _ => + val start = Origin(Some(line), Some(charPositionInLine)) + (start, start) + } + throw new IcebergParseException(None, msg, start, stop) + } +} + +/** + * Copied from Apache Spark + * A [[ParseException]] is an [[AnalysisException]] that is thrown during the parse process. It + * contains fields and an extended error message that make reporting and diagnosing errors easier. + */ +class IcebergParseException( + val command: Option[String], + message: String, + val start: Origin, + val stop: Origin) extends AnalysisException(message, start.line, start.startPosition) { + + def this(message: String, ctx: ParserRuleContext) = { + this(Option(IcebergParserUtils.command(ctx)), + message, + IcebergParserUtils.position(ctx.getStart), + IcebergParserUtils.position(ctx.getStop)) + } + + override def getMessage: String = { + val builder = new StringBuilder + builder ++= "\n" ++= message + start match { + case Origin( + Some(l), Some(p), Some(startIndex), Some(stopIndex), Some(sqlText), Some(objectType), Some(objectName)) => + builder ++= s"(line $l, pos $p)\n" + command.foreach { cmd => + val (above, below) = cmd.split("\n").splitAt(l) + builder ++= "\n== SQL ==\n" + above.foreach(builder ++= _ += '\n') + builder ++= (0 until p).map(_ => "-").mkString("") ++= "^^^\n" + below.foreach(builder ++= _ += '\n') + } + case _ => + command.foreach { cmd => + builder ++= "\n== SQL ==\n" ++= cmd + } + } + builder.toString + } + + def withCommand(cmd: String): IcebergParseException = { + new IcebergParseException(Option(cmd), message, start, stop) + } +} \ No newline at end of file diff --git a/spark/v3.4/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/parser/extensions/IcebergSqlExtensionsAstBuilder.scala b/spark/v3.4/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/parser/extensions/IcebergSqlExtensionsAstBuilder.scala new file mode 100644 index 000000000000..f758cb08fd3d --- /dev/null +++ b/spark/v3.4/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/parser/extensions/IcebergSqlExtensionsAstBuilder.scala @@ -0,0 +1,381 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.spark.sql.catalyst.parser.extensions + +import java.util.Locale +import java.util.concurrent.TimeUnit +import org.antlr.v4.runtime._ +import org.antlr.v4.runtime.misc.Interval +import org.antlr.v4.runtime.tree.ParseTree +import org.antlr.v4.runtime.tree.TerminalNode +import org.apache.iceberg.DistributionMode +import org.apache.iceberg.NullOrder +import org.apache.iceberg.SortDirection +import org.apache.iceberg.expressions.Term +import org.apache.iceberg.spark.Spark3Util +import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.catalyst.expressions.Expression +import org.apache.spark.sql.catalyst.expressions.Literal +import org.apache.spark.sql.catalyst.parser.ParserInterface +import org.apache.spark.sql.catalyst.parser.extensions.IcebergParserUtils.withOrigin +import org.apache.spark.sql.catalyst.parser.extensions.IcebergSqlExtensionsParser._ +import org.apache.spark.sql.catalyst.plans.logical.AddPartitionField +import org.apache.spark.sql.catalyst.plans.logical.BranchOptions +import org.apache.spark.sql.catalyst.plans.logical.CallArgument +import org.apache.spark.sql.catalyst.plans.logical.CallStatement +import org.apache.spark.sql.catalyst.plans.logical.CreateOrReplaceBranch +import org.apache.spark.sql.catalyst.plans.logical.CreateOrReplaceTag +import org.apache.spark.sql.catalyst.plans.logical.DropBranch +import org.apache.spark.sql.catalyst.plans.logical.DropIdentifierFields +import org.apache.spark.sql.catalyst.plans.logical.DropPartitionField +import org.apache.spark.sql.catalyst.plans.logical.DropTag +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.catalyst.plans.logical.NamedArgument +import org.apache.spark.sql.catalyst.plans.logical.PositionalArgument +import org.apache.spark.sql.catalyst.plans.logical.ReplacePartitionField +import org.apache.spark.sql.catalyst.plans.logical.SetIdentifierFields +import org.apache.spark.sql.catalyst.plans.logical.SetWriteDistributionAndOrdering +import org.apache.spark.sql.catalyst.plans.logical.TagOptions +import org.apache.spark.sql.catalyst.trees.CurrentOrigin +import org.apache.spark.sql.catalyst.trees.Origin +import org.apache.spark.sql.connector.expressions +import org.apache.spark.sql.connector.expressions.ApplyTransform +import org.apache.spark.sql.connector.expressions.FieldReference +import org.apache.spark.sql.connector.expressions.IdentityTransform +import org.apache.spark.sql.connector.expressions.LiteralValue +import org.apache.spark.sql.connector.expressions.Transform +import scala.jdk.CollectionConverters._ + +class IcebergSqlExtensionsAstBuilder(delegate: ParserInterface) extends IcebergSqlExtensionsBaseVisitor[AnyRef] { + + private def toBuffer[T](list: java.util.List[T]): scala.collection.mutable.Buffer[T] = list.asScala + private def toSeq[T](list: java.util.List[T]): Seq[T] = toBuffer(list).toSeq + + /** + * Create a [[CallStatement]] for a stored procedure call. + */ + override def visitCall(ctx: CallContext): CallStatement = withOrigin(ctx) { + val name = toSeq(ctx.multipartIdentifier.parts).map(_.getText) + val args = toSeq(ctx.callArgument).map(typedVisit[CallArgument]) + CallStatement(name, args) + } + + /** + * Create an ADD PARTITION FIELD logical command. + */ + override def visitAddPartitionField(ctx: AddPartitionFieldContext): AddPartitionField = withOrigin(ctx) { + AddPartitionField( + typedVisit[Seq[String]](ctx.multipartIdentifier), + typedVisit[Transform](ctx.transform), + Option(ctx.name).map(_.getText)) + } + + /** + * Create a DROP PARTITION FIELD logical command. + */ + override def visitDropPartitionField(ctx: DropPartitionFieldContext): DropPartitionField = withOrigin(ctx) { + DropPartitionField( + typedVisit[Seq[String]](ctx.multipartIdentifier), + typedVisit[Transform](ctx.transform)) + } + + /** + * Create a CREATE OR REPLACE BRANCH logical command. + */ + override def visitCreateOrReplaceBranch(ctx: CreateOrReplaceBranchContext): CreateOrReplaceBranch = withOrigin(ctx) { + val createOrReplaceBranchClause = ctx.createReplaceBranchClause() + + val branchName = createOrReplaceBranchClause.identifier() + val branchOptionsContext = Option(createOrReplaceBranchClause.branchOptions()) + val snapshotId = branchOptionsContext.flatMap(branchOptions => Option(branchOptions.snapshotId())) + .map(_.getText.toLong) + val snapshotRetention = branchOptionsContext.flatMap(branchOptions => Option(branchOptions.snapshotRetention())) + val minSnapshotsToKeep = snapshotRetention.flatMap(retention => Option(retention.minSnapshotsToKeep())) + .map(minSnapshots => minSnapshots.number().getText.toLong) + val maxSnapshotAgeMs = snapshotRetention + .flatMap(retention => Option(retention.maxSnapshotAge())) + .map(retention => TimeUnit.valueOf(retention.timeUnit().getText.toUpperCase(Locale.ENGLISH)) + .toMillis(retention.number().getText.toLong)) + val branchRetention = branchOptionsContext.flatMap(branchOptions => Option(branchOptions.refRetain())) + val branchRefAgeMs = branchRetention.map(retain => + TimeUnit.valueOf(retain.timeUnit().getText.toUpperCase(Locale.ENGLISH)).toMillis(retain.number().getText.toLong)) + val replace = ctx.createReplaceBranchClause().REPLACE() != null + val ifNotExists = createOrReplaceBranchClause.EXISTS() != null + + val branchOptions = BranchOptions( + snapshotId, + minSnapshotsToKeep, + maxSnapshotAgeMs, + branchRefAgeMs + ) + + CreateOrReplaceBranch( + typedVisit[Seq[String]](ctx.multipartIdentifier), + branchName.getText, + branchOptions, + replace, + ifNotExists) + } + + /** + * Create an CREATE OR REPLACE TAG logical command. + */ + override def visitCreateOrReplaceTag(ctx: CreateOrReplaceTagContext): CreateOrReplaceTag = withOrigin(ctx) { + val createTagClause = ctx.createReplaceTagClause() + + val tagName = createTagClause.identifier().getText + + val tagOptionsContext = Option(createTagClause.tagOptions()) + val snapshotId = tagOptionsContext.flatMap(tagOptions => Option(tagOptions.snapshotId())) + .map(_.getText.toLong) + val tagRetain = tagOptionsContext.flatMap(tagOptions => Option(tagOptions.refRetain())) + val tagRefAgeMs = tagRetain.map(retain => + TimeUnit.valueOf(retain.timeUnit().getText.toUpperCase(Locale.ENGLISH)).toMillis(retain.number().getText.toLong)) + val tagOptions = TagOptions( + snapshotId, + tagRefAgeMs + ) + + val replace = createTagClause.REPLACE() != null + val ifNotExists = createTagClause.EXISTS() != null + + CreateOrReplaceTag(typedVisit[Seq[String]](ctx.multipartIdentifier), + tagName, + tagOptions, + replace, + ifNotExists) + } + + /** + * Create an DROP BRANCH logical command. + */ + override def visitDropBranch(ctx: DropBranchContext): DropBranch = withOrigin(ctx) { + DropBranch(typedVisit[Seq[String]](ctx.multipartIdentifier), ctx.identifier().getText, ctx.EXISTS() != null) + } + + /** + * Create an DROP TAG logical command. + */ + override def visitDropTag(ctx: DropTagContext): DropTag = withOrigin(ctx) { + DropTag(typedVisit[Seq[String]](ctx.multipartIdentifier), ctx.identifier().getText, ctx.EXISTS() != null) + } + + /** + * Create an REPLACE PARTITION FIELD logical command. + */ + override def visitReplacePartitionField(ctx: ReplacePartitionFieldContext): ReplacePartitionField = withOrigin(ctx) { + ReplacePartitionField( + typedVisit[Seq[String]](ctx.multipartIdentifier), + typedVisit[Transform](ctx.transform(0)), + typedVisit[Transform](ctx.transform(1)), + Option(ctx.name).map(_.getText)) + } + + /** + * Create an SET IDENTIFIER FIELDS logical command. + */ + override def visitSetIdentifierFields(ctx: SetIdentifierFieldsContext): SetIdentifierFields = withOrigin(ctx) { + SetIdentifierFields( + typedVisit[Seq[String]](ctx.multipartIdentifier), + toSeq(ctx.fieldList.fields).map(_.getText)) + } + + /** + * Create an DROP IDENTIFIER FIELDS logical command. + */ + override def visitDropIdentifierFields(ctx: DropIdentifierFieldsContext): DropIdentifierFields = withOrigin(ctx) { + DropIdentifierFields( + typedVisit[Seq[String]](ctx.multipartIdentifier), + toSeq(ctx.fieldList.fields).map(_.getText)) + } + + /** + * Create a [[SetWriteDistributionAndOrdering]] for changing the write distribution and ordering. + */ + override def visitSetWriteDistributionAndOrdering( + ctx: SetWriteDistributionAndOrderingContext): SetWriteDistributionAndOrdering = { + + val tableName = typedVisit[Seq[String]](ctx.multipartIdentifier) + + val (distributionSpec, orderingSpec) = toDistributionAndOrderingSpec(ctx.writeSpec) + + if (distributionSpec == null && orderingSpec == null) { + throw new AnalysisException( + "ALTER TABLE has no changes: missing both distribution and ordering clauses") + } + + val distributionMode = if (distributionSpec != null) { + DistributionMode.HASH + } else if (orderingSpec.UNORDERED != null || orderingSpec.LOCALLY != null) { + DistributionMode.NONE + } else { + DistributionMode.RANGE + } + + val ordering = if (orderingSpec != null && orderingSpec.order != null) { + toSeq(orderingSpec.order.fields).map(typedVisit[(Term, SortDirection, NullOrder)]) + } else { + Seq.empty + } + + SetWriteDistributionAndOrdering(tableName, distributionMode, ordering) + } + + private def toDistributionAndOrderingSpec( + writeSpec: WriteSpecContext): (WriteDistributionSpecContext, WriteOrderingSpecContext) = { + + if (writeSpec.writeDistributionSpec.size > 1) { + throw new AnalysisException("ALTER TABLE contains multiple distribution clauses") + } + + if (writeSpec.writeOrderingSpec.size > 1) { + throw new AnalysisException("ALTER TABLE contains multiple ordering clauses") + } + + val distributionSpec = toBuffer(writeSpec.writeDistributionSpec).headOption.orNull + val orderingSpec = toBuffer(writeSpec.writeOrderingSpec).headOption.orNull + + (distributionSpec, orderingSpec) + } + + /** + * Create an order field. + */ + override def visitOrderField(ctx: OrderFieldContext): (Term, SortDirection, NullOrder) = { + val term = Spark3Util.toIcebergTerm(typedVisit[Transform](ctx.transform)) + val direction = Option(ctx.ASC).map(_ => SortDirection.ASC) + .orElse(Option(ctx.DESC).map(_ => SortDirection.DESC)) + .getOrElse(SortDirection.ASC) + val nullOrder = Option(ctx.FIRST).map(_ => NullOrder.NULLS_FIRST) + .orElse(Option(ctx.LAST).map(_ => NullOrder.NULLS_LAST)) + .getOrElse(if (direction == SortDirection.ASC) NullOrder.NULLS_FIRST else NullOrder.NULLS_LAST) + (term, direction, nullOrder) + } + + /** + * Create an IdentityTransform for a column reference. + */ + override def visitIdentityTransform(ctx: IdentityTransformContext): Transform = withOrigin(ctx) { + IdentityTransform(FieldReference(typedVisit[Seq[String]](ctx.multipartIdentifier()))) + } + + /** + * Create a named Transform from argument expressions. + */ + override def visitApplyTransform(ctx: ApplyTransformContext): Transform = withOrigin(ctx) { + val args = toSeq(ctx.arguments).map(typedVisit[expressions.Expression]) + ApplyTransform(ctx.transformName.getText, args) + } + + /** + * Create a transform argument from a column reference or a constant. + */ + override def visitTransformArgument(ctx: TransformArgumentContext): expressions.Expression = withOrigin(ctx) { + val reference = Option(ctx.multipartIdentifier()) + .map(typedVisit[Seq[String]]) + .map(FieldReference(_)) + val literal = Option(ctx.constant) + .map(visitConstant) + .map(lit => LiteralValue(lit.value, lit.dataType)) + reference.orElse(literal) + .getOrElse(throw new IcebergParseException(s"Invalid transform argument", ctx)) + } + + /** + * Return a multi-part identifier as Seq[String]. + */ + override def visitMultipartIdentifier(ctx: MultipartIdentifierContext): Seq[String] = withOrigin(ctx) { + toSeq(ctx.parts).map(_.getText) + } + + override def visitSingleOrder(ctx: SingleOrderContext): Seq[(Term, SortDirection, NullOrder)] = withOrigin(ctx) { + toSeq(ctx.order.fields).map(typedVisit[(Term, SortDirection, NullOrder)]) + } + + /** + * Create a positional argument in a stored procedure call. + */ + override def visitPositionalArgument(ctx: PositionalArgumentContext): CallArgument = withOrigin(ctx) { + val expr = typedVisit[Expression](ctx.expression) + PositionalArgument(expr) + } + + /** + * Create a named argument in a stored procedure call. + */ + override def visitNamedArgument(ctx: NamedArgumentContext): CallArgument = withOrigin(ctx) { + val name = ctx.identifier.getText + val expr = typedVisit[Expression](ctx.expression) + NamedArgument(name, expr) + } + + override def visitSingleStatement(ctx: SingleStatementContext): LogicalPlan = withOrigin(ctx) { + visit(ctx.statement).asInstanceOf[LogicalPlan] + } + + def visitConstant(ctx: ConstantContext): Literal = { + delegate.parseExpression(ctx.getText).asInstanceOf[Literal] + } + + override def visitExpression(ctx: ExpressionContext): Expression = { + // reconstruct the SQL string and parse it using the main Spark parser + // while we can avoid the logic to build Spark expressions, we still have to parse them + // we cannot call ctx.getText directly since it will not render spaces correctly + // that's why we need to recurse down the tree in reconstructSqlString + val sqlString = reconstructSqlString(ctx) + delegate.parseExpression(sqlString) + } + + private def reconstructSqlString(ctx: ParserRuleContext): String = { + toBuffer(ctx.children).map { + case c: ParserRuleContext => reconstructSqlString(c) + case t: TerminalNode => t.getText + }.mkString(" ") + } + + private def typedVisit[T](ctx: ParseTree): T = { + ctx.accept(this).asInstanceOf[T] + } +} + +/* Partially copied from Apache Spark's Parser to avoid dependency on Spark Internals */ +object IcebergParserUtils { + + private[sql] def withOrigin[T](ctx: ParserRuleContext)(f: => T): T = { + val current = CurrentOrigin.get + CurrentOrigin.set(position(ctx.getStart)) + try { + f + } finally { + CurrentOrigin.set(current) + } + } + + private[sql] def position(token: Token): Origin = { + val opt = Option(token) + Origin(opt.map(_.getLine), opt.map(_.getCharPositionInLine)) + } + + /** Get the command which created the token. */ + private[sql] def command(ctx: ParserRuleContext): String = { + val stream = ctx.getStart.getInputStream + stream.getText(Interval.of(0, stream.size() - 1)) + } +} diff --git a/spark/v3.4/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/planning/RewrittenRowLevelCommand.scala b/spark/v3.4/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/planning/RewrittenRowLevelCommand.scala new file mode 100644 index 000000000000..b6b39fd77eab --- /dev/null +++ b/spark/v3.4/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/planning/RewrittenRowLevelCommand.scala @@ -0,0 +1,113 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.spark.sql.catalyst.planning + +import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.catalyst.plans.logical.ReplaceIcebergData +import org.apache.spark.sql.catalyst.plans.logical.RowLevelCommand +import org.apache.spark.sql.catalyst.plans.logical.UpdateIcebergTable +import org.apache.spark.sql.catalyst.plans.logical.WriteIcebergDelta +import org.apache.spark.sql.connector.catalog.Table +import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation +import org.apache.spark.sql.execution.datasources.v2.DataSourceV2ScanRelation + +/** + * An extractor for operations such as DELETE and MERGE that require rewriting data. + * + * This class extracts the following entities: + * - the row-level command (such as DeleteFromIcebergTable); + * - the read relation in the rewrite plan that can be either DataSourceV2Relation or + * DataSourceV2ScanRelation depending on whether the planning has already happened; + * - the current rewrite plan. + */ +object RewrittenRowLevelCommand { + type ReturnType = (RowLevelCommand, LogicalPlan, LogicalPlan) + + def unapply(plan: LogicalPlan): Option[ReturnType] = plan match { + case c: RowLevelCommand if c.rewritePlan.nonEmpty => + val rewritePlan = c.rewritePlan.get + + // both ReplaceData and WriteDelta reference a write relation + // but the corresponding read relation should be at the bottom of the write plan + // both the write and read relations will share the same RowLevelOperationTable object + // that's why it is safe to use reference equality to find the needed read relation + + val allowScanDuplication = c match { + // group-based updates that rely on the union approach may have multiple identical scans + case _: UpdateIcebergTable if rewritePlan.isInstanceOf[ReplaceIcebergData] => true + case _ => false + } + + rewritePlan match { + case rd @ ReplaceIcebergData(DataSourceV2Relation(table, _, _, _, _), query, _, _) => + val readRelation = findReadRelation(table, query, allowScanDuplication) + readRelation.map((c, _, rd)) + case wd @ WriteIcebergDelta(DataSourceV2Relation(table, _, _, _, _), query, _, _, _) => + val readRelation = findReadRelation(table, query, allowScanDuplication) + readRelation.map((c, _, wd)) + case _ => + None + } + + case _ => + None + } + + private def findReadRelation( + table: Table, + plan: LogicalPlan, + allowScanDuplication: Boolean): Option[LogicalPlan] = { + + val readRelations = plan.collect { + case r: DataSourceV2Relation if r.table eq table => r + case r: DataSourceV2ScanRelation if r.relation.table eq table => r + } + + // in some cases, the optimizer replaces the v2 read relation with a local relation + // for example, there is no reason to query the table if the condition is always false + // that's why it is valid not to find the corresponding v2 read relation + + readRelations match { + case relations if relations.isEmpty => + None + + case Seq(relation) => + Some(relation) + + case Seq(relation1: DataSourceV2Relation, relation2: DataSourceV2Relation) + if allowScanDuplication && (relation1.table eq relation2.table) => + Some(relation1) + + case Seq(relation1: DataSourceV2ScanRelation, relation2: DataSourceV2ScanRelation) + if allowScanDuplication && (relation1.scan eq relation2.scan) => + Some(relation1) + + case Seq(relation1, relation2) if allowScanDuplication => + throw new AnalysisException(s"Row-level read relations don't match: $relation1, $relation2") + + case relations if allowScanDuplication => + throw new AnalysisException(s"Expected up to two row-level read relations: $relations") + + case relations => + throw new AnalysisException(s"Expected only one row-level read relation: $relations") + } + } +} diff --git a/spark/v3.4/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/AddPartitionField.scala b/spark/v3.4/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/AddPartitionField.scala new file mode 100644 index 000000000000..e8b1b2941161 --- /dev/null +++ b/spark/v3.4/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/AddPartitionField.scala @@ -0,0 +1,33 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.spark.sql.catalyst.plans.logical + +import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.connector.expressions.Transform + +case class AddPartitionField(table: Seq[String], transform: Transform, name: Option[String]) extends LeafCommand { + import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._ + + override lazy val output: Seq[Attribute] = Nil + + override def simpleString(maxFields: Int): String = { + s"AddPartitionField ${table.quoted} ${name.map(n => s"$n=").getOrElse("")}${transform.describe}" + } +} diff --git a/spark/v3.4/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/BranchOptions.scala b/spark/v3.4/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/BranchOptions.scala new file mode 100644 index 000000000000..4d7e0a086bda --- /dev/null +++ b/spark/v3.4/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/BranchOptions.scala @@ -0,0 +1,23 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.spark.sql.catalyst.plans.logical + +case class BranchOptions (snapshotId: Option[Long], numSnapshots: Option[Long], + snapshotRetain: Option[Long], snapshotRefRetain: Option[Long]) diff --git a/spark/v3.4/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/Call.scala b/spark/v3.4/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/Call.scala new file mode 100644 index 000000000000..9616dae5a8d3 --- /dev/null +++ b/spark/v3.4/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/Call.scala @@ -0,0 +1,33 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.spark.sql.catalyst.plans.logical + +import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.catalyst.expressions.Expression +import org.apache.spark.sql.catalyst.util.truncatedString +import org.apache.spark.sql.connector.iceberg.catalog.Procedure + +case class Call(procedure: Procedure, args: Seq[Expression]) extends LeafCommand { + override lazy val output: Seq[Attribute] = procedure.outputType.toAttributes + + override def simpleString(maxFields: Int): String = { + s"Call${truncatedString(output.toSeq, "[", ", ", "]", maxFields)} ${procedure.description}" + } +} diff --git a/spark/v3.4/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/CreateOrReplaceBranch.scala b/spark/v3.4/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/CreateOrReplaceBranch.scala new file mode 100644 index 000000000000..2a22484499cf --- /dev/null +++ b/spark/v3.4/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/CreateOrReplaceBranch.scala @@ -0,0 +1,38 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.spark.sql.catalyst.plans.logical + +import org.apache.spark.sql.catalyst.expressions.Attribute + +case class CreateOrReplaceBranch( + table: Seq[String], + branch: String, + branchOptions: BranchOptions, + replace: Boolean, + ifNotExists: Boolean) extends LeafCommand { + + import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._ + + override lazy val output: Seq[Attribute] = Nil + + override def simpleString(maxFields: Int): String = { + s"CreateOrReplaceBranch branch: ${branch} for table: ${table.quoted}" + } +} diff --git a/spark/v3.4/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/CreateOrReplaceTag.scala b/spark/v3.4/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/CreateOrReplaceTag.scala new file mode 100644 index 000000000000..e48f7d8ed04c --- /dev/null +++ b/spark/v3.4/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/CreateOrReplaceTag.scala @@ -0,0 +1,38 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.spark.sql.catalyst.plans.logical + +import org.apache.spark.sql.catalyst.expressions.Attribute + +case class CreateOrReplaceTag( + table: Seq[String], + tag: String, + tagOptions: TagOptions, + replace: Boolean, + ifNotExists: Boolean) extends LeafCommand { + + import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._ + + override lazy val output: Seq[Attribute] = Nil + + override def simpleString(maxFields: Int): String = { + s"CreateOrReplaceTag tag: ${tag} for table: ${table.quoted}" + } +} diff --git a/spark/v3.4/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/DeleteFromIcebergTable.scala b/spark/v3.4/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/DeleteFromIcebergTable.scala new file mode 100644 index 000000000000..d1268e416c50 --- /dev/null +++ b/spark/v3.4/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/DeleteFromIcebergTable.scala @@ -0,0 +1,49 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.spark.sql.catalyst.plans.logical + +import org.apache.spark.sql.catalyst.expressions.Expression + +case class DeleteFromIcebergTable( + table: LogicalPlan, + condition: Option[Expression], + rewritePlan: Option[LogicalPlan] = None) extends RowLevelCommand { + + override def children: Seq[LogicalPlan] = if (rewritePlan.isDefined) { + table :: rewritePlan.get :: Nil + } else { + table :: Nil + } + + override def withNewRewritePlan(newRewritePlan: LogicalPlan): RowLevelCommand = { + copy(rewritePlan = Some(newRewritePlan)) + } + + override protected def withNewChildrenInternal( + newChildren: IndexedSeq[LogicalPlan]): DeleteFromIcebergTable = { + if (newChildren.size == 1) { + copy(table = newChildren.head, rewritePlan = None) + } else { + require(newChildren.size == 2, "DeleteFromIcebergTable expects either one or two children") + val Seq(newTable, newRewritePlan) = newChildren.take(2) + copy(table = newTable, rewritePlan = Some(newRewritePlan)) + } + } +} diff --git a/spark/v3.4/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/DropBranch.scala b/spark/v3.4/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/DropBranch.scala new file mode 100644 index 000000000000..bee0b0fae688 --- /dev/null +++ b/spark/v3.4/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/DropBranch.scala @@ -0,0 +1,33 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.spark.sql.catalyst.plans.logical + +import org.apache.spark.sql.catalyst.expressions.Attribute + +case class DropBranch(table: Seq[String], branch: String, ifExists: Boolean) extends LeafCommand { + + import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._ + + override lazy val output: Seq[Attribute] = Nil + + override def simpleString(maxFields: Int): String = { + s"DropBranch branch: ${branch} for table: ${table.quoted}" + } +} diff --git a/spark/v3.4/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/DropIdentifierFields.scala b/spark/v3.4/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/DropIdentifierFields.scala new file mode 100644 index 000000000000..29dd686a0fba --- /dev/null +++ b/spark/v3.4/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/DropIdentifierFields.scala @@ -0,0 +1,34 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.spark.sql.catalyst.plans.logical + +import org.apache.spark.sql.catalyst.expressions.Attribute + +case class DropIdentifierFields( + table: Seq[String], + fields: Seq[String]) extends LeafCommand { + import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._ + + override lazy val output: Seq[Attribute] = Nil + + override def simpleString(maxFields: Int): String = { + s"DropIdentifierFields ${table.quoted} (${fields.quoted})" + } +} diff --git a/spark/v3.4/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/DropPartitionField.scala b/spark/v3.4/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/DropPartitionField.scala new file mode 100644 index 000000000000..fb1451324182 --- /dev/null +++ b/spark/v3.4/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/DropPartitionField.scala @@ -0,0 +1,33 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.spark.sql.catalyst.plans.logical + +import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.connector.expressions.Transform + +case class DropPartitionField(table: Seq[String], transform: Transform) extends LeafCommand { + import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._ + + override lazy val output: Seq[Attribute] = Nil + + override def simpleString(maxFields: Int): String = { + s"DropPartitionField ${table.quoted} ${transform.describe}" + } +} diff --git a/spark/v3.4/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/DropTag.scala b/spark/v3.4/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/DropTag.scala new file mode 100644 index 000000000000..7e4b38e74d2f --- /dev/null +++ b/spark/v3.4/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/DropTag.scala @@ -0,0 +1,33 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.spark.sql.catalyst.plans.logical + +import org.apache.spark.sql.catalyst.expressions.Attribute + +case class DropTag(table: Seq[String], tag: String, ifExists: Boolean) extends LeafCommand { + + import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._ + + override lazy val output: Seq[Attribute] = Nil + + override def simpleString(maxFields: Int): String = { + s"DropTag tag: ${tag} for table: ${table.quoted}" + } +} diff --git a/spark/v3.4/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/MergeIntoIcebergTable.scala b/spark/v3.4/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/MergeIntoIcebergTable.scala new file mode 100644 index 000000000000..8f84851dcda2 --- /dev/null +++ b/spark/v3.4/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/MergeIntoIcebergTable.scala @@ -0,0 +1,77 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.spark.sql.catalyst.plans.logical + +import org.apache.spark.sql.catalyst.expressions.AssignmentUtils +import org.apache.spark.sql.catalyst.expressions.Expression + +case class MergeIntoIcebergTable( + targetTable: LogicalPlan, + sourceTable: LogicalPlan, + mergeCondition: Expression, + matchedActions: Seq[MergeAction], + notMatchedActions: Seq[MergeAction], + rewritePlan: Option[LogicalPlan] = None) extends RowLevelCommand { + + lazy val aligned: Boolean = { + val matchedActionsAligned = matchedActions.forall { + case UpdateAction(_, assignments) => + AssignmentUtils.aligned(targetTable, assignments) + case _: DeleteAction => + true + case _ => + false + } + + val notMatchedActionsAligned = notMatchedActions.forall { + case InsertAction(_, assignments) => + AssignmentUtils.aligned(targetTable, assignments) + case _ => + false + } + + matchedActionsAligned && notMatchedActionsAligned + } + + def condition: Option[Expression] = Some(mergeCondition) + + override def children: Seq[LogicalPlan] = if (rewritePlan.isDefined) { + targetTable :: sourceTable :: rewritePlan.get :: Nil + } else { + targetTable :: sourceTable :: Nil + } + + override def withNewRewritePlan(newRewritePlan: LogicalPlan): RowLevelCommand = { + copy(rewritePlan = Some(newRewritePlan)) + } + + override protected def withNewChildrenInternal( + newChildren: IndexedSeq[LogicalPlan]): MergeIntoIcebergTable = { + + newChildren match { + case Seq(newTarget, newSource) => + copy(targetTable = newTarget, sourceTable = newSource, rewritePlan = None) + case Seq(newTarget, newSource, newRewritePlan) => + copy(targetTable = newTarget, sourceTable = newSource, rewritePlan = Some(newRewritePlan)) + case _ => + throw new IllegalArgumentException("MergeIntoIcebergTable expects either two or three children") + } + } +} diff --git a/spark/v3.4/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/MergeRows.scala b/spark/v3.4/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/MergeRows.scala new file mode 100644 index 000000000000..3607194fe8c8 --- /dev/null +++ b/spark/v3.4/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/MergeRows.scala @@ -0,0 +1,56 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.spark.sql.catalyst.plans.logical + +import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.catalyst.expressions.AttributeSet +import org.apache.spark.sql.catalyst.expressions.Expression +import org.apache.spark.sql.catalyst.util.truncatedString + +case class MergeRows( + isSourceRowPresent: Expression, + isTargetRowPresent: Expression, + matchedConditions: Seq[Expression], + matchedOutputs: Seq[Seq[Expression]], + notMatchedConditions: Seq[Expression], + notMatchedOutputs: Seq[Seq[Expression]], + targetOutput: Seq[Expression], + rowIdAttrs: Seq[Attribute], + performCardinalityCheck: Boolean, + emitNotMatchedTargetRows: Boolean, + output: Seq[Attribute], + child: LogicalPlan) extends UnaryNode { + + require(targetOutput.nonEmpty || !emitNotMatchedTargetRows) + + override lazy val producedAttributes: AttributeSet = { + AttributeSet(output.filterNot(attr => inputSet.contains(attr))) + } + + override lazy val references: AttributeSet = child.outputSet + + override def simpleString(maxFields: Int): String = { + s"MergeRows${truncatedString(output, "[", ", ", "]", maxFields)}" + } + + override protected def withNewChildInternal(newChild: LogicalPlan): LogicalPlan = { + copy(child = newChild) + } +} diff --git a/spark/v3.4/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/NoStatsUnaryNode.scala b/spark/v3.4/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/NoStatsUnaryNode.scala new file mode 100644 index 000000000000..c21df71f069d --- /dev/null +++ b/spark/v3.4/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/NoStatsUnaryNode.scala @@ -0,0 +1,31 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.spark.sql.catalyst.plans.logical + +import org.apache.spark.sql.catalyst.expressions.Attribute + +case class NoStatsUnaryNode(child: LogicalPlan) extends UnaryNode { + override def output: Seq[Attribute] = child.output + override def stats: Statistics = Statistics(Long.MaxValue) + + override protected def withNewChildInternal(newChild: LogicalPlan): LogicalPlan = { + copy(child = newChild) + } +} diff --git a/spark/v3.4/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/ReplaceIcebergData.scala b/spark/v3.4/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/ReplaceIcebergData.scala new file mode 100644 index 000000000000..2b741bef121a --- /dev/null +++ b/spark/v3.4/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/ReplaceIcebergData.scala @@ -0,0 +1,69 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.spark.sql.catalyst.plans.logical + +import org.apache.spark.sql.catalyst.analysis.NamedRelation +import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.catalyst.expressions.AttributeSet +import org.apache.spark.sql.catalyst.util.CharVarcharUtils +import org.apache.spark.sql.connector.write.Write +import org.apache.spark.sql.types.DataType + +/** + * Replace data in an existing table. + */ +case class ReplaceIcebergData( + table: NamedRelation, + query: LogicalPlan, + originalTable: NamedRelation, + write: Option[Write] = None) extends V2WriteCommandLike { + + override lazy val references: AttributeSet = query.outputSet + override lazy val stringArgs: Iterator[Any] = Iterator(table, query, write) + + // the incoming query may include metadata columns + lazy val dataInput: Seq[Attribute] = { + val tableAttrNames = table.output.map(_.name) + query.output.filter(attr => tableAttrNames.exists(conf.resolver(_, attr.name))) + } + + override def outputResolved: Boolean = { + assert(table.resolved && query.resolved, + "`outputResolved` can only be called when `table` and `query` are both resolved.") + + // take into account only incoming data columns and ignore metadata columns in the query + // they will be discarded after the logical write is built in the optimizer + // metadata columns may be needed to request a correct distribution or ordering + // but are not passed back to the data source during writes + + table.skipSchemaResolution || (dataInput.size == table.output.size && + dataInput.zip(table.output).forall { case (inAttr, outAttr) => + val outType = CharVarcharUtils.getRawType(outAttr.metadata).getOrElse(outAttr.dataType) + // names and types must match, nullability must be compatible + inAttr.name == outAttr.name && + DataType.equalsIgnoreCompatibleNullability(inAttr.dataType, outType) && + (outAttr.nullable || !inAttr.nullable) + }) + } + + override protected def withNewChildInternal(newChild: LogicalPlan): ReplaceIcebergData = { + copy(query = newChild) + } +} diff --git a/spark/v3.4/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/ReplacePartitionField.scala b/spark/v3.4/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/ReplacePartitionField.scala new file mode 100644 index 000000000000..8c660c6f37b1 --- /dev/null +++ b/spark/v3.4/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/ReplacePartitionField.scala @@ -0,0 +1,38 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.spark.sql.catalyst.plans.logical + +import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.connector.expressions.Transform + +case class ReplacePartitionField( + table: Seq[String], + transformFrom: Transform, + transformTo: Transform, + name: Option[String]) extends LeafCommand { + import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._ + + override lazy val output: Seq[Attribute] = Nil + + override def simpleString(maxFields: Int): String = { + s"ReplacePartitionField ${table.quoted} ${transformFrom.describe} " + + s"with ${name.map(n => s"$n=").getOrElse("")}${transformTo.describe}" + } +} diff --git a/spark/v3.4/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/RowLevelCommand.scala b/spark/v3.4/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/RowLevelCommand.scala new file mode 100644 index 000000000000..837ee963bcea --- /dev/null +++ b/spark/v3.4/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/RowLevelCommand.scala @@ -0,0 +1,28 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.spark.sql.catalyst.plans.logical + +import org.apache.spark.sql.catalyst.expressions.Expression + +trait RowLevelCommand extends Command with SupportsSubquery { + def condition: Option[Expression] + def rewritePlan: Option[LogicalPlan] + def withNewRewritePlan(newRewritePlan: LogicalPlan): RowLevelCommand +} diff --git a/spark/v3.4/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/SetIdentifierFields.scala b/spark/v3.4/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/SetIdentifierFields.scala new file mode 100644 index 000000000000..a5fa28a617e7 --- /dev/null +++ b/spark/v3.4/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/SetIdentifierFields.scala @@ -0,0 +1,35 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.spark.sql.catalyst.plans.logical + +import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.connector.expressions.Transform + +case class SetIdentifierFields( + table: Seq[String], + fields: Seq[String]) extends LeafCommand { + import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._ + + override lazy val output: Seq[Attribute] = Nil + + override def simpleString(maxFields: Int): String = { + s"SetIdentifierFields ${table.quoted} (${fields.quoted})" + } +} diff --git a/spark/v3.4/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/TagOptions.scala b/spark/v3.4/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/TagOptions.scala new file mode 100644 index 000000000000..85e3b95f4aba --- /dev/null +++ b/spark/v3.4/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/TagOptions.scala @@ -0,0 +1,22 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.spark.sql.catalyst.plans.logical + +case class TagOptions(snapshotId: Option[Long], snapshotRefRetain: Option[Long]) diff --git a/spark/v3.4/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/UnresolvedMergeIntoIcebergTable.scala b/spark/v3.4/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/UnresolvedMergeIntoIcebergTable.scala new file mode 100644 index 000000000000..895aa733ff20 --- /dev/null +++ b/spark/v3.4/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/UnresolvedMergeIntoIcebergTable.scala @@ -0,0 +1,45 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.spark.sql.catalyst.plans.logical + +import org.apache.spark.sql.catalyst.expressions.Expression + +/** + * A node that hides the MERGE condition and actions from regular Spark resolution. + */ +case class UnresolvedMergeIntoIcebergTable( + targetTable: LogicalPlan, + sourceTable: LogicalPlan, + context: MergeIntoContext) extends BinaryCommand { + + def duplicateResolved: Boolean = targetTable.outputSet.intersect(sourceTable.outputSet).isEmpty + + override def left: LogicalPlan = targetTable + override def right: LogicalPlan = sourceTable + + override protected def withNewChildrenInternal(newLeft: LogicalPlan, newRight: LogicalPlan): LogicalPlan = { + copy(targetTable = newLeft, sourceTable = newRight) + } +} + +case class MergeIntoContext( + mergeCondition: Expression, + matchedActions: Seq[MergeAction], + notMatchedActions: Seq[MergeAction]) diff --git a/spark/v3.4/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/UpdateIcebergTable.scala b/spark/v3.4/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/UpdateIcebergTable.scala new file mode 100644 index 000000000000..790eb9380e3d --- /dev/null +++ b/spark/v3.4/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/UpdateIcebergTable.scala @@ -0,0 +1,53 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.spark.sql.catalyst.plans.logical + +import org.apache.spark.sql.catalyst.expressions.AssignmentUtils +import org.apache.spark.sql.catalyst.expressions.Expression + +case class UpdateIcebergTable( + table: LogicalPlan, + assignments: Seq[Assignment], + condition: Option[Expression], + rewritePlan: Option[LogicalPlan] = None) extends RowLevelCommand { + + lazy val aligned: Boolean = AssignmentUtils.aligned(table, assignments) + + override def children: Seq[LogicalPlan] = if (rewritePlan.isDefined) { + table :: rewritePlan.get :: Nil + } else { + table :: Nil + } + + override def withNewRewritePlan(newRewritePlan: LogicalPlan): RowLevelCommand = { + copy(rewritePlan = Some(newRewritePlan)) + } + + override protected def withNewChildrenInternal( + newChildren: IndexedSeq[LogicalPlan]): UpdateIcebergTable = { + if (newChildren.size == 1) { + copy(table = newChildren.head, rewritePlan = None) + } else { + require(newChildren.size == 2, "UpdateTable expects either one or two children") + val Seq(newTable, newRewritePlan) = newChildren.take(2) + copy(table = newTable, rewritePlan = Some(newRewritePlan)) + } + } +} diff --git a/spark/v3.4/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/V2WriteCommandLike.scala b/spark/v3.4/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/V2WriteCommandLike.scala new file mode 100644 index 000000000000..9192d74b7caf --- /dev/null +++ b/spark/v3.4/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/V2WriteCommandLike.scala @@ -0,0 +1,42 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.spark.sql.catalyst.plans.logical + +import org.apache.spark.sql.catalyst.analysis.NamedRelation +import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.catalyst.expressions.AttributeSet + +// a node similar to V2WriteCommand in Spark but does not extend Command +// as ReplaceData and WriteDelta that extend this trait are nested within other commands +trait V2WriteCommandLike extends UnaryNode { + def table: NamedRelation + def query: LogicalPlan + def outputResolved: Boolean + + override lazy val resolved: Boolean = table.resolved && query.resolved && outputResolved + + override def child: LogicalPlan = query + override def output: Seq[Attribute] = Seq.empty + override def producedAttributes: AttributeSet = outputSet + // Commands are eagerly executed. They will be converted to LocalRelation after the DataFrame + // is created. That said, the statistics of a command is useless. Here we just return a dummy + // statistics to avoid unnecessary statistics calculation of command's children. + override def stats: Statistics = Statistics.DUMMY +} diff --git a/spark/v3.4/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/WriteIcebergDelta.scala b/spark/v3.4/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/WriteIcebergDelta.scala new file mode 100644 index 000000000000..10db698b9b91 --- /dev/null +++ b/spark/v3.4/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/WriteIcebergDelta.scala @@ -0,0 +1,126 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.spark.sql.catalyst.plans.logical + +import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.catalyst.analysis.EliminateSubqueryAliases +import org.apache.spark.sql.catalyst.analysis.NamedRelation +import org.apache.spark.sql.catalyst.expressions.AttributeReference +import org.apache.spark.sql.catalyst.expressions.ExtendedV2ExpressionUtils +import org.apache.spark.sql.catalyst.expressions.NamedExpression +import org.apache.spark.sql.catalyst.util.CharVarcharUtils +import org.apache.spark.sql.catalyst.util.RowDeltaUtils.OPERATION_COLUMN +import org.apache.spark.sql.catalyst.util.WriteDeltaProjections +import org.apache.spark.sql.connector.write.DeltaWrite +import org.apache.spark.sql.connector.write.RowLevelOperationTable +import org.apache.spark.sql.connector.write.SupportsDelta +import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation +import org.apache.spark.sql.types.DataType +import org.apache.spark.sql.types.IntegerType +import org.apache.spark.sql.types.StructField + +/** + * Writes a delta of rows to an existing table. + */ +case class WriteIcebergDelta( + table: NamedRelation, + query: LogicalPlan, + originalTable: NamedRelation, + projections: WriteDeltaProjections, + write: Option[DeltaWrite] = None) extends V2WriteCommandLike { + + override protected lazy val stringArgs: Iterator[Any] = Iterator(table, query, write) + + private def operationResolved: Boolean = { + val attr = query.output.head + attr.name == OPERATION_COLUMN && attr.dataType == IntegerType && !attr.nullable + } + + private def operation: SupportsDelta = { + EliminateSubqueryAliases(table) match { + case DataSourceV2Relation(RowLevelOperationTable(_, operation), _, _, _, _) => + operation match { + case supportsDelta: SupportsDelta => + supportsDelta + case _ => + throw new AnalysisException(s"Operation $operation is not a delta operation") + } + case _ => + throw new AnalysisException(s"Cannot retrieve row-level operation from $table") + } + } + + private def rowAttrsResolved: Boolean = { + table.skipSchemaResolution || (projections.rowProjection match { + case Some(projection) => + table.output.size == projection.schema.size && + projection.schema.zip(table.output).forall { case (field, outAttr) => + isCompatible(field, outAttr) + } + case None => + true + }) + } + + private def rowIdAttrsResolved: Boolean = { + val rowIdAttrs = ExtendedV2ExpressionUtils.resolveRefs[AttributeReference]( + operation.rowId.toSeq, + originalTable) + + projections.rowIdProjection.schema.forall { field => + rowIdAttrs.exists(rowIdAttr => isCompatible(field, rowIdAttr)) + } + } + + private def metadataAttrsResolved: Boolean = { + projections.metadataProjection match { + case Some(projection) => + val metadataAttrs = ExtendedV2ExpressionUtils.resolveRefs[AttributeReference]( + operation.requiredMetadataAttributes.toSeq, + originalTable) + + projection.schema.forall { field => + metadataAttrs.exists(metadataAttr => isCompatible(field, metadataAttr)) + } + case None => + true + } + } + + private def isCompatible(projectionField: StructField, outAttr: NamedExpression): Boolean = { + val inType = CharVarcharUtils.getRawType(projectionField.metadata).getOrElse(outAttr.dataType) + val outType = CharVarcharUtils.getRawType(outAttr.metadata).getOrElse(outAttr.dataType) + // names and types must match, nullability must be compatible + projectionField.name == outAttr.name && + DataType.equalsIgnoreCompatibleNullability(inType, outType) && + (outAttr.nullable || !projectionField.nullable) + } + + override def outputResolved: Boolean = { + assert(table.resolved && query.resolved, + "`outputResolved` can only be called when `table` and `query` are both resolved.") + + operationResolved && rowAttrsResolved && rowIdAttrsResolved && metadataAttrsResolved + } + + override protected def withNewChildInternal(newChild: LogicalPlan): WriteIcebergDelta = { + copy(query = newChild) + } +} diff --git a/spark/v3.4/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statements.scala b/spark/v3.4/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statements.scala new file mode 100644 index 000000000000..be15f32bc1b8 --- /dev/null +++ b/spark/v3.4/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statements.scala @@ -0,0 +1,44 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.spark.sql.catalyst.plans.logical + +import org.apache.spark.sql.catalyst.expressions.Expression + +/** + * A CALL statement, as parsed from SQL. + */ +case class CallStatement(name: Seq[String], args: Seq[CallArgument]) extends LeafParsedStatement + +/** + * An argument in a CALL statement. + */ +sealed trait CallArgument { + def expr: Expression +} + +/** + * An argument in a CALL statement identified by name. + */ +case class NamedArgument(name: String, expr: Expression) extends CallArgument + +/** + * An argument in a CALL statement identified by position. + */ +case class PositionalArgument(expr: Expression) extends CallArgument diff --git a/spark/v3.4/spark-extensions/src/main/scala/org/apache/spark/sql/connector/expressions/TruncateTransform.scala b/spark/v3.4/spark-extensions/src/main/scala/org/apache/spark/sql/connector/expressions/TruncateTransform.scala new file mode 100644 index 000000000000..2a3269e2db1d --- /dev/null +++ b/spark/v3.4/spark-extensions/src/main/scala/org/apache/spark/sql/connector/expressions/TruncateTransform.scala @@ -0,0 +1,38 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.spark.sql.connector.expressions + +import org.apache.spark.sql.types.IntegerType + +private[sql] object TruncateTransform { + def unapply(expr: Expression): Option[(Int, FieldReference)] = expr match { + case transform: Transform => + transform match { + case NamedTransform("truncate", Seq(Ref(seq: Seq[String]), Lit(value: Int, IntegerType))) => + Some((value, FieldReference(seq))) + case NamedTransform("truncate", Seq(Lit(value: Int, IntegerType), Ref(seq: Seq[String]))) => + Some((value, FieldReference(seq))) + case _ => + None + } + case _ => + None + } +} diff --git a/spark/v3.4/spark-extensions/src/main/scala/org/apache/spark/sql/execution/datasources/v2/AddPartitionFieldExec.scala b/spark/v3.4/spark-extensions/src/main/scala/org/apache/spark/sql/execution/datasources/v2/AddPartitionFieldExec.scala new file mode 100644 index 000000000000..55f327f7e45e --- /dev/null +++ b/spark/v3.4/spark-extensions/src/main/scala/org/apache/spark/sql/execution/datasources/v2/AddPartitionFieldExec.scala @@ -0,0 +1,56 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.spark.sql.execution.datasources.v2 + +import org.apache.iceberg.spark.Spark3Util +import org.apache.iceberg.spark.source.SparkTable +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.connector.catalog.Identifier +import org.apache.spark.sql.connector.catalog.TableCatalog +import org.apache.spark.sql.connector.expressions.Transform + +case class AddPartitionFieldExec( + catalog: TableCatalog, + ident: Identifier, + transform: Transform, + name: Option[String]) extends LeafV2CommandExec { + import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._ + + override lazy val output: Seq[Attribute] = Nil + + override protected def run(): Seq[InternalRow] = { + catalog.loadTable(ident) match { + case iceberg: SparkTable => + iceberg.table.updateSpec() + .addField(name.orNull, Spark3Util.toIcebergTerm(transform)) + .commit() + + case table => + throw new UnsupportedOperationException(s"Cannot add partition field to non-Iceberg table: $table") + } + + Nil + } + + override def simpleString(maxFields: Int): String = { + s"AddPartitionField ${catalog.name}.${ident.quoted} ${name.map(n => s"$n=").getOrElse("")}${transform.describe}" + } +} diff --git a/spark/v3.4/spark-extensions/src/main/scala/org/apache/spark/sql/execution/datasources/v2/CallExec.scala b/spark/v3.4/spark-extensions/src/main/scala/org/apache/spark/sql/execution/datasources/v2/CallExec.scala new file mode 100644 index 000000000000..f66962a8c453 --- /dev/null +++ b/spark/v3.4/spark-extensions/src/main/scala/org/apache/spark/sql/execution/datasources/v2/CallExec.scala @@ -0,0 +1,40 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.spark.sql.execution.datasources.v2 + +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.catalyst.util.truncatedString +import org.apache.spark.sql.connector.iceberg.catalog.Procedure +import scala.collection.compat.immutable.ArraySeq + +case class CallExec( + output: Seq[Attribute], + procedure: Procedure, + input: InternalRow) extends LeafV2CommandExec { + + override protected def run(): Seq[InternalRow] = { + ArraySeq.unsafeWrapArray(procedure.call(input)) + } + + override def simpleString(maxFields: Int): String = { + s"CallExec${truncatedString(output, "[", ", ", "]", maxFields)} ${procedure.description}" + } +} diff --git a/spark/v3.4/spark-extensions/src/main/scala/org/apache/spark/sql/execution/datasources/v2/CreateOrReplaceBranchExec.scala b/spark/v3.4/spark-extensions/src/main/scala/org/apache/spark/sql/execution/datasources/v2/CreateOrReplaceBranchExec.scala new file mode 100644 index 000000000000..08230afb5a3f --- /dev/null +++ b/spark/v3.4/spark-extensions/src/main/scala/org/apache/spark/sql/execution/datasources/v2/CreateOrReplaceBranchExec.scala @@ -0,0 +1,81 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.spark.sql.execution.datasources.v2 + +import org.apache.iceberg.spark.source.SparkTable +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.catalyst.plans.logical.BranchOptions +import org.apache.spark.sql.connector.catalog.Identifier +import org.apache.spark.sql.connector.catalog.TableCatalog + +case class CreateOrReplaceBranchExec( + catalog: TableCatalog, + ident: Identifier, + branch: String, + branchOptions: BranchOptions, + replace: Boolean, + ifNotExists: Boolean) extends LeafV2CommandExec { + + import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._ + + override lazy val output: Seq[Attribute] = Nil + + override protected def run(): Seq[InternalRow] = { + catalog.loadTable(ident) match { + case iceberg: SparkTable => + val snapshotId = branchOptions.snapshotId.getOrElse(iceberg.table.currentSnapshot().snapshotId()) + val manageSnapshots = iceberg.table().manageSnapshots() + if (!replace) { + val ref = iceberg.table().refs().get(branch); + if (ref != null && ifNotExists) { + return Nil + } + + manageSnapshots.createBranch(branch, snapshotId) + } else { + manageSnapshots.replaceBranch(branch, snapshotId) + } + + if (branchOptions.numSnapshots.nonEmpty) { + manageSnapshots.setMinSnapshotsToKeep(branch, branchOptions.numSnapshots.get.toInt) + } + + if (branchOptions.snapshotRetain.nonEmpty) { + manageSnapshots.setMaxSnapshotAgeMs(branch, branchOptions.snapshotRetain.get) + } + + if (branchOptions.snapshotRefRetain.nonEmpty) { + manageSnapshots.setMaxRefAgeMs(branch, branchOptions.snapshotRefRetain.get) + } + + manageSnapshots.commit() + + case table => + throw new UnsupportedOperationException(s"Cannot create or replace branch on non-Iceberg table: $table") + } + + Nil + } + + override def simpleString(maxFields: Int): String = { + s"CreateOrReplace branch: ${branch} for table: ${ident.quoted}" + } +} diff --git a/spark/v3.4/spark-extensions/src/main/scala/org/apache/spark/sql/execution/datasources/v2/CreateOrReplaceTagExec.scala b/spark/v3.4/spark-extensions/src/main/scala/org/apache/spark/sql/execution/datasources/v2/CreateOrReplaceTagExec.scala new file mode 100644 index 000000000000..d41f9f03ff4c --- /dev/null +++ b/spark/v3.4/spark-extensions/src/main/scala/org/apache/spark/sql/execution/datasources/v2/CreateOrReplaceTagExec.scala @@ -0,0 +1,72 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.spark.sql.execution.datasources.v2 + +import org.apache.iceberg.spark.source.SparkTable +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.catalyst.plans.logical.TagOptions +import org.apache.spark.sql.connector.catalog._ + +case class CreateOrReplaceTagExec( + catalog: TableCatalog, + ident: Identifier, + tag: String, + tagOptions: TagOptions, + replace: Boolean, + ifNotExists: Boolean) extends LeafV2CommandExec { + + import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._ + + override lazy val output: Seq[Attribute] = Nil + + override protected def run(): Seq[InternalRow] = { + catalog.loadTable(ident) match { + case iceberg: SparkTable => + val snapshotId = tagOptions.snapshotId.getOrElse(iceberg.table.currentSnapshot().snapshotId()) + val manageSnapshot = iceberg.table.manageSnapshots() + if (!replace) { + val ref = iceberg.table().refs().get(tag); + if (ref != null && ifNotExists) { + return Nil + } + + manageSnapshot.createTag(tag, snapshotId) + } else { + manageSnapshot.replaceTag(tag, snapshotId) + } + + if (tagOptions.snapshotRefRetain.nonEmpty) { + manageSnapshot.setMaxRefAgeMs(tag, tagOptions.snapshotRefRetain.get) + } + + manageSnapshot.commit() + + case table => + throw new UnsupportedOperationException(s"Cannot create tag to non-Iceberg table: $table") + } + + Nil + } + + override def simpleString(maxFields: Int): String = { + s"Create tag: ${tag} for table: ${ident.quoted}" + } +} diff --git a/spark/v3.4/spark-extensions/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DropBranchExec.scala b/spark/v3.4/spark-extensions/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DropBranchExec.scala new file mode 100644 index 000000000000..ff8f1820099a --- /dev/null +++ b/spark/v3.4/spark-extensions/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DropBranchExec.scala @@ -0,0 +1,56 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.spark.sql.execution.datasources.v2 + +import org.apache.iceberg.spark.source.SparkTable +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.connector.catalog.Identifier +import org.apache.spark.sql.connector.catalog.TableCatalog + +case class DropBranchExec( + catalog: TableCatalog, + ident: Identifier, + branch: String, + ifExists: Boolean) extends LeafV2CommandExec { + + import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._ + + override lazy val output: Seq[Attribute] = Nil + + override protected def run(): Seq[InternalRow] = { + catalog.loadTable(ident) match { + case iceberg: SparkTable => + val ref = iceberg.table().refs().get(branch) + if (ref != null || !ifExists) { + iceberg.table().manageSnapshots().removeBranch(branch).commit() + } + + case table => + throw new UnsupportedOperationException(s"Cannot drop branch on non-Iceberg table: $table") + } + + Nil + } + + override def simpleString(maxFields: Int): String = { + s"DropBranch branch: ${branch} for table: ${ident.quoted}" + } +} diff --git a/spark/v3.4/spark-extensions/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DropIdentifierFieldsExec.scala b/spark/v3.4/spark-extensions/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DropIdentifierFieldsExec.scala new file mode 100644 index 000000000000..dee778b474f9 --- /dev/null +++ b/spark/v3.4/spark-extensions/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DropIdentifierFieldsExec.scala @@ -0,0 +1,65 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.spark.sql.execution.datasources.v2 + +import org.apache.iceberg.relocated.com.google.common.base.Preconditions +import org.apache.iceberg.relocated.com.google.common.collect.Sets +import org.apache.iceberg.spark.source.SparkTable +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.connector.catalog.Identifier +import org.apache.spark.sql.connector.catalog.TableCatalog + +case class DropIdentifierFieldsExec( + catalog: TableCatalog, + ident: Identifier, + fields: Seq[String]) extends LeafV2CommandExec { + import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._ + + override lazy val output: Seq[Attribute] = Nil + + override protected def run(): Seq[InternalRow] = { + catalog.loadTable(ident) match { + case iceberg: SparkTable => + val schema = iceberg.table.schema + val identifierFieldNames = Sets.newHashSet(schema.identifierFieldNames) + + for (name <- fields) { + Preconditions.checkArgument(schema.findField(name) != null, + "Cannot complete drop identifier fields operation: field %s not found", name) + Preconditions.checkArgument(identifierFieldNames.contains(name), + "Cannot complete drop identifier fields operation: %s is not an identifier field", name) + identifierFieldNames.remove(name) + } + + iceberg.table.updateSchema() + .setIdentifierFields(identifierFieldNames) + .commit(); + case table => + throw new UnsupportedOperationException(s"Cannot drop identifier fields in non-Iceberg table: $table") + } + + Nil + } + + override def simpleString(maxFields: Int): String = { + s"DropIdentifierFields ${catalog.name}.${ident.quoted} (${fields.quoted})"; + } +} diff --git a/spark/v3.4/spark-extensions/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DropPartitionFieldExec.scala b/spark/v3.4/spark-extensions/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DropPartitionFieldExec.scala new file mode 100644 index 000000000000..9a153f0c004e --- /dev/null +++ b/spark/v3.4/spark-extensions/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DropPartitionFieldExec.scala @@ -0,0 +1,67 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.spark.sql.execution.datasources.v2 + +import org.apache.iceberg.spark.Spark3Util +import org.apache.iceberg.spark.source.SparkTable +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.connector.catalog.Identifier +import org.apache.spark.sql.connector.catalog.TableCatalog +import org.apache.spark.sql.connector.expressions.FieldReference +import org.apache.spark.sql.connector.expressions.IdentityTransform +import org.apache.spark.sql.connector.expressions.Transform + +case class DropPartitionFieldExec( + catalog: TableCatalog, + ident: Identifier, + transform: Transform) extends LeafV2CommandExec { + import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._ + + override lazy val output: Seq[Attribute] = Nil + + override protected def run(): Seq[InternalRow] = { + catalog.loadTable(ident) match { + case iceberg: SparkTable => + val schema = iceberg.table.schema + transform match { + case IdentityTransform(FieldReference(parts)) if parts.size == 1 && schema.findField(parts.head) == null => + // the name is not present in the Iceberg schema, so it must be a partition field name, not a column name + iceberg.table.updateSpec() + .removeField(parts.head) + .commit() + + case _ => + iceberg.table.updateSpec() + .removeField(Spark3Util.toIcebergTerm(transform)) + .commit() + } + + case table => + throw new UnsupportedOperationException(s"Cannot drop partition field in non-Iceberg table: $table") + } + + Nil + } + + override def simpleString(maxFields: Int): String = { + s"DropPartitionField ${catalog.name}.${ident.quoted} ${transform.describe}" + } +} diff --git a/spark/v3.4/spark-extensions/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DropTagExec.scala b/spark/v3.4/spark-extensions/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DropTagExec.scala new file mode 100644 index 000000000000..8df88765a986 --- /dev/null +++ b/spark/v3.4/spark-extensions/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DropTagExec.scala @@ -0,0 +1,56 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.spark.sql.execution.datasources.v2 + +import org.apache.iceberg.spark.source.SparkTable +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.connector.catalog.Identifier +import org.apache.spark.sql.connector.catalog.TableCatalog + +case class DropTagExec( + catalog: TableCatalog, + ident: Identifier, + tag: String, + ifExists: Boolean) extends LeafV2CommandExec { + + import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._ + + override lazy val output: Seq[Attribute] = Nil + + override protected def run(): Seq[InternalRow] = { + catalog.loadTable(ident) match { + case iceberg: SparkTable => + val ref = iceberg.table().refs().get(tag) + if (ref != null || !ifExists) { + iceberg.table().manageSnapshots().removeTag(tag).commit() + } + + case table => + throw new UnsupportedOperationException(s"Cannot drop tag on non-Iceberg table: $table") + } + + Nil + } + + override def simpleString(maxFields: Int): String = { + s"DropTag tag: ${tag} for table: ${ident.quoted}" + } +} diff --git a/spark/v3.4/spark-extensions/src/main/scala/org/apache/spark/sql/execution/datasources/v2/ExtendedDataSourceV2Implicits.scala b/spark/v3.4/spark-extensions/src/main/scala/org/apache/spark/sql/execution/datasources/v2/ExtendedDataSourceV2Implicits.scala new file mode 100644 index 000000000000..85bda0b08d46 --- /dev/null +++ b/spark/v3.4/spark-extensions/src/main/scala/org/apache/spark/sql/execution/datasources/v2/ExtendedDataSourceV2Implicits.scala @@ -0,0 +1,40 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.spark.sql.execution.datasources.v2 + +import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.connector.catalog.Table +import org.apache.spark.sql.connector.write.RowLevelOperationTable + +/** + * A class similar to DataSourceV2Implicits in Spark but contains custom implicit helpers. + */ +object ExtendedDataSourceV2Implicits { + implicit class TableHelper(table: Table) { + def asRowLevelOperationTable: RowLevelOperationTable = { + table match { + case rowLevelOperationTable: RowLevelOperationTable => + rowLevelOperationTable + case _ => + throw new AnalysisException(s"Table ${table.name} is not a row-level operation table") + } + } + } +} diff --git a/spark/v3.4/spark-extensions/src/main/scala/org/apache/spark/sql/execution/datasources/v2/ExtendedDataSourceV2Strategy.scala b/spark/v3.4/spark-extensions/src/main/scala/org/apache/spark/sql/execution/datasources/v2/ExtendedDataSourceV2Strategy.scala new file mode 100644 index 000000000000..8efd0d7a5a53 --- /dev/null +++ b/spark/v3.4/spark-extensions/src/main/scala/org/apache/spark/sql/execution/datasources/v2/ExtendedDataSourceV2Strategy.scala @@ -0,0 +1,158 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.spark.sql.execution.datasources.v2 + +import org.apache.iceberg.spark.Spark3Util +import org.apache.iceberg.spark.SparkCatalog +import org.apache.iceberg.spark.SparkSessionCatalog +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.Strategy +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.Expression +import org.apache.spark.sql.catalyst.expressions.GenericInternalRow +import org.apache.spark.sql.catalyst.expressions.Literal +import org.apache.spark.sql.catalyst.expressions.PredicateHelper +import org.apache.spark.sql.catalyst.plans.logical.AddPartitionField +import org.apache.spark.sql.catalyst.plans.logical.Call +import org.apache.spark.sql.catalyst.plans.logical.CreateOrReplaceBranch +import org.apache.spark.sql.catalyst.plans.logical.CreateOrReplaceTag +import org.apache.spark.sql.catalyst.plans.logical.DeleteFromIcebergTable +import org.apache.spark.sql.catalyst.plans.logical.DropBranch +import org.apache.spark.sql.catalyst.plans.logical.DropIdentifierFields +import org.apache.spark.sql.catalyst.plans.logical.DropPartitionField +import org.apache.spark.sql.catalyst.plans.logical.DropTag +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.catalyst.plans.logical.MergeRows +import org.apache.spark.sql.catalyst.plans.logical.NoStatsUnaryNode +import org.apache.spark.sql.catalyst.plans.logical.ReplaceIcebergData +import org.apache.spark.sql.catalyst.plans.logical.ReplacePartitionField +import org.apache.spark.sql.catalyst.plans.logical.SetIdentifierFields +import org.apache.spark.sql.catalyst.plans.logical.SetWriteDistributionAndOrdering +import org.apache.spark.sql.catalyst.plans.logical.WriteIcebergDelta +import org.apache.spark.sql.connector.catalog.Identifier +import org.apache.spark.sql.connector.catalog.TableCatalog +import org.apache.spark.sql.errors.QueryCompilationErrors +import org.apache.spark.sql.execution.SparkPlan +import org.apache.spark.sql.execution.datasources.DataSourceStrategy +import scala.jdk.CollectionConverters._ + +case class ExtendedDataSourceV2Strategy(spark: SparkSession) extends Strategy with PredicateHelper { + + import DataSourceV2Implicits._ + + override def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { + case c @ Call(procedure, args) => + val input = buildInternalRow(args) + CallExec(c.output, procedure, input) :: Nil + + case AddPartitionField(IcebergCatalogAndIdentifier(catalog, ident), transform, name) => + AddPartitionFieldExec(catalog, ident, transform, name) :: Nil + + case CreateOrReplaceBranch( + IcebergCatalogAndIdentifier(catalog, ident), branch, branchOptions, replace, ifNotExists) => + CreateOrReplaceBranchExec(catalog, ident, branch, branchOptions, replace, ifNotExists) :: Nil + + case CreateOrReplaceTag(IcebergCatalogAndIdentifier(catalog, ident), tag, tagOptions, replace, ifNotExists) => + CreateOrReplaceTagExec(catalog, ident, tag, tagOptions, replace, ifNotExists) :: Nil + + case DropBranch(IcebergCatalogAndIdentifier(catalog, ident), branch, ifExists) => + DropBranchExec(catalog, ident, branch, ifExists) :: Nil + + case DropTag(IcebergCatalogAndIdentifier(catalog, ident), tag, ifExists) => + DropTagExec(catalog, ident, tag, ifExists) :: Nil + + case DropPartitionField(IcebergCatalogAndIdentifier(catalog, ident), transform) => + DropPartitionFieldExec(catalog, ident, transform) :: Nil + + case ReplacePartitionField(IcebergCatalogAndIdentifier(catalog, ident), transformFrom, transformTo, name) => + ReplacePartitionFieldExec(catalog, ident, transformFrom, transformTo, name) :: Nil + + case SetIdentifierFields(IcebergCatalogAndIdentifier(catalog, ident), fields) => + SetIdentifierFieldsExec(catalog, ident, fields) :: Nil + + case DropIdentifierFields(IcebergCatalogAndIdentifier(catalog, ident), fields) => + DropIdentifierFieldsExec(catalog, ident, fields) :: Nil + + case SetWriteDistributionAndOrdering( + IcebergCatalogAndIdentifier(catalog, ident), distributionMode, ordering) => + SetWriteDistributionAndOrderingExec(catalog, ident, distributionMode, ordering) :: Nil + + case ReplaceIcebergData(_: DataSourceV2Relation, query, r: DataSourceV2Relation, Some(write)) => + // refresh the cache using the original relation + ReplaceDataExec(planLater(query), refreshCache(r), write) :: Nil + + case WriteIcebergDelta(_: DataSourceV2Relation, query, r: DataSourceV2Relation, projs, Some(write)) => + // refresh the cache using the original relation + WriteDeltaExec(planLater(query), refreshCache(r), projs, write) :: Nil + + case MergeRows(isSourceRowPresent, isTargetRowPresent, matchedConditions, matchedOutputs, notMatchedConditions, + notMatchedOutputs, targetOutput, rowIdAttrs, performCardinalityCheck, emitNotMatchedTargetRows, + output, child) => + + MergeRowsExec(isSourceRowPresent, isTargetRowPresent, matchedConditions, matchedOutputs, notMatchedConditions, + notMatchedOutputs, targetOutput, rowIdAttrs, performCardinalityCheck, emitNotMatchedTargetRows, + output, planLater(child)) :: Nil + + case DeleteFromIcebergTable(DataSourceV2ScanRelation(r, _, output, _, _), condition, None) => + // the optimizer has already checked that this delete can be handled using a metadata operation + val deleteCond = condition.getOrElse(Literal.TrueLiteral) + val predicates = splitConjunctivePredicates(deleteCond) + val normalizedPredicates = DataSourceStrategy.normalizeExprs(predicates, output) + val filters = normalizedPredicates.flatMap { pred => + val filter = DataSourceV2Strategy.translateFilterV2(pred) + if (filter.isEmpty) { + throw QueryCompilationErrors.cannotTranslateExpressionToSourceFilterError(pred) + } + filter + }.toArray + DeleteFromTableExec(r.table.asDeletable, filters, refreshCache(r)) :: Nil + + case NoStatsUnaryNode(child) => + planLater(child) :: Nil + + case _ => Nil + } + + private def buildInternalRow(exprs: Seq[Expression]): InternalRow = { + val values = new Array[Any](exprs.size) + for (index <- exprs.indices) { + values(index) = exprs(index).eval() + } + new GenericInternalRow(values) + } + + private def refreshCache(r: DataSourceV2Relation)(): Unit = { + spark.sharedState.cacheManager.recacheByPlan(spark, r) + } + + private object IcebergCatalogAndIdentifier { + def unapply(identifier: Seq[String]): Option[(TableCatalog, Identifier)] = { + val catalogAndIdentifier = Spark3Util.catalogAndIdentifier(spark, identifier.asJava) + catalogAndIdentifier.catalog match { + case icebergCatalog: SparkCatalog => + Some((icebergCatalog, catalogAndIdentifier.identifier)) + case icebergCatalog: SparkSessionCatalog[_] => + Some((icebergCatalog, catalogAndIdentifier.identifier)) + case _ => + None + } + } + } +} diff --git a/spark/v3.4/spark-extensions/src/main/scala/org/apache/spark/sql/execution/datasources/v2/ExtendedDistributionAndOrderingUtils.scala b/spark/v3.4/spark-extensions/src/main/scala/org/apache/spark/sql/execution/datasources/v2/ExtendedDistributionAndOrderingUtils.scala new file mode 100644 index 000000000000..8c37b1b75924 --- /dev/null +++ b/spark/v3.4/spark-extensions/src/main/scala/org/apache/spark/sql/execution/datasources/v2/ExtendedDistributionAndOrderingUtils.scala @@ -0,0 +1,85 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.spark.sql.execution.datasources.v2 + +import org.apache.spark.sql.catalyst.expressions.Expression +import org.apache.spark.sql.catalyst.expressions.ExtendedV2ExpressionUtils.toCatalyst +import org.apache.spark.sql.catalyst.expressions.SortOrder +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.catalyst.plans.logical.RepartitionByExpression +import org.apache.spark.sql.catalyst.plans.logical.Sort +import org.apache.spark.sql.connector.distributions.ClusteredDistribution +import org.apache.spark.sql.connector.distributions.OrderedDistribution +import org.apache.spark.sql.connector.distributions.UnspecifiedDistribution +import org.apache.spark.sql.connector.write.RequiresDistributionAndOrdering +import org.apache.spark.sql.connector.write.Write +import org.apache.spark.sql.errors.QueryCompilationErrors +import org.apache.spark.sql.internal.SQLConf +import scala.collection.compat.immutable.ArraySeq + +/** + * A rule that is inspired by DistributionAndOrderingUtils in Spark but supports Iceberg transforms. + * + * Note that similarly to the original rule in Spark, it does not let AQE pick the number of shuffle + * partitions. See SPARK-34230 for context. + */ +object ExtendedDistributionAndOrderingUtils { + + def prepareQuery(write: Write, query: LogicalPlan, conf: SQLConf): LogicalPlan = write match { + case write: RequiresDistributionAndOrdering => + val numPartitions = write.requiredNumPartitions() + val distribution = write.requiredDistribution match { + case d: OrderedDistribution => d.ordering.map(e => toCatalyst(e, query)) + case d: ClusteredDistribution => d.clustering.map(e => toCatalyst(e, query)) + case _: UnspecifiedDistribution => Array.empty[Expression] + } + + val queryWithDistribution = if (distribution.nonEmpty) { + val finalNumPartitions = if (numPartitions > 0) { + numPartitions + } else { + conf.numShufflePartitions + } + // the conversion to catalyst expressions above produces SortOrder expressions + // for OrderedDistribution and generic expressions for ClusteredDistribution + // this allows RepartitionByExpression to pick either range or hash partitioning + RepartitionByExpression(ArraySeq.unsafeWrapArray(distribution), query, finalNumPartitions) + } else if (numPartitions > 0) { + throw QueryCompilationErrors.numberOfPartitionsNotAllowedWithUnspecifiedDistributionError() + } else { + query + } + + val ordering = write.requiredOrdering.toSeq + .map(e => toCatalyst(e, query)) + .asInstanceOf[Seq[SortOrder]] + + val queryWithDistributionAndOrdering = if (ordering.nonEmpty) { + Sort(ordering, global = false, queryWithDistribution) + } else { + queryWithDistribution + } + + queryWithDistributionAndOrdering + + case _ => + query + } +} diff --git a/spark/v3.4/spark-extensions/src/main/scala/org/apache/spark/sql/execution/datasources/v2/ExtendedV2Writes.scala b/spark/v3.4/spark-extensions/src/main/scala/org/apache/spark/sql/execution/datasources/v2/ExtendedV2Writes.scala new file mode 100644 index 000000000000..83b793925db2 --- /dev/null +++ b/spark/v3.4/spark-extensions/src/main/scala/org/apache/spark/sql/execution/datasources/v2/ExtendedV2Writes.scala @@ -0,0 +1,150 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.spark.sql.execution.datasources.v2 + +import java.util.Optional +import java.util.UUID +import org.apache.spark.sql.catalyst.expressions.PredicateHelper +import org.apache.spark.sql.catalyst.plans.logical.AppendData +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.catalyst.plans.logical.OverwriteByExpression +import org.apache.spark.sql.catalyst.plans.logical.OverwritePartitionsDynamic +import org.apache.spark.sql.catalyst.plans.logical.Project +import org.apache.spark.sql.catalyst.plans.logical.ReplaceIcebergData +import org.apache.spark.sql.catalyst.plans.logical.WriteIcebergDelta +import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.catalyst.util.WriteDeltaProjections +import org.apache.spark.sql.catalyst.utils.PlanUtils.isIcebergRelation +import org.apache.spark.sql.connector.catalog.Table +import org.apache.spark.sql.connector.write.DeltaWriteBuilder +import org.apache.spark.sql.connector.write.LogicalWriteInfoImpl +import org.apache.spark.sql.connector.write.SupportsDynamicOverwrite +import org.apache.spark.sql.connector.write.SupportsOverwrite +import org.apache.spark.sql.connector.write.SupportsTruncate +import org.apache.spark.sql.connector.write.WriteBuilder +import org.apache.spark.sql.errors.QueryCompilationErrors +import org.apache.spark.sql.errors.QueryExecutionErrors +import org.apache.spark.sql.execution.datasources.DataSourceStrategy +import org.apache.spark.sql.sources.AlwaysTrue +import org.apache.spark.sql.sources.Filter +import org.apache.spark.sql.types.StructType + +/** + * A rule that is inspired by V2Writes in Spark but supports Iceberg transforms. + */ +object ExtendedV2Writes extends Rule[LogicalPlan] with PredicateHelper { + + import DataSourceV2Implicits._ + + override def apply(plan: LogicalPlan): LogicalPlan = plan transformDown { + case a @ AppendData(r: DataSourceV2Relation, query, options, _, None, _) if isIcebergRelation(r) => + val writeBuilder = newWriteBuilder(r.table, query.schema, options) + val write = writeBuilder.build() + val newQuery = ExtendedDistributionAndOrderingUtils.prepareQuery(write, query, conf) + a.copy(write = Some(write), query = newQuery) + + case o @ OverwriteByExpression(r: DataSourceV2Relation, deleteExpr, query, options, _, None, _) + if isIcebergRelation(r) => + // fail if any filter cannot be converted. correctness depends on removing all matching data. + val filters = splitConjunctivePredicates(deleteExpr).flatMap { pred => + val filter = DataSourceStrategy.translateFilter(pred, supportNestedPredicatePushdown = true) + if (filter.isEmpty) { + throw QueryCompilationErrors.cannotTranslateExpressionToSourceFilterError(pred) + } + filter + }.toArray + + val table = r.table + val writeBuilder = newWriteBuilder(table, query.schema, options) + val write = writeBuilder match { + case builder: SupportsTruncate if isTruncate(filters) => + builder.truncate().build() + case builder: SupportsOverwrite => + builder.overwrite(filters).build() + case _ => + throw QueryExecutionErrors.overwriteTableByUnsupportedExpressionError(table) + } + + val newQuery = ExtendedDistributionAndOrderingUtils.prepareQuery(write, query, conf) + o.copy(write = Some(write), query = newQuery) + + case o @ OverwritePartitionsDynamic(r: DataSourceV2Relation, query, options, _, None) + if isIcebergRelation(r) => + val table = r.table + val writeBuilder = newWriteBuilder(table, query.schema, options) + val write = writeBuilder match { + case builder: SupportsDynamicOverwrite => + builder.overwriteDynamicPartitions().build() + case _ => + throw QueryExecutionErrors.dynamicPartitionOverwriteUnsupportedByTableError(table) + } + val newQuery = ExtendedDistributionAndOrderingUtils.prepareQuery(write, query, conf) + o.copy(write = Some(write), query = newQuery) + + case rd @ ReplaceIcebergData(r: DataSourceV2Relation, query, _, None) => + val rowSchema = StructType.fromAttributes(rd.dataInput) + val writeBuilder = newWriteBuilder(r.table, rowSchema, Map.empty) + val write = writeBuilder.build() + val newQuery = ExtendedDistributionAndOrderingUtils.prepareQuery(write, query, conf) + rd.copy(write = Some(write), query = Project(rd.dataInput, newQuery)) + + case wd @ WriteIcebergDelta(r: DataSourceV2Relation, query, _, projections, None) => + val deltaWriteBuilder = newDeltaWriteBuilder(r.table, Map.empty, projections) + val deltaWrite = deltaWriteBuilder.build() + val newQuery = ExtendedDistributionAndOrderingUtils.prepareQuery(deltaWrite, query, conf) + wd.copy(write = Some(deltaWrite), query = newQuery) + } + + private def isTruncate(filters: Array[Filter]): Boolean = { + filters.length == 1 && filters(0).isInstanceOf[AlwaysTrue] + } + + private def newWriteBuilder( + table: Table, + rowSchema: StructType, + writeOptions: Map[String, String], + queryId: String = UUID.randomUUID().toString): WriteBuilder = { + + val info = LogicalWriteInfoImpl(queryId, rowSchema, writeOptions.asOptions) + table.asWritable.newWriteBuilder(info) + } + + private def newDeltaWriteBuilder( + table: Table, + writeOptions: Map[String, String], + projections: WriteDeltaProjections, + queryId: String = UUID.randomUUID().toString): DeltaWriteBuilder = { + + val rowSchema = projections.rowProjection.map(_.schema).getOrElse(StructType(Nil)) + val rowIdSchema = projections.rowIdProjection.schema + val metadataSchema = projections.metadataProjection.map(_.schema) + + val info = LogicalWriteInfoImpl( + queryId, + rowSchema, + writeOptions.asOptions, + Optional.of(rowIdSchema), + Optional.ofNullable(metadataSchema.orNull)) + + val writeBuilder = table.asWritable.newWriteBuilder(info) + assert(writeBuilder.isInstanceOf[DeltaWriteBuilder], s"$writeBuilder must be DeltaWriteBuilder") + writeBuilder.asInstanceOf[DeltaWriteBuilder] + } +} diff --git a/spark/v3.4/spark-extensions/src/main/scala/org/apache/spark/sql/execution/datasources/v2/MergeRowsExec.scala b/spark/v3.4/spark-extensions/src/main/scala/org/apache/spark/sql/execution/datasources/v2/MergeRowsExec.scala new file mode 100644 index 000000000000..4fbf8a523a54 --- /dev/null +++ b/spark/v3.4/spark-extensions/src/main/scala/org/apache/spark/sql/execution/datasources/v2/MergeRowsExec.scala @@ -0,0 +1,190 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.spark.sql.execution.datasources.v2 + +import org.apache.spark.SparkException +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.Ascending +import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.catalyst.expressions.AttributeSet +import org.apache.spark.sql.catalyst.expressions.BasePredicate +import org.apache.spark.sql.catalyst.expressions.Expression +import org.apache.spark.sql.catalyst.expressions.SortOrder +import org.apache.spark.sql.catalyst.expressions.UnsafeProjection +import org.apache.spark.sql.catalyst.expressions.codegen.GeneratePredicate +import org.apache.spark.sql.catalyst.util.truncatedString +import org.apache.spark.sql.execution.SparkPlan +import org.apache.spark.sql.execution.UnaryExecNode + +case class MergeRowsExec( + isSourceRowPresent: Expression, + isTargetRowPresent: Expression, + matchedConditions: Seq[Expression], + matchedOutputs: Seq[Seq[Expression]], + notMatchedConditions: Seq[Expression], + notMatchedOutputs: Seq[Seq[Expression]], + targetOutput: Seq[Expression], + rowIdAttrs: Seq[Attribute], + performCardinalityCheck: Boolean, + emitNotMatchedTargetRows: Boolean, + output: Seq[Attribute], + child: SparkPlan) extends UnaryExecNode { + + override def requiredChildOrdering: Seq[Seq[SortOrder]] = { + if (performCardinalityCheck) { + // request a local sort by the row ID attrs to co-locate matches for the same target row + Seq(rowIdAttrs.map(attr => SortOrder(attr, Ascending))) + } else { + Seq(Nil) + } + } + + @transient override lazy val producedAttributes: AttributeSet = { + AttributeSet(output.filterNot(attr => inputSet.contains(attr))) + } + + @transient override lazy val references: AttributeSet = child.outputSet + + override def simpleString(maxFields: Int): String = { + s"MergeRowsExec${truncatedString(output, "[", ", ", "]", maxFields)}" + } + + override protected def withNewChildInternal(newChild: SparkPlan): SparkPlan = { + copy(child = newChild) + } + + protected override def doExecute(): RDD[InternalRow] = { + child.execute().mapPartitions(processPartition) + } + + private def createProjection(exprs: Seq[Expression], attrs: Seq[Attribute]): UnsafeProjection = { + UnsafeProjection.create(exprs, attrs) + } + + private def createPredicate(expr: Expression, attrs: Seq[Attribute]): BasePredicate = { + GeneratePredicate.generate(expr, attrs) + } + + private def applyProjection( + actions: Seq[(BasePredicate, Option[UnsafeProjection])], + inputRow: InternalRow): InternalRow = { + + // find the first action where the predicate evaluates to true + // if there are overlapping conditions in actions, use the first matching action + // in the example below, when id = 5, both actions match but the first one is applied + // WHEN MATCHED AND id > 1 AND id < 10 UPDATE * + // WHEN MATCHED AND id = 5 OR id = 21 DELETE + + val pair = actions.find { + case (predicate, _) => predicate.eval(inputRow) + } + + // apply the projection to produce an output row, or return null to suppress this row + pair match { + case Some((_, Some(projection))) => + projection.apply(inputRow) + case _ => + null + } + } + + private def processPartition(rowIterator: Iterator[InternalRow]): Iterator[InternalRow] = { + val inputAttrs = child.output + + val isSourceRowPresentPred = createPredicate(isSourceRowPresent, inputAttrs) + val isTargetRowPresentPred = createPredicate(isTargetRowPresent, inputAttrs) + + val matchedPreds = matchedConditions.map(createPredicate(_, inputAttrs)) + val matchedProjs = matchedOutputs.map { + case output if output.nonEmpty => Some(createProjection(output, inputAttrs)) + case _ => None + } + val matchedPairs = matchedPreds zip matchedProjs + + val notMatchedPreds = notMatchedConditions.map(createPredicate(_, inputAttrs)) + val notMatchedProjs = notMatchedOutputs.map { + case output if output.nonEmpty => Some(createProjection(output, inputAttrs)) + case _ => None + } + val nonMatchedPairs = notMatchedPreds zip notMatchedProjs + + val projectTargetCols = createProjection(targetOutput, inputAttrs) + val rowIdProj = createProjection(rowIdAttrs, inputAttrs) + + // This method is responsible for processing a input row to emit the resultant row with an + // additional column that indicates whether the row is going to be included in the final + // output of merge or not. + // 1. Found a target row for which there is no corresponding source row (join condition not met) + // - Only project the target columns if we need to output unchanged rows + // 2. Found a source row for which there is no corresponding target row (join condition not met) + // - Apply the not matched actions (i.e INSERT actions) if non match conditions are met. + // 3. Found a source row for which there is a corresponding target row (join condition met) + // - Apply the matched actions (i.e DELETE or UPDATE actions) if match conditions are met. + def processRow(inputRow: InternalRow): InternalRow = { + if (emitNotMatchedTargetRows && !isSourceRowPresentPred.eval(inputRow)) { + projectTargetCols.apply(inputRow) + } else if (!isTargetRowPresentPred.eval(inputRow)) { + applyProjection(nonMatchedPairs, inputRow) + } else { + applyProjection(matchedPairs, inputRow) + } + } + + var lastMatchedRowId: InternalRow = null + + def processRowWithCardinalityCheck(inputRow: InternalRow): InternalRow = { + val isSourceRowPresent = isSourceRowPresentPred.eval(inputRow) + val isTargetRowPresent = isTargetRowPresentPred.eval(inputRow) + + if (isSourceRowPresent && isTargetRowPresent) { + val currentRowId = rowIdProj.apply(inputRow) + if (currentRowId == lastMatchedRowId) { + throw new SparkException( + "The ON search condition of the MERGE statement matched a single row from " + + "the target table with multiple rows of the source table. This could result " + + "in the target row being operated on more than once with an update or delete " + + "operation and is not allowed.") + } + lastMatchedRowId = currentRowId.copy() + } else { + lastMatchedRowId = null + } + + if (emitNotMatchedTargetRows && !isSourceRowPresent) { + projectTargetCols.apply(inputRow) + } else if (!isTargetRowPresent) { + applyProjection(nonMatchedPairs, inputRow) + } else { + applyProjection(matchedPairs, inputRow) + } + } + + val processFunc: InternalRow => InternalRow = if (performCardinalityCheck) { + processRowWithCardinalityCheck + } else { + processRow + } + + rowIterator + .map(processFunc) + .filter(row => row != null) + } +} diff --git a/spark/v3.4/spark-extensions/src/main/scala/org/apache/spark/sql/execution/datasources/v2/OptimizeMetadataOnlyDeleteFromIcebergTable.scala b/spark/v3.4/spark-extensions/src/main/scala/org/apache/spark/sql/execution/datasources/v2/OptimizeMetadataOnlyDeleteFromIcebergTable.scala new file mode 100644 index 000000000000..e81a567eb2a8 --- /dev/null +++ b/spark/v3.4/spark-extensions/src/main/scala/org/apache/spark/sql/execution/datasources/v2/OptimizeMetadataOnlyDeleteFromIcebergTable.scala @@ -0,0 +1,73 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.spark.sql.execution.datasources.v2 + +import org.apache.spark.sql.catalyst.expressions.Expression +import org.apache.spark.sql.catalyst.expressions.Literal +import org.apache.spark.sql.catalyst.expressions.PredicateHelper +import org.apache.spark.sql.catalyst.expressions.SubqueryExpression +import org.apache.spark.sql.catalyst.plans.logical.DeleteFromIcebergTable +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.connector.catalog.SupportsDelete +import org.apache.spark.sql.execution.datasources.DataSourceStrategy +import org.apache.spark.sql.sources +import org.slf4j.LoggerFactory + +/** + * Checks whether a metadata delete is possible and nullifies the rewrite plan if the source can + * handle this delete without executing the rewrite plan. + * + * Note this rule must be run after expression optimization. + */ +object OptimizeMetadataOnlyDeleteFromIcebergTable extends Rule[LogicalPlan] with PredicateHelper { + + val logger = LoggerFactory.getLogger(OptimizeMetadataOnlyDeleteFromIcebergTable.getClass) + + override def apply(plan: LogicalPlan): LogicalPlan = plan transform { + case d @ DeleteFromIcebergTable(relation: DataSourceV2Relation, cond, Some(_)) => + val deleteCond = cond.getOrElse(Literal.TrueLiteral) + relation.table match { + case table: SupportsDelete if !SubqueryExpression.hasSubquery(deleteCond) => + val predicates = splitConjunctivePredicates(deleteCond) + val normalizedPredicates = DataSourceStrategy.normalizeExprs(predicates, relation.output) + val dataSourceFilters = toDataSourceFilters(normalizedPredicates) + val allPredicatesTranslated = normalizedPredicates.size == dataSourceFilters.length + if (allPredicatesTranslated && table.canDeleteWhere(dataSourceFilters)) { + logger.info(s"Optimizing delete expression: ${dataSourceFilters.mkString(",")} as metadata delete") + d.copy(rewritePlan = None) + } else { + d + } + case _ => + d + } + } + + protected def toDataSourceFilters(predicates: Seq[Expression]): Array[sources.Filter] = { + predicates.flatMap { p => + val filter = DataSourceStrategy.translateFilter(p, supportNestedPredicatePushdown = true) + if (filter.isEmpty) { + logWarning(s"Cannot translate expression to source filter: $p") + } + filter + }.toArray + } +} diff --git a/spark/v3.4/spark-extensions/src/main/scala/org/apache/spark/sql/execution/datasources/v2/ReplaceDataExec.scala b/spark/v3.4/spark-extensions/src/main/scala/org/apache/spark/sql/execution/datasources/v2/ReplaceDataExec.scala new file mode 100644 index 000000000000..26c652469ac4 --- /dev/null +++ b/spark/v3.4/spark-extensions/src/main/scala/org/apache/spark/sql/execution/datasources/v2/ReplaceDataExec.scala @@ -0,0 +1,40 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.spark.sql.execution.datasources.v2 + +import org.apache.spark.sql.catalyst.expressions.AttributeSet +import org.apache.spark.sql.connector.write.Write +import org.apache.spark.sql.execution.SparkPlan + +/** + * Physical plan node to replace data in existing tables. + */ +case class ReplaceDataExec( + query: SparkPlan, + refreshCache: () => Unit, + write: Write) extends V2ExistingTableWriteExec { + + override lazy val references: AttributeSet = query.outputSet + override lazy val stringArgs: Iterator[Any] = Iterator(query, write) + + override protected def withNewChildInternal(newChild: SparkPlan): ReplaceDataExec = { + copy(query = newChild) + } +} diff --git a/spark/v3.4/spark-extensions/src/main/scala/org/apache/spark/sql/execution/datasources/v2/ReplacePartitionFieldExec.scala b/spark/v3.4/spark-extensions/src/main/scala/org/apache/spark/sql/execution/datasources/v2/ReplacePartitionFieldExec.scala new file mode 100644 index 000000000000..fcae0a5defc4 --- /dev/null +++ b/spark/v3.4/spark-extensions/src/main/scala/org/apache/spark/sql/execution/datasources/v2/ReplacePartitionFieldExec.scala @@ -0,0 +1,72 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.spark.sql.execution.datasources.v2 + +import org.apache.iceberg.spark.Spark3Util +import org.apache.iceberg.spark.source.SparkTable +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.connector.catalog.Identifier +import org.apache.spark.sql.connector.catalog.TableCatalog +import org.apache.spark.sql.connector.expressions.FieldReference +import org.apache.spark.sql.connector.expressions.IdentityTransform +import org.apache.spark.sql.connector.expressions.Transform + +case class ReplacePartitionFieldExec( + catalog: TableCatalog, + ident: Identifier, + transformFrom: Transform, + transformTo: Transform, + name: Option[String]) extends LeafV2CommandExec { + import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._ + + override lazy val output: Seq[Attribute] = Nil + + override protected def run(): Seq[InternalRow] = { + catalog.loadTable(ident) match { + case iceberg: SparkTable => + val schema = iceberg.table.schema + transformFrom match { + case IdentityTransform(FieldReference(parts)) if parts.size == 1 && schema.findField(parts.head) == null => + // the name is not present in the Iceberg schema, so it must be a partition field name, not a column name + iceberg.table.updateSpec() + .removeField(parts.head) + .addField(name.orNull, Spark3Util.toIcebergTerm(transformTo)) + .commit() + + case _ => + iceberg.table.updateSpec() + .removeField(Spark3Util.toIcebergTerm(transformFrom)) + .addField(name.orNull, Spark3Util.toIcebergTerm(transformTo)) + .commit() + } + + case table => + throw new UnsupportedOperationException(s"Cannot replace partition field in non-Iceberg table: $table") + } + + Nil + } + + override def simpleString(maxFields: Int): String = { + s"ReplacePartitionField ${catalog.name}.${ident.quoted} ${transformFrom.describe} " + + s"with ${name.map(n => s"$n=").getOrElse("")}${transformTo.describe}" + } +} diff --git a/spark/v3.4/spark-extensions/src/main/scala/org/apache/spark/sql/execution/datasources/v2/ReplaceRewrittenRowLevelCommand.scala b/spark/v3.4/spark-extensions/src/main/scala/org/apache/spark/sql/execution/datasources/v2/ReplaceRewrittenRowLevelCommand.scala new file mode 100644 index 000000000000..414d4c0ec305 --- /dev/null +++ b/spark/v3.4/spark-extensions/src/main/scala/org/apache/spark/sql/execution/datasources/v2/ReplaceRewrittenRowLevelCommand.scala @@ -0,0 +1,34 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.spark.sql.execution.datasources.v2 + +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.catalyst.plans.logical.RowLevelCommand +import org.apache.spark.sql.catalyst.rules.Rule + +/** + * Replaces operations such as DELETE and MERGE with the corresponding rewrite plans. + */ +object ReplaceRewrittenRowLevelCommand extends Rule[LogicalPlan] { + override def apply(plan: LogicalPlan): LogicalPlan = plan transformDown { + case c: RowLevelCommand if c.rewritePlan.isDefined => + c.rewritePlan.get + } +} diff --git a/spark/v3.4/spark-extensions/src/main/scala/org/apache/spark/sql/execution/datasources/v2/RowLevelCommandScanRelationPushDown.scala b/spark/v3.4/spark-extensions/src/main/scala/org/apache/spark/sql/execution/datasources/v2/RowLevelCommandScanRelationPushDown.scala new file mode 100644 index 000000000000..e4cd9cdc9fb0 --- /dev/null +++ b/spark/v3.4/spark-extensions/src/main/scala/org/apache/spark/sql/execution/datasources/v2/RowLevelCommandScanRelationPushDown.scala @@ -0,0 +1,195 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.spark.sql.execution.datasources.v2 + +import org.apache.spark.sql.catalyst.expressions.And +import org.apache.spark.sql.catalyst.expressions.AttributeReference +import org.apache.spark.sql.catalyst.expressions.AttributeSet +import org.apache.spark.sql.catalyst.expressions.Expression +import org.apache.spark.sql.catalyst.expressions.PredicateHelper +import org.apache.spark.sql.catalyst.expressions.SubqueryExpression +import org.apache.spark.sql.catalyst.planning.RewrittenRowLevelCommand +import org.apache.spark.sql.catalyst.planning.ScanOperation +import org.apache.spark.sql.catalyst.plans.logical.Join +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.catalyst.plans.logical.MergeIntoIcebergTable +import org.apache.spark.sql.catalyst.plans.logical.NoStatsUnaryNode +import org.apache.spark.sql.catalyst.plans.logical.ReplaceIcebergData +import org.apache.spark.sql.catalyst.plans.logical.RowLevelCommand +import org.apache.spark.sql.catalyst.plans.logical.WriteIcebergDelta +import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.catalyst.util.CharVarcharUtils +import org.apache.spark.sql.connector.expressions.filter.Predicate +import org.apache.spark.sql.connector.read.ScanBuilder +import org.apache.spark.sql.execution.datasources.DataSourceStrategy +import org.apache.spark.sql.sources.Filter +import org.apache.spark.sql.types.StructType + +object RowLevelCommandScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper { + import ExtendedDataSourceV2Implicits._ + + override def apply(plan: LogicalPlan): LogicalPlan = plan transformDown { + // use native Spark planning for delta-based plans + // unlike other commands, these plans have filters that can be pushed down directly + case RewrittenRowLevelCommand(command, _: DataSourceV2Relation, rewritePlan: WriteIcebergDelta) => + val newRewritePlan = V2ScanRelationPushDown.apply(rewritePlan) + command.withNewRewritePlan(newRewritePlan) + + // group-based MERGE operations are rewritten as joins and may be planned in a special way + // the join condition is the MERGE condition and can be pushed into the source + // this allows us to remove completely pushed down predicates from the join condition + case UnplannedGroupBasedMergeOperation(command, rd: ReplaceIcebergData, + join @ Join(_, _, _, Some(joinCond), _), relation: DataSourceV2Relation) => + + val table = relation.table.asRowLevelOperationTable + val scanBuilder = table.newScanBuilder(relation.options) + + val (pushedFilters, newJoinCond) = pushMergeFilters(joinCond, relation, scanBuilder) + val pushedFiltersStr = if (pushedFilters.isLeft) { + pushedFilters.left.get.mkString(", ") + } else { + pushedFilters.right.get.mkString(", ") + } + + val (scan, output) = PushDownUtils.pruneColumns(scanBuilder, relation, relation.output, Nil) + + logInfo( + s""" + |Pushing MERGE operators to ${relation.name} + |Pushed filters: $pushedFiltersStr + |Original JOIN condition: $joinCond + |New JOIN condition: $newJoinCond + |Output: ${output.mkString(", ")} + """.stripMargin) + + val newRewritePlan = rd transformDown { + case j: Join if j eq join => + j.copy(condition = newJoinCond) + case r: DataSourceV2Relation if r.table eq table => + DataSourceV2ScanRelation(r, scan, PushDownUtils.toOutputAttrs(scan.readSchema(), r)) + } + + command.withNewRewritePlan(newRewritePlan) + + // push down the filter from the command condition instead of the filter in the rewrite plan, + // which may be negated for copy-on-write DELETE and UPDATE operations + case RewrittenRowLevelCommand(command, relation: DataSourceV2Relation, rewritePlan) => + val table = relation.table.asRowLevelOperationTable + val scanBuilder = table.newScanBuilder(relation.options) + + val (pushedFilters, remainingFilters) = command.condition match { + case Some(cond) => pushFilters(cond, scanBuilder, relation.output) + case None => (Nil, Nil) + } + + val (scan, output) = PushDownUtils.pruneColumns(scanBuilder, relation, relation.output, Nil) + + logInfo( + s""" + |Pushing operators to ${relation.name} + |Pushed filters: ${pushedFilters.mkString(", ")} + |Filters that were not pushed: ${remainingFilters.mkString(",")} + |Output: ${output.mkString(", ")} + """.stripMargin) + + // replace DataSourceV2Relation with DataSourceV2ScanRelation for the row operation table + // there may be multiple read relations for UPDATEs that rely on the UNION approach + val newRewritePlan = rewritePlan transform { + case r: DataSourceV2Relation if r.table eq table => + DataSourceV2ScanRelation(r, scan, toOutputAttrs(scan.readSchema(), r)) + } + + command.withNewRewritePlan(newRewritePlan) + } + + private def pushFilters( + cond: Expression, + scanBuilder: ScanBuilder, + tableAttrs: Seq[AttributeReference]): (Seq[Filter], Seq[Predicate]) = { + + val tableAttrSet = AttributeSet(tableAttrs) + val filters = splitConjunctivePredicates(cond).filter(_.references.subsetOf(tableAttrSet)) + val normalizedFilters = DataSourceStrategy.normalizeExprs(filters, tableAttrs) + val (_, normalizedFiltersWithoutSubquery) = + normalizedFilters.partition(SubqueryExpression.hasSubquery) + + val (pushedFilters, _) = PushDownUtils.pushFilters(scanBuilder, normalizedFiltersWithoutSubquery) + (pushedFilters.left.getOrElse(Seq.empty), pushedFilters.right.getOrElse(Seq.empty)) + } + + // splits the join condition into predicates and tries to push down each predicate into the scan + // completely pushed down predicates are removed from the join condition + // joinCond can't have subqueries as it is validated by the rule that rewrites MERGE as a join + private def pushMergeFilters( + joinCond: Expression, + relation: DataSourceV2Relation, + scanBuilder: ScanBuilder): (Either[Seq[Filter], Seq[Predicate]], Option[Expression]) = { + + val (tableFilters, commonFilters) = + splitConjunctivePredicates(joinCond).partition(_.references.subsetOf(relation.outputSet)) + val normalizedTableFilters = DataSourceStrategy.normalizeExprs(tableFilters, relation.output) + val (pushedFilters, postScanFilters) = + PushDownUtils.pushFilters(scanBuilder, normalizedTableFilters) + val newJoinCond = (commonFilters ++ postScanFilters).reduceLeftOption(And) + + (pushedFilters, newJoinCond) + } + + private def toOutputAttrs( + schema: StructType, + relation: DataSourceV2Relation): Seq[AttributeReference] = { + val nameToAttr = relation.output.map(_.name).zip(relation.output).toMap + val cleaned = CharVarcharUtils.replaceCharVarcharWithStringInSchema(schema) + cleaned.toAttributes.map { + // keep the attribute id during transformation + a => a.withExprId(nameToAttr(a.name).exprId) + } + } +} + +object UnplannedGroupBasedMergeOperation { + type ReturnType = (RowLevelCommand, ReplaceIcebergData, Join, DataSourceV2Relation) + + def unapply(plan: LogicalPlan): Option[ReturnType] = plan match { + case m @ MergeIntoIcebergTable(_, _, _, _, _, Some(rewritePlan)) => + rewritePlan match { + case rd @ ReplaceIcebergData(DataSourceV2Relation(table, _, _, _, _), query, _, _) => + val joinsAndRelations = query.collect { + case j @ Join( + NoStatsUnaryNode(ScanOperation(_, pushDownFilters, pushUpFilters, r: DataSourceV2Relation)), _, _, _, _) + if pushUpFilters.isEmpty && pushDownFilters.isEmpty && r.table.eq(table) => + j -> r + } + + joinsAndRelations match { + case Seq((join, relation)) => + Some(m, rd, join, relation) + case _ => + None + } + + case _ => + None + } + + case _ => + None + } +} diff --git a/spark/v3.4/spark-extensions/src/main/scala/org/apache/spark/sql/execution/datasources/v2/SetIdentifierFieldsExec.scala b/spark/v3.4/spark-extensions/src/main/scala/org/apache/spark/sql/execution/datasources/v2/SetIdentifierFieldsExec.scala new file mode 100644 index 000000000000..b50550ad38ef --- /dev/null +++ b/spark/v3.4/spark-extensions/src/main/scala/org/apache/spark/sql/execution/datasources/v2/SetIdentifierFieldsExec.scala @@ -0,0 +1,53 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.spark.sql.execution.datasources.v2 + +import org.apache.iceberg.spark.source.SparkTable +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.connector.catalog.Identifier +import org.apache.spark.sql.connector.catalog.TableCatalog +import scala.jdk.CollectionConverters._ + +case class SetIdentifierFieldsExec( + catalog: TableCatalog, + ident: Identifier, + fields: Seq[String]) extends LeafV2CommandExec { + import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._ + + override lazy val output: Seq[Attribute] = Nil + + override protected def run(): Seq[InternalRow] = { + catalog.loadTable(ident) match { + case iceberg: SparkTable => + iceberg.table.updateSchema() + .setIdentifierFields(fields.asJava) + .commit(); + case table => + throw new UnsupportedOperationException(s"Cannot set identifier fields in non-Iceberg table: $table") + } + + Nil + } + + override def simpleString(maxFields: Int): String = { + s"SetIdentifierFields ${catalog.name}.${ident.quoted} (${fields.quoted})"; + } +} diff --git a/spark/v3.4/spark-extensions/src/main/scala/org/apache/spark/sql/execution/datasources/v2/SetWriteDistributionAndOrderingExec.scala b/spark/v3.4/spark-extensions/src/main/scala/org/apache/spark/sql/execution/datasources/v2/SetWriteDistributionAndOrderingExec.scala new file mode 100644 index 000000000000..386485b10b05 --- /dev/null +++ b/spark/v3.4/spark-extensions/src/main/scala/org/apache/spark/sql/execution/datasources/v2/SetWriteDistributionAndOrderingExec.scala @@ -0,0 +1,78 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.spark.sql.execution.datasources.v2 + +import org.apache.iceberg.DistributionMode +import org.apache.iceberg.NullOrder +import org.apache.iceberg.SortDirection +import org.apache.iceberg.TableProperties.WRITE_DISTRIBUTION_MODE +import org.apache.iceberg.expressions.Term +import org.apache.iceberg.spark.source.SparkTable +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.connector.catalog.CatalogV2Implicits +import org.apache.spark.sql.connector.catalog.Identifier +import org.apache.spark.sql.connector.catalog.TableCatalog + +case class SetWriteDistributionAndOrderingExec( + catalog: TableCatalog, + ident: Identifier, + distributionMode: DistributionMode, + sortOrder: Seq[(Term, SortDirection, NullOrder)]) extends LeafV2CommandExec { + + import CatalogV2Implicits._ + + override lazy val output: Seq[Attribute] = Nil + + override protected def run(): Seq[InternalRow] = { + catalog.loadTable(ident) match { + case iceberg: SparkTable => + val txn = iceberg.table.newTransaction() + + val orderBuilder = txn.replaceSortOrder() + sortOrder.foreach { + case (term, SortDirection.ASC, nullOrder) => + orderBuilder.asc(term, nullOrder) + case (term, SortDirection.DESC, nullOrder) => + orderBuilder.desc(term, nullOrder) + } + orderBuilder.commit() + + txn.updateProperties() + .set(WRITE_DISTRIBUTION_MODE, distributionMode.modeName()) + .commit() + + txn.commitTransaction() + + case table => + throw new UnsupportedOperationException(s"Cannot set write order of non-Iceberg table: $table") + } + + Nil + } + + override def simpleString(maxFields: Int): String = { + val tableIdent = s"${catalog.name}.${ident.quoted}" + val order = sortOrder.map { + case (term, direction, nullOrder) => s"$term $direction $nullOrder" + }.mkString(", ") + s"SetWriteDistributionAndOrdering $tableIdent $distributionMode $order" + } +} diff --git a/spark/v3.4/spark-extensions/src/main/scala/org/apache/spark/sql/execution/dynamicpruning/RowLevelCommandDynamicPruning.scala b/spark/v3.4/spark-extensions/src/main/scala/org/apache/spark/sql/execution/dynamicpruning/RowLevelCommandDynamicPruning.scala new file mode 100644 index 000000000000..f5d5affe9e92 --- /dev/null +++ b/spark/v3.4/spark-extensions/src/main/scala/org/apache/spark/sql/execution/dynamicpruning/RowLevelCommandDynamicPruning.scala @@ -0,0 +1,182 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.spark.sql.execution.dynamicpruning + +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.catalyst.expressions.And +import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.catalyst.expressions.AttributeMap +import org.apache.spark.sql.catalyst.expressions.AttributeReference +import org.apache.spark.sql.catalyst.expressions.DynamicPruningSubquery +import org.apache.spark.sql.catalyst.expressions.Expression +import org.apache.spark.sql.catalyst.expressions.ExtendedV2ExpressionUtils +import org.apache.spark.sql.catalyst.expressions.Literal +import org.apache.spark.sql.catalyst.expressions.PredicateHelper +import org.apache.spark.sql.catalyst.expressions.SubqueryExpression +import org.apache.spark.sql.catalyst.planning.RewrittenRowLevelCommand +import org.apache.spark.sql.catalyst.plans.LeftSemi +import org.apache.spark.sql.catalyst.plans.logical.DeleteFromIcebergTable +import org.apache.spark.sql.catalyst.plans.logical.Filter +import org.apache.spark.sql.catalyst.plans.logical.Join +import org.apache.spark.sql.catalyst.plans.logical.JoinHint +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.catalyst.plans.logical.MergeIntoIcebergTable +import org.apache.spark.sql.catalyst.plans.logical.Project +import org.apache.spark.sql.catalyst.plans.logical.ReplaceIcebergData +import org.apache.spark.sql.catalyst.plans.logical.RowLevelCommand +import org.apache.spark.sql.catalyst.plans.logical.Sort +import org.apache.spark.sql.catalyst.plans.logical.Subquery +import org.apache.spark.sql.catalyst.plans.logical.UpdateIcebergTable +import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.catalyst.trees.TreePattern.PLAN_EXPRESSION +import org.apache.spark.sql.catalyst.trees.TreePattern.SORT +import org.apache.spark.sql.connector.read.SupportsRuntimeFiltering +import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation +import org.apache.spark.sql.execution.datasources.v2.DataSourceV2ScanRelation +import org.apache.spark.sql.execution.datasources.v2.ExtendedDataSourceV2Implicits +import scala.collection.compat.immutable.ArraySeq + +/** + * A rule that adds a runtime filter for row-level commands. + * + * Note that only group-based rewrite plans (i.e. ReplaceData) are taken into account. + * Row-based rewrite plans are subject to usual runtime filtering. + */ +case class RowLevelCommandDynamicPruning(spark: SparkSession) extends Rule[LogicalPlan] with PredicateHelper { + + import ExtendedDataSourceV2Implicits._ + + override def apply(plan: LogicalPlan): LogicalPlan = plan transformDown { + // apply special dynamic filtering only for plans that don't support deltas + case RewrittenRowLevelCommand( + command: RowLevelCommand, + DataSourceV2ScanRelation(_, scan: SupportsRuntimeFiltering, _, _, _), + rewritePlan: ReplaceIcebergData) if conf.dynamicPartitionPruningEnabled && isCandidate(command) => + + // use reference equality to find exactly the required scan relations + val newRewritePlan = rewritePlan transformUp { + case r: DataSourceV2ScanRelation if r.scan eq scan => + // use the original table instance that was loaded for this row-level operation + // in order to leverage a regular batch scan in the group filter query + val originalTable = r.relation.table.asRowLevelOperationTable.table + val relation = r.relation.copy(table = originalTable) + val matchingRowsPlan = buildMatchingRowsPlan(relation, command) + + val filterAttrs = ArraySeq.unsafeWrapArray(scan.filterAttributes) + val buildKeys = ExtendedV2ExpressionUtils.resolveRefs[Attribute](filterAttrs, matchingRowsPlan) + val pruningKeys = ExtendedV2ExpressionUtils.resolveRefs[Attribute](filterAttrs, r) + val dynamicPruningCond = buildDynamicPruningCond(matchingRowsPlan, buildKeys, pruningKeys) + + Filter(dynamicPruningCond, r) + } + + // always optimize dynamic filtering subqueries for row-level commands as it is important + // to rewrite introduced predicates as joins because Spark recently stopped optimizing + // dynamic subqueries to facilitate broadcast reuse + command.withNewRewritePlan(optimizeSubquery(newRewritePlan)) + } + + private def isCandidate(command: RowLevelCommand): Boolean = command.condition match { + case Some(cond) if cond != Literal.TrueLiteral => true + case _ => false + } + + private def buildMatchingRowsPlan( + relation: DataSourceV2Relation, + command: RowLevelCommand): LogicalPlan = { + + // construct a filtering plan with the original scan relation + val matchingRowsPlan = command match { + case d: DeleteFromIcebergTable => + Filter(d.condition.get, relation) + + case u: UpdateIcebergTable => + // UPDATEs with subqueries are rewritten using a UNION with two identical scan relations + // the analyzer clones of them and assigns fresh expr IDs so that attributes don't collide + // this rule assigns dynamic filters to both scan relations based on the update condition + // the condition always refers to the original expr IDs and must be transformed + // see RewriteUpdateTable for more details + val attrMap = buildAttrMap(u.table.output, relation.output) + val transformedCond = u.condition.get transform { + case attr: AttributeReference if attrMap.contains(attr) => attrMap(attr) + } + Filter(transformedCond, relation) + + case m: MergeIntoIcebergTable => + Join(relation, m.sourceTable, LeftSemi, Some(m.mergeCondition), JoinHint.NONE) + } + + // clone the original relation in the filtering plan and assign new expr IDs to avoid conflicts + matchingRowsPlan transformUpWithNewOutput { + case r: DataSourceV2Relation if r eq relation => + val oldOutput = r.output + val newOutput = oldOutput.map(_.newInstance()) + r.copy(output = newOutput) -> oldOutput.zip(newOutput) + } + } + + private def buildDynamicPruningCond( + matchingRowsPlan: LogicalPlan, + buildKeys: Seq[Attribute], + pruningKeys: Seq[Attribute]): Expression = { + + val buildQuery = Project(buildKeys, matchingRowsPlan) + val dynamicPruningSubqueries = pruningKeys.zipWithIndex.map { case (key, index) => + DynamicPruningSubquery(key, buildQuery, buildKeys, index, onlyInBroadcast = false) + } + dynamicPruningSubqueries.reduce(And) + } + + private def buildAttrMap( + tableAttrs: Seq[Attribute], + scanAttrs: Seq[Attribute]): AttributeMap[Attribute] = { + + val resolver = conf.resolver + val attrMapping = tableAttrs.flatMap { tableAttr => + scanAttrs + .find(scanAttr => resolver(scanAttr.name, tableAttr.name)) + .map(scanAttr => tableAttr -> scanAttr) + } + AttributeMap(attrMapping) + } + + // borrowed from OptimizeSubqueries in Spark + private def optimizeSubquery(plan: LogicalPlan): LogicalPlan = plan.transformAllExpressionsWithPruning( + _.containsPattern(PLAN_EXPRESSION)) { + case s: SubqueryExpression => + val Subquery(newPlan, _) = spark.sessionState.optimizer.execute(Subquery.fromExpression(s)) + // At this point we have an optimized subquery plan that we are going to attach + // to this subquery expression. Here we can safely remove any top level sort + // in the plan as tuples produced by a subquery are un-ordered. + s.withNewPlan(removeTopLevelSort(newPlan)) + } + + // borrowed from OptimizeSubqueries in Spark + private def removeTopLevelSort(plan: LogicalPlan): LogicalPlan = { + if (!plan.containsPattern(SORT)) { + return plan + } + plan match { + case Sort(_, _, child) => child + case Project(fields, child) => Project(fields, removeTopLevelSort(child)) + case other => other + } + } +} diff --git a/spark/v3.4/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/Employee.java b/spark/v3.4/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/Employee.java new file mode 100644 index 000000000000..8918dfec6584 --- /dev/null +++ b/spark/v3.4/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/Employee.java @@ -0,0 +1,66 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.extensions; + +import java.util.Objects; + +public class Employee { + private Integer id; + private String dep; + + public Employee() {} + + public Employee(Integer id, String dep) { + this.id = id; + this.dep = dep; + } + + public Integer getId() { + return id; + } + + public void setId(Integer id) { + this.id = id; + } + + public String getDep() { + return dep; + } + + public void setDep(String dep) { + this.dep = dep; + } + + @Override + public boolean equals(Object other) { + if (this == other) { + return true; + } else if (other == null || getClass() != other.getClass()) { + return false; + } + + Employee employee = (Employee) other; + return Objects.equals(id, employee.id) && Objects.equals(dep, employee.dep); + } + + @Override + public int hashCode() { + return Objects.hash(id, dep); + } +} diff --git a/spark/v3.4/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/SparkExtensionsTestBase.java b/spark/v3.4/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/SparkExtensionsTestBase.java new file mode 100644 index 000000000000..4f137f5b8dac --- /dev/null +++ b/spark/v3.4/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/SparkExtensionsTestBase.java @@ -0,0 +1,71 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.extensions; + +import static org.apache.hadoop.hive.conf.HiveConf.ConfVars.METASTOREURIS; + +import java.util.Map; +import java.util.Random; +import java.util.concurrent.ThreadLocalRandom; +import org.apache.iceberg.CatalogUtil; +import org.apache.iceberg.hive.HiveCatalog; +import org.apache.iceberg.hive.TestHiveMetastore; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap; +import org.apache.iceberg.spark.SparkCatalogTestBase; +import org.apache.iceberg.spark.SparkTestBase; +import org.apache.spark.sql.SparkSession; +import org.apache.spark.sql.internal.SQLConf; +import org.junit.BeforeClass; + +public abstract class SparkExtensionsTestBase extends SparkCatalogTestBase { + + private static final Random RANDOM = ThreadLocalRandom.current(); + + public SparkExtensionsTestBase( + String catalogName, String implementation, Map config) { + super(catalogName, implementation, config); + } + + @BeforeClass + public static void startMetastoreAndSpark() { + SparkTestBase.metastore = new TestHiveMetastore(); + metastore.start(); + SparkTestBase.hiveConf = metastore.hiveConf(); + + SparkTestBase.spark = + SparkSession.builder() + .master("local[2]") + .config("spark.testing", "true") + .config(SQLConf.PARTITION_OVERWRITE_MODE().key(), "dynamic") + .config("spark.sql.extensions", IcebergSparkSessionExtensions.class.getName()) + .config("spark.hadoop." + METASTOREURIS.varname, hiveConf.get(METASTOREURIS.varname)) + .config("spark.sql.shuffle.partitions", "4") + .config("spark.sql.hive.metastorePartitionPruningFallbackOnException", "true") + .config("spark.sql.legacy.respectNullabilityInTextDatasetConversion", "true") + .config( + SQLConf.ADAPTIVE_EXECUTION_ENABLED().key(), String.valueOf(RANDOM.nextBoolean())) + .enableHiveSupport() + .getOrCreate(); + + SparkTestBase.catalog = + (HiveCatalog) + CatalogUtil.loadCatalog( + HiveCatalog.class.getName(), "hive", ImmutableMap.of(), hiveConf); + } +} diff --git a/spark/v3.4/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/SparkRowLevelOperationsTestBase.java b/spark/v3.4/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/SparkRowLevelOperationsTestBase.java new file mode 100644 index 000000000000..0ca63ae2bfa2 --- /dev/null +++ b/spark/v3.4/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/SparkRowLevelOperationsTestBase.java @@ -0,0 +1,343 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.extensions; + +import static org.apache.iceberg.DataOperations.DELETE; +import static org.apache.iceberg.DataOperations.OVERWRITE; +import static org.apache.iceberg.SnapshotSummary.ADDED_DELETE_FILES_PROP; +import static org.apache.iceberg.SnapshotSummary.ADDED_FILES_PROP; +import static org.apache.iceberg.SnapshotSummary.CHANGED_PARTITION_COUNT_PROP; +import static org.apache.iceberg.SnapshotSummary.DELETED_FILES_PROP; +import static org.apache.iceberg.TableProperties.DEFAULT_FILE_FORMAT; +import static org.apache.iceberg.TableProperties.PARQUET_VECTORIZATION_ENABLED; +import static org.apache.iceberg.TableProperties.WRITE_DISTRIBUTION_MODE; +import static org.apache.iceberg.TableProperties.WRITE_DISTRIBUTION_MODE_HASH; +import static org.apache.iceberg.TableProperties.WRITE_DISTRIBUTION_MODE_NONE; +import static org.apache.iceberg.TableProperties.WRITE_DISTRIBUTION_MODE_RANGE; + +import java.io.IOException; +import java.io.UncheckedIOException; +import java.util.Arrays; +import java.util.List; +import java.util.Map; +import java.util.Random; +import java.util.Set; +import java.util.concurrent.ThreadLocalRandom; +import java.util.stream.Collectors; +import org.apache.iceberg.DataFile; +import org.apache.iceberg.Files; +import org.apache.iceberg.Snapshot; +import org.apache.iceberg.SnapshotRef; +import org.apache.iceberg.Table; +import org.apache.iceberg.data.GenericRecord; +import org.apache.iceberg.data.parquet.GenericParquetWriter; +import org.apache.iceberg.io.DataWriter; +import org.apache.iceberg.io.OutputFile; +import org.apache.iceberg.parquet.Parquet; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap; +import org.apache.iceberg.spark.SparkCatalog; +import org.apache.iceberg.spark.SparkSessionCatalog; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Encoder; +import org.apache.spark.sql.Encoders; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.catalyst.analysis.NoSuchTableException; +import org.junit.Assert; +import org.junit.runner.RunWith; +import org.junit.runners.Parameterized; +import org.junit.runners.Parameterized.Parameters; + +@RunWith(Parameterized.class) +public abstract class SparkRowLevelOperationsTestBase extends SparkExtensionsTestBase { + + private static final Random RANDOM = ThreadLocalRandom.current(); + + protected final String fileFormat; + protected final boolean vectorized; + protected final String distributionMode; + protected final String branch; + + public SparkRowLevelOperationsTestBase( + String catalogName, + String implementation, + Map config, + String fileFormat, + boolean vectorized, + String distributionMode, + String branch) { + super(catalogName, implementation, config); + this.fileFormat = fileFormat; + this.vectorized = vectorized; + this.distributionMode = distributionMode; + this.branch = branch; + } + + @Parameters( + name = + "catalogName = {0}, implementation = {1}, config = {2}," + + " format = {3}, vectorized = {4}, distributionMode = {5}, branch = {6}") + public static Object[][] parameters() { + return new Object[][] { + { + "testhive", + SparkCatalog.class.getName(), + ImmutableMap.of( + "type", "hive", + "default-namespace", "default"), + "orc", + true, + WRITE_DISTRIBUTION_MODE_NONE, + SnapshotRef.MAIN_BRANCH + }, + { + "testhive", + SparkCatalog.class.getName(), + ImmutableMap.of( + "type", "hive", + "default-namespace", "default"), + "parquet", + true, + WRITE_DISTRIBUTION_MODE_NONE, + null, + }, + { + "testhadoop", + SparkCatalog.class.getName(), + ImmutableMap.of("type", "hadoop"), + "parquet", + RANDOM.nextBoolean(), + WRITE_DISTRIBUTION_MODE_HASH, + null + }, + { + "spark_catalog", + SparkSessionCatalog.class.getName(), + ImmutableMap.of( + "type", "hive", + "default-namespace", "default", + "clients", "1", + "parquet-enabled", "false", + "cache-enabled", + "false" // Spark will delete tables using v1, leaving the cache out of sync + ), + "avro", + false, + WRITE_DISTRIBUTION_MODE_RANGE, + "test" + } + }; + } + + protected abstract Map extraTableProperties(); + + protected void initTable() { + sql("ALTER TABLE %s SET TBLPROPERTIES('%s' '%s')", tableName, DEFAULT_FILE_FORMAT, fileFormat); + sql( + "ALTER TABLE %s SET TBLPROPERTIES('%s' '%s')", + tableName, WRITE_DISTRIBUTION_MODE, distributionMode); + + switch (fileFormat) { + case "parquet": + sql( + "ALTER TABLE %s SET TBLPROPERTIES('%s' '%b')", + tableName, PARQUET_VECTORIZATION_ENABLED, vectorized); + break; + case "orc": + Assert.assertTrue(vectorized); + break; + case "avro": + Assert.assertFalse(vectorized); + break; + } + + Map props = extraTableProperties(); + props.forEach( + (prop, value) -> { + sql("ALTER TABLE %s SET TBLPROPERTIES('%s' '%s')", tableName, prop, value); + }); + } + + protected void createAndInitTable(String schema) { + createAndInitTable(schema, null); + } + + protected void createAndInitTable(String schema, String jsonData) { + createAndInitTable(schema, "", jsonData); + } + + protected void createAndInitTable(String schema, String partitioning, String jsonData) { + sql("CREATE TABLE %s (%s) USING iceberg %s", tableName, schema, partitioning); + initTable(); + + if (jsonData != null) { + try { + Dataset ds = toDS(schema, jsonData); + ds.coalesce(1).writeTo(tableName).append(); + createBranchIfNeeded(); + } catch (NoSuchTableException e) { + throw new RuntimeException("Failed to write data", e); + } + } + } + + protected void append(String table, String jsonData) { + append(table, null, jsonData); + } + + protected void append(String table, String schema, String jsonData) { + try { + Dataset ds = toDS(schema, jsonData); + ds.coalesce(1).writeTo(table).append(); + } catch (NoSuchTableException e) { + throw new RuntimeException("Failed to write data", e); + } + } + + protected void createOrReplaceView(String name, String jsonData) { + createOrReplaceView(name, null, jsonData); + } + + protected void createOrReplaceView(String name, String schema, String jsonData) { + Dataset ds = toDS(schema, jsonData); + ds.createOrReplaceTempView(name); + } + + protected void createOrReplaceView(String name, List data, Encoder encoder) { + spark.createDataset(data, encoder).createOrReplaceTempView(name); + } + + private Dataset toDS(String schema, String jsonData) { + List jsonRows = + Arrays.stream(jsonData.split("\n")) + .filter(str -> str.trim().length() > 0) + .collect(Collectors.toList()); + Dataset jsonDS = spark.createDataset(jsonRows, Encoders.STRING()); + + if (schema != null) { + return spark.read().schema(schema).json(jsonDS); + } else { + return spark.read().json(jsonDS); + } + } + + protected void validateDelete( + Snapshot snapshot, String changedPartitionCount, String deletedDataFiles) { + validateSnapshot(snapshot, DELETE, changedPartitionCount, deletedDataFiles, null, null); + } + + protected void validateCopyOnWrite( + Snapshot snapshot, + String changedPartitionCount, + String deletedDataFiles, + String addedDataFiles) { + validateSnapshot( + snapshot, OVERWRITE, changedPartitionCount, deletedDataFiles, null, addedDataFiles); + } + + protected void validateMergeOnRead( + Snapshot snapshot, + String changedPartitionCount, + String addedDeleteFiles, + String addedDataFiles) { + validateSnapshot( + snapshot, OVERWRITE, changedPartitionCount, null, addedDeleteFiles, addedDataFiles); + } + + protected void validateSnapshot( + Snapshot snapshot, + String operation, + String changedPartitionCount, + String deletedDataFiles, + String addedDeleteFiles, + String addedDataFiles) { + Assert.assertEquals("Operation must match", operation, snapshot.operation()); + validateProperty(snapshot, CHANGED_PARTITION_COUNT_PROP, changedPartitionCount); + validateProperty(snapshot, DELETED_FILES_PROP, deletedDataFiles); + validateProperty(snapshot, ADDED_DELETE_FILES_PROP, addedDeleteFiles); + validateProperty(snapshot, ADDED_FILES_PROP, addedDataFiles); + } + + protected void validateProperty(Snapshot snapshot, String property, Set expectedValues) { + String actual = snapshot.summary().get(property); + Assert.assertTrue( + "Snapshot property " + + property + + " has unexpected value, actual = " + + actual + + ", expected one of : " + + String.join(",", expectedValues), + expectedValues.contains(actual)); + } + + protected void validateProperty(Snapshot snapshot, String property, String expectedValue) { + String actual = snapshot.summary().get(property); + Assert.assertEquals( + "Snapshot property " + property + " has unexpected value.", expectedValue, actual); + } + + protected void sleep(long millis) { + try { + Thread.sleep(millis); + } catch (InterruptedException e) { + throw new RuntimeException(e); + } + } + + protected DataFile writeDataFile(Table table, List records) { + try { + OutputFile file = Files.localOutput(temp.newFile()); + + DataWriter dataWriter = + Parquet.writeData(file) + .forTable(table) + .createWriterFunc(GenericParquetWriter::buildWriter) + .overwrite() + .build(); + + try { + for (GenericRecord record : records) { + dataWriter.write(record); + } + } finally { + dataWriter.close(); + } + + return dataWriter.toDataFile(); + + } catch (IOException e) { + throw new UncheckedIOException(e); + } + } + + @Override + protected String commitTarget() { + return branch == null ? tableName : String.format("%s.branch_%s", tableName, branch); + } + + @Override + protected String selectTarget() { + return branch == null ? tableName : String.format("%s VERSION AS OF '%s'", tableName, branch); + } + + protected void createBranchIfNeeded() { + if (branch != null && !branch.equals(SnapshotRef.MAIN_BRANCH)) { + sql("ALTER TABLE %s CREATE BRANCH %s", tableName, branch); + } + } +} diff --git a/spark/v3.4/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestAddFilesProcedure.java b/spark/v3.4/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestAddFilesProcedure.java new file mode 100644 index 000000000000..e02065bd6347 --- /dev/null +++ b/spark/v3.4/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestAddFilesProcedure.java @@ -0,0 +1,1138 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.extensions; + +import java.io.File; +import java.io.IOException; +import java.util.List; +import java.util.Map; +import java.util.regex.Matcher; +import java.util.regex.Pattern; +import java.util.stream.Collectors; +import org.apache.avro.Schema; +import org.apache.avro.SchemaBuilder; +import org.apache.avro.file.DataFileWriter; +import org.apache.avro.generic.GenericData; +import org.apache.avro.generic.GenericDatumWriter; +import org.apache.avro.generic.GenericRecord; +import org.apache.avro.io.DatumWriter; +import org.apache.iceberg.AssertHelpers; +import org.apache.iceberg.DataFile; +import org.apache.iceberg.TableProperties; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableList; +import org.apache.iceberg.relocated.com.google.common.collect.Lists; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.RowFactory; +import org.apache.spark.sql.types.DataTypes; +import org.apache.spark.sql.types.Metadata; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; +import org.joda.time.DateTime; +import org.junit.After; +import org.junit.Assert; +import org.junit.Assume; +import org.junit.Before; +import org.junit.Ignore; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.TemporaryFolder; + +public class TestAddFilesProcedure extends SparkExtensionsTestBase { + + private final String sourceTableName = "source_table"; + private File fileTableDir; + + public TestAddFilesProcedure( + String catalogName, String implementation, Map config) { + super(catalogName, implementation, config); + } + + @Rule public TemporaryFolder temp = new TemporaryFolder(); + + @Before + public void setupTempDirs() { + try { + fileTableDir = temp.newFolder(); + } catch (IOException e) { + throw new RuntimeException(e); + } + } + + @After + public void dropTables() { + sql("DROP TABLE IF EXISTS %s", sourceTableName); + sql("DROP TABLE IF EXISTS %s", tableName); + } + + @Test + public void addDataUnpartitioned() { + createUnpartitionedFileTable("parquet"); + + String createIceberg = + "CREATE TABLE %s (id Integer, name String, dept String, subdept String) USING iceberg"; + + sql(createIceberg, tableName); + + List result = + sql( + "CALL %s.system.add_files('%s', '`parquet`.`%s`')", + catalogName, tableName, fileTableDir.getAbsolutePath()); + + assertEquals("Procedure output must match", ImmutableList.of(row(2L, 1L)), result); + + assertEquals( + "Iceberg table contains correct data", + sql("SELECT * FROM %s ORDER BY id", sourceTableName), + sql("SELECT * FROM %s ORDER BY id", tableName)); + } + + @Test + public void deleteAndAddBackUnpartitioned() { + createUnpartitionedFileTable("parquet"); + + String createIceberg = + "CREATE TABLE %s (id Integer, name String, dept String, subdept String) USING iceberg"; + + sql(createIceberg, tableName); + + sql( + "CALL %s.system.add_files('%s', '`parquet`.`%s`')", + catalogName, tableName, fileTableDir.getAbsolutePath()); + + String deleteData = "DELETE FROM %s"; + sql(deleteData, tableName); + + List result = + sql( + "CALL %s.system.add_files('%s', '`parquet`.`%s`')", + catalogName, tableName, fileTableDir.getAbsolutePath()); + + assertEquals("Procedure output must match", ImmutableList.of(row(2L, 1L)), result); + + assertEquals( + "Iceberg table contains correct data", + sql("SELECT * FROM %s ORDER BY id", sourceTableName), + sql("SELECT * FROM %s ORDER BY id", tableName)); + } + + @Ignore // TODO Classpath issues prevent us from actually writing to a Spark ORC table + public void addDataUnpartitionedOrc() { + createUnpartitionedFileTable("orc"); + + String createIceberg = + "CREATE TABLE %s (id Integer, name String, dept String, subdept String) USING iceberg"; + + sql(createIceberg, tableName); + + Object result = + scalarSql( + "CALL %s.system.add_files('%s', '`orc`.`%s`')", + catalogName, tableName, fileTableDir.getAbsolutePath()); + + Assert.assertEquals(2L, result); + + assertEquals( + "Iceberg table contains correct data", + sql("SELECT * FROM %s ORDER BY id", sourceTableName), + sql("SELECT * FROM %s ORDER BY id", tableName)); + } + + @Test + public void addAvroFile() throws Exception { + // Spark Session Catalog cannot load metadata tables + // with "The namespace in session catalog must have exactly one name part" + Assume.assumeFalse(catalogName.equals("spark_catalog")); + + // Create an Avro file + + Schema schema = + SchemaBuilder.record("record") + .fields() + .requiredInt("id") + .requiredString("data") + .endRecord(); + GenericRecord record1 = new GenericData.Record(schema); + record1.put("id", 1L); + record1.put("data", "a"); + GenericRecord record2 = new GenericData.Record(schema); + record2.put("id", 2L); + record2.put("data", "b"); + File outputFile = temp.newFile("test.avro"); + + DatumWriter datumWriter = new GenericDatumWriter(schema); + DataFileWriter dataFileWriter = new DataFileWriter(datumWriter); + dataFileWriter.create(schema, outputFile); + dataFileWriter.append(record1); + dataFileWriter.append(record2); + dataFileWriter.close(); + + String createIceberg = "CREATE TABLE %s (id Long, data String) USING iceberg"; + sql(createIceberg, tableName); + + List result = + sql( + "CALL %s.system.add_files('%s', '`avro`.`%s`')", + catalogName, tableName, outputFile.getPath()); + assertEquals("Procedure output must match", ImmutableList.of(row(1L, 1L)), result); + + List expected = Lists.newArrayList(new Object[] {1L, "a"}, new Object[] {2L, "b"}); + + assertEquals( + "Iceberg table contains correct data", + expected, + sql("SELECT * FROM %s ORDER BY id", tableName)); + + List actualRecordCount = + sql("select %s from %s.files", DataFile.RECORD_COUNT.name(), tableName); + List expectedRecordCount = Lists.newArrayList(); + expectedRecordCount.add(new Object[] {2L}); + assertEquals( + "Iceberg file metadata should have correct metadata count", + expectedRecordCount, + actualRecordCount); + } + + // TODO Adding spark-avro doesn't work in tests + @Ignore + public void addDataUnpartitionedAvro() { + createUnpartitionedFileTable("avro"); + + String createIceberg = + "CREATE TABLE %s (id Integer, name String, dept String, subdept String) USING iceberg"; + + sql(createIceberg, tableName); + + Object result = + scalarSql( + "CALL %s.system.add_files('%s', '`avro`.`%s`')", + catalogName, tableName, fileTableDir.getAbsolutePath()); + + Assert.assertEquals(2L, result); + + assertEquals( + "Iceberg table contains correct data", + sql("SELECT * FROM %s ORDER BY id", sourceTableName), + sql("SELECT * FROM %s ORDER BY id", tableName)); + } + + @Test + public void addDataUnpartitionedHive() { + createUnpartitionedHiveTable(); + + String createIceberg = + "CREATE TABLE %s (id Integer, name String, dept String, subdept String) USING iceberg"; + + sql(createIceberg, tableName); + + List result = + sql("CALL %s.system.add_files('%s', '%s')", catalogName, tableName, sourceTableName); + + assertEquals("Procedure output must match", ImmutableList.of(row(2L, 1L)), result); + + assertEquals( + "Iceberg table contains correct data", + sql("SELECT * FROM %s ORDER BY id", sourceTableName), + sql("SELECT * FROM %s ORDER BY id", tableName)); + } + + @Test + public void addDataUnpartitionedExtraCol() { + createUnpartitionedFileTable("parquet"); + + String createIceberg = + "CREATE TABLE %s (id Integer, name String, dept String, subdept String, foo string) USING iceberg"; + + sql(createIceberg, tableName); + + List result = + sql( + "CALL %s.system.add_files('%s', '`parquet`.`%s`')", + catalogName, tableName, fileTableDir.getAbsolutePath()); + + assertEquals("Procedure output must match", ImmutableList.of(row(2L, 1L)), result); + + assertEquals( + "Iceberg table contains correct data", + sql("SELECT * FROM %s ORDER BY id", sourceTableName), + sql("SELECT id, name, dept, subdept FROM %s ORDER BY id", tableName)); + } + + @Test + public void addDataUnpartitionedMissingCol() { + createUnpartitionedFileTable("parquet"); + + String createIceberg = "CREATE TABLE %s (id Integer, name String, dept String) USING iceberg"; + + sql(createIceberg, tableName); + + List result = + sql( + "CALL %s.system.add_files('%s', '`parquet`.`%s`')", + catalogName, tableName, fileTableDir.getAbsolutePath()); + + assertEquals("Procedure output must match", ImmutableList.of(row(2L, 1L)), result); + + assertEquals( + "Iceberg table contains correct data", + sql("SELECT id, name, dept FROM %s ORDER BY id", sourceTableName), + sql("SELECT * FROM %s ORDER BY id", tableName)); + } + + @Test + public void addDataPartitionedMissingCol() { + createPartitionedFileTable("parquet"); + + String createIceberg = + "CREATE TABLE %s (id Integer, name String, dept String) USING iceberg PARTITIONED BY (id)"; + + sql(createIceberg, tableName); + + List result = + sql( + "CALL %s.system.add_files('%s', '`parquet`.`%s`')", + catalogName, tableName, fileTableDir.getAbsolutePath()); + + assertEquals("Procedure output must match", ImmutableList.of(row(8L, 4L)), result); + + assertEquals( + "Iceberg table contains correct data", + sql("SELECT id, name, dept FROM %s ORDER BY id", sourceTableName), + sql("SELECT * FROM %s ORDER BY id", tableName)); + } + + @Test + public void addDataPartitioned() { + createPartitionedFileTable("parquet"); + + String createIceberg = + "CREATE TABLE %s (id Integer, name String, dept String, subdept String) USING iceberg PARTITIONED BY (id)"; + + sql(createIceberg, tableName); + + List result = + sql( + "CALL %s.system.add_files('%s', '`parquet`.`%s`')", + catalogName, tableName, fileTableDir.getAbsolutePath()); + + assertEquals("Procedure output must match", ImmutableList.of(row(8L, 4L)), result); + + assertEquals( + "Iceberg table contains correct data", + sql("SELECT id, name, dept, subdept FROM %s ORDER BY id", sourceTableName), + sql("SELECT id, name, dept, subdept FROM %s ORDER BY id", tableName)); + } + + @Ignore // TODO Classpath issues prevent us from actually writing to a Spark ORC table + public void addDataPartitionedOrc() { + createPartitionedFileTable("orc"); + + String createIceberg = + "CREATE TABLE %s (id Integer, name String, dept String, subdept String) USING iceberg PARTITIONED BY (id)"; + + sql(createIceberg, tableName); + + Object result = + scalarSql( + "CALL %s.system.add_files('%s', '`parquet`.`%s`')", + catalogName, tableName, fileTableDir.getAbsolutePath()); + + Assert.assertEquals(8L, result); + + assertEquals( + "Iceberg table contains correct data", + sql("SELECT id, name, dept, subdept FROM %s ORDER BY id", sourceTableName), + sql("SELECT id, name, dept, subdept FROM %s ORDER BY id", tableName)); + } + + // TODO Adding spark-avro doesn't work in tests + @Ignore + public void addDataPartitionedAvro() { + createPartitionedFileTable("avro"); + + String createIceberg = + "CREATE TABLE %s (id Integer, name String, dept String, subdept String) USING iceberg PARTITIONED BY (id)"; + + sql(createIceberg, tableName); + + Object result = + scalarSql( + "CALL %s.system.add_files('%s', '`avro`.`%s`')", + catalogName, tableName, fileTableDir.getAbsolutePath()); + + Assert.assertEquals(8L, result); + + assertEquals( + "Iceberg table contains correct data", + sql("SELECT id, name, dept, subdept FROM %s ORDER BY id", sourceTableName), + sql("SELECT id, name, dept, subdept FROM %s ORDER BY id", tableName)); + } + + @Test + public void addDataPartitionedHive() { + createPartitionedHiveTable(); + + String createIceberg = + "CREATE TABLE %s (id Integer, name String, dept String, subdept String) USING iceberg PARTITIONED BY (id)"; + + sql(createIceberg, tableName); + + List result = + sql("CALL %s.system.add_files('%s', '%s')", catalogName, tableName, sourceTableName); + + assertEquals("Procedure output must match", ImmutableList.of(row(8L, 4L)), result); + + assertEquals( + "Iceberg table contains correct data", + sql("SELECT id, name, dept, subdept FROM %s ORDER BY id", sourceTableName), + sql("SELECT id, name, dept, subdept FROM %s ORDER BY id", tableName)); + } + + @Test + public void addPartitionToPartitioned() { + createPartitionedFileTable("parquet"); + + String createIceberg = + "CREATE TABLE %s (id Integer, name String, dept String, subdept String) USING iceberg PARTITIONED BY (id)"; + + sql(createIceberg, tableName); + + List result = + sql( + "CALL %s.system.add_files('%s', '`parquet`.`%s`', map('id', 1))", + catalogName, tableName, fileTableDir.getAbsolutePath()); + + assertEquals("Procedure output must match", ImmutableList.of(row(2L, 1L)), result); + + assertEquals( + "Iceberg table contains correct data", + sql("SELECT id, name, dept, subdept FROM %s WHERE id = 1 ORDER BY id", sourceTableName), + sql("SELECT id, name, dept, subdept FROM %s ORDER BY id", tableName)); + } + + @Test + public void deleteAndAddBackPartitioned() { + createPartitionedFileTable("parquet"); + + String createIceberg = + "CREATE TABLE %s (id Integer, name String, dept String, subdept String) USING iceberg PARTITIONED BY (id)"; + + sql(createIceberg, tableName); + + sql( + "CALL %s.system.add_files('%s', '`parquet`.`%s`', map('id', 1))", + catalogName, tableName, fileTableDir.getAbsolutePath()); + + String deleteData = "DELETE FROM %s where id = 1"; + sql(deleteData, tableName); + + List result = + sql( + "CALL %s.system.add_files('%s', '`parquet`.`%s`', map('id', 1))", + catalogName, tableName, fileTableDir.getAbsolutePath()); + + assertEquals("Procedure output must match", ImmutableList.of(row(2L, 1L)), result); + + assertEquals( + "Iceberg table contains correct data", + sql("SELECT id, name, dept, subdept FROM %s WHERE id = 1 ORDER BY id", sourceTableName), + sql("SELECT * FROM %s ORDER BY id", tableName)); + } + + @Test + public void addPartitionToPartitionedSnapshotIdInheritanceEnabledInTwoRuns() { + createPartitionedFileTable("parquet"); + + String createIceberg = + "CREATE TABLE %s (id Integer, name String, dept String, subdept String) USING iceberg PARTITIONED BY (id)" + + "TBLPROPERTIES ('%s'='true')"; + + sql(createIceberg, tableName, TableProperties.SNAPSHOT_ID_INHERITANCE_ENABLED); + + sql( + "CALL %s.system.add_files('%s', '`parquet`.`%s`', map('id', 1))", + catalogName, tableName, fileTableDir.getAbsolutePath()); + + sql( + "CALL %s.system.add_files('%s', '`parquet`.`%s`', map('id', 2))", + catalogName, tableName, fileTableDir.getAbsolutePath()); + + assertEquals( + "Iceberg table contains correct data", + sql("SELECT id, name, dept, subdept FROM %s WHERE id < 3 ORDER BY id", sourceTableName), + sql("SELECT id, name, dept, subdept FROM %s ORDER BY id", tableName)); + + // verify manifest file name has uuid pattern + String manifestPath = (String) sql("select path from %s.manifests", tableName).get(0)[0]; + + Pattern uuidPattern = Pattern.compile("[a-f0-9]{8}(?:-[a-f0-9]{4}){4}[a-f0-9]{8}"); + + Matcher matcher = uuidPattern.matcher(manifestPath); + Assert.assertTrue("verify manifest path has uuid", matcher.find()); + } + + @Test + public void addDataPartitionedByDateToPartitioned() { + createDatePartitionedFileTable("parquet"); + + String createIceberg = + "CREATE TABLE %s (id Integer, name String, date Date) USING iceberg PARTITIONED BY (date)"; + + sql(createIceberg, tableName); + + List result = + sql( + "CALL %s.system.add_files('%s', '`parquet`.`%s`', map('date', '2021-01-01'))", + catalogName, tableName, fileTableDir.getAbsolutePath()); + + assertEquals("Procedure output must match", ImmutableList.of(row(2L, 1L)), result); + + assertEquals( + "Iceberg table contains correct data", + sql("SELECT id, name, date FROM %s WHERE date = '2021-01-01' ORDER BY id", sourceTableName), + sql("SELECT id, name, date FROM %s ORDER BY id", tableName)); + } + + @Test + public void addDataPartitionedVerifyPartitionTypeInferredCorrectly() { + createTableWithTwoPartitions("parquet"); + + String createIceberg = + "CREATE TABLE %s (id Integer, name String, date Date, dept String) USING iceberg PARTITIONED BY (date, dept)"; + + sql(createIceberg, tableName); + + sql( + "CALL %s.system.add_files('%s', '`parquet`.`%s`', map('date', '2021-01-01'))", + catalogName, tableName, fileTableDir.getAbsolutePath()); + + String sqlFormat = + "SELECT id, name, dept, date FROM %s WHERE date = '2021-01-01' and dept= '01' ORDER BY id"; + assertEquals( + "Iceberg table contains correct data", + sql(sqlFormat, sourceTableName), + sql(sqlFormat, tableName)); + } + + @Test + public void addFilteredPartitionsToPartitioned() { + createCompositePartitionedTable("parquet"); + + String createIceberg = + "CREATE TABLE %s (id Integer, name String, dept String, subdept String) USING iceberg " + + "PARTITIONED BY (id, dept)"; + + sql(createIceberg, tableName); + + List result = + sql( + "CALL %s.system.add_files('%s', '`parquet`.`%s`', map('id', 1))", + catalogName, tableName, fileTableDir.getAbsolutePath()); + + assertEquals("Procedure output must match", ImmutableList.of(row(2L, 1L)), result); + + assertEquals( + "Iceberg table contains correct data", + sql("SELECT id, name, dept, subdept FROM %s WHERE id = 1 ORDER BY id", sourceTableName), + sql("SELECT id, name, dept, subdept FROM %s ORDER BY id", tableName)); + } + + @Test + public void addFilteredPartitionsToPartitioned2() { + createCompositePartitionedTable("parquet"); + + String createIceberg = + "CREATE TABLE %s (id Integer, name String, dept String, subdept String) USING iceberg " + + "PARTITIONED BY (id, dept)"; + + sql(createIceberg, tableName); + + List result = + sql( + "CALL %s.system.add_files('%s', '`parquet`.`%s`', map('dept', 'hr'))", + catalogName, tableName, fileTableDir.getAbsolutePath()); + + assertEquals("Procedure output must match", ImmutableList.of(row(6L, 3L)), result); + + assertEquals( + "Iceberg table contains correct data", + sql( + "SELECT id, name, dept, subdept FROM %s WHERE dept = 'hr' ORDER BY id", + sourceTableName), + sql("SELECT id, name, dept, subdept FROM %s ORDER BY id", tableName)); + } + + @Test + public void addFilteredPartitionsToPartitionedWithNullValueFilteringOnId() { + createCompositePartitionedTableWithNullValueInPartitionColumn("parquet"); + + String createIceberg = + "CREATE TABLE %s (id Integer, name String, dept String, subdept String) USING iceberg " + + "PARTITIONED BY (id, dept)"; + + sql(createIceberg, tableName); + + List result = + sql( + "CALL %s.system.add_files('%s', '`parquet`.`%s`', map('id', 1))", + catalogName, tableName, fileTableDir.getAbsolutePath()); + + assertEquals("Procedure output must match", ImmutableList.of(row(2L, 1L)), result); + + assertEquals( + "Iceberg table contains correct data", + sql("SELECT id, name, dept, subdept FROM %s WHERE id = 1 ORDER BY id", sourceTableName), + sql("SELECT id, name, dept, subdept FROM %s ORDER BY id", tableName)); + } + + @Test + public void addFilteredPartitionsToPartitionedWithNullValueFilteringOnDept() { + createCompositePartitionedTableWithNullValueInPartitionColumn("parquet"); + + String createIceberg = + "CREATE TABLE %s (id Integer, name String, dept String, subdept String) USING iceberg " + + "PARTITIONED BY (id, dept)"; + + sql(createIceberg, tableName); + + List result = + sql( + "CALL %s.system.add_files('%s', '`parquet`.`%s`', map('dept', 'hr'))", + catalogName, tableName, fileTableDir.getAbsolutePath()); + + assertEquals("Procedure output must match", ImmutableList.of(row(6L, 3L)), result); + + assertEquals( + "Iceberg table contains correct data", + sql( + "SELECT id, name, dept, subdept FROM %s WHERE dept = 'hr' ORDER BY id", + sourceTableName), + sql("SELECT id, name, dept, subdept FROM %s ORDER BY id", tableName)); + } + + @Test + public void addWeirdCaseHiveTable() { + createWeirdCaseTable(); + + String createIceberg = + "CREATE TABLE %s (id Integer, `naMe` String, dept String, subdept String) USING iceberg " + + "PARTITIONED BY (`naMe`)"; + + sql(createIceberg, tableName); + + List result = + sql( + "CALL %s.system.add_files('%s', '%s', map('naMe', 'John Doe'))", + catalogName, tableName, sourceTableName); + + assertEquals("Procedure output must match", ImmutableList.of(row(2L, 1L)), result); + + /* + While we would like to use + SELECT id, `naMe`, dept, subdept FROM %s WHERE `naMe` = 'John Doe' ORDER BY id + Spark does not actually handle this pushdown correctly for hive based tables and it returns 0 records + */ + List expected = + sql("SELECT id, `naMe`, dept, subdept from %s ORDER BY id", sourceTableName).stream() + .filter(r -> r[1].equals("John Doe")) + .collect(Collectors.toList()); + + // TODO when this assert breaks Spark fixed the pushdown issue + Assert.assertEquals( + "If this assert breaks it means that Spark has fixed the pushdown issue", + 0, + sql( + "SELECT id, `naMe`, dept, subdept from %s WHERE `naMe` = 'John Doe' ORDER BY id", + sourceTableName) + .size()); + + // Pushdown works for iceberg + Assert.assertEquals( + "We should be able to pushdown mixed case partition keys", + 2, + sql( + "SELECT id, `naMe`, dept, subdept FROM %s WHERE `naMe` = 'John Doe' ORDER BY id", + tableName) + .size()); + + assertEquals( + "Iceberg table contains correct data", + expected, + sql("SELECT id, `naMe`, dept, subdept FROM %s ORDER BY id", tableName)); + } + + @Test + public void addPartitionToPartitionedHive() { + createPartitionedHiveTable(); + + String createIceberg = + "CREATE TABLE %s (id Integer, name String, dept String, subdept String) USING iceberg PARTITIONED BY (id)"; + + sql(createIceberg, tableName); + + List result = + sql( + "CALL %s.system.add_files('%s', '%s', map('id', 1))", + catalogName, tableName, sourceTableName); + + assertEquals("Procedure output must match", ImmutableList.of(row(2L, 1L)), result); + + assertEquals( + "Iceberg table contains correct data", + sql("SELECT id, name, dept, subdept FROM %s WHERE id = 1 ORDER BY id", sourceTableName), + sql("SELECT id, name, dept, subdept FROM %s ORDER BY id", tableName)); + } + + @Test + public void invalidDataImport() { + createPartitionedFileTable("parquet"); + + String createIceberg = + "CREATE TABLE %s (id Integer, name String, dept String, subdept String) USING iceberg"; + + sql(createIceberg, tableName); + + AssertHelpers.assertThrows( + "Should forbid adding of partitioned data to unpartitioned table", + IllegalArgumentException.class, + "Cannot use partition filter with an unpartitioned table", + () -> + scalarSql( + "CALL %s.system.add_files('%s', '`parquet`.`%s`', map('id', 1))", + catalogName, tableName, fileTableDir.getAbsolutePath())); + + AssertHelpers.assertThrows( + "Should forbid adding of partitioned data to unpartitioned table", + IllegalArgumentException.class, + "Cannot add partitioned files to an unpartitioned table", + () -> + scalarSql( + "CALL %s.system.add_files('%s', '`parquet`.`%s`')", + catalogName, tableName, fileTableDir.getAbsolutePath())); + } + + @Test + public void invalidDataImportPartitioned() { + createUnpartitionedFileTable("parquet"); + + String createIceberg = + "CREATE TABLE %s (id Integer, name String, dept String, subdept String) USING iceberg PARTITIONED BY (id)"; + + sql(createIceberg, tableName); + + AssertHelpers.assertThrows( + "Should forbid adding with a mismatching partition spec", + IllegalArgumentException.class, + "is greater than the number of partitioned columns", + () -> + scalarSql( + "CALL %s.system.add_files('%s', '`parquet`.`%s`', map('x', '1', 'y', '2'))", + catalogName, tableName, fileTableDir.getAbsolutePath())); + + AssertHelpers.assertThrows( + "Should forbid adding with partition spec with incorrect columns", + IllegalArgumentException.class, + "specified partition filter refers to columns that are not partitioned", + () -> + scalarSql( + "CALL %s.system.add_files('%s', '`parquet`.`%s`', map('dept', '2'))", + catalogName, tableName, fileTableDir.getAbsolutePath())); + } + + @Test + public void addTwice() { + createPartitionedHiveTable(); + + String createIceberg = + "CREATE TABLE %s (id Integer, name String, dept String, subdept String) USING iceberg PARTITIONED BY (id)"; + + sql(createIceberg, tableName); + + List result1 = + sql( + "CALL %s.system.add_files(" + + "table => '%s', " + + "source_table => '%s', " + + "partition_filter => map('id', 1))", + catalogName, tableName, sourceTableName); + assertEquals("Procedure output must match", ImmutableList.of(row(2L, 1L)), result1); + + List result2 = + sql( + "CALL %s.system.add_files(" + + "table => '%s', " + + "source_table => '%s', " + + "partition_filter => map('id', 2))", + catalogName, tableName, sourceTableName); + assertEquals("Procedure output must match", ImmutableList.of(row(2L, 1L)), result2); + + assertEquals( + "Iceberg table contains correct data", + sql("SELECT id, name, dept, subdept FROM %s WHERE id = 1 ORDER BY id", sourceTableName), + sql("SELECT id, name, dept, subdept FROM %s WHERE id = 1 ORDER BY id", tableName)); + assertEquals( + "Iceberg table contains correct data", + sql("SELECT id, name, dept, subdept FROM %s WHERE id = 2 ORDER BY id", sourceTableName), + sql("SELECT id, name, dept, subdept FROM %s WHERE id = 2 ORDER BY id", tableName)); + } + + @Test + public void duplicateDataPartitioned() { + createPartitionedHiveTable(); + + String createIceberg = + "CREATE TABLE %s (id Integer, name String, dept String, subdept String) USING iceberg PARTITIONED BY (id)"; + + sql(createIceberg, tableName); + + sql( + "CALL %s.system.add_files(" + + "table => '%s', " + + "source_table => '%s', " + + "partition_filter => map('id', 1))", + catalogName, tableName, sourceTableName); + + AssertHelpers.assertThrows( + "Should not allow adding duplicate files", + IllegalStateException.class, + "Cannot complete import because data files to be imported already" + + " exist within the target table", + () -> + scalarSql( + "CALL %s.system.add_files(" + + "table => '%s', " + + "source_table => '%s', " + + "partition_filter => map('id', 1))", + catalogName, tableName, sourceTableName)); + } + + @Test + public void duplicateDataPartitionedAllowed() { + createPartitionedHiveTable(); + + String createIceberg = + "CREATE TABLE %s (id Integer, name String, dept String, subdept String) USING iceberg PARTITIONED BY (id)"; + + sql(createIceberg, tableName); + + List result1 = + sql( + "CALL %s.system.add_files(" + + "table => '%s', " + + "source_table => '%s', " + + "partition_filter => map('id', 1))", + catalogName, tableName, sourceTableName); + + assertEquals("Procedure output must match", ImmutableList.of(row(2L, 1L)), result1); + + List result2 = + sql( + "CALL %s.system.add_files(" + + "table => '%s', " + + "source_table => '%s', " + + "partition_filter => map('id', 1)," + + "check_duplicate_files => false)", + catalogName, tableName, sourceTableName); + + assertEquals("Procedure output must match", ImmutableList.of(row(2L, 1L)), result2); + + assertEquals( + "Iceberg table contains correct data", + sql( + "SELECT id, name, dept, subdept FROM %s WHERE id = 1 UNION ALL " + + "SELECT id, name, dept, subdept FROM %s WHERE id = 1", + sourceTableName, sourceTableName), + sql("SELECT id, name, dept, subdept FROM %s", tableName, tableName)); + } + + @Test + public void duplicateDataUnpartitioned() { + createUnpartitionedHiveTable(); + + String createIceberg = + "CREATE TABLE %s (id Integer, name String, dept String, subdept String) USING iceberg"; + + sql(createIceberg, tableName); + + sql("CALL %s.system.add_files('%s', '%s')", catalogName, tableName, sourceTableName); + + AssertHelpers.assertThrows( + "Should not allow adding duplicate files", + IllegalStateException.class, + "Cannot complete import because data files to be imported already" + + " exist within the target table", + () -> + scalarSql( + "CALL %s.system.add_files('%s', '%s')", catalogName, tableName, sourceTableName)); + } + + @Test + public void duplicateDataUnpartitionedAllowed() { + createUnpartitionedHiveTable(); + + String createIceberg = + "CREATE TABLE %s (id Integer, name String, dept String, subdept String) USING iceberg"; + + sql(createIceberg, tableName); + + List result1 = + sql("CALL %s.system.add_files('%s', '%s')", catalogName, tableName, sourceTableName); + assertEquals("Procedure output must match", ImmutableList.of(row(2L, 1L)), result1); + + List result2 = + sql( + "CALL %s.system.add_files(" + + "table => '%s', " + + "source_table => '%s'," + + "check_duplicate_files => false)", + catalogName, tableName, sourceTableName); + assertEquals("Procedure output must match", ImmutableList.of(row(2L, 1L)), result2); + + assertEquals( + "Iceberg table contains correct data", + sql( + "SELECT * FROM (SELECT * FROM %s UNION ALL " + "SELECT * from %s) ORDER BY id", + sourceTableName, sourceTableName), + sql("SELECT * FROM %s ORDER BY id", tableName)); + } + + @Test + public void testEmptyImportDoesNotThrow() { + + String createIceberg = + "CREATE TABLE %s (id Integer, name String, dept String, subdept String) USING iceberg"; + sql(createIceberg, tableName); + + // Empty path based import + List pathResult = + sql( + "CALL %s.system.add_files('%s', '`parquet`.`%s`')", + catalogName, tableName, fileTableDir.getAbsolutePath()); + assertEquals("Procedure output must match", ImmutableList.of(row(0L, 0L)), pathResult); + assertEquals( + "Iceberg table contains no added data when importing from an empty path", + emptyQueryResult, + sql("SELECT * FROM %s ORDER BY id", tableName)); + + // Empty table based import + String createHive = + "CREATE TABLE %s (id Integer, name String, dept String, subdept String) STORED AS parquet"; + sql(createHive, sourceTableName); + + List tableResult = + sql("CALL %s.system.add_files('%s', '%s')", catalogName, tableName, sourceTableName); + assertEquals("Procedure output must match", ImmutableList.of(row(0L, 0L)), tableResult); + assertEquals( + "Iceberg table contains no added data when importing from an empty table", + emptyQueryResult, + sql("SELECT * FROM %s ORDER BY id", tableName)); + } + + @Test + public void testPartitionedImportFromEmptyPartitionDoesNotThrow() { + createPartitionedHiveTable(); + + final int emptyPartitionId = 999; + // Add an empty partition to the hive table + sql( + "ALTER TABLE %s ADD PARTITION (id = '%d') LOCATION '%d'", + sourceTableName, emptyPartitionId, emptyPartitionId); + + String createIceberg = + "CREATE TABLE %s (id Integer, name String, dept String, subdept String) USING iceberg PARTITIONED BY (id)"; + + sql(createIceberg, tableName); + + List tableResult = + sql( + "CALL %s.system.add_files(" + + "table => '%s', " + + "source_table => '%s', " + + "partition_filter => map('id', %d))", + catalogName, tableName, sourceTableName, emptyPartitionId); + + assertEquals("Procedure output must match", ImmutableList.of(row(0L, 0L)), tableResult); + assertEquals( + "Iceberg table contains no added data when importing from an empty table", + emptyQueryResult, + sql("SELECT * FROM %s ORDER BY id", tableName)); + } + + private static final List emptyQueryResult = Lists.newArrayList(); + + private static final StructField[] struct = { + new StructField("id", DataTypes.IntegerType, true, Metadata.empty()), + new StructField("name", DataTypes.StringType, true, Metadata.empty()), + new StructField("dept", DataTypes.StringType, true, Metadata.empty()), + new StructField("subdept", DataTypes.StringType, true, Metadata.empty()) + }; + + private static final Dataset unpartitionedDF = + spark + .createDataFrame( + ImmutableList.of( + RowFactory.create(1, "John Doe", "hr", "communications"), + RowFactory.create(2, "Jane Doe", "hr", "salary"), + RowFactory.create(3, "Matt Doe", "hr", "communications"), + RowFactory.create(4, "Will Doe", "facilities", "all")), + new StructType(struct)) + .repartition(1); + + private static final Dataset singleNullRecordDF = + spark + .createDataFrame( + ImmutableList.of(RowFactory.create(null, null, null, null)), new StructType(struct)) + .repartition(1); + + private static final Dataset partitionedDF = + unpartitionedDF.select("name", "dept", "subdept", "id"); + + private static final Dataset compositePartitionedDF = + unpartitionedDF.select("name", "subdept", "id", "dept"); + + private static final Dataset compositePartitionedNullRecordDF = + singleNullRecordDF.select("name", "subdept", "id", "dept"); + + private static final Dataset weirdColumnNamesDF = + unpartitionedDF.select( + unpartitionedDF.col("id"), + unpartitionedDF.col("subdept"), + unpartitionedDF.col("dept"), + unpartitionedDF.col("name").as("naMe")); + + private static final StructField[] dateStruct = { + new StructField("id", DataTypes.IntegerType, true, Metadata.empty()), + new StructField("name", DataTypes.StringType, true, Metadata.empty()), + new StructField("ts", DataTypes.DateType, true, Metadata.empty()), + new StructField("dept", DataTypes.StringType, true, Metadata.empty()), + }; + + private static java.sql.Date toDate(String value) { + return new java.sql.Date(DateTime.parse(value).getMillis()); + } + + private static final Dataset dateDF = + spark + .createDataFrame( + ImmutableList.of( + RowFactory.create(1, "John Doe", toDate("2021-01-01"), "01"), + RowFactory.create(2, "Jane Doe", toDate("2021-01-01"), "01"), + RowFactory.create(3, "Matt Doe", toDate("2021-01-02"), "02"), + RowFactory.create(4, "Will Doe", toDate("2021-01-02"), "02")), + new StructType(dateStruct)) + .repartition(2); + + private void createUnpartitionedFileTable(String format) { + String createParquet = + "CREATE TABLE %s (id Integer, name String, dept String, subdept String) USING %s LOCATION '%s'"; + + sql(createParquet, sourceTableName, format, fileTableDir.getAbsolutePath()); + unpartitionedDF.write().insertInto(sourceTableName); + unpartitionedDF.write().insertInto(sourceTableName); + } + + private void createPartitionedFileTable(String format) { + String createParquet = + "CREATE TABLE %s (id Integer, name String, dept String, subdept String) USING %s PARTITIONED BY (id) " + + "LOCATION '%s'"; + + sql(createParquet, sourceTableName, format, fileTableDir.getAbsolutePath()); + + partitionedDF.write().insertInto(sourceTableName); + partitionedDF.write().insertInto(sourceTableName); + } + + private void createCompositePartitionedTable(String format) { + String createParquet = + "CREATE TABLE %s (id Integer, name String, dept String, subdept String) USING %s " + + "PARTITIONED BY (id, dept) LOCATION '%s'"; + sql(createParquet, sourceTableName, format, fileTableDir.getAbsolutePath()); + + compositePartitionedDF.write().insertInto(sourceTableName); + compositePartitionedDF.write().insertInto(sourceTableName); + } + + private void createCompositePartitionedTableWithNullValueInPartitionColumn(String format) { + String createParquet = + "CREATE TABLE %s (id Integer, name String, dept String, subdept String) USING %s " + + "PARTITIONED BY (id, dept) LOCATION '%s'"; + sql(createParquet, sourceTableName, format, fileTableDir.getAbsolutePath()); + + Dataset unionedDF = + compositePartitionedDF + .unionAll(compositePartitionedNullRecordDF) + .select("name", "subdept", "id", "dept") + .repartition(1); + + unionedDF.write().insertInto(sourceTableName); + unionedDF.write().insertInto(sourceTableName); + } + + private void createWeirdCaseTable() { + String createParquet = + "CREATE TABLE %s (id Integer, subdept String, dept String) " + + "PARTITIONED BY (`naMe` String) STORED AS parquet"; + + sql(createParquet, sourceTableName); + + weirdColumnNamesDF.write().insertInto(sourceTableName); + weirdColumnNamesDF.write().insertInto(sourceTableName); + } + + private void createUnpartitionedHiveTable() { + String createHive = + "CREATE TABLE %s (id Integer, name String, dept String, subdept String) STORED AS parquet"; + + sql(createHive, sourceTableName); + + unpartitionedDF.write().insertInto(sourceTableName); + unpartitionedDF.write().insertInto(sourceTableName); + } + + private void createPartitionedHiveTable() { + String createHive = + "CREATE TABLE %s (name String, dept String, subdept String) " + + "PARTITIONED BY (id Integer) STORED AS parquet"; + + sql(createHive, sourceTableName); + + partitionedDF.write().insertInto(sourceTableName); + partitionedDF.write().insertInto(sourceTableName); + } + + private void createDatePartitionedFileTable(String format) { + String createParquet = + "CREATE TABLE %s (id Integer, name String, date Date) USING %s " + + "PARTITIONED BY (date) LOCATION '%s'"; + + sql(createParquet, sourceTableName, format, fileTableDir.getAbsolutePath()); + + dateDF.select("id", "name", "ts").write().insertInto(sourceTableName); + } + + private void createTableWithTwoPartitions(String format) { + String createParquet = + "CREATE TABLE %s (id Integer, name String, date Date, dept String) USING %s " + + "PARTITIONED BY (date, dept) LOCATION '%s'"; + + sql(createParquet, sourceTableName, format, fileTableDir.getAbsolutePath()); + + dateDF.write().insertInto(sourceTableName); + } +} diff --git a/spark/v3.4/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestAlterTablePartitionFields.java b/spark/v3.4/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestAlterTablePartitionFields.java new file mode 100644 index 000000000000..a4928e6552b8 --- /dev/null +++ b/spark/v3.4/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestAlterTablePartitionFields.java @@ -0,0 +1,462 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.extensions; + +import java.util.Map; +import org.apache.iceberg.PartitionSpec; +import org.apache.iceberg.Table; +import org.apache.iceberg.spark.source.SparkTable; +import org.apache.spark.sql.connector.catalog.CatalogManager; +import org.apache.spark.sql.connector.catalog.Identifier; +import org.apache.spark.sql.connector.catalog.TableCatalog; +import org.junit.After; +import org.junit.Assert; +import org.junit.Test; + +public class TestAlterTablePartitionFields extends SparkExtensionsTestBase { + public TestAlterTablePartitionFields( + String catalogName, String implementation, Map config) { + super(catalogName, implementation, config); + } + + @After + public void removeTable() { + sql("DROP TABLE IF EXISTS %s", tableName); + } + + @Test + public void testAddIdentityPartition() { + sql( + "CREATE TABLE %s (id bigint NOT NULL, category string, ts timestamp, data string) USING iceberg", + tableName); + Table table = validationCatalog.loadTable(tableIdent); + + Assert.assertTrue("Table should start unpartitioned", table.spec().isUnpartitioned()); + + sql("ALTER TABLE %s ADD PARTITION FIELD category", tableName); + + table.refresh(); + + PartitionSpec expected = + PartitionSpec.builderFor(table.schema()).withSpecId(1).identity("category").build(); + + Assert.assertEquals("Should have new spec field", expected, table.spec()); + } + + @Test + public void testAddBucketPartition() { + sql( + "CREATE TABLE %s (id bigint NOT NULL, category string, ts timestamp, data string) USING iceberg", + tableName); + Table table = validationCatalog.loadTable(tableIdent); + + Assert.assertTrue("Table should start unpartitioned", table.spec().isUnpartitioned()); + + sql("ALTER TABLE %s ADD PARTITION FIELD bucket(16, id)", tableName); + + table.refresh(); + + PartitionSpec expected = + PartitionSpec.builderFor(table.schema()) + .withSpecId(1) + .bucket("id", 16, "id_bucket_16") + .build(); + + Assert.assertEquals("Should have new spec field", expected, table.spec()); + } + + @Test + public void testAddTruncatePartition() { + sql( + "CREATE TABLE %s (id bigint NOT NULL, category string, ts timestamp, data string) USING iceberg", + tableName); + Table table = validationCatalog.loadTable(tableIdent); + + Assert.assertTrue("Table should start unpartitioned", table.spec().isUnpartitioned()); + + sql("ALTER TABLE %s ADD PARTITION FIELD truncate(data, 4)", tableName); + + table.refresh(); + + PartitionSpec expected = + PartitionSpec.builderFor(table.schema()) + .withSpecId(1) + .truncate("data", 4, "data_trunc_4") + .build(); + + Assert.assertEquals("Should have new spec field", expected, table.spec()); + } + + @Test + public void testAddYearsPartition() { + sql( + "CREATE TABLE %s (id bigint NOT NULL, category string, ts timestamp, data string) USING iceberg", + tableName); + Table table = validationCatalog.loadTable(tableIdent); + + Assert.assertTrue("Table should start unpartitioned", table.spec().isUnpartitioned()); + + sql("ALTER TABLE %s ADD PARTITION FIELD years(ts)", tableName); + + table.refresh(); + + PartitionSpec expected = + PartitionSpec.builderFor(table.schema()).withSpecId(1).year("ts").build(); + + Assert.assertEquals("Should have new spec field", expected, table.spec()); + } + + @Test + public void testAddMonthsPartition() { + sql( + "CREATE TABLE %s (id bigint NOT NULL, category string, ts timestamp, data string) USING iceberg", + tableName); + Table table = validationCatalog.loadTable(tableIdent); + + Assert.assertTrue("Table should start unpartitioned", table.spec().isUnpartitioned()); + + sql("ALTER TABLE %s ADD PARTITION FIELD months(ts)", tableName); + + table.refresh(); + + PartitionSpec expected = + PartitionSpec.builderFor(table.schema()).withSpecId(1).month("ts").build(); + + Assert.assertEquals("Should have new spec field", expected, table.spec()); + } + + @Test + public void testAddDaysPartition() { + sql( + "CREATE TABLE %s (id bigint NOT NULL, category string, ts timestamp, data string) USING iceberg", + tableName); + Table table = validationCatalog.loadTable(tableIdent); + + Assert.assertTrue("Table should start unpartitioned", table.spec().isUnpartitioned()); + + sql("ALTER TABLE %s ADD PARTITION FIELD days(ts)", tableName); + + table.refresh(); + + PartitionSpec expected = + PartitionSpec.builderFor(table.schema()).withSpecId(1).day("ts").build(); + + Assert.assertEquals("Should have new spec field", expected, table.spec()); + } + + @Test + public void testAddHoursPartition() { + sql( + "CREATE TABLE %s (id bigint NOT NULL, category string, ts timestamp, data string) USING iceberg", + tableName); + Table table = validationCatalog.loadTable(tableIdent); + + Assert.assertTrue("Table should start unpartitioned", table.spec().isUnpartitioned()); + + sql("ALTER TABLE %s ADD PARTITION FIELD hours(ts)", tableName); + + table.refresh(); + + PartitionSpec expected = + PartitionSpec.builderFor(table.schema()).withSpecId(1).hour("ts").build(); + + Assert.assertEquals("Should have new spec field", expected, table.spec()); + } + + @Test + public void testAddNamedPartition() { + sql( + "CREATE TABLE %s (id bigint NOT NULL, category string, ts timestamp, data string) USING iceberg", + tableName); + Table table = validationCatalog.loadTable(tableIdent); + + Assert.assertTrue("Table should start unpartitioned", table.spec().isUnpartitioned()); + + sql("ALTER TABLE %s ADD PARTITION FIELD bucket(16, id) AS shard", tableName); + + table.refresh(); + + PartitionSpec expected = + PartitionSpec.builderFor(table.schema()).withSpecId(1).bucket("id", 16, "shard").build(); + + Assert.assertEquals("Should have new spec field", expected, table.spec()); + } + + @Test + public void testDropIdentityPartition() { + sql( + "CREATE TABLE %s (id bigint NOT NULL, category string, data string) USING iceberg PARTITIONED BY (category)", + tableName); + Table table = validationCatalog.loadTable(tableIdent); + + Assert.assertEquals( + "Table should start with 1 partition field", 1, table.spec().fields().size()); + + sql("ALTER TABLE %s DROP PARTITION FIELD category", tableName); + + table.refresh(); + + PartitionSpec expected = + PartitionSpec.builderFor(table.schema()) + .withSpecId(1) + .alwaysNull("category", "category") + .build(); + + Assert.assertEquals("Should have new spec field", expected, table.spec()); + } + + @Test + public void testDropDaysPartition() { + sql( + "CREATE TABLE %s (id bigint NOT NULL, ts timestamp, data string) USING iceberg PARTITIONED BY (days(ts))", + tableName); + Table table = validationCatalog.loadTable(tableIdent); + + Assert.assertEquals( + "Table should start with 1 partition field", 1, table.spec().fields().size()); + + sql("ALTER TABLE %s DROP PARTITION FIELD days(ts)", tableName); + + table.refresh(); + + PartitionSpec expected = + PartitionSpec.builderFor(table.schema()).withSpecId(1).alwaysNull("ts", "ts_day").build(); + + Assert.assertEquals("Should have new spec field", expected, table.spec()); + } + + @Test + public void testDropBucketPartition() { + sql( + "CREATE TABLE %s (id bigint NOT NULL, data string) USING iceberg PARTITIONED BY (bucket(16, id))", + tableName); + Table table = validationCatalog.loadTable(tableIdent); + + Assert.assertEquals( + "Table should start with 1 partition field", 1, table.spec().fields().size()); + + sql("ALTER TABLE %s DROP PARTITION FIELD bucket(16, id)", tableName); + + table.refresh(); + + PartitionSpec expected = + PartitionSpec.builderFor(table.schema()) + .withSpecId(1) + .alwaysNull("id", "id_bucket") + .build(); + + Assert.assertEquals("Should have new spec field", expected, table.spec()); + } + + @Test + public void testDropPartitionByName() { + sql( + "CREATE TABLE %s (id bigint NOT NULL, category string, ts timestamp, data string) USING iceberg", + tableName); + Table table = validationCatalog.loadTable(tableIdent); + + Assert.assertTrue("Table should start unpartitioned", table.spec().isUnpartitioned()); + + sql("ALTER TABLE %s ADD PARTITION FIELD bucket(16, id) AS shard", tableName); + + table.refresh(); + + Assert.assertEquals("Table should have 1 partition field", 1, table.spec().fields().size()); + + // Should be recognized as iceberg command even with extra white spaces + sql("ALTER TABLE %s DROP PARTITION \n FIELD shard", tableName); + + table.refresh(); + + PartitionSpec expected = + PartitionSpec.builderFor(table.schema()).withSpecId(2).alwaysNull("id", "shard").build(); + + Assert.assertEquals("Should have new spec field", expected, table.spec()); + } + + @Test + public void testReplacePartition() { + sql( + "CREATE TABLE %s (id bigint NOT NULL, category string, ts timestamp, data string) USING iceberg", + tableName); + Table table = validationCatalog.loadTable(tableIdent); + Assert.assertTrue("Table should start unpartitioned", table.spec().isUnpartitioned()); + + sql("ALTER TABLE %s ADD PARTITION FIELD days(ts)", tableName); + table.refresh(); + PartitionSpec expected = + PartitionSpec.builderFor(table.schema()).withSpecId(1).day("ts").build(); + Assert.assertEquals("Should have new spec field", expected, table.spec()); + + sql("ALTER TABLE %s REPLACE PARTITION FIELD days(ts) WITH hours(ts)", tableName); + table.refresh(); + expected = + PartitionSpec.builderFor(table.schema()) + .withSpecId(2) + .alwaysNull("ts", "ts_day") + .hour("ts") + .build(); + Assert.assertEquals( + "Should changed from daily to hourly partitioned field", expected, table.spec()); + } + + @Test + public void testReplacePartitionAndRename() { + sql( + "CREATE TABLE %s (id bigint NOT NULL, category string, ts timestamp, data string) USING iceberg", + tableName); + Table table = validationCatalog.loadTable(tableIdent); + Assert.assertTrue("Table should start unpartitioned", table.spec().isUnpartitioned()); + + sql("ALTER TABLE %s ADD PARTITION FIELD days(ts)", tableName); + table.refresh(); + PartitionSpec expected = + PartitionSpec.builderFor(table.schema()).withSpecId(1).day("ts").build(); + Assert.assertEquals("Should have new spec field", expected, table.spec()); + + sql("ALTER TABLE %s REPLACE PARTITION FIELD days(ts) WITH hours(ts) AS hour_col", tableName); + table.refresh(); + expected = + PartitionSpec.builderFor(table.schema()) + .withSpecId(2) + .alwaysNull("ts", "ts_day") + .hour("ts", "hour_col") + .build(); + Assert.assertEquals( + "Should changed from daily to hourly partitioned field", expected, table.spec()); + } + + @Test + public void testReplaceNamedPartition() { + sql( + "CREATE TABLE %s (id bigint NOT NULL, category string, ts timestamp, data string) USING iceberg", + tableName); + Table table = validationCatalog.loadTable(tableIdent); + Assert.assertTrue("Table should start unpartitioned", table.spec().isUnpartitioned()); + + sql("ALTER TABLE %s ADD PARTITION FIELD days(ts) AS day_col", tableName); + table.refresh(); + PartitionSpec expected = + PartitionSpec.builderFor(table.schema()).withSpecId(1).day("ts", "day_col").build(); + Assert.assertEquals("Should have new spec field", expected, table.spec()); + + sql("ALTER TABLE %s REPLACE PARTITION FIELD day_col WITH hours(ts)", tableName); + table.refresh(); + expected = + PartitionSpec.builderFor(table.schema()) + .withSpecId(2) + .alwaysNull("ts", "day_col") + .hour("ts") + .build(); + Assert.assertEquals( + "Should changed from daily to hourly partitioned field", expected, table.spec()); + } + + @Test + public void testReplaceNamedPartitionAndRenameDifferently() { + sql( + "CREATE TABLE %s (id bigint NOT NULL, category string, ts timestamp, data string) USING iceberg", + tableName); + Table table = validationCatalog.loadTable(tableIdent); + Assert.assertTrue("Table should start unpartitioned", table.spec().isUnpartitioned()); + + sql("ALTER TABLE %s ADD PARTITION FIELD days(ts) AS day_col", tableName); + table.refresh(); + PartitionSpec expected = + PartitionSpec.builderFor(table.schema()).withSpecId(1).day("ts", "day_col").build(); + Assert.assertEquals("Should have new spec field", expected, table.spec()); + + sql("ALTER TABLE %s REPLACE PARTITION FIELD day_col WITH hours(ts) AS hour_col", tableName); + table.refresh(); + expected = + PartitionSpec.builderFor(table.schema()) + .withSpecId(2) + .alwaysNull("ts", "day_col") + .hour("ts", "hour_col") + .build(); + Assert.assertEquals( + "Should changed from daily to hourly partitioned field", expected, table.spec()); + } + + @Test + public void testSparkTableAddDropPartitions() throws Exception { + sql("CREATE TABLE %s (id bigint NOT NULL, ts timestamp, data string) USING iceberg", tableName); + Assert.assertEquals( + "spark table partition should be empty", 0, sparkTable().partitioning().length); + + sql("ALTER TABLE %s ADD PARTITION FIELD bucket(16, id) AS shard", tableName); + assertPartitioningEquals(sparkTable(), 1, "bucket(16, id)"); + + sql("ALTER TABLE %s ADD PARTITION FIELD truncate(data, 4)", tableName); + assertPartitioningEquals(sparkTable(), 2, "truncate(4, data)"); + + sql("ALTER TABLE %s ADD PARTITION FIELD years(ts)", tableName); + assertPartitioningEquals(sparkTable(), 3, "years(ts)"); + + sql("ALTER TABLE %s DROP PARTITION FIELD years(ts)", tableName); + assertPartitioningEquals(sparkTable(), 2, "truncate(4, data)"); + + sql("ALTER TABLE %s DROP PARTITION FIELD truncate(4, data)", tableName); + assertPartitioningEquals(sparkTable(), 1, "bucket(16, id)"); + + sql("ALTER TABLE %s DROP PARTITION FIELD shard", tableName); + sql("DESCRIBE %s", tableName); + Assert.assertEquals( + "spark table partition should be empty", 0, sparkTable().partitioning().length); + } + + @Test + public void testDropColumnOfOldPartitionFieldV1() { + // default table created in v1 format + sql( + "CREATE TABLE %s (id bigint NOT NULL, ts timestamp, day_of_ts date) USING iceberg PARTITIONED BY (day_of_ts) TBLPROPERTIES('format-version' = '1')", + tableName); + + sql("ALTER TABLE %s REPLACE PARTITION FIELD day_of_ts WITH days(ts)", tableName); + + sql("ALTER TABLE %s DROP COLUMN day_of_ts", tableName); + } + + @Test + public void testDropColumnOfOldPartitionFieldV2() { + sql( + "CREATE TABLE %s (id bigint NOT NULL, ts timestamp, day_of_ts date) USING iceberg PARTITIONED BY (day_of_ts) TBLPROPERTIES('format-version' = '2')", + tableName); + + sql("ALTER TABLE %s REPLACE PARTITION FIELD day_of_ts WITH days(ts)", tableName); + + sql("ALTER TABLE %s DROP COLUMN day_of_ts", tableName); + } + + private void assertPartitioningEquals(SparkTable table, int len, String transform) { + Assert.assertEquals("spark table partition should be " + len, len, table.partitioning().length); + Assert.assertEquals( + "latest spark table partition transform should match", + transform, + table.partitioning()[len - 1].toString()); + } + + private SparkTable sparkTable() throws Exception { + validationCatalog.loadTable(tableIdent).refresh(); + CatalogManager catalogManager = spark.sessionState().catalogManager(); + TableCatalog catalog = (TableCatalog) catalogManager.catalog(catalogName); + Identifier identifier = Identifier.of(tableIdent.namespace().levels(), tableIdent.name()); + return (SparkTable) catalog.loadTable(identifier); + } +} diff --git a/spark/v3.4/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestAlterTableSchema.java b/spark/v3.4/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestAlterTableSchema.java new file mode 100644 index 000000000000..c993c213dc5e --- /dev/null +++ b/spark/v3.4/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestAlterTableSchema.java @@ -0,0 +1,162 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.extensions; + +import java.util.Map; +import org.apache.iceberg.AssertHelpers; +import org.apache.iceberg.Table; +import org.apache.iceberg.relocated.com.google.common.collect.Sets; +import org.junit.After; +import org.junit.Assert; +import org.junit.Test; + +public class TestAlterTableSchema extends SparkExtensionsTestBase { + public TestAlterTableSchema( + String catalogName, String implementation, Map config) { + super(catalogName, implementation, config); + } + + @After + public void removeTable() { + sql("DROP TABLE IF EXISTS %s", tableName); + } + + @Test + public void testSetIdentifierFields() { + sql( + "CREATE TABLE %s (id bigint NOT NULL, " + + "location struct NOT NULL) USING iceberg", + tableName); + Table table = validationCatalog.loadTable(tableIdent); + Assert.assertTrue( + "Table should start without identifier", table.schema().identifierFieldIds().isEmpty()); + + sql("ALTER TABLE %s SET IDENTIFIER FIELDS id", tableName); + table.refresh(); + Assert.assertEquals( + "Should have new identifier field", + Sets.newHashSet(table.schema().findField("id").fieldId()), + table.schema().identifierFieldIds()); + + sql("ALTER TABLE %s SET IDENTIFIER FIELDS id, location.lon", tableName); + table.refresh(); + Assert.assertEquals( + "Should have new identifier field", + Sets.newHashSet( + table.schema().findField("id").fieldId(), + table.schema().findField("location.lon").fieldId()), + table.schema().identifierFieldIds()); + + sql("ALTER TABLE %s SET IDENTIFIER FIELDS location.lon", tableName); + table.refresh(); + Assert.assertEquals( + "Should have new identifier field", + Sets.newHashSet(table.schema().findField("location.lon").fieldId()), + table.schema().identifierFieldIds()); + } + + @Test + public void testSetInvalidIdentifierFields() { + sql("CREATE TABLE %s (id bigint NOT NULL, id2 bigint) USING iceberg", tableName); + Table table = validationCatalog.loadTable(tableIdent); + Assert.assertTrue( + "Table should start without identifier", table.schema().identifierFieldIds().isEmpty()); + AssertHelpers.assertThrows( + "should not allow setting unknown fields", + IllegalArgumentException.class, + "not found in current schema or added columns", + () -> sql("ALTER TABLE %s SET IDENTIFIER FIELDS unknown", tableName)); + + AssertHelpers.assertThrows( + "should not allow setting optional fields", + IllegalArgumentException.class, + "not a required field", + () -> sql("ALTER TABLE %s SET IDENTIFIER FIELDS id2", tableName)); + } + + @Test + public void testDropIdentifierFields() { + sql( + "CREATE TABLE %s (id bigint NOT NULL, " + + "location struct NOT NULL) USING iceberg", + tableName); + Table table = validationCatalog.loadTable(tableIdent); + Assert.assertTrue( + "Table should start without identifier", table.schema().identifierFieldIds().isEmpty()); + + sql("ALTER TABLE %s SET IDENTIFIER FIELDS id, location.lon", tableName); + table.refresh(); + Assert.assertEquals( + "Should have new identifier fields", + Sets.newHashSet( + table.schema().findField("id").fieldId(), + table.schema().findField("location.lon").fieldId()), + table.schema().identifierFieldIds()); + + sql("ALTER TABLE %s DROP IDENTIFIER FIELDS id", tableName); + table.refresh(); + Assert.assertEquals( + "Should removed identifier field", + Sets.newHashSet(table.schema().findField("location.lon").fieldId()), + table.schema().identifierFieldIds()); + + sql("ALTER TABLE %s SET IDENTIFIER FIELDS id, location.lon", tableName); + table.refresh(); + Assert.assertEquals( + "Should have new identifier fields", + Sets.newHashSet( + table.schema().findField("id").fieldId(), + table.schema().findField("location.lon").fieldId()), + table.schema().identifierFieldIds()); + + sql("ALTER TABLE %s DROP IDENTIFIER FIELDS id, location.lon", tableName); + table.refresh(); + Assert.assertEquals( + "Should have no identifier field", Sets.newHashSet(), table.schema().identifierFieldIds()); + } + + @Test + public void testDropInvalidIdentifierFields() { + sql( + "CREATE TABLE %s (id bigint NOT NULL, data string NOT NULL, " + + "location struct NOT NULL) USING iceberg", + tableName); + Table table = validationCatalog.loadTable(tableIdent); + Assert.assertTrue( + "Table should start without identifier", table.schema().identifierFieldIds().isEmpty()); + AssertHelpers.assertThrows( + "should not allow dropping unknown fields", + IllegalArgumentException.class, + "field unknown not found", + () -> sql("ALTER TABLE %s DROP IDENTIFIER FIELDS unknown", tableName)); + + sql("ALTER TABLE %s SET IDENTIFIER FIELDS id", tableName); + AssertHelpers.assertThrows( + "should not allow dropping a field that is not an identifier", + IllegalArgumentException.class, + "data is not an identifier field", + () -> sql("ALTER TABLE %s DROP IDENTIFIER FIELDS data", tableName)); + + AssertHelpers.assertThrows( + "should not allow dropping a nested field that is not an identifier", + IllegalArgumentException.class, + "location.lon is not an identifier field", + () -> sql("ALTER TABLE %s DROP IDENTIFIER FIELDS location.lon", tableName)); + } +} diff --git a/spark/v3.4/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestAncestorsOfProcedure.java b/spark/v3.4/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestAncestorsOfProcedure.java new file mode 100644 index 000000000000..ae591821e21a --- /dev/null +++ b/spark/v3.4/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestAncestorsOfProcedure.java @@ -0,0 +1,168 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.extensions; + +import java.util.List; +import java.util.Map; +import org.apache.iceberg.AssertHelpers; +import org.apache.iceberg.Table; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableList; +import org.apache.spark.sql.AnalysisException; +import org.junit.After; +import org.junit.Test; + +public class TestAncestorsOfProcedure extends SparkExtensionsTestBase { + + public TestAncestorsOfProcedure( + String catalogName, String implementation, Map config) { + super(catalogName, implementation, config); + } + + @After + public void removeTables() { + sql("DROP TABLE IF EXISTS %s", tableName); + } + + @Test + public void testAncestorOfUsingEmptyArgs() { + sql("CREATE TABLE %s (id bigint NOT NULL, data string) USING iceberg", tableName); + sql("INSERT INTO TABLE %s VALUES (1, 'a')", tableName); + sql("INSERT INTO TABLE %s VALUES (2, 'b')", tableName); + + Table table = validationCatalog.loadTable(tableIdent); + Long currentSnapshotId = table.currentSnapshot().snapshotId(); + Long currentTimestamp = table.currentSnapshot().timestampMillis(); + Long preSnapshotId = table.currentSnapshot().parentId(); + Long preTimeStamp = table.snapshot(table.currentSnapshot().parentId()).timestampMillis(); + + List output = sql("CALL %s.system.ancestors_of('%s')", catalogName, tableIdent); + + assertEquals( + "Procedure output must match", + ImmutableList.of( + row(currentSnapshotId, currentTimestamp), row(preSnapshotId, preTimeStamp)), + output); + } + + @Test + public void testAncestorOfUsingSnapshotId() { + sql("CREATE TABLE %s (id bigint NOT NULL, data string) USING iceberg", tableName); + sql("INSERT INTO TABLE %s VALUES (1, 'a')", tableName); + sql("INSERT INTO TABLE %s VALUES (2, 'b')", tableName); + + Table table = validationCatalog.loadTable(tableIdent); + Long currentSnapshotId = table.currentSnapshot().snapshotId(); + Long currentTimestamp = table.currentSnapshot().timestampMillis(); + Long preSnapshotId = table.currentSnapshot().parentId(); + Long preTimeStamp = table.snapshot(table.currentSnapshot().parentId()).timestampMillis(); + + assertEquals( + "Procedure output must match", + ImmutableList.of( + row(currentSnapshotId, currentTimestamp), row(preSnapshotId, preTimeStamp)), + sql("CALL %s.system.ancestors_of('%s', %dL)", catalogName, tableIdent, currentSnapshotId)); + + assertEquals( + "Procedure output must match", + ImmutableList.of(row(preSnapshotId, preTimeStamp)), + sql("CALL %s.system.ancestors_of('%s', %dL)", catalogName, tableIdent, preSnapshotId)); + } + + @Test + public void testAncestorOfWithRollBack() { + sql("CREATE TABLE %s (id bigint NOT NULL, data string) USING iceberg", tableName); + Table table = validationCatalog.loadTable(tableIdent); + sql("INSERT INTO TABLE %s VALUES (1, 'a')", tableName); + table.refresh(); + Long firstSnapshotId = table.currentSnapshot().snapshotId(); + Long firstTimestamp = table.currentSnapshot().timestampMillis(); + sql("INSERT INTO TABLE %s VALUES (2, 'b')", tableName); + table.refresh(); + Long secondSnapshotId = table.currentSnapshot().snapshotId(); + Long secondTimestamp = table.currentSnapshot().timestampMillis(); + sql("INSERT INTO TABLE %s VALUES (3, 'c')", tableName); + table.refresh(); + Long thirdSnapshotId = table.currentSnapshot().snapshotId(); + Long thirdTimestamp = table.currentSnapshot().timestampMillis(); + + // roll back + sql( + "CALL %s.system.rollback_to_snapshot('%s', %dL)", + catalogName, tableIdent, secondSnapshotId); + + sql("INSERT INTO TABLE %s VALUES (4, 'd')", tableName); + table.refresh(); + Long fourthSnapshotId = table.currentSnapshot().snapshotId(); + Long fourthTimestamp = table.currentSnapshot().timestampMillis(); + + assertEquals( + "Procedure output must match", + ImmutableList.of( + row(fourthSnapshotId, fourthTimestamp), + row(secondSnapshotId, secondTimestamp), + row(firstSnapshotId, firstTimestamp)), + sql("CALL %s.system.ancestors_of('%s', %dL)", catalogName, tableIdent, fourthSnapshotId)); + + assertEquals( + "Procedure output must match", + ImmutableList.of( + row(thirdSnapshotId, thirdTimestamp), + row(secondSnapshotId, secondTimestamp), + row(firstSnapshotId, firstTimestamp)), + sql("CALL %s.system.ancestors_of('%s', %dL)", catalogName, tableIdent, thirdSnapshotId)); + } + + @Test + public void testAncestorOfUsingNamedArgs() { + sql("CREATE TABLE %s (id bigint NOT NULL, data string) USING iceberg", tableName); + sql("INSERT INTO TABLE %s VALUES (1, 'a')", tableName); + + Table table = validationCatalog.loadTable(tableIdent); + Long firstSnapshotId = table.currentSnapshot().snapshotId(); + Long firstTimestamp = table.currentSnapshot().timestampMillis(); + + assertEquals( + "Procedure output must match", + ImmutableList.of(row(firstSnapshotId, firstTimestamp)), + sql( + "CALL %s.system.ancestors_of(snapshot_id => %dL, table => '%s')", + catalogName, firstSnapshotId, tableIdent)); + } + + @Test + public void testInvalidAncestorOfCases() { + AssertHelpers.assertThrows( + "Should reject calls without all required args", + AnalysisException.class, + "Missing required parameters", + () -> sql("CALL %s.system.ancestors_of()", catalogName)); + + AssertHelpers.assertThrows( + "Should reject calls with empty table identifier", + IllegalArgumentException.class, + "Cannot handle an empty identifier for parameter 'table'", + () -> sql("CALL %s.system.ancestors_of('')", catalogName)); + + AssertHelpers.assertThrows( + "Should reject calls with invalid arg types", + AnalysisException.class, + "Wrong arg type for snapshot_id: cannot cast", + () -> sql("CALL %s.system.ancestors_of('%s', 1.1)", catalogName, tableIdent)); + } +} diff --git a/spark/v3.4/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestBranchDDL.java b/spark/v3.4/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestBranchDDL.java new file mode 100644 index 000000000000..cc60be55ba0c --- /dev/null +++ b/spark/v3.4/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestBranchDDL.java @@ -0,0 +1,292 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.extensions; + +import java.util.List; +import java.util.Map; +import java.util.concurrent.TimeUnit; +import org.apache.iceberg.AssertHelpers; +import org.apache.iceberg.SnapshotRef; +import org.apache.iceberg.Table; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableList; +import org.apache.iceberg.spark.SparkCatalogConfig; +import org.apache.iceberg.spark.source.SimpleRecord; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.catalyst.analysis.NoSuchTableException; +import org.apache.spark.sql.catalyst.parser.extensions.IcebergParseException; +import org.junit.After; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; +import org.junit.runners.Parameterized; + +public class TestBranchDDL extends SparkExtensionsTestBase { + + @Before + public void before() { + sql("CREATE TABLE %s (id INT, data STRING) USING iceberg", tableName); + } + + @After + public void removeTable() { + sql("DROP TABLE IF EXISTS %s", tableName); + } + + @Parameterized.Parameters(name = "catalogName = {0}, implementation = {1}, config = {2}") + public static Object[][] parameters() { + return new Object[][] { + { + SparkCatalogConfig.SPARK.catalogName(), + SparkCatalogConfig.SPARK.implementation(), + SparkCatalogConfig.SPARK.properties() + } + }; + } + + public TestBranchDDL(String catalog, String implementation, Map properties) { + super(catalog, implementation, properties); + } + + @Test + public void testCreateBranch() throws NoSuchTableException { + Table table = insertRows(); + long snapshotId = table.currentSnapshot().snapshotId(); + String branchName = "b1"; + Integer minSnapshotsToKeep = 2; + long maxSnapshotAge = 2L; + long maxRefAge = 10L; + sql( + "ALTER TABLE %s CREATE BRANCH %s AS OF VERSION %d RETAIN %d DAYS WITH SNAPSHOT RETENTION %d SNAPSHOTS %d days", + tableName, branchName, snapshotId, maxRefAge, minSnapshotsToKeep, maxSnapshotAge); + table.refresh(); + SnapshotRef ref = table.refs().get(branchName); + Assert.assertEquals(table.currentSnapshot().snapshotId(), ref.snapshotId()); + Assert.assertEquals(minSnapshotsToKeep, ref.minSnapshotsToKeep()); + Assert.assertEquals(TimeUnit.DAYS.toMillis(maxSnapshotAge), ref.maxSnapshotAgeMs().longValue()); + Assert.assertEquals(TimeUnit.DAYS.toMillis(maxRefAge), ref.maxRefAgeMs().longValue()); + + AssertHelpers.assertThrows( + "Cannot create an existing branch", + IllegalArgumentException.class, + "Ref b1 already exists", + () -> sql("ALTER TABLE %s CREATE BRANCH %s", tableName, branchName)); + } + + @Test + public void testCreateBranchUseDefaultConfig() throws NoSuchTableException { + Table table = insertRows(); + String branchName = "b1"; + sql("ALTER TABLE %s CREATE BRANCH %s", tableName, branchName); + table.refresh(); + SnapshotRef ref = table.refs().get(branchName); + Assert.assertEquals(table.currentSnapshot().snapshotId(), ref.snapshotId()); + Assert.assertNull(ref.minSnapshotsToKeep()); + Assert.assertNull(ref.maxSnapshotAgeMs()); + Assert.assertNull(ref.maxRefAgeMs()); + } + + @Test + public void testCreateBranchUseCustomMinSnapshotsToKeep() throws NoSuchTableException { + Integer minSnapshotsToKeep = 2; + Table table = insertRows(); + String branchName = "b1"; + sql( + "ALTER TABLE %s CREATE BRANCH %s WITH SNAPSHOT RETENTION %d SNAPSHOTS", + tableName, branchName, minSnapshotsToKeep); + table.refresh(); + SnapshotRef ref = table.refs().get(branchName); + Assert.assertEquals(table.currentSnapshot().snapshotId(), ref.snapshotId()); + Assert.assertEquals(minSnapshotsToKeep, ref.minSnapshotsToKeep()); + Assert.assertNull(ref.maxSnapshotAgeMs()); + Assert.assertNull(ref.maxRefAgeMs()); + } + + @Test + public void testCreateBranchUseCustomMaxSnapshotAge() throws NoSuchTableException { + long maxSnapshotAge = 2L; + Table table = insertRows(); + String branchName = "b1"; + sql( + "ALTER TABLE %s CREATE BRANCH %s WITH SNAPSHOT RETENTION %d DAYS", + tableName, branchName, maxSnapshotAge); + table.refresh(); + SnapshotRef ref = table.refs().get(branchName); + Assert.assertNotNull(ref); + Assert.assertNull(ref.minSnapshotsToKeep()); + Assert.assertEquals(TimeUnit.DAYS.toMillis(maxSnapshotAge), ref.maxSnapshotAgeMs().longValue()); + Assert.assertNull(ref.maxRefAgeMs()); + } + + @Test + public void testCreateBranchIfNotExists() throws NoSuchTableException { + long maxSnapshotAge = 2L; + Table table = insertRows(); + String branchName = "b1"; + sql( + "ALTER TABLE %s CREATE BRANCH %s WITH SNAPSHOT RETENTION %d DAYS", + tableName, branchName, maxSnapshotAge); + sql("ALTER TABLE %s CREATE BRANCH IF NOT EXISTS %s", tableName, branchName); + + table.refresh(); + SnapshotRef ref = table.refs().get(branchName); + Assert.assertEquals(table.currentSnapshot().snapshotId(), ref.snapshotId()); + Assert.assertNull(ref.minSnapshotsToKeep()); + Assert.assertEquals(TimeUnit.DAYS.toMillis(maxSnapshotAge), ref.maxSnapshotAgeMs().longValue()); + Assert.assertNull(ref.maxRefAgeMs()); + } + + @Test + public void testCreateBranchUseCustomMinSnapshotsToKeepAndMaxSnapshotAge() + throws NoSuchTableException { + Integer minSnapshotsToKeep = 2; + long maxSnapshotAge = 2L; + Table table = insertRows(); + String branchName = "b1"; + sql( + "ALTER TABLE %s CREATE BRANCH %s WITH SNAPSHOT RETENTION %d SNAPSHOTS %d DAYS", + tableName, branchName, minSnapshotsToKeep, maxSnapshotAge); + table.refresh(); + SnapshotRef ref = table.refs().get(branchName); + Assert.assertEquals(table.currentSnapshot().snapshotId(), ref.snapshotId()); + Assert.assertEquals(minSnapshotsToKeep, ref.minSnapshotsToKeep()); + Assert.assertEquals(TimeUnit.DAYS.toMillis(maxSnapshotAge), ref.maxSnapshotAgeMs().longValue()); + Assert.assertNull(ref.maxRefAgeMs()); + + AssertHelpers.assertThrows( + "Illegal statement", + IcebergParseException.class, + "no viable alternative at input 'WITH SNAPSHOT RETENTION'", + () -> + sql("ALTER TABLE %s CREATE BRANCH %s WITH SNAPSHOT RETENTION", tableName, branchName)); + } + + @Test + public void testCreateBranchUseCustomMaxRefAge() throws NoSuchTableException { + long maxRefAge = 10L; + Table table = insertRows(); + String branchName = "b1"; + sql("ALTER TABLE %s CREATE BRANCH %s RETAIN %d DAYS", tableName, branchName, maxRefAge); + table.refresh(); + SnapshotRef ref = table.refs().get(branchName); + Assert.assertEquals(table.currentSnapshot().snapshotId(), ref.snapshotId()); + Assert.assertNull(ref.minSnapshotsToKeep()); + Assert.assertNull(ref.maxSnapshotAgeMs()); + Assert.assertEquals(TimeUnit.DAYS.toMillis(maxRefAge), ref.maxRefAgeMs().longValue()); + + AssertHelpers.assertThrows( + "Illegal statement", + IcebergParseException.class, + "mismatched input", + () -> sql("ALTER TABLE %s CREATE BRANCH %s RETAIN", tableName, branchName)); + + AssertHelpers.assertThrows( + "Illegal statement", + IcebergParseException.class, + "mismatched input", + () -> sql("ALTER TABLE %s CREATE BRANCH %s RETAIN %s DAYS", tableName, branchName, "abc")); + + AssertHelpers.assertThrows( + "Illegal statement", + IcebergParseException.class, + "mismatched input 'SECONDS' expecting {'DAYS', 'HOURS', 'MINUTES'}", + () -> + sql( + "ALTER TABLE %s CREATE BRANCH %s RETAIN %d SECONDS", + tableName, branchName, maxRefAge)); + } + + @Test + public void testDropBranch() throws NoSuchTableException { + insertRows(); + + Table table = validationCatalog.loadTable(tableIdent); + String branchName = "b1"; + table.manageSnapshots().createBranch(branchName, table.currentSnapshot().snapshotId()).commit(); + SnapshotRef ref = table.refs().get(branchName); + Assert.assertEquals(table.currentSnapshot().snapshotId(), ref.snapshotId()); + + sql("ALTER TABLE %s DROP BRANCH %s", tableName, branchName); + table.refresh(); + + ref = table.refs().get(branchName); + Assert.assertNull(ref); + } + + @Test + public void testDropBranchDoesNotExist() { + AssertHelpers.assertThrows( + "Cannot perform drop branch on branch which does not exist", + IllegalArgumentException.class, + "Branch does not exist: nonExistingBranch", + () -> sql("ALTER TABLE %s DROP BRANCH %s", tableName, "nonExistingBranch")); + } + + @Test + public void testDropBranchFailsForTag() throws NoSuchTableException { + String tagName = "b1"; + Table table = insertRows(); + table.manageSnapshots().createTag(tagName, table.currentSnapshot().snapshotId()).commit(); + + AssertHelpers.assertThrows( + "Cannot perform drop branch on tag", + IllegalArgumentException.class, + "Ref b1 is a tag not a branch", + () -> sql("ALTER TABLE %s DROP BRANCH %s", tableName, tagName)); + } + + @Test + public void testDropBranchNonConformingName() { + AssertHelpers.assertThrows( + "Non-conforming branch name", + IcebergParseException.class, + "mismatched input '123'", + () -> sql("ALTER TABLE %s DROP BRANCH %s", tableName, "123")); + } + + @Test + public void testDropMainBranchFails() { + AssertHelpers.assertThrows( + "Cannot drop the main branch", + IllegalArgumentException.class, + "Cannot remove main branch", + () -> sql("ALTER TABLE %s DROP BRANCH main", tableName)); + } + + @Test + public void testDropBranchIfExists() { + String branchName = "nonExistingBranch"; + Table table = validationCatalog.loadTable(tableIdent); + Assert.assertNull(table.refs().get(branchName)); + + sql("ALTER TABLE %s DROP BRANCH IF EXISTS %s", tableName, branchName); + table.refresh(); + + SnapshotRef ref = table.refs().get(branchName); + Assert.assertNull(ref); + } + + private Table insertRows() throws NoSuchTableException { + List records = + ImmutableList.of(new SimpleRecord(1, "a"), new SimpleRecord(2, "b")); + Dataset df = spark.createDataFrame(records, SimpleRecord.class); + df.writeTo(tableName).append(); + return validationCatalog.loadTable(tableIdent); + } +} diff --git a/spark/v3.4/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestCallStatementParser.java b/spark/v3.4/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestCallStatementParser.java new file mode 100644 index 000000000000..9c2233ccb791 --- /dev/null +++ b/spark/v3.4/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestCallStatementParser.java @@ -0,0 +1,214 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.extensions; + +import java.math.BigDecimal; +import java.sql.Timestamp; +import java.time.Instant; +import java.util.List; +import org.apache.iceberg.AssertHelpers; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableList; +import org.apache.iceberg.relocated.com.google.common.collect.Lists; +import org.apache.spark.sql.SparkSession; +import org.apache.spark.sql.catalyst.expressions.Expression; +import org.apache.spark.sql.catalyst.expressions.Literal; +import org.apache.spark.sql.catalyst.expressions.Literal$; +import org.apache.spark.sql.catalyst.parser.ParseException; +import org.apache.spark.sql.catalyst.parser.ParserInterface; +import org.apache.spark.sql.catalyst.parser.extensions.IcebergParseException; +import org.apache.spark.sql.catalyst.plans.logical.CallArgument; +import org.apache.spark.sql.catalyst.plans.logical.CallStatement; +import org.apache.spark.sql.catalyst.plans.logical.NamedArgument; +import org.apache.spark.sql.catalyst.plans.logical.PositionalArgument; +import org.apache.spark.sql.types.DataType; +import org.apache.spark.sql.types.DataTypes; +import org.junit.AfterClass; +import org.junit.Assert; +import org.junit.BeforeClass; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.TemporaryFolder; +import scala.collection.JavaConverters; + +public class TestCallStatementParser { + + @Rule public TemporaryFolder temp = new TemporaryFolder(); + + private static SparkSession spark = null; + private static ParserInterface parser = null; + + @BeforeClass + public static void startSpark() { + TestCallStatementParser.spark = + SparkSession.builder() + .master("local[2]") + .config("spark.sql.extensions", IcebergSparkSessionExtensions.class.getName()) + .config("spark.extra.prop", "value") + .getOrCreate(); + TestCallStatementParser.parser = spark.sessionState().sqlParser(); + } + + @AfterClass + public static void stopSpark() { + SparkSession currentSpark = TestCallStatementParser.spark; + TestCallStatementParser.spark = null; + TestCallStatementParser.parser = null; + currentSpark.stop(); + } + + @Test + public void testCallWithPositionalArgs() throws ParseException { + CallStatement call = + (CallStatement) parser.parsePlan("CALL c.n.func(1, '2', 3L, true, 1.0D, 9.0e1, 900e-1BD)"); + Assert.assertEquals( + ImmutableList.of("c", "n", "func"), JavaConverters.seqAsJavaList(call.name())); + + Assert.assertEquals(7, call.args().size()); + + checkArg(call, 0, 1, DataTypes.IntegerType); + checkArg(call, 1, "2", DataTypes.StringType); + checkArg(call, 2, 3L, DataTypes.LongType); + checkArg(call, 3, true, DataTypes.BooleanType); + checkArg(call, 4, 1.0D, DataTypes.DoubleType); + checkArg(call, 5, 9.0e1, DataTypes.DoubleType); + checkArg(call, 6, new BigDecimal("900e-1"), DataTypes.createDecimalType(3, 1)); + } + + @Test + public void testCallWithNamedArgs() throws ParseException { + CallStatement call = + (CallStatement) parser.parsePlan("CALL cat.system.func(c1 => 1, c2 => '2', c3 => true)"); + Assert.assertEquals( + ImmutableList.of("cat", "system", "func"), JavaConverters.seqAsJavaList(call.name())); + + Assert.assertEquals(3, call.args().size()); + + checkArg(call, 0, "c1", 1, DataTypes.IntegerType); + checkArg(call, 1, "c2", "2", DataTypes.StringType); + checkArg(call, 2, "c3", true, DataTypes.BooleanType); + } + + @Test + public void testCallWithMixedArgs() throws ParseException { + CallStatement call = (CallStatement) parser.parsePlan("CALL cat.system.func(c1 => 1, '2')"); + Assert.assertEquals( + ImmutableList.of("cat", "system", "func"), JavaConverters.seqAsJavaList(call.name())); + + Assert.assertEquals(2, call.args().size()); + + checkArg(call, 0, "c1", 1, DataTypes.IntegerType); + checkArg(call, 1, "2", DataTypes.StringType); + } + + @Test + public void testCallWithTimestampArg() throws ParseException { + CallStatement call = + (CallStatement) + parser.parsePlan("CALL cat.system.func(TIMESTAMP '2017-02-03T10:37:30.00Z')"); + Assert.assertEquals( + ImmutableList.of("cat", "system", "func"), JavaConverters.seqAsJavaList(call.name())); + + Assert.assertEquals(1, call.args().size()); + + checkArg( + call, 0, Timestamp.from(Instant.parse("2017-02-03T10:37:30.00Z")), DataTypes.TimestampType); + } + + @Test + public void testCallWithVarSubstitution() throws ParseException { + CallStatement call = + (CallStatement) parser.parsePlan("CALL cat.system.func('${spark.extra.prop}')"); + Assert.assertEquals( + ImmutableList.of("cat", "system", "func"), JavaConverters.seqAsJavaList(call.name())); + + Assert.assertEquals(1, call.args().size()); + + checkArg(call, 0, "value", DataTypes.StringType); + } + + @Test + public void testCallParseError() { + AssertHelpers.assertThrows( + "Should fail with a sensible parse error", + IcebergParseException.class, + "missing '(' at 'radish'", + () -> parser.parsePlan("CALL cat.system radish kebab")); + } + + @Test + public void testCallStripsComments() throws ParseException { + List callStatementsWithComments = + Lists.newArrayList( + "/* bracketed comment */ CALL cat.system.func('${spark.extra.prop}')", + "/**/ CALL cat.system.func('${spark.extra.prop}')", + "-- single line comment \n CALL cat.system.func('${spark.extra.prop}')", + "-- multiple \n-- single line \n-- comments \n CALL cat.system.func('${spark.extra.prop}')", + "/* select * from multiline_comment \n where x like '%sql%'; */ CALL cat.system.func('${spark.extra.prop}')", + "/* {\"app\": \"dbt\", \"dbt_version\": \"1.0.1\", \"profile_name\": \"profile1\", \"target_name\": \"dev\", " + + "\"node_id\": \"model.profile1.stg_users\"} \n*/ CALL cat.system.func('${spark.extra.prop}')", + "/* Some multi-line comment \n" + + "*/ CALL /* inline comment */ cat.system.func('${spark.extra.prop}') -- ending comment", + "CALL -- a line ending comment\n" + "cat.system.func('${spark.extra.prop}')"); + for (String sqlText : callStatementsWithComments) { + CallStatement call = (CallStatement) parser.parsePlan(sqlText); + Assert.assertEquals( + ImmutableList.of("cat", "system", "func"), JavaConverters.seqAsJavaList(call.name())); + + Assert.assertEquals(1, call.args().size()); + + checkArg(call, 0, "value", DataTypes.StringType); + } + } + + private void checkArg( + CallStatement call, int index, Object expectedValue, DataType expectedType) { + checkArg(call, index, null, expectedValue, expectedType); + } + + private void checkArg( + CallStatement call, + int index, + String expectedName, + Object expectedValue, + DataType expectedType) { + + if (expectedName != null) { + NamedArgument arg = checkCast(call.args().apply(index), NamedArgument.class); + Assert.assertEquals(expectedName, arg.name()); + } else { + CallArgument arg = call.args().apply(index); + checkCast(arg, PositionalArgument.class); + } + + Expression expectedExpr = toSparkLiteral(expectedValue, expectedType); + Expression actualExpr = call.args().apply(index).expr(); + Assert.assertEquals("Arg types must match", expectedExpr.dataType(), actualExpr.dataType()); + Assert.assertEquals("Arg must match", expectedExpr, actualExpr); + } + + private Literal toSparkLiteral(Object value, DataType dataType) { + return Literal$.MODULE$.create(value, dataType); + } + + private T checkCast(Object value, Class expectedClass) { + Assert.assertTrue( + "Expected instance of " + expectedClass.getName(), expectedClass.isInstance(value)); + return expectedClass.cast(value); + } +} diff --git a/spark/v3.4/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestChangelogTable.java b/spark/v3.4/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestChangelogTable.java new file mode 100644 index 000000000000..603775eb11b7 --- /dev/null +++ b/spark/v3.4/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestChangelogTable.java @@ -0,0 +1,337 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.extensions; + +import static org.apache.iceberg.AssertHelpers.assertThrows; +import static org.apache.iceberg.TableProperties.FORMAT_VERSION; +import static org.apache.iceberg.TableProperties.MANIFEST_MERGE_ENABLED; +import static org.apache.iceberg.TableProperties.MANIFEST_MIN_MERGE_COUNT; + +import java.util.List; +import java.util.Map; +import org.apache.iceberg.DataOperations; +import org.apache.iceberg.Snapshot; +import org.apache.iceberg.Table; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableList; +import org.apache.iceberg.relocated.com.google.common.collect.Iterables; +import org.apache.iceberg.spark.SparkCatalogConfig; +import org.apache.iceberg.spark.SparkReadOptions; +import org.apache.iceberg.spark.source.SparkChangelogTable; +import org.apache.spark.sql.DataFrameReader; +import org.apache.spark.sql.Row; +import org.junit.After; +import org.junit.Assert; +import org.junit.Test; +import org.junit.runners.Parameterized.Parameters; + +public class TestChangelogTable extends SparkExtensionsTestBase { + + @Parameters(name = "formatVersion = {0}, catalogName = {1}, implementation = {2}, config = {3}") + public static Object[][] parameters() { + return new Object[][] { + { + 1, + SparkCatalogConfig.SPARK.catalogName(), + SparkCatalogConfig.SPARK.implementation(), + SparkCatalogConfig.SPARK.properties() + }, + { + 2, + SparkCatalogConfig.HIVE.catalogName(), + SparkCatalogConfig.HIVE.implementation(), + SparkCatalogConfig.HIVE.properties() + } + }; + } + + private final int formatVersion; + + public TestChangelogTable( + int formatVersion, String catalogName, String implementation, Map config) { + super(catalogName, implementation, config); + this.formatVersion = formatVersion; + } + + @After + public void removeTables() { + sql("DROP TABLE IF EXISTS %s", tableName); + } + + @Test + public void testDataFilters() { + createTableWithDefaultRows(); + + sql("INSERT INTO %s VALUES (3, 'c')", tableName); + + Table table = validationCatalog.loadTable(tableIdent); + + Snapshot snap3 = table.currentSnapshot(); + + sql("DELETE FROM %s WHERE id = 3", tableName); + + table.refresh(); + + Snapshot snap4 = table.currentSnapshot(); + + assertEquals( + "Should have expected rows", + ImmutableList.of( + row(3, "c", "INSERT", 2, snap3.snapshotId()), + row(3, "c", "DELETE", 3, snap4.snapshotId())), + sql("SELECT * FROM %s.changes WHERE id = 3 ORDER BY _change_ordinal, id", tableName)); + } + + @Test + public void testOverwrites() { + createTableWithDefaultRows(); + + Table table = validationCatalog.loadTable(tableIdent); + + Snapshot snap2 = table.currentSnapshot(); + + sql("INSERT OVERWRITE %s VALUES (-2, 'b')", tableName); + + table.refresh(); + + Snapshot snap3 = table.currentSnapshot(); + + assertEquals( + "Rows should match", + ImmutableList.of( + row(2, "b", "DELETE", 0, snap3.snapshotId()), + row(-2, "b", "INSERT", 0, snap3.snapshotId())), + changelogRecords(snap2, snap3)); + } + + @Test + public void testQueryWithTimeRange() { + createTable(); + + sql("INSERT INTO %s VALUES (1, 'a')", tableName); + Table table = validationCatalog.loadTable(tableIdent); + Snapshot snap1 = table.currentSnapshot(); + long rightAfterSnap1 = waitUntilAfter(snap1.timestampMillis()); + + sql("INSERT INTO %s VALUES (2, 'b')", tableName); + table.refresh(); + Snapshot snap2 = table.currentSnapshot(); + long rightAfterSnap2 = waitUntilAfter(snap2.timestampMillis()); + + sql("INSERT OVERWRITE %s VALUES (-2, 'b')", tableName); + table.refresh(); + Snapshot snap3 = table.currentSnapshot(); + + assertEquals( + "Should have expected changed rows only from snapshot 3", + ImmutableList.of( + row(2, "b", "DELETE", 0, snap3.snapshotId()), + row(-2, "b", "INSERT", 0, snap3.snapshotId())), + changelogRecords(rightAfterSnap2, snap3.timestampMillis())); + + assertEquals( + "Should have expected changed rows only from snapshot 3", + ImmutableList.of( + row(2, "b", "DELETE", 0, snap3.snapshotId()), + row(-2, "b", "INSERT", 0, snap3.snapshotId())), + changelogRecords(snap2.timestampMillis(), snap3.timestampMillis())); + + assertEquals( + "Should have expected changed rows from snapshot 2 and 3", + ImmutableList.of( + row(2, "b", "INSERT", 0, snap2.snapshotId()), + row(2, "b", "DELETE", 1, snap3.snapshotId()), + row(-2, "b", "INSERT", 1, snap3.snapshotId())), + changelogRecords(rightAfterSnap1, snap3.timestampMillis())); + + assertEquals( + "Should have expected changed rows up to the current snapshot", + ImmutableList.of( + row(2, "b", "INSERT", 0, snap2.snapshotId()), + row(2, "b", "DELETE", 1, snap3.snapshotId()), + row(-2, "b", "INSERT", 1, snap3.snapshotId())), + changelogRecords(rightAfterSnap1, null)); + } + + @Test + public void testTimeRangeValidation() { + createTableWithDefaultRows(); + + Table table = validationCatalog.loadTable(tableIdent); + + Snapshot snap2 = table.currentSnapshot(); + + sql("INSERT OVERWRITE %s VALUES (-2, 'b')", tableName); + table.refresh(); + Snapshot snap3 = table.currentSnapshot(); + long rightAfterSnap3 = waitUntilAfter(snap3.timestampMillis()); + + assertThrows( + "Should fail if start time is after end time", + IllegalArgumentException.class, + () -> changelogRecords(snap3.timestampMillis(), snap2.timestampMillis())); + + assertThrows( + "Should fail if start time is after the current snapshot", + IllegalArgumentException.class, + () -> changelogRecords(rightAfterSnap3, null)); + } + + @Test + public void testMetadataDeletes() { + createTableWithDefaultRows(); + + Table table = validationCatalog.loadTable(tableIdent); + + Snapshot snap2 = table.currentSnapshot(); + + sql("DELETE FROM %s WHERE data = 'a'", tableName); + + table.refresh(); + + Snapshot snap3 = table.currentSnapshot(); + Assert.assertEquals("Operation must match", DataOperations.DELETE, snap3.operation()); + + assertEquals( + "Rows should match", + ImmutableList.of(row(1, "a", "DELETE", 0, snap3.snapshotId())), + changelogRecords(snap2, snap3)); + } + + @Test + public void testExistingEntriesInNewDataManifestsAreIgnored() { + sql( + "CREATE TABLE %s (id INT, data STRING) " + + "USING iceberg " + + "PARTITIONED BY (data) " + + "TBLPROPERTIES ( " + + " '%s' = '%d', " + + " '%s' = '1', " + + " '%s' = 'true' " + + ")", + tableName, FORMAT_VERSION, formatVersion, MANIFEST_MIN_MERGE_COUNT, MANIFEST_MERGE_ENABLED); + + sql("INSERT INTO %s VALUES (1, 'a')", tableName); + + Table table = validationCatalog.loadTable(tableIdent); + + Snapshot snap1 = table.currentSnapshot(); + + sql("INSERT INTO %s VALUES (2, 'b')", tableName); + + table.refresh(); + + Snapshot snap2 = table.currentSnapshot(); + Assert.assertEquals("Manifest number must match", 1, snap2.dataManifests(table.io()).size()); + + assertEquals( + "Rows should match", + ImmutableList.of(row(2, "b", "INSERT", 0, snap2.snapshotId())), + changelogRecords(snap1, snap2)); + } + + @Test + public void testManifestRewritesAreIgnored() { + createTableWithDefaultRows(); + + sql("CALL %s.system.rewrite_manifests('%s')", catalogName, tableIdent); + + Table table = validationCatalog.loadTable(tableIdent); + Assert.assertEquals("Num snapshots must match", 3, Iterables.size(table.snapshots())); + + assertEquals( + "Should have expected rows", + ImmutableList.of(row(1, "INSERT"), row(2, "INSERT")), + sql("SELECT id, _change_type FROM %s.changes ORDER BY id", tableName)); + } + + @Test + public void testMetadataColumns() { + createTableWithDefaultRows(); + List rows = + sql( + "SELECT id, _file, _pos, _deleted, _spec_id, _partition FROM %s.changes ORDER BY id", + tableName); + + String file1 = rows.get(0)[1].toString(); + Assert.assertTrue(file1.startsWith("file:/")); + String file2 = rows.get(1)[1].toString(); + + assertEquals( + "Rows should match", + ImmutableList.of( + row(1, file1, 0L, false, 0, row("a")), row(2, file2, 0L, false, 0, row("b"))), + rows); + } + + private void createTableWithDefaultRows() { + createTable(); + insertDefaultRows(); + } + + private void createTable() { + sql( + "CREATE TABLE %s (id INT, data STRING) " + + "USING iceberg " + + "PARTITIONED BY (data) " + + "TBLPROPERTIES ( " + + " '%s' = '%d' " + + ")", + tableName, FORMAT_VERSION, formatVersion); + } + + private void insertDefaultRows() { + sql("INSERT INTO %s VALUES (1, 'a')", tableName); + sql("INSERT INTO %s VALUES (2, 'b')", tableName); + } + + private List changelogRecords(Snapshot startSnapshot, Snapshot endSnapshot) { + DataFrameReader reader = spark.read(); + + if (startSnapshot != null) { + reader = reader.option(SparkReadOptions.START_SNAPSHOT_ID, startSnapshot.snapshotId()); + } + + if (endSnapshot != null) { + reader = reader.option(SparkReadOptions.END_SNAPSHOT_ID, endSnapshot.snapshotId()); + } + + return rowsToJava(collect(reader)); + } + + private List changelogRecords(Long startTimestamp, Long endTimeStamp) { + DataFrameReader reader = spark.read(); + + if (startTimestamp != null) { + reader = reader.option(SparkReadOptions.START_TIMESTAMP, startTimestamp); + } + + if (endTimeStamp != null) { + reader = reader.option(SparkReadOptions.END_TIMESTAMP, endTimeStamp); + } + + return rowsToJava(collect(reader)); + } + + private List collect(DataFrameReader reader) { + return reader + .table(tableName + "." + SparkChangelogTable.TABLE_NAME) + .orderBy("_change_ordinal", "_commit_snapshot_id", "_change_type", "id") + .collectAsList(); + } +} diff --git a/spark/v3.4/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestCherrypickSnapshotProcedure.java b/spark/v3.4/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestCherrypickSnapshotProcedure.java new file mode 100644 index 000000000000..7309a176b922 --- /dev/null +++ b/spark/v3.4/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestCherrypickSnapshotProcedure.java @@ -0,0 +1,203 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.extensions; + +import static org.apache.iceberg.TableProperties.WRITE_AUDIT_PUBLISH_ENABLED; + +import java.util.List; +import java.util.Map; +import org.apache.iceberg.AssertHelpers; +import org.apache.iceberg.Snapshot; +import org.apache.iceberg.Table; +import org.apache.iceberg.exceptions.ValidationException; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableList; +import org.apache.iceberg.relocated.com.google.common.collect.Iterables; +import org.apache.spark.sql.AnalysisException; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.catalyst.analysis.NoSuchProcedureException; +import org.junit.After; +import org.junit.Test; + +public class TestCherrypickSnapshotProcedure extends SparkExtensionsTestBase { + + public TestCherrypickSnapshotProcedure( + String catalogName, String implementation, Map config) { + super(catalogName, implementation, config); + } + + @After + public void removeTables() { + sql("DROP TABLE IF EXISTS %s", tableName); + } + + @Test + public void testCherrypickSnapshotUsingPositionalArgs() { + sql("CREATE TABLE %s (id bigint NOT NULL, data string) USING iceberg", tableName); + sql("ALTER TABLE %s SET TBLPROPERTIES ('%s' 'true')", tableName, WRITE_AUDIT_PUBLISH_ENABLED); + + spark.conf().set("spark.wap.id", "1"); + + sql("INSERT INTO TABLE %s VALUES (1, 'a')", tableName); + + assertEquals( + "Should not see rows from staged snapshot", + ImmutableList.of(), + sql("SELECT * FROM %s", tableName)); + + Table table = validationCatalog.loadTable(tableIdent); + Snapshot wapSnapshot = Iterables.getOnlyElement(table.snapshots()); + + List output = + sql( + "CALL %s.system.cherrypick_snapshot('%s', %dL)", + catalogName, tableIdent, wapSnapshot.snapshotId()); + + table.refresh(); + + Snapshot currentSnapshot = table.currentSnapshot(); + + assertEquals( + "Procedure output must match", + ImmutableList.of(row(wapSnapshot.snapshotId(), currentSnapshot.snapshotId())), + output); + + assertEquals( + "Cherrypick must be successful", + ImmutableList.of(row(1L, "a")), + sql("SELECT * FROM %s", tableName)); + } + + @Test + public void testCherrypickSnapshotUsingNamedArgs() { + sql("CREATE TABLE %s (id bigint NOT NULL, data string) USING iceberg", tableName); + sql("ALTER TABLE %s SET TBLPROPERTIES ('%s' 'true')", tableName, WRITE_AUDIT_PUBLISH_ENABLED); + + spark.conf().set("spark.wap.id", "1"); + + sql("INSERT INTO TABLE %s VALUES (1, 'a')", tableName); + + assertEquals( + "Should not see rows from staged snapshot", + ImmutableList.of(), + sql("SELECT * FROM %s", tableName)); + + Table table = validationCatalog.loadTable(tableIdent); + Snapshot wapSnapshot = Iterables.getOnlyElement(table.snapshots()); + + List output = + sql( + "CALL %s.system.cherrypick_snapshot(snapshot_id => %dL, table => '%s')", + catalogName, wapSnapshot.snapshotId(), tableIdent); + + table.refresh(); + + Snapshot currentSnapshot = table.currentSnapshot(); + + assertEquals( + "Procedure output must match", + ImmutableList.of(row(wapSnapshot.snapshotId(), currentSnapshot.snapshotId())), + output); + + assertEquals( + "Cherrypick must be successful", + ImmutableList.of(row(1L, "a")), + sql("SELECT * FROM %s", tableName)); + } + + @Test + public void testCherrypickSnapshotRefreshesRelationCache() { + sql("CREATE TABLE %s (id bigint NOT NULL, data string) USING iceberg", tableName); + sql("ALTER TABLE %s SET TBLPROPERTIES ('%s' 'true')", tableName, WRITE_AUDIT_PUBLISH_ENABLED); + + Dataset query = spark.sql("SELECT * FROM " + tableName + " WHERE id = 1"); + query.createOrReplaceTempView("tmp"); + + spark.sql("CACHE TABLE tmp"); + + assertEquals("View should not produce rows", ImmutableList.of(), sql("SELECT * FROM tmp")); + + spark.conf().set("spark.wap.id", "1"); + + sql("INSERT INTO TABLE %s VALUES (1, 'a')", tableName); + + assertEquals( + "Should not see rows from staged snapshot", + ImmutableList.of(), + sql("SELECT * FROM %s", tableName)); + + Table table = validationCatalog.loadTable(tableIdent); + Snapshot wapSnapshot = Iterables.getOnlyElement(table.snapshots()); + + sql( + "CALL %s.system.cherrypick_snapshot('%s', %dL)", + catalogName, tableIdent, wapSnapshot.snapshotId()); + + assertEquals( + "Cherrypick snapshot should be visible", + ImmutableList.of(row(1L, "a")), + sql("SELECT * FROM tmp")); + + sql("UNCACHE TABLE tmp"); + } + + @Test + public void testCherrypickInvalidSnapshot() { + sql("CREATE TABLE %s (id bigint NOT NULL, data string) USING iceberg", tableName); + + AssertHelpers.assertThrows( + "Should reject invalid snapshot id", + ValidationException.class, + "Cannot cherry-pick unknown snapshot ID", + () -> sql("CALL %s.system.cherrypick_snapshot('%s', -1L)", catalogName, tableIdent)); + } + + @Test + public void testInvalidCherrypickSnapshotCases() { + AssertHelpers.assertThrows( + "Should not allow mixed args", + AnalysisException.class, + "Named and positional arguments cannot be mixed", + () -> sql("CALL %s.system.cherrypick_snapshot('n', table => 't', 1L)", catalogName)); + + AssertHelpers.assertThrows( + "Should not resolve procedures in arbitrary namespaces", + NoSuchProcedureException.class, + "not found", + () -> sql("CALL %s.custom.cherrypick_snapshot('n', 't', 1L)", catalogName)); + + AssertHelpers.assertThrows( + "Should reject calls without all required args", + AnalysisException.class, + "Missing required parameters", + () -> sql("CALL %s.system.cherrypick_snapshot('t')", catalogName)); + + AssertHelpers.assertThrows( + "Should reject calls with empty table identifier", + IllegalArgumentException.class, + "Cannot handle an empty identifier", + () -> sql("CALL %s.system.cherrypick_snapshot('', 1L)", catalogName)); + + AssertHelpers.assertThrows( + "Should reject calls with invalid arg types", + AnalysisException.class, + "Wrong arg type for snapshot_id: cannot cast", + () -> sql("CALL %s.system.cherrypick_snapshot('t', 2.2)", catalogName)); + } +} diff --git a/spark/v3.4/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestConflictValidation.java b/spark/v3.4/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestConflictValidation.java new file mode 100644 index 000000000000..10c86015e2e8 --- /dev/null +++ b/spark/v3.4/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestConflictValidation.java @@ -0,0 +1,436 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.extensions; + +import java.util.List; +import java.util.Map; +import org.apache.iceberg.AssertHelpers; +import org.apache.iceberg.IsolationLevel; +import org.apache.iceberg.Table; +import org.apache.iceberg.exceptions.ValidationException; +import org.apache.iceberg.relocated.com.google.common.collect.Lists; +import org.apache.iceberg.spark.SparkWriteOptions; +import org.apache.iceberg.spark.source.SimpleRecord; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.catalyst.analysis.NoSuchTableException; +import org.apache.spark.sql.functions; +import org.junit.After; +import org.junit.Before; +import org.junit.Test; + +public class TestConflictValidation extends SparkExtensionsTestBase { + + public TestConflictValidation( + String catalogName, String implementation, Map config) { + super(catalogName, implementation, config); + } + + @Before + public void createTables() { + sql( + "CREATE TABLE %s (id int, data string) USING iceberg " + + "PARTITIONED BY (id)" + + "TBLPROPERTIES" + + "('format-version'='2'," + + "'write.delete.mode'='merge-on-read')", + tableName); + sql("INSERT INTO %s VALUES (1, 'a'), (2, 'b'), (3, 'c')", tableName); + } + + @After + public void removeTables() { + sql("DROP TABLE IF EXISTS %s", tableName); + } + + @Test + public void testOverwriteFilterSerializableIsolation() throws Exception { + Table table = validationCatalog.loadTable(tableIdent); + long snapshotId = table.currentSnapshot().snapshotId(); + + List records = Lists.newArrayList(new SimpleRecord(1, "a")); + spark.createDataFrame(records, SimpleRecord.class).writeTo(tableName).append(); + + // Validating from previous snapshot finds conflicts + Dataset conflictingDf = spark.createDataFrame(records, SimpleRecord.class); + AssertHelpers.assertThrows( + "Conflicting new data files should throw exception", + ValidationException.class, + "Found conflicting files that can contain records matching ref(name=\"id\") == 1:", + () -> { + try { + conflictingDf + .writeTo(tableName) + .option(SparkWriteOptions.VALIDATE_FROM_SNAPSHOT_ID, String.valueOf(snapshotId)) + .option(SparkWriteOptions.ISOLATION_LEVEL, IsolationLevel.SERIALIZABLE.toString()) + .overwrite(functions.col("id").equalTo(1)); + } catch (NoSuchTableException e) { + throw new RuntimeException(e); + } + }); + + // Validating from latest snapshot should succeed + table.refresh(); + long newSnapshotId = table.currentSnapshot().snapshotId(); + conflictingDf + .writeTo(tableName) + .option(SparkWriteOptions.VALIDATE_FROM_SNAPSHOT_ID, String.valueOf(newSnapshotId)) + .option(SparkWriteOptions.ISOLATION_LEVEL, IsolationLevel.SERIALIZABLE.toString()) + .overwrite(functions.col("id").equalTo(1)); + } + + @Test + public void testOverwriteFilterSerializableIsolation2() throws Exception { + List records = + Lists.newArrayList(new SimpleRecord(1, "a"), new SimpleRecord(1, "b")); + spark.createDataFrame(records, SimpleRecord.class).coalesce(1).writeTo(tableName).append(); + + Table table = validationCatalog.loadTable(tableIdent); + long snapshotId = table.currentSnapshot().snapshotId(); + + // This should add a delete file + sql("DELETE FROM %s WHERE id='1' and data='b'", tableName); + table.refresh(); + + // Validating from previous snapshot finds conflicts + List conflictingRecords = Lists.newArrayList(new SimpleRecord(1, "a")); + Dataset conflictingDf = spark.createDataFrame(conflictingRecords, SimpleRecord.class); + AssertHelpers.assertThrows( + "Conflicting new delete files should throw exception", + ValidationException.class, + "Found new conflicting delete files that can apply to records matching ref(name=\"id\") == 1:", + () -> { + try { + conflictingDf + .writeTo(tableName) + .option(SparkWriteOptions.VALIDATE_FROM_SNAPSHOT_ID, String.valueOf(snapshotId)) + .option(SparkWriteOptions.ISOLATION_LEVEL, IsolationLevel.SNAPSHOT.toString()) + .overwrite(functions.col("id").equalTo(1)); + } catch (NoSuchTableException e) { + throw new RuntimeException(e); + } + }); + + // Validating from latest snapshot should succeed + table.refresh(); + long newSnapshotId = table.currentSnapshot().snapshotId(); + conflictingDf + .writeTo(tableName) + .option(SparkWriteOptions.VALIDATE_FROM_SNAPSHOT_ID, String.valueOf(newSnapshotId)) + .option(SparkWriteOptions.ISOLATION_LEVEL, IsolationLevel.SERIALIZABLE.toString()) + .overwrite(functions.col("id").equalTo(1)); + } + + @Test + public void testOverwriteFilterSerializableIsolation3() throws Exception { + Table table = validationCatalog.loadTable(tableIdent); + long snapshotId = table.currentSnapshot().snapshotId(); + + // This should delete a data file + sql("DELETE FROM %s WHERE id='1'", tableName); + table.refresh(); + + // Validating from previous snapshot finds conflicts + List conflictingRecords = Lists.newArrayList(new SimpleRecord(1, "a")); + Dataset conflictingDf = spark.createDataFrame(conflictingRecords, SimpleRecord.class); + AssertHelpers.assertThrows( + "Conflicting deleted data files should throw exception", + ValidationException.class, + "Found conflicting deleted files that can contain records matching ref(name=\"id\") == 1:", + () -> { + try { + conflictingDf + .writeTo(tableName) + .option(SparkWriteOptions.VALIDATE_FROM_SNAPSHOT_ID, String.valueOf(snapshotId)) + .option(SparkWriteOptions.ISOLATION_LEVEL, IsolationLevel.SERIALIZABLE.toString()) + .overwrite(functions.col("id").equalTo(1)); + } catch (NoSuchTableException e) { + throw new RuntimeException(e); + } + }); + + // Validating from latest snapshot should succeed + table.refresh(); + long newSnapshotId = table.currentSnapshot().snapshotId(); + conflictingDf + .writeTo(tableName) + .option(SparkWriteOptions.VALIDATE_FROM_SNAPSHOT_ID, String.valueOf(newSnapshotId)) + .option(SparkWriteOptions.ISOLATION_LEVEL, IsolationLevel.SERIALIZABLE.toString()) + .overwrite(functions.col("id").equalTo(1)); + } + + @Test + public void testOverwriteFilterNoSnapshotIdValidation() throws Exception { + Table table = validationCatalog.loadTable(tableIdent); + + List records = Lists.newArrayList(new SimpleRecord(1, "a")); + spark.createDataFrame(records, SimpleRecord.class).writeTo(tableName).append(); + + // Validating from no snapshot id defaults to beginning snapshot id and finds conflicts + Dataset conflictingDf = spark.createDataFrame(records, SimpleRecord.class); + AssertHelpers.assertThrows( + "Conflicting new data files should throw exception", + ValidationException.class, + "Found conflicting files that can contain records matching ref(name=\"id\") == 1:", + () -> { + try { + conflictingDf + .writeTo(tableName) + .option(SparkWriteOptions.ISOLATION_LEVEL, IsolationLevel.SERIALIZABLE.toString()) + .overwrite(functions.col("id").equalTo(1)); + } catch (NoSuchTableException e) { + throw new RuntimeException(e); + } + }); + + // Validating from latest snapshot should succeed + table.refresh(); + long newSnapshotId = table.currentSnapshot().snapshotId(); + conflictingDf + .writeTo(tableName) + .option(SparkWriteOptions.VALIDATE_FROM_SNAPSHOT_ID, String.valueOf(newSnapshotId)) + .option(SparkWriteOptions.ISOLATION_LEVEL, IsolationLevel.SERIALIZABLE.toString()) + .overwrite(functions.col("id").equalTo(1)); + } + + @Test + public void testOverwriteFilterSnapshotIsolation() throws Exception { + List records = + Lists.newArrayList(new SimpleRecord(1, "a"), new SimpleRecord(1, "b")); + spark.createDataFrame(records, SimpleRecord.class).coalesce(1).writeTo(tableName).append(); + + Table table = validationCatalog.loadTable(tableIdent); + long snapshotId = table.currentSnapshot().snapshotId(); + + // This should add a delete file + sql("DELETE FROM %s WHERE id='1' and data='b'", tableName); + table.refresh(); + + // Validating from previous snapshot finds conflicts + List conflictingRecords = Lists.newArrayList(new SimpleRecord(1, "a")); + Dataset conflictingDf = spark.createDataFrame(conflictingRecords, SimpleRecord.class); + AssertHelpers.assertThrows( + "Conflicting new delete files should throw exception", + ValidationException.class, + "Found new conflicting delete files that can apply to records matching ref(name=\"id\") == 1:", + () -> { + try { + conflictingDf + .writeTo(tableName) + .option(SparkWriteOptions.VALIDATE_FROM_SNAPSHOT_ID, String.valueOf(snapshotId)) + .option(SparkWriteOptions.ISOLATION_LEVEL, IsolationLevel.SNAPSHOT.toString()) + .overwrite(functions.col("id").equalTo(1)); + } catch (NoSuchTableException e) { + throw new RuntimeException(e); + } + }); + + // Validating from latest snapshot should succeed + table.refresh(); + long newSnapshotId = table.currentSnapshot().snapshotId(); + conflictingDf + .writeTo(tableName) + .option(SparkWriteOptions.VALIDATE_FROM_SNAPSHOT_ID, String.valueOf(newSnapshotId)) + .option(SparkWriteOptions.ISOLATION_LEVEL, IsolationLevel.SNAPSHOT.toString()) + .overwrite(functions.col("id").equalTo(1)); + } + + @Test + public void testOverwriteFilterSnapshotIsolation2() throws Exception { + Table table = validationCatalog.loadTable(tableIdent); + long snapshotId = table.currentSnapshot().snapshotId(); + + List records = Lists.newArrayList(new SimpleRecord(1, "a")); + spark.createDataFrame(records, SimpleRecord.class).writeTo(tableName).append(); + + // Validation should not fail due to conflicting data file in snapshot isolation mode + Dataset conflictingDf = spark.createDataFrame(records, SimpleRecord.class); + conflictingDf + .writeTo(tableName) + .option(SparkWriteOptions.VALIDATE_FROM_SNAPSHOT_ID, String.valueOf(snapshotId)) + .option(SparkWriteOptions.ISOLATION_LEVEL, IsolationLevel.SNAPSHOT.toString()) + .overwrite(functions.col("id").equalTo(1)); + } + + @Test + public void testOverwritePartitionSerializableIsolation() throws Exception { + Table table = validationCatalog.loadTable(tableIdent); + final long snapshotId = table.currentSnapshot().snapshotId(); + + List records = Lists.newArrayList(new SimpleRecord(1, "a")); + spark.createDataFrame(records, SimpleRecord.class).writeTo(tableName).append(); + + // Validating from previous snapshot finds conflicts + Dataset conflictingDf = spark.createDataFrame(records, SimpleRecord.class); + AssertHelpers.assertThrows( + "Conflicting deleted data files should throw exception", + ValidationException.class, + "Found conflicting files that can contain records matching partitions [id=1]", + () -> { + try { + conflictingDf + .writeTo(tableName) + .option(SparkWriteOptions.VALIDATE_FROM_SNAPSHOT_ID, String.valueOf(snapshotId)) + .option(SparkWriteOptions.ISOLATION_LEVEL, IsolationLevel.SERIALIZABLE.toString()) + .overwritePartitions(); + } catch (NoSuchTableException e) { + throw new RuntimeException(e); + } + }); + + // Validating from latest snapshot should succeed + table.refresh(); + long newSnapshotId = table.currentSnapshot().snapshotId(); + conflictingDf + .writeTo(tableName) + .option(SparkWriteOptions.VALIDATE_FROM_SNAPSHOT_ID, String.valueOf(newSnapshotId)) + .option(SparkWriteOptions.ISOLATION_LEVEL, IsolationLevel.SERIALIZABLE.toString()) + .overwritePartitions(); + } + + @Test + public void testOverwritePartitionSnapshotIsolation() throws Exception { + List records = + Lists.newArrayList(new SimpleRecord(1, "a"), new SimpleRecord(1, "b")); + spark.createDataFrame(records, SimpleRecord.class).coalesce(1).writeTo(tableName).append(); + + Table table = validationCatalog.loadTable(tableIdent); + final long snapshotId = table.currentSnapshot().snapshotId(); + + // This should generate a delete file + sql("DELETE FROM %s WHERE data='a'", tableName); + + // Validating from previous snapshot finds conflicts + Dataset conflictingDf = spark.createDataFrame(records, SimpleRecord.class); + AssertHelpers.assertThrows( + "Conflicting deleted data files should throw exception", + ValidationException.class, + "Found new conflicting delete files that can apply to records matching [id=1]", + () -> { + try { + conflictingDf + .writeTo(tableName) + .option(SparkWriteOptions.VALIDATE_FROM_SNAPSHOT_ID, String.valueOf(snapshotId)) + .option(SparkWriteOptions.ISOLATION_LEVEL, IsolationLevel.SNAPSHOT.toString()) + .overwritePartitions(); + } catch (NoSuchTableException e) { + throw new RuntimeException(e); + } + }); + + // Validating from latest snapshot should succeed + table.refresh(); + long newSnapshotId = table.currentSnapshot().snapshotId(); + conflictingDf + .writeTo(tableName) + .option(SparkWriteOptions.VALIDATE_FROM_SNAPSHOT_ID, String.valueOf(newSnapshotId)) + .option(SparkWriteOptions.ISOLATION_LEVEL, IsolationLevel.SNAPSHOT.toString()) + .overwritePartitions(); + } + + @Test + public void testOverwritePartitionSnapshotIsolation2() throws Exception { + Table table = validationCatalog.loadTable(tableIdent); + final long snapshotId = table.currentSnapshot().snapshotId(); + + // This should delete a data file + sql("DELETE FROM %s WHERE id='1'", tableName); + + // Validating from previous snapshot finds conflicts + List records = Lists.newArrayList(new SimpleRecord(1, "a")); + spark.createDataFrame(records, SimpleRecord.class).coalesce(1).writeTo(tableName).append(); + Dataset conflictingDf = spark.createDataFrame(records, SimpleRecord.class); + + AssertHelpers.assertThrows( + "Conflicting deleted data files should throw exception", + ValidationException.class, + "Found conflicting deleted files that can apply to records matching [id=1]", + () -> { + try { + conflictingDf + .writeTo(tableName) + .option(SparkWriteOptions.VALIDATE_FROM_SNAPSHOT_ID, String.valueOf(snapshotId)) + .option(SparkWriteOptions.ISOLATION_LEVEL, IsolationLevel.SNAPSHOT.toString()) + .overwritePartitions(); + } catch (NoSuchTableException e) { + throw new RuntimeException(e); + } + }); + + // Validating from latest snapshot should succeed + table.refresh(); + long newSnapshotId = table.currentSnapshot().snapshotId(); + conflictingDf + .writeTo(tableName) + .option(SparkWriteOptions.VALIDATE_FROM_SNAPSHOT_ID, String.valueOf(newSnapshotId)) + .option(SparkWriteOptions.ISOLATION_LEVEL, IsolationLevel.SNAPSHOT.toString()) + .overwritePartitions(); + } + + @Test + public void testOverwritePartitionSnapshotIsolation3() throws Exception { + Table table = validationCatalog.loadTable(tableIdent); + final long snapshotId = table.currentSnapshot().snapshotId(); + + List records = Lists.newArrayList(new SimpleRecord(1, "a")); + spark.createDataFrame(records, SimpleRecord.class).writeTo(tableName).append(); + + // Validation should not find conflicting data file in snapshot isolation mode + Dataset conflictingDf = spark.createDataFrame(records, SimpleRecord.class); + conflictingDf + .writeTo(tableName) + .option(SparkWriteOptions.VALIDATE_FROM_SNAPSHOT_ID, String.valueOf(snapshotId)) + .option(SparkWriteOptions.ISOLATION_LEVEL, IsolationLevel.SNAPSHOT.toString()) + .overwritePartitions(); + } + + @Test + public void testOverwritePartitionNoSnapshotIdValidation() throws Exception { + Table table = validationCatalog.loadTable(tableIdent); + + List records = Lists.newArrayList(new SimpleRecord(1, "a")); + spark.createDataFrame(records, SimpleRecord.class).writeTo(tableName).append(); + + // Validating from null snapshot is equivalent to validating from beginning + Dataset conflictingDf = spark.createDataFrame(records, SimpleRecord.class); + AssertHelpers.assertThrows( + "Conflicting deleted data files should throw exception", + ValidationException.class, + "Found conflicting files that can contain records matching partitions [id=1]", + () -> { + try { + conflictingDf + .writeTo(tableName) + .option(SparkWriteOptions.ISOLATION_LEVEL, IsolationLevel.SERIALIZABLE.toString()) + .overwritePartitions(); + } catch (NoSuchTableException e) { + throw new RuntimeException(e); + } + }); + + // Validating from latest snapshot should succeed + table.refresh(); + long snapshotId = table.currentSnapshot().snapshotId(); + conflictingDf + .writeTo(tableName) + .option(SparkWriteOptions.VALIDATE_FROM_SNAPSHOT_ID, String.valueOf(snapshotId)) + .option(SparkWriteOptions.ISOLATION_LEVEL, IsolationLevel.SERIALIZABLE.toString()) + .overwritePartitions(); + } +} diff --git a/spark/v3.4/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestCopyOnWriteDelete.java b/spark/v3.4/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestCopyOnWriteDelete.java new file mode 100644 index 000000000000..53177340dadd --- /dev/null +++ b/spark/v3.4/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestCopyOnWriteDelete.java @@ -0,0 +1,189 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.extensions; + +import static org.apache.iceberg.TableProperties.DELETE_ISOLATION_LEVEL; + +import java.util.Collections; +import java.util.Map; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.Future; +import java.util.concurrent.ThreadPoolExecutor; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicInteger; +import org.apache.iceberg.AppendFiles; +import org.apache.iceberg.DataFile; +import org.apache.iceberg.RowLevelOperationMode; +import org.apache.iceberg.Snapshot; +import org.apache.iceberg.Table; +import org.apache.iceberg.TableProperties; +import org.apache.iceberg.data.GenericRecord; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableList; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap; +import org.apache.iceberg.relocated.com.google.common.collect.Iterables; +import org.apache.iceberg.relocated.com.google.common.util.concurrent.MoreExecutors; +import org.apache.iceberg.spark.Spark3Util; +import org.apache.iceberg.spark.SparkSQLProperties; +import org.apache.iceberg.util.SnapshotUtil; +import org.apache.spark.sql.Encoders; +import org.apache.spark.sql.catalyst.analysis.NoSuchTableException; +import org.apache.spark.sql.internal.SQLConf; +import org.assertj.core.api.Assertions; +import org.junit.Assert; +import org.junit.Assume; +import org.junit.Test; + +public class TestCopyOnWriteDelete extends TestDelete { + + public TestCopyOnWriteDelete( + String catalogName, + String implementation, + Map config, + String fileFormat, + Boolean vectorized, + String distributionMode, + String branch) { + super(catalogName, implementation, config, fileFormat, vectorized, distributionMode, branch); + } + + @Override + protected Map extraTableProperties() { + return ImmutableMap.of( + TableProperties.DELETE_MODE, RowLevelOperationMode.COPY_ON_WRITE.modeName()); + } + + @Test + public synchronized void testDeleteWithConcurrentTableRefresh() throws Exception { + // this test can only be run with Hive tables as it requires a reliable lock + // also, the table cache must be enabled so that the same table instance can be reused + Assume.assumeTrue(catalogName.equalsIgnoreCase("testhive")); + + createAndInitUnpartitionedTable(); + createOrReplaceView("deleted_id", Collections.singletonList(1), Encoders.INT()); + + sql( + "ALTER TABLE %s SET TBLPROPERTIES('%s' '%s')", + tableName, DELETE_ISOLATION_LEVEL, "snapshot"); + + sql("INSERT INTO TABLE %s VALUES (1, 'hr')", tableName); + createBranchIfNeeded(); + + Table table = Spark3Util.loadIcebergTable(spark, tableName); + + ExecutorService executorService = + MoreExecutors.getExitingExecutorService( + (ThreadPoolExecutor) Executors.newFixedThreadPool(2)); + + AtomicInteger barrier = new AtomicInteger(0); + AtomicBoolean shouldAppend = new AtomicBoolean(true); + + // delete thread + Future deleteFuture = + executorService.submit( + () -> { + for (int numOperations = 0; numOperations < Integer.MAX_VALUE; numOperations++) { + while (barrier.get() < numOperations * 2) { + sleep(10); + } + + sql("DELETE FROM %s WHERE id IN (SELECT * FROM deleted_id)", commitTarget()); + + barrier.incrementAndGet(); + } + }); + + // append thread + Future appendFuture = + executorService.submit( + () -> { + GenericRecord record = GenericRecord.create(SnapshotUtil.schemaFor(table, branch)); + record.set(0, 1); // id + record.set(1, "hr"); // dep + + for (int numOperations = 0; numOperations < Integer.MAX_VALUE; numOperations++) { + while (shouldAppend.get() && barrier.get() < numOperations * 2) { + sleep(10); + } + + if (!shouldAppend.get()) { + return; + } + + for (int numAppends = 0; numAppends < 5; numAppends++) { + DataFile dataFile = writeDataFile(table, ImmutableList.of(record)); + AppendFiles appendFiles = table.newFastAppend().appendFile(dataFile); + if (branch != null) { + appendFiles.toBranch(branch); + } + + appendFiles.commit(); + sleep(10); + } + + barrier.incrementAndGet(); + } + }); + + try { + Assertions.assertThatThrownBy(deleteFuture::get) + .isInstanceOf(ExecutionException.class) + .cause() + .isInstanceOf(IllegalStateException.class) + .hasMessageContaining("the table has been concurrently modified"); + } finally { + shouldAppend.set(false); + appendFuture.cancel(true); + } + + executorService.shutdown(); + Assert.assertTrue("Timeout", executorService.awaitTermination(2, TimeUnit.MINUTES)); + } + + @Test + public void testRuntimeFilteringWithPreservedDataGrouping() throws NoSuchTableException { + createAndInitPartitionedTable(); + + append(tableName, new Employee(1, "hr"), new Employee(3, "hr")); + createBranchIfNeeded(); + append(new Employee(1, "hardware"), new Employee(2, "hardware")); + + Map sqlConf = + ImmutableMap.of( + SQLConf.V2_BUCKETING_ENABLED().key(), + "true", + SparkSQLProperties.PRESERVE_DATA_GROUPING, + "true"); + + withSQLConf(sqlConf, () -> sql("DELETE FROM %s WHERE id = 2", commitTarget())); + + Table table = validationCatalog.loadTable(tableIdent); + Assert.assertEquals("Should have 3 snapshots", 3, Iterables.size(table.snapshots())); + + Snapshot currentSnapshot = SnapshotUtil.latestSnapshot(table, branch); + validateCopyOnWrite(currentSnapshot, "1", "1", "1"); + + assertEquals( + "Should have expected rows", + ImmutableList.of(row(1, "hardware"), row(1, "hr"), row(3, "hr")), + sql("SELECT * FROM %s ORDER BY id, dep", selectTarget())); + } +} diff --git a/spark/v3.4/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestCopyOnWriteMerge.java b/spark/v3.4/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestCopyOnWriteMerge.java new file mode 100644 index 000000000000..ed1e05f822cf --- /dev/null +++ b/spark/v3.4/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestCopyOnWriteMerge.java @@ -0,0 +1,200 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.extensions; + +import static org.apache.iceberg.TableProperties.MERGE_ISOLATION_LEVEL; + +import java.util.Collections; +import java.util.Map; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.Future; +import java.util.concurrent.ThreadPoolExecutor; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicInteger; +import org.apache.iceberg.DataFile; +import org.apache.iceberg.RowLevelOperationMode; +import org.apache.iceberg.Snapshot; +import org.apache.iceberg.Table; +import org.apache.iceberg.TableProperties; +import org.apache.iceberg.data.GenericRecord; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableList; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap; +import org.apache.iceberg.relocated.com.google.common.collect.Iterables; +import org.apache.iceberg.relocated.com.google.common.util.concurrent.MoreExecutors; +import org.apache.iceberg.spark.Spark3Util; +import org.apache.iceberg.spark.SparkSQLProperties; +import org.apache.iceberg.util.SnapshotUtil; +import org.apache.spark.sql.Encoders; +import org.apache.spark.sql.internal.SQLConf; +import org.assertj.core.api.Assertions; +import org.junit.Assert; +import org.junit.Assume; +import org.junit.Test; + +public class TestCopyOnWriteMerge extends TestMerge { + + public TestCopyOnWriteMerge( + String catalogName, + String implementation, + Map config, + String fileFormat, + boolean vectorized, + String distributionMode, + String branch) { + super(catalogName, implementation, config, fileFormat, vectorized, distributionMode, branch); + } + + @Override + protected Map extraTableProperties() { + return ImmutableMap.of( + TableProperties.MERGE_MODE, RowLevelOperationMode.COPY_ON_WRITE.modeName()); + } + + @Test + public synchronized void testMergeWithConcurrentTableRefresh() throws Exception { + // this test can only be run with Hive tables as it requires a reliable lock + // also, the table cache must be enabled so that the same table instance can be reused + Assume.assumeTrue(catalogName.equalsIgnoreCase("testhive")); + + createAndInitTable("id INT, dep STRING"); + createOrReplaceView("source", Collections.singletonList(1), Encoders.INT()); + + sql( + "ALTER TABLE %s SET TBLPROPERTIES('%s' '%s')", + tableName, MERGE_ISOLATION_LEVEL, "snapshot"); + + sql("INSERT INTO TABLE %s VALUES (1, 'hr')", tableName); + createBranchIfNeeded(); + + Table table = Spark3Util.loadIcebergTable(spark, tableName); + + ExecutorService executorService = + MoreExecutors.getExitingExecutorService( + (ThreadPoolExecutor) Executors.newFixedThreadPool(2)); + + AtomicInteger barrier = new AtomicInteger(0); + AtomicBoolean shouldAppend = new AtomicBoolean(true); + + // merge thread + Future mergeFuture = + executorService.submit( + () -> { + for (int numOperations = 0; numOperations < Integer.MAX_VALUE; numOperations++) { + while (barrier.get() < numOperations * 2) { + sleep(10); + } + + sql( + "MERGE INTO %s t USING source s " + + "ON t.id == s.value " + + "WHEN MATCHED THEN " + + " UPDATE SET dep = 'x'", + tableName); + + barrier.incrementAndGet(); + } + }); + + // append thread + Future appendFuture = + executorService.submit( + () -> { + GenericRecord record = GenericRecord.create(table.schema()); + record.set(0, 1); // id + record.set(1, "hr"); // dep + + for (int numOperations = 0; numOperations < Integer.MAX_VALUE; numOperations++) { + while (shouldAppend.get() && barrier.get() < numOperations * 2) { + sleep(10); + } + + if (!shouldAppend.get()) { + return; + } + + for (int numAppends = 0; numAppends < 5; numAppends++) { + DataFile dataFile = writeDataFile(table, ImmutableList.of(record)); + table.newFastAppend().appendFile(dataFile).commit(); + sleep(10); + } + + barrier.incrementAndGet(); + } + }); + + try { + Assertions.assertThatThrownBy(mergeFuture::get) + .isInstanceOf(ExecutionException.class) + .cause() + .isInstanceOf(IllegalStateException.class) + .hasMessageContaining("the table has been concurrently modified"); + } finally { + shouldAppend.set(false); + appendFuture.cancel(true); + } + + executorService.shutdown(); + Assert.assertTrue("Timeout", executorService.awaitTermination(2, TimeUnit.MINUTES)); + } + + @Test + public void testRuntimeFilteringWithReportedPartitioning() { + createAndInitTable("id INT, dep STRING"); + sql("ALTER TABLE %s ADD PARTITION FIELD dep", tableName); + + append(tableName, "{ \"id\": 1, \"dep\": \"hr\" }\n" + "{ \"id\": 3, \"dep\": \"hr\" }"); + createBranchIfNeeded(); + append( + commitTarget(), + "{ \"id\": 1, \"dep\": \"hardware\" }\n" + "{ \"id\": 2, \"dep\": \"hardware\" }"); + + createOrReplaceView("source", Collections.singletonList(2), Encoders.INT()); + + Map sqlConf = + ImmutableMap.of( + SQLConf.V2_BUCKETING_ENABLED().key(), + "true", + SparkSQLProperties.PRESERVE_DATA_GROUPING, + "true"); + + withSQLConf( + sqlConf, + () -> + sql( + "MERGE INTO %s t USING source s " + + "ON t.id == s.value " + + "WHEN MATCHED THEN " + + " UPDATE SET id = -1", + commitTarget())); + + Table table = validationCatalog.loadTable(tableIdent); + Assert.assertEquals("Should have 3 snapshots", 3, Iterables.size(table.snapshots())); + + Snapshot currentSnapshot = SnapshotUtil.latestSnapshot(table, branch); + validateCopyOnWrite(currentSnapshot, "1", "1", "1"); + + assertEquals( + "Should have expected rows", + ImmutableList.of(row(-1, "hardware"), row(1, "hardware"), row(1, "hr"), row(3, "hr")), + sql("SELECT * FROM %s ORDER BY id, dep", selectTarget())); + } +} diff --git a/spark/v3.4/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestCopyOnWriteUpdate.java b/spark/v3.4/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestCopyOnWriteUpdate.java new file mode 100644 index 000000000000..f9f48e8f41c7 --- /dev/null +++ b/spark/v3.4/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestCopyOnWriteUpdate.java @@ -0,0 +1,188 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.extensions; + +import static org.apache.iceberg.TableProperties.UPDATE_ISOLATION_LEVEL; + +import java.util.Map; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.Future; +import java.util.concurrent.ThreadPoolExecutor; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicInteger; +import org.apache.iceberg.AppendFiles; +import org.apache.iceberg.DataFile; +import org.apache.iceberg.RowLevelOperationMode; +import org.apache.iceberg.Snapshot; +import org.apache.iceberg.Table; +import org.apache.iceberg.TableProperties; +import org.apache.iceberg.data.GenericRecord; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableList; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap; +import org.apache.iceberg.relocated.com.google.common.collect.Iterables; +import org.apache.iceberg.relocated.com.google.common.util.concurrent.MoreExecutors; +import org.apache.iceberg.spark.Spark3Util; +import org.apache.iceberg.spark.SparkSQLProperties; +import org.apache.iceberg.util.SnapshotUtil; +import org.apache.spark.sql.internal.SQLConf; +import org.assertj.core.api.Assertions; +import org.junit.Assert; +import org.junit.Assume; +import org.junit.Test; + +public class TestCopyOnWriteUpdate extends TestUpdate { + + public TestCopyOnWriteUpdate( + String catalogName, + String implementation, + Map config, + String fileFormat, + boolean vectorized, + String distributionMode, + String branch) { + super(catalogName, implementation, config, fileFormat, vectorized, distributionMode, branch); + } + + @Override + protected Map extraTableProperties() { + return ImmutableMap.of( + TableProperties.UPDATE_MODE, RowLevelOperationMode.COPY_ON_WRITE.modeName()); + } + + @Test + public synchronized void testUpdateWithConcurrentTableRefresh() throws Exception { + // this test can only be run with Hive tables as it requires a reliable lock + // also, the table cache must be enabled so that the same table instance can be reused + Assume.assumeTrue(catalogName.equalsIgnoreCase("testhive")); + + createAndInitTable("id INT, dep STRING"); + + sql( + "ALTER TABLE %s SET TBLPROPERTIES('%s' '%s')", + tableName, UPDATE_ISOLATION_LEVEL, "snapshot"); + + sql("INSERT INTO TABLE %s VALUES (1, 'hr')", tableName); + createBranchIfNeeded(); + + Table table = Spark3Util.loadIcebergTable(spark, tableName); + + ExecutorService executorService = + MoreExecutors.getExitingExecutorService( + (ThreadPoolExecutor) Executors.newFixedThreadPool(2)); + + AtomicInteger barrier = new AtomicInteger(0); + AtomicBoolean shouldAppend = new AtomicBoolean(true); + + // update thread + Future updateFuture = + executorService.submit( + () -> { + for (int numOperations = 0; numOperations < Integer.MAX_VALUE; numOperations++) { + while (barrier.get() < numOperations * 2) { + sleep(10); + } + + sql("UPDATE %s SET id = -1 WHERE id = 1", commitTarget()); + + barrier.incrementAndGet(); + } + }); + + // append thread + Future appendFuture = + executorService.submit( + () -> { + GenericRecord record = GenericRecord.create(SnapshotUtil.schemaFor(table, branch)); + record.set(0, 1); // id + record.set(1, "hr"); // dep + + for (int numOperations = 0; numOperations < Integer.MAX_VALUE; numOperations++) { + while (shouldAppend.get() && barrier.get() < numOperations * 2) { + sleep(10); + } + + if (!shouldAppend.get()) { + return; + } + + for (int numAppends = 0; numAppends < 5; numAppends++) { + DataFile dataFile = writeDataFile(table, ImmutableList.of(record)); + AppendFiles appendFiles = table.newFastAppend().appendFile(dataFile); + if (branch != null) { + appendFiles.toBranch(branch); + } + + appendFiles.commit(); + sleep(10); + } + + barrier.incrementAndGet(); + } + }); + + try { + Assertions.assertThatThrownBy(updateFuture::get) + .isInstanceOf(ExecutionException.class) + .cause() + .isInstanceOf(IllegalStateException.class) + .hasMessageContaining("the table has been concurrently modified"); + } finally { + shouldAppend.set(false); + appendFuture.cancel(true); + } + + executorService.shutdown(); + Assert.assertTrue("Timeout", executorService.awaitTermination(2, TimeUnit.MINUTES)); + } + + @Test + public void testRuntimeFilteringWithReportedPartitioning() { + createAndInitTable("id INT, dep STRING"); + sql("ALTER TABLE %s ADD PARTITION FIELD dep", tableName); + + append(tableName, "{ \"id\": 1, \"dep\": \"hr\" }\n" + "{ \"id\": 3, \"dep\": \"hr\" }"); + createBranchIfNeeded(); + append( + commitTarget(), + "{ \"id\": 1, \"dep\": \"hardware\" }\n" + "{ \"id\": 2, \"dep\": \"hardware\" }"); + + Map sqlConf = + ImmutableMap.of( + SQLConf.V2_BUCKETING_ENABLED().key(), + "true", + SparkSQLProperties.PRESERVE_DATA_GROUPING, + "true"); + + withSQLConf(sqlConf, () -> sql("UPDATE %s SET id = -1 WHERE id = 2", commitTarget())); + + Table table = validationCatalog.loadTable(tableIdent); + Assert.assertEquals("Should have 3 snapshots", 3, Iterables.size(table.snapshots())); + + Snapshot currentSnapshot = SnapshotUtil.latestSnapshot(table, branch); + validateCopyOnWrite(currentSnapshot, "1", "1", "1"); + + assertEquals( + "Should have expected rows", + ImmutableList.of(row(-1, "hardware"), row(1, "hardware"), row(1, "hr"), row(3, "hr")), + sql("SELECT * FROM %s ORDER BY id, dep", selectTarget())); + } +} diff --git a/spark/v3.4/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestCreateChangelogViewProcedure.java b/spark/v3.4/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestCreateChangelogViewProcedure.java new file mode 100644 index 000000000000..dc12b0145d50 --- /dev/null +++ b/spark/v3.4/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestCreateChangelogViewProcedure.java @@ -0,0 +1,450 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.extensions; + +import java.util.List; +import java.util.Map; +import org.apache.iceberg.ChangelogOperation; +import org.apache.iceberg.Snapshot; +import org.apache.iceberg.Table; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableList; +import org.apache.iceberg.spark.SparkReadOptions; +import org.junit.After; +import org.junit.Assert; +import org.junit.Test; + +public class TestCreateChangelogViewProcedure extends SparkExtensionsTestBase { + private static final String DELETE = ChangelogOperation.DELETE.name(); + private static final String INSERT = ChangelogOperation.INSERT.name(); + private static final String UPDATE_BEFORE = ChangelogOperation.UPDATE_BEFORE.name(); + private static final String UPDATE_AFTER = ChangelogOperation.UPDATE_AFTER.name(); + + public TestCreateChangelogViewProcedure( + String catalogName, String implementation, Map config) { + super(catalogName, implementation, config); + } + + @After + public void removeTable() { + sql("DROP TABLE IF EXISTS %s", tableName); + } + + public void createTableWith2Columns() { + sql("CREATE TABLE %s (id INT, data STRING) USING iceberg", tableName); + sql("ALTER TABLE %s SET TBLPROPERTIES ('format-version'='%d')", tableName, 1); + sql("ALTER TABLE %s ADD PARTITION FIELD data", tableName); + } + + private void createTableWith3Columns() { + sql("CREATE TABLE %s (id INT, data STRING, age INT) USING iceberg", tableName); + sql("ALTER TABLE %s SET TBLPROPERTIES ('format-version'='%d')", tableName, 1); + sql("ALTER TABLE %s ADD PARTITION FIELD id", tableName); + } + + private void createTableWithIdentifierField() { + sql("CREATE TABLE %s (id INT NOT NULL, data STRING) USING iceberg", tableName); + sql("ALTER TABLE %s SET TBLPROPERTIES ('format-version'='%d')", tableName, 1); + sql("ALTER TABLE %s SET IDENTIFIER FIELDS id", tableName); + } + + @Test + public void testCustomizedViewName() { + createTableWith2Columns(); + sql("INSERT INTO %s VALUES (1, 'a')", tableName); + sql("INSERT INTO %s VALUES (2, 'b')", tableName); + + Table table = validationCatalog.loadTable(tableIdent); + + Snapshot snap1 = table.currentSnapshot(); + + sql("INSERT OVERWRITE %s VALUES (-2, 'b')", tableName); + + table.refresh(); + + Snapshot snap2 = table.currentSnapshot(); + + sql( + "CALL %s.system.create_changelog_view(" + + "table => '%s'," + + "options => map('%s','%s','%s','%s')," + + "changelog_view => '%s')", + catalogName, + tableName, + SparkReadOptions.START_SNAPSHOT_ID, + snap1.snapshotId(), + SparkReadOptions.END_SNAPSHOT_ID, + snap2.snapshotId(), + "cdc_view"); + + long rowCount = sql("select * from %s", "cdc_view").stream().count(); + Assert.assertEquals(2, rowCount); + } + + @Test + public void testNoSnapshotIdInput() { + createTableWith2Columns(); + sql("INSERT INTO %s VALUES (1, 'a')", tableName); + Table table = validationCatalog.loadTable(tableIdent); + Snapshot snap0 = table.currentSnapshot(); + + sql("INSERT INTO %s VALUES (2, 'b')", tableName); + table.refresh(); + Snapshot snap1 = table.currentSnapshot(); + + sql("INSERT OVERWRITE %s VALUES (-2, 'b')", tableName); + table.refresh(); + Snapshot snap2 = table.currentSnapshot(); + + List returns = + sql( + "CALL %s.system.create_changelog_view(" + "table => '%s')", + catalogName, tableName, "cdc_view"); + + String viewName = (String) returns.get(0)[0]; + assertEquals( + "Rows should match", + ImmutableList.of( + row(1, "a", INSERT, 0, snap0.snapshotId()), + row(2, "b", INSERT, 1, snap1.snapshotId()), + row(-2, "b", INSERT, 2, snap2.snapshotId()), + row(2, "b", DELETE, 2, snap2.snapshotId())), + sql("select * from %s order by _change_ordinal, id", viewName)); + } + + @Test + public void testTimestampsBasedQuery() { + createTableWith2Columns(); + long beginning = System.currentTimeMillis(); + + sql("INSERT INTO %s VALUES (1, 'a')", tableName); + Table table = validationCatalog.loadTable(tableIdent); + Snapshot snap0 = table.currentSnapshot(); + long afterFirstInsert = waitUntilAfter(snap0.timestampMillis()); + + sql("INSERT INTO %s VALUES (2, 'b')", tableName); + table.refresh(); + Snapshot snap1 = table.currentSnapshot(); + + sql("INSERT OVERWRITE %s VALUES (-2, 'b')", tableName); + table.refresh(); + Snapshot snap2 = table.currentSnapshot(); + + long afterInsertOverwrite = waitUntilAfter(snap2.timestampMillis()); + List returns = + sql( + "CALL %s.system.create_changelog_view(table => '%s', " + + "options => map('%s', '%s','%s', '%s'))", + catalogName, + tableName, + SparkReadOptions.START_TIMESTAMP, + beginning, + SparkReadOptions.END_TIMESTAMP, + afterInsertOverwrite); + + assertEquals( + "Rows should match", + ImmutableList.of( + row(1, "a", INSERT, 0, snap0.snapshotId()), + row(2, "b", INSERT, 1, snap1.snapshotId()), + row(-2, "b", INSERT, 2, snap2.snapshotId()), + row(2, "b", DELETE, 2, snap2.snapshotId())), + sql("select * from %s order by _change_ordinal, id", returns.get(0)[0])); + + // query the timestamps starting from the second insert + returns = + sql( + "CALL %s.system.create_changelog_view(table => '%s', " + + "options => map('%s', '%s', '%s', '%s'))", + catalogName, + tableName, + SparkReadOptions.START_TIMESTAMP, + afterFirstInsert, + SparkReadOptions.END_TIMESTAMP, + afterInsertOverwrite); + + assertEquals( + "Rows should match", + ImmutableList.of( + row(2, "b", INSERT, 0, snap1.snapshotId()), + row(-2, "b", INSERT, 1, snap2.snapshotId()), + row(2, "b", DELETE, 1, snap2.snapshotId())), + sql("select * from %s order by _change_ordinal, id", returns.get(0)[0])); + } + + @Test + public void testWithCarryovers() { + createTableWith2Columns(); + sql("INSERT INTO %s VALUES (1, 'a')", tableName); + Table table = validationCatalog.loadTable(tableIdent); + Snapshot snap0 = table.currentSnapshot(); + + sql("INSERT INTO %s VALUES (2, 'b')", tableName); + table.refresh(); + Snapshot snap1 = table.currentSnapshot(); + + sql("INSERT OVERWRITE %s VALUES (-2, 'b'), (2, 'b'), (2, 'b')", tableName); + table.refresh(); + Snapshot snap2 = table.currentSnapshot(); + + List returns = + sql( + "CALL %s.system.create_changelog_view(" + + "remove_carryovers => false," + + "table => '%s')", + catalogName, tableName, "cdc_view"); + + String viewName = (String) returns.get(0)[0]; + assertEquals( + "Rows should match", + ImmutableList.of( + row(1, "a", INSERT, 0, snap0.snapshotId()), + row(2, "b", INSERT, 1, snap1.snapshotId()), + row(-2, "b", INSERT, 2, snap2.snapshotId()), + row(2, "b", DELETE, 2, snap2.snapshotId()), + row(2, "b", INSERT, 2, snap2.snapshotId()), + row(2, "b", INSERT, 2, snap2.snapshotId())), + sql("select * from %s order by _change_ordinal, id, _change_type", viewName)); + } + + @Test + public void testUpdate() { + createTableWith2Columns(); + sql("ALTER TABLE %s DROP PARTITION FIELD data", tableName); + sql("ALTER TABLE %s ADD PARTITION FIELD id", tableName); + + sql("INSERT INTO %s VALUES (1, 'a'), (2, 'b')", tableName); + Table table = validationCatalog.loadTable(tableIdent); + Snapshot snap1 = table.currentSnapshot(); + + sql("INSERT OVERWRITE %s VALUES (3, 'c'), (2, 'd')", tableName); + table.refresh(); + Snapshot snap2 = table.currentSnapshot(); + + List returns = + sql( + "CALL %s.system.create_changelog_view(table => '%s', identifier_columns => array('id'))", + catalogName, tableName); + + String viewName = (String) returns.get(0)[0]; + assertEquals( + "Rows should match", + ImmutableList.of( + row(1, "a", INSERT, 0, snap1.snapshotId()), + row(2, "b", INSERT, 0, snap1.snapshotId()), + row(2, "b", UPDATE_BEFORE, 1, snap2.snapshotId()), + row(2, "d", UPDATE_AFTER, 1, snap2.snapshotId()), + row(3, "c", INSERT, 1, snap2.snapshotId())), + sql("select * from %s order by _change_ordinal, id, data", viewName)); + } + + @Test + public void testUpdateWithIdentifierField() { + createTableWithIdentifierField(); + + sql("INSERT INTO %s VALUES (2, 'b')", tableName); + Table table = validationCatalog.loadTable(tableIdent); + Snapshot snap1 = table.currentSnapshot(); + + sql("INSERT OVERWRITE %s VALUES (3, 'c'), (2, 'd')", tableName); + table.refresh(); + Snapshot snap2 = table.currentSnapshot(); + + List returns = + sql( + "CALL %s.system.create_changelog_view(table => '%s', compute_updates => true)", + catalogName, tableName); + + String viewName = (String) returns.get(0)[0]; + assertEquals( + "Rows should match", + ImmutableList.of( + row(2, "b", INSERT, 0, snap1.snapshotId()), + row(2, "b", UPDATE_BEFORE, 1, snap2.snapshotId()), + row(2, "d", UPDATE_AFTER, 1, snap2.snapshotId()), + row(3, "c", INSERT, 1, snap2.snapshotId())), + sql("select * from %s order by _change_ordinal, id, data", viewName)); + } + + @Test + public void testUpdateWithFilter() { + createTableWith2Columns(); + sql("ALTER TABLE %s DROP PARTITION FIELD data", tableName); + sql("ALTER TABLE %s ADD PARTITION FIELD id", tableName); + + sql("INSERT INTO %s VALUES (1, 'a'), (2, 'b')", tableName); + Table table = validationCatalog.loadTable(tableIdent); + Snapshot snap1 = table.currentSnapshot(); + + sql("INSERT OVERWRITE %s VALUES (3, 'c'), (2, 'd')", tableName); + table.refresh(); + Snapshot snap2 = table.currentSnapshot(); + + List returns = + sql( + "CALL %s.system.create_changelog_view(table => '%s', identifier_columns => array('id'))", + catalogName, tableName); + + String viewName = (String) returns.get(0)[0]; + assertEquals( + "Rows should match", + ImmutableList.of( + row(1, "a", INSERT, 0, snap1.snapshotId()), + row(2, "b", INSERT, 0, snap1.snapshotId()), + row(2, "b", UPDATE_BEFORE, 1, snap2.snapshotId()), + row(2, "d", UPDATE_AFTER, 1, snap2.snapshotId())), + // the predicate on partition columns will filter out the insert of (3, 'c') at the planning + // phase + sql("select * from %s where id != 3 order by _change_ordinal, id, data", viewName)); + } + + @Test + public void testUpdateWithMultipleIdentifierColumns() { + createTableWith3Columns(); + + sql("INSERT INTO %s VALUES (1, 'a', 12), (2, 'b', 11)", tableName); + Table table = validationCatalog.loadTable(tableIdent); + Snapshot snap1 = table.currentSnapshot(); + + sql("INSERT OVERWRITE %s VALUES (3, 'c', 13), (2, 'd', 11), (2, 'e', 12)", tableName); + table.refresh(); + Snapshot snap2 = table.currentSnapshot(); + + List returns = + sql( + "CALL %s.system.create_changelog_view(" + + "identifier_columns => array('id','age')," + + "table => '%s')", + catalogName, tableName); + + String viewName = (String) returns.get(0)[0]; + assertEquals( + "Rows should match", + ImmutableList.of( + row(1, "a", 12, INSERT, 0, snap1.snapshotId()), + row(2, "b", 11, INSERT, 0, snap1.snapshotId()), + row(2, "b", 11, UPDATE_BEFORE, 1, snap2.snapshotId()), + row(2, "d", 11, UPDATE_AFTER, 1, snap2.snapshotId()), + row(2, "e", 12, INSERT, 1, snap2.snapshotId()), + row(3, "c", 13, INSERT, 1, snap2.snapshotId())), + sql("select * from %s order by _change_ordinal, id, data", viewName)); + } + + @Test + public void testRemoveCarryOvers() { + createTableWith3Columns(); + + sql("INSERT INTO %s VALUES (1, 'a', 12), (2, 'b', 11), (2, 'e', 12)", tableName); + Table table = validationCatalog.loadTable(tableIdent); + Snapshot snap1 = table.currentSnapshot(); + + // carry-over row (2, 'e', 12) + sql("INSERT OVERWRITE %s VALUES (3, 'c', 13), (2, 'd', 11), (2, 'e', 12)", tableName); + table.refresh(); + Snapshot snap2 = table.currentSnapshot(); + + List returns = + sql( + "CALL %s.system.create_changelog_view(" + + "identifier_columns => array('id','age'), " + + "table => '%s')", + catalogName, tableName); + + String viewName = (String) returns.get(0)[0]; + // the carry-over rows (2, 'e', 12, 'DELETE', 1), (2, 'e', 12, 'INSERT', 1) are removed + assertEquals( + "Rows should match", + ImmutableList.of( + row(1, "a", 12, INSERT, 0, snap1.snapshotId()), + row(2, "b", 11, INSERT, 0, snap1.snapshotId()), + row(2, "e", 12, INSERT, 0, snap1.snapshotId()), + row(2, "b", 11, UPDATE_BEFORE, 1, snap2.snapshotId()), + row(2, "d", 11, UPDATE_AFTER, 1, snap2.snapshotId()), + row(3, "c", 13, INSERT, 1, snap2.snapshotId())), + sql("select * from %s order by _change_ordinal, id, data", viewName)); + } + + @Test + public void testRemoveCarryOversWithoutUpdatedRows() { + createTableWith3Columns(); + + sql("INSERT INTO %s VALUES (1, 'a', 12), (2, 'b', 11), (2, 'e', 12)", tableName); + Table table = validationCatalog.loadTable(tableIdent); + Snapshot snap1 = table.currentSnapshot(); + + // carry-over row (2, 'e', 12) + sql("INSERT OVERWRITE %s VALUES (3, 'c', 13), (2, 'd', 11), (2, 'e', 12)", tableName); + table.refresh(); + Snapshot snap2 = table.currentSnapshot(); + + List returns = + sql("CALL %s.system.create_changelog_view(table => '%s')", catalogName, tableName); + + String viewName = (String) returns.get(0)[0]; + + // the carry-over rows (2, 'e', 12, 'DELETE', 1), (2, 'e', 12, 'INSERT', 1) are removed, even + // though update-row is not computed + assertEquals( + "Rows should match", + ImmutableList.of( + row(1, "a", 12, INSERT, 0, snap1.snapshotId()), + row(2, "b", 11, INSERT, 0, snap1.snapshotId()), + row(2, "e", 12, INSERT, 0, snap1.snapshotId()), + row(2, "b", 11, DELETE, 1, snap2.snapshotId()), + row(2, "d", 11, INSERT, 1, snap2.snapshotId()), + row(3, "c", 13, INSERT, 1, snap2.snapshotId())), + sql("select * from %s order by _change_ordinal, id, data", viewName)); + } + + @Test + public void testNotRemoveCarryOvers() { + createTableWith3Columns(); + + sql("INSERT INTO %s VALUES (1, 'a', 12), (2, 'b', 11), (2, 'e', 12)", tableName); + Table table = validationCatalog.loadTable(tableIdent); + Snapshot snap1 = table.currentSnapshot(); + + // carry-over row (2, 'e', 12) + sql("INSERT OVERWRITE %s VALUES (3, 'c', 13), (2, 'd', 11), (2, 'e', 12)", tableName); + table.refresh(); + Snapshot snap2 = table.currentSnapshot(); + + List returns = + sql( + "CALL %s.system.create_changelog_view(" + + "remove_carryovers => false," + + "table => '%s')", + catalogName, tableName); + + String viewName = (String) returns.get(0)[0]; + + assertEquals( + "Rows should match", + ImmutableList.of( + row(1, "a", 12, INSERT, 0, snap1.snapshotId()), + row(2, "b", 11, INSERT, 0, snap1.snapshotId()), + row(2, "e", 12, INSERT, 0, snap1.snapshotId()), + row(2, "b", 11, DELETE, 1, snap2.snapshotId()), + row(2, "d", 11, INSERT, 1, snap2.snapshotId()), + // the following two rows are carry-over rows + row(2, "e", 12, DELETE, 1, snap2.snapshotId()), + row(2, "e", 12, INSERT, 1, snap2.snapshotId()), + row(3, "c", 13, INSERT, 1, snap2.snapshotId())), + sql("select * from %s order by _change_ordinal, id, data, _change_type", viewName)); + } +} diff --git a/spark/v3.4/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestDelete.java b/spark/v3.4/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestDelete.java new file mode 100644 index 000000000000..bb26d975a414 --- /dev/null +++ b/spark/v3.4/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestDelete.java @@ -0,0 +1,1177 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.extensions; + +import static org.apache.iceberg.RowLevelOperationMode.COPY_ON_WRITE; +import static org.apache.iceberg.TableProperties.DELETE_ISOLATION_LEVEL; +import static org.apache.iceberg.TableProperties.DELETE_MODE; +import static org.apache.iceberg.TableProperties.DELETE_MODE_DEFAULT; +import static org.apache.iceberg.TableProperties.PARQUET_ROW_GROUP_SIZE_BYTES; +import static org.apache.iceberg.TableProperties.SPLIT_SIZE; +import static org.apache.spark.sql.functions.lit; + +import java.util.Arrays; +import java.util.Collections; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.Future; +import java.util.concurrent.ThreadPoolExecutor; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.stream.Collectors; +import org.apache.iceberg.AppendFiles; +import org.apache.iceberg.AssertHelpers; +import org.apache.iceberg.DataFile; +import org.apache.iceberg.ManifestFile; +import org.apache.iceberg.RowLevelOperationMode; +import org.apache.iceberg.Snapshot; +import org.apache.iceberg.Table; +import org.apache.iceberg.TableProperties; +import org.apache.iceberg.data.GenericRecord; +import org.apache.iceberg.exceptions.ValidationException; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableList; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap; +import org.apache.iceberg.relocated.com.google.common.collect.Iterables; +import org.apache.iceberg.relocated.com.google.common.collect.Lists; +import org.apache.iceberg.relocated.com.google.common.util.concurrent.MoreExecutors; +import org.apache.iceberg.spark.Spark3Util; +import org.apache.iceberg.spark.SparkSQLProperties; +import org.apache.iceberg.spark.data.TestHelpers; +import org.apache.iceberg.util.SnapshotUtil; +import org.apache.spark.sql.AnalysisException; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Encoders; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.catalyst.analysis.NoSuchTableException; +import org.apache.spark.sql.catalyst.parser.ParseException; +import org.apache.spark.sql.catalyst.plans.logical.DeleteFromIcebergTable; +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan; +import org.apache.spark.sql.execution.datasources.v2.OptimizeMetadataOnlyDeleteFromIcebergTable; +import org.apache.spark.sql.internal.SQLConf; +import org.assertj.core.api.Assertions; +import org.junit.After; +import org.junit.Assert; +import org.junit.Assume; +import org.junit.BeforeClass; +import org.junit.Test; + +public abstract class TestDelete extends SparkRowLevelOperationsTestBase { + + public TestDelete( + String catalogName, + String implementation, + Map config, + String fileFormat, + Boolean vectorized, + String distributionMode, + String branch) { + super(catalogName, implementation, config, fileFormat, vectorized, distributionMode, branch); + } + + @BeforeClass + public static void setupSparkConf() { + spark.conf().set("spark.sql.shuffle.partitions", "4"); + } + + @After + public void removeTables() { + sql("DROP TABLE IF EXISTS %s", tableName); + sql("DROP TABLE IF EXISTS deleted_id"); + sql("DROP TABLE IF EXISTS deleted_dep"); + sql("DROP TABLE IF EXISTS parquet_table"); + } + + @Test + public void testDeleteWithoutScanningTable() throws Exception { + createAndInitPartitionedTable(); + + append(tableName, new Employee(1, "hr"), new Employee(3, "hr")); + createBranchIfNeeded(); + append(new Employee(1, "hardware"), new Employee(2, "hardware")); + + Table table = validationCatalog.loadTable(tableIdent); + + List manifestLocations = + SnapshotUtil.latestSnapshot(table, branch).allManifests(table.io()).stream() + .map(ManifestFile::path) + .collect(Collectors.toList()); + + withUnavailableLocations( + manifestLocations, + () -> { + LogicalPlan parsed = parsePlan("DELETE FROM %s WHERE dep = 'hr'", commitTarget()); + + DeleteFromIcebergTable analyzed = + (DeleteFromIcebergTable) spark.sessionState().analyzer().execute(parsed); + Assert.assertTrue("Should have rewrite plan", analyzed.rewritePlan().isDefined()); + + DeleteFromIcebergTable optimized = + (DeleteFromIcebergTable) OptimizeMetadataOnlyDeleteFromIcebergTable.apply(analyzed); + Assert.assertTrue("Should discard rewrite plan", optimized.rewritePlan().isEmpty()); + }); + + sql("DELETE FROM %s WHERE dep = 'hr'", commitTarget()); + + assertEquals( + "Should have expected rows", + ImmutableList.of(row(1, "hardware"), row(2, "hardware")), + sql("SELECT * FROM %s ORDER BY id", selectTarget())); + } + + @Test + public void testDeleteFileThenMetadataDelete() throws Exception { + Assume.assumeFalse("Avro does not support metadata delete", fileFormat.equals("avro")); + createAndInitUnpartitionedTable(); + + sql("INSERT INTO TABLE %s VALUES (1, 'hr'), (2, 'hardware'), (null, 'hr')", tableName); + createBranchIfNeeded(); + + // MOR mode: writes a delete file as null cannot be deleted by metadata + sql("DELETE FROM %s AS t WHERE t.id IS NULL", commitTarget()); + + // Metadata Delete + Table table = Spark3Util.loadIcebergTable(spark, tableName); + Set dataFilesBefore = TestHelpers.dataFiles(table, branch); + + sql("DELETE FROM %s AS t WHERE t.id = 1", commitTarget()); + + Set dataFilesAfter = TestHelpers.dataFiles(table, branch); + Assert.assertTrue( + "Data file should have been removed", dataFilesBefore.size() > dataFilesAfter.size()); + + assertEquals( + "Should have expected rows", + ImmutableList.of(row(2, "hardware")), + sql("SELECT * FROM %s ORDER BY id", selectTarget())); + } + + @Test + public void testDeleteWithFalseCondition() { + createAndInitUnpartitionedTable(); + + sql("INSERT INTO TABLE %s VALUES (1, 'hr'), (2, 'hardware')", tableName); + createBranchIfNeeded(); + + sql("DELETE FROM %s WHERE id = 1 AND id > 20", commitTarget()); + + Table table = validationCatalog.loadTable(tableIdent); + Assert.assertEquals("Should have 2 snapshots", 2, Iterables.size(table.snapshots())); + + assertEquals( + "Should have expected rows", + ImmutableList.of(row(1, "hr"), row(2, "hardware")), + sql("SELECT * FROM %s ORDER BY id", selectTarget())); + } + + @Test + public void testDeleteFromEmptyTable() { + Assume.assumeFalse("Custom branch does not exist for empty table", "test".equals(branch)); + createAndInitUnpartitionedTable(); + + sql("DELETE FROM %s WHERE id IN (1)", commitTarget()); + sql("DELETE FROM %s WHERE dep = 'hr'", commitTarget()); + + Table table = validationCatalog.loadTable(tableIdent); + Assert.assertEquals("Should have 2 snapshots", 2, Iterables.size(table.snapshots())); + + assertEquals( + "Should have expected rows", + ImmutableList.of(), + sql("SELECT * FROM %s ORDER BY id", selectTarget())); + } + + @Test + public void testDeleteFromNonExistingCustomBranch() { + Assume.assumeTrue("Test only applicable to custom branch", "test".equals(branch)); + createAndInitUnpartitionedTable(); + + Assertions.assertThatThrownBy(() -> sql("DELETE FROM %s WHERE id IN (1)", commitTarget())) + .isInstanceOf(ValidationException.class) + .hasMessage("Cannot use branch (does not exist): test"); + } + + @Test + public void testExplain() { + createAndInitUnpartitionedTable(); + + sql("INSERT INTO TABLE %s VALUES (1, 'hr'), (2, 'hardware'), (null, 'hr')", tableName); + createBranchIfNeeded(); + + sql("EXPLAIN DELETE FROM %s WHERE id <=> 1", commitTarget()); + + sql("EXPLAIN DELETE FROM %s WHERE true", commitTarget()); + + Table table = validationCatalog.loadTable(tableIdent); + Assert.assertEquals("Should have 1 snapshot", 1, Iterables.size(table.snapshots())); + + assertEquals( + "Should have expected rows", + ImmutableList.of(row(1, "hr"), row(2, "hardware"), row(null, "hr")), + sql("SELECT * FROM %s ORDER BY id ASC NULLS LAST", commitTarget())); + } + + @Test + public void testDeleteWithAlias() { + createAndInitUnpartitionedTable(); + + sql("INSERT INTO TABLE %s VALUES (1, 'hr'), (2, 'hardware'), (null, 'hr')", tableName); + createBranchIfNeeded(); + + sql("DELETE FROM %s AS t WHERE t.id IS NULL", commitTarget()); + + assertEquals( + "Should have expected rows", + ImmutableList.of(row(1, "hr"), row(2, "hardware")), + sql("SELECT * FROM %s ORDER BY id", selectTarget())); + } + + @Test + public void testDeleteWithDynamicFileFiltering() throws NoSuchTableException { + createAndInitPartitionedTable(); + + append(tableName, new Employee(1, "hr"), new Employee(3, "hr")); + createBranchIfNeeded(); + append(new Employee(1, "hardware"), new Employee(2, "hardware")); + + sql("DELETE FROM %s WHERE id = 2", commitTarget()); + + Table table = validationCatalog.loadTable(tableIdent); + Assert.assertEquals("Should have 3 snapshots", 3, Iterables.size(table.snapshots())); + + Snapshot currentSnapshot = SnapshotUtil.latestSnapshot(table, branch); + if (mode(table) == COPY_ON_WRITE) { + validateCopyOnWrite(currentSnapshot, "1", "1", "1"); + } else { + validateMergeOnRead(currentSnapshot, "1", "1", null); + } + + assertEquals( + "Should have expected rows", + ImmutableList.of(row(1, "hardware"), row(1, "hr"), row(3, "hr")), + sql("SELECT * FROM %s ORDER BY id, dep", selectTarget())); + } + + @Test + public void testDeleteNonExistingRecords() { + createAndInitPartitionedTable(); + + sql("INSERT INTO TABLE %s VALUES (1, 'hr'), (2, 'hardware'), (null, 'hr')", tableName); + createBranchIfNeeded(); + + sql("DELETE FROM %s AS t WHERE t.id > 10", commitTarget()); + + Table table = validationCatalog.loadTable(tableIdent); + Assert.assertEquals("Should have 2 snapshots", 2, Iterables.size(table.snapshots())); + + Snapshot currentSnapshot = SnapshotUtil.latestSnapshot(table, branch); + + if (fileFormat.equals("orc") || fileFormat.equals("parquet")) { + validateDelete(currentSnapshot, "0", null); + } else { + if (mode(table) == COPY_ON_WRITE) { + validateCopyOnWrite(currentSnapshot, "0", null, null); + } else { + validateMergeOnRead(currentSnapshot, "0", null, null); + } + } + + assertEquals( + "Should have expected rows", + ImmutableList.of(row(1, "hr"), row(2, "hardware"), row(null, "hr")), + sql("SELECT * FROM %s ORDER BY id ASC NULLS LAST", selectTarget())); + } + + @Test + public void testDeleteWithoutCondition() { + createAndInitPartitionedTable(); + + sql("INSERT INTO TABLE %s VALUES (1, 'hr')", tableName); + createBranchIfNeeded(); + sql("INSERT INTO TABLE %s VALUES (2, 'hardware')", commitTarget()); + sql("INSERT INTO TABLE %s VALUES (null, 'hr')", commitTarget()); + + sql("DELETE FROM %s", commitTarget()); + + Table table = validationCatalog.loadTable(tableIdent); + Assert.assertEquals("Should have 4 snapshots", 4, Iterables.size(table.snapshots())); + + // should be a delete instead of an overwrite as it is done through a metadata operation + Snapshot currentSnapshot = SnapshotUtil.latestSnapshot(table, branch); + validateDelete(currentSnapshot, "2", "3"); + + assertEquals( + "Should have expected rows", ImmutableList.of(), sql("SELECT * FROM %s", commitTarget())); + } + + @Test + public void testDeleteUsingMetadataWithComplexCondition() { + createAndInitPartitionedTable(); + + sql("INSERT INTO %s VALUES (1, 'dep1')", tableName); + createBranchIfNeeded(); + sql("INSERT INTO %s VALUES (2, 'dep2')", commitTarget()); + sql("INSERT INTO %s VALUES (null, 'dep3')", commitTarget()); + + sql( + "DELETE FROM %s WHERE dep > 'dep2' OR dep = CAST(4 AS STRING) OR dep = 'dep2'", + commitTarget()); + + Table table = validationCatalog.loadTable(tableIdent); + Assert.assertEquals("Should have 4 snapshots", 4, Iterables.size(table.snapshots())); + + // should be a delete instead of an overwrite as it is done through a metadata operation + Snapshot currentSnapshot = SnapshotUtil.latestSnapshot(table, branch); + validateDelete(currentSnapshot, "2", "2"); + + assertEquals( + "Should have expected rows", + ImmutableList.of(row(1, "dep1")), + sql("SELECT * FROM %s", selectTarget())); + } + + @Test + public void testDeleteWithArbitraryPartitionPredicates() { + createAndInitPartitionedTable(); + + sql("INSERT INTO TABLE %s VALUES (1, 'hr')", tableName); + createBranchIfNeeded(); + sql("INSERT INTO TABLE %s VALUES (2, 'hardware')", commitTarget()); + sql("INSERT INTO TABLE %s VALUES (null, 'hr')", commitTarget()); + + // %% is an escaped version of % + sql("DELETE FROM %s WHERE id = 10 OR dep LIKE '%%ware'", commitTarget()); + + Table table = validationCatalog.loadTable(tableIdent); + Assert.assertEquals("Should have 4 snapshots", 4, Iterables.size(table.snapshots())); + + // should be an overwrite since cannot be executed using a metadata operation + Snapshot currentSnapshot = SnapshotUtil.latestSnapshot(table, branch); + if (mode(table) == COPY_ON_WRITE) { + validateCopyOnWrite(currentSnapshot, "1", "1", null); + } else { + validateMergeOnRead(currentSnapshot, "1", "1", null); + } + + assertEquals( + "Should have expected rows", + ImmutableList.of(row(1, "hr"), row(null, "hr")), + sql("SELECT * FROM %s ORDER BY id ASC NULLS LAST", selectTarget())); + } + + @Test + public void testDeleteWithNonDeterministicCondition() { + createAndInitPartitionedTable(); + + sql("INSERT INTO TABLE %s VALUES (1, 'hr'), (2, 'hardware')", tableName); + createBranchIfNeeded(); + + AssertHelpers.assertThrows( + "Should complain about non-deterministic expressions", + AnalysisException.class, + "nondeterministic expressions are only allowed", + () -> sql("DELETE FROM %s WHERE id = 1 AND rand() > 0.5", commitTarget())); + } + + @Test + public void testDeleteWithFoldableConditions() { + createAndInitPartitionedTable(); + + sql("INSERT INTO TABLE %s VALUES (1, 'hr'), (2, 'hardware')", tableName); + createBranchIfNeeded(); + + // should keep all rows and don't trigger execution + sql("DELETE FROM %s WHERE false", commitTarget()); + assertEquals( + "Should have expected rows", + ImmutableList.of(row(1, "hr"), row(2, "hardware")), + sql("SELECT * FROM %s ORDER BY id", selectTarget())); + + // should keep all rows and don't trigger execution + sql("DELETE FROM %s WHERE 50 <> 50", commitTarget()); + assertEquals( + "Should have expected rows", + ImmutableList.of(row(1, "hr"), row(2, "hardware")), + sql("SELECT * FROM %s ORDER BY id", selectTarget())); + + // should keep all rows and don't trigger execution + sql("DELETE FROM %s WHERE 1 > null", commitTarget()); + assertEquals( + "Should have expected rows", + ImmutableList.of(row(1, "hr"), row(2, "hardware")), + sql("SELECT * FROM %s ORDER BY id", selectTarget())); + + // should remove all rows + sql("DELETE FROM %s WHERE 21 = 21", commitTarget()); + assertEquals( + "Should have expected rows", + ImmutableList.of(), + sql("SELECT * FROM %s ORDER BY id", selectTarget())); + + Table table = validationCatalog.loadTable(tableIdent); + Assert.assertEquals("Should have 2 snapshots", 2, Iterables.size(table.snapshots())); + } + + @Test + public void testDeleteWithNullConditions() { + createAndInitPartitionedTable(); + + sql( + "INSERT INTO TABLE %s VALUES (0, null), (1, 'hr'), (2, 'hardware'), (null, 'hr')", + tableName); + createBranchIfNeeded(); + + // should keep all rows as null is never equal to null + sql("DELETE FROM %s WHERE dep = null", commitTarget()); + assertEquals( + "Should have expected rows", + ImmutableList.of(row(0, null), row(1, "hr"), row(2, "hardware"), row(null, "hr")), + sql("SELECT * FROM %s ORDER BY id ASC NULLS LAST", selectTarget())); + + // null = 'software' -> null + // should delete using metadata operation only + sql("DELETE FROM %s WHERE dep = 'software'", commitTarget()); + assertEquals( + "Should have expected rows", + ImmutableList.of(row(0, null), row(1, "hr"), row(2, "hardware"), row(null, "hr")), + sql("SELECT * FROM %s ORDER BY id ASC NULLS LAST", selectTarget())); + + // should delete using metadata operation only + sql("DELETE FROM %s WHERE dep <=> NULL", commitTarget()); + assertEquals( + "Should have expected rows", + ImmutableList.of(row(1, "hr"), row(2, "hardware"), row(null, "hr")), + sql("SELECT * FROM %s ORDER BY id ASC NULLS LAST", selectTarget())); + + Table table = validationCatalog.loadTable(tableIdent); + Assert.assertEquals("Should have 3 snapshots", 3, Iterables.size(table.snapshots())); + + Snapshot currentSnapshot = SnapshotUtil.latestSnapshot(table, branch); + validateDelete(currentSnapshot, "1", "1"); + } + + @Test + public void testDeleteWithInAndNotInConditions() { + createAndInitUnpartitionedTable(); + + sql("INSERT INTO TABLE %s VALUES (1, 'hr'), (2, 'hardware'), (null, 'hr')", tableName); + createBranchIfNeeded(); + + sql("DELETE FROM %s WHERE id IN (1, null)", commitTarget()); + assertEquals( + "Should have expected rows", + ImmutableList.of(row(2, "hardware"), row(null, "hr")), + sql("SELECT * FROM %s ORDER BY id ASC NULLS LAST", selectTarget())); + + sql("DELETE FROM %s WHERE id NOT IN (null, 1)", commitTarget()); + assertEquals( + "Should have expected rows", + ImmutableList.of(row(2, "hardware"), row(null, "hr")), + sql("SELECT * FROM %s ORDER BY id ASC NULLS LAST", selectTarget())); + + sql("DELETE FROM %s WHERE id NOT IN (1, 10)", commitTarget()); + assertEquals( + "Should have expected rows", + ImmutableList.of(row(null, "hr")), + sql("SELECT * FROM %s ORDER BY id ASC NULLS LAST", selectTarget())); + } + + @Test + public void testDeleteWithMultipleRowGroupsParquet() throws NoSuchTableException { + Assume.assumeTrue(fileFormat.equalsIgnoreCase("parquet")); + + createAndInitPartitionedTable(); + + sql( + "ALTER TABLE %s SET TBLPROPERTIES('%s' '%d')", + tableName, PARQUET_ROW_GROUP_SIZE_BYTES, 100); + sql("ALTER TABLE %s SET TBLPROPERTIES('%s' '%d')", tableName, SPLIT_SIZE, 100); + + List ids = Lists.newArrayListWithCapacity(200); + for (int id = 1; id <= 200; id++) { + ids.add(id); + } + Dataset df = + spark + .createDataset(ids, Encoders.INT()) + .withColumnRenamed("value", "id") + .withColumn("dep", lit("hr")); + df.coalesce(1).writeTo(tableName).append(); + createBranchIfNeeded(); + + Assert.assertEquals(200, spark.table(commitTarget()).count()); + + // delete a record from one of two row groups and copy over the second one + sql("DELETE FROM %s WHERE id IN (200, 201)", commitTarget()); + + Assert.assertEquals(199, spark.table(commitTarget()).count()); + } + + @Test + public void testDeleteWithConditionOnNestedColumn() { + createAndInitNestedColumnsTable(); + + sql("INSERT INTO TABLE %s VALUES (1, named_struct(\"c1\", 3, \"c2\", \"v1\"))", tableName); + createBranchIfNeeded(); + sql("INSERT INTO TABLE %s VALUES (2, named_struct(\"c1\", 2, \"c2\", \"v2\"))", commitTarget()); + + sql("DELETE FROM %s WHERE complex.c1 = id + 2", commitTarget()); + assertEquals( + "Should have expected rows", + ImmutableList.of(row(2)), + sql("SELECT id FROM %s", selectTarget())); + + sql("DELETE FROM %s t WHERE t.complex.c1 = id", commitTarget()); + assertEquals( + "Should have expected rows", ImmutableList.of(), sql("SELECT id FROM %s", selectTarget())); + } + + @Test + public void testDeleteWithInSubquery() throws NoSuchTableException { + createAndInitUnpartitionedTable(); + + sql("INSERT INTO TABLE %s VALUES (1, 'hr'), (2, 'hardware'), (null, 'hr')", tableName); + createBranchIfNeeded(); + + createOrReplaceView("deleted_id", Arrays.asList(0, 1, null), Encoders.INT()); + createOrReplaceView("deleted_dep", Arrays.asList("software", "hr"), Encoders.STRING()); + + sql( + "DELETE FROM %s WHERE id IN (SELECT * FROM deleted_id) AND dep IN (SELECT * from deleted_dep)", + commitTarget()); + assertEquals( + "Should have expected rows", + ImmutableList.of(row(2, "hardware"), row(null, "hr")), + sql("SELECT * FROM %s ORDER BY id ASC NULLS LAST", selectTarget())); + + append(new Employee(1, "hr"), new Employee(-1, "hr")); + assertEquals( + "Should have expected rows", + ImmutableList.of(row(-1, "hr"), row(1, "hr"), row(2, "hardware"), row(null, "hr")), + sql("SELECT * FROM %s ORDER BY id ASC NULLS LAST", selectTarget())); + + sql( + "DELETE FROM %s WHERE id IS NULL OR id IN (SELECT value + 2 FROM deleted_id)", + commitTarget()); + assertEquals( + "Should have expected rows", + ImmutableList.of(row(-1, "hr"), row(1, "hr")), + sql("SELECT * FROM %s ORDER BY id", selectTarget())); + + append(new Employee(null, "hr"), new Employee(2, "hr")); + assertEquals( + "Should have expected rows", + ImmutableList.of(row(-1, "hr"), row(1, "hr"), row(2, "hr"), row(null, "hr")), + sql("SELECT * FROM %s ORDER BY id ASC NULLS LAST", selectTarget())); + + sql( + "DELETE FROM %s WHERE id IN (SELECT value + 2 FROM deleted_id) AND dep = 'hr'", + commitTarget()); + assertEquals( + "Should have expected rows", + ImmutableList.of(row(-1, "hr"), row(1, "hr"), row(null, "hr")), + sql("SELECT * FROM %s ORDER BY id ASC NULLS LAST", selectTarget())); + } + + @Test + public void testDeleteWithMultiColumnInSubquery() throws NoSuchTableException { + createAndInitUnpartitionedTable(); + + append(tableName, new Employee(1, "hr"), new Employee(2, "hardware"), new Employee(null, "hr")); + createBranchIfNeeded(); + + List deletedEmployees = + Arrays.asList(new Employee(null, "hr"), new Employee(1, "hr")); + createOrReplaceView("deleted_employee", deletedEmployees, Encoders.bean(Employee.class)); + + sql("DELETE FROM %s WHERE (id, dep) IN (SELECT id, dep FROM deleted_employee)", commitTarget()); + assertEquals( + "Should have expected rows", + ImmutableList.of(row(2, "hardware"), row(null, "hr")), + sql("SELECT * FROM %s ORDER BY id ASC NULLS LAST", selectTarget())); + } + + @Test + public void testDeleteWithNotInSubquery() throws NoSuchTableException { + createAndInitUnpartitionedTable(); + + append(tableName, new Employee(1, "hr"), new Employee(2, "hardware"), new Employee(null, "hr")); + createBranchIfNeeded(); + + createOrReplaceView("deleted_id", Arrays.asList(-1, -2, null), Encoders.INT()); + createOrReplaceView("deleted_dep", Arrays.asList("software", "hr"), Encoders.STRING()); + + // the file filter subquery (nested loop lef-anti join) returns 0 records + sql("DELETE FROM %s WHERE id NOT IN (SELECT * FROM deleted_id)", commitTarget()); + assertEquals( + "Should have expected rows", + ImmutableList.of(row(1, "hr"), row(2, "hardware"), row(null, "hr")), + sql("SELECT * FROM %s ORDER BY id ASC NULLS LAST", selectTarget())); + + sql( + "DELETE FROM %s WHERE id NOT IN (SELECT * FROM deleted_id WHERE value IS NOT NULL)", + commitTarget()); + assertEquals( + "Should have expected rows", + ImmutableList.of(row(null, "hr")), + sql("SELECT * FROM %s ORDER BY id ASC NULLS LAST", selectTarget())); + + sql("INSERT INTO TABLE %s VALUES (1, 'hr'), (2, 'hardware'), (null, 'hr')", commitTarget()); + assertEquals( + "Should have expected rows", + ImmutableList.of(row(1, "hr"), row(2, "hardware"), row(null, "hr"), row(null, "hr")), + sql("SELECT * FROM %s ORDER BY id ASC NULLS LAST", selectTarget())); + + sql( + "DELETE FROM %s WHERE id NOT IN (SELECT * FROM deleted_id) OR dep IN ('software', 'hr')", + commitTarget()); + assertEquals( + "Should have expected rows", + ImmutableList.of(row(2, "hardware")), + sql("SELECT * FROM %s ORDER BY id ASC NULLS LAST", selectTarget())); + + sql( + "DELETE FROM %s t WHERE " + + "id NOT IN (SELECT * FROM deleted_id WHERE value IS NOT NULL) AND " + + "EXISTS (SELECT 1 FROM FROM deleted_dep WHERE t.dep = deleted_dep.value)", + commitTarget()); + assertEquals( + "Should have expected rows", + ImmutableList.of(row(2, "hardware")), + sql("SELECT * FROM %s ORDER BY id ASC NULLS LAST", selectTarget())); + + sql( + "DELETE FROM %s t WHERE " + + "id NOT IN (SELECT * FROM deleted_id WHERE value IS NOT NULL) OR " + + "EXISTS (SELECT 1 FROM FROM deleted_dep WHERE t.dep = deleted_dep.value)", + commitTarget()); + assertEquals( + "Should have expected rows", + ImmutableList.of(), + sql("SELECT * FROM %s ORDER BY id ASC NULLS LAST", selectTarget())); + } + + @Test + public void testDeleteOnNonIcebergTableNotSupported() { + Assume.assumeTrue(catalogName.equalsIgnoreCase("spark_catalog")); + + sql("CREATE TABLE parquet_table (c1 INT, c2 INT) USING parquet"); + + AssertHelpers.assertThrows( + "Delete is supported only for Iceberg tables", + AnalysisException.class, + "does not support DELETE", + () -> sql("DELETE FROM parquet_table WHERE c1 = -100")); + } + + @Test + public void testDeleteWithExistSubquery() throws NoSuchTableException { + createAndInitUnpartitionedTable(); + + append(tableName, new Employee(1, "hr"), new Employee(2, "hardware"), new Employee(null, "hr")); + createBranchIfNeeded(); + + createOrReplaceView("deleted_id", Arrays.asList(-1, -2, null), Encoders.INT()); + createOrReplaceView("deleted_dep", Arrays.asList("software", "hr"), Encoders.STRING()); + + sql( + "DELETE FROM %s t WHERE EXISTS (SELECT 1 FROM deleted_id d WHERE t.id = d.value)", + commitTarget()); + assertEquals( + "Should have expected rows", + ImmutableList.of(row(1, "hr"), row(2, "hardware"), row(null, "hr")), + sql("SELECT * FROM %s ORDER BY id ASC NULLS LAST", selectTarget())); + + sql( + "DELETE FROM %s t WHERE EXISTS (SELECT 1 FROM deleted_id d WHERE t.id = d.value + 2)", + commitTarget()); + assertEquals( + "Should have expected rows", + ImmutableList.of(row(2, "hardware"), row(null, "hr")), + sql("SELECT * FROM %s ORDER BY id ASC NULLS LAST", selectTarget())); + + sql( + "DELETE FROM %s t WHERE EXISTS (SELECT 1 FROM deleted_id d WHERE t.id = d.value) OR t.id IS NULL", + commitTarget()); + assertEquals( + "Should have expected rows", + ImmutableList.of(row(2, "hardware")), + sql("SELECT * FROM %s", selectTarget())); + + sql( + "DELETE FROM %s t WHERE " + + "EXISTS (SELECT 1 FROM deleted_id di WHERE t.id = di.value) AND " + + "EXISTS (SELECT 1 FROM deleted_dep dd WHERE t.dep = dd.value)", + commitTarget()); + assertEquals( + "Should have expected rows", + ImmutableList.of(row(2, "hardware")), + sql("SELECT * FROM %s", selectTarget())); + } + + @Test + public void testDeleteWithNotExistsSubquery() throws NoSuchTableException { + createAndInitUnpartitionedTable(); + + append(tableName, new Employee(1, "hr"), new Employee(2, "hardware"), new Employee(null, "hr")); + createBranchIfNeeded(); + + createOrReplaceView("deleted_id", Arrays.asList(-1, -2, null), Encoders.INT()); + createOrReplaceView("deleted_dep", Arrays.asList("software", "hr"), Encoders.STRING()); + + sql( + "DELETE FROM %s t WHERE " + + "NOT EXISTS (SELECT 1 FROM deleted_id di WHERE t.id = di.value + 2) AND " + + "NOT EXISTS (SELECT 1 FROM deleted_dep dd WHERE t.dep = dd.value)", + commitTarget()); + assertEquals( + "Should have expected rows", + ImmutableList.of(row(1, "hr"), row(null, "hr")), + sql("SELECT * FROM %s ORDER BY id ASC NULLS LAST", selectTarget())); + + sql( + "DELETE FROM %s t WHERE NOT EXISTS (SELECT 1 FROM deleted_id d WHERE t.id = d.value + 2)", + commitTarget()); + assertEquals( + "Should have expected rows", + ImmutableList.of(row(1, "hr")), + sql("SELECT * FROM %s ORDER BY id ASC NULLS LAST", selectTarget())); + + String subquery = "SELECT 1 FROM deleted_id d WHERE t.id = d.value + 2"; + sql("DELETE FROM %s t WHERE NOT EXISTS (%s) OR t.id = 1", commitTarget(), subquery); + assertEquals( + "Should have expected rows", + ImmutableList.of(), + sql("SELECT * FROM %s ORDER BY id ASC NULLS LAST", selectTarget())); + } + + @Test + public void testDeleteWithScalarSubquery() throws NoSuchTableException { + createAndInitUnpartitionedTable(); + + append(tableName, new Employee(1, "hr"), new Employee(2, "hardware"), new Employee(null, "hr")); + createBranchIfNeeded(); + + createOrReplaceView("deleted_id", Arrays.asList(1, 100, null), Encoders.INT()); + + // TODO: Spark does not support AQE and DPP with aggregates at the moment + withSQLConf( + ImmutableMap.of(SQLConf.ADAPTIVE_EXECUTION_ENABLED().key(), "false"), + () -> { + sql("DELETE FROM %s t WHERE id <= (SELECT min(value) FROM deleted_id)", commitTarget()); + assertEquals( + "Should have expected rows", + ImmutableList.of(row(2, "hardware"), row(null, "hr")), + sql("SELECT * FROM %s ORDER BY id ASC NULLS LAST", selectTarget())); + }); + } + + @Test + public void testDeleteThatRequiresGroupingBeforeWrite() throws NoSuchTableException { + createAndInitPartitionedTable(); + + append(tableName, new Employee(0, "hr"), new Employee(1, "hr"), new Employee(2, "hr")); + createBranchIfNeeded(); + append(new Employee(0, "ops"), new Employee(1, "ops"), new Employee(2, "ops")); + append(new Employee(0, "hr"), new Employee(1, "hr"), new Employee(2, "hr")); + append(new Employee(0, "ops"), new Employee(1, "ops"), new Employee(2, "ops")); + + createOrReplaceView("deleted_id", Arrays.asList(1, 100), Encoders.INT()); + + String originalNumOfShufflePartitions = spark.conf().get("spark.sql.shuffle.partitions"); + try { + // set the num of shuffle partitions to 1 to ensure we have only 1 writing task + spark.conf().set("spark.sql.shuffle.partitions", "1"); + + sql("DELETE FROM %s t WHERE id IN (SELECT * FROM deleted_id)", commitTarget()); + Assert.assertEquals( + "Should have expected num of rows", 8L, spark.table(commitTarget()).count()); + } finally { + spark.conf().set("spark.sql.shuffle.partitions", originalNumOfShufflePartitions); + } + } + + @Test + public synchronized void testDeleteWithSerializableIsolation() throws InterruptedException { + // cannot run tests with concurrency for Hadoop tables without atomic renames + Assume.assumeFalse(catalogName.equalsIgnoreCase("testhadoop")); + + createAndInitUnpartitionedTable(); + createOrReplaceView("deleted_id", Collections.singletonList(1), Encoders.INT()); + + sql( + "ALTER TABLE %s SET TBLPROPERTIES('%s' '%s')", + tableName, DELETE_ISOLATION_LEVEL, "serializable"); + + sql("INSERT INTO TABLE %s VALUES (1, 'hr')", tableName); + createBranchIfNeeded(); + + ExecutorService executorService = + MoreExecutors.getExitingExecutorService( + (ThreadPoolExecutor) Executors.newFixedThreadPool(2)); + + AtomicInteger barrier = new AtomicInteger(0); + AtomicBoolean shouldAppend = new AtomicBoolean(true); + + // delete thread + Future deleteFuture = + executorService.submit( + () -> { + for (int numOperations = 0; numOperations < Integer.MAX_VALUE; numOperations++) { + while (barrier.get() < numOperations * 2) { + sleep(10); + } + + sql("DELETE FROM %s WHERE id IN (SELECT * FROM deleted_id)", commitTarget()); + + barrier.incrementAndGet(); + } + }); + + // append thread + Future appendFuture = + executorService.submit( + () -> { + // load the table via the validation catalog to use another table instance + Table table = validationCatalog.loadTable(tableIdent); + + GenericRecord record = GenericRecord.create(SnapshotUtil.schemaFor(table, branch)); + record.set(0, 1); // id + record.set(1, "hr"); // dep + + for (int numOperations = 0; numOperations < Integer.MAX_VALUE; numOperations++) { + while (shouldAppend.get() && barrier.get() < numOperations * 2) { + sleep(10); + } + + if (!shouldAppend.get()) { + return; + } + + for (int numAppends = 0; numAppends < 5; numAppends++) { + DataFile dataFile = writeDataFile(table, ImmutableList.of(record)); + AppendFiles appendFiles = table.newFastAppend().appendFile(dataFile); + if (branch != null) { + appendFiles.toBranch(branch); + } + + appendFiles.commit(); + sleep(10); + } + + barrier.incrementAndGet(); + } + }); + + try { + Assertions.assertThatThrownBy(deleteFuture::get) + .isInstanceOf(ExecutionException.class) + .cause() + .isInstanceOf(ValidationException.class) + .hasMessageContaining("Found conflicting files that can contain"); + } finally { + shouldAppend.set(false); + appendFuture.cancel(true); + } + + executorService.shutdown(); + Assert.assertTrue("Timeout", executorService.awaitTermination(2, TimeUnit.MINUTES)); + } + + @Test + public synchronized void testDeleteWithSnapshotIsolation() + throws InterruptedException, ExecutionException { + // cannot run tests with concurrency for Hadoop tables without atomic renames + Assume.assumeFalse(catalogName.equalsIgnoreCase("testhadoop")); + + createAndInitUnpartitionedTable(); + createOrReplaceView("deleted_id", Collections.singletonList(1), Encoders.INT()); + + sql( + "ALTER TABLE %s SET TBLPROPERTIES('%s' '%s')", + tableName, DELETE_ISOLATION_LEVEL, "snapshot"); + + sql("INSERT INTO TABLE %s VALUES (1, 'hr')", tableName); + createBranchIfNeeded(); + + ExecutorService executorService = + MoreExecutors.getExitingExecutorService( + (ThreadPoolExecutor) Executors.newFixedThreadPool(2)); + + AtomicInteger barrier = new AtomicInteger(0); + AtomicBoolean shouldAppend = new AtomicBoolean(true); + + // delete thread + Future deleteFuture = + executorService.submit( + () -> { + for (int numOperations = 0; numOperations < 20; numOperations++) { + while (barrier.get() < numOperations * 2) { + sleep(10); + } + + sql("DELETE FROM %s WHERE id IN (SELECT * FROM deleted_id)", commitTarget()); + + barrier.incrementAndGet(); + } + }); + + // append thread + Future appendFuture = + executorService.submit( + () -> { + // load the table via the validation catalog to use another table instance for inserts + Table table = validationCatalog.loadTable(tableIdent); + + GenericRecord record = GenericRecord.create(SnapshotUtil.schemaFor(table, branch)); + record.set(0, 1); // id + record.set(1, "hr"); // dep + + for (int numOperations = 0; numOperations < 20; numOperations++) { + while (shouldAppend.get() && barrier.get() < numOperations * 2) { + sleep(10); + } + + if (!shouldAppend.get()) { + return; + } + + for (int numAppends = 0; numAppends < 5; numAppends++) { + DataFile dataFile = writeDataFile(table, ImmutableList.of(record)); + AppendFiles appendFiles = table.newFastAppend().appendFile(dataFile); + if (branch != null) { + appendFiles.toBranch(branch); + } + + appendFiles.commit(); + sleep(10); + } + + barrier.incrementAndGet(); + } + }); + + try { + deleteFuture.get(); + } finally { + shouldAppend.set(false); + appendFuture.cancel(true); + } + + executorService.shutdown(); + Assert.assertTrue("Timeout", executorService.awaitTermination(2, TimeUnit.MINUTES)); + } + + @Test + public void testDeleteRefreshesRelationCache() throws NoSuchTableException { + createAndInitPartitionedTable(); + + append(tableName, new Employee(1, "hr"), new Employee(3, "hr")); + createBranchIfNeeded(); + append(new Employee(1, "hardware"), new Employee(2, "hardware")); + + Dataset query = spark.sql("SELECT * FROM " + commitTarget() + " WHERE id = 1"); + query.createOrReplaceTempView("tmp"); + + spark.sql("CACHE TABLE tmp"); + + assertEquals( + "View should have correct data", + ImmutableList.of(row(1, "hardware"), row(1, "hr")), + sql("SELECT * FROM tmp ORDER BY id, dep")); + + sql("DELETE FROM %s WHERE id = 1", commitTarget()); + + Table table = validationCatalog.loadTable(tableIdent); + Assert.assertEquals("Should have 3 snapshots", 3, Iterables.size(table.snapshots())); + + Snapshot currentSnapshot = SnapshotUtil.latestSnapshot(table, branch); + if (mode(table) == COPY_ON_WRITE) { + validateCopyOnWrite(currentSnapshot, "2", "2", "2"); + } else { + validateMergeOnRead(currentSnapshot, "2", "2", null); + } + assertEquals( + "Should have expected rows", + ImmutableList.of(row(2, "hardware"), row(3, "hr")), + sql("SELECT * FROM %s ORDER BY id, dep", commitTarget())); + + assertEquals( + "Should refresh the relation cache", + ImmutableList.of(), + sql("SELECT * FROM tmp ORDER BY id, dep")); + + spark.sql("UNCACHE TABLE tmp"); + } + + @Test + public void testDeleteWithMultipleSpecs() { + createAndInitTable("id INT, dep STRING, category STRING"); + + // write an unpartitioned file + append(tableName, "{ \"id\": 1, \"dep\": \"hr\", \"category\": \"c1\"}"); + createBranchIfNeeded(); + + // write a file partitioned by dep + sql("ALTER TABLE %s ADD PARTITION FIELD dep", tableName); + append( + commitTarget(), + "{ \"id\": 2, \"dep\": \"hr\", \"category\": \"c1\" }\n" + + "{ \"id\": 3, \"dep\": \"hr\", \"category\": \"c1\" }"); + + // write a file partitioned by dep and category + sql("ALTER TABLE %s ADD PARTITION FIELD category", tableName); + append(commitTarget(), "{ \"id\": 5, \"dep\": \"hr\", \"category\": \"c1\"}"); + + // write another file partitioned by dep + sql("ALTER TABLE %s DROP PARTITION FIELD category", tableName); + append(commitTarget(), "{ \"id\": 7, \"dep\": \"hr\", \"category\": \"c1\"}"); + + sql("DELETE FROM %s WHERE id IN (1, 3, 5, 7)", commitTarget()); + + Table table = validationCatalog.loadTable(tableIdent); + Assert.assertEquals("Should have 5 snapshots", 5, Iterables.size(table.snapshots())); + + Snapshot currentSnapshot = SnapshotUtil.latestSnapshot(table, branch); + if (mode(table) == COPY_ON_WRITE) { + // copy-on-write is tested against v1 and such tables have different partition evolution + // behavior + // that's why the number of changed partitions is 4 for copy-on-write + validateCopyOnWrite(currentSnapshot, "4", "4", "1"); + } else { + validateMergeOnRead(currentSnapshot, "3", "3", null); + } + + assertEquals( + "Should have expected rows", + ImmutableList.of(row(2, "hr", "c1")), + sql("SELECT * FROM %s ORDER BY id", selectTarget())); + } + + @Test + public void testDeleteToWapBranch() throws NoSuchTableException { + Assume.assumeTrue("WAP branch only works for table identifier without branch", branch == null); + + createAndInitPartitionedTable(); + sql( + "ALTER TABLE %s SET TBLPROPERTIES ('%s' = 'true')", + tableName, TableProperties.WRITE_AUDIT_PUBLISH_ENABLED); + append(new Employee(0, "hr"), new Employee(1, "hr"), new Employee(2, "hr")); + + withSQLConf( + ImmutableMap.of(SparkSQLProperties.WAP_BRANCH, "wap"), + () -> { + sql("DELETE FROM %s t WHERE id=0", tableName); + Assert.assertEquals( + "Should have expected num of rows when reading table", + 2L, + spark.table(tableName).count()); + Assert.assertEquals( + "Should have expected num of rows when reading WAP branch", + 2L, + spark.table(tableName + ".branch_wap").count()); + Assert.assertEquals( + "Should not modify main branch", 3L, spark.table(tableName + ".branch_main").count()); + }); + + withSQLConf( + ImmutableMap.of(SparkSQLProperties.WAP_BRANCH, "wap"), + () -> { + sql("DELETE FROM %s t WHERE id=1", tableName); + Assert.assertEquals( + "Should have expected num of rows when reading table with multiple writes", + 1L, + spark.table(tableName).count()); + Assert.assertEquals( + "Should have expected num of rows when reading WAP branch with multiple writes", + 1L, + spark.table(tableName + ".branch_wap").count()); + Assert.assertEquals( + "Should not modify main branch with multiple writes", + 3L, + spark.table(tableName + ".branch_main").count()); + }); + } + + @Test + public void testDeleteToWapBranchWithTableBranchIdentifier() throws NoSuchTableException { + Assume.assumeTrue("Test must have branch name part in table identifier", branch != null); + + createAndInitPartitionedTable(); + sql( + "ALTER TABLE %s SET TBLPROPERTIES ('%s' = 'true')", + tableName, TableProperties.WRITE_AUDIT_PUBLISH_ENABLED); + append(tableName, new Employee(0, "hr"), new Employee(1, "hr"), new Employee(2, "hr")); + createBranchIfNeeded(); + + withSQLConf( + ImmutableMap.of(SparkSQLProperties.WAP_BRANCH, "wap"), + () -> + Assertions.assertThatThrownBy(() -> sql("DELETE FROM %s t WHERE id=0", commitTarget())) + .isInstanceOf(ValidationException.class) + .hasMessage( + String.format( + "Cannot write to both branch and WAP branch, but got branch [%s] and WAP branch [wap]", + branch))); + } + + // TODO: multiple stripes for ORC + + protected void createAndInitPartitionedTable() { + sql("CREATE TABLE %s (id INT, dep STRING) USING iceberg PARTITIONED BY (dep)", tableName); + initTable(); + } + + protected void createAndInitUnpartitionedTable() { + sql("CREATE TABLE %s (id INT, dep STRING) USING iceberg", tableName); + initTable(); + } + + protected void createAndInitNestedColumnsTable() { + sql("CREATE TABLE %s (id INT, complex STRUCT) USING iceberg", tableName); + initTable(); + } + + protected void append(Employee... employees) throws NoSuchTableException { + append(commitTarget(), employees); + } + + protected void append(String target, Employee... employees) throws NoSuchTableException { + List input = Arrays.asList(employees); + Dataset inputDF = spark.createDataFrame(input, Employee.class); + inputDF.coalesce(1).writeTo(target).append(); + } + + private RowLevelOperationMode mode(Table table) { + String modeName = table.properties().getOrDefault(DELETE_MODE, DELETE_MODE_DEFAULT); + return RowLevelOperationMode.fromName(modeName); + } + + private LogicalPlan parsePlan(String query, Object... args) { + try { + return spark.sessionState().sqlParser().parsePlan(String.format(query, args)); + } catch (ParseException e) { + throw new RuntimeException(e); + } + } +} diff --git a/spark/v3.4/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestExpireSnapshotsProcedure.java b/spark/v3.4/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestExpireSnapshotsProcedure.java new file mode 100644 index 000000000000..efb3d43668f1 --- /dev/null +++ b/spark/v3.4/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestExpireSnapshotsProcedure.java @@ -0,0 +1,544 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.extensions; + +import static org.apache.iceberg.TableProperties.GC_ENABLED; + +import java.io.File; +import java.io.IOException; +import java.nio.ByteBuffer; +import java.nio.charset.StandardCharsets; +import java.sql.Timestamp; +import java.time.Instant; +import java.util.List; +import java.util.Map; +import java.util.UUID; +import java.util.stream.Collectors; +import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.fs.FileSystem; +import org.apache.hadoop.fs.Path; +import org.apache.iceberg.AssertHelpers; +import org.apache.iceberg.GenericBlobMetadata; +import org.apache.iceberg.GenericStatisticsFile; +import org.apache.iceberg.Snapshot; +import org.apache.iceberg.StatisticsFile; +import org.apache.iceberg.Table; +import org.apache.iceberg.exceptions.ValidationException; +import org.apache.iceberg.io.FileIO; +import org.apache.iceberg.puffin.Blob; +import org.apache.iceberg.puffin.Puffin; +import org.apache.iceberg.puffin.PuffinWriter; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableList; +import org.apache.iceberg.relocated.com.google.common.collect.Iterables; +import org.apache.iceberg.relocated.com.google.common.collect.Lists; +import org.apache.iceberg.spark.SparkCatalog; +import org.apache.iceberg.spark.data.TestHelpers; +import org.apache.iceberg.spark.source.SimpleRecord; +import org.apache.spark.sql.AnalysisException; +import org.apache.spark.sql.Encoders; +import org.apache.spark.sql.catalyst.analysis.NoSuchProcedureException; +import org.assertj.core.api.Assertions; +import org.junit.After; +import org.junit.Assert; +import org.junit.Test; + +public class TestExpireSnapshotsProcedure extends SparkExtensionsTestBase { + + public TestExpireSnapshotsProcedure( + String catalogName, String implementation, Map config) { + super(catalogName, implementation, config); + } + + @After + public void removeTables() { + sql("DROP TABLE IF EXISTS %s", tableName); + } + + @Test + public void testExpireSnapshotsInEmptyTable() { + sql("CREATE TABLE %s (id bigint NOT NULL, data string) USING iceberg", tableName); + + List output = sql("CALL %s.system.expire_snapshots('%s')", catalogName, tableIdent); + assertEquals( + "Should not delete any files", ImmutableList.of(row(0L, 0L, 0L, 0L, 0L, 0L)), output); + } + + @Test + public void testExpireSnapshotsUsingPositionalArgs() { + sql("CREATE TABLE %s (id bigint NOT NULL, data string) USING iceberg", tableName); + sql("INSERT INTO TABLE %s VALUES (1, 'a')", tableName); + + Table table = validationCatalog.loadTable(tableIdent); + Snapshot firstSnapshot = table.currentSnapshot(); + + waitUntilAfter(firstSnapshot.timestampMillis()); + + sql("INSERT INTO TABLE %s VALUES (2, 'b')", tableName); + + table.refresh(); + + Snapshot secondSnapshot = table.currentSnapshot(); + Timestamp secondSnapshotTimestamp = + Timestamp.from(Instant.ofEpochMilli(secondSnapshot.timestampMillis())); + + Assert.assertEquals("Should be 2 snapshots", 2, Iterables.size(table.snapshots())); + + // expire without retainLast param + List output1 = + sql( + "CALL %s.system.expire_snapshots('%s', TIMESTAMP '%s')", + catalogName, tableIdent, secondSnapshotTimestamp); + assertEquals( + "Procedure output must match", ImmutableList.of(row(0L, 0L, 0L, 0L, 1L, 0L)), output1); + + table.refresh(); + + Assert.assertEquals("Should expire one snapshot", 1, Iterables.size(table.snapshots())); + + sql("INSERT OVERWRITE %s VALUES (3, 'c')", tableName); + sql("INSERT INTO TABLE %s VALUES (4, 'd')", tableName); + assertEquals( + "Should have expected rows", + ImmutableList.of(row(3L, "c"), row(4L, "d")), + sql("SELECT * FROM %s ORDER BY id", tableName)); + + table.refresh(); + + waitUntilAfter(table.currentSnapshot().timestampMillis()); + + Timestamp currentTimestamp = Timestamp.from(Instant.ofEpochMilli(System.currentTimeMillis())); + + Assert.assertEquals("Should be 3 snapshots", 3, Iterables.size(table.snapshots())); + + // expire with retainLast param + List output = + sql( + "CALL %s.system.expire_snapshots('%s', TIMESTAMP '%s', 2)", + catalogName, tableIdent, currentTimestamp); + assertEquals( + "Procedure output must match", ImmutableList.of(row(2L, 0L, 0L, 2L, 1L, 0L)), output); + } + + @Test + public void testExpireSnapshotUsingNamedArgs() { + sql("CREATE TABLE %s (id bigint NOT NULL, data string) USING iceberg", tableName); + + sql("INSERT INTO TABLE %s VALUES (1, 'a')", tableName); + sql("INSERT INTO TABLE %s VALUES (2, 'b')", tableName); + + Table table = validationCatalog.loadTable(tableIdent); + + Assert.assertEquals("Should be 2 snapshots", 2, Iterables.size(table.snapshots())); + + waitUntilAfter(table.currentSnapshot().timestampMillis()); + + Timestamp currentTimestamp = Timestamp.from(Instant.ofEpochMilli(System.currentTimeMillis())); + + List output = + sql( + "CALL %s.system.expire_snapshots(older_than => TIMESTAMP '%s',table => '%s')", + catalogName, currentTimestamp, tableIdent); + assertEquals( + "Procedure output must match", ImmutableList.of(row(0L, 0L, 0L, 0L, 1L, 0L)), output); + } + + @Test + public void testExpireSnapshotsGCDisabled() { + sql("CREATE TABLE %s (id bigint NOT NULL, data string) USING iceberg", tableName); + + sql("ALTER TABLE %s SET TBLPROPERTIES ('%s' 'false')", tableName, GC_ENABLED); + + AssertHelpers.assertThrows( + "Should reject call", + ValidationException.class, + "Cannot expire snapshots: GC is disabled", + () -> sql("CALL %s.system.expire_snapshots('%s')", catalogName, tableIdent)); + } + + @Test + public void testInvalidExpireSnapshotsCases() { + AssertHelpers.assertThrows( + "Should not allow mixed args", + AnalysisException.class, + "Named and positional arguments cannot be mixed", + () -> sql("CALL %s.system.expire_snapshots('n', table => 't')", catalogName)); + + AssertHelpers.assertThrows( + "Should not resolve procedures in arbitrary namespaces", + NoSuchProcedureException.class, + "not found", + () -> sql("CALL %s.custom.expire_snapshots('n', 't')", catalogName)); + + AssertHelpers.assertThrows( + "Should reject calls without all required args", + AnalysisException.class, + "Missing required parameters", + () -> sql("CALL %s.system.expire_snapshots()", catalogName)); + + AssertHelpers.assertThrows( + "Should reject calls with invalid arg types", + AnalysisException.class, + "Wrong arg type", + () -> sql("CALL %s.system.expire_snapshots('n', 2.2)", catalogName)); + + AssertHelpers.assertThrows( + "Should reject calls with empty table identifier", + IllegalArgumentException.class, + "Cannot handle an empty identifier", + () -> sql("CALL %s.system.expire_snapshots('')", catalogName)); + } + + @Test + public void testResolvingTableInAnotherCatalog() throws IOException { + String anotherCatalog = "another_" + catalogName; + spark.conf().set("spark.sql.catalog." + anotherCatalog, SparkCatalog.class.getName()); + spark.conf().set("spark.sql.catalog." + anotherCatalog + ".type", "hadoop"); + spark + .conf() + .set( + "spark.sql.catalog." + anotherCatalog + ".warehouse", + "file:" + temp.newFolder().toString()); + + sql( + "CREATE TABLE %s.%s (id bigint NOT NULL, data string) USING iceberg", + anotherCatalog, tableIdent); + + AssertHelpers.assertThrows( + "Should reject calls for a table in another catalog", + IllegalArgumentException.class, + "Cannot run procedure in catalog", + () -> + sql( + "CALL %s.system.expire_snapshots('%s')", + catalogName, anotherCatalog + "." + tableName)); + } + + @Test + public void testConcurrentExpireSnapshots() { + sql("CREATE TABLE %s (id bigint NOT NULL, data string) USING iceberg", tableName); + + sql("INSERT INTO TABLE %s VALUES (1, 'a')", tableName); + sql("INSERT INTO TABLE %s VALUES (2, 'b')", tableName); + sql("INSERT INTO TABLE %s VALUES (3, 'c')", tableName); + sql("INSERT INTO TABLE %s VALUES (4, 'd')", tableName); + + Timestamp currentTimestamp = Timestamp.from(Instant.ofEpochMilli(System.currentTimeMillis())); + List output = + sql( + "CALL %s.system.expire_snapshots(" + + "older_than => TIMESTAMP '%s'," + + "table => '%s'," + + "max_concurrent_deletes => %s)", + catalogName, currentTimestamp, tableIdent, 4); + assertEquals( + "Expiring snapshots concurrently should succeed", + ImmutableList.of(row(0L, 0L, 0L, 0L, 3L, 0L)), + output); + } + + @Test + public void testConcurrentExpireSnapshotsWithInvalidInput() { + sql("CREATE TABLE %s (id bigint NOT NULL, data string) USING iceberg", tableName); + + AssertHelpers.assertThrows( + "Should throw an error when max_concurrent_deletes = 0", + IllegalArgumentException.class, + "max_concurrent_deletes should have value > 0", + () -> + sql( + "CALL %s.system.expire_snapshots(table => '%s', max_concurrent_deletes => %s)", + catalogName, tableIdent, 0)); + + AssertHelpers.assertThrows( + "Should throw an error when max_concurrent_deletes < 0 ", + IllegalArgumentException.class, + "max_concurrent_deletes should have value > 0", + () -> + sql( + "CALL %s.system.expire_snapshots(table => '%s', max_concurrent_deletes => %s)", + catalogName, tableIdent, -1)); + } + + @Test + public void testExpireDeleteFiles() throws Exception { + sql( + "CREATE TABLE %s (id bigint, data string) USING iceberg TBLPROPERTIES" + + "('format-version'='2', 'write.delete.mode'='merge-on-read')", + tableName); + + List records = + Lists.newArrayList( + new SimpleRecord(1, "a"), + new SimpleRecord(2, "b"), + new SimpleRecord(3, "c"), + new SimpleRecord(4, "d")); + spark + .createDataset(records, Encoders.bean(SimpleRecord.class)) + .coalesce(1) + .writeTo(tableName) + .append(); + sql("DELETE FROM %s WHERE id=1", tableName); + + Table table = validationCatalog.loadTable(tableIdent); + + Assert.assertEquals( + "Should have 1 delete manifest", 1, TestHelpers.deleteManifests(table).size()); + Assert.assertEquals("Should have 1 delete file", 1, TestHelpers.deleteFiles(table).size()); + Path deleteManifestPath = new Path(TestHelpers.deleteManifests(table).iterator().next().path()); + Path deleteFilePath = + new Path(String.valueOf(TestHelpers.deleteFiles(table).iterator().next().path())); + + sql( + "CALL %s.system.rewrite_data_files(" + + "table => '%s'," + + "options => map(" + + "'delete-file-threshold','1'," + + "'use-starting-sequence-number', 'false'))", + catalogName, tableIdent); + table.refresh(); + + sql( + "INSERT INTO TABLE %s VALUES (5, 'e')", + tableName); // this txn moves the file to the DELETED state + sql("INSERT INTO TABLE %s VALUES (6, 'f')", tableName); // this txn removes the file reference + table.refresh(); + + Assert.assertEquals( + "Should have no delete manifests", 0, TestHelpers.deleteManifests(table).size()); + Assert.assertEquals("Should have no delete files", 0, TestHelpers.deleteFiles(table).size()); + + FileSystem localFs = FileSystem.getLocal(new Configuration()); + Assert.assertTrue("Delete manifest should still exist", localFs.exists(deleteManifestPath)); + Assert.assertTrue("Delete file should still exist", localFs.exists(deleteFilePath)); + + Timestamp currentTimestamp = Timestamp.from(Instant.ofEpochMilli(System.currentTimeMillis())); + List output = + sql( + "CALL %s.system.expire_snapshots(older_than => TIMESTAMP '%s',table => '%s')", + catalogName, currentTimestamp, tableIdent); + + assertEquals( + "Should deleted 1 data and pos delete file and 4 manifests and lists (one for each txn)", + ImmutableList.of(row(1L, 1L, 0L, 4L, 4L, 0L)), + output); + Assert.assertFalse("Delete manifest should be removed", localFs.exists(deleteManifestPath)); + Assert.assertFalse("Delete file should be removed", localFs.exists(deleteFilePath)); + } + + @Test + public void testExpireSnapshotWithStreamResultsEnabled() { + sql("CREATE TABLE %s (id bigint NOT NULL, data string) USING iceberg", tableName); + + sql("INSERT INTO TABLE %s VALUES (1, 'a')", tableName); + sql("INSERT INTO TABLE %s VALUES (2, 'b')", tableName); + + Table table = validationCatalog.loadTable(tableIdent); + + Assert.assertEquals("Should be 2 snapshots", 2, Iterables.size(table.snapshots())); + + waitUntilAfter(table.currentSnapshot().timestampMillis()); + + Timestamp currentTimestamp = Timestamp.from(Instant.ofEpochMilli(System.currentTimeMillis())); + + List output = + sql( + "CALL %s.system.expire_snapshots(" + + "older_than => TIMESTAMP '%s'," + + "table => '%s'," + + "stream_results => true)", + catalogName, currentTimestamp, tableIdent); + assertEquals( + "Procedure output must match", ImmutableList.of(row(0L, 0L, 0L, 0L, 1L, 0L)), output); + } + + @Test + public void testExpireSnapshotsWithSnapshotId() { + sql("CREATE TABLE %s (id bigint NOT NULL, data string) USING iceberg", tableName); + + sql("INSERT INTO TABLE %s VALUES (1, 'a')", tableName); + sql("INSERT INTO TABLE %s VALUES (2, 'b')", tableName); + + Table table = validationCatalog.loadTable(tableIdent); + + Assert.assertEquals("Should be 2 snapshots", 2, Iterables.size(table.snapshots())); + + // Expiring the snapshot specified by snapshot_id should keep only a single snapshot. + long firstSnapshotId = table.currentSnapshot().parentId(); + sql( + "CALL %s.system.expire_snapshots(" + "table => '%s'," + "snapshot_ids => ARRAY(%d))", + catalogName, tableIdent, firstSnapshotId); + + // There should only be one single snapshot left. + table.refresh(); + Assert.assertEquals("Should be 1 snapshots", 1, Iterables.size(table.snapshots())); + Assert.assertEquals( + "Snapshot ID should not be present", + 0, + Iterables.size( + Iterables.filter( + table.snapshots(), snapshot -> snapshot.snapshotId() == firstSnapshotId))); + } + + @Test + public void testExpireSnapshotShouldFailForCurrentSnapshot() { + sql("CREATE TABLE %s (id bigint NOT NULL, data string) USING iceberg", tableName); + + sql("INSERT INTO TABLE %s VALUES (1, 'a')", tableName); + sql("INSERT INTO TABLE %s VALUES (2, 'b')", tableName); + + Table table = validationCatalog.loadTable(tableIdent); + Assert.assertEquals("Should be 2 snapshots", 2, Iterables.size(table.snapshots())); + + AssertHelpers.assertThrows( + "Should reject call", + IllegalArgumentException.class, + "Cannot expire", + () -> + sql( + "CALL %s.system.expire_snapshots(" + + "table => '%s'," + + "snapshot_ids => ARRAY(%d, %d))", + catalogName, + tableIdent, + table.currentSnapshot().snapshotId(), + table.currentSnapshot().parentId())); + } + + @Test + public void testExpireSnapshotsProcedureWorksWithSqlComments() { + // Ensure that systems such as dbt, that inject comments into the generated SQL files, will + // work with Iceberg-specific DDL + sql("CREATE TABLE %s (id bigint NOT NULL, data string) USING iceberg", tableName); + + sql("INSERT INTO TABLE %s VALUES (1, 'a')", tableName); + sql("INSERT INTO TABLE %s VALUES (2, 'b')", tableName); + + Table table = validationCatalog.loadTable(tableIdent); + + Assert.assertEquals("Should be 2 snapshots", 2, Iterables.size(table.snapshots())); + + waitUntilAfter(table.currentSnapshot().timestampMillis()); + + Timestamp currentTimestamp = Timestamp.from(Instant.ofEpochMilli(System.currentTimeMillis())); + + String callStatement = + "/* CALL statement is used to expire snapshots */\n" + + "-- And we have single line comments as well \n" + + "/* And comments that span *multiple* \n" + + " lines */ CALL /* this is the actual CALL */ %s.system.expire_snapshots(" + + " older_than => TIMESTAMP '%s'," + + " table => '%s')"; + List output = sql(callStatement, catalogName, currentTimestamp, tableIdent); + assertEquals( + "Procedure output must match", ImmutableList.of(row(0L, 0L, 0L, 0L, 1L, 0L)), output); + + table.refresh(); + + Assert.assertEquals("Should be 1 snapshot remaining", 1, Iterables.size(table.snapshots())); + } + + @Test + public void testExpireSnapshotsWithStatisticFiles() throws Exception { + sql("CREATE TABLE %s (id bigint NOT NULL, data string) USING iceberg", tableName); + sql("INSERT INTO TABLE %s VALUES (10, 'abc')", tableName); + Table table = validationCatalog.loadTable(tableIdent); + String statsFileLocation1 = statsFileLocation(table.location()); + StatisticsFile statisticsFile1 = + writeStatsFile( + table.currentSnapshot().snapshotId(), + table.currentSnapshot().sequenceNumber(), + statsFileLocation1, + table.io()); + table.updateStatistics().setStatistics(statisticsFile1.snapshotId(), statisticsFile1).commit(); + + sql("INSERT INTO %s SELECT 20, 'def'", tableName); + table.refresh(); + String statsFileLocation2 = statsFileLocation(table.location()); + StatisticsFile statisticsFile2 = + writeStatsFile( + table.currentSnapshot().snapshotId(), + table.currentSnapshot().sequenceNumber(), + statsFileLocation2, + table.io()); + table.updateStatistics().setStatistics(statisticsFile2.snapshotId(), statisticsFile2).commit(); + + waitUntilAfter(table.currentSnapshot().timestampMillis()); + + Timestamp currentTimestamp = Timestamp.from(Instant.ofEpochMilli(System.currentTimeMillis())); + List output = + sql( + "CALL %s.system.expire_snapshots(older_than => TIMESTAMP '%s',table => '%s')", + catalogName, currentTimestamp, tableIdent); + Assertions.assertThat(output.get(0)[5]).as("should be 1 deleted statistics file").isEqualTo(1L); + + table.refresh(); + List statsWithSnapshotId1 = + table.statisticsFiles().stream() + .filter(statisticsFile -> statisticsFile.snapshotId() == statisticsFile1.snapshotId()) + .collect(Collectors.toList()); + Assertions.assertThat(statsWithSnapshotId1) + .as( + "Statistics file entry in TableMetadata should be deleted for the snapshot %s", + statisticsFile1.snapshotId()) + .isEmpty(); + Assertions.assertThat(table.statisticsFiles()) + .as( + "Statistics file entry in TableMetadata should be present for the snapshot %s", + statisticsFile2.snapshotId()) + .extracting(StatisticsFile::snapshotId) + .containsExactly(statisticsFile2.snapshotId()); + + Assertions.assertThat(new File(statsFileLocation1)) + .as("Statistics file should not exist for snapshot %s", statisticsFile1.snapshotId()) + .doesNotExist(); + + Assertions.assertThat(new File(statsFileLocation2)) + .as("Statistics file should exist for snapshot %s", statisticsFile2.snapshotId()) + .exists(); + } + + private StatisticsFile writeStatsFile( + long snapshotId, long snapshotSequenceNumber, String statsLocation, FileIO fileIO) + throws IOException { + try (PuffinWriter puffinWriter = Puffin.write(fileIO.newOutputFile(statsLocation)).build()) { + puffinWriter.add( + new Blob( + "some-blob-type", + ImmutableList.of(1), + snapshotId, + snapshotSequenceNumber, + ByteBuffer.wrap("blob content".getBytes(StandardCharsets.UTF_8)))); + puffinWriter.finish(); + + return new GenericStatisticsFile( + snapshotId, + statsLocation, + puffinWriter.fileSize(), + puffinWriter.footerSize(), + puffinWriter.writtenBlobsMetadata().stream() + .map(GenericBlobMetadata::from) + .collect(ImmutableList.toImmutableList())); + } + } + + private String statsFileLocation(String tableLocation) { + String statsFileName = "stats-file-" + UUID.randomUUID(); + return tableLocation.replaceFirst("file:", "") + "/metadata/" + statsFileName; + } +} diff --git a/spark/v3.4/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestIcebergExpressions.java b/spark/v3.4/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestIcebergExpressions.java new file mode 100644 index 000000000000..8d2e10ea17eb --- /dev/null +++ b/spark/v3.4/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestIcebergExpressions.java @@ -0,0 +1,74 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.extensions; + +import java.math.BigDecimal; +import java.util.Map; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableList; +import org.apache.spark.sql.Column; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.catalyst.expressions.IcebergTruncateTransform; +import org.junit.After; +import org.junit.Test; + +public class TestIcebergExpressions extends SparkExtensionsTestBase { + + public TestIcebergExpressions( + String catalogName, String implementation, Map config) { + super(catalogName, implementation, config); + } + + @After + public void removeTables() { + sql("DROP TABLE IF EXISTS %s", tableName); + sql("DROP VIEW IF EXISTS emp"); + sql("DROP VIEW IF EXISTS v"); + } + + @Test + public void testTruncateExpressions() { + sql( + "CREATE TABLE %s ( " + + " int_c INT, long_c LONG, dec_c DECIMAL(4, 2), str_c STRING, binary_c BINARY " + + ") USING iceberg", + tableName); + + sql( + "CREATE TEMPORARY VIEW emp " + + "AS SELECT * FROM VALUES (101, 10001, 10.65, '101-Employee', CAST('1234' AS BINARY)) " + + "AS EMP(int_c, long_c, dec_c, str_c, binary_c)"); + + sql("INSERT INTO %s SELECT * FROM emp", tableName); + + Dataset df = spark.sql("SELECT * FROM " + tableName); + df.select( + new Column(new IcebergTruncateTransform(df.col("int_c").expr(), 2)).as("int_c"), + new Column(new IcebergTruncateTransform(df.col("long_c").expr(), 2)).as("long_c"), + new Column(new IcebergTruncateTransform(df.col("dec_c").expr(), 50)).as("dec_c"), + new Column(new IcebergTruncateTransform(df.col("str_c").expr(), 2)).as("str_c"), + new Column(new IcebergTruncateTransform(df.col("binary_c").expr(), 2)).as("binary_c")) + .createOrReplaceTempView("v"); + + assertEquals( + "Should have expected rows", + ImmutableList.of(row(100, 10000L, new BigDecimal("10.50"), "10", "12")), + sql("SELECT int_c, long_c, dec_c, str_c, CAST(binary_c AS STRING) FROM v")); + } +} diff --git a/spark/v3.4/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestMerge.java b/spark/v3.4/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestMerge.java new file mode 100644 index 000000000000..4ec78ec38532 --- /dev/null +++ b/spark/v3.4/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestMerge.java @@ -0,0 +1,2561 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.extensions; + +import static org.apache.iceberg.RowLevelOperationMode.COPY_ON_WRITE; +import static org.apache.iceberg.TableProperties.MERGE_ISOLATION_LEVEL; +import static org.apache.iceberg.TableProperties.MERGE_MODE; +import static org.apache.iceberg.TableProperties.MERGE_MODE_DEFAULT; +import static org.apache.iceberg.TableProperties.PARQUET_ROW_GROUP_SIZE_BYTES; +import static org.apache.iceberg.TableProperties.SPLIT_SIZE; +import static org.apache.iceberg.TableProperties.WRITE_DISTRIBUTION_MODE; +import static org.apache.spark.sql.functions.lit; + +import java.util.Arrays; +import java.util.Collections; +import java.util.List; +import java.util.Map; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.Future; +import java.util.concurrent.ThreadPoolExecutor; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicInteger; +import org.apache.iceberg.AppendFiles; +import org.apache.iceberg.AssertHelpers; +import org.apache.iceberg.DataFile; +import org.apache.iceberg.DistributionMode; +import org.apache.iceberg.RowLevelOperationMode; +import org.apache.iceberg.Snapshot; +import org.apache.iceberg.SnapshotSummary; +import org.apache.iceberg.Table; +import org.apache.iceberg.TableProperties; +import org.apache.iceberg.data.GenericRecord; +import org.apache.iceberg.exceptions.ValidationException; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableList; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap; +import org.apache.iceberg.relocated.com.google.common.collect.Lists; +import org.apache.iceberg.relocated.com.google.common.util.concurrent.MoreExecutors; +import org.apache.iceberg.spark.SparkSQLProperties; +import org.apache.iceberg.util.SnapshotUtil; +import org.apache.spark.SparkException; +import org.apache.spark.sql.AnalysisException; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Encoders; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.catalyst.analysis.NoSuchTableException; +import org.apache.spark.sql.execution.SparkPlan; +import org.apache.spark.sql.internal.SQLConf; +import org.assertj.core.api.Assertions; +import org.junit.After; +import org.junit.Assert; +import org.junit.Assume; +import org.junit.BeforeClass; +import org.junit.Test; + +public abstract class TestMerge extends SparkRowLevelOperationsTestBase { + + public TestMerge( + String catalogName, + String implementation, + Map config, + String fileFormat, + boolean vectorized, + String distributionMode, + String branch) { + super(catalogName, implementation, config, fileFormat, vectorized, distributionMode, branch); + } + + @BeforeClass + public static void setupSparkConf() { + spark.conf().set("spark.sql.shuffle.partitions", "4"); + } + + @After + public void removeTables() { + sql("DROP TABLE IF EXISTS %s", tableName); + sql("DROP TABLE IF EXISTS source"); + } + + @Test + public void testMergeConditionSplitIntoTargetPredicateAndJoinCondition() { + createAndInitTable( + "id INT, salary INT, dep STRING, sub_dep STRING", + "PARTITIONED BY (dep, sub_dep)", + "{ \"id\": 1, \"salary\": 100, \"dep\": \"d1\", \"sub_dep\": \"sd1\" }\n" + + "{ \"id\": 6, \"salary\": 600, \"dep\": \"d6\", \"sub_dep\": \"sd6\" }"); + + createOrReplaceView( + "source", + "id INT, salary INT, dep STRING, sub_dep STRING", + "{ \"id\": 1, \"salary\": 101, \"dep\": \"d1\", \"sub_dep\": \"sd1\" }\n" + + "{ \"id\": 2, \"salary\": 200, \"dep\": \"d2\", \"sub_dep\": \"sd2\" }\n" + + "{ \"id\": 3, \"salary\": 300, \"dep\": \"d3\", \"sub_dep\": \"sd3\" }"); + + String query = + String.format( + "MERGE INTO %s AS t USING source AS s " + + "ON t.id == s.id AND ((t.dep = 'd1' AND t.sub_dep IN ('sd1', 'sd3')) OR (t.dep = 'd6' AND t.sub_dep IN ('sd2', 'sd6'))) " + + "WHEN MATCHED THEN " + + " UPDATE SET salary = s.salary " + + "WHEN NOT MATCHED THEN " + + " INSERT *", + commitTarget()); + + Table table = validationCatalog.loadTable(tableIdent); + + if (mode(table) == COPY_ON_WRITE) { + checkJoinAndFilterConditions( + query, + "Join [id], [id], FullOuter", + "((dep = 'd1' AND sub_dep IN ('sd1', 'sd3')) OR (dep = 'd6' AND sub_dep IN ('sd2', 'sd6')))"); + } else { + checkJoinAndFilterConditions( + query, + "Join [id], [id], RightOuter", + "((dep = 'd1' AND sub_dep IN ('sd1', 'sd3')) OR (dep = 'd6' AND sub_dep IN ('sd2', 'sd6')))"); + } + + assertEquals( + "Should have expected rows", + ImmutableList.of( + row(1, 101, "d1", "sd1"), // updated + row(2, 200, "d2", "sd2"), // new + row(3, 300, "d3", "sd3"), // new + row(6, 600, "d6", "sd6")), // existing + sql("SELECT * FROM %s ORDER BY id", selectTarget())); + } + + @Test + public void testMergeWithStaticPredicatePushDown() { + createAndInitTable("id BIGINT, dep STRING"); + + sql("ALTER TABLE %s ADD PARTITION FIELD dep", tableName); + + // add a data file to the 'software' partition + append(tableName, "{ \"id\": 1, \"dep\": \"software\" }"); + createBranchIfNeeded(); + + // add a data file to the 'hr' partition + append(commitTarget(), "{ \"id\": 1, \"dep\": \"hr\" }"); + + Table table = validationCatalog.loadTable(tableIdent); + + Snapshot snapshot = SnapshotUtil.latestSnapshot(table, branch); + String dataFilesCount = snapshot.summary().get(SnapshotSummary.TOTAL_DATA_FILES_PROP); + Assert.assertEquals("Must have 2 files before MERGE", "2", dataFilesCount); + + createOrReplaceView( + "source", "{ \"id\": 1, \"dep\": \"finance\" }\n" + "{ \"id\": 2, \"dep\": \"hardware\" }"); + + // remove the data file from the 'hr' partition to ensure it is not scanned + withUnavailableFiles( + snapshot.addedDataFiles(table.io()), + () -> { + // disable dynamic pruning and rely only on static predicate pushdown + withSQLConf( + ImmutableMap.of(SQLConf.DYNAMIC_PARTITION_PRUNING_ENABLED().key(), "false"), + () -> { + sql( + "MERGE INTO %s t USING source " + + "ON t.id == source.id AND t.dep IN ('software') AND source.id < 10 " + + "WHEN MATCHED AND source.id = 1 THEN " + + " UPDATE SET dep = source.dep " + + "WHEN NOT MATCHED THEN " + + " INSERT (dep, id) VALUES (source.dep, source.id)", + commitTarget()); + }); + }); + + ImmutableList expectedRows = + ImmutableList.of( + row(1L, "finance"), // updated + row(1L, "hr"), // kept + row(2L, "hardware") // new + ); + assertEquals( + "Output should match", + expectedRows, + sql("SELECT * FROM %s ORDER BY id, dep", selectTarget())); + } + + @Test + public void testMergeIntoEmptyTargetInsertAllNonMatchingRows() { + Assume.assumeFalse("Custom branch does not exist for empty table", "test".equals(branch)); + createAndInitTable("id INT, dep STRING"); + + createOrReplaceView( + "source", + "id INT, dep STRING", + "{ \"id\": 1, \"dep\": \"emp-id-1\" }\n" + + "{ \"id\": 2, \"dep\": \"emp-id-2\" }\n" + + "{ \"id\": 3, \"dep\": \"emp-id-3\" }"); + + sql( + "MERGE INTO %s AS t USING source AS s " + + "ON t.id == s.id " + + "WHEN NOT MATCHED THEN " + + " INSERT *", + tableName); + + ImmutableList expectedRows = + ImmutableList.of( + row(1, "emp-id-1"), // new + row(2, "emp-id-2"), // new + row(3, "emp-id-3") // new + ); + assertEquals( + "Should have expected rows", + expectedRows, + sql("SELECT * FROM %s ORDER BY id", selectTarget())); + } + + @Test + public void testMergeIntoEmptyTargetInsertOnlyMatchingRows() { + Assume.assumeFalse("Custom branch does not exist for empty table", "test".equals(branch)); + createAndInitTable("id INT, dep STRING"); + + createOrReplaceView( + "source", + "id INT, dep STRING", + "{ \"id\": 1, \"dep\": \"emp-id-1\" }\n" + + "{ \"id\": 2, \"dep\": \"emp-id-2\" }\n" + + "{ \"id\": 3, \"dep\": \"emp-id-3\" }"); + + sql( + "MERGE INTO %s AS t USING source AS s " + + "ON t.id == s.id " + + "WHEN NOT MATCHED AND (s.id >=2) THEN " + + " INSERT *", + tableName); + + ImmutableList expectedRows = + ImmutableList.of( + row(2, "emp-id-2"), // new + row(3, "emp-id-3") // new + ); + assertEquals( + "Should have expected rows", + expectedRows, + sql("SELECT * FROM %s ORDER BY id", selectTarget())); + } + + @Test + public void testMergeWithOnlyUpdateClause() { + createAndInitTable( + "id INT, dep STRING", + "{ \"id\": 1, \"dep\": \"emp-id-one\" }\n" + "{ \"id\": 6, \"dep\": \"emp-id-six\" }"); + + createOrReplaceView( + "source", + "id INT, dep STRING", + "{ \"id\": 2, \"dep\": \"emp-id-2\" }\n" + + "{ \"id\": 1, \"dep\": \"emp-id-1\" }\n" + + "{ \"id\": 6, \"dep\": \"emp-id-6\" }"); + + sql( + "MERGE INTO %s AS t USING source AS s " + + "ON t.id == s.id " + + "WHEN MATCHED AND t.id = 1 THEN " + + " UPDATE SET *", + commitTarget()); + + ImmutableList expectedRows = + ImmutableList.of( + row(1, "emp-id-1"), // updated + row(6, "emp-id-six") // kept + ); + assertEquals( + "Should have expected rows", + expectedRows, + sql("SELECT * FROM %s ORDER BY id", selectTarget())); + } + + @Test + public void testMergeWithOnlyUpdateClauseAndNullValues() { + createAndInitTable( + "id INT, dep STRING", + "{ \"id\": null, \"dep\": \"emp-id-one\" }\n" + + "{ \"id\": 1, \"dep\": \"emp-id-one\" }\n" + + "{ \"id\": 6, \"dep\": \"emp-id-six\" }"); + + createOrReplaceView( + "source", + "id INT, dep STRING", + "{ \"id\": 2, \"dep\": \"emp-id-2\" }\n" + + "{ \"id\": 1, \"dep\": \"emp-id-1\" }\n" + + "{ \"id\": 6, \"dep\": \"emp-id-6\" }"); + + sql( + "MERGE INTO %s AS t USING source AS s " + + "ON t.id == s.id AND t.id < 3 " + + "WHEN MATCHED THEN " + + " UPDATE SET *", + commitTarget()); + + ImmutableList expectedRows = + ImmutableList.of( + row(null, "emp-id-one"), // kept + row(1, "emp-id-1"), // updated + row(6, "emp-id-six")); // kept + assertEquals( + "Should have expected rows", + expectedRows, + sql("SELECT * FROM %s ORDER BY id", selectTarget())); + } + + @Test + public void testMergeWithOnlyDeleteClause() { + createAndInitTable( + "id INT, dep STRING", + "{ \"id\": 1, \"dep\": \"emp-id-one\" }\n" + "{ \"id\": 6, \"dep\": \"emp-id-6\" }"); + + createOrReplaceView( + "source", + "id INT, dep STRING", + "{ \"id\": 2, \"dep\": \"emp-id-2\" }\n" + + "{ \"id\": 1, \"dep\": \"emp-id-1\" }\n" + + "{ \"id\": 6, \"dep\": \"emp-id-6\" }"); + + sql( + "MERGE INTO %s AS t USING source AS s " + + "ON t.id == s.id " + + "WHEN MATCHED AND t.id = 6 THEN " + + " DELETE", + commitTarget()); + + ImmutableList expectedRows = + ImmutableList.of( + row(1, "emp-id-one") // kept + ); + assertEquals( + "Should have expected rows", + expectedRows, + sql("SELECT * FROM %s ORDER BY id", selectTarget())); + } + + @Test + public void testMergeWithAllCauses() { + createAndInitTable( + "id INT, dep STRING", + "{ \"id\": 1, \"dep\": \"emp-id-one\" }\n" + "{ \"id\": 6, \"dep\": \"emp-id-6\" }"); + + createOrReplaceView( + "source", + "id INT, dep STRING", + "{ \"id\": 2, \"dep\": \"emp-id-2\" }\n" + + "{ \"id\": 1, \"dep\": \"emp-id-1\" }\n" + + "{ \"id\": 6, \"dep\": \"emp-id-6\" }"); + + sql( + "MERGE INTO %s AS t USING source AS s " + + "ON t.id == s.id " + + "WHEN MATCHED AND t.id = 1 THEN " + + " UPDATE SET * " + + "WHEN MATCHED AND t.id = 6 THEN " + + " DELETE " + + "WHEN NOT MATCHED AND s.id = 2 THEN " + + " INSERT *", + commitTarget()); + + ImmutableList expectedRows = + ImmutableList.of( + row(1, "emp-id-1"), // updated + row(2, "emp-id-2") // new + ); + assertEquals( + "Should have expected rows", + expectedRows, + sql("SELECT * FROM %s ORDER BY id", selectTarget())); + } + + @Test + public void testMergeWithAllCausesWithExplicitColumnSpecification() { + createAndInitTable( + "id INT, dep STRING", + "{ \"id\": 1, \"dep\": \"emp-id-one\" }\n" + "{ \"id\": 6, \"dep\": \"emp-id-6\" }"); + + createOrReplaceView( + "source", + "id INT, dep STRING", + "{ \"id\": 2, \"dep\": \"emp-id-2\" }\n" + + "{ \"id\": 1, \"dep\": \"emp-id-1\" }\n" + + "{ \"id\": 6, \"dep\": \"emp-id-6\" }"); + + sql( + "MERGE INTO %s AS t USING source AS s " + + "ON t.id == s.id " + + "WHEN MATCHED AND t.id = 1 THEN " + + " UPDATE SET t.id = s.id, t.dep = s.dep " + + "WHEN MATCHED AND t.id = 6 THEN " + + " DELETE " + + "WHEN NOT MATCHED AND s.id = 2 THEN " + + " INSERT (t.id, t.dep) VALUES (s.id, s.dep)", + commitTarget()); + + ImmutableList expectedRows = + ImmutableList.of( + row(1, "emp-id-1"), // updated + row(2, "emp-id-2") // new + ); + assertEquals( + "Should have expected rows", + expectedRows, + sql("SELECT * FROM %s ORDER BY id", selectTarget())); + } + + @Test + public void testMergeWithSourceCTE() { + createAndInitTable( + "id INT, dep STRING", + "{ \"id\": 2, \"dep\": \"emp-id-two\" }\n" + "{ \"id\": 6, \"dep\": \"emp-id-6\" }"); + + createOrReplaceView( + "source", + "id INT, dep STRING", + "{ \"id\": 2, \"dep\": \"emp-id-3\" }\n" + + "{ \"id\": 1, \"dep\": \"emp-id-2\" }\n" + + "{ \"id\": 5, \"dep\": \"emp-id-6\" }"); + + sql( + "WITH cte1 AS (SELECT id + 1 AS id, dep FROM source) " + + "MERGE INTO %s AS t USING cte1 AS s " + + "ON t.id == s.id " + + "WHEN MATCHED AND t.id = 2 THEN " + + " UPDATE SET * " + + "WHEN MATCHED AND t.id = 6 THEN " + + " DELETE " + + "WHEN NOT MATCHED AND s.id = 3 THEN " + + " INSERT *", + commitTarget()); + + ImmutableList expectedRows = + ImmutableList.of( + row(2, "emp-id-2"), // updated + row(3, "emp-id-3") // new + ); + assertEquals( + "Should have expected rows", + expectedRows, + sql("SELECT * FROM %s ORDER BY id", selectTarget())); + } + + @Test + public void testMergeWithSourceFromSetOps() { + createAndInitTable( + "id INT, dep STRING", + "{ \"id\": 1, \"dep\": \"emp-id-one\" }\n" + "{ \"id\": 6, \"dep\": \"emp-id-6\" }"); + + createOrReplaceView( + "source", + "id INT, dep STRING", + "{ \"id\": 2, \"dep\": \"emp-id-2\" }\n" + + "{ \"id\": 1, \"dep\": \"emp-id-1\" }\n" + + "{ \"id\": 6, \"dep\": \"emp-id-6\" }"); + + String derivedSource = + "SELECT * FROM source WHERE id = 2 " + + "UNION ALL " + + "SELECT * FROM source WHERE id = 1 OR id = 6"; + + sql( + "MERGE INTO %s AS t USING (%s) AS s " + + "ON t.id == s.id " + + "WHEN MATCHED AND t.id = 1 THEN " + + " UPDATE SET * " + + "WHEN MATCHED AND t.id = 6 THEN " + + " DELETE " + + "WHEN NOT MATCHED AND s.id = 2 THEN " + + " INSERT *", + commitTarget(), derivedSource); + + ImmutableList expectedRows = + ImmutableList.of( + row(1, "emp-id-1"), // updated + row(2, "emp-id-2") // new + ); + assertEquals( + "Should have expected rows", + expectedRows, + sql("SELECT * FROM %s ORDER BY id", selectTarget())); + } + + @Test + public void testMergeWithMultipleUpdatesForTargetRowSmallTargetLargeSource() { + createAndInitTable( + "id INT, dep STRING", + "{ \"id\": 1, \"dep\": \"emp-id-one\" }\n" + "{ \"id\": 6, \"dep\": \"emp-id-6\" }"); + + List sourceIds = Lists.newArrayList(); + for (int i = 0; i < 10_000; i++) { + sourceIds.add(i); + } + Dataset ds = spark.createDataset(sourceIds, Encoders.INT()); + ds.union(ds).createOrReplaceTempView("source"); + + String errorMsg = "a single row from the target table with multiple rows of the source table"; + AssertHelpers.assertThrowsCause( + "Should complain about multiple matches", + SparkException.class, + errorMsg, + () -> { + sql( + "MERGE INTO %s AS t USING source AS s " + + "ON t.id == s.value " + + "WHEN MATCHED AND t.id = 1 THEN " + + " UPDATE SET id = 10 " + + "WHEN MATCHED AND t.id = 6 THEN " + + " DELETE " + + "WHEN NOT MATCHED AND s.value = 2 THEN " + + " INSERT (id, dep) VALUES (s.value, null)", + commitTarget()); + }); + + assertEquals( + "Target should be unchanged", + ImmutableList.of(row(1, "emp-id-one"), row(6, "emp-id-6")), + sql("SELECT * FROM %s ORDER BY id ASC NULLS LAST", selectTarget())); + } + + @Test + public void + testMergeWithMultipleUpdatesForTargetRowSmallTargetLargeSourceEnabledHashShuffleJoin() { + createAndInitTable( + "id INT, dep STRING", + "{ \"id\": 1, \"dep\": \"emp-id-one\" }\n" + "{ \"id\": 6, \"dep\": \"emp-id-6\" }"); + + List sourceIds = Lists.newArrayList(); + for (int i = 0; i < 10_000; i++) { + sourceIds.add(i); + } + Dataset ds = spark.createDataset(sourceIds, Encoders.INT()); + ds.union(ds).createOrReplaceTempView("source"); + + withSQLConf( + ImmutableMap.of(SQLConf.PREFER_SORTMERGEJOIN().key(), "false"), + () -> { + String errorMsg = + "a single row from the target table with multiple rows of the source table"; + AssertHelpers.assertThrowsCause( + "Should complain about multiple matches", + SparkException.class, + errorMsg, + () -> { + sql( + "MERGE INTO %s AS t USING source AS s " + + "ON t.id == s.value " + + "WHEN MATCHED AND t.id = 1 THEN " + + " UPDATE SET id = 10 " + + "WHEN MATCHED AND t.id = 6 THEN " + + " DELETE " + + "WHEN NOT MATCHED AND s.value = 2 THEN " + + " INSERT (id, dep) VALUES (s.value, null)", + commitTarget()); + }); + }); + + assertEquals( + "Target should be unchanged", + ImmutableList.of(row(1, "emp-id-one"), row(6, "emp-id-6")), + sql("SELECT * FROM %s ORDER BY id ASC NULLS LAST", selectTarget())); + } + + @Test + public void testMergeWithMultipleUpdatesForTargetRowSmallTargetLargeSourceNoEqualityCondition() { + createAndInitTable("id INT, dep STRING", "{ \"id\": 1, \"dep\": \"emp-id-one\" }"); + + List sourceIds = Lists.newArrayList(); + for (int i = 0; i < 10_000; i++) { + sourceIds.add(i); + } + Dataset ds = spark.createDataset(sourceIds, Encoders.INT()); + ds.union(ds).createOrReplaceTempView("source"); + + withSQLConf( + ImmutableMap.of(SQLConf.PREFER_SORTMERGEJOIN().key(), "false"), + () -> { + String errorMsg = + "a single row from the target table with multiple rows of the source table"; + AssertHelpers.assertThrowsCause( + "Should complain about multiple matches", + SparkException.class, + errorMsg, + () -> { + sql( + "MERGE INTO %s AS t USING source AS s " + + "ON t.id > s.value " + + "WHEN MATCHED AND t.id = 1 THEN " + + " UPDATE SET id = 10 " + + "WHEN MATCHED AND t.id = 6 THEN " + + " DELETE " + + "WHEN NOT MATCHED AND s.value = 2 THEN " + + " INSERT (id, dep) VALUES (s.value, null)", + commitTarget()); + }); + }); + + assertEquals( + "Target should be unchanged", + ImmutableList.of(row(1, "emp-id-one")), + sql("SELECT * FROM %s ORDER BY id ASC NULLS LAST", selectTarget())); + } + + @Test + public void testMergeWithMultipleUpdatesForTargetRowSmallTargetLargeSourceNoNotMatchedActions() { + createAndInitTable( + "id INT, dep STRING", + "{ \"id\": 1, \"dep\": \"emp-id-one\" }\n" + "{ \"id\": 6, \"dep\": \"emp-id-6\" }"); + + List sourceIds = Lists.newArrayList(); + for (int i = 0; i < 10_000; i++) { + sourceIds.add(i); + } + Dataset ds = spark.createDataset(sourceIds, Encoders.INT()); + ds.union(ds).createOrReplaceTempView("source"); + + String errorMsg = "a single row from the target table with multiple rows of the source table"; + AssertHelpers.assertThrowsCause( + "Should complain about multiple matches", + SparkException.class, + errorMsg, + () -> { + sql( + "MERGE INTO %s AS t USING source AS s " + + "ON t.id == s.value " + + "WHEN MATCHED AND t.id = 1 THEN " + + " UPDATE SET id = 10 " + + "WHEN MATCHED AND t.id = 6 THEN " + + " DELETE", + commitTarget()); + }); + + assertEquals( + "Target should be unchanged", + ImmutableList.of(row(1, "emp-id-one"), row(6, "emp-id-6")), + sql("SELECT * FROM %s ORDER BY id ASC NULLS LAST", selectTarget())); + } + + @Test + public void + testMergeWithMultipleUpdatesForTargetRowSmallTargetLargeSourceNoNotMatchedActionsNoEqualityCondition() { + createAndInitTable("id INT, dep STRING", "{ \"id\": 1, \"dep\": \"emp-id-one\" }"); + + List sourceIds = Lists.newArrayList(); + for (int i = 0; i < 10_000; i++) { + sourceIds.add(i); + } + Dataset ds = spark.createDataset(sourceIds, Encoders.INT()); + ds.union(ds).createOrReplaceTempView("source"); + + String errorMsg = "a single row from the target table with multiple rows of the source table"; + AssertHelpers.assertThrowsCause( + "Should complain about multiple matches", + SparkException.class, + errorMsg, + () -> { + sql( + "MERGE INTO %s AS t USING source AS s " + + "ON t.id > s.value " + + "WHEN MATCHED AND t.id = 1 THEN " + + " UPDATE SET id = 10 " + + "WHEN MATCHED AND t.id = 6 THEN " + + " DELETE", + commitTarget()); + }); + + assertEquals( + "Target should be unchanged", + ImmutableList.of(row(1, "emp-id-one")), + sql("SELECT * FROM %s ORDER BY id ASC NULLS LAST", selectTarget())); + } + + @Test + public void testMergeWithMultipleUpdatesForTargetRow() { + createAndInitTable( + "id INT, dep STRING", + "{ \"id\": 1, \"dep\": \"emp-id-one\" }\n" + "{ \"id\": 6, \"dep\": \"emp-id-6\" }"); + + createOrReplaceView( + "source", + "id INT, dep STRING", + "{ \"id\": 1, \"dep\": \"emp-id-1\" }\n" + + "{ \"id\": 1, \"dep\": \"emp-id-1\" }\n" + + "{ \"id\": 2, \"dep\": \"emp-id-2\" }\n" + + "{ \"id\": 6, \"dep\": \"emp-id-6\" }"); + + String errorMsg = "a single row from the target table with multiple rows of the source table"; + AssertHelpers.assertThrowsCause( + "Should complain about multiple matches", + SparkException.class, + errorMsg, + () -> { + sql( + "MERGE INTO %s AS t USING source AS s " + + "ON t.id == s.id " + + "WHEN MATCHED AND t.id = 1 THEN " + + " UPDATE SET * " + + "WHEN MATCHED AND t.id = 6 THEN " + + " DELETE " + + "WHEN NOT MATCHED AND s.id = 2 THEN " + + " INSERT *", + commitTarget()); + }); + + assertEquals( + "Target should be unchanged", + ImmutableList.of(row(1, "emp-id-one"), row(6, "emp-id-6")), + sql("SELECT * FROM %s ORDER BY id ASC NULLS LAST", selectTarget())); + } + + @Test + public void testMergeWithUnconditionalDelete() { + createAndInitTable( + "id INT, dep STRING", + "{ \"id\": 1, \"dep\": \"emp-id-one\" }\n" + "{ \"id\": 6, \"dep\": \"emp-id-6\" }"); + + createOrReplaceView( + "source", + "id INT, dep STRING", + "{ \"id\": 1, \"dep\": \"emp-id-1\" }\n" + + "{ \"id\": 1, \"dep\": \"emp-id-1\" }\n" + + "{ \"id\": 2, \"dep\": \"emp-id-2\" }\n" + + "{ \"id\": 6, \"dep\": \"emp-id-6\" }"); + + sql( + "MERGE INTO %s AS t USING source AS s " + + "ON t.id == s.id " + + "WHEN MATCHED THEN " + + " DELETE " + + "WHEN NOT MATCHED AND s.id = 2 THEN " + + " INSERT *", + commitTarget()); + + ImmutableList expectedRows = + ImmutableList.of( + row(2, "emp-id-2") // new + ); + assertEquals( + "Should have expected rows", + expectedRows, + sql("SELECT * FROM %s ORDER BY id", selectTarget())); + } + + @Test + public void testMergeWithSingleConditionalDelete() { + createAndInitTable( + "id INT, dep STRING", + "{ \"id\": 1, \"dep\": \"emp-id-one\" }\n" + "{ \"id\": 6, \"dep\": \"emp-id-6\" }"); + + createOrReplaceView( + "source", + "id INT, dep STRING", + "{ \"id\": 1, \"dep\": \"emp-id-1\" }\n" + + "{ \"id\": 1, \"dep\": \"emp-id-1\" }\n" + + "{ \"id\": 2, \"dep\": \"emp-id-2\" }\n" + + "{ \"id\": 6, \"dep\": \"emp-id-6\" }"); + + String errorMsg = "a single row from the target table with multiple rows of the source table"; + AssertHelpers.assertThrowsCause( + "Should complain about multiple matches", + SparkException.class, + errorMsg, + () -> { + sql( + "MERGE INTO %s AS t USING source AS s " + + "ON t.id == s.id " + + "WHEN MATCHED AND t.id = 1 THEN " + + " DELETE " + + "WHEN NOT MATCHED AND s.id = 2 THEN " + + " INSERT *", + commitTarget()); + }); + + assertEquals( + "Target should be unchanged", + ImmutableList.of(row(1, "emp-id-one"), row(6, "emp-id-6")), + sql("SELECT * FROM %s ORDER BY id ASC NULLS LAST", selectTarget())); + } + + @Test + public void testMergeWithIdentityTransform() { + for (DistributionMode mode : DistributionMode.values()) { + createAndInitTable("id INT, dep STRING"); + sql("ALTER TABLE %s ADD PARTITION FIELD identity(dep)", tableName); + sql( + "ALTER TABLE %s SET TBLPROPERTIES('%s' '%s')", + tableName, WRITE_DISTRIBUTION_MODE, mode.modeName()); + + append( + tableName, + "{ \"id\": 1, \"dep\": \"emp-id-one\" }\n" + "{ \"id\": 6, \"dep\": \"emp-id-6\" }"); + createBranchIfNeeded(); + + createOrReplaceView( + "source", + "id INT, dep STRING", + "{ \"id\": 2, \"dep\": \"emp-id-2\" }\n" + + "{ \"id\": 1, \"dep\": \"emp-id-1\" }\n" + + "{ \"id\": 6, \"dep\": \"emp-id-6\" }"); + + sql( + "MERGE INTO %s AS t USING source AS s " + + "ON t.id == s.id " + + "WHEN MATCHED AND t.id = 1 THEN " + + " UPDATE SET * " + + "WHEN MATCHED AND t.id = 6 THEN " + + " DELETE " + + "WHEN NOT MATCHED AND s.id = 2 THEN " + + " INSERT *", + commitTarget()); + + ImmutableList expectedRows = + ImmutableList.of( + row(1, "emp-id-1"), // updated + row(2, "emp-id-2") // new + ); + assertEquals( + "Should have expected rows", + expectedRows, + sql("SELECT * FROM %s ORDER BY id", selectTarget())); + + removeTables(); + } + } + + @Test + public void testMergeWithDaysTransform() { + for (DistributionMode mode : DistributionMode.values()) { + createAndInitTable("id INT, ts TIMESTAMP"); + sql("ALTER TABLE %s ADD PARTITION FIELD days(ts)", tableName); + sql( + "ALTER TABLE %s SET TBLPROPERTIES('%s' '%s')", + tableName, WRITE_DISTRIBUTION_MODE, mode.modeName()); + + append( + tableName, + "id INT, ts TIMESTAMP", + "{ \"id\": 1, \"ts\": \"2000-01-01 00:00:00\" }\n" + + "{ \"id\": 6, \"ts\": \"2000-01-06 00:00:00\" }"); + createBranchIfNeeded(); + + createOrReplaceView( + "source", + "id INT, ts TIMESTAMP", + "{ \"id\": 2, \"ts\": \"2001-01-02 00:00:00\" }\n" + + "{ \"id\": 1, \"ts\": \"2001-01-01 00:00:00\" }\n" + + "{ \"id\": 6, \"ts\": \"2001-01-06 00:00:00\" }"); + + sql( + "MERGE INTO %s AS t USING source AS s " + + "ON t.id == s.id " + + "WHEN MATCHED AND t.id = 1 THEN " + + " UPDATE SET * " + + "WHEN MATCHED AND t.id = 6 THEN " + + " DELETE " + + "WHEN NOT MATCHED AND s.id = 2 THEN " + + " INSERT *", + commitTarget()); + + ImmutableList expectedRows = + ImmutableList.of( + row(1, "2001-01-01 00:00:00"), // updated + row(2, "2001-01-02 00:00:00") // new + ); + assertEquals( + "Should have expected rows", + expectedRows, + sql("SELECT id, CAST(ts AS STRING) FROM %s ORDER BY id", selectTarget())); + + removeTables(); + } + } + + @Test + public void testMergeWithBucketTransform() { + for (DistributionMode mode : DistributionMode.values()) { + createAndInitTable("id INT, dep STRING"); + sql("ALTER TABLE %s ADD PARTITION FIELD bucket(2, dep)", tableName); + sql( + "ALTER TABLE %s SET TBLPROPERTIES('%s' '%s')", + tableName, WRITE_DISTRIBUTION_MODE, mode.modeName()); + + append( + tableName, + "{ \"id\": 1, \"dep\": \"emp-id-one\" }\n" + "{ \"id\": 6, \"dep\": \"emp-id-6\" }"); + createBranchIfNeeded(); + + createOrReplaceView( + "source", + "id INT, dep STRING", + "{ \"id\": 2, \"dep\": \"emp-id-2\" }\n" + + "{ \"id\": 1, \"dep\": \"emp-id-1\" }\n" + + "{ \"id\": 6, \"dep\": \"emp-id-6\" }"); + + sql( + "MERGE INTO %s AS t USING source AS s " + + "ON t.id == s.id " + + "WHEN MATCHED AND t.id = 1 THEN " + + " UPDATE SET * " + + "WHEN MATCHED AND t.id = 6 THEN " + + " DELETE " + + "WHEN NOT MATCHED AND s.id = 2 THEN " + + " INSERT *", + commitTarget()); + + ImmutableList expectedRows = + ImmutableList.of( + row(1, "emp-id-1"), // updated + row(2, "emp-id-2") // new + ); + assertEquals( + "Should have expected rows", + expectedRows, + sql("SELECT * FROM %s ORDER BY id", selectTarget())); + + removeTables(); + } + } + + @Test + public void testMergeWithTruncateTransform() { + for (DistributionMode mode : DistributionMode.values()) { + createAndInitTable("id INT, dep STRING"); + sql("ALTER TABLE %s ADD PARTITION FIELD truncate(dep, 2)", tableName); + sql( + "ALTER TABLE %s SET TBLPROPERTIES('%s' '%s')", + tableName, WRITE_DISTRIBUTION_MODE, mode.modeName()); + + append( + tableName, + "{ \"id\": 1, \"dep\": \"emp-id-one\" }\n" + "{ \"id\": 6, \"dep\": \"emp-id-6\" }"); + createBranchIfNeeded(); + + createOrReplaceView( + "source", + "id INT, dep STRING", + "{ \"id\": 2, \"dep\": \"emp-id-2\" }\n" + + "{ \"id\": 1, \"dep\": \"emp-id-1\" }\n" + + "{ \"id\": 6, \"dep\": \"emp-id-6\" }"); + + sql( + "MERGE INTO %s AS t USING source AS s " + + "ON t.id == s.id " + + "WHEN MATCHED AND t.id = 1 THEN " + + " UPDATE SET * " + + "WHEN MATCHED AND t.id = 6 THEN " + + " DELETE " + + "WHEN NOT MATCHED AND s.id = 2 THEN " + + " INSERT *", + commitTarget()); + + ImmutableList expectedRows = + ImmutableList.of( + row(1, "emp-id-1"), // updated + row(2, "emp-id-2") // new + ); + assertEquals( + "Should have expected rows", + expectedRows, + sql("SELECT * FROM %s ORDER BY id", selectTarget())); + + removeTables(); + } + } + + @Test + public void testMergeIntoPartitionedAndOrderedTable() { + for (DistributionMode mode : DistributionMode.values()) { + createAndInitTable("id INT, dep STRING"); + sql("ALTER TABLE %s ADD PARTITION FIELD dep", tableName); + sql("ALTER TABLE %s WRITE ORDERED BY (id)", tableName); + sql( + "ALTER TABLE %s SET TBLPROPERTIES('%s' '%s')", + tableName, WRITE_DISTRIBUTION_MODE, mode.modeName()); + + append( + tableName, + "{ \"id\": 1, \"dep\": \"emp-id-one\" }\n" + "{ \"id\": 6, \"dep\": \"emp-id-6\" }"); + createBranchIfNeeded(); + + createOrReplaceView( + "source", + "id INT, dep STRING", + "{ \"id\": 2, \"dep\": \"emp-id-2\" }\n" + + "{ \"id\": 1, \"dep\": \"emp-id-1\" }\n" + + "{ \"id\": 6, \"dep\": \"emp-id-6\" }"); + + sql( + "MERGE INTO %s AS t USING source AS s " + + "ON t.id == s.id " + + "WHEN MATCHED AND t.id = 1 THEN " + + " UPDATE SET * " + + "WHEN MATCHED AND t.id = 6 THEN " + + " DELETE " + + "WHEN NOT MATCHED AND s.id = 2 THEN " + + " INSERT *", + commitTarget()); + + ImmutableList expectedRows = + ImmutableList.of( + row(1, "emp-id-1"), // updated + row(2, "emp-id-2") // new + ); + assertEquals( + "Should have expected rows", + expectedRows, + sql("SELECT * FROM %s ORDER BY id", selectTarget())); + + removeTables(); + } + } + + @Test + public void testSelfMerge() { + createAndInitTable( + "id INT, v STRING", "{ \"id\": 1, \"v\": \"v1\" }\n" + "{ \"id\": 2, \"v\": \"v2\" }"); + + sql( + "MERGE INTO %s t USING %s s " + + "ON t.id == s.id " + + "WHEN MATCHED AND t.id = 1 THEN " + + " UPDATE SET v = 'x' " + + "WHEN NOT MATCHED THEN " + + " INSERT *", + commitTarget(), commitTarget()); + + ImmutableList expectedRows = + ImmutableList.of( + row(1, "x"), // updated + row(2, "v2") // kept + ); + assertEquals( + "Output should match", expectedRows, sql("SELECT * FROM %s ORDER BY id", selectTarget())); + } + + @Test + public void testSelfMergeWithCaching() { + createAndInitTable( + "id INT, v STRING", "{ \"id\": 1, \"v\": \"v1\" }\n" + "{ \"id\": 2, \"v\": \"v2\" }"); + + sql("CACHE TABLE %s", tableName); + + sql( + "MERGE INTO %s t USING %s s " + + "ON t.id == s.id " + + "WHEN MATCHED AND t.id = 1 THEN " + + " UPDATE SET v = 'x' " + + "WHEN NOT MATCHED THEN " + + " INSERT *", + commitTarget(), commitTarget()); + + ImmutableList expectedRows = + ImmutableList.of( + row(1, "x"), // updated + row(2, "v2") // kept + ); + assertEquals( + "Output should match", expectedRows, sql("SELECT * FROM %s ORDER BY id", commitTarget())); + } + + @Test + public void testMergeWithSourceAsSelfSubquery() { + createAndInitTable( + "id INT, v STRING", "{ \"id\": 1, \"v\": \"v1\" }\n" + "{ \"id\": 2, \"v\": \"v2\" }"); + + createOrReplaceView("source", Arrays.asList(1, null), Encoders.INT()); + + sql( + "MERGE INTO %s t USING (SELECT id AS value FROM %s r JOIN source ON r.id = source.value) s " + + "ON t.id == s.value " + + "WHEN MATCHED AND t.id = 1 THEN " + + " UPDATE SET v = 'x' " + + "WHEN NOT MATCHED THEN " + + " INSERT (v, id) VALUES ('invalid', -1) ", + commitTarget(), commitTarget()); + + ImmutableList expectedRows = + ImmutableList.of( + row(1, "x"), // updated + row(2, "v2") // kept + ); + assertEquals( + "Output should match", expectedRows, sql("SELECT * FROM %s ORDER BY id", selectTarget())); + } + + @Test + public synchronized void testMergeWithSerializableIsolation() throws InterruptedException { + // cannot run tests with concurrency for Hadoop tables without atomic renames + Assume.assumeFalse(catalogName.equalsIgnoreCase("testhadoop")); + + createAndInitTable("id INT, dep STRING"); + createOrReplaceView("source", Collections.singletonList(1), Encoders.INT()); + + sql( + "ALTER TABLE %s SET TBLPROPERTIES('%s' '%s')", + tableName, MERGE_ISOLATION_LEVEL, "serializable"); + + sql("INSERT INTO TABLE %s VALUES (1, 'hr')", tableName); + createBranchIfNeeded(); + + ExecutorService executorService = + MoreExecutors.getExitingExecutorService( + (ThreadPoolExecutor) Executors.newFixedThreadPool(2)); + + AtomicInteger barrier = new AtomicInteger(0); + AtomicBoolean shouldAppend = new AtomicBoolean(true); + + // merge thread + Future mergeFuture = + executorService.submit( + () -> { + for (int numOperations = 0; numOperations < Integer.MAX_VALUE; numOperations++) { + while (barrier.get() < numOperations * 2) { + sleep(10); + } + + sql( + "MERGE INTO %s t USING source s " + + "ON t.id == s.value " + + "WHEN MATCHED THEN " + + " UPDATE SET dep = 'x'", + commitTarget()); + + barrier.incrementAndGet(); + } + }); + + // append thread + Future appendFuture = + executorService.submit( + () -> { + // load the table via the validation catalog to use another table instance + Table table = validationCatalog.loadTable(tableIdent); + + GenericRecord record = GenericRecord.create(table.schema()); + record.set(0, 1); // id + record.set(1, "hr"); // dep + + for (int numOperations = 0; numOperations < Integer.MAX_VALUE; numOperations++) { + while (shouldAppend.get() && barrier.get() < numOperations * 2) { + sleep(10); + } + + if (!shouldAppend.get()) { + return; + } + + for (int numAppends = 0; numAppends < 5; numAppends++) { + DataFile dataFile = writeDataFile(table, ImmutableList.of(record)); + AppendFiles appendFiles = table.newFastAppend().appendFile(dataFile); + if (branch != null) { + appendFiles.toBranch(branch); + } + appendFiles.commit(); + sleep(10); + } + + barrier.incrementAndGet(); + } + }); + + try { + Assertions.assertThatThrownBy(mergeFuture::get) + .isInstanceOf(ExecutionException.class) + .cause() + .isInstanceOf(ValidationException.class) + .hasMessageContaining("Found conflicting files that can contain"); + } finally { + shouldAppend.set(false); + appendFuture.cancel(true); + } + + executorService.shutdown(); + Assert.assertTrue("Timeout", executorService.awaitTermination(2, TimeUnit.MINUTES)); + } + + @Test + public synchronized void testMergeWithSnapshotIsolation() + throws InterruptedException, ExecutionException { + // cannot run tests with concurrency for Hadoop tables without atomic renames + Assume.assumeFalse(catalogName.equalsIgnoreCase("testhadoop")); + + createAndInitTable("id INT, dep STRING"); + createOrReplaceView("source", Collections.singletonList(1), Encoders.INT()); + + sql( + "ALTER TABLE %s SET TBLPROPERTIES('%s' '%s')", + tableName, MERGE_ISOLATION_LEVEL, "snapshot"); + + sql("INSERT INTO TABLE %s VALUES (1, 'hr')", tableName); + createBranchIfNeeded(); + + ExecutorService executorService = + MoreExecutors.getExitingExecutorService( + (ThreadPoolExecutor) Executors.newFixedThreadPool(2)); + + AtomicInteger barrier = new AtomicInteger(0); + AtomicBoolean shouldAppend = new AtomicBoolean(true); + + // merge thread + Future mergeFuture = + executorService.submit( + () -> { + for (int numOperations = 0; numOperations < 20; numOperations++) { + while (barrier.get() < numOperations * 2) { + sleep(10); + } + + sql( + "MERGE INTO %s t USING source s " + + "ON t.id == s.value " + + "WHEN MATCHED THEN " + + " UPDATE SET dep = 'x'", + commitTarget()); + + barrier.incrementAndGet(); + } + }); + + // append thread + Future appendFuture = + executorService.submit( + () -> { + // load the table via the validation catalog to use another table instance for inserts + Table table = validationCatalog.loadTable(tableIdent); + + GenericRecord record = GenericRecord.create(table.schema()); + record.set(0, 1); // id + record.set(1, "hr"); // dep + + for (int numOperations = 0; numOperations < 20; numOperations++) { + while (shouldAppend.get() && barrier.get() < numOperations * 2) { + sleep(10); + } + + if (!shouldAppend.get()) { + return; + } + + for (int numAppends = 0; numAppends < 5; numAppends++) { + DataFile dataFile = writeDataFile(table, ImmutableList.of(record)); + AppendFiles appendFiles = table.newFastAppend().appendFile(dataFile); + if (branch != null) { + appendFiles.toBranch(branch); + } + + appendFiles.commit(); + sleep(10); + } + + barrier.incrementAndGet(); + } + }); + + try { + mergeFuture.get(); + } finally { + shouldAppend.set(false); + appendFuture.cancel(true); + } + + executorService.shutdown(); + Assert.assertTrue("Timeout", executorService.awaitTermination(2, TimeUnit.MINUTES)); + } + + @Test + public void testMergeWithExtraColumnsInSource() { + createAndInitTable( + "id INT, v STRING", "{ \"id\": 1, \"v\": \"v1\" }\n" + "{ \"id\": 2, \"v\": \"v2\" }"); + createOrReplaceView( + "source", + "{ \"id\": 1, \"extra_col\": -1, \"v\": \"v1_1\" }\n" + + "{ \"id\": 3, \"extra_col\": -1, \"v\": \"v3\" }\n" + + "{ \"id\": 4, \"extra_col\": -1, \"v\": \"v4\" }"); + + sql( + "MERGE INTO %s t USING source " + + "ON t.id == source.id " + + "WHEN MATCHED THEN " + + " UPDATE SET v = source.v " + + "WHEN NOT MATCHED THEN " + + " INSERT (v, id) VALUES (source.v, source.id)", + commitTarget()); + + ImmutableList expectedRows = + ImmutableList.of( + row(1, "v1_1"), // new + row(2, "v2"), // kept + row(3, "v3"), // new + row(4, "v4") // new + ); + assertEquals( + "Output should match", expectedRows, sql("SELECT * FROM %s ORDER BY id", selectTarget())); + } + + @Test + public void testMergeWithNullsInTargetAndSource() { + createAndInitTable( + "id INT, v STRING", "{ \"id\": null, \"v\": \"v1\" }\n" + "{ \"id\": 2, \"v\": \"v2\" }"); + + createOrReplaceView( + "source", "{ \"id\": null, \"v\": \"v1_1\" }\n" + "{ \"id\": 4, \"v\": \"v4\" }"); + + sql( + "MERGE INTO %s t USING source " + + "ON t.id == source.id " + + "WHEN MATCHED THEN " + + " UPDATE SET v = source.v " + + "WHEN NOT MATCHED THEN " + + " INSERT (v, id) VALUES (source.v, source.id)", + commitTarget()); + + ImmutableList expectedRows = + ImmutableList.of( + row(null, "v1"), // kept + row(null, "v1_1"), // new + row(2, "v2"), // kept + row(4, "v4") // new + ); + assertEquals( + "Output should match", expectedRows, sql("SELECT * FROM %s ORDER BY v", selectTarget())); + } + + @Test + public void testMergeWithNullSafeEquals() { + createAndInitTable( + "id INT, v STRING", "{ \"id\": null, \"v\": \"v1\" }\n" + "{ \"id\": 2, \"v\": \"v2\" }"); + + createOrReplaceView( + "source", "{ \"id\": null, \"v\": \"v1_1\" }\n" + "{ \"id\": 4, \"v\": \"v4\" }"); + + sql( + "MERGE INTO %s t USING source " + + "ON t.id <=> source.id " + + "WHEN MATCHED THEN " + + " UPDATE SET v = source.v " + + "WHEN NOT MATCHED THEN " + + " INSERT (v, id) VALUES (source.v, source.id)", + commitTarget()); + + ImmutableList expectedRows = + ImmutableList.of( + row(null, "v1_1"), // updated + row(2, "v2"), // kept + row(4, "v4") // new + ); + assertEquals( + "Output should match", expectedRows, sql("SELECT * FROM %s ORDER BY v", selectTarget())); + } + + @Test + public void testMergeWithNullCondition() { + createAndInitTable( + "id INT, v STRING", "{ \"id\": null, \"v\": \"v1\" }\n" + "{ \"id\": 2, \"v\": \"v2\" }"); + + createOrReplaceView( + "source", "{ \"id\": null, \"v\": \"v1_1\" }\n" + "{ \"id\": 2, \"v\": \"v2_2\" }"); + + sql( + "MERGE INTO %s t USING source " + + "ON t.id == source.id AND NULL " + + "WHEN MATCHED THEN " + + " UPDATE SET v = source.v " + + "WHEN NOT MATCHED THEN " + + " INSERT (v, id) VALUES (source.v, source.id)", + commitTarget()); + + ImmutableList expectedRows = + ImmutableList.of( + row(null, "v1"), // kept + row(null, "v1_1"), // new + row(2, "v2"), // kept + row(2, "v2_2") // new + ); + assertEquals( + "Output should match", expectedRows, sql("SELECT * FROM %s ORDER BY v", selectTarget())); + } + + @Test + public void testMergeWithNullActionConditions() { + createAndInitTable( + "id INT, v STRING", "{ \"id\": 1, \"v\": \"v1\" }\n" + "{ \"id\": 2, \"v\": \"v2\" }"); + + createOrReplaceView( + "source", + "{ \"id\": 1, \"v\": \"v1_1\" }\n" + + "{ \"id\": 2, \"v\": \"v2_2\" }\n" + + "{ \"id\": 3, \"v\": \"v3_3\" }"); + + // all conditions are NULL and will never match any rows + sql( + "MERGE INTO %s t USING source " + + "ON t.id == source.id " + + "WHEN MATCHED AND source.id = 1 AND NULL THEN " + + " UPDATE SET v = source.v " + + "WHEN MATCHED AND source.v = 'v1_1' AND NULL THEN " + + " DELETE " + + "WHEN NOT MATCHED AND source.id = 3 AND NULL THEN " + + " INSERT (v, id) VALUES (source.v, source.id)", + commitTarget()); + + ImmutableList expectedRows1 = + ImmutableList.of( + row(1, "v1"), // kept + row(2, "v2") // kept + ); + assertEquals( + "Output should match", expectedRows1, sql("SELECT * FROM %s ORDER BY v", selectTarget())); + + // only the update and insert conditions are NULL + sql( + "MERGE INTO %s t USING source " + + "ON t.id == source.id " + + "WHEN MATCHED AND source.id = 1 AND NULL THEN " + + " UPDATE SET v = source.v " + + "WHEN MATCHED AND source.v = 'v1_1' THEN " + + " DELETE " + + "WHEN NOT MATCHED AND source.id = 3 AND NULL THEN " + + " INSERT (v, id) VALUES (source.v, source.id)", + commitTarget()); + + ImmutableList expectedRows2 = + ImmutableList.of( + row(2, "v2") // kept + ); + assertEquals( + "Output should match", expectedRows2, sql("SELECT * FROM %s ORDER BY v", selectTarget())); + } + + @Test + public void testMergeWithMultipleMatchingActions() { + createAndInitTable( + "id INT, v STRING", "{ \"id\": 1, \"v\": \"v1\" }\n" + "{ \"id\": 2, \"v\": \"v2\" }"); + + createOrReplaceView( + "source", "{ \"id\": 1, \"v\": \"v1_1\" }\n" + "{ \"id\": 2, \"v\": \"v2_2\" }"); + + // the order of match actions is important in this case + sql( + "MERGE INTO %s t USING source " + + "ON t.id == source.id " + + "WHEN MATCHED AND source.id = 1 THEN " + + " UPDATE SET v = source.v " + + "WHEN MATCHED AND source.v = 'v1_1' THEN " + + " DELETE " + + "WHEN NOT MATCHED THEN " + + " INSERT (v, id) VALUES (source.v, source.id)", + commitTarget()); + + ImmutableList expectedRows = + ImmutableList.of( + row(1, "v1_1"), // updated (also matches the delete cond but update is first) + row(2, "v2") // kept (matches neither the update nor the delete cond) + ); + assertEquals( + "Output should match", expectedRows, sql("SELECT * FROM %s ORDER BY v", selectTarget())); + } + + @Test + public void testMergeWithMultipleRowGroupsParquet() throws NoSuchTableException { + Assume.assumeTrue(fileFormat.equalsIgnoreCase("parquet")); + + createAndInitTable("id INT, dep STRING"); + sql("ALTER TABLE %s ADD PARTITION FIELD dep", tableName); + + sql( + "ALTER TABLE %s SET TBLPROPERTIES('%s' '%d')", + tableName, PARQUET_ROW_GROUP_SIZE_BYTES, 100); + sql("ALTER TABLE %s SET TBLPROPERTIES('%s' '%d')", tableName, SPLIT_SIZE, 100); + + createOrReplaceView("source", Collections.singletonList(1), Encoders.INT()); + + List ids = Lists.newArrayListWithCapacity(200); + for (int id = 1; id <= 200; id++) { + ids.add(id); + } + Dataset df = + spark + .createDataset(ids, Encoders.INT()) + .withColumnRenamed("value", "id") + .withColumn("dep", lit("hr")); + df.coalesce(1).writeTo(tableName).append(); + createBranchIfNeeded(); + + Assert.assertEquals(200, spark.table(commitTarget()).count()); + + // update a record from one of two row groups and copy over the second one + sql( + "MERGE INTO %s t USING source " + + "ON t.id == source.value " + + "WHEN MATCHED THEN " + + " UPDATE SET dep = 'x'", + commitTarget()); + + Assert.assertEquals(200, spark.table(commitTarget()).count()); + } + + @Test + public void testMergeInsertOnly() { + createAndInitTable( + "id STRING, v STRING", + "{ \"id\": \"a\", \"v\": \"v1\" }\n" + "{ \"id\": \"b\", \"v\": \"v2\" }"); + createOrReplaceView( + "source", + "{ \"id\": \"a\", \"v\": \"v1_1\" }\n" + + "{ \"id\": \"a\", \"v\": \"v1_2\" }\n" + + "{ \"id\": \"c\", \"v\": \"v3\" }\n" + + "{ \"id\": \"d\", \"v\": \"v4_1\" }\n" + + "{ \"id\": \"d\", \"v\": \"v4_2\" }"); + + sql( + "MERGE INTO %s t USING source " + + "ON t.id == source.id " + + "WHEN NOT MATCHED THEN " + + " INSERT *", + commitTarget()); + + ImmutableList expectedRows = + ImmutableList.of( + row("a", "v1"), // kept + row("b", "v2"), // kept + row("c", "v3"), // new + row("d", "v4_1"), // new + row("d", "v4_2") // new + ); + assertEquals( + "Output should match", expectedRows, sql("SELECT * FROM %s ORDER BY id", selectTarget())); + } + + @Test + public void testMergeInsertOnlyWithCondition() { + createAndInitTable("id INTEGER, v INTEGER", "{ \"id\": 1, \"v\": 1 }"); + createOrReplaceView( + "source", + "{ \"id\": 1, \"v\": 11, \"is_new\": true }\n" + + "{ \"id\": 2, \"v\": 21, \"is_new\": true }\n" + + "{ \"id\": 2, \"v\": 22, \"is_new\": false }"); + + // validate assignments are reordered to match the table attrs + sql( + "MERGE INTO %s t USING source s " + + "ON t.id == s.id " + + "WHEN NOT MATCHED AND is_new = TRUE THEN " + + " INSERT (v, id) VALUES (s.v + 100, s.id)", + commitTarget()); + + ImmutableList expectedRows = + ImmutableList.of( + row(1, 1), // kept + row(2, 121) // new + ); + assertEquals( + "Output should match", expectedRows, sql("SELECT * FROM %s ORDER BY id", selectTarget())); + } + + @Test + public void testMergeAlignsUpdateAndInsertActions() { + createAndInitTable("id INT, a INT, b STRING", "{ \"id\": 1, \"a\": 2, \"b\": \"str\" }"); + createOrReplaceView( + "source", + "{ \"id\": 1, \"c1\": -2, \"c2\": \"new_str_1\" }\n" + + "{ \"id\": 2, \"c1\": -20, \"c2\": \"new_str_2\" }"); + + sql( + "MERGE INTO %s t USING source " + + "ON t.id == source.id " + + "WHEN MATCHED THEN " + + " UPDATE SET b = c2, a = c1, t.id = source.id " + + "WHEN NOT MATCHED THEN " + + " INSERT (b, a, id) VALUES (c2, c1, id)", + commitTarget()); + + assertEquals( + "Output should match", + ImmutableList.of(row(1, -2, "new_str_1"), row(2, -20, "new_str_2")), + sql("SELECT * FROM %s ORDER BY id", selectTarget())); + } + + @Test + public void testMergeMixedCaseAlignsUpdateAndInsertActions() { + createAndInitTable("id INT, a INT, b STRING", "{ \"id\": 1, \"a\": 2, \"b\": \"str\" }"); + createOrReplaceView( + "source", + "{ \"id\": 1, \"c1\": -2, \"c2\": \"new_str_1\" }\n" + + "{ \"id\": 2, \"c1\": -20, \"c2\": \"new_str_2\" }"); + + sql( + "MERGE INTO %s t USING source " + + "ON t.iD == source.Id " + + "WHEN MATCHED THEN " + + " UPDATE SET B = c2, A = c1, t.Id = source.ID " + + "WHEN NOT MATCHED THEN " + + " INSERT (b, A, iD) VALUES (c2, c1, id)", + commitTarget()); + + assertEquals( + "Output should match", + ImmutableList.of(row(1, -2, "new_str_1"), row(2, -20, "new_str_2")), + sql("SELECT * FROM %s ORDER BY id", selectTarget())); + + assertEquals( + "Output should match", + ImmutableList.of(row(1, -2, "new_str_1")), + sql("SELECT * FROM %s WHERE id = 1 ORDER BY id", selectTarget())); + assertEquals( + "Output should match", + ImmutableList.of(row(2, -20, "new_str_2")), + sql("SELECT * FROM %s WHERE b = 'new_str_2'ORDER BY id", selectTarget())); + } + + @Test + public void testMergeUpdatesNestedStructFields() { + createAndInitTable( + "id INT, s STRUCT,m:MAP>>", + "{ \"id\": 1, \"s\": { \"c1\": 2, \"c2\": { \"a\": [1,2], \"m\": { \"a\": \"b\"} } } } }"); + createOrReplaceView("source", "{ \"id\": 1, \"c1\": -2 }"); + + // update primitive, array, map columns inside a struct + sql( + "MERGE INTO %s t USING source " + + "ON t.id == source.id " + + "WHEN MATCHED THEN " + + " UPDATE SET t.s.c1 = source.c1, t.s.c2.a = array(-1, -2), t.s.c2.m = map('k', 'v')", + commitTarget()); + + assertEquals( + "Output should match", + ImmutableList.of(row(1, row(-2, row(ImmutableList.of(-1, -2), ImmutableMap.of("k", "v"))))), + sql("SELECT * FROM %s ORDER BY id", selectTarget())); + + // set primitive, array, map columns to NULL (proper casts should be in place) + sql( + "MERGE INTO %s t USING source " + + "ON t.id == source.id " + + "WHEN MATCHED THEN " + + " UPDATE SET t.s.c1 = NULL, t.s.c2 = NULL", + commitTarget()); + + assertEquals( + "Output should match", + ImmutableList.of(row(1, row(null, null))), + sql("SELECT * FROM %s ORDER BY id", selectTarget())); + + // update all fields in a struct + sql( + "MERGE INTO %s t USING source " + + "ON t.id == source.id " + + "WHEN MATCHED THEN " + + " UPDATE SET t.s = named_struct('c1', 100, 'c2', named_struct('a', array(1), 'm', map('x', 'y')))", + commitTarget()); + + assertEquals( + "Output should match", + ImmutableList.of(row(1, row(100, row(ImmutableList.of(1), ImmutableMap.of("x", "y"))))), + sql("SELECT * FROM %s ORDER BY id", selectTarget())); + } + + @Test + public void testMergeWithInferredCasts() { + createAndInitTable("id INT, s STRING", "{ \"id\": 1, \"s\": \"value\" }"); + createOrReplaceView("source", "{ \"id\": 1, \"c1\": -2}"); + + // -2 in source should be casted to "-2" in target + sql( + "MERGE INTO %s t USING source " + + "ON t.id == source.id " + + "WHEN MATCHED THEN " + + " UPDATE SET t.s = source.c1", + commitTarget()); + + assertEquals( + "Output should match", + ImmutableList.of(row(1, "-2")), + sql("SELECT * FROM %s ORDER BY id", selectTarget())); + } + + @Test + public void testMergeModifiesNullStruct() { + createAndInitTable("id INT, s STRUCT", "{ \"id\": 1, \"s\": null }"); + createOrReplaceView("source", "{ \"id\": 1, \"n1\": -10 }"); + + sql( + "MERGE INTO %s t USING source s " + + "ON t.id == s.id " + + "WHEN MATCHED THEN " + + " UPDATE SET t.s.n1 = s.n1", + commitTarget()); + + assertEquals( + "Output should match", + ImmutableList.of(row(1, row(-10, null))), + sql("SELECT * FROM %s", selectTarget())); + } + + @Test + public void testMergeRefreshesRelationCache() { + createAndInitTable("id INT, name STRING", "{ \"id\": 1, \"name\": \"n1\" }"); + createOrReplaceView("source", "{ \"id\": 1, \"name\": \"n2\" }"); + + Dataset query = spark.sql("SELECT name FROM " + commitTarget()); + query.createOrReplaceTempView("tmp"); + + spark.sql("CACHE TABLE tmp"); + + assertEquals( + "View should have correct data", ImmutableList.of(row("n1")), sql("SELECT * FROM tmp")); + + sql( + "MERGE INTO %s t USING source s " + + "ON t.id == s.id " + + "WHEN MATCHED THEN " + + " UPDATE SET t.name = s.name", + commitTarget()); + + assertEquals( + "View should have correct data", ImmutableList.of(row("n2")), sql("SELECT * FROM tmp")); + + spark.sql("UNCACHE TABLE tmp"); + } + + @Test + public void testMergeWithMultipleNotMatchedActions() { + createAndInitTable("id INT, dep STRING", "{ \"id\": 0, \"dep\": \"emp-id-0\" }"); + + createOrReplaceView( + "source", + "id INT, dep STRING", + "{ \"id\": 1, \"dep\": \"emp-id-1\" }\n" + + "{ \"id\": 2, \"dep\": \"emp-id-2\" }\n" + + "{ \"id\": 3, \"dep\": \"emp-id-3\" }"); + + sql( + "MERGE INTO %s AS t USING source AS s " + + "ON t.id == s.id " + + "WHEN NOT MATCHED AND s.id = 1 THEN " + + " INSERT (dep, id) VALUES (s.dep, -1)" + + "WHEN NOT MATCHED THEN " + + " INSERT *", + commitTarget()); + + ImmutableList expectedRows = + ImmutableList.of( + row(-1, "emp-id-1"), // new + row(0, "emp-id-0"), // kept + row(2, "emp-id-2"), // new + row(3, "emp-id-3") // new + ); + assertEquals( + "Should have expected rows", + expectedRows, + sql("SELECT * FROM %s ORDER BY id", selectTarget())); + } + + @Test + public void testMergeWithMultipleConditionalNotMatchedActions() { + createAndInitTable("id INT, dep STRING", "{ \"id\": 0, \"dep\": \"emp-id-0\" }"); + + createOrReplaceView( + "source", + "id INT, dep STRING", + "{ \"id\": 1, \"dep\": \"emp-id-1\" }\n" + + "{ \"id\": 2, \"dep\": \"emp-id-2\" }\n" + + "{ \"id\": 3, \"dep\": \"emp-id-3\" }"); + + sql( + "MERGE INTO %s AS t USING source AS s " + + "ON t.id == s.id " + + "WHEN NOT MATCHED AND s.id = 1 THEN " + + " INSERT (dep, id) VALUES (s.dep, -1)" + + "WHEN NOT MATCHED AND s.id = 2 THEN " + + " INSERT *", + commitTarget()); + + ImmutableList expectedRows = + ImmutableList.of( + row(-1, "emp-id-1"), // new + row(0, "emp-id-0"), // kept + row(2, "emp-id-2") // new + ); + assertEquals( + "Should have expected rows", + expectedRows, + sql("SELECT * FROM %s ORDER BY id", selectTarget())); + } + + @Test + public void testMergeResolvesColumnsByName() { + createAndInitTable( + "id INT, badge INT, dep STRING", + "{ \"id\": 1, \"badge\": 1000, \"dep\": \"emp-id-one\" }\n" + + "{ \"id\": 6, \"badge\": 6000, \"dep\": \"emp-id-6\" }"); + + createOrReplaceView( + "source", + "badge INT, id INT, dep STRING", + "{ \"badge\": 1001, \"id\": 1, \"dep\": \"emp-id-1\" }\n" + + "{ \"badge\": 6006, \"id\": 6, \"dep\": \"emp-id-6\" }\n" + + "{ \"badge\": 7007, \"id\": 7, \"dep\": \"emp-id-7\" }"); + + sql( + "MERGE INTO %s AS t USING source AS s " + + "ON t.id == s.id " + + "WHEN MATCHED THEN " + + " UPDATE SET * " + + "WHEN NOT MATCHED THEN " + + " INSERT * ", + commitTarget()); + + ImmutableList expectedRows = + ImmutableList.of( + row(1, 1001, "emp-id-1"), // updated + row(6, 6006, "emp-id-6"), // updated + row(7, 7007, "emp-id-7") // new + ); + assertEquals( + "Should have expected rows", + expectedRows, + sql("SELECT id, badge, dep FROM %s ORDER BY id", selectTarget())); + } + + @Test + public void testMergeShouldResolveWhenThereAreNoUnresolvedExpressionsOrColumns() { + // ensures that MERGE INTO will resolve into the correct action even if no columns + // or otherwise unresolved expressions exist in the query (testing SPARK-34962) + createAndInitTable("id INT, dep STRING"); + + createOrReplaceView( + "source", + "id INT, dep STRING", + "{ \"id\": 1, \"dep\": \"emp-id-1\" }\n" + + "{ \"id\": 2, \"dep\": \"emp-id-2\" }\n" + + "{ \"id\": 3, \"dep\": \"emp-id-3\" }"); + + sql( + "MERGE INTO %s AS t USING source AS s " + + "ON 1 != 1 " + + "WHEN MATCHED THEN " + + " UPDATE SET * " + + "WHEN NOT MATCHED THEN " + + " INSERT *", + tableName); + createBranchIfNeeded(); + + ImmutableList expectedRows = + ImmutableList.of( + row(1, "emp-id-1"), // new + row(2, "emp-id-2"), // new + row(3, "emp-id-3") // new + ); + assertEquals( + "Should have expected rows", + expectedRows, + sql("SELECT * FROM %s ORDER BY id", selectTarget())); + } + + @Test + public void testMergeWithTableWithNonNullableColumn() { + createAndInitTable( + "id INT NOT NULL, dep STRING", + "{ \"id\": 1, \"dep\": \"emp-id-one\" }\n" + "{ \"id\": 6, \"dep\": \"emp-id-6\" }"); + + createOrReplaceView( + "source", + "id INT NOT NULL, dep STRING", + "{ \"id\": 2, \"dep\": \"emp-id-2\" }\n" + + "{ \"id\": 1, \"dep\": \"emp-id-1\" }\n" + + "{ \"id\": 6, \"dep\": \"emp-id-6\" }"); + + sql( + "MERGE INTO %s AS t USING source AS s " + + "ON t.id == s.id " + + "WHEN MATCHED AND t.id = 1 THEN " + + " UPDATE SET * " + + "WHEN MATCHED AND t.id = 6 THEN " + + " DELETE " + + "WHEN NOT MATCHED AND s.id = 2 THEN " + + " INSERT *", + commitTarget()); + + ImmutableList expectedRows = + ImmutableList.of( + row(1, "emp-id-1"), // updated + row(2, "emp-id-2")); // new + assertEquals( + "Should have expected rows", + expectedRows, + sql("SELECT * FROM %s ORDER BY id", selectTarget())); + } + + @Test + public void testMergeWithNonExistingColumns() { + createAndInitTable( + "id INT, c STRUCT>", + "{ \"id\": 1, \"c\": { \"n1\": 2, \"n2\": { \"dn1\": 3, \"dn2\": 4 } } }"); + createOrReplaceView("source", "{ \"c1\": -100, \"c2\": -200 }"); + + AssertHelpers.assertThrows( + "Should complain about the invalid top-level column", + AnalysisException.class, + "cannot resolve t.invalid_col", + () -> { + sql( + "MERGE INTO %s t USING source s " + + "ON t.id == s.c1 " + + "WHEN MATCHED THEN " + + " UPDATE SET t.invalid_col = s.c2", + commitTarget()); + }); + + AssertHelpers.assertThrows( + "Should complain about the invalid nested column", + AnalysisException.class, + "No such struct field `invalid_col`", + () -> { + sql( + "MERGE INTO %s t USING source s " + + "ON t.id == s.c1 " + + "WHEN MATCHED THEN " + + " UPDATE SET t.c.n2.invalid_col = s.c2", + commitTarget()); + }); + + AssertHelpers.assertThrows( + "Should complain about the invalid top-level column", + AnalysisException.class, + "cannot resolve invalid_col", + () -> { + sql( + "MERGE INTO %s t USING source s " + + "ON t.id == s.c1 " + + "WHEN MATCHED THEN " + + " UPDATE SET t.c.n2.dn1 = s.c2 " + + "WHEN NOT MATCHED THEN " + + " INSERT (id, invalid_col) VALUES (s.c1, null)", + commitTarget()); + }); + } + + @Test + public void testMergeWithInvalidColumnsInInsert() { + createAndInitTable( + "id INT, c STRUCT>", + "{ \"id\": 1, \"c\": { \"n1\": 2, \"n2\": { \"dn1\": 3, \"dn2\": 4 } } }"); + createOrReplaceView("source", "{ \"c1\": -100, \"c2\": -200 }"); + + AssertHelpers.assertThrows( + "Should complain about the nested column", + AnalysisException.class, + "Nested fields are not supported inside INSERT clauses", + () -> { + sql( + "MERGE INTO %s t USING source s " + + "ON t.id == s.c1 " + + "WHEN MATCHED THEN " + + " UPDATE SET t.c.n2.dn1 = s.c2 " + + "WHEN NOT MATCHED THEN " + + " INSERT (id, c.n2) VALUES (s.c1, null)", + commitTarget()); + }); + + AssertHelpers.assertThrows( + "Should complain about duplicate columns", + AnalysisException.class, + "Duplicate column names inside INSERT clause", + () -> { + sql( + "MERGE INTO %s t USING source s " + + "ON t.id == s.c1 " + + "WHEN MATCHED THEN " + + " UPDATE SET t.c.n2.dn1 = s.c2 " + + "WHEN NOT MATCHED THEN " + + " INSERT (id, id) VALUES (s.c1, null)", + commitTarget()); + }); + + AssertHelpers.assertThrows( + "Should complain about missing columns", + AnalysisException.class, + "must provide values for all columns of the target table", + () -> { + sql( + "MERGE INTO %s t USING source s " + + "ON t.id == s.c1 " + + "WHEN NOT MATCHED THEN " + + " INSERT (id) VALUES (s.c1)", + commitTarget()); + }); + } + + @Test + public void testMergeWithInvalidUpdates() { + createAndInitTable( + "id INT, a ARRAY>, m MAP", + "{ \"id\": 1, \"a\": [ { \"c1\": 2, \"c2\": 3 } ], \"m\": { \"k\": \"v\"} }"); + createOrReplaceView("source", "{ \"c1\": -100, \"c2\": -200 }"); + + AssertHelpers.assertThrows( + "Should complain about updating an array column", + AnalysisException.class, + "Updating nested fields is only supported for structs", + () -> { + sql( + "MERGE INTO %s t USING source s " + + "ON t.id == s.c1 " + + "WHEN MATCHED THEN " + + " UPDATE SET t.a.c1 = s.c2", + commitTarget()); + }); + + AssertHelpers.assertThrows( + "Should complain about updating a map column", + AnalysisException.class, + "Updating nested fields is only supported for structs", + () -> { + sql( + "MERGE INTO %s t USING source s " + + "ON t.id == s.c1 " + + "WHEN MATCHED THEN " + + " UPDATE SET t.m.key = 'new_key'", + commitTarget()); + }); + } + + @Test + public void testMergeWithConflictingUpdates() { + createAndInitTable( + "id INT, c STRUCT>", + "{ \"id\": 1, \"c\": { \"n1\": 2, \"n2\": { \"dn1\": 3, \"dn2\": 4 } } }"); + createOrReplaceView("source", "{ \"c1\": -100, \"c2\": -200 }"); + + AssertHelpers.assertThrows( + "Should complain about conflicting updates to a top-level column", + AnalysisException.class, + "Updates are in conflict", + () -> { + sql( + "MERGE INTO %s t USING source s " + + "ON t.id == s.c1 " + + "WHEN MATCHED THEN " + + " UPDATE SET t.id = 1, t.c.n1 = 2, t.id = 2", + commitTarget()); + }); + + AssertHelpers.assertThrows( + "Should complain about conflicting updates to a nested column", + AnalysisException.class, + "Updates are in conflict for these columns", + () -> { + sql( + "MERGE INTO %s t USING source s " + + "ON t.id == s.c1 " + + "WHEN MATCHED THEN " + + " UPDATE SET t.c.n1 = 1, t.id = 2, t.c.n1 = 2", + commitTarget()); + }); + + AssertHelpers.assertThrows( + "Should complain about conflicting updates to a nested column", + AnalysisException.class, + "Updates are in conflict", + () -> { + sql( + "MERGE INTO %s t USING source s " + + "ON t.id == s.c1 " + + "WHEN MATCHED THEN " + + " UPDATE SET c.n1 = 1, c = named_struct('n1', 1, 'n2', named_struct('dn1', 1, 'dn2', 2))", + commitTarget()); + }); + } + + @Test + public void testMergeWithInvalidAssignments() { + createAndInitTable( + "id INT NOT NULL, s STRUCT> NOT NULL", + "{ \"id\": 1, \"s\": { \"n1\": 2, \"n2\": { \"dn1\": 3, \"dn2\": 4 } } }"); + createOrReplaceView( + "source", + "c1 INT, c2 STRUCT NOT NULL, c3 STRING NOT NULL, c4 STRUCT", + "{ \"c1\": -100, \"c2\": { \"n1\" : 1 }, \"c3\" : 'str', \"c4\": { \"dn2\": 1, \"dn2\": 2 } }"); + + for (String policy : new String[] {"ansi", "strict"}) { + withSQLConf( + ImmutableMap.of("spark.sql.storeAssignmentPolicy", policy), + () -> { + AssertHelpers.assertThrows( + "Should complain about writing nulls to a top-level column", + AnalysisException.class, + "Cannot write nullable values to non-null column", + () -> { + sql( + "MERGE INTO %s t USING source s " + + "ON t.id == s.c1 " + + "WHEN MATCHED THEN " + + " UPDATE SET t.id = NULL", + commitTarget()); + }); + + AssertHelpers.assertThrows( + "Should complain about writing nulls to a nested column", + AnalysisException.class, + "Cannot write nullable values to non-null column", + () -> { + sql( + "MERGE INTO %s t USING source s " + + "ON t.id == s.c1 " + + "WHEN MATCHED THEN " + + " UPDATE SET t.s.n1 = NULL", + commitTarget()); + }); + + AssertHelpers.assertThrows( + "Should complain about writing missing fields in structs", + AnalysisException.class, + "missing fields", + () -> { + sql( + "MERGE INTO %s t USING source s " + + "ON t.id == s.c1 " + + "WHEN MATCHED THEN " + + " UPDATE SET t.s = s.c2", + commitTarget()); + }); + + AssertHelpers.assertThrows( + "Should complain about writing invalid data types", + AnalysisException.class, + "Cannot safely cast", + () -> { + sql( + "MERGE INTO %s t USING source s " + + "ON t.id == s.c1 " + + "WHEN MATCHED THEN " + + " UPDATE SET t.s.n1 = s.c3", + commitTarget()); + }); + + AssertHelpers.assertThrows( + "Should complain about writing incompatible structs", + AnalysisException.class, + "field name does not match", + () -> { + sql( + "MERGE INTO %s t USING source s " + + "ON t.id == s.c1 " + + "WHEN MATCHED THEN " + + " UPDATE SET t.s.n2 = s.c4", + commitTarget()); + }); + }); + } + } + + @Test + public void testMergeWithNonDeterministicConditions() { + createAndInitTable( + "id INT, c STRUCT>", + "{ \"id\": 1, \"c\": { \"n1\": 2, \"n2\": { \"dn1\": 3, \"dn2\": 4 } } }"); + createOrReplaceView("source", "{ \"c1\": -100, \"c2\": -200 }"); + + AssertHelpers.assertThrows( + "Should complain about non-deterministic search conditions", + AnalysisException.class, + "Non-deterministic functions are not supported", + () -> { + sql( + "MERGE INTO %s t USING source s " + + "ON t.id == s.c1 AND rand() > t.id " + + "WHEN MATCHED THEN " + + " UPDATE SET t.c.n1 = -1", + commitTarget()); + }); + + AssertHelpers.assertThrows( + "Should complain about non-deterministic update conditions", + AnalysisException.class, + "Non-deterministic functions are not supported", + () -> { + sql( + "MERGE INTO %s t USING source s " + + "ON t.id == s.c1 " + + "WHEN MATCHED AND rand() > t.id THEN " + + " UPDATE SET t.c.n1 = -1", + commitTarget()); + }); + + AssertHelpers.assertThrows( + "Should complain about non-deterministic delete conditions", + AnalysisException.class, + "Non-deterministic functions are not supported", + () -> { + sql( + "MERGE INTO %s t USING source s " + + "ON t.id == s.c1 " + + "WHEN MATCHED AND rand() > t.id THEN " + + " DELETE", + commitTarget()); + }); + + AssertHelpers.assertThrows( + "Should complain about non-deterministic insert conditions", + AnalysisException.class, + "Non-deterministic functions are not supported", + () -> { + sql( + "MERGE INTO %s t USING source s " + + "ON t.id == s.c1 " + + "WHEN NOT MATCHED AND rand() > c1 THEN " + + " INSERT (id, c) VALUES (1, null)", + commitTarget()); + }); + } + + @Test + public void testMergeWithAggregateExpressions() { + createAndInitTable( + "id INT, c STRUCT>", + "{ \"id\": 1, \"c\": { \"n1\": 2, \"n2\": { \"dn1\": 3, \"dn2\": 4 } } }"); + createOrReplaceView("source", "{ \"c1\": -100, \"c2\": -200 }"); + + AssertHelpers.assertThrows( + "Should complain about agg expressions in search conditions", + AnalysisException.class, + "Agg functions are not supported", + () -> { + sql( + "MERGE INTO %s t USING source s " + + "ON t.id == s.c1 AND max(t.id) == 1 " + + "WHEN MATCHED THEN " + + " UPDATE SET t.c.n1 = -1", + commitTarget()); + }); + + AssertHelpers.assertThrows( + "Should complain about agg expressions in update conditions", + AnalysisException.class, + "Agg functions are not supported", + () -> { + sql( + "MERGE INTO %s t USING source s " + + "ON t.id == s.c1 " + + "WHEN MATCHED AND sum(t.id) < 1 THEN " + + " UPDATE SET t.c.n1 = -1", + commitTarget()); + }); + + AssertHelpers.assertThrows( + "Should complain about agg expressions in delete conditions", + AnalysisException.class, + "Agg functions are not supported", + () -> { + sql( + "MERGE INTO %s t USING source s " + + "ON t.id == s.c1 " + + "WHEN MATCHED AND sum(t.id) THEN " + + " DELETE", + commitTarget()); + }); + + AssertHelpers.assertThrows( + "Should complain about agg expressions in insert conditions", + AnalysisException.class, + "Agg functions are not supported", + () -> { + sql( + "MERGE INTO %s t USING source s " + + "ON t.id == s.c1 " + + "WHEN NOT MATCHED AND sum(c1) < 1 THEN " + + " INSERT (id, c) VALUES (1, null)", + commitTarget()); + }); + } + + @Test + public void testMergeWithSubqueriesInConditions() { + createAndInitTable( + "id INT, c STRUCT>", + "{ \"id\": 1, \"c\": { \"n1\": 2, \"n2\": { \"dn1\": 3, \"dn2\": 4 } } }"); + createOrReplaceView("source", "{ \"c1\": -100, \"c2\": -200 }"); + + AssertHelpers.assertThrows( + "Should complain about subquery expressions", + AnalysisException.class, + "Subqueries are not supported in conditions", + () -> { + sql( + "MERGE INTO %s t USING source s " + + "ON t.id == s.c1 AND t.id < (SELECT max(c2) FROM source) " + + "WHEN MATCHED THEN " + + " UPDATE SET t.c.n1 = s.c2", + commitTarget()); + }); + + AssertHelpers.assertThrows( + "Should complain about subquery expressions", + AnalysisException.class, + "Subqueries are not supported in conditions", + () -> { + sql( + "MERGE INTO %s t USING source s " + + "ON t.id == s.c1 " + + "WHEN MATCHED AND t.id < (SELECT max(c2) FROM source) THEN " + + " UPDATE SET t.c.n1 = s.c2", + commitTarget()); + }); + + AssertHelpers.assertThrows( + "Should complain about subquery expressions", + AnalysisException.class, + "Subqueries are not supported in conditions", + () -> { + sql( + "MERGE INTO %s t USING source s " + + "ON t.id == s.c1 " + + "WHEN MATCHED AND t.id NOT IN (SELECT c2 FROM source) THEN " + + " DELETE", + commitTarget()); + }); + + AssertHelpers.assertThrows( + "Should complain about subquery expressions", + AnalysisException.class, + "Subqueries are not supported in conditions", + () -> { + sql( + "MERGE INTO %s t USING source s " + + "ON t.id == s.c1 " + + "WHEN NOT MATCHED AND s.c1 IN (SELECT c2 FROM source) THEN " + + " INSERT (id, c) VALUES (1, null)", + commitTarget()); + }); + } + + @Test + public void testMergeWithTargetColumnsInInsertConditions() { + createAndInitTable("id INT, c2 INT", "{ \"id\": 1, \"c2\": 2 }"); + createOrReplaceView("source", "{ \"id\": 1, \"value\": 11 }"); + + AssertHelpers.assertThrows( + "Should complain about the target column", + AnalysisException.class, + "Cannot resolve [c2]", + () -> { + sql( + "MERGE INTO %s t USING source s " + + "ON t.id == s.id " + + "WHEN NOT MATCHED AND c2 = 1 THEN " + + " INSERT (id, c2) VALUES (s.id, null)", + commitTarget()); + }); + } + + @Test + public void testMergeWithNonIcebergTargetTableNotSupported() { + createOrReplaceView("target", "{ \"c1\": -100, \"c2\": -200 }"); + createOrReplaceView("source", "{ \"c1\": -100, \"c2\": -200 }"); + + AssertHelpers.assertThrows( + "Should complain non iceberg target table", + UnsupportedOperationException.class, + "MERGE INTO TABLE is not supported temporarily.", + () -> { + sql( + "MERGE INTO target t USING source s " + + "ON t.c1 == s.c1 " + + "WHEN MATCHED THEN " + + " UPDATE SET *"); + }); + } + + /** + * Tests a merge where both the source and target are evaluated to be partitioned by + * SingePartition at planning time but DynamicFileFilterExec will return an empty target. + */ + @Test + public void testMergeSinglePartitionPartitioning() { + // This table will only have a single file and a single partition + createAndInitTable("id INT", "{\"id\": -1}"); + + // Coalesce forces our source into a SinglePartition distribution + spark.range(0, 5).coalesce(1).createOrReplaceTempView("source"); + + sql( + "MERGE INTO %s t USING source s ON t.id = s.id " + + "WHEN MATCHED THEN UPDATE SET *" + + "WHEN NOT MATCHED THEN INSERT *", + commitTarget()); + + ImmutableList expectedRows = + ImmutableList.of(row(-1), row(0), row(1), row(2), row(3), row(4)); + + List result = sql("SELECT * FROM %s ORDER BY id", selectTarget()); + assertEquals("Should correctly add the non-matching rows", expectedRows, result); + } + + @Test + public void testMergeEmptyTable() { + Assume.assumeFalse("Custom branch does not exist for empty table", "test".equals(branch)); + // This table will only have a single file and a single partition + createAndInitTable("id INT", null); + + // Coalesce forces our source into a SinglePartition distribution + spark.range(0, 5).coalesce(1).createOrReplaceTempView("source"); + + sql( + "MERGE INTO %s t USING source s ON t.id = s.id " + + "WHEN MATCHED THEN UPDATE SET *" + + "WHEN NOT MATCHED THEN INSERT *", + commitTarget()); + + ImmutableList expectedRows = ImmutableList.of(row(0), row(1), row(2), row(3), row(4)); + + List result = sql("SELECT * FROM %s ORDER BY id", selectTarget()); + assertEquals("Should correctly add the non-matching rows", expectedRows, result); + } + + @Test + public void testMergeNonExistingBranch() { + Assume.assumeTrue("Test only applicable to custom branch", "test".equals(branch)); + createAndInitTable("id INT", null); + + // Coalesce forces our source into a SinglePartition distribution + spark.range(0, 5).coalesce(1).createOrReplaceTempView("source"); + Assertions.assertThatThrownBy( + () -> + sql( + "MERGE INTO %s t USING source s ON t.id = s.id " + + "WHEN MATCHED THEN UPDATE SET *" + + "WHEN NOT MATCHED THEN INSERT *", + commitTarget())) + .isInstanceOf(ValidationException.class) + .hasMessage("Cannot use branch (does not exist): test"); + } + + @Test + public void testMergeToWapBranch() { + Assume.assumeTrue("WAP branch only works for table identifier without branch", branch == null); + + createAndInitTable("id INT", "{\"id\": -1}"); + ImmutableList originalRows = ImmutableList.of(row(-1)); + sql( + "ALTER TABLE %s SET TBLPROPERTIES ('%s' = 'true')", + tableName, TableProperties.WRITE_AUDIT_PUBLISH_ENABLED); + spark.range(0, 5).coalesce(1).createOrReplaceTempView("source"); + ImmutableList expectedRows = + ImmutableList.of(row(-1), row(0), row(1), row(2), row(3), row(4)); + + withSQLConf( + ImmutableMap.of(SparkSQLProperties.WAP_BRANCH, "wap"), + () -> { + sql( + "MERGE INTO %s t USING source s ON t.id = s.id " + + "WHEN MATCHED THEN UPDATE SET *" + + "WHEN NOT MATCHED THEN INSERT *", + tableName); + assertEquals( + "Should have expected rows when reading table", + expectedRows, + sql("SELECT * FROM %s ORDER BY id", tableName)); + assertEquals( + "Should have expected rows when reading WAP branch", + expectedRows, + sql("SELECT * FROM %s.branch_wap ORDER BY id", tableName)); + assertEquals( + "Should not modify main branch", + originalRows, + sql("SELECT * FROM %s.branch_main ORDER BY id", tableName)); + }); + + spark.range(3, 6).coalesce(1).createOrReplaceTempView("source2"); + ImmutableList expectedRows2 = + ImmutableList.of(row(-1), row(0), row(1), row(2), row(5)); + withSQLConf( + ImmutableMap.of(SparkSQLProperties.WAP_BRANCH, "wap"), + () -> { + sql( + "MERGE INTO %s t USING source2 s ON t.id = s.id " + + "WHEN MATCHED THEN DELETE " + + "WHEN NOT MATCHED THEN INSERT *", + tableName); + assertEquals( + "Should have expected rows when reading table with multiple writes", + expectedRows2, + sql("SELECT * FROM %s ORDER BY id", tableName)); + assertEquals( + "Should have expected rows when reading WAP branch with multiple writes", + expectedRows2, + sql("SELECT * FROM %s.branch_wap ORDER BY id", tableName)); + assertEquals( + "Should not modify main branch with multiple writes", + originalRows, + sql("SELECT * FROM %s.branch_main ORDER BY id", tableName)); + }); + } + + @Test + public void testMergeToWapBranchWithTableBranchIdentifier() { + Assume.assumeTrue("Test must have branch name part in table identifier", branch != null); + + createAndInitTable("id INT", "{\"id\": -1}"); + sql( + "ALTER TABLE %s SET TBLPROPERTIES ('%s' = 'true')", + tableName, TableProperties.WRITE_AUDIT_PUBLISH_ENABLED); + spark.range(0, 5).coalesce(1).createOrReplaceTempView("source"); + ImmutableList expectedRows = + ImmutableList.of(row(-1), row(0), row(1), row(2), row(3), row(4)); + + withSQLConf( + ImmutableMap.of(SparkSQLProperties.WAP_BRANCH, "wap"), + () -> + Assertions.assertThatThrownBy( + () -> + sql( + "MERGE INTO %s t USING source s ON t.id = s.id " + + "WHEN MATCHED THEN UPDATE SET *" + + "WHEN NOT MATCHED THEN INSERT *", + commitTarget())) + .isInstanceOf(ValidationException.class) + .hasMessage( + String.format( + "Cannot write to both branch and WAP branch, but got branch [%s] and WAP branch [wap]", + branch))); + } + + private void checkJoinAndFilterConditions(String query, String join, String icebergFilters) { + // disable runtime filtering for easier validation + withSQLConf( + ImmutableMap.of(SQLConf.DYNAMIC_PARTITION_PRUNING_ENABLED().key(), "false"), + () -> { + SparkPlan sparkPlan = executeAndKeepPlan(() -> sql(query)); + String planAsString = sparkPlan.toString().replaceAll("#(\\d+L?)", ""); + + Assertions.assertThat(planAsString).as("Join should match").contains(join + "\n"); + + Assertions.assertThat(planAsString) + .as("Pushed filters must match") + .contains("[filters=" + icebergFilters + ","); + }); + } + + private RowLevelOperationMode mode(Table table) { + String modeName = table.properties().getOrDefault(MERGE_MODE, MERGE_MODE_DEFAULT); + return RowLevelOperationMode.fromName(modeName); + } +} diff --git a/spark/v3.4/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestMergeOnReadDelete.java b/spark/v3.4/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestMergeOnReadDelete.java new file mode 100644 index 000000000000..307fe2a8c2d5 --- /dev/null +++ b/spark/v3.4/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestMergeOnReadDelete.java @@ -0,0 +1,158 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.extensions; + +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.spy; +import static org.mockito.Mockito.when; + +import java.util.List; +import java.util.Locale; +import java.util.Map; +import org.apache.iceberg.AssertHelpers; +import org.apache.iceberg.RowDelta; +import org.apache.iceberg.RowLevelOperationMode; +import org.apache.iceberg.Table; +import org.apache.iceberg.TableProperties; +import org.apache.iceberg.exceptions.CommitStateUnknownException; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableList; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap; +import org.apache.iceberg.relocated.com.google.common.collect.Lists; +import org.apache.iceberg.spark.source.SparkTable; +import org.apache.iceberg.spark.source.TestSparkCatalog; +import org.apache.spark.sql.connector.catalog.Identifier; +import org.junit.Assert; +import org.junit.Test; +import org.junit.runners.Parameterized; + +public class TestMergeOnReadDelete extends TestDelete { + + public TestMergeOnReadDelete( + String catalogName, + String implementation, + Map config, + String fileFormat, + Boolean vectorized, + String distributionMode, + String branch) { + super(catalogName, implementation, config, fileFormat, vectorized, distributionMode, branch); + } + + @Override + protected Map extraTableProperties() { + return ImmutableMap.of( + TableProperties.FORMAT_VERSION, + "2", + TableProperties.DELETE_MODE, + RowLevelOperationMode.MERGE_ON_READ.modeName()); + } + + @Parameterized.AfterParam + public static void clearTestSparkCatalogCache() { + TestSparkCatalog.clearTables(); + } + + @Test + public void testCommitUnknownException() { + createAndInitTable("id INT, dep STRING, category STRING"); + + // write unpartitioned files + append(tableName, "{ \"id\": 1, \"dep\": \"hr\", \"category\": \"c1\"}"); + createBranchIfNeeded(); + append( + commitTarget(), + "{ \"id\": 2, \"dep\": \"hr\", \"category\": \"c1\" }\n" + + "{ \"id\": 3, \"dep\": \"hr\", \"category\": \"c1\" }"); + + Table table = validationCatalog.loadTable(tableIdent); + + RowDelta newRowDelta = table.newRowDelta(); + if (branch != null) { + newRowDelta.toBranch(branch); + } + + RowDelta spyNewRowDelta = spy(newRowDelta); + doAnswer( + invocation -> { + newRowDelta.commit(); + throw new CommitStateUnknownException(new RuntimeException("Datacenter on Fire")); + }) + .when(spyNewRowDelta) + .commit(); + + Table spyTable = spy(table); + when(spyTable.newRowDelta()).thenReturn(spyNewRowDelta); + SparkTable sparkTable = + branch == null ? new SparkTable(spyTable, false) : new SparkTable(spyTable, branch, false); + + ImmutableMap config = + ImmutableMap.of( + "type", "hive", + "default-namespace", "default"); + spark + .conf() + .set("spark.sql.catalog.dummy_catalog", "org.apache.iceberg.spark.source.TestSparkCatalog"); + config.forEach( + (key, value) -> spark.conf().set("spark.sql.catalog.dummy_catalog." + key, value)); + Identifier ident = Identifier.of(new String[] {"default"}, "table"); + TestSparkCatalog.setTable(ident, sparkTable); + + // Although an exception is thrown here, write and commit have succeeded + AssertHelpers.assertThrows( + "Should throw a Commit State Unknown Exception", + CommitStateUnknownException.class, + "Datacenter on Fire", + () -> sql("DELETE FROM %s WHERE id = 2", "dummy_catalog.default.table")); + + // Since write and commit succeeded, the rows should be readable + assertEquals( + "Should have expected rows", + ImmutableList.of(row(1, "hr", "c1"), row(3, "hr", "c1")), + sql("SELECT * FROM %s ORDER BY id", "dummy_catalog.default.table")); + } + + @Test + public void testAggregatePushDownInMergeOnReadDelete() { + createAndInitTable("id LONG, data INT"); + sql( + "INSERT INTO TABLE %s VALUES (1, 1111), (1, 2222), (2, 3333), (2, 4444), (3, 5555), (3, 6666) ", + tableName); + createBranchIfNeeded(); + + sql("DELETE FROM %s WHERE data = 1111", commitTarget()); + String select = "SELECT max(data), min(data), count(data) FROM %s"; + + List explain = sql("EXPLAIN " + select, selectTarget()); + String explainString = explain.get(0)[0].toString().toLowerCase(Locale.ROOT); + boolean explainContainsPushDownAggregates = false; + if (explainString.contains("max(data)") + || explainString.contains("min(data)") + || explainString.contains("count(data)")) { + explainContainsPushDownAggregates = true; + } + + Assert.assertFalse( + "min/max/count not pushed down for deleted", explainContainsPushDownAggregates); + + List actual = sql(select, selectTarget()); + List expected = Lists.newArrayList(); + expected.add(new Object[] {6666, 2222, 5L}); + assertEquals("min/max/count push down", expected, actual); + } +} diff --git a/spark/v3.4/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestMergeOnReadMerge.java b/spark/v3.4/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestMergeOnReadMerge.java new file mode 100644 index 000000000000..86629a127687 --- /dev/null +++ b/spark/v3.4/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestMergeOnReadMerge.java @@ -0,0 +1,47 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.extensions; + +import java.util.Map; +import org.apache.iceberg.RowLevelOperationMode; +import org.apache.iceberg.TableProperties; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap; + +public class TestMergeOnReadMerge extends TestMerge { + + public TestMergeOnReadMerge( + String catalogName, + String implementation, + Map config, + String fileFormat, + boolean vectorized, + String distributionMode, + String branch) { + super(catalogName, implementation, config, fileFormat, vectorized, distributionMode, branch); + } + + @Override + protected Map extraTableProperties() { + return ImmutableMap.of( + TableProperties.FORMAT_VERSION, + "2", + TableProperties.MERGE_MODE, + RowLevelOperationMode.MERGE_ON_READ.modeName()); + } +} diff --git a/spark/v3.4/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestMergeOnReadUpdate.java b/spark/v3.4/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestMergeOnReadUpdate.java new file mode 100644 index 000000000000..416ee8773af6 --- /dev/null +++ b/spark/v3.4/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestMergeOnReadUpdate.java @@ -0,0 +1,47 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.extensions; + +import java.util.Map; +import org.apache.iceberg.RowLevelOperationMode; +import org.apache.iceberg.TableProperties; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap; + +public class TestMergeOnReadUpdate extends TestUpdate { + + public TestMergeOnReadUpdate( + String catalogName, + String implementation, + Map config, + String fileFormat, + boolean vectorized, + String distributionMode, + String branch) { + super(catalogName, implementation, config, fileFormat, vectorized, distributionMode, branch); + } + + @Override + protected Map extraTableProperties() { + return ImmutableMap.of( + TableProperties.FORMAT_VERSION, + "2", + TableProperties.UPDATE_MODE, + RowLevelOperationMode.MERGE_ON_READ.modeName()); + } +} diff --git a/spark/v3.4/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestMetadataTables.java b/spark/v3.4/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestMetadataTables.java new file mode 100644 index 000000000000..21439163848d --- /dev/null +++ b/spark/v3.4/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestMetadataTables.java @@ -0,0 +1,724 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.extensions; + +import static org.apache.iceberg.types.Types.NestedField.optional; + +import java.io.IOException; +import java.util.Comparator; +import java.util.List; +import java.util.Map; +import java.util.stream.Collectors; +import java.util.stream.Stream; +import org.apache.avro.generic.GenericData.Record; +import org.apache.commons.collections.ListUtils; +import org.apache.iceberg.FileContent; +import org.apache.iceberg.HasTableOperations; +import org.apache.iceberg.ManifestFile; +import org.apache.iceberg.Schema; +import org.apache.iceberg.Snapshot; +import org.apache.iceberg.Table; +import org.apache.iceberg.TableMetadata; +import org.apache.iceberg.avro.Avro; +import org.apache.iceberg.io.CloseableIterable; +import org.apache.iceberg.io.InputFile; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableList; +import org.apache.iceberg.relocated.com.google.common.collect.Lists; +import org.apache.iceberg.spark.Spark3Util; +import org.apache.iceberg.spark.SparkSchemaUtil; +import org.apache.iceberg.spark.data.TestHelpers; +import org.apache.iceberg.spark.source.SimpleRecord; +import org.apache.iceberg.types.Types; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Encoders; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.RowFactory; +import org.apache.spark.sql.catalyst.util.DateTimeUtils; +import org.apache.spark.sql.types.StructType; +import org.junit.After; +import org.junit.Assert; +import org.junit.Test; + +public class TestMetadataTables extends SparkExtensionsTestBase { + + public TestMetadataTables(String catalogName, String implementation, Map config) { + super(catalogName, implementation, config); + } + + @After + public void removeTables() { + sql("DROP TABLE IF EXISTS %s", tableName); + } + + @Test + public void testUnpartitionedTable() throws Exception { + sql( + "CREATE TABLE %s (id bigint, data string) USING iceberg TBLPROPERTIES" + + "('format-version'='2', 'write.delete.mode'='merge-on-read')", + tableName); + + List records = + Lists.newArrayList( + new SimpleRecord(1, "a"), + new SimpleRecord(2, "b"), + new SimpleRecord(3, "c"), + new SimpleRecord(4, "d")); + spark + .createDataset(records, Encoders.bean(SimpleRecord.class)) + .coalesce(1) + .writeTo(tableName) + .append(); + + sql("DELETE FROM %s WHERE id=1", tableName); + + Table table = Spark3Util.loadIcebergTable(spark, tableName); + List expectedDataManifests = TestHelpers.dataManifests(table); + List expectedDeleteManifests = TestHelpers.deleteManifests(table); + Assert.assertEquals("Should have 1 data manifest", 1, expectedDataManifests.size()); + Assert.assertEquals("Should have 1 delete manifest", 1, expectedDeleteManifests.size()); + + Schema entriesTableSchema = Spark3Util.loadIcebergTable(spark, tableName + ".entries").schema(); + Schema filesTableSchema = Spark3Util.loadIcebergTable(spark, tableName + ".files").schema(); + + // check delete files table + Dataset actualDeleteFilesDs = spark.sql("SELECT * FROM " + tableName + ".delete_files"); + List actualDeleteFiles = TestHelpers.selectNonDerived(actualDeleteFilesDs).collectAsList(); + Assert.assertEquals( + "Metadata table should return one delete file", 1, actualDeleteFiles.size()); + + List expectedDeleteFiles = + expectedEntries( + table, FileContent.POSITION_DELETES, entriesTableSchema, expectedDeleteManifests, null); + Assert.assertEquals("Should be one delete file manifest entry", 1, expectedDeleteFiles.size()); + TestHelpers.assertEqualsSafe( + TestHelpers.nonDerivedSchema(actualDeleteFilesDs), + expectedDeleteFiles.get(0), + actualDeleteFiles.get(0)); + + // check data files table + Dataset actualDataFilesDs = spark.sql("SELECT * FROM " + tableName + ".data_files"); + List actualDataFiles = TestHelpers.selectNonDerived(actualDataFilesDs).collectAsList(); + Assert.assertEquals("Metadata table should return one data file", 1, actualDataFiles.size()); + + List expectedDataFiles = + expectedEntries(table, FileContent.DATA, entriesTableSchema, expectedDataManifests, null); + Assert.assertEquals("Should be one data file manifest entry", 1, expectedDataFiles.size()); + TestHelpers.assertEqualsSafe( + TestHelpers.nonDerivedSchema(actualDataFilesDs), + expectedDataFiles.get(0), + actualDataFiles.get(0)); + + // check all files table + Dataset actualFilesDs = + spark.sql("SELECT * FROM " + tableName + ".files ORDER BY content"); + List actualFiles = TestHelpers.selectNonDerived(actualFilesDs).collectAsList(); + + Assert.assertEquals("Metadata table should return two files", 2, actualFiles.size()); + + List expectedFiles = + Stream.concat(expectedDataFiles.stream(), expectedDeleteFiles.stream()) + .collect(Collectors.toList()); + Assert.assertEquals("Should have two files manifest entries", 2, expectedFiles.size()); + TestHelpers.assertEqualsSafe( + TestHelpers.nonDerivedSchema(actualFilesDs), expectedFiles.get(0), actualFiles.get(0)); + TestHelpers.assertEqualsSafe( + TestHelpers.nonDerivedSchema(actualFilesDs), expectedFiles.get(1), actualFiles.get(1)); + } + + @Test + public void testPartitionedTable() throws Exception { + sql( + "CREATE TABLE %s (id bigint, data string) " + + "USING iceberg " + + "PARTITIONED BY (data) " + + "TBLPROPERTIES" + + "('format-version'='2', 'write.delete.mode'='merge-on-read')", + tableName); + + List recordsA = + Lists.newArrayList(new SimpleRecord(1, "a"), new SimpleRecord(2, "a")); + spark + .createDataset(recordsA, Encoders.bean(SimpleRecord.class)) + .coalesce(1) + .writeTo(tableName) + .append(); + + List recordsB = + Lists.newArrayList(new SimpleRecord(1, "b"), new SimpleRecord(2, "b")); + spark + .createDataset(recordsB, Encoders.bean(SimpleRecord.class)) + .coalesce(1) + .writeTo(tableName) + .append(); + + sql("DELETE FROM %s WHERE id=1 AND data='a'", tableName); + sql("DELETE FROM %s WHERE id=1 AND data='b'", tableName); + + Table table = Spark3Util.loadIcebergTable(spark, tableName); + Schema entriesTableSchema = Spark3Util.loadIcebergTable(spark, tableName + ".entries").schema(); + + List expectedDataManifests = TestHelpers.dataManifests(table); + List expectedDeleteManifests = TestHelpers.deleteManifests(table); + Assert.assertEquals("Should have 2 data manifests", 2, expectedDataManifests.size()); + Assert.assertEquals("Should have 2 delete manifests", 2, expectedDeleteManifests.size()); + + Schema filesTableSchema = + Spark3Util.loadIcebergTable(spark, tableName + ".delete_files").schema(); + + // Check delete files table + List expectedDeleteFiles = + expectedEntries( + table, FileContent.POSITION_DELETES, entriesTableSchema, expectedDeleteManifests, "a"); + Assert.assertEquals( + "Should have one delete file manifest entry", 1, expectedDeleteFiles.size()); + + Dataset actualDeleteFilesDs = + spark.sql("SELECT * FROM " + tableName + ".delete_files " + "WHERE partition.data='a'"); + List actualDeleteFiles = actualDeleteFilesDs.collectAsList(); + + Assert.assertEquals( + "Metadata table should return one delete file", 1, actualDeleteFiles.size()); + TestHelpers.assertEqualsSafe( + TestHelpers.nonDerivedSchema(actualDeleteFilesDs), + expectedDeleteFiles.get(0), + actualDeleteFiles.get(0)); + + // Check data files table + List expectedDataFiles = + expectedEntries(table, FileContent.DATA, entriesTableSchema, expectedDataManifests, "a"); + Assert.assertEquals("Should have one data file manifest entry", 1, expectedDataFiles.size()); + + Dataset actualDataFilesDs = + spark.sql("SELECT * FROM " + tableName + ".data_files " + "WHERE partition.data='a'"); + + List actualDataFiles = TestHelpers.selectNonDerived(actualDataFilesDs).collectAsList(); + Assert.assertEquals("Metadata table should return one data file", 1, actualDataFiles.size()); + TestHelpers.assertEqualsSafe( + TestHelpers.nonDerivedSchema(actualDataFilesDs), + expectedDataFiles.get(0), + actualDataFiles.get(0)); + + List actualPartitionsWithProjection = + spark.sql("SELECT file_count FROM " + tableName + ".partitions ").collectAsList(); + Assert.assertEquals( + "Metadata table should return two partitions record", + 2, + actualPartitionsWithProjection.size()); + for (int i = 0; i < 2; ++i) { + Assert.assertEquals(1, actualPartitionsWithProjection.get(i).get(0)); + } + + // Check files table + List expectedFiles = + Stream.concat(expectedDataFiles.stream(), expectedDeleteFiles.stream()) + .collect(Collectors.toList()); + Assert.assertEquals("Should have two file manifest entries", 2, expectedFiles.size()); + + Dataset actualFilesDs = + spark.sql( + "SELECT * FROM " + tableName + ".files " + "WHERE partition.data='a' ORDER BY content"); + List actualFiles = TestHelpers.selectNonDerived(actualFilesDs).collectAsList(); + Assert.assertEquals("Metadata table should return two files", 2, actualFiles.size()); + TestHelpers.assertEqualsSafe( + TestHelpers.nonDerivedSchema(actualFilesDs), expectedFiles.get(0), actualFiles.get(0)); + TestHelpers.assertEqualsSafe( + TestHelpers.nonDerivedSchema(actualFilesDs), expectedFiles.get(1), actualFiles.get(1)); + } + + @Test + public void testAllFilesUnpartitioned() throws Exception { + sql( + "CREATE TABLE %s (id bigint, data string) USING iceberg TBLPROPERTIES" + + "('format-version'='2', 'write.delete.mode'='merge-on-read')", + tableName); + + List records = + Lists.newArrayList( + new SimpleRecord(1, "a"), + new SimpleRecord(2, "b"), + new SimpleRecord(3, "c"), + new SimpleRecord(4, "d")); + spark + .createDataset(records, Encoders.bean(SimpleRecord.class)) + .coalesce(1) + .writeTo(tableName) + .append(); + + // Create delete file + sql("DELETE FROM %s WHERE id=1", tableName); + + Table table = Spark3Util.loadIcebergTable(spark, tableName); + List expectedDataManifests = TestHelpers.dataManifests(table); + Assert.assertEquals("Should have 1 data manifest", 1, expectedDataManifests.size()); + List expectedDeleteManifests = TestHelpers.deleteManifests(table); + Assert.assertEquals("Should have 1 delete manifest", 1, expectedDeleteManifests.size()); + + // Clear table to test whether 'all_files' can read past files + List results = sql("DELETE FROM %s", tableName); + Assert.assertEquals("Table should be cleared", 0, results.size()); + + Schema entriesTableSchema = Spark3Util.loadIcebergTable(spark, tableName + ".entries").schema(); + Schema filesTableSchema = + Spark3Util.loadIcebergTable(spark, tableName + ".all_data_files").schema(); + + // Check all data files table + Dataset actualDataFilesDs = spark.sql("SELECT * FROM " + tableName + ".all_data_files"); + List actualDataFiles = TestHelpers.selectNonDerived(actualDataFilesDs).collectAsList(); + + List expectedDataFiles = + expectedEntries(table, FileContent.DATA, entriesTableSchema, expectedDataManifests, null); + Assert.assertEquals("Should be one data file manifest entry", 1, expectedDataFiles.size()); + Assert.assertEquals("Metadata table should return one data file", 1, actualDataFiles.size()); + TestHelpers.assertEqualsSafe( + TestHelpers.nonDerivedSchema(actualDataFilesDs), + expectedDataFiles.get(0), + actualDataFiles.get(0)); + + // Check all delete files table + Dataset actualDeleteFilesDs = + spark.sql("SELECT * FROM " + tableName + ".all_delete_files"); + List actualDeleteFiles = TestHelpers.selectNonDerived(actualDeleteFilesDs).collectAsList(); + List expectedDeleteFiles = + expectedEntries( + table, FileContent.POSITION_DELETES, entriesTableSchema, expectedDeleteManifests, null); + Assert.assertEquals("Should be one delete file manifest entry", 1, expectedDeleteFiles.size()); + Assert.assertEquals( + "Metadata table should return one delete file", 1, actualDeleteFiles.size()); + TestHelpers.assertEqualsSafe( + TestHelpers.nonDerivedSchema(actualDeleteFilesDs), + expectedDeleteFiles.get(0), + actualDeleteFiles.get(0)); + + // Check all files table + Dataset actualFilesDs = + spark.sql("SELECT * FROM " + tableName + ".all_files ORDER BY content"); + List actualFiles = actualFilesDs.collectAsList(); + List expectedFiles = ListUtils.union(expectedDataFiles, expectedDeleteFiles); + expectedFiles.sort(Comparator.comparing(r -> ((Integer) r.get("content")))); + Assert.assertEquals("Metadata table should return two files", 2, actualFiles.size()); + TestHelpers.assertEqualsSafe( + TestHelpers.nonDerivedSchema(actualFilesDs), expectedFiles, actualFiles); + } + + @Test + public void testAllFilesPartitioned() throws Exception { + // Create table and insert data + sql( + "CREATE TABLE %s (id bigint, data string) " + + "USING iceberg " + + "PARTITIONED BY (data) " + + "TBLPROPERTIES" + + "('format-version'='2', 'write.delete.mode'='merge-on-read')", + tableName); + + List recordsA = + Lists.newArrayList(new SimpleRecord(1, "a"), new SimpleRecord(2, "a")); + spark + .createDataset(recordsA, Encoders.bean(SimpleRecord.class)) + .coalesce(1) + .writeTo(tableName) + .append(); + + List recordsB = + Lists.newArrayList(new SimpleRecord(1, "b"), new SimpleRecord(2, "b")); + spark + .createDataset(recordsB, Encoders.bean(SimpleRecord.class)) + .coalesce(1) + .writeTo(tableName) + .append(); + + // Create delete file + sql("DELETE FROM %s WHERE id=1", tableName); + + Table table = Spark3Util.loadIcebergTable(spark, tableName); + List expectedDataManifests = TestHelpers.dataManifests(table); + Assert.assertEquals("Should have 2 data manifests", 2, expectedDataManifests.size()); + List expectedDeleteManifests = TestHelpers.deleteManifests(table); + Assert.assertEquals("Should have 1 delete manifest", 1, expectedDeleteManifests.size()); + + // Clear table to test whether 'all_files' can read past files + List results = sql("DELETE FROM %s", tableName); + Assert.assertEquals("Table should be cleared", 0, results.size()); + + Schema entriesTableSchema = Spark3Util.loadIcebergTable(spark, tableName + ".entries").schema(); + Schema filesTableSchema = + Spark3Util.loadIcebergTable(spark, tableName + ".all_data_files").schema(); + + // Check all data files table + Dataset actualDataFilesDs = + spark.sql("SELECT * FROM " + tableName + ".all_data_files " + "WHERE partition.data='a'"); + List actualDataFiles = TestHelpers.selectNonDerived(actualDataFilesDs).collectAsList(); + List expectedDataFiles = + expectedEntries(table, FileContent.DATA, entriesTableSchema, expectedDataManifests, "a"); + Assert.assertEquals("Should be one data file manifest entry", 1, expectedDataFiles.size()); + Assert.assertEquals("Metadata table should return one data file", 1, actualDataFiles.size()); + TestHelpers.assertEqualsSafe( + SparkSchemaUtil.convert(TestHelpers.selectNonDerived(actualDataFilesDs).schema()) + .asStruct(), + expectedDataFiles.get(0), + actualDataFiles.get(0)); + + // Check all delete files table + Dataset actualDeleteFilesDs = + spark.sql("SELECT * FROM " + tableName + ".all_delete_files " + "WHERE partition.data='a'"); + List actualDeleteFiles = TestHelpers.selectNonDerived(actualDeleteFilesDs).collectAsList(); + + List expectedDeleteFiles = + expectedEntries( + table, FileContent.POSITION_DELETES, entriesTableSchema, expectedDeleteManifests, "a"); + Assert.assertEquals("Should be one data file manifest entry", 1, expectedDeleteFiles.size()); + Assert.assertEquals("Metadata table should return one data file", 1, actualDeleteFiles.size()); + + TestHelpers.assertEqualsSafe( + TestHelpers.nonDerivedSchema(actualDeleteFilesDs), + expectedDeleteFiles.get(0), + actualDeleteFiles.get(0)); + + // Check all files table + Dataset actualFilesDs = + spark.sql( + "SELECT * FROM " + + tableName + + ".all_files WHERE partition.data='a' " + + "ORDER BY content"); + List actualFiles = TestHelpers.selectNonDerived(actualFilesDs).collectAsList(); + + List expectedFiles = ListUtils.union(expectedDataFiles, expectedDeleteFiles); + expectedFiles.sort(Comparator.comparing(r -> ((Integer) r.get("content")))); + Assert.assertEquals("Metadata table should return two files", 2, actualFiles.size()); + TestHelpers.assertEqualsSafe( + TestHelpers.nonDerivedSchema(actualDataFilesDs), expectedFiles, actualFiles); + } + + @Test + public void testMetadataLogEntries() throws Exception { + // Create table and insert data + sql( + "CREATE TABLE %s (id bigint, data string) " + + "USING iceberg " + + "PARTITIONED BY (data) " + + "TBLPROPERTIES " + + "('format-version'='2')", + tableName); + + List recordsA = + Lists.newArrayList(new SimpleRecord(1, "a"), new SimpleRecord(2, "a")); + spark.createDataset(recordsA, Encoders.bean(SimpleRecord.class)).writeTo(tableName).append(); + + List recordsB = + Lists.newArrayList(new SimpleRecord(1, "b"), new SimpleRecord(2, "b")); + spark.createDataset(recordsB, Encoders.bean(SimpleRecord.class)).writeTo(tableName).append(); + + Table table = Spark3Util.loadIcebergTable(spark, tableName); + Long currentSnapshotId = table.currentSnapshot().snapshotId(); + TableMetadata tableMetadata = ((HasTableOperations) table).operations().current(); + Snapshot currentSnapshot = tableMetadata.currentSnapshot(); + Snapshot parentSnapshot = table.snapshot(currentSnapshot.parentId()); + List metadataLogEntries = + Lists.newArrayList(tableMetadata.previousFiles()); + + // Check metadataLog table + List metadataLogs = sql("SELECT * FROM %s.metadata_log_entries", tableName); + assertEquals( + "MetadataLogEntriesTable result should match the metadataLog entries", + ImmutableList.of( + row( + DateTimeUtils.toJavaTimestamp(metadataLogEntries.get(0).timestampMillis() * 1000), + metadataLogEntries.get(0).file(), + null, + null, + null), + row( + DateTimeUtils.toJavaTimestamp(metadataLogEntries.get(1).timestampMillis() * 1000), + metadataLogEntries.get(1).file(), + parentSnapshot.snapshotId(), + parentSnapshot.schemaId(), + parentSnapshot.sequenceNumber()), + row( + DateTimeUtils.toJavaTimestamp(currentSnapshot.timestampMillis() * 1000), + tableMetadata.metadataFileLocation(), + currentSnapshot.snapshotId(), + currentSnapshot.schemaId(), + currentSnapshot.sequenceNumber())), + metadataLogs); + + // test filtering + List metadataLogWithFilters = + sql( + "SELECT * FROM %s.metadata_log_entries WHERE latest_snapshot_id = %s", + tableName, currentSnapshotId); + Assert.assertEquals( + "metadataLogEntries table should return 1 row", 1, metadataLogWithFilters.size()); + assertEquals( + "Result should match the latest snapshot entry", + ImmutableList.of( + row( + DateTimeUtils.toJavaTimestamp( + tableMetadata.currentSnapshot().timestampMillis() * 1000), + tableMetadata.metadataFileLocation(), + tableMetadata.currentSnapshot().snapshotId(), + tableMetadata.currentSnapshot().schemaId(), + tableMetadata.currentSnapshot().sequenceNumber())), + metadataLogWithFilters); + + // test projection + List metadataFiles = + metadataLogEntries.stream() + .map(TableMetadata.MetadataLogEntry::file) + .collect(Collectors.toList()); + metadataFiles.add(tableMetadata.metadataFileLocation()); + List metadataLogWithProjection = + sql("SELECT file FROM %s.metadata_log_entries", tableName); + Assert.assertEquals( + "metadataLogEntries table should return 3 rows", 3, metadataLogWithProjection.size()); + assertEquals( + "metadataLog entry should be of same file", + metadataFiles.stream().map(this::row).collect(Collectors.toList()), + metadataLogWithProjection); + } + + @Test + public void testFilesTableTimeTravelWithSchemaEvolution() throws Exception { + // Create table and insert data + sql( + "CREATE TABLE %s (id bigint, data string) " + + "USING iceberg " + + "PARTITIONED BY (data) " + + "TBLPROPERTIES" + + "('format-version'='2', 'write.delete.mode'='merge-on-read')", + tableName); + + List recordsA = + Lists.newArrayList(new SimpleRecord(1, "a"), new SimpleRecord(2, "a")); + spark + .createDataset(recordsA, Encoders.bean(SimpleRecord.class)) + .coalesce(1) + .writeTo(tableName) + .append(); + + Table table = Spark3Util.loadIcebergTable(spark, tableName); + + table.updateSchema().addColumn("category", Types.StringType.get()).commit(); + + List newRecords = + Lists.newArrayList(RowFactory.create(3, "b", "c"), RowFactory.create(4, "b", "c")); + + StructType newSparkSchema = + SparkSchemaUtil.convert( + new Schema( + optional(1, "id", Types.IntegerType.get()), + optional(2, "data", Types.StringType.get()), + optional(3, "category", Types.StringType.get()))); + + spark.createDataFrame(newRecords, newSparkSchema).coalesce(1).writeTo(tableName).append(); + + Long currentSnapshotId = table.currentSnapshot().snapshotId(); + + Dataset actualFilesDs = + spark.sql( + "SELECT * FROM " + + tableName + + ".files VERSION AS OF " + + currentSnapshotId + + " ORDER BY content"); + List actualFiles = TestHelpers.selectNonDerived(actualFilesDs).collectAsList(); + Schema entriesTableSchema = Spark3Util.loadIcebergTable(spark, tableName + ".entries").schema(); + List expectedDataManifests = TestHelpers.dataManifests(table); + List expectedFiles = + expectedEntries(table, FileContent.DATA, entriesTableSchema, expectedDataManifests, null); + + Assert.assertEquals("actualFiles size should be 2", 2, actualFiles.size()); + + TestHelpers.assertEqualsSafe( + TestHelpers.nonDerivedSchema(actualFilesDs), expectedFiles.get(0), actualFiles.get(0)); + + TestHelpers.assertEqualsSafe( + TestHelpers.nonDerivedSchema(actualFilesDs), expectedFiles.get(1), actualFiles.get(1)); + + Assert.assertEquals( + "expectedFiles and actualFiles size should be the same", + actualFiles.size(), + expectedFiles.size()); + } + + @Test + public void testSnapshotReferencesMetatable() throws Exception { + // Create table and insert data + sql( + "CREATE TABLE %s (id bigint, data string) " + + "USING iceberg " + + "PARTITIONED BY (data) " + + "TBLPROPERTIES" + + "('format-version'='2', 'write.delete.mode'='merge-on-read')", + tableName); + + List recordsA = + Lists.newArrayList(new SimpleRecord(1, "a"), new SimpleRecord(2, "a")); + spark + .createDataset(recordsA, Encoders.bean(SimpleRecord.class)) + .coalesce(1) + .writeTo(tableName) + .append(); + + List recordsB = + Lists.newArrayList(new SimpleRecord(1, "b"), new SimpleRecord(2, "b")); + spark + .createDataset(recordsB, Encoders.bean(SimpleRecord.class)) + .coalesce(1) + .writeTo(tableName) + .append(); + + Table table = Spark3Util.loadIcebergTable(spark, tableName); + Long currentSnapshotId = table.currentSnapshot().snapshotId(); + + // Create branch + table + .manageSnapshots() + .createBranch("testBranch", currentSnapshotId) + .setMaxRefAgeMs("testBranch", 10) + .setMinSnapshotsToKeep("testBranch", 20) + .setMaxSnapshotAgeMs("testBranch", 30) + .commit(); + // Create Tag + table + .manageSnapshots() + .createTag("testTag", currentSnapshotId) + .setMaxRefAgeMs("testTag", 50) + .commit(); + // Check refs table + List references = spark.sql("SELECT * FROM " + tableName + ".refs").collectAsList(); + Assert.assertEquals("Refs table should return 3 rows", 3, references.size()); + List branches = + spark.sql("SELECT * FROM " + tableName + ".refs WHERE type='BRANCH'").collectAsList(); + Assert.assertEquals("Refs table should return 2 branches", 2, branches.size()); + List tags = + spark.sql("SELECT * FROM " + tableName + ".refs WHERE type='TAG'").collectAsList(); + Assert.assertEquals("Refs table should return 1 tag", 1, tags.size()); + + // Check branch entries in refs table + List mainBranch = + spark + .sql("SELECT * FROM " + tableName + ".refs WHERE name = 'main' AND type='BRANCH'") + .collectAsList(); + Assert.assertEquals("main", mainBranch.get(0).getAs("name")); + Assert.assertEquals("BRANCH", mainBranch.get(0).getAs("type")); + Assert.assertEquals(currentSnapshotId, mainBranch.get(0).getAs("snapshot_id")); + + List testBranch = + spark + .sql("SELECT * FROM " + tableName + ".refs WHERE name = 'testBranch' AND type='BRANCH'") + .collectAsList(); + Assert.assertEquals("testBranch", testBranch.get(0).getAs("name")); + Assert.assertEquals("BRANCH", testBranch.get(0).getAs("type")); + Assert.assertEquals(currentSnapshotId, testBranch.get(0).getAs("snapshot_id")); + Assert.assertEquals(Long.valueOf(10), testBranch.get(0).getAs("max_reference_age_in_ms")); + Assert.assertEquals(Integer.valueOf(20), testBranch.get(0).getAs("min_snapshots_to_keep")); + Assert.assertEquals(Long.valueOf(30), testBranch.get(0).getAs("max_snapshot_age_in_ms")); + + // Check tag entries in refs table + List testTag = + spark + .sql("SELECT * FROM " + tableName + ".refs WHERE name = 'testTag' AND type='TAG'") + .collectAsList(); + Assert.assertEquals("testTag", testTag.get(0).getAs("name")); + Assert.assertEquals("TAG", testTag.get(0).getAs("type")); + Assert.assertEquals(currentSnapshotId, testTag.get(0).getAs("snapshot_id")); + Assert.assertEquals(Long.valueOf(50), testTag.get(0).getAs("max_reference_age_in_ms")); + + // Check projection in refs table + List testTagProjection = + spark + .sql( + "SELECT name,type,snapshot_id,max_reference_age_in_ms,min_snapshots_to_keep FROM " + + tableName + + ".refs where type='TAG'") + .collectAsList(); + Assert.assertEquals("testTag", testTagProjection.get(0).getAs("name")); + Assert.assertEquals("TAG", testTagProjection.get(0).getAs("type")); + Assert.assertEquals(currentSnapshotId, testTagProjection.get(0).getAs("snapshot_id")); + Assert.assertEquals( + Long.valueOf(50), testTagProjection.get(0).getAs("max_reference_age_in_ms")); + Assert.assertNull(testTagProjection.get(0).getAs("min_snapshots_to_keep")); + + List mainBranchProjection = + spark + .sql( + "SELECT name, type FROM " + + tableName + + ".refs WHERE name = 'main' AND type = 'BRANCH'") + .collectAsList(); + Assert.assertEquals("main", mainBranchProjection.get(0).getAs("name")); + Assert.assertEquals("BRANCH", mainBranchProjection.get(0).getAs("type")); + + List testBranchProjection = + spark + .sql( + "SELECT type, name, max_reference_age_in_ms, snapshot_id FROM " + + tableName + + ".refs WHERE name = 'testBranch' AND type = 'BRANCH'") + .collectAsList(); + Assert.assertEquals("testBranch", testBranchProjection.get(0).getAs("name")); + Assert.assertEquals("BRANCH", testBranchProjection.get(0).getAs("type")); + Assert.assertEquals(currentSnapshotId, testBranchProjection.get(0).getAs("snapshot_id")); + Assert.assertEquals( + Long.valueOf(10), testBranchProjection.get(0).getAs("max_reference_age_in_ms")); + } + + /** + * Find matching manifest entries of an Iceberg table + * + * @param table iceberg table + * @param expectedContent file content to populate on entries + * @param entriesTableSchema schema of Manifest entries + * @param manifestsToExplore manifests to explore of the table + * @param partValue partition value that manifest entries must match, or null to skip filtering + */ + private List expectedEntries( + Table table, + FileContent expectedContent, + Schema entriesTableSchema, + List manifestsToExplore, + String partValue) + throws IOException { + List expected = Lists.newArrayList(); + for (ManifestFile manifest : manifestsToExplore) { + InputFile in = table.io().newInputFile(manifest.path()); + try (CloseableIterable rows = Avro.read(in).project(entriesTableSchema).build()) { + for (Record record : rows) { + if ((Integer) record.get("status") < 2 /* added or existing */) { + Record file = (Record) record.get("data_file"); + if (partitionMatch(file, partValue)) { + TestHelpers.asMetadataRecord(file, expectedContent); + expected.add(file); + } + } + } + } + } + return expected; + } + + private boolean partitionMatch(Record file, String partValue) { + if (partValue == null) { + return true; + } + Record partition = (Record) file.get(4); + return partValue.equals(partition.get(0).toString()); + } +} diff --git a/spark/v3.4/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestMigrateTableProcedure.java b/spark/v3.4/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestMigrateTableProcedure.java new file mode 100644 index 000000000000..8b2950b74f8d --- /dev/null +++ b/spark/v3.4/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestMigrateTableProcedure.java @@ -0,0 +1,187 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.extensions; + +import java.io.IOException; +import java.util.Map; +import org.apache.iceberg.AssertHelpers; +import org.apache.iceberg.Table; +import org.apache.iceberg.exceptions.ValidationException; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableList; +import org.apache.spark.sql.AnalysisException; +import org.junit.After; +import org.junit.Assert; +import org.junit.Assume; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.TemporaryFolder; + +public class TestMigrateTableProcedure extends SparkExtensionsTestBase { + + public TestMigrateTableProcedure( + String catalogName, String implementation, Map config) { + super(catalogName, implementation, config); + } + + @Rule public TemporaryFolder temp = new TemporaryFolder(); + + @After + public void removeTables() { + sql("DROP TABLE IF EXISTS %s", tableName); + sql("DROP TABLE IF EXISTS %s_BACKUP_", tableName); + } + + @Test + public void testMigrate() throws IOException { + Assume.assumeTrue(catalogName.equals("spark_catalog")); + String location = temp.newFolder().toString(); + sql( + "CREATE TABLE %s (id bigint NOT NULL, data string) USING parquet LOCATION '%s'", + tableName, location); + sql("INSERT INTO TABLE %s VALUES (1, 'a')", tableName); + Object result = scalarSql("CALL %s.system.migrate('%s')", catalogName, tableName); + + Assert.assertEquals("Should have added one file", 1L, result); + + Table createdTable = validationCatalog.loadTable(tableIdent); + + String tableLocation = createdTable.location().replace("file:", ""); + Assert.assertEquals("Table should have original location", location, tableLocation); + + sql("INSERT INTO TABLE %s VALUES (1, 'a')", tableName); + + assertEquals( + "Should have expected rows", + ImmutableList.of(row(1L, "a"), row(1L, "a")), + sql("SELECT * FROM %s ORDER BY id", tableName)); + + sql("DROP TABLE IF EXISTS %s", tableName + "_BACKUP_"); + } + + @Test + public void testMigrateWithOptions() throws IOException { + Assume.assumeTrue(catalogName.equals("spark_catalog")); + String location = temp.newFolder().toString(); + sql( + "CREATE TABLE %s (id bigint NOT NULL, data string) USING parquet LOCATION '%s'", + tableName, location); + sql("INSERT INTO TABLE %s VALUES (1, 'a')", tableName); + + Object result = + scalarSql("CALL %s.system.migrate('%s', map('foo', 'bar'))", catalogName, tableName); + + Assert.assertEquals("Should have added one file", 1L, result); + + Table createdTable = validationCatalog.loadTable(tableIdent); + + Map props = createdTable.properties(); + Assert.assertEquals("Should have extra property set", "bar", props.get("foo")); + + String tableLocation = createdTable.location().replace("file:", ""); + Assert.assertEquals("Table should have original location", location, tableLocation); + + sql("INSERT INTO TABLE %s VALUES (1, 'a')", tableName); + + assertEquals( + "Should have expected rows", + ImmutableList.of(row(1L, "a"), row(1L, "a")), + sql("SELECT * FROM %s ORDER BY id", tableName)); + + sql("DROP TABLE IF EXISTS %s", tableName + "_BACKUP_"); + } + + @Test + public void testMigrateWithDropBackup() throws IOException { + Assume.assumeTrue(catalogName.equals("spark_catalog")); + String location = temp.newFolder().toString(); + sql( + "CREATE TABLE %s (id bigint NOT NULL, data string) USING parquet LOCATION '%s'", + tableName, location); + sql("INSERT INTO TABLE %s VALUES (1, 'a')", tableName); + + Object result = + scalarSql( + "CALL %s.system.migrate(table => '%s', drop_backup => true)", catalogName, tableName); + Assert.assertEquals("Should have added one file", 1L, result); + Assert.assertFalse(spark.catalog().tableExists(tableName + "_BACKUP_")); + } + + @Test + public void testMigrateWithInvalidMetricsConfig() throws IOException { + Assume.assumeTrue(catalogName.equals("spark_catalog")); + + String location = temp.newFolder().toString(); + sql( + "CREATE TABLE %s (id bigint NOT NULL, data string) USING parquet LOCATION '%s'", + tableName, location); + + AssertHelpers.assertThrows( + "Should reject invalid metrics config", + ValidationException.class, + "Invalid metrics config", + () -> { + String props = "map('write.metadata.metrics.column.x', 'X')"; + sql("CALL %s.system.migrate('%s', %s)", catalogName, tableName, props); + }); + } + + @Test + public void testMigrateWithConflictingProps() throws IOException { + Assume.assumeTrue(catalogName.equals("spark_catalog")); + + String location = temp.newFolder().toString(); + sql( + "CREATE TABLE %s (id bigint NOT NULL, data string) USING parquet LOCATION '%s'", + tableName, location); + sql("INSERT INTO TABLE %s VALUES (1, 'a')", tableName); + + Object result = + scalarSql("CALL %s.system.migrate('%s', map('migrated', 'false'))", catalogName, tableName); + Assert.assertEquals("Should have added one file", 1L, result); + + assertEquals( + "Should have expected rows", + ImmutableList.of(row(1L, "a")), + sql("SELECT * FROM %s", tableName)); + + Table table = validationCatalog.loadTable(tableIdent); + Assert.assertEquals("Should override user value", "true", table.properties().get("migrated")); + } + + @Test + public void testInvalidMigrateCases() { + AssertHelpers.assertThrows( + "Should reject calls without all required args", + AnalysisException.class, + "Missing required parameters", + () -> sql("CALL %s.system.migrate()", catalogName)); + + AssertHelpers.assertThrows( + "Should reject calls with invalid arg types", + AnalysisException.class, + "Wrong arg type", + () -> sql("CALL %s.system.migrate(map('foo','bar'))", catalogName)); + + AssertHelpers.assertThrows( + "Should reject calls with empty table identifier", + IllegalArgumentException.class, + "Cannot handle an empty identifier", + () -> sql("CALL %s.system.migrate('')", catalogName)); + } +} diff --git a/spark/v3.4/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestPublishChangesProcedure.java b/spark/v3.4/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestPublishChangesProcedure.java new file mode 100644 index 000000000000..2b74cd475fae --- /dev/null +++ b/spark/v3.4/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestPublishChangesProcedure.java @@ -0,0 +1,193 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.extensions; + +import static org.apache.iceberg.TableProperties.WRITE_AUDIT_PUBLISH_ENABLED; + +import java.util.List; +import java.util.Map; +import org.apache.iceberg.AssertHelpers; +import org.apache.iceberg.Snapshot; +import org.apache.iceberg.Table; +import org.apache.iceberg.exceptions.ValidationException; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableList; +import org.apache.iceberg.relocated.com.google.common.collect.Iterables; +import org.apache.spark.sql.AnalysisException; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.catalyst.analysis.NoSuchProcedureException; +import org.junit.After; +import org.junit.Test; + +public class TestPublishChangesProcedure extends SparkExtensionsTestBase { + + public TestPublishChangesProcedure( + String catalogName, String implementation, Map config) { + super(catalogName, implementation, config); + } + + @After + public void removeTables() { + sql("DROP TABLE IF EXISTS %s", tableName); + } + + @Test + public void testApplyWapChangesUsingPositionalArgs() { + String wapId = "wap_id_1"; + sql("CREATE TABLE %s (id bigint NOT NULL, data string) USING iceberg", tableName); + sql("ALTER TABLE %s SET TBLPROPERTIES ('%s' 'true')", tableName, WRITE_AUDIT_PUBLISH_ENABLED); + + spark.conf().set("spark.wap.id", wapId); + + sql("INSERT INTO TABLE %s VALUES (1, 'a')", tableName); + + assertEquals( + "Should not see rows from staged snapshot", + ImmutableList.of(), + sql("SELECT * FROM %s", tableName)); + + Table table = validationCatalog.loadTable(tableIdent); + Snapshot wapSnapshot = Iterables.getOnlyElement(table.snapshots()); + + List output = + sql("CALL %s.system.publish_changes('%s', '%s')", catalogName, tableIdent, wapId); + + table.refresh(); + + Snapshot currentSnapshot = table.currentSnapshot(); + + assertEquals( + "Procedure output must match", + ImmutableList.of(row(wapSnapshot.snapshotId(), currentSnapshot.snapshotId())), + output); + + assertEquals( + "Apply of WAP changes must be successful", + ImmutableList.of(row(1L, "a")), + sql("SELECT * FROM %s", tableName)); + } + + @Test + public void testApplyWapChangesUsingNamedArgs() { + String wapId = "wap_id_1"; + sql("CREATE TABLE %s (id bigint NOT NULL, data string) USING iceberg", tableName); + sql("ALTER TABLE %s SET TBLPROPERTIES ('%s' 'true')", tableName, WRITE_AUDIT_PUBLISH_ENABLED); + + spark.conf().set("spark.wap.id", wapId); + + sql("INSERT INTO TABLE %s VALUES (1, 'a')", tableName); + + assertEquals( + "Should not see rows from staged snapshot", + ImmutableList.of(), + sql("SELECT * FROM %s", tableName)); + + Table table = validationCatalog.loadTable(tableIdent); + Snapshot wapSnapshot = Iterables.getOnlyElement(table.snapshots()); + + List output = + sql( + "CALL %s.system.publish_changes(wap_id => '%s', table => '%s')", + catalogName, wapId, tableIdent); + + table.refresh(); + + Snapshot currentSnapshot = table.currentSnapshot(); + + assertEquals( + "Procedure output must match", + ImmutableList.of(row(wapSnapshot.snapshotId(), currentSnapshot.snapshotId())), + output); + + assertEquals( + "Apply of WAP changes must be successful", + ImmutableList.of(row(1L, "a")), + sql("SELECT * FROM %s", tableName)); + } + + @Test + public void testApplyWapChangesRefreshesRelationCache() { + String wapId = "wap_id_1"; + sql("CREATE TABLE %s (id bigint NOT NULL, data string) USING iceberg", tableName); + sql("ALTER TABLE %s SET TBLPROPERTIES ('%s' 'true')", tableName, WRITE_AUDIT_PUBLISH_ENABLED); + + Dataset query = spark.sql("SELECT * FROM " + tableName + " WHERE id = 1"); + query.createOrReplaceTempView("tmp"); + + spark.sql("CACHE TABLE tmp"); + + assertEquals("View should not produce rows", ImmutableList.of(), sql("SELECT * FROM tmp")); + + spark.conf().set("spark.wap.id", wapId); + + sql("INSERT INTO TABLE %s VALUES (1, 'a')", tableName); + + assertEquals( + "Should not see rows from staged snapshot", + ImmutableList.of(), + sql("SELECT * FROM %s", tableName)); + + sql("CALL %s.system.publish_changes('%s', '%s')", catalogName, tableIdent, wapId); + + assertEquals( + "Apply of WAP changes should be visible", + ImmutableList.of(row(1L, "a")), + sql("SELECT * FROM tmp")); + + sql("UNCACHE TABLE tmp"); + } + + @Test + public void testApplyInvalidWapId() { + sql("CREATE TABLE %s (id bigint NOT NULL, data string) USING iceberg", tableName); + + AssertHelpers.assertThrows( + "Should reject invalid wap id", + ValidationException.class, + "Cannot apply unknown WAP ID", + () -> sql("CALL %s.system.publish_changes('%s', 'not_valid')", catalogName, tableIdent)); + } + + @Test + public void testInvalidApplyWapChangesCases() { + AssertHelpers.assertThrows( + "Should not allow mixed args", + AnalysisException.class, + "Named and positional arguments cannot be mixed", + () -> sql("CALL %s.system.publish_changes('n', table => 't', 'not_valid')", catalogName)); + + AssertHelpers.assertThrows( + "Should not resolve procedures in arbitrary namespaces", + NoSuchProcedureException.class, + "not found", + () -> sql("CALL %s.custom.publish_changes('n', 't', 'not_valid')", catalogName)); + + AssertHelpers.assertThrows( + "Should reject calls without all required args", + AnalysisException.class, + "Missing required parameters", + () -> sql("CALL %s.system.publish_changes('t')", catalogName)); + + AssertHelpers.assertThrows( + "Should reject calls with empty table identifier", + IllegalArgumentException.class, + "Cannot handle an empty identifier", + () -> sql("CALL %s.system.publish_changes('', 'not_valid')", catalogName)); + } +} diff --git a/spark/v3.4/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestRegisterTableProcedure.java b/spark/v3.4/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestRegisterTableProcedure.java new file mode 100644 index 000000000000..2f1165e9cd5e --- /dev/null +++ b/spark/v3.4/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestRegisterTableProcedure.java @@ -0,0 +1,85 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.extensions; + +import java.util.List; +import java.util.Map; +import org.apache.iceberg.HasTableOperations; +import org.apache.iceberg.Table; +import org.apache.iceberg.spark.Spark3Util; +import org.apache.spark.sql.catalyst.analysis.NoSuchTableException; +import org.apache.spark.sql.catalyst.parser.ParseException; +import org.apache.spark.sql.functions; +import org.apache.spark.sql.types.DataTypes; +import org.junit.After; +import org.junit.Assert; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.TemporaryFolder; + +public class TestRegisterTableProcedure extends SparkExtensionsTestBase { + + private final String targetName; + + public TestRegisterTableProcedure( + String catalogName, String implementation, Map config) { + super(catalogName, implementation, config); + targetName = tableName("register_table"); + } + + @Rule public TemporaryFolder temp = new TemporaryFolder(); + + @After + public void dropTables() { + sql("DROP TABLE IF EXISTS %s", tableName); + sql("DROP TABLE IF EXISTS %s", targetName); + } + + @Test + public void testRegisterTable() throws NoSuchTableException, ParseException { + long numRows = 1000; + + sql("CREATE TABLE %s (id int, data string) using ICEBERG", tableName); + spark + .range(0, numRows) + .withColumn("data", functions.col("id").cast(DataTypes.StringType)) + .writeTo(tableName) + .append(); + + Table table = Spark3Util.loadIcebergTable(spark, tableName); + long originalFileCount = (long) scalarSql("SELECT COUNT(*) from %s.files", tableName); + long currentSnapshotId = table.currentSnapshot().snapshotId(); + String metadataJson = + (((HasTableOperations) table).operations()).current().metadataFileLocation(); + + List result = + sql("CALL %s.system.register_table('%s', '%s')", catalogName, targetName, metadataJson); + Assert.assertEquals("Current Snapshot is not correct", currentSnapshotId, result.get(0)[0]); + + List original = sql("SELECT * FROM %s", tableName); + List registered = sql("SELECT * FROM %s", targetName); + assertEquals("Registered table rows should match original table rows", original, registered); + Assert.assertEquals( + "Should have the right row count in the procedure result", numRows, result.get(0)[1]); + Assert.assertEquals( + "Should have the right datafile count in the procedure result", + originalFileCount, + result.get(0)[2]); + } +} diff --git a/spark/v3.4/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestRemoveOrphanFilesProcedure.java b/spark/v3.4/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestRemoveOrphanFilesProcedure.java new file mode 100644 index 000000000000..4cf6b10cb293 --- /dev/null +++ b/spark/v3.4/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestRemoveOrphanFilesProcedure.java @@ -0,0 +1,629 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.extensions; + +import static org.apache.iceberg.TableProperties.GC_ENABLED; +import static org.apache.iceberg.TableProperties.WRITE_AUDIT_PUBLISH_ENABLED; + +import java.io.File; +import java.io.IOException; +import java.net.URI; +import java.nio.ByteBuffer; +import java.nio.charset.StandardCharsets; +import java.sql.Timestamp; +import java.time.Instant; +import java.util.List; +import java.util.Map; +import java.util.UUID; +import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.fs.FileSystem; +import org.apache.hadoop.fs.Path; +import org.apache.iceberg.AssertHelpers; +import org.apache.iceberg.DataFile; +import org.apache.iceberg.DataFiles; +import org.apache.iceberg.Files; +import org.apache.iceberg.GenericBlobMetadata; +import org.apache.iceberg.GenericStatisticsFile; +import org.apache.iceberg.ManifestFile; +import org.apache.iceberg.PartitionSpec; +import org.apache.iceberg.ReachableFileUtil; +import org.apache.iceberg.StatisticsFile; +import org.apache.iceberg.Table; +import org.apache.iceberg.Transaction; +import org.apache.iceberg.exceptions.ValidationException; +import org.apache.iceberg.puffin.Blob; +import org.apache.iceberg.puffin.Puffin; +import org.apache.iceberg.puffin.PuffinWriter; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableList; +import org.apache.iceberg.relocated.com.google.common.collect.Iterables; +import org.apache.iceberg.relocated.com.google.common.collect.Lists; +import org.apache.iceberg.spark.Spark3Util; +import org.apache.iceberg.spark.data.TestHelpers; +import org.apache.iceberg.spark.source.FilePathLastModifiedRecord; +import org.apache.iceberg.spark.source.SimpleRecord; +import org.apache.spark.sql.AnalysisException; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Encoders; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.catalyst.analysis.NoSuchProcedureException; +import org.apache.spark.sql.catalyst.analysis.NoSuchTableException; +import org.apache.spark.sql.catalyst.parser.ParseException; +import org.assertj.core.api.Assertions; +import org.junit.After; +import org.junit.Assert; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.TemporaryFolder; + +public class TestRemoveOrphanFilesProcedure extends SparkExtensionsTestBase { + + @Rule public TemporaryFolder temp = new TemporaryFolder(); + + public TestRemoveOrphanFilesProcedure( + String catalogName, String implementation, Map config) { + super(catalogName, implementation, config); + } + + @After + public void removeTable() { + // TODO: use the Iceberg catalog to drop the table until SPARK-43203 is fixed + validationCatalog.dropTable(tableIdent, true /* purge */); + sql("DROP TABLE IF EXISTS %s", tableName); + sql("DROP TABLE IF EXISTS p PURGE"); + } + + @Test + public void testRemoveOrphanFilesInEmptyTable() { + sql("CREATE TABLE %s (id bigint NOT NULL, data string) USING iceberg", tableName); + + List output = + sql("CALL %s.system.remove_orphan_files('%s')", catalogName, tableIdent); + assertEquals("Should be no orphan files", ImmutableList.of(), output); + + assertEquals("Should have no rows", ImmutableList.of(), sql("SELECT * FROM %s", tableName)); + } + + @Test + public void testRemoveOrphanFilesInDataFolder() throws IOException { + if (catalogName.equals("testhadoop")) { + sql("CREATE TABLE %s (id bigint NOT NULL, data string) USING iceberg", tableName); + } else { + // give a fresh location to Hive tables as Spark will not clean up the table location + // correctly while dropping tables through spark_catalog + sql( + "CREATE TABLE %s (id bigint NOT NULL, data string) USING iceberg LOCATION '%s'", + tableName, temp.newFolder()); + } + + sql("INSERT INTO TABLE %s VALUES (1, 'a')", tableName); + sql("INSERT INTO TABLE %s VALUES (2, 'b')", tableName); + + Table table = validationCatalog.loadTable(tableIdent); + + String metadataLocation = table.location() + "/metadata"; + String dataLocation = table.location() + "/data"; + + // produce orphan files in the data location using parquet + sql("CREATE TABLE p (id bigint) USING parquet LOCATION '%s'", dataLocation); + sql("INSERT INTO TABLE p VALUES (1)"); + + // wait to ensure files are old enough + waitUntilAfter(System.currentTimeMillis()); + + Timestamp currentTimestamp = Timestamp.from(Instant.ofEpochMilli(System.currentTimeMillis())); + + // check for orphans in the metadata folder + List output1 = + sql( + "CALL %s.system.remove_orphan_files(" + + "table => '%s'," + + "older_than => TIMESTAMP '%s'," + + "location => '%s')", + catalogName, tableIdent, currentTimestamp, metadataLocation); + assertEquals("Should be no orphan files in the metadata folder", ImmutableList.of(), output1); + + // check for orphans in the table location + List output2 = + sql( + "CALL %s.system.remove_orphan_files(" + + "table => '%s'," + + "older_than => TIMESTAMP '%s')", + catalogName, tableIdent, currentTimestamp); + Assert.assertEquals("Should be orphan files in the data folder", 1, output2.size()); + + // the previous call should have deleted all orphan files + List output3 = + sql( + "CALL %s.system.remove_orphan_files(" + + "table => '%s'," + + "older_than => TIMESTAMP '%s')", + catalogName, tableIdent, currentTimestamp); + Assert.assertEquals("Should be no more orphan files in the data folder", 0, output3.size()); + + assertEquals( + "Should have expected rows", + ImmutableList.of(row(1L, "a"), row(2L, "b")), + sql("SELECT * FROM %s ORDER BY id", tableName)); + } + + @Test + public void testRemoveOrphanFilesDryRun() throws IOException { + if (catalogName.equals("testhadoop")) { + sql("CREATE TABLE %s (id bigint NOT NULL, data string) USING iceberg", tableName); + } else { + // give a fresh location to Hive tables as Spark will not clean up the table location + // correctly while dropping tables through spark_catalog + sql( + "CREATE TABLE %s (id bigint NOT NULL, data string) USING iceberg LOCATION '%s'", + tableName, temp.newFolder()); + } + + sql("INSERT INTO TABLE %s VALUES (1, 'a')", tableName); + sql("INSERT INTO TABLE %s VALUES (2, 'b')", tableName); + + Table table = validationCatalog.loadTable(tableIdent); + + // produce orphan files in the table location using parquet + sql("CREATE TABLE p (id bigint) USING parquet LOCATION '%s'", table.location()); + sql("INSERT INTO TABLE p VALUES (1)"); + + // wait to ensure files are old enough + waitUntilAfter(System.currentTimeMillis()); + + Timestamp currentTimestamp = Timestamp.from(Instant.ofEpochMilli(System.currentTimeMillis())); + + // check for orphans without deleting + List output1 = + sql( + "CALL %s.system.remove_orphan_files(" + + "table => '%s'," + + "older_than => TIMESTAMP '%s'," + + "dry_run => true)", + catalogName, tableIdent, currentTimestamp); + Assert.assertEquals("Should be one orphan files", 1, output1.size()); + + // actually delete orphans + List output2 = + sql( + "CALL %s.system.remove_orphan_files(" + + "table => '%s'," + + "older_than => TIMESTAMP '%s')", + catalogName, tableIdent, currentTimestamp); + Assert.assertEquals("Should be one orphan files", 1, output2.size()); + + // the previous call should have deleted all orphan files + List output3 = + sql( + "CALL %s.system.remove_orphan_files(" + + "table => '%s'," + + "older_than => TIMESTAMP '%s')", + catalogName, tableIdent, currentTimestamp); + Assert.assertEquals("Should be no more orphan files", 0, output3.size()); + + assertEquals( + "Should have expected rows", + ImmutableList.of(row(1L, "a"), row(2L, "b")), + sql("SELECT * FROM %s ORDER BY id", tableName)); + } + + @Test + public void testRemoveOrphanFilesGCDisabled() { + sql("CREATE TABLE %s (id bigint NOT NULL, data string) USING iceberg", tableName); + + sql("ALTER TABLE %s SET TBLPROPERTIES ('%s' 'false')", tableName, GC_ENABLED); + + AssertHelpers.assertThrows( + "Should reject call", + ValidationException.class, + "Cannot delete orphan files: GC is disabled", + () -> sql("CALL %s.system.remove_orphan_files('%s')", catalogName, tableIdent)); + + // reset the property to enable the table purging in removeTable. + sql("ALTER TABLE %s SET TBLPROPERTIES ('%s' 'true')", tableName, GC_ENABLED); + } + + @Test + public void testRemoveOrphanFilesWap() { + sql("CREATE TABLE %s (id bigint NOT NULL, data string) USING iceberg", tableName); + sql("ALTER TABLE %s SET TBLPROPERTIES ('%s' 'true')", tableName, WRITE_AUDIT_PUBLISH_ENABLED); + + spark.conf().set("spark.wap.id", "1"); + + sql("INSERT INTO TABLE %s VALUES (1, 'a')", tableName); + + assertEquals( + "Should not see rows from staged snapshot", + ImmutableList.of(), + sql("SELECT * FROM %s", tableName)); + + List output = + sql("CALL %s.system.remove_orphan_files('%s')", catalogName, tableIdent); + assertEquals("Should be no orphan files", ImmutableList.of(), output); + } + + @Test + public void testInvalidRemoveOrphanFilesCases() { + AssertHelpers.assertThrows( + "Should not allow mixed args", + AnalysisException.class, + "Named and positional arguments cannot be mixed", + () -> sql("CALL %s.system.remove_orphan_files('n', table => 't')", catalogName)); + + AssertHelpers.assertThrows( + "Should not resolve procedures in arbitrary namespaces", + NoSuchProcedureException.class, + "not found", + () -> sql("CALL %s.custom.remove_orphan_files('n', 't')", catalogName)); + + AssertHelpers.assertThrows( + "Should reject calls without all required args", + AnalysisException.class, + "Missing required parameters", + () -> sql("CALL %s.system.remove_orphan_files()", catalogName)); + + AssertHelpers.assertThrows( + "Should reject calls with invalid arg types", + AnalysisException.class, + "Wrong arg type", + () -> sql("CALL %s.system.remove_orphan_files('n', 2.2)", catalogName)); + + AssertHelpers.assertThrows( + "Should reject calls with empty table identifier", + IllegalArgumentException.class, + "Cannot handle an empty identifier", + () -> sql("CALL %s.system.remove_orphan_files('')", catalogName)); + } + + @Test + public void testConcurrentRemoveOrphanFiles() throws IOException { + if (catalogName.equals("testhadoop")) { + sql("CREATE TABLE %s (id bigint NOT NULL, data string) USING iceberg", tableName); + } else { + // give a fresh location to Hive tables as Spark will not clean up the table location + // correctly while dropping tables through spark_catalog + sql( + "CREATE TABLE %s (id bigint NOT NULL, data string) USING iceberg LOCATION '%s'", + tableName, temp.newFolder()); + } + + sql("INSERT INTO TABLE %s VALUES (1, 'a')", tableName); + sql("INSERT INTO TABLE %s VALUES (2, 'b')", tableName); + + Table table = validationCatalog.loadTable(tableIdent); + + String metadataLocation = table.location() + "/metadata"; + String dataLocation = table.location() + "/data"; + + // produce orphan files in the data location using parquet + sql("CREATE TABLE p (id bigint) USING parquet LOCATION '%s'", dataLocation); + sql("INSERT INTO TABLE p VALUES (1)"); + sql("INSERT INTO TABLE p VALUES (10)"); + sql("INSERT INTO TABLE p VALUES (100)"); + sql("INSERT INTO TABLE p VALUES (1000)"); + + // wait to ensure files are old enough + waitUntilAfter(System.currentTimeMillis()); + + Timestamp currentTimestamp = Timestamp.from(Instant.ofEpochMilli(System.currentTimeMillis())); + + // check for orphans in the table location + List output = + sql( + "CALL %s.system.remove_orphan_files(" + + "table => '%s'," + + "max_concurrent_deletes => %s," + + "older_than => TIMESTAMP '%s')", + catalogName, tableIdent, 4, currentTimestamp); + Assert.assertEquals("Should be orphan files in the data folder", 4, output.size()); + + // the previous call should have deleted all orphan files + List output3 = + sql( + "CALL %s.system.remove_orphan_files(" + + "table => '%s'," + + "max_concurrent_deletes => %s," + + "older_than => TIMESTAMP '%s')", + catalogName, tableIdent, 4, currentTimestamp); + Assert.assertEquals("Should be no more orphan files in the data folder", 0, output3.size()); + + assertEquals( + "Should have expected rows", + ImmutableList.of(row(1L, "a"), row(2L, "b")), + sql("SELECT * FROM %s ORDER BY id", tableName)); + } + + @Test + public void testConcurrentRemoveOrphanFilesWithInvalidInput() { + sql("CREATE TABLE %s (id bigint NOT NULL, data string) USING iceberg", tableName); + + AssertHelpers.assertThrows( + "Should throw an error when max_concurrent_deletes = 0", + IllegalArgumentException.class, + "max_concurrent_deletes should have value > 0", + () -> + sql( + "CALL %s.system.remove_orphan_files(table => '%s', max_concurrent_deletes => %s)", + catalogName, tableIdent, 0)); + + AssertHelpers.assertThrows( + "Should throw an error when max_concurrent_deletes < 0 ", + IllegalArgumentException.class, + "max_concurrent_deletes should have value > 0", + () -> + sql( + "CALL %s.system.remove_orphan_files(table => '%s', max_concurrent_deletes => %s)", + catalogName, tableIdent, -1)); + + String tempViewName = "file_list_test"; + spark.emptyDataFrame().createOrReplaceTempView(tempViewName); + + AssertHelpers.assertThrows( + "Should throw an error if file_list_view is missing required columns", + IllegalArgumentException.class, + "does not exist. Available:", + () -> + sql( + "CALL %s.system.remove_orphan_files(table => '%s', file_list_view => '%s')", + catalogName, tableIdent, tempViewName)); + + spark + .createDataset(Lists.newArrayList(), Encoders.tuple(Encoders.INT(), Encoders.TIMESTAMP())) + .toDF("file_path", "last_modified") + .createOrReplaceTempView(tempViewName); + + AssertHelpers.assertThrows( + "Should throw an error if file_path has wrong type", + IllegalArgumentException.class, + "Invalid file_path column", + () -> + sql( + "CALL %s.system.remove_orphan_files(table => '%s', file_list_view => '%s')", + catalogName, tableIdent, tempViewName)); + + spark + .createDataset(Lists.newArrayList(), Encoders.tuple(Encoders.STRING(), Encoders.STRING())) + .toDF("file_path", "last_modified") + .createOrReplaceTempView(tempViewName); + + AssertHelpers.assertThrows( + "Should throw an error if last_modified has wrong type", + IllegalArgumentException.class, + "Invalid last_modified column", + () -> + sql( + "CALL %s.system.remove_orphan_files(table => '%s', file_list_view => '%s')", + catalogName, tableIdent, tempViewName)); + } + + @Test + public void testRemoveOrphanFilesWithDeleteFiles() throws Exception { + sql( + "CREATE TABLE %s (id int, data string) USING iceberg TBLPROPERTIES" + + "('format-version'='2', 'write.delete.mode'='merge-on-read')", + tableName); + + List records = + Lists.newArrayList( + new SimpleRecord(1, "a"), + new SimpleRecord(2, "b"), + new SimpleRecord(3, "c"), + new SimpleRecord(4, "d")); + spark + .createDataset(records, Encoders.bean(SimpleRecord.class)) + .coalesce(1) + .writeTo(tableName) + .append(); + sql("DELETE FROM %s WHERE id=1", tableName); + + Table table = Spark3Util.loadIcebergTable(spark, tableName); + Assert.assertEquals( + "Should have 1 delete manifest", 1, TestHelpers.deleteManifests(table).size()); + Assert.assertEquals("Should have 1 delete file", 1, TestHelpers.deleteFiles(table).size()); + Path deleteManifestPath = new Path(TestHelpers.deleteManifests(table).iterator().next().path()); + Path deleteFilePath = + new Path(String.valueOf(TestHelpers.deleteFiles(table).iterator().next().path())); + + // wait to ensure files are old enough + waitUntilAfter(System.currentTimeMillis()); + Timestamp currentTimestamp = Timestamp.from(Instant.ofEpochMilli(System.currentTimeMillis())); + + // delete orphans + List output = + sql( + "CALL %s.system.remove_orphan_files(" + + "table => '%s'," + + "older_than => TIMESTAMP '%s')", + catalogName, tableIdent, currentTimestamp); + Assert.assertEquals("Should be no orphan files", 0, output.size()); + + FileSystem localFs = FileSystem.getLocal(new Configuration()); + Assert.assertTrue("Delete manifest should still exist", localFs.exists(deleteManifestPath)); + Assert.assertTrue("Delete file should still exist", localFs.exists(deleteFilePath)); + + records.remove(new SimpleRecord(1, "a")); + Dataset resultDF = spark.read().format("iceberg").load(tableName); + List actualRecords = + resultDF.as(Encoders.bean(SimpleRecord.class)).collectAsList(); + Assert.assertEquals("Rows must match", records, actualRecords); + } + + @Test + public void testRemoveOrphanFilesWithStatisticFiles() throws Exception { + sql( + "CREATE TABLE %s USING iceberg " + + "TBLPROPERTIES('format-version'='2') " + + "AS SELECT 10 int, 'abc' data", + tableName); + Table table = Spark3Util.loadIcebergTable(spark, tableName); + + String statsFileName = "stats-file-" + UUID.randomUUID(); + File statsLocation = + new File(new URI(table.location())) + .toPath() + .resolve("data") + .resolve(statsFileName) + .toFile(); + StatisticsFile statisticsFile; + try (PuffinWriter puffinWriter = Puffin.write(Files.localOutput(statsLocation)).build()) { + long snapshotId = table.currentSnapshot().snapshotId(); + long snapshotSequenceNumber = table.currentSnapshot().sequenceNumber(); + puffinWriter.add( + new Blob( + "some-blob-type", + ImmutableList.of(1), + snapshotId, + snapshotSequenceNumber, + ByteBuffer.wrap("blob content".getBytes(StandardCharsets.UTF_8)))); + puffinWriter.finish(); + statisticsFile = + new GenericStatisticsFile( + snapshotId, + statsLocation.toString(), + puffinWriter.fileSize(), + puffinWriter.footerSize(), + puffinWriter.writtenBlobsMetadata().stream() + .map(GenericBlobMetadata::from) + .collect(ImmutableList.toImmutableList())); + } + + Transaction transaction = table.newTransaction(); + transaction + .updateStatistics() + .setStatistics(statisticsFile.snapshotId(), statisticsFile) + .commit(); + transaction.commitTransaction(); + + // wait to ensure files are old enough + waitUntilAfter(System.currentTimeMillis()); + Timestamp currentTimestamp = Timestamp.from(Instant.ofEpochMilli(System.currentTimeMillis())); + + List output = + sql( + "CALL %s.system.remove_orphan_files(" + + "table => '%s'," + + "older_than => TIMESTAMP '%s')", + catalogName, tableIdent, currentTimestamp); + Assertions.assertThat(output).as("Should be no orphan files").isEmpty(); + + Assertions.assertThat(statsLocation.exists()).as("stats file should exist").isTrue(); + Assertions.assertThat(statsLocation.length()) + .as("stats file length") + .isEqualTo(statisticsFile.fileSizeInBytes()); + + transaction = table.newTransaction(); + transaction.updateStatistics().removeStatistics(statisticsFile.snapshotId()).commit(); + transaction.commitTransaction(); + + output = + sql( + "CALL %s.system.remove_orphan_files(" + + "table => '%s'," + + "older_than => TIMESTAMP '%s')", + catalogName, tableIdent, currentTimestamp); + Assertions.assertThat(output).as("Should be orphan files").hasSize(1); + Assertions.assertThat(Iterables.getOnlyElement(output)) + .as("Deleted files") + .containsExactly(statsLocation.toURI().toString()); + Assertions.assertThat(statsLocation.exists()).as("stats file should be deleted").isFalse(); + } + + @Test + public void testRemoveOrphanFilesProcedureWithPrefixMode() + throws NoSuchTableException, ParseException, IOException { + if (catalogName.equals("testhadoop")) { + sql("CREATE TABLE %s (id bigint NOT NULL, data string) USING iceberg", tableName); + } else { + sql( + "CREATE TABLE %s (id bigint NOT NULL, data string) USING iceberg LOCATION '%s'", + tableName, temp.newFolder().toURI().toString()); + } + Table table = Spark3Util.loadIcebergTable(spark, tableName); + String location = table.location(); + Path originalPath = new Path(location); + + URI uri = originalPath.toUri(); + Path newParentPath = new Path("file1", uri.getAuthority(), uri.getPath()); + + DataFile dataFile1 = + DataFiles.builder(PartitionSpec.unpartitioned()) + .withPath(new Path(newParentPath, "path/to/data-a.parquet").toString()) + .withFileSizeInBytes(10) + .withRecordCount(1) + .build(); + DataFile dataFile2 = + DataFiles.builder(PartitionSpec.unpartitioned()) + .withPath(new Path(newParentPath, "path/to/data-b.parquet").toString()) + .withFileSizeInBytes(10) + .withRecordCount(1) + .build(); + + table.newFastAppend().appendFile(dataFile1).appendFile(dataFile2).commit(); + + Timestamp lastModifiedTimestamp = new Timestamp(10000); + + List allFiles = + Lists.newArrayList( + new FilePathLastModifiedRecord( + new Path(originalPath, "path/to/data-a.parquet").toString(), lastModifiedTimestamp), + new FilePathLastModifiedRecord( + new Path(originalPath, "path/to/data-b.parquet").toString(), lastModifiedTimestamp), + new FilePathLastModifiedRecord( + ReachableFileUtil.versionHintLocation(table), lastModifiedTimestamp)); + + for (String file : ReachableFileUtil.metadataFileLocations(table, true)) { + allFiles.add(new FilePathLastModifiedRecord(file, lastModifiedTimestamp)); + } + + for (ManifestFile manifest : TestHelpers.dataManifests(table)) { + allFiles.add(new FilePathLastModifiedRecord(manifest.path(), lastModifiedTimestamp)); + } + + Dataset compareToFileList = + spark + .createDataFrame(allFiles, FilePathLastModifiedRecord.class) + .withColumnRenamed("filePath", "file_path") + .withColumnRenamed("lastModified", "last_modified"); + String fileListViewName = "files_view"; + compareToFileList.createOrReplaceTempView(fileListViewName); + List orphanFiles = + sql( + "CALL %s.system.remove_orphan_files(" + + "table => '%s'," + + "equal_schemes => map('file1', 'file')," + + "file_list_view => '%s')", + catalogName, tableIdent, fileListViewName); + Assert.assertEquals(0, orphanFiles.size()); + + // Test with no equal schemes + AssertHelpers.assertThrows( + "Should complain about removing orphan files", + ValidationException.class, + "Conflicting authorities/schemes: [(file1, file)]", + () -> + sql( + "CALL %s.system.remove_orphan_files(" + + "table => '%s'," + + "file_list_view => '%s')", + catalogName, tableIdent, fileListViewName)); + + // Drop table in afterEach has purge and fails due to invalid scheme "file1" used in this test + // Dropping the table here + sql("DROP TABLE %s", tableName); + } +} diff --git a/spark/v3.4/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestReplaceBranch.java b/spark/v3.4/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestReplaceBranch.java new file mode 100644 index 000000000000..b63826e543b8 --- /dev/null +++ b/spark/v3.4/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestReplaceBranch.java @@ -0,0 +1,273 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.extensions; + +import java.util.List; +import java.util.Map; +import java.util.concurrent.TimeUnit; +import org.apache.iceberg.AssertHelpers; +import org.apache.iceberg.SnapshotRef; +import org.apache.iceberg.Table; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableList; +import org.apache.iceberg.spark.SparkCatalogConfig; +import org.apache.iceberg.spark.source.SimpleRecord; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.catalyst.analysis.NoSuchTableException; +import org.junit.After; +import org.junit.Assert; +import org.junit.Test; +import org.junit.runners.Parameterized; + +public class TestReplaceBranch extends SparkExtensionsTestBase { + + private static final String[] TIME_UNITS = {"DAYS", "HOURS", "MINUTES"}; + + @Parameterized.Parameters(name = "catalogName = {0}, implementation = {1}, config = {2}") + public static Object[][] parameters() { + return new Object[][] { + { + SparkCatalogConfig.SPARK.catalogName(), + SparkCatalogConfig.SPARK.implementation(), + SparkCatalogConfig.SPARK.properties() + } + }; + } + + public TestReplaceBranch(String catalogName, String implementation, Map config) { + super(catalogName, implementation, config); + } + + @After + public void removeTable() { + sql("DROP TABLE IF EXISTS %s", tableName); + } + + @Test + public void testReplaceBranchFailsForTag() throws NoSuchTableException { + sql("CREATE TABLE %s (id INT, data STRING) USING iceberg", tableName); + String tagName = "tag1"; + + List records = + ImmutableList.of(new SimpleRecord(1, "a"), new SimpleRecord(2, "b")); + Dataset df = spark.createDataFrame(records, SimpleRecord.class); + df.writeTo(tableName).append(); + Table table = validationCatalog.loadTable(tableIdent); + long first = table.currentSnapshot().snapshotId(); + table.manageSnapshots().createTag(tagName, first).commit(); + df.writeTo(tableName).append(); + long second = table.currentSnapshot().snapshotId(); + + AssertHelpers.assertThrows( + "Cannot perform replace branch on tags", + IllegalArgumentException.class, + "Ref tag1 is a tag not a branch", + () -> sql("ALTER TABLE %s REPLACE BRANCH %s AS OF VERSION %d", tableName, tagName, second)); + } + + @Test + public void testReplaceBranch() throws NoSuchTableException { + sql("CREATE TABLE %s (id INT, data STRING) USING iceberg", tableName); + List records = + ImmutableList.of(new SimpleRecord(1, "a"), new SimpleRecord(2, "b")); + Dataset df = spark.createDataFrame(records, SimpleRecord.class); + df.writeTo(tableName).append(); + + Table table = validationCatalog.loadTable(tableIdent); + long first = table.currentSnapshot().snapshotId(); + String branchName = "b1"; + long expectedMaxRefAgeMs = 1000; + int expectedMinSnapshotsToKeep = 2; + long expectedMaxSnapshotAgeMs = 1000; + table + .manageSnapshots() + .createBranch(branchName, first) + .setMaxRefAgeMs(branchName, expectedMaxRefAgeMs) + .setMinSnapshotsToKeep(branchName, expectedMinSnapshotsToKeep) + .setMaxSnapshotAgeMs(branchName, expectedMaxSnapshotAgeMs) + .commit(); + + df.writeTo(tableName).append(); + long second = table.currentSnapshot().snapshotId(); + + sql("ALTER TABLE %s REPLACE BRANCH %s AS OF VERSION %d", tableName, branchName, second); + + table.refresh(); + SnapshotRef ref = table.refs().get(branchName); + Assert.assertNotNull(ref); + Assert.assertEquals(second, ref.snapshotId()); + Assert.assertEquals(expectedMinSnapshotsToKeep, ref.minSnapshotsToKeep().intValue()); + Assert.assertEquals(expectedMaxSnapshotAgeMs, ref.maxSnapshotAgeMs().longValue()); + Assert.assertEquals(expectedMaxRefAgeMs, ref.maxRefAgeMs().longValue()); + } + + @Test + public void testReplaceBranchDoesNotExist() throws NoSuchTableException { + sql("CREATE TABLE %s (id INT, data STRING) USING iceberg", tableName); + List records = + ImmutableList.of(new SimpleRecord(1, "a"), new SimpleRecord(2, "b")); + Dataset df = spark.createDataFrame(records, SimpleRecord.class); + df.writeTo(tableName).append(); + Table table = validationCatalog.loadTable(tableIdent); + + AssertHelpers.assertThrows( + "Cannot perform replace branch on branch which does not exist", + IllegalArgumentException.class, + "Branch does not exist", + () -> + sql( + "ALTER TABLE %s REPLACE BRANCH %s AS OF VERSION %d", + tableName, "someBranch", table.currentSnapshot().snapshotId())); + } + + @Test + public void testReplaceBranchWithRetain() throws NoSuchTableException { + sql("CREATE TABLE %s (id INT, data STRING) USING iceberg", tableName); + List records = + ImmutableList.of(new SimpleRecord(1, "a"), new SimpleRecord(2, "b")); + Dataset df = spark.createDataFrame(records, SimpleRecord.class); + df.writeTo(tableName).append(); + + Table table = validationCatalog.loadTable(tableIdent); + long first = table.currentSnapshot().snapshotId(); + String branchName = "b1"; + table.manageSnapshots().createBranch(branchName, first).commit(); + SnapshotRef b1 = table.refs().get(branchName); + Integer minSnapshotsToKeep = b1.minSnapshotsToKeep(); + Long maxSnapshotAgeMs = b1.maxSnapshotAgeMs(); + df.writeTo(tableName).append(); + long second = table.currentSnapshot().snapshotId(); + + long maxRefAge = 10; + for (String timeUnit : TIME_UNITS) { + sql( + "ALTER TABLE %s REPLACE BRANCH %s AS OF VERSION %d RETAIN %d %s", + tableName, branchName, second, maxRefAge, timeUnit); + + table.refresh(); + SnapshotRef ref = table.refs().get(branchName); + Assert.assertNotNull(ref); + Assert.assertEquals(second, ref.snapshotId()); + Assert.assertEquals(minSnapshotsToKeep, ref.minSnapshotsToKeep()); + Assert.assertEquals(maxSnapshotAgeMs, ref.maxSnapshotAgeMs()); + Assert.assertEquals( + TimeUnit.valueOf(timeUnit).toMillis(maxRefAge), ref.maxRefAgeMs().longValue()); + } + } + + @Test + public void testReplaceBranchWithSnapshotRetention() throws NoSuchTableException { + sql("CREATE TABLE %s (id INT, data STRING) USING iceberg", tableName); + List records = + ImmutableList.of(new SimpleRecord(1, "a"), new SimpleRecord(2, "b")); + Dataset df = spark.createDataFrame(records, SimpleRecord.class); + df.writeTo(tableName).append(); + String branchName = "b1"; + Table table = validationCatalog.loadTable(tableIdent); + long first = table.currentSnapshot().snapshotId(); + table.manageSnapshots().createBranch(branchName, first).commit(); + df.writeTo(tableName).append(); + long second = table.currentSnapshot().snapshotId(); + + Integer minSnapshotsToKeep = 2; + long maxSnapshotAge = 2; + Long maxRefAgeMs = table.refs().get(branchName).maxRefAgeMs(); + for (String timeUnit : TIME_UNITS) { + sql( + "ALTER TABLE %s REPLACE BRANCH %s AS OF VERSION %d WITH SNAPSHOT RETENTION %d SNAPSHOTS %d %s", + tableName, branchName, second, minSnapshotsToKeep, maxSnapshotAge, timeUnit); + + table.refresh(); + SnapshotRef ref = table.refs().get(branchName); + Assert.assertNotNull(ref); + Assert.assertEquals(second, ref.snapshotId()); + Assert.assertEquals(minSnapshotsToKeep, ref.minSnapshotsToKeep()); + Assert.assertEquals( + TimeUnit.valueOf(timeUnit).toMillis(maxSnapshotAge), ref.maxSnapshotAgeMs().longValue()); + Assert.assertEquals(maxRefAgeMs, ref.maxRefAgeMs()); + } + } + + @Test + public void testReplaceBranchWithRetainAndSnapshotRetention() throws NoSuchTableException { + sql("CREATE TABLE %s (id INT, data STRING) USING iceberg", tableName); + List records = + ImmutableList.of(new SimpleRecord(1, "a"), new SimpleRecord(2, "b")); + Dataset df = spark.createDataFrame(records, SimpleRecord.class); + df.writeTo(tableName).append(); + + Table table = validationCatalog.loadTable(tableIdent); + long first = table.currentSnapshot().snapshotId(); + String branchName = "b1"; + table.manageSnapshots().createBranch(branchName, first).commit(); + df.writeTo(tableName).append(); + long second = table.currentSnapshot().snapshotId(); + + Integer minSnapshotsToKeep = 2; + long maxSnapshotAge = 2; + long maxRefAge = 10; + for (String timeUnit : TIME_UNITS) { + sql( + "ALTER TABLE %s REPLACE BRANCH %s AS OF VERSION %d RETAIN %d %s WITH SNAPSHOT RETENTION %d SNAPSHOTS %d %s", + tableName, + branchName, + second, + maxRefAge, + timeUnit, + minSnapshotsToKeep, + maxSnapshotAge, + timeUnit); + + table.refresh(); + SnapshotRef ref = table.refs().get(branchName); + Assert.assertNotNull(ref); + Assert.assertEquals(second, ref.snapshotId()); + Assert.assertEquals(minSnapshotsToKeep, ref.minSnapshotsToKeep()); + Assert.assertEquals( + TimeUnit.valueOf(timeUnit).toMillis(maxSnapshotAge), ref.maxSnapshotAgeMs().longValue()); + Assert.assertEquals( + TimeUnit.valueOf(timeUnit).toMillis(maxRefAge), ref.maxRefAgeMs().longValue()); + } + } + + @Test + public void testCreateOrReplace() throws NoSuchTableException { + sql("CREATE TABLE %s (id INT, data STRING) USING iceberg", tableName); + List records = + ImmutableList.of(new SimpleRecord(1, "a"), new SimpleRecord(2, "b")); + Dataset df = spark.createDataFrame(records, SimpleRecord.class); + df.writeTo(tableName).append(); + + Table table = validationCatalog.loadTable(tableIdent); + long first = table.currentSnapshot().snapshotId(); + String branchName = "b1"; + df.writeTo(tableName).append(); + long second = table.currentSnapshot().snapshotId(); + table.manageSnapshots().createBranch(branchName, second).commit(); + + sql( + "ALTER TABLE %s CREATE OR REPLACE BRANCH %s AS OF VERSION %d", + tableName, branchName, first); + + table.refresh(); + SnapshotRef ref = table.refs().get(branchName); + Assert.assertNotNull(ref); + Assert.assertEquals(first, ref.snapshotId()); + } +} diff --git a/spark/v3.4/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestRequiredDistributionAndOrdering.java b/spark/v3.4/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestRequiredDistributionAndOrdering.java new file mode 100644 index 000000000000..adb4fab41922 --- /dev/null +++ b/spark/v3.4/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestRequiredDistributionAndOrdering.java @@ -0,0 +1,300 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.extensions; + +import java.math.BigDecimal; +import java.util.List; +import java.util.Map; +import org.apache.iceberg.AssertHelpers; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableList; +import org.apache.iceberg.spark.SparkWriteOptions; +import org.apache.iceberg.spark.source.ThreeColumnRecord; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.catalyst.analysis.NoSuchTableException; +import org.junit.After; +import org.junit.Test; + +public class TestRequiredDistributionAndOrdering extends SparkExtensionsTestBase { + + public TestRequiredDistributionAndOrdering( + String catalogName, String implementation, Map config) { + super(catalogName, implementation, config); + } + + @After + public void dropTestTable() { + sql("DROP TABLE IF EXISTS %s", tableName); + } + + @Test + public void testDefaultLocalSortWithBucketTransforms() throws NoSuchTableException { + sql( + "CREATE TABLE %s (c1 INT, c2 STRING, c3 STRING) " + + "USING iceberg " + + "PARTITIONED BY (bucket(2, c1))", + tableName); + + List data = + ImmutableList.of( + new ThreeColumnRecord(1, null, "A"), + new ThreeColumnRecord(2, "BBBBBBBBBB", "B"), + new ThreeColumnRecord(3, "BBBBBBBBBB", "A"), + new ThreeColumnRecord(4, "BBBBBBBBBB", "B"), + new ThreeColumnRecord(5, "BBBBBBBBBB", "A"), + new ThreeColumnRecord(6, "BBBBBBBBBB", "B"), + new ThreeColumnRecord(7, "BBBBBBBBBB", "A")); + Dataset ds = spark.createDataFrame(data, ThreeColumnRecord.class); + Dataset inputDF = ds.coalesce(1).sortWithinPartitions("c1"); + + // should insert a local sort by partition columns by default + inputDF.writeTo(tableName).append(); + + assertEquals( + "Row count must match", + ImmutableList.of(row(7L)), + sql("SELECT count(*) FROM %s", tableName)); + } + + @Test + public void testPartitionColumnsArePrependedForRangeDistribution() throws NoSuchTableException { + sql( + "CREATE TABLE %s (c1 INT, c2 STRING, c3 STRING) " + + "USING iceberg " + + "PARTITIONED BY (bucket(2, c1))", + tableName); + + List data = + ImmutableList.of( + new ThreeColumnRecord(1, null, "A"), + new ThreeColumnRecord(2, "BBBBBBBBBB", "B"), + new ThreeColumnRecord(3, "BBBBBBBBBB", "A"), + new ThreeColumnRecord(4, "BBBBBBBBBB", "B"), + new ThreeColumnRecord(5, "BBBBBBBBBB", "A"), + new ThreeColumnRecord(6, "BBBBBBBBBB", "B"), + new ThreeColumnRecord(7, "BBBBBBBBBB", "A")); + Dataset ds = spark.createDataFrame(data, ThreeColumnRecord.class); + Dataset inputDF = ds.coalesce(1).sortWithinPartitions("c1"); + + // should automatically prepend partition columns to the ordering + sql("ALTER TABLE %s WRITE ORDERED BY c1, c2", tableName); + + inputDF.writeTo(tableName).append(); + + assertEquals( + "Row count must match", + ImmutableList.of(row(7L)), + sql("SELECT count(*) FROM %s", tableName)); + } + + @Test + public void testSortOrderIncludesPartitionColumns() throws NoSuchTableException { + sql( + "CREATE TABLE %s (c1 INT, c2 STRING, c3 STRING) " + + "USING iceberg " + + "PARTITIONED BY (bucket(2, c1))", + tableName); + + List data = + ImmutableList.of( + new ThreeColumnRecord(1, null, "A"), + new ThreeColumnRecord(2, "BBBBBBBBBB", "B"), + new ThreeColumnRecord(3, "BBBBBBBBBB", "A"), + new ThreeColumnRecord(4, "BBBBBBBBBB", "B"), + new ThreeColumnRecord(5, "BBBBBBBBBB", "A"), + new ThreeColumnRecord(6, "BBBBBBBBBB", "B"), + new ThreeColumnRecord(7, "BBBBBBBBBB", "A")); + Dataset ds = spark.createDataFrame(data, ThreeColumnRecord.class); + Dataset inputDF = ds.coalesce(1).sortWithinPartitions("c1"); + + // should succeed with a correct sort order + sql("ALTER TABLE %s WRITE ORDERED BY bucket(2, c3), c1, c2", tableName); + + inputDF.writeTo(tableName).append(); + + assertEquals( + "Row count must match", + ImmutableList.of(row(7L)), + sql("SELECT count(*) FROM %s", tableName)); + } + + @Test + public void testHashDistributionOnBucketedColumn() throws NoSuchTableException { + sql( + "CREATE TABLE %s (c1 INT, c2 STRING, c3 STRING) " + + "USING iceberg " + + "PARTITIONED BY (bucket(2, c1))", + tableName); + + List data = + ImmutableList.of( + new ThreeColumnRecord(1, null, "A"), + new ThreeColumnRecord(2, "BBBBBBBBBB", "B"), + new ThreeColumnRecord(3, "BBBBBBBBBB", "A"), + new ThreeColumnRecord(4, "BBBBBBBBBB", "B"), + new ThreeColumnRecord(5, "BBBBBBBBBB", "A"), + new ThreeColumnRecord(6, "BBBBBBBBBB", "B"), + new ThreeColumnRecord(7, "BBBBBBBBBB", "A")); + Dataset ds = spark.createDataFrame(data, ThreeColumnRecord.class); + Dataset inputDF = ds.coalesce(1).sortWithinPartitions("c1"); + + // should automatically prepend partition columns to the local ordering after hash distribution + sql("ALTER TABLE %s WRITE DISTRIBUTED BY PARTITION ORDERED BY c1, c2", tableName); + + inputDF.writeTo(tableName).append(); + + assertEquals( + "Row count must match", + ImmutableList.of(row(7L)), + sql("SELECT count(*) FROM %s", tableName)); + } + + @Test + public void testDisabledDistributionAndOrdering() { + sql( + "CREATE TABLE %s (c1 INT, c2 STRING, c3 STRING) " + + "USING iceberg " + + "PARTITIONED BY (bucket(2, c1))", + tableName); + + List data = + ImmutableList.of( + new ThreeColumnRecord(1, null, "A"), + new ThreeColumnRecord(2, "BBBBBBBBBB", "B"), + new ThreeColumnRecord(3, "BBBBBBBBBB", "A"), + new ThreeColumnRecord(4, "BBBBBBBBBB", "B"), + new ThreeColumnRecord(5, "BBBBBBBBBB", "A"), + new ThreeColumnRecord(6, "BBBBBBBBBB", "B"), + new ThreeColumnRecord(7, "BBBBBBBBBB", "A")); + Dataset ds = spark.createDataFrame(data, ThreeColumnRecord.class); + Dataset inputDF = ds.coalesce(1).sortWithinPartitions("c1"); + + // should fail if ordering is disabled + AssertHelpers.assertThrowsCause( + "Should reject writes without ordering", + IllegalStateException.class, + "Encountered records that belong to already closed files", + () -> { + try { + inputDF + .writeTo(tableName) + .option(SparkWriteOptions.USE_TABLE_DISTRIBUTION_AND_ORDERING, "false") + .append(); + } catch (NoSuchTableException e) { + throw new RuntimeException(e); + } + }); + } + + @Test + public void testDefaultSortOnDecimalBucketedColumn() { + sql( + "CREATE TABLE %s (c1 INT, c2 DECIMAL(20, 2)) " + + "USING iceberg " + + "PARTITIONED BY (bucket(2, c2))", + tableName); + + sql("INSERT INTO %s VALUES (1, 20.2), (2, 40.2), (3, 60.2)", tableName); + + List expected = + ImmutableList.of( + row(1, new BigDecimal("20.20")), + row(2, new BigDecimal("40.20")), + row(3, new BigDecimal("60.20"))); + + assertEquals("Rows must match", expected, sql("SELECT * FROM %s ORDER BY c1", tableName)); + } + + @Test + public void testDefaultSortOnStringBucketedColumn() { + sql( + "CREATE TABLE %s (c1 INT, c2 STRING) " + + "USING iceberg " + + "PARTITIONED BY (bucket(2, c2))", + tableName); + + sql("INSERT INTO %s VALUES (1, 'A'), (2, 'B')", tableName); + + List expected = ImmutableList.of(row(1, "A"), row(2, "B")); + + assertEquals("Rows must match", expected, sql("SELECT * FROM %s ORDER BY c1", tableName)); + } + + @Test + public void testDefaultSortOnDecimalTruncatedColumn() { + sql( + "CREATE TABLE %s (c1 INT, c2 DECIMAL(20, 2)) " + + "USING iceberg " + + "PARTITIONED BY (truncate(2, c2))", + tableName); + + sql("INSERT INTO %s VALUES (1, 20.2), (2, 40.2)", tableName); + + List expected = + ImmutableList.of(row(1, new BigDecimal("20.20")), row(2, new BigDecimal("40.20"))); + + assertEquals("Rows must match", expected, sql("SELECT * FROM %s ORDER BY c1", tableName)); + } + + @Test + public void testDefaultSortOnLongTruncatedColumn() { + sql( + "CREATE TABLE %s (c1 INT, c2 BIGINT) " + + "USING iceberg " + + "PARTITIONED BY (truncate(2, c2))", + tableName); + + sql("INSERT INTO %s VALUES (1, 22222222222222), (2, 444444444444)", tableName); + + List expected = ImmutableList.of(row(1, 22222222222222L), row(2, 444444444444L)); + + assertEquals("Rows must match", expected, sql("SELECT * FROM %s ORDER BY c1", tableName)); + } + + @Test + public void testRangeDistributionWithQuotedColumnNames() throws NoSuchTableException { + sql( + "CREATE TABLE %s (`c.1` INT, c2 STRING, c3 STRING) " + + "USING iceberg " + + "PARTITIONED BY (bucket(2, `c.1`))", + tableName); + + List data = + ImmutableList.of( + new ThreeColumnRecord(1, null, "A"), + new ThreeColumnRecord(2, "BBBBBBBBBB", "B"), + new ThreeColumnRecord(3, "BBBBBBBBBB", "A"), + new ThreeColumnRecord(4, "BBBBBBBBBB", "B"), + new ThreeColumnRecord(5, "BBBBBBBBBB", "A"), + new ThreeColumnRecord(6, "BBBBBBBBBB", "B"), + new ThreeColumnRecord(7, "BBBBBBBBBB", "A")); + Dataset ds = spark.createDataFrame(data, ThreeColumnRecord.class); + Dataset inputDF = + ds.selectExpr("c1 as `c.1`", "c2", "c3").coalesce(1).sortWithinPartitions("`c.1`"); + + sql("ALTER TABLE %s WRITE ORDERED BY `c.1`, c2", tableName); + + inputDF.writeTo(tableName).append(); + + assertEquals( + "Row count must match", + ImmutableList.of(row(7L)), + sql("SELECT count(*) FROM %s", tableName)); + } +} diff --git a/spark/v3.4/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestRewriteDataFilesProcedure.java b/spark/v3.4/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestRewriteDataFilesProcedure.java new file mode 100644 index 000000000000..44aca898b696 --- /dev/null +++ b/spark/v3.4/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestRewriteDataFilesProcedure.java @@ -0,0 +1,714 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.extensions; + +import static org.assertj.core.api.Assertions.assertThat; + +import java.util.Arrays; +import java.util.List; +import java.util.Map; +import java.util.stream.IntStream; +import org.apache.iceberg.AssertHelpers; +import org.apache.iceberg.SnapshotSummary; +import org.apache.iceberg.TableProperties; +import org.apache.iceberg.catalog.TableIdentifier; +import org.apache.iceberg.exceptions.ValidationException; +import org.apache.iceberg.expressions.NamedReference; +import org.apache.iceberg.expressions.Zorder; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableList; +import org.apache.iceberg.relocated.com.google.common.collect.Lists; +import org.apache.iceberg.spark.ExtendedParser; +import org.apache.iceberg.spark.SparkCatalogConfig; +import org.apache.iceberg.spark.SparkTableCache; +import org.apache.iceberg.spark.source.ThreeColumnRecord; +import org.apache.spark.sql.AnalysisException; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.catalyst.analysis.NoSuchProcedureException; +import org.junit.After; +import org.junit.Assert; +import org.junit.Assume; +import org.junit.Test; + +public class TestRewriteDataFilesProcedure extends SparkExtensionsTestBase { + + private static final String QUOTED_SPECIAL_CHARS_TABLE_NAME = "`table:with.special:chars`"; + + public TestRewriteDataFilesProcedure( + String catalogName, String implementation, Map config) { + super(catalogName, implementation, config); + } + + @After + public void removeTable() { + sql("DROP TABLE IF EXISTS %s", tableName); + sql("DROP TABLE IF EXISTS %s", tableName(QUOTED_SPECIAL_CHARS_TABLE_NAME)); + } + + @Test + public void testZOrderSortExpression() { + List order = + ExtendedParser.parseSortOrder(spark, "c1, zorder(c2, c3)"); + Assert.assertEquals("Should parse 2 order fields", 2, order.size()); + Assert.assertEquals( + "First field should be a ref", "c1", ((NamedReference) order.get(0).term()).name()); + Assert.assertTrue("Second field should be zorder", order.get(1).term() instanceof Zorder); + } + + @Test + public void testRewriteDataFilesInEmptyTable() { + createTable(); + List output = sql("CALL %s.system.rewrite_data_files('%s')", catalogName, tableIdent); + assertEquals("Procedure output must match", ImmutableList.of(row(0, 0, 0L)), output); + } + + @Test + public void testRewriteDataFilesOnPartitionTable() { + createPartitionTable(); + // create 5 files for each partition (c2 = 'foo' and c2 = 'bar') + insertData(10); + List expectedRecords = currentData(); + + List output = + sql("CALL %s.system.rewrite_data_files(table => '%s')", catalogName, tableIdent); + + assertEquals( + "Action should rewrite 10 data files and add 2 data files (one per partition) ", + row(10, 2), + Arrays.copyOf(output.get(0), 2)); + // verify rewritten bytes separately + assertThat(output.get(0)).hasSize(3); + assertThat(output.get(0)[2]) + .isInstanceOf(Long.class) + .isEqualTo(Long.valueOf(snapshotSummary().get(SnapshotSummary.REMOVED_FILE_SIZE_PROP))); + + List actualRecords = currentData(); + assertEquals("Data after compaction should not change", expectedRecords, actualRecords); + } + + @Test + public void testRewriteDataFilesOnNonPartitionTable() { + createTable(); + // create 10 files under non-partitioned table + insertData(10); + List expectedRecords = currentData(); + + List output = + sql("CALL %s.system.rewrite_data_files(table => '%s')", catalogName, tableIdent); + + assertEquals( + "Action should rewrite 10 data files and add 1 data files", + row(10, 1), + Arrays.copyOf(output.get(0), 2)); + // verify rewritten bytes separately + assertThat(output.get(0)).hasSize(3); + assertThat(output.get(0)[2]) + .isInstanceOf(Long.class) + .isEqualTo(Long.valueOf(snapshotSummary().get(SnapshotSummary.REMOVED_FILE_SIZE_PROP))); + + List actualRecords = currentData(); + assertEquals("Data after compaction should not change", expectedRecords, actualRecords); + } + + @Test + public void testRewriteDataFilesWithOptions() { + createTable(); + // create 10 files under non-partitioned table + insertData(10); + List expectedRecords = currentData(); + + // set the min-input-files = 12, instead of default 5 to skip compacting the files. + List output = + sql( + "CALL %s.system.rewrite_data_files(table => '%s', options => map('min-input-files','12'))", + catalogName, tableIdent); + + assertEquals( + "Action should rewrite 0 data files and add 0 data files", + ImmutableList.of(row(0, 0, 0L)), + output); + + List actualRecords = currentData(); + assertEquals("Data should not change", expectedRecords, actualRecords); + } + + @Test + public void testRewriteDataFilesWithSortStrategy() { + createTable(); + // create 10 files under non-partitioned table + insertData(10); + List expectedRecords = currentData(); + + // set sort_order = c1 DESC LAST + List output = + sql( + "CALL %s.system.rewrite_data_files(table => '%s', " + + "strategy => 'sort', sort_order => 'c1 DESC NULLS LAST')", + catalogName, tableIdent); + + assertEquals( + "Action should rewrite 10 data files and add 1 data files", + row(10, 1), + Arrays.copyOf(output.get(0), 2)); + // verify rewritten bytes separately + assertThat(output.get(0)).hasSize(3); + assertThat(output.get(0)[2]) + .isInstanceOf(Long.class) + .isEqualTo(Long.valueOf(snapshotSummary().get(SnapshotSummary.REMOVED_FILE_SIZE_PROP))); + + List actualRecords = currentData(); + assertEquals("Data after compaction should not change", expectedRecords, actualRecords); + } + + @Test + public void testRewriteDataFilesWithZOrder() { + createTable(); + // create 10 files under non-partitioned table + insertData(10); + + // set z_order = c1,c2 + List output = + sql( + "CALL %s.system.rewrite_data_files(table => '%s', " + + "strategy => 'sort', sort_order => 'zorder(c1,c2)')", + catalogName, tableIdent); + + assertEquals( + "Action should rewrite 10 data files and add 1 data files", + row(10, 1), + Arrays.copyOf(output.get(0), 2)); + // verify rewritten bytes separately + assertThat(output.get(0)).hasSize(3); + assertThat(output.get(0)[2]) + .isInstanceOf(Long.class) + .isEqualTo(Long.valueOf(snapshotSummary().get(SnapshotSummary.REMOVED_FILE_SIZE_PROP))); + + // Due to Z_order, the data written will be in the below order. + // As there is only one small output file, we can validate the query ordering (as it will not + // change). + ImmutableList expectedRows = + ImmutableList.of( + row(2, "bar", null), + row(2, "bar", null), + row(2, "bar", null), + row(2, "bar", null), + row(2, "bar", null), + row(1, "foo", null), + row(1, "foo", null), + row(1, "foo", null), + row(1, "foo", null), + row(1, "foo", null)); + assertEquals("Should have expected rows", expectedRows, sql("SELECT * FROM %s", tableName)); + } + + @Test + public void testRewriteDataFilesWithFilter() { + createTable(); + // create 10 files under non-partitioned table + insertData(10); + List expectedRecords = currentData(); + + // select only 5 files for compaction (files that may have c1 = 1) + List output = + sql( + "CALL %s.system.rewrite_data_files(table => '%s'," + + " where => 'c1 = 1 and c2 is not null')", + catalogName, tableIdent); + + assertEquals( + "Action should rewrite 5 data files (containing c1 = 1) and add 1 data files", + row(5, 1), + Arrays.copyOf(output.get(0), 2)); + // verify rewritten bytes separately + assertThat(output.get(0)).hasSize(3); + assertThat(output.get(0)[2]) + .isInstanceOf(Long.class) + .isEqualTo(Long.valueOf(snapshotSummary().get(SnapshotSummary.REMOVED_FILE_SIZE_PROP))); + + List actualRecords = currentData(); + assertEquals("Data after compaction should not change", expectedRecords, actualRecords); + } + + @Test + public void testRewriteDataFilesWithFilterOnPartitionTable() { + createPartitionTable(); + // create 5 files for each partition (c2 = 'foo' and c2 = 'bar') + insertData(10); + List expectedRecords = currentData(); + + // select only 5 files for compaction (files in the partition c2 = 'bar') + List output = + sql( + "CALL %s.system.rewrite_data_files(table => '%s'," + " where => 'c2 = \"bar\"')", + catalogName, tableIdent); + + assertEquals( + "Action should rewrite 5 data files from single matching partition" + + "(containing c2 = bar) and add 1 data files", + row(5, 1), + Arrays.copyOf(output.get(0), 2)); + // verify rewritten bytes separately + assertThat(output.get(0)).hasSize(3); + assertThat(output.get(0)[2]) + .isInstanceOf(Long.class) + .isEqualTo(Long.valueOf(snapshotSummary().get(SnapshotSummary.REMOVED_FILE_SIZE_PROP))); + + List actualRecords = currentData(); + assertEquals("Data after compaction should not change", expectedRecords, actualRecords); + } + + @Test + public void testRewriteDataFilesWithInFilterOnPartitionTable() { + createPartitionTable(); + // create 5 files for each partition (c2 = 'foo' and c2 = 'bar') + insertData(10); + List expectedRecords = currentData(); + + // select only 5 files for compaction (files in the partition c2 in ('bar')) + List output = + sql( + "CALL %s.system.rewrite_data_files(table => '%s'," + " where => 'c2 in (\"bar\")')", + catalogName, tableIdent); + + assertEquals( + "Action should rewrite 5 data files from single matching partition" + + "(containing c2 = bar) and add 1 data files", + row(5, 1), + Arrays.copyOf(output.get(0), 2)); + // verify rewritten bytes separately + assertThat(output.get(0)).hasSize(3); + assertThat(output.get(0)[2]) + .isInstanceOf(Long.class) + .isEqualTo(Long.valueOf(snapshotSummary().get(SnapshotSummary.REMOVED_FILE_SIZE_PROP))); + + List actualRecords = currentData(); + assertEquals("Data after compaction should not change", expectedRecords, actualRecords); + } + + @Test + public void testRewriteDataFilesWithAllPossibleFilters() { + createPartitionTable(); + // create 5 files for each partition (c2 = 'foo' and c2 = 'bar') + insertData(10); + + // Pass the literal value which is not present in the data files. + // So that parsing can be tested on a same dataset without actually compacting the files. + + // EqualTo + sql( + "CALL %s.system.rewrite_data_files(table => '%s'," + " where => 'c1 = 3')", + catalogName, tableIdent); + // GreaterThan + sql( + "CALL %s.system.rewrite_data_files(table => '%s'," + " where => 'c1 > 3')", + catalogName, tableIdent); + // GreaterThanOrEqual + sql( + "CALL %s.system.rewrite_data_files(table => '%s'," + " where => 'c1 >= 3')", + catalogName, tableIdent); + // LessThan + sql( + "CALL %s.system.rewrite_data_files(table => '%s'," + " where => 'c1 < 0')", + catalogName, tableIdent); + // LessThanOrEqual + sql( + "CALL %s.system.rewrite_data_files(table => '%s'," + " where => 'c1 <= 0')", + catalogName, tableIdent); + // In + sql( + "CALL %s.system.rewrite_data_files(table => '%s'," + " where => 'c1 in (3,4,5)')", + catalogName, tableIdent); + // IsNull + sql( + "CALL %s.system.rewrite_data_files(table => '%s'," + " where => 'c1 is null')", + catalogName, tableIdent); + // IsNotNull + sql( + "CALL %s.system.rewrite_data_files(table => '%s'," + " where => 'c3 is not null')", + catalogName, tableIdent); + // And + sql( + "CALL %s.system.rewrite_data_files(table => '%s'," + " where => 'c1 = 3 and c2 = \"bar\"')", + catalogName, tableIdent); + // Or + sql( + "CALL %s.system.rewrite_data_files(table => '%s'," + " where => 'c1 = 3 or c1 = 5')", + catalogName, tableIdent); + // Not + sql( + "CALL %s.system.rewrite_data_files(table => '%s'," + " where => 'c1 not in (1,2)')", + catalogName, tableIdent); + // StringStartsWith + sql( + "CALL %s.system.rewrite_data_files(table => '%s'," + " where => 'c2 like \"%s\"')", + catalogName, tableIdent, "car%"); + + // TODO: Enable when org.apache.iceberg.spark.SparkFilters have implementations for + // StringEndsWith & StringContains + // StringEndsWith + // sql("CALL %s.system.rewrite_data_files(table => '%s'," + + // " where => 'c2 like \"%s\"')", catalogName, tableIdent, "%car"); + // StringContains + // sql("CALL %s.system.rewrite_data_files(table => '%s'," + + // " where => 'c2 like \"%s\"')", catalogName, tableIdent, "%car%"); + } + + @Test + public void testRewriteDataFilesWithInvalidInputs() { + createTable(); + // create 2 files under non-partitioned table + insertData(2); + + // Test for invalid strategy + AssertHelpers.assertThrows( + "Should reject calls with unsupported strategy error message", + IllegalArgumentException.class, + "unsupported strategy: temp. Only binpack or sort is supported", + () -> + sql( + "CALL %s.system.rewrite_data_files(table => '%s', options => map('min-input-files','2'), " + + "strategy => 'temp')", + catalogName, tableIdent)); + + // Test for sort_order with binpack strategy + AssertHelpers.assertThrows( + "Should reject calls with error message", + IllegalArgumentException.class, + "Must use only one rewriter type (bin-pack, sort, zorder)", + () -> + sql( + "CALL %s.system.rewrite_data_files(table => '%s', strategy => 'binpack', " + + "sort_order => 'c1 ASC NULLS FIRST')", + catalogName, tableIdent)); + + // Test for sort strategy without any (default/user defined) sort_order + AssertHelpers.assertThrows( + "Should reject calls with error message", + IllegalArgumentException.class, + "Cannot sort data without a valid sort order", + () -> + sql( + "CALL %s.system.rewrite_data_files(table => '%s', strategy => 'sort')", + catalogName, tableIdent)); + + // Test for sort_order with invalid null order + AssertHelpers.assertThrows( + "Should reject calls with error message", + IllegalArgumentException.class, + "Unable to parse sortOrder:", + () -> + sql( + "CALL %s.system.rewrite_data_files(table => '%s', strategy => 'sort', " + + "sort_order => 'c1 ASC none')", + catalogName, tableIdent)); + + // Test for sort_order with invalid sort direction + AssertHelpers.assertThrows( + "Should reject calls with error message", + IllegalArgumentException.class, + "Unable to parse sortOrder:", + () -> + sql( + "CALL %s.system.rewrite_data_files(table => '%s', strategy => 'sort', " + + "sort_order => 'c1 none NULLS FIRST')", + catalogName, tableIdent)); + + // Test for sort_order with invalid column name + AssertHelpers.assertThrows( + "Should reject calls with error message", + ValidationException.class, + "Cannot find field 'col1' in struct:" + + " struct<1: c1: optional int, 2: c2: optional string, 3: c3: optional string>", + () -> + sql( + "CALL %s.system.rewrite_data_files(table => '%s', strategy => 'sort', " + + "sort_order => 'col1 DESC NULLS FIRST')", + catalogName, tableIdent)); + + // Test with invalid filter column col1 + AssertHelpers.assertThrows( + "Should reject calls with error message", + IllegalArgumentException.class, + "Cannot parse predicates in where option: col1 = 3", + () -> + sql( + "CALL %s.system.rewrite_data_files(table => '%s', " + "where => 'col1 = 3')", + catalogName, tableIdent)); + + // Test for z_order with invalid column name + AssertHelpers.assertThrows( + "Should reject calls with error message", + IllegalArgumentException.class, + "Cannot find column 'col1' in table schema (case sensitive = false): " + + "struct<1: c1: optional int, 2: c2: optional string, 3: c3: optional string>", + () -> + sql( + "CALL %s.system.rewrite_data_files(table => '%s', strategy => 'sort', " + + "sort_order => 'zorder(col1)')", + catalogName, tableIdent)); + + // Test for z_order with sort_order + AssertHelpers.assertThrows( + "Should reject calls with error message", + IllegalArgumentException.class, + "Cannot mix identity sort columns and a Zorder sort expression:" + " c1,zorder(c2,c3)", + () -> + sql( + "CALL %s.system.rewrite_data_files(table => '%s', strategy => 'sort', " + + "sort_order => 'c1,zorder(c2,c3)')", + catalogName, tableIdent)); + } + + @Test + public void testInvalidCasesForRewriteDataFiles() { + AssertHelpers.assertThrows( + "Should not allow mixed args", + AnalysisException.class, + "Named and positional arguments cannot be mixed", + () -> sql("CALL %s.system.rewrite_data_files('n', table => 't')", catalogName)); + + AssertHelpers.assertThrows( + "Should not resolve procedures in arbitrary namespaces", + NoSuchProcedureException.class, + "not found", + () -> sql("CALL %s.custom.rewrite_data_files('n', 't')", catalogName)); + + AssertHelpers.assertThrows( + "Should reject calls without all required args", + AnalysisException.class, + "Missing required parameters", + () -> sql("CALL %s.system.rewrite_data_files()", catalogName)); + + AssertHelpers.assertThrows( + "Should reject duplicate arg names name", + AnalysisException.class, + "Duplicate procedure argument: table", + () -> sql("CALL %s.system.rewrite_data_files(table => 't', table => 't')", catalogName)); + + AssertHelpers.assertThrows( + "Should reject calls with empty table identifier", + IllegalArgumentException.class, + "Cannot handle an empty identifier", + () -> sql("CALL %s.system.rewrite_data_files('')", catalogName)); + } + + @Test + public void testBinPackTableWithSpecialChars() { + Assume.assumeTrue(catalogName.equals(SparkCatalogConfig.HADOOP.catalogName())); + + TableIdentifier identifier = + TableIdentifier.of("default", QUOTED_SPECIAL_CHARS_TABLE_NAME.replaceAll("`", "")); + sql( + "CREATE TABLE %s (c1 int, c2 string, c3 string) USING iceberg", + tableName(QUOTED_SPECIAL_CHARS_TABLE_NAME)); + + insertData(tableName(QUOTED_SPECIAL_CHARS_TABLE_NAME), 10); + + List expectedRecords = currentData(tableName(QUOTED_SPECIAL_CHARS_TABLE_NAME)); + + List output = + sql( + "CALL %s.system.rewrite_data_files(table => '%s', where => 'c2 is not null')", + catalogName, tableName(QUOTED_SPECIAL_CHARS_TABLE_NAME)); + + assertEquals( + "Action should rewrite 10 data files and add 1 data file", + row(10, 1), + Arrays.copyOf(output.get(0), 2)); + // verify rewritten bytes separately + assertThat(output.get(0)).hasSize(3); + assertThat(output.get(0)[2]) + .isEqualTo( + Long.valueOf(snapshotSummary(identifier).get(SnapshotSummary.REMOVED_FILE_SIZE_PROP))); + + List actualRecords = currentData(tableName(QUOTED_SPECIAL_CHARS_TABLE_NAME)); + assertEquals("Data after compaction should not change", expectedRecords, actualRecords); + + Assert.assertEquals("Table cache must be empty", 0, SparkTableCache.get().size()); + } + + @Test + public void testSortTableWithSpecialChars() { + Assume.assumeTrue(catalogName.equals(SparkCatalogConfig.HADOOP.catalogName())); + + TableIdentifier identifier = + TableIdentifier.of("default", QUOTED_SPECIAL_CHARS_TABLE_NAME.replaceAll("`", "")); + sql( + "CREATE TABLE %s (c1 int, c2 string, c3 string) USING iceberg", + tableName(QUOTED_SPECIAL_CHARS_TABLE_NAME)); + + insertData(tableName(QUOTED_SPECIAL_CHARS_TABLE_NAME), 10); + + List expectedRecords = currentData(tableName(QUOTED_SPECIAL_CHARS_TABLE_NAME)); + + List output = + sql( + "CALL %s.system.rewrite_data_files(" + + " table => '%s'," + + " strategy => 'sort'," + + " sort_order => 'c1'," + + " where => 'c2 is not null')", + catalogName, tableName(QUOTED_SPECIAL_CHARS_TABLE_NAME)); + + assertEquals( + "Action should rewrite 10 data files and add 1 data file", + row(10, 1), + Arrays.copyOf(output.get(0), 2)); + // verify rewritten bytes separately + assertThat(output.get(0)).hasSize(3); + assertThat(output.get(0)[2]) + .isInstanceOf(Long.class) + .isEqualTo( + Long.valueOf(snapshotSummary(identifier).get(SnapshotSummary.REMOVED_FILE_SIZE_PROP))); + + List actualRecords = currentData(tableName(QUOTED_SPECIAL_CHARS_TABLE_NAME)); + assertEquals("Data after compaction should not change", expectedRecords, actualRecords); + + Assert.assertEquals("Table cache must be empty", 0, SparkTableCache.get().size()); + } + + @Test + public void testZOrderTableWithSpecialChars() { + Assume.assumeTrue(catalogName.equals(SparkCatalogConfig.HADOOP.catalogName())); + + TableIdentifier identifier = + TableIdentifier.of("default", QUOTED_SPECIAL_CHARS_TABLE_NAME.replaceAll("`", "")); + sql( + "CREATE TABLE %s (c1 int, c2 string, c3 string) USING iceberg", + tableName(QUOTED_SPECIAL_CHARS_TABLE_NAME)); + + insertData(tableName(QUOTED_SPECIAL_CHARS_TABLE_NAME), 10); + + List expectedRecords = currentData(tableName(QUOTED_SPECIAL_CHARS_TABLE_NAME)); + + List output = + sql( + "CALL %s.system.rewrite_data_files(" + + " table => '%s'," + + " strategy => 'sort'," + + " sort_order => 'zorder(c1, c2)'," + + " where => 'c2 is not null')", + catalogName, tableName(QUOTED_SPECIAL_CHARS_TABLE_NAME)); + + assertEquals( + "Action should rewrite 10 data files and add 1 data file", + row(10, 1), + Arrays.copyOf(output.get(0), 2)); + // verify rewritten bytes separately + assertThat(output.get(0)).hasSize(3); + assertThat(output.get(0)[2]) + .isInstanceOf(Long.class) + .isEqualTo( + Long.valueOf(snapshotSummary(identifier).get(SnapshotSummary.REMOVED_FILE_SIZE_PROP))); + + List actualRecords = currentData(tableName(QUOTED_SPECIAL_CHARS_TABLE_NAME)); + assertEquals("Data after compaction should not change", expectedRecords, actualRecords); + + Assert.assertEquals("Table cache must be empty", 0, SparkTableCache.get().size()); + } + + @Test + public void testDefaultSortOrder() { + createTable(); + // add a default sort order for a table + sql("ALTER TABLE %s WRITE ORDERED BY c2", tableName); + + // this creates 2 files under non-partitioned table due to sort order. + insertData(10); + List expectedRecords = currentData(); + + // When the strategy is set to 'sort' but the sort order is not specified, + // use table's default sort order. + List output = + sql( + "CALL %s.system.rewrite_data_files(table => '%s', " + + "strategy => 'sort', " + + "options => map('min-input-files','2'))", + catalogName, tableIdent); + + assertEquals( + "Action should rewrite 2 data files and add 1 data files", + row(2, 1), + Arrays.copyOf(output.get(0), 2)); + // verify rewritten bytes separately + assertThat(output.get(0)).hasSize(3); + assertThat(output.get(0)[2]) + .isInstanceOf(Long.class) + .isEqualTo(Long.valueOf(snapshotSummary().get(SnapshotSummary.REMOVED_FILE_SIZE_PROP))); + + List actualRecords = currentData(); + assertEquals("Data after compaction should not change", expectedRecords, actualRecords); + } + + private void createTable() { + sql("CREATE TABLE %s (c1 int, c2 string, c3 string) USING iceberg", tableName); + } + + private void createPartitionTable() { + sql( + "CREATE TABLE %s (c1 int, c2 string, c3 string) " + + "USING iceberg " + + "PARTITIONED BY (c2) " + + "TBLPROPERTIES ('%s' '%s')", + tableName, + TableProperties.WRITE_DISTRIBUTION_MODE, + TableProperties.WRITE_DISTRIBUTION_MODE_NONE); + } + + private void insertData(int filesCount) { + insertData(tableName, filesCount); + } + + private void insertData(String table, int filesCount) { + ThreeColumnRecord record1 = new ThreeColumnRecord(1, "foo", null); + ThreeColumnRecord record2 = new ThreeColumnRecord(2, "bar", null); + + List records = Lists.newArrayList(); + IntStream.range(0, filesCount / 2) + .forEach( + i -> { + records.add(record1); + records.add(record2); + }); + + Dataset df = + spark.createDataFrame(records, ThreeColumnRecord.class).repartition(filesCount); + try { + df.writeTo(table).append(); + } catch (org.apache.spark.sql.catalyst.analysis.NoSuchTableException e) { + throw new RuntimeException(e); + } + } + + private Map snapshotSummary() { + return snapshotSummary(tableIdent); + } + + private Map snapshotSummary(TableIdentifier tableIdentifier) { + return validationCatalog.loadTable(tableIdentifier).currentSnapshot().summary(); + } + + private List currentData() { + return currentData(tableName); + } + + private List currentData(String table) { + return rowsToJava(spark.sql("SELECT * FROM " + table + " order by c1, c2, c3").collectAsList()); + } +} diff --git a/spark/v3.4/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestRewriteManifestsProcedure.java b/spark/v3.4/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestRewriteManifestsProcedure.java new file mode 100644 index 000000000000..40625b5e3450 --- /dev/null +++ b/spark/v3.4/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestRewriteManifestsProcedure.java @@ -0,0 +1,342 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.extensions; + +import static org.apache.iceberg.TableProperties.SNAPSHOT_ID_INHERITANCE_ENABLED; + +import java.sql.Date; +import java.sql.Timestamp; +import java.util.List; +import java.util.Map; +import org.apache.iceberg.AssertHelpers; +import org.apache.iceberg.Table; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableList; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap; +import org.apache.spark.sql.AnalysisException; +import org.apache.spark.sql.RowFactory; +import org.apache.spark.sql.catalyst.analysis.NoSuchProcedureException; +import org.apache.spark.sql.catalyst.analysis.NoSuchTableException; +import org.junit.After; +import org.junit.Assert; +import org.junit.Test; + +public class TestRewriteManifestsProcedure extends SparkExtensionsTestBase { + + public TestRewriteManifestsProcedure( + String catalogName, String implementation, Map config) { + super(catalogName, implementation, config); + } + + @After + public void removeTable() { + sql("DROP TABLE IF EXISTS %s", tableName); + } + + @Test + public void testRewriteManifestsInEmptyTable() { + sql("CREATE TABLE %s (id bigint NOT NULL, data string) USING iceberg", tableName); + List output = sql("CALL %s.system.rewrite_manifests('%s')", catalogName, tableIdent); + assertEquals("Procedure output must match", ImmutableList.of(row(0, 0)), output); + } + + @Test + public void testRewriteLargeManifests() { + sql( + "CREATE TABLE %s (id bigint NOT NULL, data string) USING iceberg PARTITIONED BY (data)", + tableName); + sql("INSERT INTO TABLE %s VALUES (1, 'a'), (2, 'b'), (3, 'c'), (4, 'd')", tableName); + + Table table = validationCatalog.loadTable(tableIdent); + + Assert.assertEquals( + "Must have 1 manifest", 1, table.currentSnapshot().allManifests(table.io()).size()); + + sql("ALTER TABLE %s SET TBLPROPERTIES ('commit.manifest.target-size-bytes' '1')", tableName); + + List output = sql("CALL %s.system.rewrite_manifests('%s')", catalogName, tableIdent); + assertEquals("Procedure output must match", ImmutableList.of(row(1, 4)), output); + + table.refresh(); + + Assert.assertEquals( + "Must have 4 manifests", 4, table.currentSnapshot().allManifests(table.io()).size()); + } + + @Test + public void testRewriteManifestsNoOp() { + sql( + "CREATE TABLE %s (id bigint NOT NULL, data string) USING iceberg PARTITIONED BY (data)", + tableName); + sql("INSERT INTO TABLE %s VALUES (1, 'a'), (2, 'b'), (3, 'c'), (4, 'd')", tableName); + + Table table = validationCatalog.loadTable(tableIdent); + + Assert.assertEquals( + "Must have 1 manifest", 1, table.currentSnapshot().allManifests(table.io()).size()); + + List output = sql("CALL %s.system.rewrite_manifests('%s')", catalogName, tableIdent); + // should not rewrite any manifests for no-op (output of rewrite is same as before and after) + assertEquals("Procedure output must match", ImmutableList.of(row(0, 0)), output); + + table.refresh(); + + Assert.assertEquals( + "Must have 1 manifests", 1, table.currentSnapshot().allManifests(table.io()).size()); + } + + @Test + public void testRewriteLargeManifestsOnDatePartitionedTableWithJava8APIEnabled() { + withSQLConf( + ImmutableMap.of("spark.sql.datetime.java8API.enabled", "true"), + () -> { + sql( + "CREATE TABLE %s (id INTEGER, name STRING, dept STRING, ts DATE) USING iceberg PARTITIONED BY (ts)", + tableName); + try { + spark + .createDataFrame( + ImmutableList.of( + RowFactory.create(1, "John Doe", "hr", Date.valueOf("2021-01-01")), + RowFactory.create(2, "Jane Doe", "hr", Date.valueOf("2021-01-02")), + RowFactory.create(3, "Matt Doe", "hr", Date.valueOf("2021-01-03")), + RowFactory.create(4, "Will Doe", "facilities", Date.valueOf("2021-01-04"))), + spark.table(tableName).schema()) + .writeTo(tableName) + .append(); + } catch (NoSuchTableException e) { + // not possible as we already created the table above. + throw new RuntimeException(e); + } + + Table table = validationCatalog.loadTable(tableIdent); + + Assert.assertEquals( + "Must have 1 manifest", 1, table.currentSnapshot().allManifests(table.io()).size()); + + sql( + "ALTER TABLE %s SET TBLPROPERTIES ('commit.manifest.target-size-bytes' '1')", + tableName); + + List output = + sql("CALL %s.system.rewrite_manifests('%s')", catalogName, tableIdent); + assertEquals("Procedure output must match", ImmutableList.of(row(1, 4)), output); + + table.refresh(); + + Assert.assertEquals( + "Must have 4 manifests", 4, table.currentSnapshot().allManifests(table.io()).size()); + }); + } + + @Test + public void testRewriteLargeManifestsOnTimestampPartitionedTableWithJava8APIEnabled() { + withSQLConf( + ImmutableMap.of("spark.sql.datetime.java8API.enabled", "true"), + () -> { + sql( + "CREATE TABLE %s (id INTEGER, name STRING, dept STRING, ts TIMESTAMP) USING iceberg PARTITIONED BY (ts)", + tableName); + try { + spark + .createDataFrame( + ImmutableList.of( + RowFactory.create( + 1, "John Doe", "hr", Timestamp.valueOf("2021-01-01 00:00:00")), + RowFactory.create( + 2, "Jane Doe", "hr", Timestamp.valueOf("2021-01-02 00:00:00")), + RowFactory.create( + 3, "Matt Doe", "hr", Timestamp.valueOf("2021-01-03 00:00:00")), + RowFactory.create( + 4, "Will Doe", "facilities", Timestamp.valueOf("2021-01-04 00:00:00"))), + spark.table(tableName).schema()) + .writeTo(tableName) + .append(); + } catch (NoSuchTableException e) { + // not possible as we already created the table above. + throw new RuntimeException(e); + } + + Table table = validationCatalog.loadTable(tableIdent); + + Assert.assertEquals( + "Must have 1 manifest", 1, table.currentSnapshot().allManifests(table.io()).size()); + + sql( + "ALTER TABLE %s SET TBLPROPERTIES ('commit.manifest.target-size-bytes' '1')", + tableName); + + List output = + sql("CALL %s.system.rewrite_manifests('%s')", catalogName, tableIdent); + assertEquals("Procedure output must match", ImmutableList.of(row(1, 4)), output); + + table.refresh(); + + Assert.assertEquals( + "Must have 4 manifests", 4, table.currentSnapshot().allManifests(table.io()).size()); + }); + } + + @Test + public void testRewriteSmallManifestsWithSnapshotIdInheritance() { + sql( + "CREATE TABLE %s (id bigint NOT NULL, data string) USING iceberg PARTITIONED BY (data)", + tableName); + + sql( + "ALTER TABLE %s SET TBLPROPERTIES ('%s' '%s')", + tableName, SNAPSHOT_ID_INHERITANCE_ENABLED, "true"); + + sql("INSERT INTO TABLE %s VALUES (1, 'a')", tableName); + sql("INSERT INTO TABLE %s VALUES (2, 'b')", tableName); + sql("INSERT INTO TABLE %s VALUES (3, 'c')", tableName); + sql("INSERT INTO TABLE %s VALUES (4, 'd')", tableName); + + Table table = validationCatalog.loadTable(tableIdent); + + Assert.assertEquals( + "Must have 4 manifest", 4, table.currentSnapshot().allManifests(table.io()).size()); + + List output = + sql("CALL %s.system.rewrite_manifests(table => '%s')", catalogName, tableIdent); + assertEquals("Procedure output must match", ImmutableList.of(row(4, 1)), output); + + table.refresh(); + + Assert.assertEquals( + "Must have 1 manifests", 1, table.currentSnapshot().allManifests(table.io()).size()); + } + + @Test + public void testRewriteSmallManifestsWithoutCaching() { + sql( + "CREATE TABLE %s (id bigint NOT NULL, data string) USING iceberg PARTITIONED BY (data)", + tableName); + + sql("INSERT INTO TABLE %s VALUES (1, 'a')", tableName); + sql("INSERT INTO TABLE %s VALUES (2, 'b')", tableName); + + Table table = validationCatalog.loadTable(tableIdent); + + Assert.assertEquals( + "Must have 2 manifest", 2, table.currentSnapshot().allManifests(table.io()).size()); + + List output = + sql( + "CALL %s.system.rewrite_manifests(use_caching => false, table => '%s')", + catalogName, tableIdent); + assertEquals("Procedure output must match", ImmutableList.of(row(2, 1)), output); + + table.refresh(); + + Assert.assertEquals( + "Must have 1 manifests", 1, table.currentSnapshot().allManifests(table.io()).size()); + } + + @Test + public void testRewriteManifestsCaseInsensitiveArgs() { + sql( + "CREATE TABLE %s (id bigint NOT NULL, data string) USING iceberg PARTITIONED BY (data)", + tableName); + + sql("INSERT INTO TABLE %s VALUES (1, 'a')", tableName); + sql("INSERT INTO TABLE %s VALUES (2, 'b')", tableName); + + Table table = validationCatalog.loadTable(tableIdent); + + Assert.assertEquals( + "Must have 2 manifest", 2, table.currentSnapshot().allManifests(table.io()).size()); + + List output = + sql( + "CALL %s.system.rewrite_manifests(usE_cAcHiNg => false, tAbLe => '%s')", + catalogName, tableIdent); + assertEquals("Procedure output must match", ImmutableList.of(row(2, 1)), output); + + table.refresh(); + + Assert.assertEquals( + "Must have 1 manifests", 1, table.currentSnapshot().allManifests(table.io()).size()); + } + + @Test + public void testInvalidRewriteManifestsCases() { + AssertHelpers.assertThrows( + "Should not allow mixed args", + AnalysisException.class, + "Named and positional arguments cannot be mixed", + () -> sql("CALL %s.system.rewrite_manifests('n', table => 't')", catalogName)); + + AssertHelpers.assertThrows( + "Should not resolve procedures in arbitrary namespaces", + NoSuchProcedureException.class, + "not found", + () -> sql("CALL %s.custom.rewrite_manifests('n', 't')", catalogName)); + + AssertHelpers.assertThrows( + "Should reject calls without all required args", + AnalysisException.class, + "Missing required parameters", + () -> sql("CALL %s.system.rewrite_manifests()", catalogName)); + + AssertHelpers.assertThrows( + "Should reject calls with invalid arg types", + AnalysisException.class, + "Wrong arg type", + () -> sql("CALL %s.system.rewrite_manifests('n', 2.2)", catalogName)); + + AssertHelpers.assertThrows( + "Should reject duplicate arg names name", + AnalysisException.class, + "Duplicate procedure argument: table", + () -> sql("CALL %s.system.rewrite_manifests(table => 't', tAbLe => 't')", catalogName)); + + AssertHelpers.assertThrows( + "Should reject calls with empty table identifier", + IllegalArgumentException.class, + "Cannot handle an empty identifier", + () -> sql("CALL %s.system.rewrite_manifests('')", catalogName)); + } + + @Test + public void testReplacePartitionField() { + sql( + "CREATE TABLE %s (id int, ts timestamp, day_of_ts date) USING iceberg PARTITIONED BY (day_of_ts)", + tableName); + + sql("ALTER TABLE %s SET TBLPROPERTIES ('format-version' = '2')", tableName); + sql("ALTER TABLE %s REPLACE PARTITION FIELD day_of_ts WITH days(ts)\n", tableName); + sql( + "INSERT INTO %s VALUES (1, CAST('2022-01-01 10:00:00' AS TIMESTAMP), CAST('2022-01-01' AS DATE))", + tableName); + + assertEquals( + "Should have expected rows", + ImmutableList.of( + row(1, Timestamp.valueOf("2022-01-01 10:00:00"), Date.valueOf("2022-01-01"))), + sql("SELECT * FROM %s WHERE ts < current_timestamp()", tableName)); + + sql("CALL %s.system.rewrite_manifests(table => '%s')", catalogName, tableName); + + assertEquals( + "Should have expected rows", + ImmutableList.of( + row(1, Timestamp.valueOf("2022-01-01 10:00:00"), Date.valueOf("2022-01-01"))), + sql("SELECT * FROM %s WHERE ts < current_timestamp()", tableName)); + } +} diff --git a/spark/v3.4/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestRollbackToSnapshotProcedure.java b/spark/v3.4/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestRollbackToSnapshotProcedure.java new file mode 100644 index 000000000000..af94b456d02e --- /dev/null +++ b/spark/v3.4/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestRollbackToSnapshotProcedure.java @@ -0,0 +1,297 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.extensions; + +import java.util.List; +import java.util.Map; +import org.apache.iceberg.AssertHelpers; +import org.apache.iceberg.Snapshot; +import org.apache.iceberg.Table; +import org.apache.iceberg.exceptions.ValidationException; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableList; +import org.apache.spark.sql.AnalysisException; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.catalyst.analysis.NoSuchProcedureException; +import org.junit.After; +import org.junit.Assume; +import org.junit.Test; + +public class TestRollbackToSnapshotProcedure extends SparkExtensionsTestBase { + + public TestRollbackToSnapshotProcedure( + String catalogName, String implementation, Map config) { + super(catalogName, implementation, config); + } + + @After + public void removeTables() { + sql("DROP TABLE IF EXISTS %s", tableName); + } + + @Test + public void testRollbackToSnapshotUsingPositionalArgs() { + sql("CREATE TABLE %s (id bigint NOT NULL, data string) USING iceberg", tableName); + sql("INSERT INTO TABLE %s VALUES (1, 'a')", tableName); + + Table table = validationCatalog.loadTable(tableIdent); + Snapshot firstSnapshot = table.currentSnapshot(); + + sql("INSERT INTO TABLE %s VALUES (1, 'a')", tableName); + + assertEquals( + "Should have expected rows", + ImmutableList.of(row(1L, "a"), row(1L, "a")), + sql("SELECT * FROM %s ORDER BY id", tableName)); + + table.refresh(); + + Snapshot secondSnapshot = table.currentSnapshot(); + + List output = + sql( + "CALL %s.system.rollback_to_snapshot('%s', %dL)", + catalogName, tableIdent, firstSnapshot.snapshotId()); + + assertEquals( + "Procedure output must match", + ImmutableList.of(row(secondSnapshot.snapshotId(), firstSnapshot.snapshotId())), + output); + + assertEquals( + "Rollback must be successful", + ImmutableList.of(row(1L, "a")), + sql("SELECT * FROM %s ORDER BY id", tableName)); + } + + @Test + public void testRollbackToSnapshotUsingNamedArgs() { + sql("CREATE TABLE %s (id bigint NOT NULL, data string) USING iceberg", tableName); + sql("INSERT INTO TABLE %s VALUES (1, 'a')", tableName); + + Table table = validationCatalog.loadTable(tableIdent); + Snapshot firstSnapshot = table.currentSnapshot(); + + sql("INSERT INTO TABLE %s VALUES (1, 'a')", tableName); + + assertEquals( + "Should have expected rows", + ImmutableList.of(row(1L, "a"), row(1L, "a")), + sql("SELECT * FROM %s ORDER BY id", tableName)); + + table.refresh(); + + Snapshot secondSnapshot = table.currentSnapshot(); + + List output = + sql( + "CALL %s.system.rollback_to_snapshot(snapshot_id => %dL, table => '%s')", + catalogName, firstSnapshot.snapshotId(), tableIdent); + + assertEquals( + "Procedure output must match", + ImmutableList.of(row(secondSnapshot.snapshotId(), firstSnapshot.snapshotId())), + output); + + assertEquals( + "Rollback must be successful", + ImmutableList.of(row(1L, "a")), + sql("SELECT * FROM %s ORDER BY id", tableName)); + } + + @Test + public void testRollbackToSnapshotRefreshesRelationCache() { + sql("CREATE TABLE %s (id bigint NOT NULL, data string) USING iceberg", tableName); + sql("INSERT INTO TABLE %s VALUES (1, 'a')", tableName); + + Table table = validationCatalog.loadTable(tableIdent); + Snapshot firstSnapshot = table.currentSnapshot(); + + sql("INSERT INTO TABLE %s VALUES (1, 'a')", tableName); + + table.refresh(); + + Snapshot secondSnapshot = table.currentSnapshot(); + + Dataset query = spark.sql("SELECT * FROM " + tableName + " WHERE id = 1"); + query.createOrReplaceTempView("tmp"); + + spark.sql("CACHE TABLE tmp"); + + assertEquals( + "View should have expected rows", + ImmutableList.of(row(1L, "a"), row(1L, "a")), + sql("SELECT * FROM tmp")); + + List output = + sql( + "CALL %s.system.rollback_to_snapshot(table => '%s', snapshot_id => %dL)", + catalogName, tableIdent, firstSnapshot.snapshotId()); + + assertEquals( + "Procedure output must match", + ImmutableList.of(row(secondSnapshot.snapshotId(), firstSnapshot.snapshotId())), + output); + + assertEquals( + "View cache must be invalidated", ImmutableList.of(row(1L, "a")), sql("SELECT * FROM tmp")); + + sql("UNCACHE TABLE tmp"); + } + + @Test + public void testRollbackToSnapshotWithQuotedIdentifiers() { + sql("CREATE TABLE %s (id bigint NOT NULL, data string) USING iceberg", tableName); + sql("INSERT INTO TABLE %s VALUES (1, 'a')", tableName); + + Table table = validationCatalog.loadTable(tableIdent); + Snapshot firstSnapshot = table.currentSnapshot(); + + sql("INSERT INTO TABLE %s VALUES (1, 'a')", tableName); + + assertEquals( + "Should have expected rows", + ImmutableList.of(row(1L, "a"), row(1L, "a")), + sql("SELECT * FROM %s ORDER BY id", tableName)); + + table.refresh(); + + Snapshot secondSnapshot = table.currentSnapshot(); + + StringBuilder quotedNamespaceBuilder = new StringBuilder(); + for (String level : tableIdent.namespace().levels()) { + quotedNamespaceBuilder.append("`"); + quotedNamespaceBuilder.append(level); + quotedNamespaceBuilder.append("`"); + } + String quotedNamespace = quotedNamespaceBuilder.toString(); + + List output = + sql( + "CALL %s.system.rollback_to_snapshot('%s', %d)", + catalogName, + quotedNamespace + ".`" + tableIdent.name() + "`", + firstSnapshot.snapshotId()); + + assertEquals( + "Procedure output must match", + ImmutableList.of(row(secondSnapshot.snapshotId(), firstSnapshot.snapshotId())), + output); + + assertEquals( + "Rollback must be successful", + ImmutableList.of(row(1L, "a")), + sql("SELECT * FROM %s ORDER BY id", tableName)); + } + + @Test + public void testRollbackToSnapshotWithoutExplicitCatalog() { + Assume.assumeTrue("Working only with the session catalog", "spark_catalog".equals(catalogName)); + + sql("CREATE TABLE %s (id bigint NOT NULL, data string) USING iceberg", tableName); + sql("INSERT INTO TABLE %s VALUES (1, 'a')", tableName); + + Table table = validationCatalog.loadTable(tableIdent); + Snapshot firstSnapshot = table.currentSnapshot(); + + sql("INSERT INTO TABLE %s VALUES (1, 'a')", tableName); + + assertEquals( + "Should have expected rows", + ImmutableList.of(row(1L, "a"), row(1L, "a")), + sql("SELECT * FROM %s ORDER BY id", tableName)); + + table.refresh(); + + Snapshot secondSnapshot = table.currentSnapshot(); + + // use camel case intentionally to test case sensitivity + List output = + sql("CALL SyStEm.rOLlBaCk_to_SnApShOt('%s', %dL)", tableIdent, firstSnapshot.snapshotId()); + + assertEquals( + "Procedure output must match", + ImmutableList.of(row(secondSnapshot.snapshotId(), firstSnapshot.snapshotId())), + output); + + assertEquals( + "Rollback must be successful", + ImmutableList.of(row(1L, "a")), + sql("SELECT * FROM %s ORDER BY id", tableName)); + } + + @Test + public void testRollbackToInvalidSnapshot() { + sql("CREATE TABLE %s (id bigint NOT NULL, data string) USING iceberg", tableName); + + AssertHelpers.assertThrows( + "Should reject invalid snapshot id", + ValidationException.class, + "Cannot roll back to unknown snapshot id", + () -> sql("CALL %s.system.rollback_to_snapshot('%s', -1L)", catalogName, tableIdent)); + } + + @Test + public void testInvalidRollbackToSnapshotCases() { + AssertHelpers.assertThrows( + "Should not allow mixed args", + AnalysisException.class, + "Named and positional arguments cannot be mixed", + () -> + sql( + "CALL %s.system.rollback_to_snapshot(namespace => 'n1', table => 't', 1L)", + catalogName)); + + AssertHelpers.assertThrows( + "Should not resolve procedures in arbitrary namespaces", + NoSuchProcedureException.class, + "not found", + () -> sql("CALL %s.custom.rollback_to_snapshot('n', 't', 1L)", catalogName)); + + AssertHelpers.assertThrows( + "Should reject calls without all required args", + AnalysisException.class, + "Missing required parameters", + () -> sql("CALL %s.system.rollback_to_snapshot('t')", catalogName)); + + AssertHelpers.assertThrows( + "Should reject calls without all required args", + AnalysisException.class, + "Missing required parameters", + () -> sql("CALL %s.system.rollback_to_snapshot(1L)", catalogName)); + + AssertHelpers.assertThrows( + "Should reject calls without all required args", + AnalysisException.class, + "Missing required parameters", + () -> sql("CALL %s.system.rollback_to_snapshot(table => 't')", catalogName)); + + AssertHelpers.assertThrows( + "Should reject calls with invalid arg types", + AnalysisException.class, + "Wrong arg type for snapshot_id: cannot cast", + () -> sql("CALL %s.system.rollback_to_snapshot('t', 2.2)", catalogName)); + + AssertHelpers.assertThrows( + "Should reject calls with empty table identifier", + IllegalArgumentException.class, + "Cannot handle an empty identifier", + () -> sql("CALL %s.system.rollback_to_snapshot('', 1L)", catalogName)); + } +} diff --git a/spark/v3.4/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestRollbackToTimestampProcedure.java b/spark/v3.4/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestRollbackToTimestampProcedure.java new file mode 100644 index 000000000000..6da3853bbe24 --- /dev/null +++ b/spark/v3.4/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestRollbackToTimestampProcedure.java @@ -0,0 +1,304 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.extensions; + +import java.time.LocalDateTime; +import java.util.List; +import java.util.Map; +import org.apache.iceberg.AssertHelpers; +import org.apache.iceberg.Snapshot; +import org.apache.iceberg.Table; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableList; +import org.apache.spark.sql.AnalysisException; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.catalyst.analysis.NoSuchProcedureException; +import org.junit.After; +import org.junit.Assume; +import org.junit.Test; + +public class TestRollbackToTimestampProcedure extends SparkExtensionsTestBase { + + public TestRollbackToTimestampProcedure( + String catalogName, String implementation, Map config) { + super(catalogName, implementation, config); + } + + @After + public void removeTables() { + sql("DROP TABLE IF EXISTS %s", tableName); + } + + @Test + public void testRollbackToTimestampUsingPositionalArgs() { + sql("CREATE TABLE %s (id bigint NOT NULL, data string) USING iceberg", tableName); + sql("INSERT INTO TABLE %s VALUES (1, 'a')", tableName); + + Table table = validationCatalog.loadTable(tableIdent); + Snapshot firstSnapshot = table.currentSnapshot(); + String firstSnapshotTimestamp = LocalDateTime.now().toString(); + + waitUntilAfter(firstSnapshot.timestampMillis()); + + sql("INSERT INTO TABLE %s VALUES (1, 'a')", tableName); + + assertEquals( + "Should have expected rows", + ImmutableList.of(row(1L, "a"), row(1L, "a")), + sql("SELECT * FROM %s ORDER BY id", tableName)); + + table.refresh(); + + Snapshot secondSnapshot = table.currentSnapshot(); + + List output = + sql( + "CALL %s.system.rollback_to_timestamp('%s',TIMESTAMP '%s')", + catalogName, tableIdent, firstSnapshotTimestamp); + + assertEquals( + "Procedure output must match", + ImmutableList.of(row(secondSnapshot.snapshotId(), firstSnapshot.snapshotId())), + output); + + assertEquals( + "Rollback must be successful", + ImmutableList.of(row(1L, "a")), + sql("SELECT * FROM %s ORDER BY id", tableName)); + } + + @Test + public void testRollbackToTimestampUsingNamedArgs() { + sql("CREATE TABLE %s (id bigint NOT NULL, data string) USING iceberg", tableName); + sql("INSERT INTO TABLE %s VALUES (1, 'a')", tableName); + + Table table = validationCatalog.loadTable(tableIdent); + Snapshot firstSnapshot = table.currentSnapshot(); + String firstSnapshotTimestamp = LocalDateTime.now().toString(); + + waitUntilAfter(firstSnapshot.timestampMillis()); + + sql("INSERT INTO TABLE %s VALUES (1, 'a')", tableName); + + assertEquals( + "Should have expected rows", + ImmutableList.of(row(1L, "a"), row(1L, "a")), + sql("SELECT * FROM %s ORDER BY id", tableName)); + + table.refresh(); + + Snapshot secondSnapshot = table.currentSnapshot(); + + List output = + sql( + "CALL %s.system.rollback_to_timestamp(timestamp => TIMESTAMP '%s', table => '%s')", + catalogName, firstSnapshotTimestamp, tableIdent); + + assertEquals( + "Procedure output must match", + ImmutableList.of(row(secondSnapshot.snapshotId(), firstSnapshot.snapshotId())), + output); + + assertEquals( + "Rollback must be successful", + ImmutableList.of(row(1L, "a")), + sql("SELECT * FROM %s ORDER BY id", tableName)); + } + + @Test + public void testRollbackToTimestampRefreshesRelationCache() { + sql("CREATE TABLE %s (id bigint NOT NULL, data string) USING iceberg", tableName); + sql("INSERT INTO TABLE %s VALUES (1, 'a')", tableName); + + Table table = validationCatalog.loadTable(tableIdent); + Snapshot firstSnapshot = table.currentSnapshot(); + String firstSnapshotTimestamp = LocalDateTime.now().toString(); + + waitUntilAfter(firstSnapshot.timestampMillis()); + + sql("INSERT INTO TABLE %s VALUES (1, 'a')", tableName); + + table.refresh(); + + Snapshot secondSnapshot = table.currentSnapshot(); + + Dataset query = spark.sql("SELECT * FROM " + tableName + " WHERE id = 1"); + query.createOrReplaceTempView("tmp"); + + spark.sql("CACHE TABLE tmp"); + + assertEquals( + "View should have expected rows", + ImmutableList.of(row(1L, "a"), row(1L, "a")), + sql("SELECT * FROM tmp")); + + List output = + sql( + "CALL %s.system.rollback_to_timestamp(table => '%s', timestamp => TIMESTAMP '%s')", + catalogName, tableIdent, firstSnapshotTimestamp); + + assertEquals( + "Procedure output must match", + ImmutableList.of(row(secondSnapshot.snapshotId(), firstSnapshot.snapshotId())), + output); + + assertEquals( + "View cache must be invalidated", ImmutableList.of(row(1L, "a")), sql("SELECT * FROM tmp")); + + sql("UNCACHE TABLE tmp"); + } + + @Test + public void testRollbackToTimestampWithQuotedIdentifiers() { + sql("CREATE TABLE %s (id bigint NOT NULL, data string) USING iceberg", tableName); + sql("INSERT INTO TABLE %s VALUES (1, 'a')", tableName); + + Table table = validationCatalog.loadTable(tableIdent); + Snapshot firstSnapshot = table.currentSnapshot(); + String firstSnapshotTimestamp = LocalDateTime.now().toString(); + + waitUntilAfter(firstSnapshot.timestampMillis()); + + sql("INSERT INTO TABLE %s VALUES (1, 'a')", tableName); + + assertEquals( + "Should have expected rows", + ImmutableList.of(row(1L, "a"), row(1L, "a")), + sql("SELECT * FROM %s ORDER BY id", tableName)); + + table.refresh(); + + Snapshot secondSnapshot = table.currentSnapshot(); + + StringBuilder quotedNamespaceBuilder = new StringBuilder(); + for (String level : tableIdent.namespace().levels()) { + quotedNamespaceBuilder.append("`"); + quotedNamespaceBuilder.append(level); + quotedNamespaceBuilder.append("`"); + } + String quotedNamespace = quotedNamespaceBuilder.toString(); + + List output = + sql( + "CALL %s.system.rollback_to_timestamp('%s', TIMESTAMP '%s')", + catalogName, quotedNamespace + ".`" + tableIdent.name() + "`", firstSnapshotTimestamp); + + assertEquals( + "Procedure output must match", + ImmutableList.of(row(secondSnapshot.snapshotId(), firstSnapshot.snapshotId())), + output); + + assertEquals( + "Rollback must be successful", + ImmutableList.of(row(1L, "a")), + sql("SELECT * FROM %s ORDER BY id", tableName)); + } + + @Test + public void testRollbackToTimestampWithoutExplicitCatalog() { + Assume.assumeTrue("Working only with the session catalog", "spark_catalog".equals(catalogName)); + + sql("CREATE TABLE %s (id bigint NOT NULL, data string) USING iceberg", tableName); + sql("INSERT INTO TABLE %s VALUES (1, 'a')", tableName); + + Table table = validationCatalog.loadTable(tableIdent); + Snapshot firstSnapshot = table.currentSnapshot(); + String firstSnapshotTimestamp = LocalDateTime.now().toString(); + + waitUntilAfter(firstSnapshot.timestampMillis()); + + sql("INSERT INTO TABLE %s VALUES (1, 'a')", tableName); + + assertEquals( + "Should have expected rows", + ImmutableList.of(row(1L, "a"), row(1L, "a")), + sql("SELECT * FROM %s ORDER BY id", tableName)); + + table.refresh(); + + Snapshot secondSnapshot = table.currentSnapshot(); + + // use camel case intentionally to test case sensitivity + List output = + sql( + "CALL SyStEm.rOLlBaCk_to_TiMeStaMp('%s', TIMESTAMP '%s')", + tableIdent, firstSnapshotTimestamp); + + assertEquals( + "Procedure output must match", + ImmutableList.of(row(secondSnapshot.snapshotId(), firstSnapshot.snapshotId())), + output); + + assertEquals( + "Rollback must be successful", + ImmutableList.of(row(1L, "a")), + sql("SELECT * FROM %s ORDER BY id", tableName)); + } + + @Test + public void testInvalidRollbackToTimestampCases() { + String timestamp = "TIMESTAMP '2007-12-03T10:15:30'"; + + AssertHelpers.assertThrows( + "Should not allow mixed args", + AnalysisException.class, + "Named and positional arguments cannot be mixed", + () -> + sql( + "CALL %s.system.rollback_to_timestamp(namespace => 'n1', 't', %s)", + catalogName, timestamp)); + + AssertHelpers.assertThrows( + "Should not resolve procedures in arbitrary namespaces", + NoSuchProcedureException.class, + "not found", + () -> sql("CALL %s.custom.rollback_to_timestamp('n', 't', %s)", catalogName, timestamp)); + + AssertHelpers.assertThrows( + "Should reject calls without all required args", + AnalysisException.class, + "Missing required parameters", + () -> sql("CALL %s.system.rollback_to_timestamp('t')", catalogName)); + + AssertHelpers.assertThrows( + "Should reject calls without all required args", + AnalysisException.class, + "Missing required parameters", + () -> sql("CALL %s.system.rollback_to_timestamp(timestamp => %s)", catalogName, timestamp)); + + AssertHelpers.assertThrows( + "Should reject calls without all required args", + AnalysisException.class, + "Missing required parameters", + () -> sql("CALL %s.system.rollback_to_timestamp(table => 't')", catalogName)); + + AssertHelpers.assertThrows( + "Should reject calls with extra args", + AnalysisException.class, + "Too many arguments", + () -> + sql("CALL %s.system.rollback_to_timestamp('n', 't', %s, 1L)", catalogName, timestamp)); + + AssertHelpers.assertThrows( + "Should reject calls with invalid arg types", + AnalysisException.class, + "Wrong arg type for timestamp: cannot cast", + () -> sql("CALL %s.system.rollback_to_timestamp('t', 2.2)", catalogName)); + } +} diff --git a/spark/v3.4/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestSetCurrentSnapshotProcedure.java b/spark/v3.4/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestSetCurrentSnapshotProcedure.java new file mode 100644 index 000000000000..8a8a974bbebe --- /dev/null +++ b/spark/v3.4/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestSetCurrentSnapshotProcedure.java @@ -0,0 +1,256 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.extensions; + +import static org.apache.iceberg.TableProperties.WRITE_AUDIT_PUBLISH_ENABLED; + +import java.util.List; +import java.util.Map; +import org.apache.iceberg.AssertHelpers; +import org.apache.iceberg.Snapshot; +import org.apache.iceberg.Table; +import org.apache.iceberg.catalog.Namespace; +import org.apache.iceberg.exceptions.ValidationException; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableList; +import org.apache.iceberg.relocated.com.google.common.collect.Iterables; +import org.apache.spark.sql.AnalysisException; +import org.apache.spark.sql.catalyst.analysis.NoSuchProcedureException; +import org.junit.After; +import org.junit.Assume; +import org.junit.Test; + +public class TestSetCurrentSnapshotProcedure extends SparkExtensionsTestBase { + + public TestSetCurrentSnapshotProcedure( + String catalogName, String implementation, Map config) { + super(catalogName, implementation, config); + } + + @After + public void removeTables() { + sql("DROP TABLE IF EXISTS %s", tableName); + } + + @Test + public void testSetCurrentSnapshotUsingPositionalArgs() { + sql("CREATE TABLE %s (id bigint NOT NULL, data string) USING iceberg", tableName); + sql("INSERT INTO TABLE %s VALUES (1, 'a')", tableName); + + Table table = validationCatalog.loadTable(tableIdent); + Snapshot firstSnapshot = table.currentSnapshot(); + + sql("INSERT INTO TABLE %s VALUES (1, 'a')", tableName); + + assertEquals( + "Should have expected rows", + ImmutableList.of(row(1L, "a"), row(1L, "a")), + sql("SELECT * FROM %s ORDER BY id", tableName)); + + table.refresh(); + + Snapshot secondSnapshot = table.currentSnapshot(); + + List output = + sql( + "CALL %s.system.set_current_snapshot('%s', %dL)", + catalogName, tableIdent, firstSnapshot.snapshotId()); + + assertEquals( + "Procedure output must match", + ImmutableList.of(row(secondSnapshot.snapshotId(), firstSnapshot.snapshotId())), + output); + + assertEquals( + "Set must be successful", + ImmutableList.of(row(1L, "a")), + sql("SELECT * FROM %s ORDER BY id", tableName)); + } + + @Test + public void testSetCurrentSnapshotUsingNamedArgs() { + sql("CREATE TABLE %s (id bigint NOT NULL, data string) USING iceberg", tableName); + sql("INSERT INTO TABLE %s VALUES (1, 'a')", tableName); + + Table table = validationCatalog.loadTable(tableIdent); + Snapshot firstSnapshot = table.currentSnapshot(); + + sql("INSERT INTO TABLE %s VALUES (1, 'a')", tableName); + + assertEquals( + "Should have expected rows", + ImmutableList.of(row(1L, "a"), row(1L, "a")), + sql("SELECT * FROM %s ORDER BY id", tableName)); + + table.refresh(); + + Snapshot secondSnapshot = table.currentSnapshot(); + + List output = + sql( + "CALL %s.system.set_current_snapshot(snapshot_id => %dL, table => '%s')", + catalogName, firstSnapshot.snapshotId(), tableIdent); + + assertEquals( + "Procedure output must match", + ImmutableList.of(row(secondSnapshot.snapshotId(), firstSnapshot.snapshotId())), + output); + + assertEquals( + "Set must be successful", + ImmutableList.of(row(1L, "a")), + sql("SELECT * FROM %s ORDER BY id", tableName)); + } + + @Test + public void testSetCurrentSnapshotWap() { + sql("CREATE TABLE %s (id bigint NOT NULL, data string) USING iceberg", tableName); + sql("ALTER TABLE %s SET TBLPROPERTIES ('%s' 'true')", tableName, WRITE_AUDIT_PUBLISH_ENABLED); + + spark.conf().set("spark.wap.id", "1"); + + sql("INSERT INTO TABLE %s VALUES (1, 'a')", tableName); + + assertEquals( + "Should not see rows from staged snapshot", + ImmutableList.of(), + sql("SELECT * FROM %s", tableName)); + + Table table = validationCatalog.loadTable(tableIdent); + Snapshot wapSnapshot = Iterables.getOnlyElement(table.snapshots()); + + List output = + sql( + "CALL %s.system.set_current_snapshot(table => '%s', snapshot_id => %dL)", + catalogName, tableIdent, wapSnapshot.snapshotId()); + + assertEquals( + "Procedure output must match", + ImmutableList.of(row(null, wapSnapshot.snapshotId())), + output); + + assertEquals( + "Current snapshot must be set correctly", + ImmutableList.of(row(1L, "a")), + sql("SELECT * FROM %s", tableName)); + } + + @Test + public void tesSetCurrentSnapshotWithoutExplicitCatalog() { + Assume.assumeTrue("Working only with the session catalog", "spark_catalog".equals(catalogName)); + + sql("CREATE TABLE %s (id bigint NOT NULL, data string) USING iceberg", tableName); + sql("INSERT INTO TABLE %s VALUES (1, 'a')", tableName); + + Table table = validationCatalog.loadTable(tableIdent); + Snapshot firstSnapshot = table.currentSnapshot(); + + sql("INSERT INTO TABLE %s VALUES (1, 'a')", tableName); + + assertEquals( + "Should have expected rows", + ImmutableList.of(row(1L, "a"), row(1L, "a")), + sql("SELECT * FROM %s ORDER BY id", tableName)); + + table.refresh(); + + Snapshot secondSnapshot = table.currentSnapshot(); + + // use camel case intentionally to test case sensitivity + List output = + sql("CALL SyStEm.sEt_cuRrEnT_sNaPsHot('%s', %dL)", tableIdent, firstSnapshot.snapshotId()); + + assertEquals( + "Procedure output must match", + ImmutableList.of(row(secondSnapshot.snapshotId(), firstSnapshot.snapshotId())), + output); + + assertEquals( + "Set must be successful", + ImmutableList.of(row(1L, "a")), + sql("SELECT * FROM %s ORDER BY id", tableName)); + } + + @Test + public void testSetCurrentSnapshotToInvalidSnapshot() { + sql("CREATE TABLE %s (id bigint NOT NULL, data string) USING iceberg", tableName); + + Namespace namespace = tableIdent.namespace(); + String tableName = tableIdent.name(); + + AssertHelpers.assertThrows( + "Should reject invalid snapshot id", + ValidationException.class, + "Cannot roll back to unknown snapshot id", + () -> sql("CALL %s.system.set_current_snapshot('%s', -1L)", catalogName, tableIdent)); + } + + @Test + public void testInvalidRollbackToSnapshotCases() { + AssertHelpers.assertThrows( + "Should not allow mixed args", + AnalysisException.class, + "Named and positional arguments cannot be mixed", + () -> + sql( + "CALL %s.system.set_current_snapshot(namespace => 'n1', table => 't', 1L)", + catalogName)); + + AssertHelpers.assertThrows( + "Should not resolve procedures in arbitrary namespaces", + NoSuchProcedureException.class, + "not found", + () -> sql("CALL %s.custom.set_current_snapshot('n', 't', 1L)", catalogName)); + + AssertHelpers.assertThrows( + "Should reject calls without all required args", + AnalysisException.class, + "Missing required parameters", + () -> sql("CALL %s.system.set_current_snapshot('t')", catalogName)); + + AssertHelpers.assertThrows( + "Should reject calls without all required args", + AnalysisException.class, + "Missing required parameters", + () -> sql("CALL %s.system.set_current_snapshot(1L)", catalogName)); + + AssertHelpers.assertThrows( + "Should reject calls without all required args", + AnalysisException.class, + "Missing required parameters", + () -> sql("CALL %s.system.set_current_snapshot(snapshot_id => 1L)", catalogName)); + + AssertHelpers.assertThrows( + "Should reject calls without all required args", + AnalysisException.class, + "Missing required parameters", + () -> sql("CALL %s.system.set_current_snapshot(table => 't')", catalogName)); + + AssertHelpers.assertThrows( + "Should reject calls with invalid arg types", + AnalysisException.class, + "Wrong arg type for snapshot_id: cannot cast", + () -> sql("CALL %s.system.set_current_snapshot('t', 2.2)", catalogName)); + + AssertHelpers.assertThrows( + "Should reject calls with empty table identifier", + IllegalArgumentException.class, + "Cannot handle an empty identifier", + () -> sql("CALL %s.system.set_current_snapshot('', 1L)", catalogName)); + } +} diff --git a/spark/v3.4/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestSetWriteDistributionAndOrdering.java b/spark/v3.4/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestSetWriteDistributionAndOrdering.java new file mode 100644 index 000000000000..e7e52806792d --- /dev/null +++ b/spark/v3.4/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestSetWriteDistributionAndOrdering.java @@ -0,0 +1,302 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.extensions; + +import static org.apache.iceberg.expressions.Expressions.bucket; + +import java.util.Map; +import org.apache.iceberg.NullOrder; +import org.apache.iceberg.SortOrder; +import org.apache.iceberg.Table; +import org.apache.iceberg.TableProperties; +import org.junit.After; +import org.junit.Assert; +import org.junit.Test; + +public class TestSetWriteDistributionAndOrdering extends SparkExtensionsTestBase { + public TestSetWriteDistributionAndOrdering( + String catalogName, String implementation, Map config) { + super(catalogName, implementation, config); + } + + @After + public void removeTable() { + sql("DROP TABLE IF EXISTS %s", tableName); + } + + @Test + public void testSetWriteOrderByColumn() { + sql( + "CREATE TABLE %s (id bigint NOT NULL, category string, ts timestamp, data string) USING iceberg", + tableName); + Table table = validationCatalog.loadTable(tableIdent); + Assert.assertTrue("Table should start unsorted", table.sortOrder().isUnsorted()); + + sql("ALTER TABLE %s WRITE ORDERED BY category, id", tableName); + + table.refresh(); + + String distributionMode = table.properties().get(TableProperties.WRITE_DISTRIBUTION_MODE); + Assert.assertEquals("Distribution mode must match", "range", distributionMode); + + SortOrder expected = + SortOrder.builderFor(table.schema()) + .withOrderId(1) + .asc("category", NullOrder.NULLS_FIRST) + .asc("id", NullOrder.NULLS_FIRST) + .build(); + Assert.assertEquals("Should have expected order", expected, table.sortOrder()); + } + + @Test + public void testSetWriteOrderByColumnWithDirection() { + sql( + "CREATE TABLE %s (id bigint NOT NULL, category string, ts timestamp, data string) USING iceberg", + tableName); + Table table = validationCatalog.loadTable(tableIdent); + Assert.assertTrue("Table should start unsorted", table.sortOrder().isUnsorted()); + + sql("ALTER TABLE %s WRITE ORDERED BY category ASC, id DESC", tableName); + + table.refresh(); + + String distributionMode = table.properties().get(TableProperties.WRITE_DISTRIBUTION_MODE); + Assert.assertEquals("Distribution mode must match", "range", distributionMode); + + SortOrder expected = + SortOrder.builderFor(table.schema()) + .withOrderId(1) + .asc("category", NullOrder.NULLS_FIRST) + .desc("id", NullOrder.NULLS_LAST) + .build(); + Assert.assertEquals("Should have expected order", expected, table.sortOrder()); + } + + @Test + public void testSetWriteOrderByColumnWithDirectionAndNullOrder() { + sql( + "CREATE TABLE %s (id bigint NOT NULL, category string, ts timestamp, data string) USING iceberg", + tableName); + Table table = validationCatalog.loadTable(tableIdent); + Assert.assertTrue("Table should start unsorted", table.sortOrder().isUnsorted()); + + sql("ALTER TABLE %s WRITE ORDERED BY category ASC NULLS LAST, id DESC NULLS FIRST", tableName); + + table.refresh(); + + String distributionMode = table.properties().get(TableProperties.WRITE_DISTRIBUTION_MODE); + Assert.assertEquals("Distribution mode must match", "range", distributionMode); + + SortOrder expected = + SortOrder.builderFor(table.schema()) + .withOrderId(1) + .asc("category", NullOrder.NULLS_LAST) + .desc("id", NullOrder.NULLS_FIRST) + .build(); + Assert.assertEquals("Should have expected order", expected, table.sortOrder()); + } + + @Test + public void testSetWriteOrderByTransform() { + sql( + "CREATE TABLE %s (id bigint NOT NULL, category string, ts timestamp, data string) USING iceberg", + tableName); + Table table = validationCatalog.loadTable(tableIdent); + Assert.assertTrue("Table should start unsorted", table.sortOrder().isUnsorted()); + + sql("ALTER TABLE %s WRITE ORDERED BY category DESC, bucket(16, id), id", tableName); + + table.refresh(); + + String distributionMode = table.properties().get(TableProperties.WRITE_DISTRIBUTION_MODE); + Assert.assertEquals("Distribution mode must match", "range", distributionMode); + + SortOrder expected = + SortOrder.builderFor(table.schema()) + .withOrderId(1) + .desc("category") + .asc(bucket("id", 16)) + .asc("id") + .build(); + Assert.assertEquals("Should have expected order", expected, table.sortOrder()); + } + + @Test + public void testSetWriteUnordered() { + sql( + "CREATE TABLE %s (id bigint NOT NULL, category string, ts timestamp, data string) USING iceberg", + tableName); + Table table = validationCatalog.loadTable(tableIdent); + Assert.assertTrue("Table should start unsorted", table.sortOrder().isUnsorted()); + + sql("ALTER TABLE %s WRITE ORDERED BY category DESC, bucket(16, id), id", tableName); + + table.refresh(); + + String distributionMode = table.properties().get(TableProperties.WRITE_DISTRIBUTION_MODE); + Assert.assertEquals("Distribution mode must match", "range", distributionMode); + + Assert.assertNotEquals("Table must be sorted", SortOrder.unsorted(), table.sortOrder()); + + sql("ALTER TABLE %s WRITE UNORDERED", tableName); + + table.refresh(); + + String newDistributionMode = table.properties().get(TableProperties.WRITE_DISTRIBUTION_MODE); + Assert.assertEquals("New distribution mode must match", "none", newDistributionMode); + + Assert.assertEquals("New sort order must match", SortOrder.unsorted(), table.sortOrder()); + } + + @Test + public void testSetWriteLocallyOrdered() { + sql( + "CREATE TABLE %s (id bigint NOT NULL, category string, ts timestamp, data string) USING iceberg", + tableName); + Table table = validationCatalog.loadTable(tableIdent); + Assert.assertTrue("Table should start unsorted", table.sortOrder().isUnsorted()); + + sql("ALTER TABLE %s WRITE LOCALLY ORDERED BY category DESC, bucket(16, id), id", tableName); + + table.refresh(); + + String distributionMode = table.properties().get(TableProperties.WRITE_DISTRIBUTION_MODE); + Assert.assertEquals("Distribution mode must match", "none", distributionMode); + + SortOrder expected = + SortOrder.builderFor(table.schema()) + .withOrderId(1) + .desc("category") + .asc(bucket("id", 16)) + .asc("id") + .build(); + Assert.assertEquals("Sort order must match", expected, table.sortOrder()); + } + + @Test + public void testSetWriteDistributedByWithSort() { + sql( + "CREATE TABLE %s (id bigint NOT NULL, category string) USING iceberg PARTITIONED BY (category)", + tableName); + Table table = validationCatalog.loadTable(tableIdent); + Assert.assertTrue("Table should start unsorted", table.sortOrder().isUnsorted()); + + sql("ALTER TABLE %s WRITE DISTRIBUTED BY PARTITION ORDERED BY id", tableName); + + table.refresh(); + + String distributionMode = table.properties().get(TableProperties.WRITE_DISTRIBUTION_MODE); + Assert.assertEquals("Distribution mode must match", "hash", distributionMode); + + SortOrder expected = SortOrder.builderFor(table.schema()).withOrderId(1).asc("id").build(); + Assert.assertEquals("Sort order must match", expected, table.sortOrder()); + } + + @Test + public void testSetWriteDistributedByWithLocalSort() { + sql( + "CREATE TABLE %s (id bigint NOT NULL, category string) USING iceberg PARTITIONED BY (category)", + tableName); + Table table = validationCatalog.loadTable(tableIdent); + Assert.assertTrue("Table should start unsorted", table.sortOrder().isUnsorted()); + + sql("ALTER TABLE %s WRITE DISTRIBUTED BY PARTITION LOCALLY ORDERED BY id", tableName); + + table.refresh(); + + String distributionMode = table.properties().get(TableProperties.WRITE_DISTRIBUTION_MODE); + Assert.assertEquals("Distribution mode must match", "hash", distributionMode); + + SortOrder expected = SortOrder.builderFor(table.schema()).withOrderId(1).asc("id").build(); + Assert.assertEquals("Sort order must match", expected, table.sortOrder()); + } + + @Test + public void testSetWriteDistributedByAndUnordered() { + sql( + "CREATE TABLE %s (id bigint NOT NULL, category string) USING iceberg PARTITIONED BY (category)", + tableName); + Table table = validationCatalog.loadTable(tableIdent); + Assert.assertTrue("Table should start unsorted", table.sortOrder().isUnsorted()); + + sql("ALTER TABLE %s WRITE DISTRIBUTED BY PARTITION UNORDERED", tableName); + + table.refresh(); + + String distributionMode = table.properties().get(TableProperties.WRITE_DISTRIBUTION_MODE); + Assert.assertEquals("Distribution mode must match", "hash", distributionMode); + + Assert.assertEquals("Sort order must match", SortOrder.unsorted(), table.sortOrder()); + } + + @Test + public void testSetWriteDistributedByOnly() { + sql( + "CREATE TABLE %s (id bigint NOT NULL, category string) USING iceberg PARTITIONED BY (category)", + tableName); + Table table = validationCatalog.loadTable(tableIdent); + Assert.assertTrue("Table should start unsorted", table.sortOrder().isUnsorted()); + + sql("ALTER TABLE %s WRITE DISTRIBUTED BY PARTITION UNORDERED", tableName); + + table.refresh(); + + String distributionMode = table.properties().get(TableProperties.WRITE_DISTRIBUTION_MODE); + Assert.assertEquals("Distribution mode must match", "hash", distributionMode); + + Assert.assertEquals("Sort order must match", SortOrder.unsorted(), table.sortOrder()); + } + + @Test + public void testSetWriteDistributedAndUnorderedInverted() { + sql( + "CREATE TABLE %s (id bigint NOT NULL, category string) USING iceberg PARTITIONED BY (category)", + tableName); + Table table = validationCatalog.loadTable(tableIdent); + Assert.assertTrue("Table should start unsorted", table.sortOrder().isUnsorted()); + + sql("ALTER TABLE %s WRITE UNORDERED DISTRIBUTED BY PARTITION", tableName); + + table.refresh(); + + String distributionMode = table.properties().get(TableProperties.WRITE_DISTRIBUTION_MODE); + Assert.assertEquals("Distribution mode must match", "hash", distributionMode); + + Assert.assertEquals("Sort order must match", SortOrder.unsorted(), table.sortOrder()); + } + + @Test + public void testSetWriteDistributedAndLocallyOrderedInverted() { + sql( + "CREATE TABLE %s (id bigint NOT NULL, category string) USING iceberg PARTITIONED BY (category)", + tableName); + Table table = validationCatalog.loadTable(tableIdent); + Assert.assertTrue("Table should start unsorted", table.sortOrder().isUnsorted()); + + sql("ALTER TABLE %s WRITE ORDERED BY id DISTRIBUTED BY PARTITION", tableName); + + table.refresh(); + + String distributionMode = table.properties().get(TableProperties.WRITE_DISTRIBUTION_MODE); + Assert.assertEquals("Distribution mode must match", "hash", distributionMode); + + SortOrder expected = SortOrder.builderFor(table.schema()).withOrderId(1).asc("id").build(); + Assert.assertEquals("Sort order must match", expected, table.sortOrder()); + } +} diff --git a/spark/v3.4/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestSnapshotTableProcedure.java b/spark/v3.4/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestSnapshotTableProcedure.java new file mode 100644 index 000000000000..ed64ef331580 --- /dev/null +++ b/spark/v3.4/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestSnapshotTableProcedure.java @@ -0,0 +1,231 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.extensions; + +import java.io.IOException; +import java.util.Map; +import org.apache.iceberg.AssertHelpers; +import org.apache.iceberg.Table; +import org.apache.iceberg.TableProperties; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableList; +import org.apache.spark.sql.AnalysisException; +import org.junit.After; +import org.junit.Assert; +import org.junit.Assume; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.TemporaryFolder; + +public class TestSnapshotTableProcedure extends SparkExtensionsTestBase { + private static final String sourceName = "spark_catalog.default.source"; + // Currently we can only Snapshot only out of the Spark Session Catalog + + public TestSnapshotTableProcedure( + String catalogName, String implementation, Map config) { + super(catalogName, implementation, config); + } + + @Rule public TemporaryFolder temp = new TemporaryFolder(); + + @After + public void removeTables() { + sql("DROP TABLE IF EXISTS %s", tableName); + sql("DROP TABLE IF EXISTS %S", sourceName); + } + + @Test + public void testSnapshot() throws IOException { + String location = temp.newFolder().toString(); + sql( + "CREATE TABLE %s (id bigint NOT NULL, data string) USING parquet LOCATION '%s'", + sourceName, location); + sql("INSERT INTO TABLE %s VALUES (1, 'a')", sourceName); + Object result = + scalarSql("CALL %s.system.snapshot('%s', '%s')", catalogName, sourceName, tableName); + + Assert.assertEquals("Should have added one file", 1L, result); + + Table createdTable = validationCatalog.loadTable(tableIdent); + String tableLocation = createdTable.location(); + Assert.assertNotEquals("Table should not have the original location", location, tableLocation); + + sql("INSERT INTO TABLE %s VALUES (1, 'a')", tableName); + + assertEquals( + "Should have expected rows", + ImmutableList.of(row(1L, "a"), row(1L, "a")), + sql("SELECT * FROM %s ORDER BY id", tableName)); + } + + @Test + public void testSnapshotWithProperties() throws IOException { + String location = temp.newFolder().toString(); + sql( + "CREATE TABLE %s (id bigint NOT NULL, data string) USING parquet LOCATION '%s'", + sourceName, location); + sql("INSERT INTO TABLE %s VALUES (1, 'a')", sourceName); + Object result = + scalarSql( + "CALL %s.system.snapshot(source_table => '%s', table => '%s', properties => map('foo','bar'))", + catalogName, sourceName, tableName); + + Assert.assertEquals("Should have added one file", 1L, result); + + Table createdTable = validationCatalog.loadTable(tableIdent); + + String tableLocation = createdTable.location(); + Assert.assertNotEquals("Table should not have the original location", location, tableLocation); + + Map props = createdTable.properties(); + Assert.assertEquals("Should have extra property set", "bar", props.get("foo")); + + sql("INSERT INTO TABLE %s VALUES (1, 'a')", tableName); + + assertEquals( + "Should have expected rows", + ImmutableList.of(row(1L, "a"), row(1L, "a")), + sql("SELECT * FROM %s ORDER BY id", tableName)); + } + + @Test + public void testSnapshotWithAlternateLocation() throws IOException { + Assume.assumeTrue( + "No Snapshoting with Alternate locations with Hadoop Catalogs", + !catalogName.contains("hadoop")); + String location = temp.newFolder().toString(); + String snapshotLocation = temp.newFolder().toString(); + sql( + "CREATE TABLE %s (id bigint NOT NULL, data string) USING parquet LOCATION '%s'", + sourceName, location); + sql("INSERT INTO TABLE %s VALUES (1, 'a')", sourceName); + Object[] result = + sql( + "CALL %s.system.snapshot(source_table => '%s', table => '%s', location => '%s')", + catalogName, sourceName, tableName, snapshotLocation) + .get(0); + + Assert.assertEquals("Should have added one file", 1L, result[0]); + + String storageLocation = validationCatalog.loadTable(tableIdent).location(); + Assert.assertEquals( + "Snapshot should be made at specified location", snapshotLocation, storageLocation); + + sql("INSERT INTO TABLE %s VALUES (1, 'a')", tableName); + + assertEquals( + "Should have expected rows", + ImmutableList.of(row(1L, "a"), row(1L, "a")), + sql("SELECT * FROM %s ORDER BY id", tableName)); + } + + @Test + public void testDropTable() throws IOException { + String location = temp.newFolder().toString(); + sql( + "CREATE TABLE %s (id bigint NOT NULL, data string) USING parquet LOCATION '%s'", + sourceName, location); + sql("INSERT INTO TABLE %s VALUES (1, 'a')", sourceName); + + Object result = + scalarSql("CALL %s.system.snapshot('%s', '%s')", catalogName, sourceName, tableName); + Assert.assertEquals("Should have added one file", 1L, result); + + assertEquals( + "Should have expected rows", + ImmutableList.of(row(1L, "a")), + sql("SELECT * FROM %s", tableName)); + + sql("DROP TABLE %s", tableName); + + assertEquals( + "Source table should be intact", + ImmutableList.of(row(1L, "a")), + sql("SELECT * FROM %s", sourceName)); + } + + @Test + public void testSnapshotWithConflictingProps() throws IOException { + String location = temp.newFolder().toString(); + sql( + "CREATE TABLE %s (id bigint NOT NULL, data string) USING parquet LOCATION '%s'", + sourceName, location); + sql("INSERT INTO TABLE %s VALUES (1, 'a')", sourceName); + + Object result = + scalarSql( + "CALL %s.system.snapshot(" + + "source_table => '%s'," + + "table => '%s'," + + "properties => map('%s', 'true', 'snapshot', 'false'))", + catalogName, sourceName, tableName, TableProperties.GC_ENABLED); + Assert.assertEquals("Should have added one file", 1L, result); + + assertEquals( + "Should have expected rows", + ImmutableList.of(row(1L, "a")), + sql("SELECT * FROM %s", tableName)); + + Table table = validationCatalog.loadTable(tableIdent); + Map props = table.properties(); + Assert.assertEquals("Should override user value", "true", props.get("snapshot")); + Assert.assertEquals( + "Should override user value", "false", props.get(TableProperties.GC_ENABLED)); + } + + @Test + public void testInvalidSnapshotsCases() throws IOException { + String location = temp.newFolder().toString(); + sql( + "CREATE TABLE %s (id bigint NOT NULL, data string) USING parquet LOCATION '%s'", + sourceName, location); + + AssertHelpers.assertThrows( + "Should reject calls without all required args", + AnalysisException.class, + "Missing required parameters", + () -> sql("CALL %s.system.snapshot('foo')", catalogName)); + + AssertHelpers.assertThrows( + "Should reject calls with invalid arg types", + AnalysisException.class, + "Wrong arg type", + () -> sql("CALL %s.system.snapshot('n', 't', map('foo', 'bar'))", catalogName)); + + AssertHelpers.assertThrows( + "Should reject calls with invalid map args", + AnalysisException.class, + "The `map` requires 2n (n > 0) parameters but the actual number is 3", + () -> + sql( + "CALL %s.system.snapshot('%s', 'fable', 'loc', map(2, 1, 1))", + catalogName, sourceName)); + + AssertHelpers.assertThrows( + "Should reject calls with empty table identifier", + IllegalArgumentException.class, + "Cannot handle an empty identifier", + () -> sql("CALL %s.system.snapshot('', 'dest')", catalogName)); + + AssertHelpers.assertThrows( + "Should reject calls with empty table identifier", + IllegalArgumentException.class, + "Cannot handle an empty identifier", + () -> sql("CALL %s.system.snapshot('src', '')", catalogName)); + } +} diff --git a/spark/v3.4/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestStoragePartitionedJoinsInRowLevelOperations.java b/spark/v3.4/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestStoragePartitionedJoinsInRowLevelOperations.java new file mode 100644 index 000000000000..bc81c9ea336c --- /dev/null +++ b/spark/v3.4/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestStoragePartitionedJoinsInRowLevelOperations.java @@ -0,0 +1,283 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.extensions; + +import static org.apache.iceberg.RowLevelOperationMode.COPY_ON_WRITE; +import static org.apache.iceberg.RowLevelOperationMode.MERGE_ON_READ; + +import java.util.Map; +import org.apache.iceberg.RowLevelOperationMode; +import org.apache.iceberg.TableProperties; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableList; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap; +import org.apache.iceberg.spark.SparkCatalogConfig; +import org.apache.iceberg.spark.SparkSQLProperties; +import org.apache.spark.sql.execution.SparkPlan; +import org.apache.spark.sql.internal.SQLConf; +import org.junit.After; +import org.junit.Assert; +import org.junit.Test; +import org.junit.runners.Parameterized; + +public class TestStoragePartitionedJoinsInRowLevelOperations extends SparkExtensionsTestBase { + + private static final String OTHER_TABLE_NAME = "other_table"; + + // open file cost and split size are set as 16 MB to produce a split per file + private static final Map COMMON_TABLE_PROPERTIES = + ImmutableMap.of( + TableProperties.FORMAT_VERSION, + "2", + TableProperties.SPLIT_SIZE, + "16777216", + TableProperties.SPLIT_OPEN_FILE_COST, + "16777216"); + + // only v2 bucketing and preserve data grouping properties have to be enabled to trigger SPJ + // other properties are only to simplify testing and validation + private static final Map ENABLED_SPJ_SQL_CONF = + ImmutableMap.of( + SQLConf.V2_BUCKETING_ENABLED().key(), + "true", + SQLConf.REQUIRE_ALL_CLUSTER_KEYS_FOR_CO_PARTITION().key(), + "false", + SQLConf.ADAPTIVE_EXECUTION_ENABLED().key(), + "false", + SQLConf.AUTO_BROADCASTJOIN_THRESHOLD().key(), + "-1", + SparkSQLProperties.PRESERVE_DATA_GROUPING, + "true"); + + @Parameterized.Parameters(name = "catalogName = {0}, implementation = {1}, config = {2}") + public static Object[][] parameters() { + return new Object[][] { + { + SparkCatalogConfig.HIVE.catalogName(), + SparkCatalogConfig.HIVE.implementation(), + SparkCatalogConfig.HIVE.properties() + } + }; + } + + public TestStoragePartitionedJoinsInRowLevelOperations( + String catalogName, String implementation, Map config) { + super(catalogName, implementation, config); + } + + @After + public void removeTables() { + sql("DROP TABLE IF EXISTS %s", tableName); + sql("DROP TABLE IF EXISTS %s", tableName(OTHER_TABLE_NAME)); + } + + @Test + public void testCopyOnWriteDeleteWithoutShuffles() { + checkDelete(COPY_ON_WRITE); + } + + @Test + public void testMergeOnReadDeleteWithoutShuffles() { + checkDelete(MERGE_ON_READ); + } + + private void checkDelete(RowLevelOperationMode mode) { + String createTableStmt = + "CREATE TABLE %s (id INT, salary INT, dep STRING)" + + "USING iceberg " + + "PARTITIONED BY (dep) " + + "TBLPROPERTIES (%s)"; + + sql(createTableStmt, tableName, tablePropsAsString(COMMON_TABLE_PROPERTIES)); + + append(tableName, "{ \"id\": 1, \"salary\": 100, \"dep\": \"hr\" }"); + append(tableName, "{ \"id\": 2, \"salary\": 200, \"dep\": \"hr\" }"); + append(tableName, "{ \"id\": 3, \"salary\": 300, \"dep\": \"hr\" }"); + append(tableName, "{ \"id\": 4, \"salary\": 400, \"dep\": \"hardware\" }"); + + sql(createTableStmt, tableName(OTHER_TABLE_NAME), tablePropsAsString(COMMON_TABLE_PROPERTIES)); + + append(tableName(OTHER_TABLE_NAME), "{ \"id\": 1, \"salary\": 110, \"dep\": \"hr\" }"); + append(tableName(OTHER_TABLE_NAME), "{ \"id\": 5, \"salary\": 500, \"dep\": \"hr\" }"); + + Map deleteTableProps = + ImmutableMap.of( + TableProperties.DELETE_MODE, + mode.modeName(), + TableProperties.DELETE_DISTRIBUTION_MODE, + "none"); + + sql("ALTER TABLE %s SET TBLPROPERTIES(%s)", tableName, tablePropsAsString(deleteTableProps)); + + withSQLConf( + ENABLED_SPJ_SQL_CONF, + () -> { + SparkPlan plan = + executeAndKeepPlan( + "DELETE FROM %s t WHERE " + + "EXISTS (SELECT 1 FROM %s s WHERE t.id = s.id AND t.dep = s.dep) AND " + + "dep = 'hr'", + tableName, tableName(OTHER_TABLE_NAME)); + String planAsString = plan.toString(); + Assert.assertFalse("Should be no shuffles with SPJ", planAsString.contains("Exchange")); + }); + + ImmutableList expectedRows = + ImmutableList.of( + row(2, 200, "hr"), // remaining + row(3, 300, "hr"), // remaining + row(4, 400, "hardware")); // remaining + + assertEquals( + "Should have expected rows", + expectedRows, + sql("SELECT * FROM %s ORDER BY id, salary", tableName)); + } + + @Test + public void testCopyOnWriteUpdateWithoutShuffles() { + checkUpdate(COPY_ON_WRITE); + } + + @Test + public void testMergeOnReadUpdateWithoutShuffles() { + checkUpdate(MERGE_ON_READ); + } + + private void checkUpdate(RowLevelOperationMode mode) { + String createTableStmt = + "CREATE TABLE %s (id INT, salary INT, dep STRING)" + + "USING iceberg " + + "PARTITIONED BY (dep) " + + "TBLPROPERTIES (%s)"; + + sql(createTableStmt, tableName, tablePropsAsString(COMMON_TABLE_PROPERTIES)); + + append(tableName, "{ \"id\": 1, \"salary\": 100, \"dep\": \"hr\" }"); + append(tableName, "{ \"id\": 2, \"salary\": 200, \"dep\": \"hr\" }"); + append(tableName, "{ \"id\": 3, \"salary\": 300, \"dep\": \"hr\" }"); + append(tableName, "{ \"id\": 4, \"salary\": 400, \"dep\": \"hardware\" }"); + + sql(createTableStmt, tableName(OTHER_TABLE_NAME), tablePropsAsString(COMMON_TABLE_PROPERTIES)); + + append(tableName(OTHER_TABLE_NAME), "{ \"id\": 1, \"salary\": 110, \"dep\": \"hr\" }"); + append(tableName(OTHER_TABLE_NAME), "{ \"id\": 5, \"salary\": 500, \"dep\": \"hr\" }"); + + Map updateTableProps = + ImmutableMap.of( + TableProperties.UPDATE_MODE, + mode.modeName(), + TableProperties.UPDATE_DISTRIBUTION_MODE, + "none"); + + sql("ALTER TABLE %s SET TBLPROPERTIES(%s)", tableName, tablePropsAsString(updateTableProps)); + + withSQLConf( + ENABLED_SPJ_SQL_CONF, + () -> { + SparkPlan plan = + executeAndKeepPlan( + "UPDATE %s t SET salary = -1 WHERE " + + "EXISTS (SELECT 1 FROM %s s WHERE t.id = s.id AND t.dep = s.dep) AND " + + "dep = 'hr'", + tableName, tableName(OTHER_TABLE_NAME)); + String planAsString = plan.toString(); + Assert.assertFalse("Should be no shuffles with SPJ", planAsString.contains("Exchange")); + }); + + ImmutableList expectedRows = + ImmutableList.of( + row(1, -1, "hr"), // updated + row(2, 200, "hr"), // existing + row(3, 300, "hr"), // existing + row(4, 400, "hardware")); // existing + + assertEquals( + "Should have expected rows", + expectedRows, + sql("SELECT * FROM %s ORDER BY id, salary", tableName)); + } + + @Test + public void testCopyOnWriteMergeWithoutShuffles() { + checkMerge(COPY_ON_WRITE); + } + + @Test + public void testMergeOnReadMergeWithoutShuffles() { + checkMerge(MERGE_ON_READ); + } + + private void checkMerge(RowLevelOperationMode mode) { + String createTableStmt = + "CREATE TABLE %s (id INT, salary INT, dep STRING)" + + "USING iceberg " + + "PARTITIONED BY (dep) " + + "TBLPROPERTIES (%s)"; + + sql(createTableStmt, tableName, tablePropsAsString(COMMON_TABLE_PROPERTIES)); + + append(tableName, "{ \"id\": 1, \"salary\": 100, \"dep\": \"hr\" }"); + append(tableName, "{ \"id\": 2, \"salary\": 200, \"dep\": \"hr\" }"); + append(tableName, "{ \"id\": 3, \"salary\": 300, \"dep\": \"hr\" }"); + append(tableName, "{ \"id\": 4, \"salary\": 400, \"dep\": \"hardware\" }"); + + sql(createTableStmt, tableName(OTHER_TABLE_NAME), tablePropsAsString(COMMON_TABLE_PROPERTIES)); + + append(tableName(OTHER_TABLE_NAME), "{ \"id\": 1, \"salary\": 110, \"dep\": \"hr\" }"); + append(tableName(OTHER_TABLE_NAME), "{ \"id\": 5, \"salary\": 500, \"dep\": \"hr\" }"); + + Map mergeTableProps = + ImmutableMap.of( + TableProperties.MERGE_MODE, + mode.modeName(), + TableProperties.MERGE_DISTRIBUTION_MODE, + "none"); + + sql("ALTER TABLE %s SET TBLPROPERTIES(%s)", tableName, tablePropsAsString(mergeTableProps)); + + withSQLConf( + ENABLED_SPJ_SQL_CONF, + () -> { + SparkPlan plan = + executeAndKeepPlan( + "MERGE INTO %s AS t USING %s AS s " + + "ON t.id = s.id AND t.dep = s.dep AND t.dep = 'hr'" + + "WHEN MATCHED THEN " + + " UPDATE SET t.salary = s.salary " + + "WHEN NOT MATCHED THEN " + + " INSERT *", + tableName, tableName(OTHER_TABLE_NAME)); + String planAsString = plan.toString(); + Assert.assertFalse("Should be no shuffles with SPJ", planAsString.contains("Exchange")); + }); + + ImmutableList expectedRows = + ImmutableList.of( + row(1, 110, "hr"), // updated + row(2, 200, "hr"), // existing + row(3, 300, "hr"), // existing + row(4, 400, "hardware"), // existing + row(5, 500, "hr")); // new + + assertEquals( + "Should have expected rows", + expectedRows, + sql("SELECT * FROM %s ORDER BY id, salary", tableName)); + } +} diff --git a/spark/v3.4/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestTagDDL.java b/spark/v3.4/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestTagDDL.java new file mode 100644 index 000000000000..25efaaf766ea --- /dev/null +++ b/spark/v3.4/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestTagDDL.java @@ -0,0 +1,364 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.extensions; + +import java.util.List; +import java.util.Locale; +import java.util.Map; +import java.util.concurrent.TimeUnit; +import org.apache.iceberg.AssertHelpers; +import org.apache.iceberg.SnapshotRef; +import org.apache.iceberg.Table; +import org.apache.iceberg.exceptions.ValidationException; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableList; +import org.apache.iceberg.spark.SparkCatalogConfig; +import org.apache.iceberg.spark.source.SimpleRecord; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.catalyst.analysis.NoSuchTableException; +import org.apache.spark.sql.catalyst.parser.extensions.IcebergParseException; +import org.junit.After; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; +import org.junit.runners.Parameterized; + +public class TestTagDDL extends SparkExtensionsTestBase { + private static final String[] TIME_UNITS = {"DAYS", "HOURS", "MINUTES"}; + + @Parameterized.Parameters(name = "catalogName = {0}, implementation = {1}, config = {2}") + public static Object[][] parameters() { + return new Object[][] { + { + SparkCatalogConfig.SPARK.catalogName(), + SparkCatalogConfig.SPARK.implementation(), + SparkCatalogConfig.SPARK.properties() + } + }; + } + + public TestTagDDL(String catalogName, String implementation, Map config) { + super(catalogName, implementation, config); + } + + @Before + public void before() { + sql("CREATE TABLE %s (id INT, data STRING) USING iceberg", tableName); + } + + @After + public void removeTable() { + sql("DROP TABLE IF EXISTS %s", tableName); + } + + @Test + public void testCreateTagWithRetain() throws NoSuchTableException { + Table table = insertRows(); + long firstSnapshotId = table.currentSnapshot().snapshotId(); + long maxRefAge = 10L; + + List records = + ImmutableList.of(new SimpleRecord(1, "a"), new SimpleRecord(2, "b")); + Dataset df = spark.createDataFrame(records, SimpleRecord.class); + df.writeTo(tableName).append(); + + for (String timeUnit : TIME_UNITS) { + String tagName = "t1" + timeUnit; + sql( + "ALTER TABLE %s CREATE TAG %s AS OF VERSION %d RETAIN %d %s", + tableName, tagName, firstSnapshotId, maxRefAge, timeUnit); + table.refresh(); + SnapshotRef ref = table.refs().get(tagName); + Assert.assertEquals( + "The tag needs to point to a specific snapshot id.", firstSnapshotId, ref.snapshotId()); + Assert.assertEquals( + "The tag needs to have the correct max ref age.", + TimeUnit.valueOf(timeUnit.toUpperCase(Locale.ENGLISH)).toMillis(maxRefAge), + ref.maxRefAgeMs().longValue()); + } + + String tagName = "t1"; + AssertHelpers.assertThrows( + "Illegal statement", + IcebergParseException.class, + "mismatched input", + () -> + sql( + "ALTER TABLE %s CREATE TAG %s AS OF VERSION %d RETAIN", + tableName, tagName, firstSnapshotId, maxRefAge)); + + AssertHelpers.assertThrows( + "Illegal statement", + IcebergParseException.class, + "mismatched input", + () -> sql("ALTER TABLE %s CREATE TAG %s RETAIN %s DAYS", tableName, tagName, "abc")); + + AssertHelpers.assertThrows( + "Illegal statement", + IcebergParseException.class, + "mismatched input 'SECONDS' expecting {'DAYS', 'HOURS', 'MINUTES'}", + () -> + sql( + "ALTER TABLE %s CREATE TAG %s AS OF VERSION %d RETAIN %d SECONDS", + tableName, tagName, firstSnapshotId, maxRefAge)); + } + + @Test + public void testCreateTagUseDefaultConfig() throws NoSuchTableException { + Table table = insertRows(); + long snapshotId = table.currentSnapshot().snapshotId(); + String tagName = "t1"; + + AssertHelpers.assertThrows( + "unknown snapshot", + ValidationException.class, + "unknown snapshot: -1", + () -> sql("ALTER TABLE %s CREATE TAG %s AS OF VERSION %d", tableName, tagName, -1)); + + sql("ALTER TABLE %s CREATE TAG %s", tableName, tagName); + table.refresh(); + SnapshotRef ref = table.refs().get(tagName); + Assert.assertEquals( + "The tag needs to point to a specific snapshot id.", snapshotId, ref.snapshotId()); + Assert.assertNull( + "The tag needs to have the default max ref age, which is null.", ref.maxRefAgeMs()); + + AssertHelpers.assertThrows( + "Cannot create an exist tag", + IllegalArgumentException.class, + "already exists", + () -> sql("ALTER TABLE %s CREATE TAG %s", tableName, tagName)); + + AssertHelpers.assertThrows( + "Non-conforming tag name", + IcebergParseException.class, + "mismatched input '123'", + () -> sql("ALTER TABLE %s CREATE TAG %s", tableName, "123")); + + table.manageSnapshots().removeTag(tagName).commit(); + List records = + ImmutableList.of(new SimpleRecord(1, "a"), new SimpleRecord(2, "b")); + Dataset df = spark.createDataFrame(records, SimpleRecord.class); + df.writeTo(tableName).append(); + snapshotId = table.currentSnapshot().snapshotId(); + sql("ALTER TABLE %s CREATE TAG %s AS OF VERSION %d", tableName, tagName, snapshotId); + table.refresh(); + ref = table.refs().get(tagName); + Assert.assertEquals( + "The tag needs to point to a specific snapshot id.", snapshotId, ref.snapshotId()); + Assert.assertNull( + "The tag needs to have the default max ref age, which is null.", ref.maxRefAgeMs()); + } + + @Test + public void testCreateTagIfNotExists() throws NoSuchTableException { + long maxSnapshotAge = 2L; + Table table = insertRows(); + String tagName = "t1"; + sql("ALTER TABLE %s CREATE TAG %s RETAIN %d days", tableName, tagName, maxSnapshotAge); + sql("ALTER TABLE %s CREATE TAG IF NOT EXISTS %s", tableName, tagName); + + table.refresh(); + SnapshotRef ref = table.refs().get(tagName); + Assert.assertEquals( + "The tag needs to point to a specific snapshot id.", + table.currentSnapshot().snapshotId(), + ref.snapshotId()); + Assert.assertEquals( + "The tag needs to have the correct max ref age.", + TimeUnit.DAYS.toMillis(maxSnapshotAge), + ref.maxRefAgeMs().longValue()); + } + + @Test + public void testReplaceTagFailsForBranch() throws NoSuchTableException { + String branchName = "branch1"; + Table table = insertRows(); + long first = table.currentSnapshot().snapshotId(); + table.manageSnapshots().createBranch(branchName, first).commit(); + insertRows(); + long second = table.currentSnapshot().snapshotId(); + + AssertHelpers.assertThrows( + "Cannot perform replace tag on branches", + IllegalArgumentException.class, + "Ref branch1 is a branch not a tag", + () -> sql("ALTER TABLE %s REPLACE Tag %s", tableName, branchName, second)); + } + + @Test + public void testReplaceTag() throws NoSuchTableException { + Table table = insertRows(); + long first = table.currentSnapshot().snapshotId(); + String tagName = "t1"; + long expectedMaxRefAgeMs = 1000; + table + .manageSnapshots() + .createTag(tagName, first) + .setMaxRefAgeMs(tagName, expectedMaxRefAgeMs) + .commit(); + + insertRows(); + long second = table.currentSnapshot().snapshotId(); + + sql("ALTER TABLE %s REPLACE Tag %s AS OF VERSION %d", tableName, tagName, second); + table.refresh(); + SnapshotRef ref = table.refs().get(tagName); + Assert.assertEquals( + "The tag needs to point to a specific snapshot id.", second, ref.snapshotId()); + Assert.assertEquals( + "The tag needs to have the correct max ref age.", + expectedMaxRefAgeMs, + ref.maxRefAgeMs().longValue()); + } + + @Test + public void testReplaceTagDoesNotExist() throws NoSuchTableException { + Table table = insertRows(); + + AssertHelpers.assertThrows( + "Cannot perform replace tag on tag which does not exist", + IllegalArgumentException.class, + "Tag does not exist", + () -> + sql( + "ALTER TABLE %s REPLACE Tag %s AS OF VERSION %d", + tableName, "someTag", table.currentSnapshot().snapshotId())); + } + + @Test + public void testReplaceTagWithRetain() throws NoSuchTableException { + Table table = insertRows(); + long first = table.currentSnapshot().snapshotId(); + String tagName = "t1"; + table.manageSnapshots().createTag(tagName, first).commit(); + insertRows(); + long second = table.currentSnapshot().snapshotId(); + + long maxRefAge = 10; + for (String timeUnit : TIME_UNITS) { + sql( + "ALTER TABLE %s REPLACE Tag %s AS OF VERSION %d RETAIN %d %s", + tableName, tagName, second, maxRefAge, timeUnit); + + table.refresh(); + SnapshotRef ref = table.refs().get(tagName); + Assert.assertEquals( + "The tag needs to point to a specific snapshot id.", second, ref.snapshotId()); + Assert.assertEquals( + "The tag needs to have the correct max ref age.", + TimeUnit.valueOf(timeUnit).toMillis(maxRefAge), + ref.maxRefAgeMs().longValue()); + } + } + + @Test + public void testCreateOrReplace() throws NoSuchTableException { + Table table = insertRows(); + long first = table.currentSnapshot().snapshotId(); + String tagName = "t1"; + insertRows(); + long second = table.currentSnapshot().snapshotId(); + table.manageSnapshots().createTag(tagName, second).commit(); + + sql("ALTER TABLE %s CREATE OR REPLACE TAG %s AS OF VERSION %d", tableName, tagName, first); + table.refresh(); + SnapshotRef ref = table.refs().get(tagName); + Assert.assertEquals( + "The tag needs to point to a specific snapshot id.", first, ref.snapshotId()); + } + + @Test + public void testDropTag() throws NoSuchTableException { + insertRows(); + Table table = validationCatalog.loadTable(tableIdent); + String tagName = "t1"; + table.manageSnapshots().createTag(tagName, table.currentSnapshot().snapshotId()).commit(); + SnapshotRef ref = table.refs().get(tagName); + Assert.assertEquals( + "The tag needs to point to a specific snapshot id.", + table.currentSnapshot().snapshotId(), + ref.snapshotId()); + + sql("ALTER TABLE %s DROP TAG %s", tableName, tagName); + table.refresh(); + ref = table.refs().get(tagName); + Assert.assertNull("The tag needs to be dropped.", ref); + } + + @Test + public void testDropTagNonConformingName() { + AssertHelpers.assertThrows( + "Non-conforming tag name", + IcebergParseException.class, + "mismatched input '123'", + () -> sql("ALTER TABLE %s DROP TAG %s", tableName, "123")); + } + + @Test + public void testDropTagDoesNotExist() { + AssertHelpers.assertThrows( + "Cannot perform drop tag on tag which does not exist", + IllegalArgumentException.class, + "Tag does not exist: nonExistingTag", + () -> sql("ALTER TABLE %s DROP TAG %s", tableName, "nonExistingTag")); + } + + @Test + public void testDropTagFailesForBranch() throws NoSuchTableException { + String branchName = "b1"; + Table table = insertRows(); + table.manageSnapshots().createBranch(branchName, table.currentSnapshot().snapshotId()).commit(); + + AssertHelpers.assertThrows( + "Cannot perform drop tag on branch", + IllegalArgumentException.class, + "Ref b1 is a branch not a tag", + () -> sql("ALTER TABLE %s DROP TAG %s", tableName, branchName)); + } + + @Test + public void testDropTagIfExists() throws NoSuchTableException { + String tagName = "nonExistingTag"; + Table table = insertRows(); + Assert.assertNull("The tag does not exists.", table.refs().get(tagName)); + + sql("ALTER TABLE %s DROP TAG IF EXISTS %s", tableName, tagName); + table.refresh(); + Assert.assertNull("The tag still does not exist.", table.refs().get(tagName)); + + table.manageSnapshots().createTag(tagName, table.currentSnapshot().snapshotId()).commit(); + Assert.assertEquals( + "The tag has been created successfully.", + table.currentSnapshot().snapshotId(), + table.refs().get(tagName).snapshotId()); + + sql("ALTER TABLE %s DROP TAG IF EXISTS %s", tableName, tagName); + table.refresh(); + Assert.assertNull("The tag needs to be dropped.", table.refs().get(tagName)); + } + + private Table insertRows() throws NoSuchTableException { + List records = + ImmutableList.of(new SimpleRecord(1, "a"), new SimpleRecord(2, "b")); + Dataset df = spark.createDataFrame(records, SimpleRecord.class); + df.writeTo(tableName).append(); + return validationCatalog.loadTable(tableIdent); + } +} diff --git a/spark/v3.4/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestUpdate.java b/spark/v3.4/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestUpdate.java new file mode 100644 index 000000000000..776fbb960055 --- /dev/null +++ b/spark/v3.4/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestUpdate.java @@ -0,0 +1,1330 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.extensions; + +import static org.apache.iceberg.DataOperations.OVERWRITE; +import static org.apache.iceberg.RowLevelOperationMode.COPY_ON_WRITE; +import static org.apache.iceberg.SnapshotSummary.ADDED_FILES_PROP; +import static org.apache.iceberg.SnapshotSummary.CHANGED_PARTITION_COUNT_PROP; +import static org.apache.iceberg.SnapshotSummary.DELETED_FILES_PROP; +import static org.apache.iceberg.TableProperties.PARQUET_ROW_GROUP_SIZE_BYTES; +import static org.apache.iceberg.TableProperties.SPLIT_SIZE; +import static org.apache.iceberg.TableProperties.UPDATE_ISOLATION_LEVEL; +import static org.apache.iceberg.TableProperties.UPDATE_MODE; +import static org.apache.iceberg.TableProperties.UPDATE_MODE_DEFAULT; +import static org.apache.spark.sql.functions.lit; + +import java.util.Arrays; +import java.util.List; +import java.util.Map; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.Future; +import java.util.concurrent.ThreadPoolExecutor; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicInteger; +import org.apache.iceberg.AppendFiles; +import org.apache.iceberg.AssertHelpers; +import org.apache.iceberg.DataFile; +import org.apache.iceberg.RowLevelOperationMode; +import org.apache.iceberg.Snapshot; +import org.apache.iceberg.SnapshotSummary; +import org.apache.iceberg.Table; +import org.apache.iceberg.TableProperties; +import org.apache.iceberg.data.GenericRecord; +import org.apache.iceberg.exceptions.ValidationException; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableList; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableSet; +import org.apache.iceberg.relocated.com.google.common.collect.Iterables; +import org.apache.iceberg.relocated.com.google.common.collect.Lists; +import org.apache.iceberg.relocated.com.google.common.util.concurrent.MoreExecutors; +import org.apache.iceberg.spark.SparkSQLProperties; +import org.apache.iceberg.util.SnapshotUtil; +import org.apache.spark.sql.AnalysisException; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Encoders; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.catalyst.analysis.NoSuchTableException; +import org.apache.spark.sql.internal.SQLConf; +import org.assertj.core.api.Assertions; +import org.junit.After; +import org.junit.Assert; +import org.junit.Assume; +import org.junit.BeforeClass; +import org.junit.Test; + +public abstract class TestUpdate extends SparkRowLevelOperationsTestBase { + + public TestUpdate( + String catalogName, + String implementation, + Map config, + String fileFormat, + boolean vectorized, + String distributionMode, + String branch) { + super(catalogName, implementation, config, fileFormat, vectorized, distributionMode, branch); + } + + @BeforeClass + public static void setupSparkConf() { + spark.conf().set("spark.sql.shuffle.partitions", "4"); + } + + @After + public void removeTables() { + sql("DROP TABLE IF EXISTS %s", tableName); + sql("DROP TABLE IF EXISTS updated_id"); + sql("DROP TABLE IF EXISTS updated_dep"); + sql("DROP TABLE IF EXISTS deleted_employee"); + } + + @Test + public void testExplain() { + createAndInitTable("id INT, dep STRING"); + + sql("INSERT INTO TABLE %s VALUES (1, 'hr'), (2, 'hardware'), (null, 'hr')", tableName); + createBranchIfNeeded(); + + sql("EXPLAIN UPDATE %s SET dep = 'invalid' WHERE id <=> 1", commitTarget()); + + sql("EXPLAIN UPDATE %s SET dep = 'invalid' WHERE true", commitTarget()); + + Table table = validationCatalog.loadTable(tableIdent); + Assert.assertEquals("Should have 1 snapshot", 1, Iterables.size(table.snapshots())); + + assertEquals( + "Should have expected rows", + ImmutableList.of(row(1, "hr"), row(2, "hardware"), row(null, "hr")), + sql("SELECT * FROM %s ORDER BY id ASC NULLS LAST", selectTarget())); + } + + @Test + public void testUpdateEmptyTable() { + Assume.assumeFalse("Custom branch does not exist for empty table", "test".equals(branch)); + createAndInitTable("id INT, dep STRING"); + + sql("UPDATE %s SET dep = 'invalid' WHERE id IN (1)", commitTarget()); + sql("UPDATE %s SET id = -1 WHERE dep = 'hr'", commitTarget()); + + Table table = validationCatalog.loadTable(tableIdent); + Assert.assertEquals("Should have 2 snapshots", 2, Iterables.size(table.snapshots())); + + assertEquals( + "Should have expected rows", + ImmutableList.of(), + sql("SELECT * FROM %s ORDER BY id", selectTarget())); + } + + @Test + public void testUpdateNonExistingCustomBranch() { + Assume.assumeTrue("Test only applicable to custom branch", "test".equals(branch)); + createAndInitTable("id INT, dep STRING"); + + Assertions.assertThatThrownBy( + () -> sql("UPDATE %s SET dep = 'invalid' WHERE id IN (1)", commitTarget())) + .isInstanceOf(ValidationException.class) + .hasMessage("Cannot use branch (does not exist): test"); + } + + @Test + public void testUpdateWithAlias() { + createAndInitTable("id INT, dep STRING", "{ \"id\": 1, \"dep\": \"a\" }"); + sql("ALTER TABLE %s ADD PARTITION FIELD dep", tableName); + + sql("UPDATE %s AS t SET t.dep = 'invalid'", commitTarget()); + + Table table = validationCatalog.loadTable(tableIdent); + Assert.assertEquals("Should have 2 snapshots", 2, Iterables.size(table.snapshots())); + + assertEquals( + "Should have expected rows", + ImmutableList.of(row(1, "invalid")), + sql("SELECT * FROM %s", selectTarget())); + } + + @Test + public void testUpdateAlignsAssignments() { + createAndInitTable("id INT, c1 INT, c2 INT"); + + sql("INSERT INTO TABLE %s VALUES (1, 11, 111), (2, 22, 222)", tableName); + createBranchIfNeeded(); + + sql("UPDATE %s SET `c2` = c2 - 2, c1 = `c1` - 1 WHERE id <=> 1", commitTarget()); + + assertEquals( + "Should have expected rows", + ImmutableList.of(row(1, 10, 109), row(2, 22, 222)), + sql("SELECT * FROM %s ORDER BY id", selectTarget())); + } + + @Test + public void testUpdateWithUnsupportedPartitionPredicate() { + createAndInitTable("id INT, dep STRING"); + sql("ALTER TABLE %s ADD PARTITION FIELD dep", tableName); + + sql("INSERT INTO TABLE %s VALUES (1, 'software'), (2, 'hr')", tableName); + createBranchIfNeeded(); + + sql("UPDATE %s t SET `t`.`id` = -1 WHERE t.dep LIKE '%%r' ", commitTarget()); + + assertEquals( + "Should have expected rows", + ImmutableList.of(row(-1, "hr"), row(1, "software")), + sql("SELECT * FROM %s ORDER BY id", selectTarget())); + } + + @Test + public void testUpdateWithDynamicFileFiltering() { + createAndInitTable("id INT, dep STRING"); + sql("ALTER TABLE %s ADD PARTITION FIELD dep", tableName); + + append(tableName, "{ \"id\": 1, \"dep\": \"hr\" }\n" + "{ \"id\": 3, \"dep\": \"hr\" }"); + createBranchIfNeeded(); + append( + commitTarget(), + "{ \"id\": 1, \"dep\": \"hardware\" }\n" + "{ \"id\": 2, \"dep\": \"hardware\" }"); + + sql("UPDATE %s SET id = cast('-1' AS INT) WHERE id = 2", commitTarget()); + + Table table = validationCatalog.loadTable(tableIdent); + Assert.assertEquals("Should have 3 snapshots", 3, Iterables.size(table.snapshots())); + + Snapshot currentSnapshot = SnapshotUtil.latestSnapshot(table, branch); + if (mode(table) == COPY_ON_WRITE) { + validateCopyOnWrite(currentSnapshot, "1", "1", "1"); + } else { + validateMergeOnRead(currentSnapshot, "1", "1", "1"); + } + + assertEquals( + "Should have expected rows", + ImmutableList.of(row(-1, "hardware"), row(1, "hardware"), row(1, "hr"), row(3, "hr")), + sql("SELECT * FROM %s ORDER BY id, dep", commitTarget())); + } + + @Test + public void testUpdateNonExistingRecords() { + createAndInitTable("id INT, dep STRING"); + sql("ALTER TABLE %s ADD PARTITION FIELD dep", tableName); + + sql("INSERT INTO TABLE %s VALUES (1, 'hr'), (2, 'hardware'), (null, 'hr')", tableName); + createBranchIfNeeded(); + + sql("UPDATE %s SET id = -1 WHERE id > 10", commitTarget()); + + Table table = validationCatalog.loadTable(tableIdent); + Assert.assertEquals("Should have 2 snapshots", 2, Iterables.size(table.snapshots())); + + Snapshot currentSnapshot = SnapshotUtil.latestSnapshot(table, branch); + if (mode(table) == COPY_ON_WRITE) { + validateCopyOnWrite(currentSnapshot, "0", null, null); + } else { + validateMergeOnRead(currentSnapshot, "0", null, null); + } + + assertEquals( + "Should have expected rows", + ImmutableList.of(row(1, "hr"), row(2, "hardware"), row(null, "hr")), + sql("SELECT * FROM %s ORDER BY id ASC NULLS LAST", selectTarget())); + } + + @Test + public void testUpdateWithoutCondition() { + createAndInitTable("id INT, dep STRING"); + sql("ALTER TABLE %s ADD PARTITION FIELD dep", tableName); + + sql("ALTER TABLE %s WRITE DISTRIBUTED BY PARTITION", tableName); + + sql("INSERT INTO TABLE %s VALUES (1, 'hr')", tableName); + createBranchIfNeeded(); + sql("INSERT INTO TABLE %s VALUES (2, 'hardware')", commitTarget()); + sql("INSERT INTO TABLE %s VALUES (null, 'hr')", commitTarget()); + + // set the num of shuffle partitions to 200 instead of default 4 to reduce the chance of hashing + // records for multiple source files to one writing task (needed for a predictable num of output + // files) + withSQLConf( + ImmutableMap.of(SQLConf.SHUFFLE_PARTITIONS().key(), "200"), + () -> { + sql("UPDATE %s SET id = -1", commitTarget()); + }); + + Table table = validationCatalog.loadTable(tableIdent); + Assert.assertEquals("Should have 4 snapshots", 4, Iterables.size(table.snapshots())); + + Snapshot currentSnapshot = SnapshotUtil.latestSnapshot(table, branch); + + Assert.assertEquals("Operation must match", OVERWRITE, currentSnapshot.operation()); + if (mode(table) == COPY_ON_WRITE) { + Assert.assertEquals("Operation must match", OVERWRITE, currentSnapshot.operation()); + validateProperty(currentSnapshot, CHANGED_PARTITION_COUNT_PROP, "2"); + validateProperty(currentSnapshot, DELETED_FILES_PROP, "3"); + validateProperty(currentSnapshot, ADDED_FILES_PROP, ImmutableSet.of("2", "3")); + } else { + validateMergeOnRead(currentSnapshot, "2", "2", "2"); + } + + assertEquals( + "Should have expected rows", + ImmutableList.of(row(-1, "hardware"), row(-1, "hr"), row(-1, "hr")), + sql("SELECT * FROM %s ORDER BY dep ASC", selectTarget())); + } + + @Test + public void testUpdateWithNullConditions() { + createAndInitTable("id INT, dep STRING"); + + append( + tableName, + "{ \"id\": 0, \"dep\": null }\n" + + "{ \"id\": 1, \"dep\": \"hr\" }\n" + + "{ \"id\": 2, \"dep\": \"hardware\" }"); + createBranchIfNeeded(); + + // should not update any rows as null is never equal to null + sql("UPDATE %s SET id = -1 WHERE dep = NULL", commitTarget()); + assertEquals( + "Should have expected rows", + ImmutableList.of(row(0, null), row(1, "hr"), row(2, "hardware")), + sql("SELECT * FROM %s ORDER BY id", selectTarget())); + + // should not update any rows the condition does not match any records + sql("UPDATE %s SET id = -1 WHERE dep = 'software'", commitTarget()); + assertEquals( + "Should have expected rows", + ImmutableList.of(row(0, null), row(1, "hr"), row(2, "hardware")), + sql("SELECT * FROM %s ORDER BY id", selectTarget())); + + // should update one matching row with a null-safe condition + sql("UPDATE %s SET dep = 'invalid', id = -1 WHERE dep <=> NULL", commitTarget()); + assertEquals( + "Should have expected rows", + ImmutableList.of(row(-1, "invalid"), row(1, "hr"), row(2, "hardware")), + sql("SELECT * FROM %s ORDER BY id", selectTarget())); + } + + @Test + public void testUpdateWithInAndNotInConditions() { + createAndInitTable("id INT, dep STRING"); + + append( + tableName, + "{ \"id\": 1, \"dep\": \"hr\" }\n" + + "{ \"id\": 2, \"dep\": \"hardware\" }\n" + + "{ \"id\": null, \"dep\": \"hr\" }"); + createBranchIfNeeded(); + + sql("UPDATE %s SET id = -1 WHERE id IN (1, null)", commitTarget()); + assertEquals( + "Should have expected rows", + ImmutableList.of(row(-1, "hr"), row(2, "hardware"), row(null, "hr")), + sql("SELECT * FROM %s ORDER BY id ASC NULLS LAST", selectTarget())); + + sql("UPDATE %s SET id = 100 WHERE id NOT IN (null, 1)", commitTarget()); + assertEquals( + "Should have expected rows", + ImmutableList.of(row(-1, "hr"), row(2, "hardware"), row(null, "hr")), + sql("SELECT * FROM %s ORDER BY id ASC NULLS LAST", selectTarget())); + + sql("UPDATE %s SET id = 100 WHERE id NOT IN (1, 10)", commitTarget()); + assertEquals( + "Should have expected rows", + ImmutableList.of(row(100, "hardware"), row(100, "hr"), row(null, "hr")), + sql("SELECT * FROM %s ORDER BY id ASC NULLS LAST, dep", selectTarget())); + } + + @Test + public void testUpdateWithMultipleRowGroupsParquet() throws NoSuchTableException { + Assume.assumeTrue(fileFormat.equalsIgnoreCase("parquet")); + + createAndInitTable("id INT, dep STRING"); + sql("ALTER TABLE %s ADD PARTITION FIELD dep", tableName); + + sql( + "ALTER TABLE %s SET TBLPROPERTIES('%s' '%d')", + tableName, PARQUET_ROW_GROUP_SIZE_BYTES, 100); + sql("ALTER TABLE %s SET TBLPROPERTIES('%s' '%d')", tableName, SPLIT_SIZE, 100); + + List ids = Lists.newArrayListWithCapacity(200); + for (int id = 1; id <= 200; id++) { + ids.add(id); + } + Dataset df = + spark + .createDataset(ids, Encoders.INT()) + .withColumnRenamed("value", "id") + .withColumn("dep", lit("hr")); + df.coalesce(1).writeTo(tableName).append(); + createBranchIfNeeded(); + + Assert.assertEquals(200, spark.table(commitTarget()).count()); + + // update a record from one of two row groups and copy over the second one + sql("UPDATE %s SET id = -1 WHERE id IN (200, 201)", commitTarget()); + + Assert.assertEquals(200, spark.table(commitTarget()).count()); + } + + @Test + public void testUpdateNestedStructFields() { + createAndInitTable( + "id INT, s STRUCT,m:MAP>>", + "{ \"id\": 1, \"s\": { \"c1\": 2, \"c2\": { \"a\": [1,2], \"m\": { \"a\": \"b\"} } } } }"); + + // update primitive, array, map columns inside a struct + sql("UPDATE %s SET s.c1 = -1, s.c2.m = map('k', 'v'), s.c2.a = array(-1)", commitTarget()); + + assertEquals( + "Output should match", + ImmutableList.of(row(1, row(-1, row(ImmutableList.of(-1), ImmutableMap.of("k", "v"))))), + sql("SELECT * FROM %s", selectTarget())); + + // set primitive, array, map columns to NULL (proper casts should be in place) + sql("UPDATE %s SET s.c1 = NULL, s.c2 = NULL WHERE id IN (1)", commitTarget()); + + assertEquals( + "Output should match", + ImmutableList.of(row(1, row(null, null))), + sql("SELECT * FROM %s", selectTarget())); + + // update all fields in a struct + sql( + "UPDATE %s SET s = named_struct('c1', 1, 'c2', named_struct('a', array(1), 'm', null))", + commitTarget()); + + assertEquals( + "Output should match", + ImmutableList.of(row(1, row(1, row(ImmutableList.of(1), null)))), + sql("SELECT * FROM %s", selectTarget())); + } + + @Test + public void testUpdateWithUserDefinedDistribution() { + createAndInitTable("id INT, c2 INT, c3 INT"); + sql("ALTER TABLE %s ADD PARTITION FIELD bucket(8, c3)", tableName); + + append( + tableName, + "{ \"id\": 1, \"c2\": 11, \"c3\": 1 }\n" + + "{ \"id\": 2, \"c2\": 22, \"c3\": 1 }\n" + + "{ \"id\": 3, \"c2\": 33, \"c3\": 1 }"); + createBranchIfNeeded(); + + // request a global sort + sql("ALTER TABLE %s WRITE ORDERED BY c2", tableName); + sql("UPDATE %s SET c2 = -22 WHERE id NOT IN (1, 3)", commitTarget()); + assertEquals( + "Should have expected rows", + ImmutableList.of(row(1, 11, 1), row(2, -22, 1), row(3, 33, 1)), + sql("SELECT * FROM %s ORDER BY id ASC NULLS LAST", selectTarget())); + + // request a local sort + sql("ALTER TABLE %s WRITE LOCALLY ORDERED BY id", tableName); + sql("UPDATE %s SET c2 = -33 WHERE id = 3", commitTarget()); + assertEquals( + "Should have expected rows", + ImmutableList.of(row(1, 11, 1), row(2, -22, 1), row(3, -33, 1)), + sql("SELECT * FROM %s ORDER BY id ASC NULLS LAST", selectTarget())); + + // request a hash distribution + local sort + sql("ALTER TABLE %s WRITE DISTRIBUTED BY PARTITION ORDERED BY id", tableName); + sql("UPDATE %s SET c2 = -11 WHERE id = 1", commitTarget()); + assertEquals( + "Should have expected rows", + ImmutableList.of(row(1, -11, 1), row(2, -22, 1), row(3, -33, 1)), + sql("SELECT * FROM %s ORDER BY id ASC NULLS LAST", selectTarget())); + } + + @Test + public synchronized void testUpdateWithSerializableIsolation() throws InterruptedException { + // cannot run tests with concurrency for Hadoop tables without atomic renames + Assume.assumeFalse(catalogName.equalsIgnoreCase("testhadoop")); + + createAndInitTable("id INT, dep STRING"); + + sql( + "ALTER TABLE %s SET TBLPROPERTIES('%s' '%s')", + tableName, UPDATE_ISOLATION_LEVEL, "serializable"); + + sql("INSERT INTO TABLE %s VALUES (1, 'hr')", tableName); + createBranchIfNeeded(); + + ExecutorService executorService = + MoreExecutors.getExitingExecutorService( + (ThreadPoolExecutor) Executors.newFixedThreadPool(2)); + + AtomicInteger barrier = new AtomicInteger(0); + AtomicBoolean shouldAppend = new AtomicBoolean(true); + + // update thread + Future updateFuture = + executorService.submit( + () -> { + for (int numOperations = 0; numOperations < Integer.MAX_VALUE; numOperations++) { + while (barrier.get() < numOperations * 2) { + sleep(10); + } + + sql("UPDATE %s SET id = -1 WHERE id = 1", commitTarget()); + + barrier.incrementAndGet(); + } + }); + + // append thread + Future appendFuture = + executorService.submit( + () -> { + // load the table via the validation catalog to use another table instance + Table table = validationCatalog.loadTable(tableIdent); + + GenericRecord record = GenericRecord.create(SnapshotUtil.schemaFor(table, branch)); + record.set(0, 1); // id + record.set(1, "hr"); // dep + + for (int numOperations = 0; numOperations < Integer.MAX_VALUE; numOperations++) { + while (shouldAppend.get() && barrier.get() < numOperations * 2) { + sleep(10); + } + + if (!shouldAppend.get()) { + return; + } + + for (int numAppends = 0; numAppends < 5; numAppends++) { + DataFile dataFile = writeDataFile(table, ImmutableList.of(record)); + AppendFiles appendFiles = table.newFastAppend().appendFile(dataFile); + if (branch != null) { + appendFiles.toBranch(branch); + } + + appendFiles.commit(); + sleep(10); + } + + barrier.incrementAndGet(); + } + }); + + try { + Assertions.assertThatThrownBy(updateFuture::get) + .isInstanceOf(ExecutionException.class) + .cause() + .isInstanceOf(ValidationException.class) + .hasMessageContaining("Found conflicting files that can contain"); + } finally { + shouldAppend.set(false); + appendFuture.cancel(true); + } + + executorService.shutdown(); + Assert.assertTrue("Timeout", executorService.awaitTermination(2, TimeUnit.MINUTES)); + } + + @Test + public synchronized void testUpdateWithSnapshotIsolation() + throws InterruptedException, ExecutionException { + // cannot run tests with concurrency for Hadoop tables without atomic renames + Assume.assumeFalse(catalogName.equalsIgnoreCase("testhadoop")); + + createAndInitTable("id INT, dep STRING"); + + sql( + "ALTER TABLE %s SET TBLPROPERTIES('%s' '%s')", + tableName, UPDATE_ISOLATION_LEVEL, "snapshot"); + + sql("INSERT INTO TABLE %s VALUES (1, 'hr')", tableName); + createBranchIfNeeded(); + + ExecutorService executorService = + MoreExecutors.getExitingExecutorService( + (ThreadPoolExecutor) Executors.newFixedThreadPool(2)); + + AtomicInteger barrier = new AtomicInteger(0); + AtomicBoolean shouldAppend = new AtomicBoolean(true); + + // update thread + Future updateFuture = + executorService.submit( + () -> { + for (int numOperations = 0; numOperations < 20; numOperations++) { + while (barrier.get() < numOperations * 2) { + sleep(10); + } + + sql("UPDATE %s SET id = -1 WHERE id = 1", tableName); + + barrier.incrementAndGet(); + } + }); + + // append thread + Future appendFuture = + executorService.submit( + () -> { + // load the table via the validation catalog to use another table instance for inserts + Table table = validationCatalog.loadTable(tableIdent); + + GenericRecord record = GenericRecord.create(SnapshotUtil.schemaFor(table, branch)); + record.set(0, 1); // id + record.set(1, "hr"); // dep + + for (int numOperations = 0; numOperations < 20; numOperations++) { + while (shouldAppend.get() && barrier.get() < numOperations * 2) { + sleep(10); + } + + if (!shouldAppend.get()) { + return; + } + + for (int numAppends = 0; numAppends < 5; numAppends++) { + DataFile dataFile = writeDataFile(table, ImmutableList.of(record)); + AppendFiles appendFiles = table.newFastAppend().appendFile(dataFile); + if (branch != null) { + appendFiles.toBranch(branch); + } + + appendFiles.commit(); + sleep(10); + } + + barrier.incrementAndGet(); + } + }); + + try { + updateFuture.get(); + } finally { + shouldAppend.set(false); + appendFuture.cancel(true); + } + + executorService.shutdown(); + Assert.assertTrue("Timeout", executorService.awaitTermination(2, TimeUnit.MINUTES)); + } + + @Test + public void testUpdateWithInferredCasts() { + createAndInitTable("id INT, s STRING", "{ \"id\": 1, \"s\": \"value\" }"); + + sql("UPDATE %s SET s = -1 WHERE id = 1", commitTarget()); + + assertEquals( + "Should have expected rows", + ImmutableList.of(row(1, "-1")), + sql("SELECT * FROM %s", selectTarget())); + } + + @Test + public void testUpdateModifiesNullStruct() { + createAndInitTable("id INT, s STRUCT", "{ \"id\": 1, \"s\": null }"); + + sql("UPDATE %s SET s.n1 = -1 WHERE id = 1", commitTarget()); + + assertEquals( + "Should have expected rows", + ImmutableList.of(row(1, row(-1, null))), + sql("SELECT * FROM %s", selectTarget())); + } + + @Test + public void testUpdateRefreshesRelationCache() { + createAndInitTable("id INT, dep STRING"); + sql("ALTER TABLE %s ADD PARTITION FIELD dep", tableName); + + append(tableName, "{ \"id\": 1, \"dep\": \"hr\" }\n" + "{ \"id\": 3, \"dep\": \"hr\" }"); + createBranchIfNeeded(); + + append( + commitTarget(), + "{ \"id\": 1, \"dep\": \"hardware\" }\n" + "{ \"id\": 2, \"dep\": \"hardware\" }"); + + Dataset query = spark.sql("SELECT * FROM " + commitTarget() + " WHERE id = 1"); + query.createOrReplaceTempView("tmp"); + + spark.sql("CACHE TABLE tmp"); + + assertEquals( + "View should have correct data", + ImmutableList.of(row(1, "hardware"), row(1, "hr")), + sql("SELECT * FROM tmp ORDER BY id, dep")); + + sql("UPDATE %s SET id = -1 WHERE id = 1", commitTarget()); + + Table table = validationCatalog.loadTable(tableIdent); + Assert.assertEquals("Should have 3 snapshots", 3, Iterables.size(table.snapshots())); + + Snapshot currentSnapshot = SnapshotUtil.latestSnapshot(table, branch); + if (mode(table) == COPY_ON_WRITE) { + validateCopyOnWrite(currentSnapshot, "2", "2", "2"); + } else { + validateMergeOnRead(currentSnapshot, "2", "2", "2"); + } + + assertEquals( + "Should have expected rows", + ImmutableList.of(row(-1, "hardware"), row(-1, "hr"), row(2, "hardware"), row(3, "hr")), + sql("SELECT * FROM %s ORDER BY id, dep", commitTarget())); + + assertEquals( + "Should refresh the relation cache", + ImmutableList.of(), + sql("SELECT * FROM tmp ORDER BY id, dep")); + + spark.sql("UNCACHE TABLE tmp"); + } + + @Test + public void testUpdateWithInSubquery() { + createAndInitTable("id INT, dep STRING"); + + append( + tableName, + "{ \"id\": 1, \"dep\": \"hr\" }\n" + + "{ \"id\": 2, \"dep\": \"hardware\" }\n" + + "{ \"id\": null, \"dep\": \"hr\" }"); + createBranchIfNeeded(); + + createOrReplaceView("updated_id", Arrays.asList(0, 1, null), Encoders.INT()); + createOrReplaceView("updated_dep", Arrays.asList("software", "hr"), Encoders.STRING()); + + sql( + "UPDATE %s SET id = -1 WHERE " + + "id IN (SELECT * FROM updated_id) AND " + + "dep IN (SELECT * from updated_dep)", + commitTarget()); + assertEquals( + "Should have expected rows", + ImmutableList.of(row(-1, "hr"), row(2, "hardware"), row(null, "hr")), + sql("SELECT * FROM %s ORDER BY id ASC NULLS LAST", selectTarget())); + + sql( + "UPDATE %s SET id = 5 WHERE id IS NULL OR id IN (SELECT value + 1 FROM updated_id)", + commitTarget()); + assertEquals( + "Should have expected rows", + ImmutableList.of(row(-1, "hr"), row(5, "hardware"), row(5, "hr")), + sql("SELECT * FROM %s ORDER BY id, dep", selectTarget())); + + append( + commitTarget(), "{ \"id\": null, \"dep\": \"hr\" }\n" + "{ \"id\": 2, \"dep\": \"hr\" }"); + assertEquals( + "Should have expected rows", + ImmutableList.of( + row(-1, "hr"), row(2, "hr"), row(5, "hardware"), row(5, "hr"), row(null, "hr")), + sql("SELECT * FROM %s ORDER BY id ASC NULLS LAST, dep", selectTarget())); + + sql( + "UPDATE %s SET id = 10 WHERE id IN (SELECT value + 2 FROM updated_id) AND dep = 'hr'", + commitTarget()); + assertEquals( + "Should have expected rows", + ImmutableList.of( + row(-1, "hr"), row(5, "hardware"), row(5, "hr"), row(10, "hr"), row(null, "hr")), + sql("SELECT * FROM %s ORDER BY id ASC NULLS LAST, dep", selectTarget())); + } + + @Test + public void testUpdateWithInSubqueryAndDynamicFileFiltering() { + createAndInitTable("id INT, dep STRING"); + sql("ALTER TABLE %s ADD PARTITION FIELD dep", tableName); + + sql("ALTER TABLE %s WRITE DISTRIBUTED BY PARTITION", tableName); + + append(tableName, "{ \"id\": 1, \"dep\": \"hr\" }\n" + "{ \"id\": 3, \"dep\": \"hr\" }"); + createBranchIfNeeded(); + append( + commitTarget(), + "{ \"id\": 1, \"dep\": \"hardware\" }\n" + "{ \"id\": 2, \"dep\": \"hardware\" }"); + + createOrReplaceView("updated_id", Arrays.asList(-1, 2), Encoders.INT()); + + sql("UPDATE %s SET id = -1 WHERE id IN (SELECT * FROM updated_id)", commitTarget()); + + Table table = validationCatalog.loadTable(tableIdent); + Assert.assertEquals("Should have 3 snapshots", 3, Iterables.size(table.snapshots())); + + Snapshot currentSnapshot = SnapshotUtil.latestSnapshot(table, branch); + if (mode(table) == COPY_ON_WRITE) { + validateCopyOnWrite(currentSnapshot, "1", "1", "1"); + } else { + validateMergeOnRead(currentSnapshot, "1", "1", "1"); + } + + assertEquals( + "Should have expected rows", + ImmutableList.of(row(-1, "hardware"), row(1, "hardware"), row(1, "hr"), row(3, "hr")), + sql("SELECT * FROM %s ORDER BY id, dep", commitTarget())); + } + + @Test + public void testUpdateWithSelfSubquery() { + createAndInitTable("id INT, dep STRING"); + + append(tableName, "{ \"id\": 1, \"dep\": \"hr\" }\n" + "{ \"id\": 2, \"dep\": \"hr\" }"); + createBranchIfNeeded(); + + sql( + "UPDATE %s SET dep = 'x' WHERE id IN (SELECT id + 1 FROM %s)", + commitTarget(), commitTarget()); + assertEquals( + "Should have expected rows", + ImmutableList.of(row(1, "hr"), row(2, "x")), + sql("SELECT * FROM %s ORDER BY id", selectTarget())); + + // TODO: Spark does not support AQE and DPP with aggregates at the moment + withSQLConf( + ImmutableMap.of(SQLConf.ADAPTIVE_EXECUTION_ENABLED().key(), "false"), + () -> { + sql( + "UPDATE %s SET dep = 'y' WHERE " + + "id = (SELECT count(*) FROM (SELECT DISTINCT id FROM %s) AS t)", + commitTarget(), commitTarget()); + assertEquals( + "Should have expected rows", + ImmutableList.of(row(1, "hr"), row(2, "y")), + sql("SELECT * FROM %s ORDER BY id", selectTarget())); + }); + + sql("UPDATE %s SET id = (SELECT id - 2 FROM %s WHERE id = 1)", commitTarget(), commitTarget()); + assertEquals( + "Should have expected rows", + ImmutableList.of(row(-1, "hr"), row(-1, "y")), + sql("SELECT * FROM %s ORDER BY id, dep", selectTarget())); + } + + @Test + public void testUpdateWithMultiColumnInSubquery() { + createAndInitTable("id INT, dep STRING"); + + append( + tableName, + "{ \"id\": 1, \"dep\": \"hr\" }\n" + + "{ \"id\": 2, \"dep\": \"hardware\" }\n" + + "{ \"id\": null, \"dep\": \"hr\" }"); + createBranchIfNeeded(); + + List deletedEmployees = + Arrays.asList(new Employee(null, "hr"), new Employee(1, "hr")); + createOrReplaceView("deleted_employee", deletedEmployees, Encoders.bean(Employee.class)); + + sql( + "UPDATE %s SET dep = 'x', id = -1 WHERE (id, dep) IN (SELECT id, dep FROM deleted_employee)", + commitTarget()); + assertEquals( + "Should have expected rows", + ImmutableList.of(row(-1, "x"), row(2, "hardware"), row(null, "hr")), + sql("SELECT * FROM %s ORDER BY id ASC NULLS LAST", selectTarget())); + } + + @Test + public void testUpdateWithNotInSubquery() { + createAndInitTable("id INT, dep STRING"); + + append( + tableName, + "{ \"id\": 1, \"dep\": \"hr\" }\n" + + "{ \"id\": 2, \"dep\": \"hardware\" }\n" + + "{ \"id\": null, \"dep\": \"hr\" }"); + createBranchIfNeeded(); + + createOrReplaceView("updated_id", Arrays.asList(-1, -2, null), Encoders.INT()); + createOrReplaceView("updated_dep", Arrays.asList("software", "hr"), Encoders.STRING()); + + // the file filter subquery (nested loop lef-anti join) returns 0 records + sql("UPDATE %s SET id = -1 WHERE id NOT IN (SELECT * FROM updated_id)", commitTarget()); + assertEquals( + "Should have expected rows", + ImmutableList.of(row(1, "hr"), row(2, "hardware"), row(null, "hr")), + sql("SELECT * FROM %s ORDER BY id ASC NULLS LAST", selectTarget())); + + sql( + "UPDATE %s SET id = -1 WHERE id NOT IN (SELECT * FROM updated_id WHERE value IS NOT NULL)", + commitTarget()); + assertEquals( + "Should have expected rows", + ImmutableList.of(row(-1, "hardware"), row(-1, "hr"), row(null, "hr")), + sql("SELECT * FROM %s ORDER BY id ASC NULLS LAST, dep", selectTarget())); + + sql( + "UPDATE %s SET id = 5 WHERE id NOT IN (SELECT * FROM updated_id) OR dep IN ('software', 'hr')", + commitTarget()); + assertEquals( + "Should have expected rows", + ImmutableList.of(row(-1, "hardware"), row(5, "hr"), row(5, "hr")), + sql("SELECT * FROM %s ORDER BY id ASC NULLS LAST, dep", selectTarget())); + } + + @Test + public void testUpdateWithExistSubquery() { + createAndInitTable("id INT, dep STRING"); + + append( + tableName, + "{ \"id\": 1, \"dep\": \"hr\" }\n" + + "{ \"id\": 2, \"dep\": \"hardware\" }\n" + + "{ \"id\": null, \"dep\": \"hr\" }"); + createBranchIfNeeded(); + + createOrReplaceView("updated_id", Arrays.asList(-1, -2, null), Encoders.INT()); + createOrReplaceView("updated_dep", Arrays.asList("hr", null), Encoders.STRING()); + + sql( + "UPDATE %s t SET id = -1 WHERE EXISTS (SELECT 1 FROM updated_id u WHERE t.id = u.value)", + commitTarget()); + assertEquals( + "Should have expected rows", + ImmutableList.of(row(1, "hr"), row(2, "hardware"), row(null, "hr")), + sql("SELECT * FROM %s ORDER BY id ASC NULLS LAST", selectTarget())); + + sql( + "UPDATE %s t SET dep = 'x', id = -1 WHERE " + + "EXISTS (SELECT 1 FROM updated_id u WHERE t.id = u.value + 2)", + commitTarget()); + assertEquals( + "Should have expected rows", + ImmutableList.of(row(-1, "x"), row(2, "hardware"), row(null, "hr")), + sql("SELECT * FROM %s ORDER BY id ASC NULLS LAST", selectTarget())); + + sql( + "UPDATE %s t SET id = -2 WHERE " + + "EXISTS (SELECT 1 FROM updated_id u WHERE t.id = u.value) OR " + + "t.id IS NULL", + commitTarget()); + assertEquals( + "Should have expected rows", + ImmutableList.of(row(-2, "hr"), row(-2, "x"), row(2, "hardware")), + sql("SELECT * FROM %s ORDER BY id, dep", selectTarget())); + + sql( + "UPDATE %s t SET id = 1 WHERE " + + "EXISTS (SELECT 1 FROM updated_id ui WHERE t.id = ui.value) AND " + + "EXISTS (SELECT 1 FROM updated_dep ud WHERE t.dep = ud.value)", + commitTarget()); + assertEquals( + "Should have expected rows", + ImmutableList.of(row(-2, "x"), row(1, "hr"), row(2, "hardware")), + sql("SELECT * FROM %s ORDER BY id, dep", selectTarget())); + } + + @Test + public void testUpdateWithNotExistsSubquery() { + createAndInitTable("id INT, dep STRING"); + + append( + tableName, + "{ \"id\": 1, \"dep\": \"hr\" }\n" + + "{ \"id\": 2, \"dep\": \"hardware\" }\n" + + "{ \"id\": null, \"dep\": \"hr\" }"); + createBranchIfNeeded(); + + createOrReplaceView("updated_id", Arrays.asList(-1, -2, null), Encoders.INT()); + createOrReplaceView("updated_dep", Arrays.asList("hr", "software"), Encoders.STRING()); + + sql( + "UPDATE %s t SET id = -1 WHERE NOT EXISTS (SELECT 1 FROM updated_id u WHERE t.id = u.value + 2)", + commitTarget()); + assertEquals( + "Should have expected rows", + ImmutableList.of(row(-1, "hardware"), row(-1, "hr"), row(1, "hr")), + sql("SELECT * FROM %s ORDER BY id, dep", selectTarget())); + + sql( + "UPDATE %s t SET id = 5 WHERE " + + "NOT EXISTS (SELECT 1 FROM updated_id u WHERE t.id = u.value) OR " + + "t.id = 1", + commitTarget()); + assertEquals( + "Should have expected rows", + ImmutableList.of(row(-1, "hardware"), row(-1, "hr"), row(5, "hr")), + sql("SELECT * FROM %s ORDER BY id, dep", selectTarget())); + + sql( + "UPDATE %s t SET id = 10 WHERE " + + "NOT EXISTS (SELECT 1 FROM updated_id ui WHERE t.id = ui.value) AND " + + "EXISTS (SELECT 1 FROM updated_dep ud WHERE t.dep = ud.value)", + commitTarget()); + assertEquals( + "Should have expected rows", + ImmutableList.of(row(-1, "hardware"), row(-1, "hr"), row(10, "hr")), + sql("SELECT * FROM %s ORDER BY id, dep", selectTarget())); + } + + @Test + public void testUpdateWithScalarSubquery() { + createAndInitTable("id INT, dep STRING"); + + append( + tableName, + "{ \"id\": 1, \"dep\": \"hr\" }\n" + + "{ \"id\": 2, \"dep\": \"hardware\" }\n" + + "{ \"id\": null, \"dep\": \"hr\" }"); + createBranchIfNeeded(); + + createOrReplaceView("updated_id", Arrays.asList(1, 100, null), Encoders.INT()); + + // TODO: Spark does not support AQE and DPP with aggregates at the moment + withSQLConf( + ImmutableMap.of(SQLConf.ADAPTIVE_EXECUTION_ENABLED().key(), "false"), + () -> { + sql( + "UPDATE %s SET id = -1 WHERE id <= (SELECT min(value) FROM updated_id)", + commitTarget()); + assertEquals( + "Should have expected rows", + ImmutableList.of(row(-1, "hr"), row(2, "hardware"), row(null, "hr")), + sql("SELECT * FROM %s ORDER BY id ASC NULLS LAST", selectTarget())); + }); + } + + @Test + public void testUpdateThatRequiresGroupingBeforeWrite() { + createAndInitTable("id INT, dep STRING"); + sql("ALTER TABLE %s ADD PARTITION FIELD dep", tableName); + + append( + tableName, + "{ \"id\": 0, \"dep\": \"hr\" }\n" + + "{ \"id\": 1, \"dep\": \"hr\" }\n" + + "{ \"id\": 2, \"dep\": \"hr\" }"); + createBranchIfNeeded(); + + append( + commitTarget(), + "{ \"id\": 0, \"dep\": \"ops\" }\n" + + "{ \"id\": 1, \"dep\": \"ops\" }\n" + + "{ \"id\": 2, \"dep\": \"ops\" }"); + + append( + commitTarget(), + "{ \"id\": 0, \"dep\": \"hr\" }\n" + + "{ \"id\": 1, \"dep\": \"hr\" }\n" + + "{ \"id\": 2, \"dep\": \"hr\" }"); + + append( + commitTarget(), + "{ \"id\": 0, \"dep\": \"ops\" }\n" + + "{ \"id\": 1, \"dep\": \"ops\" }\n" + + "{ \"id\": 2, \"dep\": \"ops\" }"); + + createOrReplaceView("updated_id", Arrays.asList(1, 100), Encoders.INT()); + + String originalNumOfShufflePartitions = spark.conf().get("spark.sql.shuffle.partitions"); + try { + // set the num of shuffle partitions to 1 to ensure we have only 1 writing task + spark.conf().set("spark.sql.shuffle.partitions", "1"); + + sql("UPDATE %s t SET id = -1 WHERE id IN (SELECT * FROM updated_id)", commitTarget()); + Assert.assertEquals( + "Should have expected num of rows", 12L, spark.table(commitTarget()).count()); + } finally { + spark.conf().set("spark.sql.shuffle.partitions", originalNumOfShufflePartitions); + } + } + + @Test + public void testUpdateWithVectorization() { + createAndInitTable("id INT, dep STRING"); + + append( + tableName, + "{ \"id\": 0, \"dep\": \"hr\" }\n" + + "{ \"id\": 1, \"dep\": \"hr\" }\n" + + "{ \"id\": 2, \"dep\": \"hr\" }"); + createBranchIfNeeded(); + + withSQLConf( + ImmutableMap.of(SparkSQLProperties.VECTORIZATION_ENABLED, "true"), + () -> { + sql("UPDATE %s t SET id = -1", commitTarget()); + + assertEquals( + "Should have expected rows", + ImmutableList.of(row(-1, "hr"), row(-1, "hr"), row(-1, "hr")), + sql("SELECT * FROM %s ORDER BY id, dep", selectTarget())); + }); + } + + @Test + public void testUpdateModifyPartitionSourceField() throws NoSuchTableException { + createAndInitTable("id INT, dep STRING, country STRING"); + + sql("ALTER TABLE %s ADD PARTITION FIELD bucket(4, id)", tableName); + sql("ALTER TABLE %s ADD PARTITION FIELD dep", tableName); + + List ids = Lists.newArrayListWithCapacity(100); + for (int id = 1; id <= 100; id++) { + ids.add(id); + } + + Dataset df1 = + spark + .createDataset(ids, Encoders.INT()) + .withColumnRenamed("value", "id") + .withColumn("dep", lit("hr")) + .withColumn("country", lit("usa")); + df1.coalesce(1).writeTo(tableName).append(); + createBranchIfNeeded(); + + Dataset df2 = + spark + .createDataset(ids, Encoders.INT()) + .withColumnRenamed("value", "id") + .withColumn("dep", lit("software")) + .withColumn("country", lit("usa")); + df2.coalesce(1).writeTo(commitTarget()).append(); + + Dataset df3 = + spark + .createDataset(ids, Encoders.INT()) + .withColumnRenamed("value", "id") + .withColumn("dep", lit("hardware")) + .withColumn("country", lit("usa")); + df3.coalesce(1).writeTo(commitTarget()).append(); + + sql( + "UPDATE %s SET id = -1 WHERE id IN (10, 11, 12, 13, 14, 15, 16, 17, 18, 19)", + commitTarget()); + Assert.assertEquals(30L, scalarSql("SELECT count(*) FROM %s WHERE id = -1", selectTarget())); + } + + @Test + public void testUpdateWithStaticPredicatePushdown() { + createAndInitTable("id INT, dep STRING"); + + sql("ALTER TABLE %s ADD PARTITION FIELD dep", tableName); + + // add a data file to the 'software' partition + append(tableName, "{ \"id\": 1, \"dep\": \"software\" }"); + createBranchIfNeeded(); + + // add a data file to the 'hr' partition + append(commitTarget(), "{ \"id\": 1, \"dep\": \"hr\" }"); + + Table table = validationCatalog.loadTable(tableIdent); + + Snapshot snapshot = SnapshotUtil.latestSnapshot(table, branch); + String dataFilesCount = snapshot.summary().get(SnapshotSummary.TOTAL_DATA_FILES_PROP); + Assert.assertEquals("Must have 2 files before UPDATE", "2", dataFilesCount); + + // remove the data file from the 'hr' partition to ensure it is not scanned + DataFile dataFile = Iterables.getOnlyElement(snapshot.addedDataFiles(table.io())); + table.io().deleteFile(dataFile.path().toString()); + + // disable dynamic pruning and rely only on static predicate pushdown + withSQLConf( + ImmutableMap.of(SQLConf.DYNAMIC_PARTITION_PRUNING_ENABLED().key(), "false"), + () -> { + sql("UPDATE %s SET id = -1 WHERE dep IN ('software') AND id == 1", commitTarget()); + }); + } + + @Test + public void testUpdateWithInvalidUpdates() { + createAndInitTable( + "id INT, a ARRAY>, m MAP", + "{ \"id\": 0, \"a\": null, \"m\": null }"); + + AssertHelpers.assertThrows( + "Should complain about updating an array column", + AnalysisException.class, + "Updating nested fields is only supported for structs", + () -> sql("UPDATE %s SET a.c1 = 1", commitTarget())); + + AssertHelpers.assertThrows( + "Should complain about updating a map column", + AnalysisException.class, + "Updating nested fields is only supported for structs", + () -> sql("UPDATE %s SET m.key = 'new_key'", commitTarget())); + } + + @Test + public void testUpdateWithConflictingAssignments() { + createAndInitTable( + "id INT, c STRUCT>", "{ \"id\": 0, \"s\": null }"); + + AssertHelpers.assertThrows( + "Should complain about conflicting updates to a top-level column", + AnalysisException.class, + "Updates are in conflict", + () -> sql("UPDATE %s t SET t.id = 1, t.c.n1 = 2, t.id = 2", commitTarget())); + + AssertHelpers.assertThrows( + "Should complain about conflicting updates to a nested column", + AnalysisException.class, + "Updates are in conflict for these columns", + () -> sql("UPDATE %s t SET t.c.n1 = 1, t.id = 2, t.c.n1 = 2", commitTarget())); + + AssertHelpers.assertThrows( + "Should complain about conflicting updates to a nested column", + AnalysisException.class, + "Updates are in conflict", + () -> { + sql( + "UPDATE %s SET c.n1 = 1, c = named_struct('n1', 1, 'n2', named_struct('dn1', 1, 'dn2', 2))", + commitTarget()); + }); + } + + @Test + public void testUpdateWithInvalidAssignments() { + createAndInitTable( + "id INT NOT NULL, s STRUCT> NOT NULL", + "{ \"id\": 0, \"s\": { \"n1\": 1, \"n2\": { \"dn1\": 3, \"dn2\": 4 } } }"); + + for (String policy : new String[] {"ansi", "strict"}) { + withSQLConf( + ImmutableMap.of("spark.sql.storeAssignmentPolicy", policy), + () -> { + AssertHelpers.assertThrows( + "Should complain about writing nulls to a top-level column", + AnalysisException.class, + "Cannot write nullable values to non-null column", + () -> sql("UPDATE %s t SET t.id = NULL", commitTarget())); + + AssertHelpers.assertThrows( + "Should complain about writing nulls to a nested column", + AnalysisException.class, + "Cannot write nullable values to non-null column", + () -> sql("UPDATE %s t SET t.s.n1 = NULL", commitTarget())); + + AssertHelpers.assertThrows( + "Should complain about writing missing fields in structs", + AnalysisException.class, + "missing fields", + () -> sql("UPDATE %s t SET t.s = named_struct('n1', 1)", commitTarget())); + + AssertHelpers.assertThrows( + "Should complain about writing invalid data types", + AnalysisException.class, + "Cannot safely cast", + () -> sql("UPDATE %s t SET t.s.n1 = 'str'", commitTarget())); + + AssertHelpers.assertThrows( + "Should complain about writing incompatible structs", + AnalysisException.class, + "field name does not match", + () -> + sql( + "UPDATE %s t SET t.s.n2 = named_struct('dn2', 1, 'dn1', 2)", + commitTarget())); + }); + } + } + + @Test + public void testUpdateWithNonDeterministicCondition() { + createAndInitTable("id INT, dep STRING", "{ \"id\": 1, \"dep\": \"hr\" }"); + + AssertHelpers.assertThrows( + "Should complain about non-deterministic expressions", + AnalysisException.class, + "nondeterministic expressions are only allowed", + () -> sql("UPDATE %s SET id = -1 WHERE id = 1 AND rand() > 0.5", commitTarget())); + } + + @Test + public void testUpdateOnNonIcebergTableNotSupported() { + createOrReplaceView("testtable", "{ \"c1\": -100, \"c2\": -200 }"); + + AssertHelpers.assertThrows( + "UPDATE is not supported for non iceberg table", + UnsupportedOperationException.class, + "not supported temporarily", + () -> sql("UPDATE %s SET c1 = -1 WHERE c2 = 1", "testtable")); + } + + @Test + public void testUpdateToWAPBranch() { + Assume.assumeTrue("WAP branch only works for table identifier without branch", branch == null); + + createAndInitTable( + "id INT, dep STRING", "{ \"id\": 1, \"dep\": \"hr\" }\n" + "{ \"id\": 2, \"dep\": \"a\" }"); + sql( + "ALTER TABLE %s SET TBLPROPERTIES ('%s' = 'true')", + tableName, TableProperties.WRITE_AUDIT_PUBLISH_ENABLED); + + withSQLConf( + ImmutableMap.of(SparkSQLProperties.WAP_BRANCH, "wap"), + () -> { + sql("UPDATE %s SET dep='hr' WHERE dep='a'", tableName); + Assert.assertEquals( + "Should have expected num of rows when reading table", + 2L, + sql("SELECT * FROM %s WHERE dep='hr'", tableName).size()); + Assert.assertEquals( + "Should have expected num of rows when reading WAP branch", + 2L, + sql("SELECT * FROM %s.branch_wap WHERE dep='hr'", tableName).size()); + Assert.assertEquals( + "Should not modify main branch", + 1L, + sql("SELECT * FROM %s.branch_main WHERE dep='hr'", tableName).size()); + }); + + withSQLConf( + ImmutableMap.of(SparkSQLProperties.WAP_BRANCH, "wap"), + () -> { + sql("UPDATE %s SET dep='b' WHERE dep='hr'", tableName); + Assert.assertEquals( + "Should have expected num of rows when reading table with multiple writes", + 2L, + sql("SELECT * FROM %s WHERE dep='b'", tableName).size()); + Assert.assertEquals( + "Should have expected num of rows when reading WAP branch with multiple writes", + 2L, + sql("SELECT * FROM %s.branch_wap WHERE dep='b'", tableName).size()); + Assert.assertEquals( + "Should not modify main branch with multiple writes", + 0L, + sql("SELECT * FROM %s.branch_main WHERE dep='b'", tableName).size()); + }); + } + + @Test + public void testUpdateToWapBranchWithTableBranchIdentifier() { + Assume.assumeTrue("Test must have branch name part in table identifier", branch != null); + + createAndInitTable("id INT, dep STRING", "{ \"id\": 1, \"dep\": \"hr\" }"); + sql( + "ALTER TABLE %s SET TBLPROPERTIES ('%s' = 'true')", + tableName, TableProperties.WRITE_AUDIT_PUBLISH_ENABLED); + + withSQLConf( + ImmutableMap.of(SparkSQLProperties.WAP_BRANCH, "wap"), + () -> + Assertions.assertThatThrownBy( + () -> sql("UPDATE %s SET dep='hr' WHERE dep='a'", commitTarget())) + .isInstanceOf(ValidationException.class) + .hasMessage( + String.format( + "Cannot write to both branch and WAP branch, but got branch [%s] and WAP branch [wap]", + branch))); + } + + private RowLevelOperationMode mode(Table table) { + String modeName = table.properties().getOrDefault(UPDATE_MODE, UPDATE_MODE_DEFAULT); + return RowLevelOperationMode.fromName(modeName); + } +} diff --git a/spark/v3.4/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestWriteAborts.java b/spark/v3.4/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestWriteAborts.java new file mode 100644 index 000000000000..37b38621b1a1 --- /dev/null +++ b/spark/v3.4/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestWriteAborts.java @@ -0,0 +1,193 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.extensions; + +import java.util.List; +import java.util.Map; +import org.apache.hadoop.conf.Configuration; +import org.apache.iceberg.AssertHelpers; +import org.apache.iceberg.CatalogProperties; +import org.apache.iceberg.TableProperties; +import org.apache.iceberg.hadoop.HadoopFileIO; +import org.apache.iceberg.io.BulkDeletionFailureException; +import org.apache.iceberg.io.FileIO; +import org.apache.iceberg.io.InputFile; +import org.apache.iceberg.io.OutputFile; +import org.apache.iceberg.io.SupportsBulkOperations; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableList; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap; +import org.apache.iceberg.spark.SparkCatalog; +import org.apache.iceberg.spark.SparkWriteOptions; +import org.apache.iceberg.spark.source.SimpleRecord; +import org.apache.spark.SparkException; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.catalyst.analysis.NoSuchTableException; +import org.junit.After; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.TemporaryFolder; +import org.junit.runners.Parameterized; + +public class TestWriteAborts extends SparkExtensionsTestBase { + + @Parameterized.Parameters(name = "catalogName = {0}, implementation = {1}, config = {2}") + public static Object[][] parameters() { + return new Object[][] { + { + "testhive", + SparkCatalog.class.getName(), + ImmutableMap.of( + "type", + "hive", + CatalogProperties.FILE_IO_IMPL, + CustomFileIO.class.getName(), + "default-namespace", + "default") + }, + { + "testhivebulk", + SparkCatalog.class.getName(), + ImmutableMap.of( + "type", + "hive", + CatalogProperties.FILE_IO_IMPL, + CustomBulkFileIO.class.getName(), + "default-namespace", + "default") + } + }; + } + + @Rule public TemporaryFolder temp = new TemporaryFolder(); + + public TestWriteAborts(String catalogName, String implementation, Map config) { + super(catalogName, implementation, config); + } + + @After + public void removeTables() { + sql("DROP TABLE IF EXISTS %s", tableName); + } + + @Test + public void testBatchAppend() throws Exception { + String dataLocation = temp.newFolder().toString(); + + sql( + "CREATE TABLE %s (id INT, data STRING) " + + "USING iceberg " + + "PARTITIONED BY (data)" + + "TBLPROPERTIES ('%s' '%s')", + tableName, TableProperties.WRITE_DATA_LOCATION, dataLocation); + + List records = + ImmutableList.of( + new SimpleRecord(1, "a"), + new SimpleRecord(2, "b"), + new SimpleRecord(3, "a"), + new SimpleRecord(4, "b")); + Dataset inputDF = spark.createDataFrame(records, SimpleRecord.class); + + AssertHelpers.assertThrows( + "Write must fail", + SparkException.class, + "Encountered records that belong to already closed files", + () -> { + try { + // incoming records are not ordered by partitions so the job must fail + inputDF + .coalesce(1) + .sortWithinPartitions("id") + .writeTo(tableName) + .option(SparkWriteOptions.USE_TABLE_DISTRIBUTION_AND_ORDERING, "false") + .append(); + } catch (NoSuchTableException e) { + throw new RuntimeException(e); + } + }); + + assertEquals("Should be no records", sql("SELECT * FROM %s", tableName), ImmutableList.of()); + + assertEquals( + "Should be no orphan data files", + ImmutableList.of(), + sql( + "CALL %s.system.remove_orphan_files(table => '%s', older_than => %dL, location => '%s')", + catalogName, tableName, System.currentTimeMillis() + 5000, dataLocation)); + } + + public static class CustomFileIO implements FileIO { + + private final FileIO delegate = new HadoopFileIO(new Configuration()); + + public CustomFileIO() {} + + protected FileIO delegate() { + return delegate; + } + + @Override + public InputFile newInputFile(String path) { + return delegate.newInputFile(path); + } + + @Override + public OutputFile newOutputFile(String path) { + return delegate.newOutputFile(path); + } + + @Override + public void deleteFile(String path) { + delegate.deleteFile(path); + } + + @Override + public Map properties() { + return delegate.properties(); + } + + @Override + public void initialize(Map properties) { + delegate.initialize(properties); + } + + @Override + public void close() { + delegate.close(); + } + } + + public static class CustomBulkFileIO extends CustomFileIO implements SupportsBulkOperations { + + public CustomBulkFileIO() {} + + @Override + public void deleteFile(String path) { + throw new UnsupportedOperationException("Only bulk deletes are supported"); + } + + @Override + public void deleteFiles(Iterable paths) throws BulkDeletionFailureException { + for (String path : paths) { + delegate().deleteFile(path); + } + } + } +} diff --git a/spark/v3.4/spark-runtime/LICENSE b/spark/v3.4/spark-runtime/LICENSE new file mode 100644 index 000000000000..9d1522431696 --- /dev/null +++ b/spark/v3.4/spark-runtime/LICENSE @@ -0,0 +1,629 @@ + + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + +-------------------------------------------------------------------------------- + +This binary artifact contains Apache Avro. + +Copyright: 2014-2017 The Apache Software Foundation. +Home page: https://parquet.apache.org/ +License: http://www.apache.org/licenses/LICENSE-2.0 + +-------------------------------------------------------------------------------- + +This binary artifact contains the Jackson JSON processor. + +Copyright: 2007-2019 Tatu Saloranta and other contributors +Home page: http://jackson.codehaus.org/ +License: http://www.apache.org/licenses/LICENSE-2.0.txt + +-------------------------------------------------------------------------------- + +This binary artifact contains Paranamer. + +Copyright: 2000-2007 INRIA, France Telecom, 2006-2018 Paul Hammant & ThoughtWorks Inc +Home page: https://github.com/paul-hammant/paranamer +License: https://github.com/paul-hammant/paranamer/blob/master/LICENSE.txt (BSD) + +License text: +| Portions copyright (c) 2006-2018 Paul Hammant & ThoughtWorks Inc +| Portions copyright (c) 2000-2007 INRIA, France Telecom +| All rights reserved. +| +| Redistribution and use in source and binary forms, with or without +| modification, are permitted provided that the following conditions +| are met: +| 1. Redistributions of source code must retain the above copyright +| notice, this list of conditions and the following disclaimer. +| 2. Redistributions in binary form must reproduce the above copyright +| notice, this list of conditions and the following disclaimer in the +| documentation and/or other materials provided with the distribution. +| 3. Neither the name of the copyright holders nor the names of its +| contributors may be used to endorse or promote products derived from +| this software without specific prior written permission. +| +| THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +| AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +| IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +| ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE +| LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +| CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +| SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +| INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +| CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +| ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF +| THE POSSIBILITY OF SUCH DAMAGE. + +-------------------------------------------------------------------------------- + +This binary artifact contains Apache Parquet. + +Copyright: 2014-2017 The Apache Software Foundation. +Home page: https://parquet.apache.org/ +License: http://www.apache.org/licenses/LICENSE-2.0 + +-------------------------------------------------------------------------------- + +This binary artifact contains Apache Thrift. + +Copyright: 2006-2010 The Apache Software Foundation. +Home page: https://thrift.apache.org/ +License: http://www.apache.org/licenses/LICENSE-2.0 + +-------------------------------------------------------------------------------- + +This binary artifact contains code from Daniel Lemire's JavaFastPFOR project. + +Copyright: 2013 Daniel Lemire +Home page: https://github.com/lemire/JavaFastPFOR +License: Apache License Version 2.0 http://www.apache.org/licenses/LICENSE-2.0 + +-------------------------------------------------------------------------------- + +This binary artifact contains fastutil. + +Copyright: 2002-2014 Sebastiano Vigna +Home page: http://fastutil.di.unimi.it/ +License: http://www.apache.org/licenses/LICENSE-2.0.html + +-------------------------------------------------------------------------------- + +This binary artifact contains Apache ORC. + +Copyright: 2013-2019 The Apache Software Foundation. +Home page: https://orc.apache.org/ +License: http://www.apache.org/licenses/LICENSE-2.0 + +-------------------------------------------------------------------------------- + +This binary artifact contains Apache Hive's storage API via ORC. + +Copyright: 2013-2019 The Apache Software Foundation. +Home page: https://hive.apache.org/ +License: http://www.apache.org/licenses/LICENSE-2.0 + +-------------------------------------------------------------------------------- + +This binary artifact contains Google protobuf via ORC. + +Copyright: 2008 Google Inc. +Home page: https://developers.google.com/protocol-buffers +License: https://github.com/protocolbuffers/protobuf/blob/master/LICENSE (BSD) + +License text: + +| Copyright 2008 Google Inc. All rights reserved. +| +| Redistribution and use in source and binary forms, with or without +| modification, are permitted provided that the following conditions are +| met: +| +| * Redistributions of source code must retain the above copyright +| notice, this list of conditions and the following disclaimer. +| * Redistributions in binary form must reproduce the above +| copyright notice, this list of conditions and the following disclaimer +| in the documentation and/or other materials provided with the +| distribution. +| * Neither the name of Google Inc. nor the names of its +| contributors may be used to endorse or promote products derived from +| this software without specific prior written permission. +| +| THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +| "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +| LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +| A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +| OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +| SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +| LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +| DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +| THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +| (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +| OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +| +| Code generated by the Protocol Buffer compiler is owned by the owner +| of the input file used when generating it. This code is not +| standalone and requires a support library to be linked with it. This +| support library is itself covered by the above license. + +-------------------------------------------------------------------------------- + +This binary artifact contains Airlift Aircompressor. + +Copyright: 2011-2019 Aircompressor authors. +Home page: https://github.com/airlift/aircompressor +License: http://www.apache.org/licenses/LICENSE-2.0 + +-------------------------------------------------------------------------------- + +This binary artifact contains Airlift Slice. + +Copyright: 2013-2019 Slice authors. +Home page: https://github.com/airlift/slice +License: http://www.apache.org/licenses/LICENSE-2.0 + +-------------------------------------------------------------------------------- + +This binary artifact contains JetBrains annotations. + +Copyright: 2000-2020 JetBrains s.r.o. +Home page: https://github.com/JetBrains/java-annotations +License: http://www.apache.org/licenses/LICENSE-2.0 + +-------------------------------------------------------------------------------- + +This binary artifact contains code from Cloudera Kite. + +Copyright: 2013-2017 Cloudera Inc. +Home page: https://kitesdk.org/ +License: http://www.apache.org/licenses/LICENSE-2.0 + +-------------------------------------------------------------------------------- + +This binary artifact contains code from Presto. + +Copyright: 2016 Facebook and contributors +Home page: https://prestodb.io/ +License: http://www.apache.org/licenses/LICENSE-2.0 + +-------------------------------------------------------------------------------- + +This binary artifact contains Google Guava. + +Copyright: 2006-2019 The Guava Authors +Home page: https://github.com/google/guava +License: http://www.apache.org/licenses/LICENSE-2.0 + +-------------------------------------------------------------------------------- + +This binary artifact contains Google Error Prone Annotations. + +Copyright: Copyright 2011-2019 The Error Prone Authors +Home page: https://github.com/google/error-prone +License: http://www.apache.org/licenses/LICENSE-2.0 + +-------------------------------------------------------------------------------- + +This binary artifact contains findbugs-annotations by Stephen Connolly. + +Copyright: 2011-2016 Stephen Connolly, Greg Lucas +Home page: https://github.com/stephenc/findbugs-annotations +License: http://www.apache.org/licenses/LICENSE-2.0 + +-------------------------------------------------------------------------------- + +This binary artifact contains Google j2objc Annotations. + +Copyright: Copyright 2012-2018 Google Inc. +Home page: https://github.com/google/j2objc/tree/master/annotations +License: http://www.apache.org/licenses/LICENSE-2.0 + +-------------------------------------------------------------------------------- + +This binary artifact contains checkerframework checker-qual Annotations. + +Copyright: 2004-2019 the Checker Framework developers +Home page: https://github.com/typetools/checker-framework +License: https://github.com/typetools/checker-framework/blob/master/LICENSE.txt (MIT license) + +License text: +| The annotations are licensed under the MIT License. (The text of this +| license appears below.) More specifically, all the parts of the Checker +| Framework that you might want to include with your own program use the +| MIT License. This is the checker-qual.jar file and all the files that +| appear in it: every file in a qual/ directory, plus utility files such +| as NullnessUtil.java, RegexUtil.java, SignednessUtil.java, etc. +| In addition, the cleanroom implementations of third-party annotations, +| which the Checker Framework recognizes as aliases for its own +| annotations, are licensed under the MIT License. +| +| Permission is hereby granted, free of charge, to any person obtaining a copy +| of this software and associated documentation files (the "Software"), to deal +| in the Software without restriction, including without limitation the rights +| to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +| copies of the Software, and to permit persons to whom the Software is +| furnished to do so, subject to the following conditions: +| +| The above copyright notice and this permission notice shall be included in +| all copies or substantial portions of the Software. +| +| THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +| IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +| FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +| AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +| LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +| OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +| THE SOFTWARE. + +-------------------------------------------------------------------------------- + +This binary artifact contains Animal Sniffer Annotations. + +Copyright: 2009-2018 codehaus.org +Home page: https://www.mojohaus.org/animal-sniffer/animal-sniffer-annotations/ +License: https://www.mojohaus.org/animal-sniffer/animal-sniffer-annotations/license.html (MIT license) + +License text: +| The MIT License +| +| Copyright (c) 2009 codehaus.org. +| +| Permission is hereby granted, free of charge, to any person obtaining a copy +| of this software and associated documentation files (the "Software"), to deal +| in the Software without restriction, including without limitation the rights +| to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +| copies of the Software, and to permit persons to whom the Software is +| furnished to do so, subject to the following conditions: +| +| The above copyright notice and this permission notice shall be included in +| all copies or substantial portions of the Software. +| +| THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +| IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +| FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +| AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +| LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +| OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +| THE SOFTWARE. + +-------------------------------------------------------------------------------- + +This binary artifact contains Caffeine by Ben Manes. + +Copyright: 2014-2019 Ben Manes and contributors +Home page: https://github.com/ben-manes/caffeine +License: http://www.apache.org/licenses/LICENSE-2.0 + +-------------------------------------------------------------------------------- + +This binary artifact contains Apache Arrow. + +Copyright: 2016-2019 The Apache Software Foundation. +Home page: https://arrow.apache.org/ +License: http://www.apache.org/licenses/LICENSE-2.0 + +-------------------------------------------------------------------------------- + +This binary artifact contains Netty's buffer library. + +Copyright: 2014-2020 The Netty Project +Home page: https://netty.io/ +License: http://www.apache.org/licenses/LICENSE-2.0 + +-------------------------------------------------------------------------------- + +This binary artifact contains Google FlatBuffers. + +Copyright: 2013-2020 Google Inc. +Home page: https://google.github.io/flatbuffers/ +License: http://www.apache.org/licenses/LICENSE-2.0 + +-------------------------------------------------------------------------------- + +This binary artifact contains Carrot Search Labs HPPC. + +Copyright: 2002-2019 Carrot Search s.c. +Home page: http://labs.carrotsearch.com/hppc.html +License: http://www.apache.org/licenses/LICENSE-2.0 + +-------------------------------------------------------------------------------- + +This binary artifact contains code from Apache Lucene via Carrot Search HPPC. + +Copyright: 2011-2020 The Apache Software Foundation. +Home page: https://lucene.apache.org/ +License: http://www.apache.org/licenses/LICENSE-2.0 + +-------------------------------------------------------------------------------- + +This binary artifact contains Apache Yetus audience annotations. + +Copyright: 2008-2020 The Apache Software Foundation. +Home page: https://yetus.apache.org/ +License: http://www.apache.org/licenses/LICENSE-2.0 + +-------------------------------------------------------------------------------- + +This binary artifact contains ThreeTen. + +Copyright: 2007-present, Stephen Colebourne & Michael Nascimento Santos. +Home page: https://www.threeten.org/threeten-extra/ +License: https://github.com/ThreeTen/threeten-extra/blob/master/LICENSE.txt (BSD 3-clause) + +License text: + +| All rights reserved. +| +| * Redistribution and use in source and binary forms, with or without +| modification, are permitted provided that the following conditions are met: +| +| * Redistributions of source code must retain the above copyright notice, +| this list of conditions and the following disclaimer. +| +| * Redistributions in binary form must reproduce the above copyright notice, +| this list of conditions and the following disclaimer in the documentation +| and/or other materials provided with the distribution. +| +| * Neither the name of JSR-310 nor the names of its contributors +| may be used to endorse or promote products derived from this software +| without specific prior written permission. +| +| THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +| "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +| LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +| A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +| CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +| EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +| PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +| PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF +| LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING +| NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS +| SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +-------------------------------------------------------------------------------- + +This binary artifact contains code from Project Nessie. + +Copyright: 2020 Dremio Corporation. +Home page: https://projectnessie.org/ +License: http://www.apache.org/licenses/LICENSE-2.0 + +-------------------------------------------------------------------------------- + +This product includes code from Apache Spark. + +* vectorized reading of definition levels in BaseVectorizedParquetValuesReader.java +* portions of the extensions parser +* casting logic in AssignmentAlignmentSupport +* implementation of SetAccumulator. + +Copyright: 2011-2018 The Apache Software Foundation +Home page: https://spark.apache.org/ +License: https://www.apache.org/licenses/LICENSE-2.0 + +-------------------------------------------------------------------------------- + +This product includes code from Delta Lake. + +* AssignmentAlignmentSupport is an independent development but UpdateExpressionsSupport in Delta was used as a reference. + +Copyright: 2020 The Delta Lake Project Authors. +Home page: https://delta.io/ +License: https://www.apache.org/licenses/LICENSE-2.0 + +-------------------------------------------------------------------------------- + +This binary includes code from Apache Commons. + +* Core ArrayUtil. + +Copyright: 2020 The Apache Software Foundation +Home page: https://commons.apache.org/ +License: https://www.apache.org/licenses/LICENSE-2.0 + +-------------------------------------------------------------------------------- + +This binary artifact contains Apache HttpComponents Client. + +Copyright: 1999-2022 The Apache Software Foundation. +Home page: https://hc.apache.org/ +License: http://www.apache.org/licenses/LICENSE-2.0 diff --git a/spark/v3.4/spark-runtime/NOTICE b/spark/v3.4/spark-runtime/NOTICE new file mode 100644 index 000000000000..4a1f4dfde1cc --- /dev/null +++ b/spark/v3.4/spark-runtime/NOTICE @@ -0,0 +1,508 @@ + +Apache Iceberg +Copyright 2017-2022 The Apache Software Foundation + +This product includes software developed at +The Apache Software Foundation (http://www.apache.org/). + +-------------------------------------------------------------------------------- + +This binary artifact contains code from Kite, developed at Cloudera, Inc. with +the following copyright notice: + +| Copyright 2013 Cloudera Inc. +| +| Licensed under the Apache License, Version 2.0 (the "License"); +| you may not use this file except in compliance with the License. +| You may obtain a copy of the License at +| +| http://www.apache.org/licenses/LICENSE-2.0 +| +| Unless required by applicable law or agreed to in writing, software +| distributed under the License is distributed on an "AS IS" BASIS, +| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +| See the License for the specific language governing permissions and +| limitations under the License. + +-------------------------------------------------------------------------------- + +This binary artifact includes Apache ORC with the following in its NOTICE file: + +| Apache ORC +| Copyright 2013-2019 The Apache Software Foundation +| +| This product includes software developed by The Apache Software +| Foundation (http://www.apache.org/). +| +| This product includes software developed by Hewlett-Packard: +| (c) Copyright [2014-2015] Hewlett-Packard Development Company, L.P + +-------------------------------------------------------------------------------- + +This binary artifact includes Airlift Aircompressor with the following in its +NOTICE file: + +| Snappy Copyright Notices +| ========================= +| +| * Copyright 2011 Dain Sundstrom +| * Copyright 2011, Google Inc. +| +| +| Snappy License +| =============== +| Copyright 2011, Google Inc. +| All rights reserved. +| +| Redistribution and use in source and binary forms, with or without +| modification, are permitted provided that the following conditions are +| met: +| +| * Redistributions of source code must retain the above copyright +| notice, this list of conditions and the following disclaimer. +| * Redistributions in binary form must reproduce the above +| copyright notice, this list of conditions and the following disclaimer +| in the documentation and/or other materials provided with the +| distribution. +| * Neither the name of Google Inc. nor the names of its +| contributors may be used to endorse or promote products derived from +| this software without specific prior written permission. +| +| THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +| "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +| LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +| A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +| OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +| SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +| LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +| DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +| THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +| (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +| OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +-------------------------------------------------------------------------------- + +This binary artifact includes Carrot Search Labs HPPC with the following in its +NOTICE file: + +| ACKNOWLEDGEMENT +| =============== +| +| HPPC borrowed code, ideas or both from: +| +| * Apache Lucene, http://lucene.apache.org/ +| (Apache license) +| * Fastutil, http://fastutil.di.unimi.it/ +| (Apache license) +| * Koloboke, https://github.com/OpenHFT/Koloboke +| (Apache license) + +-------------------------------------------------------------------------------- + +This binary artifact includes Apache Yetus with the following in its NOTICE +file: + +| Apache Yetus +| Copyright 2008-2020 The Apache Software Foundation +| +| This product includes software developed at +| The Apache Software Foundation (https://www.apache.org/). +| +| --- +| Additional licenses for the Apache Yetus Source/Website: +| --- +| +| +| See LICENSE for terms. + +-------------------------------------------------------------------------------- + +This binary artifact includes Google Protobuf with the following copyright +notice: + +| Copyright 2008 Google Inc. All rights reserved. +| +| Redistribution and use in source and binary forms, with or without +| modification, are permitted provided that the following conditions are +| met: +| +| * Redistributions of source code must retain the above copyright +| notice, this list of conditions and the following disclaimer. +| * Redistributions in binary form must reproduce the above +| copyright notice, this list of conditions and the following disclaimer +| in the documentation and/or other materials provided with the +| distribution. +| * Neither the name of Google Inc. nor the names of its +| contributors may be used to endorse or promote products derived from +| this software without specific prior written permission. +| +| THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +| "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +| LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +| A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +| OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +| SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +| LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +| DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +| THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +| (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +| OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +| +| Code generated by the Protocol Buffer compiler is owned by the owner +| of the input file used when generating it. This code is not +| standalone and requires a support library to be linked with it. This +| support library is itself covered by the above license. + +-------------------------------------------------------------------------------- + +This binary artifact includes Apache Arrow with the following in its NOTICE file: + +| Apache Arrow +| Copyright 2016-2019 The Apache Software Foundation +| +| This product includes software developed at +| The Apache Software Foundation (http://www.apache.org/). +| +| This product includes software from the SFrame project (BSD, 3-clause). +| * Copyright (C) 2015 Dato, Inc. +| * Copyright (c) 2009 Carnegie Mellon University. +| +| This product includes software from the Feather project (Apache 2.0) +| https://github.com/wesm/feather +| +| This product includes software from the DyND project (BSD 2-clause) +| https://github.com/libdynd +| +| This product includes software from the LLVM project +| * distributed under the University of Illinois Open Source +| +| This product includes software from the google-lint project +| * Copyright (c) 2009 Google Inc. All rights reserved. +| +| This product includes software from the mman-win32 project +| * Copyright https://code.google.com/p/mman-win32/ +| * Licensed under the MIT License; +| +| This product includes software from the LevelDB project +| * Copyright (c) 2011 The LevelDB Authors. All rights reserved. +| * Use of this source code is governed by a BSD-style license that can be +| * Moved from Kudu http://github.com/cloudera/kudu +| +| This product includes software from the CMake project +| * Copyright 2001-2009 Kitware, Inc. +| * Copyright 2012-2014 Continuum Analytics, Inc. +| * All rights reserved. +| +| This product includes software from https://github.com/matthew-brett/multibuild (BSD 2-clause) +| * Copyright (c) 2013-2016, Matt Terry and Matthew Brett; all rights reserved. +| +| This product includes software from the Ibis project (Apache 2.0) +| * Copyright (c) 2015 Cloudera, Inc. +| * https://github.com/cloudera/ibis +| +| This product includes software from Dremio (Apache 2.0) +| * Copyright (C) 2017-2018 Dremio Corporation +| * https://github.com/dremio/dremio-oss +| +| This product includes software from Google Guava (Apache 2.0) +| * Copyright (C) 2007 The Guava Authors +| * https://github.com/google/guava +| +| This product include software from CMake (BSD 3-Clause) +| * CMake - Cross Platform Makefile Generator +| * Copyright 2000-2019 Kitware, Inc. and Contributors +| +| The web site includes files generated by Jekyll. +| +| -------------------------------------------------------------------------------- +| +| This product includes code from Apache Kudu, which includes the following in +| its NOTICE file: +| +| Apache Kudu +| Copyright 2016 The Apache Software Foundation +| +| This product includes software developed at +| The Apache Software Foundation (http://www.apache.org/). +| +| Portions of this software were developed at +| Cloudera, Inc (http://www.cloudera.com/). +| +| -------------------------------------------------------------------------------- +| +| This product includes code from Apache ORC, which includes the following in +| its NOTICE file: +| +| Apache ORC +| Copyright 2013-2019 The Apache Software Foundation +| +| This product includes software developed by The Apache Software +| Foundation (http://www.apache.org/). +| +| This product includes software developed by Hewlett-Packard: +| (c) Copyright [2014-2015] Hewlett-Packard Development Company, L.P + +-------------------------------------------------------------------------------- + +This binary artifact includes Netty buffers with the following in its NOTICE +file: + +| The Netty Project +| ================= +| +| Please visit the Netty web site for more information: +| +| * https://netty.io/ +| +| Copyright 2014 The Netty Project +| +| The Netty Project licenses this file to you under the Apache License, +| version 2.0 (the "License"); you may not use this file except in compliance +| with the License. You may obtain a copy of the License at: +| +| http://www.apache.org/licenses/LICENSE-2.0 +| +| Unless required by applicable law or agreed to in writing, software +| distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +| WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +| License for the specific language governing permissions and limitations +| under the License. +| +| Also, please refer to each LICENSE..txt file, which is located in +| the 'license' directory of the distribution file, for the license terms of the +| components that this product depends on. +| +| ------------------------------------------------------------------------------- +| This product contains the extensions to Java Collections Framework which has +| been derived from the works by JSR-166 EG, Doug Lea, and Jason T. Greene: +| +| * LICENSE: +| * license/LICENSE.jsr166y.txt (Public Domain) +| * HOMEPAGE: +| * http://gee.cs.oswego.edu/cgi-bin/viewcvs.cgi/jsr166/ +| * http://viewvc.jboss.org/cgi-bin/viewvc.cgi/jbosscache/experimental/jsr166/ +| +| This product contains a modified version of Robert Harder's Public Domain +| Base64 Encoder and Decoder, which can be obtained at: +| +| * LICENSE: +| * license/LICENSE.base64.txt (Public Domain) +| * HOMEPAGE: +| * http://iharder.sourceforge.net/current/java/base64/ +| +| This product contains a modified portion of 'Webbit', an event based +| WebSocket and HTTP server, which can be obtained at: +| +| * LICENSE: +| * license/LICENSE.webbit.txt (BSD License) +| * HOMEPAGE: +| * https://github.com/joewalnes/webbit +| +| This product contains a modified portion of 'SLF4J', a simple logging +| facade for Java, which can be obtained at: +| +| * LICENSE: +| * license/LICENSE.slf4j.txt (MIT License) +| * HOMEPAGE: +| * http://www.slf4j.org/ +| +| This product contains a modified portion of 'Apache Harmony', an open source +| Java SE, which can be obtained at: +| +| * NOTICE: +| * license/NOTICE.harmony.txt +| * LICENSE: +| * license/LICENSE.harmony.txt (Apache License 2.0) +| * HOMEPAGE: +| * http://archive.apache.org/dist/harmony/ +| +| This product contains a modified portion of 'jbzip2', a Java bzip2 compression +| and decompression library written by Matthew J. Francis. It can be obtained at: +| +| * LICENSE: +| * license/LICENSE.jbzip2.txt (MIT License) +| * HOMEPAGE: +| * https://code.google.com/p/jbzip2/ +| +| This product contains a modified portion of 'libdivsufsort', a C API library to construct +| the suffix array and the Burrows-Wheeler transformed string for any input string of +| a constant-size alphabet written by Yuta Mori. It can be obtained at: +| +| * LICENSE: +| * license/LICENSE.libdivsufsort.txt (MIT License) +| * HOMEPAGE: +| * https://github.com/y-256/libdivsufsort +| +| This product contains a modified portion of Nitsan Wakart's 'JCTools', Java Concurrency Tools for the JVM, +| which can be obtained at: +| +| * LICENSE: +| * license/LICENSE.jctools.txt (ASL2 License) +| * HOMEPAGE: +| * https://github.com/JCTools/JCTools +| +| This product optionally depends on 'JZlib', a re-implementation of zlib in +| pure Java, which can be obtained at: +| +| * LICENSE: +| * license/LICENSE.jzlib.txt (BSD style License) +| * HOMEPAGE: +| * http://www.jcraft.com/jzlib/ +| +| This product optionally depends on 'Compress-LZF', a Java library for encoding and +| decoding data in LZF format, written by Tatu Saloranta. It can be obtained at: +| +| * LICENSE: +| * license/LICENSE.compress-lzf.txt (Apache License 2.0) +| * HOMEPAGE: +| * https://github.com/ning/compress +| +| This product optionally depends on 'lz4', a LZ4 Java compression +| and decompression library written by Adrien Grand. It can be obtained at: +| +| * LICENSE: +| * license/LICENSE.lz4.txt (Apache License 2.0) +| * HOMEPAGE: +| * https://github.com/jpountz/lz4-java +| +| This product optionally depends on 'lzma-java', a LZMA Java compression +| and decompression library, which can be obtained at: +| +| * LICENSE: +| * license/LICENSE.lzma-java.txt (Apache License 2.0) +| * HOMEPAGE: +| * https://github.com/jponge/lzma-java +| +| This product contains a modified portion of 'jfastlz', a Java port of FastLZ compression +| and decompression library written by William Kinney. It can be obtained at: +| +| * LICENSE: +| * license/LICENSE.jfastlz.txt (MIT License) +| * HOMEPAGE: +| * https://code.google.com/p/jfastlz/ +| +| This product contains a modified portion of and optionally depends on 'Protocol Buffers', Google's data +| interchange format, which can be obtained at: +| +| * LICENSE: +| * license/LICENSE.protobuf.txt (New BSD License) +| * HOMEPAGE: +| * https://github.com/google/protobuf +| +| This product optionally depends on 'Bouncy Castle Crypto APIs' to generate +| a temporary self-signed X.509 certificate when the JVM does not provide the +| equivalent functionality. It can be obtained at: +| +| * LICENSE: +| * license/LICENSE.bouncycastle.txt (MIT License) +| * HOMEPAGE: +| * http://www.bouncycastle.org/ +| +| This product optionally depends on 'Snappy', a compression library produced +| by Google Inc, which can be obtained at: +| +| * LICENSE: +| * license/LICENSE.snappy.txt (New BSD License) +| * HOMEPAGE: +| * https://github.com/google/snappy +| +| This product optionally depends on 'JBoss Marshalling', an alternative Java +| serialization API, which can be obtained at: +| +| * LICENSE: +| * license/LICENSE.jboss-marshalling.txt (Apache License 2.0) +| * HOMEPAGE: +| * https://github.com/jboss-remoting/jboss-marshalling +| +| This product optionally depends on 'Caliper', Google's micro- +| benchmarking framework, which can be obtained at: +| +| * LICENSE: +| * license/LICENSE.caliper.txt (Apache License 2.0) +| * HOMEPAGE: +| * https://github.com/google/caliper +| +| This product optionally depends on 'Apache Commons Logging', a logging +| framework, which can be obtained at: +| +| * LICENSE: +| * license/LICENSE.commons-logging.txt (Apache License 2.0) +| * HOMEPAGE: +| * http://commons.apache.org/logging/ +| +| This product optionally depends on 'Apache Log4J', a logging framework, which +| can be obtained at: +| +| * LICENSE: +| * license/LICENSE.log4j.txt (Apache License 2.0) +| * HOMEPAGE: +| * http://logging.apache.org/log4j/ +| +| This product optionally depends on 'Aalto XML', an ultra-high performance +| non-blocking XML processor, which can be obtained at: +| +| * LICENSE: +| * license/LICENSE.aalto-xml.txt (Apache License 2.0) +| * HOMEPAGE: +| * http://wiki.fasterxml.com/AaltoHome +| +| This product contains a modified version of 'HPACK', a Java implementation of +| the HTTP/2 HPACK algorithm written by Twitter. It can be obtained at: +| +| * LICENSE: +| * license/LICENSE.hpack.txt (Apache License 2.0) +| * HOMEPAGE: +| * https://github.com/twitter/hpack +| +| This product contains a modified version of 'HPACK', a Java implementation of +| the HTTP/2 HPACK algorithm written by Cory Benfield. It can be obtained at: +| +| * LICENSE: +| * license/LICENSE.hyper-hpack.txt (MIT License) +| * HOMEPAGE: +| * https://github.com/python-hyper/hpack/ +| +| This product contains a modified version of 'HPACK', a Java implementation of +| the HTTP/2 HPACK algorithm written by Tatsuhiro Tsujikawa. It can be obtained at: +| +| * LICENSE: +| * license/LICENSE.nghttp2-hpack.txt (MIT License) +| * HOMEPAGE: +| * https://github.com/nghttp2/nghttp2/ +| +| This product contains a modified portion of 'Apache Commons Lang', a Java library +| provides utilities for the java.lang API, which can be obtained at: +| +| * LICENSE: +| * license/LICENSE.commons-lang.txt (Apache License 2.0) +| * HOMEPAGE: +| * https://commons.apache.org/proper/commons-lang/ +| +| +| This product contains the Maven wrapper scripts from 'Maven Wrapper', that provides an easy way to ensure a user has everything necessary to run the Maven build. +| +| * LICENSE: +| * license/LICENSE.mvn-wrapper.txt (Apache License 2.0) +| * HOMEPAGE: +| * https://github.com/takari/maven-wrapper +| +| This product contains the dnsinfo.h header file, that provides a way to retrieve the system DNS configuration on MacOS. +| This private header is also used by Apple's open source +| mDNSResponder (https://opensource.apple.com/tarballs/mDNSResponder/). +| +| * LICENSE: +| * license/LICENSE.dnsinfo.txt (Apache License 2.0) +| * HOMEPAGE: +| * http://www.opensource.apple.com/source/configd/configd-453.19/dnsinfo/dnsinfo.h + +-------------------------------------------------------------------------------- + +This binary artifact includes Project Nessie with the following in its NOTICE +file: + +| Dremio +| Copyright 2015-2017 Dremio Corporation +| +| This product includes software developed at +| The Apache Software Foundation (http://www.apache.org/). + diff --git a/spark/v3.4/spark-runtime/src/integration/java/org/apache/iceberg/spark/SmokeTest.java b/spark/v3.4/spark-runtime/src/integration/java/org/apache/iceberg/spark/SmokeTest.java new file mode 100644 index 000000000000..510108af8c21 --- /dev/null +++ b/spark/v3.4/spark-runtime/src/integration/java/org/apache/iceberg/spark/SmokeTest.java @@ -0,0 +1,179 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark; + +import java.io.IOException; +import java.util.Map; +import org.apache.iceberg.Table; +import org.apache.iceberg.catalog.TableIdentifier; +import org.apache.iceberg.spark.extensions.SparkExtensionsTestBase; +import org.apache.spark.sql.catalyst.analysis.NoSuchTableException; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; + +public class SmokeTest extends SparkExtensionsTestBase { + + public SmokeTest(String catalogName, String implementation, Map config) { + super(catalogName, implementation, config); + } + + @Before + public void dropTable() { + sql("DROP TABLE IF EXISTS %s", tableName); + } + + // Run through our Doc's Getting Started Example + // TODO Update doc example so that it can actually be run, modifications were required for this + // test suite to run + @Test + public void testGettingStarted() throws IOException { + // Creating a table + sql("CREATE TABLE %s (id bigint, data string) USING iceberg", tableName); + + // Writing + sql("INSERT INTO %s VALUES (1, 'a'), (2, 'b'), (3, 'c')", tableName); + Assert.assertEquals( + "Should have inserted 3 rows", 3L, scalarSql("SELECT COUNT(*) FROM %s", tableName)); + + sql("DROP TABLE IF EXISTS source"); + sql( + "CREATE TABLE source (id bigint, data string) USING parquet LOCATION '%s'", + temp.newFolder()); + sql("INSERT INTO source VALUES (10, 'd'), (11, 'ee')"); + + sql("INSERT INTO %s SELECT id, data FROM source WHERE length(data) = 1", tableName); + Assert.assertEquals( + "Table should now have 4 rows", 4L, scalarSql("SELECT COUNT(*) FROM %s", tableName)); + + sql("DROP TABLE IF EXISTS updates"); + sql( + "CREATE TABLE updates (id bigint, data string) USING parquet LOCATION '%s'", + temp.newFolder()); + sql("INSERT INTO updates VALUES (1, 'x'), (2, 'x'), (4, 'z')"); + + sql( + "MERGE INTO %s t USING (SELECT * FROM updates) u ON t.id = u.id\n" + + "WHEN MATCHED THEN UPDATE SET t.data = u.data\n" + + "WHEN NOT MATCHED THEN INSERT *", + tableName); + Assert.assertEquals( + "Table should now have 5 rows", 5L, scalarSql("SELECT COUNT(*) FROM %s", tableName)); + Assert.assertEquals( + "Record 1 should now have data x", + "x", + scalarSql("SELECT data FROM %s WHERE id = 1", tableName)); + + // Reading + Assert.assertEquals( + "There should be 2 records with data x", + 2L, + scalarSql("SELECT count(1) as count FROM %s WHERE data = 'x' GROUP BY data ", tableName)); + + // Not supported because of Spark limitation + if (!catalogName.equals("spark_catalog")) { + Assert.assertEquals( + "There should be 3 snapshots", + 3L, + scalarSql("SELECT COUNT(*) FROM %s.snapshots", tableName)); + } + } + + // From Spark DDL Docs section + @Test + public void testAlterTable() throws NoSuchTableException { + sql( + "CREATE TABLE %s (category int, id bigint, data string, ts timestamp) USING iceberg", + tableName); + Table table = getTable(); + // Add examples + sql("ALTER TABLE %s ADD PARTITION FIELD bucket(16, id)", tableName); + sql("ALTER TABLE %s ADD PARTITION FIELD truncate(data, 4)", tableName); + sql("ALTER TABLE %s ADD PARTITION FIELD years(ts)", tableName); + sql("ALTER TABLE %s ADD PARTITION FIELD bucket(16, category) AS shard", tableName); + table = getTable(); + Assert.assertEquals("Table should have 4 partition fields", 4, table.spec().fields().size()); + + // Drop Examples + sql("ALTER TABLE %s DROP PARTITION FIELD bucket(16, id)", tableName); + sql("ALTER TABLE %s DROP PARTITION FIELD truncate(data, 4)", tableName); + sql("ALTER TABLE %s DROP PARTITION FIELD years(ts)", tableName); + sql("ALTER TABLE %s DROP PARTITION FIELD shard", tableName); + + table = getTable(); + Assert.assertEquals("Table should have 4 partition fields", 4, table.spec().fields().size()); + // VoidTransform is package private so we can't reach it here, just checking name + Assert.assertTrue( + "All transforms should be void", + table.spec().fields().stream().allMatch(pf -> pf.transform().toString().equals("void"))); + + // Sort order examples + sql("ALTER TABLE %s WRITE ORDERED BY category, id", tableName); + sql("ALTER TABLE %s WRITE ORDERED BY category ASC, id DESC", tableName); + sql("ALTER TABLE %s WRITE ORDERED BY category ASC NULLS LAST, id DESC NULLS FIRST", tableName); + table = getTable(); + Assert.assertEquals("Table should be sorted on 2 fields", 2, table.sortOrder().fields().size()); + } + + @Test + public void testCreateTable() { + sql("DROP TABLE IF EXISTS %s", tableName("first")); + sql("DROP TABLE IF EXISTS %s", tableName("second")); + sql("DROP TABLE IF EXISTS %s", tableName("third")); + + sql( + "CREATE TABLE %s (\n" + + " id bigint COMMENT 'unique id',\n" + + " data string)\n" + + "USING iceberg", + tableName("first")); + getTable("first"); // Table should exist + + sql( + "CREATE TABLE %s (\n" + + " id bigint,\n" + + " data string,\n" + + " category string)\n" + + "USING iceberg\n" + + "PARTITIONED BY (category)", + tableName("second")); + Table second = getTable("second"); + Assert.assertEquals("Should be partitioned on 1 column", 1, second.spec().fields().size()); + + sql( + "CREATE TABLE %s (\n" + + " id bigint,\n" + + " data string,\n" + + " category string,\n" + + " ts timestamp)\n" + + "USING iceberg\n" + + "PARTITIONED BY (bucket(16, id), days(ts), category)", + tableName("third")); + Table third = getTable("third"); + Assert.assertEquals("Should be partitioned on 3 columns", 3, third.spec().fields().size()); + } + + private Table getTable(String name) { + return validationCatalog.loadTable(TableIdentifier.of("default", name)); + } + + private Table getTable() { + return validationCatalog.loadTable(tableIdent); + } +} diff --git a/spark/v3.4/spark/src/jmh/java/org/apache/iceberg/spark/SparkBenchmarkUtil.java b/spark/v3.4/spark/src/jmh/java/org/apache/iceberg/spark/SparkBenchmarkUtil.java new file mode 100644 index 000000000000..d6b0e9c94258 --- /dev/null +++ b/spark/v3.4/spark/src/jmh/java/org/apache/iceberg/spark/SparkBenchmarkUtil.java @@ -0,0 +1,57 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark; + +import java.util.List; +import org.apache.iceberg.Schema; +import org.apache.iceberg.relocated.com.google.common.collect.Lists; +import org.apache.iceberg.types.Types; +import org.apache.spark.sql.catalyst.expressions.Attribute; +import org.apache.spark.sql.catalyst.expressions.AttributeReference; +import org.apache.spark.sql.catalyst.expressions.Expression; +import org.apache.spark.sql.catalyst.expressions.UnsafeProjection; +import org.apache.spark.sql.types.StructType; +import scala.collection.JavaConverters; + +public class SparkBenchmarkUtil { + + private SparkBenchmarkUtil() {} + + public static UnsafeProjection projection(Schema expectedSchema, Schema actualSchema) { + StructType struct = SparkSchemaUtil.convert(actualSchema); + + List refs = + JavaConverters.seqAsJavaListConverter(struct.toAttributes()).asJava(); + List attrs = Lists.newArrayListWithExpectedSize(struct.fields().length); + List exprs = Lists.newArrayListWithExpectedSize(struct.fields().length); + + for (AttributeReference ref : refs) { + attrs.add(ref.toAttribute()); + } + + for (Types.NestedField field : expectedSchema.columns()) { + int indexInIterSchema = struct.fieldIndex(field.name()); + exprs.add(refs.get(indexInIterSchema)); + } + + return UnsafeProjection.create( + JavaConverters.asScalaBufferConverter(exprs).asScala().toSeq(), + JavaConverters.asScalaBufferConverter(attrs).asScala().toSeq()); + } +} diff --git a/spark/v3.4/spark/src/jmh/java/org/apache/iceberg/spark/action/DeleteOrphanFilesBenchmark.java b/spark/v3.4/spark/src/jmh/java/org/apache/iceberg/spark/action/DeleteOrphanFilesBenchmark.java new file mode 100644 index 000000000000..5a7df7283728 --- /dev/null +++ b/spark/v3.4/spark/src/jmh/java/org/apache/iceberg/spark/action/DeleteOrphanFilesBenchmark.java @@ -0,0 +1,179 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.action; + +import static org.apache.spark.sql.functions.lit; + +import java.sql.Timestamp; +import java.util.List; +import java.util.UUID; +import java.util.concurrent.TimeUnit; +import org.apache.iceberg.AppendFiles; +import org.apache.iceberg.DataFile; +import org.apache.iceberg.DataFiles; +import org.apache.iceberg.PartitionSpec; +import org.apache.iceberg.Table; +import org.apache.iceberg.actions.DeleteOrphanFiles; +import org.apache.iceberg.relocated.com.google.common.collect.Lists; +import org.apache.iceberg.relocated.com.google.common.io.Files; +import org.apache.iceberg.spark.Spark3Util; +import org.apache.iceberg.spark.SparkSessionCatalog; +import org.apache.iceberg.spark.actions.SparkActions; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Encoders; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.SparkSession; +import org.openjdk.jmh.annotations.Benchmark; +import org.openjdk.jmh.annotations.BenchmarkMode; +import org.openjdk.jmh.annotations.Fork; +import org.openjdk.jmh.annotations.Measurement; +import org.openjdk.jmh.annotations.Mode; +import org.openjdk.jmh.annotations.Scope; +import org.openjdk.jmh.annotations.Setup; +import org.openjdk.jmh.annotations.State; +import org.openjdk.jmh.annotations.TearDown; +import org.openjdk.jmh.annotations.Threads; +import org.openjdk.jmh.annotations.Timeout; +import org.openjdk.jmh.annotations.Warmup; +import org.openjdk.jmh.infra.Blackhole; + +/** + * A benchmark that evaluates the performance of remove orphan files action in Spark. + * + *

To run this benchmark for spark-3.3: + * ./gradlew -DsparkVersions=3.3 :iceberg-spark:iceberg-spark-3.3_2.12:jmh + * -PjmhIncludeRegex=DeleteOrphanFilesBenchmark + * -PjmhOutputPath=benchmark/delete-orphan-files-benchmark-results.txt + * + */ +@Fork(1) +@State(Scope.Benchmark) +@Warmup(iterations = 3) +@Measurement(iterations = 5) +@BenchmarkMode(Mode.SingleShotTime) +@Timeout(time = 1000, timeUnit = TimeUnit.HOURS) +public class DeleteOrphanFilesBenchmark { + + private static final String TABLE_NAME = "delete_orphan_perf"; + private static final int NUM_SNAPSHOTS = 1000; + private static final int NUM_FILES = 1000; + + private SparkSession spark; + private final List validAndOrphanPaths = Lists.newArrayList(); + private Table table; + + @Setup + public void setupBench() { + setupSpark(); + initTable(); + appendData(); + addOrphans(); + } + + @TearDown + public void teardownBench() { + tearDownSpark(); + } + + @Benchmark + @Threads(1) + public void testDeleteOrphanFiles(Blackhole blackhole) { + Dataset validAndOrphanPathsDF = + spark + .createDataset(validAndOrphanPaths, Encoders.STRING()) + .withColumnRenamed("value", "file_path") + .withColumn("last_modified", lit(new Timestamp(10000))); + + DeleteOrphanFiles.Result results = + SparkActions.get(spark) + .deleteOrphanFiles(table()) + .compareToFileList(validAndOrphanPathsDF) + .execute(); + blackhole.consume(results); + } + + private void initTable() { + spark.sql( + String.format( + "CREATE TABLE %s(id INT, name STRING)" + + " USING ICEBERG" + + " TBLPROPERTIES ( 'format-version' = '2')", + TABLE_NAME)); + } + + private void appendData() { + String location = table().location(); + PartitionSpec partitionSpec = table().spec(); + + for (int i = 0; i < NUM_SNAPSHOTS; i++) { + AppendFiles appendFiles = table().newFastAppend(); + for (int j = 0; j < NUM_FILES; j++) { + String path = String.format("%s/path/to/data-%d-%d.parquet", location, i, j); + validAndOrphanPaths.add(path); + DataFile dataFile = + DataFiles.builder(partitionSpec) + .withPath(path) + .withFileSizeInBytes(10) + .withRecordCount(1) + .build(); + appendFiles.appendFile(dataFile); + } + appendFiles.commit(); + } + } + + private void addOrphans() { + String location = table.location(); + // Generate 10% orphan files + int orphanFileCount = (NUM_FILES * NUM_SNAPSHOTS) / 10; + for (int i = 0; i < orphanFileCount; i++) { + validAndOrphanPaths.add( + String.format("%s/path/to/data-%s.parquet", location, UUID.randomUUID())); + } + } + + private Table table() { + if (table == null) { + try { + table = Spark3Util.loadIcebergTable(spark, TABLE_NAME); + } catch (Exception e) { + throw new RuntimeException(e); + } + } + return table; + } + + private String catalogWarehouse() { + return Files.createTempDir().getAbsolutePath() + "/" + UUID.randomUUID() + "/"; + } + + private void setupSpark() { + SparkSession.Builder builder = + SparkSession.builder() + .config("spark.sql.catalog.spark_catalog", SparkSessionCatalog.class.getName()) + .config("spark.sql.catalog.spark_catalog.type", "hadoop") + .config("spark.sql.catalog.spark_catalog.warehouse", catalogWarehouse()) + .master("local"); + spark = builder.getOrCreate(); + } + + private void tearDownSpark() { + spark.stop(); + } +} diff --git a/spark/v3.4/spark/src/jmh/java/org/apache/iceberg/spark/action/IcebergSortCompactionBenchmark.java b/spark/v3.4/spark/src/jmh/java/org/apache/iceberg/spark/action/IcebergSortCompactionBenchmark.java new file mode 100644 index 000000000000..eaef8e0bccaa --- /dev/null +++ b/spark/v3.4/spark/src/jmh/java/org/apache/iceberg/spark/action/IcebergSortCompactionBenchmark.java @@ -0,0 +1,398 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.action; + +import static org.apache.iceberg.types.Types.NestedField.optional; +import static org.apache.iceberg.types.Types.NestedField.required; +import static org.apache.spark.sql.functions.col; +import static org.apache.spark.sql.functions.current_date; +import static org.apache.spark.sql.functions.date_add; +import static org.apache.spark.sql.functions.expr; + +import java.io.IOException; +import java.util.Collections; +import java.util.UUID; +import java.util.concurrent.TimeUnit; +import org.apache.hadoop.conf.Configuration; +import org.apache.iceberg.NullOrder; +import org.apache.iceberg.Schema; +import org.apache.iceberg.SortDirection; +import org.apache.iceberg.SortOrder; +import org.apache.iceberg.Table; +import org.apache.iceberg.actions.BinPackStrategy; +import org.apache.iceberg.relocated.com.google.common.io.Files; +import org.apache.iceberg.spark.Spark3Util; +import org.apache.iceberg.spark.SparkSchemaUtil; +import org.apache.iceberg.spark.SparkSessionCatalog; +import org.apache.iceberg.spark.actions.SparkActions; +import org.apache.iceberg.types.Types; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.SaveMode; +import org.apache.spark.sql.SparkSession; +import org.apache.spark.sql.connector.catalog.Identifier; +import org.apache.spark.sql.connector.expressions.Transform; +import org.apache.spark.sql.types.DataTypes; +import org.openjdk.jmh.annotations.Benchmark; +import org.openjdk.jmh.annotations.BenchmarkMode; +import org.openjdk.jmh.annotations.Fork; +import org.openjdk.jmh.annotations.Level; +import org.openjdk.jmh.annotations.Measurement; +import org.openjdk.jmh.annotations.Mode; +import org.openjdk.jmh.annotations.Scope; +import org.openjdk.jmh.annotations.Setup; +import org.openjdk.jmh.annotations.State; +import org.openjdk.jmh.annotations.TearDown; +import org.openjdk.jmh.annotations.Threads; +import org.openjdk.jmh.annotations.Timeout; + +@Fork(1) +@State(Scope.Benchmark) +@Measurement(iterations = 10) +@BenchmarkMode(Mode.SingleShotTime) +@Timeout(time = 1000, timeUnit = TimeUnit.HOURS) +public class IcebergSortCompactionBenchmark { + + private static final String[] NAMESPACE = new String[] {"default"}; + private static final String NAME = "sortbench"; + private static final Identifier IDENT = Identifier.of(NAMESPACE, NAME); + private static final int NUM_FILES = 8; + private static final long NUM_ROWS = 7500000L; + private static final long UNIQUE_VALUES = NUM_ROWS / 4; + + private final Configuration hadoopConf = initHadoopConf(); + private SparkSession spark; + + @Setup + public void setupBench() { + setupSpark(); + } + + @TearDown + public void teardownBench() { + tearDownSpark(); + } + + @Setup(Level.Iteration) + public void setupIteration() { + initTable(); + appendData(); + } + + @TearDown(Level.Iteration) + public void cleanUpIteration() throws IOException { + cleanupFiles(); + } + + @Benchmark + @Threads(1) + public void sortInt() { + SparkActions.get() + .rewriteDataFiles(table()) + .option(BinPackStrategy.REWRITE_ALL, "true") + .sort( + SortOrder.builderFor(table().schema()) + .sortBy("intCol", SortDirection.ASC, NullOrder.NULLS_FIRST) + .build()) + .execute(); + } + + @Benchmark + @Threads(1) + public void sortInt2() { + SparkActions.get() + .rewriteDataFiles(table()) + .option(BinPackStrategy.REWRITE_ALL, "true") + .sort( + SortOrder.builderFor(table().schema()) + .sortBy("intCol", SortDirection.ASC, NullOrder.NULLS_FIRST) + .sortBy("intCol2", SortDirection.ASC, NullOrder.NULLS_FIRST) + .build()) + .execute(); + } + + @Benchmark + @Threads(1) + public void sortInt3() { + SparkActions.get() + .rewriteDataFiles(table()) + .option(BinPackStrategy.REWRITE_ALL, "true") + .sort( + SortOrder.builderFor(table().schema()) + .sortBy("intCol", SortDirection.ASC, NullOrder.NULLS_FIRST) + .sortBy("intCol2", SortDirection.ASC, NullOrder.NULLS_FIRST) + .sortBy("intCol3", SortDirection.ASC, NullOrder.NULLS_FIRST) + .sortBy("intCol4", SortDirection.ASC, NullOrder.NULLS_FIRST) + .build()) + .execute(); + } + + @Benchmark + @Threads(1) + public void sortInt4() { + SparkActions.get() + .rewriteDataFiles(table()) + .option(BinPackStrategy.REWRITE_ALL, "true") + .sort( + SortOrder.builderFor(table().schema()) + .sortBy("intCol", SortDirection.ASC, NullOrder.NULLS_FIRST) + .sortBy("intCol2", SortDirection.ASC, NullOrder.NULLS_FIRST) + .sortBy("intCol3", SortDirection.ASC, NullOrder.NULLS_FIRST) + .sortBy("intCol4", SortDirection.ASC, NullOrder.NULLS_FIRST) + .build()) + .execute(); + } + + @Benchmark + @Threads(1) + public void sortString() { + SparkActions.get() + .rewriteDataFiles(table()) + .option(BinPackStrategy.REWRITE_ALL, "true") + .sort( + SortOrder.builderFor(table().schema()) + .sortBy("stringCol", SortDirection.ASC, NullOrder.NULLS_FIRST) + .build()) + .execute(); + } + + @Benchmark + @Threads(1) + public void sortFourColumns() { + SparkActions.get() + .rewriteDataFiles(table()) + .option(BinPackStrategy.REWRITE_ALL, "true") + .sort( + SortOrder.builderFor(table().schema()) + .sortBy("stringCol", SortDirection.ASC, NullOrder.NULLS_FIRST) + .sortBy("intCol", SortDirection.ASC, NullOrder.NULLS_FIRST) + .sortBy("dateCol", SortDirection.DESC, NullOrder.NULLS_FIRST) + .sortBy("doubleCol", SortDirection.DESC, NullOrder.NULLS_FIRST) + .build()) + .execute(); + } + + @Benchmark + @Threads(1) + public void sortSixColumns() { + SparkActions.get() + .rewriteDataFiles(table()) + .option(BinPackStrategy.REWRITE_ALL, "true") + .sort( + SortOrder.builderFor(table().schema()) + .sortBy("stringCol", SortDirection.ASC, NullOrder.NULLS_FIRST) + .sortBy("intCol", SortDirection.ASC, NullOrder.NULLS_FIRST) + .sortBy("dateCol", SortDirection.DESC, NullOrder.NULLS_FIRST) + .sortBy("timestampCol", SortDirection.DESC, NullOrder.NULLS_FIRST) + .sortBy("doubleCol", SortDirection.DESC, NullOrder.NULLS_FIRST) + .sortBy("longCol", SortDirection.DESC, NullOrder.NULLS_FIRST) + .build()) + .execute(); + } + + @Benchmark + @Threads(1) + public void zSortInt() { + SparkActions.get() + .rewriteDataFiles(table()) + .option(BinPackStrategy.REWRITE_ALL, "true") + .zOrder("intCol") + .execute(); + } + + @Benchmark + @Threads(1) + public void zSortInt2() { + SparkActions.get() + .rewriteDataFiles(table()) + .option(BinPackStrategy.REWRITE_ALL, "true") + .zOrder("intCol", "intCol2") + .execute(); + } + + @Benchmark + @Threads(1) + public void zSortInt3() { + SparkActions.get() + .rewriteDataFiles(table()) + .option(BinPackStrategy.REWRITE_ALL, "true") + .zOrder("intCol", "intCol2", "intCol3") + .execute(); + } + + @Benchmark + @Threads(1) + public void zSortInt4() { + SparkActions.get() + .rewriteDataFiles(table()) + .option(BinPackStrategy.REWRITE_ALL, "true") + .zOrder("intCol", "intCol2", "intCol3", "intCol4") + .execute(); + } + + @Benchmark + @Threads(1) + public void zSortString() { + SparkActions.get() + .rewriteDataFiles(table()) + .option(BinPackStrategy.REWRITE_ALL, "true") + .zOrder("stringCol") + .execute(); + } + + @Benchmark + @Threads(1) + public void zSortFourColumns() { + SparkActions.get() + .rewriteDataFiles(table()) + .option(BinPackStrategy.REWRITE_ALL, "true") + .zOrder("stringCol", "intCol", "dateCol", "doubleCol") + .execute(); + } + + @Benchmark + @Threads(1) + public void zSortSixColumns() { + SparkActions.get() + .rewriteDataFiles(table()) + .option(BinPackStrategy.REWRITE_ALL, "true") + .zOrder("stringCol", "intCol", "dateCol", "timestampCol", "doubleCol", "longCol") + .execute(); + } + + protected Configuration initHadoopConf() { + return new Configuration(); + } + + protected final void initTable() { + Schema schema = + new Schema( + required(1, "longCol", Types.LongType.get()), + required(2, "intCol", Types.IntegerType.get()), + required(3, "intCol2", Types.IntegerType.get()), + required(4, "intCol3", Types.IntegerType.get()), + required(5, "intCol4", Types.IntegerType.get()), + required(6, "floatCol", Types.FloatType.get()), + optional(7, "doubleCol", Types.DoubleType.get()), + optional(8, "dateCol", Types.DateType.get()), + optional(9, "timestampCol", Types.TimestampType.withZone()), + optional(10, "stringCol", Types.StringType.get())); + + SparkSessionCatalog catalog; + try { + catalog = + (SparkSessionCatalog) Spark3Util.catalogAndIdentifier(spark(), "spark_catalog").catalog(); + catalog.dropTable(IDENT); + catalog.createTable( + IDENT, SparkSchemaUtil.convert(schema), new Transform[0], Collections.emptyMap()); + } catch (Exception e) { + throw new RuntimeException(e); + } + } + + private void appendData() { + Dataset df = + spark() + .range(0, NUM_ROWS * NUM_FILES, 1, NUM_FILES) + .drop("id") + .withColumn("longCol", new RandomGeneratingUDF(UNIQUE_VALUES).randomLongUDF().apply()) + .withColumn( + "intCol", + new RandomGeneratingUDF(UNIQUE_VALUES) + .randomLongUDF() + .apply() + .cast(DataTypes.IntegerType)) + .withColumn( + "intCol2", + new RandomGeneratingUDF(UNIQUE_VALUES) + .randomLongUDF() + .apply() + .cast(DataTypes.IntegerType)) + .withColumn( + "intCol3", + new RandomGeneratingUDF(UNIQUE_VALUES) + .randomLongUDF() + .apply() + .cast(DataTypes.IntegerType)) + .withColumn( + "intCol4", + new RandomGeneratingUDF(UNIQUE_VALUES) + .randomLongUDF() + .apply() + .cast(DataTypes.IntegerType)) + .withColumn( + "floatCol", + new RandomGeneratingUDF(UNIQUE_VALUES) + .randomLongUDF() + .apply() + .cast(DataTypes.FloatType)) + .withColumn( + "doubleCol", + new RandomGeneratingUDF(UNIQUE_VALUES) + .randomLongUDF() + .apply() + .cast(DataTypes.DoubleType)) + .withColumn("dateCol", date_add(current_date(), col("intCol").mod(NUM_FILES))) + .withColumn("timestampCol", expr("TO_TIMESTAMP(dateCol)")) + .withColumn("stringCol", new RandomGeneratingUDF(UNIQUE_VALUES).randomString().apply()); + writeData(df); + } + + private void writeData(Dataset df) { + df.write().format("iceberg").mode(SaveMode.Append).save(NAME); + } + + protected final Table table() { + try { + return Spark3Util.loadIcebergTable(spark(), NAME); + } catch (Exception e) { + throw new RuntimeException(e); + } + } + + protected final SparkSession spark() { + return spark; + } + + protected String getCatalogWarehouse() { + String location = Files.createTempDir().getAbsolutePath() + "/" + UUID.randomUUID() + "/"; + return location; + } + + protected void cleanupFiles() throws IOException { + spark.sql("DROP TABLE IF EXISTS " + NAME); + } + + protected void setupSpark() { + SparkSession.Builder builder = + SparkSession.builder() + .config( + "spark.sql.catalog.spark_catalog", "org.apache.iceberg.spark.SparkSessionCatalog") + .config("spark.sql.catalog.spark_catalog.type", "hadoop") + .config("spark.sql.catalog.spark_catalog.warehouse", getCatalogWarehouse()) + .master("local[*]"); + spark = builder.getOrCreate(); + Configuration sparkHadoopConf = spark.sessionState().newHadoopConf(); + hadoopConf.forEach(entry -> sparkHadoopConf.set(entry.getKey(), entry.getValue())); + } + + protected void tearDownSpark() { + spark.stop(); + } +} diff --git a/spark/v3.4/spark/src/jmh/java/org/apache/iceberg/spark/action/RandomGeneratingUDF.java b/spark/v3.4/spark/src/jmh/java/org/apache/iceberg/spark/action/RandomGeneratingUDF.java new file mode 100644 index 000000000000..63d24f7da553 --- /dev/null +++ b/spark/v3.4/spark/src/jmh/java/org/apache/iceberg/spark/action/RandomGeneratingUDF.java @@ -0,0 +1,51 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.action; + +import static org.apache.spark.sql.functions.udf; + +import java.io.Serializable; +import java.util.Random; +import org.apache.iceberg.types.Types; +import org.apache.iceberg.util.RandomUtil; +import org.apache.spark.sql.expressions.UserDefinedFunction; +import org.apache.spark.sql.types.DataTypes; + +class RandomGeneratingUDF implements Serializable { + private final long uniqueValues; + private Random rand = new Random(); + + RandomGeneratingUDF(long uniqueValues) { + this.uniqueValues = uniqueValues; + } + + UserDefinedFunction randomLongUDF() { + return udf(() -> rand.nextLong() % (uniqueValues / 2), DataTypes.LongType) + .asNondeterministic() + .asNonNullable(); + } + + UserDefinedFunction randomString() { + return udf( + () -> (String) RandomUtil.generatePrimitive(Types.StringType.get(), rand), + DataTypes.StringType) + .asNondeterministic() + .asNonNullable(); + } +} diff --git a/spark/v3.4/spark/src/jmh/java/org/apache/iceberg/spark/data/parquet/SparkParquetReadersFlatDataBenchmark.java b/spark/v3.4/spark/src/jmh/java/org/apache/iceberg/spark/data/parquet/SparkParquetReadersFlatDataBenchmark.java new file mode 100644 index 000000000000..63f111a37d62 --- /dev/null +++ b/spark/v3.4/spark/src/jmh/java/org/apache/iceberg/spark/data/parquet/SparkParquetReadersFlatDataBenchmark.java @@ -0,0 +1,225 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.data.parquet; + +import static org.apache.iceberg.types.Types.NestedField.optional; +import static org.apache.iceberg.types.Types.NestedField.required; + +import java.io.File; +import java.io.IOException; +import java.util.List; +import org.apache.avro.generic.GenericData; +import org.apache.iceberg.Files; +import org.apache.iceberg.Schema; +import org.apache.iceberg.common.DynMethods; +import org.apache.iceberg.io.CloseableIterable; +import org.apache.iceberg.io.FileAppender; +import org.apache.iceberg.parquet.Parquet; +import org.apache.iceberg.relocated.com.google.common.collect.Iterables; +import org.apache.iceberg.spark.SparkBenchmarkUtil; +import org.apache.iceberg.spark.SparkSchemaUtil; +import org.apache.iceberg.spark.data.RandomData; +import org.apache.iceberg.spark.data.SparkParquetReaders; +import org.apache.iceberg.types.Types; +import org.apache.spark.sql.catalyst.InternalRow; +import org.apache.spark.sql.catalyst.expressions.UnsafeProjection; +import org.apache.spark.sql.execution.datasources.parquet.ParquetReadSupport; +import org.apache.spark.sql.types.StructType; +import org.openjdk.jmh.annotations.Benchmark; +import org.openjdk.jmh.annotations.BenchmarkMode; +import org.openjdk.jmh.annotations.Fork; +import org.openjdk.jmh.annotations.Measurement; +import org.openjdk.jmh.annotations.Mode; +import org.openjdk.jmh.annotations.Scope; +import org.openjdk.jmh.annotations.Setup; +import org.openjdk.jmh.annotations.State; +import org.openjdk.jmh.annotations.TearDown; +import org.openjdk.jmh.annotations.Threads; +import org.openjdk.jmh.annotations.Warmup; +import org.openjdk.jmh.infra.Blackhole; + +/** + * A benchmark that evaluates the performance of reading Parquet data with a flat schema using + * Iceberg and Spark Parquet readers. + * + *

To run this benchmark for spark-3.3: + * ./gradlew -DsparkVersions=3.3 :iceberg-spark:iceberg-spark-3.3_2.12:jmh + * -PjmhIncludeRegex=SparkParquetReadersFlatDataBenchmark + * -PjmhOutputPath=benchmark/spark-parquet-readers-flat-data-benchmark-result.txt + * + */ +@Fork(1) +@State(Scope.Benchmark) +@Warmup(iterations = 3) +@Measurement(iterations = 5) +@BenchmarkMode(Mode.SingleShotTime) +public class SparkParquetReadersFlatDataBenchmark { + + private static final DynMethods.UnboundMethod APPLY_PROJECTION = + DynMethods.builder("apply").impl(UnsafeProjection.class, InternalRow.class).build(); + private static final Schema SCHEMA = + new Schema( + required(1, "longCol", Types.LongType.get()), + required(2, "intCol", Types.IntegerType.get()), + required(3, "floatCol", Types.FloatType.get()), + optional(4, "doubleCol", Types.DoubleType.get()), + optional(5, "decimalCol", Types.DecimalType.of(20, 5)), + optional(6, "dateCol", Types.DateType.get()), + optional(7, "timestampCol", Types.TimestampType.withZone()), + optional(8, "stringCol", Types.StringType.get())); + private static final Schema PROJECTED_SCHEMA = + new Schema( + required(1, "longCol", Types.LongType.get()), + optional(5, "decimalCol", Types.DecimalType.of(20, 5)), + optional(8, "stringCol", Types.StringType.get())); + private static final int NUM_RECORDS = 1000000; + private File dataFile; + + @Setup + public void setupBenchmark() throws IOException { + dataFile = File.createTempFile("parquet-flat-data-benchmark", ".parquet"); + dataFile.delete(); + List records = RandomData.generateList(SCHEMA, NUM_RECORDS, 0L); + try (FileAppender writer = + Parquet.write(Files.localOutput(dataFile)).schema(SCHEMA).named("benchmark").build()) { + writer.addAll(records); + } + } + + @TearDown + public void tearDownBenchmark() { + if (dataFile != null) { + dataFile.delete(); + } + } + + @Benchmark + @Threads(1) + public void readUsingIcebergReader(Blackhole blackHole) throws IOException { + try (CloseableIterable rows = + Parquet.read(Files.localInput(dataFile)) + .project(SCHEMA) + .createReaderFunc(type -> SparkParquetReaders.buildReader(SCHEMA, type)) + .build()) { + + for (InternalRow row : rows) { + blackHole.consume(row); + } + } + } + + @Benchmark + @Threads(1) + public void readUsingIcebergReaderUnsafe(Blackhole blackhole) throws IOException { + try (CloseableIterable rows = + Parquet.read(Files.localInput(dataFile)) + .project(SCHEMA) + .createReaderFunc(type -> SparkParquetReaders.buildReader(SCHEMA, type)) + .build()) { + + Iterable unsafeRows = + Iterables.transform( + rows, APPLY_PROJECTION.bind(SparkBenchmarkUtil.projection(SCHEMA, SCHEMA))::invoke); + + for (InternalRow row : unsafeRows) { + blackhole.consume(row); + } + } + } + + @Benchmark + @Threads(1) + public void readUsingSparkReader(Blackhole blackhole) throws IOException { + StructType sparkSchema = SparkSchemaUtil.convert(SCHEMA); + try (CloseableIterable rows = + Parquet.read(Files.localInput(dataFile)) + .project(SCHEMA) + .readSupport(new ParquetReadSupport()) + .set("org.apache.spark.sql.parquet.row.requested_schema", sparkSchema.json()) + .set("spark.sql.parquet.binaryAsString", "false") + .set("spark.sql.parquet.int96AsTimestamp", "false") + .set("spark.sql.caseSensitive", "false") + .set("spark.sql.parquet.fieldId.write.enabled", "false") + .callInit() + .build()) { + + for (InternalRow row : rows) { + blackhole.consume(row); + } + } + } + + @Benchmark + @Threads(1) + public void readWithProjectionUsingIcebergReader(Blackhole blackhole) throws IOException { + try (CloseableIterable rows = + Parquet.read(Files.localInput(dataFile)) + .project(PROJECTED_SCHEMA) + .createReaderFunc(type -> SparkParquetReaders.buildReader(PROJECTED_SCHEMA, type)) + .build()) { + + for (InternalRow row : rows) { + blackhole.consume(row); + } + } + } + + @Benchmark + @Threads(1) + public void readWithProjectionUsingIcebergReaderUnsafe(Blackhole blackhole) throws IOException { + try (CloseableIterable rows = + Parquet.read(Files.localInput(dataFile)) + .project(PROJECTED_SCHEMA) + .createReaderFunc(type -> SparkParquetReaders.buildReader(PROJECTED_SCHEMA, type)) + .build()) { + + Iterable unsafeRows = + Iterables.transform( + rows, + APPLY_PROJECTION.bind( + SparkBenchmarkUtil.projection(PROJECTED_SCHEMA, PROJECTED_SCHEMA)) + ::invoke); + + for (InternalRow row : unsafeRows) { + blackhole.consume(row); + } + } + } + + @Benchmark + @Threads(1) + public void readWithProjectionUsingSparkReader(Blackhole blackhole) throws IOException { + StructType sparkSchema = SparkSchemaUtil.convert(PROJECTED_SCHEMA); + try (CloseableIterable rows = + Parquet.read(Files.localInput(dataFile)) + .project(PROJECTED_SCHEMA) + .readSupport(new ParquetReadSupport()) + .set("org.apache.spark.sql.parquet.row.requested_schema", sparkSchema.json()) + .set("spark.sql.parquet.binaryAsString", "false") + .set("spark.sql.parquet.int96AsTimestamp", "false") + .set("spark.sql.caseSensitive", "false") + .callInit() + .build()) { + + for (InternalRow row : rows) { + blackhole.consume(row); + } + } + } +} diff --git a/spark/v3.4/spark/src/jmh/java/org/apache/iceberg/spark/data/parquet/SparkParquetReadersNestedDataBenchmark.java b/spark/v3.4/spark/src/jmh/java/org/apache/iceberg/spark/data/parquet/SparkParquetReadersNestedDataBenchmark.java new file mode 100644 index 000000000000..7a47d7ca53b9 --- /dev/null +++ b/spark/v3.4/spark/src/jmh/java/org/apache/iceberg/spark/data/parquet/SparkParquetReadersNestedDataBenchmark.java @@ -0,0 +1,223 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.data.parquet; + +import static org.apache.iceberg.types.Types.NestedField.optional; +import static org.apache.iceberg.types.Types.NestedField.required; + +import java.io.File; +import java.io.IOException; +import java.util.List; +import org.apache.avro.generic.GenericData; +import org.apache.iceberg.Files; +import org.apache.iceberg.Schema; +import org.apache.iceberg.common.DynMethods; +import org.apache.iceberg.io.CloseableIterable; +import org.apache.iceberg.io.FileAppender; +import org.apache.iceberg.parquet.Parquet; +import org.apache.iceberg.relocated.com.google.common.collect.Iterables; +import org.apache.iceberg.spark.SparkBenchmarkUtil; +import org.apache.iceberg.spark.SparkSchemaUtil; +import org.apache.iceberg.spark.data.RandomData; +import org.apache.iceberg.spark.data.SparkParquetReaders; +import org.apache.iceberg.types.Types; +import org.apache.spark.sql.catalyst.InternalRow; +import org.apache.spark.sql.catalyst.expressions.UnsafeProjection; +import org.apache.spark.sql.execution.datasources.parquet.ParquetReadSupport; +import org.apache.spark.sql.types.StructType; +import org.openjdk.jmh.annotations.Benchmark; +import org.openjdk.jmh.annotations.BenchmarkMode; +import org.openjdk.jmh.annotations.Fork; +import org.openjdk.jmh.annotations.Measurement; +import org.openjdk.jmh.annotations.Mode; +import org.openjdk.jmh.annotations.Scope; +import org.openjdk.jmh.annotations.Setup; +import org.openjdk.jmh.annotations.State; +import org.openjdk.jmh.annotations.TearDown; +import org.openjdk.jmh.annotations.Threads; +import org.openjdk.jmh.annotations.Warmup; +import org.openjdk.jmh.infra.Blackhole; + +/** + * A benchmark that evaluates the performance of reading nested Parquet data using Iceberg and Spark + * Parquet readers. + * + *

To run this benchmark for spark-3.3: + * ./gradlew -DsparkVersions=3.3 :iceberg-spark:iceberg-spark-3.3_2.12:jmh + * -PjmhIncludeRegex=SparkParquetReadersNestedDataBenchmark + * -PjmhOutputPath=benchmark/spark-parquet-readers-nested-data-benchmark-result.txt + * + */ +@Fork(1) +@State(Scope.Benchmark) +@Warmup(iterations = 3) +@Measurement(iterations = 5) +@BenchmarkMode(Mode.SingleShotTime) +public class SparkParquetReadersNestedDataBenchmark { + + private static final DynMethods.UnboundMethod APPLY_PROJECTION = + DynMethods.builder("apply").impl(UnsafeProjection.class, InternalRow.class).build(); + private static final Schema SCHEMA = + new Schema( + required(0, "id", Types.LongType.get()), + optional( + 4, + "nested", + Types.StructType.of( + required(1, "col1", Types.StringType.get()), + required(2, "col2", Types.DoubleType.get()), + required(3, "col3", Types.LongType.get())))); + private static final Schema PROJECTED_SCHEMA = + new Schema( + optional(4, "nested", Types.StructType.of(required(1, "col1", Types.StringType.get())))); + private static final int NUM_RECORDS = 1000000; + private File dataFile; + + @Setup + public void setupBenchmark() throws IOException { + dataFile = File.createTempFile("parquet-nested-data-benchmark", ".parquet"); + dataFile.delete(); + List records = RandomData.generateList(SCHEMA, NUM_RECORDS, 0L); + try (FileAppender writer = + Parquet.write(Files.localOutput(dataFile)).schema(SCHEMA).named("benchmark").build()) { + writer.addAll(records); + } + } + + @TearDown + public void tearDownBenchmark() { + if (dataFile != null) { + dataFile.delete(); + } + } + + @Benchmark + @Threads(1) + public void readUsingIcebergReader(Blackhole blackhole) throws IOException { + try (CloseableIterable rows = + Parquet.read(Files.localInput(dataFile)) + .project(SCHEMA) + .createReaderFunc(type -> SparkParquetReaders.buildReader(SCHEMA, type)) + .build()) { + + for (InternalRow row : rows) { + blackhole.consume(row); + } + } + } + + @Benchmark + @Threads(1) + public void readUsingIcebergReaderUnsafe(Blackhole blackhole) throws IOException { + try (CloseableIterable rows = + Parquet.read(Files.localInput(dataFile)) + .project(SCHEMA) + .createReaderFunc(type -> SparkParquetReaders.buildReader(SCHEMA, type)) + .build()) { + + Iterable unsafeRows = + Iterables.transform( + rows, APPLY_PROJECTION.bind(SparkBenchmarkUtil.projection(SCHEMA, SCHEMA))::invoke); + + for (InternalRow row : unsafeRows) { + blackhole.consume(row); + } + } + } + + @Benchmark + @Threads(1) + public void readUsingSparkReader(Blackhole blackhole) throws IOException { + StructType sparkSchema = SparkSchemaUtil.convert(SCHEMA); + try (CloseableIterable rows = + Parquet.read(Files.localInput(dataFile)) + .project(SCHEMA) + .readSupport(new ParquetReadSupport()) + .set("org.apache.spark.sql.parquet.row.requested_schema", sparkSchema.json()) + .set("spark.sql.parquet.binaryAsString", "false") + .set("spark.sql.parquet.int96AsTimestamp", "false") + .set("spark.sql.caseSensitive", "false") + .set("spark.sql.parquet.fieldId.write.enabled", "false") + .callInit() + .build()) { + + for (InternalRow row : rows) { + blackhole.consume(row); + } + } + } + + @Benchmark + @Threads(1) + public void readWithProjectionUsingIcebergReader(Blackhole blackhole) throws IOException { + try (CloseableIterable rows = + Parquet.read(Files.localInput(dataFile)) + .project(PROJECTED_SCHEMA) + .createReaderFunc(type -> SparkParquetReaders.buildReader(PROJECTED_SCHEMA, type)) + .build()) { + + for (InternalRow row : rows) { + blackhole.consume(row); + } + } + } + + @Benchmark + @Threads(1) + public void readWithProjectionUsingIcebergReaderUnsafe(Blackhole blackhole) throws IOException { + try (CloseableIterable rows = + Parquet.read(Files.localInput(dataFile)) + .project(PROJECTED_SCHEMA) + .createReaderFunc(type -> SparkParquetReaders.buildReader(PROJECTED_SCHEMA, type)) + .build()) { + + Iterable unsafeRows = + Iterables.transform( + rows, + APPLY_PROJECTION.bind( + SparkBenchmarkUtil.projection(PROJECTED_SCHEMA, PROJECTED_SCHEMA)) + ::invoke); + + for (InternalRow row : unsafeRows) { + blackhole.consume(row); + } + } + } + + @Benchmark + @Threads(1) + public void readWithProjectionUsingSparkReader(Blackhole blackhole) throws IOException { + StructType sparkSchema = SparkSchemaUtil.convert(PROJECTED_SCHEMA); + try (CloseableIterable rows = + Parquet.read(Files.localInput(dataFile)) + .project(PROJECTED_SCHEMA) + .readSupport(new ParquetReadSupport()) + .set("org.apache.spark.sql.parquet.row.requested_schema", sparkSchema.json()) + .set("spark.sql.parquet.binaryAsString", "false") + .set("spark.sql.parquet.int96AsTimestamp", "false") + .set("spark.sql.caseSensitive", "false") + .callInit() + .build()) { + + for (InternalRow row : rows) { + blackhole.consume(row); + } + } + } +} diff --git a/spark/v3.4/spark/src/jmh/java/org/apache/iceberg/spark/data/parquet/SparkParquetWritersFlatDataBenchmark.java b/spark/v3.4/spark/src/jmh/java/org/apache/iceberg/spark/data/parquet/SparkParquetWritersFlatDataBenchmark.java new file mode 100644 index 000000000000..f104b8b88b36 --- /dev/null +++ b/spark/v3.4/spark/src/jmh/java/org/apache/iceberg/spark/data/parquet/SparkParquetWritersFlatDataBenchmark.java @@ -0,0 +1,130 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.data.parquet; + +import static org.apache.iceberg.types.Types.NestedField.optional; +import static org.apache.iceberg.types.Types.NestedField.required; + +import java.io.File; +import java.io.IOException; +import org.apache.iceberg.Files; +import org.apache.iceberg.Schema; +import org.apache.iceberg.io.FileAppender; +import org.apache.iceberg.parquet.Parquet; +import org.apache.iceberg.spark.SparkSchemaUtil; +import org.apache.iceberg.spark.data.RandomData; +import org.apache.iceberg.spark.data.SparkParquetWriters; +import org.apache.iceberg.types.Types; +import org.apache.spark.sql.catalyst.InternalRow; +import org.apache.spark.sql.execution.datasources.parquet.ParquetWriteSupport; +import org.apache.spark.sql.types.StructType; +import org.openjdk.jmh.annotations.Benchmark; +import org.openjdk.jmh.annotations.BenchmarkMode; +import org.openjdk.jmh.annotations.Fork; +import org.openjdk.jmh.annotations.Level; +import org.openjdk.jmh.annotations.Measurement; +import org.openjdk.jmh.annotations.Mode; +import org.openjdk.jmh.annotations.Scope; +import org.openjdk.jmh.annotations.Setup; +import org.openjdk.jmh.annotations.State; +import org.openjdk.jmh.annotations.TearDown; +import org.openjdk.jmh.annotations.Threads; +import org.openjdk.jmh.annotations.Warmup; + +/** + * A benchmark that evaluates the performance of writing Parquet data with a flat schema using + * Iceberg and Spark Parquet writers. + * + *

To run this benchmark for spark-3.3: + * ./gradlew -DsparkVersions=3.3 :iceberg-spark:iceberg-spark-3.3_2.12:jmh + * -PjmhIncludeRegex=SparkParquetWritersFlatDataBenchmark + * -PjmhOutputPath=benchmark/spark-parquet-writers-flat-data-benchmark-result.txt + * + */ +@Fork(1) +@State(Scope.Benchmark) +@Warmup(iterations = 3) +@Measurement(iterations = 5) +@BenchmarkMode(Mode.SingleShotTime) +public class SparkParquetWritersFlatDataBenchmark { + + private static final Schema SCHEMA = + new Schema( + required(1, "longCol", Types.LongType.get()), + required(2, "intCol", Types.IntegerType.get()), + required(3, "floatCol", Types.FloatType.get()), + optional(4, "doubleCol", Types.DoubleType.get()), + optional(5, "decimalCol", Types.DecimalType.of(20, 5)), + optional(6, "dateCol", Types.DateType.get()), + optional(7, "timestampCol", Types.TimestampType.withZone()), + optional(8, "stringCol", Types.StringType.get())); + private static final int NUM_RECORDS = 1000000; + private Iterable rows; + private File dataFile; + + @Setup + public void setupBenchmark() throws IOException { + rows = RandomData.generateSpark(SCHEMA, NUM_RECORDS, 0L); + dataFile = File.createTempFile("parquet-flat-data-benchmark", ".parquet"); + dataFile.delete(); + } + + @TearDown(Level.Iteration) + public void tearDownBenchmark() { + if (dataFile != null) { + dataFile.delete(); + } + } + + @Benchmark + @Threads(1) + public void writeUsingIcebergWriter() throws IOException { + try (FileAppender writer = + Parquet.write(Files.localOutput(dataFile)) + .createWriterFunc( + msgType -> + SparkParquetWriters.buildWriter(SparkSchemaUtil.convert(SCHEMA), msgType)) + .schema(SCHEMA) + .build()) { + + writer.addAll(rows); + } + } + + @Benchmark + @Threads(1) + public void writeUsingSparkWriter() throws IOException { + StructType sparkSchema = SparkSchemaUtil.convert(SCHEMA); + try (FileAppender writer = + Parquet.write(Files.localOutput(dataFile)) + .writeSupport(new ParquetWriteSupport()) + .set("org.apache.spark.sql.parquet.row.attributes", sparkSchema.json()) + .set("spark.sql.parquet.writeLegacyFormat", "false") + .set("spark.sql.parquet.binaryAsString", "false") + .set("spark.sql.parquet.int96AsTimestamp", "false") + .set("spark.sql.parquet.outputTimestampType", "TIMESTAMP_MICROS") + .set("spark.sql.caseSensitive", "false") + .set("spark.sql.parquet.fieldId.write.enabled", "false") + .schema(SCHEMA) + .build()) { + + writer.addAll(rows); + } + } +} diff --git a/spark/v3.4/spark/src/jmh/java/org/apache/iceberg/spark/data/parquet/SparkParquetWritersNestedDataBenchmark.java b/spark/v3.4/spark/src/jmh/java/org/apache/iceberg/spark/data/parquet/SparkParquetWritersNestedDataBenchmark.java new file mode 100644 index 000000000000..e375d1c56a6f --- /dev/null +++ b/spark/v3.4/spark/src/jmh/java/org/apache/iceberg/spark/data/parquet/SparkParquetWritersNestedDataBenchmark.java @@ -0,0 +1,130 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.data.parquet; + +import static org.apache.iceberg.types.Types.NestedField.optional; +import static org.apache.iceberg.types.Types.NestedField.required; + +import java.io.File; +import java.io.IOException; +import org.apache.iceberg.Files; +import org.apache.iceberg.Schema; +import org.apache.iceberg.io.FileAppender; +import org.apache.iceberg.parquet.Parquet; +import org.apache.iceberg.spark.SparkSchemaUtil; +import org.apache.iceberg.spark.data.RandomData; +import org.apache.iceberg.spark.data.SparkParquetWriters; +import org.apache.iceberg.types.Types; +import org.apache.spark.sql.catalyst.InternalRow; +import org.apache.spark.sql.execution.datasources.parquet.ParquetWriteSupport; +import org.apache.spark.sql.types.StructType; +import org.openjdk.jmh.annotations.Benchmark; +import org.openjdk.jmh.annotations.BenchmarkMode; +import org.openjdk.jmh.annotations.Fork; +import org.openjdk.jmh.annotations.Level; +import org.openjdk.jmh.annotations.Measurement; +import org.openjdk.jmh.annotations.Mode; +import org.openjdk.jmh.annotations.Scope; +import org.openjdk.jmh.annotations.Setup; +import org.openjdk.jmh.annotations.State; +import org.openjdk.jmh.annotations.TearDown; +import org.openjdk.jmh.annotations.Threads; +import org.openjdk.jmh.annotations.Warmup; + +/** + * A benchmark that evaluates the performance of writing nested Parquet data using Iceberg and Spark + * Parquet writers. + * + *

To run this benchmark for spark-3.3: + * ./gradlew -DsparkVersions=3.3 :iceberg-spark:iceberg-spark-3.3_2.12:jmh + * -PjmhIncludeRegex=SparkParquetWritersNestedDataBenchmark + * -PjmhOutputPath=benchmark/spark-parquet-writers-nested-data-benchmark-result.txt + * + */ +@Fork(1) +@State(Scope.Benchmark) +@Warmup(iterations = 3) +@Measurement(iterations = 5) +@BenchmarkMode(Mode.SingleShotTime) +public class SparkParquetWritersNestedDataBenchmark { + + private static final Schema SCHEMA = + new Schema( + required(0, "id", Types.LongType.get()), + optional( + 4, + "nested", + Types.StructType.of( + required(1, "col1", Types.StringType.get()), + required(2, "col2", Types.DoubleType.get()), + required(3, "col3", Types.LongType.get())))); + private static final int NUM_RECORDS = 1000000; + private Iterable rows; + private File dataFile; + + @Setup + public void setupBenchmark() throws IOException { + rows = RandomData.generateSpark(SCHEMA, NUM_RECORDS, 0L); + dataFile = File.createTempFile("parquet-nested-data-benchmark", ".parquet"); + dataFile.delete(); + } + + @TearDown(Level.Iteration) + public void tearDownBenchmark() { + if (dataFile != null) { + dataFile.delete(); + } + } + + @Benchmark + @Threads(1) + public void writeUsingIcebergWriter() throws IOException { + try (FileAppender writer = + Parquet.write(Files.localOutput(dataFile)) + .createWriterFunc( + msgType -> + SparkParquetWriters.buildWriter(SparkSchemaUtil.convert(SCHEMA), msgType)) + .schema(SCHEMA) + .build()) { + + writer.addAll(rows); + } + } + + @Benchmark + @Threads(1) + public void writeUsingSparkWriter() throws IOException { + StructType sparkSchema = SparkSchemaUtil.convert(SCHEMA); + try (FileAppender writer = + Parquet.write(Files.localOutput(dataFile)) + .writeSupport(new ParquetWriteSupport()) + .set("org.apache.spark.sql.parquet.row.attributes", sparkSchema.json()) + .set("spark.sql.parquet.writeLegacyFormat", "false") + .set("spark.sql.parquet.binaryAsString", "false") + .set("spark.sql.parquet.int96AsTimestamp", "false") + .set("spark.sql.parquet.outputTimestampType", "TIMESTAMP_MICROS") + .set("spark.sql.caseSensitive", "false") + .set("spark.sql.parquet.fieldId.write.enabled", "false") + .schema(SCHEMA) + .build()) { + + writer.addAll(rows); + } + } +} diff --git a/spark/v3.4/spark/src/jmh/java/org/apache/iceberg/spark/source/Action.java b/spark/v3.4/spark/src/jmh/java/org/apache/iceberg/spark/source/Action.java new file mode 100644 index 000000000000..0dbf07285060 --- /dev/null +++ b/spark/v3.4/spark/src/jmh/java/org/apache/iceberg/spark/source/Action.java @@ -0,0 +1,24 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.source; + +@FunctionalInterface +public interface Action { + void invoke(); +} diff --git a/spark/v3.4/spark/src/jmh/java/org/apache/iceberg/spark/source/IcebergSourceBenchmark.java b/spark/v3.4/spark/src/jmh/java/org/apache/iceberg/spark/source/IcebergSourceBenchmark.java new file mode 100644 index 000000000000..68c537e34a4a --- /dev/null +++ b/spark/v3.4/spark/src/jmh/java/org/apache/iceberg/spark/source/IcebergSourceBenchmark.java @@ -0,0 +1,209 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.source; + +import java.io.IOException; +import java.util.Map; +import java.util.UUID; +import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.fs.FileSystem; +import org.apache.hadoop.fs.Path; +import org.apache.iceberg.FileFormat; +import org.apache.iceberg.Table; +import org.apache.iceberg.TableProperties; +import org.apache.iceberg.UpdateProperties; +import org.apache.iceberg.relocated.com.google.common.collect.Maps; +import org.apache.iceberg.spark.SparkSchemaUtil; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.SaveMode; +import org.apache.spark.sql.SparkSession; +import org.apache.spark.sql.internal.SQLConf; +import org.apache.spark.sql.types.StructType; +import org.openjdk.jmh.annotations.BenchmarkMode; +import org.openjdk.jmh.annotations.Fork; +import org.openjdk.jmh.annotations.Measurement; +import org.openjdk.jmh.annotations.Mode; +import org.openjdk.jmh.annotations.Scope; +import org.openjdk.jmh.annotations.State; +import org.openjdk.jmh.annotations.Warmup; +import org.openjdk.jmh.infra.Blackhole; + +@Fork(1) +@State(Scope.Benchmark) +@Warmup(iterations = 3) +@Measurement(iterations = 5) +@BenchmarkMode(Mode.SingleShotTime) +public abstract class IcebergSourceBenchmark { + + private final Configuration hadoopConf = initHadoopConf(); + private final Table table = initTable(); + private SparkSession spark; + + protected abstract Configuration initHadoopConf(); + + protected final Configuration hadoopConf() { + return hadoopConf; + } + + protected abstract Table initTable(); + + protected final Table table() { + return table; + } + + protected final SparkSession spark() { + return spark; + } + + protected String newTableLocation() { + String tmpDir = hadoopConf.get("hadoop.tmp.dir"); + Path tablePath = new Path(tmpDir, "spark-iceberg-table-" + UUID.randomUUID()); + return tablePath.toString(); + } + + protected String dataLocation() { + Map properties = table.properties(); + return properties.getOrDefault( + TableProperties.WRITE_DATA_LOCATION, String.format("%s/data", table.location())); + } + + protected void cleanupFiles() throws IOException { + try (FileSystem fileSystem = FileSystem.get(hadoopConf)) { + Path dataPath = new Path(dataLocation()); + fileSystem.delete(dataPath, true); + Path tablePath = new Path(table.location()); + fileSystem.delete(tablePath, true); + } + } + + protected void setupSpark(boolean enableDictionaryEncoding) { + SparkSession.Builder builder = SparkSession.builder().config("spark.ui.enabled", false); + if (!enableDictionaryEncoding) { + builder + .config("parquet.dictionary.page.size", "1") + .config("parquet.enable.dictionary", false) + .config(TableProperties.PARQUET_DICT_SIZE_BYTES, "1"); + } + builder.master("local"); + spark = builder.getOrCreate(); + Configuration sparkHadoopConf = spark.sessionState().newHadoopConf(); + hadoopConf.forEach(entry -> sparkHadoopConf.set(entry.getKey(), entry.getValue())); + } + + protected void setupSpark() { + setupSpark(false); + } + + protected void tearDownSpark() { + spark.stop(); + } + + protected void materialize(Dataset ds) { + ds.queryExecution().toRdd().toJavaRDD().foreach(record -> {}); + } + + protected void materialize(Dataset ds, Blackhole blackhole) { + blackhole.consume(ds.queryExecution().toRdd().toJavaRDD().count()); + } + + protected void appendAsFile(Dataset ds) { + // ensure the schema is precise (including nullability) + StructType sparkSchema = SparkSchemaUtil.convert(table.schema()); + spark + .createDataFrame(ds.rdd(), sparkSchema) + .coalesce(1) + .write() + .format("iceberg") + .mode(SaveMode.Append) + .save(table.location()); + } + + protected void withSQLConf(Map conf, Action action) { + SQLConf sqlConf = SQLConf.get(); + + Map currentConfValues = Maps.newHashMap(); + conf.keySet() + .forEach( + confKey -> { + if (sqlConf.contains(confKey)) { + String currentConfValue = sqlConf.getConfString(confKey); + currentConfValues.put(confKey, currentConfValue); + } + }); + + conf.forEach( + (confKey, confValue) -> { + if (SQLConf.isStaticConfigKey(confKey)) { + throw new RuntimeException("Cannot modify the value of a static config: " + confKey); + } + sqlConf.setConfString(confKey, confValue); + }); + + try { + action.invoke(); + } finally { + conf.forEach( + (confKey, confValue) -> { + if (currentConfValues.containsKey(confKey)) { + sqlConf.setConfString(confKey, currentConfValues.get(confKey)); + } else { + sqlConf.unsetConf(confKey); + } + }); + } + } + + protected void withTableProperties(Map props, Action action) { + Map tableProps = table.properties(); + Map currentPropValues = Maps.newHashMap(); + props + .keySet() + .forEach( + propKey -> { + if (tableProps.containsKey(propKey)) { + String currentPropValue = tableProps.get(propKey); + currentPropValues.put(propKey, currentPropValue); + } + }); + + UpdateProperties updateProperties = table.updateProperties(); + props.forEach(updateProperties::set); + updateProperties.commit(); + + try { + action.invoke(); + } finally { + UpdateProperties restoreProperties = table.updateProperties(); + props.forEach( + (propKey, propValue) -> { + if (currentPropValues.containsKey(propKey)) { + restoreProperties.set(propKey, currentPropValues.get(propKey)); + } else { + restoreProperties.remove(propKey); + } + }); + restoreProperties.commit(); + } + } + + protected FileFormat fileFormat() { + throw new UnsupportedOperationException("Unsupported file format"); + } +} diff --git a/spark/v3.4/spark/src/jmh/java/org/apache/iceberg/spark/source/IcebergSourceDeleteBenchmark.java b/spark/v3.4/spark/src/jmh/java/org/apache/iceberg/spark/source/IcebergSourceDeleteBenchmark.java new file mode 100644 index 000000000000..e42707bf102b --- /dev/null +++ b/spark/v3.4/spark/src/jmh/java/org/apache/iceberg/spark/source/IcebergSourceDeleteBenchmark.java @@ -0,0 +1,340 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.source; + +import static org.apache.iceberg.TableProperties.PARQUET_VECTORIZATION_ENABLED; +import static org.apache.iceberg.TableProperties.SPLIT_OPEN_FILE_COST; +import static org.apache.iceberg.types.Types.NestedField.optional; +import static org.apache.iceberg.types.Types.NestedField.required; +import static org.apache.spark.sql.functions.current_date; +import static org.apache.spark.sql.functions.date_add; +import static org.apache.spark.sql.functions.expr; + +import java.io.IOException; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.UUID; +import java.util.concurrent.ThreadLocalRandom; +import org.apache.hadoop.conf.Configuration; +import org.apache.iceberg.PartitionSpec; +import org.apache.iceberg.RowDelta; +import org.apache.iceberg.Schema; +import org.apache.iceberg.Table; +import org.apache.iceberg.TableProperties; +import org.apache.iceberg.deletes.PositionDelete; +import org.apache.iceberg.hadoop.HadoopTables; +import org.apache.iceberg.io.ClusteredEqualityDeleteWriter; +import org.apache.iceberg.io.ClusteredPositionDeleteWriter; +import org.apache.iceberg.io.OutputFileFactory; +import org.apache.iceberg.relocated.com.google.common.collect.Iterables; +import org.apache.iceberg.relocated.com.google.common.collect.Lists; +import org.apache.iceberg.relocated.com.google.common.collect.Maps; +import org.apache.iceberg.relocated.com.google.common.collect.Sets; +import org.apache.iceberg.types.Types; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.catalyst.InternalRow; +import org.apache.spark.sql.catalyst.expressions.GenericInternalRow; +import org.openjdk.jmh.annotations.Benchmark; +import org.openjdk.jmh.annotations.Setup; +import org.openjdk.jmh.annotations.TearDown; +import org.openjdk.jmh.annotations.Threads; +import org.openjdk.jmh.infra.Blackhole; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +public abstract class IcebergSourceDeleteBenchmark extends IcebergSourceBenchmark { + private static final Logger LOG = LoggerFactory.getLogger(IcebergSourceDeleteBenchmark.class); + private static final long TARGET_FILE_SIZE_IN_BYTES = 512L * 1024 * 1024; + + protected static final int NUM_FILES = 1; + protected static final int NUM_ROWS = 10 * 1000 * 1000; + + @Setup + public void setupBenchmark() throws IOException { + setupSpark(); + appendData(); + } + + @TearDown + public void tearDownBenchmark() throws IOException { + tearDownSpark(); + cleanupFiles(); + } + + @Benchmark + @Threads(1) + public void readIceberg(Blackhole blackhole) { + Map tableProperties = Maps.newHashMap(); + tableProperties.put(SPLIT_OPEN_FILE_COST, Integer.toString(128 * 1024 * 1024)); + tableProperties.put(PARQUET_VECTORIZATION_ENABLED, "false"); + withTableProperties( + tableProperties, + () -> { + String tableLocation = table().location(); + Dataset df = spark().read().format("iceberg").load(tableLocation); + materialize(df, blackhole); + }); + } + + @Benchmark + @Threads(1) + public void readIcebergWithIsDeletedColumn(Blackhole blackhole) { + Map tableProperties = Maps.newHashMap(); + tableProperties.put(SPLIT_OPEN_FILE_COST, Integer.toString(128 * 1024 * 1024)); + tableProperties.put(PARQUET_VECTORIZATION_ENABLED, "false"); + withTableProperties( + tableProperties, + () -> { + String tableLocation = table().location(); + Dataset df = + spark().read().format("iceberg").load(tableLocation).filter("_deleted = false"); + materialize(df, blackhole); + }); + } + + @Benchmark + @Threads(1) + public void readDeletedRows(Blackhole blackhole) { + Map tableProperties = Maps.newHashMap(); + tableProperties.put(SPLIT_OPEN_FILE_COST, Integer.toString(128 * 1024 * 1024)); + tableProperties.put(PARQUET_VECTORIZATION_ENABLED, "false"); + withTableProperties( + tableProperties, + () -> { + String tableLocation = table().location(); + Dataset df = + spark().read().format("iceberg").load(tableLocation).filter("_deleted = true"); + materialize(df, blackhole); + }); + } + + @Benchmark + @Threads(1) + public void readIcebergVectorized(Blackhole blackhole) { + Map tableProperties = Maps.newHashMap(); + tableProperties.put(SPLIT_OPEN_FILE_COST, Integer.toString(128 * 1024 * 1024)); + tableProperties.put(PARQUET_VECTORIZATION_ENABLED, "true"); + withTableProperties( + tableProperties, + () -> { + String tableLocation = table().location(); + Dataset df = spark().read().format("iceberg").load(tableLocation); + materialize(df, blackhole); + }); + } + + @Benchmark + @Threads(1) + public void readIcebergWithIsDeletedColumnVectorized(Blackhole blackhole) { + Map tableProperties = Maps.newHashMap(); + tableProperties.put(SPLIT_OPEN_FILE_COST, Integer.toString(128 * 1024 * 1024)); + tableProperties.put(PARQUET_VECTORIZATION_ENABLED, "true"); + withTableProperties( + tableProperties, + () -> { + String tableLocation = table().location(); + Dataset df = + spark().read().format("iceberg").load(tableLocation).filter("_deleted = false"); + materialize(df, blackhole); + }); + } + + @Benchmark + @Threads(1) + public void readDeletedRowsVectorized(Blackhole blackhole) { + Map tableProperties = Maps.newHashMap(); + tableProperties.put(SPLIT_OPEN_FILE_COST, Integer.toString(128 * 1024 * 1024)); + tableProperties.put(PARQUET_VECTORIZATION_ENABLED, "true"); + withTableProperties( + tableProperties, + () -> { + String tableLocation = table().location(); + Dataset df = + spark().read().format("iceberg").load(tableLocation).filter("_deleted = true"); + materialize(df, blackhole); + }); + } + + protected abstract void appendData() throws IOException; + + protected void writeData(int fileNum) { + Dataset df = + spark() + .range(NUM_ROWS) + .withColumnRenamed("id", "longCol") + .withColumn("intCol", expr("CAST(MOD(longCol, 2147483647) AS INT)")) + .withColumn("floatCol", expr("CAST(longCol AS FLOAT)")) + .withColumn("doubleCol", expr("CAST(longCol AS DOUBLE)")) + .withColumn("dateCol", date_add(current_date(), fileNum)) + .withColumn("timestampCol", expr("TO_TIMESTAMP(dateCol)")) + .withColumn("stringCol", expr("CAST(dateCol AS STRING)")); + appendAsFile(df); + } + + @Override + protected Table initTable() { + Schema schema = + new Schema( + required(1, "longCol", Types.LongType.get()), + required(2, "intCol", Types.IntegerType.get()), + required(3, "floatCol", Types.FloatType.get()), + optional(4, "doubleCol", Types.DoubleType.get()), + optional(6, "dateCol", Types.DateType.get()), + optional(7, "timestampCol", Types.TimestampType.withZone()), + optional(8, "stringCol", Types.StringType.get())); + PartitionSpec partitionSpec = PartitionSpec.unpartitioned(); + HadoopTables tables = new HadoopTables(hadoopConf()); + Map properties = Maps.newHashMap(); + properties.put(TableProperties.METADATA_COMPRESSION, "gzip"); + properties.put(TableProperties.FORMAT_VERSION, "2"); + return tables.create(schema, partitionSpec, properties, newTableLocation()); + } + + @Override + protected Configuration initHadoopConf() { + return new Configuration(); + } + + protected void writePosDeletes(CharSequence path, long numRows, double percentage) + throws IOException { + writePosDeletes(path, numRows, percentage, 1); + } + + protected void writePosDeletes( + CharSequence path, long numRows, double percentage, int numDeleteFile) throws IOException { + writePosDeletesWithNoise(path, numRows, percentage, 0, numDeleteFile); + } + + protected void writePosDeletesWithNoise( + CharSequence path, long numRows, double percentage, int numNoise, int numDeleteFile) + throws IOException { + Set deletedPos = Sets.newHashSet(); + while (deletedPos.size() < numRows * percentage) { + deletedPos.add(ThreadLocalRandom.current().nextLong(numRows)); + } + LOG.info("pos delete row count: {}, num of delete files: {}", deletedPos.size(), numDeleteFile); + + int partitionSize = (int) (numRows * percentage) / numDeleteFile; + Iterable> sets = Iterables.partition(deletedPos, partitionSize); + for (List item : sets) { + writePosDeletes(path, item, numNoise); + } + } + + protected void writePosDeletes(CharSequence path, List deletedPos, int numNoise) + throws IOException { + OutputFileFactory fileFactory = newFileFactory(); + SparkFileWriterFactory writerFactory = + SparkFileWriterFactory.builderFor(table()).dataFileFormat(fileFormat()).build(); + + ClusteredPositionDeleteWriter writer = + new ClusteredPositionDeleteWriter<>( + writerFactory, fileFactory, table().io(), TARGET_FILE_SIZE_IN_BYTES); + + PartitionSpec unpartitionedSpec = table().specs().get(0); + + PositionDelete positionDelete = PositionDelete.create(); + try (ClusteredPositionDeleteWriter closeableWriter = writer) { + for (Long pos : deletedPos) { + positionDelete.set(path, pos, null); + closeableWriter.write(positionDelete, unpartitionedSpec, null); + for (int i = 0; i < numNoise; i++) { + positionDelete.set(noisePath(path), pos, null); + closeableWriter.write(positionDelete, unpartitionedSpec, null); + } + } + } + + RowDelta rowDelta = table().newRowDelta(); + writer.result().deleteFiles().forEach(rowDelta::addDeletes); + rowDelta.validateDeletedFiles().commit(); + } + + protected void writeEqDeletes(long numRows, double percentage) throws IOException { + Set deletedValues = Sets.newHashSet(); + while (deletedValues.size() < numRows * percentage) { + deletedValues.add(ThreadLocalRandom.current().nextLong(numRows)); + } + + List rows = Lists.newArrayList(); + for (Long value : deletedValues) { + GenericInternalRow genericInternalRow = new GenericInternalRow(7); + genericInternalRow.setLong(0, value); + genericInternalRow.setInt(1, (int) (value % Integer.MAX_VALUE)); + genericInternalRow.setFloat(2, (float) value); + genericInternalRow.setNullAt(3); + genericInternalRow.setNullAt(4); + genericInternalRow.setNullAt(5); + genericInternalRow.setNullAt(6); + rows.add(genericInternalRow); + } + LOG.info("Num of equality deleted rows: {}", rows.size()); + + writeEqDeletes(rows); + } + + private void writeEqDeletes(List rows) throws IOException { + int equalityFieldId = table().schema().findField("longCol").fieldId(); + + OutputFileFactory fileFactory = newFileFactory(); + SparkFileWriterFactory writerFactory = + SparkFileWriterFactory.builderFor(table()) + .dataFileFormat(fileFormat()) + .equalityDeleteRowSchema(table().schema()) + .equalityFieldIds(new int[] {equalityFieldId}) + .build(); + + ClusteredEqualityDeleteWriter writer = + new ClusteredEqualityDeleteWriter<>( + writerFactory, fileFactory, table().io(), TARGET_FILE_SIZE_IN_BYTES); + + PartitionSpec unpartitionedSpec = table().specs().get(0); + try (ClusteredEqualityDeleteWriter closeableWriter = writer) { + for (InternalRow row : rows) { + closeableWriter.write(row, unpartitionedSpec, null); + } + } + + RowDelta rowDelta = table().newRowDelta(); + LOG.info("Num of Delete File: {}", writer.result().deleteFiles().size()); + writer.result().deleteFiles().forEach(rowDelta::addDeletes); + rowDelta.validateDeletedFiles().commit(); + } + + private OutputFileFactory newFileFactory() { + return OutputFileFactory.builderFor(table(), 1, 1).format(fileFormat()).build(); + } + + private CharSequence noisePath(CharSequence path) { + // assume the data file name would be something like + // "00000-0-30da64e0-56b5-4743-a11b-3188a1695bf7-00001.parquet" + // so the dataFileSuffixLen is the UUID string length + length of "-00001.parquet", which is 36 + // + 14 = 60. It's OK + // to be not accurate here. + int dataFileSuffixLen = 60; + UUID uuid = UUID.randomUUID(); + if (path.length() > dataFileSuffixLen) { + return path.subSequence(0, dataFileSuffixLen) + uuid.toString(); + } else { + return uuid.toString(); + } + } +} diff --git a/spark/v3.4/spark/src/jmh/java/org/apache/iceberg/spark/source/IcebergSourceFlatDataBenchmark.java b/spark/v3.4/spark/src/jmh/java/org/apache/iceberg/spark/source/IcebergSourceFlatDataBenchmark.java new file mode 100644 index 000000000000..59e6230350d9 --- /dev/null +++ b/spark/v3.4/spark/src/jmh/java/org/apache/iceberg/spark/source/IcebergSourceFlatDataBenchmark.java @@ -0,0 +1,59 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.source; + +import static org.apache.iceberg.types.Types.NestedField.optional; +import static org.apache.iceberg.types.Types.NestedField.required; + +import java.util.Map; +import org.apache.hadoop.conf.Configuration; +import org.apache.iceberg.PartitionSpec; +import org.apache.iceberg.Schema; +import org.apache.iceberg.Table; +import org.apache.iceberg.TableProperties; +import org.apache.iceberg.hadoop.HadoopTables; +import org.apache.iceberg.relocated.com.google.common.collect.Maps; +import org.apache.iceberg.types.Types; + +public abstract class IcebergSourceFlatDataBenchmark extends IcebergSourceBenchmark { + + @Override + protected Configuration initHadoopConf() { + return new Configuration(); + } + + @Override + protected final Table initTable() { + Schema schema = + new Schema( + required(1, "longCol", Types.LongType.get()), + required(2, "intCol", Types.IntegerType.get()), + required(3, "floatCol", Types.FloatType.get()), + optional(4, "doubleCol", Types.DoubleType.get()), + optional(5, "decimalCol", Types.DecimalType.of(20, 5)), + optional(6, "dateCol", Types.DateType.get()), + optional(7, "timestampCol", Types.TimestampType.withZone()), + optional(8, "stringCol", Types.StringType.get())); + PartitionSpec partitionSpec = PartitionSpec.unpartitioned(); + HadoopTables tables = new HadoopTables(hadoopConf()); + Map properties = Maps.newHashMap(); + properties.put(TableProperties.METADATA_COMPRESSION, "gzip"); + return tables.create(schema, partitionSpec, properties, newTableLocation()); + } +} diff --git a/spark/v3.4/spark/src/jmh/java/org/apache/iceberg/spark/source/IcebergSourceNestedDataBenchmark.java b/spark/v3.4/spark/src/jmh/java/org/apache/iceberg/spark/source/IcebergSourceNestedDataBenchmark.java new file mode 100644 index 000000000000..a1c61b9b4de0 --- /dev/null +++ b/spark/v3.4/spark/src/jmh/java/org/apache/iceberg/spark/source/IcebergSourceNestedDataBenchmark.java @@ -0,0 +1,59 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.source; + +import static org.apache.iceberg.types.Types.NestedField.optional; +import static org.apache.iceberg.types.Types.NestedField.required; + +import java.util.Map; +import org.apache.hadoop.conf.Configuration; +import org.apache.iceberg.PartitionSpec; +import org.apache.iceberg.Schema; +import org.apache.iceberg.Table; +import org.apache.iceberg.TableProperties; +import org.apache.iceberg.hadoop.HadoopTables; +import org.apache.iceberg.relocated.com.google.common.collect.Maps; +import org.apache.iceberg.types.Types; + +public abstract class IcebergSourceNestedDataBenchmark extends IcebergSourceBenchmark { + + @Override + protected Configuration initHadoopConf() { + return new Configuration(); + } + + @Override + protected final Table initTable() { + Schema schema = + new Schema( + required(0, "id", Types.LongType.get()), + optional( + 4, + "nested", + Types.StructType.of( + required(1, "col1", Types.StringType.get()), + required(2, "col2", Types.DoubleType.get()), + required(3, "col3", Types.LongType.get())))); + PartitionSpec partitionSpec = PartitionSpec.unpartitioned(); + HadoopTables tables = new HadoopTables(hadoopConf()); + Map properties = Maps.newHashMap(); + properties.put(TableProperties.METADATA_COMPRESSION, "gzip"); + return tables.create(schema, partitionSpec, properties, newTableLocation()); + } +} diff --git a/spark/v3.4/spark/src/jmh/java/org/apache/iceberg/spark/source/IcebergSourceNestedListDataBenchmark.java b/spark/v3.4/spark/src/jmh/java/org/apache/iceberg/spark/source/IcebergSourceNestedListDataBenchmark.java new file mode 100644 index 000000000000..f68b587735dd --- /dev/null +++ b/spark/v3.4/spark/src/jmh/java/org/apache/iceberg/spark/source/IcebergSourceNestedListDataBenchmark.java @@ -0,0 +1,62 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.source; + +import static org.apache.iceberg.types.Types.NestedField.optional; +import static org.apache.iceberg.types.Types.NestedField.required; + +import java.util.Map; +import org.apache.hadoop.conf.Configuration; +import org.apache.iceberg.PartitionSpec; +import org.apache.iceberg.Schema; +import org.apache.iceberg.Table; +import org.apache.iceberg.TableProperties; +import org.apache.iceberg.hadoop.HadoopTables; +import org.apache.iceberg.relocated.com.google.common.collect.Maps; +import org.apache.iceberg.types.Types; + +public abstract class IcebergSourceNestedListDataBenchmark extends IcebergSourceBenchmark { + + @Override + protected Configuration initHadoopConf() { + return new Configuration(); + } + + @Override + protected final Table initTable() { + Schema schema = + new Schema( + required(0, "id", Types.LongType.get()), + optional( + 1, + "outerlist", + Types.ListType.ofOptional( + 2, + Types.StructType.of( + required( + 3, + "innerlist", + Types.ListType.ofRequired(4, Types.StringType.get())))))); + PartitionSpec partitionSpec = PartitionSpec.unpartitioned(); + HadoopTables tables = new HadoopTables(hadoopConf()); + Map properties = Maps.newHashMap(); + properties.put(TableProperties.METADATA_COMPRESSION, "gzip"); + return tables.create(schema, partitionSpec, properties, newTableLocation()); + } +} diff --git a/spark/v3.4/spark/src/jmh/java/org/apache/iceberg/spark/source/WritersBenchmark.java b/spark/v3.4/spark/src/jmh/java/org/apache/iceberg/spark/source/WritersBenchmark.java new file mode 100644 index 000000000000..13ff034e4bf5 --- /dev/null +++ b/spark/v3.4/spark/src/jmh/java/org/apache/iceberg/spark/source/WritersBenchmark.java @@ -0,0 +1,369 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.source; + +import static org.apache.iceberg.types.Types.NestedField.optional; +import static org.apache.iceberg.types.Types.NestedField.required; + +import java.io.IOException; +import java.util.Comparator; +import java.util.List; +import java.util.Map; +import org.apache.hadoop.conf.Configuration; +import org.apache.iceberg.FileFormat; +import org.apache.iceberg.PartitionKey; +import org.apache.iceberg.PartitionSpec; +import org.apache.iceberg.Schema; +import org.apache.iceberg.Table; +import org.apache.iceberg.deletes.PositionDelete; +import org.apache.iceberg.expressions.Expressions; +import org.apache.iceberg.hadoop.HadoopTables; +import org.apache.iceberg.io.ClusteredDataWriter; +import org.apache.iceberg.io.ClusteredEqualityDeleteWriter; +import org.apache.iceberg.io.ClusteredPositionDeleteWriter; +import org.apache.iceberg.io.DeleteSchemaUtil; +import org.apache.iceberg.io.FanoutDataWriter; +import org.apache.iceberg.io.FileIO; +import org.apache.iceberg.io.OutputFileFactory; +import org.apache.iceberg.io.TaskWriter; +import org.apache.iceberg.io.UnpartitionedWriter; +import org.apache.iceberg.relocated.com.google.common.base.Preconditions; +import org.apache.iceberg.relocated.com.google.common.collect.Lists; +import org.apache.iceberg.relocated.com.google.common.collect.Maps; +import org.apache.iceberg.spark.SparkSchemaUtil; +import org.apache.iceberg.spark.data.RandomData; +import org.apache.iceberg.transforms.Transform; +import org.apache.iceberg.transforms.Transforms; +import org.apache.iceberg.types.Types; +import org.apache.spark.sql.catalyst.InternalRow; +import org.apache.spark.sql.types.StructType; +import org.openjdk.jmh.annotations.Benchmark; +import org.openjdk.jmh.annotations.Setup; +import org.openjdk.jmh.annotations.TearDown; +import org.openjdk.jmh.annotations.Threads; +import org.openjdk.jmh.infra.Blackhole; + +public abstract class WritersBenchmark extends IcebergSourceBenchmark { + + private static final int NUM_ROWS = 2500000; + private static final long TARGET_FILE_SIZE_IN_BYTES = 50L * 1024 * 1024; + + private static final Schema SCHEMA = + new Schema( + required(1, "longCol", Types.LongType.get()), + required(2, "intCol", Types.IntegerType.get()), + required(3, "floatCol", Types.FloatType.get()), + optional(4, "doubleCol", Types.DoubleType.get()), + optional(5, "decimalCol", Types.DecimalType.of(20, 5)), + optional(6, "timestampCol", Types.TimestampType.withZone()), + optional(7, "stringCol", Types.StringType.get())); + + private Iterable rows; + private Iterable positionDeleteRows; + private PartitionSpec unpartitionedSpec; + private PartitionSpec partitionedSpec; + + @Override + protected abstract FileFormat fileFormat(); + + @Setup + public void setupBenchmark() { + setupSpark(); + + List data = Lists.newArrayList(RandomData.generateSpark(SCHEMA, NUM_ROWS, 0L)); + Transform transform = Transforms.bucket(32); + data.sort( + Comparator.comparingInt( + row -> transform.bind(Types.IntegerType.get()).apply(row.getInt(1)))); + this.rows = data; + + this.positionDeleteRows = + RandomData.generateSpark(DeleteSchemaUtil.pathPosSchema(), NUM_ROWS, 0L); + + this.unpartitionedSpec = table().specs().get(0); + Preconditions.checkArgument(unpartitionedSpec.isUnpartitioned()); + this.partitionedSpec = table().specs().get(1); + } + + @TearDown + public void tearDownBenchmark() throws IOException { + tearDownSpark(); + cleanupFiles(); + } + + @Override + protected Configuration initHadoopConf() { + return new Configuration(); + } + + @Override + protected final Table initTable() { + HadoopTables tables = new HadoopTables(hadoopConf()); + PartitionSpec spec = PartitionSpec.unpartitioned(); + Map properties = Maps.newHashMap(); + Table table = tables.create(SCHEMA, spec, properties, newTableLocation()); + + // add a partitioned spec to the table + table.updateSpec().addField(Expressions.bucket("intCol", 32)).commit(); + + return table; + } + + @Benchmark + @Threads(1) + public void writeUnpartitionedClusteredDataWriter(Blackhole blackhole) throws IOException { + FileIO io = table().io(); + + OutputFileFactory fileFactory = newFileFactory(); + SparkFileWriterFactory writerFactory = + SparkFileWriterFactory.builderFor(table()) + .dataFileFormat(fileFormat()) + .dataSchema(table().schema()) + .build(); + + ClusteredDataWriter writer = + new ClusteredDataWriter<>(writerFactory, fileFactory, io, TARGET_FILE_SIZE_IN_BYTES); + + try (ClusteredDataWriter closeableWriter = writer) { + for (InternalRow row : rows) { + closeableWriter.write(row, unpartitionedSpec, null); + } + } + + blackhole.consume(writer); + } + + @Benchmark + @Threads(1) + public void writeUnpartitionedLegacyDataWriter(Blackhole blackhole) throws IOException { + FileIO io = table().io(); + + OutputFileFactory fileFactory = newFileFactory(); + + Schema writeSchema = table().schema(); + StructType sparkWriteType = SparkSchemaUtil.convert(writeSchema); + SparkAppenderFactory appenders = + SparkAppenderFactory.builderFor(table(), writeSchema, sparkWriteType) + .spec(unpartitionedSpec) + .build(); + + TaskWriter writer = + new UnpartitionedWriter<>( + unpartitionedSpec, fileFormat(), appenders, fileFactory, io, TARGET_FILE_SIZE_IN_BYTES); + + try (TaskWriter closableWriter = writer) { + for (InternalRow row : rows) { + closableWriter.write(row); + } + } + + blackhole.consume(writer.complete()); + } + + @Benchmark + @Threads(1) + public void writePartitionedClusteredDataWriter(Blackhole blackhole) throws IOException { + FileIO io = table().io(); + + OutputFileFactory fileFactory = newFileFactory(); + SparkFileWriterFactory writerFactory = + SparkFileWriterFactory.builderFor(table()) + .dataFileFormat(fileFormat()) + .dataSchema(table().schema()) + .build(); + + ClusteredDataWriter writer = + new ClusteredDataWriter<>(writerFactory, fileFactory, io, TARGET_FILE_SIZE_IN_BYTES); + + PartitionKey partitionKey = new PartitionKey(partitionedSpec, table().schema()); + StructType dataSparkType = SparkSchemaUtil.convert(table().schema()); + InternalRowWrapper internalRowWrapper = new InternalRowWrapper(dataSparkType); + + try (ClusteredDataWriter closeableWriter = writer) { + for (InternalRow row : rows) { + partitionKey.partition(internalRowWrapper.wrap(row)); + closeableWriter.write(row, partitionedSpec, partitionKey); + } + } + + blackhole.consume(writer); + } + + @Benchmark + @Threads(1) + public void writePartitionedLegacyDataWriter(Blackhole blackhole) throws IOException { + FileIO io = table().io(); + + OutputFileFactory fileFactory = newFileFactory(); + + Schema writeSchema = table().schema(); + StructType sparkWriteType = SparkSchemaUtil.convert(writeSchema); + SparkAppenderFactory appenders = + SparkAppenderFactory.builderFor(table(), writeSchema, sparkWriteType) + .spec(partitionedSpec) + .build(); + + TaskWriter writer = + new SparkPartitionedWriter( + partitionedSpec, + fileFormat(), + appenders, + fileFactory, + io, + TARGET_FILE_SIZE_IN_BYTES, + writeSchema, + sparkWriteType); + + try (TaskWriter closableWriter = writer) { + for (InternalRow row : rows) { + closableWriter.write(row); + } + } + + blackhole.consume(writer.complete()); + } + + @Benchmark + @Threads(1) + public void writePartitionedFanoutDataWriter(Blackhole blackhole) throws IOException { + FileIO io = table().io(); + + OutputFileFactory fileFactory = newFileFactory(); + SparkFileWriterFactory writerFactory = + SparkFileWriterFactory.builderFor(table()) + .dataFileFormat(fileFormat()) + .dataSchema(table().schema()) + .build(); + + FanoutDataWriter writer = + new FanoutDataWriter<>(writerFactory, fileFactory, io, TARGET_FILE_SIZE_IN_BYTES); + + PartitionKey partitionKey = new PartitionKey(partitionedSpec, table().schema()); + StructType dataSparkType = SparkSchemaUtil.convert(table().schema()); + InternalRowWrapper internalRowWrapper = new InternalRowWrapper(dataSparkType); + + try (FanoutDataWriter closeableWriter = writer) { + for (InternalRow row : rows) { + partitionKey.partition(internalRowWrapper.wrap(row)); + closeableWriter.write(row, partitionedSpec, partitionKey); + } + } + + blackhole.consume(writer); + } + + @Benchmark + @Threads(1) + public void writePartitionedLegacyFanoutDataWriter(Blackhole blackhole) throws IOException { + FileIO io = table().io(); + + OutputFileFactory fileFactory = newFileFactory(); + + Schema writeSchema = table().schema(); + StructType sparkWriteType = SparkSchemaUtil.convert(writeSchema); + SparkAppenderFactory appenders = + SparkAppenderFactory.builderFor(table(), writeSchema, sparkWriteType) + .spec(partitionedSpec) + .build(); + + TaskWriter writer = + new SparkPartitionedFanoutWriter( + partitionedSpec, + fileFormat(), + appenders, + fileFactory, + io, + TARGET_FILE_SIZE_IN_BYTES, + writeSchema, + sparkWriteType); + + try (TaskWriter closableWriter = writer) { + for (InternalRow row : rows) { + closableWriter.write(row); + } + } + + blackhole.consume(writer.complete()); + } + + @Benchmark + @Threads(1) + public void writePartitionedClusteredEqualityDeleteWriter(Blackhole blackhole) + throws IOException { + FileIO io = table().io(); + + int equalityFieldId = table().schema().findField("longCol").fieldId(); + + OutputFileFactory fileFactory = newFileFactory(); + SparkFileWriterFactory writerFactory = + SparkFileWriterFactory.builderFor(table()) + .dataFileFormat(fileFormat()) + .equalityDeleteRowSchema(table().schema()) + .equalityFieldIds(new int[] {equalityFieldId}) + .build(); + + ClusteredEqualityDeleteWriter writer = + new ClusteredEqualityDeleteWriter<>( + writerFactory, fileFactory, io, TARGET_FILE_SIZE_IN_BYTES); + + PartitionKey partitionKey = new PartitionKey(partitionedSpec, table().schema()); + StructType deleteSparkType = SparkSchemaUtil.convert(table().schema()); + InternalRowWrapper internalRowWrapper = new InternalRowWrapper(deleteSparkType); + + try (ClusteredEqualityDeleteWriter closeableWriter = writer) { + for (InternalRow row : rows) { + partitionKey.partition(internalRowWrapper.wrap(row)); + closeableWriter.write(row, partitionedSpec, partitionKey); + } + } + + blackhole.consume(writer); + } + + @Benchmark + @Threads(1) + public void writeUnpartitionedClusteredPositionDeleteWriter(Blackhole blackhole) + throws IOException { + FileIO io = table().io(); + + OutputFileFactory fileFactory = newFileFactory(); + SparkFileWriterFactory writerFactory = + SparkFileWriterFactory.builderFor(table()).dataFileFormat(fileFormat()).build(); + + ClusteredPositionDeleteWriter writer = + new ClusteredPositionDeleteWriter<>( + writerFactory, fileFactory, io, TARGET_FILE_SIZE_IN_BYTES); + + PositionDelete positionDelete = PositionDelete.create(); + try (ClusteredPositionDeleteWriter closeableWriter = writer) { + for (InternalRow row : positionDeleteRows) { + String path = row.getString(0); + long pos = row.getLong(1); + positionDelete.set(path, pos, null); + closeableWriter.write(positionDelete, unpartitionedSpec, null); + } + } + + blackhole.consume(writer); + } + + private OutputFileFactory newFileFactory() { + return OutputFileFactory.builderFor(table(), 1, 1).format(fileFormat()).build(); + } +} diff --git a/spark/v3.4/spark/src/jmh/java/org/apache/iceberg/spark/source/avro/AvroWritersBenchmark.java b/spark/v3.4/spark/src/jmh/java/org/apache/iceberg/spark/source/avro/AvroWritersBenchmark.java new file mode 100644 index 000000000000..5220f65dfa6c --- /dev/null +++ b/spark/v3.4/spark/src/jmh/java/org/apache/iceberg/spark/source/avro/AvroWritersBenchmark.java @@ -0,0 +1,39 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.source.avro; + +import org.apache.iceberg.FileFormat; +import org.apache.iceberg.spark.source.WritersBenchmark; + +/** + * A benchmark that evaluates the performance of various Iceberg writers for Avro data. + * + *

To run this benchmark for spark-3.3: + * ./gradlew -DsparkVersions=3.3 :iceberg-spark:iceberg-spark-3.3_2.12:jmh + * -PjmhIncludeRegex=AvroWritersBenchmark + * -PjmhOutputPath=benchmark/avro-writers-benchmark-result.txt + * + */ +public class AvroWritersBenchmark extends WritersBenchmark { + + @Override + protected FileFormat fileFormat() { + return FileFormat.AVRO; + } +} diff --git a/spark/v3.4/spark/src/jmh/java/org/apache/iceberg/spark/source/avro/IcebergSourceFlatAvroDataReadBenchmark.java b/spark/v3.4/spark/src/jmh/java/org/apache/iceberg/spark/source/avro/IcebergSourceFlatAvroDataReadBenchmark.java new file mode 100644 index 000000000000..4eb1ee9d92bb --- /dev/null +++ b/spark/v3.4/spark/src/jmh/java/org/apache/iceberg/spark/source/avro/IcebergSourceFlatAvroDataReadBenchmark.java @@ -0,0 +1,142 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.source.avro; + +import static org.apache.iceberg.TableProperties.DEFAULT_FILE_FORMAT; +import static org.apache.iceberg.TableProperties.SPLIT_OPEN_FILE_COST; +import static org.apache.spark.sql.functions.current_date; +import static org.apache.spark.sql.functions.date_add; +import static org.apache.spark.sql.functions.expr; + +import java.io.IOException; +import java.util.Map; +import org.apache.iceberg.relocated.com.google.common.collect.Maps; +import org.apache.iceberg.spark.source.IcebergSourceFlatDataBenchmark; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.internal.SQLConf; +import org.openjdk.jmh.annotations.Benchmark; +import org.openjdk.jmh.annotations.Setup; +import org.openjdk.jmh.annotations.TearDown; +import org.openjdk.jmh.annotations.Threads; + +/** + * A benchmark that evaluates the performance of reading Avro data with a flat schema using Iceberg + * and the built-in file source in Spark. + * + *

To run this benchmark for spark-3.3: + * ./gradlew -DsparkVersions=3.3 :iceberg-spark:iceberg-spark-3.3_2.12:jmh + * -PjmhIncludeRegex=IcebergSourceFlatAvroDataReadBenchmark + * -PjmhOutputPath=benchmark/iceberg-source-flat-avro-data-read-benchmark-result.txt + * + */ +public class IcebergSourceFlatAvroDataReadBenchmark extends IcebergSourceFlatDataBenchmark { + + private static final int NUM_FILES = 10; + private static final int NUM_ROWS = 1000000; + + @Setup + public void setupBenchmark() { + setupSpark(); + appendData(); + } + + @TearDown + public void tearDownBenchmark() throws IOException { + tearDownSpark(); + cleanupFiles(); + } + + @Benchmark + @Threads(1) + public void readIceberg() { + Map tableProperties = Maps.newHashMap(); + tableProperties.put(SPLIT_OPEN_FILE_COST, Integer.toString(128 * 1024 * 1024)); + withTableProperties( + tableProperties, + () -> { + String tableLocation = table().location(); + Dataset df = spark().read().format("iceberg").load(tableLocation); + materialize(df); + }); + } + + @Benchmark + @Threads(1) + public void readFileSource() { + Map conf = Maps.newHashMap(); + conf.put(SQLConf.FILES_OPEN_COST_IN_BYTES().key(), Integer.toString(128 * 1024 * 1024)); + withSQLConf( + conf, + () -> { + Dataset df = spark().read().format("avro").load(dataLocation()); + materialize(df); + }); + } + + @Benchmark + @Threads(1) + public void readWithProjectionIceberg() { + Map tableProperties = Maps.newHashMap(); + tableProperties.put(SPLIT_OPEN_FILE_COST, Integer.toString(128 * 1024 * 1024)); + withTableProperties( + tableProperties, + () -> { + String tableLocation = table().location(); + Dataset df = spark().read().format("iceberg").load(tableLocation).select("longCol"); + materialize(df); + }); + } + + @Benchmark + @Threads(1) + public void readWithProjectionFileSource() { + Map conf = Maps.newHashMap(); + conf.put(SQLConf.FILES_OPEN_COST_IN_BYTES().key(), Integer.toString(128 * 1024 * 1024)); + withSQLConf( + conf, + () -> { + Dataset df = spark().read().format("avro").load(dataLocation()).select("longCol"); + materialize(df); + }); + } + + private void appendData() { + Map tableProperties = Maps.newHashMap(); + tableProperties.put(DEFAULT_FILE_FORMAT, "avro"); + withTableProperties( + tableProperties, + () -> { + for (int fileNum = 1; fileNum <= NUM_FILES; fileNum++) { + Dataset df = + spark() + .range(NUM_ROWS) + .withColumnRenamed("id", "longCol") + .withColumn("intCol", expr("CAST(longCol AS INT)")) + .withColumn("floatCol", expr("CAST(longCol AS FLOAT)")) + .withColumn("doubleCol", expr("CAST(longCol AS DOUBLE)")) + .withColumn("decimalCol", expr("CAST(longCol AS DECIMAL(20, 5))")) + .withColumn("dateCol", date_add(current_date(), fileNum)) + .withColumn("timestampCol", expr("TO_TIMESTAMP(dateCol)")) + .withColumn("stringCol", expr("CAST(dateCol AS STRING)")); + appendAsFile(df); + } + }); + } +} diff --git a/spark/v3.4/spark/src/jmh/java/org/apache/iceberg/spark/source/avro/IcebergSourceNestedAvroDataReadBenchmark.java b/spark/v3.4/spark/src/jmh/java/org/apache/iceberg/spark/source/avro/IcebergSourceNestedAvroDataReadBenchmark.java new file mode 100644 index 000000000000..2e792b6d35e3 --- /dev/null +++ b/spark/v3.4/spark/src/jmh/java/org/apache/iceberg/spark/source/avro/IcebergSourceNestedAvroDataReadBenchmark.java @@ -0,0 +1,142 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.source.avro; + +import static org.apache.iceberg.TableProperties.DEFAULT_FILE_FORMAT; +import static org.apache.iceberg.TableProperties.SPLIT_OPEN_FILE_COST; +import static org.apache.spark.sql.functions.expr; +import static org.apache.spark.sql.functions.lit; +import static org.apache.spark.sql.functions.struct; + +import java.io.IOException; +import java.util.Map; +import org.apache.iceberg.relocated.com.google.common.collect.Maps; +import org.apache.iceberg.spark.source.IcebergSourceNestedDataBenchmark; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.internal.SQLConf; +import org.openjdk.jmh.annotations.Benchmark; +import org.openjdk.jmh.annotations.Setup; +import org.openjdk.jmh.annotations.TearDown; +import org.openjdk.jmh.annotations.Threads; + +/** + * A benchmark that evaluates the performance of reading Avro data with a flat schema using Iceberg + * and the built-in file source in Spark. + * + *

To run this benchmark for spark-3.3: + * ./gradlew -DsparkVersions=3.3 :iceberg-spark:iceberg-spark-3.3_2.12:jmh + * -PjmhIncludeRegex=IcebergSourceNestedAvroDataReadBenchmark + * -PjmhOutputPath=benchmark/iceberg-source-nested-avro-data-read-benchmark-result.txt + * + */ +public class IcebergSourceNestedAvroDataReadBenchmark extends IcebergSourceNestedDataBenchmark { + + private static final int NUM_FILES = 10; + private static final int NUM_ROWS = 1000000; + + @Setup + public void setupBenchmark() { + setupSpark(); + appendData(); + } + + @TearDown + public void tearDownBenchmark() throws IOException { + tearDownSpark(); + cleanupFiles(); + } + + @Benchmark + @Threads(1) + public void readIceberg() { + Map tableProperties = Maps.newHashMap(); + tableProperties.put(SPLIT_OPEN_FILE_COST, Integer.toString(128 * 1024 * 1024)); + withTableProperties( + tableProperties, + () -> { + String tableLocation = table().location(); + Dataset df = spark().read().format("iceberg").load(tableLocation); + materialize(df); + }); + } + + @Benchmark + @Threads(1) + public void readFileSource() { + Map conf = Maps.newHashMap(); + conf.put(SQLConf.FILES_OPEN_COST_IN_BYTES().key(), Integer.toString(128 * 1024 * 1024)); + withSQLConf( + conf, + () -> { + Dataset df = spark().read().format("avro").load(dataLocation()); + materialize(df); + }); + } + + @Benchmark + @Threads(1) + public void readWithProjectionIceberg() { + Map tableProperties = Maps.newHashMap(); + tableProperties.put(SPLIT_OPEN_FILE_COST, Integer.toString(128 * 1024 * 1024)); + withTableProperties( + tableProperties, + () -> { + String tableLocation = table().location(); + Dataset df = + spark().read().format("iceberg").load(tableLocation).select("nested.col3"); + materialize(df); + }); + } + + @Benchmark + @Threads(1) + public void readWithProjectionFileSource() { + Map conf = Maps.newHashMap(); + conf.put(SQLConf.FILES_OPEN_COST_IN_BYTES().key(), Integer.toString(128 * 1024 * 1024)); + withSQLConf( + conf, + () -> { + Dataset df = + spark().read().format("avro").load(dataLocation()).select("nested.col3"); + materialize(df); + }); + } + + private void appendData() { + Map tableProperties = Maps.newHashMap(); + tableProperties.put(DEFAULT_FILE_FORMAT, "avro"); + withTableProperties( + tableProperties, + () -> { + for (int fileNum = 1; fileNum <= NUM_FILES; fileNum++) { + Dataset df = + spark() + .range(NUM_ROWS) + .withColumn( + "nested", + struct( + expr("CAST(id AS string) AS col1"), + expr("CAST(id AS double) AS col2"), + lit(fileNum).cast("long").as("col3"))); + appendAsFile(df); + } + }); + } +} diff --git a/spark/v3.4/spark/src/jmh/java/org/apache/iceberg/spark/source/orc/IcebergSourceFlatORCDataBenchmark.java b/spark/v3.4/spark/src/jmh/java/org/apache/iceberg/spark/source/orc/IcebergSourceFlatORCDataBenchmark.java new file mode 100644 index 000000000000..d0fdd8915780 --- /dev/null +++ b/spark/v3.4/spark/src/jmh/java/org/apache/iceberg/spark/source/orc/IcebergSourceFlatORCDataBenchmark.java @@ -0,0 +1,68 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.source.orc; + +import static org.apache.iceberg.types.Types.NestedField.optional; +import static org.apache.iceberg.types.Types.NestedField.required; + +import java.util.Map; +import org.apache.hadoop.conf.Configuration; +import org.apache.iceberg.PartitionSpec; +import org.apache.iceberg.Schema; +import org.apache.iceberg.Table; +import org.apache.iceberg.TableProperties; +import org.apache.iceberg.hadoop.HadoopTables; +import org.apache.iceberg.relocated.com.google.common.collect.Maps; +import org.apache.iceberg.spark.source.IcebergSourceBenchmark; +import org.apache.iceberg.types.Types; + +/** + * Same as {@link org.apache.iceberg.spark.source.IcebergSourceFlatDataBenchmark} but we disable the + * Timestamp with zone type for ORC performance tests as Spark native reader does not support ORC's + * TIMESTAMP_INSTANT type + */ +public abstract class IcebergSourceFlatORCDataBenchmark extends IcebergSourceBenchmark { + + @Override + protected Configuration initHadoopConf() { + return new Configuration(); + } + + @Override + protected final Table initTable() { + Schema schema = + new Schema( + required(1, "longCol", Types.LongType.get()), + required(2, "intCol", Types.IntegerType.get()), + required(3, "floatCol", Types.FloatType.get()), + optional(4, "doubleCol", Types.DoubleType.get()), + optional(5, "decimalCol", Types.DecimalType.of(20, 5)), + optional(6, "dateCol", Types.DateType.get()), + // Disable timestamp column for ORC performance tests as Spark native reader does not + // support ORC's + // TIMESTAMP_INSTANT type + // optional(7, "timestampCol", Types.TimestampType.withZone()), + optional(8, "stringCol", Types.StringType.get())); + PartitionSpec partitionSpec = PartitionSpec.unpartitioned(); + HadoopTables tables = new HadoopTables(hadoopConf()); + Map properties = Maps.newHashMap(); + properties.put(TableProperties.METADATA_COMPRESSION, "gzip"); + return tables.create(schema, partitionSpec, properties, newTableLocation()); + } +} diff --git a/spark/v3.4/spark/src/jmh/java/org/apache/iceberg/spark/source/orc/IcebergSourceFlatORCDataReadBenchmark.java b/spark/v3.4/spark/src/jmh/java/org/apache/iceberg/spark/source/orc/IcebergSourceFlatORCDataReadBenchmark.java new file mode 100644 index 000000000000..8ee467b509e0 --- /dev/null +++ b/spark/v3.4/spark/src/jmh/java/org/apache/iceberg/spark/source/orc/IcebergSourceFlatORCDataReadBenchmark.java @@ -0,0 +1,210 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.source.orc; + +import static org.apache.iceberg.TableProperties.DEFAULT_FILE_FORMAT; +import static org.apache.iceberg.TableProperties.SPLIT_OPEN_FILE_COST; +import static org.apache.spark.sql.functions.current_date; +import static org.apache.spark.sql.functions.date_add; +import static org.apache.spark.sql.functions.expr; + +import java.io.IOException; +import java.util.Map; +import org.apache.iceberg.relocated.com.google.common.collect.Maps; +import org.apache.iceberg.spark.SparkReadOptions; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.internal.SQLConf; +import org.openjdk.jmh.annotations.Benchmark; +import org.openjdk.jmh.annotations.Setup; +import org.openjdk.jmh.annotations.TearDown; +import org.openjdk.jmh.annotations.Threads; + +/** + * A benchmark that evaluates the performance of reading ORC data with a flat schema using Iceberg + * and the built-in file source in Spark. + * + *

To run this benchmark for spark-3.3: + * ./gradlew -DsparkVersions=3.3 :iceberg-spark:iceberg-spark-3.3_2.12:jmh + * -PjmhIncludeRegex=IcebergSourceFlatORCDataReadBenchmark + * -PjmhOutputPath=benchmark/iceberg-source-flat-orc-data-read-benchmark-result.txt + * + */ +public class IcebergSourceFlatORCDataReadBenchmark extends IcebergSourceFlatORCDataBenchmark { + + private static final int NUM_FILES = 10; + private static final int NUM_ROWS = 1000000; + + @Setup + public void setupBenchmark() { + setupSpark(); + appendData(); + } + + @TearDown + public void tearDownBenchmark() throws IOException { + tearDownSpark(); + cleanupFiles(); + } + + @Benchmark + @Threads(1) + public void readIcebergNonVectorized() { + Map tableProperties = Maps.newHashMap(); + tableProperties.put(SPLIT_OPEN_FILE_COST, Integer.toString(128 * 1024 * 1024)); + withTableProperties( + tableProperties, + () -> { + String tableLocation = table().location(); + Dataset df = spark().read().format("iceberg").load(tableLocation); + materialize(df); + }); + } + + @Benchmark + @Threads(1) + public void readIcebergVectorized() { + Map tableProperties = Maps.newHashMap(); + tableProperties.put(SPLIT_OPEN_FILE_COST, Integer.toString(128 * 1024 * 1024)); + withTableProperties( + tableProperties, + () -> { + String tableLocation = table().location(); + Dataset df = + spark() + .read() + .option(SparkReadOptions.VECTORIZATION_ENABLED, "true") + .format("iceberg") + .load(tableLocation); + materialize(df); + }); + } + + @Benchmark + @Threads(1) + public void readFileSourceVectorized() { + Map conf = Maps.newHashMap(); + conf.put(SQLConf.ORC_VECTORIZED_READER_ENABLED().key(), "true"); + conf.put(SQLConf.FILES_OPEN_COST_IN_BYTES().key(), Integer.toString(128 * 1024 * 1024)); + withSQLConf( + conf, + () -> { + Dataset df = spark().read().orc(dataLocation()); + materialize(df); + }); + } + + @Benchmark + @Threads(1) + public void readFileSourceNonVectorized() { + Map conf = Maps.newHashMap(); + conf.put(SQLConf.ORC_VECTORIZED_READER_ENABLED().key(), "false"); + conf.put(SQLConf.FILES_OPEN_COST_IN_BYTES().key(), Integer.toString(128 * 1024 * 1024)); + withSQLConf( + conf, + () -> { + Dataset df = spark().read().orc(dataLocation()); + materialize(df); + }); + } + + @Benchmark + @Threads(1) + public void readWithProjectionIcebergNonVectorized() { + Map tableProperties = Maps.newHashMap(); + tableProperties.put(SPLIT_OPEN_FILE_COST, Integer.toString(128 * 1024 * 1024)); + withTableProperties( + tableProperties, + () -> { + String tableLocation = table().location(); + Dataset df = spark().read().format("iceberg").load(tableLocation).select("longCol"); + materialize(df); + }); + } + + @Benchmark + @Threads(1) + public void readWithProjectionIcebergVectorized() { + Map tableProperties = Maps.newHashMap(); + tableProperties.put(SPLIT_OPEN_FILE_COST, Integer.toString(128 * 1024 * 1024)); + withTableProperties( + tableProperties, + () -> { + String tableLocation = table().location(); + Dataset df = + spark() + .read() + .option(SparkReadOptions.VECTORIZATION_ENABLED, "true") + .format("iceberg") + .load(tableLocation) + .select("longCol"); + materialize(df); + }); + } + + @Benchmark + @Threads(1) + public void readWithProjectionFileSourceVectorized() { + Map conf = Maps.newHashMap(); + conf.put(SQLConf.ORC_VECTORIZED_READER_ENABLED().key(), "true"); + conf.put(SQLConf.FILES_OPEN_COST_IN_BYTES().key(), Integer.toString(128 * 1024 * 1024)); + withSQLConf( + conf, + () -> { + Dataset df = spark().read().orc(dataLocation()).select("longCol"); + materialize(df); + }); + } + + @Benchmark + @Threads(1) + public void readWithProjectionFileSourceNonVectorized() { + Map conf = Maps.newHashMap(); + conf.put(SQLConf.ORC_VECTORIZED_READER_ENABLED().key(), "false"); + conf.put(SQLConf.FILES_OPEN_COST_IN_BYTES().key(), Integer.toString(128 * 1024 * 1024)); + withSQLConf( + conf, + () -> { + Dataset df = spark().read().orc(dataLocation()).select("longCol"); + materialize(df); + }); + } + + private void appendData() { + Map tableProperties = Maps.newHashMap(); + tableProperties.put(DEFAULT_FILE_FORMAT, "orc"); + withTableProperties( + tableProperties, + () -> { + for (int fileNum = 1; fileNum <= NUM_FILES; fileNum++) { + Dataset df = + spark() + .range(NUM_ROWS) + .withColumnRenamed("id", "longCol") + .withColumn("intCol", expr("CAST(longCol AS INT)")) + .withColumn("floatCol", expr("CAST(longCol AS FLOAT)")) + .withColumn("doubleCol", expr("CAST(longCol AS DOUBLE)")) + .withColumn("decimalCol", expr("CAST(longCol AS DECIMAL(20, 5))")) + .withColumn("dateCol", date_add(current_date(), fileNum)) + .withColumn("stringCol", expr("CAST(dateCol AS STRING)")); + appendAsFile(df); + } + }); + } +} diff --git a/spark/v3.4/spark/src/jmh/java/org/apache/iceberg/spark/source/orc/IcebergSourceNestedListORCDataWriteBenchmark.java b/spark/v3.4/spark/src/jmh/java/org/apache/iceberg/spark/source/orc/IcebergSourceNestedListORCDataWriteBenchmark.java new file mode 100644 index 000000000000..15486113493a --- /dev/null +++ b/spark/v3.4/spark/src/jmh/java/org/apache/iceberg/spark/source/orc/IcebergSourceNestedListORCDataWriteBenchmark.java @@ -0,0 +1,109 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.source.orc; + +import static org.apache.spark.sql.functions.array_repeat; +import static org.apache.spark.sql.functions.expr; +import static org.apache.spark.sql.functions.struct; + +import java.io.IOException; +import java.util.Map; +import org.apache.iceberg.relocated.com.google.common.collect.Maps; +import org.apache.iceberg.spark.source.IcebergSourceNestedListDataBenchmark; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.SaveMode; +import org.openjdk.jmh.annotations.Benchmark; +import org.openjdk.jmh.annotations.Param; +import org.openjdk.jmh.annotations.Setup; +import org.openjdk.jmh.annotations.TearDown; +import org.openjdk.jmh.annotations.Threads; + +/** + * A benchmark that evaluates the performance of writing nested Parquet data using Iceberg and the + * built-in file source in Spark. + * + *

To run this benchmark for spark-3.3: + * ./gradlew -DsparkVersions=3.3 :iceberg-spark:iceberg-spark-3.3_2.12:jmh + * -PjmhIncludeRegex=IcebergSourceNestedListORCDataWriteBenchmark + * -PjmhOutputPath=benchmark/iceberg-source-nested-list-orc-data-write-benchmark-result.txt + * + */ +public class IcebergSourceNestedListORCDataWriteBenchmark + extends IcebergSourceNestedListDataBenchmark { + + @Setup + public void setupBenchmark() { + setupSpark(); + } + + @TearDown + public void tearDownBenchmark() throws IOException { + tearDownSpark(); + cleanupFiles(); + } + + @Param({"2000", "20000"}) + private int numRows; + + @Benchmark + @Threads(1) + public void writeIceberg() { + String tableLocation = table().location(); + benchmarkData() + .write() + .format("iceberg") + .option("write-format", "orc") + .mode(SaveMode.Append) + .save(tableLocation); + } + + @Benchmark + @Threads(1) + public void writeIcebergDictionaryOff() { + Map tableProperties = Maps.newHashMap(); + tableProperties.put("orc.dictionary.key.threshold", "0"); + withTableProperties( + tableProperties, + () -> { + String tableLocation = table().location(); + benchmarkData() + .write() + .format("iceberg") + .option("write-format", "orc") + .mode(SaveMode.Append) + .save(tableLocation); + }); + } + + @Benchmark + @Threads(1) + public void writeFileSource() { + benchmarkData().write().mode(SaveMode.Append).orc(dataLocation()); + } + + private Dataset benchmarkData() { + return spark() + .range(numRows) + .withColumn( + "outerlist", + array_repeat(struct(expr("array_repeat(CAST(id AS string), 1000) AS innerlist")), 10)) + .coalesce(1); + } +} diff --git a/spark/v3.4/spark/src/jmh/java/org/apache/iceberg/spark/source/orc/IcebergSourceNestedORCDataReadBenchmark.java b/spark/v3.4/spark/src/jmh/java/org/apache/iceberg/spark/source/orc/IcebergSourceNestedORCDataReadBenchmark.java new file mode 100644 index 000000000000..c651f9eea8c7 --- /dev/null +++ b/spark/v3.4/spark/src/jmh/java/org/apache/iceberg/spark/source/orc/IcebergSourceNestedORCDataReadBenchmark.java @@ -0,0 +1,183 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.source.orc; + +import static org.apache.iceberg.TableProperties.DEFAULT_FILE_FORMAT; +import static org.apache.iceberg.TableProperties.SPLIT_OPEN_FILE_COST; +import static org.apache.spark.sql.functions.expr; +import static org.apache.spark.sql.functions.lit; +import static org.apache.spark.sql.functions.struct; + +import java.io.IOException; +import java.util.Map; +import org.apache.iceberg.relocated.com.google.common.collect.Maps; +import org.apache.iceberg.spark.SparkReadOptions; +import org.apache.iceberg.spark.source.IcebergSourceNestedDataBenchmark; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.internal.SQLConf; +import org.openjdk.jmh.annotations.Benchmark; +import org.openjdk.jmh.annotations.Setup; +import org.openjdk.jmh.annotations.TearDown; +import org.openjdk.jmh.annotations.Threads; + +/** + * A benchmark that evaluates the performance of reading ORC data with a flat schema using Iceberg + * and the built-in file source in Spark. + * + *

To run this benchmark for spark-3.3: + * ./gradlew -DsparkVersions=3.3 :iceberg-spark:iceberg-spark-3.3_2.12:jmh + * -PjmhIncludeRegex=IcebergSourceNestedORCDataReadBenchmark + * -PjmhOutputPath=benchmark/iceberg-source-nested-orc-data-read-benchmark-result.txt + * + */ +public class IcebergSourceNestedORCDataReadBenchmark extends IcebergSourceNestedDataBenchmark { + + private static final int NUM_FILES = 10; + private static final int NUM_ROWS = 1000000; + + @Setup + public void setupBenchmark() { + setupSpark(); + appendData(); + } + + @TearDown + public void tearDownBenchmark() throws IOException { + tearDownSpark(); + cleanupFiles(); + } + + @Benchmark + @Threads(1) + public void readIcebergNonVectorized() { + Map tableProperties = Maps.newHashMap(); + tableProperties.put(SPLIT_OPEN_FILE_COST, Integer.toString(128 * 1024 * 1024)); + withTableProperties( + tableProperties, + () -> { + String tableLocation = table().location(); + Dataset df = spark().read().format("iceberg").load(tableLocation); + materialize(df); + }); + } + + @Benchmark + @Threads(1) + public void readIcebergVectorized() { + Map tableProperties = Maps.newHashMap(); + tableProperties.put(SPLIT_OPEN_FILE_COST, Integer.toString(128 * 1024 * 1024)); + withTableProperties( + tableProperties, + () -> { + String tableLocation = table().location(); + Dataset df = + spark() + .read() + .option(SparkReadOptions.VECTORIZATION_ENABLED, "true") + .format("iceberg") + .load(tableLocation); + materialize(df); + }); + } + + @Benchmark + @Threads(1) + public void readFileSourceNonVectorized() { + Map conf = Maps.newHashMap(); + conf.put(SQLConf.ORC_VECTORIZED_READER_ENABLED().key(), "false"); + conf.put(SQLConf.FILES_OPEN_COST_IN_BYTES().key(), Integer.toString(128 * 1024 * 1024)); + withSQLConf( + conf, + () -> { + Dataset df = spark().read().orc(dataLocation()); + materialize(df); + }); + } + + @Benchmark + @Threads(1) + public void readWithProjectionIcebergNonVectorized() { + Map tableProperties = Maps.newHashMap(); + tableProperties.put(SPLIT_OPEN_FILE_COST, Integer.toString(128 * 1024 * 1024)); + withTableProperties( + tableProperties, + () -> { + String tableLocation = table().location(); + Dataset df = + spark().read().format("iceberg").load(tableLocation).selectExpr("nested.col3"); + materialize(df); + }); + } + + @Benchmark + @Threads(1) + public void readWithProjectionIcebergVectorized() { + Map tableProperties = Maps.newHashMap(); + tableProperties.put(SPLIT_OPEN_FILE_COST, Integer.toString(128 * 1024 * 1024)); + withTableProperties( + tableProperties, + () -> { + String tableLocation = table().location(); + Dataset df = + spark() + .read() + .option(SparkReadOptions.VECTORIZATION_ENABLED, "true") + .format("iceberg") + .load(tableLocation) + .selectExpr("nested.col3"); + materialize(df); + }); + } + + @Benchmark + @Threads(1) + public void readWithProjectionFileSourceNonVectorized() { + Map conf = Maps.newHashMap(); + conf.put(SQLConf.ORC_VECTORIZED_READER_ENABLED().key(), "false"); + conf.put(SQLConf.FILES_OPEN_COST_IN_BYTES().key(), Integer.toString(128 * 1024 * 1024)); + withSQLConf( + conf, + () -> { + Dataset df = spark().read().orc(dataLocation()).selectExpr("nested.col3"); + materialize(df); + }); + } + + private void appendData() { + Map tableProperties = Maps.newHashMap(); + tableProperties.put(DEFAULT_FILE_FORMAT, "orc"); + withTableProperties( + tableProperties, + () -> { + for (int fileNum = 0; fileNum < NUM_FILES; fileNum++) { + Dataset df = + spark() + .range(NUM_ROWS) + .withColumn( + "nested", + struct( + expr("CAST(id AS string) AS col1"), + expr("CAST(id AS double) AS col2"), + lit(fileNum).cast("long").as("col3"))); + appendAsFile(df); + } + }); + } +} diff --git a/spark/v3.4/spark/src/jmh/java/org/apache/iceberg/spark/source/parquet/IcebergSourceFlatParquetDataFilterBenchmark.java b/spark/v3.4/spark/src/jmh/java/org/apache/iceberg/spark/source/parquet/IcebergSourceFlatParquetDataFilterBenchmark.java new file mode 100644 index 000000000000..1633709f4cd2 --- /dev/null +++ b/spark/v3.4/spark/src/jmh/java/org/apache/iceberg/spark/source/parquet/IcebergSourceFlatParquetDataFilterBenchmark.java @@ -0,0 +1,129 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.source.parquet; + +import static org.apache.iceberg.TableProperties.SPLIT_OPEN_FILE_COST; +import static org.apache.spark.sql.functions.current_date; +import static org.apache.spark.sql.functions.date_add; +import static org.apache.spark.sql.functions.expr; + +import java.io.IOException; +import java.util.Map; +import org.apache.iceberg.relocated.com.google.common.collect.Maps; +import org.apache.iceberg.spark.source.IcebergSourceFlatDataBenchmark; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.internal.SQLConf; +import org.openjdk.jmh.annotations.Benchmark; +import org.openjdk.jmh.annotations.Setup; +import org.openjdk.jmh.annotations.TearDown; +import org.openjdk.jmh.annotations.Threads; + +/** + * A benchmark that evaluates the file skipping capabilities in the Spark data source for Iceberg. + * + *

This class uses a dataset with a flat schema, where the records are clustered according to the + * column used in the filter predicate. + * + *

The performance is compared to the built-in file source in Spark. + * + *

To run this benchmark for spark-3.3: + * ./gradlew -DsparkVersions=3.3 :iceberg-spark:iceberg-spark-3.3_2.12:jmh + * -PjmhIncludeRegex=IcebergSourceFlatParquetDataFilterBenchmark + * -PjmhOutputPath=benchmark/iceberg-source-flat-parquet-data-filter-benchmark-result.txt + * + */ +public class IcebergSourceFlatParquetDataFilterBenchmark extends IcebergSourceFlatDataBenchmark { + + private static final String FILTER_COND = "dateCol == date_add(current_date(), 1)"; + private static final int NUM_FILES = 500; + private static final int NUM_ROWS = 10000; + + @Setup + public void setupBenchmark() { + setupSpark(); + appendData(); + } + + @TearDown + public void tearDownBenchmark() throws IOException { + tearDownSpark(); + cleanupFiles(); + } + + @Benchmark + @Threads(1) + public void readWithFilterIceberg() { + Map tableProperties = Maps.newHashMap(); + tableProperties.put(SPLIT_OPEN_FILE_COST, Integer.toString(128 * 1024 * 1024)); + withTableProperties( + tableProperties, + () -> { + String tableLocation = table().location(); + Dataset df = + spark().read().format("iceberg").load(tableLocation).filter(FILTER_COND); + materialize(df); + }); + } + + @Benchmark + @Threads(1) + public void readWithFilterFileSourceVectorized() { + Map conf = Maps.newHashMap(); + conf.put(SQLConf.PARQUET_VECTORIZED_READER_ENABLED().key(), "true"); + conf.put(SQLConf.FILES_OPEN_COST_IN_BYTES().key(), Integer.toString(128 * 1024 * 1024)); + withSQLConf( + conf, + () -> { + Dataset df = spark().read().parquet(dataLocation()).filter(FILTER_COND); + materialize(df); + }); + } + + @Benchmark + @Threads(1) + public void readWithFilterFileSourceNonVectorized() { + Map conf = Maps.newHashMap(); + conf.put(SQLConf.PARQUET_VECTORIZED_READER_ENABLED().key(), "false"); + conf.put(SQLConf.FILES_OPEN_COST_IN_BYTES().key(), Integer.toString(128 * 1024 * 1024)); + withSQLConf( + conf, + () -> { + Dataset df = spark().read().parquet(dataLocation()).filter(FILTER_COND); + materialize(df); + }); + } + + private void appendData() { + for (int fileNum = 1; fileNum < NUM_FILES; fileNum++) { + Dataset df = + spark() + .range(NUM_ROWS) + .withColumnRenamed("id", "longCol") + .withColumn("intCol", expr("CAST(longCol AS INT)")) + .withColumn("floatCol", expr("CAST(longCol AS FLOAT)")) + .withColumn("doubleCol", expr("CAST(longCol AS DOUBLE)")) + .withColumn("decimalCol", expr("CAST(longCol AS DECIMAL(20, 5))")) + .withColumn("dateCol", date_add(current_date(), fileNum)) + .withColumn("timestampCol", expr("TO_TIMESTAMP(dateCol)")) + .withColumn("stringCol", expr("CAST(dateCol AS STRING)")); + appendAsFile(df); + } + } +} diff --git a/spark/v3.4/spark/src/jmh/java/org/apache/iceberg/spark/source/parquet/IcebergSourceFlatParquetDataReadBenchmark.java b/spark/v3.4/spark/src/jmh/java/org/apache/iceberg/spark/source/parquet/IcebergSourceFlatParquetDataReadBenchmark.java new file mode 100644 index 000000000000..1babed8c5c79 --- /dev/null +++ b/spark/v3.4/spark/src/jmh/java/org/apache/iceberg/spark/source/parquet/IcebergSourceFlatParquetDataReadBenchmark.java @@ -0,0 +1,165 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.source.parquet; + +import static org.apache.iceberg.TableProperties.SPLIT_OPEN_FILE_COST; +import static org.apache.spark.sql.functions.current_date; +import static org.apache.spark.sql.functions.date_add; +import static org.apache.spark.sql.functions.expr; + +import java.io.IOException; +import java.util.Map; +import org.apache.iceberg.relocated.com.google.common.collect.Maps; +import org.apache.iceberg.spark.source.IcebergSourceFlatDataBenchmark; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.internal.SQLConf; +import org.openjdk.jmh.annotations.Benchmark; +import org.openjdk.jmh.annotations.Setup; +import org.openjdk.jmh.annotations.TearDown; +import org.openjdk.jmh.annotations.Threads; + +/** + * A benchmark that evaluates the performance of reading Parquet data with a flat schema using + * Iceberg and the built-in file source in Spark. + * + *

To run this benchmark for spark-3.3: + * ./gradlew -DsparkVersions=3.3 :iceberg-spark:iceberg-spark-3.3_2.12:jmh + * -PjmhIncludeRegex=IcebergSourceFlatParquetDataReadBenchmark + * -PjmhOutputPath=benchmark/iceberg-source-flat-parquet-data-read-benchmark-result.txt + * + */ +public class IcebergSourceFlatParquetDataReadBenchmark extends IcebergSourceFlatDataBenchmark { + + private static final int NUM_FILES = 10; + private static final int NUM_ROWS = 1000000; + + @Setup + public void setupBenchmark() { + setupSpark(); + appendData(); + } + + @TearDown + public void tearDownBenchmark() throws IOException { + tearDownSpark(); + cleanupFiles(); + } + + @Benchmark + @Threads(1) + public void readIceberg() { + Map tableProperties = Maps.newHashMap(); + tableProperties.put(SPLIT_OPEN_FILE_COST, Integer.toString(128 * 1024 * 1024)); + withTableProperties( + tableProperties, + () -> { + String tableLocation = table().location(); + Dataset df = spark().read().format("iceberg").load(tableLocation); + materialize(df); + }); + } + + @Benchmark + @Threads(1) + public void readFileSourceVectorized() { + Map conf = Maps.newHashMap(); + conf.put(SQLConf.PARQUET_VECTORIZED_READER_ENABLED().key(), "true"); + conf.put(SQLConf.FILES_OPEN_COST_IN_BYTES().key(), Integer.toString(128 * 1024 * 1024)); + withSQLConf( + conf, + () -> { + Dataset df = spark().read().parquet(dataLocation()); + materialize(df); + }); + } + + @Benchmark + @Threads(1) + public void readFileSourceNonVectorized() { + Map conf = Maps.newHashMap(); + conf.put(SQLConf.PARQUET_VECTORIZED_READER_ENABLED().key(), "false"); + conf.put(SQLConf.FILES_OPEN_COST_IN_BYTES().key(), Integer.toString(128 * 1024 * 1024)); + withSQLConf( + conf, + () -> { + Dataset df = spark().read().parquet(dataLocation()); + materialize(df); + }); + } + + @Benchmark + @Threads(1) + public void readWithProjectionIceberg() { + Map tableProperties = Maps.newHashMap(); + tableProperties.put(SPLIT_OPEN_FILE_COST, Integer.toString(128 * 1024 * 1024)); + withTableProperties( + tableProperties, + () -> { + String tableLocation = table().location(); + Dataset df = spark().read().format("iceberg").load(tableLocation).select("longCol"); + materialize(df); + }); + } + + @Benchmark + @Threads(1) + public void readWithProjectionFileSourceVectorized() { + Map conf = Maps.newHashMap(); + conf.put(SQLConf.PARQUET_VECTORIZED_READER_ENABLED().key(), "true"); + conf.put(SQLConf.FILES_OPEN_COST_IN_BYTES().key(), Integer.toString(128 * 1024 * 1024)); + withSQLConf( + conf, + () -> { + Dataset df = spark().read().parquet(dataLocation()).select("longCol"); + materialize(df); + }); + } + + @Benchmark + @Threads(1) + public void readWithProjectionFileSourceNonVectorized() { + Map conf = Maps.newHashMap(); + conf.put(SQLConf.PARQUET_VECTORIZED_READER_ENABLED().key(), "false"); + conf.put(SQLConf.FILES_OPEN_COST_IN_BYTES().key(), Integer.toString(128 * 1024 * 1024)); + withSQLConf( + conf, + () -> { + Dataset df = spark().read().parquet(dataLocation()).select("longCol"); + materialize(df); + }); + } + + private void appendData() { + for (int fileNum = 1; fileNum <= NUM_FILES; fileNum++) { + Dataset df = + spark() + .range(NUM_ROWS) + .withColumnRenamed("id", "longCol") + .withColumn("intCol", expr("CAST(longCol AS INT)")) + .withColumn("floatCol", expr("CAST(longCol AS FLOAT)")) + .withColumn("doubleCol", expr("CAST(longCol AS DOUBLE)")) + .withColumn("decimalCol", expr("CAST(longCol AS DECIMAL(20, 5))")) + .withColumn("dateCol", date_add(current_date(), fileNum)) + .withColumn("timestampCol", expr("TO_TIMESTAMP(dateCol)")) + .withColumn("stringCol", expr("CAST(dateCol AS STRING)")); + appendAsFile(df); + } + } +} diff --git a/spark/v3.4/spark/src/jmh/java/org/apache/iceberg/spark/source/parquet/IcebergSourceFlatParquetDataWriteBenchmark.java b/spark/v3.4/spark/src/jmh/java/org/apache/iceberg/spark/source/parquet/IcebergSourceFlatParquetDataWriteBenchmark.java new file mode 100644 index 000000000000..0bab9c401935 --- /dev/null +++ b/spark/v3.4/spark/src/jmh/java/org/apache/iceberg/spark/source/parquet/IcebergSourceFlatParquetDataWriteBenchmark.java @@ -0,0 +1,89 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.source.parquet; + +import static org.apache.spark.sql.functions.expr; + +import java.io.IOException; +import java.util.Map; +import org.apache.iceberg.relocated.com.google.common.collect.Maps; +import org.apache.iceberg.spark.source.IcebergSourceFlatDataBenchmark; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.SaveMode; +import org.apache.spark.sql.internal.SQLConf; +import org.openjdk.jmh.annotations.Benchmark; +import org.openjdk.jmh.annotations.Setup; +import org.openjdk.jmh.annotations.TearDown; +import org.openjdk.jmh.annotations.Threads; + +/** + * A benchmark that evaluates the performance of writing Parquet data with a flat schema using + * Iceberg and the built-in file source in Spark. + * + *

To run this benchmark for spark-3.3: + * ./gradlew -DsparkVersions=3.3 :iceberg-spark:iceberg-spark-3.3_2.12:jmh + * -PjmhIncludeRegex=IcebergSourceFlatParquetDataWriteBenchmark + * -PjmhOutputPath=benchmark/iceberg-source-flat-parquet-data-write-benchmark-result.txt + * + */ +public class IcebergSourceFlatParquetDataWriteBenchmark extends IcebergSourceFlatDataBenchmark { + + private static final int NUM_ROWS = 5000000; + + @Setup + public void setupBenchmark() { + setupSpark(); + } + + @TearDown + public void tearDownBenchmark() throws IOException { + tearDownSpark(); + cleanupFiles(); + } + + @Benchmark + @Threads(1) + public void writeIceberg() { + String tableLocation = table().location(); + benchmarkData().write().format("iceberg").mode(SaveMode.Append).save(tableLocation); + } + + @Benchmark + @Threads(1) + public void writeFileSource() { + Map conf = Maps.newHashMap(); + conf.put(SQLConf.PARQUET_COMPRESSION().key(), "gzip"); + withSQLConf(conf, () -> benchmarkData().write().mode(SaveMode.Append).parquet(dataLocation())); + } + + private Dataset benchmarkData() { + return spark() + .range(NUM_ROWS) + .withColumnRenamed("id", "longCol") + .withColumn("intCol", expr("CAST(longCol AS INT)")) + .withColumn("floatCol", expr("CAST(longCol AS FLOAT)")) + .withColumn("doubleCol", expr("CAST(longCol AS DOUBLE)")) + .withColumn("decimalCol", expr("CAST(longCol AS DECIMAL(20, 5))")) + .withColumn("dateCol", expr("DATE_ADD(CURRENT_DATE(), (intCol % 20))")) + .withColumn("timestampCol", expr("TO_TIMESTAMP(dateCol)")) + .withColumn("stringCol", expr("CAST(dateCol AS STRING)")) + .coalesce(1); + } +} diff --git a/spark/v3.4/spark/src/jmh/java/org/apache/iceberg/spark/source/parquet/IcebergSourceNestedListParquetDataWriteBenchmark.java b/spark/v3.4/spark/src/jmh/java/org/apache/iceberg/spark/source/parquet/IcebergSourceNestedListParquetDataWriteBenchmark.java new file mode 100644 index 000000000000..47d866f1b803 --- /dev/null +++ b/spark/v3.4/spark/src/jmh/java/org/apache/iceberg/spark/source/parquet/IcebergSourceNestedListParquetDataWriteBenchmark.java @@ -0,0 +1,89 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.source.parquet; + +import static org.apache.spark.sql.functions.array_repeat; +import static org.apache.spark.sql.functions.expr; +import static org.apache.spark.sql.functions.struct; + +import java.io.IOException; +import java.util.Map; +import org.apache.iceberg.relocated.com.google.common.collect.Maps; +import org.apache.iceberg.spark.source.IcebergSourceNestedListDataBenchmark; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.SaveMode; +import org.apache.spark.sql.internal.SQLConf; +import org.openjdk.jmh.annotations.Benchmark; +import org.openjdk.jmh.annotations.Param; +import org.openjdk.jmh.annotations.Setup; +import org.openjdk.jmh.annotations.TearDown; +import org.openjdk.jmh.annotations.Threads; + +/** + * A benchmark that evaluates the performance of writing nested Parquet data using Iceberg and the + * built-in file source in Spark. + * + *

To run this benchmark for spark-3.3: + * ./gradlew -DsparkVersions=3.3 :iceberg-spark:iceberg-spark-3.3_2.12:jmh + * -PjmhIncludeRegex=IcebergSourceNestedListParquetDataWriteBenchmark + * -PjmhOutputPath=benchmark/iceberg-source-nested-list-parquet-data-write-benchmark-result.txt + * + */ +public class IcebergSourceNestedListParquetDataWriteBenchmark + extends IcebergSourceNestedListDataBenchmark { + + @Setup + public void setupBenchmark() { + setupSpark(); + } + + @TearDown + public void tearDownBenchmark() throws IOException { + tearDownSpark(); + cleanupFiles(); + } + + @Param({"2000", "20000"}) + private int numRows; + + @Benchmark + @Threads(1) + public void writeIceberg() { + String tableLocation = table().location(); + benchmarkData().write().format("iceberg").mode(SaveMode.Append).save(tableLocation); + } + + @Benchmark + @Threads(1) + public void writeFileSource() { + Map conf = Maps.newHashMap(); + conf.put(SQLConf.PARQUET_COMPRESSION().key(), "gzip"); + withSQLConf(conf, () -> benchmarkData().write().mode(SaveMode.Append).parquet(dataLocation())); + } + + private Dataset benchmarkData() { + return spark() + .range(numRows) + .withColumn( + "outerlist", + array_repeat(struct(expr("array_repeat(CAST(id AS string), 1000) AS innerlist")), 10)) + .coalesce(1); + } +} diff --git a/spark/v3.4/spark/src/jmh/java/org/apache/iceberg/spark/source/parquet/IcebergSourceNestedParquetDataFilterBenchmark.java b/spark/v3.4/spark/src/jmh/java/org/apache/iceberg/spark/source/parquet/IcebergSourceNestedParquetDataFilterBenchmark.java new file mode 100644 index 000000000000..7da6499c14a3 --- /dev/null +++ b/spark/v3.4/spark/src/jmh/java/org/apache/iceberg/spark/source/parquet/IcebergSourceNestedParquetDataFilterBenchmark.java @@ -0,0 +1,128 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.source.parquet; + +import static org.apache.iceberg.TableProperties.SPLIT_OPEN_FILE_COST; +import static org.apache.spark.sql.functions.expr; +import static org.apache.spark.sql.functions.lit; +import static org.apache.spark.sql.functions.struct; + +import java.io.IOException; +import java.util.Map; +import org.apache.iceberg.relocated.com.google.common.collect.Maps; +import org.apache.iceberg.spark.source.IcebergSourceNestedDataBenchmark; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.internal.SQLConf; +import org.openjdk.jmh.annotations.Benchmark; +import org.openjdk.jmh.annotations.Setup; +import org.openjdk.jmh.annotations.TearDown; +import org.openjdk.jmh.annotations.Threads; + +/** + * A benchmark that evaluates the file skipping capabilities in the Spark data source for Iceberg. + * + *

This class uses a dataset with nested data, where the records are clustered according to the + * column used in the filter predicate. + * + *

The performance is compared to the built-in file source in Spark. + * + *

To run this benchmark for spark-3.3: + * ./gradlew -DsparkVersions=3.3 :iceberg-spark:iceberg-spark-3.3_2.12:jmh + * -PjmhIncludeRegex=IcebergSourceNestedParquetDataFilterBenchmark + * -PjmhOutputPath=benchmark/iceberg-source-nested-parquet-data-filter-benchmark-result.txt + * + */ +public class IcebergSourceNestedParquetDataFilterBenchmark + extends IcebergSourceNestedDataBenchmark { + + private static final String FILTER_COND = "nested.col3 == 0"; + private static final int NUM_FILES = 500; + private static final int NUM_ROWS = 10000; + + @Setup + public void setupBenchmark() { + setupSpark(); + appendData(); + } + + @TearDown + public void tearDownBenchmark() throws IOException { + tearDownSpark(); + cleanupFiles(); + } + + @Benchmark + @Threads(1) + public void readWithFilterIceberg() { + Map tableProperties = Maps.newHashMap(); + tableProperties.put(SPLIT_OPEN_FILE_COST, Integer.toString(128 * 1024 * 1024)); + withTableProperties( + tableProperties, + () -> { + String tableLocation = table().location(); + Dataset df = + spark().read().format("iceberg").load(tableLocation).filter(FILTER_COND); + materialize(df); + }); + } + + @Benchmark + @Threads(1) + public void readWithFilterFileSourceVectorized() { + Map conf = Maps.newHashMap(); + conf.put(SQLConf.PARQUET_VECTORIZED_READER_ENABLED().key(), "true"); + conf.put(SQLConf.FILES_OPEN_COST_IN_BYTES().key(), Integer.toString(128 * 1024 * 1024)); + withSQLConf( + conf, + () -> { + Dataset df = spark().read().parquet(dataLocation()).filter(FILTER_COND); + materialize(df); + }); + } + + @Benchmark + @Threads(1) + public void readWithFilterFileSourceNonVectorized() { + Map conf = Maps.newHashMap(); + conf.put(SQLConf.PARQUET_VECTORIZED_READER_ENABLED().key(), "false"); + conf.put(SQLConf.FILES_OPEN_COST_IN_BYTES().key(), Integer.toString(128 * 1024 * 1024)); + withSQLConf( + conf, + () -> { + Dataset df = spark().read().parquet(dataLocation()).filter(FILTER_COND); + materialize(df); + }); + } + + private void appendData() { + for (int fileNum = 1; fileNum <= NUM_FILES; fileNum++) { + Dataset df = + spark() + .range(NUM_ROWS) + .withColumn( + "nested", + struct( + expr("CAST(id AS string) AS col1"), + expr("CAST(id AS double) AS col2"), + lit(fileNum).cast("long").as("col3"))); + appendAsFile(df); + } + } +} diff --git a/spark/v3.4/spark/src/jmh/java/org/apache/iceberg/spark/source/parquet/IcebergSourceNestedParquetDataReadBenchmark.java b/spark/v3.4/spark/src/jmh/java/org/apache/iceberg/spark/source/parquet/IcebergSourceNestedParquetDataReadBenchmark.java new file mode 100644 index 000000000000..e55717fdc442 --- /dev/null +++ b/spark/v3.4/spark/src/jmh/java/org/apache/iceberg/spark/source/parquet/IcebergSourceNestedParquetDataReadBenchmark.java @@ -0,0 +1,166 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.source.parquet; + +import static org.apache.iceberg.TableProperties.SPLIT_OPEN_FILE_COST; +import static org.apache.spark.sql.functions.expr; +import static org.apache.spark.sql.functions.lit; +import static org.apache.spark.sql.functions.struct; + +import java.io.IOException; +import java.util.Map; +import org.apache.iceberg.relocated.com.google.common.collect.Maps; +import org.apache.iceberg.spark.source.IcebergSourceNestedDataBenchmark; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.internal.SQLConf; +import org.openjdk.jmh.annotations.Benchmark; +import org.openjdk.jmh.annotations.Setup; +import org.openjdk.jmh.annotations.TearDown; +import org.openjdk.jmh.annotations.Threads; + +/** + * A benchmark that evaluates the performance of reading nested Parquet data using Iceberg and the + * built-in file source in Spark. + * + *

To run this benchmark for spark-3.3: + * ./gradlew -DsparkVersions=3.3 :iceberg-spark:iceberg-spark-3.3_2.12:jmh + * -PjmhIncludeRegex=IcebergSourceNestedParquetDataReadBenchmark + * -PjmhOutputPath=benchmark/iceberg-source-nested-parquet-data-read-benchmark-result.txt + * + */ +public class IcebergSourceNestedParquetDataReadBenchmark extends IcebergSourceNestedDataBenchmark { + + private static final int NUM_FILES = 10; + private static final int NUM_ROWS = 1000000; + + @Setup + public void setupBenchmark() { + setupSpark(); + appendData(); + } + + @TearDown + public void tearDownBenchmark() throws IOException { + tearDownSpark(); + cleanupFiles(); + } + + @Benchmark + @Threads(1) + public void readIceberg() { + Map tableProperties = Maps.newHashMap(); + tableProperties.put(SPLIT_OPEN_FILE_COST, Integer.toString(128 * 1024 * 1024)); + withTableProperties( + tableProperties, + () -> { + String tableLocation = table().location(); + Dataset df = spark().read().format("iceberg").load(tableLocation); + materialize(df); + }); + } + + @Benchmark + @Threads(1) + public void readFileSourceVectorized() { + Map conf = Maps.newHashMap(); + conf.put(SQLConf.PARQUET_VECTORIZED_READER_ENABLED().key(), "true"); + conf.put(SQLConf.FILES_OPEN_COST_IN_BYTES().key(), Integer.toString(128 * 1024 * 1024)); + withSQLConf( + conf, + () -> { + Dataset df = spark().read().parquet(dataLocation()); + materialize(df); + }); + } + + @Benchmark + @Threads(1) + public void readFileSourceNonVectorized() { + Map conf = Maps.newHashMap(); + conf.put(SQLConf.PARQUET_VECTORIZED_READER_ENABLED().key(), "false"); + conf.put(SQLConf.FILES_OPEN_COST_IN_BYTES().key(), Integer.toString(128 * 1024 * 1024)); + withSQLConf( + conf, + () -> { + Dataset df = spark().read().parquet(dataLocation()); + materialize(df); + }); + } + + @Benchmark + @Threads(1) + public void readWithProjectionIceberg() { + Map tableProperties = Maps.newHashMap(); + tableProperties.put(SPLIT_OPEN_FILE_COST, Integer.toString(128 * 1024 * 1024)); + withTableProperties( + tableProperties, + () -> { + String tableLocation = table().location(); + Dataset df = + spark().read().format("iceberg").load(tableLocation).selectExpr("nested.col3"); + materialize(df); + }); + } + + @Benchmark + @Threads(1) + public void readWithProjectionFileSourceVectorized() { + Map conf = Maps.newHashMap(); + conf.put(SQLConf.PARQUET_VECTORIZED_READER_ENABLED().key(), "true"); + conf.put(SQLConf.FILES_OPEN_COST_IN_BYTES().key(), Integer.toString(128 * 1024 * 1024)); + conf.put(SQLConf.NESTED_SCHEMA_PRUNING_ENABLED().key(), "true"); + withSQLConf( + conf, + () -> { + Dataset df = spark().read().parquet(dataLocation()).selectExpr("nested.col3"); + materialize(df); + }); + } + + @Benchmark + @Threads(1) + public void readWithProjectionFileSourceNonVectorized() { + Map conf = Maps.newHashMap(); + conf.put(SQLConf.PARQUET_VECTORIZED_READER_ENABLED().key(), "false"); + conf.put(SQLConf.FILES_OPEN_COST_IN_BYTES().key(), Integer.toString(128 * 1024 * 1024)); + conf.put(SQLConf.NESTED_SCHEMA_PRUNING_ENABLED().key(), "true"); + withSQLConf( + conf, + () -> { + Dataset df = spark().read().parquet(dataLocation()).selectExpr("nested.col3"); + materialize(df); + }); + } + + private void appendData() { + for (int fileNum = 0; fileNum < NUM_FILES; fileNum++) { + Dataset df = + spark() + .range(NUM_ROWS) + .withColumn( + "nested", + struct( + expr("CAST(id AS string) AS col1"), + expr("CAST(id AS double) AS col2"), + lit(fileNum).cast("long").as("col3"))); + appendAsFile(df); + } + } +} diff --git a/spark/v3.4/spark/src/jmh/java/org/apache/iceberg/spark/source/parquet/IcebergSourceNestedParquetDataWriteBenchmark.java b/spark/v3.4/spark/src/jmh/java/org/apache/iceberg/spark/source/parquet/IcebergSourceNestedParquetDataWriteBenchmark.java new file mode 100644 index 000000000000..981107dc651b --- /dev/null +++ b/spark/v3.4/spark/src/jmh/java/org/apache/iceberg/spark/source/parquet/IcebergSourceNestedParquetDataWriteBenchmark.java @@ -0,0 +1,88 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.source.parquet; + +import static org.apache.spark.sql.functions.expr; +import static org.apache.spark.sql.functions.struct; + +import java.io.IOException; +import java.util.Map; +import org.apache.iceberg.relocated.com.google.common.collect.Maps; +import org.apache.iceberg.spark.source.IcebergSourceNestedDataBenchmark; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.SaveMode; +import org.apache.spark.sql.internal.SQLConf; +import org.openjdk.jmh.annotations.Benchmark; +import org.openjdk.jmh.annotations.Setup; +import org.openjdk.jmh.annotations.TearDown; +import org.openjdk.jmh.annotations.Threads; + +/** + * A benchmark that evaluates the performance of writing nested Parquet data using Iceberg and the + * built-in file source in Spark. + * + *

To run this benchmark for spark-3.3: + * ./gradlew -DsparkVersions=3.3 :iceberg-spark:iceberg-spark-3.3_2.12:jmh + * -PjmhIncludeRegex=IcebergSourceNestedParquetDataWriteBenchmark + * -PjmhOutputPath=benchmark/iceberg-source-nested-parquet-data-write-benchmark-result.txt + * + */ +public class IcebergSourceNestedParquetDataWriteBenchmark extends IcebergSourceNestedDataBenchmark { + + private static final int NUM_ROWS = 5000000; + + @Setup + public void setupBenchmark() { + setupSpark(); + } + + @TearDown + public void tearDownBenchmark() throws IOException { + tearDownSpark(); + cleanupFiles(); + } + + @Benchmark + @Threads(1) + public void writeIceberg() { + String tableLocation = table().location(); + benchmarkData().write().format("iceberg").mode(SaveMode.Append).save(tableLocation); + } + + @Benchmark + @Threads(1) + public void writeFileSource() { + Map conf = Maps.newHashMap(); + conf.put(SQLConf.PARQUET_COMPRESSION().key(), "gzip"); + withSQLConf(conf, () -> benchmarkData().write().mode(SaveMode.Append).parquet(dataLocation())); + } + + private Dataset benchmarkData() { + return spark() + .range(NUM_ROWS) + .withColumn( + "nested", + struct( + expr("CAST(id AS string) AS col1"), + expr("CAST(id AS double) AS col2"), + expr("id AS col3"))) + .coalesce(1); + } +} diff --git a/spark/v3.4/spark/src/jmh/java/org/apache/iceberg/spark/source/parquet/IcebergSourceParquetEqDeleteBenchmark.java b/spark/v3.4/spark/src/jmh/java/org/apache/iceberg/spark/source/parquet/IcebergSourceParquetEqDeleteBenchmark.java new file mode 100644 index 000000000000..f1e5956dbdc4 --- /dev/null +++ b/spark/v3.4/spark/src/jmh/java/org/apache/iceberg/spark/source/parquet/IcebergSourceParquetEqDeleteBenchmark.java @@ -0,0 +1,57 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.source.parquet; + +import java.io.IOException; +import org.apache.iceberg.FileFormat; +import org.apache.iceberg.spark.source.IcebergSourceDeleteBenchmark; +import org.openjdk.jmh.annotations.Param; + +/** + * A benchmark that evaluates the non-vectorized read and vectorized read with equality delete in + * the Spark data source for Iceberg. + * + *

This class uses a dataset with a flat schema. To run this benchmark for spark-3.3: + * ./gradlew -DsparkVersions=3.3 :iceberg-spark:iceberg-spark-3.3:jmh + * -PjmhIncludeRegex=IcebergSourceParquetEqDeleteBenchmark + * -PjmhOutputPath=benchmark/iceberg-source-parquet-eq-delete-benchmark-result.txt + * + */ +public class IcebergSourceParquetEqDeleteBenchmark extends IcebergSourceDeleteBenchmark { + @Param({"0", "0.000001", "0.05", "0.25", "0.5", "1"}) + private double percentDeleteRow; + + @Override + protected void appendData() throws IOException { + for (int fileNum = 1; fileNum <= NUM_FILES; fileNum++) { + writeData(fileNum); + + if (percentDeleteRow > 0) { + // add equality deletes + table().refresh(); + writeEqDeletes(NUM_ROWS, percentDeleteRow); + } + } + } + + @Override + protected FileFormat fileFormat() { + return FileFormat.PARQUET; + } +} diff --git a/spark/v3.4/spark/src/jmh/java/org/apache/iceberg/spark/source/parquet/IcebergSourceParquetMultiDeleteFileBenchmark.java b/spark/v3.4/spark/src/jmh/java/org/apache/iceberg/spark/source/parquet/IcebergSourceParquetMultiDeleteFileBenchmark.java new file mode 100644 index 000000000000..2ac3de2ff947 --- /dev/null +++ b/spark/v3.4/spark/src/jmh/java/org/apache/iceberg/spark/source/parquet/IcebergSourceParquetMultiDeleteFileBenchmark.java @@ -0,0 +1,57 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.source.parquet; + +import java.io.IOException; +import org.apache.iceberg.DataFile; +import org.apache.iceberg.FileFormat; +import org.apache.iceberg.spark.source.IcebergSourceDeleteBenchmark; +import org.openjdk.jmh.annotations.Param; + +/** + * A benchmark that evaluates the non-vectorized read and vectorized read with pos-delete in the + * Spark data source for Iceberg. + * + *

This class uses a dataset with a flat schema. To run this benchmark for spark-3.3: + * ./gradlew -DsparkVersions=3.3 :iceberg-spark:iceberg-spark-3.3:jmh \ + * -PjmhIncludeRegex=IcebergSourceParquetMultiDeleteFileBenchmark \ + * -PjmhOutputPath=benchmark/iceberg-source-parquet-multi-delete-file-benchmark-result.txt + * + */ +public class IcebergSourceParquetMultiDeleteFileBenchmark extends IcebergSourceDeleteBenchmark { + @Param({"1", "2", "5", "10"}) + private int numDeleteFile; + + @Override + protected void appendData() throws IOException { + for (int fileNum = 1; fileNum <= NUM_FILES; fileNum++) { + writeData(fileNum); + + table().refresh(); + for (DataFile file : table().currentSnapshot().addedDataFiles(table().io())) { + writePosDeletes(file.path(), NUM_ROWS, 0.25, numDeleteFile); + } + } + } + + @Override + protected FileFormat fileFormat() { + return FileFormat.PARQUET; + } +} diff --git a/spark/v3.4/spark/src/jmh/java/org/apache/iceberg/spark/source/parquet/IcebergSourceParquetPosDeleteBenchmark.java b/spark/v3.4/spark/src/jmh/java/org/apache/iceberg/spark/source/parquet/IcebergSourceParquetPosDeleteBenchmark.java new file mode 100644 index 000000000000..8cd6fb36fcf5 --- /dev/null +++ b/spark/v3.4/spark/src/jmh/java/org/apache/iceberg/spark/source/parquet/IcebergSourceParquetPosDeleteBenchmark.java @@ -0,0 +1,60 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.source.parquet; + +import java.io.IOException; +import org.apache.iceberg.DataFile; +import org.apache.iceberg.FileFormat; +import org.apache.iceberg.spark.source.IcebergSourceDeleteBenchmark; +import org.openjdk.jmh.annotations.Param; + +/** + * A benchmark that evaluates the non-vectorized read and vectorized read with pos-delete in the + * Spark data source for Iceberg. + * + *

This class uses a dataset with a flat schema. To run this benchmark for spark-3.3: + * ./gradlew -DsparkVersions=3.3 :iceberg-spark:iceberg-spark-3.3:jmh + * -PjmhIncludeRegex=IcebergSourceParquetPosDeleteBenchmark + * -PjmhOutputPath=benchmark/iceberg-source-parquet-pos-delete-benchmark-result.txt + * + */ +public class IcebergSourceParquetPosDeleteBenchmark extends IcebergSourceDeleteBenchmark { + @Param({"0", "0.000001", "0.05", "0.25", "0.5", "1"}) + private double percentDeleteRow; + + @Override + protected void appendData() throws IOException { + for (int fileNum = 1; fileNum <= NUM_FILES; fileNum++) { + writeData(fileNum); + + if (percentDeleteRow > 0) { + // add pos-deletes + table().refresh(); + for (DataFile file : table().currentSnapshot().addedDataFiles(table().io())) { + writePosDeletes(file.path(), NUM_ROWS, percentDeleteRow); + } + } + } + } + + @Override + protected FileFormat fileFormat() { + return FileFormat.PARQUET; + } +} diff --git a/spark/v3.4/spark/src/jmh/java/org/apache/iceberg/spark/source/parquet/IcebergSourceParquetWithUnrelatedDeleteBenchmark.java b/spark/v3.4/spark/src/jmh/java/org/apache/iceberg/spark/source/parquet/IcebergSourceParquetWithUnrelatedDeleteBenchmark.java new file mode 100644 index 000000000000..1ae48e213cb7 --- /dev/null +++ b/spark/v3.4/spark/src/jmh/java/org/apache/iceberg/spark/source/parquet/IcebergSourceParquetWithUnrelatedDeleteBenchmark.java @@ -0,0 +1,64 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.source.parquet; + +import java.io.IOException; +import org.apache.iceberg.DataFile; +import org.apache.iceberg.FileFormat; +import org.apache.iceberg.spark.source.IcebergSourceDeleteBenchmark; +import org.openjdk.jmh.annotations.Param; + +/** + * A benchmark that evaluates the non-vectorized read and vectorized read with pos-delete in the + * Spark data source for Iceberg. + * + *

This class uses a dataset with a flat schema. To run this benchmark for spark-3.3: + * ./gradlew -DsparkVersions=3.3 :iceberg-spark:iceberg-spark-3.3:jmh + * -PjmhIncludeRegex=IcebergSourceParquetWithUnrelatedDeleteBenchmark + * -PjmhOutputPath=benchmark/iceberg-source-parquet-with-unrelated-delete-benchmark-result.txt + * + */ +public class IcebergSourceParquetWithUnrelatedDeleteBenchmark extends IcebergSourceDeleteBenchmark { + private static final double PERCENT_DELETE_ROW = 0.05; + + @Param({"0", "0.05", "0.25", "0.5"}) + private double percentUnrelatedDeletes; + + @Override + protected void appendData() throws IOException { + for (int fileNum = 1; fileNum <= NUM_FILES; fileNum++) { + writeData(fileNum); + + table().refresh(); + for (DataFile file : table().currentSnapshot().addedDataFiles(table().io())) { + writePosDeletesWithNoise( + file.path(), + NUM_ROWS, + PERCENT_DELETE_ROW, + (int) (percentUnrelatedDeletes / PERCENT_DELETE_ROW), + 1); + } + } + } + + @Override + protected FileFormat fileFormat() { + return FileFormat.PARQUET; + } +} diff --git a/spark/v3.4/spark/src/jmh/java/org/apache/iceberg/spark/source/parquet/ParquetWritersBenchmark.java b/spark/v3.4/spark/src/jmh/java/org/apache/iceberg/spark/source/parquet/ParquetWritersBenchmark.java new file mode 100644 index 000000000000..3857aabf5655 --- /dev/null +++ b/spark/v3.4/spark/src/jmh/java/org/apache/iceberg/spark/source/parquet/ParquetWritersBenchmark.java @@ -0,0 +1,39 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.source.parquet; + +import org.apache.iceberg.FileFormat; +import org.apache.iceberg.spark.source.WritersBenchmark; + +/** + * A benchmark that evaluates the performance of various Iceberg writers for Parquet data. + * + *

To run this benchmark for spark-3.3: + * ./gradlew -DsparkVersions=3.3 :iceberg-spark:iceberg-spark-3.3_2.12:jmh \ + * -PjmhIncludeRegex=ParquetWritersBenchmark \ + * -PjmhOutputPath=benchmark/parquet-writers-benchmark-result.txt + * + */ +public class ParquetWritersBenchmark extends WritersBenchmark { + + @Override + protected FileFormat fileFormat() { + return FileFormat.PARQUET; + } +} diff --git a/spark/v3.4/spark/src/jmh/java/org/apache/iceberg/spark/source/parquet/vectorized/VectorizedReadDictionaryEncodedFlatParquetDataBenchmark.java b/spark/v3.4/spark/src/jmh/java/org/apache/iceberg/spark/source/parquet/vectorized/VectorizedReadDictionaryEncodedFlatParquetDataBenchmark.java new file mode 100644 index 000000000000..0a30639d1c79 --- /dev/null +++ b/spark/v3.4/spark/src/jmh/java/org/apache/iceberg/spark/source/parquet/vectorized/VectorizedReadDictionaryEncodedFlatParquetDataBenchmark.java @@ -0,0 +1,127 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.source.parquet.vectorized; + +import static org.apache.spark.sql.functions.col; +import static org.apache.spark.sql.functions.date_add; +import static org.apache.spark.sql.functions.lit; +import static org.apache.spark.sql.functions.pmod; +import static org.apache.spark.sql.functions.to_date; +import static org.apache.spark.sql.functions.to_timestamp; + +import java.util.Map; +import org.apache.iceberg.TableProperties; +import org.apache.iceberg.relocated.com.google.common.collect.Maps; +import org.apache.spark.sql.Column; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.SaveMode; +import org.apache.spark.sql.types.DataTypes; +import org.openjdk.jmh.annotations.Setup; + +/** + * Benchmark to compare performance of reading Parquet dictionary encoded data with a flat schema + * using vectorized Iceberg read path and the built-in file source in Spark. + * + *

To run this benchmark for spark-3.3: + * ./gradlew -DsparkVersions=3.3 :iceberg-spark:iceberg-spark-3.3_2.12:jmh \ + * -PjmhIncludeRegex=VectorizedReadDictionaryEncodedFlatParquetDataBenchmark \ + * -PjmhOutputPath=benchmark/results.txt + * + */ +public class VectorizedReadDictionaryEncodedFlatParquetDataBenchmark + extends VectorizedReadFlatParquetDataBenchmark { + + @Setup + @Override + public void setupBenchmark() { + setupSpark(true); + appendData(); + } + + @Override + Map parquetWriteProps() { + Map properties = Maps.newHashMap(); + properties.put(TableProperties.METADATA_COMPRESSION, "gzip"); + return properties; + } + + @Override + void appendData() { + Dataset df = idDF(); + df = withLongColumnDictEncoded(df); + df = withIntColumnDictEncoded(df); + df = withFloatColumnDictEncoded(df); + df = withDoubleColumnDictEncoded(df); + df = withBigDecimalColumnNotDictEncoded(df); // no dictionary for fixed len binary in Parquet v1 + df = withDecimalColumnDictEncoded(df); + df = withDateColumnDictEncoded(df); + df = withTimestampColumnDictEncoded(df); + df = withStringColumnDictEncoded(df); + df = df.drop("id"); + df.write().format("iceberg").mode(SaveMode.Append).save(table().location()); + } + + private static Column modColumn() { + return pmod(col("id"), lit(9)); + } + + private Dataset idDF() { + return spark().range(0, NUM_ROWS_PER_FILE * NUM_FILES, 1, NUM_FILES).toDF(); + } + + private static Dataset withLongColumnDictEncoded(Dataset df) { + return df.withColumn("longCol", modColumn().cast(DataTypes.LongType)); + } + + private static Dataset withIntColumnDictEncoded(Dataset df) { + return df.withColumn("intCol", modColumn().cast(DataTypes.IntegerType)); + } + + private static Dataset withFloatColumnDictEncoded(Dataset df) { + return df.withColumn("floatCol", modColumn().cast(DataTypes.FloatType)); + } + + private static Dataset withDoubleColumnDictEncoded(Dataset df) { + return df.withColumn("doubleCol", modColumn().cast(DataTypes.DoubleType)); + } + + private static Dataset withBigDecimalColumnNotDictEncoded(Dataset df) { + return df.withColumn("bigDecimalCol", modColumn().cast("decimal(20,5)")); + } + + private static Dataset withDecimalColumnDictEncoded(Dataset df) { + return df.withColumn("decimalCol", modColumn().cast("decimal(18,5)")); + } + + private static Dataset withDateColumnDictEncoded(Dataset df) { + Column days = modColumn().cast(DataTypes.ShortType); + return df.withColumn("dateCol", date_add(to_date(lit("04/12/2019"), "MM/dd/yyyy"), days)); + } + + private static Dataset withTimestampColumnDictEncoded(Dataset df) { + Column days = modColumn().cast(DataTypes.ShortType); + return df.withColumn( + "timestampCol", to_timestamp(date_add(to_date(lit("04/12/2019"), "MM/dd/yyyy"), days))); + } + + private static Dataset withStringColumnDictEncoded(Dataset df) { + return df.withColumn("stringCol", modColumn().cast(DataTypes.StringType)); + } +} diff --git a/spark/v3.4/spark/src/jmh/java/org/apache/iceberg/spark/source/parquet/vectorized/VectorizedReadFlatParquetDataBenchmark.java b/spark/v3.4/spark/src/jmh/java/org/apache/iceberg/spark/source/parquet/vectorized/VectorizedReadFlatParquetDataBenchmark.java new file mode 100644 index 000000000000..abfc5d950a0d --- /dev/null +++ b/spark/v3.4/spark/src/jmh/java/org/apache/iceberg/spark/source/parquet/vectorized/VectorizedReadFlatParquetDataBenchmark.java @@ -0,0 +1,355 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.source.parquet.vectorized; + +import static org.apache.iceberg.types.Types.NestedField.optional; +import static org.apache.spark.sql.functions.col; +import static org.apache.spark.sql.functions.current_date; +import static org.apache.spark.sql.functions.date_add; +import static org.apache.spark.sql.functions.expr; +import static org.apache.spark.sql.functions.lit; +import static org.apache.spark.sql.functions.pmod; +import static org.apache.spark.sql.functions.when; + +import java.io.IOException; +import java.util.Map; +import org.apache.hadoop.conf.Configuration; +import org.apache.iceberg.PartitionSpec; +import org.apache.iceberg.Schema; +import org.apache.iceberg.Table; +import org.apache.iceberg.TableProperties; +import org.apache.iceberg.hadoop.HadoopTables; +import org.apache.iceberg.relocated.com.google.common.collect.Maps; +import org.apache.iceberg.spark.source.IcebergSourceBenchmark; +import org.apache.iceberg.types.Types; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.internal.SQLConf; +import org.openjdk.jmh.annotations.Benchmark; +import org.openjdk.jmh.annotations.Setup; +import org.openjdk.jmh.annotations.TearDown; +import org.openjdk.jmh.annotations.Threads; + +/** + * Benchmark to compare performance of reading Parquet data with a flat schema using vectorized + * Iceberg read path and the built-in file source in Spark. + * + *

To run this benchmark for spark-3.3: + * ./gradlew -DsparkVersions=3.3 :iceberg-spark:iceberg-spark-3.3_2.12:jmh \ + * -PjmhIncludeRegex=VectorizedReadFlatParquetDataBenchmark \ + * -PjmhOutputPath=benchmark/results.txt + * + */ +public class VectorizedReadFlatParquetDataBenchmark extends IcebergSourceBenchmark { + + static final int NUM_FILES = 5; + static final int NUM_ROWS_PER_FILE = 10_000_000; + + @Setup + public void setupBenchmark() { + setupSpark(); + appendData(); + } + + @TearDown + public void tearDownBenchmark() throws IOException { + tearDownSpark(); + cleanupFiles(); + } + + @Override + protected Configuration initHadoopConf() { + return new Configuration(); + } + + @Override + protected Table initTable() { + // bigDecimalCol is big enough to be encoded as fix len binary (9 bytes), + // decimalCol is small enough to be encoded as a 64-bit int + Schema schema = + new Schema( + optional(1, "longCol", Types.LongType.get()), + optional(2, "intCol", Types.IntegerType.get()), + optional(3, "floatCol", Types.FloatType.get()), + optional(4, "doubleCol", Types.DoubleType.get()), + optional(5, "bigDecimalCol", Types.DecimalType.of(20, 5)), + optional(6, "decimalCol", Types.DecimalType.of(18, 5)), + optional(7, "dateCol", Types.DateType.get()), + optional(8, "timestampCol", Types.TimestampType.withZone()), + optional(9, "stringCol", Types.StringType.get())); + PartitionSpec partitionSpec = PartitionSpec.unpartitioned(); + HadoopTables tables = new HadoopTables(hadoopConf()); + Map properties = parquetWriteProps(); + return tables.create(schema, partitionSpec, properties, newTableLocation()); + } + + Map parquetWriteProps() { + Map properties = Maps.newHashMap(); + properties.put(TableProperties.METADATA_COMPRESSION, "gzip"); + properties.put(TableProperties.PARQUET_DICT_SIZE_BYTES, "1"); + return properties; + } + + void appendData() { + for (int fileNum = 1; fileNum <= NUM_FILES; fileNum++) { + Dataset df = + spark() + .range(NUM_ROWS_PER_FILE) + .withColumn( + "longCol", + when(pmod(col("id"), lit(10)).equalTo(lit(0)), lit(null)).otherwise(col("id"))) + .drop("id") + .withColumn("intCol", expr("CAST(longCol AS INT)")) + .withColumn("floatCol", expr("CAST(longCol AS FLOAT)")) + .withColumn("doubleCol", expr("CAST(longCol AS DOUBLE)")) + .withColumn("bigDecimalCol", expr("CAST(longCol AS DECIMAL(20, 5))")) + .withColumn("decimalCol", expr("CAST(longCol AS DECIMAL(18, 5))")) + .withColumn("dateCol", date_add(current_date(), fileNum)) + .withColumn("timestampCol", expr("TO_TIMESTAMP(dateCol)")) + .withColumn("stringCol", expr("CAST(longCol AS STRING)")); + appendAsFile(df); + } + } + + @Benchmark + @Threads(1) + public void readIntegersIcebergVectorized5k() { + withTableProperties( + tablePropsWithVectorizationEnabled(5000), + () -> { + String tableLocation = table().location(); + Dataset df = spark().read().format("iceberg").load(tableLocation).select("intCol"); + materialize(df); + }); + } + + @Benchmark + @Threads(1) + public void readIntegersSparkVectorized5k() { + withSQLConf( + sparkConfWithVectorizationEnabled(5000), + () -> { + Dataset df = spark().read().parquet(dataLocation()).select("intCol"); + materialize(df); + }); + } + + @Benchmark + @Threads(1) + public void readLongsIcebergVectorized5k() { + withTableProperties( + tablePropsWithVectorizationEnabled(5000), + () -> { + String tableLocation = table().location(); + Dataset df = spark().read().format("iceberg").load(tableLocation).select("longCol"); + materialize(df); + }); + } + + @Benchmark + @Threads(1) + public void readLongsSparkVectorized5k() { + withSQLConf( + sparkConfWithVectorizationEnabled(5000), + () -> { + Dataset df = spark().read().parquet(dataLocation()).select("longCol"); + materialize(df); + }); + } + + @Benchmark + @Threads(1) + public void readFloatsIcebergVectorized5k() { + withTableProperties( + tablePropsWithVectorizationEnabled(5000), + () -> { + String tableLocation = table().location(); + Dataset df = spark().read().format("iceberg").load(tableLocation).select("floatCol"); + materialize(df); + }); + } + + @Benchmark + @Threads(1) + public void readFloatsSparkVectorized5k() { + withSQLConf( + sparkConfWithVectorizationEnabled(5000), + () -> { + Dataset df = spark().read().parquet(dataLocation()).select("floatCol"); + materialize(df); + }); + } + + @Benchmark + @Threads(1) + public void readDoublesIcebergVectorized5k() { + withTableProperties( + tablePropsWithVectorizationEnabled(5000), + () -> { + String tableLocation = table().location(); + Dataset df = + spark().read().format("iceberg").load(tableLocation).select("doubleCol"); + materialize(df); + }); + } + + @Benchmark + @Threads(1) + public void readDoublesSparkVectorized5k() { + withSQLConf( + sparkConfWithVectorizationEnabled(5000), + () -> { + Dataset df = spark().read().parquet(dataLocation()).select("doubleCol"); + materialize(df); + }); + } + + @Benchmark + @Threads(1) + public void readDecimalsIcebergVectorized5k() { + withTableProperties( + tablePropsWithVectorizationEnabled(5000), + () -> { + String tableLocation = table().location(); + Dataset df = + spark().read().format("iceberg").load(tableLocation).select("decimalCol"); + materialize(df); + }); + } + + @Benchmark + @Threads(1) + public void readDecimalsSparkVectorized5k() { + withSQLConf( + sparkConfWithVectorizationEnabled(5000), + () -> { + Dataset df = spark().read().parquet(dataLocation()).select("decimalCol"); + materialize(df); + }); + } + + @Benchmark + @Threads(1) + public void readBigDecimalsIcebergVectorized5k() { + withTableProperties( + tablePropsWithVectorizationEnabled(5000), + () -> { + String tableLocation = table().location(); + Dataset df = + spark().read().format("iceberg").load(tableLocation).select("bigDecimalCol"); + materialize(df); + }); + } + + @Benchmark + @Threads(1) + public void readBigDecimalsSparkVectorized5k() { + withSQLConf( + sparkConfWithVectorizationEnabled(5000), + () -> { + Dataset df = spark().read().parquet(dataLocation()).select("bigDecimalCol"); + materialize(df); + }); + } + + @Benchmark + @Threads(1) + public void readDatesIcebergVectorized5k() { + withTableProperties( + tablePropsWithVectorizationEnabled(5000), + () -> { + String tableLocation = table().location(); + Dataset df = spark().read().format("iceberg").load(tableLocation).select("dateCol"); + materialize(df); + }); + } + + @Benchmark + @Threads(1) + public void readDatesSparkVectorized5k() { + withSQLConf( + sparkConfWithVectorizationEnabled(5000), + () -> { + Dataset df = spark().read().parquet(dataLocation()).select("dateCol"); + materialize(df); + }); + } + + @Benchmark + @Threads(1) + public void readTimestampsIcebergVectorized5k() { + withTableProperties( + tablePropsWithVectorizationEnabled(5000), + () -> { + String tableLocation = table().location(); + Dataset df = + spark().read().format("iceberg").load(tableLocation).select("timestampCol"); + materialize(df); + }); + } + + @Benchmark + @Threads(1) + public void readTimestampsSparkVectorized5k() { + withSQLConf( + sparkConfWithVectorizationEnabled(5000), + () -> { + Dataset df = spark().read().parquet(dataLocation()).select("timestampCol"); + materialize(df); + }); + } + + @Benchmark + @Threads(1) + public void readStringsIcebergVectorized5k() { + withTableProperties( + tablePropsWithVectorizationEnabled(5000), + () -> { + String tableLocation = table().location(); + Dataset df = + spark().read().format("iceberg").load(tableLocation).select("stringCol"); + materialize(df); + }); + } + + @Benchmark + @Threads(1) + public void readStringsSparkVectorized5k() { + withSQLConf( + sparkConfWithVectorizationEnabled(5000), + () -> { + Dataset df = spark().read().parquet(dataLocation()).select("stringCol"); + materialize(df); + }); + } + + private static Map tablePropsWithVectorizationEnabled(int batchSize) { + Map tableProperties = Maps.newHashMap(); + tableProperties.put(TableProperties.PARQUET_VECTORIZATION_ENABLED, "true"); + tableProperties.put(TableProperties.PARQUET_BATCH_SIZE, String.valueOf(batchSize)); + return tableProperties; + } + + private static Map sparkConfWithVectorizationEnabled(int batchSize) { + Map conf = Maps.newHashMap(); + conf.put(SQLConf.PARQUET_VECTORIZED_READER_ENABLED().key(), "true"); + conf.put(SQLConf.PARQUET_VECTORIZED_READER_BATCH_SIZE().key(), String.valueOf(batchSize)); + return conf; + } +} diff --git a/spark/v3.4/spark/src/jmh/java/org/apache/iceberg/spark/source/parquet/vectorized/VectorizedReadParquetDecimalBenchmark.java b/spark/v3.4/spark/src/jmh/java/org/apache/iceberg/spark/source/parquet/vectorized/VectorizedReadParquetDecimalBenchmark.java new file mode 100644 index 000000000000..e915da9d3c91 --- /dev/null +++ b/spark/v3.4/spark/src/jmh/java/org/apache/iceberg/spark/source/parquet/vectorized/VectorizedReadParquetDecimalBenchmark.java @@ -0,0 +1,207 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.source.parquet.vectorized; + +import static org.apache.iceberg.types.Types.NestedField.optional; +import static org.apache.spark.sql.functions.col; +import static org.apache.spark.sql.functions.expr; +import static org.apache.spark.sql.functions.lit; +import static org.apache.spark.sql.functions.pmod; +import static org.apache.spark.sql.functions.when; + +import java.io.IOException; +import java.util.Map; +import org.apache.hadoop.conf.Configuration; +import org.apache.iceberg.PartitionSpec; +import org.apache.iceberg.Schema; +import org.apache.iceberg.Table; +import org.apache.iceberg.TableProperties; +import org.apache.iceberg.hadoop.HadoopTables; +import org.apache.iceberg.relocated.com.google.common.collect.Maps; +import org.apache.iceberg.spark.source.IcebergSourceBenchmark; +import org.apache.iceberg.types.Types; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.internal.SQLConf; +import org.openjdk.jmh.annotations.Benchmark; +import org.openjdk.jmh.annotations.Setup; +import org.openjdk.jmh.annotations.TearDown; +import org.openjdk.jmh.annotations.Threads; + +/** + * Benchmark to compare performance of reading Parquet decimal data using vectorized Iceberg read + * path and the built-in file source in Spark. + * + *

To run this benchmark for spark-3.3: + * ./gradlew -DsparkVersions=3.3 :iceberg-spark:iceberg-spark-3.3_2.12:jmh \ + * -PjmhIncludeRegex=VectorizedReadParquetDecimalBenchmark \ + * -PjmhOutputPath=benchmark/results.txt + * + */ +public class VectorizedReadParquetDecimalBenchmark extends IcebergSourceBenchmark { + + static final int NUM_FILES = 5; + static final int NUM_ROWS_PER_FILE = 10_000_000; + + @Setup + public void setupBenchmark() { + setupSpark(); + appendData(); + // Allow unsafe memory access to avoid the costly check arrow does to check if index is within + // bounds + System.setProperty("arrow.enable_unsafe_memory_access", "true"); + // Disable expensive null check for every get(index) call. + // Iceberg manages nullability checks itself instead of relying on arrow. + System.setProperty("arrow.enable_null_check_for_get", "false"); + } + + @TearDown + public void tearDownBenchmark() throws IOException { + tearDownSpark(); + cleanupFiles(); + } + + @Override + protected Configuration initHadoopConf() { + return new Configuration(); + } + + @Override + protected Table initTable() { + Schema schema = + new Schema( + optional(1, "decimalCol1", Types.DecimalType.of(7, 2)), + optional(2, "decimalCol2", Types.DecimalType.of(15, 2)), + optional(3, "decimalCol3", Types.DecimalType.of(20, 2))); + PartitionSpec partitionSpec = PartitionSpec.unpartitioned(); + HadoopTables tables = new HadoopTables(hadoopConf()); + Map properties = parquetWriteProps(); + return tables.create(schema, partitionSpec, properties, newTableLocation()); + } + + Map parquetWriteProps() { + Map properties = Maps.newHashMap(); + properties.put(TableProperties.METADATA_COMPRESSION, "gzip"); + properties.put(TableProperties.PARQUET_DICT_SIZE_BYTES, "1"); + return properties; + } + + void appendData() { + for (int fileNum = 1; fileNum <= NUM_FILES; fileNum++) { + Dataset df = + spark() + .range(NUM_ROWS_PER_FILE) + .withColumn( + "longCol", + when(pmod(col("id"), lit(10)).equalTo(lit(0)), lit(null)).otherwise(col("id"))) + .drop("id") + .withColumn("decimalCol1", expr("CAST(longCol AS DECIMAL(7, 2))")) + .withColumn("decimalCol2", expr("CAST(longCol AS DECIMAL(15, 2))")) + .withColumn("decimalCol3", expr("CAST(longCol AS DECIMAL(20, 2))")) + .drop("longCol"); + appendAsFile(df); + } + } + + @Benchmark + @Threads(1) + public void readIntBackedDecimalsIcebergVectorized5k() { + withTableProperties( + tablePropsWithVectorizationEnabled(5000), + () -> { + String tableLocation = table().location(); + Dataset df = + spark().read().format("iceberg").load(tableLocation).select("decimalCol1"); + materialize(df); + }); + } + + @Benchmark + @Threads(1) + public void readIntBackedDecimalsSparkVectorized5k() { + withSQLConf( + sparkConfWithVectorizationEnabled(5000), + () -> { + Dataset df = spark().read().parquet(dataLocation()).select("decimalCol1"); + materialize(df); + }); + } + + @Benchmark + @Threads(1) + public void readLongBackedDecimalsIcebergVectorized5k() { + withTableProperties( + tablePropsWithVectorizationEnabled(5000), + () -> { + String tableLocation = table().location(); + Dataset df = + spark().read().format("iceberg").load(tableLocation).select("decimalCol2"); + materialize(df); + }); + } + + @Benchmark + @Threads(1) + public void readLongBackedDecimalsSparkVectorized5k() { + withSQLConf( + sparkConfWithVectorizationEnabled(5000), + () -> { + Dataset df = spark().read().parquet(dataLocation()).select("decimalCol2"); + materialize(df); + }); + } + + @Benchmark + @Threads(1) + public void readDecimalsIcebergVectorized5k() { + withTableProperties( + tablePropsWithVectorizationEnabled(5000), + () -> { + String tableLocation = table().location(); + Dataset df = + spark().read().format("iceberg").load(tableLocation).select("decimalCol3"); + materialize(df); + }); + } + + @Benchmark + @Threads(1) + public void readDecimalsSparkVectorized5k() { + withSQLConf( + sparkConfWithVectorizationEnabled(5000), + () -> { + Dataset df = spark().read().parquet(dataLocation()).select("decimalCol3"); + materialize(df); + }); + } + + private static Map tablePropsWithVectorizationEnabled(int batchSize) { + Map tableProperties = Maps.newHashMap(); + tableProperties.put(TableProperties.PARQUET_VECTORIZATION_ENABLED, "true"); + tableProperties.put(TableProperties.PARQUET_BATCH_SIZE, String.valueOf(batchSize)); + return tableProperties; + } + + private static Map sparkConfWithVectorizationEnabled(int batchSize) { + Map conf = Maps.newHashMap(); + conf.put(SQLConf.PARQUET_VECTORIZED_READER_ENABLED().key(), "true"); + conf.put(SQLConf.PARQUET_VECTORIZED_READER_BATCH_SIZE().key(), String.valueOf(batchSize)); + return conf; + } +} diff --git a/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/BaseCatalog.java b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/BaseCatalog.java new file mode 100644 index 000000000000..2e5e383baf42 --- /dev/null +++ b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/BaseCatalog.java @@ -0,0 +1,95 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark; + +import org.apache.iceberg.spark.functions.SparkFunctions; +import org.apache.iceberg.spark.procedures.SparkProcedures; +import org.apache.iceberg.spark.procedures.SparkProcedures.ProcedureBuilder; +import org.apache.iceberg.spark.source.HasIcebergCatalog; +import org.apache.spark.sql.catalyst.analysis.NoSuchFunctionException; +import org.apache.spark.sql.catalyst.analysis.NoSuchNamespaceException; +import org.apache.spark.sql.catalyst.analysis.NoSuchProcedureException; +import org.apache.spark.sql.connector.catalog.FunctionCatalog; +import org.apache.spark.sql.connector.catalog.Identifier; +import org.apache.spark.sql.connector.catalog.StagingTableCatalog; +import org.apache.spark.sql.connector.catalog.SupportsNamespaces; +import org.apache.spark.sql.connector.catalog.functions.UnboundFunction; +import org.apache.spark.sql.connector.iceberg.catalog.Procedure; +import org.apache.spark.sql.connector.iceberg.catalog.ProcedureCatalog; + +abstract class BaseCatalog + implements StagingTableCatalog, + ProcedureCatalog, + SupportsNamespaces, + HasIcebergCatalog, + FunctionCatalog { + + @Override + public Procedure loadProcedure(Identifier ident) throws NoSuchProcedureException { + String[] namespace = ident.namespace(); + String name = ident.name(); + + // namespace resolution is case insensitive until we have a way to configure case sensitivity in + // catalogs + if (isSystemNamespace(namespace)) { + ProcedureBuilder builder = SparkProcedures.newBuilder(name); + if (builder != null) { + return builder.withTableCatalog(this).build(); + } + } + + throw new NoSuchProcedureException(ident); + } + + @Override + public Identifier[] listFunctions(String[] namespace) throws NoSuchNamespaceException { + if (namespace.length == 0 || isSystemNamespace(namespace)) { + return SparkFunctions.list().stream() + .map(name -> Identifier.of(namespace, name)) + .toArray(Identifier[]::new); + } else if (namespaceExists(namespace)) { + return new Identifier[0]; + } + + throw new NoSuchNamespaceException(namespace); + } + + @Override + public UnboundFunction loadFunction(Identifier ident) throws NoSuchFunctionException { + String[] namespace = ident.namespace(); + String name = ident.name(); + + // Allow for empty namespace, as Spark's storage partitioned joins look up + // the corresponding functions to generate transforms for partitioning + // with an empty namespace, such as `bucket`. + // Otherwise, use `system` namespace. + if (namespace.length == 0 || isSystemNamespace(namespace)) { + UnboundFunction func = SparkFunctions.load(name); + if (func != null) { + return func; + } + } + + throw new NoSuchFunctionException(ident); + } + + private static boolean isSystemNamespace(String[] namespace) { + return namespace.length == 1 && namespace[0].equalsIgnoreCase("system"); + } +} diff --git a/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/BaseFileRewriteCoordinator.java b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/BaseFileRewriteCoordinator.java new file mode 100644 index 000000000000..45c46f1a3e12 --- /dev/null +++ b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/BaseFileRewriteCoordinator.java @@ -0,0 +1,89 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark; + +import java.util.Map; +import java.util.Set; +import java.util.stream.Collectors; +import org.apache.iceberg.ContentFile; +import org.apache.iceberg.HasTableOperations; +import org.apache.iceberg.Table; +import org.apache.iceberg.TableOperations; +import org.apache.iceberg.exceptions.ValidationException; +import org.apache.iceberg.relocated.com.google.common.collect.Maps; +import org.apache.iceberg.util.Pair; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +abstract class BaseFileRewriteCoordinator> { + + private static final Logger LOG = LoggerFactory.getLogger(BaseFileRewriteCoordinator.class); + + private final Map, Set> resultMap = Maps.newConcurrentMap(); + + /** + * Called to persist the output of a rewrite action for a specific group. Since the write is done + * via a Spark Datasource, we have to propagate the result through this side-effect call. + * + * @param table table where the rewrite is occurring + * @param fileSetId the id used to identify the source set of files being rewritten + * @param newFiles the new files which have been written + */ + public void stageRewrite(Table table, String fileSetId, Set newFiles) { + LOG.debug( + "Staging the output for {} - fileset {} with {} files", + table.name(), + fileSetId, + newFiles.size()); + Pair id = toId(table, fileSetId); + resultMap.put(id, newFiles); + } + + public Set fetchNewFiles(Table table, String fileSetId) { + Pair id = toId(table, fileSetId); + Set result = resultMap.get(id); + ValidationException.check( + result != null, "No results for rewrite of file set %s in table %s", fileSetId, table); + + return result; + } + + public void clearRewrite(Table table, String fileSetId) { + LOG.debug("Removing entry for {} - id {}", table.name(), fileSetId); + Pair id = toId(table, fileSetId); + resultMap.remove(id); + } + + public Set fetchSetIds(Table table) { + return resultMap.keySet().stream() + .filter(e -> e.first().equals(tableUUID(table))) + .map(Pair::second) + .collect(Collectors.toSet()); + } + + private Pair toId(Table table, String setId) { + String tableUUID = tableUUID(table); + return Pair.of(tableUUID, setId); + } + + private String tableUUID(Table table) { + TableOperations ops = ((HasTableOperations) table).operations(); + return ops.current().uuid(); + } +} diff --git a/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/ChangelogIterator.java b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/ChangelogIterator.java new file mode 100644 index 000000000000..b1f12af272c3 --- /dev/null +++ b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/ChangelogIterator.java @@ -0,0 +1,210 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark; + +import java.util.Arrays; +import java.util.Iterator; +import java.util.List; +import java.util.Objects; +import java.util.stream.Collectors; +import org.apache.iceberg.ChangelogOperation; +import org.apache.iceberg.MetadataColumns; +import org.apache.iceberg.relocated.com.google.common.collect.Iterators; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.RowFactory; +import org.apache.spark.sql.catalyst.expressions.GenericRow; +import org.apache.spark.sql.types.StructType; + +/** + * An iterator that transforms rows from changelog tables within a single Spark task. It assumes + * that rows are sorted by identifier columns and change type. + * + *

It removes the carry-over rows. Carry-over rows are the result of a removal and insertion of + * the same row within an operation because of the copy-on-write mechanism. For example, given a + * file which contains row1 (id=1, data='a') and row2 (id=2, data='b'). A copy-on-write delete of + * row2 would require erasing this file and preserving row1 in a new file. The change-log table + * would report this as (id=1, data='a', op='DELETE') and (id=1, data='a', op='INSERT'), despite it + * not being an actual change to the table. The iterator finds the carry-over rows and removes them + * from the result. + * + *

This iterator also finds delete/insert rows which represent an update, and converts them into + * update records. For example, these two rows + * + *

    + *
  • (id=1, data='a', op='DELETE') + *
  • (id=1, data='b', op='INSERT') + *
+ * + *

will be marked as update-rows: + * + *

    + *
  • (id=1, data='a', op='UPDATE_BEFORE') + *
  • (id=1, data='b', op='UPDATE_AFTER') + *
+ */ +public class ChangelogIterator implements Iterator { + private static final String DELETE = ChangelogOperation.DELETE.name(); + private static final String INSERT = ChangelogOperation.INSERT.name(); + private static final String UPDATE_BEFORE = ChangelogOperation.UPDATE_BEFORE.name(); + private static final String UPDATE_AFTER = ChangelogOperation.UPDATE_AFTER.name(); + + private final Iterator rowIterator; + private final int changeTypeIndex; + private final List identifierFieldIdx; + private final int[] indicesForIdentifySameRow; + + private Row cachedRow = null; + + private ChangelogIterator( + Iterator rowIterator, StructType rowType, String[] identifierFields) { + this.rowIterator = rowIterator; + this.changeTypeIndex = rowType.fieldIndex(MetadataColumns.CHANGE_TYPE.name()); + this.identifierFieldIdx = + Arrays.stream(identifierFields) + .map(column -> rowType.fieldIndex(column.toString())) + .collect(Collectors.toList()); + this.indicesForIdentifySameRow = generateIndicesForIdentifySameRow(rowType.size()); + } + + /** + * Creates an iterator for records of a changelog table. + * + * @param rowIterator the iterator of rows from a changelog table + * @param rowType the schema of the rows + * @param identifierFields the names of the identifier columns, which determine if rows are the + * same + * @return a new {@link ChangelogIterator} instance concatenated with the null-removal iterator + */ + public static Iterator create( + Iterator rowIterator, StructType rowType, String[] identifierFields) { + ChangelogIterator changelogIterator = + new ChangelogIterator(rowIterator, rowType, identifierFields); + return Iterators.filter(changelogIterator, Objects::nonNull); + } + + @Override + public boolean hasNext() { + if (cachedRow != null) { + return true; + } + return rowIterator.hasNext(); + } + + @Override + public Row next() { + // if there is an updated cached row, return it directly + if (cachedUpdateRecord(cachedRow)) { + Row row = cachedRow; + cachedRow = null; + return row; + } + + Row currentRow = currentRow(); + + if (currentRow.getString(changeTypeIndex).equals(DELETE) && rowIterator.hasNext()) { + Row nextRow = rowIterator.next(); + cachedRow = nextRow; + + if (isUpdateOrCarryoverRecord(currentRow, nextRow)) { + if (isCarryoverRecord(currentRow, nextRow)) { + // set carry-over rows to null for filtering out later + currentRow = null; + cachedRow = null; + } else { + currentRow = modify(currentRow, changeTypeIndex, UPDATE_BEFORE); + cachedRow = modify(nextRow, changeTypeIndex, UPDATE_AFTER); + } + } + } + + return currentRow; + } + + private Row modify(Row row, int valueIndex, Object value) { + if (row instanceof GenericRow) { + GenericRow genericRow = (GenericRow) row; + genericRow.values()[valueIndex] = value; + return genericRow; + } else { + Object[] values = new Object[row.size()]; + for (int index = 0; index < row.size(); index++) { + values[index] = row.get(index); + } + values[valueIndex] = value; + return RowFactory.create(values); + } + } + + private int[] generateIndicesForIdentifySameRow(int columnSize) { + int[] indices = new int[columnSize - 1]; + for (int i = 0; i < indices.length; i++) { + if (i < changeTypeIndex) { + indices[i] = i; + } else { + indices[i] = i + 1; + } + } + return indices; + } + + private boolean isCarryoverRecord(Row currentRow, Row nextRow) { + for (int idx : indicesForIdentifySameRow) { + if (!isColumnSame(currentRow, nextRow, idx)) { + return false; + } + } + + return true; + } + + private boolean cachedUpdateRecord(Row cachedRecord) { + return cachedRecord != null + && !cachedRecord.getString(changeTypeIndex).equals(DELETE) + && !cachedRecord.getString(changeTypeIndex).equals(INSERT); + } + + private Row currentRow() { + if (cachedRow != null) { + Row row = cachedRow; + cachedRow = null; + return row; + } else { + return rowIterator.next(); + } + } + + private boolean isUpdateOrCarryoverRecord(Row currentRow, Row nextRow) { + return sameLogicalRow(currentRow, nextRow) + && currentRow.getString(changeTypeIndex).equals(DELETE) + && nextRow.getString(changeTypeIndex).equals(INSERT); + } + + private boolean sameLogicalRow(Row currentRow, Row nextRow) { + for (int idx : identifierFieldIdx) { + if (!isColumnSame(currentRow, nextRow, idx)) { + return false; + } + } + return true; + } + + private static boolean isColumnSame(Row currentRow, Row nextRow, int idx) { + return Objects.equals(nextRow.get(idx), currentRow.get(idx)); + } +} diff --git a/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/CommitMetadata.java b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/CommitMetadata.java new file mode 100644 index 000000000000..641b957d1176 --- /dev/null +++ b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/CommitMetadata.java @@ -0,0 +1,58 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark; + +import java.util.Map; +import java.util.concurrent.Callable; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap; +import org.apache.iceberg.util.ExceptionUtil; + +/** utility class to accept thread local commit properties */ +public class CommitMetadata { + + private CommitMetadata() {} + + private static final ThreadLocal> COMMIT_PROPERTIES = + ThreadLocal.withInitial(ImmutableMap::of); + + /** + * running the code wrapped as a caller, and any snapshot committed within the callable object + * will be attached with the metadata defined in properties + * + * @param properties extra commit metadata to attach to the snapshot committed within callable + * @param callable the code to be executed + * @param exClass the expected type of exception which would be thrown from callable + */ + public static R withCommitProperties( + Map properties, Callable callable, Class exClass) throws E { + COMMIT_PROPERTIES.set(properties); + try { + return callable.call(); + } catch (Throwable e) { + ExceptionUtil.castAndThrow(e, exClass); + return null; + } finally { + COMMIT_PROPERTIES.set(ImmutableMap.of()); + } + } + + public static Map commitProperties() { + return COMMIT_PROPERTIES.get(); + } +} diff --git a/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/ExtendedParser.java b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/ExtendedParser.java new file mode 100644 index 000000000000..19b3dd8f49be --- /dev/null +++ b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/ExtendedParser.java @@ -0,0 +1,70 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark; + +import java.util.List; +import org.apache.iceberg.NullOrder; +import org.apache.iceberg.SortDirection; +import org.apache.iceberg.expressions.Term; +import org.apache.spark.sql.AnalysisException; +import org.apache.spark.sql.SparkSession; +import org.apache.spark.sql.catalyst.parser.ParserInterface; + +public interface ExtendedParser extends ParserInterface { + class RawOrderField { + private final Term term; + private final SortDirection direction; + private final NullOrder nullOrder; + + public RawOrderField(Term term, SortDirection direction, NullOrder nullOrder) { + this.term = term; + this.direction = direction; + this.nullOrder = nullOrder; + } + + public Term term() { + return term; + } + + public SortDirection direction() { + return direction; + } + + public NullOrder nullOrder() { + return nullOrder; + } + } + + static List parseSortOrder(SparkSession spark, String orderString) { + if (spark.sessionState().sqlParser() instanceof ExtendedParser) { + ExtendedParser parser = (ExtendedParser) spark.sessionState().sqlParser(); + try { + return parser.parseSortOrder(orderString); + } catch (AnalysisException e) { + throw new IllegalArgumentException( + String.format("Unable to parse sortOrder: %s", orderString), e); + } + } else { + throw new IllegalStateException( + "Cannot parse order: parser is not an Iceberg ExtendedParser"); + } + } + + List parseSortOrder(String orderString) throws AnalysisException; +} diff --git a/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/FileRewriteCoordinator.java b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/FileRewriteCoordinator.java new file mode 100644 index 000000000000..4f1d0fffcbd8 --- /dev/null +++ b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/FileRewriteCoordinator.java @@ -0,0 +1,46 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark; + +import java.util.Set; +import org.apache.iceberg.DataFile; +import org.apache.iceberg.Table; + +public class FileRewriteCoordinator extends BaseFileRewriteCoordinator { + + private static final FileRewriteCoordinator INSTANCE = new FileRewriteCoordinator(); + + private FileRewriteCoordinator() {} + + public static FileRewriteCoordinator get() { + return INSTANCE; + } + + /** @deprecated will be removed in 1.4.0; use {@link #fetchNewFiles(Table, String)} instead. */ + @Deprecated + public Set fetchNewDataFiles(Table table, String fileSetId) { + return fetchNewFiles(table, fileSetId); + } + + /** @deprecated will be removed in 1.4.0; use {@link #fetchSetIds(Table)} instead */ + @Deprecated + public Set fetchSetIDs(Table table) { + return fetchSetIds(table); + } +} diff --git a/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/FileScanTaskSetManager.java b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/FileScanTaskSetManager.java new file mode 100644 index 000000000000..782fde510eec --- /dev/null +++ b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/FileScanTaskSetManager.java @@ -0,0 +1,79 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark; + +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.stream.Collectors; +import org.apache.iceberg.FileScanTask; +import org.apache.iceberg.HasTableOperations; +import org.apache.iceberg.Table; +import org.apache.iceberg.TableOperations; +import org.apache.iceberg.relocated.com.google.common.base.Preconditions; +import org.apache.iceberg.relocated.com.google.common.collect.Maps; +import org.apache.iceberg.util.Pair; + +/** @deprecated will be removed in 1.3.0, use {@link ScanTaskSetManager} instead */ +@Deprecated +public class FileScanTaskSetManager { + + private static final FileScanTaskSetManager INSTANCE = new FileScanTaskSetManager(); + + private final Map, List> tasksMap = Maps.newConcurrentMap(); + + private FileScanTaskSetManager() {} + + public static FileScanTaskSetManager get() { + return INSTANCE; + } + + public void stageTasks(Table table, String setID, List tasks) { + Preconditions.checkArgument( + tasks != null && tasks.size() > 0, "Cannot stage null or empty tasks"); + Pair id = toID(table, setID); + tasksMap.put(id, tasks); + } + + public List fetchTasks(Table table, String setID) { + Pair id = toID(table, setID); + return tasksMap.get(id); + } + + public List removeTasks(Table table, String setID) { + Pair id = toID(table, setID); + return tasksMap.remove(id); + } + + public Set fetchSetIDs(Table table) { + return tasksMap.keySet().stream() + .filter(e -> e.first().equals(tableUUID(table))) + .map(Pair::second) + .collect(Collectors.toSet()); + } + + private String tableUUID(Table table) { + TableOperations ops = ((HasTableOperations) table).operations(); + return ops.current().uuid(); + } + + private Pair toID(Table table, String setID) { + return Pair.of(tableUUID(table), setID); + } +} diff --git a/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/IcebergSpark.java b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/IcebergSpark.java new file mode 100644 index 000000000000..eb2420c0b254 --- /dev/null +++ b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/IcebergSpark.java @@ -0,0 +1,56 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark; + +import java.util.function.Function; +import org.apache.iceberg.transforms.Transforms; +import org.apache.iceberg.types.Type; +import org.apache.spark.sql.SparkSession; +import org.apache.spark.sql.types.DataType; +import org.apache.spark.sql.types.DataTypes; + +public class IcebergSpark { + private IcebergSpark() {} + + public static void registerBucketUDF( + SparkSession session, String funcName, DataType sourceType, int numBuckets) { + SparkTypeToType typeConverter = new SparkTypeToType(); + Type sourceIcebergType = typeConverter.atomic(sourceType); + Function bucket = Transforms.bucket(numBuckets).bind(sourceIcebergType); + session + .udf() + .register( + funcName, + value -> bucket.apply(SparkValueConverter.convert(sourceIcebergType, value)), + DataTypes.IntegerType); + } + + public static void registerTruncateUDF( + SparkSession session, String funcName, DataType sourceType, int width) { + SparkTypeToType typeConverter = new SparkTypeToType(); + Type sourceIcebergType = typeConverter.atomic(sourceType); + Function truncate = Transforms.truncate(width).bind(sourceIcebergType); + session + .udf() + .register( + funcName, + value -> truncate.apply(SparkValueConverter.convert(sourceIcebergType, value)), + sourceType); + } +} diff --git a/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/JobGroupInfo.java b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/JobGroupInfo.java new file mode 100644 index 000000000000..c0756d924e2f --- /dev/null +++ b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/JobGroupInfo.java @@ -0,0 +1,44 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark; + +/** Captures information about the current job which is used for displaying on the UI */ +public class JobGroupInfo { + private String groupId; + private String description; + private boolean interruptOnCancel; + + public JobGroupInfo(String groupId, String desc, boolean interruptOnCancel) { + this.groupId = groupId; + this.description = desc; + this.interruptOnCancel = interruptOnCancel; + } + + public String groupId() { + return groupId; + } + + public String description() { + return description; + } + + public boolean interruptOnCancel() { + return interruptOnCancel; + } +} diff --git a/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/JobGroupUtils.java b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/JobGroupUtils.java new file mode 100644 index 000000000000..dc8ba69d40a8 --- /dev/null +++ b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/JobGroupUtils.java @@ -0,0 +1,46 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark; + +import org.apache.spark.SparkContext; +import org.apache.spark.SparkContext$; + +public class JobGroupUtils { + + private static final String JOB_GROUP_ID = SparkContext$.MODULE$.SPARK_JOB_GROUP_ID(); + private static final String JOB_GROUP_DESC = SparkContext$.MODULE$.SPARK_JOB_DESCRIPTION(); + private static final String JOB_INTERRUPT_ON_CANCEL = + SparkContext$.MODULE$.SPARK_JOB_INTERRUPT_ON_CANCEL(); + + private JobGroupUtils() {} + + public static JobGroupInfo getJobGroupInfo(SparkContext sparkContext) { + String groupId = sparkContext.getLocalProperty(JOB_GROUP_ID); + String description = sparkContext.getLocalProperty(JOB_GROUP_DESC); + String interruptOnCancel = sparkContext.getLocalProperty(JOB_INTERRUPT_ON_CANCEL); + return new JobGroupInfo(groupId, description, Boolean.parseBoolean(interruptOnCancel)); + } + + public static void setJobGroupInfo(SparkContext sparkContext, JobGroupInfo info) { + sparkContext.setLocalProperty(JOB_GROUP_ID, info.groupId()); + sparkContext.setLocalProperty(JOB_GROUP_DESC, info.description()); + sparkContext.setLocalProperty( + JOB_INTERRUPT_ON_CANCEL, String.valueOf(info.interruptOnCancel())); + } +} diff --git a/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/PathIdentifier.java b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/PathIdentifier.java new file mode 100644 index 000000000000..110af6b87de5 --- /dev/null +++ b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/PathIdentifier.java @@ -0,0 +1,57 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark; + +import java.util.List; +import org.apache.iceberg.relocated.com.google.common.base.Joiner; +import org.apache.iceberg.relocated.com.google.common.base.Splitter; +import org.apache.iceberg.relocated.com.google.common.collect.Iterables; +import org.apache.spark.sql.connector.catalog.Identifier; + +public class PathIdentifier implements Identifier { + private static final Splitter SPLIT = Splitter.on("/"); + private static final Joiner JOIN = Joiner.on("/"); + private final String[] namespace; + private final String location; + private final String name; + + public PathIdentifier(String location) { + this.location = location; + List pathParts = SPLIT.splitToList(location); + name = Iterables.getLast(pathParts); + namespace = + pathParts.size() > 1 + ? new String[] {JOIN.join(pathParts.subList(0, pathParts.size() - 1))} + : new String[0]; + } + + @Override + public String[] namespace() { + return namespace; + } + + @Override + public String name() { + return name; + } + + public String location() { + return location; + } +} diff --git a/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/PositionDeletesRewriteCoordinator.java b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/PositionDeletesRewriteCoordinator.java new file mode 100644 index 000000000000..c7568005e22f --- /dev/null +++ b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/PositionDeletesRewriteCoordinator.java @@ -0,0 +1,33 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark; + +import org.apache.iceberg.DeleteFile; + +public class PositionDeletesRewriteCoordinator extends BaseFileRewriteCoordinator { + + private static final PositionDeletesRewriteCoordinator INSTANCE = + new PositionDeletesRewriteCoordinator(); + + private PositionDeletesRewriteCoordinator() {} + + public static PositionDeletesRewriteCoordinator get() { + return INSTANCE; + } +} diff --git a/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/PruneColumnsWithReordering.java b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/PruneColumnsWithReordering.java new file mode 100644 index 000000000000..cdc0bf5f3cad --- /dev/null +++ b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/PruneColumnsWithReordering.java @@ -0,0 +1,275 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark; + +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.function.Supplier; +import org.apache.iceberg.Schema; +import org.apache.iceberg.relocated.com.google.common.base.Preconditions; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap; +import org.apache.iceberg.relocated.com.google.common.collect.Lists; +import org.apache.iceberg.relocated.com.google.common.collect.Maps; +import org.apache.iceberg.types.Type; +import org.apache.iceberg.types.Type.TypeID; +import org.apache.iceberg.types.TypeUtil; +import org.apache.iceberg.types.Types; +import org.apache.spark.sql.types.ArrayType; +import org.apache.spark.sql.types.BinaryType; +import org.apache.spark.sql.types.BooleanType; +import org.apache.spark.sql.types.DataType; +import org.apache.spark.sql.types.DateType; +import org.apache.spark.sql.types.DecimalType; +import org.apache.spark.sql.types.DoubleType; +import org.apache.spark.sql.types.FloatType; +import org.apache.spark.sql.types.IntegerType; +import org.apache.spark.sql.types.LongType; +import org.apache.spark.sql.types.MapType; +import org.apache.spark.sql.types.StringType; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; +import org.apache.spark.sql.types.TimestampType; + +public class PruneColumnsWithReordering extends TypeUtil.CustomOrderSchemaVisitor { + private final StructType requestedType; + private final Set filterRefs; + private DataType current = null; + + PruneColumnsWithReordering(StructType requestedType, Set filterRefs) { + this.requestedType = requestedType; + this.filterRefs = filterRefs; + } + + @Override + public Type schema(Schema schema, Supplier structResult) { + this.current = requestedType; + try { + return structResult.get(); + } finally { + this.current = null; + } + } + + @Override + public Type struct(Types.StructType struct, Iterable fieldResults) { + Preconditions.checkNotNull( + struct, "Cannot prune null struct. Pruning must start with a schema."); + Preconditions.checkArgument(current instanceof StructType, "Not a struct: %s", current); + StructType requestedStruct = (StructType) current; + + List fields = struct.fields(); + List types = Lists.newArrayList(fieldResults); + + boolean changed = false; + // use a LinkedHashMap to preserve the original order of filter fields that are not projected + Map projectedFields = Maps.newLinkedHashMap(); + for (int i = 0; i < fields.size(); i += 1) { + Types.NestedField field = fields.get(i); + Type type = types.get(i); + + if (type == null) { + changed = true; + + } else if (field.type() == type) { + projectedFields.put(field.name(), field); + + } else if (field.isOptional()) { + changed = true; + projectedFields.put( + field.name(), Types.NestedField.optional(field.fieldId(), field.name(), type)); + + } else { + changed = true; + projectedFields.put( + field.name(), Types.NestedField.required(field.fieldId(), field.name(), type)); + } + } + + // Construct a new struct with the projected struct's order + boolean reordered = false; + StructField[] requestedFields = requestedStruct.fields(); + List newFields = Lists.newArrayListWithExpectedSize(requestedFields.length); + for (int i = 0; i < requestedFields.length; i += 1) { + // fields are resolved by name because Spark only sees the current table schema. + String name = requestedFields[i].name(); + if (!fields.get(i).name().equals(name)) { + reordered = true; + } + newFields.add(projectedFields.remove(name)); + } + + // Add remaining filter fields that were not explicitly projected + if (!projectedFields.isEmpty()) { + newFields.addAll(projectedFields.values()); + changed = true; // order probably changed + } + + if (reordered || changed) { + return Types.StructType.of(newFields); + } + + return struct; + } + + @Override + public Type field(Types.NestedField field, Supplier fieldResult) { + Preconditions.checkArgument(current instanceof StructType, "Not a struct: %s", current); + StructType requestedStruct = (StructType) current; + + // fields are resolved by name because Spark only sees the current table schema. + if (requestedStruct.getFieldIndex(field.name()).isEmpty()) { + // make sure that filter fields are projected even if they aren't in the requested schema. + if (filterRefs.contains(field.fieldId())) { + return field.type(); + } + return null; + } + + int fieldIndex = requestedStruct.fieldIndex(field.name()); + StructField requestedField = requestedStruct.fields()[fieldIndex]; + + Preconditions.checkArgument( + requestedField.nullable() || field.isRequired(), + "Cannot project an optional field as non-null: %s", + field.name()); + + this.current = requestedField.dataType(); + try { + return fieldResult.get(); + } catch (IllegalArgumentException e) { + throw new IllegalArgumentException( + "Invalid projection for field " + field.name() + ": " + e.getMessage(), e); + } finally { + this.current = requestedStruct; + } + } + + @Override + public Type list(Types.ListType list, Supplier elementResult) { + Preconditions.checkArgument(current instanceof ArrayType, "Not an array: %s", current); + ArrayType requestedArray = (ArrayType) current; + + Preconditions.checkArgument( + requestedArray.containsNull() || !list.isElementOptional(), + "Cannot project an array of optional elements as required elements: %s", + requestedArray); + + this.current = requestedArray.elementType(); + try { + Type elementType = elementResult.get(); + if (list.elementType() == elementType) { + return list; + } + + // must be a projected element type, create a new list + if (list.isElementOptional()) { + return Types.ListType.ofOptional(list.elementId(), elementType); + } else { + return Types.ListType.ofRequired(list.elementId(), elementType); + } + } finally { + this.current = requestedArray; + } + } + + @Override + public Type map(Types.MapType map, Supplier keyResult, Supplier valueResult) { + Preconditions.checkArgument(current instanceof MapType, "Not a map: %s", current); + MapType requestedMap = (MapType) current; + + Preconditions.checkArgument( + requestedMap.valueContainsNull() || !map.isValueOptional(), + "Cannot project a map of optional values as required values: %s", + map); + Preconditions.checkArgument( + StringType.class.isInstance(requestedMap.keyType()), + "Invalid map key type (not string): %s", + requestedMap.keyType()); + + this.current = requestedMap.valueType(); + try { + Type valueType = valueResult.get(); + if (map.valueType() == valueType) { + return map; + } + + if (map.isValueOptional()) { + return Types.MapType.ofOptional(map.keyId(), map.valueId(), map.keyType(), valueType); + } else { + return Types.MapType.ofRequired(map.keyId(), map.valueId(), map.keyType(), valueType); + } + } finally { + this.current = requestedMap; + } + } + + @Override + public Type primitive(Type.PrimitiveType primitive) { + Class expectedType = TYPES.get(primitive.typeId()); + Preconditions.checkArgument( + expectedType != null && expectedType.isInstance(current), + "Cannot project %s to incompatible type: %s", + primitive, + current); + + // additional checks based on type + switch (primitive.typeId()) { + case DECIMAL: + Types.DecimalType decimal = (Types.DecimalType) primitive; + DecimalType requestedDecimal = (DecimalType) current; + Preconditions.checkArgument( + requestedDecimal.scale() == decimal.scale(), + "Cannot project decimal with incompatible scale: %s != %s", + requestedDecimal.scale(), + decimal.scale()); + Preconditions.checkArgument( + requestedDecimal.precision() >= decimal.precision(), + "Cannot project decimal with incompatible precision: %s < %s", + requestedDecimal.precision(), + decimal.precision()); + break; + case TIMESTAMP: + Types.TimestampType timestamp = (Types.TimestampType) primitive; + Preconditions.checkArgument( + timestamp.shouldAdjustToUTC(), + "Cannot project timestamp (without time zone) as timestamptz (with time zone)"); + break; + default: + } + + return primitive; + } + + private static final ImmutableMap> TYPES = + ImmutableMap.>builder() + .put(TypeID.BOOLEAN, BooleanType.class) + .put(TypeID.INTEGER, IntegerType.class) + .put(TypeID.LONG, LongType.class) + .put(TypeID.FLOAT, FloatType.class) + .put(TypeID.DOUBLE, DoubleType.class) + .put(TypeID.DATE, DateType.class) + .put(TypeID.TIMESTAMP, TimestampType.class) + .put(TypeID.DECIMAL, DecimalType.class) + .put(TypeID.UUID, StringType.class) + .put(TypeID.STRING, StringType.class) + .put(TypeID.FIXED, BinaryType.class) + .put(TypeID.BINARY, BinaryType.class) + .buildOrThrow(); +} diff --git a/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/PruneColumnsWithoutReordering.java b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/PruneColumnsWithoutReordering.java new file mode 100644 index 000000000000..a6de035c466e --- /dev/null +++ b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/PruneColumnsWithoutReordering.java @@ -0,0 +1,240 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark; + +import java.util.List; +import java.util.Set; +import java.util.function.Supplier; +import org.apache.iceberg.Schema; +import org.apache.iceberg.relocated.com.google.common.base.Preconditions; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap; +import org.apache.iceberg.relocated.com.google.common.collect.Lists; +import org.apache.iceberg.types.Type; +import org.apache.iceberg.types.Type.TypeID; +import org.apache.iceberg.types.TypeUtil; +import org.apache.iceberg.types.Types; +import org.apache.spark.sql.types.ArrayType; +import org.apache.spark.sql.types.BinaryType; +import org.apache.spark.sql.types.BooleanType; +import org.apache.spark.sql.types.DataType; +import org.apache.spark.sql.types.DateType; +import org.apache.spark.sql.types.DecimalType; +import org.apache.spark.sql.types.DoubleType; +import org.apache.spark.sql.types.FloatType; +import org.apache.spark.sql.types.IntegerType; +import org.apache.spark.sql.types.LongType; +import org.apache.spark.sql.types.MapType; +import org.apache.spark.sql.types.StringType; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; +import org.apache.spark.sql.types.TimestampType; + +public class PruneColumnsWithoutReordering extends TypeUtil.CustomOrderSchemaVisitor { + private final StructType requestedType; + private final Set filterRefs; + private DataType current = null; + + PruneColumnsWithoutReordering(StructType requestedType, Set filterRefs) { + this.requestedType = requestedType; + this.filterRefs = filterRefs; + } + + @Override + public Type schema(Schema schema, Supplier structResult) { + this.current = requestedType; + try { + return structResult.get(); + } finally { + this.current = null; + } + } + + @Override + public Type struct(Types.StructType struct, Iterable fieldResults) { + Preconditions.checkNotNull( + struct, "Cannot prune null struct. Pruning must start with a schema."); + Preconditions.checkArgument(current instanceof StructType, "Not a struct: %s", current); + + List fields = struct.fields(); + List types = Lists.newArrayList(fieldResults); + + boolean changed = false; + List newFields = Lists.newArrayListWithExpectedSize(types.size()); + for (int i = 0; i < fields.size(); i += 1) { + Types.NestedField field = fields.get(i); + Type type = types.get(i); + + if (type == null) { + changed = true; + + } else if (field.type() == type) { + newFields.add(field); + + } else if (field.isOptional()) { + changed = true; + newFields.add(Types.NestedField.optional(field.fieldId(), field.name(), type)); + + } else { + changed = true; + newFields.add(Types.NestedField.required(field.fieldId(), field.name(), type)); + } + } + + if (changed) { + return Types.StructType.of(newFields); + } + + return struct; + } + + @Override + public Type field(Types.NestedField field, Supplier fieldResult) { + Preconditions.checkArgument(current instanceof StructType, "Not a struct: %s", current); + StructType requestedStruct = (StructType) current; + + // fields are resolved by name because Spark only sees the current table schema. + if (requestedStruct.getFieldIndex(field.name()).isEmpty()) { + // make sure that filter fields are projected even if they aren't in the requested schema. + if (filterRefs.contains(field.fieldId())) { + return field.type(); + } + return null; + } + + int fieldIndex = requestedStruct.fieldIndex(field.name()); + StructField requestedField = requestedStruct.fields()[fieldIndex]; + + Preconditions.checkArgument( + requestedField.nullable() || field.isRequired(), + "Cannot project an optional field as non-null: %s", + field.name()); + + this.current = requestedField.dataType(); + try { + return fieldResult.get(); + } catch (IllegalArgumentException e) { + throw new IllegalArgumentException( + "Invalid projection for field " + field.name() + ": " + e.getMessage(), e); + } finally { + this.current = requestedStruct; + } + } + + @Override + public Type list(Types.ListType list, Supplier elementResult) { + Preconditions.checkArgument(current instanceof ArrayType, "Not an array: %s", current); + ArrayType requestedArray = (ArrayType) current; + + Preconditions.checkArgument( + requestedArray.containsNull() || !list.isElementOptional(), + "Cannot project an array of optional elements as required elements: %s", + requestedArray); + + this.current = requestedArray.elementType(); + try { + Type elementType = elementResult.get(); + if (list.elementType() == elementType) { + return list; + } + + // must be a projected element type, create a new list + if (list.isElementOptional()) { + return Types.ListType.ofOptional(list.elementId(), elementType); + } else { + return Types.ListType.ofRequired(list.elementId(), elementType); + } + } finally { + this.current = requestedArray; + } + } + + @Override + public Type map(Types.MapType map, Supplier keyResult, Supplier valueResult) { + Preconditions.checkArgument(current instanceof MapType, "Not a map: %s", current); + MapType requestedMap = (MapType) current; + + Preconditions.checkArgument( + requestedMap.valueContainsNull() || !map.isValueOptional(), + "Cannot project a map of optional values as required values: %s", + map); + + this.current = requestedMap.valueType(); + try { + Type valueType = valueResult.get(); + if (map.valueType() == valueType) { + return map; + } + + if (map.isValueOptional()) { + return Types.MapType.ofOptional(map.keyId(), map.valueId(), map.keyType(), valueType); + } else { + return Types.MapType.ofRequired(map.keyId(), map.valueId(), map.keyType(), valueType); + } + } finally { + this.current = requestedMap; + } + } + + @Override + public Type primitive(Type.PrimitiveType primitive) { + Class expectedType = TYPES.get(primitive.typeId()); + Preconditions.checkArgument( + expectedType != null && expectedType.isInstance(current), + "Cannot project %s to incompatible type: %s", + primitive, + current); + + // additional checks based on type + switch (primitive.typeId()) { + case DECIMAL: + Types.DecimalType decimal = (Types.DecimalType) primitive; + DecimalType requestedDecimal = (DecimalType) current; + Preconditions.checkArgument( + requestedDecimal.scale() == decimal.scale(), + "Cannot project decimal with incompatible scale: %s != %s", + requestedDecimal.scale(), + decimal.scale()); + Preconditions.checkArgument( + requestedDecimal.precision() >= decimal.precision(), + "Cannot project decimal with incompatible precision: %s < %s", + requestedDecimal.precision(), + decimal.precision()); + break; + default: + } + + return primitive; + } + + private static final ImmutableMap> TYPES = + ImmutableMap.>builder() + .put(TypeID.BOOLEAN, BooleanType.class) + .put(TypeID.INTEGER, IntegerType.class) + .put(TypeID.LONG, LongType.class) + .put(TypeID.FLOAT, FloatType.class) + .put(TypeID.DOUBLE, DoubleType.class) + .put(TypeID.DATE, DateType.class) + .put(TypeID.TIMESTAMP, TimestampType.class) + .put(TypeID.DECIMAL, DecimalType.class) + .put(TypeID.UUID, StringType.class) + .put(TypeID.STRING, StringType.class) + .put(TypeID.FIXED, BinaryType.class) + .put(TypeID.BINARY, BinaryType.class) + .buildOrThrow(); +} diff --git a/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/RollbackStagedTable.java b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/RollbackStagedTable.java new file mode 100644 index 000000000000..bc8a966488ee --- /dev/null +++ b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/RollbackStagedTable.java @@ -0,0 +1,142 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark; + +import java.util.Map; +import java.util.Set; +import java.util.function.Consumer; +import java.util.function.Function; +import org.apache.spark.sql.connector.catalog.Identifier; +import org.apache.spark.sql.connector.catalog.StagedTable; +import org.apache.spark.sql.connector.catalog.SupportsDelete; +import org.apache.spark.sql.connector.catalog.SupportsRead; +import org.apache.spark.sql.connector.catalog.SupportsWrite; +import org.apache.spark.sql.connector.catalog.Table; +import org.apache.spark.sql.connector.catalog.TableCapability; +import org.apache.spark.sql.connector.catalog.TableCatalog; +import org.apache.spark.sql.connector.expressions.Transform; +import org.apache.spark.sql.connector.read.ScanBuilder; +import org.apache.spark.sql.connector.write.LogicalWriteInfo; +import org.apache.spark.sql.connector.write.WriteBuilder; +import org.apache.spark.sql.sources.Filter; +import org.apache.spark.sql.types.StructType; +import org.apache.spark.sql.util.CaseInsensitiveStringMap; + +/** + * An implementation of StagedTable that mimics the behavior of Spark's non-atomic CTAS and RTAS. + * + *

A Spark catalog can implement StagingTableCatalog to support atomic operations by producing + * StagedTable. But if a catalog implements StagingTableCatalog, Spark expects the catalog to be + * able to produce a StagedTable for any table loaded by the catalog. This assumption doesn't always + * work, as in the case of {@link SparkSessionCatalog}, which supports atomic operations can produce + * a StagedTable for Iceberg tables, but wraps the session catalog and cannot necessarily produce a + * working StagedTable implementation for tables that it loads. + * + *

The work-around is this class, which implements the StagedTable interface but does not have + * atomic behavior. Instead, the StagedTable interface is used to implement the behavior of the + * non-atomic SQL plans that will create a table, write, and will drop the table to roll back. + * + *

This StagedTable implements SupportsRead, SupportsWrite, and SupportsDelete by passing the + * calls to the real table. Implementing those interfaces is safe because Spark will not use them + * unless the table supports them and returns the corresponding capabilities from {@link + * #capabilities()}. + */ +public class RollbackStagedTable + implements StagedTable, SupportsRead, SupportsWrite, SupportsDelete { + private final TableCatalog catalog; + private final Identifier ident; + private final Table table; + + public RollbackStagedTable(TableCatalog catalog, Identifier ident, Table table) { + this.catalog = catalog; + this.ident = ident; + this.table = table; + } + + @Override + public void commitStagedChanges() { + // the changes have already been committed to the table at the end of the write + } + + @Override + public void abortStagedChanges() { + // roll back changes by dropping the table + catalog.dropTable(ident); + } + + @Override + public String name() { + return table.name(); + } + + @Override + public StructType schema() { + return table.schema(); + } + + @Override + public Transform[] partitioning() { + return table.partitioning(); + } + + @Override + public Map properties() { + return table.properties(); + } + + @Override + public Set capabilities() { + return table.capabilities(); + } + + @Override + public void deleteWhere(Filter[] filters) { + call(SupportsDelete.class, t -> t.deleteWhere(filters)); + } + + @Override + public ScanBuilder newScanBuilder(CaseInsensitiveStringMap options) { + return callReturning(SupportsRead.class, t -> t.newScanBuilder(options)); + } + + @Override + public WriteBuilder newWriteBuilder(LogicalWriteInfo info) { + return callReturning(SupportsWrite.class, t -> t.newWriteBuilder(info)); + } + + private void call(Class requiredClass, Consumer task) { + callReturning( + requiredClass, + inst -> { + task.accept(inst); + return null; + }); + } + + private R callReturning(Class requiredClass, Function task) { + if (requiredClass.isInstance(table)) { + return task.apply(requiredClass.cast(table)); + } else { + throw new UnsupportedOperationException( + String.format( + "Table does not implement %s: %s (%s)", + requiredClass.getSimpleName(), table.name(), table.getClass().getName())); + } + } +} diff --git a/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/ScanTaskSetManager.java b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/ScanTaskSetManager.java new file mode 100644 index 000000000000..84dab88fbad5 --- /dev/null +++ b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/ScanTaskSetManager.java @@ -0,0 +1,80 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark; + +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.stream.Collectors; +import org.apache.iceberg.HasTableOperations; +import org.apache.iceberg.ScanTask; +import org.apache.iceberg.Table; +import org.apache.iceberg.TableOperations; +import org.apache.iceberg.relocated.com.google.common.base.Preconditions; +import org.apache.iceberg.relocated.com.google.common.collect.Maps; +import org.apache.iceberg.util.Pair; + +public class ScanTaskSetManager { + + private static final ScanTaskSetManager INSTANCE = new ScanTaskSetManager(); + + private final Map, List> tasksMap = + Maps.newConcurrentMap(); + + private ScanTaskSetManager() {} + + public static ScanTaskSetManager get() { + return INSTANCE; + } + + public void stageTasks(Table table, String setId, List tasks) { + Preconditions.checkArgument( + tasks != null && tasks.size() > 0, "Cannot stage null or empty tasks"); + Pair id = toId(table, setId); + tasksMap.put(id, tasks); + } + + @SuppressWarnings("unchecked") + public List fetchTasks(Table table, String setId) { + Pair id = toId(table, setId); + return (List) tasksMap.get(id); + } + + @SuppressWarnings("unchecked") + public List removeTasks(Table table, String setId) { + Pair id = toId(table, setId); + return (List) tasksMap.remove(id); + } + + public Set fetchSetIds(Table table) { + return tasksMap.keySet().stream() + .filter(e -> e.first().equals(tableUUID(table))) + .map(Pair::second) + .collect(Collectors.toSet()); + } + + private String tableUUID(Table table) { + TableOperations ops = ((HasTableOperations) table).operations(); + return ops.current().uuid(); + } + + private Pair toId(Table table, String setId) { + return Pair.of(tableUUID(table), setId); + } +} diff --git a/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/SortOrderToSpark.java b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/SortOrderToSpark.java new file mode 100644 index 000000000000..52d68db2e4f9 --- /dev/null +++ b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/SortOrderToSpark.java @@ -0,0 +1,101 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark; + +import java.util.Map; +import org.apache.iceberg.NullOrder; +import org.apache.iceberg.Schema; +import org.apache.iceberg.SortDirection; +import org.apache.iceberg.transforms.SortOrderVisitor; +import org.apache.spark.sql.connector.expressions.Expressions; +import org.apache.spark.sql.connector.expressions.NullOrdering; +import org.apache.spark.sql.connector.expressions.SortOrder; + +class SortOrderToSpark implements SortOrderVisitor { + + private final Map quotedNameById; + + SortOrderToSpark(Schema schema) { + this.quotedNameById = SparkSchemaUtil.indexQuotedNameById(schema); + } + + @Override + public SortOrder field(String sourceName, int id, SortDirection direction, NullOrder nullOrder) { + return Expressions.sort( + Expressions.column(quotedName(id)), toSpark(direction), toSpark(nullOrder)); + } + + @Override + public SortOrder bucket( + String sourceName, int id, int width, SortDirection direction, NullOrder nullOrder) { + return Expressions.sort( + Expressions.bucket(width, quotedName(id)), toSpark(direction), toSpark(nullOrder)); + } + + @Override + public SortOrder truncate( + String sourceName, int id, int width, SortDirection direction, NullOrder nullOrder) { + return Expressions.sort( + Expressions.apply( + "truncate", Expressions.column(quotedName(id)), Expressions.literal(width)), + toSpark(direction), + toSpark(nullOrder)); + } + + @Override + public SortOrder year(String sourceName, int id, SortDirection direction, NullOrder nullOrder) { + return Expressions.sort( + Expressions.years(quotedName(id)), toSpark(direction), toSpark(nullOrder)); + } + + @Override + public SortOrder month(String sourceName, int id, SortDirection direction, NullOrder nullOrder) { + return Expressions.sort( + Expressions.months(quotedName(id)), toSpark(direction), toSpark(nullOrder)); + } + + @Override + public SortOrder day(String sourceName, int id, SortDirection direction, NullOrder nullOrder) { + return Expressions.sort( + Expressions.days(quotedName(id)), toSpark(direction), toSpark(nullOrder)); + } + + @Override + public SortOrder hour(String sourceName, int id, SortDirection direction, NullOrder nullOrder) { + return Expressions.sort( + Expressions.hours(quotedName(id)), toSpark(direction), toSpark(nullOrder)); + } + + private String quotedName(int id) { + return quotedNameById.get(id); + } + + private org.apache.spark.sql.connector.expressions.SortDirection toSpark( + SortDirection direction) { + if (direction == SortDirection.ASC) { + return org.apache.spark.sql.connector.expressions.SortDirection.ASCENDING; + } else { + return org.apache.spark.sql.connector.expressions.SortDirection.DESCENDING; + } + } + + private NullOrdering toSpark(NullOrder nullOrder) { + return nullOrder == NullOrder.NULLS_FIRST ? NullOrdering.NULLS_FIRST : NullOrdering.NULLS_LAST; + } +} diff --git a/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/Spark3Util.java b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/Spark3Util.java new file mode 100644 index 000000000000..23a53ea9e8c3 --- /dev/null +++ b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/Spark3Util.java @@ -0,0 +1,1020 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark; + +import java.nio.ByteBuffer; +import java.util.Arrays; +import java.util.List; +import java.util.Locale; +import java.util.Map; +import java.util.Objects; +import java.util.Set; +import java.util.stream.Collectors; +import java.util.stream.Stream; +import org.apache.hadoop.fs.FileStatus; +import org.apache.hadoop.fs.Path; +import org.apache.iceberg.NullOrder; +import org.apache.iceberg.PartitionField; +import org.apache.iceberg.PartitionSpec; +import org.apache.iceberg.Schema; +import org.apache.iceberg.TableOperations; +import org.apache.iceberg.TableProperties; +import org.apache.iceberg.UpdateProperties; +import org.apache.iceberg.UpdateSchema; +import org.apache.iceberg.catalog.Catalog; +import org.apache.iceberg.catalog.Namespace; +import org.apache.iceberg.catalog.TableIdentifier; +import org.apache.iceberg.expressions.BoundPredicate; +import org.apache.iceberg.expressions.ExpressionVisitors; +import org.apache.iceberg.expressions.Term; +import org.apache.iceberg.expressions.UnboundPredicate; +import org.apache.iceberg.expressions.Zorder; +import org.apache.iceberg.relocated.com.google.common.base.Joiner; +import org.apache.iceberg.relocated.com.google.common.base.Preconditions; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableList; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableSet; +import org.apache.iceberg.relocated.com.google.common.collect.Lists; +import org.apache.iceberg.relocated.com.google.common.collect.Maps; +import org.apache.iceberg.relocated.com.google.common.io.BaseEncoding; +import org.apache.iceberg.spark.SparkTableUtil.SparkPartition; +import org.apache.iceberg.spark.source.HasIcebergCatalog; +import org.apache.iceberg.spark.source.SparkTable; +import org.apache.iceberg.transforms.PartitionSpecVisitor; +import org.apache.iceberg.transforms.SortOrderVisitor; +import org.apache.iceberg.types.Type; +import org.apache.iceberg.types.TypeUtil; +import org.apache.iceberg.types.Types; +import org.apache.iceberg.util.ByteBuffers; +import org.apache.iceberg.util.Pair; +import org.apache.spark.sql.SparkSession; +import org.apache.spark.sql.catalyst.CatalystTypeConverters; +import org.apache.spark.sql.catalyst.analysis.NoSuchTableException; +import org.apache.spark.sql.catalyst.parser.ParseException; +import org.apache.spark.sql.catalyst.parser.ParserInterface; +import org.apache.spark.sql.connector.catalog.CatalogManager; +import org.apache.spark.sql.connector.catalog.CatalogPlugin; +import org.apache.spark.sql.connector.catalog.CatalogV2Implicits; +import org.apache.spark.sql.connector.catalog.Identifier; +import org.apache.spark.sql.connector.catalog.Table; +import org.apache.spark.sql.connector.catalog.TableCatalog; +import org.apache.spark.sql.connector.catalog.TableChange; +import org.apache.spark.sql.connector.expressions.Expression; +import org.apache.spark.sql.connector.expressions.Expressions; +import org.apache.spark.sql.connector.expressions.Literal; +import org.apache.spark.sql.connector.expressions.NamedReference; +import org.apache.spark.sql.connector.expressions.Transform; +import org.apache.spark.sql.execution.datasources.FileStatusCache; +import org.apache.spark.sql.execution.datasources.InMemoryFileIndex; +import org.apache.spark.sql.execution.datasources.PartitionDirectory; +import org.apache.spark.sql.types.IntegerType; +import org.apache.spark.sql.types.LongType; +import org.apache.spark.sql.types.StructType; +import org.apache.spark.sql.util.CaseInsensitiveStringMap; +import scala.Option; +import scala.collection.JavaConverters; +import scala.collection.immutable.Seq; + +public class Spark3Util { + + private static final Set RESERVED_PROPERTIES = + ImmutableSet.of(TableCatalog.PROP_LOCATION, TableCatalog.PROP_PROVIDER); + private static final Joiner DOT = Joiner.on("."); + + private Spark3Util() {} + + public static CaseInsensitiveStringMap setOption( + String key, String value, CaseInsensitiveStringMap options) { + Map newOptions = Maps.newHashMap(); + newOptions.putAll(options); + newOptions.put(key, value); + return new CaseInsensitiveStringMap(newOptions); + } + + public static Map rebuildCreateProperties(Map createProperties) { + ImmutableMap.Builder tableProperties = ImmutableMap.builder(); + createProperties.entrySet().stream() + .filter(entry -> !RESERVED_PROPERTIES.contains(entry.getKey())) + .forEach(tableProperties::put); + + String provider = createProperties.get(TableCatalog.PROP_PROVIDER); + if ("parquet".equalsIgnoreCase(provider)) { + tableProperties.put(TableProperties.DEFAULT_FILE_FORMAT, "parquet"); + } else if ("avro".equalsIgnoreCase(provider)) { + tableProperties.put(TableProperties.DEFAULT_FILE_FORMAT, "avro"); + } else if ("orc".equalsIgnoreCase(provider)) { + tableProperties.put(TableProperties.DEFAULT_FILE_FORMAT, "orc"); + } else if (provider != null && !"iceberg".equalsIgnoreCase(provider)) { + throw new IllegalArgumentException("Unsupported format in USING: " + provider); + } + + return tableProperties.build(); + } + + /** + * Applies a list of Spark table changes to an {@link UpdateProperties} operation. + * + * @param pendingUpdate an uncommitted UpdateProperties operation to configure + * @param changes a list of Spark table changes + * @return the UpdateProperties operation configured with the changes + */ + public static UpdateProperties applyPropertyChanges( + UpdateProperties pendingUpdate, List changes) { + for (TableChange change : changes) { + if (change instanceof TableChange.SetProperty) { + TableChange.SetProperty set = (TableChange.SetProperty) change; + pendingUpdate.set(set.property(), set.value()); + + } else if (change instanceof TableChange.RemoveProperty) { + TableChange.RemoveProperty remove = (TableChange.RemoveProperty) change; + pendingUpdate.remove(remove.property()); + + } else { + throw new UnsupportedOperationException("Cannot apply unknown table change: " + change); + } + } + + return pendingUpdate; + } + + /** + * Applies a list of Spark table changes to an {@link UpdateSchema} operation. + * + * @param pendingUpdate an uncommitted UpdateSchema operation to configure + * @param changes a list of Spark table changes + * @return the UpdateSchema operation configured with the changes + */ + public static UpdateSchema applySchemaChanges( + UpdateSchema pendingUpdate, List changes) { + for (TableChange change : changes) { + if (change instanceof TableChange.AddColumn) { + apply(pendingUpdate, (TableChange.AddColumn) change); + + } else if (change instanceof TableChange.UpdateColumnType) { + TableChange.UpdateColumnType update = (TableChange.UpdateColumnType) change; + Type newType = SparkSchemaUtil.convert(update.newDataType()); + Preconditions.checkArgument( + newType.isPrimitiveType(), + "Cannot update '%s', not a primitive type: %s", + DOT.join(update.fieldNames()), + update.newDataType()); + pendingUpdate.updateColumn(DOT.join(update.fieldNames()), newType.asPrimitiveType()); + + } else if (change instanceof TableChange.UpdateColumnComment) { + TableChange.UpdateColumnComment update = (TableChange.UpdateColumnComment) change; + pendingUpdate.updateColumnDoc(DOT.join(update.fieldNames()), update.newComment()); + + } else if (change instanceof TableChange.RenameColumn) { + TableChange.RenameColumn rename = (TableChange.RenameColumn) change; + pendingUpdate.renameColumn(DOT.join(rename.fieldNames()), rename.newName()); + + } else if (change instanceof TableChange.DeleteColumn) { + TableChange.DeleteColumn delete = (TableChange.DeleteColumn) change; + pendingUpdate.deleteColumn(DOT.join(delete.fieldNames())); + + } else if (change instanceof TableChange.UpdateColumnNullability) { + TableChange.UpdateColumnNullability update = (TableChange.UpdateColumnNullability) change; + if (update.nullable()) { + pendingUpdate.makeColumnOptional(DOT.join(update.fieldNames())); + } else { + pendingUpdate.requireColumn(DOT.join(update.fieldNames())); + } + + } else if (change instanceof TableChange.UpdateColumnPosition) { + apply(pendingUpdate, (TableChange.UpdateColumnPosition) change); + + } else { + throw new UnsupportedOperationException("Cannot apply unknown table change: " + change); + } + } + + return pendingUpdate; + } + + private static void apply(UpdateSchema pendingUpdate, TableChange.UpdateColumnPosition update) { + Preconditions.checkArgument(update.position() != null, "Invalid position: null"); + + if (update.position() instanceof TableChange.After) { + TableChange.After after = (TableChange.After) update.position(); + String referenceField = peerName(update.fieldNames(), after.column()); + pendingUpdate.moveAfter(DOT.join(update.fieldNames()), referenceField); + + } else if (update.position() instanceof TableChange.First) { + pendingUpdate.moveFirst(DOT.join(update.fieldNames())); + + } else { + throw new IllegalArgumentException("Unknown position for reorder: " + update.position()); + } + } + + private static void apply(UpdateSchema pendingUpdate, TableChange.AddColumn add) { + Preconditions.checkArgument( + add.isNullable(), + "Incompatible change: cannot add required column: %s", + leafName(add.fieldNames())); + Type type = SparkSchemaUtil.convert(add.dataType()); + pendingUpdate.addColumn( + parentName(add.fieldNames()), leafName(add.fieldNames()), type, add.comment()); + + if (add.position() instanceof TableChange.After) { + TableChange.After after = (TableChange.After) add.position(); + String referenceField = peerName(add.fieldNames(), after.column()); + pendingUpdate.moveAfter(DOT.join(add.fieldNames()), referenceField); + + } else if (add.position() instanceof TableChange.First) { + pendingUpdate.moveFirst(DOT.join(add.fieldNames())); + + } else { + Preconditions.checkArgument( + add.position() == null, + "Cannot add '%s' at unknown position: %s", + DOT.join(add.fieldNames()), + add.position()); + } + } + + public static org.apache.iceberg.Table toIcebergTable(Table table) { + Preconditions.checkArgument( + table instanceof SparkTable, "Table %s is not an Iceberg table", table); + SparkTable sparkTable = (SparkTable) table; + return sparkTable.table(); + } + + public static Transform[] toTransforms(Schema schema, List fields) { + SpecTransformToSparkTransform visitor = new SpecTransformToSparkTransform(schema); + + List transforms = Lists.newArrayList(); + + for (PartitionField field : fields) { + Transform transform = PartitionSpecVisitor.visit(schema, field, visitor); + if (transform != null) { + transforms.add(transform); + } + } + + return transforms.toArray(new Transform[0]); + } + + /** + * Converts a PartitionSpec to Spark transforms. + * + * @param spec a PartitionSpec + * @return an array of Transforms + */ + public static Transform[] toTransforms(PartitionSpec spec) { + SpecTransformToSparkTransform visitor = new SpecTransformToSparkTransform(spec.schema()); + List transforms = PartitionSpecVisitor.visit(spec, visitor); + return transforms.stream().filter(Objects::nonNull).toArray(Transform[]::new); + } + + private static class SpecTransformToSparkTransform implements PartitionSpecVisitor { + private final Map quotedNameById; + + SpecTransformToSparkTransform(Schema schema) { + this.quotedNameById = SparkSchemaUtil.indexQuotedNameById(schema); + } + + @Override + public Transform identity(String sourceName, int sourceId) { + return Expressions.identity(quotedName(sourceId)); + } + + @Override + public Transform bucket(String sourceName, int sourceId, int numBuckets) { + return Expressions.bucket(numBuckets, quotedName(sourceId)); + } + + @Override + public Transform truncate(String sourceName, int sourceId, int width) { + NamedReference column = Expressions.column(quotedName(sourceId)); + return Expressions.apply("truncate", Expressions.literal(width), column); + } + + @Override + public Transform year(String sourceName, int sourceId) { + return Expressions.years(quotedName(sourceId)); + } + + @Override + public Transform month(String sourceName, int sourceId) { + return Expressions.months(quotedName(sourceId)); + } + + @Override + public Transform day(String sourceName, int sourceId) { + return Expressions.days(quotedName(sourceId)); + } + + @Override + public Transform hour(String sourceName, int sourceId) { + return Expressions.hours(quotedName(sourceId)); + } + + @Override + public Transform alwaysNull(int fieldId, String sourceName, int sourceId) { + // do nothing for alwaysNull, it doesn't need to be converted to a transform + return null; + } + + @Override + public Transform unknown(int fieldId, String sourceName, int sourceId, String transform) { + return Expressions.apply(transform, Expressions.column(quotedName(sourceId))); + } + + private String quotedName(int id) { + return quotedNameById.get(id); + } + } + + public static NamedReference toNamedReference(String name) { + return Expressions.column(name); + } + + public static Term toIcebergTerm(Expression expr) { + if (expr instanceof Transform) { + Transform transform = (Transform) expr; + Preconditions.checkArgument( + "zorder".equals(transform.name()) || transform.references().length == 1, + "Cannot convert transform with more than one column reference: %s", + transform); + String colName = DOT.join(transform.references()[0].fieldNames()); + switch (transform.name().toLowerCase(Locale.ROOT)) { + case "identity": + return org.apache.iceberg.expressions.Expressions.ref(colName); + case "bucket": + return org.apache.iceberg.expressions.Expressions.bucket(colName, findWidth(transform)); + case "years": + return org.apache.iceberg.expressions.Expressions.year(colName); + case "months": + return org.apache.iceberg.expressions.Expressions.month(colName); + case "date": + case "days": + return org.apache.iceberg.expressions.Expressions.day(colName); + case "date_hour": + case "hours": + return org.apache.iceberg.expressions.Expressions.hour(colName); + case "truncate": + return org.apache.iceberg.expressions.Expressions.truncate(colName, findWidth(transform)); + case "zorder": + return new Zorder( + Stream.of(transform.references()) + .map(ref -> DOT.join(ref.fieldNames())) + .map(org.apache.iceberg.expressions.Expressions::ref) + .collect(Collectors.toList())); + default: + throw new UnsupportedOperationException("Transform is not supported: " + transform); + } + + } else if (expr instanceof NamedReference) { + NamedReference ref = (NamedReference) expr; + return org.apache.iceberg.expressions.Expressions.ref(DOT.join(ref.fieldNames())); + + } else { + throw new UnsupportedOperationException("Cannot convert unknown expression: " + expr); + } + } + + /** + * Converts Spark transforms into a {@link PartitionSpec}. + * + * @param schema the table schema + * @param partitioning Spark Transforms + * @return a PartitionSpec + */ + public static PartitionSpec toPartitionSpec(Schema schema, Transform[] partitioning) { + if (partitioning == null || partitioning.length == 0) { + return PartitionSpec.unpartitioned(); + } + + PartitionSpec.Builder builder = PartitionSpec.builderFor(schema); + for (Transform transform : partitioning) { + Preconditions.checkArgument( + transform.references().length == 1, + "Cannot convert transform with more than one column reference: %s", + transform); + String colName = DOT.join(transform.references()[0].fieldNames()); + switch (transform.name().toLowerCase(Locale.ROOT)) { + case "identity": + builder.identity(colName); + break; + case "bucket": + builder.bucket(colName, findWidth(transform)); + break; + case "years": + builder.year(colName); + break; + case "months": + builder.month(colName); + break; + case "date": + case "days": + builder.day(colName); + break; + case "date_hour": + case "hours": + builder.hour(colName); + break; + case "truncate": + builder.truncate(colName, findWidth(transform)); + break; + default: + throw new UnsupportedOperationException("Transform is not supported: " + transform); + } + } + + return builder.build(); + } + + @SuppressWarnings("unchecked") + private static int findWidth(Transform transform) { + for (Expression expr : transform.arguments()) { + if (expr instanceof Literal) { + if (((Literal) expr).dataType() instanceof IntegerType) { + Literal lit = (Literal) expr; + Preconditions.checkArgument( + lit.value() > 0, "Unsupported width for transform: %s", transform.describe()); + return lit.value(); + + } else if (((Literal) expr).dataType() instanceof LongType) { + Literal lit = (Literal) expr; + Preconditions.checkArgument( + lit.value() > 0 && lit.value() < Integer.MAX_VALUE, + "Unsupported width for transform: %s", + transform.describe()); + if (lit.value() > Integer.MAX_VALUE) { + throw new IllegalArgumentException(); + } + return lit.value().intValue(); + } + } + } + + throw new IllegalArgumentException("Cannot find width for transform: " + transform.describe()); + } + + private static String leafName(String[] fieldNames) { + Preconditions.checkArgument( + fieldNames.length > 0, "Invalid field name: at least one name is required"); + return fieldNames[fieldNames.length - 1]; + } + + private static String peerName(String[] fieldNames, String fieldName) { + if (fieldNames.length > 1) { + String[] peerNames = Arrays.copyOf(fieldNames, fieldNames.length); + peerNames[fieldNames.length - 1] = fieldName; + return DOT.join(peerNames); + } + return fieldName; + } + + private static String parentName(String[] fieldNames) { + if (fieldNames.length > 1) { + return DOT.join(Arrays.copyOfRange(fieldNames, 0, fieldNames.length - 1)); + } + return null; + } + + public static String describe(List exprs) { + return exprs.stream().map(Spark3Util::describe).collect(Collectors.joining(", ")); + } + + public static String describe(org.apache.iceberg.expressions.Expression expr) { + return ExpressionVisitors.visit(expr, DescribeExpressionVisitor.INSTANCE); + } + + public static String describe(Schema schema) { + return TypeUtil.visit(schema, DescribeSchemaVisitor.INSTANCE); + } + + public static String describe(Type type) { + return TypeUtil.visit(type, DescribeSchemaVisitor.INSTANCE); + } + + public static String describe(org.apache.iceberg.SortOrder order) { + return Joiner.on(", ").join(SortOrderVisitor.visit(order, DescribeSortOrderVisitor.INSTANCE)); + } + + public static boolean extensionsEnabled(SparkSession spark) { + String extensions = spark.conf().get("spark.sql.extensions", ""); + return extensions.contains("IcebergSparkSessionExtensions"); + } + + public static class DescribeSchemaVisitor extends TypeUtil.SchemaVisitor { + private static final Joiner COMMA = Joiner.on(','); + private static final DescribeSchemaVisitor INSTANCE = new DescribeSchemaVisitor(); + + private DescribeSchemaVisitor() {} + + @Override + public String schema(Schema schema, String structResult) { + return structResult; + } + + @Override + public String struct(Types.StructType struct, List fieldResults) { + return "struct<" + COMMA.join(fieldResults) + ">"; + } + + @Override + public String field(Types.NestedField field, String fieldResult) { + return field.name() + ": " + fieldResult + (field.isRequired() ? " not null" : ""); + } + + @Override + public String list(Types.ListType list, String elementResult) { + return "list<" + elementResult + ">"; + } + + @Override + public String map(Types.MapType map, String keyResult, String valueResult) { + return "map<" + keyResult + ", " + valueResult + ">"; + } + + @Override + public String primitive(Type.PrimitiveType primitive) { + switch (primitive.typeId()) { + case BOOLEAN: + return "boolean"; + case INTEGER: + return "int"; + case LONG: + return "bigint"; + case FLOAT: + return "float"; + case DOUBLE: + return "double"; + case DATE: + return "date"; + case TIME: + return "time"; + case TIMESTAMP: + return "timestamp"; + case STRING: + case UUID: + return "string"; + case FIXED: + case BINARY: + return "binary"; + case DECIMAL: + Types.DecimalType decimal = (Types.DecimalType) primitive; + return "decimal(" + decimal.precision() + "," + decimal.scale() + ")"; + } + throw new UnsupportedOperationException("Cannot convert type to SQL: " + primitive); + } + } + + private static class DescribeExpressionVisitor + extends ExpressionVisitors.ExpressionVisitor { + private static final DescribeExpressionVisitor INSTANCE = new DescribeExpressionVisitor(); + + private DescribeExpressionVisitor() {} + + @Override + public String alwaysTrue() { + return "true"; + } + + @Override + public String alwaysFalse() { + return "false"; + } + + @Override + public String not(String result) { + return "NOT (" + result + ")"; + } + + @Override + public String and(String leftResult, String rightResult) { + return "(" + leftResult + " AND " + rightResult + ")"; + } + + @Override + public String or(String leftResult, String rightResult) { + return "(" + leftResult + " OR " + rightResult + ")"; + } + + @Override + public String predicate(BoundPredicate pred) { + throw new UnsupportedOperationException("Cannot convert bound predicates to SQL"); + } + + @Override + public String predicate(UnboundPredicate pred) { + switch (pred.op()) { + case IS_NULL: + return pred.ref().name() + " IS NULL"; + case NOT_NULL: + return pred.ref().name() + " IS NOT NULL"; + case IS_NAN: + return "is_nan(" + pred.ref().name() + ")"; + case NOT_NAN: + return "not_nan(" + pred.ref().name() + ")"; + case LT: + return pred.ref().name() + " < " + sqlString(pred.literal()); + case LT_EQ: + return pred.ref().name() + " <= " + sqlString(pred.literal()); + case GT: + return pred.ref().name() + " > " + sqlString(pred.literal()); + case GT_EQ: + return pred.ref().name() + " >= " + sqlString(pred.literal()); + case EQ: + return pred.ref().name() + " = " + sqlString(pred.literal()); + case NOT_EQ: + return pred.ref().name() + " != " + sqlString(pred.literal()); + case STARTS_WITH: + return pred.ref().name() + " LIKE '" + pred.literal().value() + "%'"; + case NOT_STARTS_WITH: + return pred.ref().name() + " NOT LIKE '" + pred.literal().value() + "%'"; + case IN: + return pred.ref().name() + " IN (" + sqlString(pred.literals()) + ")"; + case NOT_IN: + return pred.ref().name() + " NOT IN (" + sqlString(pred.literals()) + ")"; + default: + throw new UnsupportedOperationException("Cannot convert predicate to SQL: " + pred); + } + } + + private static String sqlString(List> literals) { + return literals.stream() + .map(DescribeExpressionVisitor::sqlString) + .collect(Collectors.joining(", ")); + } + + private static String sqlString(org.apache.iceberg.expressions.Literal lit) { + if (lit.value() instanceof String) { + return "'" + lit.value() + "'"; + } else if (lit.value() instanceof ByteBuffer) { + byte[] bytes = ByteBuffers.toByteArray((ByteBuffer) lit.value()); + return "X'" + BaseEncoding.base16().encode(bytes) + "'"; + } else { + return lit.value().toString(); + } + } + } + + /** + * Returns an Iceberg Table by its name from a Spark V2 Catalog. If cache is enabled in {@link + * SparkCatalog}, the {@link TableOperations} of the table may be stale, please refresh the table + * to get the latest one. + * + * @param spark SparkSession used for looking up catalog references and tables + * @param name The multipart identifier of the Iceberg table + * @return an Iceberg table + */ + public static org.apache.iceberg.Table loadIcebergTable(SparkSession spark, String name) + throws ParseException, NoSuchTableException { + CatalogAndIdentifier catalogAndIdentifier = catalogAndIdentifier(spark, name); + + TableCatalog catalog = asTableCatalog(catalogAndIdentifier.catalog); + Table sparkTable = catalog.loadTable(catalogAndIdentifier.identifier); + return toIcebergTable(sparkTable); + } + + /** + * Returns the underlying Iceberg Catalog object represented by a Spark Catalog + * + * @param spark SparkSession used for looking up catalog reference + * @param catalogName The name of the Spark Catalog being referenced + * @return the Iceberg catalog class being wrapped by the Spark Catalog + */ + public static Catalog loadIcebergCatalog(SparkSession spark, String catalogName) { + CatalogPlugin catalogPlugin = spark.sessionState().catalogManager().catalog(catalogName); + Preconditions.checkArgument( + catalogPlugin instanceof HasIcebergCatalog, + String.format( + "Cannot load Iceberg catalog from catalog %s because it does not contain an Iceberg Catalog. " + + "Actual Class: %s", + catalogName, catalogPlugin.getClass().getName())); + return ((HasIcebergCatalog) catalogPlugin).icebergCatalog(); + } + + public static CatalogAndIdentifier catalogAndIdentifier(SparkSession spark, String name) + throws ParseException { + return catalogAndIdentifier( + spark, name, spark.sessionState().catalogManager().currentCatalog()); + } + + public static CatalogAndIdentifier catalogAndIdentifier( + SparkSession spark, String name, CatalogPlugin defaultCatalog) throws ParseException { + ParserInterface parser = spark.sessionState().sqlParser(); + Seq multiPartIdentifier = parser.parseMultipartIdentifier(name).toIndexedSeq(); + List javaMultiPartIdentifier = JavaConverters.seqAsJavaList(multiPartIdentifier); + return catalogAndIdentifier(spark, javaMultiPartIdentifier, defaultCatalog); + } + + public static CatalogAndIdentifier catalogAndIdentifier( + String description, SparkSession spark, String name) { + return catalogAndIdentifier( + description, spark, name, spark.sessionState().catalogManager().currentCatalog()); + } + + public static CatalogAndIdentifier catalogAndIdentifier( + String description, SparkSession spark, String name, CatalogPlugin defaultCatalog) { + try { + return catalogAndIdentifier(spark, name, defaultCatalog); + } catch (ParseException e) { + throw new IllegalArgumentException("Cannot parse " + description + ": " + name, e); + } + } + + public static CatalogAndIdentifier catalogAndIdentifier( + SparkSession spark, List nameParts) { + return catalogAndIdentifier( + spark, nameParts, spark.sessionState().catalogManager().currentCatalog()); + } + + /** + * A modified version of Spark's LookupCatalog.CatalogAndIdentifier.unapply Attempts to find the + * catalog and identifier a multipart identifier represents + * + * @param spark Spark session to use for resolution + * @param nameParts Multipart identifier representing a table + * @param defaultCatalog Catalog to use if none is specified + * @return The CatalogPlugin and Identifier for the table + */ + public static CatalogAndIdentifier catalogAndIdentifier( + SparkSession spark, List nameParts, CatalogPlugin defaultCatalog) { + CatalogManager catalogManager = spark.sessionState().catalogManager(); + + String[] currentNamespace; + if (defaultCatalog.equals(catalogManager.currentCatalog())) { + currentNamespace = catalogManager.currentNamespace(); + } else { + currentNamespace = defaultCatalog.defaultNamespace(); + } + + Pair catalogIdentifier = + SparkUtil.catalogAndIdentifier( + nameParts, + catalogName -> { + try { + return catalogManager.catalog(catalogName); + } catch (Exception e) { + return null; + } + }, + Identifier::of, + defaultCatalog, + currentNamespace); + return new CatalogAndIdentifier(catalogIdentifier); + } + + private static TableCatalog asTableCatalog(CatalogPlugin catalog) { + if (catalog instanceof TableCatalog) { + return (TableCatalog) catalog; + } + + throw new IllegalArgumentException( + String.format( + "Cannot use catalog %s(%s): not a TableCatalog", + catalog.name(), catalog.getClass().getName())); + } + + /** This mimics a class inside of Spark which is private inside of LookupCatalog. */ + public static class CatalogAndIdentifier { + private final CatalogPlugin catalog; + private final Identifier identifier; + + public CatalogAndIdentifier(CatalogPlugin catalog, Identifier identifier) { + this.catalog = catalog; + this.identifier = identifier; + } + + public CatalogAndIdentifier(Pair identifier) { + this.catalog = identifier.first(); + this.identifier = identifier.second(); + } + + public CatalogPlugin catalog() { + return catalog; + } + + public Identifier identifier() { + return identifier; + } + } + + public static TableIdentifier identifierToTableIdentifier(Identifier identifier) { + return TableIdentifier.of(Namespace.of(identifier.namespace()), identifier.name()); + } + + public static String quotedFullIdentifier(String catalogName, Identifier identifier) { + List parts = + ImmutableList.builder() + .add(catalogName) + .addAll(Arrays.asList(identifier.namespace())) + .add(identifier.name()) + .build(); + + return CatalogV2Implicits.MultipartIdentifierHelper( + JavaConverters.asScalaIteratorConverter(parts.iterator()).asScala().toSeq()) + .quoted(); + } + + /** + * Use Spark to list all partitions in the table. + * + * @param spark a Spark session + * @param rootPath a table identifier + * @param format format of the file + * @param partitionFilter partitionFilter of the file + * @return all table's partitions + * @deprecated use {@link Spark3Util#getPartitions(SparkSession, Path, String, Map, + * PartitionSpec)} + */ + @Deprecated + public static List getPartitions( + SparkSession spark, Path rootPath, String format, Map partitionFilter) { + return getPartitions(spark, rootPath, format, partitionFilter, null); + } + + /** + * Use Spark to list all partitions in the table. + * + * @param spark a Spark session + * @param rootPath a table identifier + * @param format format of the file + * @param partitionFilter partitionFilter of the file + * @param partitionSpec partitionSpec of the table + * @return all table's partitions + */ + public static List getPartitions( + SparkSession spark, + Path rootPath, + String format, + Map partitionFilter, + PartitionSpec partitionSpec) { + FileStatusCache fileStatusCache = FileStatusCache.getOrCreate(spark); + + Option userSpecifiedSchema = + partitionSpec == null + ? Option.empty() + : Option.apply( + SparkSchemaUtil.convert(new Schema(partitionSpec.partitionType().fields()))); + + InMemoryFileIndex fileIndex = + new InMemoryFileIndex( + spark, + JavaConverters.collectionAsScalaIterableConverter(ImmutableList.of(rootPath)) + .asScala() + .toSeq(), + scala.collection.immutable.Map$.MODULE$.empty(), + userSpecifiedSchema, + fileStatusCache, + Option.empty(), + Option.empty()); + + org.apache.spark.sql.execution.datasources.PartitionSpec spec = fileIndex.partitionSpec(); + StructType schema = spec.partitionColumns(); + if (schema.isEmpty()) { + return Lists.newArrayList(); + } + + List filterExpressions = + SparkUtil.partitionMapToExpression(schema, partitionFilter); + Seq scalaPartitionFilters = + JavaConverters.asScalaBufferConverter(filterExpressions).asScala().toIndexedSeq(); + + List dataFilters = Lists.newArrayList(); + Seq scalaDataFilters = + JavaConverters.asScalaBufferConverter(dataFilters).asScala().toIndexedSeq(); + + Seq filteredPartitions = + fileIndex.listFiles(scalaPartitionFilters, scalaDataFilters).toIndexedSeq(); + + return JavaConverters.seqAsJavaListConverter(filteredPartitions).asJava().stream() + .map( + partition -> { + Map values = Maps.newHashMap(); + JavaConverters.asJavaIterableConverter(schema) + .asJava() + .forEach( + field -> { + int fieldIndex = schema.fieldIndex(field.name()); + Object catalystValue = partition.values().get(fieldIndex, field.dataType()); + Object value = + CatalystTypeConverters.convertToScala(catalystValue, field.dataType()); + values.put(field.name(), String.valueOf(value)); + }); + + FileStatus fileStatus = + JavaConverters.seqAsJavaListConverter(partition.files()).asJava().get(0); + + return new SparkPartition( + values, fileStatus.getPath().getParent().toString(), format); + }) + .collect(Collectors.toList()); + } + + public static org.apache.spark.sql.catalyst.TableIdentifier toV1TableIdentifier( + Identifier identifier) { + String[] namespace = identifier.namespace(); + + Preconditions.checkArgument( + namespace.length <= 1, + "Cannot convert %s to a Spark v1 identifier, namespace contains more than 1 part", + identifier); + + String table = identifier.name(); + Option database = namespace.length == 1 ? Option.apply(namespace[0]) : Option.empty(); + return org.apache.spark.sql.catalyst.TableIdentifier.apply(table, database); + } + + private static class DescribeSortOrderVisitor implements SortOrderVisitor { + private static final DescribeSortOrderVisitor INSTANCE = new DescribeSortOrderVisitor(); + + private DescribeSortOrderVisitor() {} + + @Override + public String field( + String sourceName, + int sourceId, + org.apache.iceberg.SortDirection direction, + NullOrder nullOrder) { + return String.format("%s %s %s", sourceName, direction, nullOrder); + } + + @Override + public String bucket( + String sourceName, + int sourceId, + int numBuckets, + org.apache.iceberg.SortDirection direction, + NullOrder nullOrder) { + return String.format("bucket(%s, %s) %s %s", numBuckets, sourceName, direction, nullOrder); + } + + @Override + public String truncate( + String sourceName, + int sourceId, + int width, + org.apache.iceberg.SortDirection direction, + NullOrder nullOrder) { + return String.format("truncate(%s, %s) %s %s", sourceName, width, direction, nullOrder); + } + + @Override + public String year( + String sourceName, + int sourceId, + org.apache.iceberg.SortDirection direction, + NullOrder nullOrder) { + return String.format("years(%s) %s %s", sourceName, direction, nullOrder); + } + + @Override + public String month( + String sourceName, + int sourceId, + org.apache.iceberg.SortDirection direction, + NullOrder nullOrder) { + return String.format("months(%s) %s %s", sourceName, direction, nullOrder); + } + + @Override + public String day( + String sourceName, + int sourceId, + org.apache.iceberg.SortDirection direction, + NullOrder nullOrder) { + return String.format("days(%s) %s %s", sourceName, direction, nullOrder); + } + + @Override + public String hour( + String sourceName, + int sourceId, + org.apache.iceberg.SortDirection direction, + NullOrder nullOrder) { + return String.format("hours(%s) %s %s", sourceName, direction, nullOrder); + } + + @Override + public String unknown( + String sourceName, + int sourceId, + String transform, + org.apache.iceberg.SortDirection direction, + NullOrder nullOrder) { + return String.format("%s(%s) %s %s", transform, sourceName, direction, nullOrder); + } + } +} diff --git a/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/SparkAggregates.java b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/SparkAggregates.java new file mode 100644 index 000000000000..153ef11a9eb6 --- /dev/null +++ b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/SparkAggregates.java @@ -0,0 +1,84 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark; + +import java.util.Map; +import org.apache.iceberg.expressions.Expression; +import org.apache.iceberg.expressions.Expression.Operation; +import org.apache.iceberg.expressions.Expressions; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap; +import org.apache.spark.sql.connector.expressions.NamedReference; +import org.apache.spark.sql.connector.expressions.aggregate.AggregateFunc; +import org.apache.spark.sql.connector.expressions.aggregate.Count; +import org.apache.spark.sql.connector.expressions.aggregate.CountStar; +import org.apache.spark.sql.connector.expressions.aggregate.Max; +import org.apache.spark.sql.connector.expressions.aggregate.Min; + +public class SparkAggregates { + private SparkAggregates() {} + + private static final Map, Operation> AGGREGATES = + ImmutableMap., Operation>builder() + .put(Count.class, Operation.COUNT) + .put(CountStar.class, Operation.COUNT_STAR) + .put(Max.class, Operation.MAX) + .put(Min.class, Operation.MIN) + .buildOrThrow(); + + public static Expression convert(AggregateFunc aggregate) { + Operation op = AGGREGATES.get(aggregate.getClass()); + if (op != null) { + switch (op) { + case COUNT: + Count countAgg = (Count) aggregate; + if (countAgg.isDistinct()) { + // manifest file doesn't have count distinct so this can't be pushed down + return null; + } + + if (countAgg.column() instanceof NamedReference) { + return Expressions.count(SparkUtil.toColumnName((NamedReference) countAgg.column())); + } else { + return null; + } + + case COUNT_STAR: + return Expressions.countStar(); + + case MAX: + Max maxAgg = (Max) aggregate; + if (maxAgg.column() instanceof NamedReference) { + return Expressions.max(SparkUtil.toColumnName((NamedReference) maxAgg.column())); + } else { + return null; + } + + case MIN: + Min minAgg = (Min) aggregate; + if (minAgg.column() instanceof NamedReference) { + return Expressions.min(SparkUtil.toColumnName((NamedReference) minAgg.column())); + } else { + return null; + } + } + } + + return null; + } +} diff --git a/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/SparkCachedTableCatalog.java b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/SparkCachedTableCatalog.java new file mode 100644 index 000000000000..2533b3bd75b5 --- /dev/null +++ b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/SparkCachedTableCatalog.java @@ -0,0 +1,211 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark; + +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.concurrent.TimeUnit; +import java.util.regex.Matcher; +import java.util.regex.Pattern; +import java.util.stream.Stream; +import org.apache.iceberg.Snapshot; +import org.apache.iceberg.Table; +import org.apache.iceberg.relocated.com.google.common.base.Preconditions; +import org.apache.iceberg.relocated.com.google.common.base.Splitter; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableList; +import org.apache.iceberg.spark.source.SparkTable; +import org.apache.iceberg.util.Pair; +import org.apache.iceberg.util.SnapshotUtil; +import org.apache.spark.sql.catalyst.analysis.NoSuchTableException; +import org.apache.spark.sql.catalyst.analysis.TableAlreadyExistsException; +import org.apache.spark.sql.connector.catalog.Identifier; +import org.apache.spark.sql.connector.catalog.TableCatalog; +import org.apache.spark.sql.connector.catalog.TableChange; +import org.apache.spark.sql.connector.expressions.Transform; +import org.apache.spark.sql.types.StructType; +import org.apache.spark.sql.util.CaseInsensitiveStringMap; + +/** An internal table catalog that is capable of loading tables from a cache. */ +public class SparkCachedTableCatalog implements TableCatalog { + + private static final String CLASS_NAME = SparkCachedTableCatalog.class.getName(); + private static final Splitter COMMA = Splitter.on(","); + private static final Pattern AT_TIMESTAMP = Pattern.compile("at_timestamp_(\\d+)"); + private static final Pattern SNAPSHOT_ID = Pattern.compile("snapshot_id_(\\d+)"); + private static final Pattern BRANCH = Pattern.compile("branch_(.*)"); + private static final Pattern TAG = Pattern.compile("tag_(.*)"); + + private static final SparkTableCache TABLE_CACHE = SparkTableCache.get(); + + private String name = null; + + @Override + public Identifier[] listTables(String[] namespace) { + throw new UnsupportedOperationException(CLASS_NAME + " does not support listing tables"); + } + + @Override + public SparkTable loadTable(Identifier ident) throws NoSuchTableException { + Pair table = load(ident); + return new SparkTable(table.first(), table.second(), false /* refresh eagerly */); + } + + @Override + public SparkTable loadTable(Identifier ident, String version) throws NoSuchTableException { + Pair table = load(ident); + Preconditions.checkArgument( + table.second() == null, "Cannot time travel based on both table identifier and AS OF"); + return new SparkTable(table.first(), Long.parseLong(version), false /* refresh eagerly */); + } + + @Override + public SparkTable loadTable(Identifier ident, long timestampMicros) throws NoSuchTableException { + Pair table = load(ident); + Preconditions.checkArgument( + table.second() == null, "Cannot time travel based on both table identifier and AS OF"); + // Spark passes microseconds but Iceberg uses milliseconds for snapshots + long timestampMillis = TimeUnit.MICROSECONDS.toMillis(timestampMicros); + long snapshotId = SnapshotUtil.snapshotIdAsOfTime(table.first(), timestampMillis); + return new SparkTable(table.first(), snapshotId, false /* refresh eagerly */); + } + + @Override + public void invalidateTable(Identifier ident) { + throw new UnsupportedOperationException(CLASS_NAME + " does not support table invalidation"); + } + + @Override + public SparkTable createTable( + Identifier ident, StructType schema, Transform[] partitions, Map properties) + throws TableAlreadyExistsException { + throw new UnsupportedOperationException(CLASS_NAME + " does not support creating tables"); + } + + @Override + public SparkTable alterTable(Identifier ident, TableChange... changes) { + throw new UnsupportedOperationException(CLASS_NAME + " does not support altering tables"); + } + + @Override + public boolean dropTable(Identifier ident) { + throw new UnsupportedOperationException(CLASS_NAME + " does not support dropping tables"); + } + + @Override + public boolean purgeTable(Identifier ident) throws UnsupportedOperationException { + throw new UnsupportedOperationException(CLASS_NAME + " does not support purging tables"); + } + + @Override + public void renameTable(Identifier oldIdent, Identifier newIdent) { + throw new UnsupportedOperationException(CLASS_NAME + " does not support renaming tables"); + } + + @Override + public void initialize(String catalogName, CaseInsensitiveStringMap options) { + this.name = catalogName; + } + + @Override + public String name() { + return name; + } + + private Pair load(Identifier ident) throws NoSuchTableException { + Preconditions.checkArgument( + ident.namespace().length == 0, CLASS_NAME + " does not support namespaces"); + + Pair> parsedIdent = parseIdent(ident); + String key = parsedIdent.first(); + List metadata = parsedIdent.second(); + + Long asOfTimestamp = null; + Long snapshotId = null; + String branch = null; + String tag = null; + for (String meta : metadata) { + Matcher timeBasedMatcher = AT_TIMESTAMP.matcher(meta); + if (timeBasedMatcher.matches()) { + asOfTimestamp = Long.parseLong(timeBasedMatcher.group(1)); + continue; + } + + Matcher snapshotBasedMatcher = SNAPSHOT_ID.matcher(meta); + if (snapshotBasedMatcher.matches()) { + snapshotId = Long.parseLong(snapshotBasedMatcher.group(1)); + continue; + } + + Matcher branchBasedMatcher = BRANCH.matcher(meta); + if (branchBasedMatcher.matches()) { + branch = branchBasedMatcher.group(1); + continue; + } + + Matcher tagBasedMatcher = TAG.matcher(meta); + if (tagBasedMatcher.matches()) { + tag = tagBasedMatcher.group(1); + } + } + + Preconditions.checkArgument( + Stream.of(snapshotId, asOfTimestamp, branch, tag).filter(Objects::nonNull).count() <= 1, + "Can specify only one of snapshot-id (%s), as-of-timestamp (%s), branch (%s), tag (%s)", + snapshotId, + asOfTimestamp, + branch, + tag); + + Table table = TABLE_CACHE.get(key); + + if (table == null) { + throw new NoSuchTableException(ident); + } + + if (snapshotId != null) { + return Pair.of(table, snapshotId); + } else if (asOfTimestamp != null) { + return Pair.of(table, SnapshotUtil.snapshotIdAsOfTime(table, asOfTimestamp)); + } else if (branch != null) { + Snapshot branchSnapshot = table.snapshot(branch); + Preconditions.checkArgument( + branchSnapshot != null, "Cannot find snapshot associated with branch name: %s", branch); + return Pair.of(table, branchSnapshot.snapshotId()); + } else if (tag != null) { + Snapshot tagSnapshot = table.snapshot(tag); + Preconditions.checkArgument( + tagSnapshot != null, "Cannot find snapshot associated with tag name: %s", tag); + return Pair.of(table, tagSnapshot.snapshotId()); + } else { + return Pair.of(table, null); + } + } + + private Pair> parseIdent(Identifier ident) { + int hashIndex = ident.name().lastIndexOf('#'); + if (hashIndex != -1 && !ident.name().endsWith("#")) { + String key = ident.name().substring(0, hashIndex); + List metadata = COMMA.splitToList(ident.name().substring(hashIndex + 1)); + return Pair.of(key, metadata); + } else { + return Pair.of(ident.name(), ImmutableList.of()); + } + } +} diff --git a/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/SparkCatalog.java b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/SparkCatalog.java new file mode 100644 index 000000000000..3ad3f5d0ee2a --- /dev/null +++ b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/SparkCatalog.java @@ -0,0 +1,795 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark; + +import static org.apache.iceberg.TableProperties.GC_ENABLED; +import static org.apache.iceberg.TableProperties.GC_ENABLED_DEFAULT; + +import java.util.Arrays; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.Set; +import java.util.TreeMap; +import java.util.concurrent.TimeUnit; +import java.util.regex.Matcher; +import java.util.regex.Pattern; +import java.util.stream.Stream; +import org.apache.hadoop.conf.Configuration; +import org.apache.iceberg.CachingCatalog; +import org.apache.iceberg.CatalogProperties; +import org.apache.iceberg.CatalogUtil; +import org.apache.iceberg.EnvironmentContext; +import org.apache.iceberg.HasTableOperations; +import org.apache.iceberg.MetadataTableType; +import org.apache.iceberg.Schema; +import org.apache.iceberg.Snapshot; +import org.apache.iceberg.SnapshotRef; +import org.apache.iceberg.Transaction; +import org.apache.iceberg.catalog.Catalog; +import org.apache.iceberg.catalog.Namespace; +import org.apache.iceberg.catalog.SupportsNamespaces; +import org.apache.iceberg.catalog.TableIdentifier; +import org.apache.iceberg.exceptions.AlreadyExistsException; +import org.apache.iceberg.exceptions.ValidationException; +import org.apache.iceberg.hadoop.HadoopCatalog; +import org.apache.iceberg.hadoop.HadoopTables; +import org.apache.iceberg.relocated.com.google.common.base.Preconditions; +import org.apache.iceberg.relocated.com.google.common.base.Splitter; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableList; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableSet; +import org.apache.iceberg.relocated.com.google.common.collect.Lists; +import org.apache.iceberg.relocated.com.google.common.collect.Maps; +import org.apache.iceberg.relocated.com.google.common.collect.Sets; +import org.apache.iceberg.spark.actions.SparkActions; +import org.apache.iceberg.spark.source.SparkChangelogTable; +import org.apache.iceberg.spark.source.SparkTable; +import org.apache.iceberg.spark.source.StagedSparkTable; +import org.apache.iceberg.util.Pair; +import org.apache.iceberg.util.PropertyUtil; +import org.apache.iceberg.util.SnapshotUtil; +import org.apache.spark.sql.SparkSession; +import org.apache.spark.sql.catalyst.analysis.NamespaceAlreadyExistsException; +import org.apache.spark.sql.catalyst.analysis.NoSuchNamespaceException; +import org.apache.spark.sql.catalyst.analysis.NoSuchTableException; +import org.apache.spark.sql.catalyst.analysis.TableAlreadyExistsException; +import org.apache.spark.sql.connector.catalog.Identifier; +import org.apache.spark.sql.connector.catalog.NamespaceChange; +import org.apache.spark.sql.connector.catalog.StagedTable; +import org.apache.spark.sql.connector.catalog.Table; +import org.apache.spark.sql.connector.catalog.TableCatalog; +import org.apache.spark.sql.connector.catalog.TableChange; +import org.apache.spark.sql.connector.catalog.TableChange.ColumnChange; +import org.apache.spark.sql.connector.catalog.TableChange.RemoveProperty; +import org.apache.spark.sql.connector.catalog.TableChange.SetProperty; +import org.apache.spark.sql.connector.expressions.Transform; +import org.apache.spark.sql.types.StructType; +import org.apache.spark.sql.util.CaseInsensitiveStringMap; + +/** + * A Spark TableCatalog implementation that wraps an Iceberg {@link Catalog}. + * + *

This supports the following catalog configuration options: + * + *

    + *
  • type - catalog type, "hive" or "hadoop". To specify a non-hive or hadoop + * catalog, use the catalog-impl option. + *
  • uri - the Hive Metastore URI (Hive catalog only) + *
  • warehouse - the warehouse path (Hadoop catalog only) + *
  • catalog-impl - a custom {@link Catalog} implementation to use + *
  • default-namespace - a namespace to use as the default + *
  • cache-enabled - whether to enable catalog cache + *
  • cache.expiration-interval-ms - interval in millis before expiring tables from + * catalog cache. Refer to {@link CatalogProperties#CACHE_EXPIRATION_INTERVAL_MS} for further + * details and significant values. + *
+ * + *

+ */ +public class SparkCatalog extends BaseCatalog { + private static final Set DEFAULT_NS_KEYS = ImmutableSet.of(TableCatalog.PROP_OWNER); + private static final Splitter COMMA = Splitter.on(","); + private static final Pattern AT_TIMESTAMP = Pattern.compile("at_timestamp_(\\d+)"); + private static final Pattern SNAPSHOT_ID = Pattern.compile("snapshot_id_(\\d+)"); + private static final Pattern BRANCH = Pattern.compile("branch_(.*)"); + private static final Pattern TAG = Pattern.compile("tag_(.*)"); + + private String catalogName = null; + private Catalog icebergCatalog = null; + private boolean cacheEnabled = CatalogProperties.CACHE_ENABLED_DEFAULT; + private SupportsNamespaces asNamespaceCatalog = null; + private String[] defaultNamespace = null; + private HadoopTables tables; + private boolean useTimestampsWithoutZone; + + /** + * Build an Iceberg {@link Catalog} to be used by this Spark catalog adapter. + * + * @param name Spark's catalog name + * @param options Spark's catalog options + * @return an Iceberg catalog + */ + protected Catalog buildIcebergCatalog(String name, CaseInsensitiveStringMap options) { + Configuration conf = SparkUtil.hadoopConfCatalogOverrides(SparkSession.active(), name); + Map optionsMap = new TreeMap<>(String.CASE_INSENSITIVE_ORDER); + optionsMap.putAll(options.asCaseSensitiveMap()); + optionsMap.put(CatalogProperties.APP_ID, SparkSession.active().sparkContext().applicationId()); + optionsMap.put(CatalogProperties.USER, SparkSession.active().sparkContext().sparkUser()); + return CatalogUtil.buildIcebergCatalog(name, optionsMap, conf); + } + + /** + * Build an Iceberg {@link TableIdentifier} for the given Spark identifier. + * + * @param identifier Spark's identifier + * @return an Iceberg identifier + */ + protected TableIdentifier buildIdentifier(Identifier identifier) { + return Spark3Util.identifierToTableIdentifier(identifier); + } + + @Override + public Table loadTable(Identifier ident) throws NoSuchTableException { + try { + return load(ident); + } catch (org.apache.iceberg.exceptions.NoSuchTableException e) { + throw new NoSuchTableException(ident); + } + } + + @Override + public Table loadTable(Identifier ident, String version) throws NoSuchTableException { + Table table = loadTable(ident); + + if (table instanceof SparkTable) { + SparkTable sparkTable = (SparkTable) table; + + Preconditions.checkArgument( + sparkTable.snapshotId() == null, + "Cannot do time-travel based on both table identifier and AS OF"); + + try { + return sparkTable.copyWithSnapshotId(Long.parseLong(version)); + } catch (NumberFormatException e) { + SnapshotRef ref = sparkTable.table().refs().get(version); + ValidationException.check( + ref != null, + "Cannot find matching snapshot ID or reference name for version " + version); + + if (ref.isBranch()) { + return sparkTable.copyWithBranch(version); + } else { + return sparkTable.copyWithSnapshotId(ref.snapshotId()); + } + } + + } else if (table instanceof SparkChangelogTable) { + throw new UnsupportedOperationException("AS OF is not supported for changelogs"); + + } else { + throw new IllegalArgumentException("Unknown Spark table type: " + table.getClass().getName()); + } + } + + @Override + public Table loadTable(Identifier ident, long timestamp) throws NoSuchTableException { + Table table = loadTable(ident); + + if (table instanceof SparkTable) { + SparkTable sparkTable = (SparkTable) table; + + Preconditions.checkArgument( + sparkTable.snapshotId() == null, + "Cannot do time-travel based on both table identifier and AS OF"); + + // convert the timestamp to milliseconds as Spark passes microseconds + // but Iceberg uses milliseconds for snapshot timestamps + long timestampMillis = TimeUnit.MICROSECONDS.toMillis(timestamp); + long snapshotId = SnapshotUtil.snapshotIdAsOfTime(sparkTable.table(), timestampMillis); + return sparkTable.copyWithSnapshotId(snapshotId); + + } else if (table instanceof SparkChangelogTable) { + throw new UnsupportedOperationException("AS OF is not supported for changelogs"); + + } else { + throw new IllegalArgumentException("Unknown Spark table type: " + table.getClass().getName()); + } + } + + @Override + public Table createTable( + Identifier ident, StructType schema, Transform[] transforms, Map properties) + throws TableAlreadyExistsException { + Schema icebergSchema = SparkSchemaUtil.convert(schema, useTimestampsWithoutZone); + try { + Catalog.TableBuilder builder = newBuilder(ident, icebergSchema); + org.apache.iceberg.Table icebergTable = + builder + .withPartitionSpec(Spark3Util.toPartitionSpec(icebergSchema, transforms)) + .withLocation(properties.get("location")) + .withProperties(Spark3Util.rebuildCreateProperties(properties)) + .create(); + return new SparkTable(icebergTable, !cacheEnabled); + } catch (AlreadyExistsException e) { + throw new TableAlreadyExistsException(ident); + } + } + + @Override + public StagedTable stageCreate( + Identifier ident, StructType schema, Transform[] transforms, Map properties) + throws TableAlreadyExistsException { + Schema icebergSchema = SparkSchemaUtil.convert(schema, useTimestampsWithoutZone); + try { + Catalog.TableBuilder builder = newBuilder(ident, icebergSchema); + Transaction transaction = + builder + .withPartitionSpec(Spark3Util.toPartitionSpec(icebergSchema, transforms)) + .withLocation(properties.get("location")) + .withProperties(Spark3Util.rebuildCreateProperties(properties)) + .createTransaction(); + return new StagedSparkTable(transaction); + } catch (AlreadyExistsException e) { + throw new TableAlreadyExistsException(ident); + } + } + + @Override + public StagedTable stageReplace( + Identifier ident, StructType schema, Transform[] transforms, Map properties) + throws NoSuchTableException { + Schema icebergSchema = SparkSchemaUtil.convert(schema, useTimestampsWithoutZone); + try { + Catalog.TableBuilder builder = newBuilder(ident, icebergSchema); + Transaction transaction = + builder + .withPartitionSpec(Spark3Util.toPartitionSpec(icebergSchema, transforms)) + .withLocation(properties.get("location")) + .withProperties(Spark3Util.rebuildCreateProperties(properties)) + .replaceTransaction(); + return new StagedSparkTable(transaction); + } catch (org.apache.iceberg.exceptions.NoSuchTableException e) { + throw new NoSuchTableException(ident); + } + } + + @Override + public StagedTable stageCreateOrReplace( + Identifier ident, StructType schema, Transform[] transforms, Map properties) { + Schema icebergSchema = SparkSchemaUtil.convert(schema, useTimestampsWithoutZone); + Catalog.TableBuilder builder = newBuilder(ident, icebergSchema); + Transaction transaction = + builder + .withPartitionSpec(Spark3Util.toPartitionSpec(icebergSchema, transforms)) + .withLocation(properties.get("location")) + .withProperties(Spark3Util.rebuildCreateProperties(properties)) + .createOrReplaceTransaction(); + return new StagedSparkTable(transaction); + } + + @Override + public Table alterTable(Identifier ident, TableChange... changes) throws NoSuchTableException { + SetProperty setLocation = null; + SetProperty setSnapshotId = null; + SetProperty pickSnapshotId = null; + List propertyChanges = Lists.newArrayList(); + List schemaChanges = Lists.newArrayList(); + + for (TableChange change : changes) { + if (change instanceof SetProperty) { + SetProperty set = (SetProperty) change; + if (TableCatalog.PROP_LOCATION.equalsIgnoreCase(set.property())) { + setLocation = set; + } else if ("current-snapshot-id".equalsIgnoreCase(set.property())) { + setSnapshotId = set; + } else if ("cherry-pick-snapshot-id".equalsIgnoreCase(set.property())) { + pickSnapshotId = set; + } else if ("sort-order".equalsIgnoreCase(set.property())) { + throw new UnsupportedOperationException( + "Cannot specify the 'sort-order' because it's a reserved table " + + "property. Please use the command 'ALTER TABLE ... WRITE ORDERED BY' to specify write sort-orders."); + } else { + propertyChanges.add(set); + } + } else if (change instanceof RemoveProperty) { + propertyChanges.add(change); + } else if (change instanceof ColumnChange) { + schemaChanges.add(change); + } else { + throw new UnsupportedOperationException("Cannot apply unknown table change: " + change); + } + } + + try { + org.apache.iceberg.Table table = icebergCatalog.loadTable(buildIdentifier(ident)); + commitChanges( + table, setLocation, setSnapshotId, pickSnapshotId, propertyChanges, schemaChanges); + return new SparkTable(table, true /* refreshEagerly */); + } catch (org.apache.iceberg.exceptions.NoSuchTableException e) { + throw new NoSuchTableException(ident); + } + } + + @Override + public boolean dropTable(Identifier ident) { + return dropTableWithoutPurging(ident); + } + + @Override + public boolean purgeTable(Identifier ident) { + try { + org.apache.iceberg.Table table = icebergCatalog.loadTable(buildIdentifier(ident)); + ValidationException.check( + PropertyUtil.propertyAsBoolean(table.properties(), GC_ENABLED, GC_ENABLED_DEFAULT), + "Cannot purge table: GC is disabled (deleting files may corrupt other tables)"); + String metadataFileLocation = + ((HasTableOperations) table).operations().current().metadataFileLocation(); + + boolean dropped = dropTableWithoutPurging(ident); + + if (dropped) { + // check whether the metadata file exists because HadoopCatalog/HadoopTables + // will drop the warehouse directly and ignore the `purge` argument + boolean metadataFileExists = table.io().newInputFile(metadataFileLocation).exists(); + + if (metadataFileExists) { + SparkActions.get().deleteReachableFiles(metadataFileLocation).io(table.io()).execute(); + } + } + + return dropped; + } catch (org.apache.iceberg.exceptions.NoSuchTableException e) { + return false; + } + } + + private boolean dropTableWithoutPurging(Identifier ident) { + if (isPathIdentifier(ident)) { + return tables.dropTable(((PathIdentifier) ident).location(), false /* don't purge data */); + } else { + return icebergCatalog.dropTable(buildIdentifier(ident), false /* don't purge data */); + } + } + + @Override + public void renameTable(Identifier from, Identifier to) + throws NoSuchTableException, TableAlreadyExistsException { + try { + checkNotPathIdentifier(from, "renameTable"); + checkNotPathIdentifier(to, "renameTable"); + icebergCatalog.renameTable(buildIdentifier(from), buildIdentifier(to)); + } catch (org.apache.iceberg.exceptions.NoSuchTableException e) { + throw new NoSuchTableException(from); + } catch (AlreadyExistsException e) { + throw new TableAlreadyExistsException(to); + } + } + + @Override + public void invalidateTable(Identifier ident) { + if (!isPathIdentifier(ident)) { + icebergCatalog.invalidateTable(buildIdentifier(ident)); + } + } + + @Override + public Identifier[] listTables(String[] namespace) { + return icebergCatalog.listTables(Namespace.of(namespace)).stream() + .map(ident -> Identifier.of(ident.namespace().levels(), ident.name())) + .toArray(Identifier[]::new); + } + + @Override + public String[] defaultNamespace() { + if (defaultNamespace != null) { + return defaultNamespace; + } + + return new String[0]; + } + + @Override + public String[][] listNamespaces() { + if (asNamespaceCatalog != null) { + return asNamespaceCatalog.listNamespaces().stream() + .map(Namespace::levels) + .toArray(String[][]::new); + } + + return new String[0][]; + } + + @Override + public String[][] listNamespaces(String[] namespace) throws NoSuchNamespaceException { + if (asNamespaceCatalog != null) { + try { + return asNamespaceCatalog.listNamespaces(Namespace.of(namespace)).stream() + .map(Namespace::levels) + .toArray(String[][]::new); + } catch (org.apache.iceberg.exceptions.NoSuchNamespaceException e) { + throw new NoSuchNamespaceException(namespace); + } + } + + throw new NoSuchNamespaceException(namespace); + } + + @Override + public Map loadNamespaceMetadata(String[] namespace) + throws NoSuchNamespaceException { + if (asNamespaceCatalog != null) { + try { + return asNamespaceCatalog.loadNamespaceMetadata(Namespace.of(namespace)); + } catch (org.apache.iceberg.exceptions.NoSuchNamespaceException e) { + throw new NoSuchNamespaceException(namespace); + } + } + + throw new NoSuchNamespaceException(namespace); + } + + @Override + public void createNamespace(String[] namespace, Map metadata) + throws NamespaceAlreadyExistsException { + if (asNamespaceCatalog != null) { + try { + if (asNamespaceCatalog instanceof HadoopCatalog + && DEFAULT_NS_KEYS.equals(metadata.keySet())) { + // Hadoop catalog will reject metadata properties, but Spark automatically adds "owner". + // If only the automatic properties are present, replace metadata with an empty map. + asNamespaceCatalog.createNamespace(Namespace.of(namespace), ImmutableMap.of()); + } else { + asNamespaceCatalog.createNamespace(Namespace.of(namespace), metadata); + } + } catch (AlreadyExistsException e) { + throw new NamespaceAlreadyExistsException(namespace); + } + } else { + throw new UnsupportedOperationException( + "Namespaces are not supported by catalog: " + catalogName); + } + } + + @Override + public void alterNamespace(String[] namespace, NamespaceChange... changes) + throws NoSuchNamespaceException { + if (asNamespaceCatalog != null) { + Map updates = Maps.newHashMap(); + Set removals = Sets.newHashSet(); + for (NamespaceChange change : changes) { + if (change instanceof NamespaceChange.SetProperty) { + NamespaceChange.SetProperty set = (NamespaceChange.SetProperty) change; + updates.put(set.property(), set.value()); + } else if (change instanceof NamespaceChange.RemoveProperty) { + removals.add(((NamespaceChange.RemoveProperty) change).property()); + } else { + throw new UnsupportedOperationException( + "Cannot apply unknown namespace change: " + change); + } + } + + try { + if (!updates.isEmpty()) { + asNamespaceCatalog.setProperties(Namespace.of(namespace), updates); + } + + if (!removals.isEmpty()) { + asNamespaceCatalog.removeProperties(Namespace.of(namespace), removals); + } + + } catch (org.apache.iceberg.exceptions.NoSuchNamespaceException e) { + throw new NoSuchNamespaceException(namespace); + } + } else { + throw new NoSuchNamespaceException(namespace); + } + } + + @Override + public boolean dropNamespace(String[] namespace, boolean cascade) + throws NoSuchNamespaceException { + if (asNamespaceCatalog != null) { + try { + return asNamespaceCatalog.dropNamespace(Namespace.of(namespace)); + } catch (org.apache.iceberg.exceptions.NoSuchNamespaceException e) { + throw new NoSuchNamespaceException(namespace); + } + } + + return false; + } + + @Override + public final void initialize(String name, CaseInsensitiveStringMap options) { + this.cacheEnabled = + PropertyUtil.propertyAsBoolean( + options, CatalogProperties.CACHE_ENABLED, CatalogProperties.CACHE_ENABLED_DEFAULT); + + long cacheExpirationIntervalMs = + PropertyUtil.propertyAsLong( + options, + CatalogProperties.CACHE_EXPIRATION_INTERVAL_MS, + CatalogProperties.CACHE_EXPIRATION_INTERVAL_MS_DEFAULT); + + // An expiration interval of 0ms effectively disables caching. + // Do not wrap with CachingCatalog. + if (cacheExpirationIntervalMs == 0) { + this.cacheEnabled = false; + } + + Catalog catalog = buildIcebergCatalog(name, options); + + this.catalogName = name; + SparkSession sparkSession = SparkSession.active(); + this.useTimestampsWithoutZone = + SparkUtil.useTimestampWithoutZoneInNewTables(sparkSession.conf()); + this.tables = + new HadoopTables(SparkUtil.hadoopConfCatalogOverrides(SparkSession.active(), name)); + this.icebergCatalog = + cacheEnabled ? CachingCatalog.wrap(catalog, cacheExpirationIntervalMs) : catalog; + if (catalog instanceof SupportsNamespaces) { + this.asNamespaceCatalog = (SupportsNamespaces) catalog; + if (options.containsKey("default-namespace")) { + this.defaultNamespace = + Splitter.on('.').splitToList(options.get("default-namespace")).toArray(new String[0]); + } + } + + EnvironmentContext.put(EnvironmentContext.ENGINE_NAME, "spark"); + EnvironmentContext.put( + EnvironmentContext.ENGINE_VERSION, sparkSession.sparkContext().version()); + EnvironmentContext.put(CatalogProperties.APP_ID, sparkSession.sparkContext().applicationId()); + } + + @Override + public String name() { + return catalogName; + } + + private static void commitChanges( + org.apache.iceberg.Table table, + SetProperty setLocation, + SetProperty setSnapshotId, + SetProperty pickSnapshotId, + List propertyChanges, + List schemaChanges) { + // don't allow setting the snapshot and picking a commit at the same time because order is + // ambiguous and choosing one order leads to different results + Preconditions.checkArgument( + setSnapshotId == null || pickSnapshotId == null, + "Cannot set the current the current snapshot ID and cherry-pick snapshot changes"); + + if (setSnapshotId != null) { + long newSnapshotId = Long.parseLong(setSnapshotId.value()); + table.manageSnapshots().setCurrentSnapshot(newSnapshotId).commit(); + } + + // if updating the table snapshot, perform that update first in case it fails + if (pickSnapshotId != null) { + long newSnapshotId = Long.parseLong(pickSnapshotId.value()); + table.manageSnapshots().cherrypick(newSnapshotId).commit(); + } + + Transaction transaction = table.newTransaction(); + + if (setLocation != null) { + transaction.updateLocation().setLocation(setLocation.value()).commit(); + } + + if (!propertyChanges.isEmpty()) { + Spark3Util.applyPropertyChanges(transaction.updateProperties(), propertyChanges).commit(); + } + + if (!schemaChanges.isEmpty()) { + Spark3Util.applySchemaChanges(transaction.updateSchema(), schemaChanges).commit(); + } + + transaction.commitTransaction(); + } + + private static boolean isPathIdentifier(Identifier ident) { + return ident instanceof PathIdentifier; + } + + private static void checkNotPathIdentifier(Identifier identifier, String method) { + if (identifier instanceof PathIdentifier) { + throw new IllegalArgumentException( + String.format( + "Cannot pass path based identifier to %s method. %s is a path.", method, identifier)); + } + } + + private Table load(Identifier ident) { + if (isPathIdentifier(ident)) { + return loadFromPathIdentifier((PathIdentifier) ident); + } + + try { + org.apache.iceberg.Table table = icebergCatalog.loadTable(buildIdentifier(ident)); + return new SparkTable(table, !cacheEnabled); + + } catch (org.apache.iceberg.exceptions.NoSuchTableException e) { + if (ident.namespace().length == 0) { + throw e; + } + + // if the original load didn't work, try using the namespace as an identifier because + // the original identifier may include a snapshot selector or may point to the changelog + TableIdentifier namespaceAsIdent = buildIdentifier(namespaceToIdentifier(ident.namespace())); + org.apache.iceberg.Table table; + try { + table = icebergCatalog.loadTable(namespaceAsIdent); + } catch (Exception ignored) { + // the namespace does not identify a table, so it cannot be a table with a snapshot selector + // throw the original exception + throw e; + } + + // loading the namespace as a table worked, check the name to see if it is a valid selector + // or if the name points to the changelog + + if (ident.name().equalsIgnoreCase(SparkChangelogTable.TABLE_NAME)) { + return new SparkChangelogTable(table, !cacheEnabled); + } + + Matcher at = AT_TIMESTAMP.matcher(ident.name()); + if (at.matches()) { + long asOfTimestamp = Long.parseLong(at.group(1)); + long snapshotId = SnapshotUtil.snapshotIdAsOfTime(table, asOfTimestamp); + return new SparkTable(table, snapshotId, !cacheEnabled); + } + + Matcher id = SNAPSHOT_ID.matcher(ident.name()); + if (id.matches()) { + long snapshotId = Long.parseLong(id.group(1)); + return new SparkTable(table, snapshotId, !cacheEnabled); + } + + Matcher branch = BRANCH.matcher(ident.name()); + if (branch.matches()) { + return new SparkTable(table, branch.group(1), !cacheEnabled); + } + + Matcher tag = TAG.matcher(ident.name()); + if (tag.matches()) { + Snapshot tagSnapshot = table.snapshot(tag.group(1)); + if (tagSnapshot != null) { + return new SparkTable(table, tagSnapshot.snapshotId(), !cacheEnabled); + } + } + + // the name wasn't a valid snapshot selector and did not point to the changelog + // throw the original exception + throw e; + } + } + + private Pair> parseLocationString(String location) { + int hashIndex = location.lastIndexOf('#'); + if (hashIndex != -1 && !location.endsWith("#")) { + String baseLocation = location.substring(0, hashIndex); + List metadata = COMMA.splitToList(location.substring(hashIndex + 1)); + return Pair.of(baseLocation, metadata); + } else { + return Pair.of(location, ImmutableList.of()); + } + } + + @SuppressWarnings("CyclomaticComplexity") + private Table loadFromPathIdentifier(PathIdentifier ident) { + Pair> parsed = parseLocationString(ident.location()); + + String metadataTableName = null; + Long asOfTimestamp = null; + Long snapshotId = null; + String branch = null; + String tag = null; + boolean isChangelog = false; + + for (String meta : parsed.second()) { + if (meta.equalsIgnoreCase(SparkChangelogTable.TABLE_NAME)) { + isChangelog = true; + continue; + } + + if (MetadataTableType.from(meta) != null) { + metadataTableName = meta; + continue; + } + + Matcher at = AT_TIMESTAMP.matcher(meta); + if (at.matches()) { + asOfTimestamp = Long.parseLong(at.group(1)); + continue; + } + + Matcher id = SNAPSHOT_ID.matcher(meta); + if (id.matches()) { + snapshotId = Long.parseLong(id.group(1)); + continue; + } + + Matcher branchRef = BRANCH.matcher(meta); + if (branchRef.matches()) { + branch = branchRef.group(1); + continue; + } + + Matcher tagRef = TAG.matcher(meta); + if (tagRef.matches()) { + tag = tagRef.group(1); + } + } + + Preconditions.checkArgument( + Stream.of(snapshotId, asOfTimestamp, branch, tag).filter(Objects::nonNull).count() <= 1, + "Can specify only one of snapshot-id (%s), as-of-timestamp (%s), branch (%s), tag (%s)", + snapshotId, + asOfTimestamp, + branch, + tag); + + Preconditions.checkArgument( + !isChangelog || (snapshotId == null && asOfTimestamp == null), + "Cannot specify snapshot-id and as-of-timestamp for changelogs"); + + org.apache.iceberg.Table table = + tables.load(parsed.first() + (metadataTableName != null ? "#" + metadataTableName : "")); + + if (isChangelog) { + return new SparkChangelogTable(table, !cacheEnabled); + + } else if (asOfTimestamp != null) { + long snapshotIdAsOfTime = SnapshotUtil.snapshotIdAsOfTime(table, asOfTimestamp); + return new SparkTable(table, snapshotIdAsOfTime, !cacheEnabled); + + } else if (branch != null) { + return new SparkTable(table, branch, !cacheEnabled); + + } else if (tag != null) { + Snapshot tagSnapshot = table.snapshot(tag); + Preconditions.checkArgument( + tagSnapshot != null, "Cannot find snapshot associated with tag name: %s", tag); + return new SparkTable(table, tagSnapshot.snapshotId(), !cacheEnabled); + + } else { + return new SparkTable(table, snapshotId, !cacheEnabled); + } + } + + private Identifier namespaceToIdentifier(String[] namespace) { + Preconditions.checkArgument( + namespace.length > 0, "Cannot convert empty namespace to identifier"); + String[] ns = Arrays.copyOf(namespace, namespace.length - 1); + String name = namespace[ns.length]; + return Identifier.of(ns, name); + } + + private Catalog.TableBuilder newBuilder(Identifier ident, Schema schema) { + return isPathIdentifier(ident) + ? tables.buildTable(((PathIdentifier) ident).location(), schema) + : icebergCatalog.buildTable(buildIdentifier(ident), schema); + } + + @Override + public Catalog icebergCatalog() { + return icebergCatalog; + } +} diff --git a/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/SparkConfParser.java b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/SparkConfParser.java new file mode 100644 index 000000000000..8242e67da64b --- /dev/null +++ b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/SparkConfParser.java @@ -0,0 +1,203 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark; + +import java.util.List; +import java.util.Locale; +import java.util.Map; +import java.util.function.Function; +import org.apache.iceberg.Table; +import org.apache.iceberg.relocated.com.google.common.base.Preconditions; +import org.apache.iceberg.relocated.com.google.common.collect.Lists; +import org.apache.spark.sql.RuntimeConfig; +import org.apache.spark.sql.SparkSession; + +class SparkConfParser { + + private final Map properties; + private final RuntimeConfig sessionConf; + private final Map options; + + SparkConfParser(SparkSession spark, Table table, Map options) { + this.properties = table.properties(); + this.sessionConf = spark.conf(); + this.options = options; + } + + public BooleanConfParser booleanConf() { + return new BooleanConfParser(); + } + + public IntConfParser intConf() { + return new IntConfParser(); + } + + public LongConfParser longConf() { + return new LongConfParser(); + } + + public StringConfParser stringConf() { + return new StringConfParser(); + } + + class BooleanConfParser extends ConfParser { + private Boolean defaultValue; + + @Override + protected BooleanConfParser self() { + return this; + } + + public BooleanConfParser defaultValue(boolean value) { + this.defaultValue = value; + return self(); + } + + public BooleanConfParser defaultValue(String value) { + this.defaultValue = Boolean.parseBoolean(value); + return self(); + } + + public boolean parse() { + Preconditions.checkArgument(defaultValue != null, "Default value cannot be null"); + return parse(Boolean::parseBoolean, defaultValue); + } + } + + class IntConfParser extends ConfParser { + private Integer defaultValue; + + @Override + protected IntConfParser self() { + return this; + } + + public IntConfParser defaultValue(int value) { + this.defaultValue = value; + return self(); + } + + public int parse() { + Preconditions.checkArgument(defaultValue != null, "Default value cannot be null"); + return parse(Integer::parseInt, defaultValue); + } + + public Integer parseOptional() { + return parse(Integer::parseInt, null); + } + } + + class LongConfParser extends ConfParser { + private Long defaultValue; + + @Override + protected LongConfParser self() { + return this; + } + + public LongConfParser defaultValue(long value) { + this.defaultValue = value; + return self(); + } + + public long parse() { + Preconditions.checkArgument(defaultValue != null, "Default value cannot be null"); + return parse(Long::parseLong, defaultValue); + } + + public Long parseOptional() { + return parse(Long::parseLong, null); + } + } + + class StringConfParser extends ConfParser { + private String defaultValue; + + @Override + protected StringConfParser self() { + return this; + } + + public StringConfParser defaultValue(String value) { + this.defaultValue = value; + return self(); + } + + public String parse() { + Preconditions.checkArgument(defaultValue != null, "Default value cannot be null"); + return parse(Function.identity(), defaultValue); + } + + public String parseOptional() { + return parse(Function.identity(), null); + } + } + + abstract class ConfParser { + private final List optionNames = Lists.newArrayList(); + private String sessionConfName; + private String tablePropertyName; + + protected abstract ThisT self(); + + public ThisT option(String name) { + this.optionNames.add(name); + return self(); + } + + public ThisT sessionConf(String name) { + this.sessionConfName = name; + return self(); + } + + public ThisT tableProperty(String name) { + this.tablePropertyName = name; + return self(); + } + + protected T parse(Function conversion, T defaultValue) { + if (!optionNames.isEmpty()) { + for (String optionName : optionNames) { + // use lower case comparison as DataSourceOptions.asMap() in Spark 2 returns a lower case + // map + String optionValue = options.get(optionName.toLowerCase(Locale.ROOT)); + if (optionValue != null) { + return conversion.apply(optionValue); + } + } + } + + if (sessionConfName != null) { + String sessionConfValue = sessionConf.get(sessionConfName, null); + if (sessionConfValue != null) { + return conversion.apply(sessionConfValue); + } + } + + if (tablePropertyName != null) { + String propertyValue = properties.get(tablePropertyName); + if (propertyValue != null) { + return conversion.apply(propertyValue); + } + } + + return defaultValue; + } + } +} diff --git a/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/SparkDataFile.java b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/SparkDataFile.java new file mode 100644 index 000000000000..76796825894a --- /dev/null +++ b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/SparkDataFile.java @@ -0,0 +1,225 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark; + +import java.nio.ByteBuffer; +import java.util.List; +import java.util.Map; +import org.apache.iceberg.DataFile; +import org.apache.iceberg.FileFormat; +import org.apache.iceberg.StructLike; +import org.apache.iceberg.relocated.com.google.common.collect.Maps; +import org.apache.iceberg.types.Type; +import org.apache.iceberg.types.Types; +import org.apache.iceberg.util.StructProjection; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.types.StructType; + +public class SparkDataFile implements DataFile { + + private final int filePathPosition; + private final int fileFormatPosition; + private final int partitionPosition; + private final int recordCountPosition; + private final int fileSizeInBytesPosition; + private final int columnSizesPosition; + private final int valueCountsPosition; + private final int nullValueCountsPosition; + private final int nanValueCountsPosition; + private final int lowerBoundsPosition; + private final int upperBoundsPosition; + private final int keyMetadataPosition; + private final int splitOffsetsPosition; + private final int sortOrderIdPosition; + private final Type lowerBoundsType; + private final Type upperBoundsType; + private final Type keyMetadataType; + + private final SparkStructLike wrappedPartition; + private final StructLike partitionProjection; + private Row wrapped; + + public SparkDataFile(Types.StructType type, StructType sparkType) { + this(type, null, sparkType); + } + + public SparkDataFile( + Types.StructType type, Types.StructType projectedType, StructType sparkType) { + this.lowerBoundsType = type.fieldType("lower_bounds"); + this.upperBoundsType = type.fieldType("upper_bounds"); + this.keyMetadataType = type.fieldType("key_metadata"); + + Types.StructType partitionType = type.fieldType("partition").asStructType(); + this.wrappedPartition = new SparkStructLike(partitionType); + + if (projectedType != null) { + Types.StructType projectedPartitionType = projectedType.fieldType("partition").asStructType(); + this.partitionProjection = + StructProjection.create(partitionType, projectedPartitionType).wrap(wrappedPartition); + } else { + this.partitionProjection = wrappedPartition; + } + + Map positions = Maps.newHashMap(); + type.fields() + .forEach( + field -> { + String fieldName = field.name(); + positions.put(fieldName, fieldPosition(fieldName, sparkType)); + }); + + filePathPosition = positions.get("file_path"); + fileFormatPosition = positions.get("file_format"); + partitionPosition = positions.get("partition"); + recordCountPosition = positions.get("record_count"); + fileSizeInBytesPosition = positions.get("file_size_in_bytes"); + columnSizesPosition = positions.get("column_sizes"); + valueCountsPosition = positions.get("value_counts"); + nullValueCountsPosition = positions.get("null_value_counts"); + nanValueCountsPosition = positions.get("nan_value_counts"); + lowerBoundsPosition = positions.get("lower_bounds"); + upperBoundsPosition = positions.get("upper_bounds"); + keyMetadataPosition = positions.get("key_metadata"); + splitOffsetsPosition = positions.get("split_offsets"); + sortOrderIdPosition = positions.get("sort_order_id"); + } + + public SparkDataFile wrap(Row row) { + this.wrapped = row; + if (wrappedPartition.size() > 0) { + this.wrappedPartition.wrap(row.getAs(partitionPosition)); + } + return this; + } + + @Override + public Long pos() { + return null; + } + + @Override + public int specId() { + return -1; + } + + @Override + public CharSequence path() { + return wrapped.getAs(filePathPosition); + } + + @Override + public FileFormat format() { + return FileFormat.fromString(wrapped.getString(fileFormatPosition)); + } + + @Override + public StructLike partition() { + return partitionProjection; + } + + @Override + public long recordCount() { + return wrapped.getAs(recordCountPosition); + } + + @Override + public long fileSizeInBytes() { + return wrapped.getAs(fileSizeInBytesPosition); + } + + @Override + public Map columnSizes() { + return wrapped.isNullAt(columnSizesPosition) ? null : wrapped.getJavaMap(columnSizesPosition); + } + + @Override + public Map valueCounts() { + return wrapped.isNullAt(valueCountsPosition) ? null : wrapped.getJavaMap(valueCountsPosition); + } + + @Override + public Map nullValueCounts() { + return wrapped.isNullAt(nullValueCountsPosition) + ? null + : wrapped.getJavaMap(nullValueCountsPosition); + } + + @Override + public Map nanValueCounts() { + return wrapped.isNullAt(nanValueCountsPosition) + ? null + : wrapped.getJavaMap(nanValueCountsPosition); + } + + @Override + public Map lowerBounds() { + Map lowerBounds = + wrapped.isNullAt(lowerBoundsPosition) ? null : wrapped.getJavaMap(lowerBoundsPosition); + return convert(lowerBoundsType, lowerBounds); + } + + @Override + public Map upperBounds() { + Map upperBounds = + wrapped.isNullAt(upperBoundsPosition) ? null : wrapped.getJavaMap(upperBoundsPosition); + return convert(upperBoundsType, upperBounds); + } + + @Override + public ByteBuffer keyMetadata() { + return convert(keyMetadataType, wrapped.get(keyMetadataPosition)); + } + + @Override + public DataFile copy() { + throw new UnsupportedOperationException("Not implemented: copy"); + } + + @Override + public DataFile copyWithoutStats() { + throw new UnsupportedOperationException("Not implemented: copyWithoutStats"); + } + + @Override + public List splitOffsets() { + return wrapped.isNullAt(splitOffsetsPosition) ? null : wrapped.getList(splitOffsetsPosition); + } + + @Override + public Integer sortOrderId() { + return wrapped.getAs(sortOrderIdPosition); + } + + private int fieldPosition(String name, StructType sparkType) { + try { + return sparkType.fieldIndex(name); + } catch (IllegalArgumentException e) { + // the partition field is absent for unpartitioned tables + if (name.equals("partition") && wrappedPartition.size() == 0) { + return -1; + } + throw e; + } + } + + @SuppressWarnings("unchecked") + private T convert(Type valueType, Object value) { + return (T) SparkValueConverter.convert(valueType, value); + } +} diff --git a/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/SparkDistributionAndOrderingUtil.java b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/SparkDistributionAndOrderingUtil.java new file mode 100644 index 000000000000..f2c8f6e26ca4 --- /dev/null +++ b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/SparkDistributionAndOrderingUtil.java @@ -0,0 +1,239 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark; + +import static org.apache.spark.sql.connector.write.RowLevelOperation.Command.DELETE; +import static org.apache.spark.sql.connector.write.RowLevelOperation.Command.UPDATE; + +import java.util.List; +import org.apache.iceberg.DistributionMode; +import org.apache.iceberg.MetadataColumns; +import org.apache.iceberg.Table; +import org.apache.iceberg.relocated.com.google.common.collect.ObjectArrays; +import org.apache.iceberg.transforms.SortOrderVisitor; +import org.apache.iceberg.util.SortOrderUtil; +import org.apache.spark.sql.connector.distributions.ClusteredDistribution; +import org.apache.spark.sql.connector.distributions.Distribution; +import org.apache.spark.sql.connector.distributions.Distributions; +import org.apache.spark.sql.connector.distributions.OrderedDistribution; +import org.apache.spark.sql.connector.distributions.UnspecifiedDistribution; +import org.apache.spark.sql.connector.expressions.Expression; +import org.apache.spark.sql.connector.expressions.Expressions; +import org.apache.spark.sql.connector.expressions.NamedReference; +import org.apache.spark.sql.connector.expressions.SortDirection; +import org.apache.spark.sql.connector.expressions.SortOrder; +import org.apache.spark.sql.connector.write.RowLevelOperation.Command; + +public class SparkDistributionAndOrderingUtil { + + private static final NamedReference SPEC_ID = Expressions.column(MetadataColumns.SPEC_ID.name()); + private static final NamedReference PARTITION = + Expressions.column(MetadataColumns.PARTITION_COLUMN_NAME); + private static final NamedReference FILE_PATH = + Expressions.column(MetadataColumns.FILE_PATH.name()); + private static final NamedReference ROW_POSITION = + Expressions.column(MetadataColumns.ROW_POSITION.name()); + + private static final SortOrder SPEC_ID_ORDER = Expressions.sort(SPEC_ID, SortDirection.ASCENDING); + private static final SortOrder PARTITION_ORDER = + Expressions.sort(PARTITION, SortDirection.ASCENDING); + private static final SortOrder FILE_PATH_ORDER = + Expressions.sort(FILE_PATH, SortDirection.ASCENDING); + private static final SortOrder ROW_POSITION_ORDER = + Expressions.sort(ROW_POSITION, SortDirection.ASCENDING); + + private static final SortOrder[] EXISTING_FILE_ORDERING = + new SortOrder[] {FILE_PATH_ORDER, ROW_POSITION_ORDER}; + private static final SortOrder[] POSITION_DELETE_ORDERING = + new SortOrder[] {SPEC_ID_ORDER, PARTITION_ORDER, FILE_PATH_ORDER, ROW_POSITION_ORDER}; + + private SparkDistributionAndOrderingUtil() {} + + public static Distribution buildRequiredDistribution( + Table table, DistributionMode distributionMode) { + switch (distributionMode) { + case NONE: + return Distributions.unspecified(); + + case HASH: + return Distributions.clustered(Spark3Util.toTransforms(table.spec())); + + case RANGE: + return Distributions.ordered(buildTableOrdering(table)); + + default: + throw new IllegalArgumentException("Unsupported distribution mode: " + distributionMode); + } + } + + public static SortOrder[] buildRequiredOrdering(Table table, Distribution distribution) { + if (distribution instanceof OrderedDistribution) { + OrderedDistribution orderedDistribution = (OrderedDistribution) distribution; + return orderedDistribution.ordering(); + } else { + return buildTableOrdering(table); + } + } + + public static Distribution buildCopyOnWriteDistribution( + Table table, Command command, DistributionMode distributionMode) { + if (command == DELETE || command == UPDATE) { + return buildCopyOnWriteDeleteUpdateDistribution(table, distributionMode); + } else { + return buildRequiredDistribution(table, distributionMode); + } + } + + private static Distribution buildCopyOnWriteDeleteUpdateDistribution( + Table table, DistributionMode distributionMode) { + switch (distributionMode) { + case NONE: + return Distributions.unspecified(); + + case HASH: + Expression[] clustering = new Expression[] {FILE_PATH}; + return Distributions.clustered(clustering); + + case RANGE: + SortOrder[] tableOrdering = buildTableOrdering(table); + if (table.sortOrder().isSorted()) { + return Distributions.ordered(tableOrdering); + } else { + SortOrder[] ordering = + ObjectArrays.concat(tableOrdering, EXISTING_FILE_ORDERING, SortOrder.class); + return Distributions.ordered(ordering); + } + + default: + throw new IllegalArgumentException("Unexpected distribution mode: " + distributionMode); + } + } + + public static SortOrder[] buildCopyOnWriteOrdering( + Table table, Command command, Distribution distribution) { + if (command == DELETE || command == UPDATE) { + return buildCopyOnWriteDeleteUpdateOrdering(table, distribution); + } else { + return buildRequiredOrdering(table, distribution); + } + } + + private static SortOrder[] buildCopyOnWriteDeleteUpdateOrdering( + Table table, Distribution distribution) { + if (distribution instanceof UnspecifiedDistribution) { + return buildTableOrdering(table); + + } else if (distribution instanceof ClusteredDistribution) { + SortOrder[] tableOrdering = buildTableOrdering(table); + if (table.sortOrder().isSorted()) { + return tableOrdering; + } else { + return ObjectArrays.concat(tableOrdering, EXISTING_FILE_ORDERING, SortOrder.class); + } + + } else if (distribution instanceof OrderedDistribution) { + OrderedDistribution orderedDistribution = (OrderedDistribution) distribution; + return orderedDistribution.ordering(); + + } else { + throw new IllegalArgumentException( + "Unexpected distribution type: " + distribution.getClass().getName()); + } + } + + public static Distribution buildPositionDeltaDistribution( + Table table, Command command, DistributionMode distributionMode) { + if (command == DELETE || command == UPDATE) { + return buildPositionDeleteUpdateDistribution(distributionMode); + } else { + return buildPositionMergeDistribution(table, distributionMode); + } + } + + private static Distribution buildPositionMergeDistribution( + Table table, DistributionMode distributionMode) { + switch (distributionMode) { + case NONE: + return Distributions.unspecified(); + + case HASH: + if (table.spec().isUnpartitioned()) { + Expression[] clustering = new Expression[] {SPEC_ID, PARTITION, FILE_PATH}; + return Distributions.clustered(clustering); + } else { + Distribution dataDistribution = buildRequiredDistribution(table, distributionMode); + Expression[] dataClustering = ((ClusteredDistribution) dataDistribution).clustering(); + Expression[] deleteClustering = new Expression[] {SPEC_ID, PARTITION}; + Expression[] clustering = + ObjectArrays.concat(deleteClustering, dataClustering, Expression.class); + return Distributions.clustered(clustering); + } + + case RANGE: + Distribution dataDistribution = buildRequiredDistribution(table, distributionMode); + SortOrder[] dataOrdering = ((OrderedDistribution) dataDistribution).ordering(); + SortOrder[] deleteOrdering = + new SortOrder[] {SPEC_ID_ORDER, PARTITION_ORDER, FILE_PATH_ORDER}; + SortOrder[] ordering = ObjectArrays.concat(deleteOrdering, dataOrdering, SortOrder.class); + return Distributions.ordered(ordering); + + default: + throw new IllegalArgumentException("Unexpected distribution mode: " + distributionMode); + } + } + + private static Distribution buildPositionDeleteUpdateDistribution( + DistributionMode distributionMode) { + switch (distributionMode) { + case NONE: + return Distributions.unspecified(); + + case HASH: + Expression[] clustering = new Expression[] {SPEC_ID, PARTITION}; + return Distributions.clustered(clustering); + + case RANGE: + SortOrder[] ordering = new SortOrder[] {SPEC_ID_ORDER, PARTITION_ORDER, FILE_PATH_ORDER}; + return Distributions.ordered(ordering); + + default: + throw new IllegalArgumentException("Unsupported distribution mode: " + distributionMode); + } + } + + public static SortOrder[] buildPositionDeltaOrdering(Table table, Command command) { + if (command == DELETE || command == UPDATE) { + return POSITION_DELETE_ORDERING; + } else { + // all metadata columns like _spec_id, _file, _pos will be null for new data records + SortOrder[] dataOrdering = buildTableOrdering(table); + return ObjectArrays.concat(POSITION_DELETE_ORDERING, dataOrdering, SortOrder.class); + } + } + + public static SortOrder[] convert(org.apache.iceberg.SortOrder sortOrder) { + List converted = + SortOrderVisitor.visit(sortOrder, new SortOrderToSpark(sortOrder.schema())); + return converted.toArray(new SortOrder[0]); + } + + private static SortOrder[] buildTableOrdering(Table table) { + return convert(SortOrderUtil.buildSortOrder(table)); + } +} diff --git a/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/SparkExceptionUtil.java b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/SparkExceptionUtil.java new file mode 100644 index 000000000000..5c6fe3e0ff96 --- /dev/null +++ b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/SparkExceptionUtil.java @@ -0,0 +1,64 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark; + +import com.google.errorprone.annotations.FormatMethod; +import java.io.IOException; +import org.apache.iceberg.exceptions.NoSuchNamespaceException; +import org.apache.iceberg.exceptions.NoSuchTableException; +import org.apache.iceberg.exceptions.RuntimeIOException; +import org.apache.iceberg.exceptions.ValidationException; +import org.apache.spark.sql.AnalysisException; + +public class SparkExceptionUtil { + + private SparkExceptionUtil() {} + + /** + * Converts checked exceptions to unchecked exceptions. + * + * @param cause a checked exception object which is to be converted to its unchecked equivalent. + * @param message exception message as a format string + * @param args format specifiers + * @return unchecked exception. + */ + @FormatMethod + public static RuntimeException toUncheckedException( + final Throwable cause, final String message, final Object... args) { + // Parameters are required to be final to help @FormatMethod do static analysis + if (cause instanceof RuntimeException) { + return (RuntimeException) cause; + + } else if (cause instanceof org.apache.spark.sql.catalyst.analysis.NoSuchDatabaseException) { + return new NoSuchNamespaceException(cause, message, args); + + } else if (cause instanceof org.apache.spark.sql.catalyst.analysis.NoSuchTableException) { + return new NoSuchTableException(cause, message, args); + + } else if (cause instanceof AnalysisException) { + return new ValidationException(cause, message, args); + + } else if (cause instanceof IOException) { + return new RuntimeIOException((IOException) cause, message, args); + + } else { + return new RuntimeException(String.format(message, args), cause); + } + } +} diff --git a/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/SparkFilters.java b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/SparkFilters.java new file mode 100644 index 000000000000..f70730f9cc13 --- /dev/null +++ b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/SparkFilters.java @@ -0,0 +1,277 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark; + +import static org.apache.iceberg.expressions.Expressions.and; +import static org.apache.iceberg.expressions.Expressions.equal; +import static org.apache.iceberg.expressions.Expressions.greaterThan; +import static org.apache.iceberg.expressions.Expressions.greaterThanOrEqual; +import static org.apache.iceberg.expressions.Expressions.in; +import static org.apache.iceberg.expressions.Expressions.isNaN; +import static org.apache.iceberg.expressions.Expressions.isNull; +import static org.apache.iceberg.expressions.Expressions.lessThan; +import static org.apache.iceberg.expressions.Expressions.lessThanOrEqual; +import static org.apache.iceberg.expressions.Expressions.not; +import static org.apache.iceberg.expressions.Expressions.notIn; +import static org.apache.iceberg.expressions.Expressions.notNull; +import static org.apache.iceberg.expressions.Expressions.or; +import static org.apache.iceberg.expressions.Expressions.startsWith; + +import java.sql.Date; +import java.sql.Timestamp; +import java.time.Instant; +import java.time.LocalDate; +import java.util.Map; +import java.util.Objects; +import java.util.regex.Matcher; +import java.util.regex.Pattern; +import java.util.stream.Collectors; +import java.util.stream.Stream; +import org.apache.iceberg.expressions.Expression; +import org.apache.iceberg.expressions.Expression.Operation; +import org.apache.iceberg.expressions.Expressions; +import org.apache.iceberg.relocated.com.google.common.base.Preconditions; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap; +import org.apache.iceberg.util.NaNUtil; +import org.apache.spark.sql.catalyst.util.DateTimeUtils; +import org.apache.spark.sql.sources.AlwaysFalse; +import org.apache.spark.sql.sources.AlwaysFalse$; +import org.apache.spark.sql.sources.AlwaysTrue; +import org.apache.spark.sql.sources.AlwaysTrue$; +import org.apache.spark.sql.sources.And; +import org.apache.spark.sql.sources.EqualNullSafe; +import org.apache.spark.sql.sources.EqualTo; +import org.apache.spark.sql.sources.Filter; +import org.apache.spark.sql.sources.GreaterThan; +import org.apache.spark.sql.sources.GreaterThanOrEqual; +import org.apache.spark.sql.sources.In; +import org.apache.spark.sql.sources.IsNotNull; +import org.apache.spark.sql.sources.IsNull; +import org.apache.spark.sql.sources.LessThan; +import org.apache.spark.sql.sources.LessThanOrEqual; +import org.apache.spark.sql.sources.Not; +import org.apache.spark.sql.sources.Or; +import org.apache.spark.sql.sources.StringStartsWith; + +public class SparkFilters { + + private static final Pattern BACKTICKS_PATTERN = Pattern.compile("([`])(.|$)"); + + private SparkFilters() {} + + private static final Map, Operation> FILTERS = + ImmutableMap., Operation>builder() + .put(AlwaysTrue.class, Operation.TRUE) + .put(AlwaysTrue$.class, Operation.TRUE) + .put(AlwaysFalse$.class, Operation.FALSE) + .put(AlwaysFalse.class, Operation.FALSE) + .put(EqualTo.class, Operation.EQ) + .put(EqualNullSafe.class, Operation.EQ) + .put(GreaterThan.class, Operation.GT) + .put(GreaterThanOrEqual.class, Operation.GT_EQ) + .put(LessThan.class, Operation.LT) + .put(LessThanOrEqual.class, Operation.LT_EQ) + .put(In.class, Operation.IN) + .put(IsNull.class, Operation.IS_NULL) + .put(IsNotNull.class, Operation.NOT_NULL) + .put(And.class, Operation.AND) + .put(Or.class, Operation.OR) + .put(Not.class, Operation.NOT) + .put(StringStartsWith.class, Operation.STARTS_WITH) + .buildOrThrow(); + + public static Expression convert(Filter[] filters) { + Expression expression = Expressions.alwaysTrue(); + for (Filter filter : filters) { + Expression converted = convert(filter); + Preconditions.checkArgument( + converted != null, "Cannot convert filter to Iceberg: %s", filter); + expression = Expressions.and(expression, converted); + } + return expression; + } + + public static Expression convert(Filter filter) { + // avoid using a chain of if instanceof statements by mapping to the expression enum. + Operation op = FILTERS.get(filter.getClass()); + if (op != null) { + switch (op) { + case TRUE: + return Expressions.alwaysTrue(); + + case FALSE: + return Expressions.alwaysFalse(); + + case IS_NULL: + IsNull isNullFilter = (IsNull) filter; + return isNull(unquote(isNullFilter.attribute())); + + case NOT_NULL: + IsNotNull notNullFilter = (IsNotNull) filter; + return notNull(unquote(notNullFilter.attribute())); + + case LT: + LessThan lt = (LessThan) filter; + return lessThan(unquote(lt.attribute()), convertLiteral(lt.value())); + + case LT_EQ: + LessThanOrEqual ltEq = (LessThanOrEqual) filter; + return lessThanOrEqual(unquote(ltEq.attribute()), convertLiteral(ltEq.value())); + + case GT: + GreaterThan gt = (GreaterThan) filter; + return greaterThan(unquote(gt.attribute()), convertLiteral(gt.value())); + + case GT_EQ: + GreaterThanOrEqual gtEq = (GreaterThanOrEqual) filter; + return greaterThanOrEqual(unquote(gtEq.attribute()), convertLiteral(gtEq.value())); + + case EQ: // used for both eq and null-safe-eq + if (filter instanceof EqualTo) { + EqualTo eq = (EqualTo) filter; + // comparison with null in normal equality is always null. this is probably a mistake. + Preconditions.checkNotNull( + eq.value(), "Expression is always false (eq is not null-safe): %s", filter); + return handleEqual(unquote(eq.attribute()), eq.value()); + } else { + EqualNullSafe eq = (EqualNullSafe) filter; + if (eq.value() == null) { + return isNull(unquote(eq.attribute())); + } else { + return handleEqual(unquote(eq.attribute()), eq.value()); + } + } + + case IN: + In inFilter = (In) filter; + return in( + unquote(inFilter.attribute()), + Stream.of(inFilter.values()) + .filter(Objects::nonNull) + .map(SparkFilters::convertLiteral) + .collect(Collectors.toList())); + + case NOT: + Not notFilter = (Not) filter; + Filter childFilter = notFilter.child(); + Operation childOp = FILTERS.get(childFilter.getClass()); + if (childOp == Operation.IN) { + // infer an extra notNull predicate for Spark NOT IN filters + // as Iceberg expressions don't follow the 3-value SQL boolean logic + // col NOT IN (1, 2) in Spark is equivalent to notNull(col) && notIn(col, 1, 2) in + // Iceberg + In childInFilter = (In) childFilter; + Expression notIn = + notIn( + unquote(childInFilter.attribute()), + Stream.of(childInFilter.values()) + .map(SparkFilters::convertLiteral) + .collect(Collectors.toList())); + return and(notNull(childInFilter.attribute()), notIn); + } else if (hasNoInFilter(childFilter)) { + Expression child = convert(childFilter); + if (child != null) { + return not(child); + } + } + return null; + + case AND: + { + And andFilter = (And) filter; + Expression left = convert(andFilter.left()); + Expression right = convert(andFilter.right()); + if (left != null && right != null) { + return and(left, right); + } + return null; + } + + case OR: + { + Or orFilter = (Or) filter; + Expression left = convert(orFilter.left()); + Expression right = convert(orFilter.right()); + if (left != null && right != null) { + return or(left, right); + } + return null; + } + + case STARTS_WITH: + { + StringStartsWith stringStartsWith = (StringStartsWith) filter; + return startsWith(unquote(stringStartsWith.attribute()), stringStartsWith.value()); + } + } + } + + return null; + } + + private static Object convertLiteral(Object value) { + if (value instanceof Timestamp) { + return DateTimeUtils.fromJavaTimestamp((Timestamp) value); + } else if (value instanceof Date) { + return DateTimeUtils.fromJavaDate((Date) value); + } else if (value instanceof Instant) { + return DateTimeUtils.instantToMicros((Instant) value); + } else if (value instanceof LocalDate) { + return DateTimeUtils.localDateToDays((LocalDate) value); + } + return value; + } + + private static Expression handleEqual(String attribute, Object value) { + if (NaNUtil.isNaN(value)) { + return isNaN(attribute); + } else { + return equal(attribute, convertLiteral(value)); + } + } + + private static String unquote(String attributeName) { + Matcher matcher = BACKTICKS_PATTERN.matcher(attributeName); + return matcher.replaceAll("$2"); + } + + private static boolean hasNoInFilter(Filter filter) { + Operation op = FILTERS.get(filter.getClass()); + + if (op != null) { + switch (op) { + case AND: + And andFilter = (And) filter; + return hasNoInFilter(andFilter.left()) && hasNoInFilter(andFilter.right()); + case OR: + Or orFilter = (Or) filter; + return hasNoInFilter(orFilter.left()) && hasNoInFilter(orFilter.right()); + case NOT: + Not notFilter = (Not) filter; + return hasNoInFilter(notFilter.child()); + case IN: + return false; + default: + return true; + } + } + + return false; + } +} diff --git a/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/SparkFixupTimestampType.java b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/SparkFixupTimestampType.java new file mode 100644 index 000000000000..b35213501aef --- /dev/null +++ b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/SparkFixupTimestampType.java @@ -0,0 +1,57 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark; + +import org.apache.iceberg.Schema; +import org.apache.iceberg.types.FixupTypes; +import org.apache.iceberg.types.Type; +import org.apache.iceberg.types.TypeUtil; +import org.apache.iceberg.types.Types; + +/** + * By default Spark type {@link org.apache.iceberg.types.Types.TimestampType} should be converted to + * {@link Types.TimestampType#withZone()} iceberg type. But we also can convert {@link + * org.apache.iceberg.types.Types.TimestampType} to {@link Types.TimestampType#withoutZone()} + * iceberg type by setting {@link SparkSQLProperties#USE_TIMESTAMP_WITHOUT_TIME_ZONE_IN_NEW_TABLES} + * to 'true' + */ +class SparkFixupTimestampType extends FixupTypes { + + private SparkFixupTimestampType(Schema referenceSchema) { + super(referenceSchema); + } + + static Schema fixup(Schema schema) { + return new Schema( + TypeUtil.visit(schema, new SparkFixupTimestampType(schema)).asStructType().fields()); + } + + @Override + public Type primitive(Type.PrimitiveType primitive) { + if (primitive.typeId() == Type.TypeID.TIMESTAMP) { + return Types.TimestampType.withoutZone(); + } + return primitive; + } + + @Override + protected boolean fixupPrimitive(Type.PrimitiveType type, Type source) { + return Type.TypeID.TIMESTAMP.equals(type.typeId()); + } +} diff --git a/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/SparkFixupTypes.java b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/SparkFixupTypes.java new file mode 100644 index 000000000000..6c4ec39b20f1 --- /dev/null +++ b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/SparkFixupTypes.java @@ -0,0 +1,63 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark; + +import org.apache.iceberg.Schema; +import org.apache.iceberg.types.FixupTypes; +import org.apache.iceberg.types.Type; +import org.apache.iceberg.types.TypeUtil; + +/** + * Some types, like binary and fixed, are converted to the same Spark type. Conversion back can + * produce only one, which may not be correct. + */ +class SparkFixupTypes extends FixupTypes { + + private SparkFixupTypes(Schema referenceSchema) { + super(referenceSchema); + } + + static Schema fixup(Schema schema, Schema referenceSchema) { + return new Schema( + TypeUtil.visit(schema, new SparkFixupTypes(referenceSchema)).asStructType().fields()); + } + + @Override + protected boolean fixupPrimitive(Type.PrimitiveType type, Type source) { + switch (type.typeId()) { + case STRING: + if (source.typeId() == Type.TypeID.UUID) { + return true; + } + break; + case BINARY: + if (source.typeId() == Type.TypeID.FIXED) { + return true; + } + break; + case TIMESTAMP: + if (source.typeId() == Type.TypeID.TIMESTAMP) { + return true; + } + break; + default: + } + return false; + } +} diff --git a/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/SparkReadConf.java b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/SparkReadConf.java new file mode 100644 index 000000000000..1c1182c4da60 --- /dev/null +++ b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/SparkReadConf.java @@ -0,0 +1,274 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark; + +import java.util.Map; +import org.apache.iceberg.Table; +import org.apache.iceberg.TableProperties; +import org.apache.iceberg.exceptions.ValidationException; +import org.apache.iceberg.hadoop.Util; +import org.apache.iceberg.util.PropertyUtil; +import org.apache.spark.sql.SparkSession; + +/** + * A class for common Iceberg configs for Spark reads. + * + *

If a config is set at multiple levels, the following order of precedence is used (top to + * bottom): + * + *

    + *
  1. Read options + *
  2. Session configuration + *
  3. Table metadata + *
+ * + * The most specific value is set in read options and takes precedence over all other configs. If no + * read option is provided, this class checks the session configuration for any overrides. If no + * applicable value is found in the session configuration, this class uses the table metadata. + * + *

Note this class is NOT meant to be serialized and sent to executors. + */ +public class SparkReadConf { + + private final SparkSession spark; + private final Table table; + private final String branch; + private final Map readOptions; + private final SparkConfParser confParser; + + public SparkReadConf(SparkSession spark, Table table, Map readOptions) { + this(spark, table, null, readOptions); + } + + public SparkReadConf( + SparkSession spark, Table table, String branch, Map readOptions) { + this.spark = spark; + this.table = table; + this.branch = branch; + this.readOptions = readOptions; + this.confParser = new SparkConfParser(spark, table, readOptions); + } + + public boolean caseSensitive() { + return SparkUtil.caseSensitive(spark); + } + + public boolean localityEnabled() { + boolean defaultValue = Util.mayHaveBlockLocations(table.io(), table.location()); + return PropertyUtil.propertyAsBoolean(readOptions, SparkReadOptions.LOCALITY, defaultValue); + } + + public Long snapshotId() { + return confParser.longConf().option(SparkReadOptions.SNAPSHOT_ID).parseOptional(); + } + + public Long asOfTimestamp() { + return confParser.longConf().option(SparkReadOptions.AS_OF_TIMESTAMP).parseOptional(); + } + + public Long startSnapshotId() { + return confParser.longConf().option(SparkReadOptions.START_SNAPSHOT_ID).parseOptional(); + } + + public Long endSnapshotId() { + return confParser.longConf().option(SparkReadOptions.END_SNAPSHOT_ID).parseOptional(); + } + + public String branch() { + String optionBranch = confParser.stringConf().option(SparkReadOptions.BRANCH).parseOptional(); + ValidationException.check( + branch == null || optionBranch == null || optionBranch.equals(branch), + "Must not specify different branches in both table identifier and read option, " + + "got [%s] in identifier and [%s] in options", + branch, + optionBranch); + String inputBranch = branch != null ? branch : optionBranch; + if (inputBranch != null) { + return inputBranch; + } + + boolean wapEnabled = + PropertyUtil.propertyAsBoolean( + table.properties(), TableProperties.WRITE_AUDIT_PUBLISH_ENABLED, false); + if (wapEnabled) { + String wapBranch = spark.conf().get(SparkSQLProperties.WAP_BRANCH, null); + if (wapBranch != null && table.refs().containsKey(wapBranch)) { + return wapBranch; + } + } + + return null; + } + + public String tag() { + return confParser.stringConf().option(SparkReadOptions.TAG).parseOptional(); + } + + public String scanTaskSetId() { + return confParser.stringConf().option(SparkReadOptions.SCAN_TASK_SET_ID).parseOptional(); + } + + public boolean streamingSkipDeleteSnapshots() { + return confParser + .booleanConf() + .option(SparkReadOptions.STREAMING_SKIP_DELETE_SNAPSHOTS) + .defaultValue(SparkReadOptions.STREAMING_SKIP_DELETE_SNAPSHOTS_DEFAULT) + .parse(); + } + + public boolean streamingSkipOverwriteSnapshots() { + return confParser + .booleanConf() + .option(SparkReadOptions.STREAMING_SKIP_OVERWRITE_SNAPSHOTS) + .defaultValue(SparkReadOptions.STREAMING_SKIP_OVERWRITE_SNAPSHOTS_DEFAULT) + .parse(); + } + + public boolean parquetVectorizationEnabled() { + return confParser + .booleanConf() + .option(SparkReadOptions.VECTORIZATION_ENABLED) + .sessionConf(SparkSQLProperties.VECTORIZATION_ENABLED) + .tableProperty(TableProperties.PARQUET_VECTORIZATION_ENABLED) + .defaultValue(TableProperties.PARQUET_VECTORIZATION_ENABLED_DEFAULT) + .parse(); + } + + public int parquetBatchSize() { + return confParser + .intConf() + .option(SparkReadOptions.VECTORIZATION_BATCH_SIZE) + .tableProperty(TableProperties.PARQUET_BATCH_SIZE) + .defaultValue(TableProperties.PARQUET_BATCH_SIZE_DEFAULT) + .parse(); + } + + public boolean orcVectorizationEnabled() { + return confParser + .booleanConf() + .option(SparkReadOptions.VECTORIZATION_ENABLED) + .sessionConf(SparkSQLProperties.VECTORIZATION_ENABLED) + .tableProperty(TableProperties.ORC_VECTORIZATION_ENABLED) + .defaultValue(TableProperties.ORC_VECTORIZATION_ENABLED_DEFAULT) + .parse(); + } + + public int orcBatchSize() { + return confParser + .intConf() + .option(SparkReadOptions.VECTORIZATION_BATCH_SIZE) + .tableProperty(TableProperties.ORC_BATCH_SIZE) + .defaultValue(TableProperties.ORC_BATCH_SIZE_DEFAULT) + .parse(); + } + + public Long splitSizeOption() { + return confParser.longConf().option(SparkReadOptions.SPLIT_SIZE).parseOptional(); + } + + public long splitSize() { + return confParser + .longConf() + .option(SparkReadOptions.SPLIT_SIZE) + .tableProperty(TableProperties.SPLIT_SIZE) + .defaultValue(TableProperties.SPLIT_SIZE_DEFAULT) + .parse(); + } + + public Integer splitLookbackOption() { + return confParser.intConf().option(SparkReadOptions.LOOKBACK).parseOptional(); + } + + public int splitLookback() { + return confParser + .intConf() + .option(SparkReadOptions.LOOKBACK) + .tableProperty(TableProperties.SPLIT_LOOKBACK) + .defaultValue(TableProperties.SPLIT_LOOKBACK_DEFAULT) + .parse(); + } + + public Long splitOpenFileCostOption() { + return confParser.longConf().option(SparkReadOptions.FILE_OPEN_COST).parseOptional(); + } + + public long splitOpenFileCost() { + return confParser + .longConf() + .option(SparkReadOptions.FILE_OPEN_COST) + .tableProperty(TableProperties.SPLIT_OPEN_FILE_COST) + .defaultValue(TableProperties.SPLIT_OPEN_FILE_COST_DEFAULT) + .parse(); + } + + /** + * Enables reading a timestamp without time zone as a timestamp with time zone. + * + *

Generally, this is not safe as a timestamp without time zone is supposed to represent the + * wall-clock time, i.e. no matter the reader/writer timezone 3PM should always be read as 3PM, + * but a timestamp with time zone represents instant semantics, i.e. the timestamp is adjusted so + * that the corresponding time in the reader timezone is displayed. + * + *

When set to false (default), an exception must be thrown while reading a timestamp without + * time zone. + * + * @return boolean indicating if reading timestamps without timezone is allowed + */ + public boolean handleTimestampWithoutZone() { + return confParser + .booleanConf() + .option(SparkReadOptions.HANDLE_TIMESTAMP_WITHOUT_TIMEZONE) + .sessionConf(SparkSQLProperties.HANDLE_TIMESTAMP_WITHOUT_TIMEZONE) + .defaultValue(SparkSQLProperties.HANDLE_TIMESTAMP_WITHOUT_TIMEZONE_DEFAULT) + .parse(); + } + + public Long streamFromTimestamp() { + return confParser + .longConf() + .option(SparkReadOptions.STREAM_FROM_TIMESTAMP) + .defaultValue(Long.MIN_VALUE) + .parse(); + } + + public Long startTimestamp() { + return confParser.longConf().option(SparkReadOptions.START_TIMESTAMP).parseOptional(); + } + + public Long endTimestamp() { + return confParser.longConf().option(SparkReadOptions.END_TIMESTAMP).parseOptional(); + } + + public boolean preserveDataGrouping() { + return confParser + .booleanConf() + .sessionConf(SparkSQLProperties.PRESERVE_DATA_GROUPING) + .defaultValue(SparkSQLProperties.PRESERVE_DATA_GROUPING_DEFAULT) + .parse(); + } + + public boolean aggregatePushDownEnabled() { + return confParser + .booleanConf() + .option(SparkReadOptions.AGGREGATE_PUSH_DOWN_ENABLED) + .sessionConf(SparkSQLProperties.AGGREGATE_PUSH_DOWN_ENABLED) + .defaultValue(SparkSQLProperties.AGGREGATE_PUSH_DOWN_ENABLED_DEFAULT) + .parse(); + } +} diff --git a/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/SparkReadOptions.java b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/SparkReadOptions.java new file mode 100644 index 000000000000..9063e0f9aba6 --- /dev/null +++ b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/SparkReadOptions.java @@ -0,0 +1,95 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark; + +/** Spark DF read options */ +public class SparkReadOptions { + + private SparkReadOptions() {} + + // Snapshot ID of the table snapshot to read + public static final String SNAPSHOT_ID = "snapshot-id"; + + // Start snapshot ID used in incremental scans (exclusive) + public static final String START_SNAPSHOT_ID = "start-snapshot-id"; + + // End snapshot ID used in incremental scans (inclusive) + public static final String END_SNAPSHOT_ID = "end-snapshot-id"; + + // Start timestamp used in multi-snapshot scans (exclusive) + public static final String START_TIMESTAMP = "start-timestamp"; + + // End timestamp used in multi-snapshot scans (inclusive) + public static final String END_TIMESTAMP = "end-timestamp"; + + // A timestamp in milliseconds; the snapshot used will be the snapshot current at this time. + public static final String AS_OF_TIMESTAMP = "as-of-timestamp"; + + // Branch to read from + public static final String BRANCH = "branch"; + + // Tag to read from + public static final String TAG = "tag"; + + // Overrides the table's read.split.target-size and read.split.metadata-target-size + public static final String SPLIT_SIZE = "split-size"; + + // Overrides the table's read.split.planning-lookback + public static final String LOOKBACK = "lookback"; + + // Overrides the table's read.split.open-file-cost + public static final String FILE_OPEN_COST = "file-open-cost"; + + // Overrides the table's read.split.open-file-cost + public static final String VECTORIZATION_ENABLED = "vectorization-enabled"; + + // Overrides the table's read.parquet.vectorization.batch-size + public static final String VECTORIZATION_BATCH_SIZE = "batch-size"; + + // Set ID that is used to fetch scan tasks + public static final String SCAN_TASK_SET_ID = "scan-task-set-id"; + + // skip snapshots of type delete while reading stream out of iceberg table + public static final String STREAMING_SKIP_DELETE_SNAPSHOTS = "streaming-skip-delete-snapshots"; + public static final boolean STREAMING_SKIP_DELETE_SNAPSHOTS_DEFAULT = false; + + // skip snapshots of type overwrite while reading stream out of iceberg table + public static final String STREAMING_SKIP_OVERWRITE_SNAPSHOTS = + "streaming-skip-overwrite-snapshots"; + public static final boolean STREAMING_SKIP_OVERWRITE_SNAPSHOTS_DEFAULT = false; + + // Controls whether to allow reading timestamps without zone info + public static final String HANDLE_TIMESTAMP_WITHOUT_TIMEZONE = + "handle-timestamp-without-timezone"; + + // Controls whether to report locality information to Spark while allocating input partitions + public static final String LOCALITY = "locality"; + + // Timestamp in milliseconds; start a stream from the snapshot that occurs after this timestamp + public static final String STREAM_FROM_TIMESTAMP = "stream-from-timestamp"; + + // Table path + public static final String PATH = "path"; + + public static final String VERSION_AS_OF = "versionAsOf"; + + public static final String TIMESTAMP_AS_OF = "timestampAsOf"; + + public static final String AGGREGATE_PUSH_DOWN_ENABLED = "aggregate-push-down-enabled"; +} diff --git a/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/SparkSQLProperties.java b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/SparkSQLProperties.java new file mode 100644 index 000000000000..d7ff4311c907 --- /dev/null +++ b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/SparkSQLProperties.java @@ -0,0 +1,66 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark; + +public class SparkSQLProperties { + + private SparkSQLProperties() {} + + // Controls whether vectorized reads are enabled + public static final String VECTORIZATION_ENABLED = "spark.sql.iceberg.vectorization.enabled"; + + // Controls whether reading/writing timestamps without timezones is allowed + public static final String HANDLE_TIMESTAMP_WITHOUT_TIMEZONE = + "spark.sql.iceberg.handle-timestamp-without-timezone"; + public static final boolean HANDLE_TIMESTAMP_WITHOUT_TIMEZONE_DEFAULT = false; + + // Controls whether timestamp types for new tables should be stored with timezone info + public static final String USE_TIMESTAMP_WITHOUT_TIME_ZONE_IN_NEW_TABLES = + "spark.sql.iceberg.use-timestamp-without-timezone-in-new-tables"; + public static final boolean USE_TIMESTAMP_WITHOUT_TIME_ZONE_IN_NEW_TABLES_DEFAULT = false; + + // Controls whether to perform the nullability check during writes + public static final String CHECK_NULLABILITY = "spark.sql.iceberg.check-nullability"; + public static final boolean CHECK_NULLABILITY_DEFAULT = true; + + // Controls whether to check the order of fields during writes + public static final String CHECK_ORDERING = "spark.sql.iceberg.check-ordering"; + public static final boolean CHECK_ORDERING_DEFAULT = true; + + // Controls whether to preserve the existing grouping of data while planning splits + public static final String PRESERVE_DATA_GROUPING = + "spark.sql.iceberg.planning.preserve-data-grouping"; + public static final boolean PRESERVE_DATA_GROUPING_DEFAULT = false; + + // Controls whether to push down aggregate (MAX/MIN/COUNT) to Iceberg + public static final String AGGREGATE_PUSH_DOWN_ENABLED = + "spark.sql.iceberg.aggregate-push-down.enabled"; + public static final boolean AGGREGATE_PUSH_DOWN_ENABLED_DEFAULT = true; + + // Controls write distribution mode + public static final String DISTRIBUTION_MODE = "spark.sql.iceberg.distribution-mode"; + + // Controls the WAP ID used for write-audit-publish workflow. + // When set, new snapshots will be staged with this ID in snapshot summary. + public static final String WAP_ID = "spark.wap.id"; + + // Controls the WAP branch used for write-audit-publish workflow. + // When set, new snapshots will be committed to this branch. + public static final String WAP_BRANCH = "spark.wap.branch"; +} diff --git a/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/SparkSchemaUtil.java b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/SparkSchemaUtil.java new file mode 100644 index 000000000000..6075aba7ac5f --- /dev/null +++ b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/SparkSchemaUtil.java @@ -0,0 +1,395 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark; + +import java.util.Collection; +import java.util.Collections; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.function.Function; +import java.util.stream.Collectors; +import org.apache.iceberg.MetadataColumns; +import org.apache.iceberg.PartitionSpec; +import org.apache.iceberg.Schema; +import org.apache.iceberg.exceptions.ValidationException; +import org.apache.iceberg.expressions.Binder; +import org.apache.iceberg.expressions.Expression; +import org.apache.iceberg.relocated.com.google.common.base.Splitter; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableSet; +import org.apache.iceberg.relocated.com.google.common.collect.Lists; +import org.apache.iceberg.relocated.com.google.common.math.LongMath; +import org.apache.iceberg.types.Type; +import org.apache.iceberg.types.TypeUtil; +import org.apache.iceberg.types.Types; +import org.apache.spark.sql.AnalysisException; +import org.apache.spark.sql.SparkSession; +import org.apache.spark.sql.catalog.Column; +import org.apache.spark.sql.types.DataType; +import org.apache.spark.sql.types.StructType; + +/** Helper methods for working with Spark/Hive metadata. */ +public class SparkSchemaUtil { + private SparkSchemaUtil() {} + + /** + * Returns a {@link Schema} for the given table with fresh field ids. + * + *

This creates a Schema for an existing table by looking up the table's schema with Spark and + * converting that schema. Spark/Hive partition columns are included in the schema. + * + * @param spark a Spark session + * @param name a table name and (optional) database + * @return a Schema for the table, if found + */ + public static Schema schemaForTable(SparkSession spark, String name) { + StructType sparkType = spark.table(name).schema(); + Type converted = SparkTypeVisitor.visit(sparkType, new SparkTypeToType(sparkType)); + return new Schema(converted.asNestedType().asStructType().fields()); + } + + /** + * Returns a {@link PartitionSpec} for the given table. + * + *

This creates a partition spec for an existing table by looking up the table's schema and + * creating a spec with identity partitions for each partition column. + * + * @param spark a Spark session + * @param name a table name and (optional) database + * @return a PartitionSpec for the table + * @throws AnalysisException if thrown by the Spark catalog + */ + public static PartitionSpec specForTable(SparkSession spark, String name) + throws AnalysisException { + List parts = Lists.newArrayList(Splitter.on('.').limit(2).split(name)); + String db = parts.size() == 1 ? "default" : parts.get(0); + String table = parts.get(parts.size() == 1 ? 0 : 1); + + PartitionSpec spec = + identitySpec( + schemaForTable(spark, name), spark.catalog().listColumns(db, table).collectAsList()); + return spec == null ? PartitionSpec.unpartitioned() : spec; + } + + /** + * Convert a {@link Schema} to a {@link DataType Spark type}. + * + * @param schema a Schema + * @return the equivalent Spark type + * @throws IllegalArgumentException if the type cannot be converted to Spark + */ + public static StructType convert(Schema schema) { + return (StructType) TypeUtil.visit(schema, new TypeToSparkType()); + } + + /** + * Convert a {@link Type} to a {@link DataType Spark type}. + * + * @param type a Type + * @return the equivalent Spark type + * @throws IllegalArgumentException if the type cannot be converted to Spark + */ + public static DataType convert(Type type) { + return TypeUtil.visit(type, new TypeToSparkType()); + } + + /** + * Convert a Spark {@link StructType struct} to a {@link Schema} with new field ids. + * + *

This conversion assigns fresh ids. + * + *

Some data types are represented as the same Spark type. These are converted to a default + * type. + * + *

To convert using a reference schema for field ids and ambiguous types, use {@link + * #convert(Schema, StructType)}. + * + * @param sparkType a Spark StructType + * @return the equivalent Schema + * @throws IllegalArgumentException if the type cannot be converted + */ + public static Schema convert(StructType sparkType) { + return convert(sparkType, false); + } + + /** + * Convert a Spark {@link StructType struct} to a {@link Schema} with new field ids. + * + *

This conversion assigns fresh ids. + * + *

Some data types are represented as the same Spark type. These are converted to a default + * type. + * + *

To convert using a reference schema for field ids and ambiguous types, use {@link + * #convert(Schema, StructType)}. + * + * @param sparkType a Spark StructType + * @param useTimestampWithoutZone boolean flag indicates that timestamp should be stored without + * timezone + * @return the equivalent Schema + * @throws IllegalArgumentException if the type cannot be converted + */ + public static Schema convert(StructType sparkType, boolean useTimestampWithoutZone) { + Type converted = SparkTypeVisitor.visit(sparkType, new SparkTypeToType(sparkType)); + Schema schema = new Schema(converted.asNestedType().asStructType().fields()); + if (useTimestampWithoutZone) { + schema = SparkFixupTimestampType.fixup(schema); + } + return schema; + } + + /** + * Convert a Spark {@link DataType struct} to a {@link Type} with new field ids. + * + *

This conversion assigns fresh ids. + * + *

Some data types are represented as the same Spark type. These are converted to a default + * type. + * + *

To convert using a reference schema for field ids and ambiguous types, use {@link + * #convert(Schema, StructType)}. + * + * @param sparkType a Spark DataType + * @return the equivalent Type + * @throws IllegalArgumentException if the type cannot be converted + */ + public static Type convert(DataType sparkType) { + return SparkTypeVisitor.visit(sparkType, new SparkTypeToType()); + } + + /** + * Convert a Spark {@link StructType struct} to a {@link Schema} based on the given schema. + * + *

This conversion does not assign new ids; it uses ids from the base schema. + * + *

Data types, field order, and nullability will match the spark type. This conversion may + * return a schema that is not compatible with base schema. + * + * @param baseSchema a Schema on which conversion is based + * @param sparkType a Spark StructType + * @return the equivalent Schema + * @throws IllegalArgumentException if the type cannot be converted or there are missing ids + */ + public static Schema convert(Schema baseSchema, StructType sparkType) { + return convert(baseSchema, sparkType, true); + } + + /** + * Convert a Spark {@link StructType struct} to a {@link Schema} based on the given schema. + * + *

This conversion does not assign new ids; it uses ids from the base schema. + * + *

Data types, field order, and nullability will match the spark type. This conversion may + * return a schema that is not compatible with base schema. + * + * @param baseSchema a Schema on which conversion is based + * @param sparkType a Spark StructType + * @param caseSensitive when false, the case of schema fields is ignored + * @return the equivalent Schema + * @throws IllegalArgumentException if the type cannot be converted or there are missing ids + */ + public static Schema convert(Schema baseSchema, StructType sparkType, boolean caseSensitive) { + // convert to a type with fresh ids + Types.StructType struct = + SparkTypeVisitor.visit(sparkType, new SparkTypeToType(sparkType)).asStructType(); + // reassign ids to match the base schema + Schema schema = TypeUtil.reassignIds(new Schema(struct.fields()), baseSchema, caseSensitive); + // fix types that can't be represented in Spark (UUID and Fixed) + return SparkFixupTypes.fixup(schema, baseSchema); + } + + /** + * Convert a Spark {@link StructType struct} to a {@link Schema} based on the given schema. + * + *

This conversion will assign new ids for fields that are not found in the base schema. + * + *

Data types, field order, and nullability will match the spark type. This conversion may + * return a schema that is not compatible with base schema. + * + * @param baseSchema a Schema on which conversion is based + * @param sparkType a Spark StructType + * @return the equivalent Schema + * @throws IllegalArgumentException if the type cannot be converted or there are missing ids + */ + public static Schema convertWithFreshIds(Schema baseSchema, StructType sparkType) { + return convertWithFreshIds(baseSchema, sparkType, true); + } + + /** + * Convert a Spark {@link StructType struct} to a {@link Schema} based on the given schema. + * + *

This conversion will assign new ids for fields that are not found in the base schema. + * + *

Data types, field order, and nullability will match the spark type. This conversion may + * return a schema that is not compatible with base schema. + * + * @param baseSchema a Schema on which conversion is based + * @param sparkType a Spark StructType + * @param caseSensitive when false, case of field names in schema is ignored + * @return the equivalent Schema + * @throws IllegalArgumentException if the type cannot be converted or there are missing ids + */ + public static Schema convertWithFreshIds( + Schema baseSchema, StructType sparkType, boolean caseSensitive) { + // convert to a type with fresh ids + Types.StructType struct = + SparkTypeVisitor.visit(sparkType, new SparkTypeToType(sparkType)).asStructType(); + // reassign ids to match the base schema + Schema schema = + TypeUtil.reassignOrRefreshIds(new Schema(struct.fields()), baseSchema, caseSensitive); + // fix types that can't be represented in Spark (UUID and Fixed) + return SparkFixupTypes.fixup(schema, baseSchema); + } + + /** + * Prune columns from a {@link Schema} using a {@link StructType Spark type} projection. + * + *

This requires that the Spark type is a projection of the Schema. Nullability and types must + * match. + * + * @param schema a Schema + * @param requestedType a projection of the Spark representation of the Schema + * @return a Schema corresponding to the Spark projection + * @throws IllegalArgumentException if the Spark type does not match the Schema + */ + public static Schema prune(Schema schema, StructType requestedType) { + return new Schema( + TypeUtil.visit(schema, new PruneColumnsWithoutReordering(requestedType, ImmutableSet.of())) + .asNestedType() + .asStructType() + .fields()); + } + + /** + * Prune columns from a {@link Schema} using a {@link StructType Spark type} projection. + * + *

This requires that the Spark type is a projection of the Schema. Nullability and types must + * match. + * + *

The filters list of {@link Expression} is used to ensure that columns referenced by filters + * are projected. + * + * @param schema a Schema + * @param requestedType a projection of the Spark representation of the Schema + * @param filters a list of filters + * @return a Schema corresponding to the Spark projection + * @throws IllegalArgumentException if the Spark type does not match the Schema + */ + public static Schema prune(Schema schema, StructType requestedType, List filters) { + Set filterRefs = Binder.boundReferences(schema.asStruct(), filters, true); + return new Schema( + TypeUtil.visit(schema, new PruneColumnsWithoutReordering(requestedType, filterRefs)) + .asNestedType() + .asStructType() + .fields()); + } + + /** + * Prune columns from a {@link Schema} using a {@link StructType Spark type} projection. + * + *

This requires that the Spark type is a projection of the Schema. Nullability and types must + * match. + * + *

The filters list of {@link Expression} is used to ensure that columns referenced by filters + * are projected. + * + * @param schema a Schema + * @param requestedType a projection of the Spark representation of the Schema + * @param filter a filters + * @return a Schema corresponding to the Spark projection + * @throws IllegalArgumentException if the Spark type does not match the Schema + */ + public static Schema prune( + Schema schema, StructType requestedType, Expression filter, boolean caseSensitive) { + Set filterRefs = + Binder.boundReferences(schema.asStruct(), Collections.singletonList(filter), caseSensitive); + + return new Schema( + TypeUtil.visit(schema, new PruneColumnsWithoutReordering(requestedType, filterRefs)) + .asNestedType() + .asStructType() + .fields()); + } + + private static PartitionSpec identitySpec(Schema schema, Collection columns) { + List names = Lists.newArrayList(); + for (Column column : columns) { + if (column.isPartition()) { + names.add(column.name()); + } + } + + return identitySpec(schema, names); + } + + private static PartitionSpec identitySpec(Schema schema, List partitionNames) { + if (partitionNames == null || partitionNames.isEmpty()) { + return null; + } + + PartitionSpec.Builder builder = PartitionSpec.builderFor(schema); + for (String partitionName : partitionNames) { + builder.identity(partitionName); + } + + return builder.build(); + } + + /** + * Estimate approximate table size based on Spark schema and total records. + * + * @param tableSchema Spark schema + * @param totalRecords total records in the table + * @return approximate size based on table schema + */ + public static long estimateSize(StructType tableSchema, long totalRecords) { + if (totalRecords == Long.MAX_VALUE) { + return totalRecords; + } + + long result; + try { + result = LongMath.checkedMultiply(tableSchema.defaultSize(), totalRecords); + } catch (ArithmeticException e) { + result = Long.MAX_VALUE; + } + return result; + } + + public static void validateMetadataColumnReferences(Schema tableSchema, Schema readSchema) { + List conflictingColumnNames = + readSchema.columns().stream() + .map(Types.NestedField::name) + .filter( + name -> + MetadataColumns.isMetadataColumn(name) && tableSchema.findField(name) != null) + .collect(Collectors.toList()); + + ValidationException.check( + conflictingColumnNames.isEmpty(), + "Table column names conflict with names reserved for Iceberg metadata columns: %s.\n" + + "Please, use ALTER TABLE statements to rename the conflicting table columns.", + conflictingColumnNames); + } + + public static Map indexQuotedNameById(Schema schema) { + Function quotingFunc = name -> String.format("`%s`", name.replace("`", "``")); + return TypeUtil.indexQuotedNameById(schema.asStruct(), quotingFunc); + } +} diff --git a/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/SparkSessionCatalog.java b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/SparkSessionCatalog.java new file mode 100644 index 000000000000..c891985b383d --- /dev/null +++ b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/SparkSessionCatalog.java @@ -0,0 +1,394 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark; + +import java.util.Map; +import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.hive.conf.HiveConf; +import org.apache.iceberg.CatalogProperties; +import org.apache.iceberg.catalog.Catalog; +import org.apache.iceberg.relocated.com.google.common.base.Preconditions; +import org.apache.iceberg.spark.source.HasIcebergCatalog; +import org.apache.spark.sql.SparkSession; +import org.apache.spark.sql.catalyst.analysis.NamespaceAlreadyExistsException; +import org.apache.spark.sql.catalyst.analysis.NoSuchFunctionException; +import org.apache.spark.sql.catalyst.analysis.NoSuchNamespaceException; +import org.apache.spark.sql.catalyst.analysis.NoSuchTableException; +import org.apache.spark.sql.catalyst.analysis.NonEmptyNamespaceException; +import org.apache.spark.sql.catalyst.analysis.TableAlreadyExistsException; +import org.apache.spark.sql.connector.catalog.CatalogExtension; +import org.apache.spark.sql.connector.catalog.CatalogPlugin; +import org.apache.spark.sql.connector.catalog.FunctionCatalog; +import org.apache.spark.sql.connector.catalog.Identifier; +import org.apache.spark.sql.connector.catalog.NamespaceChange; +import org.apache.spark.sql.connector.catalog.StagedTable; +import org.apache.spark.sql.connector.catalog.StagingTableCatalog; +import org.apache.spark.sql.connector.catalog.SupportsNamespaces; +import org.apache.spark.sql.connector.catalog.Table; +import org.apache.spark.sql.connector.catalog.TableCatalog; +import org.apache.spark.sql.connector.catalog.TableChange; +import org.apache.spark.sql.connector.catalog.functions.UnboundFunction; +import org.apache.spark.sql.connector.expressions.Transform; +import org.apache.spark.sql.types.StructType; +import org.apache.spark.sql.util.CaseInsensitiveStringMap; + +/** + * A Spark catalog that can also load non-Iceberg tables. + * + * @param CatalogPlugin class to avoid casting to TableCatalog, FunctionCatalog and + * SupportsNamespaces. + */ +public class SparkSessionCatalog + extends BaseCatalog implements CatalogExtension { + private static final String[] DEFAULT_NAMESPACE = new String[] {"default"}; + + private String catalogName = null; + private TableCatalog icebergCatalog = null; + private StagingTableCatalog asStagingCatalog = null; + private T sessionCatalog = null; + private boolean createParquetAsIceberg = false; + private boolean createAvroAsIceberg = false; + private boolean createOrcAsIceberg = false; + + /** + * Build a {@link SparkCatalog} to be used for Iceberg operations. + * + *

The default implementation creates a new SparkCatalog with the session catalog's name and + * options. + * + * @param name catalog name + * @param options catalog options + * @return a SparkCatalog to be used for Iceberg tables + */ + protected TableCatalog buildSparkCatalog(String name, CaseInsensitiveStringMap options) { + SparkCatalog newCatalog = new SparkCatalog(); + newCatalog.initialize(name, options); + return newCatalog; + } + + @Override + public String[] defaultNamespace() { + return DEFAULT_NAMESPACE; + } + + @Override + public String[][] listNamespaces() throws NoSuchNamespaceException { + return getSessionCatalog().listNamespaces(); + } + + @Override + public String[][] listNamespaces(String[] namespace) throws NoSuchNamespaceException { + return getSessionCatalog().listNamespaces(namespace); + } + + @Override + public boolean namespaceExists(String[] namespace) { + return getSessionCatalog().namespaceExists(namespace); + } + + @Override + public Map loadNamespaceMetadata(String[] namespace) + throws NoSuchNamespaceException { + return getSessionCatalog().loadNamespaceMetadata(namespace); + } + + @Override + public void createNamespace(String[] namespace, Map metadata) + throws NamespaceAlreadyExistsException { + getSessionCatalog().createNamespace(namespace, metadata); + } + + @Override + public void alterNamespace(String[] namespace, NamespaceChange... changes) + throws NoSuchNamespaceException { + getSessionCatalog().alterNamespace(namespace, changes); + } + + @Override + public boolean dropNamespace(String[] namespace, boolean cascade) + throws NoSuchNamespaceException, NonEmptyNamespaceException { + return getSessionCatalog().dropNamespace(namespace, cascade); + } + + @Override + public Identifier[] listTables(String[] namespace) throws NoSuchNamespaceException { + // delegate to the session catalog because all tables share the same namespace + return getSessionCatalog().listTables(namespace); + } + + @Override + public Table loadTable(Identifier ident) throws NoSuchTableException { + try { + return icebergCatalog.loadTable(ident); + } catch (NoSuchTableException e) { + return getSessionCatalog().loadTable(ident); + } + } + + @Override + public Table loadTable(Identifier ident, String version) throws NoSuchTableException { + try { + return icebergCatalog.loadTable(ident, version); + } catch (org.apache.iceberg.exceptions.NoSuchTableException e) { + return getSessionCatalog().loadTable(ident, version); + } + } + + @Override + public Table loadTable(Identifier ident, long timestamp) throws NoSuchTableException { + try { + return icebergCatalog.loadTable(ident, timestamp); + } catch (org.apache.iceberg.exceptions.NoSuchTableException e) { + return getSessionCatalog().loadTable(ident, timestamp); + } + } + + @Override + public void invalidateTable(Identifier ident) { + // We do not need to check whether the table exists and whether + // it is an Iceberg table to reduce remote service requests. + icebergCatalog.invalidateTable(ident); + getSessionCatalog().invalidateTable(ident); + } + + @Override + public Table createTable( + Identifier ident, StructType schema, Transform[] partitions, Map properties) + throws TableAlreadyExistsException, NoSuchNamespaceException { + String provider = properties.get("provider"); + if (useIceberg(provider)) { + return icebergCatalog.createTable(ident, schema, partitions, properties); + } else { + // delegate to the session catalog + return getSessionCatalog().createTable(ident, schema, partitions, properties); + } + } + + @Override + public StagedTable stageCreate( + Identifier ident, StructType schema, Transform[] partitions, Map properties) + throws TableAlreadyExistsException, NoSuchNamespaceException { + String provider = properties.get("provider"); + TableCatalog catalog; + if (useIceberg(provider)) { + if (asStagingCatalog != null) { + return asStagingCatalog.stageCreate(ident, schema, partitions, properties); + } + catalog = icebergCatalog; + } else { + catalog = getSessionCatalog(); + } + + // create the table with the session catalog, then wrap it in a staged table that will delete to + // roll back + Table table = catalog.createTable(ident, schema, partitions, properties); + return new RollbackStagedTable(catalog, ident, table); + } + + @Override + public StagedTable stageReplace( + Identifier ident, StructType schema, Transform[] partitions, Map properties) + throws NoSuchNamespaceException, NoSuchTableException { + String provider = properties.get("provider"); + TableCatalog catalog; + if (useIceberg(provider)) { + if (asStagingCatalog != null) { + return asStagingCatalog.stageReplace(ident, schema, partitions, properties); + } + catalog = icebergCatalog; + } else { + catalog = getSessionCatalog(); + } + + // attempt to drop the table and fail if it doesn't exist + if (!catalog.dropTable(ident)) { + throw new NoSuchTableException(ident); + } + + try { + // create the table with the session catalog, then wrap it in a staged table that will delete + // to roll back + Table table = catalog.createTable(ident, schema, partitions, properties); + return new RollbackStagedTable(catalog, ident, table); + + } catch (TableAlreadyExistsException e) { + // the table was deleted, but now already exists again. retry the replace. + return stageReplace(ident, schema, partitions, properties); + } + } + + @Override + public StagedTable stageCreateOrReplace( + Identifier ident, StructType schema, Transform[] partitions, Map properties) + throws NoSuchNamespaceException { + String provider = properties.get("provider"); + TableCatalog catalog; + if (useIceberg(provider)) { + if (asStagingCatalog != null) { + return asStagingCatalog.stageCreateOrReplace(ident, schema, partitions, properties); + } + catalog = icebergCatalog; + } else { + catalog = getSessionCatalog(); + } + + // drop the table if it exists + catalog.dropTable(ident); + + try { + // create the table with the session catalog, then wrap it in a staged table that will delete + // to roll back + Table sessionCatalogTable = catalog.createTable(ident, schema, partitions, properties); + return new RollbackStagedTable(catalog, ident, sessionCatalogTable); + + } catch (TableAlreadyExistsException e) { + // the table was deleted, but now already exists again. retry the replace. + return stageCreateOrReplace(ident, schema, partitions, properties); + } + } + + @Override + public Table alterTable(Identifier ident, TableChange... changes) throws NoSuchTableException { + if (icebergCatalog.tableExists(ident)) { + return icebergCatalog.alterTable(ident, changes); + } else { + return getSessionCatalog().alterTable(ident, changes); + } + } + + @Override + public boolean dropTable(Identifier ident) { + // no need to check table existence to determine which catalog to use. if a table doesn't exist + // then both are + // required to return false. + return icebergCatalog.dropTable(ident) || getSessionCatalog().dropTable(ident); + } + + @Override + public boolean purgeTable(Identifier ident) { + // no need to check table existence to determine which catalog to use. if a table doesn't exist + // then both are + // required to return false. + return icebergCatalog.purgeTable(ident) || getSessionCatalog().purgeTable(ident); + } + + @Override + public void renameTable(Identifier from, Identifier to) + throws NoSuchTableException, TableAlreadyExistsException { + // rename is not supported by HadoopCatalog. to avoid UnsupportedOperationException for session + // catalog tables, + // check table existence first to ensure that the table belongs to the Iceberg catalog. + if (icebergCatalog.tableExists(from)) { + icebergCatalog.renameTable(from, to); + } else { + getSessionCatalog().renameTable(from, to); + } + } + + @Override + public final void initialize(String name, CaseInsensitiveStringMap options) { + if (options.containsKey("type") && options.get("type").equalsIgnoreCase("hive")) { + validateHmsUri(options.get(CatalogProperties.URI)); + } + + this.catalogName = name; + this.icebergCatalog = buildSparkCatalog(name, options); + if (icebergCatalog instanceof StagingTableCatalog) { + this.asStagingCatalog = (StagingTableCatalog) icebergCatalog; + } + + this.createParquetAsIceberg = options.getBoolean("parquet-enabled", createParquetAsIceberg); + this.createAvroAsIceberg = options.getBoolean("avro-enabled", createAvroAsIceberg); + this.createOrcAsIceberg = options.getBoolean("orc-enabled", createOrcAsIceberg); + } + + private void validateHmsUri(String catalogHmsUri) { + if (catalogHmsUri == null) { + return; + } + + Configuration conf = SparkSession.active().sessionState().newHadoopConf(); + String envHmsUri = conf.get(HiveConf.ConfVars.METASTOREURIS.varname, null); + if (envHmsUri == null) { + return; + } + + Preconditions.checkArgument( + catalogHmsUri.equals(envHmsUri), + "Inconsistent Hive metastore URIs: %s (Spark session) != %s (spark_catalog)", + envHmsUri, + catalogHmsUri); + } + + @Override + @SuppressWarnings("unchecked") + public void setDelegateCatalog(CatalogPlugin sparkSessionCatalog) { + if (sparkSessionCatalog instanceof TableCatalog + && sparkSessionCatalog instanceof FunctionCatalog + && sparkSessionCatalog instanceof SupportsNamespaces) { + this.sessionCatalog = (T) sparkSessionCatalog; + } else { + throw new IllegalArgumentException("Invalid session catalog: " + sparkSessionCatalog); + } + } + + @Override + public String name() { + return catalogName; + } + + private boolean useIceberg(String provider) { + if (provider == null || "iceberg".equalsIgnoreCase(provider)) { + return true; + } else if (createParquetAsIceberg && "parquet".equalsIgnoreCase(provider)) { + return true; + } else if (createAvroAsIceberg && "avro".equalsIgnoreCase(provider)) { + return true; + } else if (createOrcAsIceberg && "orc".equalsIgnoreCase(provider)) { + return true; + } + + return false; + } + + private T getSessionCatalog() { + Preconditions.checkNotNull( + sessionCatalog, + "Delegated SessionCatalog is missing. " + + "Please make sure your are replacing Spark's default catalog, named 'spark_catalog'."); + return sessionCatalog; + } + + @Override + public Catalog icebergCatalog() { + Preconditions.checkArgument( + icebergCatalog instanceof HasIcebergCatalog, + "Cannot return underlying Iceberg Catalog, wrapped catalog does not contain an Iceberg Catalog"); + return ((HasIcebergCatalog) icebergCatalog).icebergCatalog(); + } + + @Override + public UnboundFunction loadFunction(Identifier ident) throws NoSuchFunctionException { + try { + return super.loadFunction(ident); + } catch (NoSuchFunctionException e) { + return getSessionCatalog().loadFunction(ident); + } + } +} diff --git a/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/SparkStructLike.java b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/SparkStructLike.java new file mode 100644 index 000000000000..77cfa0f34c63 --- /dev/null +++ b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/SparkStructLike.java @@ -0,0 +1,54 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark; + +import org.apache.iceberg.StructLike; +import org.apache.iceberg.types.Types; +import org.apache.spark.sql.Row; + +public class SparkStructLike implements StructLike { + + private final Types.StructType type; + private Row wrapped; + + public SparkStructLike(Types.StructType type) { + this.type = type; + } + + public SparkStructLike wrap(Row row) { + this.wrapped = row; + return this; + } + + @Override + public int size() { + return type.fields().size(); + } + + @Override + public T get(int pos, Class javaClass) { + Types.NestedField field = type.fields().get(pos); + return javaClass.cast(SparkValueConverter.convert(field.type(), wrapped.get(pos))); + } + + @Override + public void set(int pos, T value) { + throw new UnsupportedOperationException("Not implemented: set"); + } +} diff --git a/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/SparkTableCache.java b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/SparkTableCache.java new file mode 100644 index 000000000000..6218423db491 --- /dev/null +++ b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/SparkTableCache.java @@ -0,0 +1,54 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark; + +import java.util.Map; +import org.apache.iceberg.Table; +import org.apache.iceberg.relocated.com.google.common.collect.Maps; + +public class SparkTableCache { + + private static final SparkTableCache INSTANCE = new SparkTableCache(); + + private final Map cache = Maps.newConcurrentMap(); + + public static SparkTableCache get() { + return INSTANCE; + } + + public int size() { + return cache.size(); + } + + public void add(String key, Table table) { + cache.put(key, table); + } + + public boolean contains(String key) { + return cache.containsKey(key); + } + + public Table get(String key) { + return cache.get(key); + } + + public Table remove(String key) { + return cache.remove(key); + } +} diff --git a/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/SparkTableUtil.java b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/SparkTableUtil.java new file mode 100644 index 000000000000..90bdbfc1d9ba --- /dev/null +++ b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/SparkTableUtil.java @@ -0,0 +1,770 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark; + +import static org.apache.spark.sql.functions.col; + +import java.io.IOException; +import java.io.Serializable; +import java.net.URI; +import java.util.Collections; +import java.util.Iterator; +import java.util.List; +import java.util.Map; +import java.util.UUID; +import java.util.stream.Collectors; +import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.fs.Path; +import org.apache.iceberg.AppendFiles; +import org.apache.iceberg.DataFile; +import org.apache.iceberg.FileFormat; +import org.apache.iceberg.ManifestFile; +import org.apache.iceberg.ManifestFiles; +import org.apache.iceberg.ManifestWriter; +import org.apache.iceberg.MetadataTableType; +import org.apache.iceberg.MetadataTableUtils; +import org.apache.iceberg.MetricsConfig; +import org.apache.iceberg.PartitionSpec; +import org.apache.iceberg.Table; +import org.apache.iceberg.TableProperties; +import org.apache.iceberg.data.TableMigrationUtil; +import org.apache.iceberg.hadoop.HadoopFileIO; +import org.apache.iceberg.hadoop.SerializableConfiguration; +import org.apache.iceberg.hadoop.Util; +import org.apache.iceberg.io.FileIO; +import org.apache.iceberg.io.OutputFile; +import org.apache.iceberg.mapping.NameMapping; +import org.apache.iceberg.mapping.NameMappingParser; +import org.apache.iceberg.relocated.com.google.common.base.Joiner; +import org.apache.iceberg.relocated.com.google.common.base.MoreObjects; +import org.apache.iceberg.relocated.com.google.common.base.Objects; +import org.apache.iceberg.relocated.com.google.common.base.Preconditions; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableList; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap; +import org.apache.iceberg.relocated.com.google.common.collect.Lists; +import org.apache.iceberg.relocated.com.google.common.collect.Maps; +import org.apache.iceberg.spark.source.SparkTable; +import org.apache.iceberg.util.PropertyUtil; +import org.apache.iceberg.util.Tasks; +import org.apache.iceberg.util.ThreadPools; +import org.apache.spark.TaskContext; +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.api.java.function.FlatMapFunction; +import org.apache.spark.api.java.function.MapFunction; +import org.apache.spark.api.java.function.MapPartitionsFunction; +import org.apache.spark.sql.AnalysisException; +import org.apache.spark.sql.Column; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Encoders; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.SparkSession; +import org.apache.spark.sql.catalyst.TableIdentifier; +import org.apache.spark.sql.catalyst.analysis.NoSuchDatabaseException; +import org.apache.spark.sql.catalyst.analysis.NoSuchTableException; +import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute; +import org.apache.spark.sql.catalyst.catalog.CatalogTable; +import org.apache.spark.sql.catalyst.catalog.CatalogTablePartition; +import org.apache.spark.sql.catalyst.catalog.SessionCatalog; +import org.apache.spark.sql.catalyst.expressions.Expression; +import org.apache.spark.sql.catalyst.expressions.NamedExpression; +import org.apache.spark.sql.catalyst.parser.ParseException; +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan; +import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation; +import org.apache.spark.sql.util.CaseInsensitiveStringMap; +import scala.Function2; +import scala.Option; +import scala.Some; +import scala.Tuple2; +import scala.collection.JavaConverters; +import scala.collection.immutable.Map$; +import scala.collection.immutable.Seq; +import scala.collection.mutable.Builder; +import scala.runtime.AbstractPartialFunction; + +/** + * Java version of the original SparkTableUtil.scala + * https://github.com/apache/iceberg/blob/apache-iceberg-0.8.0-incubating/spark/src/main/scala/org/apache/iceberg/spark/SparkTableUtil.scala + */ +public class SparkTableUtil { + + private static final String DUPLICATE_FILE_MESSAGE = + "Cannot complete import because data files " + + "to be imported already exist within the target table: %s. " + + "This is disabled by default as Iceberg is not designed for multiple references to the same file" + + " within the same table. If you are sure, you may set 'check_duplicate_files' to false to force the import."; + + private SparkTableUtil() {} + + /** + * Returns a DataFrame with a row for each partition in the table. + * + *

The DataFrame has 3 columns, partition key (a=1/b=2), partition location, and format (avro + * or parquet). + * + * @param spark a Spark session + * @param table a table name and (optional) database + * @return a DataFrame of the table's partitions + */ + public static Dataset partitionDF(SparkSession spark, String table) { + List partitions = getPartitions(spark, table); + return spark + .createDataFrame(partitions, SparkPartition.class) + .toDF("partition", "uri", "format"); + } + + /** + * Returns a DataFrame with a row for each partition that matches the specified 'expression'. + * + * @param spark a Spark session. + * @param table name of the table. + * @param expression The expression whose matching partitions are returned. + * @return a DataFrame of the table partitions. + */ + public static Dataset partitionDFByFilter( + SparkSession spark, String table, String expression) { + List partitions = getPartitionsByFilter(spark, table, expression); + return spark + .createDataFrame(partitions, SparkPartition.class) + .toDF("partition", "uri", "format"); + } + + /** + * Returns all partitions in the table. + * + * @param spark a Spark session + * @param table a table name and (optional) database + * @return all table's partitions + */ + public static List getPartitions(SparkSession spark, String table) { + try { + TableIdentifier tableIdent = spark.sessionState().sqlParser().parseTableIdentifier(table); + return getPartitions(spark, tableIdent, null); + } catch (ParseException e) { + throw SparkExceptionUtil.toUncheckedException( + e, "Unable to parse table identifier: %s", table); + } + } + + /** + * Returns all partitions in the table. + * + * @param spark a Spark session + * @param tableIdent a table identifier + * @param partitionFilter partition filter, or null if no filter + * @return all table's partitions + */ + public static List getPartitions( + SparkSession spark, TableIdentifier tableIdent, Map partitionFilter) { + try { + SessionCatalog catalog = spark.sessionState().catalog(); + CatalogTable catalogTable = catalog.getTableMetadata(tableIdent); + + Option> scalaPartitionFilter; + if (partitionFilter != null && !partitionFilter.isEmpty()) { + Builder, scala.collection.immutable.Map> builder = + Map$.MODULE$.newBuilder(); + partitionFilter.forEach((key, value) -> builder.$plus$eq(Tuple2.apply(key, value))); + scalaPartitionFilter = Option.apply(builder.result()); + } else { + scalaPartitionFilter = Option.empty(); + } + Seq partitions = + catalog.listPartitions(tableIdent, scalaPartitionFilter).toIndexedSeq(); + return JavaConverters.seqAsJavaListConverter(partitions).asJava().stream() + .map(catalogPartition -> toSparkPartition(catalogPartition, catalogTable)) + .collect(Collectors.toList()); + } catch (NoSuchDatabaseException e) { + throw SparkExceptionUtil.toUncheckedException( + e, "Unknown table: %s. Database not found in catalog.", tableIdent); + } catch (NoSuchTableException e) { + throw SparkExceptionUtil.toUncheckedException( + e, "Unknown table: %s. Table not found in catalog.", tableIdent); + } + } + + /** + * Returns partitions that match the specified 'predicate'. + * + * @param spark a Spark session + * @param table a table name and (optional) database + * @param predicate a predicate on partition columns + * @return matching table's partitions + */ + public static List getPartitionsByFilter( + SparkSession spark, String table, String predicate) { + TableIdentifier tableIdent; + try { + tableIdent = spark.sessionState().sqlParser().parseTableIdentifier(table); + } catch (ParseException e) { + throw SparkExceptionUtil.toUncheckedException( + e, "Unable to parse the table identifier: %s", table); + } + + Expression unresolvedPredicateExpr; + try { + unresolvedPredicateExpr = spark.sessionState().sqlParser().parseExpression(predicate); + } catch (ParseException e) { + throw SparkExceptionUtil.toUncheckedException( + e, "Unable to parse the predicate expression: %s", predicate); + } + + Expression resolvedPredicateExpr = resolveAttrs(spark, table, unresolvedPredicateExpr); + return getPartitionsByFilter(spark, tableIdent, resolvedPredicateExpr); + } + + /** + * Returns partitions that match the specified 'predicate'. + * + * @param spark a Spark session + * @param tableIdent a table identifier + * @param predicateExpr a predicate expression on partition columns + * @return matching table's partitions + */ + public static List getPartitionsByFilter( + SparkSession spark, TableIdentifier tableIdent, Expression predicateExpr) { + try { + SessionCatalog catalog = spark.sessionState().catalog(); + CatalogTable catalogTable = catalog.getTableMetadata(tableIdent); + + Expression resolvedPredicateExpr; + if (!predicateExpr.resolved()) { + resolvedPredicateExpr = resolveAttrs(spark, tableIdent.quotedString(), predicateExpr); + } else { + resolvedPredicateExpr = predicateExpr; + } + Seq predicates = + JavaConverters.collectionAsScalaIterableConverter(ImmutableList.of(resolvedPredicateExpr)) + .asScala() + .toIndexedSeq(); + + Seq partitions = + catalog.listPartitionsByFilter(tableIdent, predicates).toIndexedSeq(); + + return JavaConverters.seqAsJavaListConverter(partitions).asJava().stream() + .map(catalogPartition -> toSparkPartition(catalogPartition, catalogTable)) + .collect(Collectors.toList()); + } catch (NoSuchDatabaseException e) { + throw SparkExceptionUtil.toUncheckedException( + e, "Unknown table: %s. Database not found in catalog.", tableIdent); + } catch (NoSuchTableException e) { + throw SparkExceptionUtil.toUncheckedException( + e, "Unknown table: %s. Table not found in catalog.", tableIdent); + } + } + + /** + * Returns the data files in a partition by listing the partition location. + * + *

For Parquet and ORC partitions, this will read metrics from the file footer. For Avro + * partitions, metrics are set to null. + * + * @param partition a partition + * @param conf a serializable Hadoop conf + * @param metricsConfig a metrics conf + * @return a List of DataFile + * @deprecated use {@link TableMigrationUtil#listPartition(Map, String, String, PartitionSpec, + * Configuration, MetricsConfig, NameMapping)} + */ + @Deprecated + public static List listPartition( + SparkPartition partition, + PartitionSpec spec, + SerializableConfiguration conf, + MetricsConfig metricsConfig) { + return listPartition(partition, spec, conf, metricsConfig, null); + } + + /** + * Returns the data files in a partition by listing the partition location. + * + *

For Parquet and ORC partitions, this will read metrics from the file footer. For Avro + * partitions, metrics are set to null. + * + * @param partition a partition + * @param conf a serializable Hadoop conf + * @param metricsConfig a metrics conf + * @param mapping a name mapping + * @return a List of DataFile + * @deprecated use {@link TableMigrationUtil#listPartition(Map, String, String, PartitionSpec, + * Configuration, MetricsConfig, NameMapping)} + */ + @Deprecated + public static List listPartition( + SparkPartition partition, + PartitionSpec spec, + SerializableConfiguration conf, + MetricsConfig metricsConfig, + NameMapping mapping) { + return TableMigrationUtil.listPartition( + partition.values, + partition.uri, + partition.format, + spec, + conf.get(), + metricsConfig, + mapping); + } + + private static SparkPartition toSparkPartition( + CatalogTablePartition partition, CatalogTable table) { + Option locationUri = partition.storage().locationUri(); + Option serde = partition.storage().serde(); + + Preconditions.checkArgument(locationUri.nonEmpty(), "Partition URI should be defined"); + Preconditions.checkArgument( + serde.nonEmpty() || table.provider().nonEmpty(), "Partition format should be defined"); + + String uri = Util.uriToString(locationUri.get()); + String format = serde.nonEmpty() ? serde.get() : table.provider().get(); + + Map partitionSpec = + JavaConverters.mapAsJavaMapConverter(partition.spec()).asJava(); + return new SparkPartition(partitionSpec, uri, format); + } + + private static Expression resolveAttrs(SparkSession spark, String table, Expression expr) { + Function2 resolver = spark.sessionState().analyzer().resolver(); + LogicalPlan plan = spark.table(table).queryExecution().analyzed(); + return expr.transform( + new AbstractPartialFunction() { + @Override + public Expression apply(Expression attr) { + UnresolvedAttribute unresolvedAttribute = (UnresolvedAttribute) attr; + Option namedExpressionOption = + plan.resolve(unresolvedAttribute.nameParts(), resolver); + if (namedExpressionOption.isDefined()) { + return (Expression) namedExpressionOption.get(); + } else { + throw new IllegalArgumentException( + String.format("Could not resolve %s using columns: %s", attr, plan.output())); + } + } + + @Override + public boolean isDefinedAt(Expression attr) { + return attr instanceof UnresolvedAttribute; + } + }); + } + + private static Iterator buildManifest( + SerializableConfiguration conf, + PartitionSpec spec, + String basePath, + Iterator> fileTuples) { + if (fileTuples.hasNext()) { + FileIO io = new HadoopFileIO(conf.get()); + TaskContext ctx = TaskContext.get(); + String suffix = + String.format( + "stage-%d-task-%d-manifest-%s", + ctx.stageId(), ctx.taskAttemptId(), UUID.randomUUID()); + Path location = new Path(basePath, suffix); + String outputPath = FileFormat.AVRO.addExtension(location.toString()); + OutputFile outputFile = io.newOutputFile(outputPath); + ManifestWriter writer = ManifestFiles.write(spec, outputFile); + + try (ManifestWriter writerRef = writer) { + fileTuples.forEachRemaining(fileTuple -> writerRef.add(fileTuple._2)); + } catch (IOException e) { + throw SparkExceptionUtil.toUncheckedException( + e, "Unable to close the manifest writer: %s", outputPath); + } + + ManifestFile manifestFile = writer.toManifestFile(); + return ImmutableList.of(manifestFile).iterator(); + } else { + return Collections.emptyIterator(); + } + } + + /** + * Import files from an existing Spark table to an Iceberg table. + * + *

The import uses the Spark session to get table metadata. It assumes no operation is going on + * the original and target table and thus is not thread-safe. + * + * @param spark a Spark session + * @param sourceTableIdent an identifier of the source Spark table + * @param targetTable an Iceberg table where to import the data + * @param stagingDir a staging directory to store temporary manifest files + * @param partitionFilter only import partitions whose values match those in the map, can be + * partially defined + * @param checkDuplicateFiles if true, throw exception if import results in a duplicate data file + */ + public static void importSparkTable( + SparkSession spark, + TableIdentifier sourceTableIdent, + Table targetTable, + String stagingDir, + Map partitionFilter, + boolean checkDuplicateFiles) { + SessionCatalog catalog = spark.sessionState().catalog(); + + String db = + sourceTableIdent.database().nonEmpty() + ? sourceTableIdent.database().get() + : catalog.getCurrentDatabase(); + TableIdentifier sourceTableIdentWithDB = + new TableIdentifier(sourceTableIdent.table(), Some.apply(db)); + + if (!catalog.tableExists(sourceTableIdentWithDB)) { + throw new org.apache.iceberg.exceptions.NoSuchTableException( + "Table %s does not exist", sourceTableIdentWithDB); + } + + try { + PartitionSpec spec = + SparkSchemaUtil.specForTable(spark, sourceTableIdentWithDB.unquotedString()); + + if (Objects.equal(spec, PartitionSpec.unpartitioned())) { + importUnpartitionedSparkTable( + spark, sourceTableIdentWithDB, targetTable, checkDuplicateFiles); + } else { + List sourceTablePartitions = + getPartitions(spark, sourceTableIdent, partitionFilter); + Preconditions.checkArgument( + !sourceTablePartitions.isEmpty(), + "Cannot find any partitions in table %s", + sourceTableIdent); + importSparkPartitions( + spark, sourceTablePartitions, targetTable, spec, stagingDir, checkDuplicateFiles); + } + } catch (AnalysisException e) { + throw SparkExceptionUtil.toUncheckedException( + e, "Unable to get partition spec for table: %s", sourceTableIdentWithDB); + } + } + + /** + * Import files from an existing Spark table to an Iceberg table. + * + *

The import uses the Spark session to get table metadata. It assumes no operation is going on + * the original and target table and thus is not thread-safe. + * + * @param spark a Spark session + * @param sourceTableIdent an identifier of the source Spark table + * @param targetTable an Iceberg table where to import the data + * @param stagingDir a staging directory to store temporary manifest files + * @param checkDuplicateFiles if true, throw exception if import results in a duplicate data file + */ + public static void importSparkTable( + SparkSession spark, + TableIdentifier sourceTableIdent, + Table targetTable, + String stagingDir, + boolean checkDuplicateFiles) { + importSparkTable( + spark, + sourceTableIdent, + targetTable, + stagingDir, + Collections.emptyMap(), + checkDuplicateFiles); + } + + /** + * Import files from an existing Spark table to an Iceberg table. + * + *

The import uses the Spark session to get table metadata. It assumes no operation is going on + * the original and target table and thus is not thread-safe. + * + * @param spark a Spark session + * @param sourceTableIdent an identifier of the source Spark table + * @param targetTable an Iceberg table where to import the data + * @param stagingDir a staging directory to store temporary manifest files + */ + public static void importSparkTable( + SparkSession spark, TableIdentifier sourceTableIdent, Table targetTable, String stagingDir) { + importSparkTable( + spark, sourceTableIdent, targetTable, stagingDir, Collections.emptyMap(), false); + } + + private static void importUnpartitionedSparkTable( + SparkSession spark, + TableIdentifier sourceTableIdent, + Table targetTable, + boolean checkDuplicateFiles) { + try { + CatalogTable sourceTable = spark.sessionState().catalog().getTableMetadata(sourceTableIdent); + Option format = + sourceTable.storage().serde().nonEmpty() + ? sourceTable.storage().serde() + : sourceTable.provider(); + Preconditions.checkArgument(format.nonEmpty(), "Could not determine table format"); + + Map partition = Collections.emptyMap(); + PartitionSpec spec = PartitionSpec.unpartitioned(); + Configuration conf = spark.sessionState().newHadoopConf(); + MetricsConfig metricsConfig = MetricsConfig.forTable(targetTable); + String nameMappingString = targetTable.properties().get(TableProperties.DEFAULT_NAME_MAPPING); + NameMapping nameMapping = + nameMappingString != null ? NameMappingParser.fromJson(nameMappingString) : null; + + List files = + TableMigrationUtil.listPartition( + partition, + Util.uriToString(sourceTable.location()), + format.get(), + spec, + conf, + metricsConfig, + nameMapping); + + if (checkDuplicateFiles) { + Dataset importedFiles = + spark + .createDataset(Lists.transform(files, f -> f.path().toString()), Encoders.STRING()) + .toDF("file_path"); + Dataset existingFiles = + loadMetadataTable(spark, targetTable, MetadataTableType.ENTRIES).filter("status != 2"); + Column joinCond = + existingFiles.col("data_file.file_path").equalTo(importedFiles.col("file_path")); + Dataset duplicates = + importedFiles.join(existingFiles, joinCond).select("file_path").as(Encoders.STRING()); + Preconditions.checkState( + duplicates.isEmpty(), + String.format( + DUPLICATE_FILE_MESSAGE, Joiner.on(",").join((String[]) duplicates.take(10)))); + } + + AppendFiles append = targetTable.newAppend(); + files.forEach(append::appendFile); + append.commit(); + } catch (NoSuchDatabaseException e) { + throw SparkExceptionUtil.toUncheckedException( + e, "Unknown table: %s. Database not found in catalog.", sourceTableIdent); + } catch (NoSuchTableException e) { + throw SparkExceptionUtil.toUncheckedException( + e, "Unknown table: %s. Table not found in catalog.", sourceTableIdent); + } + } + + /** + * Import files from given partitions to an Iceberg table. + * + * @param spark a Spark session + * @param partitions partitions to import + * @param targetTable an Iceberg table where to import the data + * @param spec a partition spec + * @param stagingDir a staging directory to store temporary manifest files + * @param checkDuplicateFiles if true, throw exception if import results in a duplicate data file + */ + public static void importSparkPartitions( + SparkSession spark, + List partitions, + Table targetTable, + PartitionSpec spec, + String stagingDir, + boolean checkDuplicateFiles) { + Configuration conf = spark.sessionState().newHadoopConf(); + SerializableConfiguration serializableConf = new SerializableConfiguration(conf); + int parallelism = + Math.min( + partitions.size(), spark.sessionState().conf().parallelPartitionDiscoveryParallelism()); + int numShufflePartitions = spark.sessionState().conf().numShufflePartitions(); + MetricsConfig metricsConfig = MetricsConfig.fromProperties(targetTable.properties()); + String nameMappingString = targetTable.properties().get(TableProperties.DEFAULT_NAME_MAPPING); + NameMapping nameMapping = + nameMappingString != null ? NameMappingParser.fromJson(nameMappingString) : null; + + JavaSparkContext sparkContext = JavaSparkContext.fromSparkContext(spark.sparkContext()); + JavaRDD partitionRDD = sparkContext.parallelize(partitions, parallelism); + + Dataset partitionDS = + spark.createDataset(partitionRDD.rdd(), Encoders.javaSerialization(SparkPartition.class)); + + Dataset filesToImport = + partitionDS.flatMap( + (FlatMapFunction) + sparkPartition -> + listPartition( + sparkPartition, spec, serializableConf, metricsConfig, nameMapping) + .iterator(), + Encoders.javaSerialization(DataFile.class)); + + if (checkDuplicateFiles) { + Dataset importedFiles = + filesToImport + .map((MapFunction) f -> f.path().toString(), Encoders.STRING()) + .toDF("file_path"); + Dataset existingFiles = + loadMetadataTable(spark, targetTable, MetadataTableType.ENTRIES).filter("status != 2"); + Column joinCond = + existingFiles.col("data_file.file_path").equalTo(importedFiles.col("file_path")); + Dataset duplicates = + importedFiles.join(existingFiles, joinCond).select("file_path").as(Encoders.STRING()); + Preconditions.checkState( + duplicates.isEmpty(), + String.format( + DUPLICATE_FILE_MESSAGE, Joiner.on(",").join((String[]) duplicates.take(10)))); + } + + List manifests = + filesToImport + .repartition(numShufflePartitions) + .map( + (MapFunction>) + file -> Tuple2.apply(file.path().toString(), file), + Encoders.tuple(Encoders.STRING(), Encoders.javaSerialization(DataFile.class))) + .orderBy(col("_1")) + .mapPartitions( + (MapPartitionsFunction, ManifestFile>) + fileTuple -> buildManifest(serializableConf, spec, stagingDir, fileTuple), + Encoders.javaSerialization(ManifestFile.class)) + .collectAsList(); + + try { + boolean snapshotIdInheritanceEnabled = + PropertyUtil.propertyAsBoolean( + targetTable.properties(), + TableProperties.SNAPSHOT_ID_INHERITANCE_ENABLED, + TableProperties.SNAPSHOT_ID_INHERITANCE_ENABLED_DEFAULT); + + AppendFiles append = targetTable.newAppend(); + manifests.forEach(append::appendManifest); + append.commit(); + + if (!snapshotIdInheritanceEnabled) { + // delete original manifests as they were rewritten before the commit + deleteManifests(targetTable.io(), manifests); + } + } catch (Throwable e) { + deleteManifests(targetTable.io(), manifests); + throw e; + } + } + + /** + * Import files from given partitions to an Iceberg table. + * + * @param spark a Spark session + * @param partitions partitions to import + * @param targetTable an Iceberg table where to import the data + * @param spec a partition spec + * @param stagingDir a staging directory to store temporary manifest files + */ + public static void importSparkPartitions( + SparkSession spark, + List partitions, + Table targetTable, + PartitionSpec spec, + String stagingDir) { + importSparkPartitions(spark, partitions, targetTable, spec, stagingDir, false); + } + + public static List filterPartitions( + List partitions, Map partitionFilter) { + if (partitionFilter.isEmpty()) { + return partitions; + } else { + return partitions.stream() + .filter(p -> p.getValues().entrySet().containsAll(partitionFilter.entrySet())) + .collect(Collectors.toList()); + } + } + + private static void deleteManifests(FileIO io, List manifests) { + Tasks.foreach(manifests) + .executeWith(ThreadPools.getWorkerPool()) + .noRetry() + .suppressFailureWhenFinished() + .run(item -> io.deleteFile(item.path())); + } + + /** + * Loads a metadata table. + * + * @deprecated since 0.14.0, will be removed in 0.15.0; use {@link + * #loadMetadataTable(SparkSession, Table, MetadataTableType)}. + */ + @Deprecated + public static Dataset loadCatalogMetadataTable( + SparkSession spark, Table table, MetadataTableType type) { + return loadMetadataTable(spark, table, type); + } + + public static Dataset loadMetadataTable( + SparkSession spark, Table table, MetadataTableType type) { + return loadMetadataTable(spark, table, type, ImmutableMap.of()); + } + + public static Dataset loadMetadataTable( + SparkSession spark, Table table, MetadataTableType type, Map extraOptions) { + SparkTable metadataTable = + new SparkTable(MetadataTableUtils.createMetadataTableInstance(table, type), false); + CaseInsensitiveStringMap options = new CaseInsensitiveStringMap(extraOptions); + return Dataset.ofRows( + spark, DataSourceV2Relation.create(metadataTable, Some.empty(), Some.empty(), options)); + } + + /** Class representing a table partition. */ + public static class SparkPartition implements Serializable { + private final Map values; + private final String uri; + private final String format; + + public SparkPartition(Map values, String uri, String format) { + this.values = Maps.newHashMap(values); + this.uri = uri; + this.format = format; + } + + public Map getValues() { + return values; + } + + public String getUri() { + return uri; + } + + public String getFormat() { + return format; + } + + @Override + public String toString() { + return MoreObjects.toStringHelper(this) + .add("values", values) + .add("uri", uri) + .add("format", format) + .toString(); + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + SparkPartition that = (SparkPartition) o; + return Objects.equal(values, that.values) + && Objects.equal(uri, that.uri) + && Objects.equal(format, that.format); + } + + @Override + public int hashCode() { + return Objects.hashCode(values, uri, format); + } + } +} diff --git a/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/SparkTypeToType.java b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/SparkTypeToType.java new file mode 100644 index 000000000000..17499736fbeb --- /dev/null +++ b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/SparkTypeToType.java @@ -0,0 +1,158 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark; + +import java.util.List; +import org.apache.iceberg.relocated.com.google.common.collect.Lists; +import org.apache.iceberg.types.Type; +import org.apache.iceberg.types.Types; +import org.apache.spark.sql.types.ArrayType; +import org.apache.spark.sql.types.BinaryType; +import org.apache.spark.sql.types.BooleanType; +import org.apache.spark.sql.types.ByteType; +import org.apache.spark.sql.types.CharType; +import org.apache.spark.sql.types.DataType; +import org.apache.spark.sql.types.DateType; +import org.apache.spark.sql.types.DecimalType; +import org.apache.spark.sql.types.DoubleType; +import org.apache.spark.sql.types.FloatType; +import org.apache.spark.sql.types.IntegerType; +import org.apache.spark.sql.types.LongType; +import org.apache.spark.sql.types.MapType; +import org.apache.spark.sql.types.ShortType; +import org.apache.spark.sql.types.StringType; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; +import org.apache.spark.sql.types.TimestampType; +import org.apache.spark.sql.types.VarcharType; + +class SparkTypeToType extends SparkTypeVisitor { + private final StructType root; + private int nextId = 0; + + SparkTypeToType() { + this.root = null; + } + + SparkTypeToType(StructType root) { + this.root = root; + // the root struct's fields use the first ids + this.nextId = root.fields().length; + } + + private int getNextId() { + int next = nextId; + nextId += 1; + return next; + } + + @Override + @SuppressWarnings("ReferenceEquality") + public Type struct(StructType struct, List types) { + StructField[] fields = struct.fields(); + List newFields = Lists.newArrayListWithExpectedSize(fields.length); + boolean isRoot = root == struct; + for (int i = 0; i < fields.length; i += 1) { + StructField field = fields[i]; + Type type = types.get(i); + + int id; + if (isRoot) { + // for new conversions, use ordinals for ids in the root struct + id = i; + } else { + id = getNextId(); + } + + String doc = field.getComment().isDefined() ? field.getComment().get() : null; + + if (field.nullable()) { + newFields.add(Types.NestedField.optional(id, field.name(), type, doc)); + } else { + newFields.add(Types.NestedField.required(id, field.name(), type, doc)); + } + } + + return Types.StructType.of(newFields); + } + + @Override + public Type field(StructField field, Type typeResult) { + return typeResult; + } + + @Override + public Type array(ArrayType array, Type elementType) { + if (array.containsNull()) { + return Types.ListType.ofOptional(getNextId(), elementType); + } else { + return Types.ListType.ofRequired(getNextId(), elementType); + } + } + + @Override + public Type map(MapType map, Type keyType, Type valueType) { + if (map.valueContainsNull()) { + return Types.MapType.ofOptional(getNextId(), getNextId(), keyType, valueType); + } else { + return Types.MapType.ofRequired(getNextId(), getNextId(), keyType, valueType); + } + } + + @SuppressWarnings("checkstyle:CyclomaticComplexity") + @Override + public Type atomic(DataType atomic) { + if (atomic instanceof BooleanType) { + return Types.BooleanType.get(); + + } else if (atomic instanceof IntegerType + || atomic instanceof ShortType + || atomic instanceof ByteType) { + return Types.IntegerType.get(); + + } else if (atomic instanceof LongType) { + return Types.LongType.get(); + + } else if (atomic instanceof FloatType) { + return Types.FloatType.get(); + + } else if (atomic instanceof DoubleType) { + return Types.DoubleType.get(); + + } else if (atomic instanceof StringType + || atomic instanceof CharType + || atomic instanceof VarcharType) { + return Types.StringType.get(); + + } else if (atomic instanceof DateType) { + return Types.DateType.get(); + + } else if (atomic instanceof TimestampType) { + return Types.TimestampType.withZone(); + + } else if (atomic instanceof DecimalType) { + return Types.DecimalType.of( + ((DecimalType) atomic).precision(), ((DecimalType) atomic).scale()); + } else if (atomic instanceof BinaryType) { + return Types.BinaryType.get(); + } + + throw new UnsupportedOperationException("Not a supported type: " + atomic.catalogString()); + } +} diff --git a/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/SparkTypeVisitor.java b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/SparkTypeVisitor.java new file mode 100644 index 000000000000..1ef694263fa4 --- /dev/null +++ b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/SparkTypeVisitor.java @@ -0,0 +1,78 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark; + +import java.util.List; +import org.apache.iceberg.relocated.com.google.common.collect.Lists; +import org.apache.spark.sql.types.ArrayType; +import org.apache.spark.sql.types.DataType; +import org.apache.spark.sql.types.MapType; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; +import org.apache.spark.sql.types.UserDefinedType; + +class SparkTypeVisitor { + static T visit(DataType type, SparkTypeVisitor visitor) { + if (type instanceof StructType) { + StructField[] fields = ((StructType) type).fields(); + List fieldResults = Lists.newArrayListWithExpectedSize(fields.length); + + for (StructField field : fields) { + fieldResults.add(visitor.field(field, visit(field.dataType(), visitor))); + } + + return visitor.struct((StructType) type, fieldResults); + + } else if (type instanceof MapType) { + return visitor.map( + (MapType) type, + visit(((MapType) type).keyType(), visitor), + visit(((MapType) type).valueType(), visitor)); + + } else if (type instanceof ArrayType) { + return visitor.array((ArrayType) type, visit(((ArrayType) type).elementType(), visitor)); + + } else if (type instanceof UserDefinedType) { + throw new UnsupportedOperationException("User-defined types are not supported"); + + } else { + return visitor.atomic(type); + } + } + + public T struct(StructType struct, List fieldResults) { + return null; + } + + public T field(StructField field, T typeResult) { + return null; + } + + public T array(ArrayType array, T elementResult) { + return null; + } + + public T map(MapType map, T keyResult, T valueResult) { + return null; + } + + public T atomic(DataType atomic) { + return null; + } +} diff --git a/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/SparkUtil.java b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/SparkUtil.java new file mode 100644 index 000000000000..23d0d9303e4f --- /dev/null +++ b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/SparkUtil.java @@ -0,0 +1,307 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark; + +import java.sql.Date; +import java.sql.Timestamp; +import java.util.List; +import java.util.Map; +import java.util.function.BiFunction; +import java.util.function.Function; +import java.util.stream.Collectors; +import org.apache.hadoop.conf.Configuration; +import org.apache.iceberg.PartitionField; +import org.apache.iceberg.PartitionSpec; +import org.apache.iceberg.Schema; +import org.apache.iceberg.Table; +import org.apache.iceberg.hadoop.HadoopConfigurable; +import org.apache.iceberg.io.FileIO; +import org.apache.iceberg.relocated.com.google.common.base.Joiner; +import org.apache.iceberg.relocated.com.google.common.base.Preconditions; +import org.apache.iceberg.relocated.com.google.common.collect.Lists; +import org.apache.iceberg.transforms.Transform; +import org.apache.iceberg.transforms.UnknownTransform; +import org.apache.iceberg.types.TypeUtil; +import org.apache.iceberg.types.Types; +import org.apache.iceberg.util.Pair; +import org.apache.spark.sql.RuntimeConfig; +import org.apache.spark.sql.SparkSession; +import org.apache.spark.sql.catalyst.expressions.BoundReference; +import org.apache.spark.sql.catalyst.expressions.EqualTo; +import org.apache.spark.sql.catalyst.expressions.Expression; +import org.apache.spark.sql.catalyst.expressions.Literal; +import org.apache.spark.sql.connector.expressions.NamedReference; +import org.apache.spark.sql.types.DataType; +import org.apache.spark.sql.types.DataTypes; +import org.apache.spark.sql.types.StructType; +import org.apache.spark.util.SerializableConfiguration; +import org.joda.time.DateTime; + +public class SparkUtil { + + public static final String TIMESTAMP_WITHOUT_TIMEZONE_ERROR = + String.format( + "Cannot handle timestamp without" + + " timezone fields in Spark. Spark does not natively support this type but if you would like to handle all" + + " timestamps as timestamp with timezone set '%s' to true. This will not change the underlying values stored" + + " but will change their displayed values in Spark. For more information please see" + + " https://docs.databricks.com/spark/latest/dataframes-datasets/dates-timestamps.html#ansi-sql-and" + + "-spark-sql-timestamps", + SparkSQLProperties.HANDLE_TIMESTAMP_WITHOUT_TIMEZONE); + + private static final String SPARK_CATALOG_CONF_PREFIX = "spark.sql.catalog"; + // Format string used as the prefix for Spark configuration keys to override Hadoop configuration + // values for Iceberg tables from a given catalog. These keys can be specified as + // `spark.sql.catalog.$catalogName.hadoop.*`, similar to using `spark.hadoop.*` to override + // Hadoop configurations globally for a given Spark session. + private static final String SPARK_CATALOG_HADOOP_CONF_OVERRIDE_FMT_STR = + SPARK_CATALOG_CONF_PREFIX + ".%s.hadoop."; + + private static final Joiner DOT = Joiner.on("."); + + private SparkUtil() {} + + /** + * Using this to broadcast FileIO can lead to unexpected behavior, as broadcast variables that + * implement {@link AutoCloseable} will be closed by Spark during broadcast removal. As an + * alternative, use {@link org.apache.iceberg.SerializableTable}. + * + * @deprecated will be removed in 1.4.0 + */ + @Deprecated + public static FileIO serializableFileIO(Table table) { + if (table.io() instanceof HadoopConfigurable) { + // we need to use Spark's SerializableConfiguration to avoid issues with Kryo serialization + ((HadoopConfigurable) table.io()) + .serializeConfWith(conf -> new SerializableConfiguration(conf)::value); + } + + return table.io(); + } + + /** + * Check whether the partition transforms in a spec can be used to write data. + * + * @param spec a PartitionSpec + * @throws UnsupportedOperationException if the spec contains unknown partition transforms + */ + public static void validatePartitionTransforms(PartitionSpec spec) { + if (spec.fields().stream().anyMatch(field -> field.transform() instanceof UnknownTransform)) { + String unsupported = + spec.fields().stream() + .map(PartitionField::transform) + .filter(transform -> transform instanceof UnknownTransform) + .map(Transform::toString) + .collect(Collectors.joining(", ")); + + throw new UnsupportedOperationException( + String.format("Cannot write using unsupported transforms: %s", unsupported)); + } + } + + /** + * A modified version of Spark's LookupCatalog.CatalogAndIdentifier.unapply Attempts to find the + * catalog and identifier a multipart identifier represents + * + * @param nameParts Multipart identifier representing a table + * @return The CatalogPlugin and Identifier for the table + */ + public static Pair catalogAndIdentifier( + List nameParts, + Function catalogProvider, + BiFunction identiferProvider, + C currentCatalog, + String[] currentNamespace) { + Preconditions.checkArgument( + !nameParts.isEmpty(), "Cannot determine catalog and identifier from empty name"); + + int lastElementIndex = nameParts.size() - 1; + String name = nameParts.get(lastElementIndex); + + if (nameParts.size() == 1) { + // Only a single element, use current catalog and namespace + return Pair.of(currentCatalog, identiferProvider.apply(currentNamespace, name)); + } else { + C catalog = catalogProvider.apply(nameParts.get(0)); + if (catalog == null) { + // The first element was not a valid catalog, treat it like part of the namespace + String[] namespace = nameParts.subList(0, lastElementIndex).toArray(new String[0]); + return Pair.of(currentCatalog, identiferProvider.apply(namespace, name)); + } else { + // Assume the first element is a valid catalog + String[] namespace = nameParts.subList(1, lastElementIndex).toArray(new String[0]); + return Pair.of(catalog, identiferProvider.apply(namespace, name)); + } + } + } + + /** + * Responsible for checking if the table schema has a timestamp without timezone column + * + * @param schema table schema to check if it contains a timestamp without timezone column + * @return boolean indicating if the schema passed in has a timestamp field without a timezone + */ + public static boolean hasTimestampWithoutZone(Schema schema) { + return TypeUtil.find(schema, t -> Types.TimestampType.withoutZone().equals(t)) != null; + } + + /** + * Checks whether timestamp types for new tables should be stored with timezone info. + * + *

The default value is false and all timestamp fields are stored as {@link + * Types.TimestampType#withZone()}. If enabled, all timestamp fields in new tables will be stored + * as {@link Types.TimestampType#withoutZone()}. + * + * @param sessionConf a Spark runtime config + * @return true if timestamp types for new tables should be stored with timezone info + */ + public static boolean useTimestampWithoutZoneInNewTables(RuntimeConfig sessionConf) { + String sessionConfValue = + sessionConf.get(SparkSQLProperties.USE_TIMESTAMP_WITHOUT_TIME_ZONE_IN_NEW_TABLES, null); + if (sessionConfValue != null) { + return Boolean.parseBoolean(sessionConfValue); + } + return SparkSQLProperties.USE_TIMESTAMP_WITHOUT_TIME_ZONE_IN_NEW_TABLES_DEFAULT; + } + + /** + * Pulls any Catalog specific overrides for the Hadoop conf from the current SparkSession, which + * can be set via `spark.sql.catalog.$catalogName.hadoop.*` + * + *

Mirrors the override of hadoop configurations for a given spark session using + * `spark.hadoop.*`. + * + *

The SparkCatalog allows for hadoop configurations to be overridden per catalog, by setting + * them on the SQLConf, where the following will add the property "fs.default.name" with value + * "hdfs://hanksnamenode:8020" to the catalog's hadoop configuration. SparkSession.builder() + * .config(s"spark.sql.catalog.$catalogName.hadoop.fs.default.name", "hdfs://hanksnamenode:8020") + * .getOrCreate() + * + * @param spark The current Spark session + * @param catalogName Name of the catalog to find overrides for. + * @return the Hadoop Configuration that should be used for this catalog, with catalog specific + * overrides applied. + */ + public static Configuration hadoopConfCatalogOverrides(SparkSession spark, String catalogName) { + // Find keys for the catalog intended to be hadoop configurations + final String hadoopConfCatalogPrefix = hadoopConfPrefixForCatalog(catalogName); + final Configuration conf = spark.sessionState().newHadoopConf(); + spark + .sqlContext() + .conf() + .settings() + .forEach( + (k, v) -> { + // these checks are copied from `spark.sessionState().newHadoopConfWithOptions()` + // to avoid converting back and forth between Scala / Java map types + if (v != null && k != null && k.startsWith(hadoopConfCatalogPrefix)) { + conf.set(k.substring(hadoopConfCatalogPrefix.length()), v); + } + }); + return conf; + } + + private static String hadoopConfPrefixForCatalog(String catalogName) { + return String.format(SPARK_CATALOG_HADOOP_CONF_OVERRIDE_FMT_STR, catalogName); + } + + /** + * Get a List of Spark filter Expression. + * + * @param schema table schema + * @param filters filters in the format of a Map, where key is one of the table column name, and + * value is the specific value to be filtered on the column. + * @return a List of filters in the format of Spark Expression. + */ + public static List partitionMapToExpression( + StructType schema, Map filters) { + List filterExpressions = Lists.newArrayList(); + for (Map.Entry entry : filters.entrySet()) { + try { + int index = schema.fieldIndex(entry.getKey()); + DataType dataType = schema.fields()[index].dataType(); + BoundReference ref = new BoundReference(index, dataType, true); + switch (dataType.typeName()) { + case "integer": + filterExpressions.add( + new EqualTo( + ref, + Literal.create(Integer.parseInt(entry.getValue()), DataTypes.IntegerType))); + break; + case "string": + filterExpressions.add( + new EqualTo(ref, Literal.create(entry.getValue(), DataTypes.StringType))); + break; + case "short": + filterExpressions.add( + new EqualTo( + ref, Literal.create(Short.parseShort(entry.getValue()), DataTypes.ShortType))); + break; + case "long": + filterExpressions.add( + new EqualTo( + ref, Literal.create(Long.parseLong(entry.getValue()), DataTypes.LongType))); + break; + case "float": + filterExpressions.add( + new EqualTo( + ref, Literal.create(Float.parseFloat(entry.getValue()), DataTypes.FloatType))); + break; + case "double": + filterExpressions.add( + new EqualTo( + ref, + Literal.create(Double.parseDouble(entry.getValue()), DataTypes.DoubleType))); + break; + case "date": + filterExpressions.add( + new EqualTo( + ref, + Literal.create( + new Date(DateTime.parse(entry.getValue()).getMillis()), + DataTypes.DateType))); + break; + case "timestamp": + filterExpressions.add( + new EqualTo( + ref, + Literal.create( + new Timestamp(DateTime.parse(entry.getValue()).getMillis()), + DataTypes.TimestampType))); + break; + default: + throw new IllegalStateException( + "Unexpected data type in partition filters: " + dataType); + } + } catch (IllegalArgumentException e) { + // ignore if filter is not on table columns + } + } + + return filterExpressions; + } + + public static String toColumnName(NamedReference ref) { + return DOT.join(ref.fieldNames()); + } + + public static boolean caseSensitive(SparkSession spark) { + return Boolean.parseBoolean(spark.conf().get("spark.sql.caseSensitive")); + } +} diff --git a/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/SparkV2Filters.java b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/SparkV2Filters.java new file mode 100644 index 000000000000..6d564bbd623b --- /dev/null +++ b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/SparkV2Filters.java @@ -0,0 +1,331 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark; + +import static org.apache.iceberg.expressions.Expressions.and; +import static org.apache.iceberg.expressions.Expressions.equal; +import static org.apache.iceberg.expressions.Expressions.greaterThan; +import static org.apache.iceberg.expressions.Expressions.greaterThanOrEqual; +import static org.apache.iceberg.expressions.Expressions.in; +import static org.apache.iceberg.expressions.Expressions.isNaN; +import static org.apache.iceberg.expressions.Expressions.isNull; +import static org.apache.iceberg.expressions.Expressions.lessThan; +import static org.apache.iceberg.expressions.Expressions.lessThanOrEqual; +import static org.apache.iceberg.expressions.Expressions.not; +import static org.apache.iceberg.expressions.Expressions.notIn; +import static org.apache.iceberg.expressions.Expressions.notNull; +import static org.apache.iceberg.expressions.Expressions.or; +import static org.apache.iceberg.expressions.Expressions.startsWith; + +import java.util.Arrays; +import java.util.Map; +import java.util.Objects; +import java.util.stream.Collectors; +import org.apache.iceberg.expressions.Expression; +import org.apache.iceberg.expressions.Expression.Operation; +import org.apache.iceberg.expressions.Expressions; +import org.apache.iceberg.relocated.com.google.common.base.Preconditions; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap; +import org.apache.iceberg.util.NaNUtil; +import org.apache.spark.sql.connector.expressions.Literal; +import org.apache.spark.sql.connector.expressions.NamedReference; +import org.apache.spark.sql.connector.expressions.filter.And; +import org.apache.spark.sql.connector.expressions.filter.Not; +import org.apache.spark.sql.connector.expressions.filter.Or; +import org.apache.spark.sql.connector.expressions.filter.Predicate; +import org.apache.spark.unsafe.types.UTF8String; + +public class SparkV2Filters { + + private static final String TRUE = "ALWAYS_TRUE"; + private static final String FALSE = "ALWAYS_FALSE"; + private static final String EQ = "="; + private static final String EQ_NULL_SAFE = "<=>"; + private static final String GT = ">"; + private static final String GT_EQ = ">="; + private static final String LT = "<"; + private static final String LT_EQ = "<="; + private static final String IN = "IN"; + private static final String IS_NULL = "IS_NULL"; + private static final String NOT_NULL = "IS_NOT_NULL"; + private static final String AND = "AND"; + private static final String OR = "OR"; + private static final String NOT = "NOT"; + private static final String STARTS_WITH = "STARTS_WITH"; + + private static final Map FILTERS = + ImmutableMap.builder() + .put(TRUE, Operation.TRUE) + .put(FALSE, Operation.FALSE) + .put(EQ, Operation.EQ) + .put(EQ_NULL_SAFE, Operation.EQ) + .put(GT, Operation.GT) + .put(GT_EQ, Operation.GT_EQ) + .put(LT, Operation.LT) + .put(LT_EQ, Operation.LT_EQ) + .put(IN, Operation.IN) + .put(IS_NULL, Operation.IS_NULL) + .put(NOT_NULL, Operation.NOT_NULL) + .put(AND, Operation.AND) + .put(OR, Operation.OR) + .put(NOT, Operation.NOT) + .put(STARTS_WITH, Operation.STARTS_WITH) + .buildOrThrow(); + + private SparkV2Filters() {} + + @SuppressWarnings({"checkstyle:CyclomaticComplexity", "checkstyle:MethodLength"}) + public static Expression convert(Predicate predicate) { + Operation op = FILTERS.get(predicate.name()); + if (op != null) { + switch (op) { + case TRUE: + return Expressions.alwaysTrue(); + + case FALSE: + return Expressions.alwaysFalse(); + + case IS_NULL: + return isRef(child(predicate)) ? isNull(SparkUtil.toColumnName(child(predicate))) : null; + + case NOT_NULL: + return isRef(child(predicate)) ? notNull(SparkUtil.toColumnName(child(predicate))) : null; + + case LT: + if (isRef(leftChild(predicate)) && isLiteral(rightChild(predicate))) { + String columnName = SparkUtil.toColumnName(leftChild(predicate)); + return lessThan(columnName, convertLiteral(rightChild(predicate))); + } else if (isRef(rightChild(predicate)) && isLiteral(leftChild(predicate))) { + String columnName = SparkUtil.toColumnName(rightChild(predicate)); + return greaterThan(columnName, convertLiteral(leftChild(predicate))); + } else { + return null; + } + + case LT_EQ: + if (isRef(leftChild(predicate)) && isLiteral(rightChild(predicate))) { + String columnName = SparkUtil.toColumnName(leftChild(predicate)); + return lessThanOrEqual(columnName, convertLiteral(rightChild(predicate))); + } else if (isRef(rightChild(predicate)) && isLiteral(leftChild(predicate))) { + String columnName = SparkUtil.toColumnName(rightChild(predicate)); + return greaterThanOrEqual(columnName, convertLiteral(leftChild(predicate))); + } else { + return null; + } + + case GT: + if (isRef(leftChild(predicate)) && isLiteral(rightChild(predicate))) { + String columnName = SparkUtil.toColumnName(leftChild(predicate)); + return greaterThan(columnName, convertLiteral(rightChild(predicate))); + } else if (isRef(rightChild(predicate)) && isLiteral(leftChild(predicate))) { + String columnName = SparkUtil.toColumnName(rightChild(predicate)); + return lessThan(columnName, convertLiteral(leftChild(predicate))); + } else { + return null; + } + + case GT_EQ: + if (isRef(leftChild(predicate)) && isLiteral(rightChild(predicate))) { + String columnName = SparkUtil.toColumnName(leftChild(predicate)); + return greaterThanOrEqual(columnName, convertLiteral(rightChild(predicate))); + } else if (isRef(rightChild(predicate)) && isLiteral(leftChild(predicate))) { + String columnName = SparkUtil.toColumnName(rightChild(predicate)); + return lessThanOrEqual(columnName, convertLiteral(leftChild(predicate))); + } else { + return null; + } + + case EQ: // used for both eq and null-safe-eq + Object value; + String columnName; + if (isRef(leftChild(predicate)) && isLiteral(rightChild(predicate))) { + columnName = SparkUtil.toColumnName(leftChild(predicate)); + value = convertLiteral(rightChild(predicate)); + } else if (isRef(rightChild(predicate)) && isLiteral(leftChild(predicate))) { + columnName = SparkUtil.toColumnName(rightChild(predicate)); + value = convertLiteral(leftChild(predicate)); + } else { + return null; + } + + if (predicate.name().equals(EQ)) { + // comparison with null in normal equality is always null. this is probably a mistake. + Preconditions.checkNotNull( + value, "Expression is always false (eq is not null-safe): %s", predicate); + return handleEqual(columnName, value); + } else { // "<=>" + if (value == null) { + return isNull(columnName); + } else { + return handleEqual(columnName, value); + } + } + + case IN: + if (isSupportedInPredicate(predicate)) { + return in( + SparkUtil.toColumnName(childAtIndex(predicate, 0)), + Arrays.stream(predicate.children()) + .skip(1) + .map(val -> convertLiteral(((Literal) val))) + .filter(Objects::nonNull) + .collect(Collectors.toList())); + } else { + return null; + } + + case NOT: + Not notPredicate = (Not) predicate; + Predicate childPredicate = notPredicate.child(); + if (childPredicate.name().equals(IN) && isSupportedInPredicate(childPredicate)) { + // infer an extra notNull predicate for Spark NOT IN filters + // as Iceberg expressions don't follow the 3-value SQL boolean logic + // col NOT IN (1, 2) in Spark is equal to notNull(col) && notIn(col, 1, 2) in Iceberg + Expression notIn = + notIn( + SparkUtil.toColumnName(childAtIndex(childPredicate, 0)), + Arrays.stream(childPredicate.children()) + .skip(1) + .map(val -> convertLiteral(((Literal) val))) + .filter(Objects::nonNull) + .collect(Collectors.toList())); + return and(notNull(SparkUtil.toColumnName(childAtIndex(childPredicate, 0))), notIn); + } else if (hasNoInFilter(childPredicate)) { + Expression child = convert(childPredicate); + if (child != null) { + return not(child); + } + } + return null; + + case AND: + { + And andPredicate = (And) predicate; + Expression left = convert(andPredicate.left()); + Expression right = convert(andPredicate.right()); + if (left != null && right != null) { + return and(left, right); + } + return null; + } + + case OR: + { + Or orPredicate = (Or) predicate; + Expression left = convert(orPredicate.left()); + Expression right = convert(orPredicate.right()); + if (left != null && right != null) { + return or(left, right); + } + return null; + } + + case STARTS_WITH: + String colName = SparkUtil.toColumnName(leftChild(predicate)); + return startsWith(colName, convertLiteral(rightChild(predicate)).toString()); + } + } + + return null; + } + + @SuppressWarnings("unchecked") + private static T child(Predicate predicate) { + org.apache.spark.sql.connector.expressions.Expression[] children = predicate.children(); + Preconditions.checkArgument( + children.length == 1, "Predicate should have one child: %s", predicate); + return (T) children[0]; + } + + @SuppressWarnings("unchecked") + private static T leftChild(Predicate predicate) { + org.apache.spark.sql.connector.expressions.Expression[] children = predicate.children(); + Preconditions.checkArgument( + children.length == 2, "Predicate should have two children: %s", predicate); + return (T) children[0]; + } + + @SuppressWarnings("unchecked") + private static T rightChild(Predicate predicate) { + org.apache.spark.sql.connector.expressions.Expression[] children = predicate.children(); + Preconditions.checkArgument( + children.length == 2, "Predicate should have two children: %s", predicate); + return (T) children[1]; + } + + @SuppressWarnings("unchecked") + private static T childAtIndex(Predicate predicate, int index) { + return (T) predicate.children()[index]; + } + + private static boolean isRef(org.apache.spark.sql.connector.expressions.Expression expr) { + return expr instanceof NamedReference; + } + + private static boolean isLiteral(org.apache.spark.sql.connector.expressions.Expression expr) { + return expr instanceof Literal; + } + + private static Object convertLiteral(Literal literal) { + if (literal.value() instanceof UTF8String) { + return ((UTF8String) literal.value()).toString(); + } + return literal.value(); + } + + private static Expression handleEqual(String attribute, Object value) { + if (NaNUtil.isNaN(value)) { + return isNaN(attribute); + } else { + return equal(attribute, value); + } + } + + private static boolean hasNoInFilter(Predicate predicate) { + Operation op = FILTERS.get(predicate.name()); + + if (op != null) { + switch (op) { + case AND: + And andPredicate = (And) predicate; + return hasNoInFilter(andPredicate.left()) && hasNoInFilter(andPredicate.right()); + case OR: + Or orPredicate = (Or) predicate; + return hasNoInFilter(orPredicate.left()) && hasNoInFilter(orPredicate.right()); + case NOT: + Not notPredicate = (Not) predicate; + return hasNoInFilter(notPredicate.child()); + case IN: + return false; + default: + return true; + } + } + + return false; + } + + private static boolean isSupportedInPredicate(Predicate predicate) { + if (!isRef(childAtIndex(predicate, 0))) { + return false; + } else { + return Arrays.stream(predicate.children()).skip(1).allMatch(SparkV2Filters::isLiteral); + } + } +} diff --git a/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/SparkValueConverter.java b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/SparkValueConverter.java new file mode 100644 index 000000000000..687d9f43ade8 --- /dev/null +++ b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/SparkValueConverter.java @@ -0,0 +1,123 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark; + +import java.nio.ByteBuffer; +import java.util.List; +import java.util.Map; +import org.apache.iceberg.Schema; +import org.apache.iceberg.data.GenericRecord; +import org.apache.iceberg.data.Record; +import org.apache.iceberg.relocated.com.google.common.collect.Lists; +import org.apache.iceberg.relocated.com.google.common.collect.Maps; +import org.apache.iceberg.types.Type; +import org.apache.iceberg.types.Types; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.catalyst.util.DateTimeUtils; + +/** A utility class that converts Spark values to Iceberg's internal representation. */ +public class SparkValueConverter { + + private SparkValueConverter() {} + + public static Record convert(Schema schema, Row row) { + return convert(schema.asStruct(), row); + } + + public static Object convert(Type type, Object object) { + if (object == null) { + return null; + } + + switch (type.typeId()) { + case STRUCT: + return convert(type.asStructType(), (Row) object); + + case LIST: + List convertedList = Lists.newArrayList(); + List list = (List) object; + for (Object element : list) { + convertedList.add(convert(type.asListType().elementType(), element)); + } + return convertedList; + + case MAP: + Map convertedMap = Maps.newLinkedHashMap(); + Map map = (Map) object; + for (Map.Entry entry : map.entrySet()) { + convertedMap.put( + convert(type.asMapType().keyType(), entry.getKey()), + convert(type.asMapType().valueType(), entry.getValue())); + } + return convertedMap; + + case DATE: + // if spark.sql.datetime.java8API.enabled is set to true, java.time.LocalDate + // for Spark SQL DATE type otherwise java.sql.Date is returned. + return DateTimeUtils.anyToDays(object); + case TIMESTAMP: + // if spark.sql.datetime.java8API.enabled is set to true, java.time.Instant + // for Spark SQL TIMESTAMP type is returned otherwise java.sql.Timestamp is returned. + return DateTimeUtils.anyToMicros(object); + case BINARY: + return ByteBuffer.wrap((byte[]) object); + case INTEGER: + return ((Number) object).intValue(); + case BOOLEAN: + case LONG: + case FLOAT: + case DOUBLE: + case DECIMAL: + case STRING: + case FIXED: + return object; + default: + throw new UnsupportedOperationException("Not a supported type: " + type); + } + } + + private static Record convert(Types.StructType struct, Row row) { + if (row == null) { + return null; + } + + Record record = GenericRecord.create(struct); + List fields = struct.fields(); + for (int i = 0; i < fields.size(); i += 1) { + Types.NestedField field = fields.get(i); + + Type fieldType = field.type(); + + switch (fieldType.typeId()) { + case STRUCT: + record.set(i, convert(fieldType.asStructType(), row.getStruct(i))); + break; + case LIST: + record.set(i, convert(fieldType.asListType(), row.getList(i))); + break; + case MAP: + record.set(i, convert(fieldType.asMapType(), row.getJavaMap(i))); + break; + default: + record.set(i, convert(fieldType, row.get(i))); + } + } + return record; + } +} diff --git a/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/SparkWriteConf.java b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/SparkWriteConf.java new file mode 100644 index 000000000000..41777c515582 --- /dev/null +++ b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/SparkWriteConf.java @@ -0,0 +1,376 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark; + +import static org.apache.iceberg.DistributionMode.HASH; +import static org.apache.iceberg.DistributionMode.NONE; +import static org.apache.iceberg.DistributionMode.RANGE; + +import java.util.Locale; +import java.util.Map; +import org.apache.iceberg.DistributionMode; +import org.apache.iceberg.FileFormat; +import org.apache.iceberg.IsolationLevel; +import org.apache.iceberg.SnapshotSummary; +import org.apache.iceberg.Table; +import org.apache.iceberg.TableProperties; +import org.apache.iceberg.exceptions.ValidationException; +import org.apache.iceberg.relocated.com.google.common.base.Preconditions; +import org.apache.iceberg.relocated.com.google.common.collect.Maps; +import org.apache.spark.sql.RuntimeConfig; +import org.apache.spark.sql.SparkSession; +import org.apache.spark.sql.internal.SQLConf; + +/** + * A class for common Iceberg configs for Spark writes. + * + *

If a config is set at multiple levels, the following order of precedence is used (top to + * bottom): + * + *

    + *
  1. Write options + *
  2. Session configuration + *
  3. Table metadata + *
+ * + * The most specific value is set in write options and takes precedence over all other configs. If + * no write option is provided, this class checks the session configuration for any overrides. If no + * applicable value is found in the session configuration, this class uses the table metadata. + * + *

Note this class is NOT meant to be serialized and sent to executors. + */ +public class SparkWriteConf { + + private final Table table; + private final String branch; + private final RuntimeConfig sessionConf; + private final Map writeOptions; + private final SparkConfParser confParser; + + public SparkWriteConf(SparkSession spark, Table table, Map writeOptions) { + this(spark, table, null, writeOptions); + } + + public SparkWriteConf( + SparkSession spark, Table table, String branch, Map writeOptions) { + this.table = table; + this.branch = branch; + this.sessionConf = spark.conf(); + this.writeOptions = writeOptions; + this.confParser = new SparkConfParser(spark, table, writeOptions); + } + + public boolean checkNullability() { + return confParser + .booleanConf() + .option(SparkWriteOptions.CHECK_NULLABILITY) + .sessionConf(SparkSQLProperties.CHECK_NULLABILITY) + .defaultValue(SparkSQLProperties.CHECK_NULLABILITY_DEFAULT) + .parse(); + } + + public boolean checkOrdering() { + return confParser + .booleanConf() + .option(SparkWriteOptions.CHECK_ORDERING) + .sessionConf(SparkSQLProperties.CHECK_ORDERING) + .defaultValue(SparkSQLProperties.CHECK_ORDERING_DEFAULT) + .parse(); + } + + /** + * Enables writing a timestamp with time zone as a timestamp without time zone. + * + *

Generally, this is not safe as a timestamp without time zone is supposed to represent the + * wall-clock time, i.e. no matter the reader/writer timezone 3PM should always be read as 3PM, + * but a timestamp with time zone represents instant semantics, i.e. the timestamp is adjusted so + * that the corresponding time in the reader timezone is displayed. + * + *

When set to false (default), an exception must be thrown if the table contains a timestamp + * without time zone. + * + * @return boolean indicating if writing timestamps without timezone is allowed + */ + public boolean handleTimestampWithoutZone() { + return confParser + .booleanConf() + .option(SparkWriteOptions.HANDLE_TIMESTAMP_WITHOUT_TIMEZONE) + .sessionConf(SparkSQLProperties.HANDLE_TIMESTAMP_WITHOUT_TIMEZONE) + .defaultValue(SparkSQLProperties.HANDLE_TIMESTAMP_WITHOUT_TIMEZONE_DEFAULT) + .parse(); + } + + public String overwriteMode() { + String overwriteMode = writeOptions.get(SparkWriteOptions.OVERWRITE_MODE); + return overwriteMode != null ? overwriteMode.toLowerCase(Locale.ROOT) : null; + } + + public boolean wapEnabled() { + return confParser + .booleanConf() + .tableProperty(TableProperties.WRITE_AUDIT_PUBLISH_ENABLED) + .defaultValue(TableProperties.WRITE_AUDIT_PUBLISH_ENABLED_DEFAULT) + .parse(); + } + + public String wapId() { + return sessionConf.get(SparkSQLProperties.WAP_ID, null); + } + + public boolean mergeSchema() { + return confParser + .booleanConf() + .option(SparkWriteOptions.MERGE_SCHEMA) + .option(SparkWriteOptions.SPARK_MERGE_SCHEMA) + .defaultValue(SparkWriteOptions.MERGE_SCHEMA_DEFAULT) + .parse(); + } + + public int outputSpecId() { + int outputSpecId = + confParser + .intConf() + .option(SparkWriteOptions.OUTPUT_SPEC_ID) + .defaultValue(table.spec().specId()) + .parse(); + Preconditions.checkArgument( + table.specs().containsKey(outputSpecId), + "Output spec id %s is not a valid spec id for table", + outputSpecId); + return outputSpecId; + } + + public FileFormat dataFileFormat() { + String valueAsString = + confParser + .stringConf() + .option(SparkWriteOptions.WRITE_FORMAT) + .tableProperty(TableProperties.DEFAULT_FILE_FORMAT) + .defaultValue(TableProperties.DEFAULT_FILE_FORMAT_DEFAULT) + .parse(); + return FileFormat.fromString(valueAsString); + } + + public long targetDataFileSize() { + return confParser + .longConf() + .option(SparkWriteOptions.TARGET_FILE_SIZE_BYTES) + .tableProperty(TableProperties.WRITE_TARGET_FILE_SIZE_BYTES) + .defaultValue(TableProperties.WRITE_TARGET_FILE_SIZE_BYTES_DEFAULT) + .parse(); + } + + public boolean fanoutWriterEnabled() { + return confParser + .booleanConf() + .option(SparkWriteOptions.FANOUT_ENABLED) + .tableProperty(TableProperties.SPARK_WRITE_PARTITIONED_FANOUT_ENABLED) + .defaultValue(TableProperties.SPARK_WRITE_PARTITIONED_FANOUT_ENABLED_DEFAULT) + .parse(); + } + + public FileFormat deleteFileFormat() { + String valueAsString = + confParser + .stringConf() + .option(SparkWriteOptions.DELETE_FORMAT) + .tableProperty(TableProperties.DELETE_DEFAULT_FILE_FORMAT) + .parseOptional(); + return valueAsString != null ? FileFormat.fromString(valueAsString) : dataFileFormat(); + } + + public long targetDeleteFileSize() { + return confParser + .longConf() + .option(SparkWriteOptions.TARGET_DELETE_FILE_SIZE_BYTES) + .tableProperty(TableProperties.DELETE_TARGET_FILE_SIZE_BYTES) + .defaultValue(TableProperties.DELETE_TARGET_FILE_SIZE_BYTES_DEFAULT) + .parse(); + } + + public Map extraSnapshotMetadata() { + Map extraSnapshotMetadata = Maps.newHashMap(); + + writeOptions.forEach( + (key, value) -> { + if (key.startsWith(SnapshotSummary.EXTRA_METADATA_PREFIX)) { + extraSnapshotMetadata.put( + key.substring(SnapshotSummary.EXTRA_METADATA_PREFIX.length()), value); + } + }); + + return extraSnapshotMetadata; + } + + public String rewrittenFileSetId() { + return confParser + .stringConf() + .option(SparkWriteOptions.REWRITTEN_FILE_SCAN_TASK_SET_ID) + .parseOptional(); + } + + public DistributionMode distributionMode() { + String modeName = + confParser + .stringConf() + .option(SparkWriteOptions.DISTRIBUTION_MODE) + .sessionConf(SparkSQLProperties.DISTRIBUTION_MODE) + .tableProperty(TableProperties.WRITE_DISTRIBUTION_MODE) + .parseOptional(); + + if (modeName != null) { + DistributionMode mode = DistributionMode.fromName(modeName); + return adjustWriteDistributionMode(mode); + } else { + return defaultWriteDistributionMode(); + } + } + + private DistributionMode adjustWriteDistributionMode(DistributionMode mode) { + if (mode == RANGE && table.spec().isUnpartitioned() && table.sortOrder().isUnsorted()) { + return NONE; + } else if (mode == HASH && table.spec().isUnpartitioned()) { + return NONE; + } else { + return mode; + } + } + + private DistributionMode defaultWriteDistributionMode() { + if (table.sortOrder().isSorted()) { + return RANGE; + } else if (table.spec().isPartitioned()) { + return HASH; + } else { + return NONE; + } + } + + public DistributionMode deleteDistributionMode() { + String deleteModeName = + confParser + .stringConf() + .option(SparkWriteOptions.DISTRIBUTION_MODE) + .sessionConf(SparkSQLProperties.DISTRIBUTION_MODE) + .tableProperty(TableProperties.DELETE_DISTRIBUTION_MODE) + .defaultValue(TableProperties.WRITE_DISTRIBUTION_MODE_HASH) + .parse(); + return DistributionMode.fromName(deleteModeName); + } + + public DistributionMode updateDistributionMode() { + String updateModeName = + confParser + .stringConf() + .option(SparkWriteOptions.DISTRIBUTION_MODE) + .sessionConf(SparkSQLProperties.DISTRIBUTION_MODE) + .tableProperty(TableProperties.UPDATE_DISTRIBUTION_MODE) + .defaultValue(TableProperties.WRITE_DISTRIBUTION_MODE_HASH) + .parse(); + return DistributionMode.fromName(updateModeName); + } + + public DistributionMode copyOnWriteMergeDistributionMode() { + String mergeModeName = + confParser + .stringConf() + .option(SparkWriteOptions.DISTRIBUTION_MODE) + .sessionConf(SparkSQLProperties.DISTRIBUTION_MODE) + .tableProperty(TableProperties.MERGE_DISTRIBUTION_MODE) + .parseOptional(); + + if (mergeModeName != null) { + DistributionMode mergeMode = DistributionMode.fromName(mergeModeName); + return adjustWriteDistributionMode(mergeMode); + + } else if (table.spec().isPartitioned()) { + return HASH; + + } else { + return distributionMode(); + } + } + + public DistributionMode positionDeltaMergeDistributionMode() { + String mergeModeName = + confParser + .stringConf() + .option(SparkWriteOptions.DISTRIBUTION_MODE) + .sessionConf(SparkSQLProperties.DISTRIBUTION_MODE) + .tableProperty(TableProperties.MERGE_DISTRIBUTION_MODE) + .defaultValue(TableProperties.WRITE_DISTRIBUTION_MODE_HASH) + .parse(); + return DistributionMode.fromName(mergeModeName); + } + + public boolean useTableDistributionAndOrdering() { + return confParser + .booleanConf() + .option(SparkWriteOptions.USE_TABLE_DISTRIBUTION_AND_ORDERING) + .defaultValue(SparkWriteOptions.USE_TABLE_DISTRIBUTION_AND_ORDERING_DEFAULT) + .parse(); + } + + public Long validateFromSnapshotId() { + return confParser + .longConf() + .option(SparkWriteOptions.VALIDATE_FROM_SNAPSHOT_ID) + .parseOptional(); + } + + public IsolationLevel isolationLevel() { + String isolationLevelName = + confParser.stringConf().option(SparkWriteOptions.ISOLATION_LEVEL).parseOptional(); + return isolationLevelName != null ? IsolationLevel.fromName(isolationLevelName) : null; + } + + public boolean caseSensitive() { + return confParser + .booleanConf() + .sessionConf(SQLConf.CASE_SENSITIVE().key()) + .defaultValue(SQLConf.CASE_SENSITIVE().defaultValueString()) + .parse(); + } + + public String branch() { + if (wapEnabled()) { + String wapId = wapId(); + String wapBranch = + confParser.stringConf().sessionConf(SparkSQLProperties.WAP_BRANCH).parseOptional(); + + ValidationException.check( + wapId == null || wapBranch == null, + "Cannot set both WAP ID and branch, but got ID [%s] and branch [%s]", + wapId, + wapBranch); + + if (wapBranch != null) { + ValidationException.check( + branch == null, + "Cannot write to both branch and WAP branch, but got branch [%s] and WAP branch [%s]", + branch, + wapBranch); + + return wapBranch; + } + } + + return branch; + } +} diff --git a/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/SparkWriteOptions.java b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/SparkWriteOptions.java new file mode 100644 index 000000000000..c4eacb7b98a4 --- /dev/null +++ b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/SparkWriteOptions.java @@ -0,0 +1,82 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark; + +/** Spark DF write options */ +public class SparkWriteOptions { + + private SparkWriteOptions() {} + + // Fileformat for write operations(default: Table write.format.default ) + public static final String WRITE_FORMAT = "write-format"; + + // Overrides this table's write.target-file-size-bytes + public static final String TARGET_FILE_SIZE_BYTES = "target-file-size-bytes"; + + // Overrides the default file format for delete files + public static final String DELETE_FORMAT = "delete-format"; + + // Overrides the default size for delete files + public static final String TARGET_DELETE_FILE_SIZE_BYTES = "target-delete-file-size-bytes"; + + // Sets the nullable check on fields(default: true) + public static final String CHECK_NULLABILITY = "check-nullability"; + + // Adds an entry with custom-key and corresponding value in the snapshot summary + // ex: df.write().format(iceberg) + // .option(SparkWriteOptions.SNAPSHOT_PROPERTY_PREFIX."key1", "value1") + // .save(location) + public static final String SNAPSHOT_PROPERTY_PREFIX = "snapshot-property"; + + // Overrides table property write.spark.fanout.enabled(default: false) + public static final String FANOUT_ENABLED = "fanout-enabled"; + + // Checks if input schema and table schema are same(default: true) + public static final String CHECK_ORDERING = "check-ordering"; + + // File scan task set ID that indicates which files must be replaced + public static final String REWRITTEN_FILE_SCAN_TASK_SET_ID = "rewritten-file-scan-task-set-id"; + + // Controls whether to allow writing timestamps without zone info + public static final String HANDLE_TIMESTAMP_WITHOUT_TIMEZONE = + "handle-timestamp-without-timezone"; + + public static final String OUTPUT_SPEC_ID = "output-spec-id"; + + public static final String OVERWRITE_MODE = "overwrite-mode"; + + // Overrides the default distribution mode for a write operation + public static final String DISTRIBUTION_MODE = "distribution-mode"; + + // Controls whether to take into account the table distribution and sort order during a write + // operation + public static final String USE_TABLE_DISTRIBUTION_AND_ORDERING = + "use-table-distribution-and-ordering"; + public static final boolean USE_TABLE_DISTRIBUTION_AND_ORDERING_DEFAULT = true; + + public static final String MERGE_SCHEMA = "merge-schema"; + public static final String SPARK_MERGE_SCHEMA = "mergeSchema"; + public static final boolean MERGE_SCHEMA_DEFAULT = false; + + // Identifies snapshot from which to start validating conflicting changes + public static final String VALIDATE_FROM_SNAPSHOT_ID = "validate-from-snapshot-id"; + + // Isolation Level for DataFrame calls. Currently supported by overwritePartitions + public static final String ISOLATION_LEVEL = "isolation-level"; +} diff --git a/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/TypeToSparkType.java b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/TypeToSparkType.java new file mode 100644 index 000000000000..4d4ec6782c72 --- /dev/null +++ b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/TypeToSparkType.java @@ -0,0 +1,134 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark; + +import java.util.List; +import org.apache.iceberg.MetadataColumns; +import org.apache.iceberg.Schema; +import org.apache.iceberg.relocated.com.google.common.collect.Lists; +import org.apache.iceberg.types.Type; +import org.apache.iceberg.types.TypeUtil; +import org.apache.iceberg.types.Types; +import org.apache.spark.sql.types.ArrayType$; +import org.apache.spark.sql.types.BinaryType$; +import org.apache.spark.sql.types.BooleanType$; +import org.apache.spark.sql.types.DataType; +import org.apache.spark.sql.types.DateType$; +import org.apache.spark.sql.types.DecimalType$; +import org.apache.spark.sql.types.DoubleType$; +import org.apache.spark.sql.types.FloatType$; +import org.apache.spark.sql.types.IntegerType$; +import org.apache.spark.sql.types.LongType$; +import org.apache.spark.sql.types.MapType$; +import org.apache.spark.sql.types.Metadata; +import org.apache.spark.sql.types.MetadataBuilder; +import org.apache.spark.sql.types.StringType$; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType$; +import org.apache.spark.sql.types.TimestampType$; + +class TypeToSparkType extends TypeUtil.SchemaVisitor { + TypeToSparkType() {} + + public static final String METADATA_COL_ATTR_KEY = "__metadata_col"; + + @Override + public DataType schema(Schema schema, DataType structType) { + return structType; + } + + @Override + public DataType struct(Types.StructType struct, List fieldResults) { + List fields = struct.fields(); + + List sparkFields = Lists.newArrayListWithExpectedSize(fieldResults.size()); + for (int i = 0; i < fields.size(); i += 1) { + Types.NestedField field = fields.get(i); + DataType type = fieldResults.get(i); + Metadata metadata = fieldMetadata(field.fieldId()); + StructField sparkField = StructField.apply(field.name(), type, field.isOptional(), metadata); + if (field.doc() != null) { + sparkField = sparkField.withComment(field.doc()); + } + sparkFields.add(sparkField); + } + + return StructType$.MODULE$.apply(sparkFields); + } + + @Override + public DataType field(Types.NestedField field, DataType fieldResult) { + return fieldResult; + } + + @Override + public DataType list(Types.ListType list, DataType elementResult) { + return ArrayType$.MODULE$.apply(elementResult, list.isElementOptional()); + } + + @Override + public DataType map(Types.MapType map, DataType keyResult, DataType valueResult) { + return MapType$.MODULE$.apply(keyResult, valueResult, map.isValueOptional()); + } + + @Override + public DataType primitive(Type.PrimitiveType primitive) { + switch (primitive.typeId()) { + case BOOLEAN: + return BooleanType$.MODULE$; + case INTEGER: + return IntegerType$.MODULE$; + case LONG: + return LongType$.MODULE$; + case FLOAT: + return FloatType$.MODULE$; + case DOUBLE: + return DoubleType$.MODULE$; + case DATE: + return DateType$.MODULE$; + case TIME: + throw new UnsupportedOperationException("Spark does not support time fields"); + case TIMESTAMP: + return TimestampType$.MODULE$; + case STRING: + return StringType$.MODULE$; + case UUID: + // use String + return StringType$.MODULE$; + case FIXED: + return BinaryType$.MODULE$; + case BINARY: + return BinaryType$.MODULE$; + case DECIMAL: + Types.DecimalType decimal = (Types.DecimalType) primitive; + return DecimalType$.MODULE$.apply(decimal.precision(), decimal.scale()); + default: + throw new UnsupportedOperationException( + "Cannot convert unknown type to Spark: " + primitive); + } + } + + private Metadata fieldMetadata(int fieldId) { + if (MetadataColumns.metadataFieldIds().contains(fieldId)) { + return new MetadataBuilder().putBoolean(METADATA_COL_ATTR_KEY, true).build(); + } + + return Metadata.empty(); + } +} diff --git a/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/actions/BaseSnapshotUpdateSparkAction.java b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/actions/BaseSnapshotUpdateSparkAction.java new file mode 100644 index 000000000000..77debe1e589d --- /dev/null +++ b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/actions/BaseSnapshotUpdateSparkAction.java @@ -0,0 +1,42 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.actions; + +import java.util.Map; +import org.apache.iceberg.relocated.com.google.common.collect.Maps; +import org.apache.spark.sql.SparkSession; + +abstract class BaseSnapshotUpdateSparkAction extends BaseSparkAction { + + private final Map summary = Maps.newHashMap(); + + protected BaseSnapshotUpdateSparkAction(SparkSession spark) { + super(spark); + } + + public ThisT snapshotProperty(String property, String value) { + summary.put(property, value); + return self(); + } + + protected void commit(org.apache.iceberg.SnapshotUpdate update) { + summary.forEach(update::set); + update.commit(); + } +} diff --git a/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/actions/BaseSparkAction.java b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/actions/BaseSparkAction.java new file mode 100644 index 000000000000..3c007c6214c2 --- /dev/null +++ b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/actions/BaseSparkAction.java @@ -0,0 +1,454 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.actions; + +import static org.apache.iceberg.MetadataTableType.ALL_MANIFESTS; +import static org.apache.spark.sql.functions.col; +import static org.apache.spark.sql.functions.lit; + +import java.util.Collection; +import java.util.Iterator; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.concurrent.atomic.AtomicLong; +import java.util.function.Consumer; +import java.util.function.Predicate; +import java.util.function.Supplier; +import org.apache.iceberg.AllManifestsTable; +import org.apache.iceberg.BaseTable; +import org.apache.iceberg.ContentFile; +import org.apache.iceberg.DataFile; +import org.apache.iceberg.FileContent; +import org.apache.iceberg.ManifestContent; +import org.apache.iceberg.ManifestFiles; +import org.apache.iceberg.MetadataTableType; +import org.apache.iceberg.PartitionSpec; +import org.apache.iceberg.ReachableFileUtil; +import org.apache.iceberg.StaticTableOperations; +import org.apache.iceberg.StatisticsFile; +import org.apache.iceberg.Table; +import org.apache.iceberg.TableMetadata; +import org.apache.iceberg.exceptions.NotFoundException; +import org.apache.iceberg.exceptions.ValidationException; +import org.apache.iceberg.io.BulkDeletionFailureException; +import org.apache.iceberg.io.CloseableIterator; +import org.apache.iceberg.io.ClosingIterator; +import org.apache.iceberg.io.FileIO; +import org.apache.iceberg.io.SupportsBulkOperations; +import org.apache.iceberg.relocated.com.google.common.base.Joiner; +import org.apache.iceberg.relocated.com.google.common.base.Splitter; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableList; +import org.apache.iceberg.relocated.com.google.common.collect.Iterators; +import org.apache.iceberg.relocated.com.google.common.collect.ListMultimap; +import org.apache.iceberg.relocated.com.google.common.collect.Lists; +import org.apache.iceberg.relocated.com.google.common.collect.Maps; +import org.apache.iceberg.relocated.com.google.common.collect.Multimaps; +import org.apache.iceberg.spark.JobGroupInfo; +import org.apache.iceberg.spark.JobGroupUtils; +import org.apache.iceberg.spark.SparkTableUtil; +import org.apache.iceberg.spark.source.SerializableTableWithSize; +import org.apache.iceberg.util.Tasks; +import org.apache.spark.SparkContext; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.api.java.function.FlatMapFunction; +import org.apache.spark.broadcast.Broadcast; +import org.apache.spark.sql.Column; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.SparkSession; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +abstract class BaseSparkAction { + + protected static final String MANIFEST = "Manifest"; + protected static final String MANIFEST_LIST = "Manifest List"; + protected static final String STATISTICS_FILES = "Statistics Files"; + protected static final String OTHERS = "Others"; + + protected static final String FILE_PATH = "file_path"; + protected static final String LAST_MODIFIED = "last_modified"; + + protected static final Splitter COMMA_SPLITTER = Splitter.on(","); + protected static final Joiner COMMA_JOINER = Joiner.on(','); + + private static final Logger LOG = LoggerFactory.getLogger(BaseSparkAction.class); + private static final AtomicInteger JOB_COUNTER = new AtomicInteger(); + private static final int DELETE_NUM_RETRIES = 3; + private static final int DELETE_GROUP_SIZE = 100000; + + private final SparkSession spark; + private final JavaSparkContext sparkContext; + private final Map options = Maps.newHashMap(); + + protected BaseSparkAction(SparkSession spark) { + this.spark = spark; + this.sparkContext = JavaSparkContext.fromSparkContext(spark.sparkContext()); + } + + protected SparkSession spark() { + return spark; + } + + protected JavaSparkContext sparkContext() { + return sparkContext; + } + + protected abstract ThisT self(); + + public ThisT option(String name, String value) { + options.put(name, value); + return self(); + } + + public ThisT options(Map newOptions) { + options.putAll(newOptions); + return self(); + } + + protected Map options() { + return options; + } + + protected T withJobGroupInfo(JobGroupInfo info, Supplier supplier) { + SparkContext context = spark().sparkContext(); + JobGroupInfo previousInfo = JobGroupUtils.getJobGroupInfo(context); + try { + JobGroupUtils.setJobGroupInfo(context, info); + return supplier.get(); + } finally { + JobGroupUtils.setJobGroupInfo(context, previousInfo); + } + } + + protected JobGroupInfo newJobGroupInfo(String groupId, String desc) { + return new JobGroupInfo(groupId + "-" + JOB_COUNTER.incrementAndGet(), desc, false); + } + + protected Table newStaticTable(TableMetadata metadata, FileIO io) { + String metadataFileLocation = metadata.metadataFileLocation(); + StaticTableOperations ops = new StaticTableOperations(metadataFileLocation, io); + return new BaseTable(ops, metadataFileLocation); + } + + protected Dataset contentFileDS(Table table) { + return contentFileDS(table, null); + } + + protected Dataset contentFileDS(Table table, Set snapshotIds) { + Table serializableTable = SerializableTableWithSize.copyOf(table); + Broadcast tableBroadcast = sparkContext.broadcast(serializableTable); + int numShufflePartitions = spark.sessionState().conf().numShufflePartitions(); + + Dataset manifestBeanDS = + manifestDF(table, snapshotIds) + .selectExpr( + "content", + "path", + "length", + "partition_spec_id as partitionSpecId", + "added_snapshot_id as addedSnapshotId") + .dropDuplicates("path") + .repartition(numShufflePartitions) // avoid adaptive execution combining tasks + .as(ManifestFileBean.ENCODER); + + return manifestBeanDS.flatMap(new ReadManifest(tableBroadcast), FileInfo.ENCODER); + } + + protected Dataset manifestDS(Table table) { + return manifestDS(table, null); + } + + protected Dataset manifestDS(Table table, Set snapshotIds) { + return manifestDF(table, snapshotIds) + .select(col("path"), lit(MANIFEST).as("type")) + .as(FileInfo.ENCODER); + } + + private Dataset manifestDF(Table table, Set snapshotIds) { + Dataset manifestDF = loadMetadataTable(table, ALL_MANIFESTS); + if (snapshotIds != null) { + Column filterCond = col(AllManifestsTable.REF_SNAPSHOT_ID.name()).isInCollection(snapshotIds); + return manifestDF.filter(filterCond); + } else { + return manifestDF; + } + } + + protected Dataset manifestListDS(Table table) { + return manifestListDS(table, null); + } + + protected Dataset manifestListDS(Table table, Set snapshotIds) { + List manifestLists = ReachableFileUtil.manifestListLocations(table, snapshotIds); + return toFileInfoDS(manifestLists, MANIFEST_LIST); + } + + protected Dataset statisticsFileDS(Table table, Set snapshotIds) { + Predicate predicate; + if (snapshotIds == null) { + predicate = statisticsFile -> true; + } else { + predicate = statisticsFile -> snapshotIds.contains(statisticsFile.snapshotId()); + } + + List statisticsFiles = ReachableFileUtil.statisticsFilesLocations(table, predicate); + return toFileInfoDS(statisticsFiles, STATISTICS_FILES); + } + + protected Dataset otherMetadataFileDS(Table table) { + return otherMetadataFileDS(table, false /* include all reachable old metadata locations */); + } + + protected Dataset allReachableOtherMetadataFileDS(Table table) { + return otherMetadataFileDS(table, true /* include all reachable old metadata locations */); + } + + private Dataset otherMetadataFileDS(Table table, boolean recursive) { + List otherMetadataFiles = Lists.newArrayList(); + otherMetadataFiles.addAll(ReachableFileUtil.metadataFileLocations(table, recursive)); + otherMetadataFiles.add(ReachableFileUtil.versionHintLocation(table)); + otherMetadataFiles.addAll(ReachableFileUtil.statisticsFilesLocations(table)); + return toFileInfoDS(otherMetadataFiles, OTHERS); + } + + protected Dataset loadMetadataTable(Table table, MetadataTableType type) { + return SparkTableUtil.loadMetadataTable(spark, table, type); + } + + private Dataset toFileInfoDS(List paths, String type) { + List fileInfoList = Lists.transform(paths, path -> new FileInfo(path, type)); + return spark.createDataset(fileInfoList, FileInfo.ENCODER); + } + + /** + * Deletes files and keeps track of how many files were removed for each file type. + * + * @param executorService an executor service to use for parallel deletes + * @param deleteFunc a delete func + * @param files an iterator of Spark rows of the structure (path: String, type: String) + * @return stats on which files were deleted + */ + protected DeleteSummary deleteFiles( + ExecutorService executorService, Consumer deleteFunc, Iterator files) { + + DeleteSummary summary = new DeleteSummary(); + + Tasks.foreach(files) + .retry(DELETE_NUM_RETRIES) + .stopRetryOn(NotFoundException.class) + .suppressFailureWhenFinished() + .executeWith(executorService) + .onFailure( + (fileInfo, exc) -> { + String path = fileInfo.getPath(); + String type = fileInfo.getType(); + LOG.warn("Delete failed for {}: {}", type, path, exc); + }) + .run( + fileInfo -> { + String path = fileInfo.getPath(); + String type = fileInfo.getType(); + deleteFunc.accept(path); + summary.deletedFile(path, type); + }); + + return summary; + } + + protected DeleteSummary deleteFiles(SupportsBulkOperations io, Iterator files) { + DeleteSummary summary = new DeleteSummary(); + Iterator> fileGroups = Iterators.partition(files, DELETE_GROUP_SIZE); + + Tasks.foreach(fileGroups) + .suppressFailureWhenFinished() + .run(fileGroup -> deleteFileGroup(fileGroup, io, summary)); + + return summary; + } + + private static void deleteFileGroup( + List fileGroup, SupportsBulkOperations io, DeleteSummary summary) { + + ListMultimap filesByType = Multimaps.index(fileGroup, FileInfo::getType); + ListMultimap pathsByType = + Multimaps.transformValues(filesByType, FileInfo::getPath); + + for (Map.Entry> entry : pathsByType.asMap().entrySet()) { + String type = entry.getKey(); + Collection paths = entry.getValue(); + int failures = 0; + try { + io.deleteFiles(paths); + } catch (BulkDeletionFailureException e) { + failures = e.numberFailedObjects(); + } + summary.deletedFiles(type, paths.size() - failures); + } + } + + static class DeleteSummary { + private final AtomicLong dataFilesCount = new AtomicLong(0L); + private final AtomicLong positionDeleteFilesCount = new AtomicLong(0L); + private final AtomicLong equalityDeleteFilesCount = new AtomicLong(0L); + private final AtomicLong manifestsCount = new AtomicLong(0L); + private final AtomicLong manifestListsCount = new AtomicLong(0L); + private final AtomicLong statisticsFilesCount = new AtomicLong(0L); + private final AtomicLong otherFilesCount = new AtomicLong(0L); + + public void deletedFiles(String type, int numFiles) { + if (FileContent.DATA.name().equalsIgnoreCase(type)) { + dataFilesCount.addAndGet(numFiles); + + } else if (FileContent.POSITION_DELETES.name().equalsIgnoreCase(type)) { + positionDeleteFilesCount.addAndGet(numFiles); + + } else if (FileContent.EQUALITY_DELETES.name().equalsIgnoreCase(type)) { + equalityDeleteFilesCount.addAndGet(numFiles); + + } else if (MANIFEST.equalsIgnoreCase(type)) { + manifestsCount.addAndGet(numFiles); + + } else if (MANIFEST_LIST.equalsIgnoreCase(type)) { + manifestListsCount.addAndGet(numFiles); + + } else if (STATISTICS_FILES.equalsIgnoreCase(type)) { + statisticsFilesCount.addAndGet(numFiles); + + } else if (OTHERS.equalsIgnoreCase(type)) { + otherFilesCount.addAndGet(numFiles); + + } else { + throw new ValidationException("Illegal file type: %s", type); + } + } + + public void deletedFile(String path, String type) { + if (FileContent.DATA.name().equalsIgnoreCase(type)) { + dataFilesCount.incrementAndGet(); + LOG.trace("Deleted data file: {}", path); + + } else if (FileContent.POSITION_DELETES.name().equalsIgnoreCase(type)) { + positionDeleteFilesCount.incrementAndGet(); + LOG.trace("Deleted positional delete file: {}", path); + + } else if (FileContent.EQUALITY_DELETES.name().equalsIgnoreCase(type)) { + equalityDeleteFilesCount.incrementAndGet(); + LOG.trace("Deleted equality delete file: {}", path); + + } else if (MANIFEST.equalsIgnoreCase(type)) { + manifestsCount.incrementAndGet(); + LOG.debug("Deleted manifest: {}", path); + + } else if (MANIFEST_LIST.equalsIgnoreCase(type)) { + manifestListsCount.incrementAndGet(); + LOG.debug("Deleted manifest list: {}", path); + + } else if (STATISTICS_FILES.equalsIgnoreCase(type)) { + statisticsFilesCount.incrementAndGet(); + LOG.debug("Deleted statistics file: {}", path); + + } else if (OTHERS.equalsIgnoreCase(type)) { + otherFilesCount.incrementAndGet(); + LOG.debug("Deleted other metadata file: {}", path); + + } else { + throw new ValidationException("Illegal file type: %s", type); + } + } + + public long dataFilesCount() { + return dataFilesCount.get(); + } + + public long positionDeleteFilesCount() { + return positionDeleteFilesCount.get(); + } + + public long equalityDeleteFilesCount() { + return equalityDeleteFilesCount.get(); + } + + public long manifestsCount() { + return manifestsCount.get(); + } + + public long manifestListsCount() { + return manifestListsCount.get(); + } + + public long statisticsFilesCount() { + return statisticsFilesCount.get(); + } + + public long otherFilesCount() { + return otherFilesCount.get(); + } + + public long totalFilesCount() { + return dataFilesCount() + + positionDeleteFilesCount() + + equalityDeleteFilesCount() + + manifestsCount() + + manifestListsCount() + + statisticsFilesCount() + + otherFilesCount(); + } + } + + private static class ReadManifest implements FlatMapFunction { + private final Broadcast
table; + + ReadManifest(Broadcast
table) { + this.table = table; + } + + @Override + public Iterator call(ManifestFileBean manifest) { + return new ClosingIterator<>(entries(manifest)); + } + + public CloseableIterator entries(ManifestFileBean manifest) { + ManifestContent content = manifest.content(); + FileIO io = table.getValue().io(); + Map specs = table.getValue().specs(); + List proj = ImmutableList.of(DataFile.FILE_PATH.name(), DataFile.CONTENT.name()); + + switch (content) { + case DATA: + return CloseableIterator.transform( + ManifestFiles.read(manifest, io, specs).select(proj).iterator(), + ReadManifest::toFileInfo); + case DELETES: + return CloseableIterator.transform( + ManifestFiles.readDeleteManifest(manifest, io, specs).select(proj).iterator(), + ReadManifest::toFileInfo); + default: + throw new IllegalArgumentException("Unsupported manifest content type:" + content); + } + } + + static FileInfo toFileInfo(ContentFile file) { + return new FileInfo(file.path().toString(), file.content().toString()); + } + } +} diff --git a/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/actions/BaseTableCreationSparkAction.java b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/actions/BaseTableCreationSparkAction.java new file mode 100644 index 000000000000..520c520484dc --- /dev/null +++ b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/actions/BaseTableCreationSparkAction.java @@ -0,0 +1,182 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.actions; + +import java.util.List; +import java.util.Locale; +import java.util.Map; +import java.util.Set; +import org.apache.iceberg.Table; +import org.apache.iceberg.TableProperties; +import org.apache.iceberg.exceptions.AlreadyExistsException; +import org.apache.iceberg.exceptions.NoSuchNamespaceException; +import org.apache.iceberg.exceptions.NoSuchTableException; +import org.apache.iceberg.mapping.MappingUtil; +import org.apache.iceberg.mapping.NameMapping; +import org.apache.iceberg.mapping.NameMappingParser; +import org.apache.iceberg.relocated.com.google.common.base.Preconditions; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableList; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableSet; +import org.apache.iceberg.relocated.com.google.common.collect.Maps; +import org.apache.iceberg.spark.SparkCatalog; +import org.apache.iceberg.spark.SparkSessionCatalog; +import org.apache.iceberg.spark.source.StagedSparkTable; +import org.apache.iceberg.util.LocationUtil; +import org.apache.spark.sql.SparkSession; +import org.apache.spark.sql.catalyst.catalog.CatalogTable; +import org.apache.spark.sql.catalyst.catalog.CatalogUtils; +import org.apache.spark.sql.connector.catalog.CatalogPlugin; +import org.apache.spark.sql.connector.catalog.Identifier; +import org.apache.spark.sql.connector.catalog.StagingTableCatalog; +import org.apache.spark.sql.connector.catalog.TableCatalog; +import org.apache.spark.sql.connector.catalog.V1Table; +import org.apache.spark.sql.connector.expressions.Transform; +import org.apache.spark.sql.types.StructType; + +abstract class BaseTableCreationSparkAction extends BaseSparkAction { + private static final Set ALLOWED_SOURCES = + ImmutableSet.of("parquet", "avro", "orc", "hive"); + protected static final String LOCATION = "location"; + protected static final String ICEBERG_METADATA_FOLDER = "metadata"; + protected static final List EXCLUDED_PROPERTIES = + ImmutableList.of("path", "transient_lastDdlTime", "serialization.format"); + + // Source Fields + private final V1Table sourceTable; + private final CatalogTable sourceCatalogTable; + private final String sourceTableLocation; + private final TableCatalog sourceCatalog; + private final Identifier sourceTableIdent; + + // Optional Parameters for destination + private final Map additionalProperties = Maps.newHashMap(); + + BaseTableCreationSparkAction( + SparkSession spark, CatalogPlugin sourceCatalog, Identifier sourceTableIdent) { + super(spark); + + this.sourceCatalog = checkSourceCatalog(sourceCatalog); + this.sourceTableIdent = sourceTableIdent; + + try { + this.sourceTable = (V1Table) this.sourceCatalog.loadTable(sourceTableIdent); + this.sourceCatalogTable = sourceTable.v1Table(); + } catch (org.apache.spark.sql.catalyst.analysis.NoSuchTableException e) { + throw new NoSuchTableException("Cannot not find source table '%s'", sourceTableIdent); + } catch (ClassCastException e) { + throw new IllegalArgumentException( + String.format("Cannot use non-v1 table '%s' as a source", sourceTableIdent), e); + } + validateSourceTable(); + + this.sourceTableLocation = + CatalogUtils.URIToString(sourceCatalogTable.storage().locationUri().get()); + } + + protected abstract TableCatalog checkSourceCatalog(CatalogPlugin catalog); + + protected abstract StagingTableCatalog destCatalog(); + + protected abstract Identifier destTableIdent(); + + protected abstract Map destTableProps(); + + protected String sourceTableLocation() { + return sourceTableLocation; + } + + protected CatalogTable v1SourceTable() { + return sourceCatalogTable; + } + + protected TableCatalog sourceCatalog() { + return sourceCatalog; + } + + protected Identifier sourceTableIdent() { + return sourceTableIdent; + } + + protected void setProperties(Map properties) { + additionalProperties.putAll(properties); + } + + protected void setProperty(String key, String value) { + additionalProperties.put(key, value); + } + + protected Map additionalProperties() { + return additionalProperties; + } + + private void validateSourceTable() { + String sourceTableProvider = sourceCatalogTable.provider().get().toLowerCase(Locale.ROOT); + Preconditions.checkArgument( + ALLOWED_SOURCES.contains(sourceTableProvider), + "Cannot create an Iceberg table from source provider: '%s'", + sourceTableProvider); + Preconditions.checkArgument( + !sourceCatalogTable.storage().locationUri().isEmpty(), + "Cannot create an Iceberg table from a source without an explicit location"); + } + + protected StagingTableCatalog checkDestinationCatalog(CatalogPlugin catalog) { + Preconditions.checkArgument( + catalog instanceof SparkSessionCatalog || catalog instanceof SparkCatalog, + "Cannot create Iceberg table in non-Iceberg Catalog. " + + "Catalog '%s' was of class '%s' but '%s' or '%s' are required", + catalog.name(), + catalog.getClass().getName(), + SparkSessionCatalog.class.getName(), + SparkCatalog.class.getName()); + + return (StagingTableCatalog) catalog; + } + + protected StagedSparkTable stageDestTable() { + try { + Map props = destTableProps(); + StructType schema = sourceTable.schema(); + Transform[] partitioning = sourceTable.partitioning(); + return (StagedSparkTable) + destCatalog().stageCreate(destTableIdent(), schema, partitioning, props); + } catch (org.apache.spark.sql.catalyst.analysis.NoSuchNamespaceException e) { + throw new NoSuchNamespaceException( + "Cannot create table %s as the namespace does not exist", destTableIdent()); + } catch (org.apache.spark.sql.catalyst.analysis.TableAlreadyExistsException e) { + throw new AlreadyExistsException( + "Cannot create table %s as it already exists", destTableIdent()); + } + } + + protected void ensureNameMappingPresent(Table table) { + if (!table.properties().containsKey(TableProperties.DEFAULT_NAME_MAPPING)) { + NameMapping nameMapping = MappingUtil.create(table.schema()); + String nameMappingJson = NameMappingParser.toJson(nameMapping); + table.updateProperties().set(TableProperties.DEFAULT_NAME_MAPPING, nameMappingJson).commit(); + } + } + + protected String getMetadataLocation(Table table) { + String defaultValue = + LocationUtil.stripTrailingSlash(table.location()) + "/" + ICEBERG_METADATA_FOLDER; + return LocationUtil.stripTrailingSlash( + table.properties().getOrDefault(TableProperties.WRITE_METADATA_LOCATION, defaultValue)); + } +} diff --git a/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/actions/DeleteOrphanFilesSparkAction.java b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/actions/DeleteOrphanFilesSparkAction.java new file mode 100644 index 000000000000..b00ed42008f1 --- /dev/null +++ b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/actions/DeleteOrphanFilesSparkAction.java @@ -0,0 +1,676 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.actions; + +import static org.apache.iceberg.TableProperties.GC_ENABLED; +import static org.apache.iceberg.TableProperties.GC_ENABLED_DEFAULT; + +import java.io.IOException; +import java.io.Serializable; +import java.io.UncheckedIOException; +import java.net.URI; +import java.sql.Timestamp; +import java.util.Collections; +import java.util.Iterator; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.Set; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.TimeUnit; +import java.util.function.Consumer; +import java.util.function.Predicate; +import java.util.stream.Collectors; +import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.fs.FileStatus; +import org.apache.hadoop.fs.FileSystem; +import org.apache.hadoop.fs.Path; +import org.apache.hadoop.fs.PathFilter; +import org.apache.iceberg.PartitionSpec; +import org.apache.iceberg.Table; +import org.apache.iceberg.actions.DeleteOrphanFiles; +import org.apache.iceberg.actions.ImmutableDeleteOrphanFiles; +import org.apache.iceberg.exceptions.ValidationException; +import org.apache.iceberg.hadoop.HiddenPathFilter; +import org.apache.iceberg.io.BulkDeletionFailureException; +import org.apache.iceberg.io.SupportsBulkOperations; +import org.apache.iceberg.relocated.com.google.common.annotations.VisibleForTesting; +import org.apache.iceberg.relocated.com.google.common.base.Preconditions; +import org.apache.iceberg.relocated.com.google.common.base.Strings; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap; +import org.apache.iceberg.relocated.com.google.common.collect.Iterators; +import org.apache.iceberg.relocated.com.google.common.collect.Lists; +import org.apache.iceberg.relocated.com.google.common.collect.Maps; +import org.apache.iceberg.spark.JobGroupInfo; +import org.apache.iceberg.util.Pair; +import org.apache.iceberg.util.PropertyUtil; +import org.apache.iceberg.util.Tasks; +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.function.FlatMapFunction; +import org.apache.spark.api.java.function.MapPartitionsFunction; +import org.apache.spark.broadcast.Broadcast; +import org.apache.spark.sql.Column; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Encoder; +import org.apache.spark.sql.Encoders; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.SparkSession; +import org.apache.spark.sql.types.DataTypes; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; +import org.apache.spark.util.SerializableConfiguration; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import scala.Tuple2; + +/** + * An action that removes orphan metadata, data and delete files by listing a given location and + * comparing the actual files in that location with content and metadata files referenced by all + * valid snapshots. The location must be accessible for listing via the Hadoop {@link FileSystem}. + * + *

By default, this action cleans up the table location returned by {@link Table#location()} and + * removes unreachable files that are older than 3 days using {@link Table#io()}. The behavior can + * be modified by passing a custom location to {@link #location} and a custom timestamp to {@link + * #olderThan(long)}. For example, someone might point this action to the data folder to clean up + * only orphan data files. + * + *

Configure an alternative delete method using {@link #deleteWith(Consumer)}. + * + *

For full control of the set of files being evaluated, use the {@link + * #compareToFileList(Dataset)} argument. This skips the directory listing - any files in the + * dataset provided which are not found in table metadata will be deleted, using the same {@link + * Table#location()} and {@link #olderThan(long)} filtering as above. + * + *

Note: It is dangerous to call this action with a short retention interval as it might + * corrupt the state of the table if another operation is writing at the same time. + */ +public class DeleteOrphanFilesSparkAction extends BaseSparkAction + implements DeleteOrphanFiles { + + private static final Logger LOG = LoggerFactory.getLogger(DeleteOrphanFilesSparkAction.class); + private static final Map EQUAL_SCHEMES_DEFAULT = ImmutableMap.of("s3n,s3a", "s3"); + private static final int MAX_DRIVER_LISTING_DEPTH = 3; + private static final int MAX_DRIVER_LISTING_DIRECT_SUB_DIRS = 10; + private static final int MAX_EXECUTOR_LISTING_DEPTH = 2000; + private static final int MAX_EXECUTOR_LISTING_DIRECT_SUB_DIRS = Integer.MAX_VALUE; + + private final SerializableConfiguration hadoopConf; + private final int listingParallelism; + private final Table table; + private Map equalSchemes = flattenMap(EQUAL_SCHEMES_DEFAULT); + private Map equalAuthorities = Collections.emptyMap(); + private PrefixMismatchMode prefixMismatchMode = PrefixMismatchMode.ERROR; + private String location = null; + private long olderThanTimestamp = System.currentTimeMillis() - TimeUnit.DAYS.toMillis(3); + private Dataset compareToFileList; + private Consumer deleteFunc = null; + private ExecutorService deleteExecutorService = null; + + DeleteOrphanFilesSparkAction(SparkSession spark, Table table) { + super(spark); + + this.hadoopConf = new SerializableConfiguration(spark.sessionState().newHadoopConf()); + this.listingParallelism = spark.sessionState().conf().parallelPartitionDiscoveryParallelism(); + this.table = table; + this.location = table.location(); + + ValidationException.check( + PropertyUtil.propertyAsBoolean(table.properties(), GC_ENABLED, GC_ENABLED_DEFAULT), + "Cannot delete orphan files: GC is disabled (deleting files may corrupt other tables)"); + } + + @Override + protected DeleteOrphanFilesSparkAction self() { + return this; + } + + @Override + public DeleteOrphanFilesSparkAction executeDeleteWith(ExecutorService executorService) { + this.deleteExecutorService = executorService; + return this; + } + + @Override + public DeleteOrphanFilesSparkAction prefixMismatchMode(PrefixMismatchMode newPrefixMismatchMode) { + this.prefixMismatchMode = newPrefixMismatchMode; + return this; + } + + @Override + public DeleteOrphanFilesSparkAction equalSchemes(Map newEqualSchemes) { + this.equalSchemes = Maps.newHashMap(); + equalSchemes.putAll(flattenMap(EQUAL_SCHEMES_DEFAULT)); + equalSchemes.putAll(flattenMap(newEqualSchemes)); + return this; + } + + @Override + public DeleteOrphanFilesSparkAction equalAuthorities(Map newEqualAuthorities) { + this.equalAuthorities = Maps.newHashMap(); + equalAuthorities.putAll(flattenMap(newEqualAuthorities)); + return this; + } + + @Override + public DeleteOrphanFilesSparkAction location(String newLocation) { + this.location = newLocation; + return this; + } + + @Override + public DeleteOrphanFilesSparkAction olderThan(long newOlderThanTimestamp) { + this.olderThanTimestamp = newOlderThanTimestamp; + return this; + } + + @Override + public DeleteOrphanFilesSparkAction deleteWith(Consumer newDeleteFunc) { + this.deleteFunc = newDeleteFunc; + return this; + } + + public DeleteOrphanFilesSparkAction compareToFileList(Dataset files) { + StructType schema = files.schema(); + + StructField filePathField = schema.apply(FILE_PATH); + Preconditions.checkArgument( + filePathField.dataType() == DataTypes.StringType, + "Invalid %s column: %s is not a string", + FILE_PATH, + filePathField.dataType()); + + StructField lastModifiedField = schema.apply(LAST_MODIFIED); + Preconditions.checkArgument( + lastModifiedField.dataType() == DataTypes.TimestampType, + "Invalid %s column: %s is not a timestamp", + LAST_MODIFIED, + lastModifiedField.dataType()); + + this.compareToFileList = files; + return this; + } + + private Dataset filteredCompareToFileList() { + Dataset files = compareToFileList; + if (location != null) { + files = files.filter(files.col(FILE_PATH).startsWith(location)); + } + return files + .filter(files.col(LAST_MODIFIED).lt(new Timestamp(olderThanTimestamp))) + .select(files.col(FILE_PATH)) + .as(Encoders.STRING()); + } + + @Override + public DeleteOrphanFiles.Result execute() { + JobGroupInfo info = newJobGroupInfo("DELETE-ORPHAN-FILES", jobDesc()); + return withJobGroupInfo(info, this::doExecute); + } + + private String jobDesc() { + List options = Lists.newArrayList(); + options.add("older_than=" + olderThanTimestamp); + if (location != null) { + options.add("location=" + location); + } + String optionsAsString = COMMA_JOINER.join(options); + return String.format("Deleting orphan files (%s) from %s", optionsAsString, table.name()); + } + + private void deleteFiles(SupportsBulkOperations io, List paths) { + try { + io.deleteFiles(paths); + LOG.info("Deleted {} files using bulk deletes", paths.size()); + } catch (BulkDeletionFailureException e) { + int deletedFilesCount = paths.size() - e.numberFailedObjects(); + LOG.warn("Deleted only {} of {} files using bulk deletes", deletedFilesCount, paths.size()); + } + } + + private DeleteOrphanFiles.Result doExecute() { + Dataset actualFileIdentDS = actualFileIdentDS(); + Dataset validFileIdentDS = validFileIdentDS(); + + List orphanFiles = + findOrphanFiles(spark(), actualFileIdentDS, validFileIdentDS, prefixMismatchMode); + + if (deleteFunc == null && table.io() instanceof SupportsBulkOperations) { + deleteFiles((SupportsBulkOperations) table.io(), orphanFiles); + } else { + + Tasks.Builder deleteTasks = + Tasks.foreach(orphanFiles) + .noRetry() + .executeWith(deleteExecutorService) + .suppressFailureWhenFinished() + .onFailure((file, exc) -> LOG.warn("Failed to delete file: {}", file, exc)); + + if (deleteFunc == null) { + LOG.info( + "Table IO {} does not support bulk operations. Using non-bulk deletes.", + table.io().getClass().getName()); + deleteTasks.run(table.io()::deleteFile); + } else { + LOG.info("Custom delete function provided. Using non-bulk deletes"); + deleteTasks.run(deleteFunc::accept); + } + } + + return ImmutableDeleteOrphanFiles.Result.builder().orphanFileLocations(orphanFiles).build(); + } + + private Dataset validFileIdentDS() { + // transform before union to avoid extra serialization/deserialization + FileInfoToFileURI toFileURI = new FileInfoToFileURI(equalSchemes, equalAuthorities); + + Dataset contentFileIdentDS = toFileURI.apply(contentFileDS(table)); + Dataset manifestFileIdentDS = toFileURI.apply(manifestDS(table)); + Dataset manifestListIdentDS = toFileURI.apply(manifestListDS(table)); + Dataset otherMetadataFileIdentDS = toFileURI.apply(otherMetadataFileDS(table)); + + return contentFileIdentDS + .union(manifestFileIdentDS) + .union(manifestListIdentDS) + .union(otherMetadataFileIdentDS); + } + + private Dataset actualFileIdentDS() { + StringToFileURI toFileURI = new StringToFileURI(equalSchemes, equalAuthorities); + if (compareToFileList == null) { + return toFileURI.apply(listedFileDS()); + } else { + return toFileURI.apply(filteredCompareToFileList()); + } + } + + private Dataset listedFileDS() { + List subDirs = Lists.newArrayList(); + List matchingFiles = Lists.newArrayList(); + + Predicate predicate = file -> file.getModificationTime() < olderThanTimestamp; + PathFilter pathFilter = PartitionAwareHiddenPathFilter.forSpecs(table.specs()); + + // list at most MAX_DRIVER_LISTING_DEPTH levels and only dirs that have + // less than MAX_DRIVER_LISTING_DIRECT_SUB_DIRS direct sub dirs on the driver + listDirRecursively( + location, + predicate, + hadoopConf.value(), + MAX_DRIVER_LISTING_DEPTH, + MAX_DRIVER_LISTING_DIRECT_SUB_DIRS, + subDirs, + pathFilter, + matchingFiles); + + JavaRDD matchingFileRDD = sparkContext().parallelize(matchingFiles, 1); + + if (subDirs.isEmpty()) { + return spark().createDataset(matchingFileRDD.rdd(), Encoders.STRING()); + } + + int parallelism = Math.min(subDirs.size(), listingParallelism); + JavaRDD subDirRDD = sparkContext().parallelize(subDirs, parallelism); + + Broadcast conf = sparkContext().broadcast(hadoopConf); + ListDirsRecursively listDirs = new ListDirsRecursively(conf, olderThanTimestamp, pathFilter); + JavaRDD matchingLeafFileRDD = subDirRDD.mapPartitions(listDirs); + + JavaRDD completeMatchingFileRDD = matchingFileRDD.union(matchingLeafFileRDD); + return spark().createDataset(completeMatchingFileRDD.rdd(), Encoders.STRING()); + } + + private static void listDirRecursively( + String dir, + Predicate predicate, + Configuration conf, + int maxDepth, + int maxDirectSubDirs, + List remainingSubDirs, + PathFilter pathFilter, + List matchingFiles) { + + // stop listing whenever we reach the max depth + if (maxDepth <= 0) { + remainingSubDirs.add(dir); + return; + } + + try { + Path path = new Path(dir); + FileSystem fs = path.getFileSystem(conf); + + List subDirs = Lists.newArrayList(); + + for (FileStatus file : fs.listStatus(path, pathFilter)) { + if (file.isDirectory()) { + subDirs.add(file.getPath().toString()); + } else if (file.isFile() && predicate.test(file)) { + matchingFiles.add(file.getPath().toString()); + } + } + + // stop listing if the number of direct sub dirs is bigger than maxDirectSubDirs + if (subDirs.size() > maxDirectSubDirs) { + remainingSubDirs.addAll(subDirs); + return; + } + + for (String subDir : subDirs) { + listDirRecursively( + subDir, + predicate, + conf, + maxDepth - 1, + maxDirectSubDirs, + remainingSubDirs, + pathFilter, + matchingFiles); + } + } catch (IOException e) { + throw new UncheckedIOException(e); + } + } + + @VisibleForTesting + static List findOrphanFiles( + SparkSession spark, + Dataset actualFileIdentDS, + Dataset validFileIdentDS, + PrefixMismatchMode prefixMismatchMode) { + + SetAccumulator> conflicts = new SetAccumulator<>(); + spark.sparkContext().register(conflicts); + + Column joinCond = actualFileIdentDS.col("path").equalTo(validFileIdentDS.col("path")); + + List orphanFiles = + actualFileIdentDS + .joinWith(validFileIdentDS, joinCond, "leftouter") + .mapPartitions(new FindOrphanFiles(prefixMismatchMode, conflicts), Encoders.STRING()) + .collectAsList(); + + if (prefixMismatchMode == PrefixMismatchMode.ERROR && !conflicts.value().isEmpty()) { + throw new ValidationException( + "Unable to determine whether certain files are orphan. " + + "Metadata references files that match listed/provided files except for authority/scheme. " + + "Please, inspect the conflicting authorities/schemes and provide which of them are equal " + + "by further configuring the action via equalSchemes() and equalAuthorities() methods. " + + "Set the prefix mismatch mode to 'NONE' to ignore remaining locations with conflicting " + + "authorities/schemes or to 'DELETE' iff you are ABSOLUTELY confident that remaining conflicting " + + "authorities/schemes are different. It will be impossible to recover deleted files. " + + "Conflicting authorities/schemes: %s.", + conflicts.value()); + } + + return orphanFiles; + } + + private static Map flattenMap(Map map) { + Map flattenedMap = Maps.newHashMap(); + if (map != null) { + for (String key : map.keySet()) { + String value = map.get(key); + for (String splitKey : COMMA_SPLITTER.split(key)) { + flattenedMap.put(splitKey.trim(), value.trim()); + } + } + } + return flattenedMap; + } + + private static class ListDirsRecursively implements FlatMapFunction, String> { + + private final Broadcast hadoopConf; + private final long olderThanTimestamp; + private final PathFilter pathFilter; + + ListDirsRecursively( + Broadcast hadoopConf, + long olderThanTimestamp, + PathFilter pathFilter) { + + this.hadoopConf = hadoopConf; + this.olderThanTimestamp = olderThanTimestamp; + this.pathFilter = pathFilter; + } + + @Override + public Iterator call(Iterator dirs) throws Exception { + List subDirs = Lists.newArrayList(); + List files = Lists.newArrayList(); + + Predicate predicate = file -> file.getModificationTime() < olderThanTimestamp; + + while (dirs.hasNext()) { + listDirRecursively( + dirs.next(), + predicate, + hadoopConf.value().value(), + MAX_EXECUTOR_LISTING_DEPTH, + MAX_EXECUTOR_LISTING_DIRECT_SUB_DIRS, + subDirs, + pathFilter, + files); + } + + if (!subDirs.isEmpty()) { + throw new RuntimeException( + "Could not list sub directories, reached maximum depth: " + MAX_EXECUTOR_LISTING_DEPTH); + } + + return files.iterator(); + } + } + + private static class FindOrphanFiles + implements MapPartitionsFunction, String> { + + private final PrefixMismatchMode mode; + private final SetAccumulator> conflicts; + + FindOrphanFiles(PrefixMismatchMode mode, SetAccumulator> conflicts) { + this.mode = mode; + this.conflicts = conflicts; + } + + @Override + public Iterator call(Iterator> rows) throws Exception { + Iterator orphanFiles = Iterators.transform(rows, this::toOrphanFile); + return Iterators.filter(orphanFiles, Objects::nonNull); + } + + private String toOrphanFile(Tuple2 row) { + FileURI actual = row._1; + FileURI valid = row._2; + + if (valid == null) { + return actual.uriAsString; + } + + boolean schemeMatch = uriComponentMatch(valid.scheme, actual.scheme); + boolean authorityMatch = uriComponentMatch(valid.authority, actual.authority); + + if ((!schemeMatch || !authorityMatch) && mode == PrefixMismatchMode.DELETE) { + return actual.uriAsString; + } else { + if (!schemeMatch) { + conflicts.add(Pair.of(valid.scheme, actual.scheme)); + } + + if (!authorityMatch) { + conflicts.add(Pair.of(valid.authority, actual.authority)); + } + + return null; + } + } + + private boolean uriComponentMatch(String valid, String actual) { + return Strings.isNullOrEmpty(valid) || valid.equalsIgnoreCase(actual); + } + } + + @VisibleForTesting + static class StringToFileURI extends ToFileURI { + StringToFileURI(Map equalSchemes, Map equalAuthorities) { + super(equalSchemes, equalAuthorities); + } + + @Override + protected String uriAsString(String input) { + return input; + } + } + + @VisibleForTesting + static class FileInfoToFileURI extends ToFileURI { + FileInfoToFileURI(Map equalSchemes, Map equalAuthorities) { + super(equalSchemes, equalAuthorities); + } + + @Override + protected String uriAsString(FileInfo fileInfo) { + return fileInfo.getPath(); + } + } + + private abstract static class ToFileURI implements MapPartitionsFunction { + + private final Map equalSchemes; + private final Map equalAuthorities; + + ToFileURI(Map equalSchemes, Map equalAuthorities) { + this.equalSchemes = equalSchemes; + this.equalAuthorities = equalAuthorities; + } + + protected abstract String uriAsString(I input); + + Dataset apply(Dataset ds) { + return ds.mapPartitions(this, FileURI.ENCODER); + } + + @Override + public Iterator call(Iterator rows) throws Exception { + return Iterators.transform(rows, this::toFileURI); + } + + private FileURI toFileURI(I input) { + String uriAsString = uriAsString(input); + URI uri = new Path(uriAsString).toUri(); + String scheme = equalSchemes.getOrDefault(uri.getScheme(), uri.getScheme()); + String authority = equalAuthorities.getOrDefault(uri.getAuthority(), uri.getAuthority()); + return new FileURI(scheme, authority, uri.getPath(), uriAsString); + } + } + + /** + * A {@link PathFilter} that filters out hidden path, but does not filter out paths that would be + * marked as hidden by {@link HiddenPathFilter} due to a partition field that starts with one of + * the characters that indicate a hidden path. + */ + @VisibleForTesting + static class PartitionAwareHiddenPathFilter implements PathFilter, Serializable { + + private final Set hiddenPathPartitionNames; + + PartitionAwareHiddenPathFilter(Set hiddenPathPartitionNames) { + this.hiddenPathPartitionNames = hiddenPathPartitionNames; + } + + @Override + public boolean accept(Path path) { + return isHiddenPartitionPath(path) || HiddenPathFilter.get().accept(path); + } + + private boolean isHiddenPartitionPath(Path path) { + return hiddenPathPartitionNames.stream().anyMatch(path.getName()::startsWith); + } + + static PathFilter forSpecs(Map specs) { + if (specs == null) { + return HiddenPathFilter.get(); + } + + Set partitionNames = + specs.values().stream() + .map(PartitionSpec::fields) + .flatMap(List::stream) + .filter(field -> field.name().startsWith("_") || field.name().startsWith(".")) + .map(field -> field.name() + "=") + .collect(Collectors.toSet()); + + if (partitionNames.isEmpty()) { + return HiddenPathFilter.get(); + } else { + return new PartitionAwareHiddenPathFilter(partitionNames); + } + } + } + + public static class FileURI { + public static final Encoder ENCODER = Encoders.bean(FileURI.class); + + private String scheme; + private String authority; + private String path; + private String uriAsString; + + public FileURI(String scheme, String authority, String path, String uriAsString) { + this.scheme = scheme; + this.authority = authority; + this.path = path; + this.uriAsString = uriAsString; + } + + public FileURI() {} + + public void setScheme(String scheme) { + this.scheme = scheme; + } + + public void setAuthority(String authority) { + this.authority = authority; + } + + public void setPath(String path) { + this.path = path; + } + + public void setUriAsString(String uriAsString) { + this.uriAsString = uriAsString; + } + + public String getScheme() { + return scheme; + } + + public String getAuthority() { + return authority; + } + + public String getPath() { + return path; + } + + public String getUriAsString() { + return uriAsString; + } + } +} diff --git a/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/actions/DeleteReachableFilesSparkAction.java b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/actions/DeleteReachableFilesSparkAction.java new file mode 100644 index 000000000000..ea6ac9f3dbf5 --- /dev/null +++ b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/actions/DeleteReachableFilesSparkAction.java @@ -0,0 +1,156 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.actions; + +import static org.apache.iceberg.TableProperties.GC_ENABLED; +import static org.apache.iceberg.TableProperties.GC_ENABLED_DEFAULT; + +import java.util.Iterator; +import java.util.concurrent.ExecutorService; +import java.util.function.Consumer; +import org.apache.iceberg.Table; +import org.apache.iceberg.TableMetadata; +import org.apache.iceberg.TableMetadataParser; +import org.apache.iceberg.actions.DeleteReachableFiles; +import org.apache.iceberg.actions.ImmutableDeleteReachableFiles; +import org.apache.iceberg.exceptions.ValidationException; +import org.apache.iceberg.hadoop.HadoopFileIO; +import org.apache.iceberg.io.FileIO; +import org.apache.iceberg.io.SupportsBulkOperations; +import org.apache.iceberg.relocated.com.google.common.base.Preconditions; +import org.apache.iceberg.spark.JobGroupInfo; +import org.apache.iceberg.util.PropertyUtil; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.SparkSession; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * An implementation of {@link DeleteReachableFiles} that uses metadata tables in Spark to determine + * which files should be deleted. + */ +@SuppressWarnings("UnnecessaryAnonymousClass") +public class DeleteReachableFilesSparkAction + extends BaseSparkAction implements DeleteReachableFiles { + + public static final String STREAM_RESULTS = "stream-results"; + public static final boolean STREAM_RESULTS_DEFAULT = false; + + private static final Logger LOG = LoggerFactory.getLogger(DeleteReachableFilesSparkAction.class); + + private final String metadataFileLocation; + + private Consumer deleteFunc = null; + private ExecutorService deleteExecutorService = null; + private FileIO io = new HadoopFileIO(spark().sessionState().newHadoopConf()); + + DeleteReachableFilesSparkAction(SparkSession spark, String metadataFileLocation) { + super(spark); + this.metadataFileLocation = metadataFileLocation; + } + + @Override + protected DeleteReachableFilesSparkAction self() { + return this; + } + + @Override + public DeleteReachableFilesSparkAction io(FileIO fileIO) { + this.io = fileIO; + return this; + } + + @Override + public DeleteReachableFilesSparkAction deleteWith(Consumer newDeleteFunc) { + this.deleteFunc = newDeleteFunc; + return this; + } + + @Override + public DeleteReachableFilesSparkAction executeDeleteWith(ExecutorService executorService) { + this.deleteExecutorService = executorService; + return this; + } + + @Override + public Result execute() { + Preconditions.checkArgument(io != null, "File IO cannot be null"); + String jobDesc = String.format("Deleting files reachable from %s", metadataFileLocation); + JobGroupInfo info = newJobGroupInfo("DELETE-REACHABLE-FILES", jobDesc); + return withJobGroupInfo(info, this::doExecute); + } + + private Result doExecute() { + TableMetadata metadata = TableMetadataParser.read(io, metadataFileLocation); + + ValidationException.check( + PropertyUtil.propertyAsBoolean(metadata.properties(), GC_ENABLED, GC_ENABLED_DEFAULT), + "Cannot delete files: GC is disabled (deleting files may corrupt other tables)"); + + Dataset reachableFileDS = reachableFileDS(metadata); + + if (streamResults()) { + return deleteFiles(reachableFileDS.toLocalIterator()); + } else { + return deleteFiles(reachableFileDS.collectAsList().iterator()); + } + } + + private boolean streamResults() { + return PropertyUtil.propertyAsBoolean(options(), STREAM_RESULTS, STREAM_RESULTS_DEFAULT); + } + + private Dataset reachableFileDS(TableMetadata metadata) { + Table staticTable = newStaticTable(metadata, io); + return contentFileDS(staticTable) + .union(manifestDS(staticTable)) + .union(manifestListDS(staticTable)) + .union(allReachableOtherMetadataFileDS(staticTable)) + .distinct(); + } + + private DeleteReachableFiles.Result deleteFiles(Iterator files) { + DeleteSummary summary; + if (deleteFunc == null && io instanceof SupportsBulkOperations) { + summary = deleteFiles((SupportsBulkOperations) io, files); + } else { + + if (deleteFunc == null) { + LOG.info( + "Table IO {} does not support bulk operations. Using non-bulk deletes.", + io.getClass().getName()); + summary = deleteFiles(deleteExecutorService, io::deleteFile, files); + } else { + LOG.info("Custom delete function provided. Using non-bulk deletes"); + summary = deleteFiles(deleteExecutorService, deleteFunc, files); + } + } + + LOG.info("Deleted {} total files", summary.totalFilesCount()); + + return ImmutableDeleteReachableFiles.Result.builder() + .deletedDataFilesCount(summary.dataFilesCount()) + .deletedPositionDeleteFilesCount(summary.positionDeleteFilesCount()) + .deletedEqualityDeleteFilesCount(summary.equalityDeleteFilesCount()) + .deletedManifestsCount(summary.manifestsCount()) + .deletedManifestListsCount(summary.manifestListsCount()) + .deletedOtherFilesCount(summary.otherFilesCount()) + .build(); + } +} diff --git a/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/actions/ExpireSnapshotsSparkAction.java b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/actions/ExpireSnapshotsSparkAction.java new file mode 100644 index 000000000000..2468497e42d0 --- /dev/null +++ b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/actions/ExpireSnapshotsSparkAction.java @@ -0,0 +1,270 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.actions; + +import static org.apache.iceberg.TableProperties.GC_ENABLED; +import static org.apache.iceberg.TableProperties.GC_ENABLED_DEFAULT; + +import java.util.Iterator; +import java.util.List; +import java.util.Set; +import java.util.concurrent.ExecutorService; +import java.util.function.Consumer; +import java.util.stream.Collectors; +import org.apache.iceberg.HasTableOperations; +import org.apache.iceberg.Snapshot; +import org.apache.iceberg.Table; +import org.apache.iceberg.TableMetadata; +import org.apache.iceberg.TableOperations; +import org.apache.iceberg.actions.ExpireSnapshots; +import org.apache.iceberg.actions.ImmutableExpireSnapshots; +import org.apache.iceberg.exceptions.ValidationException; +import org.apache.iceberg.io.SupportsBulkOperations; +import org.apache.iceberg.relocated.com.google.common.base.Preconditions; +import org.apache.iceberg.relocated.com.google.common.collect.Lists; +import org.apache.iceberg.relocated.com.google.common.collect.Sets; +import org.apache.iceberg.spark.JobGroupInfo; +import org.apache.iceberg.util.PropertyUtil; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.SparkSession; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * An action that performs the same operation as {@link org.apache.iceberg.ExpireSnapshots} but uses + * Spark to determine the delta in files between the pre and post-expiration table metadata. All of + * the same restrictions of {@link org.apache.iceberg.ExpireSnapshots} also apply to this action. + * + *

This action first leverages {@link org.apache.iceberg.ExpireSnapshots} to expire snapshots and + * then uses metadata tables to find files that can be safely deleted. This is done by anti-joining + * two Datasets that contain all manifest and content files before and after the expiration. The + * snapshot expiration will be fully committed before any deletes are issued. + * + *

This operation performs a shuffle so the parallelism can be controlled through + * 'spark.sql.shuffle.partitions'. + * + *

Deletes are still performed locally after retrieving the results from the Spark executors. + */ +@SuppressWarnings("UnnecessaryAnonymousClass") +public class ExpireSnapshotsSparkAction extends BaseSparkAction + implements ExpireSnapshots { + + public static final String STREAM_RESULTS = "stream-results"; + public static final boolean STREAM_RESULTS_DEFAULT = false; + + private static final Logger LOG = LoggerFactory.getLogger(ExpireSnapshotsSparkAction.class); + + private final Table table; + private final TableOperations ops; + + private final Set expiredSnapshotIds = Sets.newHashSet(); + private Long expireOlderThanValue = null; + private Integer retainLastValue = null; + private Consumer deleteFunc = null; + private ExecutorService deleteExecutorService = null; + private Dataset expiredFileDS = null; + + ExpireSnapshotsSparkAction(SparkSession spark, Table table) { + super(spark); + this.table = table; + this.ops = ((HasTableOperations) table).operations(); + + ValidationException.check( + PropertyUtil.propertyAsBoolean(table.properties(), GC_ENABLED, GC_ENABLED_DEFAULT), + "Cannot expire snapshots: GC is disabled (deleting files may corrupt other tables)"); + } + + @Override + protected ExpireSnapshotsSparkAction self() { + return this; + } + + @Override + public ExpireSnapshotsSparkAction executeDeleteWith(ExecutorService executorService) { + this.deleteExecutorService = executorService; + return this; + } + + @Override + public ExpireSnapshotsSparkAction expireSnapshotId(long snapshotId) { + expiredSnapshotIds.add(snapshotId); + return this; + } + + @Override + public ExpireSnapshotsSparkAction expireOlderThan(long timestampMillis) { + this.expireOlderThanValue = timestampMillis; + return this; + } + + @Override + public ExpireSnapshotsSparkAction retainLast(int numSnapshots) { + Preconditions.checkArgument( + 1 <= numSnapshots, + "Number of snapshots to retain must be at least 1, cannot be: %s", + numSnapshots); + this.retainLastValue = numSnapshots; + return this; + } + + @Override + public ExpireSnapshotsSparkAction deleteWith(Consumer newDeleteFunc) { + this.deleteFunc = newDeleteFunc; + return this; + } + + /** + * Expires snapshots and commits the changes to the table, returning a Dataset of files to delete. + * + *

This does not delete data files. To delete data files, run {@link #execute()}. + * + *

This may be called before or after {@link #execute()} to return the expired files. + * + * @return a Dataset of files that are no longer referenced by the table + */ + public Dataset expireFiles() { + if (expiredFileDS == null) { + // fetch metadata before expiration + TableMetadata originalMetadata = ops.current(); + + // perform expiration + org.apache.iceberg.ExpireSnapshots expireSnapshots = table.expireSnapshots(); + + for (long id : expiredSnapshotIds) { + expireSnapshots = expireSnapshots.expireSnapshotId(id); + } + + if (expireOlderThanValue != null) { + expireSnapshots = expireSnapshots.expireOlderThan(expireOlderThanValue); + } + + if (retainLastValue != null) { + expireSnapshots = expireSnapshots.retainLast(retainLastValue); + } + + expireSnapshots.cleanExpiredFiles(false).commit(); + + // fetch valid files after expiration + TableMetadata updatedMetadata = ops.refresh(); + Dataset validFileDS = fileDS(updatedMetadata); + + // fetch files referenced by expired snapshots + Set deletedSnapshotIds = findExpiredSnapshotIds(originalMetadata, updatedMetadata); + Dataset deleteCandidateFileDS = fileDS(originalMetadata, deletedSnapshotIds); + + // determine expired files + this.expiredFileDS = deleteCandidateFileDS.except(validFileDS); + } + + return expiredFileDS; + } + + @Override + public ExpireSnapshots.Result execute() { + JobGroupInfo info = newJobGroupInfo("EXPIRE-SNAPSHOTS", jobDesc()); + return withJobGroupInfo(info, this::doExecute); + } + + private String jobDesc() { + List options = Lists.newArrayList(); + + if (expireOlderThanValue != null) { + options.add("older_than=" + expireOlderThanValue); + } + + if (retainLastValue != null) { + options.add("retain_last=" + retainLastValue); + } + + if (!expiredSnapshotIds.isEmpty()) { + Long first = expiredSnapshotIds.stream().findFirst().get(); + if (expiredSnapshotIds.size() > 1) { + options.add( + String.format("snapshot_ids: %s (%s more...)", first, expiredSnapshotIds.size() - 1)); + } else { + options.add(String.format("snapshot_id: %s", first)); + } + } + + return String.format("Expiring snapshots (%s) in %s", COMMA_JOINER.join(options), table.name()); + } + + private ExpireSnapshots.Result doExecute() { + if (streamResults()) { + return deleteFiles(expireFiles().toLocalIterator()); + } else { + return deleteFiles(expireFiles().collectAsList().iterator()); + } + } + + private boolean streamResults() { + return PropertyUtil.propertyAsBoolean(options(), STREAM_RESULTS, STREAM_RESULTS_DEFAULT); + } + + private Dataset fileDS(TableMetadata metadata) { + return fileDS(metadata, null); + } + + private Dataset fileDS(TableMetadata metadata, Set snapshotIds) { + Table staticTable = newStaticTable(metadata, table.io()); + return contentFileDS(staticTable, snapshotIds) + .union(manifestDS(staticTable, snapshotIds)) + .union(manifestListDS(staticTable, snapshotIds)) + .union(statisticsFileDS(staticTable, snapshotIds)); + } + + private Set findExpiredSnapshotIds( + TableMetadata originalMetadata, TableMetadata updatedMetadata) { + Set retainedSnapshots = + updatedMetadata.snapshots().stream().map(Snapshot::snapshotId).collect(Collectors.toSet()); + return originalMetadata.snapshots().stream() + .map(Snapshot::snapshotId) + .filter(id -> !retainedSnapshots.contains(id)) + .collect(Collectors.toSet()); + } + + private ExpireSnapshots.Result deleteFiles(Iterator files) { + DeleteSummary summary; + if (deleteFunc == null && table.io() instanceof SupportsBulkOperations) { + summary = deleteFiles((SupportsBulkOperations) table.io(), files); + } else { + + if (deleteFunc == null) { + LOG.info( + "Table IO {} does not support bulk operations. Using non-bulk deletes.", + table.io().getClass().getName()); + summary = deleteFiles(deleteExecutorService, table.io()::deleteFile, files); + } else { + LOG.info("Custom delete function provided. Using non-bulk deletes"); + summary = deleteFiles(deleteExecutorService, deleteFunc, files); + } + } + + LOG.info("Deleted {} total files", summary.totalFilesCount()); + + return ImmutableExpireSnapshots.Result.builder() + .deletedDataFilesCount(summary.dataFilesCount()) + .deletedPositionDeleteFilesCount(summary.positionDeleteFilesCount()) + .deletedEqualityDeleteFilesCount(summary.equalityDeleteFilesCount()) + .deletedManifestsCount(summary.manifestsCount()) + .deletedManifestListsCount(summary.manifestListsCount()) + .deletedStatisticsFilesCount(summary.statisticsFilesCount()) + .build(); + } +} diff --git a/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/actions/FileInfo.java b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/actions/FileInfo.java new file mode 100644 index 000000000000..51ff7c80fd18 --- /dev/null +++ b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/actions/FileInfo.java @@ -0,0 +1,52 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.actions; + +import org.apache.spark.sql.Encoder; +import org.apache.spark.sql.Encoders; + +public class FileInfo { + public static final Encoder ENCODER = Encoders.bean(FileInfo.class); + + private String path; + private String type; + + public FileInfo(String path, String type) { + this.path = path; + this.type = type; + } + + public FileInfo() {} + + public String getPath() { + return path; + } + + public void setPath(String path) { + this.path = path; + } + + public String getType() { + return type; + } + + public void setType(String type) { + this.type = type; + } +} diff --git a/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/actions/ManifestFileBean.java b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/actions/ManifestFileBean.java new file mode 100644 index 000000000000..45647070e602 --- /dev/null +++ b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/actions/ManifestFileBean.java @@ -0,0 +1,156 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.actions; + +import java.nio.ByteBuffer; +import java.util.List; +import org.apache.iceberg.ManifestContent; +import org.apache.iceberg.ManifestFile; +import org.apache.spark.sql.Encoder; +import org.apache.spark.sql.Encoders; + +public class ManifestFileBean implements ManifestFile { + public static final Encoder ENCODER = Encoders.bean(ManifestFileBean.class); + + private String path = null; + private Long length = null; + private Integer partitionSpecId = null; + private Long addedSnapshotId = null; + private Integer content = null; + + public String getPath() { + return path; + } + + public void setPath(String path) { + this.path = path; + } + + public Long getLength() { + return length; + } + + public void setLength(Long length) { + this.length = length; + } + + public Integer getPartitionSpecId() { + return partitionSpecId; + } + + public void setPartitionSpecId(Integer partitionSpecId) { + this.partitionSpecId = partitionSpecId; + } + + public Long getAddedSnapshotId() { + return addedSnapshotId; + } + + public void setAddedSnapshotId(Long addedSnapshotId) { + this.addedSnapshotId = addedSnapshotId; + } + + public Integer getContent() { + return content; + } + + public void setContent(Integer content) { + this.content = content; + } + + @Override + public String path() { + return path; + } + + @Override + public long length() { + return length; + } + + @Override + public int partitionSpecId() { + return partitionSpecId; + } + + @Override + public ManifestContent content() { + return ManifestContent.fromId(content); + } + + @Override + public long sequenceNumber() { + return 0; + } + + @Override + public long minSequenceNumber() { + return 0; + } + + @Override + public Long snapshotId() { + return addedSnapshotId; + } + + @Override + public Integer addedFilesCount() { + return null; + } + + @Override + public Long addedRowsCount() { + return null; + } + + @Override + public Integer existingFilesCount() { + return null; + } + + @Override + public Long existingRowsCount() { + return null; + } + + @Override + public Integer deletedFilesCount() { + return null; + } + + @Override + public Long deletedRowsCount() { + return null; + } + + @Override + public List partitions() { + return null; + } + + @Override + public ByteBuffer keyMetadata() { + return null; + } + + @Override + public ManifestFile copy() { + throw new UnsupportedOperationException("Cannot copy"); + } +} diff --git a/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/actions/MigrateTableSparkAction.java b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/actions/MigrateTableSparkAction.java new file mode 100644 index 000000000000..fe8acf0157d3 --- /dev/null +++ b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/actions/MigrateTableSparkAction.java @@ -0,0 +1,244 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.actions; + +import java.util.Map; +import org.apache.iceberg.Snapshot; +import org.apache.iceberg.SnapshotSummary; +import org.apache.iceberg.Table; +import org.apache.iceberg.actions.ImmutableMigrateTable; +import org.apache.iceberg.actions.MigrateTable; +import org.apache.iceberg.exceptions.AlreadyExistsException; +import org.apache.iceberg.exceptions.NoSuchTableException; +import org.apache.iceberg.relocated.com.google.common.base.Preconditions; +import org.apache.iceberg.relocated.com.google.common.collect.Maps; +import org.apache.iceberg.spark.JobGroupInfo; +import org.apache.iceberg.spark.SparkSessionCatalog; +import org.apache.iceberg.spark.SparkTableUtil; +import org.apache.iceberg.spark.source.StagedSparkTable; +import org.apache.spark.sql.SparkSession; +import org.apache.spark.sql.catalyst.TableIdentifier; +import org.apache.spark.sql.connector.catalog.CatalogPlugin; +import org.apache.spark.sql.connector.catalog.Identifier; +import org.apache.spark.sql.connector.catalog.StagingTableCatalog; +import org.apache.spark.sql.connector.catalog.TableCatalog; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import scala.Some; +import scala.collection.JavaConverters; + +/** + * Takes a Spark table in the source catalog and attempts to transform it into an Iceberg table in + * the same location with the same identifier. Once complete the identifier which previously + * referred to a non-Iceberg table will refer to the newly migrated Iceberg table. + */ +public class MigrateTableSparkAction extends BaseTableCreationSparkAction + implements MigrateTable { + + private static final Logger LOG = LoggerFactory.getLogger(MigrateTableSparkAction.class); + private static final String BACKUP_SUFFIX = "_BACKUP_"; + + private final StagingTableCatalog destCatalog; + private final Identifier destTableIdent; + private final Identifier backupIdent; + + private boolean dropBackup = false; + + MigrateTableSparkAction( + SparkSession spark, CatalogPlugin sourceCatalog, Identifier sourceTableIdent) { + super(spark, sourceCatalog, sourceTableIdent); + this.destCatalog = checkDestinationCatalog(sourceCatalog); + this.destTableIdent = sourceTableIdent; + String backupName = sourceTableIdent.name() + BACKUP_SUFFIX; + this.backupIdent = Identifier.of(sourceTableIdent.namespace(), backupName); + } + + @Override + protected MigrateTableSparkAction self() { + return this; + } + + @Override + protected StagingTableCatalog destCatalog() { + return destCatalog; + } + + @Override + protected Identifier destTableIdent() { + return destTableIdent; + } + + @Override + public MigrateTableSparkAction tableProperties(Map properties) { + setProperties(properties); + return this; + } + + @Override + public MigrateTableSparkAction tableProperty(String property, String value) { + setProperty(property, value); + return this; + } + + @Override + public MigrateTableSparkAction dropBackup() { + this.dropBackup = true; + return this; + } + + @Override + public MigrateTable.Result execute() { + String desc = String.format("Migrating table %s", destTableIdent().toString()); + JobGroupInfo info = newJobGroupInfo("MIGRATE-TABLE", desc); + return withJobGroupInfo(info, this::doExecute); + } + + private MigrateTable.Result doExecute() { + LOG.info("Starting the migration of {} to Iceberg", sourceTableIdent()); + + // move the source table to a new name, halting all modifications and allowing us to stage + // the creation of a new Iceberg table in its place + renameAndBackupSourceTable(); + + StagedSparkTable stagedTable = null; + Table icebergTable; + boolean threw = true; + try { + LOG.info("Staging a new Iceberg table {}", destTableIdent()); + stagedTable = stageDestTable(); + icebergTable = stagedTable.table(); + + LOG.info("Ensuring {} has a valid name mapping", destTableIdent()); + ensureNameMappingPresent(icebergTable); + + Some backupNamespace = Some.apply(backupIdent.namespace()[0]); + TableIdentifier v1BackupIdent = new TableIdentifier(backupIdent.name(), backupNamespace); + String stagingLocation = getMetadataLocation(icebergTable); + LOG.info("Generating Iceberg metadata for {} in {}", destTableIdent(), stagingLocation); + SparkTableUtil.importSparkTable(spark(), v1BackupIdent, icebergTable, stagingLocation); + + LOG.info("Committing staged changes to {}", destTableIdent()); + stagedTable.commitStagedChanges(); + threw = false; + } finally { + if (threw) { + LOG.error( + "Failed to perform the migration, aborting table creation and restoring the original table"); + + restoreSourceTable(); + + if (stagedTable != null) { + try { + stagedTable.abortStagedChanges(); + } catch (Exception abortException) { + LOG.error("Cannot abort staged changes", abortException); + } + } + } else if (dropBackup) { + dropBackupTable(); + } + } + + Snapshot snapshot = icebergTable.currentSnapshot(); + long migratedDataFilesCount = + Long.parseLong(snapshot.summary().get(SnapshotSummary.TOTAL_DATA_FILES_PROP)); + LOG.info( + "Successfully loaded Iceberg metadata for {} files to {}", + migratedDataFilesCount, + destTableIdent()); + return ImmutableMigrateTable.Result.builder() + .migratedDataFilesCount(migratedDataFilesCount) + .build(); + } + + @Override + protected Map destTableProps() { + Map properties = Maps.newHashMap(); + + // copy over relevant source table props + properties.putAll(JavaConverters.mapAsJavaMapConverter(v1SourceTable().properties()).asJava()); + EXCLUDED_PROPERTIES.forEach(properties::remove); + + // set default and user-provided props + properties.put(TableCatalog.PROP_PROVIDER, "iceberg"); + properties.putAll(additionalProperties()); + + // make sure we mark this table as migrated + properties.put("migrated", "true"); + + // inherit the source table location + properties.putIfAbsent(LOCATION, sourceTableLocation()); + + return properties; + } + + @Override + protected TableCatalog checkSourceCatalog(CatalogPlugin catalog) { + // currently the import code relies on being able to look up the table in the session catalog + Preconditions.checkArgument( + catalog instanceof SparkSessionCatalog, + "Cannot migrate a table from a non-Iceberg Spark Session Catalog. Found %s of class %s as the source catalog.", + catalog.name(), + catalog.getClass().getName()); + + return (TableCatalog) catalog; + } + + private void renameAndBackupSourceTable() { + try { + LOG.info("Renaming {} as {} for backup", sourceTableIdent(), backupIdent); + destCatalog().renameTable(sourceTableIdent(), backupIdent); + + } catch (org.apache.spark.sql.catalyst.analysis.NoSuchTableException e) { + throw new NoSuchTableException("Cannot find source table %s", sourceTableIdent()); + + } catch (org.apache.spark.sql.catalyst.analysis.TableAlreadyExistsException e) { + throw new AlreadyExistsException( + "Cannot rename %s as %s for backup. The backup table already exists.", + sourceTableIdent(), backupIdent); + } + } + + private void restoreSourceTable() { + try { + LOG.info("Restoring {} from {}", sourceTableIdent(), backupIdent); + destCatalog().renameTable(backupIdent, sourceTableIdent()); + + } catch (org.apache.spark.sql.catalyst.analysis.NoSuchTableException e) { + LOG.error( + "Cannot restore the original table, the backup table {} cannot be found", backupIdent, e); + + } catch (org.apache.spark.sql.catalyst.analysis.TableAlreadyExistsException e) { + LOG.error( + "Cannot restore the original table, a table with the original name exists. " + + "Use the backup table {} to restore the original table manually.", + backupIdent, + e); + } + } + + private void dropBackupTable() { + try { + destCatalog().dropTable(backupIdent); + } catch (Exception e) { + LOG.error( + "Cannot drop the backup table {}, after the migration is completed.", backupIdent, e); + } + } +} diff --git a/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/actions/RewriteDataFilesSparkAction.java b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/actions/RewriteDataFilesSparkAction.java new file mode 100644 index 000000000000..5f95ef3ed4c9 --- /dev/null +++ b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/actions/RewriteDataFilesSparkAction.java @@ -0,0 +1,517 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.actions; + +import java.io.IOException; +import java.math.RoundingMode; +import java.util.Arrays; +import java.util.Comparator; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.concurrent.ConcurrentLinkedQueue; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.ThreadPoolExecutor; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.stream.Collectors; +import java.util.stream.Stream; +import org.apache.iceberg.DataFile; +import org.apache.iceberg.FileScanTask; +import org.apache.iceberg.RewriteJobOrder; +import org.apache.iceberg.SortOrder; +import org.apache.iceberg.StructLike; +import org.apache.iceberg.Table; +import org.apache.iceberg.actions.FileRewriter; +import org.apache.iceberg.actions.ImmutableRewriteDataFiles; +import org.apache.iceberg.actions.RewriteDataFiles; +import org.apache.iceberg.actions.RewriteDataFilesCommitManager; +import org.apache.iceberg.actions.RewriteFileGroup; +import org.apache.iceberg.data.GenericRecord; +import org.apache.iceberg.exceptions.CommitFailedException; +import org.apache.iceberg.exceptions.ValidationException; +import org.apache.iceberg.expressions.Expression; +import org.apache.iceberg.expressions.Expressions; +import org.apache.iceberg.io.CloseableIterable; +import org.apache.iceberg.relocated.com.google.common.annotations.VisibleForTesting; +import org.apache.iceberg.relocated.com.google.common.base.Preconditions; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableList; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableSet; +import org.apache.iceberg.relocated.com.google.common.collect.Lists; +import org.apache.iceberg.relocated.com.google.common.collect.Maps; +import org.apache.iceberg.relocated.com.google.common.collect.Queues; +import org.apache.iceberg.relocated.com.google.common.collect.Sets; +import org.apache.iceberg.relocated.com.google.common.math.IntMath; +import org.apache.iceberg.relocated.com.google.common.util.concurrent.MoreExecutors; +import org.apache.iceberg.relocated.com.google.common.util.concurrent.ThreadFactoryBuilder; +import org.apache.iceberg.types.Types.StructType; +import org.apache.iceberg.util.PropertyUtil; +import org.apache.iceberg.util.StructLikeMap; +import org.apache.iceberg.util.Tasks; +import org.apache.spark.sql.SparkSession; +import org.apache.spark.sql.internal.SQLConf; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +public class RewriteDataFilesSparkAction + extends BaseSnapshotUpdateSparkAction implements RewriteDataFiles { + + private static final Logger LOG = LoggerFactory.getLogger(RewriteDataFilesSparkAction.class); + private static final Set VALID_OPTIONS = + ImmutableSet.of( + MAX_CONCURRENT_FILE_GROUP_REWRITES, + MAX_FILE_GROUP_SIZE_BYTES, + PARTIAL_PROGRESS_ENABLED, + PARTIAL_PROGRESS_MAX_COMMITS, + TARGET_FILE_SIZE_BYTES, + USE_STARTING_SEQUENCE_NUMBER, + REWRITE_JOB_ORDER); + + private final Table table; + + private Expression filter = Expressions.alwaysTrue(); + private int maxConcurrentFileGroupRewrites; + private int maxCommits; + private boolean partialProgressEnabled; + private boolean useStartingSequenceNumber; + private RewriteJobOrder rewriteJobOrder; + private FileRewriter rewriter = null; + + RewriteDataFilesSparkAction(SparkSession spark, Table table) { + super(spark.cloneSession()); + // Disable Adaptive Query Execution as this may change the output partitioning of our write + spark().conf().set(SQLConf.ADAPTIVE_EXECUTION_ENABLED().key(), false); + this.table = table; + } + + @Override + protected RewriteDataFilesSparkAction self() { + return this; + } + + @Override + public RewriteDataFilesSparkAction binPack() { + Preconditions.checkArgument( + rewriter == null, "Must use only one rewriter type (bin-pack, sort, zorder)"); + this.rewriter = new SparkBinPackDataRewriter(spark(), table); + return this; + } + + @Override + public RewriteDataFilesSparkAction sort(SortOrder sortOrder) { + Preconditions.checkArgument( + rewriter == null, "Must use only one rewriter type (bin-pack, sort, zorder)"); + this.rewriter = new SparkSortDataRewriter(spark(), table, sortOrder); + return this; + } + + @Override + public RewriteDataFilesSparkAction sort() { + Preconditions.checkArgument( + rewriter == null, "Must use only one rewriter type (bin-pack, sort, zorder)"); + this.rewriter = new SparkSortDataRewriter(spark(), table); + return this; + } + + @Override + public RewriteDataFilesSparkAction zOrder(String... columnNames) { + Preconditions.checkArgument( + rewriter == null, "Must use only one rewriter type (bin-pack, sort, zorder)"); + this.rewriter = new SparkZOrderDataRewriter(spark(), table, Arrays.asList(columnNames)); + return this; + } + + @Override + public RewriteDataFilesSparkAction filter(Expression expression) { + filter = Expressions.and(filter, expression); + return this; + } + + @Override + public RewriteDataFiles.Result execute() { + if (table.currentSnapshot() == null) { + return ImmutableRewriteDataFiles.Result.builder().rewriteResults(ImmutableList.of()).build(); + } + + long startingSnapshotId = table.currentSnapshot().snapshotId(); + + // Default to BinPack if no strategy selected + if (this.rewriter == null) { + this.rewriter = new SparkBinPackDataRewriter(spark(), table); + } + + validateAndInitOptions(); + + Map>> fileGroupsByPartition = + planFileGroups(startingSnapshotId); + RewriteExecutionContext ctx = new RewriteExecutionContext(fileGroupsByPartition); + + if (ctx.totalGroupCount() == 0) { + LOG.info("Nothing found to rewrite in {}", table.name()); + return ImmutableRewriteDataFiles.Result.builder().rewriteResults(ImmutableList.of()).build(); + } + + Stream groupStream = toGroupStream(ctx, fileGroupsByPartition); + + RewriteDataFilesCommitManager commitManager = commitManager(startingSnapshotId); + if (partialProgressEnabled) { + return doExecuteWithPartialProgress(ctx, groupStream, commitManager); + } else { + return doExecute(ctx, groupStream, commitManager); + } + } + + Map>> planFileGroups(long startingSnapshotId) { + CloseableIterable fileScanTasks = + table + .newScan() + .useSnapshot(startingSnapshotId) + .filter(filter) + .ignoreResiduals() + .planFiles(); + + try { + StructType partitionType = table.spec().partitionType(); + StructLikeMap> filesByPartition = StructLikeMap.create(partitionType); + StructLike emptyStruct = GenericRecord.create(partitionType); + + fileScanTasks.forEach( + task -> { + // If a task uses an incompatible partition spec the data inside could contain values + // which + // belong to multiple partitions in the current spec. Treating all such files as + // un-partitioned and + // grouping them together helps to minimize new files made. + StructLike taskPartition = + task.file().specId() == table.spec().specId() + ? task.file().partition() + : emptyStruct; + + List files = filesByPartition.get(taskPartition); + if (files == null) { + files = Lists.newArrayList(); + } + + files.add(task); + filesByPartition.put(taskPartition, files); + }); + + StructLikeMap>> fileGroupsByPartition = + StructLikeMap.create(partitionType); + + filesByPartition.forEach( + (partition, tasks) -> { + Iterable> plannedFileGroups = rewriter.planFileGroups(tasks); + List> fileGroups = ImmutableList.copyOf(plannedFileGroups); + if (fileGroups.size() > 0) { + fileGroupsByPartition.put(partition, fileGroups); + } + }); + + return fileGroupsByPartition; + } finally { + try { + fileScanTasks.close(); + } catch (IOException io) { + LOG.error("Cannot properly close file iterable while planning for rewrite", io); + } + } + } + + @VisibleForTesting + RewriteFileGroup rewriteFiles(RewriteExecutionContext ctx, RewriteFileGroup fileGroup) { + String desc = jobDesc(fileGroup, ctx); + Set addedFiles = + withJobGroupInfo( + newJobGroupInfo("REWRITE-DATA-FILES", desc), + () -> rewriter.rewrite(fileGroup.fileScans())); + + fileGroup.setOutputFiles(addedFiles); + LOG.info("Rewrite Files Ready to be Committed - {}", desc); + return fileGroup; + } + + private ExecutorService rewriteService() { + return MoreExecutors.getExitingExecutorService( + (ThreadPoolExecutor) + Executors.newFixedThreadPool( + maxConcurrentFileGroupRewrites, + new ThreadFactoryBuilder().setNameFormat("Rewrite-Service-%d").build())); + } + + @VisibleForTesting + RewriteDataFilesCommitManager commitManager(long startingSnapshotId) { + return new RewriteDataFilesCommitManager(table, startingSnapshotId, useStartingSequenceNumber); + } + + private Result doExecute( + RewriteExecutionContext ctx, + Stream groupStream, + RewriteDataFilesCommitManager commitManager) { + ExecutorService rewriteService = rewriteService(); + + ConcurrentLinkedQueue rewrittenGroups = Queues.newConcurrentLinkedQueue(); + + Tasks.Builder rewriteTaskBuilder = + Tasks.foreach(groupStream) + .executeWith(rewriteService) + .stopOnFailure() + .noRetry() + .onFailure( + (fileGroup, exception) -> { + LOG.warn( + "Failure during rewrite process for group {}", fileGroup.info(), exception); + }); + + try { + rewriteTaskBuilder.run( + fileGroup -> { + rewrittenGroups.add(rewriteFiles(ctx, fileGroup)); + }); + } catch (Exception e) { + // At least one rewrite group failed, clean up all completed rewrites + LOG.error( + "Cannot complete rewrite, {} is not enabled and one of the file set groups failed to " + + "be rewritten. This error occurred during the writing of new files, not during the commit process. This " + + "indicates something is wrong that doesn't involve conflicts with other Iceberg operations. Enabling " + + "{} may help in this case but the root cause should be investigated. Cleaning up {} groups which finished " + + "being written.", + PARTIAL_PROGRESS_ENABLED, + PARTIAL_PROGRESS_ENABLED, + rewrittenGroups.size(), + e); + + Tasks.foreach(rewrittenGroups) + .suppressFailureWhenFinished() + .run(group -> commitManager.abortFileGroup(group)); + throw e; + } finally { + rewriteService.shutdown(); + } + + try { + commitManager.commitOrClean(Sets.newHashSet(rewrittenGroups)); + } catch (ValidationException | CommitFailedException e) { + String errorMessage = + String.format( + "Cannot commit rewrite because of a ValidationException or CommitFailedException. This usually means that " + + "this rewrite has conflicted with another concurrent Iceberg operation. To reduce the likelihood of " + + "conflicts, set %s which will break up the rewrite into multiple smaller commits controlled by %s. " + + "Separate smaller rewrite commits can succeed independently while any commits that conflict with " + + "another Iceberg operation will be ignored. This mode will create additional snapshots in the table " + + "history, one for each commit.", + PARTIAL_PROGRESS_ENABLED, PARTIAL_PROGRESS_MAX_COMMITS); + throw new RuntimeException(errorMessage, e); + } + + List rewriteResults = + rewrittenGroups.stream().map(RewriteFileGroup::asResult).collect(Collectors.toList()); + return ImmutableRewriteDataFiles.Result.builder().rewriteResults(rewriteResults).build(); + } + + private Result doExecuteWithPartialProgress( + RewriteExecutionContext ctx, + Stream groupStream, + RewriteDataFilesCommitManager commitManager) { + ExecutorService rewriteService = rewriteService(); + + // Start Commit Service + int groupsPerCommit = IntMath.divide(ctx.totalGroupCount(), maxCommits, RoundingMode.CEILING); + RewriteDataFilesCommitManager.CommitService commitService = + commitManager.service(groupsPerCommit); + commitService.start(); + + // Start rewrite tasks + Tasks.foreach(groupStream) + .suppressFailureWhenFinished() + .executeWith(rewriteService) + .noRetry() + .onFailure( + (fileGroup, exception) -> + LOG.error("Failure during rewrite group {}", fileGroup.info(), exception)) + .run(fileGroup -> commitService.offer(rewriteFiles(ctx, fileGroup))); + rewriteService.shutdown(); + + // Stop Commit service + commitService.close(); + List commitResults = commitService.results(); + if (commitResults.size() == 0) { + LOG.error( + "{} is true but no rewrite commits succeeded. Check the logs to determine why the individual " + + "commits failed. If this is persistent it may help to increase {} which will break the rewrite operation " + + "into smaller commits.", + PARTIAL_PROGRESS_ENABLED, + PARTIAL_PROGRESS_MAX_COMMITS); + } + + List rewriteResults = + commitResults.stream().map(RewriteFileGroup::asResult).collect(Collectors.toList()); + return ImmutableRewriteDataFiles.Result.builder().rewriteResults(rewriteResults).build(); + } + + Stream toGroupStream( + RewriteExecutionContext ctx, + Map>> fileGroupsByPartition) { + Stream rewriteFileGroupStream = + fileGroupsByPartition.entrySet().stream() + .flatMap( + e -> { + StructLike partition = e.getKey(); + List> fileGroups = e.getValue(); + return fileGroups.stream() + .map( + tasks -> { + int globalIndex = ctx.currentGlobalIndex(); + int partitionIndex = ctx.currentPartitionIndex(partition); + FileGroupInfo info = + ImmutableRewriteDataFiles.FileGroupInfo.builder() + .globalIndex(globalIndex) + .partitionIndex(partitionIndex) + .partition(partition) + .build(); + return new RewriteFileGroup(info, tasks); + }); + }); + + return rewriteFileGroupStream.sorted(rewriteGroupComparator()); + } + + private Comparator rewriteGroupComparator() { + switch (rewriteJobOrder) { + case BYTES_ASC: + return Comparator.comparing(RewriteFileGroup::sizeInBytes); + case BYTES_DESC: + return Comparator.comparing(RewriteFileGroup::sizeInBytes, Comparator.reverseOrder()); + case FILES_ASC: + return Comparator.comparing(RewriteFileGroup::numFiles); + case FILES_DESC: + return Comparator.comparing(RewriteFileGroup::numFiles, Comparator.reverseOrder()); + default: + return (fileGroupOne, fileGroupTwo) -> 0; + } + } + + void validateAndInitOptions() { + Set validOptions = Sets.newHashSet(rewriter.validOptions()); + validOptions.addAll(VALID_OPTIONS); + + Set invalidKeys = Sets.newHashSet(options().keySet()); + invalidKeys.removeAll(validOptions); + + Preconditions.checkArgument( + invalidKeys.isEmpty(), + "Cannot use options %s, they are not supported by the action or the rewriter %s", + invalidKeys, + rewriter.description()); + + rewriter.init(options()); + + maxConcurrentFileGroupRewrites = + PropertyUtil.propertyAsInt( + options(), + MAX_CONCURRENT_FILE_GROUP_REWRITES, + MAX_CONCURRENT_FILE_GROUP_REWRITES_DEFAULT); + + maxCommits = + PropertyUtil.propertyAsInt( + options(), PARTIAL_PROGRESS_MAX_COMMITS, PARTIAL_PROGRESS_MAX_COMMITS_DEFAULT); + + partialProgressEnabled = + PropertyUtil.propertyAsBoolean( + options(), PARTIAL_PROGRESS_ENABLED, PARTIAL_PROGRESS_ENABLED_DEFAULT); + + useStartingSequenceNumber = + PropertyUtil.propertyAsBoolean( + options(), USE_STARTING_SEQUENCE_NUMBER, USE_STARTING_SEQUENCE_NUMBER_DEFAULT); + + rewriteJobOrder = + RewriteJobOrder.fromName( + PropertyUtil.propertyAsString(options(), REWRITE_JOB_ORDER, REWRITE_JOB_ORDER_DEFAULT)); + + Preconditions.checkArgument( + maxConcurrentFileGroupRewrites >= 1, + "Cannot set %s to %s, the value must be positive.", + MAX_CONCURRENT_FILE_GROUP_REWRITES, + maxConcurrentFileGroupRewrites); + + Preconditions.checkArgument( + !partialProgressEnabled || maxCommits > 0, + "Cannot set %s to %s, the value must be positive when %s is true", + PARTIAL_PROGRESS_MAX_COMMITS, + maxCommits, + PARTIAL_PROGRESS_ENABLED); + } + + private String jobDesc(RewriteFileGroup group, RewriteExecutionContext ctx) { + StructLike partition = group.info().partition(); + if (partition.size() > 0) { + return String.format( + "Rewriting %d files (%s, file group %d/%d, %s (%d/%d)) in %s", + group.rewrittenFiles().size(), + rewriter.description(), + group.info().globalIndex(), + ctx.totalGroupCount(), + partition, + group.info().partitionIndex(), + ctx.groupsInPartition(partition), + table.name()); + } else { + return String.format( + "Rewriting %d files (%s, file group %d/%d) in %s", + group.rewrittenFiles().size(), + rewriter.description(), + group.info().globalIndex(), + ctx.totalGroupCount(), + table.name()); + } + } + + @VisibleForTesting + static class RewriteExecutionContext { + private final Map numGroupsByPartition; + private final int totalGroupCount; + private final Map partitionIndexMap; + private final AtomicInteger groupIndex; + + RewriteExecutionContext(Map>> fileGroupsByPartition) { + this.numGroupsByPartition = + fileGroupsByPartition.entrySet().stream() + .collect(Collectors.toMap(Map.Entry::getKey, e -> e.getValue().size())); + this.totalGroupCount = numGroupsByPartition.values().stream().reduce(Integer::sum).orElse(0); + this.partitionIndexMap = Maps.newConcurrentMap(); + this.groupIndex = new AtomicInteger(1); + } + + public int currentGlobalIndex() { + return groupIndex.getAndIncrement(); + } + + public int currentPartitionIndex(StructLike partition) { + return partitionIndexMap.merge(partition, 1, Integer::sum); + } + + public int groupsInPartition(StructLike partition) { + return numGroupsByPartition.get(partition); + } + + public int totalGroupCount() { + return totalGroupCount; + } + } +} diff --git a/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/actions/RewriteManifestsSparkAction.java b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/actions/RewriteManifestsSparkAction.java new file mode 100644 index 000000000000..06a5c8c5720f --- /dev/null +++ b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/actions/RewriteManifestsSparkAction.java @@ -0,0 +1,463 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.actions; + +import static org.apache.iceberg.MetadataTableType.ENTRIES; + +import java.io.IOException; +import java.util.Collections; +import java.util.List; +import java.util.UUID; +import java.util.function.Function; +import java.util.function.Predicate; +import java.util.stream.Collectors; +import org.apache.hadoop.fs.Path; +import org.apache.iceberg.DataFile; +import org.apache.iceberg.FileFormat; +import org.apache.iceberg.HasTableOperations; +import org.apache.iceberg.ManifestFile; +import org.apache.iceberg.ManifestFiles; +import org.apache.iceberg.ManifestWriter; +import org.apache.iceberg.PartitionSpec; +import org.apache.iceberg.Partitioning; +import org.apache.iceberg.SerializableTable; +import org.apache.iceberg.Snapshot; +import org.apache.iceberg.Table; +import org.apache.iceberg.TableOperations; +import org.apache.iceberg.TableProperties; +import org.apache.iceberg.actions.ImmutableRewriteManifests; +import org.apache.iceberg.actions.RewriteManifests; +import org.apache.iceberg.exceptions.CommitStateUnknownException; +import org.apache.iceberg.exceptions.ValidationException; +import org.apache.iceberg.io.OutputFile; +import org.apache.iceberg.relocated.com.google.common.base.Preconditions; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableList; +import org.apache.iceberg.relocated.com.google.common.collect.Iterables; +import org.apache.iceberg.relocated.com.google.common.collect.Lists; +import org.apache.iceberg.spark.JobGroupInfo; +import org.apache.iceberg.spark.SparkDataFile; +import org.apache.iceberg.types.Types; +import org.apache.iceberg.util.PropertyUtil; +import org.apache.iceberg.util.Tasks; +import org.apache.iceberg.util.ThreadPools; +import org.apache.spark.api.java.function.MapFunction; +import org.apache.spark.api.java.function.MapPartitionsFunction; +import org.apache.spark.broadcast.Broadcast; +import org.apache.spark.sql.Column; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Encoder; +import org.apache.spark.sql.Encoders; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.SparkSession; +import org.apache.spark.sql.internal.SQLConf; +import org.apache.spark.sql.types.StructType; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * An action that rewrites manifests in a distributed manner and co-locates metadata for partitions. + * + *

By default, this action rewrites all manifests for the current partition spec and writes the + * result to the metadata folder. The behavior can be modified by passing a custom predicate to + * {@link #rewriteIf(Predicate)} and a custom spec id to {@link #specId(int)}. In addition, there is + * a way to configure a custom location for new manifests via {@link #stagingLocation}. + */ +public class RewriteManifestsSparkAction + extends BaseSnapshotUpdateSparkAction implements RewriteManifests { + + public static final String USE_CACHING = "use-caching"; + public static final boolean USE_CACHING_DEFAULT = true; + + private static final Logger LOG = LoggerFactory.getLogger(RewriteManifestsSparkAction.class); + + private final Encoder manifestEncoder; + private final Table table; + private final int formatVersion; + private final long targetManifestSizeBytes; + + private PartitionSpec spec = null; + private Predicate predicate = manifest -> true; + private String stagingLocation = null; + + RewriteManifestsSparkAction(SparkSession spark, Table table) { + super(spark); + this.manifestEncoder = Encoders.javaSerialization(ManifestFile.class); + this.table = table; + this.spec = table.spec(); + this.targetManifestSizeBytes = + PropertyUtil.propertyAsLong( + table.properties(), + TableProperties.MANIFEST_TARGET_SIZE_BYTES, + TableProperties.MANIFEST_TARGET_SIZE_BYTES_DEFAULT); + + // default the staging location to the metadata location + TableOperations ops = ((HasTableOperations) table).operations(); + Path metadataFilePath = new Path(ops.metadataFileLocation("file")); + this.stagingLocation = metadataFilePath.getParent().toString(); + + // use the current table format version for new manifests + this.formatVersion = ops.current().formatVersion(); + } + + @Override + protected RewriteManifestsSparkAction self() { + return this; + } + + @Override + public RewriteManifestsSparkAction specId(int specId) { + Preconditions.checkArgument(table.specs().containsKey(specId), "Invalid spec id %s", specId); + this.spec = table.specs().get(specId); + return this; + } + + @Override + public RewriteManifestsSparkAction rewriteIf(Predicate newPredicate) { + this.predicate = newPredicate; + return this; + } + + @Override + public RewriteManifestsSparkAction stagingLocation(String newStagingLocation) { + this.stagingLocation = newStagingLocation; + return this; + } + + @Override + public RewriteManifests.Result execute() { + String desc = + String.format( + "Rewriting manifests (staging location=%s) of %s", stagingLocation, table.name()); + JobGroupInfo info = newJobGroupInfo("REWRITE-MANIFESTS", desc); + return withJobGroupInfo(info, this::doExecute); + } + + private RewriteManifests.Result doExecute() { + List matchingManifests = findMatchingManifests(); + if (matchingManifests.isEmpty()) { + return ImmutableRewriteManifests.Result.builder() + .addedManifests(ImmutableList.of()) + .rewrittenManifests(ImmutableList.of()) + .build(); + } + + long totalSizeBytes = 0L; + int numEntries = 0; + + for (ManifestFile manifest : matchingManifests) { + ValidationException.check( + hasFileCounts(manifest), "No file counts in manifest: %s", manifest.path()); + + totalSizeBytes += manifest.length(); + numEntries += + manifest.addedFilesCount() + manifest.existingFilesCount() + manifest.deletedFilesCount(); + } + + int targetNumManifests = targetNumManifests(totalSizeBytes); + int targetNumManifestEntries = targetNumManifestEntries(numEntries, targetNumManifests); + + if (targetNumManifests == 1 && matchingManifests.size() == 1) { + return ImmutableRewriteManifests.Result.builder() + .addedManifests(ImmutableList.of()) + .rewrittenManifests(ImmutableList.of()) + .build(); + } + + Dataset manifestEntryDF = buildManifestEntryDF(matchingManifests); + + List newManifests; + if (spec.fields().size() < 1) { + newManifests = writeManifestsForUnpartitionedTable(manifestEntryDF, targetNumManifests); + } else { + newManifests = + writeManifestsForPartitionedTable( + manifestEntryDF, targetNumManifests, targetNumManifestEntries); + } + + replaceManifests(matchingManifests, newManifests); + + return ImmutableRewriteManifests.Result.builder() + .rewrittenManifests(matchingManifests) + .addedManifests(newManifests) + .build(); + } + + private Dataset buildManifestEntryDF(List manifests) { + Dataset manifestDF = + spark() + .createDataset(Lists.transform(manifests, ManifestFile::path), Encoders.STRING()) + .toDF("manifest"); + + Dataset manifestEntryDF = + loadMetadataTable(table, ENTRIES) + .filter("status < 2") // select only live entries + .selectExpr( + "input_file_name() as manifest", + "snapshot_id", + "sequence_number", + "file_sequence_number", + "data_file"); + + Column joinCond = manifestDF.col("manifest").equalTo(manifestEntryDF.col("manifest")); + return manifestEntryDF + .join(manifestDF, joinCond, "left_semi") + .select("snapshot_id", "sequence_number", "file_sequence_number", "data_file"); + } + + private List writeManifestsForUnpartitionedTable( + Dataset manifestEntryDF, int numManifests) { + Broadcast

tableBroadcast = sparkContext().broadcast(SerializableTable.copyOf(table)); + StructType sparkType = (StructType) manifestEntryDF.schema().apply("data_file").dataType(); + Types.StructType combinedPartitionType = Partitioning.partitionType(table); + + // we rely only on the target number of manifests for unpartitioned tables + // as we should not worry about having too much metadata per partition + long maxNumManifestEntries = Long.MAX_VALUE; + + return manifestEntryDF + .repartition(numManifests) + .mapPartitions( + toManifests( + tableBroadcast, + maxNumManifestEntries, + stagingLocation, + formatVersion, + combinedPartitionType, + spec, + sparkType), + manifestEncoder) + .collectAsList(); + } + + private List writeManifestsForPartitionedTable( + Dataset manifestEntryDF, int numManifests, int targetNumManifestEntries) { + + Broadcast
tableBroadcast = sparkContext().broadcast(SerializableTable.copyOf(table)); + StructType sparkType = (StructType) manifestEntryDF.schema().apply("data_file").dataType(); + Types.StructType combinedPartitionType = Partitioning.partitionType(table); + + // we allow the actual size of manifests to be 10% higher if the estimation is not precise + // enough + long maxNumManifestEntries = (long) (1.1 * targetNumManifestEntries); + + return withReusableDS( + manifestEntryDF, + df -> { + Column partitionColumn = df.col("data_file.partition"); + return df.repartitionByRange(numManifests, partitionColumn) + .sortWithinPartitions(partitionColumn) + .mapPartitions( + toManifests( + tableBroadcast, + maxNumManifestEntries, + stagingLocation, + formatVersion, + combinedPartitionType, + spec, + sparkType), + manifestEncoder) + .collectAsList(); + }); + } + + private U withReusableDS(Dataset ds, Function, U> func) { + Dataset reusableDS; + boolean useCaching = + PropertyUtil.propertyAsBoolean(options(), USE_CACHING, USE_CACHING_DEFAULT); + if (useCaching) { + reusableDS = ds.cache(); + } else { + int parallelism = SQLConf.get().numShufflePartitions(); + reusableDS = + ds.repartition(parallelism).map((MapFunction) value -> value, ds.exprEnc()); + } + + try { + return func.apply(reusableDS); + } finally { + if (useCaching) { + reusableDS.unpersist(false); + } + } + } + + private List findMatchingManifests() { + Snapshot currentSnapshot = table.currentSnapshot(); + + if (currentSnapshot == null) { + return ImmutableList.of(); + } + + return currentSnapshot.dataManifests(table.io()).stream() + .filter(manifest -> manifest.partitionSpecId() == spec.specId() && predicate.test(manifest)) + .collect(Collectors.toList()); + } + + private int targetNumManifests(long totalSizeBytes) { + return (int) ((totalSizeBytes + targetManifestSizeBytes - 1) / targetManifestSizeBytes); + } + + private int targetNumManifestEntries(int numEntries, int numManifests) { + return (numEntries + numManifests - 1) / numManifests; + } + + private boolean hasFileCounts(ManifestFile manifest) { + return manifest.addedFilesCount() != null + && manifest.existingFilesCount() != null + && manifest.deletedFilesCount() != null; + } + + private void replaceManifests( + Iterable deletedManifests, Iterable addedManifests) { + try { + boolean snapshotIdInheritanceEnabled = + PropertyUtil.propertyAsBoolean( + table.properties(), + TableProperties.SNAPSHOT_ID_INHERITANCE_ENABLED, + TableProperties.SNAPSHOT_ID_INHERITANCE_ENABLED_DEFAULT); + + org.apache.iceberg.RewriteManifests rewriteManifests = table.rewriteManifests(); + deletedManifests.forEach(rewriteManifests::deleteManifest); + addedManifests.forEach(rewriteManifests::addManifest); + commit(rewriteManifests); + + if (!snapshotIdInheritanceEnabled) { + // delete new manifests as they were rewritten before the commit + deleteFiles(Iterables.transform(addedManifests, ManifestFile::path)); + } + } catch (CommitStateUnknownException commitStateUnknownException) { + // don't clean up added manifest files, because they may have been successfully committed. + throw commitStateUnknownException; + } catch (Exception e) { + // delete all new manifests because the rewrite failed + deleteFiles(Iterables.transform(addedManifests, ManifestFile::path)); + throw e; + } + } + + private void deleteFiles(Iterable locations) { + Tasks.foreach(locations) + .executeWith(ThreadPools.getWorkerPool()) + .noRetry() + .suppressFailureWhenFinished() + .onFailure((location, exc) -> LOG.warn("Failed to delete: {}", location, exc)) + .run(location -> table.io().deleteFile(location)); + } + + private static ManifestFile writeManifest( + List rows, + int startIndex, + int endIndex, + Broadcast
tableBroadcast, + String location, + int format, + Types.StructType combinedPartitionType, + PartitionSpec spec, + StructType sparkType) + throws IOException { + + String manifestName = "optimized-m-" + UUID.randomUUID(); + Path manifestPath = new Path(location, manifestName); + OutputFile outputFile = + tableBroadcast + .value() + .io() + .newOutputFile(FileFormat.AVRO.addExtension(manifestPath.toString())); + + Types.StructType combinedFileType = DataFile.getType(combinedPartitionType); + Types.StructType manifestFileType = DataFile.getType(spec.partitionType()); + SparkDataFile wrapper = new SparkDataFile(combinedFileType, manifestFileType, sparkType); + + ManifestWriter writer = ManifestFiles.write(format, spec, outputFile, null); + + try { + for (int index = startIndex; index < endIndex; index++) { + Row row = rows.get(index); + long snapshotId = row.getLong(0); + long sequenceNumber = row.getLong(1); + Long fileSequenceNumber = row.isNullAt(2) ? null : row.getLong(2); + Row file = row.getStruct(3); + writer.existing(wrapper.wrap(file), snapshotId, sequenceNumber, fileSequenceNumber); + } + } finally { + writer.close(); + } + + return writer.toManifestFile(); + } + + private static MapPartitionsFunction toManifests( + Broadcast
tableBroadcast, + long maxNumManifestEntries, + String location, + int format, + Types.StructType combinedPartitionType, + PartitionSpec spec, + StructType sparkType) { + + return rows -> { + List rowsAsList = Lists.newArrayList(rows); + + if (rowsAsList.isEmpty()) { + return Collections.emptyIterator(); + } + + List manifests = Lists.newArrayList(); + if (rowsAsList.size() <= maxNumManifestEntries) { + manifests.add( + writeManifest( + rowsAsList, + 0, + rowsAsList.size(), + tableBroadcast, + location, + format, + combinedPartitionType, + spec, + sparkType)); + } else { + int midIndex = rowsAsList.size() / 2; + manifests.add( + writeManifest( + rowsAsList, + 0, + midIndex, + tableBroadcast, + location, + format, + combinedPartitionType, + spec, + sparkType)); + manifests.add( + writeManifest( + rowsAsList, + midIndex, + rowsAsList.size(), + tableBroadcast, + location, + format, + combinedPartitionType, + spec, + sparkType)); + } + + return manifests.iterator(); + }; + } +} diff --git a/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/actions/SetAccumulator.java b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/actions/SetAccumulator.java new file mode 100644 index 000000000000..745169fc1efd --- /dev/null +++ b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/actions/SetAccumulator.java @@ -0,0 +1,61 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.actions; + +import java.util.Collections; +import java.util.Set; +import org.apache.iceberg.relocated.com.google.common.collect.Sets; +import org.apache.spark.util.AccumulatorV2; + +public class SetAccumulator extends AccumulatorV2> { + + private final Set set = Collections.synchronizedSet(Sets.newHashSet()); + + @Override + public boolean isZero() { + return set.isEmpty(); + } + + @Override + public AccumulatorV2> copy() { + SetAccumulator newAccumulator = new SetAccumulator<>(); + newAccumulator.set.addAll(set); + return newAccumulator; + } + + @Override + public void reset() { + set.clear(); + } + + @Override + public void add(T v) { + set.add(v); + } + + @Override + public void merge(AccumulatorV2> other) { + set.addAll(other.value()); + } + + @Override + public Set value() { + return set; + } +} diff --git a/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/actions/SnapshotTableSparkAction.java b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/actions/SnapshotTableSparkAction.java new file mode 100644 index 000000000000..8e59c13543f8 --- /dev/null +++ b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/actions/SnapshotTableSparkAction.java @@ -0,0 +1,216 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.actions; + +import java.util.Map; +import org.apache.iceberg.Snapshot; +import org.apache.iceberg.SnapshotSummary; +import org.apache.iceberg.Table; +import org.apache.iceberg.TableProperties; +import org.apache.iceberg.actions.ImmutableSnapshotTable; +import org.apache.iceberg.actions.SnapshotTable; +import org.apache.iceberg.relocated.com.google.common.base.Preconditions; +import org.apache.iceberg.relocated.com.google.common.collect.Maps; +import org.apache.iceberg.spark.JobGroupInfo; +import org.apache.iceberg.spark.Spark3Util; +import org.apache.iceberg.spark.Spark3Util.CatalogAndIdentifier; +import org.apache.iceberg.spark.SparkTableUtil; +import org.apache.iceberg.spark.source.StagedSparkTable; +import org.apache.spark.sql.SparkSession; +import org.apache.spark.sql.catalyst.TableIdentifier; +import org.apache.spark.sql.connector.catalog.CatalogPlugin; +import org.apache.spark.sql.connector.catalog.Identifier; +import org.apache.spark.sql.connector.catalog.StagingTableCatalog; +import org.apache.spark.sql.connector.catalog.TableCatalog; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import scala.collection.JavaConverters; + +/** + * Creates a new Iceberg table based on a source Spark table. The new Iceberg table will have a + * different data and metadata directory allowing it to exist independently of the source table. + */ +public class SnapshotTableSparkAction extends BaseTableCreationSparkAction + implements SnapshotTable { + + private static final Logger LOG = LoggerFactory.getLogger(SnapshotTableSparkAction.class); + + private StagingTableCatalog destCatalog; + private Identifier destTableIdent; + private String destTableLocation = null; + + SnapshotTableSparkAction( + SparkSession spark, CatalogPlugin sourceCatalog, Identifier sourceTableIdent) { + super(spark, sourceCatalog, sourceTableIdent); + } + + @Override + protected SnapshotTableSparkAction self() { + return this; + } + + @Override + protected StagingTableCatalog destCatalog() { + return destCatalog; + } + + @Override + protected Identifier destTableIdent() { + return destTableIdent; + } + + @Override + public SnapshotTableSparkAction as(String ident) { + String ctx = "snapshot destination"; + CatalogPlugin defaultCatalog = spark().sessionState().catalogManager().currentCatalog(); + CatalogAndIdentifier catalogAndIdent = + Spark3Util.catalogAndIdentifier(ctx, spark(), ident, defaultCatalog); + this.destCatalog = checkDestinationCatalog(catalogAndIdent.catalog()); + this.destTableIdent = catalogAndIdent.identifier(); + return this; + } + + @Override + public SnapshotTableSparkAction tableProperties(Map properties) { + setProperties(properties); + return this; + } + + @Override + public SnapshotTableSparkAction tableProperty(String property, String value) { + setProperty(property, value); + return this; + } + + @Override + public SnapshotTable.Result execute() { + String desc = String.format("Snapshotting table %s as %s", sourceTableIdent(), destTableIdent); + JobGroupInfo info = newJobGroupInfo("SNAPSHOT-TABLE", desc); + return withJobGroupInfo(info, this::doExecute); + } + + private SnapshotTable.Result doExecute() { + Preconditions.checkArgument( + destCatalog() != null && destTableIdent() != null, + "The destination catalog and identifier cannot be null. " + + "Make sure to configure the action with a valid destination table identifier via the `as` method."); + + LOG.info( + "Staging a new Iceberg table {} as a snapshot of {}", destTableIdent(), sourceTableIdent()); + StagedSparkTable stagedTable = stageDestTable(); + Table icebergTable = stagedTable.table(); + + // TODO: Check the dest table location does not overlap with the source table location + + boolean threw = true; + try { + LOG.info("Ensuring {} has a valid name mapping", destTableIdent()); + ensureNameMappingPresent(icebergTable); + + TableIdentifier v1TableIdent = v1SourceTable().identifier(); + String stagingLocation = getMetadataLocation(icebergTable); + LOG.info("Generating Iceberg metadata for {} in {}", destTableIdent(), stagingLocation); + SparkTableUtil.importSparkTable(spark(), v1TableIdent, icebergTable, stagingLocation); + + LOG.info("Committing staged changes to {}", destTableIdent()); + stagedTable.commitStagedChanges(); + threw = false; + } finally { + if (threw) { + LOG.error("Error when populating the staged table with metadata, aborting changes"); + + try { + stagedTable.abortStagedChanges(); + } catch (Exception abortException) { + LOG.error("Cannot abort staged changes", abortException); + } + } + } + + Snapshot snapshot = icebergTable.currentSnapshot(); + long importedDataFilesCount = + Long.parseLong(snapshot.summary().get(SnapshotSummary.TOTAL_DATA_FILES_PROP)); + LOG.info( + "Successfully loaded Iceberg metadata for {} files to {}", + importedDataFilesCount, + destTableIdent()); + return ImmutableSnapshotTable.Result.builder() + .importedDataFilesCount(importedDataFilesCount) + .build(); + } + + @Override + protected Map destTableProps() { + Map properties = Maps.newHashMap(); + + // copy over relevant source table props + properties.putAll(JavaConverters.mapAsJavaMapConverter(v1SourceTable().properties()).asJava()); + EXCLUDED_PROPERTIES.forEach(properties::remove); + + // remove any possible location properties from origin properties + properties.remove(LOCATION); + properties.remove(TableProperties.WRITE_METADATA_LOCATION); + properties.remove(TableProperties.WRITE_FOLDER_STORAGE_LOCATION); + properties.remove(TableProperties.OBJECT_STORE_PATH); + properties.remove(TableProperties.WRITE_DATA_LOCATION); + + // set default and user-provided props + properties.put(TableCatalog.PROP_PROVIDER, "iceberg"); + properties.putAll(additionalProperties()); + + // make sure we mark this table as a snapshot table + properties.put(TableProperties.GC_ENABLED, "false"); + properties.put("snapshot", "true"); + + // set the destination table location if provided + if (destTableLocation != null) { + properties.put(LOCATION, destTableLocation); + } + + return properties; + } + + @Override + protected TableCatalog checkSourceCatalog(CatalogPlugin catalog) { + // currently the import code relies on being able to look up the table in the session catalog + Preconditions.checkArgument( + catalog.name().equalsIgnoreCase("spark_catalog"), + "Cannot snapshot a table that isn't in the session catalog (i.e. spark_catalog). " + + "Found source catalog: %s.", + catalog.name()); + + Preconditions.checkArgument( + catalog instanceof TableCatalog, + "Cannot snapshot as catalog %s of class %s in not a table catalog", + catalog.name(), + catalog.getClass().getName()); + + return (TableCatalog) catalog; + } + + @Override + public SnapshotTableSparkAction tableLocation(String location) { + Preconditions.checkArgument( + !sourceTableLocation().equals(location), + "The snapshot table location cannot be same as the source table location. " + + "This would mix snapshot table files with original table files."); + this.destTableLocation = location; + return this; + } +} diff --git a/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/actions/SparkActions.java b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/actions/SparkActions.java new file mode 100644 index 000000000000..8c886adf510e --- /dev/null +++ b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/actions/SparkActions.java @@ -0,0 +1,94 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.actions; + +import org.apache.iceberg.Table; +import org.apache.iceberg.actions.ActionsProvider; +import org.apache.iceberg.spark.Spark3Util; +import org.apache.iceberg.spark.Spark3Util.CatalogAndIdentifier; +import org.apache.spark.sql.SparkSession; +import org.apache.spark.sql.connector.catalog.CatalogPlugin; + +/** + * An implementation of {@link ActionsProvider} for Spark. + * + *

This class is the primary API for interacting with actions in Spark that users should use to + * instantiate particular actions. + */ +public class SparkActions implements ActionsProvider { + + private final SparkSession spark; + + private SparkActions(SparkSession spark) { + this.spark = spark; + } + + public static SparkActions get(SparkSession spark) { + return new SparkActions(spark); + } + + public static SparkActions get() { + return new SparkActions(SparkSession.active()); + } + + @Override + public SnapshotTableSparkAction snapshotTable(String tableIdent) { + String ctx = "snapshot source"; + CatalogPlugin defaultCatalog = spark.sessionState().catalogManager().currentCatalog(); + CatalogAndIdentifier catalogAndIdent = + Spark3Util.catalogAndIdentifier(ctx, spark, tableIdent, defaultCatalog); + return new SnapshotTableSparkAction( + spark, catalogAndIdent.catalog(), catalogAndIdent.identifier()); + } + + @Override + public MigrateTableSparkAction migrateTable(String tableIdent) { + String ctx = "migrate target"; + CatalogPlugin defaultCatalog = spark.sessionState().catalogManager().currentCatalog(); + CatalogAndIdentifier catalogAndIdent = + Spark3Util.catalogAndIdentifier(ctx, spark, tableIdent, defaultCatalog); + return new MigrateTableSparkAction( + spark, catalogAndIdent.catalog(), catalogAndIdent.identifier()); + } + + @Override + public RewriteDataFilesSparkAction rewriteDataFiles(Table table) { + return new RewriteDataFilesSparkAction(spark, table); + } + + @Override + public DeleteOrphanFilesSparkAction deleteOrphanFiles(Table table) { + return new DeleteOrphanFilesSparkAction(spark, table); + } + + @Override + public RewriteManifestsSparkAction rewriteManifests(Table table) { + return new RewriteManifestsSparkAction(spark, table); + } + + @Override + public ExpireSnapshotsSparkAction expireSnapshots(Table table) { + return new ExpireSnapshotsSparkAction(spark, table); + } + + @Override + public DeleteReachableFilesSparkAction deleteReachableFiles(String metadataLocation) { + return new DeleteReachableFilesSparkAction(spark, metadataLocation); + } +} diff --git a/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/actions/SparkBinPackDataRewriter.java b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/actions/SparkBinPackDataRewriter.java new file mode 100644 index 000000000000..21e94ef9b4bf --- /dev/null +++ b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/actions/SparkBinPackDataRewriter.java @@ -0,0 +1,82 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.actions; + +import java.util.List; +import org.apache.iceberg.DistributionMode; +import org.apache.iceberg.FileScanTask; +import org.apache.iceberg.Table; +import org.apache.iceberg.spark.SparkReadOptions; +import org.apache.iceberg.spark.SparkWriteOptions; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.SparkSession; + +class SparkBinPackDataRewriter extends SparkSizeBasedDataRewriter { + + private static final long SPLIT_OVERHEAD = 5 * 1024; + + SparkBinPackDataRewriter(SparkSession spark, Table table) { + super(spark, table); + } + + @Override + public String description() { + return "BIN-PACK"; + } + + @Override + protected void doRewrite(String groupId, List group) { + // read the files packing them into splits of the required size + Dataset scanDF = + spark() + .read() + .format("iceberg") + .option(SparkReadOptions.SCAN_TASK_SET_ID, groupId) + .option(SparkReadOptions.SPLIT_SIZE, splitSize(inputSize(group))) + .option(SparkReadOptions.FILE_OPEN_COST, "0") + .load(groupId); + + // write the packed data into new files where each split becomes a new file + scanDF + .write() + .format("iceberg") + .option(SparkWriteOptions.REWRITTEN_FILE_SCAN_TASK_SET_ID, groupId) + .option(SparkWriteOptions.TARGET_FILE_SIZE_BYTES, writeMaxFileSize()) + .option(SparkWriteOptions.DISTRIBUTION_MODE, distributionMode(group).modeName()) + .mode("append") + .save(groupId); + } + + // invoke a shuffle if the original spec does not match the output spec + private DistributionMode distributionMode(List group) { + boolean requiresRepartition = !group.get(0).spec().equals(table().spec()); + return requiresRepartition ? DistributionMode.RANGE : DistributionMode.NONE; + } + + /** + * Returns the smallest of our max write file threshold and our estimated split size based on the + * number of output files we want to generate. Add an overhead onto the estimated split size to + * try to avoid small errors in size creating brand-new files. + */ + private long splitSize(long inputSize) { + long estimatedSplitSize = (inputSize / numOutputFiles(inputSize)) + SPLIT_OVERHEAD; + return Math.min(estimatedSplitSize, writeMaxFileSize()); + } +} diff --git a/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/actions/SparkBinPackStrategy.java b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/actions/SparkBinPackStrategy.java new file mode 100644 index 000000000000..07d3210ead66 --- /dev/null +++ b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/actions/SparkBinPackStrategy.java @@ -0,0 +1,103 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.actions; + +import java.util.List; +import java.util.Set; +import java.util.UUID; +import org.apache.iceberg.DataFile; +import org.apache.iceberg.DistributionMode; +import org.apache.iceberg.FileScanTask; +import org.apache.iceberg.Table; +import org.apache.iceberg.actions.BinPackStrategy; +import org.apache.iceberg.spark.FileRewriteCoordinator; +import org.apache.iceberg.spark.ScanTaskSetManager; +import org.apache.iceberg.spark.SparkReadOptions; +import org.apache.iceberg.spark.SparkTableCache; +import org.apache.iceberg.spark.SparkWriteOptions; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.SparkSession; + +/** + * A Spark strategy to bin-pack data. + * + * @deprecated since 1.3.0, will be removed in 1.4.0; use {@link SparkBinPackDataRewriter} instead. + */ +@Deprecated +public class SparkBinPackStrategy extends BinPackStrategy { + private final Table table; + private final SparkSession spark; + private final SparkTableCache tableCache = SparkTableCache.get(); + private final ScanTaskSetManager manager = ScanTaskSetManager.get(); + private final FileRewriteCoordinator rewriteCoordinator = FileRewriteCoordinator.get(); + + public SparkBinPackStrategy(Table table, SparkSession spark) { + this.table = table; + this.spark = spark; + } + + @Override + public Table table() { + return table; + } + + @Override + public Set rewriteFiles(List filesToRewrite) { + String groupID = UUID.randomUUID().toString(); + try { + tableCache.add(groupID, table); + manager.stageTasks(table, groupID, filesToRewrite); + + Dataset scanDF = + spark + .read() + .format("iceberg") + .option(SparkReadOptions.SCAN_TASK_SET_ID, groupID) + .option(SparkReadOptions.SPLIT_SIZE, splitSize(inputFileSize(filesToRewrite))) + .option(SparkReadOptions.FILE_OPEN_COST, "0") + .load(groupID); + + // All files within a file group are written with the same spec, so check the first + boolean requiresRepartition = !filesToRewrite.get(0).spec().equals(table.spec()); + + // Invoke a shuffle if the partition spec of the incoming partition does not match the table + String distributionMode = + requiresRepartition + ? DistributionMode.RANGE.modeName() + : DistributionMode.NONE.modeName(); + + // write the packed data into new files where each split becomes a new file + scanDF + .write() + .format("iceberg") + .option(SparkWriteOptions.REWRITTEN_FILE_SCAN_TASK_SET_ID, groupID) + .option(SparkWriteOptions.TARGET_FILE_SIZE_BYTES, writeMaxFileSize()) + .option(SparkWriteOptions.DISTRIBUTION_MODE, distributionMode) + .mode("append") + .save(groupID); + + return rewriteCoordinator.fetchNewDataFiles(table, groupID); + } finally { + tableCache.remove(groupID); + manager.removeTasks(table, groupID); + rewriteCoordinator.clearRewrite(table, groupID); + } + } +} diff --git a/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/actions/SparkShufflingDataRewriter.java b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/actions/SparkShufflingDataRewriter.java new file mode 100644 index 000000000000..1add6383c618 --- /dev/null +++ b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/actions/SparkShufflingDataRewriter.java @@ -0,0 +1,138 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.actions; + +import java.util.List; +import java.util.Map; +import java.util.Set; +import org.apache.iceberg.FileScanTask; +import org.apache.iceberg.Table; +import org.apache.iceberg.relocated.com.google.common.base.Preconditions; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableSet; +import org.apache.iceberg.spark.SparkDistributionAndOrderingUtil; +import org.apache.iceberg.spark.SparkReadOptions; +import org.apache.iceberg.spark.SparkWriteOptions; +import org.apache.iceberg.util.PropertyUtil; +import org.apache.iceberg.util.SortOrderUtil; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.SparkSession; +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan; +import org.apache.spark.sql.catalyst.utils.DistributionAndOrderingUtils$; +import org.apache.spark.sql.connector.distributions.Distributions; +import org.apache.spark.sql.connector.distributions.OrderedDistribution; +import org.apache.spark.sql.connector.expressions.SortOrder; +import org.apache.spark.sql.internal.SQLConf; + +abstract class SparkShufflingDataRewriter extends SparkSizeBasedDataRewriter { + + /** + * The number of shuffle partitions and consequently the number of output files created by the + * Spark sort is based on the size of the input data files used in this file rewriter. Due to + * compression, the disk file sizes may not accurately represent the size of files in the output. + * This parameter lets the user adjust the file size used for estimating actual output data size. + * A factor greater than 1.0 would generate more files than we would expect based on the on-disk + * file size. A value less than 1.0 would create fewer files than we would expect based on the + * on-disk size. + */ + public static final String COMPRESSION_FACTOR = "compression-factor"; + + public static final double COMPRESSION_FACTOR_DEFAULT = 1.0; + + private double compressionFactor; + + protected SparkShufflingDataRewriter(SparkSession spark, Table table) { + super(spark, table); + } + + protected abstract Dataset sortedDF(Dataset df, List group); + + @Override + public Set validOptions() { + return ImmutableSet.builder() + .addAll(super.validOptions()) + .add(COMPRESSION_FACTOR) + .build(); + } + + @Override + public void init(Map options) { + super.init(options); + this.compressionFactor = compressionFactor(options); + } + + @Override + public void doRewrite(String groupId, List group) { + // the number of shuffle partition controls the number of output files + spark().conf().set(SQLConf.SHUFFLE_PARTITIONS().key(), numShufflePartitions(group)); + + Dataset scanDF = + spark() + .read() + .format("iceberg") + .option(SparkReadOptions.SCAN_TASK_SET_ID, groupId) + .load(groupId); + + Dataset sortedDF = sortedDF(scanDF, group); + + sortedDF + .write() + .format("iceberg") + .option(SparkWriteOptions.REWRITTEN_FILE_SCAN_TASK_SET_ID, groupId) + .option(SparkWriteOptions.TARGET_FILE_SIZE_BYTES, writeMaxFileSize()) + .option(SparkWriteOptions.USE_TABLE_DISTRIBUTION_AND_ORDERING, "false") + .mode("append") + .save(groupId); + } + + protected Dataset sort(Dataset df, org.apache.iceberg.SortOrder sortOrder) { + SortOrder[] ordering = SparkDistributionAndOrderingUtil.convert(sortOrder); + OrderedDistribution distribution = Distributions.ordered(ordering); + SQLConf conf = spark().sessionState().conf(); + LogicalPlan plan = df.logicalPlan(); + LogicalPlan sortPlan = + DistributionAndOrderingUtils$.MODULE$.prepareQuery(distribution, ordering, plan, conf); + return new Dataset<>(spark(), sortPlan, df.encoder()); + } + + protected org.apache.iceberg.SortOrder outputSortOrder( + List group, org.apache.iceberg.SortOrder sortOrder) { + boolean includePartitionColumns = !group.get(0).spec().equals(table().spec()); + if (includePartitionColumns) { + // build in the requirement for partition sorting into our sort order + // as the original spec for this group does not match the output spec + return SortOrderUtil.buildSortOrder(table(), sortOrder); + } else { + return sortOrder; + } + } + + private long numShufflePartitions(List group) { + long numOutputFiles = numOutputFiles((long) (inputSize(group) * compressionFactor)); + return Math.max(1, numOutputFiles); + } + + private double compressionFactor(Map options) { + double value = + PropertyUtil.propertyAsDouble(options, COMPRESSION_FACTOR, COMPRESSION_FACTOR_DEFAULT); + Preconditions.checkArgument( + value > 0, "'%s' is set to %s but must be > 0", COMPRESSION_FACTOR, value); + return value; + } +} diff --git a/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/actions/SparkSizeBasedDataRewriter.java b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/actions/SparkSizeBasedDataRewriter.java new file mode 100644 index 000000000000..d40cbbb871b3 --- /dev/null +++ b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/actions/SparkSizeBasedDataRewriter.java @@ -0,0 +1,67 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.actions; + +import java.util.List; +import java.util.Set; +import java.util.UUID; +import org.apache.iceberg.DataFile; +import org.apache.iceberg.FileScanTask; +import org.apache.iceberg.Table; +import org.apache.iceberg.actions.SizeBasedDataRewriter; +import org.apache.iceberg.spark.FileRewriteCoordinator; +import org.apache.iceberg.spark.ScanTaskSetManager; +import org.apache.iceberg.spark.SparkTableCache; +import org.apache.spark.sql.SparkSession; + +abstract class SparkSizeBasedDataRewriter extends SizeBasedDataRewriter { + + private final SparkSession spark; + private final SparkTableCache tableCache = SparkTableCache.get(); + private final ScanTaskSetManager taskSetManager = ScanTaskSetManager.get(); + private final FileRewriteCoordinator coordinator = FileRewriteCoordinator.get(); + + SparkSizeBasedDataRewriter(SparkSession spark, Table table) { + super(table); + this.spark = spark; + } + + protected abstract void doRewrite(String groupId, List group); + + protected SparkSession spark() { + return spark; + } + + @Override + public Set rewrite(List group) { + String groupId = UUID.randomUUID().toString(); + try { + tableCache.add(groupId, table()); + taskSetManager.stageTasks(table(), groupId, group); + + doRewrite(groupId, group); + + return coordinator.fetchNewDataFiles(table(), groupId); + } finally { + tableCache.remove(groupId); + taskSetManager.removeTasks(table(), groupId); + coordinator.clearRewrite(table(), groupId); + } + } +} diff --git a/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/actions/SparkSortDataRewriter.java b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/actions/SparkSortDataRewriter.java new file mode 100644 index 000000000000..4615f3cebc92 --- /dev/null +++ b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/actions/SparkSortDataRewriter.java @@ -0,0 +1,60 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.actions; + +import java.util.List; +import org.apache.iceberg.FileScanTask; +import org.apache.iceberg.SortOrder; +import org.apache.iceberg.Table; +import org.apache.iceberg.relocated.com.google.common.base.Preconditions; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.SparkSession; + +class SparkSortDataRewriter extends SparkShufflingDataRewriter { + + private final SortOrder sortOrder; + + SparkSortDataRewriter(SparkSession spark, Table table) { + super(spark, table); + Preconditions.checkArgument( + table.sortOrder().isSorted(), + "Cannot sort data without a valid sort order, table '%s' is unsorted and no sort order is provided", + table.name()); + this.sortOrder = table.sortOrder(); + } + + SparkSortDataRewriter(SparkSession spark, Table table, SortOrder sortOrder) { + super(spark, table); + Preconditions.checkArgument( + sortOrder != null && sortOrder.isSorted(), + "Cannot sort data without a valid sort order, the provided sort order is null or empty"); + this.sortOrder = sortOrder; + } + + @Override + public String description() { + return "SORT"; + } + + @Override + protected Dataset sortedDF(Dataset df, List group) { + return sort(df, outputSortOrder(group, sortOrder)); + } +} diff --git a/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/actions/SparkSortStrategy.java b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/actions/SparkSortStrategy.java new file mode 100644 index 000000000000..21e29263c925 --- /dev/null +++ b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/actions/SparkSortStrategy.java @@ -0,0 +1,187 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.actions; + +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.UUID; +import org.apache.iceberg.DataFile; +import org.apache.iceberg.FileScanTask; +import org.apache.iceberg.Table; +import org.apache.iceberg.actions.RewriteStrategy; +import org.apache.iceberg.actions.SortStrategy; +import org.apache.iceberg.relocated.com.google.common.base.Preconditions; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableSet; +import org.apache.iceberg.spark.FileRewriteCoordinator; +import org.apache.iceberg.spark.ScanTaskSetManager; +import org.apache.iceberg.spark.SparkDistributionAndOrderingUtil; +import org.apache.iceberg.spark.SparkReadOptions; +import org.apache.iceberg.spark.SparkTableCache; +import org.apache.iceberg.spark.SparkWriteOptions; +import org.apache.iceberg.util.PropertyUtil; +import org.apache.iceberg.util.SortOrderUtil; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.SparkSession; +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan; +import org.apache.spark.sql.catalyst.utils.DistributionAndOrderingUtils$; +import org.apache.spark.sql.connector.distributions.Distribution; +import org.apache.spark.sql.connector.distributions.Distributions; +import org.apache.spark.sql.connector.expressions.SortOrder; +import org.apache.spark.sql.internal.SQLConf; + +/** + * A Spark strategy to sort data. + * + * @deprecated since 1.3.0, will be removed in 1.4.0; use {@link SparkSortDataRewriter} instead. + */ +@Deprecated +public class SparkSortStrategy extends SortStrategy { + + /** + * The number of shuffle partitions and consequently the number of output files created by the + * Spark Sort is based on the size of the input data files used in this rewrite operation. Due to + * compression, the disk file sizes may not accurately represent the size of files in the output. + * This parameter lets the user adjust the file size used for estimating actual output data size. + * A factor greater than 1.0 would generate more files than we would expect based on the on-disk + * file size. A value less than 1.0 would create fewer files than we would expect due to the + * on-disk size. + */ + public static final String COMPRESSION_FACTOR = "compression-factor"; + + private final Table table; + private final SparkSession spark; + private final SparkTableCache tableCache = SparkTableCache.get(); + private final ScanTaskSetManager manager = ScanTaskSetManager.get(); + private final FileRewriteCoordinator rewriteCoordinator = FileRewriteCoordinator.get(); + + private double sizeEstimateMultiple; + + public SparkSortStrategy(Table table, SparkSession spark) { + this.table = table; + this.spark = spark; + } + + @Override + public Table table() { + return table; + } + + @Override + public Set validOptions() { + return ImmutableSet.builder() + .addAll(super.validOptions()) + .add(COMPRESSION_FACTOR) + .build(); + } + + @Override + public RewriteStrategy options(Map options) { + sizeEstimateMultiple = PropertyUtil.propertyAsDouble(options, COMPRESSION_FACTOR, 1.0); + + Preconditions.checkArgument( + sizeEstimateMultiple > 0, + "Invalid compression factor: %s (not positive)", + sizeEstimateMultiple); + + return super.options(options); + } + + @Override + public Set rewriteFiles(List filesToRewrite) { + String groupID = UUID.randomUUID().toString(); + boolean requiresRepartition = !filesToRewrite.get(0).spec().equals(table.spec()); + + SortOrder[] ordering; + if (requiresRepartition) { + // Build in the requirement for Partition Sorting into our sort order + ordering = + SparkDistributionAndOrderingUtil.convert( + SortOrderUtil.buildSortOrder(table, sortOrder())); + } else { + ordering = SparkDistributionAndOrderingUtil.convert(sortOrder()); + } + + Distribution distribution = Distributions.ordered(ordering); + + try { + tableCache.add(groupID, table); + manager.stageTasks(table, groupID, filesToRewrite); + + // Reset Shuffle Partitions for our sort + long numOutputFiles = + numOutputFiles((long) (inputFileSize(filesToRewrite) * sizeEstimateMultiple)); + spark.conf().set(SQLConf.SHUFFLE_PARTITIONS().key(), Math.max(1, numOutputFiles)); + + Dataset scanDF = + spark + .read() + .format("iceberg") + .option(SparkReadOptions.SCAN_TASK_SET_ID, groupID) + .load(groupID); + + // write the packed data into new files where each split becomes a new file + SQLConf sqlConf = spark.sessionState().conf(); + LogicalPlan sortPlan = sortPlan(distribution, ordering, scanDF.logicalPlan(), sqlConf); + Dataset sortedDf = new Dataset<>(spark, sortPlan, scanDF.encoder()); + + sortedDf + .write() + .format("iceberg") + .option(SparkWriteOptions.REWRITTEN_FILE_SCAN_TASK_SET_ID, groupID) + .option(SparkWriteOptions.TARGET_FILE_SIZE_BYTES, writeMaxFileSize()) + .option(SparkWriteOptions.USE_TABLE_DISTRIBUTION_AND_ORDERING, "false") + .mode("append") // This will only write files without modifying the table, see + // SparkWrite.RewriteFiles + .save(groupID); + + return rewriteCoordinator.fetchNewDataFiles(table, groupID); + } finally { + tableCache.remove(groupID); + manager.removeTasks(table, groupID); + rewriteCoordinator.clearRewrite(table, groupID); + } + } + + protected SparkSession spark() { + return this.spark; + } + + protected LogicalPlan sortPlan( + Distribution distribution, SortOrder[] ordering, LogicalPlan plan, SQLConf conf) { + return DistributionAndOrderingUtils$.MODULE$.prepareQuery(distribution, ordering, plan, conf); + } + + protected double sizeEstimateMultiple() { + return sizeEstimateMultiple; + } + + protected SparkTableCache tableCache() { + return tableCache; + } + + protected ScanTaskSetManager manager() { + return manager; + } + + protected FileRewriteCoordinator rewriteCoordinator() { + return rewriteCoordinator; + } +} diff --git a/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/actions/SparkZOrderDataRewriter.java b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/actions/SparkZOrderDataRewriter.java new file mode 100644 index 000000000000..68db76d37fcb --- /dev/null +++ b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/actions/SparkZOrderDataRewriter.java @@ -0,0 +1,184 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.actions; + +import static org.apache.spark.sql.functions.array; + +import java.util.List; +import java.util.Map; +import java.util.Set; +import org.apache.iceberg.FileScanTask; +import org.apache.iceberg.NullOrder; +import org.apache.iceberg.Schema; +import org.apache.iceberg.SortDirection; +import org.apache.iceberg.SortOrder; +import org.apache.iceberg.Table; +import org.apache.iceberg.relocated.com.google.common.base.Preconditions; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableSet; +import org.apache.iceberg.relocated.com.google.common.collect.Lists; +import org.apache.iceberg.spark.SparkUtil; +import org.apache.iceberg.types.Types; +import org.apache.iceberg.util.PropertyUtil; +import org.apache.iceberg.util.ZOrderByteUtils; +import org.apache.spark.sql.Column; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.SparkSession; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +class SparkZOrderDataRewriter extends SparkShufflingDataRewriter { + + private static final Logger LOG = LoggerFactory.getLogger(SparkZOrderDataRewriter.class); + + private static final String Z_COLUMN = "ICEZVALUE"; + private static final Schema Z_SCHEMA = + new Schema(Types.NestedField.required(0, Z_COLUMN, Types.BinaryType.get())); + private static final SortOrder Z_SORT_ORDER = + SortOrder.builderFor(Z_SCHEMA) + .sortBy(Z_COLUMN, SortDirection.ASC, NullOrder.NULLS_LAST) + .build(); + + /** + * Controls the amount of bytes interleaved in the ZOrder algorithm. Default is all bytes being + * interleaved. + */ + public static final String MAX_OUTPUT_SIZE = "max-output-size"; + + public static final int MAX_OUTPUT_SIZE_DEFAULT = Integer.MAX_VALUE; + + /** + * Controls the number of bytes considered from an input column of a type with variable length + * (String, Binary). + * + *

Default is to use the same size as primitives {@link ZOrderByteUtils#PRIMITIVE_BUFFER_SIZE}. + */ + public static final String VAR_LENGTH_CONTRIBUTION = "var-length-contribution"; + + public static final int VAR_LENGTH_CONTRIBUTION_DEFAULT = ZOrderByteUtils.PRIMITIVE_BUFFER_SIZE; + + private final List zOrderColNames; + private int maxOutputSize; + private int varLengthContribution; + + SparkZOrderDataRewriter(SparkSession spark, Table table, List zOrderColNames) { + super(spark, table); + this.zOrderColNames = validZOrderColNames(spark, table, zOrderColNames); + } + + @Override + public String description() { + return "Z-ORDER"; + } + + @Override + public Set validOptions() { + return ImmutableSet.builder() + .addAll(super.validOptions()) + .add(MAX_OUTPUT_SIZE) + .add(VAR_LENGTH_CONTRIBUTION) + .build(); + } + + @Override + public void init(Map options) { + super.init(options); + this.maxOutputSize = maxOutputSize(options); + this.varLengthContribution = varLengthContribution(options); + } + + @Override + protected Dataset sortedDF(Dataset df, List group) { + Dataset zValueDF = df.withColumn(Z_COLUMN, zValue(df)); + Dataset sortedDF = sort(zValueDF, outputSortOrder(group, Z_SORT_ORDER)); + return sortedDF.drop(Z_COLUMN); + } + + private Column zValue(Dataset df) { + SparkZOrderUDF zOrderUDF = + new SparkZOrderUDF(zOrderColNames.size(), varLengthContribution, maxOutputSize); + + Column[] zOrderCols = + zOrderColNames.stream() + .map(df.schema()::apply) + .map(col -> zOrderUDF.sortedLexicographically(df.col(col.name()), col.dataType())) + .toArray(Column[]::new); + + return zOrderUDF.interleaveBytes(array(zOrderCols)); + } + + private int varLengthContribution(Map options) { + int value = + PropertyUtil.propertyAsInt( + options, VAR_LENGTH_CONTRIBUTION, VAR_LENGTH_CONTRIBUTION_DEFAULT); + Preconditions.checkArgument( + value > 0, + "Cannot use less than 1 byte for variable length types with ZOrder, '%s' was set to %s", + VAR_LENGTH_CONTRIBUTION, + value); + return value; + } + + private int maxOutputSize(Map options) { + int value = PropertyUtil.propertyAsInt(options, MAX_OUTPUT_SIZE, MAX_OUTPUT_SIZE_DEFAULT); + Preconditions.checkArgument( + value > 0, + "Cannot have the interleaved ZOrder value use less than 1 byte, '%s' was set to %s", + MAX_OUTPUT_SIZE, + value); + return value; + } + + private List validZOrderColNames( + SparkSession spark, Table table, List inputZOrderColNames) { + + Preconditions.checkArgument( + inputZOrderColNames != null && !inputZOrderColNames.isEmpty(), + "Cannot ZOrder when no columns are specified"); + + Schema schema = table.schema(); + Set identityPartitionFieldIds = table.spec().identitySourceIds(); + boolean caseSensitive = SparkUtil.caseSensitive(spark); + + List validZOrderColNames = Lists.newArrayList(); + + for (String colName : inputZOrderColNames) { + Types.NestedField field = + caseSensitive ? schema.findField(colName) : schema.caseInsensitiveFindField(colName); + Preconditions.checkArgument( + field != null, + "Cannot find column '%s' in table schema (case sensitive = %s): %s", + colName, + caseSensitive, + schema.asStruct()); + + if (identityPartitionFieldIds.contains(field.fieldId())) { + LOG.warn("Ignoring '{}' as such values are constant within a partition", colName); + } else { + validZOrderColNames.add(colName); + } + } + + Preconditions.checkArgument( + validZOrderColNames.size() > 0, + "Cannot ZOrder, all columns provided were identity partition columns and cannot be used"); + + return validZOrderColNames; + } +} diff --git a/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/actions/SparkZOrderStrategy.java b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/actions/SparkZOrderStrategy.java new file mode 100644 index 000000000000..26d2b4837b4b --- /dev/null +++ b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/actions/SparkZOrderStrategy.java @@ -0,0 +1,268 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.actions; + +import java.util.Arrays; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.UUID; +import java.util.stream.Collectors; +import java.util.stream.Stream; +import org.apache.iceberg.DataFile; +import org.apache.iceberg.FileScanTask; +import org.apache.iceberg.NullOrder; +import org.apache.iceberg.PartitionField; +import org.apache.iceberg.Schema; +import org.apache.iceberg.SortDirection; +import org.apache.iceberg.Table; +import org.apache.iceberg.actions.RewriteStrategy; +import org.apache.iceberg.relocated.com.google.common.base.Preconditions; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableSet; +import org.apache.iceberg.spark.SparkDistributionAndOrderingUtil; +import org.apache.iceberg.spark.SparkReadOptions; +import org.apache.iceberg.spark.SparkUtil; +import org.apache.iceberg.spark.SparkWriteOptions; +import org.apache.iceberg.types.Types; +import org.apache.iceberg.types.Types.NestedField; +import org.apache.iceberg.util.PropertyUtil; +import org.apache.iceberg.util.SortOrderUtil; +import org.apache.iceberg.util.ZOrderByteUtils; +import org.apache.spark.sql.Column; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.SparkSession; +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan; +import org.apache.spark.sql.connector.distributions.Distribution; +import org.apache.spark.sql.connector.distributions.Distributions; +import org.apache.spark.sql.connector.expressions.SortOrder; +import org.apache.spark.sql.functions; +import org.apache.spark.sql.internal.SQLConf; +import org.apache.spark.sql.types.StructField; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * A Spark strategy to zOrder data. + * + * @deprecated since 1.3.0, will be removed in 1.4.0; use {@link SparkZOrderDataRewriter} instead. + */ +@Deprecated +public class SparkZOrderStrategy extends SparkSortStrategy { + private static final Logger LOG = LoggerFactory.getLogger(SparkZOrderStrategy.class); + + private static final String Z_COLUMN = "ICEZVALUE"; + private static final Schema Z_SCHEMA = + new Schema(NestedField.required(0, Z_COLUMN, Types.BinaryType.get())); + private static final org.apache.iceberg.SortOrder Z_SORT_ORDER = + org.apache.iceberg.SortOrder.builderFor(Z_SCHEMA) + .sortBy(Z_COLUMN, SortDirection.ASC, NullOrder.NULLS_LAST) + .build(); + + /** + * Controls the amount of bytes interleaved in the ZOrder Algorithm. Default is all bytes being + * interleaved. + */ + private static final String MAX_OUTPUT_SIZE_KEY = "max-output-size"; + + private static final int DEFAULT_MAX_OUTPUT_SIZE = Integer.MAX_VALUE; + + /** + * Controls the number of bytes considered from an input column of a type with variable length + * (String, Binary). Default is to use the same size as primitives {@link + * ZOrderByteUtils#PRIMITIVE_BUFFER_SIZE} + */ + private static final String VAR_LENGTH_CONTRIBUTION_KEY = "var-length-contribution"; + + private static final int DEFAULT_VAR_LENGTH_CONTRIBUTION = ZOrderByteUtils.PRIMITIVE_BUFFER_SIZE; + + private final List zOrderColNames; + + private int maxOutputSize; + private int varLengthContribution; + + @Override + public Set validOptions() { + return ImmutableSet.builder() + .addAll(super.validOptions()) + .add(VAR_LENGTH_CONTRIBUTION_KEY) + .add(MAX_OUTPUT_SIZE_KEY) + .build(); + } + + @Override + public RewriteStrategy options(Map options) { + super.options(options); + + varLengthContribution = + PropertyUtil.propertyAsInt( + options, VAR_LENGTH_CONTRIBUTION_KEY, DEFAULT_VAR_LENGTH_CONTRIBUTION); + Preconditions.checkArgument( + varLengthContribution > 0, + "Cannot use less than 1 byte for variable length types with zOrder, %s was set to %s", + VAR_LENGTH_CONTRIBUTION_KEY, + varLengthContribution); + + maxOutputSize = + PropertyUtil.propertyAsInt(options, MAX_OUTPUT_SIZE_KEY, DEFAULT_MAX_OUTPUT_SIZE); + Preconditions.checkArgument( + maxOutputSize > 0, + "Cannot have the interleaved ZOrder value use less than 1 byte, %s was set to %s", + MAX_OUTPUT_SIZE_KEY, + maxOutputSize); + + return this; + } + + public SparkZOrderStrategy(Table table, SparkSession spark, List zOrderColNames) { + super(table, spark); + + Preconditions.checkArgument( + zOrderColNames != null && !zOrderColNames.isEmpty(), + "Cannot ZOrder when no columns are specified"); + + Stream identityPartitionColumns = + table.spec().fields().stream() + .filter(f -> f.transform().isIdentity()) + .map(PartitionField::name); + List partZOrderCols = + identityPartitionColumns.filter(zOrderColNames::contains).collect(Collectors.toList()); + + if (!partZOrderCols.isEmpty()) { + LOG.warn( + "Cannot ZOrder on an Identity partition column as these values are constant within a partition " + + "and will be removed from the ZOrder expression: {}", + partZOrderCols); + zOrderColNames.removeAll(partZOrderCols); + Preconditions.checkArgument( + !zOrderColNames.isEmpty(), + "Cannot perform ZOrdering, all columns provided were identity partition columns and cannot be used."); + } + + validateColumnsExistence(table, spark, zOrderColNames); + + this.zOrderColNames = zOrderColNames; + } + + private void validateColumnsExistence(Table table, SparkSession spark, List colNames) { + boolean caseSensitive = SparkUtil.caseSensitive(spark); + Schema schema = table.schema(); + colNames.forEach( + col -> { + NestedField nestedField = + caseSensitive ? schema.findField(col) : schema.caseInsensitiveFindField(col); + if (nestedField == null) { + throw new IllegalArgumentException( + String.format( + "Cannot find column '%s' in table schema: %s", col, schema.asStruct())); + } + }); + } + + @Override + public String name() { + return "Z-ORDER"; + } + + @Override + protected void validateOptions() { + // Ignore SortStrategy validation + return; + } + + @Override + public Set rewriteFiles(List filesToRewrite) { + SparkZOrderUDF zOrderUDF = + new SparkZOrderUDF(zOrderColNames.size(), varLengthContribution, maxOutputSize); + + String groupID = UUID.randomUUID().toString(); + boolean requiresRepartition = !filesToRewrite.get(0).spec().equals(table().spec()); + + SortOrder[] ordering; + if (requiresRepartition) { + ordering = + SparkDistributionAndOrderingUtil.convert( + SortOrderUtil.buildSortOrder(table(), sortOrder())); + } else { + ordering = SparkDistributionAndOrderingUtil.convert(sortOrder()); + } + + Distribution distribution = Distributions.ordered(ordering); + + try { + tableCache().add(groupID, table()); + manager().stageTasks(table(), groupID, filesToRewrite); + + // spark session from parent + SparkSession spark = spark(); + // Reset Shuffle Partitions for our sort + long numOutputFiles = + numOutputFiles((long) (inputFileSize(filesToRewrite) * sizeEstimateMultiple())); + spark.conf().set(SQLConf.SHUFFLE_PARTITIONS().key(), Math.max(1, numOutputFiles)); + + Dataset scanDF = + spark + .read() + .format("iceberg") + .option(SparkReadOptions.SCAN_TASK_SET_ID, groupID) + .load(groupID); + + Column[] originalColumns = + Arrays.stream(scanDF.schema().names()).map(n -> functions.col(n)).toArray(Column[]::new); + + List zOrderColumns = + zOrderColNames.stream().map(scanDF.schema()::apply).collect(Collectors.toList()); + + Column zvalueArray = + functions.array( + zOrderColumns.stream() + .map( + colStruct -> + zOrderUDF.sortedLexicographically( + functions.col(colStruct.name()), colStruct.dataType())) + .toArray(Column[]::new)); + + Dataset zvalueDF = scanDF.withColumn(Z_COLUMN, zOrderUDF.interleaveBytes(zvalueArray)); + + SQLConf sqlConf = spark.sessionState().conf(); + LogicalPlan sortPlan = sortPlan(distribution, ordering, zvalueDF.logicalPlan(), sqlConf); + Dataset sortedDf = new Dataset<>(spark, sortPlan, zvalueDF.encoder()); + sortedDf + .select(originalColumns) + .write() + .format("iceberg") + .option(SparkWriteOptions.REWRITTEN_FILE_SCAN_TASK_SET_ID, groupID) + .option(SparkWriteOptions.TARGET_FILE_SIZE_BYTES, writeMaxFileSize()) + .option(SparkWriteOptions.USE_TABLE_DISTRIBUTION_AND_ORDERING, "false") + .mode("append") + .save(groupID); + + return rewriteCoordinator().fetchNewDataFiles(table(), groupID); + } finally { + tableCache().remove(groupID); + manager().removeTasks(table(), groupID); + rewriteCoordinator().clearRewrite(table(), groupID); + } + } + + @Override + protected org.apache.iceberg.SortOrder sortOrder() { + return Z_SORT_ORDER; + } +} diff --git a/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/actions/SparkZOrderUDF.java b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/actions/SparkZOrderUDF.java new file mode 100644 index 000000000000..db359fdd62fc --- /dev/null +++ b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/actions/SparkZOrderUDF.java @@ -0,0 +1,325 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.actions; + +import java.io.IOException; +import java.io.ObjectInputStream; +import java.io.Serializable; +import java.nio.ByteBuffer; +import java.nio.charset.CharsetEncoder; +import java.nio.charset.StandardCharsets; +import org.apache.iceberg.util.ZOrderByteUtils; +import org.apache.spark.sql.Column; +import org.apache.spark.sql.expressions.UserDefinedFunction; +import org.apache.spark.sql.functions; +import org.apache.spark.sql.types.BinaryType; +import org.apache.spark.sql.types.BooleanType; +import org.apache.spark.sql.types.ByteType; +import org.apache.spark.sql.types.DataType; +import org.apache.spark.sql.types.DataTypes; +import org.apache.spark.sql.types.DateType; +import org.apache.spark.sql.types.DoubleType; +import org.apache.spark.sql.types.FloatType; +import org.apache.spark.sql.types.IntegerType; +import org.apache.spark.sql.types.LongType; +import org.apache.spark.sql.types.ShortType; +import org.apache.spark.sql.types.StringType; +import org.apache.spark.sql.types.TimestampType; +import scala.collection.JavaConverters; +import scala.collection.Seq; + +class SparkZOrderUDF implements Serializable { + private static final byte[] PRIMITIVE_EMPTY = new byte[ZOrderByteUtils.PRIMITIVE_BUFFER_SIZE]; + + /** + * Every Spark task runs iteratively on a rows in a single thread so ThreadLocal should protect + * from concurrent access to any of these structures. + */ + private transient ThreadLocal outputBuffer; + + private transient ThreadLocal inputHolder; + private transient ThreadLocal inputBuffers; + private transient ThreadLocal encoder; + + private final int numCols; + + private int inputCol = 0; + private int totalOutputBytes = 0; + private final int varTypeSize; + private final int maxOutputSize; + + SparkZOrderUDF(int numCols, int varTypeSize, int maxOutputSize) { + this.numCols = numCols; + this.varTypeSize = varTypeSize; + this.maxOutputSize = maxOutputSize; + } + + private void readObject(ObjectInputStream in) throws IOException, ClassNotFoundException { + in.defaultReadObject(); + inputBuffers = ThreadLocal.withInitial(() -> new ByteBuffer[numCols]); + inputHolder = ThreadLocal.withInitial(() -> new byte[numCols][]); + outputBuffer = ThreadLocal.withInitial(() -> ByteBuffer.allocate(totalOutputBytes)); + encoder = ThreadLocal.withInitial(() -> StandardCharsets.UTF_8.newEncoder()); + } + + private ByteBuffer inputBuffer(int position, int size) { + ByteBuffer buffer = inputBuffers.get()[position]; + if (buffer == null) { + buffer = ByteBuffer.allocate(size); + inputBuffers.get()[position] = buffer; + } + return buffer; + } + + byte[] interleaveBits(Seq scalaBinary) { + byte[][] columnsBinary = JavaConverters.seqAsJavaList(scalaBinary).toArray(inputHolder.get()); + return ZOrderByteUtils.interleaveBits(columnsBinary, totalOutputBytes, outputBuffer.get()); + } + + private UserDefinedFunction tinyToOrderedBytesUDF() { + int position = inputCol; + UserDefinedFunction udf = + functions + .udf( + (Byte value) -> { + if (value == null) { + return PRIMITIVE_EMPTY; + } + return ZOrderByteUtils.tinyintToOrderedBytes( + value, inputBuffer(position, ZOrderByteUtils.PRIMITIVE_BUFFER_SIZE)) + .array(); + }, + DataTypes.BinaryType) + .withName("TINY_ORDERED_BYTES"); + + this.inputCol++; + increaseOutputSize(ZOrderByteUtils.PRIMITIVE_BUFFER_SIZE); + + return udf; + } + + private UserDefinedFunction shortToOrderedBytesUDF() { + int position = inputCol; + UserDefinedFunction udf = + functions + .udf( + (Short value) -> { + if (value == null) { + return PRIMITIVE_EMPTY; + } + return ZOrderByteUtils.shortToOrderedBytes( + value, inputBuffer(position, ZOrderByteUtils.PRIMITIVE_BUFFER_SIZE)) + .array(); + }, + DataTypes.BinaryType) + .withName("SHORT_ORDERED_BYTES"); + + this.inputCol++; + increaseOutputSize(ZOrderByteUtils.PRIMITIVE_BUFFER_SIZE); + + return udf; + } + + private UserDefinedFunction intToOrderedBytesUDF() { + int position = inputCol; + UserDefinedFunction udf = + functions + .udf( + (Integer value) -> { + if (value == null) { + return PRIMITIVE_EMPTY; + } + return ZOrderByteUtils.intToOrderedBytes( + value, inputBuffer(position, ZOrderByteUtils.PRIMITIVE_BUFFER_SIZE)) + .array(); + }, + DataTypes.BinaryType) + .withName("INT_ORDERED_BYTES"); + + this.inputCol++; + increaseOutputSize(ZOrderByteUtils.PRIMITIVE_BUFFER_SIZE); + + return udf; + } + + private UserDefinedFunction longToOrderedBytesUDF() { + int position = inputCol; + UserDefinedFunction udf = + functions + .udf( + (Long value) -> { + if (value == null) { + return PRIMITIVE_EMPTY; + } + return ZOrderByteUtils.longToOrderedBytes( + value, inputBuffer(position, ZOrderByteUtils.PRIMITIVE_BUFFER_SIZE)) + .array(); + }, + DataTypes.BinaryType) + .withName("LONG_ORDERED_BYTES"); + + this.inputCol++; + increaseOutputSize(ZOrderByteUtils.PRIMITIVE_BUFFER_SIZE); + + return udf; + } + + private UserDefinedFunction floatToOrderedBytesUDF() { + int position = inputCol; + UserDefinedFunction udf = + functions + .udf( + (Float value) -> { + if (value == null) { + return PRIMITIVE_EMPTY; + } + return ZOrderByteUtils.floatToOrderedBytes( + value, inputBuffer(position, ZOrderByteUtils.PRIMITIVE_BUFFER_SIZE)) + .array(); + }, + DataTypes.BinaryType) + .withName("FLOAT_ORDERED_BYTES"); + + this.inputCol++; + increaseOutputSize(ZOrderByteUtils.PRIMITIVE_BUFFER_SIZE); + + return udf; + } + + private UserDefinedFunction doubleToOrderedBytesUDF() { + int position = inputCol; + UserDefinedFunction udf = + functions + .udf( + (Double value) -> { + if (value == null) { + return PRIMITIVE_EMPTY; + } + return ZOrderByteUtils.doubleToOrderedBytes( + value, inputBuffer(position, ZOrderByteUtils.PRIMITIVE_BUFFER_SIZE)) + .array(); + }, + DataTypes.BinaryType) + .withName("DOUBLE_ORDERED_BYTES"); + + this.inputCol++; + increaseOutputSize(ZOrderByteUtils.PRIMITIVE_BUFFER_SIZE); + + return udf; + } + + private UserDefinedFunction booleanToOrderedBytesUDF() { + int position = inputCol; + UserDefinedFunction udf = + functions + .udf( + (Boolean value) -> { + ByteBuffer buffer = inputBuffer(position, ZOrderByteUtils.PRIMITIVE_BUFFER_SIZE); + buffer.put(0, (byte) (value ? -127 : 0)); + return buffer.array(); + }, + DataTypes.BinaryType) + .withName("BOOLEAN-LEXICAL-BYTES"); + + this.inputCol++; + increaseOutputSize(ZOrderByteUtils.PRIMITIVE_BUFFER_SIZE); + return udf; + } + + private UserDefinedFunction stringToOrderedBytesUDF() { + int position = inputCol; + UserDefinedFunction udf = + functions + .udf( + (String value) -> + ZOrderByteUtils.stringToOrderedBytes( + value, varTypeSize, inputBuffer(position, varTypeSize), encoder.get()) + .array(), + DataTypes.BinaryType) + .withName("STRING-LEXICAL-BYTES"); + + this.inputCol++; + increaseOutputSize(varTypeSize); + + return udf; + } + + private UserDefinedFunction bytesTruncateUDF() { + int position = inputCol; + UserDefinedFunction udf = + functions + .udf( + (byte[] value) -> + ZOrderByteUtils.byteTruncateOrFill( + value, varTypeSize, inputBuffer(position, varTypeSize)) + .array(), + DataTypes.BinaryType) + .withName("BYTE-TRUNCATE"); + + this.inputCol++; + increaseOutputSize(varTypeSize); + + return udf; + } + + private final UserDefinedFunction interleaveUDF = + functions + .udf((Seq arrayBinary) -> interleaveBits(arrayBinary), DataTypes.BinaryType) + .withName("INTERLEAVE_BYTES"); + + Column interleaveBytes(Column arrayBinary) { + return interleaveUDF.apply(arrayBinary); + } + + @SuppressWarnings("checkstyle:CyclomaticComplexity") + Column sortedLexicographically(Column column, DataType type) { + if (type instanceof ByteType) { + return tinyToOrderedBytesUDF().apply(column); + } else if (type instanceof ShortType) { + return shortToOrderedBytesUDF().apply(column); + } else if (type instanceof IntegerType) { + return intToOrderedBytesUDF().apply(column); + } else if (type instanceof LongType) { + return longToOrderedBytesUDF().apply(column); + } else if (type instanceof FloatType) { + return floatToOrderedBytesUDF().apply(column); + } else if (type instanceof DoubleType) { + return doubleToOrderedBytesUDF().apply(column); + } else if (type instanceof StringType) { + return stringToOrderedBytesUDF().apply(column); + } else if (type instanceof BinaryType) { + return bytesTruncateUDF().apply(column); + } else if (type instanceof BooleanType) { + return booleanToOrderedBytesUDF().apply(column); + } else if (type instanceof TimestampType) { + return longToOrderedBytesUDF().apply(column.cast(DataTypes.LongType)); + } else if (type instanceof DateType) { + return longToOrderedBytesUDF().apply(column.cast(DataTypes.LongType)); + } else { + throw new IllegalArgumentException( + String.format( + "Cannot use column %s of type %s in ZOrdering, the type is unsupported", + column, type)); + } + } + + private void increaseOutputSize(int bytes) { + totalOutputBytes = Math.min(totalOutputBytes + bytes, maxOutputSize); + } +} diff --git a/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/data/AvroWithSparkSchemaVisitor.java b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/data/AvroWithSparkSchemaVisitor.java new file mode 100644 index 000000000000..74454fc1e466 --- /dev/null +++ b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/data/AvroWithSparkSchemaVisitor.java @@ -0,0 +1,76 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.data; + +import org.apache.iceberg.avro.AvroWithPartnerByStructureVisitor; +import org.apache.iceberg.relocated.com.google.common.base.Preconditions; +import org.apache.iceberg.util.Pair; +import org.apache.spark.sql.types.ArrayType; +import org.apache.spark.sql.types.DataType; +import org.apache.spark.sql.types.DataTypes; +import org.apache.spark.sql.types.MapType; +import org.apache.spark.sql.types.StringType; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; + +public abstract class AvroWithSparkSchemaVisitor + extends AvroWithPartnerByStructureVisitor { + + @Override + protected boolean isStringType(DataType dataType) { + return dataType instanceof StringType; + } + + @Override + protected boolean isMapType(DataType dataType) { + return dataType instanceof MapType; + } + + @Override + protected DataType arrayElementType(DataType arrayType) { + Preconditions.checkArgument( + arrayType instanceof ArrayType, "Invalid array: %s is not an array", arrayType); + return ((ArrayType) arrayType).elementType(); + } + + @Override + protected DataType mapKeyType(DataType mapType) { + Preconditions.checkArgument(isMapType(mapType), "Invalid map: %s is not a map", mapType); + return ((MapType) mapType).keyType(); + } + + @Override + protected DataType mapValueType(DataType mapType) { + Preconditions.checkArgument(isMapType(mapType), "Invalid map: %s is not a map", mapType); + return ((MapType) mapType).valueType(); + } + + @Override + protected Pair fieldNameAndType(DataType structType, int pos) { + Preconditions.checkArgument( + structType instanceof StructType, "Invalid struct: %s is not a struct", structType); + StructField field = ((StructType) structType).apply(pos); + return Pair.of(field.name(), field.dataType()); + } + + @Override + protected DataType nullType() { + return DataTypes.NullType; + } +} diff --git a/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/data/ParquetWithSparkSchemaVisitor.java b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/data/ParquetWithSparkSchemaVisitor.java new file mode 100644 index 000000000000..d74a76f94e87 --- /dev/null +++ b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/data/ParquetWithSparkSchemaVisitor.java @@ -0,0 +1,231 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.data; + +import java.util.Deque; +import java.util.List; +import org.apache.iceberg.avro.AvroSchemaUtil; +import org.apache.iceberg.relocated.com.google.common.base.Preconditions; +import org.apache.iceberg.relocated.com.google.common.collect.Lists; +import org.apache.parquet.schema.GroupType; +import org.apache.parquet.schema.MessageType; +import org.apache.parquet.schema.OriginalType; +import org.apache.parquet.schema.PrimitiveType; +import org.apache.parquet.schema.Type; +import org.apache.parquet.schema.Type.Repetition; +import org.apache.spark.sql.types.ArrayType; +import org.apache.spark.sql.types.DataType; +import org.apache.spark.sql.types.MapType; +import org.apache.spark.sql.types.Metadata; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; + +/** + * Visitor for traversing a Parquet type with a companion Spark type. + * + * @param the Java class returned by the visitor + */ +public class ParquetWithSparkSchemaVisitor { + private final Deque fieldNames = Lists.newLinkedList(); + + public static T visit(DataType sType, Type type, ParquetWithSparkSchemaVisitor visitor) { + Preconditions.checkArgument(sType != null, "Invalid DataType: null"); + if (type instanceof MessageType) { + Preconditions.checkArgument( + sType instanceof StructType, "Invalid struct: %s is not a struct", sType); + StructType struct = (StructType) sType; + return visitor.message( + struct, (MessageType) type, visitFields(struct, type.asGroupType(), visitor)); + + } else if (type.isPrimitive()) { + return visitor.primitive(sType, type.asPrimitiveType()); + + } else { + // if not a primitive, the typeId must be a group + GroupType group = type.asGroupType(); + OriginalType annotation = group.getOriginalType(); + if (annotation != null) { + switch (annotation) { + case LIST: + Preconditions.checkArgument( + !group.isRepetition(Repetition.REPEATED), + "Invalid list: top-level group is repeated: %s", + group); + Preconditions.checkArgument( + group.getFieldCount() == 1, + "Invalid list: does not contain single repeated field: %s", + group); + + GroupType repeatedElement = group.getFields().get(0).asGroupType(); + Preconditions.checkArgument( + repeatedElement.isRepetition(Repetition.REPEATED), + "Invalid list: inner group is not repeated"); + Preconditions.checkArgument( + repeatedElement.getFieldCount() <= 1, + "Invalid list: repeated group is not a single field: %s", + group); + + Preconditions.checkArgument( + sType instanceof ArrayType, "Invalid list: %s is not an array", sType); + ArrayType array = (ArrayType) sType; + StructField element = + new StructField( + "element", array.elementType(), array.containsNull(), Metadata.empty()); + + visitor.fieldNames.push(repeatedElement.getName()); + try { + T elementResult = null; + if (repeatedElement.getFieldCount() > 0) { + elementResult = visitField(element, repeatedElement.getType(0), visitor); + } + + return visitor.list(array, group, elementResult); + + } finally { + visitor.fieldNames.pop(); + } + + case MAP: + Preconditions.checkArgument( + !group.isRepetition(Repetition.REPEATED), + "Invalid map: top-level group is repeated: %s", + group); + Preconditions.checkArgument( + group.getFieldCount() == 1, + "Invalid map: does not contain single repeated field: %s", + group); + + GroupType repeatedKeyValue = group.getType(0).asGroupType(); + Preconditions.checkArgument( + repeatedKeyValue.isRepetition(Repetition.REPEATED), + "Invalid map: inner group is not repeated"); + Preconditions.checkArgument( + repeatedKeyValue.getFieldCount() <= 2, + "Invalid map: repeated group does not have 2 fields"); + + Preconditions.checkArgument( + sType instanceof MapType, "Invalid map: %s is not a map", sType); + MapType map = (MapType) sType; + StructField keyField = new StructField("key", map.keyType(), false, Metadata.empty()); + StructField valueField = + new StructField( + "value", map.valueType(), map.valueContainsNull(), Metadata.empty()); + + visitor.fieldNames.push(repeatedKeyValue.getName()); + try { + T keyResult = null; + T valueResult = null; + switch (repeatedKeyValue.getFieldCount()) { + case 2: + // if there are 2 fields, both key and value are projected + keyResult = visitField(keyField, repeatedKeyValue.getType(0), visitor); + valueResult = visitField(valueField, repeatedKeyValue.getType(1), visitor); + break; + case 1: + // if there is just one, use the name to determine what it is + Type keyOrValue = repeatedKeyValue.getType(0); + if (keyOrValue.getName().equalsIgnoreCase("key")) { + keyResult = visitField(keyField, keyOrValue, visitor); + // value result remains null + } else { + valueResult = visitField(valueField, keyOrValue, visitor); + // key result remains null + } + break; + default: + // both results will remain null + } + + return visitor.map(map, group, keyResult, valueResult); + + } finally { + visitor.fieldNames.pop(); + } + + default: + } + } + + Preconditions.checkArgument( + sType instanceof StructType, "Invalid struct: %s is not a struct", sType); + StructType struct = (StructType) sType; + return visitor.struct(struct, group, visitFields(struct, group, visitor)); + } + } + + private static T visitField( + StructField sField, Type field, ParquetWithSparkSchemaVisitor visitor) { + visitor.fieldNames.push(field.getName()); + try { + return visit(sField.dataType(), field, visitor); + } finally { + visitor.fieldNames.pop(); + } + } + + private static List visitFields( + StructType struct, GroupType group, ParquetWithSparkSchemaVisitor visitor) { + StructField[] sFields = struct.fields(); + Preconditions.checkArgument( + sFields.length == group.getFieldCount(), "Structs do not match: %s and %s", struct, group); + List results = Lists.newArrayListWithExpectedSize(group.getFieldCount()); + for (int i = 0; i < sFields.length; i += 1) { + Type field = group.getFields().get(i); + StructField sField = sFields[i]; + Preconditions.checkArgument( + field.getName().equals(AvroSchemaUtil.makeCompatibleName(sField.name())), + "Structs do not match: field %s != %s", + field.getName(), + sField.name()); + results.add(visitField(sField, field, visitor)); + } + + return results; + } + + public T message(StructType sStruct, MessageType message, List fields) { + return null; + } + + public T struct(StructType sStruct, GroupType struct, List fields) { + return null; + } + + public T list(ArrayType sArray, GroupType array, T element) { + return null; + } + + public T map(MapType sMap, GroupType map, T key, T value) { + return null; + } + + public T primitive(DataType sPrimitive, PrimitiveType primitive) { + return null; + } + + protected String[] currentPath() { + return Lists.newArrayList(fieldNames.descendingIterator()).toArray(new String[0]); + } + + protected String[] path(String name) { + List list = Lists.newArrayList(fieldNames.descendingIterator()); + list.add(name); + return list.toArray(new String[0]); + } +} diff --git a/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/data/SparkAvroReader.java b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/data/SparkAvroReader.java new file mode 100644 index 000000000000..4622d2928ac4 --- /dev/null +++ b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/data/SparkAvroReader.java @@ -0,0 +1,168 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.data; + +import java.io.IOException; +import java.util.List; +import java.util.Map; +import java.util.function.Supplier; +import org.apache.avro.LogicalType; +import org.apache.avro.LogicalTypes; +import org.apache.avro.Schema; +import org.apache.avro.io.DatumReader; +import org.apache.avro.io.Decoder; +import org.apache.iceberg.avro.AvroSchemaWithTypeVisitor; +import org.apache.iceberg.avro.SupportsRowPosition; +import org.apache.iceberg.avro.ValueReader; +import org.apache.iceberg.avro.ValueReaders; +import org.apache.iceberg.data.avro.DecoderResolver; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap; +import org.apache.iceberg.types.Type; +import org.apache.iceberg.types.Types; +import org.apache.spark.sql.catalyst.InternalRow; + +public class SparkAvroReader implements DatumReader, SupportsRowPosition { + + private final Schema readSchema; + private final ValueReader reader; + private Schema fileSchema = null; + + public SparkAvroReader(org.apache.iceberg.Schema expectedSchema, Schema readSchema) { + this(expectedSchema, readSchema, ImmutableMap.of()); + } + + @SuppressWarnings("unchecked") + public SparkAvroReader( + org.apache.iceberg.Schema expectedSchema, Schema readSchema, Map constants) { + this.readSchema = readSchema; + this.reader = + (ValueReader) + AvroSchemaWithTypeVisitor.visit(expectedSchema, readSchema, new ReadBuilder(constants)); + } + + @Override + public void setSchema(Schema newFileSchema) { + this.fileSchema = Schema.applyAliases(newFileSchema, readSchema); + } + + @Override + public InternalRow read(InternalRow reuse, Decoder decoder) throws IOException { + return DecoderResolver.resolveAndRead(decoder, readSchema, fileSchema, reader, reuse); + } + + @Override + public void setRowPositionSupplier(Supplier posSupplier) { + if (reader instanceof SupportsRowPosition) { + ((SupportsRowPosition) reader).setRowPositionSupplier(posSupplier); + } + } + + private static class ReadBuilder extends AvroSchemaWithTypeVisitor> { + private final Map idToConstant; + + private ReadBuilder(Map idToConstant) { + this.idToConstant = idToConstant; + } + + @Override + public ValueReader record( + Types.StructType expected, Schema record, List names, List> fields) { + return SparkValueReaders.struct(fields, expected, idToConstant); + } + + @Override + public ValueReader union(Type expected, Schema union, List> options) { + return ValueReaders.union(options); + } + + @Override + public ValueReader array( + Types.ListType expected, Schema array, ValueReader elementReader) { + return SparkValueReaders.array(elementReader); + } + + @Override + public ValueReader map( + Types.MapType expected, Schema map, ValueReader keyReader, ValueReader valueReader) { + return SparkValueReaders.arrayMap(keyReader, valueReader); + } + + @Override + public ValueReader map(Types.MapType expected, Schema map, ValueReader valueReader) { + return SparkValueReaders.map(SparkValueReaders.strings(), valueReader); + } + + @Override + public ValueReader primitive(Type.PrimitiveType expected, Schema primitive) { + LogicalType logicalType = primitive.getLogicalType(); + if (logicalType != null) { + switch (logicalType.getName()) { + case "date": + // Spark uses the same representation + return ValueReaders.ints(); + + case "timestamp-millis": + // adjust to microseconds + ValueReader longs = ValueReaders.longs(); + return (ValueReader) (decoder, ignored) -> longs.read(decoder, null) * 1000L; + + case "timestamp-micros": + // Spark uses the same representation + return ValueReaders.longs(); + + case "decimal": + return SparkValueReaders.decimal( + ValueReaders.decimalBytesReader(primitive), + ((LogicalTypes.Decimal) logicalType).getScale()); + + case "uuid": + return SparkValueReaders.uuids(); + + default: + throw new IllegalArgumentException("Unknown logical type: " + logicalType); + } + } + + switch (primitive.getType()) { + case NULL: + return ValueReaders.nulls(); + case BOOLEAN: + return ValueReaders.booleans(); + case INT: + return ValueReaders.ints(); + case LONG: + return ValueReaders.longs(); + case FLOAT: + return ValueReaders.floats(); + case DOUBLE: + return ValueReaders.doubles(); + case STRING: + return SparkValueReaders.strings(); + case FIXED: + return ValueReaders.fixed(primitive.getFixedSize()); + case BYTES: + return ValueReaders.bytes(); + case ENUM: + return SparkValueReaders.enums(primitive.getEnumSymbols()); + default: + throw new IllegalArgumentException("Unsupported type: " + primitive); + } + } + } +} diff --git a/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/data/SparkAvroWriter.java b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/data/SparkAvroWriter.java new file mode 100644 index 000000000000..15465568c231 --- /dev/null +++ b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/data/SparkAvroWriter.java @@ -0,0 +1,165 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.data; + +import java.io.IOException; +import java.util.List; +import java.util.stream.Collectors; +import java.util.stream.IntStream; +import java.util.stream.Stream; +import org.apache.avro.LogicalType; +import org.apache.avro.LogicalTypes; +import org.apache.avro.Schema; +import org.apache.avro.io.Encoder; +import org.apache.iceberg.FieldMetrics; +import org.apache.iceberg.avro.MetricsAwareDatumWriter; +import org.apache.iceberg.avro.ValueWriter; +import org.apache.iceberg.avro.ValueWriters; +import org.apache.iceberg.relocated.com.google.common.base.Preconditions; +import org.apache.spark.sql.catalyst.InternalRow; +import org.apache.spark.sql.types.ByteType; +import org.apache.spark.sql.types.DataType; +import org.apache.spark.sql.types.ShortType; +import org.apache.spark.sql.types.StructType; + +public class SparkAvroWriter implements MetricsAwareDatumWriter { + private final StructType dsSchema; + private ValueWriter writer = null; + + public SparkAvroWriter(StructType dsSchema) { + this.dsSchema = dsSchema; + } + + @Override + @SuppressWarnings("unchecked") + public void setSchema(Schema schema) { + this.writer = + (ValueWriter) + AvroWithSparkSchemaVisitor.visit(dsSchema, schema, new WriteBuilder()); + } + + @Override + public void write(InternalRow datum, Encoder out) throws IOException { + writer.write(datum, out); + } + + @Override + public Stream metrics() { + return writer.metrics(); + } + + private static class WriteBuilder extends AvroWithSparkSchemaVisitor> { + @Override + public ValueWriter record( + DataType struct, Schema record, List names, List> fields) { + return SparkValueWriters.struct( + fields, + IntStream.range(0, names.size()) + .mapToObj(i -> fieldNameAndType(struct, i).second()) + .collect(Collectors.toList())); + } + + @Override + public ValueWriter union(DataType type, Schema union, List> options) { + Preconditions.checkArgument( + options.contains(ValueWriters.nulls()), + "Cannot create writer for non-option union: %s", + union); + Preconditions.checkArgument( + options.size() == 2, "Cannot create writer for non-option union: %s", union); + if (union.getTypes().get(0).getType() == Schema.Type.NULL) { + return ValueWriters.option(0, options.get(1)); + } else { + return ValueWriters.option(1, options.get(0)); + } + } + + @Override + public ValueWriter array(DataType sArray, Schema array, ValueWriter elementWriter) { + return SparkValueWriters.array(elementWriter, arrayElementType(sArray)); + } + + @Override + public ValueWriter map(DataType sMap, Schema map, ValueWriter valueReader) { + return SparkValueWriters.map( + SparkValueWriters.strings(), mapKeyType(sMap), valueReader, mapValueType(sMap)); + } + + @Override + public ValueWriter map( + DataType sMap, Schema map, ValueWriter keyWriter, ValueWriter valueWriter) { + return SparkValueWriters.arrayMap( + keyWriter, mapKeyType(sMap), valueWriter, mapValueType(sMap)); + } + + @Override + public ValueWriter primitive(DataType type, Schema primitive) { + LogicalType logicalType = primitive.getLogicalType(); + if (logicalType != null) { + switch (logicalType.getName()) { + case "date": + // Spark uses the same representation + return ValueWriters.ints(); + + case "timestamp-micros": + // Spark uses the same representation + return ValueWriters.longs(); + + case "decimal": + LogicalTypes.Decimal decimal = (LogicalTypes.Decimal) logicalType; + return SparkValueWriters.decimal(decimal.getPrecision(), decimal.getScale()); + + case "uuid": + return ValueWriters.uuids(); + + default: + throw new IllegalArgumentException("Unsupported logical type: " + logicalType); + } + } + + switch (primitive.getType()) { + case NULL: + return ValueWriters.nulls(); + case BOOLEAN: + return ValueWriters.booleans(); + case INT: + if (type instanceof ByteType) { + return ValueWriters.tinyints(); + } else if (type instanceof ShortType) { + return ValueWriters.shorts(); + } + return ValueWriters.ints(); + case LONG: + return ValueWriters.longs(); + case FLOAT: + return ValueWriters.floats(); + case DOUBLE: + return ValueWriters.doubles(); + case STRING: + return SparkValueWriters.strings(); + case FIXED: + return ValueWriters.fixed(primitive.getFixedSize()); + case BYTES: + return ValueWriters.bytes(); + default: + throw new IllegalArgumentException("Unsupported type: " + primitive); + } + } + } +} diff --git a/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/data/SparkOrcReader.java b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/data/SparkOrcReader.java new file mode 100644 index 000000000000..78db137054bc --- /dev/null +++ b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/data/SparkOrcReader.java @@ -0,0 +1,132 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.data; + +import java.util.List; +import java.util.Map; +import org.apache.iceberg.orc.OrcRowReader; +import org.apache.iceberg.orc.OrcSchemaWithTypeVisitor; +import org.apache.iceberg.orc.OrcValueReader; +import org.apache.iceberg.orc.OrcValueReaders; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap; +import org.apache.iceberg.types.Type; +import org.apache.iceberg.types.Types; +import org.apache.orc.TypeDescription; +import org.apache.orc.storage.ql.exec.vector.StructColumnVector; +import org.apache.orc.storage.ql.exec.vector.VectorizedRowBatch; +import org.apache.spark.sql.catalyst.InternalRow; + +/** + * Converts the OrcIterator, which returns ORC's VectorizedRowBatch to a set of Spark's UnsafeRows. + * + *

It minimizes allocations by reusing most of the objects in the implementation. + */ +public class SparkOrcReader implements OrcRowReader { + private final OrcValueReader reader; + + public SparkOrcReader(org.apache.iceberg.Schema expectedSchema, TypeDescription readSchema) { + this(expectedSchema, readSchema, ImmutableMap.of()); + } + + @SuppressWarnings("unchecked") + public SparkOrcReader( + org.apache.iceberg.Schema expectedSchema, + TypeDescription readOrcSchema, + Map idToConstant) { + this.reader = + OrcSchemaWithTypeVisitor.visit( + expectedSchema, readOrcSchema, new ReadBuilder(idToConstant)); + } + + @Override + public InternalRow read(VectorizedRowBatch batch, int row) { + return (InternalRow) reader.read(new StructColumnVector(batch.size, batch.cols), row); + } + + @Override + public void setBatchContext(long batchOffsetInFile) { + reader.setBatchContext(batchOffsetInFile); + } + + private static class ReadBuilder extends OrcSchemaWithTypeVisitor> { + private final Map idToConstant; + + private ReadBuilder(Map idToConstant) { + this.idToConstant = idToConstant; + } + + @Override + public OrcValueReader record( + Types.StructType expected, + TypeDescription record, + List names, + List> fields) { + return SparkOrcValueReaders.struct(fields, expected, idToConstant); + } + + @Override + public OrcValueReader list( + Types.ListType iList, TypeDescription array, OrcValueReader elementReader) { + return SparkOrcValueReaders.array(elementReader); + } + + @Override + public OrcValueReader map( + Types.MapType iMap, + TypeDescription map, + OrcValueReader keyReader, + OrcValueReader valueReader) { + return SparkOrcValueReaders.map(keyReader, valueReader); + } + + @Override + public OrcValueReader primitive(Type.PrimitiveType iPrimitive, TypeDescription primitive) { + switch (primitive.getCategory()) { + case BOOLEAN: + return OrcValueReaders.booleans(); + case BYTE: + // Iceberg does not have a byte type. Use int + case SHORT: + // Iceberg does not have a short type. Use int + case DATE: + case INT: + return OrcValueReaders.ints(); + case LONG: + return OrcValueReaders.longs(); + case FLOAT: + return OrcValueReaders.floats(); + case DOUBLE: + return OrcValueReaders.doubles(); + case TIMESTAMP_INSTANT: + case TIMESTAMP: + return SparkOrcValueReaders.timestampTzs(); + case DECIMAL: + return SparkOrcValueReaders.decimals(primitive.getPrecision(), primitive.getScale()); + case CHAR: + case VARCHAR: + case STRING: + return SparkOrcValueReaders.utf8String(); + case BINARY: + return OrcValueReaders.bytes(); + default: + throw new IllegalArgumentException("Unhandled type " + primitive); + } + } + } +} diff --git a/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/data/SparkOrcValueReaders.java b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/data/SparkOrcValueReaders.java new file mode 100644 index 000000000000..9e9b3e53bbcc --- /dev/null +++ b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/data/SparkOrcValueReaders.java @@ -0,0 +1,241 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.data; + +import java.math.BigDecimal; +import java.util.List; +import java.util.Map; +import org.apache.iceberg.orc.OrcValueReader; +import org.apache.iceberg.orc.OrcValueReaders; +import org.apache.iceberg.relocated.com.google.common.base.Preconditions; +import org.apache.iceberg.relocated.com.google.common.collect.Lists; +import org.apache.iceberg.types.Types; +import org.apache.orc.storage.ql.exec.vector.BytesColumnVector; +import org.apache.orc.storage.ql.exec.vector.ColumnVector; +import org.apache.orc.storage.ql.exec.vector.DecimalColumnVector; +import org.apache.orc.storage.ql.exec.vector.ListColumnVector; +import org.apache.orc.storage.ql.exec.vector.MapColumnVector; +import org.apache.orc.storage.ql.exec.vector.TimestampColumnVector; +import org.apache.orc.storage.serde2.io.HiveDecimalWritable; +import org.apache.spark.sql.catalyst.InternalRow; +import org.apache.spark.sql.catalyst.expressions.GenericInternalRow; +import org.apache.spark.sql.catalyst.util.ArrayBasedMapData; +import org.apache.spark.sql.catalyst.util.ArrayData; +import org.apache.spark.sql.catalyst.util.GenericArrayData; +import org.apache.spark.sql.catalyst.util.MapData; +import org.apache.spark.sql.types.Decimal; +import org.apache.spark.unsafe.types.UTF8String; + +public class SparkOrcValueReaders { + private SparkOrcValueReaders() {} + + public static OrcValueReader utf8String() { + return StringReader.INSTANCE; + } + + public static OrcValueReader timestampTzs() { + return TimestampTzReader.INSTANCE; + } + + public static OrcValueReader decimals(int precision, int scale) { + if (precision <= Decimal.MAX_LONG_DIGITS()) { + return new SparkOrcValueReaders.Decimal18Reader(precision, scale); + } else if (precision <= 38) { + return new SparkOrcValueReaders.Decimal38Reader(precision, scale); + } else { + throw new IllegalArgumentException("Invalid precision: " + precision); + } + } + + static OrcValueReader struct( + List> readers, Types.StructType struct, Map idToConstant) { + return new StructReader(readers, struct, idToConstant); + } + + static OrcValueReader array(OrcValueReader elementReader) { + return new ArrayReader(elementReader); + } + + static OrcValueReader map(OrcValueReader keyReader, OrcValueReader valueReader) { + return new MapReader(keyReader, valueReader); + } + + private static class ArrayReader implements OrcValueReader { + private final OrcValueReader elementReader; + + private ArrayReader(OrcValueReader elementReader) { + this.elementReader = elementReader; + } + + @Override + public ArrayData nonNullRead(ColumnVector vector, int row) { + ListColumnVector listVector = (ListColumnVector) vector; + int offset = (int) listVector.offsets[row]; + int length = (int) listVector.lengths[row]; + List elements = Lists.newArrayListWithExpectedSize(length); + for (int c = 0; c < length; ++c) { + elements.add(elementReader.read(listVector.child, offset + c)); + } + return new GenericArrayData(elements.toArray()); + } + + @Override + public void setBatchContext(long batchOffsetInFile) { + elementReader.setBatchContext(batchOffsetInFile); + } + } + + private static class MapReader implements OrcValueReader { + private final OrcValueReader keyReader; + private final OrcValueReader valueReader; + + private MapReader(OrcValueReader keyReader, OrcValueReader valueReader) { + this.keyReader = keyReader; + this.valueReader = valueReader; + } + + @Override + public MapData nonNullRead(ColumnVector vector, int row) { + MapColumnVector mapVector = (MapColumnVector) vector; + int offset = (int) mapVector.offsets[row]; + long length = mapVector.lengths[row]; + List keys = Lists.newArrayListWithExpectedSize((int) length); + List values = Lists.newArrayListWithExpectedSize((int) length); + for (int c = 0; c < length; c++) { + keys.add(keyReader.read(mapVector.keys, offset + c)); + values.add(valueReader.read(mapVector.values, offset + c)); + } + + return new ArrayBasedMapData( + new GenericArrayData(keys.toArray()), new GenericArrayData(values.toArray())); + } + + @Override + public void setBatchContext(long batchOffsetInFile) { + keyReader.setBatchContext(batchOffsetInFile); + valueReader.setBatchContext(batchOffsetInFile); + } + } + + static class StructReader extends OrcValueReaders.StructReader { + private final int numFields; + + protected StructReader( + List> readers, Types.StructType struct, Map idToConstant) { + super(readers, struct, idToConstant); + this.numFields = struct.fields().size(); + } + + @Override + protected InternalRow create() { + return new GenericInternalRow(numFields); + } + + @Override + protected void set(InternalRow struct, int pos, Object value) { + if (value != null) { + struct.update(pos, value); + } else { + struct.setNullAt(pos); + } + } + } + + private static class StringReader implements OrcValueReader { + private static final StringReader INSTANCE = new StringReader(); + + private StringReader() {} + + @Override + public UTF8String nonNullRead(ColumnVector vector, int row) { + BytesColumnVector bytesVector = (BytesColumnVector) vector; + return UTF8String.fromBytes( + bytesVector.vector[row], bytesVector.start[row], bytesVector.length[row]); + } + } + + private static class TimestampTzReader implements OrcValueReader { + private static final TimestampTzReader INSTANCE = new TimestampTzReader(); + + private TimestampTzReader() {} + + @Override + public Long nonNullRead(ColumnVector vector, int row) { + TimestampColumnVector tcv = (TimestampColumnVector) vector; + return Math.floorDiv(tcv.time[row], 1_000) * 1_000_000 + Math.floorDiv(tcv.nanos[row], 1000); + } + } + + private static class Decimal18Reader implements OrcValueReader { + private final int precision; + private final int scale; + + Decimal18Reader(int precision, int scale) { + this.precision = precision; + this.scale = scale; + } + + @Override + public Decimal nonNullRead(ColumnVector vector, int row) { + HiveDecimalWritable value = ((DecimalColumnVector) vector).vector[row]; + + // The scale of decimal read from hive ORC file may be not equals to the expected scale. For + // data type + // decimal(10,3) and the value 10.100, the hive ORC writer will remove its trailing zero and + // store it + // as 101*10^(-1), its scale will adjust from 3 to 1. So here we could not assert that + // value.scale() == scale. + // we also need to convert the hive orc decimal to a decimal with expected precision and + // scale. + Preconditions.checkArgument( + value.precision() <= precision, + "Cannot read value as decimal(%s,%s), too large: %s", + precision, + scale, + value); + + return new Decimal().set(value.serialize64(scale), precision, scale); + } + } + + private static class Decimal38Reader implements OrcValueReader { + private final int precision; + private final int scale; + + Decimal38Reader(int precision, int scale) { + this.precision = precision; + this.scale = scale; + } + + @Override + public Decimal nonNullRead(ColumnVector vector, int row) { + BigDecimal value = + ((DecimalColumnVector) vector).vector[row].getHiveDecimal().bigDecimalValue(); + + Preconditions.checkArgument( + value.precision() <= precision, + "Cannot read value as decimal(%s,%s), too large: %s", + precision, + scale, + value); + + return new Decimal().set(new scala.math.BigDecimal(value), precision, scale); + } + } +} diff --git a/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/data/SparkOrcValueWriters.java b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/data/SparkOrcValueWriters.java new file mode 100644 index 000000000000..780090f99109 --- /dev/null +++ b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/data/SparkOrcValueWriters.java @@ -0,0 +1,201 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.data; + +import java.util.List; +import java.util.stream.Stream; +import org.apache.iceberg.FieldMetrics; +import org.apache.iceberg.orc.OrcValueWriter; +import org.apache.orc.TypeDescription; +import org.apache.orc.storage.common.type.HiveDecimal; +import org.apache.orc.storage.ql.exec.vector.BytesColumnVector; +import org.apache.orc.storage.ql.exec.vector.ColumnVector; +import org.apache.orc.storage.ql.exec.vector.DecimalColumnVector; +import org.apache.orc.storage.ql.exec.vector.ListColumnVector; +import org.apache.orc.storage.ql.exec.vector.MapColumnVector; +import org.apache.orc.storage.ql.exec.vector.TimestampColumnVector; +import org.apache.spark.sql.catalyst.util.ArrayData; +import org.apache.spark.sql.catalyst.util.MapData; +import org.apache.spark.sql.types.Decimal; +import org.apache.spark.unsafe.types.UTF8String; + +class SparkOrcValueWriters { + private SparkOrcValueWriters() {} + + static OrcValueWriter strings() { + return StringWriter.INSTANCE; + } + + static OrcValueWriter timestampTz() { + return TimestampTzWriter.INSTANCE; + } + + static OrcValueWriter decimal(int precision, int scale) { + if (precision <= 18) { + return new Decimal18Writer(scale); + } else { + return new Decimal38Writer(); + } + } + + static OrcValueWriter list(OrcValueWriter element, List orcType) { + return new ListWriter<>(element, orcType); + } + + static OrcValueWriter map( + OrcValueWriter keyWriter, OrcValueWriter valueWriter, List orcTypes) { + return new MapWriter<>(keyWriter, valueWriter, orcTypes); + } + + private static class StringWriter implements OrcValueWriter { + private static final StringWriter INSTANCE = new StringWriter(); + + @Override + public void nonNullWrite(int rowId, UTF8String data, ColumnVector output) { + byte[] value = data.getBytes(); + ((BytesColumnVector) output).setRef(rowId, value, 0, value.length); + } + } + + private static class TimestampTzWriter implements OrcValueWriter { + private static final TimestampTzWriter INSTANCE = new TimestampTzWriter(); + + @Override + public void nonNullWrite(int rowId, Long micros, ColumnVector output) { + TimestampColumnVector cv = (TimestampColumnVector) output; + cv.time[rowId] = Math.floorDiv(micros, 1_000); // millis + cv.nanos[rowId] = (int) Math.floorMod(micros, 1_000_000) * 1_000; // nanos + } + } + + private static class Decimal18Writer implements OrcValueWriter { + private final int scale; + + Decimal18Writer(int scale) { + this.scale = scale; + } + + @Override + public void nonNullWrite(int rowId, Decimal decimal, ColumnVector output) { + ((DecimalColumnVector) output) + .vector[rowId].setFromLongAndScale(decimal.toUnscaledLong(), scale); + } + } + + private static class Decimal38Writer implements OrcValueWriter { + + @Override + public void nonNullWrite(int rowId, Decimal decimal, ColumnVector output) { + ((DecimalColumnVector) output) + .vector[rowId].set(HiveDecimal.create(decimal.toJavaBigDecimal())); + } + } + + private static class ListWriter implements OrcValueWriter { + private final OrcValueWriter writer; + private final SparkOrcWriter.FieldGetter fieldGetter; + + @SuppressWarnings("unchecked") + ListWriter(OrcValueWriter writer, List orcTypes) { + if (orcTypes.size() != 1) { + throw new IllegalArgumentException( + "Expected one (and same) ORC type for list elements, got: " + orcTypes); + } + this.writer = writer; + this.fieldGetter = + (SparkOrcWriter.FieldGetter) SparkOrcWriter.createFieldGetter(orcTypes.get(0)); + } + + @Override + public void nonNullWrite(int rowId, ArrayData value, ColumnVector output) { + ListColumnVector cv = (ListColumnVector) output; + // record the length and start of the list elements + cv.lengths[rowId] = value.numElements(); + cv.offsets[rowId] = cv.childCount; + cv.childCount = (int) (cv.childCount + cv.lengths[rowId]); + // make sure the child is big enough + growColumnVector(cv.child, cv.childCount); + // Add each element + for (int e = 0; e < cv.lengths[rowId]; ++e) { + writer.write((int) (e + cv.offsets[rowId]), fieldGetter.getFieldOrNull(value, e), cv.child); + } + } + + @Override + public Stream> metrics() { + return writer.metrics(); + } + } + + private static class MapWriter implements OrcValueWriter { + private final OrcValueWriter keyWriter; + private final OrcValueWriter valueWriter; + private final SparkOrcWriter.FieldGetter keyFieldGetter; + private final SparkOrcWriter.FieldGetter valueFieldGetter; + + @SuppressWarnings("unchecked") + MapWriter( + OrcValueWriter keyWriter, + OrcValueWriter valueWriter, + List orcTypes) { + if (orcTypes.size() != 2) { + throw new IllegalArgumentException( + "Expected two ORC type descriptions for a map, got: " + orcTypes); + } + this.keyWriter = keyWriter; + this.valueWriter = valueWriter; + this.keyFieldGetter = + (SparkOrcWriter.FieldGetter) SparkOrcWriter.createFieldGetter(orcTypes.get(0)); + this.valueFieldGetter = + (SparkOrcWriter.FieldGetter) SparkOrcWriter.createFieldGetter(orcTypes.get(1)); + } + + @Override + public void nonNullWrite(int rowId, MapData map, ColumnVector output) { + ArrayData key = map.keyArray(); + ArrayData value = map.valueArray(); + MapColumnVector cv = (MapColumnVector) output; + // record the length and start of the list elements + cv.lengths[rowId] = value.numElements(); + cv.offsets[rowId] = cv.childCount; + cv.childCount = (int) (cv.childCount + cv.lengths[rowId]); + // make sure the child is big enough + growColumnVector(cv.keys, cv.childCount); + growColumnVector(cv.values, cv.childCount); + // Add each element + for (int e = 0; e < cv.lengths[rowId]; ++e) { + int pos = (int) (e + cv.offsets[rowId]); + keyWriter.write(pos, keyFieldGetter.getFieldOrNull(key, e), cv.keys); + valueWriter.write(pos, valueFieldGetter.getFieldOrNull(value, e), cv.values); + } + } + + @Override + public Stream> metrics() { + return Stream.concat(keyWriter.metrics(), valueWriter.metrics()); + } + } + + private static void growColumnVector(ColumnVector cv, int requestedSize) { + if (cv.isNull.length < requestedSize) { + // Use growth factor of 3 to avoid frequent array allocations + cv.ensureSize(requestedSize * 3, true); + } + } +} diff --git a/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/data/SparkOrcWriter.java b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/data/SparkOrcWriter.java new file mode 100644 index 000000000000..60868b8700a3 --- /dev/null +++ b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/data/SparkOrcWriter.java @@ -0,0 +1,226 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.data; + +import java.io.Serializable; +import java.util.List; +import java.util.stream.Stream; +import javax.annotation.Nullable; +import org.apache.iceberg.FieldMetrics; +import org.apache.iceberg.Schema; +import org.apache.iceberg.data.orc.GenericOrcWriters; +import org.apache.iceberg.orc.ORCSchemaUtil; +import org.apache.iceberg.orc.OrcRowWriter; +import org.apache.iceberg.orc.OrcSchemaWithTypeVisitor; +import org.apache.iceberg.orc.OrcValueWriter; +import org.apache.iceberg.relocated.com.google.common.base.Preconditions; +import org.apache.iceberg.relocated.com.google.common.collect.Lists; +import org.apache.iceberg.types.Type; +import org.apache.iceberg.types.Types; +import org.apache.orc.TypeDescription; +import org.apache.orc.storage.ql.exec.vector.VectorizedRowBatch; +import org.apache.spark.sql.catalyst.InternalRow; +import org.apache.spark.sql.catalyst.expressions.SpecializedGetters; + +/** This class acts as an adaptor from an OrcFileAppender to a FileAppender<InternalRow>. */ +public class SparkOrcWriter implements OrcRowWriter { + + private final InternalRowWriter writer; + + public SparkOrcWriter(Schema iSchema, TypeDescription orcSchema) { + Preconditions.checkArgument( + orcSchema.getCategory() == TypeDescription.Category.STRUCT, + "Top level must be a struct " + orcSchema); + + writer = + (InternalRowWriter) OrcSchemaWithTypeVisitor.visit(iSchema, orcSchema, new WriteBuilder()); + } + + @Override + public void write(InternalRow value, VectorizedRowBatch output) { + Preconditions.checkArgument(value != null, "value must not be null"); + writer.writeRow(value, output); + } + + @Override + public List> writers() { + return writer.writers(); + } + + @Override + public Stream> metrics() { + return writer.metrics(); + } + + private static class WriteBuilder extends OrcSchemaWithTypeVisitor> { + private WriteBuilder() {} + + @Override + public OrcValueWriter record( + Types.StructType iStruct, + TypeDescription record, + List names, + List> fields) { + return new InternalRowWriter(fields, record.getChildren()); + } + + @Override + public OrcValueWriter list( + Types.ListType iList, TypeDescription array, OrcValueWriter element) { + return SparkOrcValueWriters.list(element, array.getChildren()); + } + + @Override + public OrcValueWriter map( + Types.MapType iMap, TypeDescription map, OrcValueWriter key, OrcValueWriter value) { + return SparkOrcValueWriters.map(key, value, map.getChildren()); + } + + @Override + public OrcValueWriter primitive(Type.PrimitiveType iPrimitive, TypeDescription primitive) { + switch (primitive.getCategory()) { + case BOOLEAN: + return GenericOrcWriters.booleans(); + case BYTE: + return GenericOrcWriters.bytes(); + case SHORT: + return GenericOrcWriters.shorts(); + case DATE: + case INT: + return GenericOrcWriters.ints(); + case LONG: + return GenericOrcWriters.longs(); + case FLOAT: + return GenericOrcWriters.floats(ORCSchemaUtil.fieldId(primitive)); + case DOUBLE: + return GenericOrcWriters.doubles(ORCSchemaUtil.fieldId(primitive)); + case BINARY: + return GenericOrcWriters.byteArrays(); + case STRING: + case CHAR: + case VARCHAR: + return SparkOrcValueWriters.strings(); + case DECIMAL: + return SparkOrcValueWriters.decimal(primitive.getPrecision(), primitive.getScale()); + case TIMESTAMP_INSTANT: + case TIMESTAMP: + return SparkOrcValueWriters.timestampTz(); + default: + throw new IllegalArgumentException("Unhandled type " + primitive); + } + } + } + + private static class InternalRowWriter extends GenericOrcWriters.StructWriter { + private final List> fieldGetters; + + InternalRowWriter(List> writers, List orcTypes) { + super(writers); + this.fieldGetters = Lists.newArrayListWithExpectedSize(orcTypes.size()); + + for (TypeDescription orcType : orcTypes) { + fieldGetters.add(createFieldGetter(orcType)); + } + } + + @Override + protected Object get(InternalRow struct, int index) { + return fieldGetters.get(index).getFieldOrNull(struct, index); + } + } + + static FieldGetter createFieldGetter(TypeDescription fieldType) { + final FieldGetter fieldGetter; + switch (fieldType.getCategory()) { + case BOOLEAN: + fieldGetter = SpecializedGetters::getBoolean; + break; + case BYTE: + fieldGetter = SpecializedGetters::getByte; + break; + case SHORT: + fieldGetter = SpecializedGetters::getShort; + break; + case DATE: + case INT: + fieldGetter = SpecializedGetters::getInt; + break; + case LONG: + case TIMESTAMP: + case TIMESTAMP_INSTANT: + fieldGetter = SpecializedGetters::getLong; + break; + case FLOAT: + fieldGetter = SpecializedGetters::getFloat; + break; + case DOUBLE: + fieldGetter = SpecializedGetters::getDouble; + break; + case BINARY: + fieldGetter = SpecializedGetters::getBinary; + // getBinary always makes a copy, so we don't need to worry about it + // being changed behind our back. + break; + case DECIMAL: + fieldGetter = + (row, ordinal) -> + row.getDecimal(ordinal, fieldType.getPrecision(), fieldType.getScale()); + break; + case STRING: + case CHAR: + case VARCHAR: + fieldGetter = SpecializedGetters::getUTF8String; + break; + case STRUCT: + fieldGetter = (row, ordinal) -> row.getStruct(ordinal, fieldType.getChildren().size()); + break; + case LIST: + fieldGetter = SpecializedGetters::getArray; + break; + case MAP: + fieldGetter = SpecializedGetters::getMap; + break; + default: + throw new IllegalArgumentException( + "Encountered an unsupported ORC type during a write from Spark."); + } + + return (row, ordinal) -> { + if (row.isNullAt(ordinal)) { + return null; + } + return fieldGetter.getFieldOrNull(row, ordinal); + }; + } + + interface FieldGetter extends Serializable { + + /** + * Returns a value from a complex Spark data holder such ArrayData, InternalRow, etc... Calls + * the appropriate getter for the expected data type. + * + * @param row Spark's data representation + * @param ordinal index in the data structure (e.g. column index for InterRow, list index in + * ArrayData, etc..) + * @return field value at ordinal + */ + @Nullable + T getFieldOrNull(SpecializedGetters row, int ordinal); + } +} diff --git a/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/data/SparkParquetReaders.java b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/data/SparkParquetReaders.java new file mode 100644 index 000000000000..59f81de6ae4a --- /dev/null +++ b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/data/SparkParquetReaders.java @@ -0,0 +1,769 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.data; + +import java.math.BigDecimal; +import java.math.BigInteger; +import java.nio.ByteBuffer; +import java.nio.ByteOrder; +import java.util.Arrays; +import java.util.List; +import java.util.Map; +import org.apache.iceberg.MetadataColumns; +import org.apache.iceberg.Schema; +import org.apache.iceberg.parquet.ParquetSchemaUtil; +import org.apache.iceberg.parquet.ParquetUtil; +import org.apache.iceberg.parquet.ParquetValueReader; +import org.apache.iceberg.parquet.ParquetValueReaders; +import org.apache.iceberg.parquet.ParquetValueReaders.FloatAsDoubleReader; +import org.apache.iceberg.parquet.ParquetValueReaders.IntAsLongReader; +import org.apache.iceberg.parquet.ParquetValueReaders.PrimitiveReader; +import org.apache.iceberg.parquet.ParquetValueReaders.RepeatedKeyValueReader; +import org.apache.iceberg.parquet.ParquetValueReaders.RepeatedReader; +import org.apache.iceberg.parquet.ParquetValueReaders.ReusableEntry; +import org.apache.iceberg.parquet.ParquetValueReaders.StructReader; +import org.apache.iceberg.parquet.ParquetValueReaders.UnboxedReader; +import org.apache.iceberg.parquet.TypeWithSchemaVisitor; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableList; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap; +import org.apache.iceberg.relocated.com.google.common.collect.Lists; +import org.apache.iceberg.relocated.com.google.common.collect.Maps; +import org.apache.iceberg.types.Type.TypeID; +import org.apache.iceberg.types.Types; +import org.apache.parquet.column.ColumnDescriptor; +import org.apache.parquet.io.api.Binary; +import org.apache.parquet.schema.GroupType; +import org.apache.parquet.schema.LogicalTypeAnnotation.DecimalLogicalTypeAnnotation; +import org.apache.parquet.schema.MessageType; +import org.apache.parquet.schema.PrimitiveType; +import org.apache.parquet.schema.Type; +import org.apache.spark.sql.catalyst.InternalRow; +import org.apache.spark.sql.catalyst.expressions.GenericInternalRow; +import org.apache.spark.sql.catalyst.util.ArrayBasedMapData; +import org.apache.spark.sql.catalyst.util.ArrayData; +import org.apache.spark.sql.catalyst.util.GenericArrayData; +import org.apache.spark.sql.catalyst.util.MapData; +import org.apache.spark.sql.types.DataType; +import org.apache.spark.sql.types.Decimal; +import org.apache.spark.unsafe.types.CalendarInterval; +import org.apache.spark.unsafe.types.UTF8String; + +public class SparkParquetReaders { + private SparkParquetReaders() {} + + public static ParquetValueReader buildReader( + Schema expectedSchema, MessageType fileSchema) { + return buildReader(expectedSchema, fileSchema, ImmutableMap.of()); + } + + @SuppressWarnings("unchecked") + public static ParquetValueReader buildReader( + Schema expectedSchema, MessageType fileSchema, Map idToConstant) { + if (ParquetSchemaUtil.hasIds(fileSchema)) { + return (ParquetValueReader) + TypeWithSchemaVisitor.visit( + expectedSchema.asStruct(), fileSchema, new ReadBuilder(fileSchema, idToConstant)); + } else { + return (ParquetValueReader) + TypeWithSchemaVisitor.visit( + expectedSchema.asStruct(), + fileSchema, + new FallbackReadBuilder(fileSchema, idToConstant)); + } + } + + private static class FallbackReadBuilder extends ReadBuilder { + FallbackReadBuilder(MessageType type, Map idToConstant) { + super(type, idToConstant); + } + + @Override + public ParquetValueReader message( + Types.StructType expected, MessageType message, List> fieldReaders) { + // the top level matches by ID, but the remaining IDs are missing + return super.struct(expected, message, fieldReaders); + } + + @Override + public ParquetValueReader struct( + Types.StructType ignored, GroupType struct, List> fieldReaders) { + // the expected struct is ignored because nested fields are never found when the + List> newFields = + Lists.newArrayListWithExpectedSize(fieldReaders.size()); + List types = Lists.newArrayListWithExpectedSize(fieldReaders.size()); + List fields = struct.getFields(); + for (int i = 0; i < fields.size(); i += 1) { + Type fieldType = fields.get(i); + int fieldD = type().getMaxDefinitionLevel(path(fieldType.getName())) - 1; + newFields.add(ParquetValueReaders.option(fieldType, fieldD, fieldReaders.get(i))); + types.add(fieldType); + } + + return new InternalRowReader(types, newFields); + } + } + + private static class ReadBuilder extends TypeWithSchemaVisitor> { + private final MessageType type; + private final Map idToConstant; + + ReadBuilder(MessageType type, Map idToConstant) { + this.type = type; + this.idToConstant = idToConstant; + } + + @Override + public ParquetValueReader message( + Types.StructType expected, MessageType message, List> fieldReaders) { + return struct(expected, message.asGroupType(), fieldReaders); + } + + @Override + public ParquetValueReader struct( + Types.StructType expected, GroupType struct, List> fieldReaders) { + // match the expected struct's order + Map> readersById = Maps.newHashMap(); + Map typesById = Maps.newHashMap(); + Map maxDefinitionLevelsById = Maps.newHashMap(); + List fields = struct.getFields(); + for (int i = 0; i < fields.size(); i += 1) { + Type fieldType = fields.get(i); + int fieldD = type.getMaxDefinitionLevel(path(fieldType.getName())) - 1; + if (fieldType.getId() != null) { + int id = fieldType.getId().intValue(); + readersById.put(id, ParquetValueReaders.option(fieldType, fieldD, fieldReaders.get(i))); + typesById.put(id, fieldType); + if (idToConstant.containsKey(id)) { + maxDefinitionLevelsById.put(id, fieldD); + } + } + } + + List expectedFields = + expected != null ? expected.fields() : ImmutableList.of(); + List> reorderedFields = + Lists.newArrayListWithExpectedSize(expectedFields.size()); + List types = Lists.newArrayListWithExpectedSize(expectedFields.size()); + // Defaulting to parent max definition level + int defaultMaxDefinitionLevel = type.getMaxDefinitionLevel(currentPath()); + for (Types.NestedField field : expectedFields) { + int id = field.fieldId(); + if (idToConstant.containsKey(id)) { + // containsKey is used because the constant may be null + int fieldMaxDefinitionLevel = + maxDefinitionLevelsById.getOrDefault(id, defaultMaxDefinitionLevel); + reorderedFields.add( + ParquetValueReaders.constant(idToConstant.get(id), fieldMaxDefinitionLevel)); + types.add(null); + } else if (id == MetadataColumns.ROW_POSITION.fieldId()) { + reorderedFields.add(ParquetValueReaders.position()); + types.add(null); + } else if (id == MetadataColumns.IS_DELETED.fieldId()) { + reorderedFields.add(ParquetValueReaders.constant(false)); + types.add(null); + } else { + ParquetValueReader reader = readersById.get(id); + if (reader != null) { + reorderedFields.add(reader); + types.add(typesById.get(id)); + } else { + reorderedFields.add(ParquetValueReaders.nulls()); + types.add(null); + } + } + } + + return new InternalRowReader(types, reorderedFields); + } + + @Override + public ParquetValueReader list( + Types.ListType expectedList, GroupType array, ParquetValueReader elementReader) { + String[] repeatedPath = currentPath(); + + int repeatedD = type.getMaxDefinitionLevel(repeatedPath) - 1; + int repeatedR = type.getMaxRepetitionLevel(repeatedPath) - 1; + + Type elementType = ParquetSchemaUtil.determineListElementType(array); + int elementD = type.getMaxDefinitionLevel(path(elementType.getName())) - 1; + + return new ArrayReader<>( + repeatedD, repeatedR, ParquetValueReaders.option(elementType, elementD, elementReader)); + } + + @Override + public ParquetValueReader map( + Types.MapType expectedMap, + GroupType map, + ParquetValueReader keyReader, + ParquetValueReader valueReader) { + GroupType repeatedKeyValue = map.getFields().get(0).asGroupType(); + String[] repeatedPath = currentPath(); + + int repeatedD = type.getMaxDefinitionLevel(repeatedPath) - 1; + int repeatedR = type.getMaxRepetitionLevel(repeatedPath) - 1; + + Type keyType = repeatedKeyValue.getType(0); + int keyD = type.getMaxDefinitionLevel(path(keyType.getName())) - 1; + Type valueType = repeatedKeyValue.getType(1); + int valueD = type.getMaxDefinitionLevel(path(valueType.getName())) - 1; + + return new MapReader<>( + repeatedD, + repeatedR, + ParquetValueReaders.option(keyType, keyD, keyReader), + ParquetValueReaders.option(valueType, valueD, valueReader)); + } + + @Override + public ParquetValueReader primitive( + org.apache.iceberg.types.Type.PrimitiveType expected, PrimitiveType primitive) { + ColumnDescriptor desc = type.getColumnDescription(currentPath()); + + if (primitive.getOriginalType() != null) { + switch (primitive.getOriginalType()) { + case ENUM: + case JSON: + case UTF8: + return new StringReader(desc); + case INT_8: + case INT_16: + case INT_32: + if (expected != null && expected.typeId() == Types.LongType.get().typeId()) { + return new IntAsLongReader(desc); + } else { + return new UnboxedReader(desc); + } + case DATE: + case INT_64: + case TIMESTAMP_MICROS: + return new UnboxedReader<>(desc); + case TIMESTAMP_MILLIS: + return new TimestampMillisReader(desc); + case DECIMAL: + DecimalLogicalTypeAnnotation decimal = + (DecimalLogicalTypeAnnotation) primitive.getLogicalTypeAnnotation(); + switch (primitive.getPrimitiveTypeName()) { + case BINARY: + case FIXED_LEN_BYTE_ARRAY: + return new BinaryDecimalReader(desc, decimal.getScale()); + case INT64: + return new LongDecimalReader(desc, decimal.getPrecision(), decimal.getScale()); + case INT32: + return new IntegerDecimalReader(desc, decimal.getPrecision(), decimal.getScale()); + default: + throw new UnsupportedOperationException( + "Unsupported base type for decimal: " + primitive.getPrimitiveTypeName()); + } + case BSON: + return new ParquetValueReaders.ByteArrayReader(desc); + default: + throw new UnsupportedOperationException( + "Unsupported logical type: " + primitive.getOriginalType()); + } + } + + switch (primitive.getPrimitiveTypeName()) { + case FIXED_LEN_BYTE_ARRAY: + case BINARY: + return new ParquetValueReaders.ByteArrayReader(desc); + case INT32: + if (expected != null && expected.typeId() == TypeID.LONG) { + return new IntAsLongReader(desc); + } else { + return new UnboxedReader<>(desc); + } + case FLOAT: + if (expected != null && expected.typeId() == TypeID.DOUBLE) { + return new FloatAsDoubleReader(desc); + } else { + return new UnboxedReader<>(desc); + } + case BOOLEAN: + case INT64: + case DOUBLE: + return new UnboxedReader<>(desc); + case INT96: + // Impala & Spark used to write timestamps as INT96 without a logical type. For backwards + // compatibility we try to read INT96 as timestamps. + return new TimestampInt96Reader(desc); + default: + throw new UnsupportedOperationException("Unsupported type: " + primitive); + } + } + + protected MessageType type() { + return type; + } + } + + private static class BinaryDecimalReader extends PrimitiveReader { + private final int scale; + + BinaryDecimalReader(ColumnDescriptor desc, int scale) { + super(desc); + this.scale = scale; + } + + @Override + public Decimal read(Decimal ignored) { + Binary binary = column.nextBinary(); + return Decimal.fromDecimal(new BigDecimal(new BigInteger(binary.getBytes()), scale)); + } + } + + private static class IntegerDecimalReader extends PrimitiveReader { + private final int precision; + private final int scale; + + IntegerDecimalReader(ColumnDescriptor desc, int precision, int scale) { + super(desc); + this.precision = precision; + this.scale = scale; + } + + @Override + public Decimal read(Decimal ignored) { + return Decimal.apply(column.nextInteger(), precision, scale); + } + } + + private static class LongDecimalReader extends PrimitiveReader { + private final int precision; + private final int scale; + + LongDecimalReader(ColumnDescriptor desc, int precision, int scale) { + super(desc); + this.precision = precision; + this.scale = scale; + } + + @Override + public Decimal read(Decimal ignored) { + return Decimal.apply(column.nextLong(), precision, scale); + } + } + + private static class TimestampMillisReader extends UnboxedReader { + TimestampMillisReader(ColumnDescriptor desc) { + super(desc); + } + + @Override + public Long read(Long ignored) { + return readLong(); + } + + @Override + public long readLong() { + return 1000 * column.nextLong(); + } + } + + private static class TimestampInt96Reader extends UnboxedReader { + + TimestampInt96Reader(ColumnDescriptor desc) { + super(desc); + } + + @Override + public Long read(Long ignored) { + return readLong(); + } + + @Override + public long readLong() { + final ByteBuffer byteBuffer = + column.nextBinary().toByteBuffer().order(ByteOrder.LITTLE_ENDIAN); + return ParquetUtil.extractTimestampInt96(byteBuffer); + } + } + + private static class StringReader extends PrimitiveReader { + StringReader(ColumnDescriptor desc) { + super(desc); + } + + @Override + public UTF8String read(UTF8String ignored) { + Binary binary = column.nextBinary(); + ByteBuffer buffer = binary.toByteBuffer(); + if (buffer.hasArray()) { + return UTF8String.fromBytes( + buffer.array(), buffer.arrayOffset() + buffer.position(), buffer.remaining()); + } else { + return UTF8String.fromBytes(binary.getBytes()); + } + } + } + + private static class ArrayReader extends RepeatedReader { + private int readPos = 0; + private int writePos = 0; + + ArrayReader(int definitionLevel, int repetitionLevel, ParquetValueReader reader) { + super(definitionLevel, repetitionLevel, reader); + } + + @Override + @SuppressWarnings("unchecked") + protected ReusableArrayData newListData(ArrayData reuse) { + this.readPos = 0; + this.writePos = 0; + + if (reuse instanceof ReusableArrayData) { + return (ReusableArrayData) reuse; + } else { + return new ReusableArrayData(); + } + } + + @Override + @SuppressWarnings("unchecked") + protected E getElement(ReusableArrayData list) { + E value = null; + if (readPos < list.capacity()) { + value = (E) list.values[readPos]; + } + + readPos += 1; + + return value; + } + + @Override + protected void addElement(ReusableArrayData reused, E element) { + if (writePos >= reused.capacity()) { + reused.grow(); + } + + reused.values[writePos] = element; + + writePos += 1; + } + + @Override + protected ArrayData buildList(ReusableArrayData list) { + list.setNumElements(writePos); + return list; + } + } + + private static class MapReader + extends RepeatedKeyValueReader { + private int readPos = 0; + private int writePos = 0; + + private final ReusableEntry entry = new ReusableEntry<>(); + private final ReusableEntry nullEntry = new ReusableEntry<>(); + + MapReader( + int definitionLevel, + int repetitionLevel, + ParquetValueReader keyReader, + ParquetValueReader valueReader) { + super(definitionLevel, repetitionLevel, keyReader, valueReader); + } + + @Override + @SuppressWarnings("unchecked") + protected ReusableMapData newMapData(MapData reuse) { + this.readPos = 0; + this.writePos = 0; + + if (reuse instanceof ReusableMapData) { + return (ReusableMapData) reuse; + } else { + return new ReusableMapData(); + } + } + + @Override + @SuppressWarnings("unchecked") + protected Map.Entry getPair(ReusableMapData map) { + Map.Entry kv = nullEntry; + if (readPos < map.capacity()) { + entry.set((K) map.keys.values[readPos], (V) map.values.values[readPos]); + kv = entry; + } + + readPos += 1; + + return kv; + } + + @Override + protected void addPair(ReusableMapData map, K key, V value) { + if (writePos >= map.capacity()) { + map.grow(); + } + + map.keys.values[writePos] = key; + map.values.values[writePos] = value; + + writePos += 1; + } + + @Override + protected MapData buildMap(ReusableMapData map) { + map.setNumElements(writePos); + return map; + } + } + + private static class InternalRowReader extends StructReader { + private final int numFields; + + InternalRowReader(List types, List> readers) { + super(types, readers); + this.numFields = readers.size(); + } + + @Override + protected GenericInternalRow newStructData(InternalRow reuse) { + if (reuse instanceof GenericInternalRow) { + return (GenericInternalRow) reuse; + } else { + return new GenericInternalRow(numFields); + } + } + + @Override + protected Object getField(GenericInternalRow intermediate, int pos) { + return intermediate.genericGet(pos); + } + + @Override + protected InternalRow buildStruct(GenericInternalRow struct) { + return struct; + } + + @Override + protected void set(GenericInternalRow row, int pos, Object value) { + row.update(pos, value); + } + + @Override + protected void setNull(GenericInternalRow row, int pos) { + row.setNullAt(pos); + } + + @Override + protected void setBoolean(GenericInternalRow row, int pos, boolean value) { + row.setBoolean(pos, value); + } + + @Override + protected void setInteger(GenericInternalRow row, int pos, int value) { + row.setInt(pos, value); + } + + @Override + protected void setLong(GenericInternalRow row, int pos, long value) { + row.setLong(pos, value); + } + + @Override + protected void setFloat(GenericInternalRow row, int pos, float value) { + row.setFloat(pos, value); + } + + @Override + protected void setDouble(GenericInternalRow row, int pos, double value) { + row.setDouble(pos, value); + } + } + + private static class ReusableMapData extends MapData { + private final ReusableArrayData keys; + private final ReusableArrayData values; + private int numElements; + + private ReusableMapData() { + this.keys = new ReusableArrayData(); + this.values = new ReusableArrayData(); + } + + private void grow() { + keys.grow(); + values.grow(); + } + + private int capacity() { + return keys.capacity(); + } + + public void setNumElements(int numElements) { + this.numElements = numElements; + keys.setNumElements(numElements); + values.setNumElements(numElements); + } + + @Override + public int numElements() { + return numElements; + } + + @Override + public MapData copy() { + return new ArrayBasedMapData(keyArray().copy(), valueArray().copy()); + } + + @Override + public ReusableArrayData keyArray() { + return keys; + } + + @Override + public ReusableArrayData valueArray() { + return values; + } + } + + private static class ReusableArrayData extends ArrayData { + private static final Object[] EMPTY = new Object[0]; + + private Object[] values = EMPTY; + private int numElements = 0; + + private void grow() { + if (values.length == 0) { + this.values = new Object[20]; + } else { + Object[] old = values; + this.values = new Object[old.length << 2]; + // copy the old array in case it has values that can be reused + System.arraycopy(old, 0, values, 0, old.length); + } + } + + private int capacity() { + return values.length; + } + + public void setNumElements(int numElements) { + this.numElements = numElements; + } + + @Override + public Object get(int ordinal, DataType dataType) { + return values[ordinal]; + } + + @Override + public int numElements() { + return numElements; + } + + @Override + public ArrayData copy() { + return new GenericArrayData(array()); + } + + @Override + public Object[] array() { + return Arrays.copyOfRange(values, 0, numElements); + } + + @Override + public void setNullAt(int i) { + values[i] = null; + } + + @Override + public void update(int ordinal, Object value) { + values[ordinal] = value; + } + + @Override + public boolean isNullAt(int ordinal) { + return null == values[ordinal]; + } + + @Override + public boolean getBoolean(int ordinal) { + return (boolean) values[ordinal]; + } + + @Override + public byte getByte(int ordinal) { + return (byte) values[ordinal]; + } + + @Override + public short getShort(int ordinal) { + return (short) values[ordinal]; + } + + @Override + public int getInt(int ordinal) { + return (int) values[ordinal]; + } + + @Override + public long getLong(int ordinal) { + return (long) values[ordinal]; + } + + @Override + public float getFloat(int ordinal) { + return (float) values[ordinal]; + } + + @Override + public double getDouble(int ordinal) { + return (double) values[ordinal]; + } + + @Override + public Decimal getDecimal(int ordinal, int precision, int scale) { + return (Decimal) values[ordinal]; + } + + @Override + public UTF8String getUTF8String(int ordinal) { + return (UTF8String) values[ordinal]; + } + + @Override + public byte[] getBinary(int ordinal) { + return (byte[]) values[ordinal]; + } + + @Override + public CalendarInterval getInterval(int ordinal) { + return (CalendarInterval) values[ordinal]; + } + + @Override + public InternalRow getStruct(int ordinal, int numFields) { + return (InternalRow) values[ordinal]; + } + + @Override + public ArrayData getArray(int ordinal) { + return (ArrayData) values[ordinal]; + } + + @Override + public MapData getMap(int ordinal) { + return (MapData) values[ordinal]; + } + } +} diff --git a/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/data/SparkParquetWriters.java b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/data/SparkParquetWriters.java new file mode 100644 index 000000000000..3637fa4a2604 --- /dev/null +++ b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/data/SparkParquetWriters.java @@ -0,0 +1,458 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.data; + +import java.util.Iterator; +import java.util.List; +import java.util.Map; +import java.util.NoSuchElementException; +import org.apache.iceberg.parquet.ParquetValueReaders.ReusableEntry; +import org.apache.iceberg.parquet.ParquetValueWriter; +import org.apache.iceberg.parquet.ParquetValueWriters; +import org.apache.iceberg.parquet.ParquetValueWriters.PrimitiveWriter; +import org.apache.iceberg.parquet.ParquetValueWriters.RepeatedKeyValueWriter; +import org.apache.iceberg.parquet.ParquetValueWriters.RepeatedWriter; +import org.apache.iceberg.relocated.com.google.common.base.Preconditions; +import org.apache.iceberg.relocated.com.google.common.collect.Lists; +import org.apache.iceberg.types.TypeUtil; +import org.apache.iceberg.util.DecimalUtil; +import org.apache.parquet.column.ColumnDescriptor; +import org.apache.parquet.io.api.Binary; +import org.apache.parquet.schema.GroupType; +import org.apache.parquet.schema.LogicalTypeAnnotation.DecimalLogicalTypeAnnotation; +import org.apache.parquet.schema.MessageType; +import org.apache.parquet.schema.PrimitiveType; +import org.apache.parquet.schema.Type; +import org.apache.spark.sql.catalyst.InternalRow; +import org.apache.spark.sql.catalyst.util.ArrayData; +import org.apache.spark.sql.catalyst.util.MapData; +import org.apache.spark.sql.types.ArrayType; +import org.apache.spark.sql.types.ByteType; +import org.apache.spark.sql.types.DataType; +import org.apache.spark.sql.types.Decimal; +import org.apache.spark.sql.types.MapType; +import org.apache.spark.sql.types.ShortType; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; +import org.apache.spark.unsafe.types.UTF8String; + +public class SparkParquetWriters { + private SparkParquetWriters() {} + + @SuppressWarnings("unchecked") + public static ParquetValueWriter buildWriter(StructType dfSchema, MessageType type) { + return (ParquetValueWriter) + ParquetWithSparkSchemaVisitor.visit(dfSchema, type, new WriteBuilder(type)); + } + + private static class WriteBuilder extends ParquetWithSparkSchemaVisitor> { + private final MessageType type; + + WriteBuilder(MessageType type) { + this.type = type; + } + + @Override + public ParquetValueWriter message( + StructType sStruct, MessageType message, List> fieldWriters) { + return struct(sStruct, message.asGroupType(), fieldWriters); + } + + @Override + public ParquetValueWriter struct( + StructType sStruct, GroupType struct, List> fieldWriters) { + List fields = struct.getFields(); + StructField[] sparkFields = sStruct.fields(); + List> writers = Lists.newArrayListWithExpectedSize(fieldWriters.size()); + List sparkTypes = Lists.newArrayList(); + for (int i = 0; i < fields.size(); i += 1) { + writers.add(newOption(struct.getType(i), fieldWriters.get(i))); + sparkTypes.add(sparkFields[i].dataType()); + } + + return new InternalRowWriter(writers, sparkTypes); + } + + @Override + public ParquetValueWriter list( + ArrayType sArray, GroupType array, ParquetValueWriter elementWriter) { + GroupType repeated = array.getFields().get(0).asGroupType(); + String[] repeatedPath = currentPath(); + + int repeatedD = type.getMaxDefinitionLevel(repeatedPath); + int repeatedR = type.getMaxRepetitionLevel(repeatedPath); + + return new ArrayDataWriter<>( + repeatedD, + repeatedR, + newOption(repeated.getType(0), elementWriter), + sArray.elementType()); + } + + @Override + public ParquetValueWriter map( + MapType sMap, + GroupType map, + ParquetValueWriter keyWriter, + ParquetValueWriter valueWriter) { + GroupType repeatedKeyValue = map.getFields().get(0).asGroupType(); + String[] repeatedPath = currentPath(); + + int repeatedD = type.getMaxDefinitionLevel(repeatedPath); + int repeatedR = type.getMaxRepetitionLevel(repeatedPath); + + return new MapDataWriter<>( + repeatedD, + repeatedR, + newOption(repeatedKeyValue.getType(0), keyWriter), + newOption(repeatedKeyValue.getType(1), valueWriter), + sMap.keyType(), + sMap.valueType()); + } + + private ParquetValueWriter newOption(Type fieldType, ParquetValueWriter writer) { + int maxD = type.getMaxDefinitionLevel(path(fieldType.getName())); + return ParquetValueWriters.option(fieldType, maxD, writer); + } + + @Override + public ParquetValueWriter primitive(DataType sType, PrimitiveType primitive) { + ColumnDescriptor desc = type.getColumnDescription(currentPath()); + + if (primitive.getOriginalType() != null) { + switch (primitive.getOriginalType()) { + case ENUM: + case JSON: + case UTF8: + return utf8Strings(desc); + case DATE: + case INT_8: + case INT_16: + case INT_32: + return ints(sType, desc); + case INT_64: + case TIME_MICROS: + case TIMESTAMP_MICROS: + return ParquetValueWriters.longs(desc); + case DECIMAL: + DecimalLogicalTypeAnnotation decimal = + (DecimalLogicalTypeAnnotation) primitive.getLogicalTypeAnnotation(); + switch (primitive.getPrimitiveTypeName()) { + case INT32: + return decimalAsInteger(desc, decimal.getPrecision(), decimal.getScale()); + case INT64: + return decimalAsLong(desc, decimal.getPrecision(), decimal.getScale()); + case BINARY: + case FIXED_LEN_BYTE_ARRAY: + return decimalAsFixed(desc, decimal.getPrecision(), decimal.getScale()); + default: + throw new UnsupportedOperationException( + "Unsupported base type for decimal: " + primitive.getPrimitiveTypeName()); + } + case BSON: + return byteArrays(desc); + default: + throw new UnsupportedOperationException( + "Unsupported logical type: " + primitive.getOriginalType()); + } + } + + switch (primitive.getPrimitiveTypeName()) { + case FIXED_LEN_BYTE_ARRAY: + case BINARY: + return byteArrays(desc); + case BOOLEAN: + return ParquetValueWriters.booleans(desc); + case INT32: + return ints(sType, desc); + case INT64: + return ParquetValueWriters.longs(desc); + case FLOAT: + return ParquetValueWriters.floats(desc); + case DOUBLE: + return ParquetValueWriters.doubles(desc); + default: + throw new UnsupportedOperationException("Unsupported type: " + primitive); + } + } + } + + private static PrimitiveWriter ints(DataType type, ColumnDescriptor desc) { + if (type instanceof ByteType) { + return ParquetValueWriters.tinyints(desc); + } else if (type instanceof ShortType) { + return ParquetValueWriters.shorts(desc); + } + return ParquetValueWriters.ints(desc); + } + + private static PrimitiveWriter utf8Strings(ColumnDescriptor desc) { + return new UTF8StringWriter(desc); + } + + private static PrimitiveWriter decimalAsInteger( + ColumnDescriptor desc, int precision, int scale) { + return new IntegerDecimalWriter(desc, precision, scale); + } + + private static PrimitiveWriter decimalAsLong( + ColumnDescriptor desc, int precision, int scale) { + return new LongDecimalWriter(desc, precision, scale); + } + + private static PrimitiveWriter decimalAsFixed( + ColumnDescriptor desc, int precision, int scale) { + return new FixedDecimalWriter(desc, precision, scale); + } + + private static PrimitiveWriter byteArrays(ColumnDescriptor desc) { + return new ByteArrayWriter(desc); + } + + private static class UTF8StringWriter extends PrimitiveWriter { + private UTF8StringWriter(ColumnDescriptor desc) { + super(desc); + } + + @Override + public void write(int repetitionLevel, UTF8String value) { + column.writeBinary(repetitionLevel, Binary.fromReusedByteArray(value.getBytes())); + } + } + + private static class IntegerDecimalWriter extends PrimitiveWriter { + private final int precision; + private final int scale; + + private IntegerDecimalWriter(ColumnDescriptor desc, int precision, int scale) { + super(desc); + this.precision = precision; + this.scale = scale; + } + + @Override + public void write(int repetitionLevel, Decimal decimal) { + Preconditions.checkArgument( + decimal.scale() == scale, + "Cannot write value as decimal(%s,%s), wrong scale: %s", + precision, + scale, + decimal); + Preconditions.checkArgument( + decimal.precision() <= precision, + "Cannot write value as decimal(%s,%s), too large: %s", + precision, + scale, + decimal); + + column.writeInteger(repetitionLevel, (int) decimal.toUnscaledLong()); + } + } + + private static class LongDecimalWriter extends PrimitiveWriter { + private final int precision; + private final int scale; + + private LongDecimalWriter(ColumnDescriptor desc, int precision, int scale) { + super(desc); + this.precision = precision; + this.scale = scale; + } + + @Override + public void write(int repetitionLevel, Decimal decimal) { + Preconditions.checkArgument( + decimal.scale() == scale, + "Cannot write value as decimal(%s,%s), wrong scale: %s", + precision, + scale, + decimal); + Preconditions.checkArgument( + decimal.precision() <= precision, + "Cannot write value as decimal(%s,%s), too large: %s", + precision, + scale, + decimal); + + column.writeLong(repetitionLevel, decimal.toUnscaledLong()); + } + } + + private static class FixedDecimalWriter extends PrimitiveWriter { + private final int precision; + private final int scale; + private final ThreadLocal bytes; + + private FixedDecimalWriter(ColumnDescriptor desc, int precision, int scale) { + super(desc); + this.precision = precision; + this.scale = scale; + this.bytes = + ThreadLocal.withInitial(() -> new byte[TypeUtil.decimalRequiredBytes(precision)]); + } + + @Override + public void write(int repetitionLevel, Decimal decimal) { + byte[] binary = + DecimalUtil.toReusedFixLengthBytes( + precision, scale, decimal.toJavaBigDecimal(), bytes.get()); + column.writeBinary(repetitionLevel, Binary.fromReusedByteArray(binary)); + } + } + + private static class ByteArrayWriter extends PrimitiveWriter { + private ByteArrayWriter(ColumnDescriptor desc) { + super(desc); + } + + @Override + public void write(int repetitionLevel, byte[] bytes) { + column.writeBinary(repetitionLevel, Binary.fromReusedByteArray(bytes)); + } + } + + private static class ArrayDataWriter extends RepeatedWriter { + private final DataType elementType; + + private ArrayDataWriter( + int definitionLevel, + int repetitionLevel, + ParquetValueWriter writer, + DataType elementType) { + super(definitionLevel, repetitionLevel, writer); + this.elementType = elementType; + } + + @Override + protected Iterator elements(ArrayData list) { + return new ElementIterator<>(list); + } + + private class ElementIterator implements Iterator { + private final int size; + private final ArrayData list; + private int index; + + private ElementIterator(ArrayData list) { + this.list = list; + size = list.numElements(); + index = 0; + } + + @Override + public boolean hasNext() { + return index != size; + } + + @Override + @SuppressWarnings("unchecked") + public E next() { + if (index >= size) { + throw new NoSuchElementException(); + } + + E element; + if (list.isNullAt(index)) { + element = null; + } else { + element = (E) list.get(index, elementType); + } + + index += 1; + + return element; + } + } + } + + private static class MapDataWriter extends RepeatedKeyValueWriter { + private final DataType keyType; + private final DataType valueType; + + private MapDataWriter( + int definitionLevel, + int repetitionLevel, + ParquetValueWriter keyWriter, + ParquetValueWriter valueWriter, + DataType keyType, + DataType valueType) { + super(definitionLevel, repetitionLevel, keyWriter, valueWriter); + this.keyType = keyType; + this.valueType = valueType; + } + + @Override + protected Iterator> pairs(MapData map) { + return new EntryIterator<>(map); + } + + private class EntryIterator implements Iterator> { + private final int size; + private final ArrayData keys; + private final ArrayData values; + private final ReusableEntry entry; + private int index; + + private EntryIterator(MapData map) { + size = map.numElements(); + keys = map.keyArray(); + values = map.valueArray(); + entry = new ReusableEntry<>(); + index = 0; + } + + @Override + public boolean hasNext() { + return index != size; + } + + @Override + @SuppressWarnings("unchecked") + public Map.Entry next() { + if (index >= size) { + throw new NoSuchElementException(); + } + + if (values.isNullAt(index)) { + entry.set((K) keys.get(index, keyType), null); + } else { + entry.set((K) keys.get(index, keyType), (V) values.get(index, valueType)); + } + + index += 1; + + return entry; + } + } + } + + private static class InternalRowWriter extends ParquetValueWriters.StructWriter { + private final DataType[] types; + + private InternalRowWriter(List> writers, List types) { + super(writers); + this.types = types.toArray(new DataType[types.size()]); + } + + @Override + protected Object get(InternalRow struct, int index) { + return struct.get(index, types[index]); + } + } +} diff --git a/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/data/SparkValueReaders.java b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/data/SparkValueReaders.java new file mode 100644 index 000000000000..3cbf38d88bf4 --- /dev/null +++ b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/data/SparkValueReaders.java @@ -0,0 +1,284 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.data; + +import java.io.IOException; +import java.math.BigDecimal; +import java.math.BigInteger; +import java.nio.ByteBuffer; +import java.nio.ByteOrder; +import java.nio.charset.StandardCharsets; +import java.util.List; +import java.util.Map; +import org.apache.avro.io.Decoder; +import org.apache.avro.util.Utf8; +import org.apache.iceberg.avro.ValueReader; +import org.apache.iceberg.avro.ValueReaders; +import org.apache.iceberg.relocated.com.google.common.collect.Lists; +import org.apache.iceberg.types.Types; +import org.apache.iceberg.util.UUIDUtil; +import org.apache.spark.sql.catalyst.InternalRow; +import org.apache.spark.sql.catalyst.expressions.GenericInternalRow; +import org.apache.spark.sql.catalyst.util.ArrayBasedMapData; +import org.apache.spark.sql.catalyst.util.ArrayData; +import org.apache.spark.sql.catalyst.util.GenericArrayData; +import org.apache.spark.sql.types.Decimal; +import org.apache.spark.unsafe.types.UTF8String; + +public class SparkValueReaders { + + private SparkValueReaders() {} + + static ValueReader strings() { + return StringReader.INSTANCE; + } + + static ValueReader enums(List symbols) { + return new EnumReader(symbols); + } + + static ValueReader uuids() { + return UUIDReader.INSTANCE; + } + + static ValueReader decimal(ValueReader unscaledReader, int scale) { + return new DecimalReader(unscaledReader, scale); + } + + static ValueReader array(ValueReader elementReader) { + return new ArrayReader(elementReader); + } + + static ValueReader arrayMap( + ValueReader keyReader, ValueReader valueReader) { + return new ArrayMapReader(keyReader, valueReader); + } + + static ValueReader map(ValueReader keyReader, ValueReader valueReader) { + return new MapReader(keyReader, valueReader); + } + + static ValueReader struct( + List> readers, Types.StructType struct, Map idToConstant) { + return new StructReader(readers, struct, idToConstant); + } + + private static class StringReader implements ValueReader { + private static final StringReader INSTANCE = new StringReader(); + + private StringReader() {} + + @Override + public UTF8String read(Decoder decoder, Object reuse) throws IOException { + // use the decoder's readString(Utf8) method because it may be a resolving decoder + Utf8 utf8 = null; + if (reuse instanceof UTF8String) { + utf8 = new Utf8(((UTF8String) reuse).getBytes()); + } + + Utf8 string = decoder.readString(utf8); + return UTF8String.fromBytes(string.getBytes(), 0, string.getByteLength()); + } + } + + private static class EnumReader implements ValueReader { + private final UTF8String[] symbols; + + private EnumReader(List symbols) { + this.symbols = new UTF8String[symbols.size()]; + for (int i = 0; i < this.symbols.length; i += 1) { + this.symbols[i] = UTF8String.fromBytes(symbols.get(i).getBytes(StandardCharsets.UTF_8)); + } + } + + @Override + public UTF8String read(Decoder decoder, Object ignore) throws IOException { + int index = decoder.readEnum(); + return symbols[index]; + } + } + + private static class UUIDReader implements ValueReader { + private static final ThreadLocal BUFFER = + ThreadLocal.withInitial( + () -> { + ByteBuffer buffer = ByteBuffer.allocate(16); + buffer.order(ByteOrder.BIG_ENDIAN); + return buffer; + }); + + private static final UUIDReader INSTANCE = new UUIDReader(); + + private UUIDReader() {} + + @Override + @SuppressWarnings("ByteBufferBackingArray") + public UTF8String read(Decoder decoder, Object reuse) throws IOException { + ByteBuffer buffer = BUFFER.get(); + buffer.rewind(); + + decoder.readFixed(buffer.array(), 0, 16); + + return UTF8String.fromString(UUIDUtil.convert(buffer).toString()); + } + } + + private static class DecimalReader implements ValueReader { + private final ValueReader bytesReader; + private final int scale; + + private DecimalReader(ValueReader bytesReader, int scale) { + this.bytesReader = bytesReader; + this.scale = scale; + } + + @Override + public Decimal read(Decoder decoder, Object reuse) throws IOException { + byte[] bytes = bytesReader.read(decoder, null); + return Decimal.apply(new BigDecimal(new BigInteger(bytes), scale)); + } + } + + private static class ArrayReader implements ValueReader { + private final ValueReader elementReader; + private final List reusedList = Lists.newArrayList(); + + private ArrayReader(ValueReader elementReader) { + this.elementReader = elementReader; + } + + @Override + public GenericArrayData read(Decoder decoder, Object reuse) throws IOException { + reusedList.clear(); + long chunkLength = decoder.readArrayStart(); + + while (chunkLength > 0) { + for (int i = 0; i < chunkLength; i += 1) { + reusedList.add(elementReader.read(decoder, null)); + } + + chunkLength = decoder.arrayNext(); + } + + // this will convert the list to an array so it is okay to reuse the list + return new GenericArrayData(reusedList.toArray()); + } + } + + private static class ArrayMapReader implements ValueReader { + private final ValueReader keyReader; + private final ValueReader valueReader; + + private final List reusedKeyList = Lists.newArrayList(); + private final List reusedValueList = Lists.newArrayList(); + + private ArrayMapReader(ValueReader keyReader, ValueReader valueReader) { + this.keyReader = keyReader; + this.valueReader = valueReader; + } + + @Override + public ArrayBasedMapData read(Decoder decoder, Object reuse) throws IOException { + reusedKeyList.clear(); + reusedValueList.clear(); + + long chunkLength = decoder.readArrayStart(); + + while (chunkLength > 0) { + for (int i = 0; i < chunkLength; i += 1) { + reusedKeyList.add(keyReader.read(decoder, null)); + reusedValueList.add(valueReader.read(decoder, null)); + } + + chunkLength = decoder.arrayNext(); + } + + return new ArrayBasedMapData( + new GenericArrayData(reusedKeyList.toArray()), + new GenericArrayData(reusedValueList.toArray())); + } + } + + private static class MapReader implements ValueReader { + private final ValueReader keyReader; + private final ValueReader valueReader; + + private final List reusedKeyList = Lists.newArrayList(); + private final List reusedValueList = Lists.newArrayList(); + + private MapReader(ValueReader keyReader, ValueReader valueReader) { + this.keyReader = keyReader; + this.valueReader = valueReader; + } + + @Override + public ArrayBasedMapData read(Decoder decoder, Object reuse) throws IOException { + reusedKeyList.clear(); + reusedValueList.clear(); + + long chunkLength = decoder.readMapStart(); + + while (chunkLength > 0) { + for (int i = 0; i < chunkLength; i += 1) { + reusedKeyList.add(keyReader.read(decoder, null)); + reusedValueList.add(valueReader.read(decoder, null)); + } + + chunkLength = decoder.mapNext(); + } + + return new ArrayBasedMapData( + new GenericArrayData(reusedKeyList.toArray()), + new GenericArrayData(reusedValueList.toArray())); + } + } + + static class StructReader extends ValueReaders.StructReader { + private final int numFields; + + protected StructReader( + List> readers, Types.StructType struct, Map idToConstant) { + super(readers, struct, idToConstant); + this.numFields = readers.size(); + } + + @Override + protected InternalRow reuseOrCreate(Object reuse) { + if (reuse instanceof GenericInternalRow + && ((GenericInternalRow) reuse).numFields() == numFields) { + return (InternalRow) reuse; + } + return new GenericInternalRow(numFields); + } + + @Override + protected Object get(InternalRow struct, int pos) { + return null; + } + + @Override + protected void set(InternalRow struct, int pos, Object value) { + if (value != null) { + struct.update(pos, value); + } else { + struct.setNullAt(pos); + } + } + } +} diff --git a/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/data/SparkValueWriters.java b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/data/SparkValueWriters.java new file mode 100644 index 000000000000..5f2e2c054888 --- /dev/null +++ b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/data/SparkValueWriters.java @@ -0,0 +1,258 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.data; + +import java.io.IOException; +import java.lang.reflect.Array; +import java.nio.ByteBuffer; +import java.nio.ByteOrder; +import java.util.List; +import java.util.UUID; +import org.apache.avro.io.Encoder; +import org.apache.avro.util.Utf8; +import org.apache.iceberg.avro.ValueWriter; +import org.apache.iceberg.types.TypeUtil; +import org.apache.iceberg.util.DecimalUtil; +import org.apache.spark.sql.catalyst.InternalRow; +import org.apache.spark.sql.catalyst.util.ArrayData; +import org.apache.spark.sql.catalyst.util.MapData; +import org.apache.spark.sql.types.DataType; +import org.apache.spark.sql.types.Decimal; +import org.apache.spark.unsafe.types.UTF8String; + +public class SparkValueWriters { + + private SparkValueWriters() {} + + static ValueWriter strings() { + return StringWriter.INSTANCE; + } + + static ValueWriter uuids() { + return UUIDWriter.INSTANCE; + } + + static ValueWriter decimal(int precision, int scale) { + return new DecimalWriter(precision, scale); + } + + static ValueWriter array(ValueWriter elementWriter, DataType elementType) { + return new ArrayWriter<>(elementWriter, elementType); + } + + static ValueWriter arrayMap( + ValueWriter keyWriter, DataType keyType, ValueWriter valueWriter, DataType valueType) { + return new ArrayMapWriter<>(keyWriter, keyType, valueWriter, valueType); + } + + static ValueWriter map( + ValueWriter keyWriter, DataType keyType, ValueWriter valueWriter, DataType valueType) { + return new MapWriter<>(keyWriter, keyType, valueWriter, valueType); + } + + static ValueWriter struct(List> writers, List types) { + return new StructWriter(writers, types); + } + + private static class StringWriter implements ValueWriter { + private static final StringWriter INSTANCE = new StringWriter(); + + private StringWriter() {} + + @Override + public void write(UTF8String s, Encoder encoder) throws IOException { + // use getBytes because it may return the backing byte array if available. + // otherwise, it copies to a new byte array, which is still cheaper than Avro + // calling toString, which incurs encoding costs + encoder.writeString(new Utf8(s.getBytes())); + } + } + + private static class UUIDWriter implements ValueWriter { + private static final ThreadLocal BUFFER = + ThreadLocal.withInitial( + () -> { + ByteBuffer buffer = ByteBuffer.allocate(16); + buffer.order(ByteOrder.BIG_ENDIAN); + return buffer; + }); + + private static final UUIDWriter INSTANCE = new UUIDWriter(); + + private UUIDWriter() {} + + @Override + @SuppressWarnings("ByteBufferBackingArray") + public void write(UTF8String s, Encoder encoder) throws IOException { + // TODO: direct conversion from string to byte buffer + UUID uuid = UUID.fromString(s.toString()); + ByteBuffer buffer = BUFFER.get(); + buffer.rewind(); + buffer.putLong(uuid.getMostSignificantBits()); + buffer.putLong(uuid.getLeastSignificantBits()); + encoder.writeFixed(buffer.array()); + } + } + + private static class DecimalWriter implements ValueWriter { + private final int precision; + private final int scale; + private final ThreadLocal bytes; + + private DecimalWriter(int precision, int scale) { + this.precision = precision; + this.scale = scale; + this.bytes = + ThreadLocal.withInitial(() -> new byte[TypeUtil.decimalRequiredBytes(precision)]); + } + + @Override + public void write(Decimal d, Encoder encoder) throws IOException { + encoder.writeFixed( + DecimalUtil.toReusedFixLengthBytes(precision, scale, d.toJavaBigDecimal(), bytes.get())); + } + } + + private static class ArrayWriter implements ValueWriter { + private final ValueWriter elementWriter; + private final DataType elementType; + + private ArrayWriter(ValueWriter elementWriter, DataType elementType) { + this.elementWriter = elementWriter; + this.elementType = elementType; + } + + @Override + @SuppressWarnings("unchecked") + public void write(ArrayData array, Encoder encoder) throws IOException { + encoder.writeArrayStart(); + int numElements = array.numElements(); + encoder.setItemCount(numElements); + for (int i = 0; i < numElements; i += 1) { + encoder.startItem(); + elementWriter.write((T) array.get(i, elementType), encoder); + } + encoder.writeArrayEnd(); + } + } + + private static class ArrayMapWriter implements ValueWriter { + private final ValueWriter keyWriter; + private final ValueWriter valueWriter; + private final DataType keyType; + private final DataType valueType; + + private ArrayMapWriter( + ValueWriter keyWriter, + DataType keyType, + ValueWriter valueWriter, + DataType valueType) { + this.keyWriter = keyWriter; + this.keyType = keyType; + this.valueWriter = valueWriter; + this.valueType = valueType; + } + + @Override + @SuppressWarnings("unchecked") + public void write(MapData map, Encoder encoder) throws IOException { + encoder.writeArrayStart(); + int numElements = map.numElements(); + encoder.setItemCount(numElements); + ArrayData keyArray = map.keyArray(); + ArrayData valueArray = map.valueArray(); + for (int i = 0; i < numElements; i += 1) { + encoder.startItem(); + keyWriter.write((K) keyArray.get(i, keyType), encoder); + valueWriter.write((V) valueArray.get(i, valueType), encoder); + } + encoder.writeArrayEnd(); + } + } + + private static class MapWriter implements ValueWriter { + private final ValueWriter keyWriter; + private final ValueWriter valueWriter; + private final DataType keyType; + private final DataType valueType; + + private MapWriter( + ValueWriter keyWriter, + DataType keyType, + ValueWriter valueWriter, + DataType valueType) { + this.keyWriter = keyWriter; + this.keyType = keyType; + this.valueWriter = valueWriter; + this.valueType = valueType; + } + + @Override + @SuppressWarnings("unchecked") + public void write(MapData map, Encoder encoder) throws IOException { + encoder.writeMapStart(); + int numElements = map.numElements(); + encoder.setItemCount(numElements); + ArrayData keyArray = map.keyArray(); + ArrayData valueArray = map.valueArray(); + for (int i = 0; i < numElements; i += 1) { + encoder.startItem(); + keyWriter.write((K) keyArray.get(i, keyType), encoder); + valueWriter.write((V) valueArray.get(i, valueType), encoder); + } + encoder.writeMapEnd(); + } + } + + static class StructWriter implements ValueWriter { + private final ValueWriter[] writers; + private final DataType[] types; + + @SuppressWarnings("unchecked") + private StructWriter(List> writers, List types) { + this.writers = (ValueWriter[]) Array.newInstance(ValueWriter.class, writers.size()); + this.types = new DataType[writers.size()]; + for (int i = 0; i < writers.size(); i += 1) { + this.writers[i] = writers.get(i); + this.types[i] = types.get(i); + } + } + + ValueWriter[] writers() { + return writers; + } + + @Override + public void write(InternalRow row, Encoder encoder) throws IOException { + for (int i = 0; i < types.length; i += 1) { + if (row.isNullAt(i)) { + writers[i].write(null, encoder); + } else { + write(row, i, writers[i], encoder); + } + } + } + + @SuppressWarnings("unchecked") + private void write(InternalRow row, int pos, ValueWriter writer, Encoder encoder) + throws IOException { + writer.write((T) row.get(pos, types[pos]), encoder); + } + } +} diff --git a/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/data/vectorized/ArrowVectorAccessorFactory.java b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/data/vectorized/ArrowVectorAccessorFactory.java new file mode 100644 index 000000000000..e32ebcb02bbc --- /dev/null +++ b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/data/vectorized/ArrowVectorAccessorFactory.java @@ -0,0 +1,125 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.data.vectorized; + +import java.math.BigDecimal; +import java.nio.ByteBuffer; +import org.apache.arrow.memory.ArrowBuf; +import org.apache.arrow.vector.ValueVector; +import org.apache.arrow.vector.VarCharVector; +import org.apache.arrow.vector.complex.ListVector; +import org.apache.iceberg.arrow.vectorized.GenericArrowVectorAccessorFactory; +import org.apache.spark.sql.types.Decimal; +import org.apache.spark.sql.vectorized.ArrowColumnVector; +import org.apache.spark.sql.vectorized.ColumnarArray; +import org.apache.spark.unsafe.types.UTF8String; + +final class ArrowVectorAccessorFactory + extends GenericArrowVectorAccessorFactory< + Decimal, UTF8String, ColumnarArray, ArrowColumnVector> { + + ArrowVectorAccessorFactory() { + super( + DecimalFactoryImpl::new, + StringFactoryImpl::new, + StructChildFactoryImpl::new, + ArrayFactoryImpl::new); + } + + private static final class DecimalFactoryImpl implements DecimalFactory { + @Override + public Class getGenericClass() { + return Decimal.class; + } + + @Override + public Decimal ofLong(long value, int precision, int scale) { + return Decimal.apply(value, precision, scale); + } + + @Override + public Decimal ofBigDecimal(BigDecimal value, int precision, int scale) { + return Decimal.apply(value, precision, scale); + } + } + + private static final class StringFactoryImpl implements StringFactory { + @Override + public Class getGenericClass() { + return UTF8String.class; + } + + @Override + public UTF8String ofRow(VarCharVector vector, int rowId) { + int start = vector.getStartOffset(rowId); + int end = vector.getEndOffset(rowId); + + return UTF8String.fromAddress( + null, vector.getDataBuffer().memoryAddress() + start, end - start); + } + + @Override + public UTF8String ofBytes(byte[] bytes) { + return UTF8String.fromBytes(bytes); + } + + @Override + public UTF8String ofByteBuffer(ByteBuffer byteBuffer) { + if (byteBuffer.hasArray()) { + return UTF8String.fromBytes( + byteBuffer.array(), + byteBuffer.arrayOffset() + byteBuffer.position(), + byteBuffer.remaining()); + } + byte[] bytes = new byte[byteBuffer.remaining()]; + byteBuffer.get(bytes); + return UTF8String.fromBytes(bytes); + } + } + + private static final class ArrayFactoryImpl + implements ArrayFactory { + @Override + public ArrowColumnVector ofChild(ValueVector childVector) { + return new ArrowColumnVector(childVector); + } + + @Override + public ColumnarArray ofRow(ValueVector vector, ArrowColumnVector childData, int rowId) { + ArrowBuf offsets = vector.getOffsetBuffer(); + int index = rowId * ListVector.OFFSET_WIDTH; + int start = offsets.getInt(index); + int end = offsets.getInt(index + ListVector.OFFSET_WIDTH); + return new ColumnarArray(childData, start, end - start); + } + } + + private static final class StructChildFactoryImpl + implements StructChildFactory { + @Override + public Class getGenericClass() { + return ArrowColumnVector.class; + } + + @Override + public ArrowColumnVector of(ValueVector childVector) { + return new ArrowColumnVector(childVector); + } + } +} diff --git a/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/data/vectorized/ArrowVectorAccessors.java b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/data/vectorized/ArrowVectorAccessors.java new file mode 100644 index 000000000000..810fef81b5bb --- /dev/null +++ b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/data/vectorized/ArrowVectorAccessors.java @@ -0,0 +1,38 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.data.vectorized; + +import org.apache.iceberg.arrow.vectorized.ArrowVectorAccessor; +import org.apache.iceberg.arrow.vectorized.VectorHolder; +import org.apache.spark.sql.types.Decimal; +import org.apache.spark.sql.vectorized.ArrowColumnVector; +import org.apache.spark.sql.vectorized.ColumnarArray; +import org.apache.spark.unsafe.types.UTF8String; + +public class ArrowVectorAccessors { + + private static final ArrowVectorAccessorFactory factory = new ArrowVectorAccessorFactory(); + + static ArrowVectorAccessor + getVectorAccessor(VectorHolder holder) { + return factory.getVectorAccessor(holder); + } + + private ArrowVectorAccessors() {} +} diff --git a/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/data/vectorized/ColumnVectorBuilder.java b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/data/vectorized/ColumnVectorBuilder.java new file mode 100644 index 000000000000..8080a946c6f7 --- /dev/null +++ b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/data/vectorized/ColumnVectorBuilder.java @@ -0,0 +1,52 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.data.vectorized; + +import org.apache.iceberg.arrow.vectorized.VectorHolder; +import org.apache.iceberg.arrow.vectorized.VectorHolder.ConstantVectorHolder; +import org.apache.iceberg.types.Types; +import org.apache.spark.sql.vectorized.ColumnVector; + +class ColumnVectorBuilder { + private boolean[] isDeleted; + private int[] rowIdMapping; + + public ColumnVectorBuilder withDeletedRows(int[] rowIdMappingArray, boolean[] isDeletedArray) { + this.rowIdMapping = rowIdMappingArray; + this.isDeleted = isDeletedArray; + return this; + } + + public ColumnVector build(VectorHolder holder, int numRows) { + if (holder.isDummy()) { + if (holder instanceof VectorHolder.DeletedVectorHolder) { + return new DeletedColumnVector(Types.BooleanType.get(), isDeleted); + } else if (holder instanceof ConstantVectorHolder) { + return new ConstantColumnVector( + Types.IntegerType.get(), numRows, ((ConstantVectorHolder) holder).getConstant()); + } else { + throw new IllegalStateException("Unknown dummy vector holder: " + holder); + } + } else if (rowIdMapping != null) { + return new ColumnVectorWithFilter(holder, rowIdMapping); + } else { + return new IcebergArrowColumnVector(holder); + } + } +} diff --git a/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/data/vectorized/ColumnVectorWithFilter.java b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/data/vectorized/ColumnVectorWithFilter.java new file mode 100644 index 000000000000..ab0d652321d3 --- /dev/null +++ b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/data/vectorized/ColumnVectorWithFilter.java @@ -0,0 +1,95 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.data.vectorized; + +import org.apache.iceberg.arrow.vectorized.VectorHolder; +import org.apache.spark.sql.types.Decimal; +import org.apache.spark.sql.vectorized.ColumnarArray; +import org.apache.spark.unsafe.types.UTF8String; + +public class ColumnVectorWithFilter extends IcebergArrowColumnVector { + private final int[] rowIdMapping; + + public ColumnVectorWithFilter(VectorHolder holder, int[] rowIdMapping) { + super(holder); + this.rowIdMapping = rowIdMapping; + } + + @Override + public boolean isNullAt(int rowId) { + return nullabilityHolder().isNullAt(rowIdMapping[rowId]) == 1; + } + + @Override + public boolean getBoolean(int rowId) { + return accessor().getBoolean(rowIdMapping[rowId]); + } + + @Override + public int getInt(int rowId) { + return accessor().getInt(rowIdMapping[rowId]); + } + + @Override + public long getLong(int rowId) { + return accessor().getLong(rowIdMapping[rowId]); + } + + @Override + public float getFloat(int rowId) { + return accessor().getFloat(rowIdMapping[rowId]); + } + + @Override + public double getDouble(int rowId) { + return accessor().getDouble(rowIdMapping[rowId]); + } + + @Override + public ColumnarArray getArray(int rowId) { + if (isNullAt(rowId)) { + return null; + } + return accessor().getArray(rowIdMapping[rowId]); + } + + @Override + public Decimal getDecimal(int rowId, int precision, int scale) { + if (isNullAt(rowId)) { + return null; + } + return accessor().getDecimal(rowIdMapping[rowId], precision, scale); + } + + @Override + public UTF8String getUTF8String(int rowId) { + if (isNullAt(rowId)) { + return null; + } + return accessor().getUTF8String(rowIdMapping[rowId]); + } + + @Override + public byte[] getBinary(int rowId) { + if (isNullAt(rowId)) { + return null; + } + return accessor().getBinary(rowIdMapping[rowId]); + } +} diff --git a/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/data/vectorized/ColumnarBatchReader.java b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/data/vectorized/ColumnarBatchReader.java new file mode 100644 index 000000000000..f07d8c545e35 --- /dev/null +++ b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/data/vectorized/ColumnarBatchReader.java @@ -0,0 +1,249 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.data.vectorized; + +import java.util.Iterator; +import java.util.List; +import java.util.Map; +import org.apache.iceberg.arrow.vectorized.BaseBatchReader; +import org.apache.iceberg.arrow.vectorized.VectorizedArrowReader; +import org.apache.iceberg.arrow.vectorized.VectorizedArrowReader.DeletedVectorReader; +import org.apache.iceberg.data.DeleteFilter; +import org.apache.iceberg.deletes.PositionDeleteIndex; +import org.apache.iceberg.parquet.VectorizedReader; +import org.apache.iceberg.relocated.com.google.common.base.Preconditions; +import org.apache.iceberg.util.Pair; +import org.apache.parquet.column.page.PageReadStore; +import org.apache.parquet.hadoop.metadata.ColumnChunkMetaData; +import org.apache.parquet.hadoop.metadata.ColumnPath; +import org.apache.spark.sql.catalyst.InternalRow; +import org.apache.spark.sql.vectorized.ColumnVector; +import org.apache.spark.sql.vectorized.ColumnarBatch; + +/** + * {@link VectorizedReader} that returns Spark's {@link ColumnarBatch} to support Spark's vectorized + * read path. The {@link ColumnarBatch} returned is created by passing in the Arrow vectors + * populated via delegated read calls to {@linkplain VectorizedArrowReader VectorReader(s)}. + */ +public class ColumnarBatchReader extends BaseBatchReader { + private final boolean hasIsDeletedColumn; + private DeleteFilter deletes = null; + private long rowStartPosInBatch = 0; + + public ColumnarBatchReader(List> readers) { + super(readers); + this.hasIsDeletedColumn = + readers.stream().anyMatch(reader -> reader instanceof DeletedVectorReader); + } + + @Override + public void setRowGroupInfo( + PageReadStore pageStore, Map metaData, long rowPosition) { + super.setRowGroupInfo(pageStore, metaData, rowPosition); + this.rowStartPosInBatch = rowPosition; + } + + public void setDeleteFilter(DeleteFilter deleteFilter) { + this.deletes = deleteFilter; + } + + @Override + public final ColumnarBatch read(ColumnarBatch reuse, int numRowsToRead) { + if (reuse == null) { + closeVectors(); + } + + ColumnarBatch columnarBatch = new ColumnBatchLoader(numRowsToRead).loadDataToColumnBatch(); + rowStartPosInBatch += numRowsToRead; + return columnarBatch; + } + + private class ColumnBatchLoader { + private final int numRowsToRead; + // the rowId mapping to skip deleted rows for all column vectors inside a batch, it is null when + // there is no deletes + private int[] rowIdMapping; + // the array to indicate if a row is deleted or not, it is null when there is no "_deleted" + // metadata column + private boolean[] isDeleted; + + ColumnBatchLoader(int numRowsToRead) { + Preconditions.checkArgument( + numRowsToRead > 0, "Invalid number of rows to read: %s", numRowsToRead); + this.numRowsToRead = numRowsToRead; + if (hasIsDeletedColumn) { + isDeleted = new boolean[numRowsToRead]; + } + } + + ColumnarBatch loadDataToColumnBatch() { + int numRowsUndeleted = initRowIdMapping(); + + ColumnVector[] arrowColumnVectors = readDataToColumnVectors(); + + ColumnarBatch newColumnarBatch = new ColumnarBatch(arrowColumnVectors); + newColumnarBatch.setNumRows(numRowsUndeleted); + + if (hasEqDeletes()) { + applyEqDelete(newColumnarBatch); + } + + if (hasIsDeletedColumn && rowIdMapping != null) { + // reset the row id mapping array, so that it doesn't filter out the deleted rows + for (int i = 0; i < numRowsToRead; i++) { + rowIdMapping[i] = i; + } + newColumnarBatch.setNumRows(numRowsToRead); + } + + return newColumnarBatch; + } + + ColumnVector[] readDataToColumnVectors() { + ColumnVector[] arrowColumnVectors = new ColumnVector[readers.length]; + + ColumnVectorBuilder columnVectorBuilder = new ColumnVectorBuilder(); + for (int i = 0; i < readers.length; i += 1) { + vectorHolders[i] = readers[i].read(vectorHolders[i], numRowsToRead); + int numRowsInVector = vectorHolders[i].numValues(); + Preconditions.checkState( + numRowsInVector == numRowsToRead, + "Number of rows in the vector %s didn't match expected %s ", + numRowsInVector, + numRowsToRead); + + arrowColumnVectors[i] = + columnVectorBuilder + .withDeletedRows(rowIdMapping, isDeleted) + .build(vectorHolders[i], numRowsInVector); + } + return arrowColumnVectors; + } + + boolean hasEqDeletes() { + return deletes != null && deletes.hasEqDeletes(); + } + + int initRowIdMapping() { + Pair posDeleteRowIdMapping = posDelRowIdMapping(); + if (posDeleteRowIdMapping != null) { + rowIdMapping = posDeleteRowIdMapping.first(); + return posDeleteRowIdMapping.second(); + } else { + rowIdMapping = initEqDeleteRowIdMapping(); + return numRowsToRead; + } + } + + Pair posDelRowIdMapping() { + if (deletes != null && deletes.hasPosDeletes()) { + return buildPosDelRowIdMapping(deletes.deletedRowPositions()); + } else { + return null; + } + } + + /** + * Build a row id mapping inside a batch, which skips deleted rows. Here is an example of how we + * delete 2 rows in a batch with 8 rows in total. [0,1,2,3,4,5,6,7] -- Original status of the + * row id mapping array [F,F,F,F,F,F,F,F] -- Original status of the isDeleted array Position + * delete 2, 6 [0,1,3,4,5,7,-,-] -- After applying position deletes [Set Num records to 6] + * [F,F,T,F,F,F,T,F] -- After applying position deletes + * + * @param deletedRowPositions a set of deleted row positions + * @return the mapping array and the new num of rows in a batch, null if no row is deleted + */ + Pair buildPosDelRowIdMapping(PositionDeleteIndex deletedRowPositions) { + if (deletedRowPositions == null) { + return null; + } + + int[] posDelRowIdMapping = new int[numRowsToRead]; + int originalRowId = 0; + int currentRowId = 0; + while (originalRowId < numRowsToRead) { + if (!deletedRowPositions.isDeleted(originalRowId + rowStartPosInBatch)) { + posDelRowIdMapping[currentRowId] = originalRowId; + currentRowId++; + } else { + if (hasIsDeletedColumn) { + isDeleted[originalRowId] = true; + } + + deletes.incrementDeleteCount(); + } + originalRowId++; + } + + if (currentRowId == numRowsToRead) { + // there is no delete in this batch + return null; + } else { + return Pair.of(posDelRowIdMapping, currentRowId); + } + } + + int[] initEqDeleteRowIdMapping() { + int[] eqDeleteRowIdMapping = null; + if (hasEqDeletes()) { + eqDeleteRowIdMapping = new int[numRowsToRead]; + for (int i = 0; i < numRowsToRead; i++) { + eqDeleteRowIdMapping[i] = i; + } + } + + return eqDeleteRowIdMapping; + } + + /** + * Filter out the equality deleted rows. Here is an example, [0,1,2,3,4,5,6,7] -- Original + * status of the row id mapping array [F,F,F,F,F,F,F,F] -- Original status of the isDeleted + * array Position delete 2, 6 [0,1,3,4,5,7,-,-] -- After applying position deletes [Set Num + * records to 6] [F,F,T,F,F,F,T,F] -- After applying position deletes Equality delete 1 <= x <= + * 3 [0,4,5,7,-,-,-,-] -- After applying equality deletes [Set Num records to 4] + * [F,T,T,T,F,F,T,F] -- After applying equality deletes + * + * @param columnarBatch the {@link ColumnarBatch} to apply the equality delete + */ + void applyEqDelete(ColumnarBatch columnarBatch) { + Iterator it = columnarBatch.rowIterator(); + int rowId = 0; + int currentRowId = 0; + while (it.hasNext()) { + InternalRow row = it.next(); + if (deletes.eqDeletedRowFilter().test(row)) { + // the row is NOT deleted + // skip deleted rows by pointing to the next undeleted row Id + rowIdMapping[currentRowId] = rowIdMapping[rowId]; + currentRowId++; + } else { + if (hasIsDeletedColumn) { + isDeleted[rowIdMapping[rowId]] = true; + } + + deletes.incrementDeleteCount(); + } + + rowId++; + } + + columnarBatch.setNumRows(currentRowId); + } + } +} diff --git a/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/data/vectorized/ConstantColumnVector.java b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/data/vectorized/ConstantColumnVector.java new file mode 100644 index 000000000000..42683ffa901e --- /dev/null +++ b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/data/vectorized/ConstantColumnVector.java @@ -0,0 +1,122 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.data.vectorized; + +import org.apache.iceberg.spark.SparkSchemaUtil; +import org.apache.iceberg.types.Type; +import org.apache.spark.sql.types.Decimal; +import org.apache.spark.sql.vectorized.ColumnVector; +import org.apache.spark.sql.vectorized.ColumnarArray; +import org.apache.spark.sql.vectorized.ColumnarMap; +import org.apache.spark.unsafe.types.UTF8String; + +class ConstantColumnVector extends ColumnVector { + + private final Object constant; + private final int batchSize; + + ConstantColumnVector(Type type, int batchSize, Object constant) { + super(SparkSchemaUtil.convert(type)); + this.constant = constant; + this.batchSize = batchSize; + } + + @Override + public void close() {} + + @Override + public boolean hasNull() { + return constant == null; + } + + @Override + public int numNulls() { + return constant == null ? batchSize : 0; + } + + @Override + public boolean isNullAt(int rowId) { + return constant == null; + } + + @Override + public boolean getBoolean(int rowId) { + return (boolean) constant; + } + + @Override + public byte getByte(int rowId) { + return (byte) constant; + } + + @Override + public short getShort(int rowId) { + return (short) constant; + } + + @Override + public int getInt(int rowId) { + return (int) constant; + } + + @Override + public long getLong(int rowId) { + return (long) constant; + } + + @Override + public float getFloat(int rowId) { + return (float) constant; + } + + @Override + public double getDouble(int rowId) { + return (double) constant; + } + + @Override + public ColumnarArray getArray(int rowId) { + throw new UnsupportedOperationException("ConstantColumnVector only supports primitives"); + } + + @Override + public ColumnarMap getMap(int ordinal) { + throw new UnsupportedOperationException("ConstantColumnVector only supports primitives"); + } + + @Override + public Decimal getDecimal(int rowId, int precision, int scale) { + return (Decimal) constant; + } + + @Override + public UTF8String getUTF8String(int rowId) { + return (UTF8String) constant; + } + + @Override + public byte[] getBinary(int rowId) { + return (byte[]) constant; + } + + @Override + public ColumnVector getChild(int ordinal) { + throw new UnsupportedOperationException("ConstantColumnVector only supports primitives"); + } +} diff --git a/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/data/vectorized/DeletedColumnVector.java b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/data/vectorized/DeletedColumnVector.java new file mode 100644 index 000000000000..eec6ecb9ace4 --- /dev/null +++ b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/data/vectorized/DeletedColumnVector.java @@ -0,0 +1,121 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.data.vectorized; + +import org.apache.iceberg.relocated.com.google.common.base.Preconditions; +import org.apache.iceberg.spark.SparkSchemaUtil; +import org.apache.iceberg.types.Type; +import org.apache.spark.sql.types.Decimal; +import org.apache.spark.sql.vectorized.ColumnVector; +import org.apache.spark.sql.vectorized.ColumnarArray; +import org.apache.spark.sql.vectorized.ColumnarMap; +import org.apache.spark.unsafe.types.UTF8String; + +public class DeletedColumnVector extends ColumnVector { + private final boolean[] isDeleted; + + public DeletedColumnVector(Type type, boolean[] isDeleted) { + super(SparkSchemaUtil.convert(type)); + Preconditions.checkArgument(isDeleted != null, "Boolean array isDeleted cannot be null"); + this.isDeleted = isDeleted; + } + + @Override + public void close() {} + + @Override + public boolean hasNull() { + return false; + } + + @Override + public int numNulls() { + return 0; + } + + @Override + public boolean isNullAt(int rowId) { + return false; + } + + @Override + public boolean getBoolean(int rowId) { + return isDeleted[rowId]; + } + + @Override + public byte getByte(int rowId) { + throw new UnsupportedOperationException(); + } + + @Override + public short getShort(int rowId) { + throw new UnsupportedOperationException(); + } + + @Override + public int getInt(int rowId) { + throw new UnsupportedOperationException(); + } + + @Override + public long getLong(int rowId) { + throw new UnsupportedOperationException(); + } + + @Override + public float getFloat(int rowId) { + throw new UnsupportedOperationException(); + } + + @Override + public double getDouble(int rowId) { + throw new UnsupportedOperationException(); + } + + @Override + public ColumnarArray getArray(int rowId) { + throw new UnsupportedOperationException(); + } + + @Override + public ColumnarMap getMap(int ordinal) { + throw new UnsupportedOperationException(); + } + + @Override + public Decimal getDecimal(int rowId, int precision, int scale) { + throw new UnsupportedOperationException(); + } + + @Override + public UTF8String getUTF8String(int rowId) { + throw new UnsupportedOperationException(); + } + + @Override + public byte[] getBinary(int rowId) { + throw new UnsupportedOperationException(); + } + + @Override + public ColumnVector getChild(int ordinal) { + throw new UnsupportedOperationException(); + } +} diff --git a/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/data/vectorized/IcebergArrowColumnVector.java b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/data/vectorized/IcebergArrowColumnVector.java new file mode 100644 index 000000000000..38ec3a0e838c --- /dev/null +++ b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/data/vectorized/IcebergArrowColumnVector.java @@ -0,0 +1,158 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.data.vectorized; + +import org.apache.iceberg.arrow.vectorized.ArrowVectorAccessor; +import org.apache.iceberg.arrow.vectorized.NullabilityHolder; +import org.apache.iceberg.arrow.vectorized.VectorHolder; +import org.apache.iceberg.spark.SparkSchemaUtil; +import org.apache.spark.sql.types.Decimal; +import org.apache.spark.sql.vectorized.ArrowColumnVector; +import org.apache.spark.sql.vectorized.ColumnVector; +import org.apache.spark.sql.vectorized.ColumnarArray; +import org.apache.spark.sql.vectorized.ColumnarMap; +import org.apache.spark.unsafe.types.UTF8String; + +/** + * Implementation of Spark's {@link ColumnVector} interface. The code for this class is heavily + * inspired from Spark's {@link ArrowColumnVector} The main difference is in how nullability checks + * are made in this class by relying on {@link NullabilityHolder} instead of the validity vector in + * the Arrow vector. + */ +public class IcebergArrowColumnVector extends ColumnVector { + + private final ArrowVectorAccessor accessor; + private final NullabilityHolder nullabilityHolder; + + public IcebergArrowColumnVector(VectorHolder holder) { + super(SparkSchemaUtil.convert(holder.icebergType())); + this.nullabilityHolder = holder.nullabilityHolder(); + this.accessor = ArrowVectorAccessors.getVectorAccessor(holder); + } + + protected ArrowVectorAccessor accessor() { + return accessor; + } + + protected NullabilityHolder nullabilityHolder() { + return nullabilityHolder; + } + + @Override + public void close() { + accessor.close(); + } + + @Override + public boolean hasNull() { + return nullabilityHolder.hasNulls(); + } + + @Override + public int numNulls() { + return nullabilityHolder.numNulls(); + } + + @Override + public boolean isNullAt(int rowId) { + return nullabilityHolder.isNullAt(rowId) == 1; + } + + @Override + public boolean getBoolean(int rowId) { + return accessor.getBoolean(rowId); + } + + @Override + public byte getByte(int rowId) { + throw new UnsupportedOperationException("Unsupported type - byte"); + } + + @Override + public short getShort(int rowId) { + throw new UnsupportedOperationException("Unsupported type - short"); + } + + @Override + public int getInt(int rowId) { + return accessor.getInt(rowId); + } + + @Override + public long getLong(int rowId) { + return accessor.getLong(rowId); + } + + @Override + public float getFloat(int rowId) { + return accessor.getFloat(rowId); + } + + @Override + public double getDouble(int rowId) { + return accessor.getDouble(rowId); + } + + @Override + public ColumnarArray getArray(int rowId) { + if (isNullAt(rowId)) { + return null; + } + return accessor.getArray(rowId); + } + + @Override + public ColumnarMap getMap(int rowId) { + throw new UnsupportedOperationException("Unsupported type - map"); + } + + @Override + public Decimal getDecimal(int rowId, int precision, int scale) { + if (isNullAt(rowId)) { + return null; + } + return accessor.getDecimal(rowId, precision, scale); + } + + @Override + public UTF8String getUTF8String(int rowId) { + if (isNullAt(rowId)) { + return null; + } + return accessor.getUTF8String(rowId); + } + + @Override + public byte[] getBinary(int rowId) { + if (isNullAt(rowId)) { + return null; + } + return accessor.getBinary(rowId); + } + + @Override + public ArrowColumnVector getChild(int ordinal) { + return accessor.childColumn(ordinal); + } + + public ArrowVectorAccessor + vectorAccessor() { + return accessor; + } +} diff --git a/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/data/vectorized/RowPositionColumnVector.java b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/data/vectorized/RowPositionColumnVector.java new file mode 100644 index 000000000000..a389cd8286e5 --- /dev/null +++ b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/data/vectorized/RowPositionColumnVector.java @@ -0,0 +1,120 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.data.vectorized; + +import org.apache.iceberg.spark.SparkSchemaUtil; +import org.apache.iceberg.types.Types; +import org.apache.spark.sql.types.Decimal; +import org.apache.spark.sql.vectorized.ColumnVector; +import org.apache.spark.sql.vectorized.ColumnarArray; +import org.apache.spark.sql.vectorized.ColumnarMap; +import org.apache.spark.unsafe.types.UTF8String; + +public class RowPositionColumnVector extends ColumnVector { + + private final long batchOffsetInFile; + + RowPositionColumnVector(long batchOffsetInFile) { + super(SparkSchemaUtil.convert(Types.LongType.get())); + this.batchOffsetInFile = batchOffsetInFile; + } + + @Override + public void close() {} + + @Override + public boolean hasNull() { + return false; + } + + @Override + public int numNulls() { + return 0; + } + + @Override + public boolean isNullAt(int rowId) { + return false; + } + + @Override + public boolean getBoolean(int rowId) { + throw new UnsupportedOperationException(); + } + + @Override + public byte getByte(int rowId) { + throw new UnsupportedOperationException(); + } + + @Override + public short getShort(int rowId) { + throw new UnsupportedOperationException(); + } + + @Override + public int getInt(int rowId) { + throw new UnsupportedOperationException(); + } + + @Override + public long getLong(int rowId) { + return batchOffsetInFile + rowId; + } + + @Override + public float getFloat(int rowId) { + throw new UnsupportedOperationException(); + } + + @Override + public double getDouble(int rowId) { + throw new UnsupportedOperationException(); + } + + @Override + public ColumnarArray getArray(int rowId) { + throw new UnsupportedOperationException(); + } + + @Override + public ColumnarMap getMap(int ordinal) { + throw new UnsupportedOperationException(); + } + + @Override + public Decimal getDecimal(int rowId, int precision, int scale) { + throw new UnsupportedOperationException(); + } + + @Override + public UTF8String getUTF8String(int rowId) { + throw new UnsupportedOperationException(); + } + + @Override + public byte[] getBinary(int rowId) { + throw new UnsupportedOperationException(); + } + + @Override + public ColumnVector getChild(int ordinal) { + throw new UnsupportedOperationException(); + } +} diff --git a/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/data/vectorized/VectorizedSparkOrcReaders.java b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/data/vectorized/VectorizedSparkOrcReaders.java new file mode 100644 index 000000000000..b2d8bd14beca --- /dev/null +++ b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/data/vectorized/VectorizedSparkOrcReaders.java @@ -0,0 +1,494 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.data.vectorized; + +import java.util.List; +import java.util.Map; +import java.util.stream.IntStream; +import org.apache.iceberg.MetadataColumns; +import org.apache.iceberg.Schema; +import org.apache.iceberg.orc.OrcBatchReader; +import org.apache.iceberg.orc.OrcSchemaWithTypeVisitor; +import org.apache.iceberg.orc.OrcValueReader; +import org.apache.iceberg.orc.OrcValueReaders; +import org.apache.iceberg.relocated.com.google.common.collect.Lists; +import org.apache.iceberg.spark.SparkSchemaUtil; +import org.apache.iceberg.spark.data.SparkOrcValueReaders; +import org.apache.iceberg.types.Type; +import org.apache.iceberg.types.Types; +import org.apache.orc.TypeDescription; +import org.apache.orc.storage.ql.exec.vector.ListColumnVector; +import org.apache.orc.storage.ql.exec.vector.MapColumnVector; +import org.apache.orc.storage.ql.exec.vector.StructColumnVector; +import org.apache.orc.storage.ql.exec.vector.VectorizedRowBatch; +import org.apache.spark.sql.types.Decimal; +import org.apache.spark.sql.vectorized.ColumnVector; +import org.apache.spark.sql.vectorized.ColumnarArray; +import org.apache.spark.sql.vectorized.ColumnarBatch; +import org.apache.spark.sql.vectorized.ColumnarMap; +import org.apache.spark.unsafe.types.UTF8String; + +public class VectorizedSparkOrcReaders { + + private VectorizedSparkOrcReaders() {} + + public static OrcBatchReader buildReader( + Schema expectedSchema, TypeDescription fileSchema, Map idToConstant) { + Converter converter = + OrcSchemaWithTypeVisitor.visit(expectedSchema, fileSchema, new ReadBuilder(idToConstant)); + + return new OrcBatchReader() { + private long batchOffsetInFile; + + @Override + public ColumnarBatch read(VectorizedRowBatch batch) { + BaseOrcColumnVector cv = + (BaseOrcColumnVector) + converter.convert( + new StructColumnVector(batch.size, batch.cols), + batch.size, + batchOffsetInFile, + batch.selectedInUse, + batch.selected); + ColumnarBatch columnarBatch = + new ColumnarBatch( + IntStream.range(0, expectedSchema.columns().size()) + .mapToObj(cv::getChild) + .toArray(ColumnVector[]::new)); + columnarBatch.setNumRows(batch.size); + return columnarBatch; + } + + @Override + public void setBatchContext(long batchOffsetInFile) { + this.batchOffsetInFile = batchOffsetInFile; + } + }; + } + + private interface Converter { + ColumnVector convert( + org.apache.orc.storage.ql.exec.vector.ColumnVector columnVector, + int batchSize, + long batchOffsetInFile, + boolean isSelectedInUse, + int[] selected); + } + + private static class ReadBuilder extends OrcSchemaWithTypeVisitor { + private final Map idToConstant; + + private ReadBuilder(Map idToConstant) { + this.idToConstant = idToConstant; + } + + @Override + public Converter record( + Types.StructType iStruct, + TypeDescription record, + List names, + List fields) { + return new StructConverter(iStruct, fields, idToConstant); + } + + @Override + public Converter list(Types.ListType iList, TypeDescription array, Converter element) { + return new ArrayConverter(iList, element); + } + + @Override + public Converter map(Types.MapType iMap, TypeDescription map, Converter key, Converter value) { + return new MapConverter(iMap, key, value); + } + + @Override + public Converter primitive(Type.PrimitiveType iPrimitive, TypeDescription primitive) { + final OrcValueReader primitiveValueReader; + switch (primitive.getCategory()) { + case BOOLEAN: + primitiveValueReader = OrcValueReaders.booleans(); + break; + case BYTE: + // Iceberg does not have a byte type. Use int + case SHORT: + // Iceberg does not have a short type. Use int + case DATE: + case INT: + primitiveValueReader = OrcValueReaders.ints(); + break; + case LONG: + primitiveValueReader = OrcValueReaders.longs(); + break; + case FLOAT: + primitiveValueReader = OrcValueReaders.floats(); + break; + case DOUBLE: + primitiveValueReader = OrcValueReaders.doubles(); + break; + case TIMESTAMP_INSTANT: + case TIMESTAMP: + primitiveValueReader = SparkOrcValueReaders.timestampTzs(); + break; + case DECIMAL: + primitiveValueReader = + SparkOrcValueReaders.decimals(primitive.getPrecision(), primitive.getScale()); + break; + case CHAR: + case VARCHAR: + case STRING: + primitiveValueReader = SparkOrcValueReaders.utf8String(); + break; + case BINARY: + primitiveValueReader = OrcValueReaders.bytes(); + break; + default: + throw new IllegalArgumentException("Unhandled type " + primitive); + } + return (columnVector, batchSize, batchOffsetInFile, isSelectedInUse, selected) -> + new PrimitiveOrcColumnVector( + iPrimitive, + batchSize, + columnVector, + primitiveValueReader, + batchOffsetInFile, + isSelectedInUse, + selected); + } + } + + private abstract static class BaseOrcColumnVector extends ColumnVector { + private final org.apache.orc.storage.ql.exec.vector.ColumnVector vector; + private final int batchSize; + private final boolean isSelectedInUse; + private final int[] selected; + private Integer numNulls; + + BaseOrcColumnVector( + Type type, + int batchSize, + org.apache.orc.storage.ql.exec.vector.ColumnVector vector, + boolean isSelectedInUse, + int[] selected) { + super(SparkSchemaUtil.convert(type)); + this.vector = vector; + this.batchSize = batchSize; + this.isSelectedInUse = isSelectedInUse; + this.selected = selected; + } + + @Override + public void close() {} + + @Override + public boolean hasNull() { + return !vector.noNulls; + } + + @Override + public int numNulls() { + if (numNulls == null) { + numNulls = numNullsHelper(); + } + return numNulls; + } + + private int numNullsHelper() { + if (vector.isRepeating) { + if (vector.isNull[0]) { + return batchSize; + } else { + return 0; + } + } else if (vector.noNulls) { + return 0; + } else { + int count = 0; + for (int i = 0; i < batchSize; i++) { + if (vector.isNull[i]) { + count++; + } + } + return count; + } + } + + protected int getRowIndex(int rowId) { + int row = isSelectedInUse ? selected[rowId] : rowId; + return vector.isRepeating ? 0 : row; + } + + @Override + public boolean isNullAt(int rowId) { + return vector.isNull[getRowIndex(rowId)]; + } + + @Override + public boolean getBoolean(int rowId) { + throw new UnsupportedOperationException(); + } + + @Override + public byte getByte(int rowId) { + throw new UnsupportedOperationException(); + } + + @Override + public short getShort(int rowId) { + throw new UnsupportedOperationException(); + } + + @Override + public int getInt(int rowId) { + throw new UnsupportedOperationException(); + } + + @Override + public long getLong(int rowId) { + throw new UnsupportedOperationException(); + } + + @Override + public float getFloat(int rowId) { + throw new UnsupportedOperationException(); + } + + @Override + public double getDouble(int rowId) { + throw new UnsupportedOperationException(); + } + + @Override + public Decimal getDecimal(int rowId, int precision, int scale) { + throw new UnsupportedOperationException(); + } + + @Override + public UTF8String getUTF8String(int rowId) { + throw new UnsupportedOperationException(); + } + + @Override + public byte[] getBinary(int rowId) { + throw new UnsupportedOperationException(); + } + + @Override + public ColumnarArray getArray(int rowId) { + throw new UnsupportedOperationException(); + } + + @Override + public ColumnarMap getMap(int rowId) { + throw new UnsupportedOperationException(); + } + + @Override + public ColumnVector getChild(int ordinal) { + throw new UnsupportedOperationException(); + } + } + + private static class PrimitiveOrcColumnVector extends BaseOrcColumnVector { + private final org.apache.orc.storage.ql.exec.vector.ColumnVector vector; + private final OrcValueReader primitiveValueReader; + private final long batchOffsetInFile; + + PrimitiveOrcColumnVector( + Type type, + int batchSize, + org.apache.orc.storage.ql.exec.vector.ColumnVector vector, + OrcValueReader primitiveValueReader, + long batchOffsetInFile, + boolean isSelectedInUse, + int[] selected) { + super(type, batchSize, vector, isSelectedInUse, selected); + this.vector = vector; + this.primitiveValueReader = primitiveValueReader; + this.batchOffsetInFile = batchOffsetInFile; + } + + @Override + public boolean getBoolean(int rowId) { + return (Boolean) primitiveValueReader.read(vector, getRowIndex(rowId)); + } + + @Override + public int getInt(int rowId) { + return (Integer) primitiveValueReader.read(vector, getRowIndex(rowId)); + } + + @Override + public long getLong(int rowId) { + return (Long) primitiveValueReader.read(vector, getRowIndex(rowId)); + } + + @Override + public float getFloat(int rowId) { + return (Float) primitiveValueReader.read(vector, getRowIndex(rowId)); + } + + @Override + public double getDouble(int rowId) { + return (Double) primitiveValueReader.read(vector, getRowIndex(rowId)); + } + + @Override + public Decimal getDecimal(int rowId, int precision, int scale) { + // TODO: Is it okay to assume that (precision,scale) parameters == (precision,scale) of the + // decimal type + // and return a Decimal with (precision,scale) of the decimal type? + return (Decimal) primitiveValueReader.read(vector, getRowIndex(rowId)); + } + + @Override + public UTF8String getUTF8String(int rowId) { + return (UTF8String) primitiveValueReader.read(vector, getRowIndex(rowId)); + } + + @Override + public byte[] getBinary(int rowId) { + return (byte[]) primitiveValueReader.read(vector, getRowIndex(rowId)); + } + } + + private static class ArrayConverter implements Converter { + private final Types.ListType listType; + private final Converter elementConverter; + + private ArrayConverter(Types.ListType listType, Converter elementConverter) { + this.listType = listType; + this.elementConverter = elementConverter; + } + + @Override + public ColumnVector convert( + org.apache.orc.storage.ql.exec.vector.ColumnVector vector, + int batchSize, + long batchOffsetInFile, + boolean isSelectedInUse, + int[] selected) { + ListColumnVector listVector = (ListColumnVector) vector; + ColumnVector elementVector = + elementConverter.convert(listVector.child, batchSize, batchOffsetInFile, false, null); + + return new BaseOrcColumnVector(listType, batchSize, vector, isSelectedInUse, selected) { + @Override + public ColumnarArray getArray(int rowId) { + int index = getRowIndex(rowId); + return new ColumnarArray( + elementVector, (int) listVector.offsets[index], (int) listVector.lengths[index]); + } + }; + } + } + + private static class MapConverter implements Converter { + private final Types.MapType mapType; + private final Converter keyConverter; + private final Converter valueConverter; + + private MapConverter(Types.MapType mapType, Converter keyConverter, Converter valueConverter) { + this.mapType = mapType; + this.keyConverter = keyConverter; + this.valueConverter = valueConverter; + } + + @Override + public ColumnVector convert( + org.apache.orc.storage.ql.exec.vector.ColumnVector vector, + int batchSize, + long batchOffsetInFile, + boolean isSelectedInUse, + int[] selected) { + MapColumnVector mapVector = (MapColumnVector) vector; + ColumnVector keyVector = + keyConverter.convert(mapVector.keys, batchSize, batchOffsetInFile, false, null); + ColumnVector valueVector = + valueConverter.convert(mapVector.values, batchSize, batchOffsetInFile, false, null); + + return new BaseOrcColumnVector(mapType, batchSize, vector, isSelectedInUse, selected) { + @Override + public ColumnarMap getMap(int rowId) { + int index = getRowIndex(rowId); + return new ColumnarMap( + keyVector, + valueVector, + (int) mapVector.offsets[index], + (int) mapVector.lengths[index]); + } + }; + } + } + + private static class StructConverter implements Converter { + private final Types.StructType structType; + private final List fieldConverters; + private final Map idToConstant; + + private StructConverter( + Types.StructType structType, + List fieldConverters, + Map idToConstant) { + this.structType = structType; + this.fieldConverters = fieldConverters; + this.idToConstant = idToConstant; + } + + @Override + public ColumnVector convert( + org.apache.orc.storage.ql.exec.vector.ColumnVector vector, + int batchSize, + long batchOffsetInFile, + boolean isSelectedInUse, + int[] selected) { + StructColumnVector structVector = (StructColumnVector) vector; + List fields = structType.fields(); + List fieldVectors = Lists.newArrayListWithExpectedSize(fields.size()); + for (int pos = 0, vectorIndex = 0; pos < fields.size(); pos += 1) { + Types.NestedField field = fields.get(pos); + if (idToConstant.containsKey(field.fieldId())) { + fieldVectors.add( + new ConstantColumnVector(field.type(), batchSize, idToConstant.get(field.fieldId()))); + } else if (field.equals(MetadataColumns.ROW_POSITION)) { + fieldVectors.add(new RowPositionColumnVector(batchOffsetInFile)); + } else if (field.equals(MetadataColumns.IS_DELETED)) { + fieldVectors.add(new ConstantColumnVector(field.type(), batchSize, false)); + } else { + fieldVectors.add( + fieldConverters + .get(vectorIndex) + .convert( + structVector.fields[vectorIndex], + batchSize, + batchOffsetInFile, + isSelectedInUse, + selected)); + vectorIndex++; + } + } + + return new BaseOrcColumnVector(structType, batchSize, vector, isSelectedInUse, selected) { + @Override + public ColumnVector getChild(int ordinal) { + return fieldVectors.get(ordinal); + } + }; + } + } +} diff --git a/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/data/vectorized/VectorizedSparkParquetReaders.java b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/data/vectorized/VectorizedSparkParquetReaders.java new file mode 100644 index 000000000000..7a849979509f --- /dev/null +++ b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/data/vectorized/VectorizedSparkParquetReaders.java @@ -0,0 +1,185 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.data.vectorized; + +import java.util.List; +import java.util.Map; +import java.util.function.Function; +import org.apache.arrow.vector.NullCheckingForGet; +import org.apache.iceberg.Schema; +import org.apache.iceberg.arrow.vectorized.VectorizedReaderBuilder; +import org.apache.iceberg.data.DeleteFilter; +import org.apache.iceberg.parquet.TypeWithSchemaVisitor; +import org.apache.iceberg.parquet.VectorizedReader; +import org.apache.iceberg.relocated.com.google.common.collect.Maps; +import org.apache.parquet.schema.MessageType; +import org.apache.spark.sql.catalyst.InternalRow; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +public class VectorizedSparkParquetReaders { + + private static final Logger LOG = LoggerFactory.getLogger(VectorizedSparkParquetReaders.class); + private static final String ENABLE_UNSAFE_MEMORY_ACCESS = "arrow.enable_unsafe_memory_access"; + private static final String ENABLE_UNSAFE_MEMORY_ACCESS_ENV = "ARROW_ENABLE_UNSAFE_MEMORY_ACCESS"; + private static final String ENABLE_NULL_CHECK_FOR_GET = "arrow.enable_null_check_for_get"; + private static final String ENABLE_NULL_CHECK_FOR_GET_ENV = "ARROW_ENABLE_NULL_CHECK_FOR_GET"; + + static { + try { + enableUnsafeMemoryAccess(); + disableNullCheckForGet(); + } catch (Exception e) { + LOG.warn("Couldn't set Arrow properties, which may impact read performance", e); + } + } + + private VectorizedSparkParquetReaders() {} + + /** + * @deprecated will be removed in 1.3.0, use {@link #buildReader(Schema, MessageType, Map, + * DeleteFilter)} instead. + */ + @Deprecated + public static ColumnarBatchReader buildReader( + Schema expectedSchema, MessageType fileSchema, boolean setArrowValidityVector) { + return buildReader(expectedSchema, fileSchema, setArrowValidityVector, Maps.newHashMap()); + } + + /** + * @deprecated will be removed in 1.3.0, use {@link #buildReader(Schema, MessageType, Map, + * DeleteFilter)} instead. + */ + @Deprecated + public static ColumnarBatchReader buildReader( + Schema expectedSchema, + MessageType fileSchema, + boolean setArrowValidityVector, + Map idToConstant) { + return (ColumnarBatchReader) + TypeWithSchemaVisitor.visit( + expectedSchema.asStruct(), + fileSchema, + new VectorizedReaderBuilder( + expectedSchema, + fileSchema, + setArrowValidityVector, + idToConstant, + ColumnarBatchReader::new)); + } + + /** + * @deprecated will be removed in 1.3.0, use {@link #buildReader(Schema, MessageType, Map, + * DeleteFilter)} instead. + */ + @Deprecated + public static ColumnarBatchReader buildReader( + Schema expectedSchema, + MessageType fileSchema, + boolean setArrowValidityVector, + Map idToConstant, + DeleteFilter deleteFilter) { + return (ColumnarBatchReader) + TypeWithSchemaVisitor.visit( + expectedSchema.asStruct(), + fileSchema, + new ReaderBuilder( + expectedSchema, + fileSchema, + setArrowValidityVector, + idToConstant, + ColumnarBatchReader::new, + deleteFilter)); + } + + public static ColumnarBatchReader buildReader( + Schema expectedSchema, + MessageType fileSchema, + Map idToConstant, + DeleteFilter deleteFilter) { + return (ColumnarBatchReader) + TypeWithSchemaVisitor.visit( + expectedSchema.asStruct(), + fileSchema, + new ReaderBuilder( + expectedSchema, + fileSchema, + NullCheckingForGet.NULL_CHECKING_ENABLED, + idToConstant, + ColumnarBatchReader::new, + deleteFilter)); + } + + // enables unsafe memory access to avoid costly checks to see if index is within bounds + // as long as it is not configured explicitly (see BoundsChecking in Arrow) + private static void enableUnsafeMemoryAccess() { + String value = confValue(ENABLE_UNSAFE_MEMORY_ACCESS, ENABLE_UNSAFE_MEMORY_ACCESS_ENV); + if (value == null) { + LOG.info("Enabling {}", ENABLE_UNSAFE_MEMORY_ACCESS); + System.setProperty(ENABLE_UNSAFE_MEMORY_ACCESS, "true"); + } else { + LOG.info("Unsafe memory access was configured explicitly: {}", value); + } + } + + // disables expensive null checks for every get call in favor of Iceberg nullability + // as long as it is not configured explicitly (see NullCheckingForGet in Arrow) + private static void disableNullCheckForGet() { + String value = confValue(ENABLE_NULL_CHECK_FOR_GET, ENABLE_NULL_CHECK_FOR_GET_ENV); + if (value == null) { + LOG.info("Disabling {}", ENABLE_NULL_CHECK_FOR_GET); + System.setProperty(ENABLE_NULL_CHECK_FOR_GET, "false"); + } else { + LOG.info("Null checking for get calls was configured explicitly: {}", value); + } + } + + private static String confValue(String propName, String envName) { + String propValue = System.getProperty(propName); + if (propValue != null) { + return propValue; + } + + return System.getenv(envName); + } + + private static class ReaderBuilder extends VectorizedReaderBuilder { + private final DeleteFilter deleteFilter; + + ReaderBuilder( + Schema expectedSchema, + MessageType parquetSchema, + boolean setArrowValidityVector, + Map idToConstant, + Function>, VectorizedReader> readerFactory, + DeleteFilter deleteFilter) { + super(expectedSchema, parquetSchema, setArrowValidityVector, idToConstant, readerFactory); + this.deleteFilter = deleteFilter; + } + + @Override + protected VectorizedReader vectorizedReader(List> reorderedFields) { + VectorizedReader reader = super.vectorizedReader(reorderedFields); + if (deleteFilter != null) { + ((ColumnarBatchReader) reader).setDeleteFilter(deleteFilter); + } + return reader; + } + } +} diff --git a/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/functions/BucketFunction.java b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/functions/BucketFunction.java new file mode 100644 index 000000000000..b5736a866b57 --- /dev/null +++ b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/functions/BucketFunction.java @@ -0,0 +1,323 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.functions; + +import java.math.BigDecimal; +import java.nio.ByteBuffer; +import java.util.Set; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableSet; +import org.apache.iceberg.util.BucketUtil; +import org.apache.spark.sql.catalyst.InternalRow; +import org.apache.spark.sql.connector.catalog.functions.BoundFunction; +import org.apache.spark.sql.connector.catalog.functions.ScalarFunction; +import org.apache.spark.sql.connector.catalog.functions.UnboundFunction; +import org.apache.spark.sql.types.BinaryType; +import org.apache.spark.sql.types.ByteType; +import org.apache.spark.sql.types.DataType; +import org.apache.spark.sql.types.DataTypes; +import org.apache.spark.sql.types.DateType; +import org.apache.spark.sql.types.Decimal; +import org.apache.spark.sql.types.DecimalType; +import org.apache.spark.sql.types.IntegerType; +import org.apache.spark.sql.types.LongType; +import org.apache.spark.sql.types.ShortType; +import org.apache.spark.sql.types.StringType; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; +import org.apache.spark.sql.types.TimestampType; +import org.apache.spark.unsafe.types.UTF8String; + +/** + * A Spark function implementation for the Iceberg bucket transform. + * + *

Example usage: {@code SELECT system.bucket(128, 'abc')}, which returns the bucket 122. + * + *

Note that for performance reasons, the given input number of buckets is not validated in the + * implementations used in code-gen. The number of buckets must be positive to give meaningful + * results. + */ +public class BucketFunction implements UnboundFunction { + + private static final int NUM_BUCKETS_ORDINAL = 0; + private static final int VALUE_ORDINAL = 1; + + private static final Set SUPPORTED_NUM_BUCKETS_TYPES = + ImmutableSet.of(DataTypes.ByteType, DataTypes.ShortType, DataTypes.IntegerType); + + @Override + public BoundFunction bind(StructType inputType) { + if (inputType.size() != 2) { + throw new UnsupportedOperationException( + "Wrong number of inputs (expected numBuckets and value)"); + } + + StructField numBucketsField = inputType.fields()[NUM_BUCKETS_ORDINAL]; + StructField valueField = inputType.fields()[VALUE_ORDINAL]; + + if (!SUPPORTED_NUM_BUCKETS_TYPES.contains(numBucketsField.dataType())) { + throw new UnsupportedOperationException( + "Expected number of buckets to be tinyint, shortint or int"); + } + + DataType type = valueField.dataType(); + if (type instanceof DateType) { + return new BucketInt(type); + } else if (type instanceof ByteType + || type instanceof ShortType + || type instanceof IntegerType) { + return new BucketInt(DataTypes.IntegerType); + } else if (type instanceof LongType) { + return new BucketLong(type); + } else if (type instanceof TimestampType) { + return new BucketLong(type); + } else if (type instanceof DecimalType) { + return new BucketDecimal(type); + } else if (type instanceof StringType) { + return new BucketString(); + } else if (type instanceof BinaryType) { + return new BucketBinary(); + } else { + throw new UnsupportedOperationException( + "Expected column to be date, tinyint, smallint, int, bigint, decimal, timestamp, string, or binary"); + } + } + + @Override + public String description() { + return name() + + "(numBuckets, col) - Call Iceberg's bucket transform\n" + + " numBuckets :: number of buckets to divide the rows into, e.g. bucket(100, 34) -> 79 (must be a tinyint, smallint, or int)\n" + + " col :: column to bucket (must be a date, integer, long, timestamp, decimal, string, or binary)"; + } + + @Override + public String name() { + return "bucket"; + } + + public abstract static class BucketBase implements ScalarFunction { + public static int apply(int numBuckets, int hashedValue) { + return (hashedValue & Integer.MAX_VALUE) % numBuckets; + } + + @Override + public String name() { + return "bucket"; + } + + @Override + public DataType resultType() { + return DataTypes.IntegerType; + } + } + + // Used for both int and date - tinyint and smallint are upcasted to int by Spark. + public static class BucketInt extends BucketBase { + private final DataType sqlType; + + // magic method used in codegen + public static int invoke(int numBuckets, int value) { + return apply(numBuckets, hash(value)); + } + + // Visible for testing + public static int hash(int value) { + return BucketUtil.hash(value); + } + + public BucketInt(DataType sqlType) { + this.sqlType = sqlType; + } + + @Override + public DataType[] inputTypes() { + return new DataType[] {DataTypes.IntegerType, sqlType}; + } + + @Override + public String canonicalName() { + return String.format("iceberg.bucket(%s)", sqlType.catalogString()); + } + + @Override + public Integer produceResult(InternalRow input) { + // return null for null input to match what Spark does in the code-generated versions. + if (input.isNullAt(NUM_BUCKETS_ORDINAL) || input.isNullAt(VALUE_ORDINAL)) { + return null; + } else { + return invoke(input.getInt(NUM_BUCKETS_ORDINAL), input.getInt(VALUE_ORDINAL)); + } + } + } + + // Used for both BigInt and Timestamp + public static class BucketLong extends BucketBase { + private final DataType sqlType; + + // magic function for usage with codegen - needs to be static + public static int invoke(int numBuckets, long value) { + return apply(numBuckets, hash(value)); + } + + // Visible for testing + public static int hash(long value) { + return BucketUtil.hash(value); + } + + public BucketLong(DataType sqlType) { + this.sqlType = sqlType; + } + + @Override + public DataType[] inputTypes() { + return new DataType[] {DataTypes.IntegerType, sqlType}; + } + + @Override + public String canonicalName() { + return String.format("iceberg.bucket(%s)", sqlType.catalogString()); + } + + @Override + public Integer produceResult(InternalRow input) { + if (input.isNullAt(NUM_BUCKETS_ORDINAL) || input.isNullAt(VALUE_ORDINAL)) { + return null; + } else { + return invoke(input.getInt(NUM_BUCKETS_ORDINAL), input.getLong(VALUE_ORDINAL)); + } + } + } + + public static class BucketString extends BucketBase { + // magic function for usage with codegen + public static Integer invoke(int numBuckets, UTF8String value) { + if (value == null) { + return null; + } + + // TODO - We can probably hash the bytes directly given they're already UTF-8 input. + return apply(numBuckets, hash(value.toString())); + } + + // Visible for testing + public static int hash(String value) { + return BucketUtil.hash(value); + } + + @Override + public DataType[] inputTypes() { + return new DataType[] {DataTypes.IntegerType, DataTypes.StringType}; + } + + @Override + public String canonicalName() { + return "iceberg.bucket(string)"; + } + + @Override + public Integer produceResult(InternalRow input) { + if (input.isNullAt(NUM_BUCKETS_ORDINAL) || input.isNullAt(VALUE_ORDINAL)) { + return null; + } else { + return invoke(input.getInt(NUM_BUCKETS_ORDINAL), input.getUTF8String(VALUE_ORDINAL)); + } + } + } + + public static class BucketBinary extends BucketBase { + public static Integer invoke(int numBuckets, byte[] value) { + if (value == null) { + return null; + } + + return apply(numBuckets, hash(ByteBuffer.wrap(value))); + } + + // Visible for testing + public static int hash(ByteBuffer value) { + return BucketUtil.hash(value); + } + + @Override + public DataType[] inputTypes() { + return new DataType[] {DataTypes.IntegerType, DataTypes.BinaryType}; + } + + @Override + public Integer produceResult(InternalRow input) { + if (input.isNullAt(NUM_BUCKETS_ORDINAL) || input.isNullAt(VALUE_ORDINAL)) { + return null; + } else { + return invoke(input.getInt(NUM_BUCKETS_ORDINAL), input.getBinary(VALUE_ORDINAL)); + } + } + + @Override + public String canonicalName() { + return "iceberg.bucket(binary)"; + } + } + + public static class BucketDecimal extends BucketBase { + private final DataType sqlType; + private final int precision; + private final int scale; + + // magic method used in codegen + public static Integer invoke(int numBuckets, Decimal value) { + if (value == null) { + return null; + } + + return apply(numBuckets, hash(value.toJavaBigDecimal())); + } + + // Visible for testing + public static int hash(BigDecimal value) { + return BucketUtil.hash(value); + } + + public BucketDecimal(DataType sqlType) { + this.sqlType = sqlType; + this.precision = ((DecimalType) sqlType).precision(); + this.scale = ((DecimalType) sqlType).scale(); + } + + @Override + public DataType[] inputTypes() { + return new DataType[] {DataTypes.IntegerType, sqlType}; + } + + @Override + public Integer produceResult(InternalRow input) { + if (input.isNullAt(NUM_BUCKETS_ORDINAL) || input.isNullAt(VALUE_ORDINAL)) { + return null; + } else { + int numBuckets = input.getInt(NUM_BUCKETS_ORDINAL); + Decimal value = input.getDecimal(VALUE_ORDINAL, precision, scale); + return invoke(numBuckets, value); + } + } + + @Override + public String canonicalName() { + return "iceberg.bucket(decimal)"; + } + } +} diff --git a/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/functions/DaysFunction.java b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/functions/DaysFunction.java new file mode 100644 index 000000000000..c2bd9f37aa23 --- /dev/null +++ b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/functions/DaysFunction.java @@ -0,0 +1,119 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.functions; + +import org.apache.iceberg.util.DateTimeUtil; +import org.apache.spark.sql.catalyst.InternalRow; +import org.apache.spark.sql.connector.catalog.functions.BoundFunction; +import org.apache.spark.sql.connector.catalog.functions.ScalarFunction; +import org.apache.spark.sql.types.DataType; +import org.apache.spark.sql.types.DataTypes; +import org.apache.spark.sql.types.DateType; +import org.apache.spark.sql.types.TimestampType; + +/** + * A Spark function implementation for the Iceberg day transform. + * + *

Example usage: {@code SELECT system.days('source_col')}. + */ +public class DaysFunction extends UnaryUnboundFunction { + + @Override + protected BoundFunction doBind(DataType valueType) { + if (valueType instanceof DateType) { + return new DateToDaysFunction(); + } else if (valueType instanceof TimestampType) { + return new TimestampToDaysFunction(); + } else { + throw new UnsupportedOperationException( + "Expected value to be date or timestamp: " + valueType.catalogString()); + } + } + + @Override + public String description() { + return name() + + "(col) - Call Iceberg's day transform\n" + + " col :: source column (must be date or timestamp)"; + } + + @Override + public String name() { + return "days"; + } + + private abstract static class BaseToDaysFunction implements ScalarFunction { + @Override + public String name() { + return "days"; + } + + @Override + public DataType resultType() { + return DataTypes.DateType; + } + } + + // Spark and Iceberg internal representations of dates match so no transformation is required + public static class DateToDaysFunction extends BaseToDaysFunction { + // magic method used in codegen + public static int invoke(int days) { + return days; + } + + @Override + public DataType[] inputTypes() { + return new DataType[] {DataTypes.DateType}; + } + + @Override + public String canonicalName() { + return "iceberg.days(date)"; + } + + @Override + public Integer produceResult(InternalRow input) { + // return null for null input to match what Spark does in codegen + return input.isNullAt(0) ? null : input.getInt(0); + } + } + + public static class TimestampToDaysFunction extends BaseToDaysFunction { + // magic method used in codegen + public static int invoke(long micros) { + return DateTimeUtil.microsToDays(micros); + } + + @Override + public DataType[] inputTypes() { + return new DataType[] {DataTypes.TimestampType}; + } + + @Override + public String canonicalName() { + return "iceberg.days(timestamp)"; + } + + @Override + public Integer produceResult(InternalRow input) { + // return null for null input to match what Spark does in codegen + return input.isNullAt(0) ? null : invoke(input.getLong(0)); + } + } +} diff --git a/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/functions/HoursFunction.java b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/functions/HoursFunction.java new file mode 100644 index 000000000000..5d8ae97f4c6d --- /dev/null +++ b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/functions/HoursFunction.java @@ -0,0 +1,90 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.functions; + +import org.apache.iceberg.util.DateTimeUtil; +import org.apache.spark.sql.catalyst.InternalRow; +import org.apache.spark.sql.connector.catalog.functions.BoundFunction; +import org.apache.spark.sql.connector.catalog.functions.ScalarFunction; +import org.apache.spark.sql.types.DataType; +import org.apache.spark.sql.types.DataTypes; +import org.apache.spark.sql.types.TimestampType; + +/** + * A Spark function implementation for the Iceberg hour transform. + * + *

Example usage: {@code SELECT system.hours('source_col')}. + */ +public class HoursFunction extends UnaryUnboundFunction { + + @Override + protected BoundFunction doBind(DataType valueType) { + if (valueType instanceof TimestampType) { + return new TimestampToHoursFunction(); + } else { + throw new UnsupportedOperationException( + "Expected value to be timestamp: " + valueType.catalogString()); + } + } + + @Override + public String description() { + return name() + + "(col) - Call Iceberg's hour transform\n" + + " col :: source column (must be timestamp)"; + } + + @Override + public String name() { + return "hours"; + } + + public static class TimestampToHoursFunction implements ScalarFunction { + // magic method used in codegen + public static int invoke(long micros) { + return DateTimeUtil.microsToHours(micros); + } + + @Override + public String name() { + return "hours"; + } + + @Override + public DataType[] inputTypes() { + return new DataType[] {DataTypes.TimestampType}; + } + + @Override + public DataType resultType() { + return DataTypes.IntegerType; + } + + @Override + public String canonicalName() { + return "iceberg.hours(timestamp)"; + } + + @Override + public Integer produceResult(InternalRow input) { + // return null for null input to match what Spark does in codegen + return input.isNullAt(0) ? null : invoke(input.getLong(0)); + } + } +} diff --git a/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/functions/IcebergVersionFunction.java b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/functions/IcebergVersionFunction.java new file mode 100644 index 000000000000..9cd059377ce3 --- /dev/null +++ b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/functions/IcebergVersionFunction.java @@ -0,0 +1,96 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.functions; + +import org.apache.iceberg.IcebergBuild; +import org.apache.spark.sql.catalyst.InternalRow; +import org.apache.spark.sql.connector.catalog.functions.BoundFunction; +import org.apache.spark.sql.connector.catalog.functions.ScalarFunction; +import org.apache.spark.sql.connector.catalog.functions.UnboundFunction; +import org.apache.spark.sql.types.DataType; +import org.apache.spark.sql.types.DataTypes; +import org.apache.spark.sql.types.StructType; +import org.apache.spark.unsafe.types.UTF8String; + +/** + * A function for use in SQL that returns the current Iceberg version, e.g. {@code SELECT + * system.iceberg_version()} will return a String such as "0.14.0" or "0.15.0-SNAPSHOT" + */ +public class IcebergVersionFunction implements UnboundFunction { + @Override + public BoundFunction bind(StructType inputType) { + if (inputType.fields().length > 0) { + throw new UnsupportedOperationException( + String.format("Cannot bind: %s does not accept arguments", name())); + } + + return new IcebergVersionFunctionImpl(); + } + + @Override + public String description() { + return name() + " - Returns the runtime Iceberg version"; + } + + @Override + public String name() { + return "iceberg_version"; + } + + // Implementing class cannot be private, otherwise Spark is unable to access the static invoke + // function during code-gen and calling the function fails + static class IcebergVersionFunctionImpl implements ScalarFunction { + private static final UTF8String VERSION = UTF8String.fromString(IcebergBuild.version()); + + // magic function used in code-gen. must be named `invoke`. + public static UTF8String invoke() { + return VERSION; + } + + @Override + public DataType[] inputTypes() { + return new DataType[0]; + } + + @Override + public DataType resultType() { + return DataTypes.StringType; + } + + @Override + public boolean isResultNullable() { + return false; + } + + @Override + public String canonicalName() { + return "iceberg." + name(); + } + + @Override + public String name() { + return "iceberg_version"; + } + + @Override + public UTF8String produceResult(InternalRow input) { + return invoke(); + } + } +} diff --git a/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/functions/MonthsFunction.java b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/functions/MonthsFunction.java new file mode 100644 index 000000000000..c073c048a5fa --- /dev/null +++ b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/functions/MonthsFunction.java @@ -0,0 +1,118 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.functions; + +import org.apache.iceberg.util.DateTimeUtil; +import org.apache.spark.sql.catalyst.InternalRow; +import org.apache.spark.sql.connector.catalog.functions.BoundFunction; +import org.apache.spark.sql.connector.catalog.functions.ScalarFunction; +import org.apache.spark.sql.types.DataType; +import org.apache.spark.sql.types.DataTypes; +import org.apache.spark.sql.types.DateType; +import org.apache.spark.sql.types.TimestampType; + +/** + * A Spark function implementation for the Iceberg month transform. + * + *

Example usage: {@code SELECT system.months('source_col')}. + */ +public class MonthsFunction extends UnaryUnboundFunction { + + @Override + protected BoundFunction doBind(DataType valueType) { + if (valueType instanceof DateType) { + return new DateToMonthsFunction(); + } else if (valueType instanceof TimestampType) { + return new TimestampToMonthsFunction(); + } else { + throw new UnsupportedOperationException( + "Expected value to be date or timestamp: " + valueType.catalogString()); + } + } + + @Override + public String description() { + return name() + + "(col) - Call Iceberg's month transform\n" + + " col :: source column (must be date or timestamp)"; + } + + @Override + public String name() { + return "months"; + } + + private abstract static class BaseToMonthsFunction implements ScalarFunction { + @Override + public String name() { + return "months"; + } + + @Override + public DataType resultType() { + return DataTypes.IntegerType; + } + } + + public static class DateToMonthsFunction extends BaseToMonthsFunction { + // magic method used in codegen + public static int invoke(int days) { + return DateTimeUtil.daysToMonths(days); + } + + @Override + public DataType[] inputTypes() { + return new DataType[] {DataTypes.DateType}; + } + + @Override + public String canonicalName() { + return "iceberg.months(date)"; + } + + @Override + public Integer produceResult(InternalRow input) { + // return null for null input to match what Spark does in codegen + return input.isNullAt(0) ? null : invoke(input.getInt(0)); + } + } + + public static class TimestampToMonthsFunction extends BaseToMonthsFunction { + // magic method used in codegen + public static int invoke(long micros) { + return DateTimeUtil.microsToMonths(micros); + } + + @Override + public DataType[] inputTypes() { + return new DataType[] {DataTypes.TimestampType}; + } + + @Override + public String canonicalName() { + return "iceberg.months(timestamp)"; + } + + @Override + public Integer produceResult(InternalRow input) { + // return null for null input to match what Spark does in codegen + return input.isNullAt(0) ? null : invoke(input.getLong(0)); + } + } +} diff --git a/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/functions/SparkFunctions.java b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/functions/SparkFunctions.java new file mode 100644 index 000000000000..d14bd4583134 --- /dev/null +++ b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/functions/SparkFunctions.java @@ -0,0 +1,57 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.functions; + +import java.util.List; +import java.util.Locale; +import java.util.Map; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableList; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap; +import org.apache.spark.sql.connector.catalog.functions.UnboundFunction; + +public class SparkFunctions { + + private SparkFunctions() {} + + private static final Map FUNCTIONS = + ImmutableMap.of( + "iceberg_version", new IcebergVersionFunction(), + "years", new YearsFunction(), + "months", new MonthsFunction(), + "days", new DaysFunction(), + "hours", new HoursFunction(), + "bucket", new BucketFunction(), + "truncate", new TruncateFunction()); + + private static final List FUNCTION_NAMES = ImmutableList.copyOf(FUNCTIONS.keySet()); + + // Functions that are added to all Iceberg catalogs should be accessed with the `system` + // namespace. They can also be accessed with no namespace at all if qualified with the + // catalog name, e.g. my_hadoop_catalog.iceberg_version(). + // As namespace resolution is handled by those rules in BaseCatalog, a list of names + // alone is returned. + public static List list() { + return FUNCTION_NAMES; + } + + public static UnboundFunction load(String name) { + // function resolution is case-insensitive to match the existing Spark behavior for functions + return FUNCTIONS.get(name.toLowerCase(Locale.ROOT)); + } +} diff --git a/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/functions/TruncateFunction.java b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/functions/TruncateFunction.java new file mode 100644 index 000000000000..8cfb529e1028 --- /dev/null +++ b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/functions/TruncateFunction.java @@ -0,0 +1,357 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.functions; + +import java.math.BigInteger; +import java.nio.ByteBuffer; +import java.util.Set; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableSet; +import org.apache.iceberg.util.BinaryUtil; +import org.apache.iceberg.util.ByteBuffers; +import org.apache.iceberg.util.TruncateUtil; +import org.apache.spark.sql.catalyst.InternalRow; +import org.apache.spark.sql.connector.catalog.functions.BoundFunction; +import org.apache.spark.sql.connector.catalog.functions.ScalarFunction; +import org.apache.spark.sql.connector.catalog.functions.UnboundFunction; +import org.apache.spark.sql.types.BinaryType; +import org.apache.spark.sql.types.ByteType; +import org.apache.spark.sql.types.DataType; +import org.apache.spark.sql.types.DataTypes; +import org.apache.spark.sql.types.Decimal; +import org.apache.spark.sql.types.DecimalType; +import org.apache.spark.sql.types.IntegerType; +import org.apache.spark.sql.types.LongType; +import org.apache.spark.sql.types.ShortType; +import org.apache.spark.sql.types.StringType; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; +import org.apache.spark.unsafe.types.UTF8String; + +/** + * A Spark function implementation for the Iceberg truncate transform. + * + *

Example usage: {@code SELECT system.truncate(1, 'abc')}, which returns the String 'a'. + * + *

Note that for performance reasons, the given input width is not validated in the + * implementations used in code-gen. The width must remain non-negative to give meaningful results. + */ +public class TruncateFunction implements UnboundFunction { + + private static final int WIDTH_ORDINAL = 0; + private static final int VALUE_ORDINAL = 1; + + private static final Set SUPPORTED_WIDTH_TYPES = + ImmutableSet.of(DataTypes.ByteType, DataTypes.ShortType, DataTypes.IntegerType); + + @Override + public BoundFunction bind(StructType inputType) { + if (inputType.size() != 2) { + throw new UnsupportedOperationException("Wrong number of inputs (expected width and value)"); + } + + StructField widthField = inputType.fields()[WIDTH_ORDINAL]; + StructField valueField = inputType.fields()[VALUE_ORDINAL]; + + if (!SUPPORTED_WIDTH_TYPES.contains(widthField.dataType())) { + throw new UnsupportedOperationException( + "Expected truncation width to be tinyint, shortint or int"); + } + + DataType valueType = valueField.dataType(); + if (valueType instanceof ByteType) { + return new TruncateTinyInt(); + } else if (valueType instanceof ShortType) { + return new TruncateSmallInt(); + } else if (valueType instanceof IntegerType) { + return new TruncateInt(); + } else if (valueType instanceof LongType) { + return new TruncateBigInt(); + } else if (valueType instanceof DecimalType) { + return new TruncateDecimal( + ((DecimalType) valueType).precision(), ((DecimalType) valueType).scale()); + } else if (valueType instanceof StringType) { + return new TruncateString(); + } else if (valueType instanceof BinaryType) { + return new TruncateBinary(); + } else { + throw new UnsupportedOperationException( + "Expected truncation col to be tinyint, shortint, int, bigint, decimal, string, or binary"); + } + } + + @Override + public String description() { + return name() + + "(width, col) - Call Iceberg's truncate transform\n" + + " width :: width for truncation, e.g. truncate(10, 255) -> 250 (must be an integer)\n" + + " col :: column to truncate (must be an integer, decimal, string, or binary)"; + } + + @Override + public String name() { + return "truncate"; + } + + public abstract static class TruncateBase implements ScalarFunction { + @Override + public String name() { + return "truncate"; + } + } + + public static class TruncateTinyInt extends TruncateBase { + public static byte invoke(int width, byte value) { + return TruncateUtil.truncateByte(width, value); + } + + @Override + public DataType[] inputTypes() { + return new DataType[] {DataTypes.IntegerType, DataTypes.ByteType}; + } + + @Override + public DataType resultType() { + return DataTypes.ByteType; + } + + @Override + public String canonicalName() { + return "iceberg.truncate(tinyint)"; + } + + @Override + public Byte produceResult(InternalRow input) { + if (input.isNullAt(WIDTH_ORDINAL) || input.isNullAt(VALUE_ORDINAL)) { + return null; + } else { + return invoke(input.getInt(WIDTH_ORDINAL), input.getByte(VALUE_ORDINAL)); + } + } + } + + public static class TruncateSmallInt extends TruncateBase { + // magic method used in codegen + public static short invoke(int width, short value) { + return TruncateUtil.truncateShort(width, value); + } + + @Override + public DataType[] inputTypes() { + return new DataType[] {DataTypes.IntegerType, DataTypes.ShortType}; + } + + @Override + public DataType resultType() { + return DataTypes.ShortType; + } + + @Override + public String canonicalName() { + return "iceberg.truncate(smallint)"; + } + + @Override + public Short produceResult(InternalRow input) { + if (input.isNullAt(WIDTH_ORDINAL) || input.isNullAt(VALUE_ORDINAL)) { + return null; + } else { + return invoke(input.getInt(WIDTH_ORDINAL), input.getShort(VALUE_ORDINAL)); + } + } + } + + public static class TruncateInt extends TruncateBase { + // magic method used in codegen + public static int invoke(int width, int value) { + return TruncateUtil.truncateInt(width, value); + } + + @Override + public DataType[] inputTypes() { + return new DataType[] {DataTypes.IntegerType, DataTypes.IntegerType}; + } + + @Override + public DataType resultType() { + return DataTypes.IntegerType; + } + + @Override + public String canonicalName() { + return "iceberg.truncate(int)"; + } + + @Override + public Integer produceResult(InternalRow input) { + if (input.isNullAt(WIDTH_ORDINAL) || input.isNullAt(VALUE_ORDINAL)) { + return null; + } else { + return invoke(input.getInt(WIDTH_ORDINAL), input.getInt(VALUE_ORDINAL)); + } + } + } + + public static class TruncateBigInt extends TruncateBase { + // magic function for usage with codegen + public static long invoke(int width, long value) { + return TruncateUtil.truncateLong(width, value); + } + + @Override + public DataType[] inputTypes() { + return new DataType[] {DataTypes.IntegerType, DataTypes.LongType}; + } + + @Override + public DataType resultType() { + return DataTypes.LongType; + } + + @Override + public String canonicalName() { + return "iceberg.truncate(bigint)"; + } + + @Override + public Long produceResult(InternalRow input) { + if (input.isNullAt(WIDTH_ORDINAL) || input.isNullAt(VALUE_ORDINAL)) { + return null; + } else { + return invoke(input.getInt(WIDTH_ORDINAL), input.getLong(VALUE_ORDINAL)); + } + } + } + + public static class TruncateString extends TruncateBase { + // magic function for usage with codegen + public static UTF8String invoke(int width, UTF8String value) { + if (value == null) { + return null; + } + + return value.substring(0, width); + } + + @Override + public DataType[] inputTypes() { + return new DataType[] {DataTypes.IntegerType, DataTypes.StringType}; + } + + @Override + public DataType resultType() { + return DataTypes.StringType; + } + + @Override + public String canonicalName() { + return "iceberg.truncate(string)"; + } + + @Override + public UTF8String produceResult(InternalRow input) { + if (input.isNullAt(WIDTH_ORDINAL) || input.isNullAt(VALUE_ORDINAL)) { + return null; + } else { + return invoke(input.getInt(WIDTH_ORDINAL), input.getUTF8String(VALUE_ORDINAL)); + } + } + } + + public static class TruncateBinary extends TruncateBase { + // magic method used in codegen + public static byte[] invoke(int width, byte[] value) { + if (value == null) { + return null; + } + + return ByteBuffers.toByteArray( + BinaryUtil.truncateBinaryUnsafe(ByteBuffer.wrap(value), width)); + } + + @Override + public DataType[] inputTypes() { + return new DataType[] {DataTypes.IntegerType, DataTypes.BinaryType}; + } + + @Override + public DataType resultType() { + return DataTypes.BinaryType; + } + + @Override + public String canonicalName() { + return "iceberg.truncate(binary)"; + } + + @Override + public byte[] produceResult(InternalRow input) { + if (input.isNullAt(WIDTH_ORDINAL) || input.isNullAt(VALUE_ORDINAL)) { + return null; + } else { + return invoke(input.getInt(WIDTH_ORDINAL), input.getBinary(VALUE_ORDINAL)); + } + } + } + + public static class TruncateDecimal extends TruncateBase { + private final int precision; + private final int scale; + + public TruncateDecimal(int precision, int scale) { + this.precision = precision; + this.scale = scale; + } + + // magic method used in codegen + public static Decimal invoke(int width, Decimal value) { + if (value == null) { + return null; + } + + return Decimal.apply( + TruncateUtil.truncateDecimal(BigInteger.valueOf(width), value.toJavaBigDecimal())); + } + + @Override + public DataType[] inputTypes() { + return new DataType[] {DataTypes.IntegerType, DataTypes.createDecimalType(precision, scale)}; + } + + @Override + public DataType resultType() { + return DataTypes.createDecimalType(precision, scale); + } + + @Override + public String canonicalName() { + return String.format("iceberg.truncate(decimal(%d,%d))", precision, scale); + } + + @Override + public Decimal produceResult(InternalRow input) { + if (input.isNullAt(WIDTH_ORDINAL) || input.isNullAt(VALUE_ORDINAL)) { + return null; + } else { + int width = input.getInt(WIDTH_ORDINAL); + Decimal value = input.getDecimal(VALUE_ORDINAL, precision, scale); + return invoke(width, value); + } + } + } +} diff --git a/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/functions/UnaryUnboundFunction.java b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/functions/UnaryUnboundFunction.java new file mode 100644 index 000000000000..9003c68919dc --- /dev/null +++ b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/functions/UnaryUnboundFunction.java @@ -0,0 +1,44 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.functions; + +import org.apache.spark.sql.connector.catalog.functions.BoundFunction; +import org.apache.spark.sql.connector.catalog.functions.UnboundFunction; +import org.apache.spark.sql.types.DataType; +import org.apache.spark.sql.types.StructType; + +/** An unbound function that accepts only one argument */ +abstract class UnaryUnboundFunction implements UnboundFunction { + + @Override + public BoundFunction bind(StructType inputType) { + DataType valueType = valueType(inputType); + return doBind(valueType); + } + + protected abstract BoundFunction doBind(DataType valueType); + + private DataType valueType(StructType inputType) { + if (inputType.size() != 1) { + throw new UnsupportedOperationException("Wrong number of inputs (expected value)"); + } + + return inputType.fields()[0].dataType(); + } +} diff --git a/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/functions/YearsFunction.java b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/functions/YearsFunction.java new file mode 100644 index 000000000000..779e7a28ca4b --- /dev/null +++ b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/functions/YearsFunction.java @@ -0,0 +1,118 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.functions; + +import org.apache.iceberg.util.DateTimeUtil; +import org.apache.spark.sql.catalyst.InternalRow; +import org.apache.spark.sql.connector.catalog.functions.BoundFunction; +import org.apache.spark.sql.connector.catalog.functions.ScalarFunction; +import org.apache.spark.sql.types.DataType; +import org.apache.spark.sql.types.DataTypes; +import org.apache.spark.sql.types.DateType; +import org.apache.spark.sql.types.TimestampType; + +/** + * A Spark function implementation for the Iceberg year transform. + * + *

Example usage: {@code SELECT system.years('source_col')}. + */ +public class YearsFunction extends UnaryUnboundFunction { + + @Override + protected BoundFunction doBind(DataType valueType) { + if (valueType instanceof DateType) { + return new DateToYearsFunction(); + } else if (valueType instanceof TimestampType) { + return new TimestampToYearsFunction(); + } else { + throw new UnsupportedOperationException( + "Expected value to be date or timestamp: " + valueType.catalogString()); + } + } + + @Override + public String description() { + return name() + + "(col) - Call Iceberg's year transform\n" + + " col :: source column (must be date or timestamp)"; + } + + @Override + public String name() { + return "years"; + } + + private abstract static class BaseToYearsFunction implements ScalarFunction { + @Override + public String name() { + return "years"; + } + + @Override + public DataType resultType() { + return DataTypes.IntegerType; + } + } + + public static class DateToYearsFunction extends BaseToYearsFunction { + // magic method used in codegen + public static int invoke(int days) { + return DateTimeUtil.daysToYears(days); + } + + @Override + public DataType[] inputTypes() { + return new DataType[] {DataTypes.DateType}; + } + + @Override + public String canonicalName() { + return "iceberg.years(date)"; + } + + @Override + public Integer produceResult(InternalRow input) { + // return null for null input to match what Spark does in codegen + return input.isNullAt(0) ? null : invoke(input.getInt(0)); + } + } + + public static class TimestampToYearsFunction extends BaseToYearsFunction { + // magic method used in codegen + public static int invoke(long micros) { + return DateTimeUtil.microsToYears(micros); + } + + @Override + public DataType[] inputTypes() { + return new DataType[] {DataTypes.TimestampType}; + } + + @Override + public String canonicalName() { + return "iceberg.years(timestamp)"; + } + + @Override + public Integer produceResult(InternalRow input) { + // return null for null input to match what Spark does in codegen + return input.isNullAt(0) ? null : invoke(input.getLong(0)); + } + } +} diff --git a/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/procedures/AddFilesProcedure.java b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/procedures/AddFilesProcedure.java new file mode 100644 index 000000000000..b349694130d3 --- /dev/null +++ b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/procedures/AddFilesProcedure.java @@ -0,0 +1,282 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.procedures; + +import java.util.Collections; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.stream.Collectors; +import org.apache.hadoop.fs.Path; +import org.apache.iceberg.PartitionField; +import org.apache.iceberg.PartitionSpec; +import org.apache.iceberg.Snapshot; +import org.apache.iceberg.SnapshotSummary; +import org.apache.iceberg.Table; +import org.apache.iceberg.TableProperties; +import org.apache.iceberg.mapping.MappingUtil; +import org.apache.iceberg.mapping.NameMapping; +import org.apache.iceberg.mapping.NameMappingParser; +import org.apache.iceberg.relocated.com.google.common.base.Preconditions; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableList; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap; +import org.apache.iceberg.spark.Spark3Util; +import org.apache.iceberg.spark.SparkTableUtil; +import org.apache.iceberg.spark.SparkTableUtil.SparkPartition; +import org.apache.iceberg.util.LocationUtil; +import org.apache.spark.sql.catalyst.InternalRow; +import org.apache.spark.sql.catalyst.TableIdentifier; +import org.apache.spark.sql.connector.catalog.CatalogPlugin; +import org.apache.spark.sql.connector.catalog.Identifier; +import org.apache.spark.sql.connector.catalog.TableCatalog; +import org.apache.spark.sql.connector.iceberg.catalog.ProcedureParameter; +import org.apache.spark.sql.types.DataTypes; +import org.apache.spark.sql.types.Metadata; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; + +class AddFilesProcedure extends BaseProcedure { + + private static final ProcedureParameter TABLE_PARAM = + ProcedureParameter.required("table", DataTypes.StringType); + private static final ProcedureParameter SOURCE_TABLE_PARAM = + ProcedureParameter.required("source_table", DataTypes.StringType); + private static final ProcedureParameter PARTITION_FILTER_PARAM = + ProcedureParameter.optional("partition_filter", STRING_MAP); + private static final ProcedureParameter CHECK_DUPLICATE_FILES_PARAM = + ProcedureParameter.optional("check_duplicate_files", DataTypes.BooleanType); + + private static final ProcedureParameter[] PARAMETERS = + new ProcedureParameter[] { + TABLE_PARAM, SOURCE_TABLE_PARAM, PARTITION_FILTER_PARAM, CHECK_DUPLICATE_FILES_PARAM + }; + + private static final StructType OUTPUT_TYPE = + new StructType( + new StructField[] { + new StructField("added_files_count", DataTypes.LongType, false, Metadata.empty()), + new StructField("changed_partition_count", DataTypes.LongType, false, Metadata.empty()), + }); + + private AddFilesProcedure(TableCatalog tableCatalog) { + super(tableCatalog); + } + + public static SparkProcedures.ProcedureBuilder builder() { + return new BaseProcedure.Builder() { + @Override + protected AddFilesProcedure doBuild() { + return new AddFilesProcedure(tableCatalog()); + } + }; + } + + @Override + public ProcedureParameter[] parameters() { + return PARAMETERS; + } + + @Override + public StructType outputType() { + return OUTPUT_TYPE; + } + + @Override + public InternalRow[] call(InternalRow args) { + ProcedureInput input = new ProcedureInput(spark(), tableCatalog(), PARAMETERS, args); + + Identifier tableIdent = input.ident(TABLE_PARAM); + + CatalogPlugin sessionCat = spark().sessionState().catalogManager().v2SessionCatalog(); + Identifier sourceIdent = input.ident(SOURCE_TABLE_PARAM, sessionCat); + + Map partitionFilter = + input.asStringMap(PARTITION_FILTER_PARAM, ImmutableMap.of()); + + boolean checkDuplicateFiles = input.asBoolean(CHECK_DUPLICATE_FILES_PARAM, true); + + return importToIceberg(tableIdent, sourceIdent, partitionFilter, checkDuplicateFiles); + } + + private InternalRow[] toOutputRows(Snapshot snapshot) { + Map summary = snapshot.summary(); + return new InternalRow[] { + newInternalRow( + Long.parseLong(summary.getOrDefault(SnapshotSummary.ADDED_FILES_PROP, "0")), + Long.parseLong(summary.getOrDefault(SnapshotSummary.CHANGED_PARTITION_COUNT_PROP, "0"))) + }; + } + + private boolean isFileIdentifier(Identifier ident) { + String[] namespace = ident.namespace(); + return namespace.length == 1 + && (namespace[0].equalsIgnoreCase("orc") + || namespace[0].equalsIgnoreCase("parquet") + || namespace[0].equalsIgnoreCase("avro")); + } + + private InternalRow[] importToIceberg( + Identifier destIdent, + Identifier sourceIdent, + Map partitionFilter, + boolean checkDuplicateFiles) { + return modifyIcebergTable( + destIdent, + table -> { + validatePartitionSpec(table, partitionFilter); + ensureNameMappingPresent(table); + + if (isFileIdentifier(sourceIdent)) { + Path sourcePath = new Path(sourceIdent.name()); + String format = sourceIdent.namespace()[0]; + importFileTable( + table, sourcePath, format, partitionFilter, checkDuplicateFiles, table.spec()); + } else { + importCatalogTable(table, sourceIdent, partitionFilter, checkDuplicateFiles); + } + + Snapshot snapshot = table.currentSnapshot(); + return toOutputRows(snapshot); + }); + } + + private static void ensureNameMappingPresent(Table table) { + if (table.properties().get(TableProperties.DEFAULT_NAME_MAPPING) == null) { + // Forces Name based resolution instead of position based resolution + NameMapping mapping = MappingUtil.create(table.schema()); + String mappingJson = NameMappingParser.toJson(mapping); + table.updateProperties().set(TableProperties.DEFAULT_NAME_MAPPING, mappingJson).commit(); + } + } + + private void importFileTable( + Table table, + Path tableLocation, + String format, + Map partitionFilter, + boolean checkDuplicateFiles, + PartitionSpec spec) { + // List Partitions via Spark InMemory file search interface + List partitions = + Spark3Util.getPartitions(spark(), tableLocation, format, partitionFilter, spec); + + if (table.spec().isUnpartitioned()) { + Preconditions.checkArgument( + partitions.isEmpty(), "Cannot add partitioned files to an unpartitioned table"); + Preconditions.checkArgument( + partitionFilter.isEmpty(), + "Cannot use a partition filter when importing" + "to an unpartitioned table"); + + // Build a Global Partition for the source + SparkPartition partition = + new SparkPartition(Collections.emptyMap(), tableLocation.toString(), format); + importPartitions(table, ImmutableList.of(partition), checkDuplicateFiles); + } else { + Preconditions.checkArgument( + !partitions.isEmpty(), "Cannot find any matching partitions in table %s", partitions); + importPartitions(table, partitions, checkDuplicateFiles); + } + } + + private void importCatalogTable( + Table table, + Identifier sourceIdent, + Map partitionFilter, + boolean checkDuplicateFiles) { + String stagingLocation = getMetadataLocation(table); + TableIdentifier sourceTableIdentifier = Spark3Util.toV1TableIdentifier(sourceIdent); + SparkTableUtil.importSparkTable( + spark(), + sourceTableIdentifier, + table, + stagingLocation, + partitionFilter, + checkDuplicateFiles); + } + + private void importPartitions( + Table table, List partitions, boolean checkDuplicateFiles) { + String stagingLocation = getMetadataLocation(table); + SparkTableUtil.importSparkPartitions( + spark(), partitions, table, table.spec(), stagingLocation, checkDuplicateFiles); + } + + private String getMetadataLocation(Table table) { + String defaultValue = LocationUtil.stripTrailingSlash(table.location()) + "/metadata"; + return LocationUtil.stripTrailingSlash( + table.properties().getOrDefault(TableProperties.WRITE_METADATA_LOCATION, defaultValue)); + } + + @Override + public String description() { + return "AddFiles"; + } + + private void validatePartitionSpec(Table table, Map partitionFilter) { + List partitionFields = table.spec().fields(); + Set partitionNames = + table.spec().fields().stream().map(PartitionField::name).collect(Collectors.toSet()); + + boolean tablePartitioned = !partitionFields.isEmpty(); + boolean partitionSpecPassed = !partitionFilter.isEmpty(); + + // Check for any non-identity partition columns + List nonIdentityFields = + partitionFields.stream() + .filter(x -> !x.transform().isIdentity()) + .collect(Collectors.toList()); + Preconditions.checkArgument( + nonIdentityFields.isEmpty(), + "Cannot add data files to target table %s because that table is partitioned and contains non-identity" + + "partition transforms which will not be compatible. Found non-identity fields %s", + table.name(), + nonIdentityFields); + + if (tablePartitioned && partitionSpecPassed) { + // Check to see there are sufficient partition columns to satisfy the filter + Preconditions.checkArgument( + partitionFields.size() >= partitionFilter.size(), + "Cannot add data files to target table %s because that table is partitioned, " + + "but the number of columns in the provided partition filter (%s) " + + "is greater than the number of partitioned columns in table (%s)", + table.name(), + partitionFilter.size(), + partitionFields.size()); + + // Check for any filters of non existent columns + List unMatchedFilters = + partitionFilter.keySet().stream() + .filter(filterName -> !partitionNames.contains(filterName)) + .collect(Collectors.toList()); + Preconditions.checkArgument( + unMatchedFilters.isEmpty(), + "Cannot add files to target table %s. %s is partitioned but the specified partition filter " + + "refers to columns that are not partitioned: '%s' . Valid partition columns %s", + table.name(), + table.name(), + unMatchedFilters, + String.join(",", partitionNames)); + } else { + Preconditions.checkArgument( + !partitionSpecPassed, + "Cannot use partition filter with an unpartitioned table %s", + table.name()); + } + } +} diff --git a/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/procedures/AncestorsOfProcedure.java b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/procedures/AncestorsOfProcedure.java new file mode 100644 index 000000000000..c3a6ca138358 --- /dev/null +++ b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/procedures/AncestorsOfProcedure.java @@ -0,0 +1,115 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.procedures; + +import java.util.List; +import org.apache.iceberg.Table; +import org.apache.iceberg.relocated.com.google.common.collect.Lists; +import org.apache.iceberg.spark.source.SparkTable; +import org.apache.iceberg.util.SnapshotUtil; +import org.apache.spark.sql.catalyst.InternalRow; +import org.apache.spark.sql.connector.catalog.Identifier; +import org.apache.spark.sql.connector.catalog.TableCatalog; +import org.apache.spark.sql.connector.iceberg.catalog.ProcedureParameter; +import org.apache.spark.sql.types.DataTypes; +import org.apache.spark.sql.types.Metadata; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; + +public class AncestorsOfProcedure extends BaseProcedure { + + private static final ProcedureParameter TABLE_PARAM = + ProcedureParameter.required("table", DataTypes.StringType); + private static final ProcedureParameter SNAPSHOT_ID_PARAM = + ProcedureParameter.optional("snapshot_id", DataTypes.LongType); + + private static final ProcedureParameter[] PARAMETERS = + new ProcedureParameter[] {TABLE_PARAM, SNAPSHOT_ID_PARAM}; + + private static final StructType OUTPUT_TYPE = + new StructType( + new StructField[] { + new StructField("snapshot_id", DataTypes.LongType, true, Metadata.empty()), + new StructField("timestamp", DataTypes.LongType, true, Metadata.empty()) + }); + + private AncestorsOfProcedure(TableCatalog tableCatalog) { + super(tableCatalog); + } + + public static SparkProcedures.ProcedureBuilder builder() { + return new Builder() { + @Override + protected AncestorsOfProcedure doBuild() { + return new AncestorsOfProcedure(tableCatalog()); + } + }; + } + + @Override + public ProcedureParameter[] parameters() { + return PARAMETERS; + } + + @Override + public StructType outputType() { + return OUTPUT_TYPE; + } + + @Override + public InternalRow[] call(InternalRow args) { + ProcedureInput input = new ProcedureInput(spark(), tableCatalog(), PARAMETERS, args); + + Identifier tableIdent = input.ident(TABLE_PARAM); + Long toSnapshotId = input.asLong(SNAPSHOT_ID_PARAM, null); + + SparkTable sparkTable = loadSparkTable(tableIdent); + Table icebergTable = sparkTable.table(); + + if (toSnapshotId == null) { + toSnapshotId = + icebergTable.currentSnapshot() != null ? icebergTable.currentSnapshot().snapshotId() : -1; + } + + List snapshotIds = + Lists.newArrayList( + SnapshotUtil.ancestorIdsBetween(toSnapshotId, null, icebergTable::snapshot)); + + return toOutputRow(icebergTable, snapshotIds); + } + + @Override + public String description() { + return "AncestorsOf"; + } + + private InternalRow[] toOutputRow(Table table, List snapshotIds) { + if (snapshotIds.isEmpty()) { + return new InternalRow[0]; + } + + InternalRow[] internalRows = new InternalRow[snapshotIds.size()]; + for (int i = 0; i < snapshotIds.size(); i++) { + Long snapshotId = snapshotIds.get(i); + internalRows[i] = newInternalRow(snapshotId, table.snapshot(snapshotId).timestampMillis()); + } + + return internalRows; + } +} diff --git a/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/procedures/BaseProcedure.java b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/procedures/BaseProcedure.java new file mode 100644 index 000000000000..ed0156adc55b --- /dev/null +++ b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/procedures/BaseProcedure.java @@ -0,0 +1,226 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.procedures; + +import java.util.Map; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.ThreadPoolExecutor; +import java.util.function.Function; +import org.apache.iceberg.exceptions.ValidationException; +import org.apache.iceberg.relocated.com.google.common.base.Preconditions; +import org.apache.iceberg.relocated.com.google.common.util.concurrent.MoreExecutors; +import org.apache.iceberg.relocated.com.google.common.util.concurrent.ThreadFactoryBuilder; +import org.apache.iceberg.spark.Spark3Util; +import org.apache.iceberg.spark.Spark3Util.CatalogAndIdentifier; +import org.apache.iceberg.spark.actions.SparkActions; +import org.apache.iceberg.spark.procedures.SparkProcedures.ProcedureBuilder; +import org.apache.iceberg.spark.source.SparkTable; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.SparkSession; +import org.apache.spark.sql.catalyst.InternalRow; +import org.apache.spark.sql.catalyst.analysis.NoSuchTableException; +import org.apache.spark.sql.catalyst.expressions.GenericInternalRow; +import org.apache.spark.sql.connector.catalog.CatalogPlugin; +import org.apache.spark.sql.connector.catalog.Identifier; +import org.apache.spark.sql.connector.catalog.Table; +import org.apache.spark.sql.connector.catalog.TableCatalog; +import org.apache.spark.sql.connector.iceberg.catalog.Procedure; +import org.apache.spark.sql.execution.CacheManager; +import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation; +import org.apache.spark.sql.types.DataType; +import org.apache.spark.sql.types.DataTypes; +import scala.Option; + +abstract class BaseProcedure implements Procedure { + protected static final DataType STRING_MAP = + DataTypes.createMapType(DataTypes.StringType, DataTypes.StringType); + protected static final DataType STRING_ARRAY = DataTypes.createArrayType(DataTypes.StringType); + + private final SparkSession spark; + private final TableCatalog tableCatalog; + + private SparkActions actions; + private ExecutorService executorService = null; + + protected BaseProcedure(TableCatalog tableCatalog) { + this.spark = SparkSession.active(); + this.tableCatalog = tableCatalog; + } + + protected SparkSession spark() { + return this.spark; + } + + protected SparkActions actions() { + if (actions == null) { + this.actions = SparkActions.get(spark); + } + return actions; + } + + protected TableCatalog tableCatalog() { + return this.tableCatalog; + } + + protected T modifyIcebergTable(Identifier ident, Function func) { + try { + return execute(ident, true, func); + } finally { + closeService(); + } + } + + protected T withIcebergTable(Identifier ident, Function func) { + try { + return execute(ident, false, func); + } finally { + closeService(); + } + } + + private T execute( + Identifier ident, boolean refreshSparkCache, Function func) { + SparkTable sparkTable = loadSparkTable(ident); + org.apache.iceberg.Table icebergTable = sparkTable.table(); + + T result = func.apply(icebergTable); + + if (refreshSparkCache) { + refreshSparkCache(ident, sparkTable); + } + + return result; + } + + protected Identifier toIdentifier(String identifierAsString, String argName) { + CatalogAndIdentifier catalogAndIdentifier = + toCatalogAndIdentifier(identifierAsString, argName, tableCatalog); + + Preconditions.checkArgument( + catalogAndIdentifier.catalog().equals(tableCatalog), + "Cannot run procedure in catalog '%s': '%s' is a table in catalog '%s'", + tableCatalog.name(), + identifierAsString, + catalogAndIdentifier.catalog().name()); + + return catalogAndIdentifier.identifier(); + } + + protected CatalogAndIdentifier toCatalogAndIdentifier( + String identifierAsString, String argName, CatalogPlugin catalog) { + Preconditions.checkArgument( + identifierAsString != null && !identifierAsString.isEmpty(), + "Cannot handle an empty identifier for argument %s", + argName); + + return Spark3Util.catalogAndIdentifier( + "identifier for arg " + argName, spark, identifierAsString, catalog); + } + + protected SparkTable loadSparkTable(Identifier ident) { + try { + Table table = tableCatalog.loadTable(ident); + ValidationException.check( + table instanceof SparkTable, "%s is not %s", ident, SparkTable.class.getName()); + return (SparkTable) table; + } catch (NoSuchTableException e) { + String errMsg = + String.format("Couldn't load table '%s' in catalog '%s'", ident, tableCatalog.name()); + throw new RuntimeException(errMsg, e); + } + } + + protected Dataset loadRows(Identifier tableIdent, Map options) { + String tableName = Spark3Util.quotedFullIdentifier(tableCatalog().name(), tableIdent); + return spark().read().options(options).table(tableName); + } + + protected void refreshSparkCache(Identifier ident, Table table) { + CacheManager cacheManager = spark.sharedState().cacheManager(); + DataSourceV2Relation relation = + DataSourceV2Relation.create(table, Option.apply(tableCatalog), Option.apply(ident)); + cacheManager.recacheByPlan(spark, relation); + } + + protected InternalRow newInternalRow(Object... values) { + return new GenericInternalRow(values); + } + + protected abstract static class Builder implements ProcedureBuilder { + private TableCatalog tableCatalog; + + @Override + public Builder withTableCatalog(TableCatalog newTableCatalog) { + this.tableCatalog = newTableCatalog; + return this; + } + + @Override + public T build() { + return doBuild(); + } + + protected abstract T doBuild(); + + TableCatalog tableCatalog() { + return tableCatalog; + } + } + + /** + * Closes this procedure's executor service if a new one was created with {@link + * #executorService(int, String)}. Does not block for any remaining tasks. + */ + protected void closeService() { + if (executorService != null) { + executorService.shutdown(); + } + } + + /** + * Starts a new executor service which can be used by this procedure in its work. The pool will be + * automatically shut down if {@link #withIcebergTable(Identifier, Function)} or {@link + * #modifyIcebergTable(Identifier, Function)} are called. If these methods are not used then the + * service can be shut down with {@link #closeService()} or left to be closed when this class is + * finalized. + * + * @param threadPoolSize number of threads in the service + * @param nameFormat name prefix for threads created in this service + * @return the new executor service owned by this procedure + */ + protected ExecutorService executorService(int threadPoolSize, String nameFormat) { + Preconditions.checkArgument( + executorService == null, "Cannot create a new executor service, one already exists."); + Preconditions.checkArgument( + nameFormat != null, "Cannot create a service with null nameFormat arg"); + this.executorService = + MoreExecutors.getExitingExecutorService( + (ThreadPoolExecutor) + Executors.newFixedThreadPool( + threadPoolSize, + new ThreadFactoryBuilder() + .setDaemon(true) + .setNameFormat(nameFormat + "-%d") + .build())); + + return executorService; + } +} diff --git a/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/procedures/CherrypickSnapshotProcedure.java b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/procedures/CherrypickSnapshotProcedure.java new file mode 100644 index 000000000000..efe9aeb9e8e8 --- /dev/null +++ b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/procedures/CherrypickSnapshotProcedure.java @@ -0,0 +1,100 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.procedures; + +import org.apache.iceberg.Snapshot; +import org.apache.iceberg.spark.procedures.SparkProcedures.ProcedureBuilder; +import org.apache.spark.sql.catalyst.InternalRow; +import org.apache.spark.sql.connector.catalog.Identifier; +import org.apache.spark.sql.connector.catalog.TableCatalog; +import org.apache.spark.sql.connector.iceberg.catalog.ProcedureParameter; +import org.apache.spark.sql.types.DataTypes; +import org.apache.spark.sql.types.Metadata; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; + +/** + * A procedure that applies changes in a given snapshot and creates a new snapshot which will be set + * as the current snapshot in a table. + * + *

Note: this procedure invalidates all cached Spark plans that reference the affected + * table. + * + * @see org.apache.iceberg.ManageSnapshots#cherrypick(long) + */ +class CherrypickSnapshotProcedure extends BaseProcedure { + + private static final ProcedureParameter[] PARAMETERS = + new ProcedureParameter[] { + ProcedureParameter.required("table", DataTypes.StringType), + ProcedureParameter.required("snapshot_id", DataTypes.LongType) + }; + + private static final StructType OUTPUT_TYPE = + new StructType( + new StructField[] { + new StructField("source_snapshot_id", DataTypes.LongType, false, Metadata.empty()), + new StructField("current_snapshot_id", DataTypes.LongType, false, Metadata.empty()) + }); + + public static ProcedureBuilder builder() { + return new BaseProcedure.Builder() { + @Override + protected CherrypickSnapshotProcedure doBuild() { + return new CherrypickSnapshotProcedure(tableCatalog()); + } + }; + } + + private CherrypickSnapshotProcedure(TableCatalog catalog) { + super(catalog); + } + + @Override + public ProcedureParameter[] parameters() { + return PARAMETERS; + } + + @Override + public StructType outputType() { + return OUTPUT_TYPE; + } + + @Override + public InternalRow[] call(InternalRow args) { + Identifier tableIdent = toIdentifier(args.getString(0), PARAMETERS[0].name()); + long snapshotId = args.getLong(1); + + return modifyIcebergTable( + tableIdent, + table -> { + table.manageSnapshots().cherrypick(snapshotId).commit(); + + Snapshot currentSnapshot = table.currentSnapshot(); + + InternalRow outputRow = newInternalRow(snapshotId, currentSnapshot.snapshotId()); + return new InternalRow[] {outputRow}; + }); + } + + @Override + public String description() { + return "CherrypickSnapshotProcedure"; + } +} diff --git a/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/procedures/CreateChangelogViewProcedure.java b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/procedures/CreateChangelogViewProcedure.java new file mode 100644 index 000000000000..9cd52df40217 --- /dev/null +++ b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/procedures/CreateChangelogViewProcedure.java @@ -0,0 +1,246 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.procedures; + +import java.util.Arrays; +import java.util.List; +import java.util.Map; +import org.apache.iceberg.MetadataColumns; +import org.apache.iceberg.Table; +import org.apache.iceberg.relocated.com.google.common.base.Preconditions; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap; +import org.apache.iceberg.relocated.com.google.common.collect.Lists; +import org.apache.iceberg.spark.ChangelogIterator; +import org.apache.iceberg.spark.source.SparkChangelogTable; +import org.apache.spark.api.java.function.MapPartitionsFunction; +import org.apache.spark.sql.Column; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.catalyst.InternalRow; +import org.apache.spark.sql.catalyst.encoders.RowEncoder; +import org.apache.spark.sql.connector.catalog.Identifier; +import org.apache.spark.sql.connector.catalog.TableCatalog; +import org.apache.spark.sql.connector.iceberg.catalog.ProcedureParameter; +import org.apache.spark.sql.types.DataTypes; +import org.apache.spark.sql.types.Metadata; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; +import org.apache.spark.unsafe.types.UTF8String; + +/** + * A procedure that creates a view for changed rows. + * + *

The procedure removes the carry-over rows by default. If you want to keep them, you can set + * "remove_carryovers" to be false in the options. + * + *

The procedure doesn't compute the pre/post update images by default. If you want to compute + * them, you can set "compute_updates" to be true in the options. + * + *

Carry-over rows are the result of a removal and insertion of the same row within an operation + * because of the copy-on-write mechanism. For example, given a file which contains row1 (id=1, + * data='a') and row2 (id=2, data='b'). A copy-on-write delete of row2 would require erasing this + * file and preserving row1 in a new file. The changelog table would report this as (id=1, data='a', + * op='DELETE') and (id=1, data='a', op='INSERT'), despite it not being an actual change to the + * table. The procedure finds the carry-over rows and removes them from the result. + * + *

Pre/post update images are converted from a pair of a delete row and an insert row. Identifier + * columns are used for determining whether an insert and a delete record refer to the same row. If + * the two records share the same values for the identity columns they are considered to be before + * and after states of the same row. You can either set identifier fields in the table schema or + * input them as the procedure parameters. Here is an example of pre/post update images with an + * identifier column(id). A pair of a delete row and an insert row with the same id: + * + *

    + *
  • (id=1, data='a', op='DELETE') + *
  • (id=1, data='b', op='INSERT') + *
+ * + *

will be marked as pre/post update images: + * + *

    + *
  • (id=1, data='a', op='UPDATE_BEFORE') + *
  • (id=1, data='b', op='UPDATE_AFTER') + *
+ */ +public class CreateChangelogViewProcedure extends BaseProcedure { + + private static final ProcedureParameter TABLE_PARAM = + ProcedureParameter.required("table", DataTypes.StringType); + private static final ProcedureParameter CHANGELOG_VIEW_PARAM = + ProcedureParameter.optional("changelog_view", DataTypes.StringType); + private static final ProcedureParameter OPTIONS_PARAM = + ProcedureParameter.optional("options", STRING_MAP); + private static final ProcedureParameter COMPUTE_UPDATES_PARAM = + ProcedureParameter.optional("compute_updates", DataTypes.BooleanType); + private static final ProcedureParameter REMOVE_CARRYOVERS_PARAM = + ProcedureParameter.optional("remove_carryovers", DataTypes.BooleanType); + private static final ProcedureParameter IDENTIFIER_COLUMNS_PARAM = + ProcedureParameter.optional("identifier_columns", STRING_ARRAY); + + private static final ProcedureParameter[] PARAMETERS = + new ProcedureParameter[] { + TABLE_PARAM, + CHANGELOG_VIEW_PARAM, + OPTIONS_PARAM, + COMPUTE_UPDATES_PARAM, + REMOVE_CARRYOVERS_PARAM, + IDENTIFIER_COLUMNS_PARAM, + }; + + private static final StructType OUTPUT_TYPE = + new StructType( + new StructField[] { + new StructField("changelog_view", DataTypes.StringType, false, Metadata.empty()) + }); + + public static SparkProcedures.ProcedureBuilder builder() { + return new BaseProcedure.Builder() { + @Override + protected CreateChangelogViewProcedure doBuild() { + return new CreateChangelogViewProcedure(tableCatalog()); + } + }; + } + + private CreateChangelogViewProcedure(TableCatalog tableCatalog) { + super(tableCatalog); + } + + @Override + public ProcedureParameter[] parameters() { + return PARAMETERS; + } + + @Override + public StructType outputType() { + return OUTPUT_TYPE; + } + + @Override + public InternalRow[] call(InternalRow args) { + ProcedureInput input = new ProcedureInput(spark(), tableCatalog(), PARAMETERS, args); + + Identifier tableIdent = input.ident(TABLE_PARAM); + + // load insert and deletes from the changelog table + Identifier changelogTableIdent = changelogTableIdent(tableIdent); + Dataset df = loadRows(changelogTableIdent, options(input)); + + if (shouldComputeUpdateImages(input)) { + df = computeUpdateImages(identifierColumns(input, tableIdent), df); + } else if (shouldRemoveCarryoverRows(input)) { + df = removeCarryoverRows(df); + } + + String viewName = viewName(input, tableIdent.name()); + + df.createOrReplaceTempView(viewName); + + return toOutputRows(viewName); + } + + private Dataset computeUpdateImages(String[] identifierColumns, Dataset df) { + Preconditions.checkArgument( + identifierColumns.length > 0, + "Cannot compute the update images because identifier columns are not set"); + + Column[] repartitionSpec = new Column[identifierColumns.length + 1]; + for (int i = 0; i < identifierColumns.length; i++) { + repartitionSpec[i] = df.col(identifierColumns[i]); + } + repartitionSpec[repartitionSpec.length - 1] = df.col(MetadataColumns.CHANGE_ORDINAL.name()); + + return applyChangelogIterator(df, repartitionSpec); + } + + private boolean shouldComputeUpdateImages(ProcedureInput input) { + // If the identifier columns are set, we compute pre/post update images by default. + boolean defaultValue = input.isProvided(IDENTIFIER_COLUMNS_PARAM); + return input.asBoolean(COMPUTE_UPDATES_PARAM, defaultValue); + } + + private boolean shouldRemoveCarryoverRows(ProcedureInput input) { + return input.asBoolean(REMOVE_CARRYOVERS_PARAM, true); + } + + private Dataset removeCarryoverRows(Dataset df) { + Column[] repartitionSpec = + Arrays.stream(df.columns()) + .filter(c -> !c.equals(MetadataColumns.CHANGE_TYPE.name())) + .map(df::col) + .toArray(Column[]::new); + return applyChangelogIterator(df, repartitionSpec); + } + + private String[] identifierColumns(ProcedureInput input, Identifier tableIdent) { + if (input.isProvided(IDENTIFIER_COLUMNS_PARAM)) { + return input.asStringArray(IDENTIFIER_COLUMNS_PARAM); + } else { + Table table = loadSparkTable(tableIdent).table(); + return table.schema().identifierFieldNames().toArray(new String[0]); + } + } + + private Identifier changelogTableIdent(Identifier tableIdent) { + List namespace = Lists.newArrayList(); + namespace.addAll(Arrays.asList(tableIdent.namespace())); + namespace.add(tableIdent.name()); + return Identifier.of(namespace.toArray(new String[0]), SparkChangelogTable.TABLE_NAME); + } + + private Map options(ProcedureInput input) { + return input.asStringMap(OPTIONS_PARAM, ImmutableMap.of()); + } + + private String viewName(ProcedureInput input, String tableName) { + String defaultValue = String.format("`%s_changes`", tableName); + return input.asString(CHANGELOG_VIEW_PARAM, defaultValue); + } + + private Dataset applyChangelogIterator(Dataset df, Column[] repartitionSpec) { + Column[] sortSpec = sortSpec(df, repartitionSpec); + StructType schema = df.schema(); + String[] identifierFields = + Arrays.stream(repartitionSpec).map(Column::toString).toArray(String[]::new); + + return df.repartition(repartitionSpec) + .sortWithinPartitions(sortSpec) + .mapPartitions( + (MapPartitionsFunction) + rowIterator -> ChangelogIterator.create(rowIterator, schema, identifierFields), + RowEncoder.apply(schema)); + } + + private static Column[] sortSpec(Dataset df, Column[] repartitionSpec) { + Column[] sortSpec = new Column[repartitionSpec.length + 1]; + System.arraycopy(repartitionSpec, 0, sortSpec, 0, repartitionSpec.length); + sortSpec[sortSpec.length - 1] = df.col(MetadataColumns.CHANGE_TYPE.name()); + return sortSpec; + } + + private InternalRow[] toOutputRows(String viewName) { + InternalRow row = newInternalRow(UTF8String.fromString(viewName)); + return new InternalRow[] {row}; + } + + @Override + public String description() { + return "CreateChangelogViewProcedure"; + } +} diff --git a/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/procedures/ExpireSnapshotsProcedure.java b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/procedures/ExpireSnapshotsProcedure.java new file mode 100644 index 000000000000..9d2fc7e467cf --- /dev/null +++ b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/procedures/ExpireSnapshotsProcedure.java @@ -0,0 +1,173 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.procedures; + +import org.apache.iceberg.Table; +import org.apache.iceberg.actions.ExpireSnapshots; +import org.apache.iceberg.io.SupportsBulkOperations; +import org.apache.iceberg.relocated.com.google.common.base.Preconditions; +import org.apache.iceberg.spark.actions.ExpireSnapshotsSparkAction; +import org.apache.iceberg.spark.actions.SparkActions; +import org.apache.iceberg.spark.procedures.SparkProcedures.ProcedureBuilder; +import org.apache.iceberg.util.DateTimeUtil; +import org.apache.spark.sql.catalyst.InternalRow; +import org.apache.spark.sql.connector.catalog.Identifier; +import org.apache.spark.sql.connector.catalog.TableCatalog; +import org.apache.spark.sql.connector.iceberg.catalog.ProcedureParameter; +import org.apache.spark.sql.types.DataTypes; +import org.apache.spark.sql.types.Metadata; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * A procedure that expires snapshots in a table. + * + * @see SparkActions#expireSnapshots(Table) + */ +public class ExpireSnapshotsProcedure extends BaseProcedure { + + private static final Logger LOG = LoggerFactory.getLogger(ExpireSnapshotsProcedure.class); + + private static final ProcedureParameter[] PARAMETERS = + new ProcedureParameter[] { + ProcedureParameter.required("table", DataTypes.StringType), + ProcedureParameter.optional("older_than", DataTypes.TimestampType), + ProcedureParameter.optional("retain_last", DataTypes.IntegerType), + ProcedureParameter.optional("max_concurrent_deletes", DataTypes.IntegerType), + ProcedureParameter.optional("stream_results", DataTypes.BooleanType), + ProcedureParameter.optional("snapshot_ids", DataTypes.createArrayType(DataTypes.LongType)) + }; + + private static final StructType OUTPUT_TYPE = + new StructType( + new StructField[] { + new StructField("deleted_data_files_count", DataTypes.LongType, true, Metadata.empty()), + new StructField( + "deleted_position_delete_files_count", DataTypes.LongType, true, Metadata.empty()), + new StructField( + "deleted_equality_delete_files_count", DataTypes.LongType, true, Metadata.empty()), + new StructField( + "deleted_manifest_files_count", DataTypes.LongType, true, Metadata.empty()), + new StructField( + "deleted_manifest_lists_count", DataTypes.LongType, true, Metadata.empty()), + new StructField( + "deleted_statistics_files_count", DataTypes.LongType, true, Metadata.empty()) + }); + + public static ProcedureBuilder builder() { + return new BaseProcedure.Builder() { + @Override + protected ExpireSnapshotsProcedure doBuild() { + return new ExpireSnapshotsProcedure(tableCatalog()); + } + }; + } + + private ExpireSnapshotsProcedure(TableCatalog tableCatalog) { + super(tableCatalog); + } + + @Override + public ProcedureParameter[] parameters() { + return PARAMETERS; + } + + @Override + public StructType outputType() { + return OUTPUT_TYPE; + } + + @Override + @SuppressWarnings("checkstyle:CyclomaticComplexity") + public InternalRow[] call(InternalRow args) { + Identifier tableIdent = toIdentifier(args.getString(0), PARAMETERS[0].name()); + Long olderThanMillis = args.isNullAt(1) ? null : DateTimeUtil.microsToMillis(args.getLong(1)); + Integer retainLastNum = args.isNullAt(2) ? null : args.getInt(2); + Integer maxConcurrentDeletes = args.isNullAt(3) ? null : args.getInt(3); + Boolean streamResult = args.isNullAt(4) ? null : args.getBoolean(4); + long[] snapshotIds = args.isNullAt(5) ? null : args.getArray(5).toLongArray(); + + Preconditions.checkArgument( + maxConcurrentDeletes == null || maxConcurrentDeletes > 0, + "max_concurrent_deletes should have value > 0, value: %s", + maxConcurrentDeletes); + + return modifyIcebergTable( + tableIdent, + table -> { + ExpireSnapshots action = actions().expireSnapshots(table); + + if (olderThanMillis != null) { + action.expireOlderThan(olderThanMillis); + } + + if (retainLastNum != null) { + action.retainLast(retainLastNum); + } + + if (maxConcurrentDeletes != null) { + if (table.io() instanceof SupportsBulkOperations) { + LOG.warn( + "max_concurrent_deletes only works with FileIOs that do not support bulk deletes. This" + + "table is currently using {} which supports bulk deletes so the parameter will be ignored. " + + "See that IO's documentation to learn how to adjust parallelism for that particular " + + "IO's bulk delete.", + table.io().getClass().getName()); + } else { + + action.executeDeleteWith(executorService(maxConcurrentDeletes, "expire-snapshots")); + } + } + + if (snapshotIds != null) { + for (long snapshotId : snapshotIds) { + action.expireSnapshotId(snapshotId); + } + } + + if (streamResult != null) { + action.option( + ExpireSnapshotsSparkAction.STREAM_RESULTS, Boolean.toString(streamResult)); + } + + ExpireSnapshots.Result result = action.execute(); + + return toOutputRows(result); + }); + } + + private InternalRow[] toOutputRows(ExpireSnapshots.Result result) { + InternalRow row = + newInternalRow( + result.deletedDataFilesCount(), + result.deletedPositionDeleteFilesCount(), + result.deletedEqualityDeleteFilesCount(), + result.deletedManifestsCount(), + result.deletedManifestListsCount(), + result.deletedStatisticsFilesCount()); + return new InternalRow[] {row}; + } + + @Override + public String description() { + return "ExpireSnapshotProcedure"; + } +} diff --git a/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/procedures/MigrateTableProcedure.java b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/procedures/MigrateTableProcedure.java new file mode 100644 index 000000000000..aaa6d2cb238d --- /dev/null +++ b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/procedures/MigrateTableProcedure.java @@ -0,0 +1,112 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.procedures; + +import java.util.Map; +import org.apache.iceberg.actions.MigrateTable; +import org.apache.iceberg.relocated.com.google.common.base.Preconditions; +import org.apache.iceberg.relocated.com.google.common.collect.Maps; +import org.apache.iceberg.spark.actions.MigrateTableSparkAction; +import org.apache.iceberg.spark.actions.SparkActions; +import org.apache.iceberg.spark.procedures.SparkProcedures.ProcedureBuilder; +import org.apache.spark.sql.catalyst.InternalRow; +import org.apache.spark.sql.connector.catalog.TableCatalog; +import org.apache.spark.sql.connector.iceberg.catalog.ProcedureParameter; +import org.apache.spark.sql.types.DataTypes; +import org.apache.spark.sql.types.Metadata; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; +import scala.runtime.BoxedUnit; + +class MigrateTableProcedure extends BaseProcedure { + private static final ProcedureParameter[] PARAMETERS = + new ProcedureParameter[] { + ProcedureParameter.required("table", DataTypes.StringType), + ProcedureParameter.optional("properties", STRING_MAP), + ProcedureParameter.optional("drop_backup", DataTypes.BooleanType) + }; + + private static final StructType OUTPUT_TYPE = + new StructType( + new StructField[] { + new StructField("migrated_files_count", DataTypes.LongType, false, Metadata.empty()) + }); + + private MigrateTableProcedure(TableCatalog tableCatalog) { + super(tableCatalog); + } + + public static ProcedureBuilder builder() { + return new BaseProcedure.Builder() { + @Override + protected MigrateTableProcedure doBuild() { + return new MigrateTableProcedure(tableCatalog()); + } + }; + } + + @Override + public ProcedureParameter[] parameters() { + return PARAMETERS; + } + + @Override + public StructType outputType() { + return OUTPUT_TYPE; + } + + @Override + public InternalRow[] call(InternalRow args) { + String tableName = args.getString(0); + Preconditions.checkArgument( + tableName != null && !tableName.isEmpty(), + "Cannot handle an empty identifier for argument table"); + + Map properties = Maps.newHashMap(); + if (!args.isNullAt(1)) { + args.getMap(1) + .foreach( + DataTypes.StringType, + DataTypes.StringType, + (k, v) -> { + properties.put(k.toString(), v.toString()); + return BoxedUnit.UNIT; + }); + } + + boolean dropBackup = args.isNullAt(2) ? false : args.getBoolean(2); + + MigrateTableSparkAction migrateTableSparkAction = + SparkActions.get().migrateTable(tableName).tableProperties(properties); + + MigrateTable.Result result; + if (dropBackup) { + result = migrateTableSparkAction.dropBackup().execute(); + } else { + result = migrateTableSparkAction.execute(); + } + + return new InternalRow[] {newInternalRow(result.migratedDataFilesCount())}; + } + + @Override + public String description() { + return "MigrateTableProcedure"; + } +} diff --git a/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/procedures/ProcedureInput.java b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/procedures/ProcedureInput.java new file mode 100644 index 000000000000..42e4d8ba0603 --- /dev/null +++ b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/procedures/ProcedureInput.java @@ -0,0 +1,229 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.procedures; + +import java.lang.reflect.Array; +import java.util.Map; +import java.util.function.BiFunction; +import org.apache.commons.lang3.StringUtils; +import org.apache.iceberg.relocated.com.google.common.base.Preconditions; +import org.apache.iceberg.relocated.com.google.common.collect.Maps; +import org.apache.iceberg.spark.Spark3Util; +import org.apache.iceberg.spark.Spark3Util.CatalogAndIdentifier; +import org.apache.spark.sql.SparkSession; +import org.apache.spark.sql.catalyst.InternalRow; +import org.apache.spark.sql.catalyst.util.ArrayData; +import org.apache.spark.sql.catalyst.util.MapData; +import org.apache.spark.sql.connector.catalog.CatalogPlugin; +import org.apache.spark.sql.connector.catalog.Identifier; +import org.apache.spark.sql.connector.catalog.TableCatalog; +import org.apache.spark.sql.connector.iceberg.catalog.ProcedureParameter; +import org.apache.spark.sql.types.DataType; +import org.apache.spark.sql.types.DataTypes; + +/** A class that abstracts common logic for working with input to a procedure. */ +class ProcedureInput { + + private static final DataType STRING_ARRAY = DataTypes.createArrayType(DataTypes.StringType); + private static final DataType STRING_MAP = + DataTypes.createMapType(DataTypes.StringType, DataTypes.StringType); + + private final SparkSession spark; + private final TableCatalog catalog; + private final Map paramOrdinals; + private final InternalRow args; + + ProcedureInput( + SparkSession spark, TableCatalog catalog, ProcedureParameter[] params, InternalRow args) { + this.spark = spark; + this.catalog = catalog; + this.paramOrdinals = computeParamOrdinals(params); + this.args = args; + } + + public boolean isProvided(ProcedureParameter param) { + int ordinal = ordinal(param); + return !args.isNullAt(ordinal); + } + + public Boolean asBoolean(ProcedureParameter param, Boolean defaultValue) { + validateParamType(param, DataTypes.BooleanType); + int ordinal = ordinal(param); + return args.isNullAt(ordinal) ? defaultValue : (Boolean) args.getBoolean(ordinal); + } + + public long asLong(ProcedureParameter param) { + Long value = asLong(param, null); + Preconditions.checkArgument(value != null, "Parameter '%s' is not set", param.name()); + return value; + } + + public Long asLong(ProcedureParameter param, Long defaultValue) { + validateParamType(param, DataTypes.LongType); + int ordinal = ordinal(param); + return args.isNullAt(ordinal) ? defaultValue : (Long) args.getLong(ordinal); + } + + public String asString(ProcedureParameter param) { + String value = asString(param, null); + Preconditions.checkArgument(value != null, "Parameter '%s' is not set", param.name()); + return value; + } + + public String asString(ProcedureParameter param, String defaultValue) { + validateParamType(param, DataTypes.StringType); + int ordinal = ordinal(param); + return args.isNullAt(ordinal) ? defaultValue : args.getString(ordinal); + } + + public String[] asStringArray(ProcedureParameter param) { + String[] value = asStringArray(param, null); + Preconditions.checkArgument(value != null, "Parameter '%s' is not set", param.name()); + return value; + } + + public String[] asStringArray(ProcedureParameter param, String[] defaultValue) { + validateParamType(param, STRING_ARRAY); + return array( + param, + (array, ordinal) -> array.getUTF8String(ordinal).toString(), + String.class, + defaultValue); + } + + @SuppressWarnings("unchecked") + private T[] array( + ProcedureParameter param, + BiFunction convertElement, + Class elementClass, + T[] defaultValue) { + + int ordinal = ordinal(param); + + if (args.isNullAt(ordinal)) { + return defaultValue; + } + + ArrayData arrayData = args.getArray(ordinal); + + T[] convertedArray = (T[]) Array.newInstance(elementClass, arrayData.numElements()); + + for (int index = 0; index < arrayData.numElements(); index++) { + convertedArray[index] = convertElement.apply(arrayData, index); + } + + return convertedArray; + } + + public Map asStringMap( + ProcedureParameter param, Map defaultValue) { + validateParamType(param, STRING_MAP); + return map( + param, + (keys, ordinal) -> keys.getUTF8String(ordinal).toString(), + (values, ordinal) -> values.getUTF8String(ordinal).toString(), + defaultValue); + } + + private Map map( + ProcedureParameter param, + BiFunction convertKey, + BiFunction convertValue, + Map defaultValue) { + + int ordinal = ordinal(param); + + if (args.isNullAt(ordinal)) { + return defaultValue; + } + + MapData mapData = args.getMap(ordinal); + + Map convertedMap = Maps.newHashMap(); + + for (int index = 0; index < mapData.numElements(); index++) { + K convertedKey = convertKey.apply(mapData.keyArray(), index); + V convertedValue = convertValue.apply(mapData.valueArray(), index); + convertedMap.put(convertedKey, convertedValue); + } + + return convertedMap; + } + + public Identifier ident(ProcedureParameter param) { + CatalogAndIdentifier catalogAndIdent = catalogAndIdent(param, catalog); + + Preconditions.checkArgument( + catalogAndIdent.catalog().equals(catalog), + "Cannot run procedure in catalog '%s': '%s' is a table in catalog '%s'", + catalog.name(), + catalogAndIdent.identifier(), + catalogAndIdent.catalog().name()); + + return catalogAndIdent.identifier(); + } + + public Identifier ident(ProcedureParameter param, CatalogPlugin defaultCatalog) { + CatalogAndIdentifier catalogAndIdent = catalogAndIdent(param, defaultCatalog); + return catalogAndIdent.identifier(); + } + + private CatalogAndIdentifier catalogAndIdent( + ProcedureParameter param, CatalogPlugin defaultCatalog) { + + String identAsString = asString(param); + + Preconditions.checkArgument( + StringUtils.isNotBlank(identAsString), + "Cannot handle an empty identifier for parameter '%s'", + param.name()); + + String desc = String.format("identifier for parameter '%s'", param.name()); + return Spark3Util.catalogAndIdentifier(desc, spark, identAsString, defaultCatalog); + } + + private int ordinal(ProcedureParameter param) { + return paramOrdinals.get(param.name()); + } + + private Map computeParamOrdinals(ProcedureParameter[] params) { + Map ordinals = Maps.newHashMap(); + + for (int index = 0; index < params.length; index++) { + String paramName = params[index].name(); + + Preconditions.checkArgument( + !ordinals.containsKey(paramName), + "Detected multiple parameters named as '%s'", + paramName); + + ordinals.put(paramName, index); + } + + return ordinals; + } + + private void validateParamType(ProcedureParameter param, DataType expectedDataType) { + Preconditions.checkArgument( + expectedDataType.sameType(param.dataType()), + "Parameter '%s' must be of type %s", + param.name(), + expectedDataType.catalogString()); + } +} diff --git a/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/procedures/PublishChangesProcedure.java b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/procedures/PublishChangesProcedure.java new file mode 100644 index 000000000000..eb6c762ed51e --- /dev/null +++ b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/procedures/PublishChangesProcedure.java @@ -0,0 +1,115 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.procedures; + +import java.util.Optional; +import org.apache.iceberg.Snapshot; +import org.apache.iceberg.exceptions.ValidationException; +import org.apache.iceberg.relocated.com.google.common.collect.Iterables; +import org.apache.iceberg.spark.procedures.SparkProcedures.ProcedureBuilder; +import org.apache.iceberg.util.WapUtil; +import org.apache.spark.sql.catalyst.InternalRow; +import org.apache.spark.sql.connector.catalog.Identifier; +import org.apache.spark.sql.connector.catalog.TableCatalog; +import org.apache.spark.sql.connector.iceberg.catalog.ProcedureParameter; +import org.apache.spark.sql.types.DataTypes; +import org.apache.spark.sql.types.Metadata; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; + +/** + * A procedure that applies changes in a snapshot created within a Write-Audit-Publish workflow with + * a wap_id and creates a new snapshot which will be set as the current snapshot in a table. + * + *

Note: this procedure invalidates all cached Spark plans that reference the affected + * table. + * + * @see org.apache.iceberg.ManageSnapshots#cherrypick(long) + */ +class PublishChangesProcedure extends BaseProcedure { + + private static final ProcedureParameter[] PARAMETERS = + new ProcedureParameter[] { + ProcedureParameter.required("table", DataTypes.StringType), + ProcedureParameter.required("wap_id", DataTypes.StringType) + }; + + private static final StructType OUTPUT_TYPE = + new StructType( + new StructField[] { + new StructField("source_snapshot_id", DataTypes.LongType, false, Metadata.empty()), + new StructField("current_snapshot_id", DataTypes.LongType, false, Metadata.empty()) + }); + + public static ProcedureBuilder builder() { + return new Builder() { + @Override + protected PublishChangesProcedure doBuild() { + return new PublishChangesProcedure(tableCatalog()); + } + }; + } + + private PublishChangesProcedure(TableCatalog catalog) { + super(catalog); + } + + @Override + public ProcedureParameter[] parameters() { + return PARAMETERS; + } + + @Override + public StructType outputType() { + return OUTPUT_TYPE; + } + + @Override + public InternalRow[] call(InternalRow args) { + Identifier tableIdent = toIdentifier(args.getString(0), PARAMETERS[0].name()); + String wapId = args.getString(1); + + return modifyIcebergTable( + tableIdent, + table -> { + Optional wapSnapshot = + Optional.ofNullable( + Iterables.find( + table.snapshots(), + snapshot -> wapId.equals(WapUtil.stagedWapId(snapshot)), + null)); + if (!wapSnapshot.isPresent()) { + throw new ValidationException(String.format("Cannot apply unknown WAP ID '%s'", wapId)); + } + + long wapSnapshotId = wapSnapshot.get().snapshotId(); + table.manageSnapshots().cherrypick(wapSnapshotId).commit(); + + Snapshot currentSnapshot = table.currentSnapshot(); + + InternalRow outputRow = newInternalRow(wapSnapshotId, currentSnapshot.snapshotId()); + return new InternalRow[] {outputRow}; + }); + } + + @Override + public String description() { + return "ApplyWapChangesProcedure"; + } +} diff --git a/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/procedures/RegisterTableProcedure.java b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/procedures/RegisterTableProcedure.java new file mode 100644 index 000000000000..857949e052c8 --- /dev/null +++ b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/procedures/RegisterTableProcedure.java @@ -0,0 +1,110 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.procedures; + +import org.apache.iceberg.Snapshot; +import org.apache.iceberg.SnapshotSummary; +import org.apache.iceberg.Table; +import org.apache.iceberg.catalog.Catalog; +import org.apache.iceberg.catalog.TableIdentifier; +import org.apache.iceberg.relocated.com.google.common.base.Preconditions; +import org.apache.iceberg.spark.Spark3Util; +import org.apache.iceberg.spark.procedures.SparkProcedures.ProcedureBuilder; +import org.apache.iceberg.spark.source.HasIcebergCatalog; +import org.apache.spark.sql.catalyst.InternalRow; +import org.apache.spark.sql.connector.catalog.TableCatalog; +import org.apache.spark.sql.connector.iceberg.catalog.ProcedureParameter; +import org.apache.spark.sql.types.DataTypes; +import org.apache.spark.sql.types.Metadata; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; + +class RegisterTableProcedure extends BaseProcedure { + private static final ProcedureParameter[] PARAMETERS = + new ProcedureParameter[] { + ProcedureParameter.required("table", DataTypes.StringType), + ProcedureParameter.required("metadata_file", DataTypes.StringType) + }; + + private static final StructType OUTPUT_TYPE = + new StructType( + new StructField[] { + new StructField("current_snapshot_id", DataTypes.LongType, true, Metadata.empty()), + new StructField("total_records_count", DataTypes.LongType, true, Metadata.empty()), + new StructField("total_data_files_count", DataTypes.LongType, true, Metadata.empty()) + }); + + private RegisterTableProcedure(TableCatalog tableCatalog) { + super(tableCatalog); + } + + public static ProcedureBuilder builder() { + return new BaseProcedure.Builder() { + @Override + protected RegisterTableProcedure doBuild() { + return new RegisterTableProcedure(tableCatalog()); + } + }; + } + + @Override + public ProcedureParameter[] parameters() { + return PARAMETERS; + } + + @Override + public StructType outputType() { + return OUTPUT_TYPE; + } + + @Override + public InternalRow[] call(InternalRow args) { + TableIdentifier tableName = + Spark3Util.identifierToTableIdentifier(toIdentifier(args.getString(0), "table")); + String metadataFile = args.getString(1); + Preconditions.checkArgument( + tableCatalog() instanceof HasIcebergCatalog, + "Cannot use Register Table in a non-Iceberg catalog"); + Preconditions.checkArgument( + metadataFile != null && !metadataFile.isEmpty(), + "Cannot handle an empty argument metadata_file"); + + Catalog icebergCatalog = ((HasIcebergCatalog) tableCatalog()).icebergCatalog(); + Table table = icebergCatalog.registerTable(tableName, metadataFile); + Long currentSnapshotId = null; + Long totalDataFiles = null; + Long totalRecords = null; + + Snapshot currentSnapshot = table.currentSnapshot(); + if (currentSnapshot != null) { + currentSnapshotId = currentSnapshot.snapshotId(); + totalDataFiles = + Long.parseLong(currentSnapshot.summary().get(SnapshotSummary.TOTAL_DATA_FILES_PROP)); + totalRecords = + Long.parseLong(currentSnapshot.summary().get(SnapshotSummary.TOTAL_RECORDS_PROP)); + } + + return new InternalRow[] {newInternalRow(currentSnapshotId, totalRecords, totalDataFiles)}; + } + + @Override + public String description() { + return "RegisterTableProcedure"; + } +} diff --git a/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/procedures/RemoveOrphanFilesProcedure.java b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/procedures/RemoveOrphanFilesProcedure.java new file mode 100644 index 000000000000..6e66ea2629b8 --- /dev/null +++ b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/procedures/RemoveOrphanFilesProcedure.java @@ -0,0 +1,222 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.procedures; + +import java.util.Map; +import java.util.concurrent.TimeUnit; +import org.apache.iceberg.Table; +import org.apache.iceberg.actions.DeleteOrphanFiles; +import org.apache.iceberg.actions.DeleteOrphanFiles.PrefixMismatchMode; +import org.apache.iceberg.io.SupportsBulkOperations; +import org.apache.iceberg.relocated.com.google.common.base.Preconditions; +import org.apache.iceberg.relocated.com.google.common.collect.Iterables; +import org.apache.iceberg.relocated.com.google.common.collect.Maps; +import org.apache.iceberg.spark.actions.DeleteOrphanFilesSparkAction; +import org.apache.iceberg.spark.actions.SparkActions; +import org.apache.iceberg.spark.procedures.SparkProcedures.ProcedureBuilder; +import org.apache.iceberg.util.DateTimeUtil; +import org.apache.spark.sql.catalyst.InternalRow; +import org.apache.spark.sql.connector.catalog.Identifier; +import org.apache.spark.sql.connector.catalog.TableCatalog; +import org.apache.spark.sql.connector.iceberg.catalog.ProcedureParameter; +import org.apache.spark.sql.types.DataTypes; +import org.apache.spark.sql.types.Metadata; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; +import org.apache.spark.unsafe.types.UTF8String; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import scala.runtime.BoxedUnit; + +/** + * A procedure that removes orphan files in a table. + * + * @see SparkActions#deleteOrphanFiles(Table) + */ +public class RemoveOrphanFilesProcedure extends BaseProcedure { + private static final Logger LOG = LoggerFactory.getLogger(RemoveOrphanFilesProcedure.class); + + private static final ProcedureParameter[] PARAMETERS = + new ProcedureParameter[] { + ProcedureParameter.required("table", DataTypes.StringType), + ProcedureParameter.optional("older_than", DataTypes.TimestampType), + ProcedureParameter.optional("location", DataTypes.StringType), + ProcedureParameter.optional("dry_run", DataTypes.BooleanType), + ProcedureParameter.optional("max_concurrent_deletes", DataTypes.IntegerType), + ProcedureParameter.optional("file_list_view", DataTypes.StringType), + ProcedureParameter.optional("equal_schemes", STRING_MAP), + ProcedureParameter.optional("equal_authorities", STRING_MAP), + ProcedureParameter.optional("prefix_mismatch_mode", DataTypes.StringType), + }; + + private static final StructType OUTPUT_TYPE = + new StructType( + new StructField[] { + new StructField("orphan_file_location", DataTypes.StringType, false, Metadata.empty()) + }); + + public static ProcedureBuilder builder() { + return new BaseProcedure.Builder() { + @Override + protected RemoveOrphanFilesProcedure doBuild() { + return new RemoveOrphanFilesProcedure(tableCatalog()); + } + }; + } + + private RemoveOrphanFilesProcedure(TableCatalog catalog) { + super(catalog); + } + + @Override + public ProcedureParameter[] parameters() { + return PARAMETERS; + } + + @Override + public StructType outputType() { + return OUTPUT_TYPE; + } + + @Override + @SuppressWarnings("checkstyle:CyclomaticComplexity") + public InternalRow[] call(InternalRow args) { + Identifier tableIdent = toIdentifier(args.getString(0), PARAMETERS[0].name()); + Long olderThanMillis = args.isNullAt(1) ? null : DateTimeUtil.microsToMillis(args.getLong(1)); + String location = args.isNullAt(2) ? null : args.getString(2); + boolean dryRun = args.isNullAt(3) ? false : args.getBoolean(3); + Integer maxConcurrentDeletes = args.isNullAt(4) ? null : args.getInt(4); + String fileListView = args.isNullAt(5) ? null : args.getString(5); + + Preconditions.checkArgument( + maxConcurrentDeletes == null || maxConcurrentDeletes > 0, + "max_concurrent_deletes should have value > 0, value: %s", + maxConcurrentDeletes); + + Map equalSchemes = Maps.newHashMap(); + if (!args.isNullAt(6)) { + args.getMap(6) + .foreach( + DataTypes.StringType, + DataTypes.StringType, + (k, v) -> { + equalSchemes.put(k.toString(), v.toString()); + return BoxedUnit.UNIT; + }); + } + + Map equalAuthorities = Maps.newHashMap(); + if (!args.isNullAt(7)) { + args.getMap(7) + .foreach( + DataTypes.StringType, + DataTypes.StringType, + (k, v) -> { + equalSchemes.put(k.toString(), v.toString()); + return BoxedUnit.UNIT; + }); + } + + PrefixMismatchMode prefixMismatchMode = + args.isNullAt(8) ? null : PrefixMismatchMode.fromString(args.getString(8)); + + return withIcebergTable( + tableIdent, + table -> { + DeleteOrphanFilesSparkAction action = actions().deleteOrphanFiles(table); + + if (olderThanMillis != null) { + boolean isTesting = Boolean.parseBoolean(spark().conf().get("spark.testing", "false")); + if (!isTesting) { + validateInterval(olderThanMillis); + } + action.olderThan(olderThanMillis); + } + + if (location != null) { + action.location(location); + } + + if (dryRun) { + action.deleteWith(file -> {}); + } + + if (maxConcurrentDeletes != null) { + if (table.io() instanceof SupportsBulkOperations) { + LOG.warn( + "max_concurrent_deletes only works with FileIOs that do not support bulk deletes. This" + + "table is currently using {} which supports bulk deletes so the parameter will be ignored. " + + "See that IO's documentation to learn how to adjust parallelism for that particular " + + "IO's bulk delete.", + table.io().getClass().getName()); + } else { + + action.executeDeleteWith(executorService(maxConcurrentDeletes, "remove-orphans")); + } + } + + if (fileListView != null) { + action.compareToFileList(spark().table(fileListView)); + } + + action.equalSchemes(equalSchemes); + action.equalAuthorities(equalAuthorities); + + if (prefixMismatchMode != null) { + action.prefixMismatchMode(prefixMismatchMode); + } + + DeleteOrphanFiles.Result result = action.execute(); + + return toOutputRows(result); + }); + } + + private InternalRow[] toOutputRows(DeleteOrphanFiles.Result result) { + Iterable orphanFileLocations = result.orphanFileLocations(); + + int orphanFileLocationsCount = Iterables.size(orphanFileLocations); + InternalRow[] rows = new InternalRow[orphanFileLocationsCount]; + + int index = 0; + for (String fileLocation : orphanFileLocations) { + rows[index] = newInternalRow(UTF8String.fromString(fileLocation)); + index++; + } + + return rows; + } + + private void validateInterval(long olderThanMillis) { + long intervalMillis = System.currentTimeMillis() - olderThanMillis; + if (intervalMillis < TimeUnit.DAYS.toMillis(1)) { + throw new IllegalArgumentException( + "Cannot remove orphan files with an interval less than 24 hours. Executing this " + + "procedure with a short interval may corrupt the table if other operations are happening " + + "at the same time. If you are absolutely confident that no concurrent operations will be " + + "affected by removing orphan files with such a short interval, you can use the Action API " + + "to remove orphan files with an arbitrary interval."); + } + } + + @Override + public String description() { + return "RemoveOrphanFilesProcedure"; + } +} diff --git a/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/procedures/RewriteDataFilesProcedure.java b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/procedures/RewriteDataFilesProcedure.java new file mode 100644 index 000000000000..1aea61e74785 --- /dev/null +++ b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/procedures/RewriteDataFilesProcedure.java @@ -0,0 +1,228 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.procedures; + +import java.util.List; +import java.util.Map; +import org.apache.iceberg.Schema; +import org.apache.iceberg.SortOrder; +import org.apache.iceberg.Table; +import org.apache.iceberg.actions.RewriteDataFiles; +import org.apache.iceberg.expressions.NamedReference; +import org.apache.iceberg.expressions.Zorder; +import org.apache.iceberg.relocated.com.google.common.collect.Lists; +import org.apache.iceberg.relocated.com.google.common.collect.Maps; +import org.apache.iceberg.spark.ExtendedParser; +import org.apache.iceberg.spark.Spark3Util; +import org.apache.iceberg.spark.procedures.SparkProcedures.ProcedureBuilder; +import org.apache.spark.sql.AnalysisException; +import org.apache.spark.sql.catalyst.InternalRow; +import org.apache.spark.sql.catalyst.expressions.Expression; +import org.apache.spark.sql.connector.catalog.Identifier; +import org.apache.spark.sql.connector.catalog.TableCatalog; +import org.apache.spark.sql.connector.iceberg.catalog.ProcedureParameter; +import org.apache.spark.sql.execution.datasources.SparkExpressionConverter; +import org.apache.spark.sql.types.DataTypes; +import org.apache.spark.sql.types.Metadata; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; +import scala.runtime.BoxedUnit; + +/** + * A procedure that rewrites datafiles in a table. + * + * @see org.apache.iceberg.spark.actions.SparkActions#rewriteDataFiles(Table) + */ +class RewriteDataFilesProcedure extends BaseProcedure { + + private static final ProcedureParameter[] PARAMETERS = + new ProcedureParameter[] { + ProcedureParameter.required("table", DataTypes.StringType), + ProcedureParameter.optional("strategy", DataTypes.StringType), + ProcedureParameter.optional("sort_order", DataTypes.StringType), + ProcedureParameter.optional("options", STRING_MAP), + ProcedureParameter.optional("where", DataTypes.StringType) + }; + + // counts are not nullable since the action result is never null + private static final StructType OUTPUT_TYPE = + new StructType( + new StructField[] { + new StructField( + "rewritten_data_files_count", DataTypes.IntegerType, false, Metadata.empty()), + new StructField( + "added_data_files_count", DataTypes.IntegerType, false, Metadata.empty()), + new StructField("rewritten_bytes_count", DataTypes.LongType, false, Metadata.empty()) + }); + + public static ProcedureBuilder builder() { + return new Builder() { + @Override + protected RewriteDataFilesProcedure doBuild() { + return new RewriteDataFilesProcedure(tableCatalog()); + } + }; + } + + private RewriteDataFilesProcedure(TableCatalog tableCatalog) { + super(tableCatalog); + } + + @Override + public ProcedureParameter[] parameters() { + return PARAMETERS; + } + + @Override + public StructType outputType() { + return OUTPUT_TYPE; + } + + @Override + public InternalRow[] call(InternalRow args) { + Identifier tableIdent = toIdentifier(args.getString(0), PARAMETERS[0].name()); + + return modifyIcebergTable( + tableIdent, + table -> { + String quotedFullIdentifier = + Spark3Util.quotedFullIdentifier(tableCatalog().name(), tableIdent); + RewriteDataFiles action = actions().rewriteDataFiles(table); + + String strategy = args.isNullAt(1) ? null : args.getString(1); + String sortOrderString = args.isNullAt(2) ? null : args.getString(2); + + if (strategy != null || sortOrderString != null) { + action = checkAndApplyStrategy(action, strategy, sortOrderString, table.schema()); + } + + if (!args.isNullAt(3)) { + action = checkAndApplyOptions(args, action); + } + + String where = args.isNullAt(4) ? null : args.getString(4); + + action = checkAndApplyFilter(action, where, quotedFullIdentifier); + + RewriteDataFiles.Result result = action.execute(); + + return toOutputRows(result); + }); + } + + private RewriteDataFiles checkAndApplyFilter( + RewriteDataFiles action, String where, String tableName) { + if (where != null) { + try { + Expression expression = + SparkExpressionConverter.collectResolvedSparkExpression(spark(), tableName, where); + return action.filter(SparkExpressionConverter.convertToIcebergExpression(expression)); + } catch (AnalysisException e) { + throw new IllegalArgumentException("Cannot parse predicates in where option: " + where); + } + } + return action; + } + + private RewriteDataFiles checkAndApplyOptions(InternalRow args, RewriteDataFiles action) { + Map options = Maps.newHashMap(); + args.getMap(3) + .foreach( + DataTypes.StringType, + DataTypes.StringType, + (k, v) -> { + options.put(k.toString(), v.toString()); + return BoxedUnit.UNIT; + }); + return action.options(options); + } + + private RewriteDataFiles checkAndApplyStrategy( + RewriteDataFiles action, String strategy, String sortOrderString, Schema schema) { + List zOrderTerms = Lists.newArrayList(); + List sortOrderFields = Lists.newArrayList(); + if (sortOrderString != null) { + ExtendedParser.parseSortOrder(spark(), sortOrderString) + .forEach( + field -> { + if (field.term() instanceof Zorder) { + zOrderTerms.add((Zorder) field.term()); + } else { + sortOrderFields.add(field); + } + }); + + if (!zOrderTerms.isEmpty() && !sortOrderFields.isEmpty()) { + // TODO: we need to allow this in future when SparkAction has handling for this. + throw new IllegalArgumentException( + "Cannot mix identity sort columns and a Zorder sort expression: " + sortOrderString); + } + } + + // caller of this function ensures that between strategy and sortOrder, at least one of them is + // not null. + if (strategy == null || strategy.equalsIgnoreCase("sort")) { + if (!zOrderTerms.isEmpty()) { + String[] columnNames = + zOrderTerms.stream() + .flatMap(zOrder -> zOrder.refs().stream().map(NamedReference::name)) + .toArray(String[]::new); + return action.zOrder(columnNames); + } else if (!sortOrderFields.isEmpty()) { + return action.sort(buildSortOrder(sortOrderFields, schema)); + } else { + return action.sort(); + } + } + if (strategy.equalsIgnoreCase("binpack")) { + RewriteDataFiles rewriteDataFiles = action.binPack(); + if (sortOrderString != null) { + // calling below method to throw the error as user has set both binpack strategy and sort + // order + return rewriteDataFiles.sort(buildSortOrder(sortOrderFields, schema)); + } + return rewriteDataFiles; + } else { + throw new IllegalArgumentException( + "unsupported strategy: " + strategy + ". Only binpack or sort is supported"); + } + } + + private SortOrder buildSortOrder( + List rawOrderFields, Schema schema) { + SortOrder.Builder builder = SortOrder.builderFor(schema); + rawOrderFields.forEach( + rawField -> builder.sortBy(rawField.term(), rawField.direction(), rawField.nullOrder())); + return builder.build(); + } + + private InternalRow[] toOutputRows(RewriteDataFiles.Result result) { + int rewrittenDataFilesCount = result.rewrittenDataFilesCount(); + long rewrittenBytesCount = result.rewrittenBytesCount(); + int addedDataFilesCount = result.addedDataFilesCount(); + InternalRow row = + newInternalRow(rewrittenDataFilesCount, addedDataFilesCount, rewrittenBytesCount); + return new InternalRow[] {row}; + } + + @Override + public String description() { + return "RewriteDataFilesProcedure"; + } +} diff --git a/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/procedures/RewriteManifestsProcedure.java b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/procedures/RewriteManifestsProcedure.java new file mode 100644 index 000000000000..c8becc7e5a0f --- /dev/null +++ b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/procedures/RewriteManifestsProcedure.java @@ -0,0 +1,115 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.procedures; + +import org.apache.iceberg.Table; +import org.apache.iceberg.actions.RewriteManifests; +import org.apache.iceberg.relocated.com.google.common.collect.Iterables; +import org.apache.iceberg.spark.actions.RewriteManifestsSparkAction; +import org.apache.iceberg.spark.actions.SparkActions; +import org.apache.iceberg.spark.procedures.SparkProcedures.ProcedureBuilder; +import org.apache.spark.sql.catalyst.InternalRow; +import org.apache.spark.sql.connector.catalog.Identifier; +import org.apache.spark.sql.connector.catalog.TableCatalog; +import org.apache.spark.sql.connector.iceberg.catalog.ProcedureParameter; +import org.apache.spark.sql.types.DataTypes; +import org.apache.spark.sql.types.Metadata; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; + +/** + * A procedure that rewrites manifests in a table. + * + *

Note: this procedure invalidates all cached Spark plans that reference the affected + * table. + * + * @see SparkActions#rewriteManifests(Table) () + */ +class RewriteManifestsProcedure extends BaseProcedure { + + private static final ProcedureParameter[] PARAMETERS = + new ProcedureParameter[] { + ProcedureParameter.required("table", DataTypes.StringType), + ProcedureParameter.optional("use_caching", DataTypes.BooleanType) + }; + + // counts are not nullable since the action result is never null + private static final StructType OUTPUT_TYPE = + new StructType( + new StructField[] { + new StructField( + "rewritten_manifests_count", DataTypes.IntegerType, false, Metadata.empty()), + new StructField("added_manifests_count", DataTypes.IntegerType, false, Metadata.empty()) + }); + + public static ProcedureBuilder builder() { + return new BaseProcedure.Builder() { + @Override + protected RewriteManifestsProcedure doBuild() { + return new RewriteManifestsProcedure(tableCatalog()); + } + }; + } + + private RewriteManifestsProcedure(TableCatalog tableCatalog) { + super(tableCatalog); + } + + @Override + public ProcedureParameter[] parameters() { + return PARAMETERS; + } + + @Override + public StructType outputType() { + return OUTPUT_TYPE; + } + + @Override + public InternalRow[] call(InternalRow args) { + Identifier tableIdent = toIdentifier(args.getString(0), PARAMETERS[0].name()); + Boolean useCaching = args.isNullAt(1) ? null : args.getBoolean(1); + + return modifyIcebergTable( + tableIdent, + table -> { + RewriteManifestsSparkAction action = actions().rewriteManifests(table); + + if (useCaching != null) { + action.option(RewriteManifestsSparkAction.USE_CACHING, useCaching.toString()); + } + + RewriteManifests.Result result = action.execute(); + + return toOutputRows(result); + }); + } + + private InternalRow[] toOutputRows(RewriteManifests.Result result) { + int rewrittenManifestsCount = Iterables.size(result.rewrittenManifests()); + int addedManifestsCount = Iterables.size(result.addedManifests()); + InternalRow row = newInternalRow(rewrittenManifestsCount, addedManifestsCount); + return new InternalRow[] {row}; + } + + @Override + public String description() { + return "RewriteManifestsProcedure"; + } +} diff --git a/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/procedures/RollbackToSnapshotProcedure.java b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/procedures/RollbackToSnapshotProcedure.java new file mode 100644 index 000000000000..49cc1a5ceae3 --- /dev/null +++ b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/procedures/RollbackToSnapshotProcedure.java @@ -0,0 +1,99 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.procedures; + +import org.apache.iceberg.Snapshot; +import org.apache.iceberg.spark.procedures.SparkProcedures.ProcedureBuilder; +import org.apache.spark.sql.catalyst.InternalRow; +import org.apache.spark.sql.connector.catalog.Identifier; +import org.apache.spark.sql.connector.catalog.TableCatalog; +import org.apache.spark.sql.connector.iceberg.catalog.ProcedureParameter; +import org.apache.spark.sql.types.DataTypes; +import org.apache.spark.sql.types.Metadata; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; + +/** + * A procedure that rollbacks a table to a specific snapshot id. + * + *

Note: this procedure invalidates all cached Spark plans that reference the affected + * table. + * + * @see org.apache.iceberg.ManageSnapshots#rollbackTo(long) + */ +class RollbackToSnapshotProcedure extends BaseProcedure { + + private static final ProcedureParameter[] PARAMETERS = + new ProcedureParameter[] { + ProcedureParameter.required("table", DataTypes.StringType), + ProcedureParameter.required("snapshot_id", DataTypes.LongType) + }; + + private static final StructType OUTPUT_TYPE = + new StructType( + new StructField[] { + new StructField("previous_snapshot_id", DataTypes.LongType, false, Metadata.empty()), + new StructField("current_snapshot_id", DataTypes.LongType, false, Metadata.empty()) + }); + + public static ProcedureBuilder builder() { + return new BaseProcedure.Builder() { + @Override + public RollbackToSnapshotProcedure doBuild() { + return new RollbackToSnapshotProcedure(tableCatalog()); + } + }; + } + + private RollbackToSnapshotProcedure(TableCatalog tableCatalog) { + super(tableCatalog); + } + + @Override + public ProcedureParameter[] parameters() { + return PARAMETERS; + } + + @Override + public StructType outputType() { + return OUTPUT_TYPE; + } + + @Override + public InternalRow[] call(InternalRow args) { + Identifier tableIdent = toIdentifier(args.getString(0), PARAMETERS[0].name()); + long snapshotId = args.getLong(1); + + return modifyIcebergTable( + tableIdent, + table -> { + Snapshot previousSnapshot = table.currentSnapshot(); + + table.manageSnapshots().rollbackTo(snapshotId).commit(); + + InternalRow outputRow = newInternalRow(previousSnapshot.snapshotId(), snapshotId); + return new InternalRow[] {outputRow}; + }); + } + + @Override + public String description() { + return "RollbackToSnapshotProcedure"; + } +} diff --git a/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/procedures/RollbackToTimestampProcedure.java b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/procedures/RollbackToTimestampProcedure.java new file mode 100644 index 000000000000..059725f0c152 --- /dev/null +++ b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/procedures/RollbackToTimestampProcedure.java @@ -0,0 +1,104 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.procedures; + +import org.apache.iceberg.Snapshot; +import org.apache.iceberg.spark.procedures.SparkProcedures.ProcedureBuilder; +import org.apache.iceberg.util.DateTimeUtil; +import org.apache.spark.sql.catalyst.InternalRow; +import org.apache.spark.sql.connector.catalog.Identifier; +import org.apache.spark.sql.connector.catalog.TableCatalog; +import org.apache.spark.sql.connector.iceberg.catalog.ProcedureParameter; +import org.apache.spark.sql.types.DataTypes; +import org.apache.spark.sql.types.Metadata; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; + +/** + * A procedure that rollbacks a table to a given point in time. + * + *

Note: this procedure invalidates all cached Spark plans that reference the affected + * table. + * + * @see org.apache.iceberg.ManageSnapshots#rollbackToTime(long) + */ +class RollbackToTimestampProcedure extends BaseProcedure { + + private static final ProcedureParameter[] PARAMETERS = + new ProcedureParameter[] { + ProcedureParameter.required("table", DataTypes.StringType), + ProcedureParameter.required("timestamp", DataTypes.TimestampType) + }; + + private static final StructType OUTPUT_TYPE = + new StructType( + new StructField[] { + new StructField("previous_snapshot_id", DataTypes.LongType, false, Metadata.empty()), + new StructField("current_snapshot_id", DataTypes.LongType, false, Metadata.empty()) + }); + + public static ProcedureBuilder builder() { + return new BaseProcedure.Builder() { + @Override + protected RollbackToTimestampProcedure doBuild() { + return new RollbackToTimestampProcedure(tableCatalog()); + } + }; + } + + private RollbackToTimestampProcedure(TableCatalog catalog) { + super(catalog); + } + + @Override + public ProcedureParameter[] parameters() { + return PARAMETERS; + } + + @Override + public StructType outputType() { + return OUTPUT_TYPE; + } + + @Override + public InternalRow[] call(InternalRow args) { + Identifier tableIdent = toIdentifier(args.getString(0), PARAMETERS[0].name()); + // timestamps in Spark have microsecond precision so this conversion is lossy + long timestampMillis = DateTimeUtil.microsToMillis(args.getLong(1)); + + return modifyIcebergTable( + tableIdent, + table -> { + Snapshot previousSnapshot = table.currentSnapshot(); + + table.manageSnapshots().rollbackToTime(timestampMillis).commit(); + + Snapshot currentSnapshot = table.currentSnapshot(); + + InternalRow outputRow = + newInternalRow(previousSnapshot.snapshotId(), currentSnapshot.snapshotId()); + return new InternalRow[] {outputRow}; + }); + } + + @Override + public String description() { + return "RollbackToTimestampProcedure"; + } +} diff --git a/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/procedures/SetCurrentSnapshotProcedure.java b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/procedures/SetCurrentSnapshotProcedure.java new file mode 100644 index 000000000000..f8f8049c22b6 --- /dev/null +++ b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/procedures/SetCurrentSnapshotProcedure.java @@ -0,0 +1,100 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.procedures; + +import org.apache.iceberg.Snapshot; +import org.apache.iceberg.spark.procedures.SparkProcedures.ProcedureBuilder; +import org.apache.spark.sql.catalyst.InternalRow; +import org.apache.spark.sql.connector.catalog.Identifier; +import org.apache.spark.sql.connector.catalog.TableCatalog; +import org.apache.spark.sql.connector.iceberg.catalog.ProcedureParameter; +import org.apache.spark.sql.types.DataTypes; +import org.apache.spark.sql.types.Metadata; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; + +/** + * A procedure that sets the current snapshot in a table. + * + *

Note: this procedure invalidates all cached Spark plans that reference the affected + * table. + * + * @see org.apache.iceberg.ManageSnapshots#setCurrentSnapshot(long) + */ +class SetCurrentSnapshotProcedure extends BaseProcedure { + + private static final ProcedureParameter[] PARAMETERS = + new ProcedureParameter[] { + ProcedureParameter.required("table", DataTypes.StringType), + ProcedureParameter.required("snapshot_id", DataTypes.LongType) + }; + + private static final StructType OUTPUT_TYPE = + new StructType( + new StructField[] { + new StructField("previous_snapshot_id", DataTypes.LongType, true, Metadata.empty()), + new StructField("current_snapshot_id", DataTypes.LongType, false, Metadata.empty()) + }); + + public static ProcedureBuilder builder() { + return new BaseProcedure.Builder() { + @Override + protected SetCurrentSnapshotProcedure doBuild() { + return new SetCurrentSnapshotProcedure(tableCatalog()); + } + }; + } + + private SetCurrentSnapshotProcedure(TableCatalog catalog) { + super(catalog); + } + + @Override + public ProcedureParameter[] parameters() { + return PARAMETERS; + } + + @Override + public StructType outputType() { + return OUTPUT_TYPE; + } + + @Override + public InternalRow[] call(InternalRow args) { + Identifier tableIdent = toIdentifier(args.getString(0), PARAMETERS[0].name()); + long snapshotId = args.getLong(1); + + return modifyIcebergTable( + tableIdent, + table -> { + Snapshot previousSnapshot = table.currentSnapshot(); + Long previousSnapshotId = previousSnapshot != null ? previousSnapshot.snapshotId() : null; + + table.manageSnapshots().setCurrentSnapshot(snapshotId).commit(); + + InternalRow outputRow = newInternalRow(previousSnapshotId, snapshotId); + return new InternalRow[] {outputRow}; + }); + } + + @Override + public String description() { + return "SetCurrentSnapshotProcedure"; + } +} diff --git a/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/procedures/SnapshotTableProcedure.java b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/procedures/SnapshotTableProcedure.java new file mode 100644 index 000000000000..7a015a51e8ed --- /dev/null +++ b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/procedures/SnapshotTableProcedure.java @@ -0,0 +1,113 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.procedures; + +import java.util.Map; +import org.apache.iceberg.actions.SnapshotTable; +import org.apache.iceberg.relocated.com.google.common.base.Preconditions; +import org.apache.iceberg.relocated.com.google.common.collect.Maps; +import org.apache.iceberg.spark.actions.SparkActions; +import org.apache.spark.sql.catalyst.InternalRow; +import org.apache.spark.sql.connector.catalog.TableCatalog; +import org.apache.spark.sql.connector.iceberg.catalog.ProcedureParameter; +import org.apache.spark.sql.types.DataTypes; +import org.apache.spark.sql.types.Metadata; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; +import scala.runtime.BoxedUnit; + +class SnapshotTableProcedure extends BaseProcedure { + private static final ProcedureParameter[] PARAMETERS = + new ProcedureParameter[] { + ProcedureParameter.required("source_table", DataTypes.StringType), + ProcedureParameter.required("table", DataTypes.StringType), + ProcedureParameter.optional("location", DataTypes.StringType), + ProcedureParameter.optional("properties", STRING_MAP) + }; + + private static final StructType OUTPUT_TYPE = + new StructType( + new StructField[] { + new StructField("imported_files_count", DataTypes.LongType, false, Metadata.empty()) + }); + + private SnapshotTableProcedure(TableCatalog tableCatalog) { + super(tableCatalog); + } + + public static SparkProcedures.ProcedureBuilder builder() { + return new BaseProcedure.Builder() { + @Override + protected SnapshotTableProcedure doBuild() { + return new SnapshotTableProcedure(tableCatalog()); + } + }; + } + + @Override + public ProcedureParameter[] parameters() { + return PARAMETERS; + } + + @Override + public StructType outputType() { + return OUTPUT_TYPE; + } + + @Override + public InternalRow[] call(InternalRow args) { + String source = args.getString(0); + Preconditions.checkArgument( + source != null && !source.isEmpty(), + "Cannot handle an empty identifier for argument source_table"); + String dest = args.getString(1); + Preconditions.checkArgument( + dest != null && !dest.isEmpty(), "Cannot handle an empty identifier for argument table"); + String snapshotLocation = args.isNullAt(2) ? null : args.getString(2); + + Map properties = Maps.newHashMap(); + if (!args.isNullAt(3)) { + args.getMap(3) + .foreach( + DataTypes.StringType, + DataTypes.StringType, + (k, v) -> { + properties.put(k.toString(), v.toString()); + return BoxedUnit.UNIT; + }); + } + + Preconditions.checkArgument( + !source.equals(dest), + "Cannot create a snapshot with the same name as the source of the snapshot."); + SnapshotTable action = SparkActions.get().snapshotTable(source).as(dest); + + if (snapshotLocation != null) { + action.tableLocation(snapshotLocation); + } + + SnapshotTable.Result result = action.tableProperties(properties).execute(); + return new InternalRow[] {newInternalRow(result.importedDataFilesCount())}; + } + + @Override + public String description() { + return "SnapshotTableProcedure"; + } +} diff --git a/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/procedures/SparkProcedures.java b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/procedures/SparkProcedures.java new file mode 100644 index 000000000000..8ee3a9550194 --- /dev/null +++ b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/procedures/SparkProcedures.java @@ -0,0 +1,65 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.procedures; + +import java.util.Locale; +import java.util.Map; +import java.util.function.Supplier; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap; +import org.apache.spark.sql.connector.catalog.TableCatalog; +import org.apache.spark.sql.connector.iceberg.catalog.Procedure; + +public class SparkProcedures { + + private static final Map> BUILDERS = initProcedureBuilders(); + + private SparkProcedures() {} + + public static ProcedureBuilder newBuilder(String name) { + // procedure resolution is case insensitive to match the existing Spark behavior for functions + Supplier builderSupplier = BUILDERS.get(name.toLowerCase(Locale.ROOT)); + return builderSupplier != null ? builderSupplier.get() : null; + } + + private static Map> initProcedureBuilders() { + ImmutableMap.Builder> mapBuilder = ImmutableMap.builder(); + mapBuilder.put("rollback_to_snapshot", RollbackToSnapshotProcedure::builder); + mapBuilder.put("rollback_to_timestamp", RollbackToTimestampProcedure::builder); + mapBuilder.put("set_current_snapshot", SetCurrentSnapshotProcedure::builder); + mapBuilder.put("cherrypick_snapshot", CherrypickSnapshotProcedure::builder); + mapBuilder.put("rewrite_data_files", RewriteDataFilesProcedure::builder); + mapBuilder.put("rewrite_manifests", RewriteManifestsProcedure::builder); + mapBuilder.put("remove_orphan_files", RemoveOrphanFilesProcedure::builder); + mapBuilder.put("expire_snapshots", ExpireSnapshotsProcedure::builder); + mapBuilder.put("migrate", MigrateTableProcedure::builder); + mapBuilder.put("snapshot", SnapshotTableProcedure::builder); + mapBuilder.put("add_files", AddFilesProcedure::builder); + mapBuilder.put("ancestors_of", AncestorsOfProcedure::builder); + mapBuilder.put("register_table", RegisterTableProcedure::builder); + mapBuilder.put("publish_changes", PublishChangesProcedure::builder); + mapBuilder.put("create_changelog_view", CreateChangelogViewProcedure::builder); + return mapBuilder.build(); + } + + public interface ProcedureBuilder { + ProcedureBuilder withTableCatalog(TableCatalog tableCatalog); + + Procedure build(); + } +} diff --git a/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/source/BaseBatchReader.java b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/source/BaseBatchReader.java new file mode 100644 index 000000000000..c05b694a60dc --- /dev/null +++ b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/source/BaseBatchReader.java @@ -0,0 +1,128 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.source; + +import java.util.Map; +import java.util.Set; +import org.apache.iceberg.FileFormat; +import org.apache.iceberg.MetadataColumns; +import org.apache.iceberg.ScanTask; +import org.apache.iceberg.ScanTaskGroup; +import org.apache.iceberg.Schema; +import org.apache.iceberg.Table; +import org.apache.iceberg.expressions.Expression; +import org.apache.iceberg.io.CloseableIterable; +import org.apache.iceberg.io.InputFile; +import org.apache.iceberg.orc.ORC; +import org.apache.iceberg.parquet.Parquet; +import org.apache.iceberg.relocated.com.google.common.collect.Sets; +import org.apache.iceberg.spark.data.vectorized.VectorizedSparkOrcReaders; +import org.apache.iceberg.spark.data.vectorized.VectorizedSparkParquetReaders; +import org.apache.iceberg.types.TypeUtil; +import org.apache.spark.sql.vectorized.ColumnarBatch; + +abstract class BaseBatchReader extends BaseReader { + private final int batchSize; + + BaseBatchReader( + Table table, + ScanTaskGroup taskGroup, + Schema tableSchema, + Schema expectedSchema, + boolean caseSensitive, + int batchSize) { + super(table, taskGroup, tableSchema, expectedSchema, caseSensitive); + this.batchSize = batchSize; + } + + protected CloseableIterable newBatchIterable( + InputFile inputFile, + FileFormat format, + long start, + long length, + Expression residual, + Map idToConstant, + SparkDeleteFilter deleteFilter) { + switch (format) { + case PARQUET: + return newParquetIterable(inputFile, start, length, residual, idToConstant, deleteFilter); + + case ORC: + return newOrcIterable(inputFile, start, length, residual, idToConstant); + + default: + throw new UnsupportedOperationException( + "Format: " + format + " not supported for batched reads"); + } + } + + private CloseableIterable newParquetIterable( + InputFile inputFile, + long start, + long length, + Expression residual, + Map idToConstant, + SparkDeleteFilter deleteFilter) { + // get required schema if there are deletes + Schema requiredSchema = deleteFilter != null ? deleteFilter.requiredSchema() : expectedSchema(); + + return Parquet.read(inputFile) + .project(requiredSchema) + .split(start, length) + .createBatchedReaderFunc( + fileSchema -> + VectorizedSparkParquetReaders.buildReader( + requiredSchema, fileSchema, idToConstant, deleteFilter)) + .recordsPerBatch(batchSize) + .filter(residual) + .caseSensitive(caseSensitive()) + // Spark eagerly consumes the batches. So the underlying memory allocated could be reused + // without worrying about subsequent reads clobbering over each other. This improves + // read performance as every batch read doesn't have to pay the cost of allocating memory. + .reuseContainers() + .withNameMapping(nameMapping()) + .build(); + } + + private CloseableIterable newOrcIterable( + InputFile inputFile, + long start, + long length, + Expression residual, + Map idToConstant) { + Set constantFieldIds = idToConstant.keySet(); + Set metadataFieldIds = MetadataColumns.metadataFieldIds(); + Sets.SetView constantAndMetadataFieldIds = + Sets.union(constantFieldIds, metadataFieldIds); + Schema schemaWithoutConstantAndMetadataFields = + TypeUtil.selectNot(expectedSchema(), constantAndMetadataFieldIds); + + return ORC.read(inputFile) + .project(schemaWithoutConstantAndMetadataFields) + .split(start, length) + .createBatchedReaderFunc( + fileSchema -> + VectorizedSparkOrcReaders.buildReader(expectedSchema(), fileSchema, idToConstant)) + .recordsPerBatch(batchSize) + .filter(residual) + .caseSensitive(caseSensitive()) + .withNameMapping(nameMapping()) + .build(); + } +} diff --git a/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/source/BaseReader.java b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/source/BaseReader.java new file mode 100644 index 000000000000..4fb838202c88 --- /dev/null +++ b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/source/BaseReader.java @@ -0,0 +1,283 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.source; + +import java.io.Closeable; +import java.io.IOException; +import java.math.BigDecimal; +import java.nio.ByteBuffer; +import java.util.Iterator; +import java.util.List; +import java.util.Map; +import java.util.stream.Collectors; +import java.util.stream.Stream; +import org.apache.avro.generic.GenericData; +import org.apache.avro.util.Utf8; +import org.apache.iceberg.ContentFile; +import org.apache.iceberg.ContentScanTask; +import org.apache.iceberg.DeleteFile; +import org.apache.iceberg.MetadataColumns; +import org.apache.iceberg.Partitioning; +import org.apache.iceberg.ScanTask; +import org.apache.iceberg.ScanTaskGroup; +import org.apache.iceberg.Schema; +import org.apache.iceberg.StructLike; +import org.apache.iceberg.Table; +import org.apache.iceberg.TableProperties; +import org.apache.iceberg.data.DeleteFilter; +import org.apache.iceberg.deletes.DeleteCounter; +import org.apache.iceberg.encryption.EncryptedFiles; +import org.apache.iceberg.encryption.EncryptedInputFile; +import org.apache.iceberg.io.CloseableIterator; +import org.apache.iceberg.io.InputFile; +import org.apache.iceberg.mapping.NameMapping; +import org.apache.iceberg.mapping.NameMappingParser; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap; +import org.apache.iceberg.relocated.com.google.common.collect.Maps; +import org.apache.iceberg.spark.SparkSchemaUtil; +import org.apache.iceberg.types.Type; +import org.apache.iceberg.types.Types.NestedField; +import org.apache.iceberg.types.Types.StructType; +import org.apache.iceberg.util.ByteBuffers; +import org.apache.iceberg.util.PartitionUtil; +import org.apache.spark.rdd.InputFileBlockHolder; +import org.apache.spark.sql.catalyst.InternalRow; +import org.apache.spark.sql.catalyst.expressions.GenericInternalRow; +import org.apache.spark.sql.types.Decimal; +import org.apache.spark.unsafe.types.UTF8String; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * Base class of Spark readers. + * + * @param is the Java class returned by this reader whose objects contain one or more rows. + */ +abstract class BaseReader implements Closeable { + private static final Logger LOG = LoggerFactory.getLogger(BaseReader.class); + + private final Table table; + private final Schema tableSchema; + private final Schema expectedSchema; + private final boolean caseSensitive; + private final NameMapping nameMapping; + private final ScanTaskGroup taskGroup; + private final Iterator tasks; + private final DeleteCounter counter; + + private Map lazyInputFiles; + private CloseableIterator currentIterator; + private T current = null; + private TaskT currentTask = null; + + BaseReader( + Table table, + ScanTaskGroup taskGroup, + Schema tableSchema, + Schema expectedSchema, + boolean caseSensitive) { + this.table = table; + this.taskGroup = taskGroup; + this.tasks = taskGroup.tasks().iterator(); + this.currentIterator = CloseableIterator.empty(); + this.tableSchema = tableSchema; + this.expectedSchema = expectedSchema; + this.caseSensitive = caseSensitive; + String nameMappingString = table.properties().get(TableProperties.DEFAULT_NAME_MAPPING); + this.nameMapping = + nameMappingString != null ? NameMappingParser.fromJson(nameMappingString) : null; + this.counter = new DeleteCounter(); + } + + protected abstract CloseableIterator open(TaskT task); + + protected abstract Stream> referencedFiles(TaskT task); + + protected Schema expectedSchema() { + return expectedSchema; + } + + protected boolean caseSensitive() { + return caseSensitive; + } + + protected NameMapping nameMapping() { + return nameMapping; + } + + protected Table table() { + return table; + } + + protected DeleteCounter counter() { + return counter; + } + + public boolean next() throws IOException { + try { + while (true) { + if (currentIterator.hasNext()) { + this.current = currentIterator.next(); + return true; + } else if (tasks.hasNext()) { + this.currentIterator.close(); + this.currentTask = tasks.next(); + this.currentIterator = open(currentTask); + } else { + this.currentIterator.close(); + return false; + } + } + } catch (IOException | RuntimeException e) { + if (currentTask != null && !currentTask.isDataTask()) { + String filePaths = + referencedFiles(currentTask) + .map(file -> file.path().toString()) + .collect(Collectors.joining(", ")); + LOG.error("Error reading file(s): {}", filePaths, e); + } + throw e; + } + } + + public T get() { + return current; + } + + @Override + public void close() throws IOException { + InputFileBlockHolder.unset(); + + // close the current iterator + this.currentIterator.close(); + + // exhaust the task iterator + while (tasks.hasNext()) { + tasks.next(); + } + } + + protected InputFile getInputFile(String location) { + return inputFiles().get(location); + } + + private Map inputFiles() { + if (lazyInputFiles == null) { + Stream encryptedFiles = + taskGroup.tasks().stream().flatMap(this::referencedFiles).map(this::toEncryptedInputFile); + + // decrypt with the batch call to avoid multiple RPCs to a key server, if possible + Iterable decryptedFiles = table.encryption().decrypt(encryptedFiles::iterator); + + Map files = Maps.newHashMapWithExpectedSize(taskGroup.tasks().size()); + decryptedFiles.forEach(decrypted -> files.putIfAbsent(decrypted.location(), decrypted)); + this.lazyInputFiles = ImmutableMap.copyOf(files); + } + + return lazyInputFiles; + } + + private EncryptedInputFile toEncryptedInputFile(ContentFile file) { + InputFile inputFile = table.io().newInputFile(file.path().toString()); + return EncryptedFiles.encryptedInput(inputFile, file.keyMetadata()); + } + + protected Map constantsMap(ContentScanTask task, Schema readSchema) { + if (readSchema.findField(MetadataColumns.PARTITION_COLUMN_ID) != null) { + StructType partitionType = Partitioning.partitionType(table); + return PartitionUtil.constantsMap(task, partitionType, BaseReader::convertConstant); + } else { + return PartitionUtil.constantsMap(task, BaseReader::convertConstant); + } + } + + protected static Object convertConstant(Type type, Object value) { + if (value == null) { + return null; + } + + switch (type.typeId()) { + case DECIMAL: + return Decimal.apply((BigDecimal) value); + case STRING: + if (value instanceof Utf8) { + Utf8 utf8 = (Utf8) value; + return UTF8String.fromBytes(utf8.getBytes(), 0, utf8.getByteLength()); + } + return UTF8String.fromString(value.toString()); + case FIXED: + if (value instanceof byte[]) { + return value; + } else if (value instanceof GenericData.Fixed) { + return ((GenericData.Fixed) value).bytes(); + } + return ByteBuffers.toByteArray((ByteBuffer) value); + case BINARY: + return ByteBuffers.toByteArray((ByteBuffer) value); + case STRUCT: + StructType structType = (StructType) type; + + if (structType.fields().isEmpty()) { + return new GenericInternalRow(); + } + + List fields = structType.fields(); + Object[] values = new Object[fields.size()]; + StructLike struct = (StructLike) value; + + for (int index = 0; index < fields.size(); index++) { + NestedField field = fields.get(index); + Type fieldType = field.type(); + values[index] = + convertConstant(fieldType, struct.get(index, fieldType.typeId().javaClass())); + } + + return new GenericInternalRow(values); + default: + } + return value; + } + + protected class SparkDeleteFilter extends DeleteFilter { + private final InternalRowWrapper asStructLike; + + SparkDeleteFilter(String filePath, List deletes, DeleteCounter counter) { + super(filePath, deletes, tableSchema, expectedSchema, counter); + this.asStructLike = new InternalRowWrapper(SparkSchemaUtil.convert(requiredSchema())); + } + + @Override + protected StructLike asStructLike(InternalRow row) { + return asStructLike.wrap(row); + } + + @Override + protected InputFile getInputFile(String location) { + return BaseReader.this.getInputFile(location); + } + + @Override + protected void markRowDeleted(InternalRow row) { + if (!row.getBoolean(columnIsDeletedPosition())) { + row.setBoolean(columnIsDeletedPosition(), true); + counter().increment(); + } + } + } +} diff --git a/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/source/BaseRowReader.java b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/source/BaseRowReader.java new file mode 100644 index 000000000000..927084caea1c --- /dev/null +++ b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/source/BaseRowReader.java @@ -0,0 +1,125 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.source; + +import java.util.Map; +import org.apache.iceberg.FileFormat; +import org.apache.iceberg.MetadataColumns; +import org.apache.iceberg.ScanTask; +import org.apache.iceberg.ScanTaskGroup; +import org.apache.iceberg.Schema; +import org.apache.iceberg.Table; +import org.apache.iceberg.avro.Avro; +import org.apache.iceberg.expressions.Expression; +import org.apache.iceberg.io.CloseableIterable; +import org.apache.iceberg.io.InputFile; +import org.apache.iceberg.orc.ORC; +import org.apache.iceberg.parquet.Parquet; +import org.apache.iceberg.relocated.com.google.common.collect.Sets; +import org.apache.iceberg.spark.data.SparkAvroReader; +import org.apache.iceberg.spark.data.SparkOrcReader; +import org.apache.iceberg.spark.data.SparkParquetReaders; +import org.apache.iceberg.types.TypeUtil; +import org.apache.spark.sql.catalyst.InternalRow; + +abstract class BaseRowReader extends BaseReader { + BaseRowReader( + Table table, + ScanTaskGroup taskGroup, + Schema tableSchema, + Schema expectedSchema, + boolean caseSensitive) { + super(table, taskGroup, tableSchema, expectedSchema, caseSensitive); + } + + protected CloseableIterable newIterable( + InputFile file, + FileFormat format, + long start, + long length, + Expression residual, + Schema projection, + Map idToConstant) { + switch (format) { + case PARQUET: + return newParquetIterable(file, start, length, residual, projection, idToConstant); + + case AVRO: + return newAvroIterable(file, start, length, projection, idToConstant); + + case ORC: + return newOrcIterable(file, start, length, residual, projection, idToConstant); + + default: + throw new UnsupportedOperationException("Cannot read unknown format: " + format); + } + } + + private CloseableIterable newAvroIterable( + InputFile file, long start, long length, Schema projection, Map idToConstant) { + return Avro.read(file) + .reuseContainers() + .project(projection) + .split(start, length) + .createReaderFunc(readSchema -> new SparkAvroReader(projection, readSchema, idToConstant)) + .withNameMapping(nameMapping()) + .build(); + } + + private CloseableIterable newParquetIterable( + InputFile file, + long start, + long length, + Expression residual, + Schema readSchema, + Map idToConstant) { + return Parquet.read(file) + .reuseContainers() + .split(start, length) + .project(readSchema) + .createReaderFunc( + fileSchema -> SparkParquetReaders.buildReader(readSchema, fileSchema, idToConstant)) + .filter(residual) + .caseSensitive(caseSensitive()) + .withNameMapping(nameMapping()) + .build(); + } + + private CloseableIterable newOrcIterable( + InputFile file, + long start, + long length, + Expression residual, + Schema readSchema, + Map idToConstant) { + Schema readSchemaWithoutConstantAndMetadataFields = + TypeUtil.selectNot( + readSchema, Sets.union(idToConstant.keySet(), MetadataColumns.metadataFieldIds())); + + return ORC.read(file) + .project(readSchemaWithoutConstantAndMetadataFields) + .split(start, length) + .createReaderFunc( + readOrcSchema -> new SparkOrcReader(readSchema, readOrcSchema, idToConstant)) + .filter(residual) + .caseSensitive(caseSensitive()) + .withNameMapping(nameMapping()) + .build(); + } +} diff --git a/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/source/BatchDataReader.java b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/source/BatchDataReader.java new file mode 100644 index 000000000000..389ad1d5a2d9 --- /dev/null +++ b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/source/BatchDataReader.java @@ -0,0 +1,111 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.source; + +import java.util.Map; +import java.util.stream.Stream; +import org.apache.iceberg.ContentFile; +import org.apache.iceberg.FileScanTask; +import org.apache.iceberg.ScanTaskGroup; +import org.apache.iceberg.Schema; +import org.apache.iceberg.Table; +import org.apache.iceberg.io.CloseableIterator; +import org.apache.iceberg.io.InputFile; +import org.apache.iceberg.relocated.com.google.common.base.Preconditions; +import org.apache.iceberg.spark.source.metrics.TaskNumDeletes; +import org.apache.iceberg.spark.source.metrics.TaskNumSplits; +import org.apache.iceberg.util.SnapshotUtil; +import org.apache.spark.rdd.InputFileBlockHolder; +import org.apache.spark.sql.connector.metric.CustomTaskMetric; +import org.apache.spark.sql.connector.read.PartitionReader; +import org.apache.spark.sql.vectorized.ColumnarBatch; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +class BatchDataReader extends BaseBatchReader + implements PartitionReader { + + private static final Logger LOG = LoggerFactory.getLogger(BatchDataReader.class); + + private final long numSplits; + + BatchDataReader(SparkInputPartition partition, int batchSize) { + this( + partition.table(), + partition.taskGroup(), + SnapshotUtil.schemaFor(partition.table(), partition.branch()), + partition.expectedSchema(), + partition.isCaseSensitive(), + batchSize); + } + + BatchDataReader( + Table table, + ScanTaskGroup taskGroup, + Schema tableSchema, + Schema expectedSchema, + boolean caseSensitive, + int size) { + super(table, taskGroup, tableSchema, expectedSchema, caseSensitive, size); + + numSplits = taskGroup.tasks().size(); + LOG.debug("Reading {} file split(s) for table {}", numSplits, table.name()); + } + + @Override + public CustomTaskMetric[] currentMetricsValues() { + return new CustomTaskMetric[] { + new TaskNumSplits(numSplits), new TaskNumDeletes(counter().get()) + }; + } + + @Override + protected Stream> referencedFiles(FileScanTask task) { + return Stream.concat(Stream.of(task.file()), task.deletes().stream()); + } + + @Override + protected CloseableIterator open(FileScanTask task) { + String filePath = task.file().path().toString(); + LOG.debug("Opening data file {}", filePath); + + // update the current file for Spark's filename() function + InputFileBlockHolder.set(filePath, task.start(), task.length()); + + Map idToConstant = constantsMap(task, expectedSchema()); + + InputFile inputFile = getInputFile(filePath); + Preconditions.checkNotNull(inputFile, "Could not find InputFile associated with FileScanTask"); + + SparkDeleteFilter deleteFilter = + task.deletes().isEmpty() + ? null + : new SparkDeleteFilter(filePath, task.deletes(), counter()); + + return newBatchIterable( + inputFile, + task.file().format(), + task.start(), + task.length(), + task.residual(), + idToConstant, + deleteFilter) + .iterator(); + } +} diff --git a/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/source/ChangelogRowReader.java b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/source/ChangelogRowReader.java new file mode 100644 index 000000000000..572f955884a3 --- /dev/null +++ b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/source/ChangelogRowReader.java @@ -0,0 +1,172 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.source; + +import java.util.List; +import java.util.Map; +import java.util.stream.Stream; +import org.apache.iceberg.AddedRowsScanTask; +import org.apache.iceberg.ChangelogScanTask; +import org.apache.iceberg.ChangelogUtil; +import org.apache.iceberg.ContentFile; +import org.apache.iceberg.ContentScanTask; +import org.apache.iceberg.DataFile; +import org.apache.iceberg.DeleteFile; +import org.apache.iceberg.DeletedDataFileScanTask; +import org.apache.iceberg.DeletedRowsScanTask; +import org.apache.iceberg.ScanTaskGroup; +import org.apache.iceberg.Schema; +import org.apache.iceberg.Table; +import org.apache.iceberg.io.CloseableIterable; +import org.apache.iceberg.io.CloseableIterator; +import org.apache.iceberg.io.InputFile; +import org.apache.iceberg.relocated.com.google.common.base.Preconditions; +import org.apache.iceberg.util.SnapshotUtil; +import org.apache.spark.rdd.InputFileBlockHolder; +import org.apache.spark.sql.catalyst.InternalRow; +import org.apache.spark.sql.catalyst.expressions.GenericInternalRow; +import org.apache.spark.sql.catalyst.expressions.JoinedRow; +import org.apache.spark.sql.connector.read.PartitionReader; +import org.apache.spark.unsafe.types.UTF8String; + +class ChangelogRowReader extends BaseRowReader + implements PartitionReader { + + ChangelogRowReader(SparkInputPartition partition) { + this( + partition.table(), + partition.taskGroup(), + SnapshotUtil.schemaFor(partition.table(), partition.branch()), + partition.expectedSchema(), + partition.isCaseSensitive()); + } + + ChangelogRowReader( + Table table, + ScanTaskGroup taskGroup, + Schema tableSchema, + Schema expectedSchema, + boolean caseSensitive) { + super( + table, + taskGroup, + tableSchema, + ChangelogUtil.dropChangelogMetadata(expectedSchema), + caseSensitive); + } + + @Override + protected CloseableIterator open(ChangelogScanTask task) { + JoinedRow cdcRow = new JoinedRow(); + + cdcRow.withRight(changelogMetadata(task)); + + CloseableIterable rows = openChangelogScanTask(task); + CloseableIterable cdcRows = CloseableIterable.transform(rows, cdcRow::withLeft); + + return cdcRows.iterator(); + } + + private static InternalRow changelogMetadata(ChangelogScanTask task) { + InternalRow metadataRow = new GenericInternalRow(3); + + metadataRow.update(0, UTF8String.fromString(task.operation().name())); + metadataRow.update(1, task.changeOrdinal()); + metadataRow.update(2, task.commitSnapshotId()); + + return metadataRow; + } + + private CloseableIterable openChangelogScanTask(ChangelogScanTask task) { + if (task instanceof AddedRowsScanTask) { + return openAddedRowsScanTask((AddedRowsScanTask) task); + + } else if (task instanceof DeletedRowsScanTask) { + throw new UnsupportedOperationException("Deleted rows scan task is not supported yet"); + + } else if (task instanceof DeletedDataFileScanTask) { + return openDeletedDataFileScanTask((DeletedDataFileScanTask) task); + + } else { + throw new IllegalArgumentException( + "Unsupported changelog scan task type: " + task.getClass().getName()); + } + } + + CloseableIterable openAddedRowsScanTask(AddedRowsScanTask task) { + String filePath = task.file().path().toString(); + SparkDeleteFilter deletes = new SparkDeleteFilter(filePath, task.deletes(), counter()); + return deletes.filter(rows(task, deletes.requiredSchema())); + } + + private CloseableIterable openDeletedDataFileScanTask(DeletedDataFileScanTask task) { + String filePath = task.file().path().toString(); + SparkDeleteFilter deletes = new SparkDeleteFilter(filePath, task.existingDeletes(), counter()); + return deletes.filter(rows(task, deletes.requiredSchema())); + } + + private CloseableIterable rows(ContentScanTask task, Schema readSchema) { + Map idToConstant = constantsMap(task, readSchema); + + String filePath = task.file().path().toString(); + + // update the current file for Spark's filename() function + InputFileBlockHolder.set(filePath, task.start(), task.length()); + + InputFile location = getInputFile(filePath); + Preconditions.checkNotNull(location, "Could not find InputFile"); + return newIterable( + location, + task.file().format(), + task.start(), + task.length(), + task.residual(), + readSchema, + idToConstant); + } + + @Override + protected Stream> referencedFiles(ChangelogScanTask task) { + if (task instanceof AddedRowsScanTask) { + return addedRowsScanTaskFiles((AddedRowsScanTask) task); + + } else if (task instanceof DeletedRowsScanTask) { + throw new UnsupportedOperationException("Deleted rows scan task is not supported yet"); + + } else if (task instanceof DeletedDataFileScanTask) { + return deletedDataFileScanTaskFiles((DeletedDataFileScanTask) task); + + } else { + throw new IllegalArgumentException( + "Unsupported changelog scan task type: " + task.getClass().getName()); + } + } + + private static Stream> deletedDataFileScanTaskFiles(DeletedDataFileScanTask task) { + DataFile file = task.file(); + List existingDeletes = task.existingDeletes(); + return Stream.concat(Stream.of(file), existingDeletes.stream()); + } + + private static Stream> addedRowsScanTaskFiles(AddedRowsScanTask task) { + DataFile file = task.file(); + List deletes = task.deletes(); + return Stream.concat(Stream.of(file), deletes.stream()); + } +} diff --git a/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/source/EqualityDeleteRowReader.java b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/source/EqualityDeleteRowReader.java new file mode 100644 index 000000000000..f5b98a5a43bd --- /dev/null +++ b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/source/EqualityDeleteRowReader.java @@ -0,0 +1,56 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.source; + +import java.util.Map; +import org.apache.iceberg.CombinedScanTask; +import org.apache.iceberg.DataFile; +import org.apache.iceberg.FileScanTask; +import org.apache.iceberg.Schema; +import org.apache.iceberg.Table; +import org.apache.iceberg.io.CloseableIterator; +import org.apache.spark.rdd.InputFileBlockHolder; +import org.apache.spark.sql.catalyst.InternalRow; + +public class EqualityDeleteRowReader extends RowDataReader { + public EqualityDeleteRowReader( + CombinedScanTask task, + Table table, + Schema tableSchema, + Schema expectedSchema, + boolean caseSensitive) { + super(table, task, tableSchema, expectedSchema, caseSensitive); + } + + @Override + protected CloseableIterator open(FileScanTask task) { + SparkDeleteFilter matches = + new SparkDeleteFilter(task.file().path().toString(), task.deletes(), counter()); + + // schema or rows returned by readers + Schema requiredSchema = matches.requiredSchema(); + Map idToConstant = constantsMap(task, expectedSchema()); + DataFile file = task.file(); + + // update the current file for Spark's filename() function + InputFileBlockHolder.set(file.path().toString(), task.start(), task.length()); + + return matches.findEqualityDeleteRows(open(task, requiredSchema, idToConstant)).iterator(); + } +} diff --git a/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/source/HasIcebergCatalog.java b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/source/HasIcebergCatalog.java new file mode 100644 index 000000000000..37e0c4dfcdb6 --- /dev/null +++ b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/source/HasIcebergCatalog.java @@ -0,0 +1,30 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.source; + +import org.apache.iceberg.catalog.Catalog; +import org.apache.spark.sql.connector.catalog.TableCatalog; + +public interface HasIcebergCatalog extends TableCatalog { + + /** + * Returns the underlying {@link org.apache.iceberg.catalog.Catalog} backing this Spark Catalog + */ + Catalog icebergCatalog(); +} diff --git a/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/source/IcebergSource.java b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/source/IcebergSource.java new file mode 100644 index 000000000000..8975c7f32db1 --- /dev/null +++ b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/source/IcebergSource.java @@ -0,0 +1,248 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.source; + +import java.util.Arrays; +import java.util.Map; +import java.util.Objects; +import java.util.Optional; +import java.util.stream.Stream; +import org.apache.iceberg.relocated.com.google.common.base.Preconditions; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap; +import org.apache.iceberg.spark.PathIdentifier; +import org.apache.iceberg.spark.Spark3Util; +import org.apache.iceberg.spark.SparkCachedTableCatalog; +import org.apache.iceberg.spark.SparkCatalog; +import org.apache.iceberg.spark.SparkReadOptions; +import org.apache.iceberg.spark.SparkSessionCatalog; +import org.apache.iceberg.spark.SparkTableCache; +import org.apache.iceberg.util.PropertyUtil; +import org.apache.spark.sql.SparkSession; +import org.apache.spark.sql.catalyst.analysis.NoSuchTableException; +import org.apache.spark.sql.connector.catalog.CatalogManager; +import org.apache.spark.sql.connector.catalog.CatalogPlugin; +import org.apache.spark.sql.connector.catalog.Identifier; +import org.apache.spark.sql.connector.catalog.SupportsCatalogOptions; +import org.apache.spark.sql.connector.catalog.Table; +import org.apache.spark.sql.connector.catalog.TableCatalog; +import org.apache.spark.sql.connector.expressions.Transform; +import org.apache.spark.sql.sources.DataSourceRegister; +import org.apache.spark.sql.types.StructType; +import org.apache.spark.sql.util.CaseInsensitiveStringMap; + +/** + * The IcebergSource loads/writes tables with format "iceberg". It can load paths and tables. + * + *

How paths/tables are loaded when using spark.read().format("iceberg").load(table) + * + *

table = "file:///path/to/table" -> loads a HadoopTable at given path table = "tablename" + * -> loads currentCatalog.currentNamespace.tablename table = "catalog.tablename" -> load + * "tablename" from the specified catalog. table = "namespace.tablename" -> load + * "namespace.tablename" from current catalog table = "catalog.namespace.tablename" -> + * "namespace.tablename" from the specified catalog. table = "namespace1.namespace2.tablename" -> + * load "namespace1.namespace2.tablename" from current catalog + * + *

The above list is in order of priority. For example: a matching catalog will take priority + * over any namespace resolution. + */ +public class IcebergSource implements DataSourceRegister, SupportsCatalogOptions { + private static final String DEFAULT_CATALOG_NAME = "default_iceberg"; + private static final String DEFAULT_CACHE_CATALOG_NAME = "default_cache_iceberg"; + private static final String DEFAULT_CATALOG = "spark.sql.catalog." + DEFAULT_CATALOG_NAME; + private static final String DEFAULT_CACHE_CATALOG = + "spark.sql.catalog." + DEFAULT_CACHE_CATALOG_NAME; + private static final String AT_TIMESTAMP = "at_timestamp_"; + private static final String SNAPSHOT_ID = "snapshot_id_"; + private static final String BRANCH_PREFIX = "branch_"; + private static final String TAG_PREFIX = "tag_"; + private static final String[] EMPTY_NAMESPACE = new String[0]; + + private static final SparkTableCache TABLE_CACHE = SparkTableCache.get(); + + @Override + public String shortName() { + return "iceberg"; + } + + @Override + public StructType inferSchema(CaseInsensitiveStringMap options) { + return null; + } + + @Override + public Transform[] inferPartitioning(CaseInsensitiveStringMap options) { + return getTable(null, null, options).partitioning(); + } + + @Override + public boolean supportsExternalMetadata() { + return true; + } + + @Override + public Table getTable(StructType schema, Transform[] partitioning, Map options) { + Spark3Util.CatalogAndIdentifier catalogIdentifier = + catalogAndIdentifier(new CaseInsensitiveStringMap(options)); + CatalogPlugin catalog = catalogIdentifier.catalog(); + Identifier ident = catalogIdentifier.identifier(); + + try { + if (catalog instanceof TableCatalog) { + return ((TableCatalog) catalog).loadTable(ident); + } + } catch (NoSuchTableException e) { + // throwing an iceberg NoSuchTableException because the Spark one is typed and cant be thrown + // from this interface + throw new org.apache.iceberg.exceptions.NoSuchTableException( + e, "Cannot find table for %s.", ident); + } + + // throwing an iceberg NoSuchTableException because the Spark one is typed and cant be thrown + // from this interface + throw new org.apache.iceberg.exceptions.NoSuchTableException( + "Cannot find table for %s.", ident); + } + + private Spark3Util.CatalogAndIdentifier catalogAndIdentifier(CaseInsensitiveStringMap options) { + Preconditions.checkArgument( + options.containsKey(SparkReadOptions.PATH), "Cannot open table: path is not set"); + SparkSession spark = SparkSession.active(); + setupDefaultSparkCatalogs(spark); + String path = options.get(SparkReadOptions.PATH); + + Long snapshotId = propertyAsLong(options, SparkReadOptions.SNAPSHOT_ID); + Long asOfTimestamp = propertyAsLong(options, SparkReadOptions.AS_OF_TIMESTAMP); + String branch = options.get(SparkReadOptions.BRANCH); + String tag = options.get(SparkReadOptions.TAG); + Preconditions.checkArgument( + Stream.of(snapshotId, asOfTimestamp, branch, tag).filter(Objects::nonNull).count() <= 1, + "Can specify only one of snapshot-id (%s), as-of-timestamp (%s), branch (%s), tag (%s)", + snapshotId, + asOfTimestamp, + branch, + tag); + + String selector = null; + + if (snapshotId != null) { + selector = SNAPSHOT_ID + snapshotId; + } + + if (asOfTimestamp != null) { + selector = AT_TIMESTAMP + asOfTimestamp; + } + + if (branch != null) { + selector = BRANCH_PREFIX + branch; + } + + if (tag != null) { + selector = TAG_PREFIX + tag; + } + + CatalogManager catalogManager = spark.sessionState().catalogManager(); + + if (TABLE_CACHE.contains(path)) { + return new Spark3Util.CatalogAndIdentifier( + catalogManager.catalog(DEFAULT_CACHE_CATALOG_NAME), + Identifier.of(EMPTY_NAMESPACE, pathWithSelector(path, selector))); + } else if (path.contains("/")) { + // contains a path. Return iceberg default catalog and a PathIdentifier + return new Spark3Util.CatalogAndIdentifier( + catalogManager.catalog(DEFAULT_CATALOG_NAME), + new PathIdentifier(pathWithSelector(path, selector))); + } + + final Spark3Util.CatalogAndIdentifier catalogAndIdentifier = + Spark3Util.catalogAndIdentifier("path or identifier", spark, path); + + Identifier ident = identifierWithSelector(catalogAndIdentifier.identifier(), selector); + if (catalogAndIdentifier.catalog().name().equals("spark_catalog") + && !(catalogAndIdentifier.catalog() instanceof SparkSessionCatalog)) { + // catalog is a session catalog but does not support Iceberg. Use Iceberg instead. + return new Spark3Util.CatalogAndIdentifier( + catalogManager.catalog(DEFAULT_CATALOG_NAME), ident); + } else { + return new Spark3Util.CatalogAndIdentifier(catalogAndIdentifier.catalog(), ident); + } + } + + private String pathWithSelector(String path, String selector) { + return (selector == null) ? path : path + "#" + selector; + } + + private Identifier identifierWithSelector(Identifier ident, String selector) { + if (selector == null) { + return ident; + } else { + String[] namespace = ident.namespace(); + String[] ns = Arrays.copyOf(namespace, namespace.length + 1); + ns[namespace.length] = ident.name(); + return Identifier.of(ns, selector); + } + } + + @Override + public Identifier extractIdentifier(CaseInsensitiveStringMap options) { + return catalogAndIdentifier(options).identifier(); + } + + @Override + public String extractCatalog(CaseInsensitiveStringMap options) { + return catalogAndIdentifier(options).catalog().name(); + } + + @Override + public Optional extractTimeTravelVersion(CaseInsensitiveStringMap options) { + return Optional.ofNullable( + PropertyUtil.propertyAsString(options, SparkReadOptions.VERSION_AS_OF, null)); + } + + @Override + public Optional extractTimeTravelTimestamp(CaseInsensitiveStringMap options) { + return Optional.ofNullable( + PropertyUtil.propertyAsString(options, SparkReadOptions.TIMESTAMP_AS_OF, null)); + } + + private static Long propertyAsLong(CaseInsensitiveStringMap options, String property) { + String value = options.get(property); + if (value != null) { + return Long.parseLong(value); + } + + return null; + } + + private static void setupDefaultSparkCatalogs(SparkSession spark) { + if (!spark.conf().contains(DEFAULT_CATALOG)) { + ImmutableMap config = + ImmutableMap.of( + "type", "hive", + "default-namespace", "default", + "cache-enabled", "false" // the source should not use a cache + ); + spark.conf().set(DEFAULT_CATALOG, SparkCatalog.class.getName()); + config.forEach((key, value) -> spark.conf().set(DEFAULT_CATALOG + "." + key, value)); + } + + if (!spark.conf().contains(DEFAULT_CACHE_CATALOG)) { + spark.conf().set(DEFAULT_CACHE_CATALOG, SparkCachedTableCatalog.class.getName()); + } + } +} diff --git a/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/source/InternalRowWrapper.java b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/source/InternalRowWrapper.java new file mode 100644 index 000000000000..524266f6f83a --- /dev/null +++ b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/source/InternalRowWrapper.java @@ -0,0 +1,91 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.source; + +import java.nio.ByteBuffer; +import java.util.function.BiFunction; +import java.util.stream.Stream; +import org.apache.iceberg.StructLike; +import org.apache.spark.sql.catalyst.InternalRow; +import org.apache.spark.sql.types.BinaryType; +import org.apache.spark.sql.types.DataType; +import org.apache.spark.sql.types.DecimalType; +import org.apache.spark.sql.types.StringType; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; + +/** + * Class to adapt a Spark {@code InternalRow} to Iceberg {@link StructLike} for uses like {@link + * org.apache.iceberg.PartitionKey#partition(StructLike)} + */ +class InternalRowWrapper implements StructLike { + private final DataType[] types; + private final BiFunction[] getters; + private InternalRow row = null; + + @SuppressWarnings("unchecked") + InternalRowWrapper(StructType rowType) { + this.types = Stream.of(rowType.fields()).map(StructField::dataType).toArray(DataType[]::new); + this.getters = Stream.of(types).map(InternalRowWrapper::getter).toArray(BiFunction[]::new); + } + + InternalRowWrapper wrap(InternalRow internalRow) { + this.row = internalRow; + return this; + } + + @Override + public int size() { + return types.length; + } + + @Override + public T get(int pos, Class javaClass) { + if (row.isNullAt(pos)) { + return null; + } else if (getters[pos] != null) { + return javaClass.cast(getters[pos].apply(row, pos)); + } + + return javaClass.cast(row.get(pos, types[pos])); + } + + @Override + public void set(int pos, T value) { + row.update(pos, value); + } + + private static BiFunction getter(DataType type) { + if (type instanceof StringType) { + return (row, pos) -> row.getUTF8String(pos).toString(); + } else if (type instanceof DecimalType) { + DecimalType decimal = (DecimalType) type; + return (row, pos) -> + row.getDecimal(pos, decimal.precision(), decimal.scale()).toJavaBigDecimal(); + } else if (type instanceof BinaryType) { + return (row, pos) -> ByteBuffer.wrap(row.getBinary(pos)); + } else if (type instanceof StructType) { + StructType structType = (StructType) type; + InternalRowWrapper nestedWrapper = new InternalRowWrapper(structType); + return (row, pos) -> nestedWrapper.wrap(row.getStruct(pos, structType.size())); + } + + return null; + } +} diff --git a/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/source/PositionDeletesRowReader.java b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/source/PositionDeletesRowReader.java new file mode 100644 index 000000000000..4b847474153c --- /dev/null +++ b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/source/PositionDeletesRowReader.java @@ -0,0 +1,111 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.source; + +import java.util.Map; +import java.util.Set; +import java.util.stream.Collectors; +import java.util.stream.Stream; +import org.apache.iceberg.ContentFile; +import org.apache.iceberg.PositionDeletesScanTask; +import org.apache.iceberg.ScanTaskGroup; +import org.apache.iceberg.Schema; +import org.apache.iceberg.Table; +import org.apache.iceberg.expressions.Expression; +import org.apache.iceberg.expressions.ExpressionUtil; +import org.apache.iceberg.io.CloseableIterator; +import org.apache.iceberg.io.InputFile; +import org.apache.iceberg.relocated.com.google.common.base.Preconditions; +import org.apache.iceberg.relocated.com.google.common.primitives.Ints; +import org.apache.iceberg.util.SnapshotUtil; +import org.apache.spark.rdd.InputFileBlockHolder; +import org.apache.spark.sql.catalyst.InternalRow; +import org.apache.spark.sql.connector.read.PartitionReader; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +class PositionDeletesRowReader extends BaseRowReader + implements PartitionReader { + + private static final Logger LOG = LoggerFactory.getLogger(PositionDeletesRowReader.class); + + PositionDeletesRowReader(SparkInputPartition partition) { + this( + partition.table(), + partition.taskGroup(), + SnapshotUtil.schemaFor(partition.table(), partition.branch()), + partition.expectedSchema(), + partition.isCaseSensitive()); + } + + PositionDeletesRowReader( + Table table, + ScanTaskGroup taskGroup, + Schema tableSchema, + Schema expectedSchema, + boolean caseSensitive) { + + super(table, taskGroup, tableSchema, expectedSchema, caseSensitive); + + int numSplits = taskGroup.tasks().size(); + LOG.debug("Reading {} position delete file split(s) for table {}", numSplits, table.name()); + } + + @Override + protected Stream> referencedFiles(PositionDeletesScanTask task) { + return Stream.of(task.file()); + } + + @Override + protected CloseableIterator open(PositionDeletesScanTask task) { + String filePath = task.file().path().toString(); + LOG.debug("Opening position delete file {}", filePath); + + // update the current file for Spark's filename() function + InputFileBlockHolder.set(filePath, task.start(), task.length()); + + InputFile inputFile = getInputFile(task.file().path().toString()); + Preconditions.checkNotNull(inputFile, "Could not find InputFile associated with %s", task); + + // select out constant fields when pushing down filter to row reader + Map idToConstant = constantsMap(task, expectedSchema()); + Set nonConstantFieldIds = nonConstantFieldIds(idToConstant); + Expression residualWithoutConstants = + ExpressionUtil.extractByIdInclusive( + task.residual(), expectedSchema(), caseSensitive(), Ints.toArray(nonConstantFieldIds)); + + return newIterable( + inputFile, + task.file().format(), + task.start(), + task.length(), + residualWithoutConstants, + expectedSchema(), + idToConstant) + .iterator(); + } + + private Set nonConstantFieldIds(Map idToConstant) { + Set fields = expectedSchema().idToName().keySet(); + return fields.stream() + .filter(id -> expectedSchema().findField(id).type().isPrimitiveType()) + .filter(id -> !idToConstant.containsKey(id)) + .collect(Collectors.toSet()); + } +} diff --git a/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/source/RowDataReader.java b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/source/RowDataReader.java new file mode 100644 index 000000000000..9356f62f3593 --- /dev/null +++ b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/source/RowDataReader.java @@ -0,0 +1,121 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.source; + +import java.util.Map; +import java.util.stream.Stream; +import org.apache.iceberg.ContentFile; +import org.apache.iceberg.DataTask; +import org.apache.iceberg.FileScanTask; +import org.apache.iceberg.ScanTaskGroup; +import org.apache.iceberg.Schema; +import org.apache.iceberg.Table; +import org.apache.iceberg.io.CloseableIterable; +import org.apache.iceberg.io.CloseableIterator; +import org.apache.iceberg.io.InputFile; +import org.apache.iceberg.relocated.com.google.common.base.Preconditions; +import org.apache.iceberg.spark.source.metrics.TaskNumDeletes; +import org.apache.iceberg.spark.source.metrics.TaskNumSplits; +import org.apache.iceberg.util.SnapshotUtil; +import org.apache.spark.rdd.InputFileBlockHolder; +import org.apache.spark.sql.catalyst.InternalRow; +import org.apache.spark.sql.connector.metric.CustomTaskMetric; +import org.apache.spark.sql.connector.read.PartitionReader; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +class RowDataReader extends BaseRowReader implements PartitionReader { + private static final Logger LOG = LoggerFactory.getLogger(RowDataReader.class); + + private final long numSplits; + + RowDataReader(SparkInputPartition partition) { + this( + partition.table(), + partition.taskGroup(), + SnapshotUtil.schemaFor(partition.table(), partition.branch()), + partition.expectedSchema(), + partition.isCaseSensitive()); + } + + RowDataReader( + Table table, + ScanTaskGroup taskGroup, + Schema tableSchema, + Schema expectedSchema, + boolean caseSensitive) { + + super(table, taskGroup, tableSchema, expectedSchema, caseSensitive); + + numSplits = taskGroup.tasks().size(); + LOG.debug("Reading {} file split(s) for table {}", numSplits, table.name()); + } + + @Override + public CustomTaskMetric[] currentMetricsValues() { + return new CustomTaskMetric[] { + new TaskNumSplits(numSplits), new TaskNumDeletes(counter().get()) + }; + } + + @Override + protected Stream> referencedFiles(FileScanTask task) { + return Stream.concat(Stream.of(task.file()), task.deletes().stream()); + } + + @Override + protected CloseableIterator open(FileScanTask task) { + String filePath = task.file().path().toString(); + LOG.debug("Opening data file {}", filePath); + SparkDeleteFilter deleteFilter = new SparkDeleteFilter(filePath, task.deletes(), counter()); + + // schema or rows returned by readers + Schema requiredSchema = deleteFilter.requiredSchema(); + Map idToConstant = constantsMap(task, requiredSchema); + + // update the current file for Spark's filename() function + InputFileBlockHolder.set(filePath, task.start(), task.length()); + + return deleteFilter.filter(open(task, requiredSchema, idToConstant)).iterator(); + } + + protected CloseableIterable open( + FileScanTask task, Schema readSchema, Map idToConstant) { + if (task.isDataTask()) { + return newDataIterable(task.asDataTask(), readSchema); + } else { + InputFile inputFile = getInputFile(task.file().path().toString()); + Preconditions.checkNotNull( + inputFile, "Could not find InputFile associated with FileScanTask"); + return newIterable( + inputFile, + task.file().format(), + task.start(), + task.length(), + task.residual(), + readSchema, + idToConstant); + } + } + + private CloseableIterable newDataIterable(DataTask task, Schema readSchema) { + StructInternalRow row = new StructInternalRow(readSchema.asStruct()); + return CloseableIterable.transform(task.asDataTask().rows(), row::setStruct); + } +} diff --git a/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/source/SerializableTableWithSize.java b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/source/SerializableTableWithSize.java new file mode 100644 index 000000000000..e3b81cea7cd1 --- /dev/null +++ b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/source/SerializableTableWithSize.java @@ -0,0 +1,64 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.source; + +import org.apache.iceberg.BaseMetadataTable; +import org.apache.iceberg.SerializableTable; +import org.apache.iceberg.Table; +import org.apache.spark.util.KnownSizeEstimation; + +/** + * This class provides a serializable table with a known size estimate. Spark calls its + * SizeEstimator class when broadcasting variables and this can be an expensive operation, so + * providing a known size estimate allows that operation to be skipped. + */ +public class SerializableTableWithSize extends SerializableTable implements KnownSizeEstimation { + + private static final long SIZE_ESTIMATE = 32_768L; + + protected SerializableTableWithSize(Table table) { + super(table); + } + + @Override + public long estimatedSize() { + return SIZE_ESTIMATE; + } + + public static Table copyOf(Table table) { + if (table instanceof BaseMetadataTable) { + return new SerializableMetadataTableWithSize((BaseMetadataTable) table); + } else { + return new SerializableTableWithSize(table); + } + } + + public static class SerializableMetadataTableWithSize extends SerializableMetadataTable + implements KnownSizeEstimation { + + protected SerializableMetadataTableWithSize(BaseMetadataTable metadataTable) { + super(metadataTable); + } + + @Override + public long estimatedSize() { + return SIZE_ESTIMATE; + } + } +} diff --git a/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/source/SparkAppenderFactory.java b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/source/SparkAppenderFactory.java new file mode 100644 index 000000000000..6372edde0782 --- /dev/null +++ b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/source/SparkAppenderFactory.java @@ -0,0 +1,318 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.source; + +import java.io.IOException; +import java.io.UncheckedIOException; +import java.util.Map; +import org.apache.iceberg.FileFormat; +import org.apache.iceberg.MetricsConfig; +import org.apache.iceberg.PartitionSpec; +import org.apache.iceberg.Schema; +import org.apache.iceberg.StructLike; +import org.apache.iceberg.Table; +import org.apache.iceberg.avro.Avro; +import org.apache.iceberg.deletes.EqualityDeleteWriter; +import org.apache.iceberg.deletes.PositionDeleteWriter; +import org.apache.iceberg.encryption.EncryptedOutputFile; +import org.apache.iceberg.exceptions.RuntimeIOException; +import org.apache.iceberg.io.DataWriter; +import org.apache.iceberg.io.DeleteSchemaUtil; +import org.apache.iceberg.io.FileAppender; +import org.apache.iceberg.io.FileAppenderFactory; +import org.apache.iceberg.io.OutputFile; +import org.apache.iceberg.orc.ORC; +import org.apache.iceberg.parquet.Parquet; +import org.apache.iceberg.relocated.com.google.common.base.Preconditions; +import org.apache.iceberg.spark.SparkSchemaUtil; +import org.apache.iceberg.spark.data.SparkAvroWriter; +import org.apache.iceberg.spark.data.SparkOrcWriter; +import org.apache.iceberg.spark.data.SparkParquetWriters; +import org.apache.spark.sql.catalyst.InternalRow; +import org.apache.spark.sql.types.StructType; +import org.apache.spark.unsafe.types.UTF8String; + +class SparkAppenderFactory implements FileAppenderFactory { + private final Map properties; + private final Schema writeSchema; + private final StructType dsSchema; + private final PartitionSpec spec; + private final int[] equalityFieldIds; + private final Schema eqDeleteRowSchema; + private final Schema posDeleteRowSchema; + + private StructType eqDeleteSparkType = null; + private StructType posDeleteSparkType = null; + + SparkAppenderFactory( + Map properties, + Schema writeSchema, + StructType dsSchema, + PartitionSpec spec, + int[] equalityFieldIds, + Schema eqDeleteRowSchema, + Schema posDeleteRowSchema) { + this.properties = properties; + this.writeSchema = writeSchema; + this.dsSchema = dsSchema; + this.spec = spec; + this.equalityFieldIds = equalityFieldIds; + this.eqDeleteRowSchema = eqDeleteRowSchema; + this.posDeleteRowSchema = posDeleteRowSchema; + } + + static Builder builderFor(Table table, Schema writeSchema, StructType dsSchema) { + return new Builder(table, writeSchema, dsSchema); + } + + static class Builder { + private final Table table; + private final Schema writeSchema; + private final StructType dsSchema; + private PartitionSpec spec; + private int[] equalityFieldIds; + private Schema eqDeleteRowSchema; + private Schema posDeleteRowSchema; + + Builder(Table table, Schema writeSchema, StructType dsSchema) { + this.table = table; + this.spec = table.spec(); + this.writeSchema = writeSchema; + this.dsSchema = dsSchema; + } + + Builder spec(PartitionSpec newSpec) { + this.spec = newSpec; + return this; + } + + Builder equalityFieldIds(int[] newEqualityFieldIds) { + this.equalityFieldIds = newEqualityFieldIds; + return this; + } + + Builder eqDeleteRowSchema(Schema newEqDeleteRowSchema) { + this.eqDeleteRowSchema = newEqDeleteRowSchema; + return this; + } + + Builder posDelRowSchema(Schema newPosDelRowSchema) { + this.posDeleteRowSchema = newPosDelRowSchema; + return this; + } + + SparkAppenderFactory build() { + Preconditions.checkNotNull(table, "Table must not be null"); + Preconditions.checkNotNull(writeSchema, "Write Schema must not be null"); + Preconditions.checkNotNull(dsSchema, "DS Schema must not be null"); + if (equalityFieldIds != null) { + Preconditions.checkNotNull( + eqDeleteRowSchema, + "Equality Field Ids and Equality Delete Row Schema" + " must be set together"); + } + if (eqDeleteRowSchema != null) { + Preconditions.checkNotNull( + equalityFieldIds, + "Equality Field Ids and Equality Delete Row Schema" + " must be set together"); + } + + return new SparkAppenderFactory( + table.properties(), + writeSchema, + dsSchema, + spec, + equalityFieldIds, + eqDeleteRowSchema, + posDeleteRowSchema); + } + } + + private StructType lazyEqDeleteSparkType() { + if (eqDeleteSparkType == null) { + Preconditions.checkNotNull(eqDeleteRowSchema, "Equality delete row schema shouldn't be null"); + this.eqDeleteSparkType = SparkSchemaUtil.convert(eqDeleteRowSchema); + } + return eqDeleteSparkType; + } + + private StructType lazyPosDeleteSparkType() { + if (posDeleteSparkType == null) { + Preconditions.checkNotNull( + posDeleteRowSchema, "Position delete row schema shouldn't be null"); + this.posDeleteSparkType = SparkSchemaUtil.convert(posDeleteRowSchema); + } + return posDeleteSparkType; + } + + @Override + public FileAppender newAppender(OutputFile file, FileFormat fileFormat) { + MetricsConfig metricsConfig = MetricsConfig.fromProperties(properties); + try { + switch (fileFormat) { + case PARQUET: + return Parquet.write(file) + .createWriterFunc(msgType -> SparkParquetWriters.buildWriter(dsSchema, msgType)) + .setAll(properties) + .metricsConfig(metricsConfig) + .schema(writeSchema) + .overwrite() + .build(); + + case AVRO: + return Avro.write(file) + .createWriterFunc(ignored -> new SparkAvroWriter(dsSchema)) + .setAll(properties) + .schema(writeSchema) + .overwrite() + .build(); + + case ORC: + return ORC.write(file) + .createWriterFunc(SparkOrcWriter::new) + .setAll(properties) + .metricsConfig(metricsConfig) + .schema(writeSchema) + .overwrite() + .build(); + + default: + throw new UnsupportedOperationException("Cannot write unknown format: " + fileFormat); + } + } catch (IOException e) { + throw new RuntimeIOException(e); + } + } + + @Override + public DataWriter newDataWriter( + EncryptedOutputFile file, FileFormat format, StructLike partition) { + return new DataWriter<>( + newAppender(file.encryptingOutputFile(), format), + format, + file.encryptingOutputFile().location(), + spec, + partition, + file.keyMetadata()); + } + + @Override + public EqualityDeleteWriter newEqDeleteWriter( + EncryptedOutputFile file, FileFormat format, StructLike partition) { + Preconditions.checkState( + equalityFieldIds != null && equalityFieldIds.length > 0, + "Equality field ids shouldn't be null or empty when creating equality-delete writer"); + Preconditions.checkNotNull( + eqDeleteRowSchema, + "Equality delete row schema shouldn't be null when creating equality-delete writer"); + + try { + switch (format) { + case PARQUET: + return Parquet.writeDeletes(file.encryptingOutputFile()) + .createWriterFunc( + msgType -> SparkParquetWriters.buildWriter(lazyEqDeleteSparkType(), msgType)) + .overwrite() + .rowSchema(eqDeleteRowSchema) + .withSpec(spec) + .withPartition(partition) + .equalityFieldIds(equalityFieldIds) + .withKeyMetadata(file.keyMetadata()) + .buildEqualityWriter(); + + case AVRO: + return Avro.writeDeletes(file.encryptingOutputFile()) + .createWriterFunc(ignored -> new SparkAvroWriter(lazyEqDeleteSparkType())) + .overwrite() + .rowSchema(eqDeleteRowSchema) + .withSpec(spec) + .withPartition(partition) + .equalityFieldIds(equalityFieldIds) + .withKeyMetadata(file.keyMetadata()) + .buildEqualityWriter(); + + case ORC: + return ORC.writeDeletes(file.encryptingOutputFile()) + .createWriterFunc(SparkOrcWriter::new) + .overwrite() + .rowSchema(eqDeleteRowSchema) + .withSpec(spec) + .withPartition(partition) + .equalityFieldIds(equalityFieldIds) + .withKeyMetadata(file.keyMetadata()) + .buildEqualityWriter(); + + default: + throw new UnsupportedOperationException( + "Cannot write equality-deletes for unsupported file format: " + format); + } + } catch (IOException e) { + throw new UncheckedIOException("Failed to create new equality delete writer", e); + } + } + + @Override + public PositionDeleteWriter newPosDeleteWriter( + EncryptedOutputFile file, FileFormat format, StructLike partition) { + try { + switch (format) { + case PARQUET: + StructType sparkPosDeleteSchema = + SparkSchemaUtil.convert(DeleteSchemaUtil.posDeleteSchema(posDeleteRowSchema)); + return Parquet.writeDeletes(file.encryptingOutputFile()) + .createWriterFunc( + msgType -> SparkParquetWriters.buildWriter(sparkPosDeleteSchema, msgType)) + .overwrite() + .rowSchema(posDeleteRowSchema) + .withSpec(spec) + .withPartition(partition) + .withKeyMetadata(file.keyMetadata()) + .transformPaths(path -> UTF8String.fromString(path.toString())) + .buildPositionWriter(); + + case AVRO: + return Avro.writeDeletes(file.encryptingOutputFile()) + .createWriterFunc(ignored -> new SparkAvroWriter(lazyPosDeleteSparkType())) + .overwrite() + .rowSchema(posDeleteRowSchema) + .withSpec(spec) + .withPartition(partition) + .withKeyMetadata(file.keyMetadata()) + .buildPositionWriter(); + + case ORC: + return ORC.writeDeletes(file.encryptingOutputFile()) + .createWriterFunc(SparkOrcWriter::new) + .overwrite() + .rowSchema(posDeleteRowSchema) + .withSpec(spec) + .withPartition(partition) + .withKeyMetadata(file.keyMetadata()) + .transformPaths(path -> UTF8String.fromString(path.toString())) + .buildPositionWriter(); + + default: + throw new UnsupportedOperationException( + "Cannot write pos-deletes for unsupported file format: " + format); + } + + } catch (IOException e) { + throw new UncheckedIOException("Failed to create new equality delete writer", e); + } + } +} diff --git a/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/source/SparkBatch.java b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/source/SparkBatch.java new file mode 100644 index 000000000000..63aef25ba9b1 --- /dev/null +++ b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/source/SparkBatch.java @@ -0,0 +1,181 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.source; + +import java.util.List; +import java.util.Objects; +import org.apache.iceberg.FileFormat; +import org.apache.iceberg.FileScanTask; +import org.apache.iceberg.ScanTask; +import org.apache.iceberg.ScanTaskGroup; +import org.apache.iceberg.Schema; +import org.apache.iceberg.SchemaParser; +import org.apache.iceberg.Table; +import org.apache.iceberg.spark.SparkReadConf; +import org.apache.iceberg.types.Types; +import org.apache.iceberg.util.Tasks; +import org.apache.iceberg.util.ThreadPools; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.broadcast.Broadcast; +import org.apache.spark.sql.connector.read.Batch; +import org.apache.spark.sql.connector.read.InputPartition; +import org.apache.spark.sql.connector.read.PartitionReaderFactory; + +class SparkBatch implements Batch { + + private final JavaSparkContext sparkContext; + private final Table table; + private final String branch; + private final SparkReadConf readConf; + private final Types.StructType groupingKeyType; + private final List> taskGroups; + private final Schema expectedSchema; + private final boolean caseSensitive; + private final boolean localityEnabled; + private final int scanHashCode; + + SparkBatch( + JavaSparkContext sparkContext, + Table table, + SparkReadConf readConf, + Types.StructType groupingKeyType, + List> taskGroups, + Schema expectedSchema, + int scanHashCode) { + this.sparkContext = sparkContext; + this.table = table; + this.branch = readConf.branch(); + this.readConf = readConf; + this.groupingKeyType = groupingKeyType; + this.taskGroups = taskGroups; + this.expectedSchema = expectedSchema; + this.caseSensitive = readConf.caseSensitive(); + this.localityEnabled = readConf.localityEnabled(); + this.scanHashCode = scanHashCode; + } + + @Override + public InputPartition[] planInputPartitions() { + // broadcast the table metadata as input partitions will be sent to executors + Broadcast

tableBroadcast = + sparkContext.broadcast(SerializableTableWithSize.copyOf(table)); + String expectedSchemaString = SchemaParser.toJson(expectedSchema); + + InputPartition[] partitions = new InputPartition[taskGroups.size()]; + + Tasks.range(partitions.length) + .stopOnFailure() + .executeWith(localityEnabled ? ThreadPools.getWorkerPool() : null) + .run( + index -> + partitions[index] = + new SparkInputPartition( + groupingKeyType, + taskGroups.get(index), + tableBroadcast, + branch, + expectedSchemaString, + caseSensitive, + localityEnabled)); + + return partitions; + } + + @Override + public PartitionReaderFactory createReaderFactory() { + if (useParquetBatchReads()) { + int batchSize = readConf.parquetBatchSize(); + return new SparkColumnarReaderFactory(batchSize); + + } else if (useOrcBatchReads()) { + int batchSize = readConf.orcBatchSize(); + return new SparkColumnarReaderFactory(batchSize); + + } else { + return new SparkRowReaderFactory(); + } + } + + // conditions for using Parquet batch reads: + // - Parquet vectorization is enabled + // - at least one column is projected + // - only primitives are projected + // - all tasks are of FileScanTask type and read only Parquet files + private boolean useParquetBatchReads() { + return readConf.parquetVectorizationEnabled() + && expectedSchema.columns().size() > 0 + && expectedSchema.columns().stream().allMatch(c -> c.type().isPrimitiveType()) + && taskGroups.stream().allMatch(this::supportsParquetBatchReads); + } + + private boolean supportsParquetBatchReads(ScanTask task) { + if (task instanceof ScanTaskGroup) { + ScanTaskGroup taskGroup = (ScanTaskGroup) task; + return taskGroup.tasks().stream().allMatch(this::supportsParquetBatchReads); + + } else if (task.isFileScanTask() && !task.isDataTask()) { + FileScanTask fileScanTask = task.asFileScanTask(); + return fileScanTask.file().format() == FileFormat.PARQUET; + + } else { + return false; + } + } + + // conditions for using ORC batch reads: + // - ORC vectorization is enabled + // - all tasks are of type FileScanTask and read only ORC files with no delete files + private boolean useOrcBatchReads() { + return readConf.orcVectorizationEnabled() + && taskGroups.stream().allMatch(this::supportsOrcBatchReads); + } + + private boolean supportsOrcBatchReads(ScanTask task) { + if (task instanceof ScanTaskGroup) { + ScanTaskGroup taskGroup = (ScanTaskGroup) task; + return taskGroup.tasks().stream().allMatch(this::supportsOrcBatchReads); + + } else if (task.isFileScanTask() && !task.isDataTask()) { + FileScanTask fileScanTask = task.asFileScanTask(); + return fileScanTask.file().format() == FileFormat.ORC && fileScanTask.deletes().isEmpty(); + + } else { + return false; + } + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + + if (o == null || getClass() != o.getClass()) { + return false; + } + + SparkBatch that = (SparkBatch) o; + return table.name().equals(that.table.name()) && scanHashCode == that.scanHashCode; + } + + @Override + public int hashCode() { + return Objects.hash(table.name(), scanHashCode); + } +} diff --git a/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/source/SparkBatchQueryScan.java b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/source/SparkBatchQueryScan.java new file mode 100644 index 000000000000..dd493fbc5097 --- /dev/null +++ b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/source/SparkBatchQueryScan.java @@ -0,0 +1,259 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.source; + +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.Set; +import java.util.stream.Collectors; +import org.apache.iceberg.PartitionField; +import org.apache.iceberg.PartitionScanTask; +import org.apache.iceberg.PartitionSpec; +import org.apache.iceberg.Scan; +import org.apache.iceberg.ScanTask; +import org.apache.iceberg.ScanTaskGroup; +import org.apache.iceberg.Schema; +import org.apache.iceberg.Snapshot; +import org.apache.iceberg.Table; +import org.apache.iceberg.exceptions.ValidationException; +import org.apache.iceberg.expressions.Binder; +import org.apache.iceberg.expressions.Evaluator; +import org.apache.iceberg.expressions.Expression; +import org.apache.iceberg.expressions.ExpressionUtil; +import org.apache.iceberg.expressions.Expressions; +import org.apache.iceberg.expressions.Projections; +import org.apache.iceberg.relocated.com.google.common.collect.Lists; +import org.apache.iceberg.relocated.com.google.common.collect.Maps; +import org.apache.iceberg.relocated.com.google.common.collect.Sets; +import org.apache.iceberg.spark.Spark3Util; +import org.apache.iceberg.spark.SparkFilters; +import org.apache.iceberg.spark.SparkReadConf; +import org.apache.iceberg.spark.SparkSchemaUtil; +import org.apache.iceberg.util.SnapshotUtil; +import org.apache.spark.sql.SparkSession; +import org.apache.spark.sql.connector.expressions.NamedReference; +import org.apache.spark.sql.connector.read.Statistics; +import org.apache.spark.sql.connector.read.SupportsRuntimeFiltering; +import org.apache.spark.sql.sources.Filter; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +class SparkBatchQueryScan extends SparkPartitioningAwareScan + implements SupportsRuntimeFiltering { + + private static final Logger LOG = LoggerFactory.getLogger(SparkBatchQueryScan.class); + + private final Long snapshotId; + private final Long startSnapshotId; + private final Long endSnapshotId; + private final Long asOfTimestamp; + private final String tag; + private final List runtimeFilterExpressions; + + SparkBatchQueryScan( + SparkSession spark, + Table table, + Scan> scan, + SparkReadConf readConf, + Schema expectedSchema, + List filters) { + + super(spark, table, scan, readConf, expectedSchema, filters); + + this.snapshotId = readConf.snapshotId(); + this.startSnapshotId = readConf.startSnapshotId(); + this.endSnapshotId = readConf.endSnapshotId(); + this.asOfTimestamp = readConf.asOfTimestamp(); + this.tag = readConf.tag(); + this.runtimeFilterExpressions = Lists.newArrayList(); + } + + Long snapshotId() { + return snapshotId; + } + + @Override + protected Class taskJavaClass() { + return PartitionScanTask.class; + } + + @Override + public NamedReference[] filterAttributes() { + Set partitionFieldSourceIds = Sets.newHashSet(); + + for (PartitionSpec spec : specs()) { + for (PartitionField field : spec.fields()) { + partitionFieldSourceIds.add(field.sourceId()); + } + } + + Map quotedNameById = SparkSchemaUtil.indexQuotedNameById(expectedSchema()); + + // the optimizer will look for an equality condition with filter attributes in a join + // as the scan has been already planned, filtering can only be done on projected attributes + // that's why only partition source fields that are part of the read schema can be reported + + return partitionFieldSourceIds.stream() + .filter(fieldId -> expectedSchema().findField(fieldId) != null) + .map(fieldId -> Spark3Util.toNamedReference(quotedNameById.get(fieldId))) + .toArray(NamedReference[]::new); + } + + @Override + public void filter(Filter[] filters) { + Expression runtimeFilterExpr = convertRuntimeFilters(filters); + + if (runtimeFilterExpr != Expressions.alwaysTrue()) { + Map evaluatorsBySpecId = Maps.newHashMap(); + + for (PartitionSpec spec : specs()) { + Expression inclusiveExpr = + Projections.inclusive(spec, caseSensitive()).project(runtimeFilterExpr); + Evaluator inclusive = new Evaluator(spec.partitionType(), inclusiveExpr); + evaluatorsBySpecId.put(spec.specId(), inclusive); + } + + List filteredTasks = + tasks().stream() + .filter( + task -> { + Evaluator evaluator = evaluatorsBySpecId.get(task.spec().specId()); + return evaluator.eval(task.partition()); + }) + .collect(Collectors.toList()); + + LOG.info( + "{} of {} task(s) for table {} matched runtime filter {}", + filteredTasks.size(), + tasks().size(), + table().name(), + ExpressionUtil.toSanitizedString(runtimeFilterExpr)); + + // don't invalidate tasks if the runtime filter had no effect to avoid planning splits again + if (filteredTasks.size() < tasks().size()) { + resetTasks(filteredTasks); + } + + // save the evaluated filter for equals/hashCode + runtimeFilterExpressions.add(runtimeFilterExpr); + } + } + + // at this moment, Spark can only pass IN filters for a single attribute + // if there are multiple filter attributes, Spark will pass two separate IN filters + private Expression convertRuntimeFilters(Filter[] filters) { + Expression runtimeFilterExpr = Expressions.alwaysTrue(); + + for (Filter filter : filters) { + Expression expr = SparkFilters.convert(filter); + if (expr != null) { + try { + Binder.bind(expectedSchema().asStruct(), expr, caseSensitive()); + runtimeFilterExpr = Expressions.and(runtimeFilterExpr, expr); + } catch (ValidationException e) { + LOG.warn("Failed to bind {} to expected schema, skipping runtime filter", expr, e); + } + } else { + LOG.warn("Unsupported runtime filter {}", filter); + } + } + + return runtimeFilterExpr; + } + + @Override + public Statistics estimateStatistics() { + if (scan() == null) { + return estimateStatistics(null); + + } else if (snapshotId != null) { + Snapshot snapshot = table().snapshot(snapshotId); + return estimateStatistics(snapshot); + + } else if (asOfTimestamp != null) { + long snapshotIdAsOfTime = SnapshotUtil.snapshotIdAsOfTime(table(), asOfTimestamp); + Snapshot snapshot = table().snapshot(snapshotIdAsOfTime); + return estimateStatistics(snapshot); + + } else if (branch() != null) { + Snapshot snapshot = table().snapshot(branch()); + return estimateStatistics(snapshot); + + } else if (tag != null) { + Snapshot snapshot = table().snapshot(tag); + return estimateStatistics(snapshot); + + } else { + Snapshot snapshot = table().currentSnapshot(); + return estimateStatistics(snapshot); + } + } + + @Override + @SuppressWarnings("checkstyle:CyclomaticComplexity") + public boolean equals(Object o) { + if (this == o) { + return true; + } + + if (o == null || getClass() != o.getClass()) { + return false; + } + + SparkBatchQueryScan that = (SparkBatchQueryScan) o; + return table().name().equals(that.table().name()) + && Objects.equals(branch(), that.branch()) + && readSchema().equals(that.readSchema()) // compare Spark schemas to ignore field ids + && filterExpressions().toString().equals(that.filterExpressions().toString()) + && runtimeFilterExpressions.toString().equals(that.runtimeFilterExpressions.toString()) + && Objects.equals(snapshotId, that.snapshotId) + && Objects.equals(startSnapshotId, that.startSnapshotId) + && Objects.equals(endSnapshotId, that.endSnapshotId) + && Objects.equals(asOfTimestamp, that.asOfTimestamp) + && Objects.equals(tag, that.tag); + } + + @Override + public int hashCode() { + return Objects.hash( + table().name(), + branch(), + readSchema(), + filterExpressions().toString(), + runtimeFilterExpressions.toString(), + snapshotId, + startSnapshotId, + endSnapshotId, + asOfTimestamp, + tag); + } + + @Override + public String toString() { + return String.format( + "IcebergScan(table=%s, branch=%s, type=%s, filters=%s, runtimeFilters=%s, caseSensitive=%s)", + table(), + branch(), + expectedSchema().asStruct(), + filterExpressions(), + runtimeFilterExpressions, + caseSensitive()); + } +} diff --git a/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/source/SparkChangelogScan.java b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/source/SparkChangelogScan.java new file mode 100644 index 000000000000..54fdd186d473 --- /dev/null +++ b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/source/SparkChangelogScan.java @@ -0,0 +1,172 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.source; + +import java.io.IOException; +import java.io.UncheckedIOException; +import java.util.Collections; +import java.util.List; +import java.util.Objects; +import org.apache.iceberg.ChangelogScanTask; +import org.apache.iceberg.IncrementalChangelogScan; +import org.apache.iceberg.ScanTaskGroup; +import org.apache.iceberg.Schema; +import org.apache.iceberg.Table; +import org.apache.iceberg.expressions.Expression; +import org.apache.iceberg.io.CloseableIterable; +import org.apache.iceberg.relocated.com.google.common.base.Preconditions; +import org.apache.iceberg.relocated.com.google.common.collect.Lists; +import org.apache.iceberg.spark.Spark3Util; +import org.apache.iceberg.spark.SparkReadConf; +import org.apache.iceberg.spark.SparkSchemaUtil; +import org.apache.iceberg.spark.SparkUtil; +import org.apache.iceberg.types.Types; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.sql.SparkSession; +import org.apache.spark.sql.connector.read.Batch; +import org.apache.spark.sql.connector.read.Scan; +import org.apache.spark.sql.connector.read.Statistics; +import org.apache.spark.sql.connector.read.SupportsReportStatistics; +import org.apache.spark.sql.types.StructType; + +class SparkChangelogScan implements Scan, SupportsReportStatistics { + + private static final Types.StructType EMPTY_GROUPING_KEY_TYPE = Types.StructType.of(); + + private final JavaSparkContext sparkContext; + private final Table table; + private final IncrementalChangelogScan scan; + private final SparkReadConf readConf; + private final Schema expectedSchema; + private final List filters; + private final Long startSnapshotId; + private final Long endSnapshotId; + private final boolean readTimestampWithoutZone; + + // lazy variables + private List> taskGroups = null; + private StructType expectedSparkType = null; + + SparkChangelogScan( + SparkSession spark, + Table table, + IncrementalChangelogScan scan, + SparkReadConf readConf, + Schema expectedSchema, + List filters) { + + SparkSchemaUtil.validateMetadataColumnReferences(table.schema(), expectedSchema); + + this.sparkContext = JavaSparkContext.fromSparkContext(spark.sparkContext()); + this.table = table; + this.scan = scan; + this.readConf = readConf; + this.expectedSchema = expectedSchema; + this.filters = filters != null ? filters : Collections.emptyList(); + this.startSnapshotId = readConf.startSnapshotId(); + this.endSnapshotId = readConf.endSnapshotId(); + this.readTimestampWithoutZone = readConf.handleTimestampWithoutZone(); + } + + @Override + public Statistics estimateStatistics() { + long rowsCount = taskGroups().stream().mapToLong(ScanTaskGroup::estimatedRowsCount).sum(); + long sizeInBytes = SparkSchemaUtil.estimateSize(readSchema(), rowsCount); + return new Stats(sizeInBytes, rowsCount); + } + + @Override + public StructType readSchema() { + if (expectedSparkType == null) { + Preconditions.checkArgument( + readTimestampWithoutZone || !SparkUtil.hasTimestampWithoutZone(expectedSchema), + SparkUtil.TIMESTAMP_WITHOUT_TIMEZONE_ERROR); + + this.expectedSparkType = SparkSchemaUtil.convert(expectedSchema); + } + + return expectedSparkType; + } + + @Override + public Batch toBatch() { + return new SparkBatch( + sparkContext, + table, + readConf, + EMPTY_GROUPING_KEY_TYPE, + taskGroups(), + expectedSchema, + hashCode()); + } + + private List> taskGroups() { + if (taskGroups == null) { + try (CloseableIterable> groups = scan.planTasks()) { + this.taskGroups = Lists.newArrayList(groups); + } catch (IOException e) { + throw new UncheckedIOException("Failed to close changelog scan: " + scan, e); + } + } + + return taskGroups; + } + + @Override + public String description() { + return String.format( + "%s [fromSnapshotId=%d, toSnapshotId=%d, filters=%s]", + table, startSnapshotId, endSnapshotId, Spark3Util.describe(filters)); + } + + @Override + public String toString() { + return String.format( + "IcebergChangelogScan(table=%s, type=%s, fromSnapshotId=%d, toSnapshotId=%d, filters=%s)", + table, + expectedSchema.asStruct(), + startSnapshotId, + endSnapshotId, + Spark3Util.describe(filters)); + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + + if (o == null || getClass() != o.getClass()) { + return false; + } + + SparkChangelogScan that = (SparkChangelogScan) o; + return table.name().equals(that.table.name()) + && readSchema().equals(that.readSchema()) // compare Spark schemas to ignore field IDs + && filters.toString().equals(that.filters.toString()) + && Objects.equals(startSnapshotId, that.startSnapshotId) + && Objects.equals(endSnapshotId, that.endSnapshotId); + } + + @Override + public int hashCode() { + return Objects.hash( + table.name(), readSchema(), filters.toString(), startSnapshotId, endSnapshotId); + } +} diff --git a/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/source/SparkChangelogTable.java b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/source/SparkChangelogTable.java new file mode 100644 index 000000000000..61611a08c4d4 --- /dev/null +++ b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/source/SparkChangelogTable.java @@ -0,0 +1,120 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.source; + +import java.util.Set; +import org.apache.iceberg.ChangelogUtil; +import org.apache.iceberg.MetadataColumns; +import org.apache.iceberg.Partitioning; +import org.apache.iceberg.Schema; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableSet; +import org.apache.iceberg.spark.SparkSchemaUtil; +import org.apache.spark.sql.SparkSession; +import org.apache.spark.sql.connector.catalog.MetadataColumn; +import org.apache.spark.sql.connector.catalog.SupportsMetadataColumns; +import org.apache.spark.sql.connector.catalog.SupportsRead; +import org.apache.spark.sql.connector.catalog.Table; +import org.apache.spark.sql.connector.catalog.TableCapability; +import org.apache.spark.sql.connector.read.Scan; +import org.apache.spark.sql.connector.read.ScanBuilder; +import org.apache.spark.sql.types.DataType; +import org.apache.spark.sql.types.DataTypes; +import org.apache.spark.sql.types.StructType; +import org.apache.spark.sql.util.CaseInsensitiveStringMap; + +public class SparkChangelogTable implements Table, SupportsRead, SupportsMetadataColumns { + + public static final String TABLE_NAME = "changes"; + + private static final Set CAPABILITIES = + ImmutableSet.of(TableCapability.BATCH_READ); + + private final org.apache.iceberg.Table icebergTable; + private final boolean refreshEagerly; + + private SparkSession lazySpark = null; + private StructType lazyTableSparkType = null; + private Schema lazyChangelogSchema = null; + + public SparkChangelogTable(org.apache.iceberg.Table icebergTable, boolean refreshEagerly) { + this.icebergTable = icebergTable; + this.refreshEagerly = refreshEagerly; + } + + @Override + public String name() { + return icebergTable.name() + "." + TABLE_NAME; + } + + @Override + public StructType schema() { + if (lazyTableSparkType == null) { + this.lazyTableSparkType = SparkSchemaUtil.convert(changelogSchema()); + } + + return lazyTableSparkType; + } + + @Override + public Set capabilities() { + return CAPABILITIES; + } + + @Override + public ScanBuilder newScanBuilder(CaseInsensitiveStringMap options) { + if (refreshEagerly) { + icebergTable.refresh(); + } + + return new SparkScanBuilder(spark(), icebergTable, changelogSchema(), options) { + @Override + public Scan build() { + return buildChangelogScan(); + } + }; + } + + private Schema changelogSchema() { + if (lazyChangelogSchema == null) { + this.lazyChangelogSchema = ChangelogUtil.changelogSchema(icebergTable.schema()); + } + + return lazyChangelogSchema; + } + + private SparkSession spark() { + if (lazySpark == null) { + this.lazySpark = SparkSession.active(); + } + + return lazySpark; + } + + @Override + public MetadataColumn[] metadataColumns() { + DataType sparkPartitionType = SparkSchemaUtil.convert(Partitioning.partitionType(icebergTable)); + return new MetadataColumn[] { + new SparkMetadataColumn(MetadataColumns.SPEC_ID.name(), DataTypes.IntegerType, false), + new SparkMetadataColumn(MetadataColumns.PARTITION_COLUMN_NAME, sparkPartitionType, true), + new SparkMetadataColumn(MetadataColumns.FILE_PATH.name(), DataTypes.StringType, false), + new SparkMetadataColumn(MetadataColumns.ROW_POSITION.name(), DataTypes.LongType, false), + new SparkMetadataColumn(MetadataColumns.IS_DELETED.name(), DataTypes.BooleanType, false) + }; + } +} diff --git a/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/source/SparkCleanupUtil.java b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/source/SparkCleanupUtil.java new file mode 100644 index 000000000000..a103a5003222 --- /dev/null +++ b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/source/SparkCleanupUtil.java @@ -0,0 +1,138 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.source; + +import java.util.List; +import java.util.concurrent.atomic.AtomicInteger; +import org.apache.iceberg.ContentFile; +import org.apache.iceberg.exceptions.NotFoundException; +import org.apache.iceberg.io.BulkDeletionFailureException; +import org.apache.iceberg.io.FileIO; +import org.apache.iceberg.io.SupportsBulkOperations; +import org.apache.iceberg.relocated.com.google.common.collect.Lists; +import org.apache.iceberg.util.Tasks; +import org.apache.iceberg.util.ThreadPools; +import org.apache.spark.TaskContext; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** A utility for cleaning up written but not committed files. */ +class SparkCleanupUtil { + + private static final Logger LOG = LoggerFactory.getLogger(SparkCleanupUtil.class); + + private static final int DELETE_NUM_RETRIES = 3; + private static final int DELETE_MIN_RETRY_WAIT_MS = 100; // 100 ms + private static final int DELETE_MAX_RETRY_WAIT_MS = 30 * 1000; // 30 seconds + private static final int DELETE_TOTAL_RETRY_TIME_MS = 2 * 60 * 1000; // 2 minutes + + private SparkCleanupUtil() {} + + /** + * Attempts to delete as many files produced by a task as possible. + * + *

Note this method will log Spark task info and is supposed to be called only on executors. + * Use {@link #deleteFiles(String, FileIO, List)} to delete files on the driver. + * + * @param io a {@link FileIO} instance used for deleting files + * @param files a list of files to delete + */ + public static void deleteTaskFiles(FileIO io, List> files) { + deleteFiles(taskInfo(), io, files); + } + + // the format matches what Spark uses for internal logging + private static String taskInfo() { + TaskContext taskContext = TaskContext.get(); + if (taskContext == null) { + return "unknown task"; + } else { + return String.format( + "partition %d (task %d, attempt %d, stage %d.%d)", + taskContext.partitionId(), + taskContext.taskAttemptId(), + taskContext.attemptNumber(), + taskContext.stageId(), + taskContext.stageAttemptNumber()); + } + } + + /** + * Attempts to delete as many given files as possible. + * + * @param context a helpful description of the operation invoking this method + * @param io a {@link FileIO} instance used for deleting files + * @param files a list of files to delete + */ + public static void deleteFiles(String context, FileIO io, List> files) { + List paths = Lists.transform(files, file -> file.path().toString()); + deletePaths(context, io, paths); + } + + private static void deletePaths(String context, FileIO io, List paths) { + if (io instanceof SupportsBulkOperations) { + SupportsBulkOperations bulkIO = (SupportsBulkOperations) io; + bulkDelete(context, bulkIO, paths); + } else { + delete(context, io, paths); + } + } + + private static void bulkDelete(String context, SupportsBulkOperations io, List paths) { + try { + io.deleteFiles(paths); + LOG.info("Deleted {} file(s) using bulk deletes ({})", paths.size(), context); + + } catch (BulkDeletionFailureException e) { + int deletedFilesCount = paths.size() - e.numberFailedObjects(); + LOG.warn( + "Deleted only {} of {} file(s) using bulk deletes ({})", + deletedFilesCount, + paths.size(), + context); + } + } + + private static void delete(String context, FileIO io, List paths) { + AtomicInteger deletedFilesCount = new AtomicInteger(0); + + Tasks.foreach(paths) + .executeWith(ThreadPools.getWorkerPool()) + .stopRetryOn(NotFoundException.class) + .suppressFailureWhenFinished() + .onFailure((path, exc) -> LOG.warn("Failed to delete {} ({})", path, context, exc)) + .retry(DELETE_NUM_RETRIES) + .exponentialBackoff( + DELETE_MIN_RETRY_WAIT_MS, + DELETE_MAX_RETRY_WAIT_MS, + DELETE_TOTAL_RETRY_TIME_MS, + 2 /* exponential */) + .run( + path -> { + io.deleteFile(path); + deletedFilesCount.incrementAndGet(); + }); + + if (deletedFilesCount.get() < paths.size()) { + LOG.warn("Deleted only {} of {} file(s) ({})", deletedFilesCount, paths.size(), context); + } else { + LOG.info("Deleted {} file(s) ({})", paths.size(), context); + } + } +} diff --git a/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/source/SparkColumnarReaderFactory.java b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/source/SparkColumnarReaderFactory.java new file mode 100644 index 000000000000..655e20a50e11 --- /dev/null +++ b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/source/SparkColumnarReaderFactory.java @@ -0,0 +1,64 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.source; + +import org.apache.iceberg.FileScanTask; +import org.apache.iceberg.relocated.com.google.common.base.Preconditions; +import org.apache.spark.sql.catalyst.InternalRow; +import org.apache.spark.sql.connector.read.InputPartition; +import org.apache.spark.sql.connector.read.PartitionReader; +import org.apache.spark.sql.connector.read.PartitionReaderFactory; +import org.apache.spark.sql.vectorized.ColumnarBatch; + +class SparkColumnarReaderFactory implements PartitionReaderFactory { + private final int batchSize; + + SparkColumnarReaderFactory(int batchSize) { + Preconditions.checkArgument(batchSize > 1, "Batch size must be > 1"); + this.batchSize = batchSize; + } + + @Override + public PartitionReader createReader(InputPartition inputPartition) { + throw new UnsupportedOperationException("Row-based reads are not supported"); + } + + @Override + public PartitionReader createColumnarReader(InputPartition inputPartition) { + Preconditions.checkArgument( + inputPartition instanceof SparkInputPartition, + "Unknown input partition type: %s", + inputPartition.getClass().getName()); + + SparkInputPartition partition = (SparkInputPartition) inputPartition; + + if (partition.allTasksOfType(FileScanTask.class)) { + return new BatchDataReader(partition, batchSize); + + } else { + throw new UnsupportedOperationException( + "Unsupported task group for columnar reads: " + partition.taskGroup()); + } + } + + @Override + public boolean supportColumnarReads(InputPartition inputPartition) { + return true; + } +} diff --git a/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/source/SparkCopyOnWriteOperation.java b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/source/SparkCopyOnWriteOperation.java new file mode 100644 index 000000000000..4fca05345a2e --- /dev/null +++ b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/source/SparkCopyOnWriteOperation.java @@ -0,0 +1,107 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.source; + +import static org.apache.spark.sql.connector.write.RowLevelOperation.Command.DELETE; +import static org.apache.spark.sql.connector.write.RowLevelOperation.Command.UPDATE; + +import org.apache.iceberg.IsolationLevel; +import org.apache.iceberg.MetadataColumns; +import org.apache.iceberg.Table; +import org.apache.spark.sql.SparkSession; +import org.apache.spark.sql.connector.expressions.Expressions; +import org.apache.spark.sql.connector.expressions.NamedReference; +import org.apache.spark.sql.connector.read.Scan; +import org.apache.spark.sql.connector.read.ScanBuilder; +import org.apache.spark.sql.connector.write.LogicalWriteInfo; +import org.apache.spark.sql.connector.write.RowLevelOperation; +import org.apache.spark.sql.connector.write.RowLevelOperationInfo; +import org.apache.spark.sql.connector.write.WriteBuilder; +import org.apache.spark.sql.util.CaseInsensitiveStringMap; + +class SparkCopyOnWriteOperation implements RowLevelOperation { + + private final SparkSession spark; + private final Table table; + private final String branch; + private final Command command; + private final IsolationLevel isolationLevel; + + // lazy vars + private ScanBuilder lazyScanBuilder; + private Scan configuredScan; + private WriteBuilder lazyWriteBuilder; + + SparkCopyOnWriteOperation( + SparkSession spark, + Table table, + String branch, + RowLevelOperationInfo info, + IsolationLevel isolationLevel) { + this.spark = spark; + this.table = table; + this.branch = branch; + this.command = info.command(); + this.isolationLevel = isolationLevel; + } + + @Override + public Command command() { + return command; + } + + @Override + public ScanBuilder newScanBuilder(CaseInsensitiveStringMap options) { + if (lazyScanBuilder == null) { + lazyScanBuilder = + new SparkScanBuilder(spark, table, branch, options) { + @Override + public Scan build() { + Scan scan = super.buildCopyOnWriteScan(); + SparkCopyOnWriteOperation.this.configuredScan = scan; + return scan; + } + }; + } + + return lazyScanBuilder; + } + + @Override + public WriteBuilder newWriteBuilder(LogicalWriteInfo info) { + if (lazyWriteBuilder == null) { + SparkWriteBuilder writeBuilder = new SparkWriteBuilder(spark, table, branch, info); + lazyWriteBuilder = writeBuilder.overwriteFiles(configuredScan, command, isolationLevel); + } + + return lazyWriteBuilder; + } + + @Override + public NamedReference[] requiredMetadataAttributes() { + NamedReference file = Expressions.column(MetadataColumns.FILE_PATH.name()); + NamedReference pos = Expressions.column(MetadataColumns.ROW_POSITION.name()); + + if (command == DELETE || command == UPDATE) { + return new NamedReference[] {file, pos}; + } else { + return new NamedReference[] {file}; + } + } +} diff --git a/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/source/SparkCopyOnWriteScan.java b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/source/SparkCopyOnWriteScan.java new file mode 100644 index 000000000000..d978b81e67bd --- /dev/null +++ b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/source/SparkCopyOnWriteScan.java @@ -0,0 +1,186 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.source; + +import java.util.Collections; +import java.util.List; +import java.util.Objects; +import java.util.Set; +import java.util.stream.Collectors; +import org.apache.iceberg.BatchScan; +import org.apache.iceberg.FileScanTask; +import org.apache.iceberg.MetadataColumns; +import org.apache.iceberg.Schema; +import org.apache.iceberg.Snapshot; +import org.apache.iceberg.Table; +import org.apache.iceberg.expressions.Expression; +import org.apache.iceberg.relocated.com.google.common.base.Preconditions; +import org.apache.iceberg.relocated.com.google.common.collect.Sets; +import org.apache.iceberg.spark.SparkReadConf; +import org.apache.iceberg.util.SnapshotUtil; +import org.apache.spark.sql.SparkSession; +import org.apache.spark.sql.connector.expressions.Expressions; +import org.apache.spark.sql.connector.expressions.NamedReference; +import org.apache.spark.sql.connector.read.Statistics; +import org.apache.spark.sql.connector.read.SupportsRuntimeFiltering; +import org.apache.spark.sql.sources.Filter; +import org.apache.spark.sql.sources.In; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +class SparkCopyOnWriteScan extends SparkPartitioningAwareScan + implements SupportsRuntimeFiltering { + + private static final Logger LOG = LoggerFactory.getLogger(SparkCopyOnWriteScan.class); + + private final Snapshot snapshot; + private Set filteredLocations = null; + + SparkCopyOnWriteScan( + SparkSession spark, + Table table, + SparkReadConf readConf, + Schema expectedSchema, + List filters) { + this(spark, table, null, null, readConf, expectedSchema, filters); + } + + SparkCopyOnWriteScan( + SparkSession spark, + Table table, + BatchScan scan, + Snapshot snapshot, + SparkReadConf readConf, + Schema expectedSchema, + List filters) { + + super(spark, table, scan, readConf, expectedSchema, filters); + + this.snapshot = snapshot; + + if (scan == null) { + this.filteredLocations = Collections.emptySet(); + } + } + + Long snapshotId() { + return snapshot != null ? snapshot.snapshotId() : null; + } + + @Override + protected Class taskJavaClass() { + return FileScanTask.class; + } + + @Override + public Statistics estimateStatistics() { + return estimateStatistics(snapshot); + } + + public NamedReference[] filterAttributes() { + NamedReference file = Expressions.column(MetadataColumns.FILE_PATH.name()); + return new NamedReference[] {file}; + } + + @Override + public void filter(Filter[] filters) { + Preconditions.checkState( + Objects.equals(snapshotId(), currentSnapshotId()), + "Runtime file filtering is not possible: the table has been concurrently modified. " + + "Row-level operation scan snapshot ID: %s, current table snapshot ID: %s. " + + "If multiple threads modify the table, use independent Spark sessions in each thread.", + snapshotId(), + currentSnapshotId()); + + for (Filter filter : filters) { + // Spark can only pass In filters at the moment + if (filter instanceof In + && ((In) filter).attribute().equalsIgnoreCase(MetadataColumns.FILE_PATH.name())) { + In in = (In) filter; + + Set fileLocations = Sets.newHashSet(); + for (Object value : in.values()) { + fileLocations.add((String) value); + } + + // Spark may call this multiple times for UPDATEs with subqueries + // as such cases are rewritten using UNION and the same scan on both sides + // so filter files only if it is beneficial + if (filteredLocations == null || fileLocations.size() < filteredLocations.size()) { + this.filteredLocations = fileLocations; + List filteredTasks = + tasks().stream() + .filter(file -> fileLocations.contains(file.file().path().toString())) + .collect(Collectors.toList()); + + LOG.info( + "{} of {} task(s) for table {} matched runtime file filter with {} location(s)", + filteredTasks.size(), + tasks().size(), + table().name(), + fileLocations.size()); + + resetTasks(filteredTasks); + } + } else { + LOG.warn("Unsupported runtime filter {}", filter); + } + } + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + + if (o == null || getClass() != o.getClass()) { + return false; + } + + SparkCopyOnWriteScan that = (SparkCopyOnWriteScan) o; + return table().name().equals(that.table().name()) + && readSchema().equals(that.readSchema()) // compare Spark schemas to ignore field ids + && filterExpressions().toString().equals(that.filterExpressions().toString()) + && Objects.equals(snapshotId(), that.snapshotId()) + && Objects.equals(filteredLocations, that.filteredLocations); + } + + @Override + public int hashCode() { + return Objects.hash( + table().name(), + readSchema(), + filterExpressions().toString(), + snapshotId(), + filteredLocations); + } + + @Override + public String toString() { + return String.format( + "IcebergCopyOnWriteScan(table=%s, type=%s, filters=%s, caseSensitive=%s)", + table(), expectedSchema().asStruct(), filterExpressions(), caseSensitive()); + } + + private Long currentSnapshotId() { + Snapshot currentSnapshot = SnapshotUtil.latestSnapshot(table(), branch()); + return currentSnapshot != null ? currentSnapshot.snapshotId() : null; + } +} diff --git a/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/source/SparkFileWriterFactory.java b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/source/SparkFileWriterFactory.java new file mode 100644 index 000000000000..d90dc5dafc59 --- /dev/null +++ b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/source/SparkFileWriterFactory.java @@ -0,0 +1,275 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.source; + +import static org.apache.iceberg.MetadataColumns.DELETE_FILE_ROW_FIELD_NAME; +import static org.apache.iceberg.TableProperties.DEFAULT_FILE_FORMAT; +import static org.apache.iceberg.TableProperties.DEFAULT_FILE_FORMAT_DEFAULT; +import static org.apache.iceberg.TableProperties.DELETE_DEFAULT_FILE_FORMAT; + +import java.util.Map; +import org.apache.iceberg.FileFormat; +import org.apache.iceberg.Schema; +import org.apache.iceberg.SortOrder; +import org.apache.iceberg.Table; +import org.apache.iceberg.avro.Avro; +import org.apache.iceberg.data.BaseFileWriterFactory; +import org.apache.iceberg.io.DeleteSchemaUtil; +import org.apache.iceberg.orc.ORC; +import org.apache.iceberg.parquet.Parquet; +import org.apache.iceberg.relocated.com.google.common.base.Preconditions; +import org.apache.iceberg.spark.SparkSchemaUtil; +import org.apache.iceberg.spark.data.SparkAvroWriter; +import org.apache.iceberg.spark.data.SparkOrcWriter; +import org.apache.iceberg.spark.data.SparkParquetWriters; +import org.apache.spark.sql.catalyst.InternalRow; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; +import org.apache.spark.unsafe.types.UTF8String; + +class SparkFileWriterFactory extends BaseFileWriterFactory { + private StructType dataSparkType; + private StructType equalityDeleteSparkType; + private StructType positionDeleteSparkType; + + SparkFileWriterFactory( + Table table, + FileFormat dataFileFormat, + Schema dataSchema, + StructType dataSparkType, + SortOrder dataSortOrder, + FileFormat deleteFileFormat, + int[] equalityFieldIds, + Schema equalityDeleteRowSchema, + StructType equalityDeleteSparkType, + SortOrder equalityDeleteSortOrder, + Schema positionDeleteRowSchema, + StructType positionDeleteSparkType) { + + super( + table, + dataFileFormat, + dataSchema, + dataSortOrder, + deleteFileFormat, + equalityFieldIds, + equalityDeleteRowSchema, + equalityDeleteSortOrder, + positionDeleteRowSchema); + + this.dataSparkType = dataSparkType; + this.equalityDeleteSparkType = equalityDeleteSparkType; + this.positionDeleteSparkType = positionDeleteSparkType; + } + + static Builder builderFor(Table table) { + return new Builder(table); + } + + @Override + protected void configureDataWrite(Avro.DataWriteBuilder builder) { + builder.createWriterFunc(ignored -> new SparkAvroWriter(dataSparkType())); + } + + @Override + protected void configureEqualityDelete(Avro.DeleteWriteBuilder builder) { + builder.createWriterFunc(ignored -> new SparkAvroWriter(equalityDeleteSparkType())); + } + + @Override + protected void configurePositionDelete(Avro.DeleteWriteBuilder builder) { + boolean withRow = + positionDeleteSparkType().getFieldIndex(DELETE_FILE_ROW_FIELD_NAME).isDefined(); + if (withRow) { + // SparkAvroWriter accepts just the Spark type of the row ignoring the path and pos + StructField rowField = positionDeleteSparkType().apply(DELETE_FILE_ROW_FIELD_NAME); + StructType positionDeleteRowSparkType = (StructType) rowField.dataType(); + builder.createWriterFunc(ignored -> new SparkAvroWriter(positionDeleteRowSparkType)); + } + } + + @Override + protected void configureDataWrite(Parquet.DataWriteBuilder builder) { + builder.createWriterFunc(msgType -> SparkParquetWriters.buildWriter(dataSparkType(), msgType)); + } + + @Override + protected void configureEqualityDelete(Parquet.DeleteWriteBuilder builder) { + builder.createWriterFunc( + msgType -> SparkParquetWriters.buildWriter(equalityDeleteSparkType(), msgType)); + } + + @Override + protected void configurePositionDelete(Parquet.DeleteWriteBuilder builder) { + builder.createWriterFunc( + msgType -> SparkParquetWriters.buildWriter(positionDeleteSparkType(), msgType)); + builder.transformPaths(path -> UTF8String.fromString(path.toString())); + } + + @Override + protected void configureDataWrite(ORC.DataWriteBuilder builder) { + builder.createWriterFunc(SparkOrcWriter::new); + } + + @Override + protected void configureEqualityDelete(ORC.DeleteWriteBuilder builder) { + builder.createWriterFunc(SparkOrcWriter::new); + } + + @Override + protected void configurePositionDelete(ORC.DeleteWriteBuilder builder) { + builder.createWriterFunc(SparkOrcWriter::new); + builder.transformPaths(path -> UTF8String.fromString(path.toString())); + } + + private StructType dataSparkType() { + if (dataSparkType == null) { + Preconditions.checkNotNull(dataSchema(), "Data schema must not be null"); + this.dataSparkType = SparkSchemaUtil.convert(dataSchema()); + } + + return dataSparkType; + } + + private StructType equalityDeleteSparkType() { + if (equalityDeleteSparkType == null) { + Preconditions.checkNotNull( + equalityDeleteRowSchema(), "Equality delete schema must not be null"); + this.equalityDeleteSparkType = SparkSchemaUtil.convert(equalityDeleteRowSchema()); + } + + return equalityDeleteSparkType; + } + + private StructType positionDeleteSparkType() { + if (positionDeleteSparkType == null) { + // wrap the optional row schema into the position delete schema containing path and position + Schema positionDeleteSchema = DeleteSchemaUtil.posDeleteSchema(positionDeleteRowSchema()); + this.positionDeleteSparkType = SparkSchemaUtil.convert(positionDeleteSchema); + } + + return positionDeleteSparkType; + } + + static class Builder { + private final Table table; + private FileFormat dataFileFormat; + private Schema dataSchema; + private StructType dataSparkType; + private SortOrder dataSortOrder; + private FileFormat deleteFileFormat; + private int[] equalityFieldIds; + private Schema equalityDeleteRowSchema; + private StructType equalityDeleteSparkType; + private SortOrder equalityDeleteSortOrder; + private Schema positionDeleteRowSchema; + private StructType positionDeleteSparkType; + + Builder(Table table) { + this.table = table; + + Map properties = table.properties(); + + String dataFileFormatName = + properties.getOrDefault(DEFAULT_FILE_FORMAT, DEFAULT_FILE_FORMAT_DEFAULT); + this.dataFileFormat = FileFormat.fromString(dataFileFormatName); + + String deleteFileFormatName = + properties.getOrDefault(DELETE_DEFAULT_FILE_FORMAT, dataFileFormatName); + this.deleteFileFormat = FileFormat.fromString(deleteFileFormatName); + } + + Builder dataFileFormat(FileFormat newDataFileFormat) { + this.dataFileFormat = newDataFileFormat; + return this; + } + + Builder dataSchema(Schema newDataSchema) { + this.dataSchema = newDataSchema; + return this; + } + + Builder dataSparkType(StructType newDataSparkType) { + this.dataSparkType = newDataSparkType; + return this; + } + + Builder dataSortOrder(SortOrder newDataSortOrder) { + this.dataSortOrder = newDataSortOrder; + return this; + } + + Builder deleteFileFormat(FileFormat newDeleteFileFormat) { + this.deleteFileFormat = newDeleteFileFormat; + return this; + } + + Builder equalityFieldIds(int[] newEqualityFieldIds) { + this.equalityFieldIds = newEqualityFieldIds; + return this; + } + + Builder equalityDeleteRowSchema(Schema newEqualityDeleteRowSchema) { + this.equalityDeleteRowSchema = newEqualityDeleteRowSchema; + return this; + } + + Builder equalityDeleteSparkType(StructType newEqualityDeleteSparkType) { + this.equalityDeleteSparkType = newEqualityDeleteSparkType; + return this; + } + + Builder equalityDeleteSortOrder(SortOrder newEqualityDeleteSortOrder) { + this.equalityDeleteSortOrder = newEqualityDeleteSortOrder; + return this; + } + + Builder positionDeleteRowSchema(Schema newPositionDeleteRowSchema) { + this.positionDeleteRowSchema = newPositionDeleteRowSchema; + return this; + } + + Builder positionDeleteSparkType(StructType newPositionDeleteSparkType) { + this.positionDeleteSparkType = newPositionDeleteSparkType; + return this; + } + + SparkFileWriterFactory build() { + boolean noEqualityDeleteConf = equalityFieldIds == null && equalityDeleteRowSchema == null; + boolean fullEqualityDeleteConf = equalityFieldIds != null && equalityDeleteRowSchema != null; + Preconditions.checkArgument( + noEqualityDeleteConf || fullEqualityDeleteConf, + "Equality field IDs and equality delete row schema must be set together"); + + return new SparkFileWriterFactory( + table, + dataFileFormat, + dataSchema, + dataSparkType, + dataSortOrder, + deleteFileFormat, + equalityFieldIds, + equalityDeleteRowSchema, + equalityDeleteSparkType, + equalityDeleteSortOrder, + positionDeleteRowSchema, + positionDeleteSparkType); + } + } +} diff --git a/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/source/SparkInputPartition.java b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/source/SparkInputPartition.java new file mode 100644 index 000000000000..0394b691e152 --- /dev/null +++ b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/source/SparkInputPartition.java @@ -0,0 +1,106 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.source; + +import java.io.Serializable; +import org.apache.iceberg.ScanTask; +import org.apache.iceberg.ScanTaskGroup; +import org.apache.iceberg.Schema; +import org.apache.iceberg.SchemaParser; +import org.apache.iceberg.Table; +import org.apache.iceberg.hadoop.HadoopInputFile; +import org.apache.iceberg.hadoop.Util; +import org.apache.iceberg.types.Types; +import org.apache.spark.broadcast.Broadcast; +import org.apache.spark.sql.catalyst.InternalRow; +import org.apache.spark.sql.connector.read.HasPartitionKey; +import org.apache.spark.sql.connector.read.InputPartition; + +class SparkInputPartition implements InputPartition, HasPartitionKey, Serializable { + private final Types.StructType groupingKeyType; + private final ScanTaskGroup taskGroup; + private final Broadcast

tableBroadcast; + private final String branch; + private final String expectedSchemaString; + private final boolean caseSensitive; + + private transient Schema expectedSchema = null; + private transient String[] preferredLocations = null; + + SparkInputPartition( + Types.StructType groupingKeyType, + ScanTaskGroup taskGroup, + Broadcast
tableBroadcast, + String branch, + String expectedSchemaString, + boolean caseSensitive, + boolean localityPreferred) { + this.groupingKeyType = groupingKeyType; + this.taskGroup = taskGroup; + this.tableBroadcast = tableBroadcast; + this.branch = branch; + this.expectedSchemaString = expectedSchemaString; + this.caseSensitive = caseSensitive; + if (localityPreferred) { + Table table = tableBroadcast.value(); + this.preferredLocations = Util.blockLocations(table.io(), taskGroup); + } else { + this.preferredLocations = HadoopInputFile.NO_LOCATION_PREFERENCE; + } + } + + @Override + public String[] preferredLocations() { + return preferredLocations; + } + + @Override + public InternalRow partitionKey() { + return new StructInternalRow(groupingKeyType).setStruct(taskGroup.groupingKey()); + } + + @SuppressWarnings("unchecked") + public ScanTaskGroup taskGroup() { + return (ScanTaskGroup) taskGroup; + } + + public boolean allTasksOfType(Class javaClass) { + return taskGroup.tasks().stream().allMatch(javaClass::isInstance); + } + + public Table table() { + return tableBroadcast.value(); + } + + public String branch() { + return branch; + } + + public boolean isCaseSensitive() { + return caseSensitive; + } + + public Schema expectedSchema() { + if (expectedSchema == null) { + this.expectedSchema = SchemaParser.fromJson(expectedSchemaString); + } + + return expectedSchema; + } +} diff --git a/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/source/SparkLocalScan.java b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/source/SparkLocalScan.java new file mode 100644 index 000000000000..c2f9707775dd --- /dev/null +++ b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/source/SparkLocalScan.java @@ -0,0 +1,66 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.source; + +import java.util.List; +import org.apache.iceberg.Table; +import org.apache.iceberg.expressions.Expression; +import org.apache.iceberg.spark.Spark3Util; +import org.apache.iceberg.spark.SparkSchemaUtil; +import org.apache.spark.sql.catalyst.InternalRow; +import org.apache.spark.sql.connector.read.LocalScan; +import org.apache.spark.sql.types.StructType; + +class SparkLocalScan implements LocalScan { + + private final Table table; + private final StructType readSchema; + private final InternalRow[] rows; + private final List filterExpressions; + + SparkLocalScan( + Table table, StructType readSchema, InternalRow[] rows, List filterExpressions) { + this.table = table; + this.readSchema = readSchema; + this.rows = rows; + this.filterExpressions = filterExpressions; + } + + @Override + public InternalRow[] rows() { + return rows; + } + + @Override + public StructType readSchema() { + return readSchema; + } + + @Override + public String description() { + return String.format("%s [filters=%s]", table, Spark3Util.describe(filterExpressions)); + } + + @Override + public String toString() { + return String.format( + "IcebergLocalScan(table=%s, type=%s, filters=%s)", + table, SparkSchemaUtil.convert(readSchema).asStruct(), filterExpressions); + } +} diff --git a/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/source/SparkMetadataColumn.java b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/source/SparkMetadataColumn.java new file mode 100644 index 000000000000..94f87c28741d --- /dev/null +++ b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/source/SparkMetadataColumn.java @@ -0,0 +1,50 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.source; + +import org.apache.spark.sql.connector.catalog.MetadataColumn; +import org.apache.spark.sql.types.DataType; + +public class SparkMetadataColumn implements MetadataColumn { + + private final String name; + private final DataType dataType; + private final boolean isNullable; + + public SparkMetadataColumn(String name, DataType dataType, boolean isNullable) { + this.name = name; + this.dataType = dataType; + this.isNullable = isNullable; + } + + @Override + public String name() { + return name; + } + + @Override + public DataType dataType() { + return dataType; + } + + @Override + public boolean isNullable() { + return isNullable; + } +} diff --git a/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/source/SparkMicroBatchStream.java b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/source/SparkMicroBatchStream.java new file mode 100644 index 000000000000..6e03dd69a850 --- /dev/null +++ b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/source/SparkMicroBatchStream.java @@ -0,0 +1,347 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.source; + +import java.io.BufferedWriter; +import java.io.IOException; +import java.io.InputStream; +import java.io.OutputStream; +import java.io.OutputStreamWriter; +import java.io.UncheckedIOException; +import java.nio.charset.StandardCharsets; +import java.util.List; +import java.util.Locale; +import org.apache.iceberg.CombinedScanTask; +import org.apache.iceberg.DataOperations; +import org.apache.iceberg.FileScanTask; +import org.apache.iceberg.MicroBatches; +import org.apache.iceberg.MicroBatches.MicroBatch; +import org.apache.iceberg.Schema; +import org.apache.iceberg.SchemaParser; +import org.apache.iceberg.Snapshot; +import org.apache.iceberg.SnapshotSummary; +import org.apache.iceberg.Table; +import org.apache.iceberg.io.CloseableIterable; +import org.apache.iceberg.io.FileIO; +import org.apache.iceberg.io.InputFile; +import org.apache.iceberg.io.OutputFile; +import org.apache.iceberg.relocated.com.google.common.base.Joiner; +import org.apache.iceberg.relocated.com.google.common.base.Preconditions; +import org.apache.iceberg.relocated.com.google.common.collect.Iterables; +import org.apache.iceberg.relocated.com.google.common.collect.Lists; +import org.apache.iceberg.spark.SparkReadConf; +import org.apache.iceberg.spark.SparkReadOptions; +import org.apache.iceberg.types.Types; +import org.apache.iceberg.util.PropertyUtil; +import org.apache.iceberg.util.SnapshotUtil; +import org.apache.iceberg.util.TableScanUtil; +import org.apache.iceberg.util.Tasks; +import org.apache.iceberg.util.ThreadPools; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.broadcast.Broadcast; +import org.apache.spark.sql.connector.read.InputPartition; +import org.apache.spark.sql.connector.read.PartitionReaderFactory; +import org.apache.spark.sql.connector.read.streaming.MicroBatchStream; +import org.apache.spark.sql.connector.read.streaming.Offset; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +public class SparkMicroBatchStream implements MicroBatchStream { + private static final Joiner SLASH = Joiner.on("/"); + private static final Logger LOG = LoggerFactory.getLogger(SparkMicroBatchStream.class); + private static final Types.StructType EMPTY_GROUPING_KEY_TYPE = Types.StructType.of(); + + private final Table table; + private final String branch; + private final boolean caseSensitive; + private final String expectedSchema; + private final Broadcast
tableBroadcast; + private final Long splitSize; + private final Integer splitLookback; + private final Long splitOpenFileCost; + private final boolean localityPreferred; + private final StreamingOffset initialOffset; + private final boolean skipDelete; + private final boolean skipOverwrite; + private final Long fromTimestamp; + + SparkMicroBatchStream( + JavaSparkContext sparkContext, + Table table, + SparkReadConf readConf, + Schema expectedSchema, + String checkpointLocation) { + this.table = table; + this.branch = readConf.branch(); + this.caseSensitive = readConf.caseSensitive(); + this.expectedSchema = SchemaParser.toJson(expectedSchema); + this.localityPreferred = readConf.localityEnabled(); + this.tableBroadcast = sparkContext.broadcast(SerializableTableWithSize.copyOf(table)); + this.splitSize = readConf.splitSize(); + this.splitLookback = readConf.splitLookback(); + this.splitOpenFileCost = readConf.splitOpenFileCost(); + this.fromTimestamp = readConf.streamFromTimestamp(); + + InitialOffsetStore initialOffsetStore = + new InitialOffsetStore(table, checkpointLocation, fromTimestamp); + this.initialOffset = initialOffsetStore.initialOffset(); + + this.skipDelete = readConf.streamingSkipDeleteSnapshots(); + this.skipOverwrite = readConf.streamingSkipOverwriteSnapshots(); + } + + @Override + public Offset latestOffset() { + table.refresh(); + if (table.currentSnapshot() == null) { + return StreamingOffset.START_OFFSET; + } + + if (table.currentSnapshot().timestampMillis() < fromTimestamp) { + return StreamingOffset.START_OFFSET; + } + + Snapshot latestSnapshot = table.currentSnapshot(); + long addedFilesCount = + PropertyUtil.propertyAsLong(latestSnapshot.summary(), SnapshotSummary.ADDED_FILES_PROP, -1); + // if the latest snapshot summary doesn't contain SnapshotSummary.ADDED_FILES_PROP, + // iterate through addedDataFiles to compute addedFilesCount + addedFilesCount = + addedFilesCount == -1 + ? Iterables.size(latestSnapshot.addedDataFiles(table.io())) + : addedFilesCount; + + return new StreamingOffset(latestSnapshot.snapshotId(), addedFilesCount, false); + } + + @Override + public InputPartition[] planInputPartitions(Offset start, Offset end) { + Preconditions.checkArgument( + end instanceof StreamingOffset, "Invalid end offset: %s is not a StreamingOffset", end); + Preconditions.checkArgument( + start instanceof StreamingOffset, + "Invalid start offset: %s is not a StreamingOffset", + start); + + if (end.equals(StreamingOffset.START_OFFSET)) { + return new InputPartition[0]; + } + + StreamingOffset endOffset = (StreamingOffset) end; + StreamingOffset startOffset = (StreamingOffset) start; + + List fileScanTasks = planFiles(startOffset, endOffset); + + CloseableIterable splitTasks = + TableScanUtil.splitFiles(CloseableIterable.withNoopClose(fileScanTasks), splitSize); + List combinedScanTasks = + Lists.newArrayList( + TableScanUtil.planTasks(splitTasks, splitSize, splitLookback, splitOpenFileCost)); + + InputPartition[] partitions = new InputPartition[combinedScanTasks.size()]; + + Tasks.range(partitions.length) + .stopOnFailure() + .executeWith(localityPreferred ? ThreadPools.getWorkerPool() : null) + .run( + index -> + partitions[index] = + new SparkInputPartition( + EMPTY_GROUPING_KEY_TYPE, + combinedScanTasks.get(index), + tableBroadcast, + branch, + expectedSchema, + caseSensitive, + localityPreferred)); + + return partitions; + } + + @Override + public PartitionReaderFactory createReaderFactory() { + return new SparkRowReaderFactory(); + } + + @Override + public Offset initialOffset() { + return initialOffset; + } + + @Override + public Offset deserializeOffset(String json) { + return StreamingOffset.fromJson(json); + } + + @Override + public void commit(Offset end) {} + + @Override + public void stop() {} + + private List planFiles(StreamingOffset startOffset, StreamingOffset endOffset) { + List fileScanTasks = Lists.newArrayList(); + StreamingOffset batchStartOffset = + StreamingOffset.START_OFFSET.equals(startOffset) + ? determineStartingOffset(table, fromTimestamp) + : startOffset; + + StreamingOffset currentOffset = null; + + do { + if (currentOffset == null) { + currentOffset = batchStartOffset; + } else { + Snapshot snapshotAfter = SnapshotUtil.snapshotAfter(table, currentOffset.snapshotId()); + currentOffset = new StreamingOffset(snapshotAfter.snapshotId(), 0L, false); + } + + Snapshot snapshot = table.snapshot(currentOffset.snapshotId()); + + if (snapshot == null) { + throw new IllegalStateException( + String.format( + "Cannot load current offset at snapshot %d, the snapshot was expired or removed", + currentOffset.snapshotId())); + } + + if (!shouldProcess(snapshot)) { + LOG.debug("Skipping snapshot: {} of table {}", currentOffset.snapshotId(), table.name()); + continue; + } + + MicroBatch latestMicroBatch = + MicroBatches.from(table.snapshot(currentOffset.snapshotId()), table.io()) + .caseSensitive(caseSensitive) + .specsById(table.specs()) + .generate( + currentOffset.position(), Long.MAX_VALUE, currentOffset.shouldScanAllFiles()); + + fileScanTasks.addAll(latestMicroBatch.tasks()); + } while (currentOffset.snapshotId() != endOffset.snapshotId()); + + return fileScanTasks; + } + + private boolean shouldProcess(Snapshot snapshot) { + String op = snapshot.operation(); + switch (op) { + case DataOperations.APPEND: + return true; + case DataOperations.REPLACE: + return false; + case DataOperations.DELETE: + Preconditions.checkState( + skipDelete, + "Cannot process delete snapshot: %s, to ignore deletes, set %s=true", + snapshot.snapshotId(), + SparkReadOptions.STREAMING_SKIP_DELETE_SNAPSHOTS); + return false; + case DataOperations.OVERWRITE: + Preconditions.checkState( + skipOverwrite, + "Cannot process overwrite snapshot: %s, to ignore overwrites, set %s=true", + snapshot.snapshotId(), + SparkReadOptions.STREAMING_SKIP_OVERWRITE_SNAPSHOTS); + return false; + default: + throw new IllegalStateException( + String.format( + "Cannot process unknown snapshot operation: %s (snapshot id %s)", + op.toLowerCase(Locale.ROOT), snapshot.snapshotId())); + } + } + + private static StreamingOffset determineStartingOffset(Table table, Long fromTimestamp) { + if (table.currentSnapshot() == null) { + return StreamingOffset.START_OFFSET; + } + + if (fromTimestamp == null) { + // match existing behavior and start from the oldest snapshot + return new StreamingOffset(SnapshotUtil.oldestAncestor(table).snapshotId(), 0, false); + } + + if (table.currentSnapshot().timestampMillis() < fromTimestamp) { + return StreamingOffset.START_OFFSET; + } + + try { + Snapshot snapshot = SnapshotUtil.oldestAncestorAfter(table, fromTimestamp); + if (snapshot != null) { + return new StreamingOffset(snapshot.snapshotId(), 0, false); + } else { + return StreamingOffset.START_OFFSET; + } + } catch (IllegalStateException e) { + // could not determine the first snapshot after the timestamp. use the oldest ancestor instead + return new StreamingOffset(SnapshotUtil.oldestAncestor(table).snapshotId(), 0, false); + } + } + + private static class InitialOffsetStore { + private final Table table; + private final FileIO io; + private final String initialOffsetLocation; + private final Long fromTimestamp; + + InitialOffsetStore(Table table, String checkpointLocation, Long fromTimestamp) { + this.table = table; + this.io = table.io(); + this.initialOffsetLocation = SLASH.join(checkpointLocation, "offsets/0"); + this.fromTimestamp = fromTimestamp; + } + + public StreamingOffset initialOffset() { + InputFile inputFile = io.newInputFile(initialOffsetLocation); + if (inputFile.exists()) { + return readOffset(inputFile); + } + + table.refresh(); + StreamingOffset offset = determineStartingOffset(table, fromTimestamp); + + OutputFile outputFile = io.newOutputFile(initialOffsetLocation); + writeOffset(offset, outputFile); + + return offset; + } + + private void writeOffset(StreamingOffset offset, OutputFile file) { + try (OutputStream outputStream = file.create()) { + BufferedWriter writer = + new BufferedWriter(new OutputStreamWriter(outputStream, StandardCharsets.UTF_8)); + writer.write(offset.json()); + writer.flush(); + } catch (IOException ioException) { + throw new UncheckedIOException( + String.format("Failed writing offset to: %s", initialOffsetLocation), ioException); + } + } + + private StreamingOffset readOffset(InputFile file) { + try (InputStream in = file.newStream()) { + return StreamingOffset.fromJson(in); + } catch (IOException ioException) { + throw new UncheckedIOException( + String.format("Failed reading offset from: %s", initialOffsetLocation), ioException); + } + } + } +} diff --git a/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/source/SparkPartitionedFanoutWriter.java b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/source/SparkPartitionedFanoutWriter.java new file mode 100644 index 000000000000..f17cd260f928 --- /dev/null +++ b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/source/SparkPartitionedFanoutWriter.java @@ -0,0 +1,55 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.source; + +import org.apache.iceberg.FileFormat; +import org.apache.iceberg.PartitionKey; +import org.apache.iceberg.PartitionSpec; +import org.apache.iceberg.Schema; +import org.apache.iceberg.io.FileAppenderFactory; +import org.apache.iceberg.io.FileIO; +import org.apache.iceberg.io.OutputFileFactory; +import org.apache.iceberg.io.PartitionedFanoutWriter; +import org.apache.spark.sql.catalyst.InternalRow; +import org.apache.spark.sql.types.StructType; + +public class SparkPartitionedFanoutWriter extends PartitionedFanoutWriter { + private final PartitionKey partitionKey; + private final InternalRowWrapper internalRowWrapper; + + public SparkPartitionedFanoutWriter( + PartitionSpec spec, + FileFormat format, + FileAppenderFactory appenderFactory, + OutputFileFactory fileFactory, + FileIO io, + long targetFileSize, + Schema schema, + StructType sparkSchema) { + super(spec, format, appenderFactory, fileFactory, io, targetFileSize); + this.partitionKey = new PartitionKey(spec, schema); + this.internalRowWrapper = new InternalRowWrapper(sparkSchema); + } + + @Override + protected PartitionKey partition(InternalRow row) { + partitionKey.partition(internalRowWrapper.wrap(row)); + return partitionKey; + } +} diff --git a/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/source/SparkPartitionedWriter.java b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/source/SparkPartitionedWriter.java new file mode 100644 index 000000000000..a86091644360 --- /dev/null +++ b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/source/SparkPartitionedWriter.java @@ -0,0 +1,55 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.source; + +import org.apache.iceberg.FileFormat; +import org.apache.iceberg.PartitionKey; +import org.apache.iceberg.PartitionSpec; +import org.apache.iceberg.Schema; +import org.apache.iceberg.io.FileAppenderFactory; +import org.apache.iceberg.io.FileIO; +import org.apache.iceberg.io.OutputFileFactory; +import org.apache.iceberg.io.PartitionedWriter; +import org.apache.spark.sql.catalyst.InternalRow; +import org.apache.spark.sql.types.StructType; + +public class SparkPartitionedWriter extends PartitionedWriter { + private final PartitionKey partitionKey; + private final InternalRowWrapper internalRowWrapper; + + public SparkPartitionedWriter( + PartitionSpec spec, + FileFormat format, + FileAppenderFactory appenderFactory, + OutputFileFactory fileFactory, + FileIO io, + long targetFileSize, + Schema schema, + StructType sparkSchema) { + super(spec, format, appenderFactory, fileFactory, io, targetFileSize); + this.partitionKey = new PartitionKey(spec, schema); + this.internalRowWrapper = new InternalRowWrapper(sparkSchema); + } + + @Override + protected PartitionKey partition(InternalRow row) { + partitionKey.partition(internalRowWrapper.wrap(row)); + return partitionKey; + } +} diff --git a/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/source/SparkPartitioningAwareScan.java b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/source/SparkPartitioningAwareScan.java new file mode 100644 index 000000000000..4c7a02543abe --- /dev/null +++ b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/source/SparkPartitioningAwareScan.java @@ -0,0 +1,282 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.source; + +import java.io.IOException; +import java.io.UncheckedIOException; +import java.util.Collections; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.stream.Collectors; +import java.util.stream.IntStream; +import org.apache.iceberg.BaseScanTaskGroup; +import org.apache.iceberg.PartitionField; +import org.apache.iceberg.PartitionScanTask; +import org.apache.iceberg.PartitionSpec; +import org.apache.iceberg.Scan; +import org.apache.iceberg.ScanTask; +import org.apache.iceberg.ScanTaskGroup; +import org.apache.iceberg.Schema; +import org.apache.iceberg.StructLike; +import org.apache.iceberg.Table; +import org.apache.iceberg.exceptions.ValidationException; +import org.apache.iceberg.expressions.Expression; +import org.apache.iceberg.io.CloseableIterable; +import org.apache.iceberg.relocated.com.google.common.collect.Lists; +import org.apache.iceberg.relocated.com.google.common.collect.Maps; +import org.apache.iceberg.spark.Spark3Util; +import org.apache.iceberg.spark.SparkReadConf; +import org.apache.iceberg.types.Types.StructType; +import org.apache.iceberg.util.SnapshotUtil; +import org.apache.iceberg.util.StructLikeSet; +import org.apache.iceberg.util.TableScanUtil; +import org.apache.spark.sql.SparkSession; +import org.apache.spark.sql.connector.expressions.Transform; +import org.apache.spark.sql.connector.read.SupportsReportPartitioning; +import org.apache.spark.sql.connector.read.partitioning.KeyGroupedPartitioning; +import org.apache.spark.sql.connector.read.partitioning.Partitioning; +import org.apache.spark.sql.connector.read.partitioning.UnknownPartitioning; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +abstract class SparkPartitioningAwareScan extends SparkScan + implements SupportsReportPartitioning { + + private static final Logger LOG = LoggerFactory.getLogger(SparkPartitioningAwareScan.class); + + private final Scan> scan; + private final boolean preserveDataGrouping; + + private Set specs = null; // lazy cache of scanned specs + private List tasks = null; // lazy cache of uncombined tasks + private List> taskGroups = null; // lazy cache of task groups + private StructType groupingKeyType = null; // lazy cache of the grouping key type + private Transform[] groupingKeyTransforms = null; // lazy cache of grouping key transforms + private StructLikeSet groupingKeys = null; // lazy cache of grouping keys + + SparkPartitioningAwareScan( + SparkSession spark, + Table table, + Scan> scan, + SparkReadConf readConf, + Schema expectedSchema, + List filters) { + + super(spark, table, readConf, expectedSchema, filters); + + this.scan = scan; + this.preserveDataGrouping = readConf.preserveDataGrouping(); + + if (scan == null) { + this.specs = Collections.emptySet(); + this.tasks = Collections.emptyList(); + this.taskGroups = Collections.emptyList(); + } + } + + protected abstract Class taskJavaClass(); + + protected Scan> scan() { + return scan; + } + + @Override + public Partitioning outputPartitioning() { + if (groupingKeyType().fields().isEmpty()) { + LOG.info( + "Reporting UnknownPartitioning with {} partition(s) for table {}", + taskGroups().size(), + table().name()); + return new UnknownPartitioning(taskGroups().size()); + } else { + LOG.info( + "Reporting KeyGroupedPartitioning by {} with {} partition(s) for table {}", + groupingKeyTransforms(), + taskGroups().size(), + table().name()); + return new KeyGroupedPartitioning(groupingKeyTransforms(), taskGroups().size()); + } + } + + @Override + protected StructType groupingKeyType() { + if (groupingKeyType == null) { + if (preserveDataGrouping) { + this.groupingKeyType = computeGroupingKeyType(); + } else { + this.groupingKeyType = StructType.of(); + } + } + + return groupingKeyType; + } + + private StructType computeGroupingKeyType() { + return org.apache.iceberg.Partitioning.groupingKeyType(expectedSchema(), specs()); + } + + private Transform[] groupingKeyTransforms() { + if (groupingKeyTransforms == null) { + Map fieldsById = indexFieldsById(specs()); + + List groupingKeyFields = + groupingKeyType().fields().stream() + .map(field -> fieldsById.get(field.fieldId())) + .collect(Collectors.toList()); + + Schema schema = SnapshotUtil.schemaFor(table(), branch()); + this.groupingKeyTransforms = Spark3Util.toTransforms(schema, groupingKeyFields); + } + + return groupingKeyTransforms; + } + + private Map indexFieldsById(Iterable specIterable) { + Map fieldsById = Maps.newHashMap(); + + for (PartitionSpec spec : specIterable) { + for (PartitionField field : spec.fields()) { + fieldsById.putIfAbsent(field.fieldId(), field); + } + } + + return fieldsById; + } + + protected Set specs() { + if (specs == null) { + // avoid calling equals/hashCode on specs as those methods are relatively expensive + IntStream specIds = tasks().stream().mapToInt(task -> task.spec().specId()).distinct(); + this.specs = specIds.mapToObj(id -> table().specs().get(id)).collect(Collectors.toSet()); + } + + return specs; + } + + protected synchronized List tasks() { + if (tasks == null) { + try (CloseableIterable taskIterable = scan.planFiles()) { + List plannedTasks = Lists.newArrayList(); + + for (ScanTask task : taskIterable) { + ValidationException.check( + taskJavaClass().isInstance(task), + "Unsupported task type, expected a subtype of %s: %", + taskJavaClass().getName(), + task.getClass().getName()); + + plannedTasks.add(taskJavaClass().cast(task)); + } + + this.tasks = plannedTasks; + } catch (IOException e) { + throw new UncheckedIOException("Failed to close scan: " + scan, e); + } + } + + return tasks; + } + + @Override + protected synchronized List> taskGroups() { + if (taskGroups == null) { + if (groupingKeyType().fields().isEmpty()) { + CloseableIterable> plannedTaskGroups = + TableScanUtil.planTaskGroups( + CloseableIterable.withNoopClose(tasks()), + scan.targetSplitSize(), + scan.splitLookback(), + scan.splitOpenFileCost()); + this.taskGroups = Lists.newArrayList(plannedTaskGroups); + + LOG.debug( + "Planned {} task group(s) without data grouping for table {}", + taskGroups.size(), + table().name()); + + } else { + List> plannedTaskGroups = + TableScanUtil.planTaskGroups( + tasks(), + scan.targetSplitSize(), + scan.splitLookback(), + scan.splitOpenFileCost(), + groupingKeyType()); + StructLikeSet plannedGroupingKeys = collectGroupingKeys(plannedTaskGroups); + + LOG.debug( + "Planned {} task group(s) with {} grouping key type and {} unique grouping key(s) for table {}", + plannedTaskGroups.size(), + groupingKeyType(), + plannedGroupingKeys.size(), + table().name()); + + // task groups may be planned multiple times because of runtime filtering + // the number of task groups may change but the set of grouping keys must stay same + // if grouping keys are not null, this planning happens after runtime filtering + // so an empty task group must be added for each filtered out grouping key + + if (groupingKeys == null) { + this.taskGroups = plannedTaskGroups; + this.groupingKeys = plannedGroupingKeys; + + } else { + StructLikeSet missingGroupingKeys = StructLikeSet.create(groupingKeyType()); + + for (StructLike groupingKey : groupingKeys) { + if (!plannedGroupingKeys.contains(groupingKey)) { + missingGroupingKeys.add(groupingKey); + } + } + + LOG.debug( + "{} grouping key(s) were filtered out at runtime for table {}", + missingGroupingKeys.size(), + table().name()); + + for (StructLike groupingKey : missingGroupingKeys) { + plannedTaskGroups.add(new BaseScanTaskGroup<>(groupingKey, Collections.emptyList())); + } + + this.taskGroups = plannedTaskGroups; + } + } + } + + return taskGroups; + } + + // only task groups can be reset while resetting tasks + // the set of scanned specs, grouping key type, grouping keys must never change + protected void resetTasks(List filteredTasks) { + this.taskGroups = null; + this.tasks = filteredTasks; + } + + private StructLikeSet collectGroupingKeys(Iterable> taskGroupIterable) { + StructLikeSet keys = StructLikeSet.create(groupingKeyType()); + + for (ScanTaskGroup taskGroup : taskGroupIterable) { + keys.add(taskGroup.groupingKey()); + } + + return keys; + } +} diff --git a/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/source/SparkPositionDeletesRewrite.java b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/source/SparkPositionDeletesRewrite.java new file mode 100644 index 000000000000..0aebb6bdb2fd --- /dev/null +++ b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/source/SparkPositionDeletesRewrite.java @@ -0,0 +1,413 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.source; + +import java.io.IOException; +import java.util.Arrays; +import java.util.List; +import org.apache.iceberg.DeleteFile; +import org.apache.iceberg.FileFormat; +import org.apache.iceberg.MetadataColumns; +import org.apache.iceberg.PartitionSpec; +import org.apache.iceberg.PositionDeletesTable; +import org.apache.iceberg.Schema; +import org.apache.iceberg.StructLike; +import org.apache.iceberg.Table; +import org.apache.iceberg.deletes.PositionDelete; +import org.apache.iceberg.io.ClusteredPositionDeleteWriter; +import org.apache.iceberg.io.FileIO; +import org.apache.iceberg.io.OutputFileFactory; +import org.apache.iceberg.relocated.com.google.common.base.Preconditions; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableSet; +import org.apache.iceberg.relocated.com.google.common.collect.Lists; +import org.apache.iceberg.spark.PositionDeletesRewriteCoordinator; +import org.apache.iceberg.spark.ScanTaskSetManager; +import org.apache.iceberg.spark.SparkWriteConf; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.broadcast.Broadcast; +import org.apache.spark.sql.SparkSession; +import org.apache.spark.sql.catalyst.InternalRow; +import org.apache.spark.sql.connector.write.BatchWrite; +import org.apache.spark.sql.connector.write.DataWriter; +import org.apache.spark.sql.connector.write.DataWriterFactory; +import org.apache.spark.sql.connector.write.LogicalWriteInfo; +import org.apache.spark.sql.connector.write.PhysicalWriteInfo; +import org.apache.spark.sql.connector.write.Write; +import org.apache.spark.sql.connector.write.WriterCommitMessage; +import org.apache.spark.sql.types.DataType; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; + +/** + * {@link Write} class for rewriting position delete files from Spark. Responsible for creating + * {@link PositionDeleteBatchWrite}. + * + *

This class is meant to be used for an action to rewrite position delete files. Hence, it + * assumes all position deletes to rewrite have come from {@link ScanTaskSetManager} and that all + * have the same partition spec id and partition values. + */ +public class SparkPositionDeletesRewrite implements Write { + + private final JavaSparkContext sparkContext; + private final Table table; + private final String queryId; + private final FileFormat format; + private final long targetFileSize; + private final Schema writeSchema; + private final StructType dsSchema; + private final String fileSetId; + private final int specId; + private final StructLike partition; + + /** + * Constructs a {@link SparkPositionDeletesRewrite}. + * + * @param spark Spark session + * @param table instance of {@link PositionDeletesTable} + * @param writeConf Spark write config + * @param writeInfo Spark write info + * @param writeSchema Iceberg output schema + * @param dsSchema schema of original incoming position deletes dataset + * @param specId spec id of position deletes + * @param partition partition value of position deletes + */ + SparkPositionDeletesRewrite( + SparkSession spark, + Table table, + SparkWriteConf writeConf, + LogicalWriteInfo writeInfo, + Schema writeSchema, + StructType dsSchema, + int specId, + StructLike partition) { + this.sparkContext = JavaSparkContext.fromSparkContext(spark.sparkContext()); + this.table = table; + this.queryId = writeInfo.queryId(); + this.format = writeConf.deleteFileFormat(); + this.targetFileSize = writeConf.targetDeleteFileSize(); + this.writeSchema = writeSchema; + this.dsSchema = dsSchema; + this.fileSetId = writeConf.rewrittenFileSetId(); + this.specId = specId; + this.partition = partition; + } + + @Override + public BatchWrite toBatch() { + return new PositionDeleteBatchWrite(); + } + + /** {@link BatchWrite} class for rewriting position deletes files from Spark */ + class PositionDeleteBatchWrite implements BatchWrite { + + @Override + public DataWriterFactory createBatchWriterFactory(PhysicalWriteInfo info) { + // broadcast the table metadata as the writer factory will be sent to executors + Broadcast

tableBroadcast = + sparkContext.broadcast(SerializableTableWithSize.copyOf(table)); + return new PositionDeletesWriterFactory( + tableBroadcast, + queryId, + format, + targetFileSize, + writeSchema, + dsSchema, + specId, + partition); + } + + @Override + public void commit(WriterCommitMessage[] messages) { + PositionDeletesRewriteCoordinator coordinator = PositionDeletesRewriteCoordinator.get(); + coordinator.stageRewrite(table, fileSetId, ImmutableSet.copyOf(files(messages))); + } + + @Override + public void abort(WriterCommitMessage[] messages) { + SparkCleanupUtil.deleteFiles("job abort", table.io(), files(messages)); + } + + private List files(WriterCommitMessage[] messages) { + List files = Lists.newArrayList(); + + for (WriterCommitMessage message : messages) { + if (message != null) { + DeleteTaskCommit taskCommit = (DeleteTaskCommit) message; + files.addAll(Arrays.asList(taskCommit.files())); + } + } + + return files; + } + } + + /** + * Writer factory for position deletes metadata table. Responsible for creating {@link + * DeleteWriter}. + * + *

This writer is meant to be used for an action to rewrite delete files. Hence, it makes an + * assumption that all incoming deletes belong to the same partition, and that incoming dataset is + * from {@link ScanTaskSetManager}. + */ + static class PositionDeletesWriterFactory implements DataWriterFactory { + private final Broadcast

tableBroadcast; + private final String queryId; + private final FileFormat format; + private final Long targetFileSize; + private final Schema writeSchema; + private final StructType dsSchema; + private final int specId; + private final StructLike partition; + + PositionDeletesWriterFactory( + Broadcast
tableBroadcast, + String queryId, + FileFormat format, + long targetFileSize, + Schema writeSchema, + StructType dsSchema, + int specId, + StructLike partition) { + this.tableBroadcast = tableBroadcast; + this.queryId = queryId; + this.format = format; + this.targetFileSize = targetFileSize; + this.writeSchema = writeSchema; + this.dsSchema = dsSchema; + this.specId = specId; + this.partition = partition; + } + + @Override + public DataWriter createWriter(int partitionId, long taskId) { + Table table = tableBroadcast.value(); + + OutputFileFactory deleteFileFactory = + OutputFileFactory.builderFor(table, partitionId, taskId) + .format(format) + .operationId(queryId) + .suffix("deletes") + .build(); + + Schema positionDeleteRowSchema = positionDeleteRowSchema(); + StructType deleteSparkType = deleteSparkType(); + StructType deleteSparkTypeWithoutRow = deleteSparkTypeWithoutRow(); + + SparkFileWriterFactory writerFactoryWithRow = + SparkFileWriterFactory.builderFor(table) + .deleteFileFormat(format) + .positionDeleteRowSchema(positionDeleteRowSchema) + .positionDeleteSparkType(deleteSparkType) + .build(); + SparkFileWriterFactory writerFactoryWithoutRow = + SparkFileWriterFactory.builderFor(table) + .deleteFileFormat(format) + .positionDeleteSparkType(deleteSparkTypeWithoutRow) + .build(); + + return new DeleteWriter( + table, + writerFactoryWithRow, + writerFactoryWithoutRow, + deleteFileFactory, + targetFileSize, + dsSchema, + specId, + partition); + } + + private Schema positionDeleteRowSchema() { + return new Schema( + writeSchema + .findField(MetadataColumns.DELETE_FILE_ROW_FIELD_NAME) + .type() + .asStructType() + .fields()); + } + + private StructType deleteSparkType() { + return new StructType( + new StructField[] { + dsSchema.apply(MetadataColumns.DELETE_FILE_PATH.name()), + dsSchema.apply(MetadataColumns.DELETE_FILE_POS.name()), + dsSchema.apply(MetadataColumns.DELETE_FILE_ROW_FIELD_NAME) + }); + } + + private StructType deleteSparkTypeWithoutRow() { + return new StructType( + new StructField[] { + dsSchema.apply(MetadataColumns.DELETE_FILE_PATH.name()), + dsSchema.apply(MetadataColumns.DELETE_FILE_POS.name()), + }); + } + } + + /** + * Writer for position deletes metadata table. + * + *

Iceberg specifies delete files schema as having either 'row' as a required field, or omits + * 'row' altogether. This is to ensure accuracy of delete file statistics on 'row' column. Hence, + * this writer, if receiving source position deletes with null and non-null rows, redirects rows + * with null 'row' to one file writer, and non-null 'row' to another file writer. + * + *

This writer is meant to be used for an action to rewrite delete files. Hence, it makes an + * assumption that all incoming deletes belong to the same partition. + */ + private static class DeleteWriter implements DataWriter { + private final SparkFileWriterFactory writerFactoryWithRow; + private final SparkFileWriterFactory writerFactoryWithoutRow; + private final OutputFileFactory deleteFileFactory; + private final long targetFileSize; + private final PositionDelete positionDelete; + private final FileIO io; + private final PartitionSpec spec; + private final int fileOrdinal; + private final int positionOrdinal; + private final int rowOrdinal; + private final int rowSize; + private final StructLike partition; + + private ClusteredPositionDeleteWriter writerWithRow; + private ClusteredPositionDeleteWriter writerWithoutRow; + private boolean closed = false; + + /** + * Constructs a {@link DeleteWriter}. + * + * @param table position deletes metadata table + * @param writerFactoryWithRow writer factory for deletes with non-null 'row' + * @param writerFactoryWithoutRow writer factory for deletes with null 'row' + * @param deleteFileFactory delete file factory + * @param targetFileSize target file size + * @param dsSchema schema of incoming dataset of position deletes + * @param specId partition spec id of incoming position deletes. All incoming partition deletes + * are required to have the same spec id. + * @param partition partition value of incoming position delete. All incoming partition deletes + * are required to have the same partition. + */ + DeleteWriter( + Table table, + SparkFileWriterFactory writerFactoryWithRow, + SparkFileWriterFactory writerFactoryWithoutRow, + OutputFileFactory deleteFileFactory, + long targetFileSize, + StructType dsSchema, + int specId, + StructLike partition) { + this.deleteFileFactory = deleteFileFactory; + this.targetFileSize = targetFileSize; + this.writerFactoryWithRow = writerFactoryWithRow; + this.writerFactoryWithoutRow = writerFactoryWithoutRow; + this.positionDelete = PositionDelete.create(); + this.io = table.io(); + this.spec = table.specs().get(specId); + this.partition = partition; + + this.fileOrdinal = dsSchema.fieldIndex(MetadataColumns.DELETE_FILE_PATH.name()); + this.positionOrdinal = dsSchema.fieldIndex(MetadataColumns.DELETE_FILE_POS.name()); + + this.rowOrdinal = dsSchema.fieldIndex(MetadataColumns.DELETE_FILE_ROW_FIELD_NAME); + DataType type = dsSchema.apply(MetadataColumns.DELETE_FILE_ROW_FIELD_NAME).dataType(); + Preconditions.checkArgument( + type instanceof StructType, "Expected row as struct type but was %s", type); + this.rowSize = ((StructType) type).size(); + } + + @Override + public void write(InternalRow record) throws IOException { + String file = record.getString(fileOrdinal); + long position = record.getLong(positionOrdinal); + InternalRow row = record.getStruct(rowOrdinal, rowSize); + if (row != null) { + positionDelete.set(file, position, row); + lazyWriterWithRow().write(positionDelete, spec, partition); + } else { + positionDelete.set(file, position, null); + lazyWriterWithoutRow().write(positionDelete, spec, partition); + } + } + + @Override + public WriterCommitMessage commit() throws IOException { + close(); + return new DeleteTaskCommit(allDeleteFiles()); + } + + @Override + public void abort() throws IOException { + close(); + SparkCleanupUtil.deleteTaskFiles(io, allDeleteFiles()); + } + + @Override + public void close() throws IOException { + if (!closed) { + if (writerWithRow != null) { + writerWithRow.close(); + } + if (writerWithoutRow != null) { + writerWithoutRow.close(); + } + this.closed = true; + } + } + + private ClusteredPositionDeleteWriter lazyWriterWithRow() { + if (writerWithRow == null) { + this.writerWithRow = + new ClusteredPositionDeleteWriter<>( + writerFactoryWithRow, deleteFileFactory, io, targetFileSize); + } + return writerWithRow; + } + + private ClusteredPositionDeleteWriter lazyWriterWithoutRow() { + if (writerWithoutRow == null) { + this.writerWithoutRow = + new ClusteredPositionDeleteWriter<>( + writerFactoryWithoutRow, deleteFileFactory, io, targetFileSize); + } + return writerWithoutRow; + } + + private List allDeleteFiles() { + List allDeleteFiles = Lists.newArrayList(); + if (writerWithRow != null) { + allDeleteFiles.addAll(writerWithRow.result().deleteFiles()); + } + if (writerWithoutRow != null) { + allDeleteFiles.addAll(writerWithoutRow.result().deleteFiles()); + } + return allDeleteFiles; + } + } + + public static class DeleteTaskCommit implements WriterCommitMessage { + private final DeleteFile[] taskFiles; + + DeleteTaskCommit(List deleteFiles) { + this.taskFiles = deleteFiles.toArray(new DeleteFile[0]); + } + + DeleteFile[] files() { + return taskFiles; + } + } +} diff --git a/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/source/SparkPositionDeletesRewriteBuilder.java b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/source/SparkPositionDeletesRewriteBuilder.java new file mode 100644 index 000000000000..cc5c987fc4cd --- /dev/null +++ b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/source/SparkPositionDeletesRewriteBuilder.java @@ -0,0 +1,113 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.source; + +import java.util.List; +import java.util.Set; +import java.util.stream.Collectors; +import org.apache.iceberg.ContentScanTask; +import org.apache.iceberg.PositionDeletesScanTask; +import org.apache.iceberg.Schema; +import org.apache.iceberg.StructLike; +import org.apache.iceberg.Table; +import org.apache.iceberg.relocated.com.google.common.base.Joiner; +import org.apache.iceberg.relocated.com.google.common.base.Preconditions; +import org.apache.iceberg.spark.ScanTaskSetManager; +import org.apache.iceberg.spark.SparkSchemaUtil; +import org.apache.iceberg.spark.SparkUtil; +import org.apache.iceberg.spark.SparkWriteConf; +import org.apache.iceberg.util.StructLikeSet; +import org.apache.spark.sql.SparkSession; +import org.apache.spark.sql.connector.write.LogicalWriteInfo; +import org.apache.spark.sql.connector.write.Write; +import org.apache.spark.sql.connector.write.WriteBuilder; +import org.apache.spark.sql.types.StructType; + +/** + * Builder class for rewrites of position delete files from Spark. Responsible for creating {@link + * SparkPositionDeletesRewrite}. + * + *

This class is meant to be used for an action to rewrite delete files. Hence, it makes an + * assumption that all incoming deletes belong to the same partition, and that incoming dataset is + * from {@link ScanTaskSetManager}. + */ +public class SparkPositionDeletesRewriteBuilder implements WriteBuilder { + + private final SparkSession spark; + private final Table table; + private final SparkWriteConf writeConf; + private final LogicalWriteInfo writeInfo; + private final StructType dsSchema; + private final Schema writeSchema; + + SparkPositionDeletesRewriteBuilder( + SparkSession spark, Table table, String branch, LogicalWriteInfo info) { + this.spark = spark; + this.table = table; + this.writeConf = new SparkWriteConf(spark, table, branch, info.options()); + this.writeInfo = info; + this.dsSchema = info.schema(); + this.writeSchema = SparkSchemaUtil.convert(table.schema(), dsSchema, writeConf.caseSensitive()); + } + + @Override + public Write build() { + String fileSetId = writeConf.rewrittenFileSetId(); + boolean handleTimestampWithoutZone = writeConf.handleTimestampWithoutZone(); + + Preconditions.checkArgument( + fileSetId != null, "Can only write to %s via actions", table.name()); + Preconditions.checkArgument( + handleTimestampWithoutZone || !SparkUtil.hasTimestampWithoutZone(table.schema()), + SparkUtil.TIMESTAMP_WITHOUT_TIMEZONE_ERROR); + + // all files of rewrite group have same partition and spec id + ScanTaskSetManager taskSetManager = ScanTaskSetManager.get(); + List tasks = taskSetManager.fetchTasks(table, fileSetId); + Preconditions.checkArgument( + tasks != null && tasks.size() > 0, "No scan tasks found for %s", fileSetId); + + int specId = specId(fileSetId, tasks); + StructLike partition = partition(fileSetId, tasks); + + return new SparkPositionDeletesRewrite( + spark, table, writeConf, writeInfo, writeSchema, dsSchema, specId, partition); + } + + private int specId(String fileSetId, List tasks) { + Set specIds = tasks.stream().map(t -> t.spec().specId()).collect(Collectors.toSet()); + Preconditions.checkArgument( + specIds.size() == 1, + "All scan tasks of %s are expected to have same spec id, but got %s", + fileSetId, + Joiner.on(",").join(specIds)); + return tasks.get(0).spec().specId(); + } + + private StructLike partition(String fileSetId, List tasks) { + StructLikeSet partitions = StructLikeSet.create(tasks.get(0).spec().partitionType()); + tasks.stream().map(ContentScanTask::partition).forEach(partitions::add); + Preconditions.checkArgument( + partitions.size() == 1, + "All scan tasks of %s are expected to have the same partition, but got %s", + fileSetId, + Joiner.on(",").join(partitions)); + return tasks.get(0).partition(); + } +} diff --git a/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/source/SparkPositionDeltaOperation.java b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/source/SparkPositionDeltaOperation.java new file mode 100644 index 000000000000..8acd87d3cbac --- /dev/null +++ b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/source/SparkPositionDeltaOperation.java @@ -0,0 +1,110 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.source; + +import org.apache.iceberg.IsolationLevel; +import org.apache.iceberg.MetadataColumns; +import org.apache.iceberg.Table; +import org.apache.spark.sql.SparkSession; +import org.apache.spark.sql.connector.expressions.Expressions; +import org.apache.spark.sql.connector.expressions.NamedReference; +import org.apache.spark.sql.connector.read.Scan; +import org.apache.spark.sql.connector.read.ScanBuilder; +import org.apache.spark.sql.connector.write.DeltaWriteBuilder; +import org.apache.spark.sql.connector.write.LogicalWriteInfo; +import org.apache.spark.sql.connector.write.RowLevelOperation; +import org.apache.spark.sql.connector.write.RowLevelOperationInfo; +import org.apache.spark.sql.connector.write.SupportsDelta; +import org.apache.spark.sql.util.CaseInsensitiveStringMap; + +class SparkPositionDeltaOperation implements RowLevelOperation, SupportsDelta { + + private final SparkSession spark; + private final Table table; + private final String branch; + private final Command command; + private final IsolationLevel isolationLevel; + + // lazy vars + private ScanBuilder lazyScanBuilder; + private Scan configuredScan; + private DeltaWriteBuilder lazyWriteBuilder; + + SparkPositionDeltaOperation( + SparkSession spark, + Table table, + String branch, + RowLevelOperationInfo info, + IsolationLevel isolationLevel) { + this.spark = spark; + this.table = table; + this.branch = branch; + this.command = info.command(); + this.isolationLevel = isolationLevel; + } + + @Override + public Command command() { + return command; + } + + @Override + public ScanBuilder newScanBuilder(CaseInsensitiveStringMap options) { + if (lazyScanBuilder == null) { + this.lazyScanBuilder = + new SparkScanBuilder(spark, table, branch, options) { + @Override + public Scan build() { + Scan scan = super.buildMergeOnReadScan(); + SparkPositionDeltaOperation.this.configuredScan = scan; + return scan; + } + }; + } + + return lazyScanBuilder; + } + + @Override + public DeltaWriteBuilder newWriteBuilder(LogicalWriteInfo info) { + if (lazyWriteBuilder == null) { + // don't validate the scan is not null as if the condition evaluates to false, + // the optimizer replaces the original scan relation with a local relation + lazyWriteBuilder = + new SparkPositionDeltaWriteBuilder( + spark, table, branch, command, configuredScan, isolationLevel, info); + } + + return lazyWriteBuilder; + } + + @Override + public NamedReference[] requiredMetadataAttributes() { + NamedReference specId = Expressions.column(MetadataColumns.SPEC_ID.name()); + NamedReference partition = Expressions.column(MetadataColumns.PARTITION_COLUMN_NAME); + return new NamedReference[] {specId, partition}; + } + + @Override + public NamedReference[] rowId() { + NamedReference file = Expressions.column(MetadataColumns.FILE_PATH.name()); + NamedReference pos = Expressions.column(MetadataColumns.ROW_POSITION.name()); + return new NamedReference[] {file, pos}; + } +} diff --git a/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/source/SparkPositionDeltaWrite.java b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/source/SparkPositionDeltaWrite.java new file mode 100644 index 000000000000..74d46339eed3 --- /dev/null +++ b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/source/SparkPositionDeltaWrite.java @@ -0,0 +1,731 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.source; + +import static org.apache.iceberg.IsolationLevel.SERIALIZABLE; +import static org.apache.spark.sql.connector.write.RowLevelOperation.Command.DELETE; +import static org.apache.spark.sql.connector.write.RowLevelOperation.Command.MERGE; +import static org.apache.spark.sql.connector.write.RowLevelOperation.Command.UPDATE; + +import java.io.IOException; +import java.io.Serializable; +import java.util.Arrays; +import java.util.List; +import java.util.Map; +import org.apache.iceberg.ContentFile; +import org.apache.iceberg.DataFile; +import org.apache.iceberg.DeleteFile; +import org.apache.iceberg.FileFormat; +import org.apache.iceberg.IsolationLevel; +import org.apache.iceberg.MetadataColumns; +import org.apache.iceberg.PartitionKey; +import org.apache.iceberg.PartitionSpec; +import org.apache.iceberg.Partitioning; +import org.apache.iceberg.RowDelta; +import org.apache.iceberg.Schema; +import org.apache.iceberg.SnapshotSummary; +import org.apache.iceberg.SnapshotUpdate; +import org.apache.iceberg.Table; +import org.apache.iceberg.deletes.PositionDelete; +import org.apache.iceberg.exceptions.CommitStateUnknownException; +import org.apache.iceberg.expressions.Expression; +import org.apache.iceberg.expressions.Expressions; +import org.apache.iceberg.io.BasePositionDeltaWriter; +import org.apache.iceberg.io.ClusteredDataWriter; +import org.apache.iceberg.io.ClusteredPositionDeleteWriter; +import org.apache.iceberg.io.DataWriteResult; +import org.apache.iceberg.io.DeleteWriteResult; +import org.apache.iceberg.io.FanoutDataWriter; +import org.apache.iceberg.io.FileIO; +import org.apache.iceberg.io.OutputFileFactory; +import org.apache.iceberg.io.PartitioningWriter; +import org.apache.iceberg.io.PositionDeltaWriter; +import org.apache.iceberg.io.WriteResult; +import org.apache.iceberg.relocated.com.google.common.collect.Lists; +import org.apache.iceberg.relocated.com.google.common.collect.Maps; +import org.apache.iceberg.spark.CommitMetadata; +import org.apache.iceberg.spark.SparkSchemaUtil; +import org.apache.iceberg.spark.SparkWriteConf; +import org.apache.iceberg.types.Types; +import org.apache.iceberg.util.CharSequenceSet; +import org.apache.iceberg.util.StructProjection; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.broadcast.Broadcast; +import org.apache.spark.sql.SparkSession; +import org.apache.spark.sql.catalyst.InternalRow; +import org.apache.spark.sql.connector.distributions.Distribution; +import org.apache.spark.sql.connector.expressions.SortOrder; +import org.apache.spark.sql.connector.write.DeltaBatchWrite; +import org.apache.spark.sql.connector.write.DeltaWrite; +import org.apache.spark.sql.connector.write.DeltaWriter; +import org.apache.spark.sql.connector.write.DeltaWriterFactory; +import org.apache.spark.sql.connector.write.LogicalWriteInfo; +import org.apache.spark.sql.connector.write.PhysicalWriteInfo; +import org.apache.spark.sql.connector.write.RequiresDistributionAndOrdering; +import org.apache.spark.sql.connector.write.RowLevelOperation.Command; +import org.apache.spark.sql.connector.write.WriterCommitMessage; +import org.apache.spark.sql.types.StructType; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +class SparkPositionDeltaWrite implements DeltaWrite, RequiresDistributionAndOrdering { + + private static final Logger LOG = LoggerFactory.getLogger(SparkPositionDeltaWrite.class); + + private final JavaSparkContext sparkContext; + private final Table table; + private final Command command; + private final SparkBatchQueryScan scan; + private final IsolationLevel isolationLevel; + private final Context context; + private final String applicationId; + private final boolean wapEnabled; + private final String wapId; + private final String branch; + private final Map extraSnapshotMetadata; + private final Distribution requiredDistribution; + private final SortOrder[] requiredOrdering; + + private boolean cleanupOnAbort = true; + + SparkPositionDeltaWrite( + SparkSession spark, + Table table, + Command command, + SparkBatchQueryScan scan, + IsolationLevel isolationLevel, + SparkWriteConf writeConf, + LogicalWriteInfo info, + Schema dataSchema, + Distribution requiredDistribution, + SortOrder[] requiredOrdering) { + this.sparkContext = JavaSparkContext.fromSparkContext(spark.sparkContext()); + this.table = table; + this.command = command; + this.scan = scan; + this.isolationLevel = isolationLevel; + this.context = new Context(dataSchema, writeConf, info); + this.applicationId = spark.sparkContext().applicationId(); + this.wapEnabled = writeConf.wapEnabled(); + this.wapId = writeConf.wapId(); + this.branch = writeConf.branch(); + this.extraSnapshotMetadata = writeConf.extraSnapshotMetadata(); + this.requiredDistribution = requiredDistribution; + this.requiredOrdering = requiredOrdering; + } + + @Override + public Distribution requiredDistribution() { + return requiredDistribution; + } + + @Override + public SortOrder[] requiredOrdering() { + return requiredOrdering; + } + + @Override + public DeltaBatchWrite toBatch() { + return new PositionDeltaBatchWrite(); + } + + private class PositionDeltaBatchWrite implements DeltaBatchWrite { + + @Override + public DeltaWriterFactory createBatchWriterFactory(PhysicalWriteInfo info) { + // broadcast the table metadata as the writer factory will be sent to executors + Broadcast

tableBroadcast = + sparkContext.broadcast(SerializableTableWithSize.copyOf(table)); + return new PositionDeltaWriteFactory(tableBroadcast, command, context); + } + + @Override + public void commit(WriterCommitMessage[] messages) { + RowDelta rowDelta = table.newRowDelta(); + + CharSequenceSet referencedDataFiles = CharSequenceSet.empty(); + + int addedDataFilesCount = 0; + int addedDeleteFilesCount = 0; + + for (WriterCommitMessage message : messages) { + DeltaTaskCommit taskCommit = (DeltaTaskCommit) message; + + for (DataFile dataFile : taskCommit.dataFiles()) { + rowDelta.addRows(dataFile); + addedDataFilesCount += 1; + } + + for (DeleteFile deleteFile : taskCommit.deleteFiles()) { + rowDelta.addDeletes(deleteFile); + addedDeleteFilesCount += 1; + } + + referencedDataFiles.addAll(Arrays.asList(taskCommit.referencedDataFiles())); + } + + // the scan may be null if the optimizer replaces it with an empty relation + // no validation is needed in this case as the command is independent of the table state + if (scan != null) { + Expression conflictDetectionFilter = conflictDetectionFilter(scan); + rowDelta.conflictDetectionFilter(conflictDetectionFilter); + + rowDelta.validateDataFilesExist(referencedDataFiles); + + if (scan.snapshotId() != null) { + // set the read snapshot ID to check only snapshots that happened after the table was read + // otherwise, the validation will go through all snapshots present in the table + rowDelta.validateFromSnapshot(scan.snapshotId()); + } + + if (command == UPDATE || command == MERGE) { + rowDelta.validateDeletedFiles(); + rowDelta.validateNoConflictingDeleteFiles(); + } + + if (isolationLevel == SERIALIZABLE) { + rowDelta.validateNoConflictingDataFiles(); + } + + String commitMsg = + String.format( + "position delta with %d data files and %d delete files " + + "(scanSnapshotId: %d, conflictDetectionFilter: %s, isolationLevel: %s)", + addedDataFilesCount, + addedDeleteFilesCount, + scan.snapshotId(), + conflictDetectionFilter, + isolationLevel); + commitOperation(rowDelta, commitMsg); + + } else { + String commitMsg = + String.format( + "position delta with %d data files and %d delete files (no validation required)", + addedDataFilesCount, addedDeleteFilesCount); + commitOperation(rowDelta, commitMsg); + } + } + + private Expression conflictDetectionFilter(SparkBatchQueryScan queryScan) { + Expression filter = Expressions.alwaysTrue(); + + for (Expression expr : queryScan.filterExpressions()) { + filter = Expressions.and(filter, expr); + } + + return filter; + } + + @Override + public void abort(WriterCommitMessage[] messages) { + if (cleanupOnAbort) { + SparkCleanupUtil.deleteFiles("job abort", table.io(), files(messages)); + } else { + LOG.warn("Skipping cleanup of written files"); + } + } + + private List> files(WriterCommitMessage[] messages) { + List> files = Lists.newArrayList(); + + for (WriterCommitMessage message : messages) { + if (message != null) { + DeltaTaskCommit taskCommit = (DeltaTaskCommit) message; + files.addAll(Arrays.asList(taskCommit.dataFiles())); + files.addAll(Arrays.asList(taskCommit.deleteFiles())); + } + } + + return files; + } + + private void commitOperation(SnapshotUpdate operation, String description) { + LOG.info("Committing {} to table {}", description, table); + if (applicationId != null) { + operation.set("spark.app.id", applicationId); + } + + extraSnapshotMetadata.forEach(operation::set); + + if (!CommitMetadata.commitProperties().isEmpty()) { + CommitMetadata.commitProperties().forEach(operation::set); + } + + if (wapEnabled && wapId != null) { + // write-audit-publish is enabled for this table and job + // stage the changes without changing the current snapshot + operation.set(SnapshotSummary.STAGED_WAP_ID_PROP, wapId); + operation.stageOnly(); + } + + if (branch != null) { + operation.toBranch(branch); + } + + try { + long start = System.currentTimeMillis(); + operation.commit(); // abort is automatically called if this fails + long duration = System.currentTimeMillis() - start; + LOG.info("Committed in {} ms", duration); + } catch (CommitStateUnknownException commitStateUnknownException) { + cleanupOnAbort = false; + throw commitStateUnknownException; + } + } + } + + public static class DeltaTaskCommit implements WriterCommitMessage { + private final DataFile[] dataFiles; + private final DeleteFile[] deleteFiles; + private final CharSequence[] referencedDataFiles; + + DeltaTaskCommit(WriteResult result) { + this.dataFiles = result.dataFiles(); + this.deleteFiles = result.deleteFiles(); + this.referencedDataFiles = result.referencedDataFiles(); + } + + DeltaTaskCommit(DeleteWriteResult result) { + this.dataFiles = new DataFile[0]; + this.deleteFiles = result.deleteFiles().toArray(new DeleteFile[0]); + this.referencedDataFiles = result.referencedDataFiles().toArray(new CharSequence[0]); + } + + DataFile[] dataFiles() { + return dataFiles; + } + + DeleteFile[] deleteFiles() { + return deleteFiles; + } + + CharSequence[] referencedDataFiles() { + return referencedDataFiles; + } + } + + private static class PositionDeltaWriteFactory implements DeltaWriterFactory { + private final Broadcast
tableBroadcast; + private final Command command; + private final Context context; + + PositionDeltaWriteFactory(Broadcast
tableBroadcast, Command command, Context context) { + this.tableBroadcast = tableBroadcast; + this.command = command; + this.context = context; + } + + @Override + public DeltaWriter createWriter(int partitionId, long taskId) { + Table table = tableBroadcast.value(); + + OutputFileFactory dataFileFactory = + OutputFileFactory.builderFor(table, partitionId, taskId) + .format(context.dataFileFormat()) + .operationId(context.queryId()) + .build(); + OutputFileFactory deleteFileFactory = + OutputFileFactory.builderFor(table, partitionId, taskId) + .format(context.deleteFileFormat()) + .operationId(context.queryId()) + .suffix("deletes") + .build(); + + SparkFileWriterFactory writerFactory = + SparkFileWriterFactory.builderFor(table) + .dataFileFormat(context.dataFileFormat()) + .dataSchema(context.dataSchema()) + .dataSparkType(context.dataSparkType()) + .deleteFileFormat(context.deleteFileFormat()) + .positionDeleteSparkType(context.deleteSparkType()) + .build(); + + if (command == DELETE) { + return new DeleteOnlyDeltaWriter(table, writerFactory, deleteFileFactory, context); + + } else if (table.spec().isUnpartitioned()) { + return new UnpartitionedDeltaWriter( + table, writerFactory, dataFileFactory, deleteFileFactory, context); + + } else { + return new PartitionedDeltaWriter( + table, writerFactory, dataFileFactory, deleteFileFactory, context); + } + } + } + + private abstract static class BaseDeltaWriter implements DeltaWriter { + + protected InternalRowWrapper initPartitionRowWrapper(Types.StructType partitionType) { + StructType sparkPartitionType = (StructType) SparkSchemaUtil.convert(partitionType); + return new InternalRowWrapper(sparkPartitionType); + } + + protected Map buildPartitionProjections( + Types.StructType partitionType, Map specs) { + Map partitionProjections = Maps.newHashMap(); + specs.forEach( + (specID, spec) -> + partitionProjections.put( + specID, StructProjection.create(partitionType, spec.partitionType()))); + return partitionProjections; + } + } + + private static class DeleteOnlyDeltaWriter extends BaseDeltaWriter { + private final ClusteredPositionDeleteWriter delegate; + private final PositionDelete positionDelete; + private final FileIO io; + private final Map specs; + private final InternalRowWrapper partitionRowWrapper; + private final Map partitionProjections; + private final int specIdOrdinal; + private final int partitionOrdinal; + private final int fileOrdinal; + private final int positionOrdinal; + + private boolean closed = false; + + DeleteOnlyDeltaWriter( + Table table, + SparkFileWriterFactory writerFactory, + OutputFileFactory deleteFileFactory, + Context context) { + + this.delegate = + new ClusteredPositionDeleteWriter<>( + writerFactory, deleteFileFactory, table.io(), context.targetDeleteFileSize()); + this.positionDelete = PositionDelete.create(); + this.io = table.io(); + this.specs = table.specs(); + + Types.StructType partitionType = Partitioning.partitionType(table); + this.partitionRowWrapper = initPartitionRowWrapper(partitionType); + this.partitionProjections = buildPartitionProjections(partitionType, specs); + + this.specIdOrdinal = context.metadataSparkType().fieldIndex(MetadataColumns.SPEC_ID.name()); + this.partitionOrdinal = + context.metadataSparkType().fieldIndex(MetadataColumns.PARTITION_COLUMN_NAME); + this.fileOrdinal = context.deleteSparkType().fieldIndex(MetadataColumns.FILE_PATH.name()); + this.positionOrdinal = + context.deleteSparkType().fieldIndex(MetadataColumns.ROW_POSITION.name()); + } + + @Override + public void delete(InternalRow metadata, InternalRow id) throws IOException { + int specId = metadata.getInt(specIdOrdinal); + PartitionSpec spec = specs.get(specId); + + InternalRow partition = metadata.getStruct(partitionOrdinal, partitionRowWrapper.size()); + StructProjection partitionProjection = partitionProjections.get(specId); + partitionProjection.wrap(partitionRowWrapper.wrap(partition)); + + String file = id.getString(fileOrdinal); + long position = id.getLong(positionOrdinal); + positionDelete.set(file, position, null); + delegate.write(positionDelete, spec, partitionProjection); + } + + @Override + public void update(InternalRow metadata, InternalRow id, InternalRow row) { + throw new UnsupportedOperationException( + this.getClass().getName() + " does not implement update"); + } + + @Override + public void insert(InternalRow row) throws IOException { + throw new UnsupportedOperationException( + this.getClass().getName() + " does not implement insert"); + } + + @Override + public WriterCommitMessage commit() throws IOException { + close(); + + DeleteWriteResult result = delegate.result(); + return new DeltaTaskCommit(result); + } + + @Override + public void abort() throws IOException { + close(); + + DeleteWriteResult result = delegate.result(); + SparkCleanupUtil.deleteTaskFiles(io, result.deleteFiles()); + } + + @Override + public void close() throws IOException { + if (!closed) { + delegate.close(); + this.closed = true; + } + } + } + + @SuppressWarnings("checkstyle:VisibilityModifier") + private abstract static class DeleteAndDataDeltaWriter extends BaseDeltaWriter { + protected final PositionDeltaWriter delegate; + private final FileIO io; + private final Map specs; + private final InternalRowWrapper deletePartitionRowWrapper; + private final Map deletePartitionProjections; + private final int specIdOrdinal; + private final int partitionOrdinal; + private final int fileOrdinal; + private final int positionOrdinal; + + private boolean closed = false; + + DeleteAndDataDeltaWriter( + Table table, + SparkFileWriterFactory writerFactory, + OutputFileFactory dataFileFactory, + OutputFileFactory deleteFileFactory, + Context context) { + this.delegate = + new BasePositionDeltaWriter<>( + newInsertWriter(table, writerFactory, dataFileFactory, context), + newUpdateWriter(table, writerFactory, dataFileFactory, context), + newDeleteWriter(table, writerFactory, deleteFileFactory, context)); + this.io = table.io(); + this.specs = table.specs(); + + Types.StructType partitionType = Partitioning.partitionType(table); + this.deletePartitionRowWrapper = initPartitionRowWrapper(partitionType); + this.deletePartitionProjections = buildPartitionProjections(partitionType, specs); + + this.specIdOrdinal = context.metadataSparkType().fieldIndex(MetadataColumns.SPEC_ID.name()); + this.partitionOrdinal = + context.metadataSparkType().fieldIndex(MetadataColumns.PARTITION_COLUMN_NAME); + this.fileOrdinal = context.deleteSparkType().fieldIndex(MetadataColumns.FILE_PATH.name()); + this.positionOrdinal = + context.deleteSparkType().fieldIndex(MetadataColumns.ROW_POSITION.name()); + } + + @Override + public void delete(InternalRow meta, InternalRow id) throws IOException { + int specId = meta.getInt(specIdOrdinal); + PartitionSpec spec = specs.get(specId); + + InternalRow partition = meta.getStruct(partitionOrdinal, deletePartitionRowWrapper.size()); + StructProjection partitionProjection = deletePartitionProjections.get(specId); + partitionProjection.wrap(deletePartitionRowWrapper.wrap(partition)); + + String file = id.getString(fileOrdinal); + long position = id.getLong(positionOrdinal); + delegate.delete(file, position, spec, partitionProjection); + } + + @Override + public WriterCommitMessage commit() throws IOException { + close(); + + WriteResult result = delegate.result(); + return new DeltaTaskCommit(result); + } + + @Override + public void abort() throws IOException { + close(); + + WriteResult result = delegate.result(); + SparkCleanupUtil.deleteTaskFiles(io, files(result)); + } + + private List> files(WriteResult result) { + List> files = Lists.newArrayList(); + files.addAll(Arrays.asList(result.dataFiles())); + files.addAll(Arrays.asList(result.deleteFiles())); + return files; + } + + @Override + public void close() throws IOException { + if (!closed) { + delegate.close(); + this.closed = true; + } + } + + private PartitioningWriter newInsertWriter( + Table table, + SparkFileWriterFactory writerFactory, + OutputFileFactory fileFactory, + Context context) { + long targetFileSize = context.targetDataFileSize(); + + if (table.spec().isPartitioned() && context.fanoutWriterEnabled()) { + return new FanoutDataWriter<>(writerFactory, fileFactory, table.io(), targetFileSize); + } else { + return new ClusteredDataWriter<>(writerFactory, fileFactory, table.io(), targetFileSize); + } + } + + private PartitioningWriter newUpdateWriter( + Table table, + SparkFileWriterFactory writerFactory, + OutputFileFactory fileFactory, + Context context) { + long targetFileSize = context.targetDataFileSize(); + + if (table.spec().isPartitioned()) { + // use a fanout writer for partitioned tables to write updates as they may be out of order + return new FanoutDataWriter<>(writerFactory, fileFactory, table.io(), targetFileSize); + } else { + return new ClusteredDataWriter<>(writerFactory, fileFactory, table.io(), targetFileSize); + } + } + + private ClusteredPositionDeleteWriter newDeleteWriter( + Table table, + SparkFileWriterFactory writerFactory, + OutputFileFactory fileFactory, + Context context) { + long targetFileSize = context.targetDeleteFileSize(); + return new ClusteredPositionDeleteWriter<>( + writerFactory, fileFactory, table.io(), targetFileSize); + } + } + + private static class UnpartitionedDeltaWriter extends DeleteAndDataDeltaWriter { + private final PartitionSpec dataSpec; + + UnpartitionedDeltaWriter( + Table table, + SparkFileWriterFactory writerFactory, + OutputFileFactory dataFileFactory, + OutputFileFactory deleteFileFactory, + Context context) { + super(table, writerFactory, dataFileFactory, deleteFileFactory, context); + this.dataSpec = table.spec(); + } + + @Override + public void update(InternalRow meta, InternalRow id, InternalRow row) throws IOException { + delete(meta, id); + delegate.update(row, dataSpec, null); + } + + @Override + public void insert(InternalRow row) throws IOException { + delegate.insert(row, dataSpec, null); + } + } + + private static class PartitionedDeltaWriter extends DeleteAndDataDeltaWriter { + private final PartitionSpec dataSpec; + private final PartitionKey dataPartitionKey; + private final InternalRowWrapper internalRowDataWrapper; + + PartitionedDeltaWriter( + Table table, + SparkFileWriterFactory writerFactory, + OutputFileFactory dataFileFactory, + OutputFileFactory deleteFileFactory, + Context context) { + super(table, writerFactory, dataFileFactory, deleteFileFactory, context); + + this.dataSpec = table.spec(); + this.dataPartitionKey = new PartitionKey(dataSpec, context.dataSchema()); + this.internalRowDataWrapper = new InternalRowWrapper(context.dataSparkType()); + } + + @Override + public void update(InternalRow meta, InternalRow id, InternalRow row) throws IOException { + delete(meta, id); + dataPartitionKey.partition(internalRowDataWrapper.wrap(row)); + delegate.update(row, dataSpec, dataPartitionKey); + } + + @Override + public void insert(InternalRow row) throws IOException { + dataPartitionKey.partition(internalRowDataWrapper.wrap(row)); + delegate.insert(row, dataSpec, dataPartitionKey); + } + } + + // a serializable helper class for common parameters required to configure writers + private static class Context implements Serializable { + private final Schema dataSchema; + private final StructType dataSparkType; + private final FileFormat dataFileFormat; + private final long targetDataFileSize; + private final StructType deleteSparkType; + private final StructType metadataSparkType; + private final FileFormat deleteFileFormat; + private final long targetDeleteFileSize; + private final boolean fanoutWriterEnabled; + private final String queryId; + + Context(Schema dataSchema, SparkWriteConf writeConf, LogicalWriteInfo info) { + this.dataSchema = dataSchema; + this.dataSparkType = info.schema(); + this.dataFileFormat = writeConf.dataFileFormat(); + this.targetDataFileSize = writeConf.targetDataFileSize(); + this.deleteSparkType = info.rowIdSchema().get(); + this.deleteFileFormat = writeConf.deleteFileFormat(); + this.targetDeleteFileSize = writeConf.targetDeleteFileSize(); + this.metadataSparkType = info.metadataSchema().get(); + this.fanoutWriterEnabled = writeConf.fanoutWriterEnabled(); + this.queryId = info.queryId(); + } + + Schema dataSchema() { + return dataSchema; + } + + StructType dataSparkType() { + return dataSparkType; + } + + FileFormat dataFileFormat() { + return dataFileFormat; + } + + long targetDataFileSize() { + return targetDataFileSize; + } + + StructType deleteSparkType() { + return deleteSparkType; + } + + StructType metadataSparkType() { + return metadataSparkType; + } + + FileFormat deleteFileFormat() { + return deleteFileFormat; + } + + long targetDeleteFileSize() { + return targetDeleteFileSize; + } + + boolean fanoutWriterEnabled() { + return fanoutWriterEnabled; + } + + String queryId() { + return queryId; + } + } +} diff --git a/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/source/SparkPositionDeltaWriteBuilder.java b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/source/SparkPositionDeltaWriteBuilder.java new file mode 100644 index 000000000000..8f8f64ba3157 --- /dev/null +++ b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/source/SparkPositionDeltaWriteBuilder.java @@ -0,0 +1,153 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.source; + +import org.apache.iceberg.DistributionMode; +import org.apache.iceberg.IsolationLevel; +import org.apache.iceberg.MetadataColumns; +import org.apache.iceberg.Schema; +import org.apache.iceberg.Table; +import org.apache.iceberg.relocated.com.google.common.base.Preconditions; +import org.apache.iceberg.spark.SparkDistributionAndOrderingUtil; +import org.apache.iceberg.spark.SparkSchemaUtil; +import org.apache.iceberg.spark.SparkUtil; +import org.apache.iceberg.spark.SparkWriteConf; +import org.apache.iceberg.types.TypeUtil; +import org.apache.spark.sql.SparkSession; +import org.apache.spark.sql.connector.distributions.Distribution; +import org.apache.spark.sql.connector.expressions.SortOrder; +import org.apache.spark.sql.connector.read.Scan; +import org.apache.spark.sql.connector.write.DeltaWrite; +import org.apache.spark.sql.connector.write.DeltaWriteBuilder; +import org.apache.spark.sql.connector.write.LogicalWriteInfo; +import org.apache.spark.sql.connector.write.RowLevelOperation.Command; +import org.apache.spark.sql.types.StructType; + +class SparkPositionDeltaWriteBuilder implements DeltaWriteBuilder { + + private static final Schema EXPECTED_ROW_ID_SCHEMA = + new Schema(MetadataColumns.FILE_PATH, MetadataColumns.ROW_POSITION); + + private final SparkSession spark; + private final Table table; + private final Command command; + private final SparkBatchQueryScan scan; + private final IsolationLevel isolationLevel; + private final SparkWriteConf writeConf; + private final LogicalWriteInfo info; + private final boolean handleTimestampWithoutZone; + private final boolean checkNullability; + private final boolean checkOrdering; + + SparkPositionDeltaWriteBuilder( + SparkSession spark, + Table table, + String branch, + Command command, + Scan scan, + IsolationLevel isolationLevel, + LogicalWriteInfo info) { + this.spark = spark; + this.table = table; + this.command = command; + this.scan = (SparkBatchQueryScan) scan; + this.isolationLevel = isolationLevel; + this.writeConf = new SparkWriteConf(spark, table, branch, info.options()); + this.info = info; + this.handleTimestampWithoutZone = writeConf.handleTimestampWithoutZone(); + this.checkNullability = writeConf.checkNullability(); + this.checkOrdering = writeConf.checkOrdering(); + } + + @Override + public DeltaWrite build() { + Preconditions.checkArgument( + handleTimestampWithoutZone || !SparkUtil.hasTimestampWithoutZone(table.schema()), + SparkUtil.TIMESTAMP_WITHOUT_TIMEZONE_ERROR); + + Schema dataSchema = dataSchema(); + + validateRowIdSchema(); + validateMetadataSchema(); + SparkUtil.validatePartitionTransforms(table.spec()); + + Distribution distribution = + SparkDistributionAndOrderingUtil.buildPositionDeltaDistribution( + table, command, distributionMode()); + SortOrder[] ordering = + SparkDistributionAndOrderingUtil.buildPositionDeltaOrdering(table, command); + + return new SparkPositionDeltaWrite( + spark, + table, + command, + scan, + isolationLevel, + writeConf, + info, + dataSchema, + distribution, + ordering); + } + + private Schema dataSchema() { + if (info.schema() == null || info.schema().isEmpty()) { + return null; + } else { + Schema dataSchema = SparkSchemaUtil.convert(table.schema(), info.schema()); + validateSchema("data", table.schema(), dataSchema); + return dataSchema; + } + } + + private void validateRowIdSchema() { + Preconditions.checkArgument(info.rowIdSchema().isPresent(), "Row ID schema must be set"); + StructType rowIdSparkType = info.rowIdSchema().get(); + Schema rowIdSchema = SparkSchemaUtil.convert(EXPECTED_ROW_ID_SCHEMA, rowIdSparkType); + validateSchema("row ID", EXPECTED_ROW_ID_SCHEMA, rowIdSchema); + } + + private void validateMetadataSchema() { + Preconditions.checkArgument(info.metadataSchema().isPresent(), "Metadata schema must be set"); + Schema expectedMetadataSchema = + new Schema( + MetadataColumns.SPEC_ID, + MetadataColumns.metadataColumn(table, MetadataColumns.PARTITION_COLUMN_NAME)); + StructType metadataSparkType = info.metadataSchema().get(); + Schema metadataSchema = SparkSchemaUtil.convert(expectedMetadataSchema, metadataSparkType); + validateSchema("metadata", expectedMetadataSchema, metadataSchema); + } + + private void validateSchema(String context, Schema expected, Schema actual) { + TypeUtil.validateSchema(context, expected, actual, checkNullability, checkOrdering); + } + + private DistributionMode distributionMode() { + switch (command) { + case DELETE: + return writeConf.deleteDistributionMode(); + case UPDATE: + return writeConf.updateDistributionMode(); + case MERGE: + return writeConf.positionDeltaMergeDistributionMode(); + default: + throw new IllegalArgumentException("Unexpected command: " + command); + } + } +} diff --git a/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/source/SparkRowLevelOperationBuilder.java b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/source/SparkRowLevelOperationBuilder.java new file mode 100644 index 000000000000..b113bd9b25af --- /dev/null +++ b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/source/SparkRowLevelOperationBuilder.java @@ -0,0 +1,114 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.source; + +import static org.apache.iceberg.TableProperties.DELETE_ISOLATION_LEVEL; +import static org.apache.iceberg.TableProperties.DELETE_ISOLATION_LEVEL_DEFAULT; +import static org.apache.iceberg.TableProperties.DELETE_MODE; +import static org.apache.iceberg.TableProperties.DELETE_MODE_DEFAULT; +import static org.apache.iceberg.TableProperties.MERGE_ISOLATION_LEVEL; +import static org.apache.iceberg.TableProperties.MERGE_ISOLATION_LEVEL_DEFAULT; +import static org.apache.iceberg.TableProperties.MERGE_MODE; +import static org.apache.iceberg.TableProperties.MERGE_MODE_DEFAULT; +import static org.apache.iceberg.TableProperties.UPDATE_ISOLATION_LEVEL; +import static org.apache.iceberg.TableProperties.UPDATE_ISOLATION_LEVEL_DEFAULT; +import static org.apache.iceberg.TableProperties.UPDATE_MODE; +import static org.apache.iceberg.TableProperties.UPDATE_MODE_DEFAULT; + +import java.util.Map; +import org.apache.iceberg.IsolationLevel; +import org.apache.iceberg.RowLevelOperationMode; +import org.apache.iceberg.Table; +import org.apache.spark.sql.SparkSession; +import org.apache.spark.sql.connector.write.RowLevelOperation; +import org.apache.spark.sql.connector.write.RowLevelOperation.Command; +import org.apache.spark.sql.connector.write.RowLevelOperationBuilder; +import org.apache.spark.sql.connector.write.RowLevelOperationInfo; + +class SparkRowLevelOperationBuilder implements RowLevelOperationBuilder { + + private final SparkSession spark; + private final Table table; + private final String branch; + private final RowLevelOperationInfo info; + private final RowLevelOperationMode mode; + private final IsolationLevel isolationLevel; + + SparkRowLevelOperationBuilder( + SparkSession spark, Table table, String branch, RowLevelOperationInfo info) { + this.spark = spark; + this.table = table; + this.branch = branch; + this.info = info; + this.mode = mode(table.properties(), info.command()); + this.isolationLevel = isolationLevel(table.properties(), info.command()); + } + + @Override + public RowLevelOperation build() { + switch (mode) { + case COPY_ON_WRITE: + return new SparkCopyOnWriteOperation(spark, table, branch, info, isolationLevel); + case MERGE_ON_READ: + return new SparkPositionDeltaOperation(spark, table, branch, info, isolationLevel); + default: + throw new IllegalArgumentException("Unsupported operation mode: " + mode); + } + } + + private RowLevelOperationMode mode(Map properties, Command command) { + String modeName; + + switch (command) { + case DELETE: + modeName = properties.getOrDefault(DELETE_MODE, DELETE_MODE_DEFAULT); + break; + case UPDATE: + modeName = properties.getOrDefault(UPDATE_MODE, UPDATE_MODE_DEFAULT); + break; + case MERGE: + modeName = properties.getOrDefault(MERGE_MODE, MERGE_MODE_DEFAULT); + break; + default: + throw new IllegalArgumentException("Unsupported command: " + command); + } + + return RowLevelOperationMode.fromName(modeName); + } + + private IsolationLevel isolationLevel(Map properties, Command command) { + String levelName; + + switch (command) { + case DELETE: + levelName = properties.getOrDefault(DELETE_ISOLATION_LEVEL, DELETE_ISOLATION_LEVEL_DEFAULT); + break; + case UPDATE: + levelName = properties.getOrDefault(UPDATE_ISOLATION_LEVEL, UPDATE_ISOLATION_LEVEL_DEFAULT); + break; + case MERGE: + levelName = properties.getOrDefault(MERGE_ISOLATION_LEVEL, MERGE_ISOLATION_LEVEL_DEFAULT); + break; + default: + throw new IllegalArgumentException("Unsupported command: " + command); + } + + return IsolationLevel.fromName(levelName); + } +} diff --git a/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/source/SparkRowReaderFactory.java b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/source/SparkRowReaderFactory.java new file mode 100644 index 000000000000..23699aeb167c --- /dev/null +++ b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/source/SparkRowReaderFactory.java @@ -0,0 +1,68 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.source; + +import org.apache.iceberg.ChangelogScanTask; +import org.apache.iceberg.FileScanTask; +import org.apache.iceberg.PositionDeletesScanTask; +import org.apache.iceberg.relocated.com.google.common.base.Preconditions; +import org.apache.spark.sql.catalyst.InternalRow; +import org.apache.spark.sql.connector.read.InputPartition; +import org.apache.spark.sql.connector.read.PartitionReader; +import org.apache.spark.sql.connector.read.PartitionReaderFactory; +import org.apache.spark.sql.vectorized.ColumnarBatch; + +class SparkRowReaderFactory implements PartitionReaderFactory { + + SparkRowReaderFactory() {} + + @Override + public PartitionReader createReader(InputPartition inputPartition) { + Preconditions.checkArgument( + inputPartition instanceof SparkInputPartition, + "Unknown input partition type: %s", + inputPartition.getClass().getName()); + + SparkInputPartition partition = (SparkInputPartition) inputPartition; + + if (partition.allTasksOfType(FileScanTask.class)) { + return new RowDataReader(partition); + + } else if (partition.allTasksOfType(ChangelogScanTask.class)) { + return new ChangelogRowReader(partition); + + } else if (partition.allTasksOfType(PositionDeletesScanTask.class)) { + return new PositionDeletesRowReader(partition); + + } else { + throw new UnsupportedOperationException( + "Unsupported task group for row-based reads: " + partition.taskGroup()); + } + } + + @Override + public PartitionReader createColumnarReader(InputPartition inputPartition) { + throw new UnsupportedOperationException("Columnar reads are not supported"); + } + + @Override + public boolean supportColumnarReads(InputPartition inputPartition) { + return false; + } +} diff --git a/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/source/SparkScan.java b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/source/SparkScan.java new file mode 100644 index 000000000000..06fc4a07a0eb --- /dev/null +++ b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/source/SparkScan.java @@ -0,0 +1,184 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.source; + +import java.util.Collections; +import java.util.List; +import java.util.Map; +import java.util.stream.Collectors; +import org.apache.iceberg.ScanTaskGroup; +import org.apache.iceberg.Schema; +import org.apache.iceberg.Snapshot; +import org.apache.iceberg.SnapshotSummary; +import org.apache.iceberg.Table; +import org.apache.iceberg.expressions.Expression; +import org.apache.iceberg.relocated.com.google.common.base.Preconditions; +import org.apache.iceberg.spark.Spark3Util; +import org.apache.iceberg.spark.SparkReadConf; +import org.apache.iceberg.spark.SparkSchemaUtil; +import org.apache.iceberg.spark.SparkUtil; +import org.apache.iceberg.spark.source.metrics.NumDeletes; +import org.apache.iceberg.spark.source.metrics.NumSplits; +import org.apache.iceberg.types.Types; +import org.apache.iceberg.util.PropertyUtil; +import org.apache.iceberg.util.SnapshotUtil; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.sql.SparkSession; +import org.apache.spark.sql.connector.metric.CustomMetric; +import org.apache.spark.sql.connector.read.Batch; +import org.apache.spark.sql.connector.read.Scan; +import org.apache.spark.sql.connector.read.Statistics; +import org.apache.spark.sql.connector.read.SupportsReportStatistics; +import org.apache.spark.sql.connector.read.streaming.MicroBatchStream; +import org.apache.spark.sql.types.StructType; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +abstract class SparkScan implements Scan, SupportsReportStatistics { + private static final Logger LOG = LoggerFactory.getLogger(SparkScan.class); + + private final JavaSparkContext sparkContext; + private final Table table; + private final SparkReadConf readConf; + private final boolean caseSensitive; + private final Schema expectedSchema; + private final List filterExpressions; + private final boolean readTimestampWithoutZone; + private final String branch; + + // lazy variables + private StructType readSchema; + + SparkScan( + SparkSession spark, + Table table, + SparkReadConf readConf, + Schema expectedSchema, + List filters) { + Schema snapshotSchema = SnapshotUtil.schemaFor(table, readConf.branch()); + SparkSchemaUtil.validateMetadataColumnReferences(snapshotSchema, expectedSchema); + + this.sparkContext = JavaSparkContext.fromSparkContext(spark.sparkContext()); + this.table = table; + this.readConf = readConf; + this.caseSensitive = readConf.caseSensitive(); + this.expectedSchema = expectedSchema; + this.filterExpressions = filters != null ? filters : Collections.emptyList(); + this.readTimestampWithoutZone = readConf.handleTimestampWithoutZone(); + this.branch = readConf.branch(); + } + + protected Table table() { + return table; + } + + protected String branch() { + return branch; + } + + protected boolean caseSensitive() { + return caseSensitive; + } + + protected Schema expectedSchema() { + return expectedSchema; + } + + protected List filterExpressions() { + return filterExpressions; + } + + protected Types.StructType groupingKeyType() { + return Types.StructType.of(); + } + + protected abstract List> taskGroups(); + + @Override + public Batch toBatch() { + return new SparkBatch( + sparkContext, table, readConf, groupingKeyType(), taskGroups(), expectedSchema, hashCode()); + } + + @Override + public MicroBatchStream toMicroBatchStream(String checkpointLocation) { + return new SparkMicroBatchStream( + sparkContext, table, readConf, expectedSchema, checkpointLocation); + } + + @Override + public StructType readSchema() { + if (readSchema == null) { + Preconditions.checkArgument( + readTimestampWithoutZone || !SparkUtil.hasTimestampWithoutZone(expectedSchema), + SparkUtil.TIMESTAMP_WITHOUT_TIMEZONE_ERROR); + this.readSchema = SparkSchemaUtil.convert(expectedSchema); + } + return readSchema; + } + + @Override + public Statistics estimateStatistics() { + return estimateStatistics(SnapshotUtil.latestSnapshot(table, branch)); + } + + protected Statistics estimateStatistics(Snapshot snapshot) { + // its a fresh table, no data + if (snapshot == null) { + return new Stats(0L, 0L); + } + + // estimate stats using snapshot summary only for partitioned tables + // (metadata tables are unpartitioned) + if (!table.spec().isUnpartitioned() && filterExpressions.isEmpty()) { + LOG.debug( + "Using snapshot {} metadata to estimate statistics for table {}", + snapshot.snapshotId(), + table.name()); + long totalRecords = totalRecords(snapshot); + return new Stats(SparkSchemaUtil.estimateSize(readSchema(), totalRecords), totalRecords); + } + + long rowsCount = taskGroups().stream().mapToLong(ScanTaskGroup::estimatedRowsCount).sum(); + long sizeInBytes = SparkSchemaUtil.estimateSize(readSchema(), rowsCount); + return new Stats(sizeInBytes, rowsCount); + } + + private long totalRecords(Snapshot snapshot) { + Map summary = snapshot.summary(); + return PropertyUtil.propertyAsLong(summary, SnapshotSummary.TOTAL_RECORDS_PROP, Long.MAX_VALUE); + } + + @Override + public String description() { + String groupingKeyFieldNamesAsString = + groupingKeyType().fields().stream() + .map(Types.NestedField::name) + .collect(Collectors.joining(", ")); + + return String.format( + "%s (branch=%s) [filters=%s, groupedBy=%s]", + table(), branch(), Spark3Util.describe(filterExpressions), groupingKeyFieldNamesAsString); + } + + @Override + public CustomMetric[] supportedCustomMetrics() { + return new CustomMetric[] {new NumSplits(), new NumDeletes()}; + } +} diff --git a/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/source/SparkScanBuilder.java b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/source/SparkScanBuilder.java new file mode 100644 index 000000000000..23cd8524b3c8 --- /dev/null +++ b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/source/SparkScanBuilder.java @@ -0,0 +1,656 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.source; + +import java.io.IOException; +import java.util.List; +import java.util.stream.Collectors; +import java.util.stream.Stream; +import org.apache.iceberg.BaseTable; +import org.apache.iceberg.BatchScan; +import org.apache.iceberg.FileScanTask; +import org.apache.iceberg.IncrementalAppendScan; +import org.apache.iceberg.IncrementalChangelogScan; +import org.apache.iceberg.MetadataColumns; +import org.apache.iceberg.MetricsConfig; +import org.apache.iceberg.MetricsModes; +import org.apache.iceberg.Schema; +import org.apache.iceberg.Snapshot; +import org.apache.iceberg.StructLike; +import org.apache.iceberg.Table; +import org.apache.iceberg.TableProperties; +import org.apache.iceberg.TableScan; +import org.apache.iceberg.expressions.AggregateEvaluator; +import org.apache.iceberg.expressions.Binder; +import org.apache.iceberg.expressions.BoundAggregate; +import org.apache.iceberg.expressions.Expression; +import org.apache.iceberg.expressions.ExpressionUtil; +import org.apache.iceberg.expressions.Expressions; +import org.apache.iceberg.io.CloseableIterable; +import org.apache.iceberg.relocated.com.google.common.base.Preconditions; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableList; +import org.apache.iceberg.relocated.com.google.common.collect.Lists; +import org.apache.iceberg.spark.Spark3Util; +import org.apache.iceberg.spark.SparkAggregates; +import org.apache.iceberg.spark.SparkFilters; +import org.apache.iceberg.spark.SparkReadConf; +import org.apache.iceberg.spark.SparkReadOptions; +import org.apache.iceberg.spark.SparkSchemaUtil; +import org.apache.iceberg.types.Type; +import org.apache.iceberg.types.TypeUtil; +import org.apache.iceberg.types.Types; +import org.apache.iceberg.util.SnapshotUtil; +import org.apache.spark.sql.SparkSession; +import org.apache.spark.sql.catalyst.InternalRow; +import org.apache.spark.sql.connector.expressions.aggregate.AggregateFunc; +import org.apache.spark.sql.connector.expressions.aggregate.Aggregation; +import org.apache.spark.sql.connector.read.Scan; +import org.apache.spark.sql.connector.read.ScanBuilder; +import org.apache.spark.sql.connector.read.Statistics; +import org.apache.spark.sql.connector.read.SupportsPushDownAggregates; +import org.apache.spark.sql.connector.read.SupportsPushDownFilters; +import org.apache.spark.sql.connector.read.SupportsPushDownRequiredColumns; +import org.apache.spark.sql.connector.read.SupportsReportStatistics; +import org.apache.spark.sql.sources.Filter; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; +import org.apache.spark.sql.util.CaseInsensitiveStringMap; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +public class SparkScanBuilder + implements ScanBuilder, + SupportsPushDownAggregates, + SupportsPushDownFilters, + SupportsPushDownRequiredColumns, + SupportsReportStatistics { + + private static final Logger LOG = LoggerFactory.getLogger(SparkScanBuilder.class); + private static final Filter[] NO_FILTERS = new Filter[0]; + private StructType pushedAggregateSchema; + private Scan localScan; + + private final SparkSession spark; + private final Table table; + private final CaseInsensitiveStringMap options; + private final SparkReadConf readConf; + private final List metaColumns = Lists.newArrayList(); + + private Schema schema = null; + private boolean caseSensitive; + private List filterExpressions = null; + private Filter[] pushedFilters = NO_FILTERS; + + SparkScanBuilder( + SparkSession spark, + Table table, + String branch, + Schema schema, + CaseInsensitiveStringMap options) { + this.spark = spark; + this.table = table; + this.schema = schema; + this.options = options; + this.readConf = new SparkReadConf(spark, table, branch, options); + this.caseSensitive = readConf.caseSensitive(); + } + + SparkScanBuilder(SparkSession spark, Table table, CaseInsensitiveStringMap options) { + this(spark, table, table.schema(), options); + } + + SparkScanBuilder( + SparkSession spark, Table table, String branch, CaseInsensitiveStringMap options) { + this(spark, table, branch, SnapshotUtil.schemaFor(table, branch), options); + } + + SparkScanBuilder( + SparkSession spark, Table table, Schema schema, CaseInsensitiveStringMap options) { + this(spark, table, null, schema, options); + } + + private Expression filterExpression() { + if (filterExpressions != null) { + return filterExpressions.stream().reduce(Expressions.alwaysTrue(), Expressions::and); + } + return Expressions.alwaysTrue(); + } + + public SparkScanBuilder caseSensitive(boolean isCaseSensitive) { + this.caseSensitive = isCaseSensitive; + return this; + } + + @Override + public Filter[] pushFilters(Filter[] filters) { + // there are 3 kinds of filters: + // (1) filters that can be pushed down completely and don't have to evaluated by Spark + // (e.g. filters that select entire partitions) + // (2) filters that can be pushed down partially and require record-level filtering in Spark + // (e.g. filters that may select some but not necessarily all rows in a file) + // (3) filters that can't be pushed down at all and have to be evaluated by Spark + // (e.g. unsupported filters) + // filters (1) and (2) are used prune files during job planning in Iceberg + // filters (2) and (3) form a set of post scan filters and must be evaluated by Spark + + List expressions = Lists.newArrayListWithExpectedSize(filters.length); + List pushableFilters = Lists.newArrayListWithExpectedSize(filters.length); + List postScanFilters = Lists.newArrayListWithExpectedSize(filters.length); + + for (Filter filter : filters) { + try { + Expression expr = SparkFilters.convert(filter); + + if (expr != null) { + // try binding the expression to ensure it can be pushed down + Binder.bind(schema.asStruct(), expr, caseSensitive); + expressions.add(expr); + pushableFilters.add(filter); + } + + if (expr == null || !ExpressionUtil.selectsPartitions(expr, table, caseSensitive)) { + postScanFilters.add(filter); + } else { + LOG.info("Evaluating completely on Iceberg side: {}", filter); + } + + } catch (Exception e) { + LOG.warn("Failed to check if {} can be pushed down: {}", filter, e.getMessage()); + postScanFilters.add(filter); + } + } + + this.filterExpressions = expressions; + this.pushedFilters = pushableFilters.toArray(new Filter[0]); + + return postScanFilters.toArray(new Filter[0]); + } + + @Override + public Filter[] pushedFilters() { + return pushedFilters; + } + + @Override + public boolean pushAggregation(Aggregation aggregation) { + if (!canPushDownAggregation(aggregation)) { + return false; + } + + AggregateEvaluator aggregateEvaluator; + List> expressions = + Lists.newArrayListWithExpectedSize(aggregation.aggregateExpressions().length); + + for (AggregateFunc aggregateFunc : aggregation.aggregateExpressions()) { + try { + Expression expr = SparkAggregates.convert(aggregateFunc); + if (expr != null) { + Expression bound = Binder.bind(schema.asStruct(), expr, caseSensitive); + expressions.add((BoundAggregate) bound); + } else { + LOG.info( + "Skipping aggregate pushdown: AggregateFunc {} can't be converted to iceberg expression", + aggregateFunc); + return false; + } + } catch (IllegalArgumentException e) { + LOG.info("Skipping aggregate pushdown: Bind failed for AggregateFunc {}", aggregateFunc, e); + return false; + } + } + + aggregateEvaluator = AggregateEvaluator.create(expressions); + + if (!metricsModeSupportsAggregatePushDown(aggregateEvaluator.aggregates())) { + return false; + } + + TableScan scan = table.newScan().includeColumnStats(); + Snapshot snapshot = readSnapshot(); + if (snapshot == null) { + LOG.info("Skipping aggregate pushdown: table snapshot is null"); + return false; + } + scan = scan.useSnapshot(snapshot.snapshotId()); + scan = configureSplitPlanning(scan); + scan = scan.filter(filterExpression()); + + try (CloseableIterable fileScanTasks = scan.planFiles()) { + List tasks = ImmutableList.copyOf(fileScanTasks); + for (FileScanTask task : tasks) { + if (!task.deletes().isEmpty()) { + LOG.info("Skipping aggregate pushdown: detected row level deletes"); + return false; + } + + aggregateEvaluator.update(task.file()); + } + } catch (IOException e) { + LOG.info("Skipping aggregate pushdown: ", e); + return false; + } + + if (!aggregateEvaluator.allAggregatorsValid()) { + return false; + } + + pushedAggregateSchema = + SparkSchemaUtil.convert(new Schema(aggregateEvaluator.resultType().fields())); + InternalRow[] pushedAggregateRows = new InternalRow[1]; + StructLike structLike = aggregateEvaluator.result(); + pushedAggregateRows[0] = + new StructInternalRow(aggregateEvaluator.resultType()).setStruct(structLike); + localScan = + new SparkLocalScan(table, pushedAggregateSchema, pushedAggregateRows, filterExpressions); + + return true; + } + + private boolean canPushDownAggregation(Aggregation aggregation) { + if (!(table instanceof BaseTable)) { + return false; + } + + if (!readConf.aggregatePushDownEnabled()) { + return false; + } + + // If group by expression is the same as the partition, the statistics information can still + // be used to calculate min/max/count, will enable aggregate push down in next phase. + // TODO: enable aggregate push down for partition col group by expression + if (aggregation.groupByExpressions().length > 0) { + LOG.info("Skipping aggregate pushdown: group by aggregation push down is not supported"); + return false; + } + + return true; + } + + private Snapshot readSnapshot() { + Snapshot snapshot; + if (readConf.snapshotId() != null) { + snapshot = table.snapshot(readConf.snapshotId()); + } else { + snapshot = SnapshotUtil.latestSnapshot(table, readConf.branch()); + } + + return snapshot; + } + + private boolean metricsModeSupportsAggregatePushDown(List> aggregates) { + MetricsConfig config = MetricsConfig.forTable(table); + for (BoundAggregate aggregate : aggregates) { + String colName = aggregate.columnName(); + if (!colName.equals("*")) { + MetricsModes.MetricsMode mode = config.columnMode(colName); + if (mode instanceof MetricsModes.None) { + LOG.info("Skipping aggregate pushdown: No metrics for column {}", colName); + return false; + } else if (mode instanceof MetricsModes.Counts) { + if (aggregate.op() == Expression.Operation.MAX + || aggregate.op() == Expression.Operation.MIN) { + LOG.info( + "Skipping aggregate pushdown: Cannot produce min or max from count for column {}", + colName); + return false; + } + } else if (mode instanceof MetricsModes.Truncate) { + // lower_bounds and upper_bounds may be truncated, so disable push down + if (aggregate.type().typeId() == Type.TypeID.STRING) { + if (aggregate.op() == Expression.Operation.MAX + || aggregate.op() == Expression.Operation.MIN) { + LOG.info( + "Skipping aggregate pushdown: Cannot produce min or max from truncated values for column {}", + colName); + return false; + } + } + } + } + } + + return true; + } + + @Override + public void pruneColumns(StructType requestedSchema) { + StructType requestedProjection = + new StructType( + Stream.of(requestedSchema.fields()) + .filter(field -> MetadataColumns.nonMetadataColumn(field.name())) + .toArray(StructField[]::new)); + + // the projection should include all columns that will be returned, including those only used in + // filters + this.schema = + SparkSchemaUtil.prune(schema, requestedProjection, filterExpression(), caseSensitive); + + Stream.of(requestedSchema.fields()) + .map(StructField::name) + .filter(MetadataColumns::isMetadataColumn) + .distinct() + .forEach(metaColumns::add); + } + + private Schema schemaWithMetadataColumns() { + // metadata columns + List fields = + metaColumns.stream() + .distinct() + .map(name -> MetadataColumns.metadataColumn(table, name)) + .collect(Collectors.toList()); + Schema meta = new Schema(fields); + + // schema or rows returned by readers + return TypeUtil.join(schema, meta); + } + + @Override + public Scan build() { + if (localScan != null) { + return localScan; + } else { + return buildBatchScan(); + } + } + + private Scan buildBatchScan() { + Long snapshotId = readConf.snapshotId(); + Long asOfTimestamp = readConf.asOfTimestamp(); + String branch = readConf.branch(); + String tag = readConf.tag(); + + Preconditions.checkArgument( + snapshotId == null || asOfTimestamp == null, + "Cannot set both %s and %s to select which table snapshot to scan", + SparkReadOptions.SNAPSHOT_ID, + SparkReadOptions.AS_OF_TIMESTAMP); + + Long startSnapshotId = readConf.startSnapshotId(); + Long endSnapshotId = readConf.endSnapshotId(); + + if (snapshotId != null || asOfTimestamp != null) { + Preconditions.checkArgument( + startSnapshotId == null && endSnapshotId == null, + "Cannot set %s and %s for incremental scans when either %s or %s is set", + SparkReadOptions.START_SNAPSHOT_ID, + SparkReadOptions.END_SNAPSHOT_ID, + SparkReadOptions.SNAPSHOT_ID, + SparkReadOptions.AS_OF_TIMESTAMP); + } + + Preconditions.checkArgument( + startSnapshotId != null || endSnapshotId == null, + "Cannot set only %s for incremental scans. Please, set %s too.", + SparkReadOptions.END_SNAPSHOT_ID, + SparkReadOptions.START_SNAPSHOT_ID); + + Long startTimestamp = readConf.startTimestamp(); + Long endTimestamp = readConf.endTimestamp(); + Preconditions.checkArgument( + startTimestamp == null && endTimestamp == null, + "Cannot set %s or %s for incremental scans and batch scan. They are only valid for " + + "changelog scans.", + SparkReadOptions.START_TIMESTAMP, + SparkReadOptions.END_TIMESTAMP); + + if (startSnapshotId != null) { + return buildIncrementalAppendScan(startSnapshotId, endSnapshotId); + } else { + return buildBatchScan(snapshotId, asOfTimestamp, branch, tag); + } + } + + private Scan buildBatchScan(Long snapshotId, Long asOfTimestamp, String branch, String tag) { + Schema expectedSchema = schemaWithMetadataColumns(); + + BatchScan scan = + table + .newBatchScan() + .caseSensitive(caseSensitive) + .filter(filterExpression()) + .project(expectedSchema); + + if (snapshotId != null) { + scan = scan.useSnapshot(snapshotId); + } + + if (asOfTimestamp != null) { + scan = scan.asOfTime(asOfTimestamp); + } + + if (branch != null) { + scan = scan.useRef(branch); + } + + if (tag != null) { + scan = scan.useRef(tag); + } + + scan = configureSplitPlanning(scan); + + return new SparkBatchQueryScan(spark, table, scan, readConf, expectedSchema, filterExpressions); + } + + private Scan buildIncrementalAppendScan(long startSnapshotId, Long endSnapshotId) { + Schema expectedSchema = schemaWithMetadataColumns(); + + IncrementalAppendScan scan = + table + .newIncrementalAppendScan() + .fromSnapshotExclusive(startSnapshotId) + .caseSensitive(caseSensitive) + .filter(filterExpression()) + .project(expectedSchema); + + if (endSnapshotId != null) { + scan = scan.toSnapshot(endSnapshotId); + } + + scan = configureSplitPlanning(scan); + + return new SparkBatchQueryScan(spark, table, scan, readConf, expectedSchema, filterExpressions); + } + + public Scan buildChangelogScan() { + Preconditions.checkArgument( + readConf.snapshotId() == null + && readConf.asOfTimestamp() == null + && readConf.branch() == null + && readConf.tag() == null, + "Cannot set neither %s, %s, %s and %s for changelogs", + SparkReadOptions.SNAPSHOT_ID, + SparkReadOptions.AS_OF_TIMESTAMP, + SparkReadOptions.BRANCH, + SparkReadOptions.TAG); + + Long startSnapshotId = readConf.startSnapshotId(); + Long endSnapshotId = readConf.endSnapshotId(); + Long startTimestamp = readConf.startTimestamp(); + Long endTimestamp = readConf.endTimestamp(); + + Preconditions.checkArgument( + !(startSnapshotId != null && startTimestamp != null), + "Cannot set both %s and %s for changelogs", + SparkReadOptions.START_SNAPSHOT_ID, + SparkReadOptions.START_TIMESTAMP); + + Preconditions.checkArgument( + !(endSnapshotId != null && endTimestamp != null), + "Cannot set both %s and %s for changelogs", + SparkReadOptions.END_SNAPSHOT_ID, + SparkReadOptions.END_TIMESTAMP); + + if (startTimestamp != null && endTimestamp != null) { + Preconditions.checkArgument( + startTimestamp < endTimestamp, + "Cannot set %s to be greater than %s for changelogs", + SparkReadOptions.START_TIMESTAMP, + SparkReadOptions.END_TIMESTAMP); + } + + if (startTimestamp != null) { + startSnapshotId = getStartSnapshotId(startTimestamp); + } + + if (endTimestamp != null) { + endSnapshotId = SnapshotUtil.snapshotIdAsOfTime(table, endTimestamp); + } + + Schema expectedSchema = schemaWithMetadataColumns(); + + IncrementalChangelogScan scan = + table + .newIncrementalChangelogScan() + .caseSensitive(caseSensitive) + .filter(filterExpression()) + .project(expectedSchema); + + if (startSnapshotId != null) { + scan = scan.fromSnapshotExclusive(startSnapshotId); + } + + if (endSnapshotId != null) { + scan = scan.toSnapshot(endSnapshotId); + } + + scan = configureSplitPlanning(scan); + + return new SparkChangelogScan(spark, table, scan, readConf, expectedSchema, filterExpressions); + } + + private Long getStartSnapshotId(Long startTimestamp) { + Snapshot oldestSnapshotAfter = SnapshotUtil.oldestAncestorAfter(table, startTimestamp); + Preconditions.checkArgument( + oldestSnapshotAfter != null, + "Cannot find a snapshot older than %s for table %s", + startTimestamp, + table.name()); + + if (oldestSnapshotAfter.timestampMillis() == startTimestamp) { + return oldestSnapshotAfter.snapshotId(); + } else { + return oldestSnapshotAfter.parentId(); + } + } + + public Scan buildMergeOnReadScan() { + Preconditions.checkArgument( + readConf.snapshotId() == null && readConf.asOfTimestamp() == null && readConf.tag() == null, + "Cannot set time travel options %s, %s, %s for row-level command scans", + SparkReadOptions.SNAPSHOT_ID, + SparkReadOptions.AS_OF_TIMESTAMP, + SparkReadOptions.TAG); + + Preconditions.checkArgument( + readConf.startSnapshotId() == null && readConf.endSnapshotId() == null, + "Cannot set incremental scan options %s and %s for row-level command scans", + SparkReadOptions.START_SNAPSHOT_ID, + SparkReadOptions.END_SNAPSHOT_ID); + + Snapshot snapshot = SnapshotUtil.latestSnapshot(table, readConf.branch()); + + if (snapshot == null) { + return new SparkBatchQueryScan( + spark, table, null, readConf, schemaWithMetadataColumns(), filterExpressions); + } + + // remember the current snapshot ID for commit validation + long snapshotId = snapshot.snapshotId(); + + CaseInsensitiveStringMap adjustedOptions = + Spark3Util.setOption(SparkReadOptions.SNAPSHOT_ID, Long.toString(snapshotId), options); + SparkReadConf adjustedReadConf = + new SparkReadConf(spark, table, readConf.branch(), adjustedOptions); + + Schema expectedSchema = schemaWithMetadataColumns(); + + BatchScan scan = + table + .newBatchScan() + .useSnapshot(snapshotId) + .caseSensitive(caseSensitive) + .filter(filterExpression()) + .project(expectedSchema); + + scan = configureSplitPlanning(scan); + + return new SparkBatchQueryScan( + spark, table, scan, adjustedReadConf, expectedSchema, filterExpressions); + } + + public Scan buildCopyOnWriteScan() { + Snapshot snapshot = SnapshotUtil.latestSnapshot(table, readConf.branch()); + + if (snapshot == null) { + return new SparkCopyOnWriteScan( + spark, table, readConf, schemaWithMetadataColumns(), filterExpressions); + } + + Schema expectedSchema = schemaWithMetadataColumns(); + + BatchScan scan = + table + .newBatchScan() + .useSnapshot(snapshot.snapshotId()) + .ignoreResiduals() + .caseSensitive(caseSensitive) + .filter(filterExpression()) + .project(expectedSchema); + + scan = configureSplitPlanning(scan); + + return new SparkCopyOnWriteScan( + spark, table, scan, snapshot, readConf, expectedSchema, filterExpressions); + } + + private > T configureSplitPlanning(T scan) { + T configuredScan = scan; + + Long splitSize = readConf.splitSizeOption(); + if (splitSize != null) { + configuredScan = configuredScan.option(TableProperties.SPLIT_SIZE, String.valueOf(splitSize)); + } + + Integer splitLookback = readConf.splitLookbackOption(); + if (splitLookback != null) { + configuredScan = + configuredScan.option(TableProperties.SPLIT_LOOKBACK, String.valueOf(splitLookback)); + } + + Long splitOpenFileCost = readConf.splitOpenFileCostOption(); + if (splitOpenFileCost != null) { + configuredScan = + configuredScan.option( + TableProperties.SPLIT_OPEN_FILE_COST, String.valueOf(splitOpenFileCost)); + } + + return configuredScan; + } + + @Override + public Statistics estimateStatistics() { + return ((SparkScan) build()).estimateStatistics(); + } + + @Override + public StructType readSchema() { + return build().readSchema(); + } +} diff --git a/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/source/SparkStagedScan.java b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/source/SparkStagedScan.java new file mode 100644 index 000000000000..89b184c91c51 --- /dev/null +++ b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/source/SparkStagedScan.java @@ -0,0 +1,96 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.source; + +import java.util.List; +import java.util.Objects; +import org.apache.iceberg.ScanTask; +import org.apache.iceberg.ScanTaskGroup; +import org.apache.iceberg.Table; +import org.apache.iceberg.exceptions.ValidationException; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableList; +import org.apache.iceberg.spark.ScanTaskSetManager; +import org.apache.iceberg.spark.SparkReadConf; +import org.apache.iceberg.util.TableScanUtil; +import org.apache.spark.sql.SparkSession; + +class SparkStagedScan extends SparkScan { + + private final String taskSetId; + private final long splitSize; + private final int splitLookback; + private final long openFileCost; + + private List> taskGroups = null; // lazy cache of tasks + + SparkStagedScan(SparkSession spark, Table table, SparkReadConf readConf) { + super(spark, table, readConf, table.schema(), ImmutableList.of()); + + this.taskSetId = readConf.scanTaskSetId(); + this.splitSize = readConf.splitSize(); + this.splitLookback = readConf.splitLookback(); + this.openFileCost = readConf.splitOpenFileCost(); + } + + @Override + protected List> taskGroups() { + if (taskGroups == null) { + ScanTaskSetManager taskSetManager = ScanTaskSetManager.get(); + List tasks = taskSetManager.fetchTasks(table(), taskSetId); + ValidationException.check( + tasks != null, + "Task set manager has no tasks for table %s with task set ID %s", + table(), + taskSetId); + + this.taskGroups = TableScanUtil.planTaskGroups(tasks, splitSize, splitLookback, openFileCost); + } + return taskGroups; + } + + @Override + public boolean equals(Object other) { + if (this == other) { + return true; + } + + if (other == null || getClass() != other.getClass()) { + return false; + } + + SparkStagedScan that = (SparkStagedScan) other; + return table().name().equals(that.table().name()) + && Objects.equals(taskSetId, that.taskSetId) + && Objects.equals(splitSize, that.splitSize) + && Objects.equals(splitLookback, that.splitLookback) + && Objects.equals(openFileCost, that.openFileCost); + } + + @Override + public int hashCode() { + return Objects.hash(table().name(), taskSetId, splitSize, splitSize, openFileCost); + } + + @Override + public String toString() { + return String.format( + "IcebergStagedScan(table=%s, type=%s, taskSetID=%s, caseSensitive=%s)", + table(), expectedSchema().asStruct(), taskSetId, caseSensitive()); + } +} diff --git a/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/source/SparkStagedScanBuilder.java b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/source/SparkStagedScanBuilder.java new file mode 100644 index 000000000000..37bbea42e5b1 --- /dev/null +++ b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/source/SparkStagedScanBuilder.java @@ -0,0 +1,44 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.source; + +import org.apache.iceberg.Table; +import org.apache.iceberg.spark.SparkReadConf; +import org.apache.spark.sql.SparkSession; +import org.apache.spark.sql.connector.read.Scan; +import org.apache.spark.sql.connector.read.ScanBuilder; +import org.apache.spark.sql.util.CaseInsensitiveStringMap; + +class SparkStagedScanBuilder implements ScanBuilder { + + private final SparkSession spark; + private final Table table; + private final SparkReadConf readConf; + + SparkStagedScanBuilder(SparkSession spark, Table table, CaseInsensitiveStringMap options) { + this.spark = spark; + this.table = table; + this.readConf = new SparkReadConf(spark, table, options); + } + + @Override + public Scan build() { + return new SparkStagedScan(spark, table, readConf); + } +} diff --git a/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/source/SparkTable.java b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/source/SparkTable.java new file mode 100644 index 000000000000..d84528451348 --- /dev/null +++ b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/source/SparkTable.java @@ -0,0 +1,427 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.source; + +import static org.apache.iceberg.TableProperties.CURRENT_SNAPSHOT_ID; +import static org.apache.iceberg.TableProperties.FORMAT_VERSION; + +import java.io.IOException; +import java.util.Map; +import java.util.Set; +import org.apache.iceberg.BaseMetadataTable; +import org.apache.iceberg.BaseTable; +import org.apache.iceberg.DataFile; +import org.apache.iceberg.DeleteFiles; +import org.apache.iceberg.FileScanTask; +import org.apache.iceberg.MetadataColumns; +import org.apache.iceberg.PartitionSpec; +import org.apache.iceberg.Partitioning; +import org.apache.iceberg.PositionDeletesTable; +import org.apache.iceberg.Schema; +import org.apache.iceberg.SnapshotRef; +import org.apache.iceberg.Table; +import org.apache.iceberg.TableOperations; +import org.apache.iceberg.TableProperties; +import org.apache.iceberg.TableScan; +import org.apache.iceberg.exceptions.ValidationException; +import org.apache.iceberg.expressions.Evaluator; +import org.apache.iceberg.expressions.Expression; +import org.apache.iceberg.expressions.ExpressionUtil; +import org.apache.iceberg.expressions.Expressions; +import org.apache.iceberg.expressions.Projections; +import org.apache.iceberg.expressions.StrictMetricsEvaluator; +import org.apache.iceberg.io.CloseableIterable; +import org.apache.iceberg.relocated.com.google.common.base.Preconditions; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableSet; +import org.apache.iceberg.relocated.com.google.common.collect.Iterables; +import org.apache.iceberg.relocated.com.google.common.collect.Maps; +import org.apache.iceberg.spark.Spark3Util; +import org.apache.iceberg.spark.SparkFilters; +import org.apache.iceberg.spark.SparkReadOptions; +import org.apache.iceberg.spark.SparkSchemaUtil; +import org.apache.iceberg.spark.SparkUtil; +import org.apache.iceberg.util.PropertyUtil; +import org.apache.iceberg.util.SnapshotUtil; +import org.apache.spark.sql.SparkSession; +import org.apache.spark.sql.connector.catalog.MetadataColumn; +import org.apache.spark.sql.connector.catalog.SupportsDelete; +import org.apache.spark.sql.connector.catalog.SupportsMetadataColumns; +import org.apache.spark.sql.connector.catalog.SupportsRead; +import org.apache.spark.sql.connector.catalog.SupportsRowLevelOperations; +import org.apache.spark.sql.connector.catalog.SupportsWrite; +import org.apache.spark.sql.connector.catalog.TableCapability; +import org.apache.spark.sql.connector.expressions.Transform; +import org.apache.spark.sql.connector.read.ScanBuilder; +import org.apache.spark.sql.connector.write.LogicalWriteInfo; +import org.apache.spark.sql.connector.write.RowLevelOperationBuilder; +import org.apache.spark.sql.connector.write.RowLevelOperationInfo; +import org.apache.spark.sql.connector.write.WriteBuilder; +import org.apache.spark.sql.sources.Filter; +import org.apache.spark.sql.types.DataType; +import org.apache.spark.sql.types.DataTypes; +import org.apache.spark.sql.types.StructType; +import org.apache.spark.sql.util.CaseInsensitiveStringMap; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +public class SparkTable + implements org.apache.spark.sql.connector.catalog.Table, + SupportsRead, + SupportsWrite, + SupportsDelete, + SupportsRowLevelOperations, + SupportsMetadataColumns { + + private static final Logger LOG = LoggerFactory.getLogger(SparkTable.class); + + private static final Set RESERVED_PROPERTIES = + ImmutableSet.of( + "provider", + "format", + CURRENT_SNAPSHOT_ID, + "location", + FORMAT_VERSION, + "sort-order", + "identifier-fields"); + private static final Set CAPABILITIES = + ImmutableSet.of( + TableCapability.BATCH_READ, + TableCapability.BATCH_WRITE, + TableCapability.MICRO_BATCH_READ, + TableCapability.STREAMING_WRITE, + TableCapability.OVERWRITE_BY_FILTER, + TableCapability.OVERWRITE_DYNAMIC); + private static final Set CAPABILITIES_WITH_ACCEPT_ANY_SCHEMA = + ImmutableSet.builder() + .addAll(CAPABILITIES) + .add(TableCapability.ACCEPT_ANY_SCHEMA) + .build(); + + private final Table icebergTable; + private final Long snapshotId; + private final boolean refreshEagerly; + private final Set capabilities; + private String branch; + private StructType lazyTableSchema = null; + private SparkSession lazySpark = null; + + public SparkTable(Table icebergTable, boolean refreshEagerly) { + this(icebergTable, (Long) null, refreshEagerly); + } + + public SparkTable(Table icebergTable, String branch, boolean refreshEagerly) { + this(icebergTable, refreshEagerly); + this.branch = branch; + ValidationException.check( + branch == null + || SnapshotRef.MAIN_BRANCH.equals(branch) + || icebergTable.snapshot(branch) != null, + "Cannot use branch (does not exist): %s", + branch); + } + + public SparkTable(Table icebergTable, Long snapshotId, boolean refreshEagerly) { + this.icebergTable = icebergTable; + this.snapshotId = snapshotId; + this.refreshEagerly = refreshEagerly; + + boolean acceptAnySchema = + PropertyUtil.propertyAsBoolean( + icebergTable.properties(), + TableProperties.SPARK_WRITE_ACCEPT_ANY_SCHEMA, + TableProperties.SPARK_WRITE_ACCEPT_ANY_SCHEMA_DEFAULT); + this.capabilities = acceptAnySchema ? CAPABILITIES_WITH_ACCEPT_ANY_SCHEMA : CAPABILITIES; + } + + private SparkSession sparkSession() { + if (lazySpark == null) { + this.lazySpark = SparkSession.active(); + } + + return lazySpark; + } + + public Table table() { + return icebergTable; + } + + @Override + public String name() { + return icebergTable.toString(); + } + + public Long snapshotId() { + return snapshotId; + } + + public SparkTable copyWithSnapshotId(long newSnapshotId) { + return new SparkTable(icebergTable, newSnapshotId, refreshEagerly); + } + + public SparkTable copyWithBranch(String targetBranch) { + return new SparkTable(icebergTable, targetBranch, refreshEagerly); + } + + private Schema snapshotSchema() { + if (icebergTable instanceof BaseMetadataTable) { + return icebergTable.schema(); + } else if (branch != null) { + return SnapshotUtil.schemaFor(icebergTable, branch); + } else { + return SnapshotUtil.schemaFor(icebergTable, snapshotId, null); + } + } + + @Override + public StructType schema() { + if (lazyTableSchema == null) { + this.lazyTableSchema = SparkSchemaUtil.convert(snapshotSchema()); + } + + return lazyTableSchema; + } + + @Override + public Transform[] partitioning() { + return Spark3Util.toTransforms(icebergTable.spec()); + } + + @Override + public Map properties() { + ImmutableMap.Builder propsBuilder = ImmutableMap.builder(); + + String fileFormat = + icebergTable + .properties() + .getOrDefault( + TableProperties.DEFAULT_FILE_FORMAT, TableProperties.DEFAULT_FILE_FORMAT_DEFAULT); + propsBuilder.put("format", "iceberg/" + fileFormat); + propsBuilder.put("provider", "iceberg"); + String currentSnapshotId = + icebergTable.currentSnapshot() != null + ? String.valueOf(icebergTable.currentSnapshot().snapshotId()) + : "none"; + propsBuilder.put(CURRENT_SNAPSHOT_ID, currentSnapshotId); + propsBuilder.put("location", icebergTable.location()); + + if (icebergTable instanceof BaseTable) { + TableOperations ops = ((BaseTable) icebergTable).operations(); + propsBuilder.put(FORMAT_VERSION, String.valueOf(ops.current().formatVersion())); + } + + if (!icebergTable.sortOrder().isUnsorted()) { + propsBuilder.put("sort-order", Spark3Util.describe(icebergTable.sortOrder())); + } + + Set identifierFields = icebergTable.schema().identifierFieldNames(); + if (!identifierFields.isEmpty()) { + propsBuilder.put("identifier-fields", "[" + String.join(",", identifierFields) + "]"); + } + + icebergTable.properties().entrySet().stream() + .filter(entry -> !RESERVED_PROPERTIES.contains(entry.getKey())) + .forEach(propsBuilder::put); + + return propsBuilder.build(); + } + + @Override + public Set capabilities() { + return capabilities; + } + + @Override + public MetadataColumn[] metadataColumns() { + DataType sparkPartitionType = SparkSchemaUtil.convert(Partitioning.partitionType(table())); + return new MetadataColumn[] { + new SparkMetadataColumn(MetadataColumns.SPEC_ID.name(), DataTypes.IntegerType, false), + new SparkMetadataColumn(MetadataColumns.PARTITION_COLUMN_NAME, sparkPartitionType, true), + new SparkMetadataColumn(MetadataColumns.FILE_PATH.name(), DataTypes.StringType, false), + new SparkMetadataColumn(MetadataColumns.ROW_POSITION.name(), DataTypes.LongType, false), + new SparkMetadataColumn(MetadataColumns.IS_DELETED.name(), DataTypes.BooleanType, false) + }; + } + + @Override + public ScanBuilder newScanBuilder(CaseInsensitiveStringMap options) { + if (options.containsKey(SparkReadOptions.SCAN_TASK_SET_ID)) { + return new SparkStagedScanBuilder(sparkSession(), icebergTable, options); + } + + if (refreshEagerly) { + icebergTable.refresh(); + } + + CaseInsensitiveStringMap scanOptions = + branch != null ? options : addSnapshotId(options, snapshotId); + return new SparkScanBuilder( + sparkSession(), icebergTable, branch, snapshotSchema(), scanOptions); + } + + @Override + public WriteBuilder newWriteBuilder(LogicalWriteInfo info) { + Preconditions.checkArgument( + snapshotId == null, "Cannot write to table at a specific snapshot: %s", snapshotId); + + if (icebergTable instanceof PositionDeletesTable) { + return new SparkPositionDeletesRewriteBuilder(sparkSession(), icebergTable, branch, info); + } else { + return new SparkWriteBuilder(sparkSession(), icebergTable, branch, info); + } + } + + @Override + public RowLevelOperationBuilder newRowLevelOperationBuilder(RowLevelOperationInfo info) { + return new SparkRowLevelOperationBuilder(sparkSession(), icebergTable, branch, info); + } + + @Override + public boolean canDeleteWhere(Filter[] filters) { + Preconditions.checkArgument( + snapshotId == null, "Cannot delete from table at a specific snapshot: %s", snapshotId); + + Expression deleteExpr = Expressions.alwaysTrue(); + + for (Filter filter : filters) { + Expression expr = SparkFilters.convert(filter); + if (expr != null) { + deleteExpr = Expressions.and(deleteExpr, expr); + } else { + return false; + } + } + + return canDeleteUsingMetadata(deleteExpr); + } + + // a metadata delete is possible iff matching files can be deleted entirely + private boolean canDeleteUsingMetadata(Expression deleteExpr) { + boolean caseSensitive = SparkUtil.caseSensitive(sparkSession()); + + if (ExpressionUtil.selectsPartitions(deleteExpr, table(), caseSensitive)) { + return true; + } + + TableScan scan = + table() + .newScan() + .filter(deleteExpr) + .caseSensitive(caseSensitive) + .includeColumnStats() + .ignoreResiduals(); + + if (branch != null) { + scan.useRef(branch); + } + + try (CloseableIterable tasks = scan.planFiles()) { + Map evaluators = Maps.newHashMap(); + StrictMetricsEvaluator metricsEvaluator = + new StrictMetricsEvaluator(SnapshotUtil.schemaFor(table(), branch), deleteExpr); + + return Iterables.all( + tasks, + task -> { + DataFile file = task.file(); + PartitionSpec spec = task.spec(); + Evaluator evaluator = + evaluators.computeIfAbsent( + spec.specId(), + specId -> + new Evaluator( + spec.partitionType(), Projections.strict(spec).project(deleteExpr))); + return evaluator.eval(file.partition()) || metricsEvaluator.eval(file); + }); + + } catch (IOException ioe) { + LOG.warn("Failed to close task iterable", ioe); + return false; + } + } + + @Override + public void deleteWhere(Filter[] filters) { + Expression deleteExpr = SparkFilters.convert(filters); + + if (deleteExpr == Expressions.alwaysFalse()) { + LOG.info("Skipping the delete operation as the condition is always false"); + return; + } + + DeleteFiles deleteFiles = + icebergTable + .newDelete() + .set("spark.app.id", sparkSession().sparkContext().applicationId()) + .deleteFromRowFilter(deleteExpr); + + if (branch != null) { + deleteFiles.toBranch(branch); + } + + deleteFiles.commit(); + } + + @Override + public String toString() { + return icebergTable.toString(); + } + + @Override + public boolean equals(Object other) { + if (this == other) { + return true; + } else if (other == null || getClass() != other.getClass()) { + return false; + } + + // use only name in order to correctly invalidate Spark cache + SparkTable that = (SparkTable) other; + return icebergTable.name().equals(that.icebergTable.name()); + } + + @Override + public int hashCode() { + // use only name in order to correctly invalidate Spark cache + return icebergTable.name().hashCode(); + } + + private static CaseInsensitiveStringMap addSnapshotId( + CaseInsensitiveStringMap options, Long snapshotId) { + if (snapshotId != null) { + String snapshotIdFromOptions = options.get(SparkReadOptions.SNAPSHOT_ID); + String value = snapshotId.toString(); + Preconditions.checkArgument( + snapshotIdFromOptions == null || snapshotIdFromOptions.equals(value), + "Cannot override snapshot ID more than once: %s", + snapshotIdFromOptions); + + Map scanOptions = Maps.newHashMap(); + scanOptions.putAll(options.asCaseSensitiveMap()); + scanOptions.put(SparkReadOptions.SNAPSHOT_ID, value); + scanOptions.remove(SparkReadOptions.AS_OF_TIMESTAMP); + scanOptions.remove(SparkReadOptions.BRANCH); + scanOptions.remove(SparkReadOptions.TAG); + + return new CaseInsensitiveStringMap(scanOptions); + } + + return options; + } +} diff --git a/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/source/SparkWrite.java b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/source/SparkWrite.java new file mode 100644 index 000000000000..a080fcead13b --- /dev/null +++ b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/source/SparkWrite.java @@ -0,0 +1,776 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.source; + +import static org.apache.iceberg.IsolationLevel.SERIALIZABLE; +import static org.apache.iceberg.IsolationLevel.SNAPSHOT; + +import java.io.IOException; +import java.util.Arrays; +import java.util.List; +import java.util.Map; +import java.util.stream.Collectors; +import org.apache.iceberg.AppendFiles; +import org.apache.iceberg.DataFile; +import org.apache.iceberg.FileFormat; +import org.apache.iceberg.FileScanTask; +import org.apache.iceberg.IsolationLevel; +import org.apache.iceberg.OverwriteFiles; +import org.apache.iceberg.PartitionKey; +import org.apache.iceberg.PartitionSpec; +import org.apache.iceberg.ReplacePartitions; +import org.apache.iceberg.Schema; +import org.apache.iceberg.Snapshot; +import org.apache.iceberg.SnapshotSummary; +import org.apache.iceberg.SnapshotUpdate; +import org.apache.iceberg.Table; +import org.apache.iceberg.exceptions.CommitStateUnknownException; +import org.apache.iceberg.expressions.Expression; +import org.apache.iceberg.expressions.Expressions; +import org.apache.iceberg.io.ClusteredDataWriter; +import org.apache.iceberg.io.DataWriteResult; +import org.apache.iceberg.io.FanoutDataWriter; +import org.apache.iceberg.io.FileIO; +import org.apache.iceberg.io.FileWriter; +import org.apache.iceberg.io.OutputFileFactory; +import org.apache.iceberg.io.PartitioningWriter; +import org.apache.iceberg.io.RollingDataWriter; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableList; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableSet; +import org.apache.iceberg.relocated.com.google.common.collect.Lists; +import org.apache.iceberg.spark.CommitMetadata; +import org.apache.iceberg.spark.FileRewriteCoordinator; +import org.apache.iceberg.spark.SparkWriteConf; +import org.apache.spark.TaskContext; +import org.apache.spark.TaskContext$; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.broadcast.Broadcast; +import org.apache.spark.executor.OutputMetrics; +import org.apache.spark.sql.SparkSession; +import org.apache.spark.sql.catalyst.InternalRow; +import org.apache.spark.sql.connector.distributions.Distribution; +import org.apache.spark.sql.connector.expressions.SortOrder; +import org.apache.spark.sql.connector.write.BatchWrite; +import org.apache.spark.sql.connector.write.DataWriter; +import org.apache.spark.sql.connector.write.DataWriterFactory; +import org.apache.spark.sql.connector.write.LogicalWriteInfo; +import org.apache.spark.sql.connector.write.PhysicalWriteInfo; +import org.apache.spark.sql.connector.write.RequiresDistributionAndOrdering; +import org.apache.spark.sql.connector.write.Write; +import org.apache.spark.sql.connector.write.WriterCommitMessage; +import org.apache.spark.sql.connector.write.streaming.StreamingDataWriterFactory; +import org.apache.spark.sql.connector.write.streaming.StreamingWrite; +import org.apache.spark.sql.types.StructType; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +abstract class SparkWrite implements Write, RequiresDistributionAndOrdering { + private static final Logger LOG = LoggerFactory.getLogger(SparkWrite.class); + + private final JavaSparkContext sparkContext; + private final SparkWriteConf writeConf; + private final Table table; + private final String queryId; + private final FileFormat format; + private final String applicationId; + private final boolean wapEnabled; + private final String wapId; + private final int outputSpecId; + private final String branch; + private final long targetFileSize; + private final Schema writeSchema; + private final StructType dsSchema; + private final Map extraSnapshotMetadata; + private final boolean partitionedFanoutEnabled; + private final Distribution requiredDistribution; + private final SortOrder[] requiredOrdering; + + private boolean cleanupOnAbort = true; + + SparkWrite( + SparkSession spark, + Table table, + SparkWriteConf writeConf, + LogicalWriteInfo writeInfo, + String applicationId, + Schema writeSchema, + StructType dsSchema, + Distribution requiredDistribution, + SortOrder[] requiredOrdering) { + this.sparkContext = JavaSparkContext.fromSparkContext(spark.sparkContext()); + this.table = table; + this.writeConf = writeConf; + this.queryId = writeInfo.queryId(); + this.format = writeConf.dataFileFormat(); + this.applicationId = applicationId; + this.wapEnabled = writeConf.wapEnabled(); + this.wapId = writeConf.wapId(); + this.branch = writeConf.branch(); + this.targetFileSize = writeConf.targetDataFileSize(); + this.writeSchema = writeSchema; + this.dsSchema = dsSchema; + this.extraSnapshotMetadata = writeConf.extraSnapshotMetadata(); + this.partitionedFanoutEnabled = writeConf.fanoutWriterEnabled(); + this.requiredDistribution = requiredDistribution; + this.requiredOrdering = requiredOrdering; + this.outputSpecId = writeConf.outputSpecId(); + } + + @Override + public Distribution requiredDistribution() { + return requiredDistribution; + } + + @Override + public SortOrder[] requiredOrdering() { + return requiredOrdering; + } + + BatchWrite asBatchAppend() { + return new BatchAppend(); + } + + BatchWrite asDynamicOverwrite() { + return new DynamicOverwrite(); + } + + BatchWrite asOverwriteByFilter(Expression overwriteExpr) { + return new OverwriteByFilter(overwriteExpr); + } + + BatchWrite asCopyOnWriteOperation(SparkCopyOnWriteScan scan, IsolationLevel isolationLevel) { + return new CopyOnWriteOperation(scan, isolationLevel); + } + + BatchWrite asRewrite(String fileSetID) { + return new RewriteFiles(fileSetID); + } + + StreamingWrite asStreamingAppend() { + return new StreamingAppend(); + } + + StreamingWrite asStreamingOverwrite() { + return new StreamingOverwrite(); + } + + // the writer factory works for both batch and streaming + private WriterFactory createWriterFactory() { + // broadcast the table metadata as the writer factory will be sent to executors + Broadcast
tableBroadcast = + sparkContext.broadcast(SerializableTableWithSize.copyOf(table)); + return new WriterFactory( + tableBroadcast, + queryId, + format, + outputSpecId, + targetFileSize, + writeSchema, + dsSchema, + partitionedFanoutEnabled); + } + + private void commitOperation(SnapshotUpdate operation, String description) { + LOG.info("Committing {} to table {}", description, table); + if (applicationId != null) { + operation.set("spark.app.id", applicationId); + } + + if (!extraSnapshotMetadata.isEmpty()) { + extraSnapshotMetadata.forEach(operation::set); + } + + if (!CommitMetadata.commitProperties().isEmpty()) { + CommitMetadata.commitProperties().forEach(operation::set); + } + + if (wapEnabled && wapId != null) { + // write-audit-publish is enabled for this table and job + // stage the changes without changing the current snapshot + operation.set(SnapshotSummary.STAGED_WAP_ID_PROP, wapId); + operation.stageOnly(); + } + + if (branch != null) { + operation.toBranch(branch); + } + + try { + long start = System.currentTimeMillis(); + operation.commit(); // abort is automatically called if this fails + long duration = System.currentTimeMillis() - start; + LOG.info("Committed in {} ms", duration); + } catch (CommitStateUnknownException commitStateUnknownException) { + cleanupOnAbort = false; + throw commitStateUnknownException; + } + } + + private void abort(WriterCommitMessage[] messages) { + if (cleanupOnAbort) { + SparkCleanupUtil.deleteFiles("job abort", table.io(), files(messages)); + } else { + LOG.warn("Skipping cleanup of written files"); + } + } + + private List files(WriterCommitMessage[] messages) { + List files = Lists.newArrayList(); + + for (WriterCommitMessage message : messages) { + if (message != null) { + TaskCommit taskCommit = (TaskCommit) message; + files.addAll(Arrays.asList(taskCommit.files())); + } + } + + return files; + } + + @Override + public String toString() { + return String.format("IcebergWrite(table=%s, format=%s)", table, format); + } + + private abstract class BaseBatchWrite implements BatchWrite { + @Override + public DataWriterFactory createBatchWriterFactory(PhysicalWriteInfo info) { + return createWriterFactory(); + } + + @Override + public void abort(WriterCommitMessage[] messages) { + SparkWrite.this.abort(messages); + } + + @Override + public String toString() { + return String.format("IcebergBatchWrite(table=%s, format=%s)", table, format); + } + } + + private class BatchAppend extends BaseBatchWrite { + @Override + public void commit(WriterCommitMessage[] messages) { + AppendFiles append = table.newAppend(); + + int numFiles = 0; + for (DataFile file : files(messages)) { + numFiles += 1; + append.appendFile(file); + } + + commitOperation(append, String.format("append with %d new data files", numFiles)); + } + } + + private class DynamicOverwrite extends BaseBatchWrite { + @Override + public void commit(WriterCommitMessage[] messages) { + List files = files(messages); + + if (files.isEmpty()) { + LOG.info("Dynamic overwrite is empty, skipping commit"); + return; + } + + ReplacePartitions dynamicOverwrite = table.newReplacePartitions(); + IsolationLevel isolationLevel = writeConf.isolationLevel(); + Long validateFromSnapshotId = writeConf.validateFromSnapshotId(); + + if (isolationLevel != null && validateFromSnapshotId != null) { + dynamicOverwrite.validateFromSnapshot(validateFromSnapshotId); + } + + if (isolationLevel == SERIALIZABLE) { + dynamicOverwrite.validateNoConflictingData(); + dynamicOverwrite.validateNoConflictingDeletes(); + + } else if (isolationLevel == SNAPSHOT) { + dynamicOverwrite.validateNoConflictingDeletes(); + } + + int numFiles = 0; + for (DataFile file : files) { + numFiles += 1; + dynamicOverwrite.addFile(file); + } + + commitOperation( + dynamicOverwrite, + String.format("dynamic partition overwrite with %d new data files", numFiles)); + } + } + + private class OverwriteByFilter extends BaseBatchWrite { + private final Expression overwriteExpr; + + private OverwriteByFilter(Expression overwriteExpr) { + this.overwriteExpr = overwriteExpr; + } + + @Override + public void commit(WriterCommitMessage[] messages) { + OverwriteFiles overwriteFiles = table.newOverwrite(); + overwriteFiles.overwriteByRowFilter(overwriteExpr); + + int numFiles = 0; + for (DataFile file : files(messages)) { + numFiles += 1; + overwriteFiles.addFile(file); + } + + IsolationLevel isolationLevel = writeConf.isolationLevel(); + Long validateFromSnapshotId = writeConf.validateFromSnapshotId(); + + if (isolationLevel != null && validateFromSnapshotId != null) { + overwriteFiles.validateFromSnapshot(validateFromSnapshotId); + } + + if (isolationLevel == SERIALIZABLE) { + overwriteFiles.validateNoConflictingDeletes(); + overwriteFiles.validateNoConflictingData(); + + } else if (isolationLevel == SNAPSHOT) { + overwriteFiles.validateNoConflictingDeletes(); + } + + String commitMsg = + String.format("overwrite by filter %s with %d new data files", overwriteExpr, numFiles); + commitOperation(overwriteFiles, commitMsg); + } + } + + private class CopyOnWriteOperation extends BaseBatchWrite { + private final SparkCopyOnWriteScan scan; + private final IsolationLevel isolationLevel; + + private CopyOnWriteOperation(SparkCopyOnWriteScan scan, IsolationLevel isolationLevel) { + this.scan = scan; + this.isolationLevel = isolationLevel; + } + + private List overwrittenFiles() { + if (scan == null) { + return ImmutableList.of(); + } else { + return scan.tasks().stream().map(FileScanTask::file).collect(Collectors.toList()); + } + } + + private Expression conflictDetectionFilter() { + // the list of filter expressions may be empty but is never null + List scanFilterExpressions = scan.filterExpressions(); + + Expression filter = Expressions.alwaysTrue(); + + for (Expression expr : scanFilterExpressions) { + filter = Expressions.and(filter, expr); + } + + return filter; + } + + @Override + public void commit(WriterCommitMessage[] messages) { + OverwriteFiles overwriteFiles = table.newOverwrite(); + + List overwrittenFiles = overwrittenFiles(); + int numOverwrittenFiles = overwrittenFiles.size(); + for (DataFile overwrittenFile : overwrittenFiles) { + overwriteFiles.deleteFile(overwrittenFile); + } + + int numAddedFiles = 0; + for (DataFile file : files(messages)) { + numAddedFiles += 1; + overwriteFiles.addFile(file); + } + + // the scan may be null if the optimizer replaces it with an empty relation (e.g. false cond) + // no validation is needed in this case as the command does not depend on the table state + if (scan != null) { + if (isolationLevel == SERIALIZABLE) { + commitWithSerializableIsolation(overwriteFiles, numOverwrittenFiles, numAddedFiles); + } else if (isolationLevel == SNAPSHOT) { + commitWithSnapshotIsolation(overwriteFiles, numOverwrittenFiles, numAddedFiles); + } else { + throw new IllegalArgumentException("Unsupported isolation level: " + isolationLevel); + } + + } else { + commitOperation( + overwriteFiles, + String.format("overwrite with %d new data files (no validation)", numAddedFiles)); + } + } + + private void commitWithSerializableIsolation( + OverwriteFiles overwriteFiles, int numOverwrittenFiles, int numAddedFiles) { + Long scanSnapshotId = scan.snapshotId(); + if (scanSnapshotId != null) { + overwriteFiles.validateFromSnapshot(scanSnapshotId); + } + + Expression conflictDetectionFilter = conflictDetectionFilter(); + overwriteFiles.conflictDetectionFilter(conflictDetectionFilter); + overwriteFiles.validateNoConflictingData(); + overwriteFiles.validateNoConflictingDeletes(); + + String commitMsg = + String.format( + "overwrite of %d data files with %d new data files, scanSnapshotId: %d, conflictDetectionFilter: %s", + numOverwrittenFiles, numAddedFiles, scanSnapshotId, conflictDetectionFilter); + commitOperation(overwriteFiles, commitMsg); + } + + private void commitWithSnapshotIsolation( + OverwriteFiles overwriteFiles, int numOverwrittenFiles, int numAddedFiles) { + Long scanSnapshotId = scan.snapshotId(); + if (scanSnapshotId != null) { + overwriteFiles.validateFromSnapshot(scanSnapshotId); + } + + Expression conflictDetectionFilter = conflictDetectionFilter(); + overwriteFiles.conflictDetectionFilter(conflictDetectionFilter); + overwriteFiles.validateNoConflictingDeletes(); + + String commitMsg = + String.format( + "overwrite of %d data files with %d new data files", + numOverwrittenFiles, numAddedFiles); + commitOperation(overwriteFiles, commitMsg); + } + } + + private class RewriteFiles extends BaseBatchWrite { + private final String fileSetID; + + private RewriteFiles(String fileSetID) { + this.fileSetID = fileSetID; + } + + @Override + public void commit(WriterCommitMessage[] messages) { + FileRewriteCoordinator coordinator = FileRewriteCoordinator.get(); + coordinator.stageRewrite(table, fileSetID, ImmutableSet.copyOf(files(messages))); + } + } + + private abstract class BaseStreamingWrite implements StreamingWrite { + private static final String QUERY_ID_PROPERTY = "spark.sql.streaming.queryId"; + private static final String EPOCH_ID_PROPERTY = "spark.sql.streaming.epochId"; + + protected abstract String mode(); + + @Override + public StreamingDataWriterFactory createStreamingWriterFactory(PhysicalWriteInfo info) { + return createWriterFactory(); + } + + @Override + public final void commit(long epochId, WriterCommitMessage[] messages) { + LOG.info("Committing epoch {} for query {} in {} mode", epochId, queryId, mode()); + + table.refresh(); + + Long lastCommittedEpochId = findLastCommittedEpochId(); + if (lastCommittedEpochId != null && epochId <= lastCommittedEpochId) { + LOG.info("Skipping epoch {} for query {} as it was already committed", epochId, queryId); + return; + } + + doCommit(epochId, messages); + } + + protected abstract void doCommit(long epochId, WriterCommitMessage[] messages); + + protected void commit(SnapshotUpdate snapshotUpdate, long epochId, String description) { + snapshotUpdate.set(QUERY_ID_PROPERTY, queryId); + snapshotUpdate.set(EPOCH_ID_PROPERTY, Long.toString(epochId)); + commitOperation(snapshotUpdate, description); + } + + private Long findLastCommittedEpochId() { + Snapshot snapshot = table.currentSnapshot(); + Long lastCommittedEpochId = null; + while (snapshot != null) { + Map summary = snapshot.summary(); + String snapshotQueryId = summary.get(QUERY_ID_PROPERTY); + if (queryId.equals(snapshotQueryId)) { + lastCommittedEpochId = Long.valueOf(summary.get(EPOCH_ID_PROPERTY)); + break; + } + Long parentSnapshotId = snapshot.parentId(); + snapshot = parentSnapshotId != null ? table.snapshot(parentSnapshotId) : null; + } + return lastCommittedEpochId; + } + + @Override + public void abort(long epochId, WriterCommitMessage[] messages) { + SparkWrite.this.abort(messages); + } + + @Override + public String toString() { + return String.format("IcebergStreamingWrite(table=%s, format=%s)", table, format); + } + } + + private class StreamingAppend extends BaseStreamingWrite { + @Override + protected String mode() { + return "append"; + } + + @Override + protected void doCommit(long epochId, WriterCommitMessage[] messages) { + AppendFiles append = table.newFastAppend(); + int numFiles = 0; + for (DataFile file : files(messages)) { + append.appendFile(file); + numFiles++; + } + commit(append, epochId, String.format("streaming append with %d new data files", numFiles)); + } + } + + private class StreamingOverwrite extends BaseStreamingWrite { + @Override + protected String mode() { + return "complete"; + } + + @Override + public void doCommit(long epochId, WriterCommitMessage[] messages) { + OverwriteFiles overwriteFiles = table.newOverwrite(); + overwriteFiles.overwriteByRowFilter(Expressions.alwaysTrue()); + int numFiles = 0; + for (DataFile file : files(messages)) { + overwriteFiles.addFile(file); + numFiles++; + } + commit( + overwriteFiles, + epochId, + String.format("streaming complete overwrite with %d new data files", numFiles)); + } + } + + public static class TaskCommit implements WriterCommitMessage { + private final DataFile[] taskFiles; + + TaskCommit(DataFile[] taskFiles) { + this.taskFiles = taskFiles; + } + + // Reports bytesWritten and recordsWritten to the Spark output metrics. + // Can only be called in executor. + void reportOutputMetrics() { + long bytesWritten = 0L; + long recordsWritten = 0L; + for (DataFile dataFile : taskFiles) { + bytesWritten += dataFile.fileSizeInBytes(); + recordsWritten += dataFile.recordCount(); + } + + TaskContext taskContext = TaskContext$.MODULE$.get(); + if (taskContext != null) { + OutputMetrics outputMetrics = taskContext.taskMetrics().outputMetrics(); + outputMetrics.setBytesWritten(bytesWritten); + outputMetrics.setRecordsWritten(recordsWritten); + } + } + + DataFile[] files() { + return taskFiles; + } + } + + private static class WriterFactory implements DataWriterFactory, StreamingDataWriterFactory { + private final Broadcast
tableBroadcast; + private final FileFormat format; + private final int outputSpecId; + private final long targetFileSize; + private final Schema writeSchema; + private final StructType dsSchema; + private final boolean partitionedFanoutEnabled; + private final String queryId; + + protected WriterFactory( + Broadcast
tableBroadcast, + String queryId, + FileFormat format, + int outputSpecId, + long targetFileSize, + Schema writeSchema, + StructType dsSchema, + boolean partitionedFanoutEnabled) { + this.tableBroadcast = tableBroadcast; + this.format = format; + this.outputSpecId = outputSpecId; + this.targetFileSize = targetFileSize; + this.writeSchema = writeSchema; + this.dsSchema = dsSchema; + this.partitionedFanoutEnabled = partitionedFanoutEnabled; + this.queryId = queryId; + } + + @Override + public DataWriter createWriter(int partitionId, long taskId) { + return createWriter(partitionId, taskId, 0); + } + + @Override + public DataWriter createWriter(int partitionId, long taskId, long epochId) { + Table table = tableBroadcast.value(); + PartitionSpec spec = table.specs().get(outputSpecId); + FileIO io = table.io(); + + OutputFileFactory fileFactory = + OutputFileFactory.builderFor(table, partitionId, taskId) + .format(format) + .operationId(queryId) + .build(); + SparkFileWriterFactory writerFactory = + SparkFileWriterFactory.builderFor(table) + .dataFileFormat(format) + .dataSchema(writeSchema) + .dataSparkType(dsSchema) + .build(); + + if (spec.isUnpartitioned()) { + return new UnpartitionedDataWriter(writerFactory, fileFactory, io, spec, targetFileSize); + + } else { + return new PartitionedDataWriter( + writerFactory, + fileFactory, + io, + spec, + writeSchema, + dsSchema, + targetFileSize, + partitionedFanoutEnabled); + } + } + } + + private static class UnpartitionedDataWriter implements DataWriter { + private final FileWriter delegate; + private final FileIO io; + + private UnpartitionedDataWriter( + SparkFileWriterFactory writerFactory, + OutputFileFactory fileFactory, + FileIO io, + PartitionSpec spec, + long targetFileSize) { + this.delegate = + new RollingDataWriter<>(writerFactory, fileFactory, io, targetFileSize, spec, null); + this.io = io; + } + + @Override + public void write(InternalRow record) throws IOException { + delegate.write(record); + } + + @Override + public WriterCommitMessage commit() throws IOException { + close(); + + DataWriteResult result = delegate.result(); + TaskCommit taskCommit = new TaskCommit(result.dataFiles().toArray(new DataFile[0])); + taskCommit.reportOutputMetrics(); + return taskCommit; + } + + @Override + public void abort() throws IOException { + close(); + + DataWriteResult result = delegate.result(); + SparkCleanupUtil.deleteTaskFiles(io, result.dataFiles()); + } + + @Override + public void close() throws IOException { + delegate.close(); + } + } + + private static class PartitionedDataWriter implements DataWriter { + private final PartitioningWriter delegate; + private final FileIO io; + private final PartitionSpec spec; + private final PartitionKey partitionKey; + private final InternalRowWrapper internalRowWrapper; + + private PartitionedDataWriter( + SparkFileWriterFactory writerFactory, + OutputFileFactory fileFactory, + FileIO io, + PartitionSpec spec, + Schema dataSchema, + StructType dataSparkType, + long targetFileSize, + boolean fanoutEnabled) { + if (fanoutEnabled) { + this.delegate = new FanoutDataWriter<>(writerFactory, fileFactory, io, targetFileSize); + } else { + this.delegate = new ClusteredDataWriter<>(writerFactory, fileFactory, io, targetFileSize); + } + this.io = io; + this.spec = spec; + this.partitionKey = new PartitionKey(spec, dataSchema); + this.internalRowWrapper = new InternalRowWrapper(dataSparkType); + } + + @Override + public void write(InternalRow row) throws IOException { + partitionKey.partition(internalRowWrapper.wrap(row)); + delegate.write(row, spec, partitionKey); + } + + @Override + public WriterCommitMessage commit() throws IOException { + close(); + + DataWriteResult result = delegate.result(); + TaskCommit taskCommit = new TaskCommit(result.dataFiles().toArray(new DataFile[0])); + taskCommit.reportOutputMetrics(); + return taskCommit; + } + + @Override + public void abort() throws IOException { + close(); + + DataWriteResult result = delegate.result(); + SparkCleanupUtil.deleteTaskFiles(io, result.dataFiles()); + } + + @Override + public void close() throws IOException { + delegate.close(); + } + } +} diff --git a/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/source/SparkWriteBuilder.java b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/source/SparkWriteBuilder.java new file mode 100644 index 000000000000..133ca45b4603 --- /dev/null +++ b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/source/SparkWriteBuilder.java @@ -0,0 +1,269 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.source; + +import org.apache.iceberg.DistributionMode; +import org.apache.iceberg.IsolationLevel; +import org.apache.iceberg.PartitionSpec; +import org.apache.iceberg.Schema; +import org.apache.iceberg.Table; +import org.apache.iceberg.UpdateSchema; +import org.apache.iceberg.expressions.Expression; +import org.apache.iceberg.expressions.Expressions; +import org.apache.iceberg.relocated.com.google.common.base.Preconditions; +import org.apache.iceberg.spark.Spark3Util; +import org.apache.iceberg.spark.SparkDistributionAndOrderingUtil; +import org.apache.iceberg.spark.SparkFilters; +import org.apache.iceberg.spark.SparkSchemaUtil; +import org.apache.iceberg.spark.SparkUtil; +import org.apache.iceberg.spark.SparkWriteConf; +import org.apache.iceberg.types.TypeUtil; +import org.apache.spark.sql.SparkSession; +import org.apache.spark.sql.connector.distributions.Distribution; +import org.apache.spark.sql.connector.distributions.Distributions; +import org.apache.spark.sql.connector.expressions.SortOrder; +import org.apache.spark.sql.connector.read.Scan; +import org.apache.spark.sql.connector.write.BatchWrite; +import org.apache.spark.sql.connector.write.LogicalWriteInfo; +import org.apache.spark.sql.connector.write.RowLevelOperation.Command; +import org.apache.spark.sql.connector.write.SupportsDynamicOverwrite; +import org.apache.spark.sql.connector.write.SupportsOverwrite; +import org.apache.spark.sql.connector.write.Write; +import org.apache.spark.sql.connector.write.WriteBuilder; +import org.apache.spark.sql.connector.write.streaming.StreamingWrite; +import org.apache.spark.sql.sources.Filter; +import org.apache.spark.sql.types.StructType; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +class SparkWriteBuilder implements WriteBuilder, SupportsDynamicOverwrite, SupportsOverwrite { + private static final Logger LOG = LoggerFactory.getLogger(SparkWriteBuilder.class); + private static final SortOrder[] NO_ORDERING = new SortOrder[0]; + + private final SparkSession spark; + private final Table table; + private final SparkWriteConf writeConf; + private final LogicalWriteInfo writeInfo; + private final StructType dsSchema; + private final String overwriteMode; + private final String rewrittenFileSetId; + private final boolean handleTimestampWithoutZone; + private final boolean useTableDistributionAndOrdering; + private boolean overwriteDynamic = false; + private boolean overwriteByFilter = false; + private Expression overwriteExpr = null; + private boolean overwriteFiles = false; + private SparkCopyOnWriteScan copyOnWriteScan = null; + private Command copyOnWriteCommand = null; + private IsolationLevel copyOnWriteIsolationLevel = null; + + SparkWriteBuilder(SparkSession spark, Table table, String branch, LogicalWriteInfo info) { + this.spark = spark; + this.table = table; + this.writeConf = new SparkWriteConf(spark, table, branch, info.options()); + this.writeInfo = info; + this.dsSchema = info.schema(); + this.overwriteMode = writeConf.overwriteMode(); + this.rewrittenFileSetId = writeConf.rewrittenFileSetId(); + this.handleTimestampWithoutZone = writeConf.handleTimestampWithoutZone(); + this.useTableDistributionAndOrdering = writeConf.useTableDistributionAndOrdering(); + } + + public WriteBuilder overwriteFiles(Scan scan, Command command, IsolationLevel isolationLevel) { + Preconditions.checkState(!overwriteByFilter, "Cannot overwrite individual files and by filter"); + Preconditions.checkState( + !overwriteDynamic, "Cannot overwrite individual files and dynamically"); + Preconditions.checkState( + rewrittenFileSetId == null, "Cannot overwrite individual files and rewrite"); + + this.overwriteFiles = true; + this.copyOnWriteScan = (SparkCopyOnWriteScan) scan; + this.copyOnWriteCommand = command; + this.copyOnWriteIsolationLevel = isolationLevel; + return this; + } + + @Override + public WriteBuilder overwriteDynamicPartitions() { + Preconditions.checkState( + !overwriteByFilter, "Cannot overwrite dynamically and by filter: %s", overwriteExpr); + Preconditions.checkState(!overwriteFiles, "Cannot overwrite individual files and dynamically"); + Preconditions.checkState( + rewrittenFileSetId == null, "Cannot overwrite dynamically and rewrite"); + + this.overwriteDynamic = true; + return this; + } + + @Override + public WriteBuilder overwrite(Filter[] filters) { + Preconditions.checkState( + !overwriteFiles, "Cannot overwrite individual files and using filters"); + Preconditions.checkState(rewrittenFileSetId == null, "Cannot overwrite and rewrite"); + + this.overwriteExpr = SparkFilters.convert(filters); + if (overwriteExpr == Expressions.alwaysTrue() && "dynamic".equals(overwriteMode)) { + // use the write option to override truncating the table. use dynamic overwrite instead. + this.overwriteDynamic = true; + } else { + Preconditions.checkState( + !overwriteDynamic, "Cannot overwrite dynamically and by filter: %s", overwriteExpr); + this.overwriteByFilter = true; + } + return this; + } + + @Override + public Write build() { + // Validate + Preconditions.checkArgument( + handleTimestampWithoutZone || !SparkUtil.hasTimestampWithoutZone(table.schema()), + SparkUtil.TIMESTAMP_WITHOUT_TIMEZONE_ERROR); + + Schema writeSchema = validateOrMergeWriteSchema(table, dsSchema, writeConf); + SparkUtil.validatePartitionTransforms(table.spec()); + + // Get application id + String appId = spark.sparkContext().applicationId(); + + Distribution distribution; + SortOrder[] ordering; + + if (useTableDistributionAndOrdering) { + if (Spark3Util.extensionsEnabled(spark) || allIdentityTransforms(table.spec())) { + distribution = buildRequiredDistribution(); + ordering = buildRequiredOrdering(distribution); + } else { + LOG.warn( + "Skipping distribution/ordering: extensions are disabled and spec contains unsupported transforms"); + distribution = Distributions.unspecified(); + ordering = NO_ORDERING; + } + } else { + LOG.info("Skipping distribution/ordering: disabled per job configuration"); + distribution = Distributions.unspecified(); + ordering = NO_ORDERING; + } + + return new SparkWrite( + spark, table, writeConf, writeInfo, appId, writeSchema, dsSchema, distribution, ordering) { + + @Override + public BatchWrite toBatch() { + if (rewrittenFileSetId != null) { + return asRewrite(rewrittenFileSetId); + } else if (overwriteByFilter) { + return asOverwriteByFilter(overwriteExpr); + } else if (overwriteDynamic) { + return asDynamicOverwrite(); + } else if (overwriteFiles) { + return asCopyOnWriteOperation(copyOnWriteScan, copyOnWriteIsolationLevel); + } else { + return asBatchAppend(); + } + } + + @Override + public StreamingWrite toStreaming() { + Preconditions.checkState( + !overwriteDynamic, "Unsupported streaming operation: dynamic partition overwrite"); + Preconditions.checkState( + !overwriteByFilter || overwriteExpr == Expressions.alwaysTrue(), + "Unsupported streaming operation: overwrite by filter: %s", + overwriteExpr); + Preconditions.checkState( + rewrittenFileSetId == null, "Unsupported streaming operation: rewrite"); + + if (overwriteByFilter) { + return asStreamingOverwrite(); + } else { + return asStreamingAppend(); + } + } + }; + } + + private Distribution buildRequiredDistribution() { + if (overwriteFiles) { + DistributionMode distributionMode = copyOnWriteDistributionMode(); + return SparkDistributionAndOrderingUtil.buildCopyOnWriteDistribution( + table, copyOnWriteCommand, distributionMode); + } else { + DistributionMode distributionMode = writeConf.distributionMode(); + return SparkDistributionAndOrderingUtil.buildRequiredDistribution(table, distributionMode); + } + } + + private DistributionMode copyOnWriteDistributionMode() { + switch (copyOnWriteCommand) { + case DELETE: + return writeConf.deleteDistributionMode(); + case UPDATE: + return writeConf.updateDistributionMode(); + case MERGE: + return writeConf.copyOnWriteMergeDistributionMode(); + default: + throw new IllegalArgumentException("Unexpected command: " + copyOnWriteCommand); + } + } + + private SortOrder[] buildRequiredOrdering(Distribution requiredDistribution) { + if (overwriteFiles) { + return SparkDistributionAndOrderingUtil.buildCopyOnWriteOrdering( + table, copyOnWriteCommand, requiredDistribution); + } else { + return SparkDistributionAndOrderingUtil.buildRequiredOrdering(table, requiredDistribution); + } + } + + private boolean allIdentityTransforms(PartitionSpec spec) { + return spec.fields().stream().allMatch(field -> field.transform().isIdentity()); + } + + private static Schema validateOrMergeWriteSchema( + Table table, StructType dsSchema, SparkWriteConf writeConf) { + Schema writeSchema; + boolean caseSensitive = writeConf.caseSensitive(); + if (writeConf.mergeSchema()) { + // convert the dataset schema and assign fresh ids for new fields + Schema newSchema = + SparkSchemaUtil.convertWithFreshIds(table.schema(), dsSchema, caseSensitive); + + // update the table to get final id assignments and validate the changes + UpdateSchema update = + table.updateSchema().caseSensitive(caseSensitive).unionByNameWith(newSchema); + Schema mergedSchema = update.apply(); + + // reconvert the dsSchema without assignment to use the ids assigned by UpdateSchema + writeSchema = SparkSchemaUtil.convert(mergedSchema, dsSchema, caseSensitive); + + TypeUtil.validateWriteSchema( + mergedSchema, writeSchema, writeConf.checkNullability(), writeConf.checkOrdering()); + + // if the validation passed, update the table schema + update.commit(); + } else { + writeSchema = SparkSchemaUtil.convert(table.schema(), dsSchema, caseSensitive); + TypeUtil.validateWriteSchema( + table.schema(), writeSchema, writeConf.checkNullability(), writeConf.checkOrdering()); + } + + return writeSchema; + } +} diff --git a/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/source/StagedSparkTable.java b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/source/StagedSparkTable.java new file mode 100644 index 000000000000..b92c02d2b536 --- /dev/null +++ b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/source/StagedSparkTable.java @@ -0,0 +1,41 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.source; + +import org.apache.iceberg.Transaction; +import org.apache.spark.sql.connector.catalog.StagedTable; + +public class StagedSparkTable extends SparkTable implements StagedTable { + private final Transaction transaction; + + public StagedSparkTable(Transaction transaction) { + super(transaction.table(), false); + this.transaction = transaction; + } + + @Override + public void commitStagedChanges() { + transaction.commitTransaction(); + } + + @Override + public void abortStagedChanges() { + // TODO: clean up + } +} diff --git a/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/source/Stats.java b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/source/Stats.java new file mode 100644 index 000000000000..ddf6ca834d9b --- /dev/null +++ b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/source/Stats.java @@ -0,0 +1,42 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.source; + +import java.util.OptionalLong; +import org.apache.spark.sql.connector.read.Statistics; + +class Stats implements Statistics { + private final OptionalLong sizeInBytes; + private final OptionalLong numRows; + + Stats(long sizeInBytes, long numRows) { + this.sizeInBytes = OptionalLong.of(sizeInBytes); + this.numRows = OptionalLong.of(numRows); + } + + @Override + public OptionalLong sizeInBytes() { + return sizeInBytes; + } + + @Override + public OptionalLong numRows() { + return numRows; + } +} diff --git a/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/source/StreamingOffset.java b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/source/StreamingOffset.java new file mode 100644 index 000000000000..f2088deb1ee3 --- /dev/null +++ b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/source/StreamingOffset.java @@ -0,0 +1,157 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.source; + +import com.fasterxml.jackson.core.JsonGenerator; +import com.fasterxml.jackson.databind.JsonNode; +import java.io.IOException; +import java.io.InputStream; +import java.io.StringWriter; +import java.io.UncheckedIOException; +import org.apache.iceberg.relocated.com.google.common.base.Objects; +import org.apache.iceberg.relocated.com.google.common.base.Preconditions; +import org.apache.iceberg.util.JsonUtil; +import org.apache.spark.sql.connector.read.streaming.Offset; + +class StreamingOffset extends Offset { + static final StreamingOffset START_OFFSET = new StreamingOffset(-1L, -1, false); + + private static final int CURR_VERSION = 1; + private static final String VERSION = "version"; + private static final String SNAPSHOT_ID = "snapshot_id"; + private static final String POSITION = "position"; + private static final String SCAN_ALL_FILES = "scan_all_files"; + + private final long snapshotId; + private final long position; + private final boolean scanAllFiles; + + /** + * An implementation of Spark Structured Streaming Offset, to track the current processed files of + * Iceberg table. + * + * @param snapshotId The current processed snapshot id. + * @param position The position of last scanned file in snapshot. + * @param scanAllFiles whether to scan all files in a snapshot; for example, to read all data when + * starting a stream. + */ + StreamingOffset(long snapshotId, long position, boolean scanAllFiles) { + this.snapshotId = snapshotId; + this.position = position; + this.scanAllFiles = scanAllFiles; + } + + static StreamingOffset fromJson(String json) { + Preconditions.checkNotNull(json, "Cannot parse StreamingOffset JSON: null"); + + try { + JsonNode node = JsonUtil.mapper().readValue(json, JsonNode.class); + return fromJsonNode(node); + } catch (IOException e) { + throw new UncheckedIOException( + String.format("Failed to parse StreamingOffset from JSON string %s", json), e); + } + } + + static StreamingOffset fromJson(InputStream inputStream) { + Preconditions.checkNotNull(inputStream, "Cannot parse StreamingOffset from inputStream: null"); + + JsonNode node; + try { + node = JsonUtil.mapper().readValue(inputStream, JsonNode.class); + } catch (IOException e) { + throw new UncheckedIOException("Failed to read StreamingOffset from json", e); + } + + return fromJsonNode(node); + } + + @Override + public String json() { + StringWriter writer = new StringWriter(); + try { + JsonGenerator generator = JsonUtil.factory().createGenerator(writer); + generator.writeStartObject(); + generator.writeNumberField(VERSION, CURR_VERSION); + generator.writeNumberField(SNAPSHOT_ID, snapshotId); + generator.writeNumberField(POSITION, position); + generator.writeBooleanField(SCAN_ALL_FILES, scanAllFiles); + generator.writeEndObject(); + generator.flush(); + + } catch (IOException e) { + throw new UncheckedIOException("Failed to write StreamingOffset to json", e); + } + + return writer.toString(); + } + + long snapshotId() { + return snapshotId; + } + + long position() { + return position; + } + + boolean shouldScanAllFiles() { + return scanAllFiles; + } + + @Override + public boolean equals(Object obj) { + if (obj instanceof StreamingOffset) { + StreamingOffset offset = (StreamingOffset) obj; + return offset.snapshotId == snapshotId + && offset.position == position + && offset.scanAllFiles == scanAllFiles; + } else { + return false; + } + } + + @Override + public int hashCode() { + return Objects.hashCode(snapshotId, position, scanAllFiles); + } + + @Override + public String toString() { + return String.format( + "Streaming Offset[%d: position (%d) scan_all_files (%b)]", + snapshotId, position, scanAllFiles); + } + + private static StreamingOffset fromJsonNode(JsonNode node) { + // The version of StreamingOffset. The offset was created with a version number + // used to validate when deserializing from json string. + int version = JsonUtil.getInt(VERSION, node); + Preconditions.checkArgument( + version == CURR_VERSION, + "This version of Iceberg source only supports version %s. Version %s is not supported.", + CURR_VERSION, + version); + + long snapshotId = JsonUtil.getLong(SNAPSHOT_ID, node); + int position = JsonUtil.getInt(POSITION, node); + boolean shouldScanAllFiles = JsonUtil.getBool(SCAN_ALL_FILES, node); + + return new StreamingOffset(snapshotId, position, shouldScanAllFiles); + } +} diff --git a/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/source/StructInternalRow.java b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/source/StructInternalRow.java new file mode 100644 index 000000000000..f67013f8c457 --- /dev/null +++ b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/source/StructInternalRow.java @@ -0,0 +1,379 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.source; + +import java.math.BigDecimal; +import java.nio.ByteBuffer; +import java.time.Duration; +import java.time.Instant; +import java.time.LocalDate; +import java.time.OffsetDateTime; +import java.util.Collection; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.function.BiConsumer; +import java.util.function.Function; +import org.apache.iceberg.StructLike; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableList; +import org.apache.iceberg.relocated.com.google.common.collect.Lists; +import org.apache.iceberg.types.Type; +import org.apache.iceberg.types.Types; +import org.apache.iceberg.util.ByteBuffers; +import org.apache.spark.sql.catalyst.InternalRow; +import org.apache.spark.sql.catalyst.util.ArrayBasedMapData; +import org.apache.spark.sql.catalyst.util.ArrayData; +import org.apache.spark.sql.catalyst.util.GenericArrayData; +import org.apache.spark.sql.catalyst.util.MapData; +import org.apache.spark.sql.types.ArrayType; +import org.apache.spark.sql.types.BinaryType; +import org.apache.spark.sql.types.BooleanType; +import org.apache.spark.sql.types.ByteType; +import org.apache.spark.sql.types.DataType; +import org.apache.spark.sql.types.DateType; +import org.apache.spark.sql.types.Decimal; +import org.apache.spark.sql.types.DecimalType; +import org.apache.spark.sql.types.DoubleType; +import org.apache.spark.sql.types.FloatType; +import org.apache.spark.sql.types.IntegerType; +import org.apache.spark.sql.types.LongType; +import org.apache.spark.sql.types.MapType; +import org.apache.spark.sql.types.ShortType; +import org.apache.spark.sql.types.StringType; +import org.apache.spark.sql.types.StructType; +import org.apache.spark.sql.types.TimestampType; +import org.apache.spark.unsafe.types.CalendarInterval; +import org.apache.spark.unsafe.types.UTF8String; + +class StructInternalRow extends InternalRow { + private final Types.StructType type; + private StructLike struct; + + StructInternalRow(Types.StructType type) { + this.type = type; + } + + private StructInternalRow(Types.StructType type, StructLike struct) { + this.type = type; + this.struct = struct; + } + + public StructInternalRow setStruct(StructLike newStruct) { + this.struct = newStruct; + return this; + } + + @Override + public int numFields() { + return struct.size(); + } + + @Override + public void setNullAt(int i) { + throw new UnsupportedOperationException("StructInternalRow is read-only"); + } + + @Override + public void update(int i, Object value) { + throw new UnsupportedOperationException("StructInternalRow is read-only"); + } + + @Override + public InternalRow copy() { + return this; + } + + @Override + public boolean isNullAt(int ordinal) { + return struct.get(ordinal, Object.class) == null; + } + + @Override + public boolean getBoolean(int ordinal) { + return struct.get(ordinal, Boolean.class); + } + + @Override + public byte getByte(int ordinal) { + return (byte) (int) struct.get(ordinal, Integer.class); + } + + @Override + public short getShort(int ordinal) { + return (short) (int) struct.get(ordinal, Integer.class); + } + + @Override + public int getInt(int ordinal) { + Object integer = struct.get(ordinal, Object.class); + + if (integer instanceof Integer) { + return (int) integer; + } else if (integer instanceof LocalDate) { + return (int) ((LocalDate) integer).toEpochDay(); + } else { + throw new IllegalStateException( + "Unknown type for int field. Type name: " + integer.getClass().getName()); + } + } + + @Override + public long getLong(int ordinal) { + Object longVal = struct.get(ordinal, Object.class); + + if (longVal instanceof Long) { + return (long) longVal; + } else if (longVal instanceof OffsetDateTime) { + return Duration.between(Instant.EPOCH, (OffsetDateTime) longVal).toNanos() / 1000; + } else if (longVal instanceof LocalDate) { + return ((LocalDate) longVal).toEpochDay(); + } else { + throw new IllegalStateException( + "Unknown type for long field. Type name: " + longVal.getClass().getName()); + } + } + + @Override + public float getFloat(int ordinal) { + return struct.get(ordinal, Float.class); + } + + @Override + public double getDouble(int ordinal) { + return struct.get(ordinal, Double.class); + } + + @Override + public Decimal getDecimal(int ordinal, int precision, int scale) { + return isNullAt(ordinal) ? null : getDecimalInternal(ordinal, precision, scale); + } + + private Decimal getDecimalInternal(int ordinal, int precision, int scale) { + return Decimal.apply(struct.get(ordinal, BigDecimal.class)); + } + + @Override + public UTF8String getUTF8String(int ordinal) { + return isNullAt(ordinal) ? null : getUTF8StringInternal(ordinal); + } + + private UTF8String getUTF8StringInternal(int ordinal) { + CharSequence seq = struct.get(ordinal, CharSequence.class); + return UTF8String.fromString(seq.toString()); + } + + @Override + public byte[] getBinary(int ordinal) { + return isNullAt(ordinal) ? null : getBinaryInternal(ordinal); + } + + private byte[] getBinaryInternal(int ordinal) { + Object bytes = struct.get(ordinal, Object.class); + + // should only be either ByteBuffer or byte[] + if (bytes instanceof ByteBuffer) { + return ByteBuffers.toByteArray((ByteBuffer) bytes); + } else if (bytes instanceof byte[]) { + return (byte[]) bytes; + } else { + throw new IllegalStateException( + "Unknown type for binary field. Type name: " + bytes.getClass().getName()); + } + } + + @Override + public CalendarInterval getInterval(int ordinal) { + throw new UnsupportedOperationException("Unsupported type: interval"); + } + + @Override + public InternalRow getStruct(int ordinal, int numFields) { + return isNullAt(ordinal) ? null : getStructInternal(ordinal, numFields); + } + + private InternalRow getStructInternal(int ordinal, int numFields) { + return new StructInternalRow( + type.fields().get(ordinal).type().asStructType(), struct.get(ordinal, StructLike.class)); + } + + @Override + public ArrayData getArray(int ordinal) { + return isNullAt(ordinal) ? null : getArrayInternal(ordinal); + } + + private ArrayData getArrayInternal(int ordinal) { + return collectionToArrayData( + type.fields().get(ordinal).type().asListType().elementType(), + struct.get(ordinal, Collection.class)); + } + + @Override + public MapData getMap(int ordinal) { + return isNullAt(ordinal) ? null : getMapInternal(ordinal); + } + + private MapData getMapInternal(int ordinal) { + return mapToMapData( + type.fields().get(ordinal).type().asMapType(), struct.get(ordinal, Map.class)); + } + + @Override + @SuppressWarnings("checkstyle:CyclomaticComplexity") + public Object get(int ordinal, DataType dataType) { + if (isNullAt(ordinal)) { + return null; + } + + if (dataType instanceof IntegerType) { + return getInt(ordinal); + } else if (dataType instanceof LongType) { + return getLong(ordinal); + } else if (dataType instanceof StringType) { + return getUTF8StringInternal(ordinal); + } else if (dataType instanceof FloatType) { + return getFloat(ordinal); + } else if (dataType instanceof DoubleType) { + return getDouble(ordinal); + } else if (dataType instanceof DecimalType) { + DecimalType decimalType = (DecimalType) dataType; + return getDecimalInternal(ordinal, decimalType.precision(), decimalType.scale()); + } else if (dataType instanceof BinaryType) { + return getBinaryInternal(ordinal); + } else if (dataType instanceof StructType) { + return getStructInternal(ordinal, ((StructType) dataType).size()); + } else if (dataType instanceof ArrayType) { + return getArrayInternal(ordinal); + } else if (dataType instanceof MapType) { + return getMapInternal(ordinal); + } else if (dataType instanceof BooleanType) { + return getBoolean(ordinal); + } else if (dataType instanceof ByteType) { + return getByte(ordinal); + } else if (dataType instanceof ShortType) { + return getShort(ordinal); + } else if (dataType instanceof DateType) { + return getInt(ordinal); + } else if (dataType instanceof TimestampType) { + return getLong(ordinal); + } + return null; + } + + private MapData mapToMapData(Types.MapType mapType, Map map) { + // make a defensive copy to ensure entries do not change + List> entries = ImmutableList.copyOf(map.entrySet()); + return new ArrayBasedMapData( + collectionToArrayData(mapType.keyType(), Lists.transform(entries, Map.Entry::getKey)), + collectionToArrayData(mapType.valueType(), Lists.transform(entries, Map.Entry::getValue))); + } + + private ArrayData collectionToArrayData(Type elementType, Collection values) { + switch (elementType.typeId()) { + case BOOLEAN: + case INTEGER: + case DATE: + case TIME: + case LONG: + case TIMESTAMP: + case FLOAT: + case DOUBLE: + return fillArray(values, array -> (pos, value) -> array[pos] = value); + case STRING: + return fillArray( + values, + array -> + (BiConsumer) + (pos, seq) -> array[pos] = UTF8String.fromString(seq.toString())); + case FIXED: + case BINARY: + return fillArray( + values, + array -> + (BiConsumer) + (pos, buf) -> array[pos] = ByteBuffers.toByteArray(buf)); + case DECIMAL: + return fillArray( + values, + array -> + (BiConsumer) (pos, dec) -> array[pos] = Decimal.apply(dec)); + case STRUCT: + return fillArray( + values, + array -> + (BiConsumer) + (pos, tuple) -> + array[pos] = new StructInternalRow(elementType.asStructType(), tuple)); + case LIST: + return fillArray( + values, + array -> + (BiConsumer>) + (pos, list) -> + array[pos] = + collectionToArrayData(elementType.asListType().elementType(), list)); + case MAP: + return fillArray( + values, + array -> + (BiConsumer>) + (pos, map) -> array[pos] = mapToMapData(elementType.asMapType(), map)); + default: + throw new UnsupportedOperationException("Unsupported array element type: " + elementType); + } + } + + @SuppressWarnings("unchecked") + private GenericArrayData fillArray( + Collection values, Function> makeSetter) { + Object[] array = new Object[values.size()]; + BiConsumer setter = makeSetter.apply(array); + + int index = 0; + for (Object value : values) { + if (value == null) { + array[index] = null; + } else { + setter.accept(index, (T) value); + } + + index += 1; + } + + return new GenericArrayData(array); + } + + @Override + public boolean equals(Object other) { + if (this == other) { + return true; + } + + if (other == null || getClass() != other.getClass()) { + return false; + } + + StructInternalRow that = (StructInternalRow) other; + return type.equals(that.type) && struct.equals(that.struct); + } + + @Override + public int hashCode() { + return Objects.hash(type, struct); + } +} diff --git a/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/source/metrics/NumDeletes.java b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/source/metrics/NumDeletes.java new file mode 100644 index 000000000000..000499874ba5 --- /dev/null +++ b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/source/metrics/NumDeletes.java @@ -0,0 +1,47 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.source.metrics; + +import java.text.NumberFormat; +import org.apache.spark.sql.connector.metric.CustomMetric; + +public class NumDeletes implements CustomMetric { + + public static final String DISPLAY_STRING = "number of row deletes applied"; + + @Override + public String name() { + return "numDeletes"; + } + + @Override + public String description() { + return DISPLAY_STRING; + } + + @Override + public String aggregateTaskMetrics(long[] taskMetrics) { + long sum = initialValue; + for (long taskMetric : taskMetrics) { + sum += taskMetric; + } + + return NumberFormat.getIntegerInstance().format(sum); + } +} diff --git a/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/source/metrics/NumSplits.java b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/source/metrics/NumSplits.java new file mode 100644 index 000000000000..41d7c1e8db71 --- /dev/null +++ b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/source/metrics/NumSplits.java @@ -0,0 +1,45 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.source.metrics; + +import java.text.NumberFormat; +import org.apache.spark.sql.connector.metric.CustomMetric; + +public class NumSplits implements CustomMetric { + + @Override + public String name() { + return "numSplits"; + } + + @Override + public String description() { + return "number of file splits read"; + } + + @Override + public String aggregateTaskMetrics(long[] taskMetrics) { + long sum = initialValue; + for (long taskMetric : taskMetrics) { + sum += taskMetric; + } + + return NumberFormat.getIntegerInstance().format(sum); + } +} diff --git a/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/source/metrics/TaskNumDeletes.java b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/source/metrics/TaskNumDeletes.java new file mode 100644 index 000000000000..8c734ba9f022 --- /dev/null +++ b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/source/metrics/TaskNumDeletes.java @@ -0,0 +1,39 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.source.metrics; + +import org.apache.spark.sql.connector.metric.CustomTaskMetric; + +public class TaskNumDeletes implements CustomTaskMetric { + private final long value; + + public TaskNumDeletes(long value) { + this.value = value; + } + + @Override + public String name() { + return "numDeletes"; + } + + @Override + public long value() { + return value; + } +} diff --git a/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/source/metrics/TaskNumSplits.java b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/source/metrics/TaskNumSplits.java new file mode 100644 index 000000000000..d8cbc4db05bb --- /dev/null +++ b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/source/metrics/TaskNumSplits.java @@ -0,0 +1,39 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.source.metrics; + +import org.apache.spark.sql.connector.metric.CustomTaskMetric; + +public class TaskNumSplits implements CustomTaskMetric { + private final long value; + + public TaskNumSplits(long value) { + this.value = value; + } + + @Override + public String name() { + return "numSplits"; + } + + @Override + public long value() { + return value; + } +} diff --git a/spark/v3.4/spark/src/main/java/org/apache/spark/sql/catalyst/analysis/NoSuchProcedureException.java b/spark/v3.4/spark/src/main/java/org/apache/spark/sql/catalyst/analysis/NoSuchProcedureException.java new file mode 100644 index 000000000000..2a89ac73e2c6 --- /dev/null +++ b/spark/v3.4/spark/src/main/java/org/apache/spark/sql/catalyst/analysis/NoSuchProcedureException.java @@ -0,0 +1,39 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.spark.sql.catalyst.analysis; + +import org.apache.spark.QueryContext; +import org.apache.spark.sql.AnalysisException; +import org.apache.spark.sql.connector.catalog.Identifier; +import scala.Option; +import scala.collection.immutable.Map$; + +public class NoSuchProcedureException extends AnalysisException { + public NoSuchProcedureException(Identifier ident) { + super( + "Procedure " + ident + " not found", + Option.empty(), + Option.empty(), + Option.empty(), + Option.empty(), + Option.empty(), + Map$.MODULE$.empty(), + new QueryContext[0]); + } +} diff --git a/spark/v3.4/spark/src/main/java/org/apache/spark/sql/connector/iceberg/catalog/Procedure.java b/spark/v3.4/spark/src/main/java/org/apache/spark/sql/connector/iceberg/catalog/Procedure.java new file mode 100644 index 000000000000..11f215ba040a --- /dev/null +++ b/spark/v3.4/spark/src/main/java/org/apache/spark/sql/connector/iceberg/catalog/Procedure.java @@ -0,0 +1,50 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.spark.sql.connector.iceberg.catalog; + +import org.apache.spark.sql.catalyst.InternalRow; +import org.apache.spark.sql.types.StructType; + +/** An interface representing a stored procedure available for execution. */ +public interface Procedure { + /** Returns the input parameters of this procedure. */ + ProcedureParameter[] parameters(); + + /** Returns the type of rows produced by this procedure. */ + StructType outputType(); + + /** + * Executes this procedure. + * + *

Spark will align the provided arguments according to the input parameters defined in {@link + * #parameters()} either by position or by name before execution. + * + *

Implementations may provide a summary of execution by returning one or many rows as a + * result. The schema of output rows must match the defined output type in {@link #outputType()}. + * + * @param args input arguments + * @return the result of executing this procedure with the given arguments + */ + InternalRow[] call(InternalRow args); + + /** Returns the description of this procedure. */ + default String description() { + return this.getClass().toString(); + } +} diff --git a/spark/v3.4/spark/src/main/java/org/apache/spark/sql/connector/iceberg/catalog/ProcedureCatalog.java b/spark/v3.4/spark/src/main/java/org/apache/spark/sql/connector/iceberg/catalog/ProcedureCatalog.java new file mode 100644 index 000000000000..2cee97ee5938 --- /dev/null +++ b/spark/v3.4/spark/src/main/java/org/apache/spark/sql/connector/iceberg/catalog/ProcedureCatalog.java @@ -0,0 +1,40 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.spark.sql.connector.iceberg.catalog; + +import org.apache.spark.sql.catalyst.analysis.NoSuchProcedureException; +import org.apache.spark.sql.connector.catalog.CatalogPlugin; +import org.apache.spark.sql.connector.catalog.Identifier; + +/** + * A catalog API for working with stored procedures. + * + *

Implementations should implement this interface if they expose stored procedures that can be + * called via CALL statements. + */ +public interface ProcedureCatalog extends CatalogPlugin { + /** + * Load a {@link Procedure stored procedure} by {@link Identifier identifier}. + * + * @param ident a stored procedure identifier + * @return the stored procedure's metadata + * @throws NoSuchProcedureException if there is no matching stored procedure + */ + Procedure loadProcedure(Identifier ident) throws NoSuchProcedureException; +} diff --git a/spark/v3.4/spark/src/main/java/org/apache/spark/sql/connector/iceberg/catalog/ProcedureParameter.java b/spark/v3.4/spark/src/main/java/org/apache/spark/sql/connector/iceberg/catalog/ProcedureParameter.java new file mode 100644 index 000000000000..e1e84b2597f3 --- /dev/null +++ b/spark/v3.4/spark/src/main/java/org/apache/spark/sql/connector/iceberg/catalog/ProcedureParameter.java @@ -0,0 +1,56 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.spark.sql.connector.iceberg.catalog; + +import org.apache.spark.sql.types.DataType; + +/** An input parameter of a {@link Procedure stored procedure}. */ +public interface ProcedureParameter { + + /** + * Creates a required input parameter. + * + * @param name the name of the parameter + * @param dataType the type of the parameter + * @return the constructed stored procedure parameter + */ + static ProcedureParameter required(String name, DataType dataType) { + return new ProcedureParameterImpl(name, dataType, true); + } + + /** + * Creates an optional input parameter. + * + * @param name the name of the parameter. + * @param dataType the type of the parameter. + * @return the constructed optional stored procedure parameter + */ + static ProcedureParameter optional(String name, DataType dataType) { + return new ProcedureParameterImpl(name, dataType, false); + } + + /** Returns the name of this parameter. */ + String name(); + + /** Returns the type of this parameter. */ + DataType dataType(); + + /** Returns true if this parameter is required. */ + boolean required(); +} diff --git a/spark/v3.4/spark/src/main/java/org/apache/spark/sql/connector/iceberg/catalog/ProcedureParameterImpl.java b/spark/v3.4/spark/src/main/java/org/apache/spark/sql/connector/iceberg/catalog/ProcedureParameterImpl.java new file mode 100644 index 000000000000..c59951e24330 --- /dev/null +++ b/spark/v3.4/spark/src/main/java/org/apache/spark/sql/connector/iceberg/catalog/ProcedureParameterImpl.java @@ -0,0 +1,75 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.spark.sql.connector.iceberg.catalog; + +import java.util.Objects; +import org.apache.spark.sql.types.DataType; + +/** A {@link ProcedureParameter} implementation. */ +class ProcedureParameterImpl implements ProcedureParameter { + private final String name; + private final DataType dataType; + private final boolean required; + + ProcedureParameterImpl(String name, DataType dataType, boolean required) { + this.name = name; + this.dataType = dataType; + this.required = required; + } + + @Override + public String name() { + return name; + } + + @Override + public DataType dataType() { + return dataType; + } + + @Override + public boolean required() { + return required; + } + + @Override + public boolean equals(Object other) { + if (this == other) { + return true; + } else if (other == null || getClass() != other.getClass()) { + return false; + } + + ProcedureParameterImpl that = (ProcedureParameterImpl) other; + return required == that.required + && Objects.equals(name, that.name) + && Objects.equals(dataType, that.dataType); + } + + @Override + public int hashCode() { + return Objects.hash(name, dataType, required); + } + + @Override + public String toString() { + return String.format( + "ProcedureParameter(name='%s', type=%s, required=%b)", name, dataType, required); + } +} diff --git a/spark/v3.4/spark/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister b/spark/v3.4/spark/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister new file mode 100644 index 000000000000..01a6c4e0670d --- /dev/null +++ b/spark/v3.4/spark/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister @@ -0,0 +1,20 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +org.apache.iceberg.spark.source.IcebergSource diff --git a/spark/v3.4/spark/src/main/scala/org/apache/spark/sql/catalyst/expressions/TransformExpressions.scala b/spark/v3.4/spark/src/main/scala/org/apache/spark/sql/catalyst/expressions/TransformExpressions.scala new file mode 100644 index 000000000000..dffac82af791 --- /dev/null +++ b/spark/v3.4/spark/src/main/scala/org/apache/spark/sql/catalyst/expressions/TransformExpressions.scala @@ -0,0 +1,160 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.spark.sql.catalyst.expressions + +import java.nio.ByteBuffer +import java.nio.CharBuffer +import java.nio.charset.StandardCharsets +import java.util.function +import org.apache.iceberg.spark.SparkSchemaUtil +import org.apache.iceberg.transforms.Transform +import org.apache.iceberg.transforms.Transforms +import org.apache.iceberg.types.Type +import org.apache.iceberg.types.Types +import org.apache.iceberg.util.ByteBuffers +import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback +import org.apache.spark.sql.types.AbstractDataType +import org.apache.spark.sql.types.BinaryType +import org.apache.spark.sql.types.DataType +import org.apache.spark.sql.types.Decimal +import org.apache.spark.sql.types.DecimalType +import org.apache.spark.sql.types.IntegerType +import org.apache.spark.sql.types.StringType +import org.apache.spark.sql.types.TimestampType +import org.apache.spark.unsafe.types.UTF8String + +abstract class IcebergTransformExpression + extends UnaryExpression with CodegenFallback with NullIntolerant { + + @transient lazy val icebergInputType: Type = SparkSchemaUtil.convert(child.dataType) +} + +abstract class IcebergTimeTransform + extends IcebergTransformExpression with ImplicitCastInputTypes { + + def transform: function.Function[Any, Integer] + + override protected def nullSafeEval(value: Any): Any = { + transform(value).toInt + } + + override def dataType: DataType = IntegerType + + override def inputTypes: Seq[AbstractDataType] = Seq(TimestampType) +} + +case class IcebergYearTransform(child: Expression) + extends IcebergTimeTransform { + + @transient lazy val transform: function.Function[Any, Integer] = Transforms.year[Any]().bind(icebergInputType) + + override protected def withNewChildInternal(newChild: Expression): Expression = { + copy(child = newChild) + } +} + +case class IcebergMonthTransform(child: Expression) + extends IcebergTimeTransform { + + @transient lazy val transform: function.Function[Any, Integer] = Transforms.month[Any]().bind(icebergInputType) + + override protected def withNewChildInternal(newChild: Expression): Expression = { + copy(child = newChild) + } +} + +case class IcebergDayTransform(child: Expression) + extends IcebergTimeTransform { + + @transient lazy val transform: function.Function[Any, Integer] = Transforms.day[Any]().bind(icebergInputType) + + override protected def withNewChildInternal(newChild: Expression): Expression = { + copy(child = newChild) + } +} + +case class IcebergHourTransform(child: Expression) + extends IcebergTimeTransform { + + @transient lazy val transform: function.Function[Any, Integer] = Transforms.hour[Any]().bind(icebergInputType) + + override protected def withNewChildInternal(newChild: Expression): Expression = { + copy(child = newChild) + } +} + +case class IcebergBucketTransform(numBuckets: Int, child: Expression) extends IcebergTransformExpression { + + @transient lazy val bucketFunc: Any => Int = child.dataType match { + case _: DecimalType => + val t = Transforms.bucket[Any](numBuckets).bind(icebergInputType) + d: Any => t(d.asInstanceOf[Decimal].toJavaBigDecimal).toInt + case _: StringType => + // the spec requires that the hash of a string is equal to the hash of its UTF-8 encoded bytes + // TODO: pass bytes without the copy out of the InternalRow + val t = Transforms.bucket[ByteBuffer](numBuckets).bind(Types.BinaryType.get()) + s: Any => t(ByteBuffer.wrap(s.asInstanceOf[UTF8String].getBytes)).toInt + case _ => + val t = Transforms.bucket[Any](numBuckets).bind(icebergInputType) + a: Any => t(a).toInt + } + + override protected def nullSafeEval(value: Any): Any = { + bucketFunc(value) + } + + override def dataType: DataType = IntegerType + + override protected def withNewChildInternal(newChild: Expression): Expression = { + copy(child = newChild) + } +} + +case class IcebergTruncateTransform(child: Expression, width: Int) extends IcebergTransformExpression { + + @transient lazy val truncateFunc: Any => Any = child.dataType match { + case _: DecimalType => + val t = Transforms.truncate[java.math.BigDecimal](width).bind(icebergInputType) + d: Any => Decimal.apply(t(d.asInstanceOf[Decimal].toJavaBigDecimal)) + case _: StringType => + val t = Transforms.truncate[CharSequence](width).bind(icebergInputType) + s: Any => { + val charSequence = t(StandardCharsets.UTF_8.decode(ByteBuffer.wrap(s.asInstanceOf[UTF8String].getBytes))) + val bb = StandardCharsets.UTF_8.encode(CharBuffer.wrap(charSequence)); + UTF8String.fromBytes(ByteBuffers.toByteArray(bb)) + } + case _: BinaryType => + val t = Transforms.truncate[ByteBuffer](width).bind(icebergInputType) + s: Any => ByteBuffers.toByteArray(t(ByteBuffer.wrap(s.asInstanceOf[Array[Byte]]))) + case _ => + val t = Transforms.truncate[Any](width).bind(icebergInputType) + a: Any => t(a) + } + + override protected def nullSafeEval(value: Any): Any = { + truncateFunc(value) + } + + override def dataType: DataType = child.dataType + + override protected def withNewChildInternal(newChild: Expression): Expression = { + copy(child = newChild) + } +} diff --git a/spark/v3.4/spark/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/SetWriteDistributionAndOrdering.scala b/spark/v3.4/spark/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/SetWriteDistributionAndOrdering.scala new file mode 100644 index 000000000000..0a0234cdfe34 --- /dev/null +++ b/spark/v3.4/spark/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/SetWriteDistributionAndOrdering.scala @@ -0,0 +1,44 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.spark.sql.catalyst.plans.logical + +import org.apache.iceberg.DistributionMode +import org.apache.iceberg.NullOrder +import org.apache.iceberg.SortDirection +import org.apache.iceberg.expressions.Term +import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.connector.catalog.CatalogV2Implicits + +case class SetWriteDistributionAndOrdering( + table: Seq[String], + distributionMode: DistributionMode, + sortOrder: Seq[(Term, SortDirection, NullOrder)]) extends LeafCommand { + + import CatalogV2Implicits._ + + override lazy val output: Seq[Attribute] = Nil + + override def simpleString(maxFields: Int): String = { + val order = sortOrder.map { + case (term, direction, nullOrder) => s"$term $direction $nullOrder" + }.mkString(", ") + s"SetWriteDistributionAndOrdering ${table.quoted} $distributionMode $order" + } +} diff --git a/spark/v3.4/spark/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/SortOrderParserUtil.scala b/spark/v3.4/spark/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/SortOrderParserUtil.scala new file mode 100644 index 000000000000..bf19ef8a2167 --- /dev/null +++ b/spark/v3.4/spark/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/SortOrderParserUtil.scala @@ -0,0 +1,40 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.spark.sql.catalyst.plans.logical + +import org.apache.iceberg.NullOrder +import org.apache.iceberg.Schema +import org.apache.iceberg.SortDirection +import org.apache.iceberg.SortOrder +import org.apache.iceberg.expressions.Term + +class SortOrderParserUtil { + + def collectSortOrder(tableSchema:Schema, sortOrder: Seq[(Term, SortDirection, NullOrder)]): SortOrder = { + val orderBuilder = SortOrder.builderFor(tableSchema) + sortOrder.foreach { + case (term, SortDirection.ASC, nullOrder) => + orderBuilder.asc(term, nullOrder) + case (term, SortDirection.DESC, nullOrder) => + orderBuilder.desc(term, nullOrder) + } + orderBuilder.build(); + } +} diff --git a/spark/v3.4/spark/src/main/scala/org/apache/spark/sql/catalyst/utils/DistributionAndOrderingUtils.scala b/spark/v3.4/spark/src/main/scala/org/apache/spark/sql/catalyst/utils/DistributionAndOrderingUtils.scala new file mode 100644 index 000000000000..94b6f651a0df --- /dev/null +++ b/spark/v3.4/spark/src/main/scala/org/apache/spark/sql/catalyst/utils/DistributionAndOrderingUtils.scala @@ -0,0 +1,189 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.spark.sql.catalyst.utils + +import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.catalyst +import org.apache.spark.sql.catalyst.analysis.Resolver +import org.apache.spark.sql.catalyst.expressions.IcebergBucketTransform +import org.apache.spark.sql.catalyst.expressions.IcebergDayTransform +import org.apache.spark.sql.catalyst.expressions.IcebergHourTransform +import org.apache.spark.sql.catalyst.expressions.IcebergMonthTransform +import org.apache.spark.sql.catalyst.expressions.IcebergTruncateTransform +import org.apache.spark.sql.catalyst.expressions.IcebergYearTransform +import org.apache.spark.sql.catalyst.expressions.NamedExpression +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.catalyst.plans.logical.RepartitionByExpression +import org.apache.spark.sql.catalyst.plans.logical.Sort +import org.apache.spark.sql.connector.distributions.ClusteredDistribution +import org.apache.spark.sql.connector.distributions.Distribution +import org.apache.spark.sql.connector.distributions.OrderedDistribution +import org.apache.spark.sql.connector.distributions.UnspecifiedDistribution +import org.apache.spark.sql.connector.expressions.ApplyTransform +import org.apache.spark.sql.connector.expressions.BucketTransform +import org.apache.spark.sql.connector.expressions.DaysTransform +import org.apache.spark.sql.connector.expressions.Expression +import org.apache.spark.sql.connector.expressions.FieldReference +import org.apache.spark.sql.connector.expressions.HoursTransform +import org.apache.spark.sql.connector.expressions.IdentityTransform +import org.apache.spark.sql.connector.expressions.Literal +import org.apache.spark.sql.connector.expressions.MonthsTransform +import org.apache.spark.sql.connector.expressions.NamedReference +import org.apache.spark.sql.connector.expressions.NullOrdering +import org.apache.spark.sql.connector.expressions.SortDirection +import org.apache.spark.sql.connector.expressions.SortOrder +import org.apache.spark.sql.connector.expressions.Transform +import org.apache.spark.sql.connector.expressions.YearsTransform +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.types.DataType +import org.apache.spark.sql.types.IntegerType +import scala.collection.compat.immutable.ArraySeq + +object DistributionAndOrderingUtils { + + import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._ + + def prepareQuery( + requiredDistribution: Distribution, + requiredOrdering: Array[SortOrder], + query: LogicalPlan, + conf: SQLConf): LogicalPlan = { + + val resolver = conf.resolver + + val distribution = requiredDistribution match { + case d: OrderedDistribution => + d.ordering.map(e => toCatalyst(e, query, resolver)) + case d: ClusteredDistribution => + d.clustering.map(e => toCatalyst(e, query, resolver)) + case _: UnspecifiedDistribution => + Array.empty[catalyst.expressions.Expression] + } + + val queryWithDistribution = if (distribution.nonEmpty) { + // the conversion to catalyst expressions above produces SortOrder expressions + // for OrderedDistribution and generic expressions for ClusteredDistribution + // this allows RepartitionByExpression to pick either range or hash partitioning + RepartitionByExpression(distribution.toSeq, query, None) + } else { + query + } + + val ordering = requiredOrdering + .map(e => toCatalyst(e, query, resolver).asInstanceOf[catalyst.expressions.SortOrder]) + + val queryWithDistributionAndOrdering = if (ordering.nonEmpty) { + Sort(ArraySeq.unsafeWrapArray(ordering), global = false, queryWithDistribution) + } else { + queryWithDistribution + } + + queryWithDistributionAndOrdering + } + + private def toCatalyst( + expr: Expression, + query: LogicalPlan, + resolver: Resolver): catalyst.expressions.Expression = { + + // we cannot perform the resolution in the analyzer since we need to optimize expressions + // in nodes like OverwriteByExpression before constructing a logical write + def resolve(parts: Seq[String]): NamedExpression = { + query.resolve(parts, resolver) match { + case Some(attr) => + attr + case None => + val ref = parts.quoted + throw new AnalysisException(s"Cannot resolve '$ref' using ${query.output}") + } + } + + expr match { + case s: SortOrder => + val catalystChild = toCatalyst(s.expression(), query, resolver) + catalyst.expressions.SortOrder(catalystChild, toCatalyst(s.direction), toCatalyst(s.nullOrdering), Seq.empty) + case it: IdentityTransform => + resolve(ArraySeq.unsafeWrapArray(it.ref.fieldNames)) + case BucketTransform(numBuckets, ref) => + IcebergBucketTransform(numBuckets, resolve(ArraySeq.unsafeWrapArray(ref.fieldNames))) + case TruncateTransform(ref, width) => + IcebergTruncateTransform(resolve(ArraySeq.unsafeWrapArray(ref.fieldNames)), width) + case yt: YearsTransform => + IcebergYearTransform(resolve(ArraySeq.unsafeWrapArray(yt.ref.fieldNames))) + case mt: MonthsTransform => + IcebergMonthTransform(resolve(ArraySeq.unsafeWrapArray(mt.ref.fieldNames))) + case dt: DaysTransform => + IcebergDayTransform(resolve(ArraySeq.unsafeWrapArray(dt.ref.fieldNames))) + case ht: HoursTransform => + IcebergHourTransform(resolve(ArraySeq.unsafeWrapArray(ht.ref.fieldNames))) + case ref: FieldReference => + resolve(ArraySeq.unsafeWrapArray(ref.fieldNames)) + case _ => + throw new RuntimeException(s"$expr is not currently supported") + + } + } + + private def toCatalyst(direction: SortDirection): catalyst.expressions.SortDirection = { + direction match { + case SortDirection.ASCENDING => catalyst.expressions.Ascending + case SortDirection.DESCENDING => catalyst.expressions.Descending + } + } + + private def toCatalyst(nullOrdering: NullOrdering): catalyst.expressions.NullOrdering = { + nullOrdering match { + case NullOrdering.NULLS_FIRST => catalyst.expressions.NullsFirst + case NullOrdering.NULLS_LAST => catalyst.expressions.NullsLast + } + } + + private object BucketTransform { + def unapply(transform: Transform): Option[(Int, FieldReference)] = transform match { + case bt: BucketTransform => bt.columns match { + case Seq(nf: NamedReference) => + Some(bt.numBuckets.value(), FieldReference(ArraySeq.unsafeWrapArray(nf.fieldNames()))) + case _ => + None + } + case _ => None + } + } + + private object Lit { + def unapply[T](literal: Literal[T]): Some[(T, DataType)] = { + Some((literal.value, literal.dataType)) + } + } + + private object TruncateTransform { + def unapply(transform: Transform): Option[(FieldReference, Int)] = transform match { + case at @ ApplyTransform(name, _) if name.equalsIgnoreCase("truncate") => at.args match { + case Seq(nf: NamedReference, Lit(value: Int, IntegerType)) => + Some(FieldReference(ArraySeq.unsafeWrapArray(nf.fieldNames())), value) + case Seq(Lit(value: Int, IntegerType), nf: NamedReference) => + Some(FieldReference(ArraySeq.unsafeWrapArray(nf.fieldNames())), value) + case _ => + None + } + case _ => None + } + } +} diff --git a/spark/v3.4/spark/src/main/scala/org/apache/spark/sql/catalyst/utils/PlanUtils.scala b/spark/v3.4/spark/src/main/scala/org/apache/spark/sql/catalyst/utils/PlanUtils.scala new file mode 100644 index 000000000000..aa9e9c553346 --- /dev/null +++ b/spark/v3.4/spark/src/main/scala/org/apache/spark/sql/catalyst/utils/PlanUtils.scala @@ -0,0 +1,42 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.spark.sql.catalyst.utils + +import org.apache.iceberg.spark.source.SparkTable +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.catalyst.plans.logical.SubqueryAlias +import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation +import scala.annotation.tailrec + +object PlanUtils { + @tailrec + def isIcebergRelation(plan: LogicalPlan): Boolean = { + def isIcebergTable(relation: DataSourceV2Relation): Boolean = relation.table match { + case _: SparkTable => true + case _ => false + } + + plan match { + case s: SubqueryAlias => isIcebergRelation(s.child) + case r: DataSourceV2Relation => isIcebergTable(r) + case _ => false + } + } +} diff --git a/spark/v3.4/spark/src/main/scala/org/apache/spark/sql/execution/datasources/SparkExpressionConverter.scala b/spark/v3.4/spark/src/main/scala/org/apache/spark/sql/execution/datasources/SparkExpressionConverter.scala new file mode 100644 index 000000000000..554fa7f66dc6 --- /dev/null +++ b/spark/v3.4/spark/src/main/scala/org/apache/spark/sql/execution/datasources/SparkExpressionConverter.scala @@ -0,0 +1,51 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.spark.sql.execution.datasources + +import org.apache.iceberg.spark.SparkFilters +import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.catalyst.expressions.Expression +import org.apache.spark.sql.catalyst.plans.logical.Filter +import org.apache.spark.sql.catalyst.plans.logical.LeafNode + +object SparkExpressionConverter { + + def convertToIcebergExpression(sparkExpression: Expression): org.apache.iceberg.expressions.Expression = { + // Currently, it is a double conversion as we are converting Spark expression to Spark filter + // and then converting Spark filter to Iceberg expression. + // But these two conversions already exist and well tested. So, we are going with this approach. + SparkFilters.convert(DataSourceStrategy.translateFilter(sparkExpression, supportNestedPredicatePushdown = true).get) + } + + @throws[AnalysisException] + def collectResolvedSparkExpression(session: SparkSession, tableName: String, where: String): Expression = { + val tableAttrs = session.table(tableName).queryExecution.analyzed.output + val unresolvedExpression = session.sessionState.sqlParser.parseExpression(where) + val filter = Filter(unresolvedExpression, DummyRelation(tableAttrs)) + val optimizedLogicalPlan = session.sessionState.executePlan(filter).optimizedPlan + optimizedLogicalPlan.collectFirst { + case filter: Filter => filter.condition + }.getOrElse(throw new AnalysisException("Failed to find filter expression")) + } + + case class DummyRelation(output: Seq[Attribute]) extends LeafNode +} diff --git a/spark/v3.4/spark/src/test/java/org/apache/iceberg/KryoHelpers.java b/spark/v3.4/spark/src/test/java/org/apache/iceberg/KryoHelpers.java new file mode 100644 index 000000000000..6d88aaa11813 --- /dev/null +++ b/spark/v3.4/spark/src/test/java/org/apache/iceberg/KryoHelpers.java @@ -0,0 +1,51 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg; + +import com.esotericsoftware.kryo.Kryo; +import com.esotericsoftware.kryo.io.Input; +import com.esotericsoftware.kryo.io.Output; +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.ObjectInputStream; +import java.io.ObjectOutputStream; +import org.apache.spark.SparkConf; +import org.apache.spark.serializer.KryoSerializer; + +public class KryoHelpers { + + private KryoHelpers() {} + + @SuppressWarnings("unchecked") + public static T roundTripSerialize(T obj) throws IOException { + Kryo kryo = new KryoSerializer(new SparkConf()).newKryo(); + + ByteArrayOutputStream bytes = new ByteArrayOutputStream(); + + try (Output out = new Output(new ObjectOutputStream(bytes))) { + kryo.writeClassAndObject(out, obj); + } + + try (Input in = + new Input(new ObjectInputStream(new ByteArrayInputStream(bytes.toByteArray())))) { + return (T) kryo.readClassAndObject(in); + } + } +} diff --git a/spark/v3.4/spark/src/test/java/org/apache/iceberg/TaskCheckHelper.java b/spark/v3.4/spark/src/test/java/org/apache/iceberg/TaskCheckHelper.java new file mode 100644 index 000000000000..c44bacf149b5 --- /dev/null +++ b/spark/v3.4/spark/src/test/java/org/apache/iceberg/TaskCheckHelper.java @@ -0,0 +1,111 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg; + +import java.util.Comparator; +import java.util.List; +import java.util.stream.Collectors; +import org.junit.Assert; + +public final class TaskCheckHelper { + private TaskCheckHelper() {} + + public static void assertEquals( + ScanTaskGroup expected, ScanTaskGroup actual) { + List expectedTasks = getFileScanTasksInFilePathOrder(expected); + List actualTasks = getFileScanTasksInFilePathOrder(actual); + + Assert.assertEquals( + "The number of file scan tasks should match", expectedTasks.size(), actualTasks.size()); + + for (int i = 0; i < expectedTasks.size(); i++) { + FileScanTask expectedTask = expectedTasks.get(i); + FileScanTask actualTask = actualTasks.get(i); + assertEquals(expectedTask, actualTask); + } + } + + public static void assertEquals(FileScanTask expected, FileScanTask actual) { + assertEquals(expected.file(), actual.file()); + + // PartitionSpec implements its own equals method + Assert.assertEquals("PartitionSpec doesn't match", expected.spec(), actual.spec()); + + Assert.assertEquals("starting position doesn't match", expected.start(), actual.start()); + + Assert.assertEquals( + "the number of bytes to scan doesn't match", expected.start(), actual.start()); + + // simplify comparison on residual expression via comparing toString + Assert.assertEquals( + "Residual expression doesn't match", + expected.residual().toString(), + actual.residual().toString()); + } + + public static void assertEquals(DataFile expected, DataFile actual) { + Assert.assertEquals("Should match the serialized record path", expected.path(), actual.path()); + Assert.assertEquals( + "Should match the serialized record format", expected.format(), actual.format()); + Assert.assertEquals( + "Should match the serialized record partition", + expected.partition().get(0, Object.class), + actual.partition().get(0, Object.class)); + Assert.assertEquals( + "Should match the serialized record count", expected.recordCount(), actual.recordCount()); + Assert.assertEquals( + "Should match the serialized record size", + expected.fileSizeInBytes(), + actual.fileSizeInBytes()); + Assert.assertEquals( + "Should match the serialized record value counts", + expected.valueCounts(), + actual.valueCounts()); + Assert.assertEquals( + "Should match the serialized record null value counts", + expected.nullValueCounts(), + actual.nullValueCounts()); + Assert.assertEquals( + "Should match the serialized record lower bounds", + expected.lowerBounds(), + actual.lowerBounds()); + Assert.assertEquals( + "Should match the serialized record upper bounds", + expected.upperBounds(), + actual.upperBounds()); + Assert.assertEquals( + "Should match the serialized record key metadata", + expected.keyMetadata(), + actual.keyMetadata()); + Assert.assertEquals( + "Should match the serialized record offsets", + expected.splitOffsets(), + actual.splitOffsets()); + Assert.assertEquals( + "Should match the serialized record offsets", expected.keyMetadata(), actual.keyMetadata()); + } + + private static List getFileScanTasksInFilePathOrder( + ScanTaskGroup taskGroup) { + return taskGroup.tasks().stream() + // use file path + start position to differentiate the tasks + .sorted(Comparator.comparing(o -> o.file().path().toString() + "##" + o.start())) + .collect(Collectors.toList()); + } +} diff --git a/spark/v3.4/spark/src/test/java/org/apache/iceberg/TestDataFileSerialization.java b/spark/v3.4/spark/src/test/java/org/apache/iceberg/TestDataFileSerialization.java new file mode 100644 index 000000000000..33b5316b72b7 --- /dev/null +++ b/spark/v3.4/spark/src/test/java/org/apache/iceberg/TestDataFileSerialization.java @@ -0,0 +1,176 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg; + +import static org.apache.iceberg.TaskCheckHelper.assertEquals; +import static org.apache.iceberg.types.Types.NestedField.optional; +import static org.apache.iceberg.types.Types.NestedField.required; + +import com.esotericsoftware.kryo.Kryo; +import com.esotericsoftware.kryo.io.Input; +import com.esotericsoftware.kryo.io.Output; +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.File; +import java.io.FileInputStream; +import java.io.FileOutputStream; +import java.io.IOException; +import java.io.ObjectInputStream; +import java.io.ObjectOutputStream; +import java.nio.ByteBuffer; +import java.nio.ByteOrder; +import java.util.Map; +import java.util.UUID; +import org.apache.iceberg.io.FileAppender; +import org.apache.iceberg.parquet.Parquet; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableList; +import org.apache.iceberg.relocated.com.google.common.collect.Maps; +import org.apache.iceberg.spark.SparkSchemaUtil; +import org.apache.iceberg.spark.data.RandomData; +import org.apache.iceberg.spark.data.SparkParquetWriters; +import org.apache.iceberg.types.Types; +import org.apache.spark.SparkConf; +import org.apache.spark.serializer.KryoSerializer; +import org.apache.spark.sql.catalyst.InternalRow; +import org.assertj.core.api.Assertions; +import org.junit.Assert; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.TemporaryFolder; + +public class TestDataFileSerialization { + + private static final Schema DATE_SCHEMA = + new Schema( + required(1, "id", Types.LongType.get()), + optional(2, "data", Types.StringType.get()), + required(3, "date", Types.StringType.get()), + optional(4, "double", Types.DoubleType.get())); + + private static final PartitionSpec PARTITION_SPEC = + PartitionSpec.builderFor(DATE_SCHEMA).identity("date").build(); + + private static final Map VALUE_COUNTS = Maps.newHashMap(); + private static final Map NULL_VALUE_COUNTS = Maps.newHashMap(); + private static final Map NAN_VALUE_COUNTS = Maps.newHashMap(); + private static final Map LOWER_BOUNDS = Maps.newHashMap(); + private static final Map UPPER_BOUNDS = Maps.newHashMap(); + + static { + VALUE_COUNTS.put(1, 5L); + VALUE_COUNTS.put(2, 3L); + VALUE_COUNTS.put(4, 2L); + NULL_VALUE_COUNTS.put(1, 0L); + NULL_VALUE_COUNTS.put(2, 2L); + NAN_VALUE_COUNTS.put(4, 1L); + LOWER_BOUNDS.put(1, longToBuffer(0L)); + UPPER_BOUNDS.put(1, longToBuffer(4L)); + } + + private static final DataFile DATA_FILE = + DataFiles.builder(PARTITION_SPEC) + .withPath("/path/to/data-1.parquet") + .withFileSizeInBytes(1234) + .withPartitionPath("date=2018-06-08") + .withMetrics( + new Metrics( + 5L, + null, + VALUE_COUNTS, + NULL_VALUE_COUNTS, + NAN_VALUE_COUNTS, + LOWER_BOUNDS, + UPPER_BOUNDS)) + .withSplitOffsets(ImmutableList.of(4L)) + .withEncryptionKeyMetadata(ByteBuffer.allocate(4).putInt(34)) + .withSortOrder(SortOrder.unsorted()) + .build(); + + @Rule public TemporaryFolder temp = new TemporaryFolder(); + + @Test + public void testDataFileKryoSerialization() throws Exception { + File data = temp.newFile(); + Assert.assertTrue(data.delete()); + Kryo kryo = new KryoSerializer(new SparkConf()).newKryo(); + + try (Output out = new Output(new FileOutputStream(data))) { + kryo.writeClassAndObject(out, DATA_FILE); + kryo.writeClassAndObject(out, DATA_FILE.copy()); + } + + try (Input in = new Input(new FileInputStream(data))) { + for (int i = 0; i < 2; i += 1) { + Object obj = kryo.readClassAndObject(in); + Assertions.assertThat(obj).as("Should be a DataFile").isInstanceOf(DataFile.class); + assertEquals(DATA_FILE, (DataFile) obj); + } + } + } + + @Test + public void testDataFileJavaSerialization() throws Exception { + ByteArrayOutputStream bytes = new ByteArrayOutputStream(); + try (ObjectOutputStream out = new ObjectOutputStream(bytes)) { + out.writeObject(DATA_FILE); + out.writeObject(DATA_FILE.copy()); + } + + try (ObjectInputStream in = + new ObjectInputStream(new ByteArrayInputStream(bytes.toByteArray()))) { + for (int i = 0; i < 2; i += 1) { + Object obj = in.readObject(); + Assertions.assertThat(obj).as("Should be a DataFile").isInstanceOf(DataFile.class); + assertEquals(DATA_FILE, (DataFile) obj); + } + } + } + + @Test + public void testParquetWriterSplitOffsets() throws IOException { + Iterable records = RandomData.generateSpark(DATE_SCHEMA, 1, 33L); + File parquetFile = + new File(temp.getRoot(), FileFormat.PARQUET.addExtension(UUID.randomUUID().toString())); + FileAppender writer = + Parquet.write(Files.localOutput(parquetFile)) + .schema(DATE_SCHEMA) + .createWriterFunc( + msgType -> + SparkParquetWriters.buildWriter(SparkSchemaUtil.convert(DATE_SCHEMA), msgType)) + .build(); + try { + writer.addAll(records); + } finally { + writer.close(); + } + + Kryo kryo = new KryoSerializer(new SparkConf()).newKryo(); + File dataFile = temp.newFile(); + try (Output out = new Output(new FileOutputStream(dataFile))) { + kryo.writeClassAndObject(out, writer.splitOffsets()); + } + try (Input in = new Input(new FileInputStream(dataFile))) { + kryo.readClassAndObject(in); + } + } + + private static ByteBuffer longToBuffer(long value) { + return ByteBuffer.allocate(8).order(ByteOrder.LITTLE_ENDIAN).putLong(0, value); + } +} diff --git a/spark/v3.4/spark/src/test/java/org/apache/iceberg/TestFileIOSerialization.java b/spark/v3.4/spark/src/test/java/org/apache/iceberg/TestFileIOSerialization.java new file mode 100644 index 000000000000..c6f491ece5ad --- /dev/null +++ b/spark/v3.4/spark/src/test/java/org/apache/iceberg/TestFileIOSerialization.java @@ -0,0 +1,109 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg; + +import static org.apache.iceberg.types.Types.NestedField.optional; +import static org.apache.iceberg.types.Types.NestedField.required; + +import java.io.File; +import java.io.IOException; +import java.util.Map; +import org.apache.hadoop.conf.Configuration; +import org.apache.iceberg.hadoop.HadoopFileIO; +import org.apache.iceberg.hadoop.HadoopTables; +import org.apache.iceberg.io.FileIO; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap; +import org.apache.iceberg.relocated.com.google.common.collect.Maps; +import org.apache.iceberg.spark.source.SerializableTableWithSize; +import org.apache.iceberg.types.Types; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.TemporaryFolder; + +public class TestFileIOSerialization { + + private static final Configuration CONF = new Configuration(); + private static final HadoopTables TABLES = new HadoopTables(CONF); + + private static final Schema SCHEMA = + new Schema( + required(1, "id", Types.LongType.get()), + optional(2, "data", Types.StringType.get()), + required(3, "date", Types.StringType.get()), + optional(4, "double", Types.DoubleType.get())); + + private static final PartitionSpec SPEC = + PartitionSpec.builderFor(SCHEMA).identity("date").build(); + + private static final SortOrder SORT_ORDER = SortOrder.builderFor(SCHEMA).asc("id").build(); + + static { + CONF.set("k1", "v1"); + CONF.set("k2", "v2"); + } + + @Rule public TemporaryFolder temp = new TemporaryFolder(); + private Table table; + + @Before + public void initTable() throws IOException { + Map props = ImmutableMap.of("k1", "v1", "k2", "v2"); + + File tableLocation = temp.newFolder(); + Assert.assertTrue(tableLocation.delete()); + + this.table = TABLES.create(SCHEMA, SPEC, SORT_ORDER, props, tableLocation.toString()); + } + + @Test + public void testHadoopFileIOKryoSerialization() throws IOException { + FileIO io = table.io(); + Configuration expectedConf = ((HadoopFileIO) io).conf(); + + Table serializableTable = SerializableTableWithSize.copyOf(table); + FileIO deserializedIO = KryoHelpers.roundTripSerialize(serializableTable.io()); + Configuration actualConf = ((HadoopFileIO) deserializedIO).conf(); + + Assert.assertEquals("Conf pairs must match", toMap(expectedConf), toMap(actualConf)); + Assert.assertEquals("Conf values must be present", "v1", actualConf.get("k1")); + Assert.assertEquals("Conf values must be present", "v2", actualConf.get("k2")); + } + + @Test + public void testHadoopFileIOJavaSerialization() throws IOException, ClassNotFoundException { + FileIO io = table.io(); + Configuration expectedConf = ((HadoopFileIO) io).conf(); + + Table serializableTable = SerializableTableWithSize.copyOf(table); + FileIO deserializedIO = TestHelpers.roundTripSerialize(serializableTable.io()); + Configuration actualConf = ((HadoopFileIO) deserializedIO).conf(); + + Assert.assertEquals("Conf pairs must match", toMap(expectedConf), toMap(actualConf)); + Assert.assertEquals("Conf values must be present", "v1", actualConf.get("k1")); + Assert.assertEquals("Conf values must be present", "v2", actualConf.get("k2")); + } + + private Map toMap(Configuration conf) { + Map map = Maps.newHashMapWithExpectedSize(conf.size()); + conf.forEach(entry -> map.put(entry.getKey(), entry.getValue())); + return map; + } +} diff --git a/spark/v3.4/spark/src/test/java/org/apache/iceberg/TestHadoopMetricsContextSerialization.java b/spark/v3.4/spark/src/test/java/org/apache/iceberg/TestHadoopMetricsContextSerialization.java new file mode 100644 index 000000000000..92d233e129e2 --- /dev/null +++ b/spark/v3.4/spark/src/test/java/org/apache/iceberg/TestHadoopMetricsContextSerialization.java @@ -0,0 +1,56 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg; + +import java.io.IOException; +import org.apache.iceberg.hadoop.HadoopMetricsContext; +import org.apache.iceberg.io.FileIOMetricsContext; +import org.apache.iceberg.metrics.MetricsContext; +import org.apache.iceberg.relocated.com.google.common.collect.Maps; +import org.junit.Test; + +public class TestHadoopMetricsContextSerialization { + + @Test + public void testHadoopMetricsContextKryoSerialization() throws IOException { + MetricsContext metricsContext = new HadoopMetricsContext("s3"); + + metricsContext.initialize(Maps.newHashMap()); + + MetricsContext deserializedMetricContext = KryoHelpers.roundTripSerialize(metricsContext); + // statistics are properly re-initialized post de-serialization + deserializedMetricContext + .counter(FileIOMetricsContext.WRITE_BYTES, MetricsContext.Unit.BYTES) + .increment(); + } + + @Test + public void testHadoopMetricsContextJavaSerialization() + throws IOException, ClassNotFoundException { + MetricsContext metricsContext = new HadoopMetricsContext("s3"); + + metricsContext.initialize(Maps.newHashMap()); + + MetricsContext deserializedMetricContext = TestHelpers.roundTripSerialize(metricsContext); + // statistics are properly re-initialized post de-serialization + deserializedMetricContext + .counter(FileIOMetricsContext.WRITE_BYTES, MetricsContext.Unit.BYTES) + .increment(); + } +} diff --git a/spark/v3.4/spark/src/test/java/org/apache/iceberg/TestManifestFileSerialization.java b/spark/v3.4/spark/src/test/java/org/apache/iceberg/TestManifestFileSerialization.java new file mode 100644 index 000000000000..92a646d3861b --- /dev/null +++ b/spark/v3.4/spark/src/test/java/org/apache/iceberg/TestManifestFileSerialization.java @@ -0,0 +1,217 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg; + +import static org.apache.iceberg.types.Types.NestedField.optional; +import static org.apache.iceberg.types.Types.NestedField.required; + +import com.esotericsoftware.kryo.Kryo; +import com.esotericsoftware.kryo.io.Input; +import com.esotericsoftware.kryo.io.Output; +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.File; +import java.io.FileInputStream; +import java.io.FileOutputStream; +import java.io.IOException; +import java.io.ObjectInputStream; +import java.io.ObjectOutputStream; +import java.nio.ByteBuffer; +import java.nio.ByteOrder; +import org.apache.hadoop.conf.Configuration; +import org.apache.iceberg.ManifestFile.PartitionFieldSummary; +import org.apache.iceberg.hadoop.HadoopFileIO; +import org.apache.iceberg.io.FileIO; +import org.apache.iceberg.io.OutputFile; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap; +import org.apache.iceberg.types.Types; +import org.apache.spark.SparkConf; +import org.apache.spark.serializer.KryoSerializer; +import org.assertj.core.api.Assertions; +import org.junit.Assert; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.TemporaryFolder; + +public class TestManifestFileSerialization { + + private static final Schema SCHEMA = + new Schema( + required(1, "id", Types.LongType.get()), + optional(2, "data", Types.StringType.get()), + required(3, "date", Types.StringType.get()), + required(4, "double", Types.DoubleType.get())); + + private static final PartitionSpec SPEC = + PartitionSpec.builderFor(SCHEMA).identity("double").build(); + + private static final DataFile FILE_A = + DataFiles.builder(SPEC) + .withPath("/path/to/data-1.parquet") + .withFileSizeInBytes(0) + .withPartition(TestHelpers.Row.of(1D)) + .withPartitionPath("double=1") + .withMetrics( + new Metrics( + 5L, + null, // no column sizes + ImmutableMap.of(1, 5L, 2, 3L), // value count + ImmutableMap.of(1, 0L, 2, 2L), // null count + ImmutableMap.of(), // nan count + ImmutableMap.of(1, longToBuffer(0L)), // lower bounds + ImmutableMap.of(1, longToBuffer(4L)) // upper bounds + )) + .build(); + + private static final DataFile FILE_B = + DataFiles.builder(SPEC) + .withPath("/path/to/data-2.parquet") + .withFileSizeInBytes(0) + .withPartition(TestHelpers.Row.of(Double.NaN)) + .withPartitionPath("double=NaN") + .withMetrics( + new Metrics( + 1L, + null, // no column sizes + ImmutableMap.of(1, 1L, 4, 1L), // value count + ImmutableMap.of(1, 0L, 2, 0L), // null count + ImmutableMap.of(4, 1L), // nan count + ImmutableMap.of(1, longToBuffer(0L)), // lower bounds + ImmutableMap.of(1, longToBuffer(1L)) // upper bounds + )) + .build(); + + private static final FileIO FILE_IO = new HadoopFileIO(new Configuration()); + + @Rule public TemporaryFolder temp = new TemporaryFolder(); + + @Test + public void testManifestFileKryoSerialization() throws IOException { + File data = temp.newFile(); + Assert.assertTrue(data.delete()); + + Kryo kryo = new KryoSerializer(new SparkConf()).newKryo(); + + ManifestFile manifest = writeManifest(FILE_A, FILE_B); + + try (Output out = new Output(new FileOutputStream(data))) { + kryo.writeClassAndObject(out, manifest); + kryo.writeClassAndObject(out, manifest.copy()); + kryo.writeClassAndObject(out, GenericManifestFile.copyOf(manifest).build()); + } + + try (Input in = new Input(new FileInputStream(data))) { + for (int i = 0; i < 3; i += 1) { + Object obj = kryo.readClassAndObject(in); + Assertions.assertThat(obj).as("Should be a ManifestFile").isInstanceOf(ManifestFile.class); + checkManifestFile(manifest, (ManifestFile) obj); + } + } + } + + @Test + public void testManifestFileJavaSerialization() throws Exception { + ByteArrayOutputStream bytes = new ByteArrayOutputStream(); + + ManifestFile manifest = writeManifest(FILE_A, FILE_B); + + try (ObjectOutputStream out = new ObjectOutputStream(bytes)) { + out.writeObject(manifest); + out.writeObject(manifest.copy()); + out.writeObject(GenericManifestFile.copyOf(manifest).build()); + } + + try (ObjectInputStream in = + new ObjectInputStream(new ByteArrayInputStream(bytes.toByteArray()))) { + for (int i = 0; i < 3; i += 1) { + Object obj = in.readObject(); + Assertions.assertThat(obj).as("Should be a ManifestFile").isInstanceOf(ManifestFile.class); + checkManifestFile(manifest, (ManifestFile) obj); + } + } + } + + private void checkManifestFile(ManifestFile expected, ManifestFile actual) { + Assert.assertEquals("Path must match", expected.path(), actual.path()); + Assert.assertEquals("Length must match", expected.length(), actual.length()); + Assert.assertEquals("Spec id must match", expected.partitionSpecId(), actual.partitionSpecId()); + Assert.assertEquals("Snapshot id must match", expected.snapshotId(), actual.snapshotId()); + Assert.assertEquals( + "Added files flag must match", expected.hasAddedFiles(), actual.hasAddedFiles()); + Assert.assertEquals( + "Added files count must match", expected.addedFilesCount(), actual.addedFilesCount()); + Assert.assertEquals( + "Added rows count must match", expected.addedRowsCount(), actual.addedRowsCount()); + Assert.assertEquals( + "Existing files flag must match", expected.hasExistingFiles(), actual.hasExistingFiles()); + Assert.assertEquals( + "Existing files count must match", + expected.existingFilesCount(), + actual.existingFilesCount()); + Assert.assertEquals( + "Existing rows count must match", expected.existingRowsCount(), actual.existingRowsCount()); + Assert.assertEquals( + "Deleted files flag must match", expected.hasDeletedFiles(), actual.hasDeletedFiles()); + Assert.assertEquals( + "Deleted files count must match", expected.deletedFilesCount(), actual.deletedFilesCount()); + Assert.assertEquals( + "Deleted rows count must match", expected.deletedRowsCount(), actual.deletedRowsCount()); + + PartitionFieldSummary expectedPartition = expected.partitions().get(0); + PartitionFieldSummary actualPartition = actual.partitions().get(0); + + Assert.assertEquals( + "Null flag in partition must match", + expectedPartition.containsNull(), + actualPartition.containsNull()); + Assert.assertEquals( + "NaN flag in partition must match", + expectedPartition.containsNaN(), + actualPartition.containsNaN()); + Assert.assertEquals( + "Lower bounds in partition must match", + expectedPartition.lowerBound(), + actualPartition.lowerBound()); + Assert.assertEquals( + "Upper bounds in partition must match", + expectedPartition.upperBound(), + actualPartition.upperBound()); + } + + private ManifestFile writeManifest(DataFile... files) throws IOException { + File manifestFile = temp.newFile("input.m0.avro"); + Assert.assertTrue(manifestFile.delete()); + OutputFile outputFile = FILE_IO.newOutputFile(manifestFile.getCanonicalPath()); + + ManifestWriter writer = ManifestFiles.write(SPEC, outputFile); + try { + for (DataFile file : files) { + writer.add(file); + } + } finally { + writer.close(); + } + + return writer.toManifestFile(); + } + + private static ByteBuffer longToBuffer(long value) { + return ByteBuffer.allocate(8).order(ByteOrder.LITTLE_ENDIAN).putLong(0, value); + } +} diff --git a/spark/v3.4/spark/src/test/java/org/apache/iceberg/TestScanTaskSerialization.java b/spark/v3.4/spark/src/test/java/org/apache/iceberg/TestScanTaskSerialization.java new file mode 100644 index 000000000000..5e5d657eab56 --- /dev/null +++ b/spark/v3.4/spark/src/test/java/org/apache/iceberg/TestScanTaskSerialization.java @@ -0,0 +1,202 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg; + +import static org.apache.iceberg.types.Types.NestedField.optional; + +import com.esotericsoftware.kryo.Kryo; +import com.esotericsoftware.kryo.io.Input; +import com.esotericsoftware.kryo.io.Output; +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.File; +import java.io.FileInputStream; +import java.io.FileOutputStream; +import java.io.ObjectInputStream; +import java.io.ObjectOutputStream; +import java.nio.file.Files; +import java.util.List; +import java.util.Map; +import org.apache.hadoop.conf.Configuration; +import org.apache.iceberg.hadoop.HadoopTables; +import org.apache.iceberg.io.CloseableIterable; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableList; +import org.apache.iceberg.relocated.com.google.common.collect.Lists; +import org.apache.iceberg.relocated.com.google.common.collect.Maps; +import org.apache.iceberg.spark.SparkTestBase; +import org.apache.iceberg.spark.source.ThreeColumnRecord; +import org.apache.iceberg.types.Types; +import org.apache.spark.SparkConf; +import org.apache.spark.serializer.KryoSerializer; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Row; +import org.assertj.core.api.Assertions; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.TemporaryFolder; + +public class TestScanTaskSerialization extends SparkTestBase { + + private static final HadoopTables TABLES = new HadoopTables(new Configuration()); + private static final Schema SCHEMA = + new Schema( + optional(1, "c1", Types.IntegerType.get()), + optional(2, "c2", Types.StringType.get()), + optional(3, "c3", Types.StringType.get())); + + @Rule public TemporaryFolder temp = new TemporaryFolder(); + + private String tableLocation = null; + + @Before + public void setupTableLocation() throws Exception { + File tableDir = temp.newFolder(); + this.tableLocation = tableDir.toURI().toString(); + } + + @Test + public void testBaseCombinedScanTaskKryoSerialization() throws Exception { + BaseCombinedScanTask scanTask = prepareBaseCombinedScanTaskForSerDeTest(); + + File data = temp.newFile(); + Assert.assertTrue(data.delete()); + Kryo kryo = new KryoSerializer(new SparkConf()).newKryo(); + + try (Output out = new Output(new FileOutputStream(data))) { + kryo.writeClassAndObject(out, scanTask); + } + + try (Input in = new Input(new FileInputStream(data))) { + Object obj = kryo.readClassAndObject(in); + Assertions.assertThat(obj) + .as("Should be a BaseCombinedScanTask") + .isInstanceOf(BaseCombinedScanTask.class); + TaskCheckHelper.assertEquals(scanTask, (BaseCombinedScanTask) obj); + } + } + + @Test + public void testBaseCombinedScanTaskJavaSerialization() throws Exception { + BaseCombinedScanTask scanTask = prepareBaseCombinedScanTaskForSerDeTest(); + + ByteArrayOutputStream bytes = new ByteArrayOutputStream(); + try (ObjectOutputStream out = new ObjectOutputStream(bytes)) { + out.writeObject(scanTask); + } + + try (ObjectInputStream in = + new ObjectInputStream(new ByteArrayInputStream(bytes.toByteArray()))) { + Object obj = in.readObject(); + Assertions.assertThat(obj) + .as("Should be a BaseCombinedScanTask") + .isInstanceOf(BaseCombinedScanTask.class); + TaskCheckHelper.assertEquals(scanTask, (BaseCombinedScanTask) obj); + } + } + + @Test + @SuppressWarnings("unchecked") + public void testBaseScanTaskGroupKryoSerialization() throws Exception { + BaseScanTaskGroup taskGroup = prepareBaseScanTaskGroupForSerDeTest(); + + Assert.assertTrue("Task group can't be empty", taskGroup.tasks().size() > 0); + + File data = temp.newFile(); + Assert.assertTrue(data.delete()); + Kryo kryo = new KryoSerializer(new SparkConf()).newKryo(); + + try (Output out = new Output(Files.newOutputStream(data.toPath()))) { + kryo.writeClassAndObject(out, taskGroup); + } + + try (Input in = new Input(Files.newInputStream(data.toPath()))) { + Object obj = kryo.readClassAndObject(in); + Assertions.assertThat(obj) + .as("should be a BaseScanTaskGroup") + .isInstanceOf(BaseScanTaskGroup.class); + TaskCheckHelper.assertEquals(taskGroup, (BaseScanTaskGroup) obj); + } + } + + @Test + @SuppressWarnings("unchecked") + public void testBaseScanTaskGroupJavaSerialization() throws Exception { + BaseScanTaskGroup taskGroup = prepareBaseScanTaskGroupForSerDeTest(); + + Assert.assertTrue("Task group can't be empty", taskGroup.tasks().size() > 0); + + ByteArrayOutputStream bytes = new ByteArrayOutputStream(); + try (ObjectOutputStream out = new ObjectOutputStream(bytes)) { + out.writeObject(taskGroup); + } + + try (ObjectInputStream in = + new ObjectInputStream(new ByteArrayInputStream(bytes.toByteArray()))) { + Object obj = in.readObject(); + Assertions.assertThat(obj) + .as("should be a BaseScanTaskGroup") + .isInstanceOf(BaseScanTaskGroup.class); + TaskCheckHelper.assertEquals(taskGroup, (BaseScanTaskGroup) obj); + } + } + + private BaseCombinedScanTask prepareBaseCombinedScanTaskForSerDeTest() { + Table table = initTable(); + CloseableIterable tasks = table.newScan().planFiles(); + return new BaseCombinedScanTask(Lists.newArrayList(tasks)); + } + + private BaseScanTaskGroup prepareBaseScanTaskGroupForSerDeTest() { + Table table = initTable(); + CloseableIterable tasks = table.newScan().planFiles(); + return new BaseScanTaskGroup<>(ImmutableList.copyOf(tasks)); + } + + private Table initTable() { + PartitionSpec spec = PartitionSpec.unpartitioned(); + Map options = Maps.newHashMap(); + Table table = TABLES.create(SCHEMA, spec, options, tableLocation); + + List records1 = + Lists.newArrayList( + new ThreeColumnRecord(1, null, "AAAA"), new ThreeColumnRecord(1, "BBBBBBBBBB", "BBBB")); + writeRecords(records1); + + List records2 = + Lists.newArrayList( + new ThreeColumnRecord(2, "CCCCCCCCCC", "CCCC"), + new ThreeColumnRecord(2, "DDDDDDDDDD", "DDDD")); + writeRecords(records2); + + table.refresh(); + + return table; + } + + private void writeRecords(List records) { + Dataset df = spark.createDataFrame(records, ThreeColumnRecord.class); + writeDF(df); + } + + private void writeDF(Dataset df) { + df.select("c1", "c2", "c3").write().format("iceberg").mode("append").save(tableLocation); + } +} diff --git a/spark/v3.4/spark/src/test/java/org/apache/iceberg/TestTableSerialization.java b/spark/v3.4/spark/src/test/java/org/apache/iceberg/TestTableSerialization.java new file mode 100644 index 000000000000..b134aacac0d7 --- /dev/null +++ b/spark/v3.4/spark/src/test/java/org/apache/iceberg/TestTableSerialization.java @@ -0,0 +1,113 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg; + +import static org.apache.iceberg.types.Types.NestedField.optional; +import static org.apache.iceberg.types.Types.NestedField.required; + +import java.io.File; +import java.io.IOException; +import java.util.Map; +import org.apache.iceberg.hadoop.HadoopTables; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap; +import org.apache.iceberg.spark.source.SerializableTableWithSize; +import org.apache.iceberg.types.Types; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.TemporaryFolder; +import org.junit.runner.RunWith; +import org.junit.runners.Parameterized; + +@RunWith(Parameterized.class) +public class TestTableSerialization { + + public TestTableSerialization(String isObjectStoreEnabled) { + this.isObjectStoreEnabled = isObjectStoreEnabled; + } + + @Parameterized.Parameters(name = "isObjectStoreEnabled = {0}") + public static Object[] parameters() { + return new Object[] {"true", "false"}; + } + + private static final HadoopTables TABLES = new HadoopTables(); + + private final String isObjectStoreEnabled; + + private static final Schema SCHEMA = + new Schema( + required(1, "id", Types.LongType.get()), + optional(2, "data", Types.StringType.get()), + required(3, "date", Types.StringType.get()), + optional(4, "double", Types.DoubleType.get())); + + private static final PartitionSpec SPEC = + PartitionSpec.builderFor(SCHEMA).identity("date").build(); + + private static final SortOrder SORT_ORDER = SortOrder.builderFor(SCHEMA).asc("id").build(); + + @Rule public TemporaryFolder temp = new TemporaryFolder(); + private Table table; + + @Before + public void initTable() throws IOException { + Map props = + ImmutableMap.of("k1", "v1", TableProperties.OBJECT_STORE_ENABLED, isObjectStoreEnabled); + + File tableLocation = temp.newFolder(); + Assert.assertTrue(tableLocation.delete()); + + this.table = TABLES.create(SCHEMA, SPEC, SORT_ORDER, props, tableLocation.toString()); + } + + @Test + public void testSerializableTableKryoSerialization() throws IOException { + Table serializableTable = SerializableTableWithSize.copyOf(table); + TestHelpers.assertSerializedAndLoadedMetadata( + table, KryoHelpers.roundTripSerialize(serializableTable)); + } + + @Test + public void testSerializableMetadataTableKryoSerialization() throws IOException { + for (MetadataTableType type : MetadataTableType.values()) { + TableOperations ops = ((HasTableOperations) table).operations(); + Table metadataTable = + MetadataTableUtils.createMetadataTableInstance(ops, table.name(), "meta", type); + Table serializableMetadataTable = SerializableTableWithSize.copyOf(metadataTable); + + TestHelpers.assertSerializedAndLoadedMetadata( + metadataTable, KryoHelpers.roundTripSerialize(serializableMetadataTable)); + } + } + + @Test + public void testSerializableTransactionTableKryoSerialization() throws IOException { + Transaction txn = table.newTransaction(); + + txn.updateProperties().set("k1", "v1").commit(); + + Table txnTable = txn.table(); + Table serializableTxnTable = SerializableTableWithSize.copyOf(txnTable); + + TestHelpers.assertSerializedMetadata( + txnTable, KryoHelpers.roundTripSerialize(serializableTxnTable)); + } +} diff --git a/spark/v3.4/spark/src/test/java/org/apache/iceberg/ValidationHelpers.java b/spark/v3.4/spark/src/test/java/org/apache/iceberg/ValidationHelpers.java new file mode 100644 index 000000000000..70ab04f0a080 --- /dev/null +++ b/spark/v3.4/spark/src/test/java/org/apache/iceberg/ValidationHelpers.java @@ -0,0 +1,77 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg; + +import java.util.Arrays; +import java.util.List; +import java.util.stream.Collectors; +import org.apache.iceberg.relocated.com.google.common.collect.Lists; +import org.assertj.core.api.Assertions; + +public class ValidationHelpers { + + private ValidationHelpers() {} + + public static List dataSeqs(Long... seqs) { + return Arrays.asList(seqs); + } + + public static List fileSeqs(Long... seqs) { + return Arrays.asList(seqs); + } + + public static List snapshotIds(Long... ids) { + return Arrays.asList(ids); + } + + public static List files(ContentFile... files) { + return Arrays.stream(files).map(file -> file.path().toString()).collect(Collectors.toList()); + } + + public static void validateDataManifest( + Table table, + ManifestFile manifest, + List dataSeqs, + List fileSeqs, + List snapshotIds, + List files) { + + List actualDataSeqs = Lists.newArrayList(); + List actualFileSeqs = Lists.newArrayList(); + List actualSnapshotIds = Lists.newArrayList(); + List actualFiles = Lists.newArrayList(); + + for (ManifestEntry entry : ManifestFiles.read(manifest, table.io()).entries()) { + actualDataSeqs.add(entry.dataSequenceNumber()); + actualFileSeqs.add(entry.fileSequenceNumber()); + actualSnapshotIds.add(entry.snapshotId()); + actualFiles.add(entry.file().path().toString()); + } + + assertSameElements("data seqs", actualDataSeqs, dataSeqs); + assertSameElements("file seqs", actualFileSeqs, fileSeqs); + assertSameElements("snapshot IDs", actualSnapshotIds, snapshotIds); + assertSameElements("files", actualFiles, files); + } + + private static void assertSameElements(String context, List actual, List expected) { + String errorMessage = String.format("%s must match", context); + Assertions.assertThat(actual).as(errorMessage).hasSameElementsAs(expected); + } +} diff --git a/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/SparkCatalogConfig.java b/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/SparkCatalogConfig.java new file mode 100644 index 000000000000..fc18ed3bb174 --- /dev/null +++ b/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/SparkCatalogConfig.java @@ -0,0 +1,67 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark; + +import java.util.Map; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap; + +public enum SparkCatalogConfig { + HIVE( + "testhive", + SparkCatalog.class.getName(), + ImmutableMap.of( + "type", "hive", + "default-namespace", "default")), + HADOOP( + "testhadoop", + SparkCatalog.class.getName(), + ImmutableMap.of("type", "hadoop", "cache-enabled", "false")), + SPARK( + "spark_catalog", + SparkSessionCatalog.class.getName(), + ImmutableMap.of( + "type", "hive", + "default-namespace", "default", + "parquet-enabled", "true", + "cache-enabled", + "false" // Spark will delete tables using v1, leaving the cache out of sync + )); + + private final String catalogName; + private final String implementation; + private final Map properties; + + SparkCatalogConfig(String catalogName, String implementation, Map properties) { + this.catalogName = catalogName; + this.implementation = implementation; + this.properties = properties; + } + + public String catalogName() { + return catalogName; + } + + public String implementation() { + return implementation; + } + + public Map properties() { + return properties; + } +} diff --git a/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/SparkCatalogTestBase.java b/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/SparkCatalogTestBase.java new file mode 100644 index 000000000000..89323c26100c --- /dev/null +++ b/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/SparkCatalogTestBase.java @@ -0,0 +1,62 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark; + +import java.util.Map; +import org.junit.Rule; +import org.junit.rules.TemporaryFolder; +import org.junit.runner.RunWith; +import org.junit.runners.Parameterized; + +@RunWith(Parameterized.class) +public abstract class SparkCatalogTestBase extends SparkTestBaseWithCatalog { + + // these parameters are broken out to avoid changes that need to modify lots of test suites + @Parameterized.Parameters(name = "catalogName = {0}, implementation = {1}, config = {2}") + public static Object[][] parameters() { + return new Object[][] { + { + SparkCatalogConfig.HIVE.catalogName(), + SparkCatalogConfig.HIVE.implementation(), + SparkCatalogConfig.HIVE.properties() + }, + { + SparkCatalogConfig.HADOOP.catalogName(), + SparkCatalogConfig.HADOOP.implementation(), + SparkCatalogConfig.HADOOP.properties() + }, + { + SparkCatalogConfig.SPARK.catalogName(), + SparkCatalogConfig.SPARK.implementation(), + SparkCatalogConfig.SPARK.properties() + } + }; + } + + @Rule public TemporaryFolder temp = new TemporaryFolder(); + + public SparkCatalogTestBase(SparkCatalogConfig config) { + super(config); + } + + public SparkCatalogTestBase( + String catalogName, String implementation, Map config) { + super(catalogName, implementation, config); + } +} diff --git a/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/SparkTestBase.java b/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/SparkTestBase.java new file mode 100644 index 000000000000..82b36528996f --- /dev/null +++ b/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/SparkTestBase.java @@ -0,0 +1,287 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark; + +import static org.apache.hadoop.hive.conf.HiveConf.ConfVars.METASTOREURIS; + +import java.io.IOException; +import java.io.UncheckedIOException; +import java.net.URI; +import java.nio.file.Files; +import java.nio.file.Path; +import java.nio.file.Paths; +import java.util.List; +import java.util.Map; +import java.util.TimeZone; +import java.util.concurrent.TimeoutException; +import java.util.concurrent.atomic.AtomicReference; +import org.apache.hadoop.hive.conf.HiveConf; +import org.apache.iceberg.CatalogUtil; +import org.apache.iceberg.ContentFile; +import org.apache.iceberg.catalog.Namespace; +import org.apache.iceberg.exceptions.AlreadyExistsException; +import org.apache.iceberg.hive.HiveCatalog; +import org.apache.iceberg.hive.TestHiveMetastore; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableList; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap; +import org.apache.iceberg.relocated.com.google.common.collect.Iterables; +import org.apache.iceberg.relocated.com.google.common.collect.Maps; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Encoders; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.SparkSession; +import org.apache.spark.sql.catalyst.analysis.NoSuchTableException; +import org.apache.spark.sql.execution.QueryExecution; +import org.apache.spark.sql.execution.SparkPlan; +import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanExec; +import org.apache.spark.sql.internal.SQLConf; +import org.apache.spark.sql.util.QueryExecutionListener; +import org.junit.AfterClass; +import org.junit.Assert; +import org.junit.BeforeClass; + +public abstract class SparkTestBase extends SparkTestHelperBase { + + protected static TestHiveMetastore metastore = null; + protected static HiveConf hiveConf = null; + protected static SparkSession spark = null; + protected static JavaSparkContext sparkContext = null; + protected static HiveCatalog catalog = null; + + @BeforeClass + public static void startMetastoreAndSpark() { + SparkTestBase.metastore = new TestHiveMetastore(); + metastore.start(); + SparkTestBase.hiveConf = metastore.hiveConf(); + + SparkTestBase.spark = + SparkSession.builder() + .master("local[2]") + .config(SQLConf.PARTITION_OVERWRITE_MODE().key(), "dynamic") + .config("spark.hadoop." + METASTOREURIS.varname, hiveConf.get(METASTOREURIS.varname)) + .config("spark.sql.legacy.respectNullabilityInTextDatasetConversion", "true") + .enableHiveSupport() + .getOrCreate(); + + SparkTestBase.sparkContext = JavaSparkContext.fromSparkContext(spark.sparkContext()); + + SparkTestBase.catalog = + (HiveCatalog) + CatalogUtil.loadCatalog( + HiveCatalog.class.getName(), "hive", ImmutableMap.of(), hiveConf); + + try { + catalog.createNamespace(Namespace.of("default")); + } catch (AlreadyExistsException ignored) { + // the default namespace already exists. ignore the create error + } + } + + @AfterClass + public static void stopMetastoreAndSpark() throws Exception { + SparkTestBase.catalog = null; + if (metastore != null) { + metastore.stop(); + SparkTestBase.metastore = null; + } + if (spark != null) { + spark.stop(); + SparkTestBase.spark = null; + SparkTestBase.sparkContext = null; + } + } + + protected long waitUntilAfter(long timestampMillis) { + long current = System.currentTimeMillis(); + while (current <= timestampMillis) { + current = System.currentTimeMillis(); + } + return current; + } + + protected List sql(String query, Object... args) { + List rows = spark.sql(String.format(query, args)).collectAsList(); + if (rows.size() < 1) { + return ImmutableList.of(); + } + + return rowsToJava(rows); + } + + protected Object scalarSql(String query, Object... args) { + List rows = sql(query, args); + Assert.assertEquals("Scalar SQL should return one row", 1, rows.size()); + Object[] row = Iterables.getOnlyElement(rows); + Assert.assertEquals("Scalar SQL should return one value", 1, row.length); + return row[0]; + } + + protected Object[] row(Object... values) { + return values; + } + + protected static String dbPath(String dbName) { + return metastore.getDatabasePath(dbName); + } + + protected void withUnavailableFiles(Iterable> files, Action action) { + Iterable fileLocations = Iterables.transform(files, file -> file.path().toString()); + withUnavailableLocations(fileLocations, action); + } + + private void move(String location, String newLocation) { + Path path = Paths.get(URI.create(location)); + Path tempPath = Paths.get(URI.create(newLocation)); + + try { + Files.move(path, tempPath); + } catch (IOException e) { + throw new UncheckedIOException("Failed to move: " + location, e); + } + } + + protected void withUnavailableLocations(Iterable locations, Action action) { + for (String location : locations) { + move(location, location + "_temp"); + } + + try { + action.invoke(); + } finally { + for (String location : locations) { + move(location + "_temp", location); + } + } + } + + protected void withDefaultTimeZone(String zoneId, Action action) { + TimeZone currentZone = TimeZone.getDefault(); + try { + TimeZone.setDefault(TimeZone.getTimeZone(zoneId)); + action.invoke(); + } finally { + TimeZone.setDefault(currentZone); + } + } + + protected void withSQLConf(Map conf, Action action) { + SQLConf sqlConf = SQLConf.get(); + + Map currentConfValues = Maps.newHashMap(); + conf.keySet() + .forEach( + confKey -> { + if (sqlConf.contains(confKey)) { + String currentConfValue = sqlConf.getConfString(confKey); + currentConfValues.put(confKey, currentConfValue); + } + }); + + conf.forEach( + (confKey, confValue) -> { + if (SQLConf.isStaticConfigKey(confKey)) { + throw new RuntimeException("Cannot modify the value of a static config: " + confKey); + } + sqlConf.setConfString(confKey, confValue); + }); + + try { + action.invoke(); + } finally { + conf.forEach( + (confKey, confValue) -> { + if (currentConfValues.containsKey(confKey)) { + sqlConf.setConfString(confKey, currentConfValues.get(confKey)); + } else { + sqlConf.unsetConf(confKey); + } + }); + } + } + + protected Dataset jsonToDF(String schema, String... records) { + Dataset jsonDF = spark.createDataset(ImmutableList.copyOf(records), Encoders.STRING()); + return spark.read().schema(schema).json(jsonDF); + } + + protected void append(String table, String... jsonRecords) { + try { + String schema = spark.table(table).schema().toDDL(); + Dataset df = jsonToDF(schema, jsonRecords); + df.coalesce(1).writeTo(table).append(); + } catch (NoSuchTableException e) { + throw new RuntimeException("Failed to write data", e); + } + } + + protected String tablePropsAsString(Map tableProps) { + StringBuilder stringBuilder = new StringBuilder(); + + for (Map.Entry property : tableProps.entrySet()) { + if (stringBuilder.length() > 0) { + stringBuilder.append(", "); + } + stringBuilder.append(String.format("'%s' '%s'", property.getKey(), property.getValue())); + } + + return stringBuilder.toString(); + } + + protected SparkPlan executeAndKeepPlan(String query, Object... args) { + return executeAndKeepPlan(() -> sql(query, args)); + } + + protected SparkPlan executeAndKeepPlan(Action action) { + AtomicReference executedPlanRef = new AtomicReference<>(); + + QueryExecutionListener listener = + new QueryExecutionListener() { + @Override + public void onSuccess(String funcName, QueryExecution qe, long durationNs) { + executedPlanRef.set(qe.executedPlan()); + } + + @Override + public void onFailure(String funcName, QueryExecution qe, Exception exception) {} + }; + + spark.listenerManager().register(listener); + + action.invoke(); + + try { + spark.sparkContext().listenerBus().waitUntilEmpty(); + } catch (TimeoutException e) { + throw new RuntimeException("Timeout while waiting for processing events", e); + } + + SparkPlan executedPlan = executedPlanRef.get(); + if (executedPlan instanceof AdaptiveSparkPlanExec) { + return ((AdaptiveSparkPlanExec) executedPlan).executedPlan(); + } else { + return executedPlan; + } + } + + @FunctionalInterface + protected interface Action { + void invoke(); + } +} diff --git a/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/SparkTestBaseWithCatalog.java b/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/SparkTestBaseWithCatalog.java new file mode 100644 index 000000000000..00a9339cb743 --- /dev/null +++ b/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/SparkTestBaseWithCatalog.java @@ -0,0 +1,105 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark; + +import java.io.File; +import java.io.IOException; +import java.util.Map; +import org.apache.hadoop.fs.FileSystem; +import org.apache.hadoop.fs.Path; +import org.apache.iceberg.catalog.Catalog; +import org.apache.iceberg.catalog.Namespace; +import org.apache.iceberg.catalog.SupportsNamespaces; +import org.apache.iceberg.catalog.TableIdentifier; +import org.apache.iceberg.hadoop.HadoopCatalog; +import org.junit.AfterClass; +import org.junit.Assert; +import org.junit.BeforeClass; +import org.junit.Rule; +import org.junit.rules.TemporaryFolder; + +public abstract class SparkTestBaseWithCatalog extends SparkTestBase { + private static File warehouse = null; + + @BeforeClass + public static void createWarehouse() throws IOException { + SparkTestBaseWithCatalog.warehouse = File.createTempFile("warehouse", null); + Assert.assertTrue(warehouse.delete()); + } + + @AfterClass + public static void dropWarehouse() throws IOException { + if (warehouse != null && warehouse.exists()) { + Path warehousePath = new Path(warehouse.getAbsolutePath()); + FileSystem fs = warehousePath.getFileSystem(hiveConf); + Assert.assertTrue("Failed to delete " + warehousePath, fs.delete(warehousePath, true)); + } + } + + @Rule public TemporaryFolder temp = new TemporaryFolder(); + + protected final String catalogName; + protected final Catalog validationCatalog; + protected final SupportsNamespaces validationNamespaceCatalog; + protected final TableIdentifier tableIdent = TableIdentifier.of(Namespace.of("default"), "table"); + protected final String tableName; + + public SparkTestBaseWithCatalog() { + this(SparkCatalogConfig.HADOOP); + } + + public SparkTestBaseWithCatalog(SparkCatalogConfig config) { + this(config.catalogName(), config.implementation(), config.properties()); + } + + public SparkTestBaseWithCatalog( + String catalogName, String implementation, Map config) { + this.catalogName = catalogName; + this.validationCatalog = + catalogName.equals("testhadoop") + ? new HadoopCatalog(spark.sessionState().newHadoopConf(), "file:" + warehouse) + : catalog; + this.validationNamespaceCatalog = (SupportsNamespaces) validationCatalog; + + spark.conf().set("spark.sql.catalog." + catalogName, implementation); + config.forEach( + (key, value) -> spark.conf().set("spark.sql.catalog." + catalogName + "." + key, value)); + + if (config.get("type").equalsIgnoreCase("hadoop")) { + spark.conf().set("spark.sql.catalog." + catalogName + ".warehouse", "file:" + warehouse); + } + + this.tableName = + (catalogName.equals("spark_catalog") ? "" : catalogName + ".") + "default.table"; + + sql("CREATE NAMESPACE IF NOT EXISTS default"); + } + + protected String tableName(String name) { + return (catalogName.equals("spark_catalog") ? "" : catalogName + ".") + "default." + name; + } + + protected String commitTarget() { + return tableName; + } + + protected String selectTarget() { + return tableName; + } +} diff --git a/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/SparkTestHelperBase.java b/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/SparkTestHelperBase.java new file mode 100644 index 000000000000..97484702cad6 --- /dev/null +++ b/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/SparkTestHelperBase.java @@ -0,0 +1,88 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark; + +import java.util.List; +import java.util.stream.Collectors; +import java.util.stream.IntStream; +import org.apache.spark.sql.Row; +import org.junit.Assert; + +public class SparkTestHelperBase { + protected static final Object ANY = new Object(); + + protected List rowsToJava(List rows) { + return rows.stream().map(this::toJava).collect(Collectors.toList()); + } + + private Object[] toJava(Row row) { + return IntStream.range(0, row.size()) + .mapToObj( + pos -> { + if (row.isNullAt(pos)) { + return null; + } + + Object value = row.get(pos); + if (value instanceof Row) { + return toJava((Row) value); + } else if (value instanceof scala.collection.Seq) { + return row.getList(pos); + } else if (value instanceof scala.collection.Map) { + return row.getJavaMap(pos); + } else { + return value; + } + }) + .toArray(Object[]::new); + } + + protected void assertEquals( + String context, List expectedRows, List actualRows) { + Assert.assertEquals( + context + ": number of results should match", expectedRows.size(), actualRows.size()); + for (int row = 0; row < expectedRows.size(); row += 1) { + Object[] expected = expectedRows.get(row); + Object[] actual = actualRows.get(row); + Assert.assertEquals("Number of columns should match", expected.length, actual.length); + for (int col = 0; col < actualRows.get(row).length; col += 1) { + String newContext = String.format("%s: row %d col %d", context, row + 1, col + 1); + assertEquals(newContext, expected, actual); + } + } + } + + protected void assertEquals(String context, Object[] expectedRow, Object[] actualRow) { + Assert.assertEquals("Number of columns should match", expectedRow.length, actualRow.length); + for (int col = 0; col < actualRow.length; col += 1) { + Object expectedValue = expectedRow[col]; + Object actualValue = actualRow[col]; + if (expectedValue != null && expectedValue.getClass().isArray()) { + String newContext = String.format("%s (nested col %d)", context, col + 1); + if (expectedValue instanceof byte[]) { + Assert.assertArrayEquals(newContext, (byte[]) expectedValue, (byte[]) actualValue); + } else { + assertEquals(newContext, (Object[]) expectedValue, (Object[]) actualValue); + } + } else if (expectedValue != ANY) { + Assert.assertEquals(context + " contents should match", expectedValue, actualValue); + } + } + } +} diff --git a/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/TestChangelogIterator.java b/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/TestChangelogIterator.java new file mode 100644 index 000000000000..b75ba47db120 --- /dev/null +++ b/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/TestChangelogIterator.java @@ -0,0 +1,205 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark; + +import java.util.Arrays; +import java.util.Collections; +import java.util.Iterator; +import java.util.List; +import org.apache.iceberg.ChangelogOperation; +import org.apache.iceberg.MetadataColumns; +import org.apache.iceberg.relocated.com.google.common.collect.Lists; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.catalyst.expressions.GenericRowWithSchema; +import org.apache.spark.sql.types.DataTypes; +import org.apache.spark.sql.types.Metadata; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; +import org.junit.Assert; +import org.junit.Test; + +public class TestChangelogIterator extends SparkTestHelperBase { + private static final String DELETE = ChangelogOperation.DELETE.name(); + private static final String INSERT = ChangelogOperation.INSERT.name(); + private static final String UPDATE_BEFORE = ChangelogOperation.UPDATE_BEFORE.name(); + private static final String UPDATE_AFTER = ChangelogOperation.UPDATE_AFTER.name(); + + private static final StructType SCHEMA = + new StructType( + new StructField[] { + new StructField("id", DataTypes.IntegerType, false, Metadata.empty()), + new StructField("name", DataTypes.StringType, false, Metadata.empty()), + new StructField("data", DataTypes.StringType, true, Metadata.empty()), + new StructField( + MetadataColumns.CHANGE_TYPE.name(), DataTypes.StringType, false, Metadata.empty()) + }); + private static final String[] IDENTIFIER_FIELDS = new String[] {"id", "name"}; + + private enum RowType { + DELETED, + INSERTED, + CARRY_OVER, + UPDATED + } + + @Test + public void testIterator() { + List permutations = Lists.newArrayList(); + // generate 24 permutations + permute( + Arrays.asList(RowType.DELETED, RowType.INSERTED, RowType.CARRY_OVER, RowType.UPDATED), + 0, + permutations); + Assert.assertEquals(24, permutations.size()); + + for (Object[] permutation : permutations) { + validate(permutation); + } + } + + private void validate(Object[] permutation) { + List rows = Lists.newArrayList(); + List expectedRows = Lists.newArrayList(); + for (int i = 0; i < permutation.length; i++) { + rows.addAll(toOriginalRows((RowType) permutation[i], i)); + expectedRows.addAll(toExpectedRows((RowType) permutation[i], i)); + } + + Iterator iterator = ChangelogIterator.create(rows.iterator(), SCHEMA, IDENTIFIER_FIELDS); + List result = Lists.newArrayList(iterator); + assertEquals("Rows should match", expectedRows, rowsToJava(result)); + } + + private List toOriginalRows(RowType rowType, int index) { + switch (rowType) { + case DELETED: + return Lists.newArrayList( + new GenericRowWithSchema(new Object[] {index, "b", "data", DELETE}, null)); + case INSERTED: + return Lists.newArrayList( + new GenericRowWithSchema(new Object[] {index, "c", "data", INSERT}, null)); + case CARRY_OVER: + return Lists.newArrayList( + new GenericRowWithSchema(new Object[] {index, "d", "data", DELETE}, null), + new GenericRowWithSchema(new Object[] {index, "d", "data", INSERT}, null)); + case UPDATED: + return Lists.newArrayList( + new GenericRowWithSchema(new Object[] {index, "a", "data", DELETE}, null), + new GenericRowWithSchema(new Object[] {index, "a", "new_data", INSERT}, null)); + default: + throw new IllegalArgumentException("Unknown row type: " + rowType); + } + } + + private List toExpectedRows(RowType rowType, int order) { + switch (rowType) { + case DELETED: + List rows = Lists.newArrayList(); + rows.add(new Object[] {order, "b", "data", DELETE}); + return rows; + case INSERTED: + List insertedRows = Lists.newArrayList(); + insertedRows.add(new Object[] {order, "c", "data", INSERT}); + return insertedRows; + case CARRY_OVER: + return Lists.newArrayList(); + case UPDATED: + return Lists.newArrayList( + new Object[] {order, "a", "data", UPDATE_BEFORE}, + new Object[] {order, "a", "new_data", UPDATE_AFTER}); + default: + throw new IllegalArgumentException("Unknown row type: " + rowType); + } + } + + private void permute(List arr, int start, List pm) { + for (int i = start; i < arr.size(); i++) { + Collections.swap(arr, i, start); + permute(arr, start + 1, pm); + Collections.swap(arr, start, i); + } + if (start == arr.size() - 1) { + pm.add(arr.toArray()); + } + } + + @Test + public void testRowsWithNullValue() { + final List rowsWithNull = + Lists.newArrayList( + new GenericRowWithSchema(new Object[] {2, null, null, DELETE}, null), + new GenericRowWithSchema(new Object[] {3, null, null, INSERT}, null), + new GenericRowWithSchema(new Object[] {4, null, null, DELETE}, null), + new GenericRowWithSchema(new Object[] {4, null, null, INSERT}, null), + // mixed null and non-null value in non-identifier columns + new GenericRowWithSchema(new Object[] {5, null, null, DELETE}, null), + new GenericRowWithSchema(new Object[] {5, null, "data", INSERT}, null), + // mixed null and non-null value in identifier columns + new GenericRowWithSchema(new Object[] {6, null, null, DELETE}, null), + new GenericRowWithSchema(new Object[] {6, "name", null, INSERT}, null)); + + Iterator iterator = + ChangelogIterator.create(rowsWithNull.iterator(), SCHEMA, IDENTIFIER_FIELDS); + List result = Lists.newArrayList(iterator); + + assertEquals( + "Rows should match", + Lists.newArrayList( + new Object[] {2, null, null, DELETE}, + new Object[] {3, null, null, INSERT}, + new Object[] {5, null, null, UPDATE_BEFORE}, + new Object[] {5, null, "data", UPDATE_AFTER}, + new Object[] {6, null, null, DELETE}, + new Object[] {6, "name", null, INSERT}), + rowsToJava(result)); + } + + @Test + public void testUpdatedRowsWithDuplication() { + List rowsWithDuplication = + Lists.newArrayList( + // next two rows are identical + new GenericRowWithSchema(new Object[] {1, "a", "data", DELETE}, null), + new GenericRowWithSchema(new Object[] {1, "a", "data", DELETE}, null), + // next two rows are identical + new GenericRowWithSchema(new Object[] {1, "a", "new_data", INSERT}, null), + new GenericRowWithSchema(new Object[] {1, "a", "new_data", INSERT}, null), + // next two rows are identical + new GenericRowWithSchema(new Object[] {4, "d", "data", DELETE}, null), + new GenericRowWithSchema(new Object[] {4, "d", "data", DELETE}, null), + // next two rows are identical + new GenericRowWithSchema(new Object[] {4, "d", "data", INSERT}, null), + new GenericRowWithSchema(new Object[] {4, "d", "data", INSERT}, null)); + + Iterator iterator = + ChangelogIterator.create(rowsWithDuplication.iterator(), SCHEMA, IDENTIFIER_FIELDS); + List result = Lists.newArrayList(iterator); + + assertEquals( + "Duplicate rows should not be removed", + Lists.newArrayList( + new Object[] {1, "a", "data", DELETE}, + new Object[] {1, "a", "data", UPDATE_BEFORE}, + new Object[] {1, "a", "new_data", UPDATE_AFTER}, + new Object[] {1, "a", "new_data", INSERT}, + new Object[] {4, "d", "data", DELETE}, + new Object[] {4, "d", "data", INSERT}), + rowsToJava(result)); + } +} diff --git a/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/TestFileRewriteCoordinator.java b/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/TestFileRewriteCoordinator.java new file mode 100644 index 000000000000..a0e231e863dd --- /dev/null +++ b/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/TestFileRewriteCoordinator.java @@ -0,0 +1,273 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark; + +import java.io.IOException; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.UUID; +import java.util.stream.Collectors; +import org.apache.iceberg.DataFile; +import org.apache.iceberg.FileScanTask; +import org.apache.iceberg.Table; +import org.apache.iceberg.io.CloseableIterable; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableSet; +import org.apache.iceberg.relocated.com.google.common.collect.Iterables; +import org.apache.iceberg.relocated.com.google.common.collect.Lists; +import org.apache.iceberg.spark.source.SimpleRecord; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Encoders; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.catalyst.analysis.NoSuchTableException; +import org.junit.After; +import org.junit.Assert; +import org.junit.Test; + +public class TestFileRewriteCoordinator extends SparkCatalogTestBase { + + public TestFileRewriteCoordinator( + String catalogName, String implementation, Map config) { + super(catalogName, implementation, config); + } + + @After + public void removeTables() { + sql("DROP TABLE IF EXISTS %s", tableName); + } + + @Test + public void testBinPackRewrite() throws NoSuchTableException, IOException { + sql("CREATE TABLE %s (id INT, data STRING) USING iceberg", tableName); + + Dataset df = newDF(1000); + df.coalesce(1).writeTo(tableName).append(); + df.coalesce(1).writeTo(tableName).append(); + df.coalesce(1).writeTo(tableName).append(); + df.coalesce(1).writeTo(tableName).append(); + + Table table = validationCatalog.loadTable(tableIdent); + Assert.assertEquals("Should produce 4 snapshots", 4, Iterables.size(table.snapshots())); + + Dataset fileDF = + spark.read().format("iceberg").load(tableName(tableIdent.name() + ".files")); + List fileSizes = fileDF.select("file_size_in_bytes").as(Encoders.LONG()).collectAsList(); + long avgFileSize = fileSizes.stream().mapToLong(i -> i).sum() / fileSizes.size(); + + try (CloseableIterable fileScanTasks = table.newScan().planFiles()) { + String fileSetID = UUID.randomUUID().toString(); + + ScanTaskSetManager taskSetManager = ScanTaskSetManager.get(); + taskSetManager.stageTasks(table, fileSetID, Lists.newArrayList(fileScanTasks)); + + // read and pack original 4 files into 2 splits + Dataset scanDF = + spark + .read() + .format("iceberg") + .option(SparkReadOptions.SCAN_TASK_SET_ID, fileSetID) + .option(SparkReadOptions.SPLIT_SIZE, Long.toString(avgFileSize * 2)) + .option(SparkReadOptions.FILE_OPEN_COST, "0") + .load(tableName); + + // write the packed data into new files where each split becomes a new file + scanDF + .writeTo(tableName) + .option(SparkWriteOptions.REWRITTEN_FILE_SCAN_TASK_SET_ID, fileSetID) + .append(); + + // commit the rewrite + FileRewriteCoordinator rewriteCoordinator = FileRewriteCoordinator.get(); + Set rewrittenFiles = + taskSetManager.fetchTasks(table, fileSetID).stream() + .map(t -> t.asFileScanTask().file()) + .collect(Collectors.toSet()); + Set addedFiles = rewriteCoordinator.fetchNewDataFiles(table, fileSetID); + table.newRewrite().rewriteFiles(rewrittenFiles, addedFiles).commit(); + } + + table.refresh(); + + Map summary = table.currentSnapshot().summary(); + Assert.assertEquals("Deleted files count must match", "4", summary.get("deleted-data-files")); + Assert.assertEquals("Added files count must match", "2", summary.get("added-data-files")); + + Object rowCount = scalarSql("SELECT count(*) FROM %s", tableName); + Assert.assertEquals("Row count must match", 4000L, rowCount); + } + + @Test + public void testSortRewrite() throws NoSuchTableException, IOException { + sql("CREATE TABLE %s (id INT, data STRING) USING iceberg", tableName); + + Dataset df = newDF(1000); + df.coalesce(1).writeTo(tableName).append(); + df.coalesce(1).writeTo(tableName).append(); + df.coalesce(1).writeTo(tableName).append(); + df.coalesce(1).writeTo(tableName).append(); + + Table table = validationCatalog.loadTable(tableIdent); + Assert.assertEquals("Should produce 4 snapshots", 4, Iterables.size(table.snapshots())); + + try (CloseableIterable fileScanTasks = table.newScan().planFiles()) { + String fileSetID = UUID.randomUUID().toString(); + + ScanTaskSetManager taskSetManager = ScanTaskSetManager.get(); + taskSetManager.stageTasks(table, fileSetID, Lists.newArrayList(fileScanTasks)); + + // read original 4 files as 4 splits + Dataset scanDF = + spark + .read() + .format("iceberg") + .option(SparkReadOptions.SCAN_TASK_SET_ID, fileSetID) + .option(SparkReadOptions.SPLIT_SIZE, "134217728") + .option(SparkReadOptions.FILE_OPEN_COST, "134217728") + .load(tableName); + + // make sure we disable AQE and set the number of shuffle partitions as the target num files + ImmutableMap sqlConf = + ImmutableMap.of( + "spark.sql.shuffle.partitions", "2", + "spark.sql.adaptive.enabled", "false"); + + withSQLConf( + sqlConf, + () -> { + try { + // write new files with sorted records + scanDF + .sort("id") + .writeTo(tableName) + .option(SparkWriteOptions.REWRITTEN_FILE_SCAN_TASK_SET_ID, fileSetID) + .append(); + } catch (NoSuchTableException e) { + throw new RuntimeException("Could not replace files", e); + } + }); + + // commit the rewrite + FileRewriteCoordinator rewriteCoordinator = FileRewriteCoordinator.get(); + Set rewrittenFiles = + taskSetManager.fetchTasks(table, fileSetID).stream() + .map(t -> t.asFileScanTask().file()) + .collect(Collectors.toSet()); + Set addedFiles = rewriteCoordinator.fetchNewDataFiles(table, fileSetID); + table.newRewrite().rewriteFiles(rewrittenFiles, addedFiles).commit(); + } + + table.refresh(); + + Map summary = table.currentSnapshot().summary(); + Assert.assertEquals("Deleted files count must match", "4", summary.get("deleted-data-files")); + Assert.assertEquals("Added files count must match", "2", summary.get("added-data-files")); + + Object rowCount = scalarSql("SELECT count(*) FROM %s", tableName); + Assert.assertEquals("Row count must match", 4000L, rowCount); + } + + @Test + public void testCommitMultipleRewrites() throws NoSuchTableException, IOException { + sql("CREATE TABLE %s (id INT, data STRING) USING iceberg", tableName); + + Dataset df = newDF(1000); + + // add first two files + df.coalesce(1).writeTo(tableName).append(); + df.coalesce(1).writeTo(tableName).append(); + + Table table = validationCatalog.loadTable(tableIdent); + + String firstFileSetID = UUID.randomUUID().toString(); + long firstFileSetSnapshotId = table.currentSnapshot().snapshotId(); + + ScanTaskSetManager taskSetManager = ScanTaskSetManager.get(); + + try (CloseableIterable tasks = table.newScan().planFiles()) { + // stage first 2 files for compaction + taskSetManager.stageTasks(table, firstFileSetID, Lists.newArrayList(tasks)); + } + + // add two more files + df.coalesce(1).writeTo(tableName).append(); + df.coalesce(1).writeTo(tableName).append(); + + table.refresh(); + + String secondFileSetID = UUID.randomUUID().toString(); + + try (CloseableIterable tasks = + table.newScan().appendsAfter(firstFileSetSnapshotId).planFiles()) { + // stage 2 more files for compaction + taskSetManager.stageTasks(table, secondFileSetID, Lists.newArrayList(tasks)); + } + + ImmutableSet fileSetIDs = ImmutableSet.of(firstFileSetID, secondFileSetID); + + for (String fileSetID : fileSetIDs) { + // read and pack 2 files into 1 split + Dataset scanDF = + spark + .read() + .format("iceberg") + .option(SparkReadOptions.SCAN_TASK_SET_ID, fileSetID) + .option(SparkReadOptions.SPLIT_SIZE, Long.MAX_VALUE) + .load(tableName); + + // write the combined data as one file + scanDF + .writeTo(tableName) + .option(SparkWriteOptions.REWRITTEN_FILE_SCAN_TASK_SET_ID, fileSetID) + .append(); + } + + // commit both rewrites at the same time + FileRewriteCoordinator rewriteCoordinator = FileRewriteCoordinator.get(); + Set rewrittenFiles = + fileSetIDs.stream() + .flatMap(fileSetID -> taskSetManager.fetchTasks(table, fileSetID).stream()) + .map(t -> t.asFileScanTask().file()) + .collect(Collectors.toSet()); + Set addedFiles = + fileSetIDs.stream() + .flatMap(fileSetID -> rewriteCoordinator.fetchNewDataFiles(table, fileSetID).stream()) + .collect(Collectors.toSet()); + table.newRewrite().rewriteFiles(rewrittenFiles, addedFiles).commit(); + + table.refresh(); + + Assert.assertEquals("Should produce 5 snapshots", 5, Iterables.size(table.snapshots())); + + Map summary = table.currentSnapshot().summary(); + Assert.assertEquals("Deleted files count must match", "4", summary.get("deleted-data-files")); + Assert.assertEquals("Added files count must match", "2", summary.get("added-data-files")); + + Object rowCount = scalarSql("SELECT count(*) FROM %s", tableName); + Assert.assertEquals("Row count must match", 4000L, rowCount); + } + + private Dataset newDF(int numRecords) { + List data = Lists.newArrayListWithExpectedSize(numRecords); + for (int index = 0; index < numRecords; index++) { + data.add(new SimpleRecord(index, Integer.toString(index))); + } + return spark.createDataFrame(data, SimpleRecord.class); + } +} diff --git a/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/TestFunctionCatalog.java b/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/TestFunctionCatalog.java new file mode 100644 index 000000000000..1db0fa41f7c6 --- /dev/null +++ b/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/TestFunctionCatalog.java @@ -0,0 +1,140 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark; + +import org.apache.iceberg.AssertHelpers; +import org.apache.iceberg.IcebergBuild; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableList; +import org.apache.iceberg.spark.functions.IcebergVersionFunction; +import org.apache.spark.sql.AnalysisException; +import org.apache.spark.sql.catalyst.analysis.NoSuchFunctionException; +import org.apache.spark.sql.catalyst.analysis.NoSuchNamespaceException; +import org.apache.spark.sql.connector.catalog.FunctionCatalog; +import org.apache.spark.sql.connector.catalog.Identifier; +import org.apache.spark.sql.connector.catalog.functions.UnboundFunction; +import org.assertj.core.api.Assertions; +import org.junit.After; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; + +public class TestFunctionCatalog extends SparkTestBaseWithCatalog { + private static final String[] EMPTY_NAMESPACE = new String[] {}; + private static final String[] SYSTEM_NAMESPACE = new String[] {"system"}; + private static final String[] DEFAULT_NAMESPACE = new String[] {"default"}; + private static final String[] DB_NAMESPACE = new String[] {"db"}; + private final FunctionCatalog asFunctionCatalog; + + public TestFunctionCatalog() { + this.asFunctionCatalog = castToFunctionCatalog(catalogName); + } + + @Before + public void createDefaultNamespace() { + sql("CREATE NAMESPACE IF NOT EXISTS %s", catalogName + ".default"); + } + + @After + public void dropDefaultNamespace() { + sql("DROP NAMESPACE IF EXISTS %s", catalogName + ".default"); + } + + @Test + public void testListFunctionsViaCatalog() throws NoSuchNamespaceException { + Assertions.assertThat(asFunctionCatalog.listFunctions(EMPTY_NAMESPACE)) + .anyMatch(func -> "iceberg_version".equals(func.name())); + + Assertions.assertThat(asFunctionCatalog.listFunctions(SYSTEM_NAMESPACE)) + .anyMatch(func -> "iceberg_version".equals(func.name())); + + Assert.assertArrayEquals( + "Listing functions in an existing namespace that's not system should not throw", + new Identifier[0], + asFunctionCatalog.listFunctions(DEFAULT_NAMESPACE)); + + AssertHelpers.assertThrows( + "Listing functions in a namespace that does not exist should throw", + NoSuchNamespaceException.class, + "The schema `db` cannot be found", + () -> asFunctionCatalog.listFunctions(DB_NAMESPACE)); + } + + @Test + public void testLoadFunctions() throws NoSuchFunctionException { + for (String[] namespace : ImmutableList.of(EMPTY_NAMESPACE, SYSTEM_NAMESPACE)) { + Identifier identifier = Identifier.of(namespace, "iceberg_version"); + UnboundFunction func = asFunctionCatalog.loadFunction(identifier); + + Assertions.assertThat(func) + .isNotNull() + .isInstanceOf(UnboundFunction.class) + .isExactlyInstanceOf(IcebergVersionFunction.class); + } + + AssertHelpers.assertThrows( + "Cannot load a function if it's not used with the system namespace or the empty namespace", + NoSuchFunctionException.class, + "The function default.iceberg_version cannot be found", + () -> asFunctionCatalog.loadFunction(Identifier.of(DEFAULT_NAMESPACE, "iceberg_version"))); + + Identifier undefinedFunction = Identifier.of(SYSTEM_NAMESPACE, "undefined_function"); + AssertHelpers.assertThrows( + "Cannot load a function that does not exist", + NoSuchFunctionException.class, + "The function system.undefined_function cannot be found", + () -> asFunctionCatalog.loadFunction(undefinedFunction)); + + AssertHelpers.assertThrows( + "Using an undefined function from SQL should fail analysis", + AnalysisException.class, + "Cannot resolve function", + () -> sql("SELECT undefined_function(1, 2)")); + } + + @Test + public void testCallingFunctionInSQLEndToEnd() { + String buildVersion = IcebergBuild.version(); + + Assert.assertEquals( + "Should be able to use the Iceberg version function from the fully qualified system namespace", + buildVersion, + scalarSql("SELECT %s.system.iceberg_version()", catalogName)); + + Assert.assertEquals( + "Should be able to use the Iceberg version function when fully qualified without specifying a namespace", + buildVersion, + scalarSql("SELECT %s.iceberg_version()", catalogName)); + + sql("USE %s", catalogName); + + Assert.assertEquals( + "Should be able to call iceberg_version from system namespace without fully qualified name when using Iceberg catalog", + buildVersion, + scalarSql("SELECT system.iceberg_version()")); + + Assert.assertEquals( + "Should be able to call iceberg_version from empty namespace without fully qualified name when using Iceberg catalog", + buildVersion, + scalarSql("SELECT iceberg_version()")); + } + + private FunctionCatalog castToFunctionCatalog(String name) { + return (FunctionCatalog) spark.sessionState().catalogManager().catalog(name); + } +} diff --git a/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/TestSpark3Util.java b/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/TestSpark3Util.java new file mode 100644 index 000000000000..96dc2c29eb7f --- /dev/null +++ b/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/TestSpark3Util.java @@ -0,0 +1,143 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark; + +import static org.apache.iceberg.NullOrder.NULLS_FIRST; +import static org.apache.iceberg.NullOrder.NULLS_LAST; +import static org.apache.iceberg.types.Types.NestedField.optional; +import static org.apache.iceberg.types.Types.NestedField.required; + +import org.apache.iceberg.CachingCatalog; +import org.apache.iceberg.Schema; +import org.apache.iceberg.SortOrder; +import org.apache.iceberg.SortOrderParser; +import org.apache.iceberg.Table; +import org.apache.iceberg.catalog.Catalog; +import org.apache.iceberg.types.Types; +import org.junit.Assert; +import org.junit.Test; + +public class TestSpark3Util extends SparkTestBase { + @Test + public void testDescribeSortOrder() { + Schema schema = + new Schema( + required(1, "data", Types.StringType.get()), + required(2, "time", Types.TimestampType.withoutZone())); + + Assert.assertEquals( + "Sort order isn't correct.", + "data DESC NULLS FIRST", + Spark3Util.describe(buildSortOrder("Identity", schema, 1))); + Assert.assertEquals( + "Sort order isn't correct.", + "bucket(1, data) DESC NULLS FIRST", + Spark3Util.describe(buildSortOrder("bucket[1]", schema, 1))); + Assert.assertEquals( + "Sort order isn't correct.", + "truncate(data, 3) DESC NULLS FIRST", + Spark3Util.describe(buildSortOrder("truncate[3]", schema, 1))); + Assert.assertEquals( + "Sort order isn't correct.", + "years(time) DESC NULLS FIRST", + Spark3Util.describe(buildSortOrder("year", schema, 2))); + Assert.assertEquals( + "Sort order isn't correct.", + "months(time) DESC NULLS FIRST", + Spark3Util.describe(buildSortOrder("month", schema, 2))); + Assert.assertEquals( + "Sort order isn't correct.", + "days(time) DESC NULLS FIRST", + Spark3Util.describe(buildSortOrder("day", schema, 2))); + Assert.assertEquals( + "Sort order isn't correct.", + "hours(time) DESC NULLS FIRST", + Spark3Util.describe(buildSortOrder("hour", schema, 2))); + Assert.assertEquals( + "Sort order isn't correct.", + "unknown(data) DESC NULLS FIRST", + Spark3Util.describe(buildSortOrder("unknown", schema, 1))); + + // multiple sort orders + SortOrder multiOrder = + SortOrder.builderFor(schema).asc("time", NULLS_FIRST).asc("data", NULLS_LAST).build(); + Assert.assertEquals( + "Sort order isn't correct.", + "time ASC NULLS FIRST, data ASC NULLS LAST", + Spark3Util.describe(multiOrder)); + } + + @Test + public void testDescribeSchema() { + Schema schema = + new Schema( + required(1, "data", Types.ListType.ofRequired(2, Types.StringType.get())), + optional( + 3, + "pairs", + Types.MapType.ofOptional(4, 5, Types.StringType.get(), Types.LongType.get())), + required(6, "time", Types.TimestampType.withoutZone())); + + Assert.assertEquals( + "Schema description isn't correct.", + "struct not null,pairs: map,time: timestamp not null>", + Spark3Util.describe(schema)); + } + + @Test + public void testLoadIcebergTable() throws Exception { + spark.conf().set("spark.sql.catalog.hive", SparkCatalog.class.getName()); + spark.conf().set("spark.sql.catalog.hive.type", "hive"); + spark.conf().set("spark.sql.catalog.hive.default-namespace", "default"); + + String tableFullName = "hive.default.tbl"; + sql("CREATE TABLE %s (c1 bigint, c2 string, c3 string) USING iceberg", tableFullName); + + Table table = Spark3Util.loadIcebergTable(spark, tableFullName); + Assert.assertTrue(table.name().equals(tableFullName)); + } + + @Test + public void testLoadIcebergCatalog() throws Exception { + spark.conf().set("spark.sql.catalog.test_cat", SparkCatalog.class.getName()); + spark.conf().set("spark.sql.catalog.test_cat.type", "hive"); + Catalog catalog = Spark3Util.loadIcebergCatalog(spark, "test_cat"); + Assert.assertTrue( + "Should retrieve underlying catalog class", catalog instanceof CachingCatalog); + } + + private SortOrder buildSortOrder(String transform, Schema schema, int sourceId) { + String jsonString = + "{\n" + + " \"order-id\" : 10,\n" + + " \"fields\" : [ {\n" + + " \"transform\" : \"" + + transform + + "\",\n" + + " \"source-id\" : " + + sourceId + + ",\n" + + " \"direction\" : \"desc\",\n" + + " \"null-order\" : \"nulls-first\"\n" + + " } ]\n" + + "}"; + + return SortOrderParser.fromJson(schema, jsonString); + } +} diff --git a/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/TestSparkCachedTableCatalog.java b/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/TestSparkCachedTableCatalog.java new file mode 100644 index 000000000000..23e8717fb8c3 --- /dev/null +++ b/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/TestSparkCachedTableCatalog.java @@ -0,0 +1,94 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark; + +import org.apache.iceberg.Snapshot; +import org.apache.iceberg.Table; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableList; +import org.junit.AfterClass; +import org.junit.BeforeClass; +import org.junit.Test; + +public class TestSparkCachedTableCatalog extends SparkTestBaseWithCatalog { + + private static final SparkTableCache TABLE_CACHE = SparkTableCache.get(); + + @BeforeClass + public static void setupCachedTableCatalog() { + spark.conf().set("spark.sql.catalog.testcache", SparkCachedTableCatalog.class.getName()); + } + + @AfterClass + public static void unsetCachedTableCatalog() { + spark.conf().unset("spark.sql.catalog.testcache"); + } + + public TestSparkCachedTableCatalog() { + super(SparkCatalogConfig.HIVE); + } + + @Test + public void testTimeTravel() { + sql("CREATE TABLE %s (id INT, dep STRING) USING iceberg", tableName); + + Table table = validationCatalog.loadTable(tableIdent); + + sql("INSERT INTO TABLE %s VALUES (1, 'hr')", tableName); + + table.refresh(); + Snapshot firstSnapshot = table.currentSnapshot(); + waitUntilAfter(firstSnapshot.timestampMillis()); + + sql("INSERT INTO TABLE %s VALUES (2, 'hr')", tableName); + + table.refresh(); + Snapshot secondSnapshot = table.currentSnapshot(); + waitUntilAfter(secondSnapshot.timestampMillis()); + + sql("INSERT INTO TABLE %s VALUES (3, 'hr')", tableName); + + table.refresh(); + + try { + TABLE_CACHE.add("key", table); + + assertEquals( + "Should have expected rows in 3rd snapshot", + ImmutableList.of(row(1, "hr"), row(2, "hr"), row(3, "hr")), + sql("SELECT * FROM testcache.key ORDER BY id")); + + assertEquals( + "Should have expected rows in 2nd snapshot", + ImmutableList.of(row(1, "hr"), row(2, "hr")), + sql( + "SELECT * FROM testcache.`key#at_timestamp_%s` ORDER BY id", + secondSnapshot.timestampMillis())); + + assertEquals( + "Should have expected rows in 1st snapshot", + ImmutableList.of(row(1, "hr")), + sql( + "SELECT * FROM testcache.`key#snapshot_id_%d` ORDER BY id", + firstSnapshot.snapshotId())); + + } finally { + TABLE_CACHE.remove("key"); + } + } +} diff --git a/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/TestSparkCatalogOperations.java b/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/TestSparkCatalogOperations.java new file mode 100644 index 000000000000..0836271a7c22 --- /dev/null +++ b/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/TestSparkCatalogOperations.java @@ -0,0 +1,97 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark; + +import java.util.Map; +import org.apache.iceberg.Schema; +import org.apache.iceberg.catalog.Catalog; +import org.apache.spark.sql.catalyst.analysis.NoSuchTableException; +import org.apache.spark.sql.connector.catalog.Identifier; +import org.apache.spark.sql.connector.catalog.Table; +import org.apache.spark.sql.connector.catalog.TableChange; +import org.apache.spark.sql.types.DataTypes; +import org.apache.spark.sql.types.StructField; +import org.junit.After; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; + +public class TestSparkCatalogOperations extends SparkCatalogTestBase { + public TestSparkCatalogOperations( + String catalogName, String implementation, Map config) { + super(catalogName, implementation, config); + } + + @Before + public void createTable() { + sql("CREATE TABLE %s (id bigint NOT NULL, data string) USING iceberg", tableName); + } + + @After + public void removeTable() { + sql("DROP TABLE IF EXISTS %s", tableName); + } + + @Test + public void testAlterTable() throws NoSuchTableException { + BaseCatalog catalog = (BaseCatalog) spark.sessionState().catalogManager().catalog(catalogName); + Identifier identifier = Identifier.of(tableIdent.namespace().levels(), tableIdent.name()); + + String fieldName = "location"; + String propsKey = "note"; + String propsValue = "jazz"; + Table table = + catalog.alterTable( + identifier, + TableChange.addColumn(new String[] {fieldName}, DataTypes.StringType, true), + TableChange.setProperty(propsKey, propsValue)); + + Assert.assertNotNull("Should return updated table", table); + + StructField expectedField = DataTypes.createStructField(fieldName, DataTypes.StringType, true); + Assert.assertEquals( + "Adding a column to a table should return the updated table with the new column", + table.schema().fields()[2], + expectedField); + + Assert.assertTrue( + "Adding a property to a table should return the updated table with the new property", + table.properties().containsKey(propsKey)); + Assert.assertEquals( + "Altering a table to add a new property should add the correct value", + propsValue, + table.properties().get(propsKey)); + } + + @Test + public void testInvalidateTable() { + // load table to CachingCatalog + sql("SELECT count(1) FROM %s", tableName); + + // recreate table from another catalog or program + Catalog anotherCatalog = validationCatalog; + Schema schema = anotherCatalog.loadTable(tableIdent).schema(); + anotherCatalog.dropTable(tableIdent); + anotherCatalog.createTable(tableIdent, schema); + + // invalidate and reload table + sql("REFRESH TABLE %s", tableName); + sql("SELECT count(1) FROM %s", tableName); + } +} diff --git a/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/TestSparkDistributionAndOrderingUtil.java b/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/TestSparkDistributionAndOrderingUtil.java new file mode 100644 index 000000000000..c6b1eaeceb4c --- /dev/null +++ b/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/TestSparkDistributionAndOrderingUtil.java @@ -0,0 +1,2295 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark; + +import static org.apache.iceberg.TableProperties.DELETE_DISTRIBUTION_MODE; +import static org.apache.iceberg.TableProperties.MERGE_DISTRIBUTION_MODE; +import static org.apache.iceberg.TableProperties.UPDATE_DISTRIBUTION_MODE; +import static org.apache.iceberg.TableProperties.WRITE_DISTRIBUTION_MODE; +import static org.apache.iceberg.TableProperties.WRITE_DISTRIBUTION_MODE_HASH; +import static org.apache.iceberg.TableProperties.WRITE_DISTRIBUTION_MODE_NONE; +import static org.apache.iceberg.TableProperties.WRITE_DISTRIBUTION_MODE_RANGE; +import static org.apache.spark.sql.connector.write.RowLevelOperation.Command.DELETE; +import static org.apache.spark.sql.connector.write.RowLevelOperation.Command.MERGE; +import static org.apache.spark.sql.connector.write.RowLevelOperation.Command.UPDATE; + +import org.apache.iceberg.DistributionMode; +import org.apache.iceberg.MetadataColumns; +import org.apache.iceberg.Table; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap; +import org.apache.spark.sql.connector.distributions.Distribution; +import org.apache.spark.sql.connector.distributions.Distributions; +import org.apache.spark.sql.connector.expressions.Expression; +import org.apache.spark.sql.connector.expressions.Expressions; +import org.apache.spark.sql.connector.expressions.SortDirection; +import org.apache.spark.sql.connector.expressions.SortOrder; +import org.apache.spark.sql.connector.write.RowLevelOperation.Command; +import org.junit.After; +import org.junit.Assert; +import org.junit.Test; + +public class TestSparkDistributionAndOrderingUtil extends SparkTestBaseWithCatalog { + + private static final Distribution UNSPECIFIED_DISTRIBUTION = Distributions.unspecified(); + private static final Distribution FILE_CLUSTERED_DISTRIBUTION = + Distributions.clustered( + new Expression[] {Expressions.column(MetadataColumns.FILE_PATH.name())}); + private static final Distribution SPEC_ID_PARTITION_CLUSTERED_DISTRIBUTION = + Distributions.clustered( + new Expression[] { + Expressions.column(MetadataColumns.SPEC_ID.name()), + Expressions.column(MetadataColumns.PARTITION_COLUMN_NAME) + }); + + private static final SortOrder[] EMPTY_ORDERING = new SortOrder[] {}; + private static final SortOrder[] FILE_POSITION_ORDERING = + new SortOrder[] { + Expressions.sort( + Expressions.column(MetadataColumns.FILE_PATH.name()), SortDirection.ASCENDING), + Expressions.sort( + Expressions.column(MetadataColumns.ROW_POSITION.name()), SortDirection.ASCENDING) + }; + private static final SortOrder[] SPEC_ID_PARTITION_FILE_ORDERING = + new SortOrder[] { + Expressions.sort( + Expressions.column(MetadataColumns.SPEC_ID.name()), SortDirection.ASCENDING), + Expressions.sort( + Expressions.column(MetadataColumns.PARTITION_COLUMN_NAME), SortDirection.ASCENDING), + Expressions.sort( + Expressions.column(MetadataColumns.FILE_PATH.name()), SortDirection.ASCENDING) + }; + private static final SortOrder[] SPEC_ID_PARTITION_FILE_POSITION_ORDERING = + new SortOrder[] { + Expressions.sort( + Expressions.column(MetadataColumns.SPEC_ID.name()), SortDirection.ASCENDING), + Expressions.sort( + Expressions.column(MetadataColumns.PARTITION_COLUMN_NAME), SortDirection.ASCENDING), + Expressions.sort( + Expressions.column(MetadataColumns.FILE_PATH.name()), SortDirection.ASCENDING), + Expressions.sort( + Expressions.column(MetadataColumns.ROW_POSITION.name()), SortDirection.ASCENDING) + }; + + @After + public void dropTable() { + sql("DROP TABLE IF EXISTS %s", tableName); + } + + @Test + public void testDefaultWriteUnpartitionedUnsortedTable() { + sql("CREATE TABLE %s (id bigint, data string) USING iceberg", tableName); + + Table table = validationCatalog.loadTable(tableIdent); + + checkWriteDistributionAndOrdering(table, UNSPECIFIED_DISTRIBUTION, EMPTY_ORDERING); + } + + @Test + public void testHashWriteUnpartitionedUnsortedTable() { + sql("CREATE TABLE %s (id bigint, data string) USING iceberg", tableName); + + Table table = validationCatalog.loadTable(tableIdent); + + table.updateProperties().set(WRITE_DISTRIBUTION_MODE, WRITE_DISTRIBUTION_MODE_HASH).commit(); + + checkWriteDistributionAndOrdering(table, UNSPECIFIED_DISTRIBUTION, EMPTY_ORDERING); + } + + @Test + public void testRangeWriteUnpartitionedUnsortedTable() { + sql("CREATE TABLE %s (id bigint, data string) USING iceberg", tableName); + + Table table = validationCatalog.loadTable(tableIdent); + + table.updateProperties().set(WRITE_DISTRIBUTION_MODE, WRITE_DISTRIBUTION_MODE_RANGE).commit(); + + checkWriteDistributionAndOrdering(table, UNSPECIFIED_DISTRIBUTION, EMPTY_ORDERING); + } + + @Test + public void testDefaultWriteUnpartitionedSortedTable() { + sql("CREATE TABLE %s (id bigint, data string) USING iceberg", tableName); + + Table table = validationCatalog.loadTable(tableIdent); + + table.replaceSortOrder().asc("id").asc("data").commit(); + + SortOrder[] expectedOrdering = + new SortOrder[] { + Expressions.sort(Expressions.column("id"), SortDirection.ASCENDING), + Expressions.sort(Expressions.column("data"), SortDirection.ASCENDING) + }; + + Distribution expectedDistribution = Distributions.ordered(expectedOrdering); + + checkWriteDistributionAndOrdering(table, expectedDistribution, expectedOrdering); + } + + @Test + public void testHashWriteUnpartitionedSortedTable() { + sql("CREATE TABLE %s (id bigint, data string) USING iceberg", tableName); + + Table table = validationCatalog.loadTable(tableIdent); + + table.updateProperties().set(WRITE_DISTRIBUTION_MODE, WRITE_DISTRIBUTION_MODE_HASH).commit(); + + table.replaceSortOrder().asc("id").asc("data").commit(); + + SortOrder[] expectedOrdering = + new SortOrder[] { + Expressions.sort(Expressions.column("id"), SortDirection.ASCENDING), + Expressions.sort(Expressions.column("data"), SortDirection.ASCENDING) + }; + + checkWriteDistributionAndOrdering(table, UNSPECIFIED_DISTRIBUTION, expectedOrdering); + } + + @Test + public void testRangeWriteUnpartitionedSortedTable() { + sql("CREATE TABLE %s (id bigint, data string) USING iceberg", tableName); + + Table table = validationCatalog.loadTable(tableIdent); + + table.updateProperties().set(WRITE_DISTRIBUTION_MODE, WRITE_DISTRIBUTION_MODE_RANGE).commit(); + + table.replaceSortOrder().asc("id").asc("data").commit(); + + SortOrder[] expectedOrdering = + new SortOrder[] { + Expressions.sort(Expressions.column("id"), SortDirection.ASCENDING), + Expressions.sort(Expressions.column("data"), SortDirection.ASCENDING) + }; + + Distribution expectedDistribution = Distributions.ordered(expectedOrdering); + + checkWriteDistributionAndOrdering(table, expectedDistribution, expectedOrdering); + } + + @Test + public void testDefaultWritePartitionedUnsortedTable() { + sql( + "CREATE TABLE %s (id BIGINT, data STRING, date DATE, ts TIMESTAMP) " + + "USING iceberg " + + "PARTITIONED BY (date, days(ts))", + tableName); + + Table table = validationCatalog.loadTable(tableIdent); + + Expression[] expectedClustering = + new Expression[] {Expressions.identity("date"), Expressions.days("ts")}; + Distribution expectedDistribution = Distributions.clustered(expectedClustering); + + SortOrder[] expectedOrdering = + new SortOrder[] { + Expressions.sort(Expressions.column("date"), SortDirection.ASCENDING), + Expressions.sort(Expressions.days("ts"), SortDirection.ASCENDING) + }; + + checkWriteDistributionAndOrdering(table, expectedDistribution, expectedOrdering); + } + + @Test + public void testHashWritePartitionedUnsortedTable() { + sql( + "CREATE TABLE %s (id BIGINT, data STRING, date DATE, ts TIMESTAMP) " + + "USING iceberg " + + "PARTITIONED BY (date, days(ts))", + tableName); + + Table table = validationCatalog.loadTable(tableIdent); + + table.updateProperties().set(WRITE_DISTRIBUTION_MODE, WRITE_DISTRIBUTION_MODE_HASH).commit(); + + Expression[] expectedClustering = + new Expression[] {Expressions.identity("date"), Expressions.days("ts")}; + Distribution expectedDistribution = Distributions.clustered(expectedClustering); + + SortOrder[] expectedOrdering = + new SortOrder[] { + Expressions.sort(Expressions.column("date"), SortDirection.ASCENDING), + Expressions.sort(Expressions.days("ts"), SortDirection.ASCENDING) + }; + + checkWriteDistributionAndOrdering(table, expectedDistribution, expectedOrdering); + } + + @Test + public void testRangeWritePartitionedUnsortedTable() { + sql( + "CREATE TABLE %s (id BIGINT, data STRING, date DATE, ts TIMESTAMP) " + + "USING iceberg " + + "PARTITIONED BY (date, days(ts))", + tableName); + + Table table = validationCatalog.loadTable(tableIdent); + + table.updateProperties().set(WRITE_DISTRIBUTION_MODE, WRITE_DISTRIBUTION_MODE_RANGE).commit(); + + SortOrder[] expectedOrdering = + new SortOrder[] { + Expressions.sort(Expressions.column("date"), SortDirection.ASCENDING), + Expressions.sort(Expressions.days("ts"), SortDirection.ASCENDING) + }; + + Distribution expectedDistribution = Distributions.ordered(expectedOrdering); + + checkWriteDistributionAndOrdering(table, expectedDistribution, expectedOrdering); + } + + @Test + public void testDefaultWritePartitionedSortedTable() { + sql( + "CREATE TABLE %s (id BIGINT, data STRING, date DATE, ts TIMESTAMP) " + + "USING iceberg " + + "PARTITIONED BY (date)", + tableName); + + Table table = validationCatalog.loadTable(tableIdent); + + table.replaceSortOrder().desc("id").commit(); + + SortOrder[] expectedOrdering = + new SortOrder[] { + Expressions.sort(Expressions.column("date"), SortDirection.ASCENDING), + Expressions.sort(Expressions.column("id"), SortDirection.DESCENDING) + }; + + Distribution expectedDistribution = Distributions.ordered(expectedOrdering); + + checkWriteDistributionAndOrdering(table, expectedDistribution, expectedOrdering); + } + + @Test + public void testHashWritePartitionedSortedTable() { + sql( + "CREATE TABLE %s (id BIGINT, data STRING, date DATE, ts TIMESTAMP) " + + "USING iceberg " + + "PARTITIONED BY (date, bucket(8, data))", + tableName); + + Table table = validationCatalog.loadTable(tableIdent); + + table.updateProperties().set(WRITE_DISTRIBUTION_MODE, WRITE_DISTRIBUTION_MODE_HASH).commit(); + + table.replaceSortOrder().asc("id").commit(); + + Expression[] expectedClustering = + new Expression[] {Expressions.identity("date"), Expressions.bucket(8, "data")}; + Distribution expectedDistribution = Distributions.clustered(expectedClustering); + + SortOrder[] expectedOrdering = + new SortOrder[] { + Expressions.sort(Expressions.column("date"), SortDirection.ASCENDING), + Expressions.sort(Expressions.bucket(8, "data"), SortDirection.ASCENDING), + Expressions.sort(Expressions.column("id"), SortDirection.ASCENDING) + }; + + checkWriteDistributionAndOrdering(table, expectedDistribution, expectedOrdering); + } + + @Test + public void testRangeWritePartitionedSortedTable() { + sql( + "CREATE TABLE %s (id BIGINT, data STRING, date DATE, ts TIMESTAMP) " + + "USING iceberg " + + "PARTITIONED BY (date)", + tableName); + + Table table = validationCatalog.loadTable(tableIdent); + + table.replaceSortOrder().asc("id").commit(); + + SortOrder[] expectedOrdering = + new SortOrder[] { + Expressions.sort(Expressions.column("date"), SortDirection.ASCENDING), + Expressions.sort(Expressions.column("id"), SortDirection.ASCENDING) + }; + + Distribution expectedDistribution = Distributions.ordered(expectedOrdering); + + checkWriteDistributionAndOrdering(table, expectedDistribution, expectedOrdering); + } + + // ============================================================= + // Distribution and ordering for copy-on-write DELETE operations + // ============================================================= + // + // UNPARTITIONED UNORDERED + // ------------------------------------------------------------------------- + // delete mode is NOT SET -> CLUSTER BY _file + LOCALLY ORDER BY _file, _pos + // delete mode is NONE -> unspecified distribution + empty ordering + // delete mode is HASH -> CLUSTER BY _file + LOCALLY ORDER BY _file, _pos + // delete mode is RANGE -> ORDER BY _file, _pos + // + // UNPARTITIONED ORDERED BY id, data + // ------------------------------------------------------------------------- + // delete mode is NOT SET -> CLUSTER BY _file + LOCALLY ORDER BY id, data + // delete mode is NONE -> unspecified distribution + LOCALLY ORDER BY id, data + // delete mode is HASH -> CLUSTER BY _file + LOCALLY ORDER BY id, data + // delete mode is RANGE -> ORDER BY id, data + // + // PARTITIONED BY date, days(ts) UNORDERED + // ------------------------------------------------------------------------- + // delete mode is NOT SET -> CLUSTER BY _file + LOCALLY ORDER BY date, days(ts), _file, _pos + // delete mode is NONE -> unspecified distribution + LOCALLY ORDERED BY date, days(ts) + // delete mode is HASH -> CLUSTER BY _file + LOCALLY ORDER BY date, days(ts), _file, _pos + // delete mode is RANGE -> ORDER BY date, days(ts), _file, _pos + // + // PARTITIONED BY date ORDERED BY id + // ------------------------------------------------------------------------- + // delete mode is NOT SET -> CLUSTER BY _file + LOCALLY ORDER BY date, id + // delete mode is NONE -> unspecified distribution + LOCALLY ORDERED BY date, id + // delete mode is HASH -> CLUSTER BY _file + LOCALLY ORDER BY date, id + // delete mode is RANGE -> ORDERED BY date, id + + @Test + public void testDefaultCopyOnWriteDeleteUnpartitionedUnsortedTable() { + sql("CREATE TABLE %s (id bigint, data string) USING iceberg", tableName); + + Table table = validationCatalog.loadTable(tableIdent); + + checkCopyOnWriteDistributionAndOrdering( + table, DELETE, FILE_CLUSTERED_DISTRIBUTION, FILE_POSITION_ORDERING); + } + + @Test + public void testNoneCopyOnWriteDeleteUnpartitionedUnsortedTable() { + sql("CREATE TABLE %s (id bigint, data string) USING iceberg", tableName); + + Table table = validationCatalog.loadTable(tableIdent); + + table.updateProperties().set(DELETE_DISTRIBUTION_MODE, WRITE_DISTRIBUTION_MODE_NONE).commit(); + + checkCopyOnWriteDistributionAndOrdering( + table, DELETE, UNSPECIFIED_DISTRIBUTION, EMPTY_ORDERING); + } + + @Test + public void testHashCopyOnWriteDeleteUnpartitionedUnsortedTable() { + sql("CREATE TABLE %s (id bigint, data string) USING iceberg", tableName); + + Table table = validationCatalog.loadTable(tableIdent); + + table.updateProperties().set(DELETE_DISTRIBUTION_MODE, WRITE_DISTRIBUTION_MODE_HASH).commit(); + + checkCopyOnWriteDistributionAndOrdering( + table, DELETE, FILE_CLUSTERED_DISTRIBUTION, FILE_POSITION_ORDERING); + } + + @Test + public void testRangeCopyOnWriteDeleteUnpartitionedUnsortedTable() { + sql("CREATE TABLE %s (id bigint, data string) USING iceberg", tableName); + + Table table = validationCatalog.loadTable(tableIdent); + + table.updateProperties().set(DELETE_DISTRIBUTION_MODE, WRITE_DISTRIBUTION_MODE_RANGE).commit(); + + Distribution expectedDistribution = Distributions.ordered(FILE_POSITION_ORDERING); + + checkCopyOnWriteDistributionAndOrdering( + table, DELETE, expectedDistribution, FILE_POSITION_ORDERING); + } + + @Test + public void testDefaultCopyOnWriteDeleteUnpartitionedSortedTable() { + sql("CREATE TABLE %s (id bigint, data string) USING iceberg", tableName); + + Table table = validationCatalog.loadTable(tableIdent); + + table.replaceSortOrder().asc("id").asc("data").commit(); + + SortOrder[] expectedOrdering = + new SortOrder[] { + Expressions.sort(Expressions.column("id"), SortDirection.ASCENDING), + Expressions.sort(Expressions.column("data"), SortDirection.ASCENDING) + }; + + checkCopyOnWriteDistributionAndOrdering( + table, DELETE, FILE_CLUSTERED_DISTRIBUTION, expectedOrdering); + } + + @Test + public void testNoneCopyOnWriteDeleteUnpartitionedSortedTable() { + sql("CREATE TABLE %s (id bigint, data string) USING iceberg", tableName); + + Table table = validationCatalog.loadTable(tableIdent); + + table.updateProperties().set(DELETE_DISTRIBUTION_MODE, WRITE_DISTRIBUTION_MODE_NONE).commit(); + + table.replaceSortOrder().asc("id").asc("data").commit(); + + SortOrder[] expectedOrdering = + new SortOrder[] { + Expressions.sort(Expressions.column("id"), SortDirection.ASCENDING), + Expressions.sort(Expressions.column("data"), SortDirection.ASCENDING) + }; + + checkCopyOnWriteDistributionAndOrdering( + table, DELETE, UNSPECIFIED_DISTRIBUTION, expectedOrdering); + } + + @Test + public void testHashCopyOnWriteDeleteUnpartitionedSortedTable() { + sql("CREATE TABLE %s (id bigint, data string) USING iceberg", tableName); + + Table table = validationCatalog.loadTable(tableIdent); + + table.updateProperties().set(DELETE_DISTRIBUTION_MODE, WRITE_DISTRIBUTION_MODE_HASH).commit(); + + table.replaceSortOrder().asc("id").asc("data").commit(); + + SortOrder[] expectedOrdering = + new SortOrder[] { + Expressions.sort(Expressions.column("id"), SortDirection.ASCENDING), + Expressions.sort(Expressions.column("data"), SortDirection.ASCENDING) + }; + + checkCopyOnWriteDistributionAndOrdering( + table, DELETE, FILE_CLUSTERED_DISTRIBUTION, expectedOrdering); + } + + @Test + public void testRangeCopyOnWriteDeleteUnpartitionedSortedTable() { + sql("CREATE TABLE %s (id bigint, data string) USING iceberg", tableName); + + Table table = validationCatalog.loadTable(tableIdent); + + table.updateProperties().set(DELETE_DISTRIBUTION_MODE, WRITE_DISTRIBUTION_MODE_RANGE).commit(); + + table.replaceSortOrder().asc("id").asc("data").commit(); + + SortOrder[] expectedOrdering = + new SortOrder[] { + Expressions.sort(Expressions.column("id"), SortDirection.ASCENDING), + Expressions.sort(Expressions.column("data"), SortDirection.ASCENDING) + }; + + Distribution expectedDistribution = Distributions.ordered(expectedOrdering); + + checkCopyOnWriteDistributionAndOrdering(table, DELETE, expectedDistribution, expectedOrdering); + } + + @Test + public void testDefaultCopyOnWriteDeletePartitionedUnsortedTable() { + sql( + "CREATE TABLE %s (id BIGINT, data STRING, date DATE, ts TIMESTAMP) " + + "USING iceberg " + + "PARTITIONED BY (date, days(ts))", + tableName); + + Table table = validationCatalog.loadTable(tableIdent); + + SortOrder[] expectedOrdering = + new SortOrder[] { + Expressions.sort(Expressions.column("date"), SortDirection.ASCENDING), + Expressions.sort(Expressions.days("ts"), SortDirection.ASCENDING), + Expressions.sort( + Expressions.column(MetadataColumns.FILE_PATH.name()), SortDirection.ASCENDING), + Expressions.sort( + Expressions.column(MetadataColumns.ROW_POSITION.name()), SortDirection.ASCENDING) + }; + + checkCopyOnWriteDistributionAndOrdering( + table, DELETE, FILE_CLUSTERED_DISTRIBUTION, expectedOrdering); + } + + @Test + public void testNoneCopyOnWriteDeletePartitionedUnsortedTable() { + sql( + "CREATE TABLE %s (id BIGINT, data STRING, date DATE, ts TIMESTAMP) " + + "USING iceberg " + + "PARTITIONED BY (date, days(ts))", + tableName); + + Table table = validationCatalog.loadTable(tableIdent); + + table.updateProperties().set(DELETE_DISTRIBUTION_MODE, WRITE_DISTRIBUTION_MODE_NONE).commit(); + + SortOrder[] expectedOrdering = + new SortOrder[] { + Expressions.sort(Expressions.column("date"), SortDirection.ASCENDING), + Expressions.sort(Expressions.days("ts"), SortDirection.ASCENDING) + }; + + checkCopyOnWriteDistributionAndOrdering( + table, DELETE, UNSPECIFIED_DISTRIBUTION, expectedOrdering); + } + + @Test + public void testHashCopyOnWriteDeletePartitionedUnsortedTable() { + sql( + "CREATE TABLE %s (id BIGINT, data STRING, date DATE, ts TIMESTAMP) " + + "USING iceberg " + + "PARTITIONED BY (date, days(ts))", + tableName); + + Table table = validationCatalog.loadTable(tableIdent); + + table.updateProperties().set(DELETE_DISTRIBUTION_MODE, WRITE_DISTRIBUTION_MODE_HASH).commit(); + + SortOrder[] expectedOrdering = + new SortOrder[] { + Expressions.sort(Expressions.column("date"), SortDirection.ASCENDING), + Expressions.sort(Expressions.days("ts"), SortDirection.ASCENDING), + Expressions.sort( + Expressions.column(MetadataColumns.FILE_PATH.name()), SortDirection.ASCENDING), + Expressions.sort( + Expressions.column(MetadataColumns.ROW_POSITION.name()), SortDirection.ASCENDING) + }; + + checkCopyOnWriteDistributionAndOrdering( + table, DELETE, FILE_CLUSTERED_DISTRIBUTION, expectedOrdering); + } + + @Test + public void testRangeCopyOnWriteDeletePartitionedUnsortedTable() { + sql( + "CREATE TABLE %s (id BIGINT, data STRING, date DATE, ts TIMESTAMP) " + + "USING iceberg " + + "PARTITIONED BY (date, days(ts))", + tableName); + + Table table = validationCatalog.loadTable(tableIdent); + + table.updateProperties().set(DELETE_DISTRIBUTION_MODE, WRITE_DISTRIBUTION_MODE_RANGE).commit(); + + SortOrder[] expectedOrdering = + new SortOrder[] { + Expressions.sort(Expressions.column("date"), SortDirection.ASCENDING), + Expressions.sort(Expressions.days("ts"), SortDirection.ASCENDING), + Expressions.sort( + Expressions.column(MetadataColumns.FILE_PATH.name()), SortDirection.ASCENDING), + Expressions.sort( + Expressions.column(MetadataColumns.ROW_POSITION.name()), SortDirection.ASCENDING) + }; + + Distribution expectedDistribution = Distributions.ordered(expectedOrdering); + + checkCopyOnWriteDistributionAndOrdering(table, DELETE, expectedDistribution, expectedOrdering); + } + + @Test + public void testDefaultCopyOnWriteDeletePartitionedSortedTable() { + sql( + "CREATE TABLE %s (id BIGINT, data STRING, date DATE, ts TIMESTAMP) " + + "USING iceberg " + + "PARTITIONED BY (date)", + tableName); + + Table table = validationCatalog.loadTable(tableIdent); + + table.replaceSortOrder().desc("id").commit(); + + SortOrder[] expectedOrdering = + new SortOrder[] { + Expressions.sort(Expressions.column("date"), SortDirection.ASCENDING), + Expressions.sort(Expressions.column("id"), SortDirection.DESCENDING) + }; + + checkCopyOnWriteDistributionAndOrdering( + table, DELETE, FILE_CLUSTERED_DISTRIBUTION, expectedOrdering); + } + + @Test + public void testNoneCopyOnWriteDeletePartitionedSortedTable() { + sql( + "CREATE TABLE %s (id BIGINT, data STRING, date DATE, ts TIMESTAMP) " + + "USING iceberg " + + "PARTITIONED BY (date)", + tableName); + + Table table = validationCatalog.loadTable(tableIdent); + + table.updateProperties().set(DELETE_DISTRIBUTION_MODE, WRITE_DISTRIBUTION_MODE_NONE).commit(); + + table.replaceSortOrder().desc("id").commit(); + + SortOrder[] expectedOrdering = + new SortOrder[] { + Expressions.sort(Expressions.column("date"), SortDirection.ASCENDING), + Expressions.sort(Expressions.column("id"), SortDirection.DESCENDING) + }; + + checkCopyOnWriteDistributionAndOrdering( + table, DELETE, UNSPECIFIED_DISTRIBUTION, expectedOrdering); + } + + @Test + public void testHashCopyOnWriteDeletePartitionedSortedTable() { + sql( + "CREATE TABLE %s (id BIGINT, data STRING, date DATE, ts TIMESTAMP) " + + "USING iceberg " + + "PARTITIONED BY (date, bucket(8, data))", + tableName); + + Table table = validationCatalog.loadTable(tableIdent); + + table.updateProperties().set(DELETE_DISTRIBUTION_MODE, WRITE_DISTRIBUTION_MODE_HASH).commit(); + + table.replaceSortOrder().asc("id").commit(); + + SortOrder[] expectedOrdering = + new SortOrder[] { + Expressions.sort(Expressions.column("date"), SortDirection.ASCENDING), + Expressions.sort(Expressions.bucket(8, "data"), SortDirection.ASCENDING), + Expressions.sort(Expressions.column("id"), SortDirection.ASCENDING) + }; + + checkCopyOnWriteDistributionAndOrdering( + table, DELETE, FILE_CLUSTERED_DISTRIBUTION, expectedOrdering); + } + + @Test + public void testRangeCopyOnWriteDeletePartitionedSortedTable() { + sql( + "CREATE TABLE %s (id BIGINT, data STRING, date DATE, ts TIMESTAMP) " + + "USING iceberg " + + "PARTITIONED BY (date)", + tableName); + + Table table = validationCatalog.loadTable(tableIdent); + + table.updateProperties().set(DELETE_DISTRIBUTION_MODE, WRITE_DISTRIBUTION_MODE_RANGE).commit(); + + table.replaceSortOrder().asc("id").commit(); + + SortOrder[] expectedOrdering = + new SortOrder[] { + Expressions.sort(Expressions.column("date"), SortDirection.ASCENDING), + Expressions.sort(Expressions.column("id"), SortDirection.ASCENDING) + }; + + Distribution expectedDistribution = Distributions.ordered(expectedOrdering); + + checkCopyOnWriteDistributionAndOrdering(table, DELETE, expectedDistribution, expectedOrdering); + } + + // ============================================================= + // Distribution and ordering for copy-on-write UPDATE operations + // ============================================================= + // + // UNPARTITIONED UNORDERED + // ------------------------------------------------------------------------- + // update mode is NOT SET -> CLUSTER BY _file + LOCALLY ORDER BY _file, _pos + // update mode is NONE -> unspecified distribution + empty ordering + // update mode is HASH -> CLUSTER BY _file + LOCALLY ORDER BY _file, _pos + // update mode is RANGE -> ORDER BY _file, _pos + // + // UNPARTITIONED ORDERED BY id, data + // ------------------------------------------------------------------------- + // update mode is NOT SET -> CLUSTER BY _file + LOCALLY ORDER BY id, data + // update mode is NONE -> unspecified distribution + LOCALLY ORDER BY id, data + // update mode is HASH -> CLUSTER BY _file + LOCALLY ORDER BY id, data + // update mode is RANGE -> ORDER BY id, data + // + // PARTITIONED BY date, days(ts) UNORDERED + // ------------------------------------------------------------------------- + // update mode is NOT SET -> CLUSTER BY _file + LOCALLY ORDER BY date, days(ts), _file, _pos + // update mode is NONE -> unspecified distribution + LOCALLY ORDERED BY date, days(ts) + // update mode is HASH -> CLUSTER BY _file + LOCALLY ORDER BY date, days(ts), _file, _pos + // update mode is RANGE -> ORDER BY date, days(ts), _file, _pos + // + // PARTITIONED BY date ORDERED BY id + // ------------------------------------------------------------------------- + // update mode is NOT SET -> CLUSTER BY _file + LOCALLY ORDER BY date, id + // update mode is NONE -> unspecified distribution + LOCALLY ORDERED BY date, id + // update mode is HASH -> CLUSTER BY _file + LOCALLY ORDER BY date, id + // update mode is RANGE -> ORDERED BY date, id + + @Test + public void testDefaultCopyOnWriteUpdateUnpartitionedUnsortedTable() { + sql("CREATE TABLE %s (id bigint, data string) USING iceberg", tableName); + + Table table = validationCatalog.loadTable(tableIdent); + + checkCopyOnWriteDistributionAndOrdering( + table, UPDATE, FILE_CLUSTERED_DISTRIBUTION, FILE_POSITION_ORDERING); + } + + @Test + public void testNoneCopyOnWriteUpdateUnpartitionedUnsortedTable() { + sql("CREATE TABLE %s (id bigint, data string) USING iceberg", tableName); + + Table table = validationCatalog.loadTable(tableIdent); + + table.updateProperties().set(UPDATE_DISTRIBUTION_MODE, WRITE_DISTRIBUTION_MODE_NONE).commit(); + + checkCopyOnWriteDistributionAndOrdering( + table, UPDATE, UNSPECIFIED_DISTRIBUTION, EMPTY_ORDERING); + } + + @Test + public void testHashCopyOnWriteUpdateUnpartitionedUnsortedTable() { + sql("CREATE TABLE %s (id bigint, data string) USING iceberg", tableName); + + Table table = validationCatalog.loadTable(tableIdent); + + table.updateProperties().set(UPDATE_DISTRIBUTION_MODE, WRITE_DISTRIBUTION_MODE_HASH).commit(); + + checkCopyOnWriteDistributionAndOrdering( + table, UPDATE, FILE_CLUSTERED_DISTRIBUTION, FILE_POSITION_ORDERING); + } + + @Test + public void testRangeCopyOnWriteUpdateUnpartitionedUnsortedTable() { + sql("CREATE TABLE %s (id bigint, data string) USING iceberg", tableName); + + Table table = validationCatalog.loadTable(tableIdent); + + table.updateProperties().set(UPDATE_DISTRIBUTION_MODE, WRITE_DISTRIBUTION_MODE_RANGE).commit(); + + Distribution expectedDistribution = Distributions.ordered(FILE_POSITION_ORDERING); + + checkCopyOnWriteDistributionAndOrdering( + table, UPDATE, expectedDistribution, FILE_POSITION_ORDERING); + } + + @Test + public void testDefaultCopyOnWriteUpdateUnpartitionedSortedTable() { + sql("CREATE TABLE %s (id bigint, data string) USING iceberg", tableName); + + Table table = validationCatalog.loadTable(tableIdent); + + table.replaceSortOrder().asc("id").asc("data").commit(); + + SortOrder[] expectedOrdering = + new SortOrder[] { + Expressions.sort(Expressions.column("id"), SortDirection.ASCENDING), + Expressions.sort(Expressions.column("data"), SortDirection.ASCENDING) + }; + + checkCopyOnWriteDistributionAndOrdering( + table, UPDATE, FILE_CLUSTERED_DISTRIBUTION, expectedOrdering); + } + + @Test + public void testNoneCopyOnWriteUpdateUnpartitionedSortedTable() { + sql("CREATE TABLE %s (id bigint, data string) USING iceberg", tableName); + + Table table = validationCatalog.loadTable(tableIdent); + + table.updateProperties().set(UPDATE_DISTRIBUTION_MODE, WRITE_DISTRIBUTION_MODE_NONE).commit(); + + table.replaceSortOrder().asc("id").asc("data").commit(); + + SortOrder[] expectedOrdering = + new SortOrder[] { + Expressions.sort(Expressions.column("id"), SortDirection.ASCENDING), + Expressions.sort(Expressions.column("data"), SortDirection.ASCENDING) + }; + + checkCopyOnWriteDistributionAndOrdering( + table, UPDATE, UNSPECIFIED_DISTRIBUTION, expectedOrdering); + } + + @Test + public void testHashCopyOnWriteUpdateUnpartitionedSortedTable() { + sql("CREATE TABLE %s (id bigint, data string) USING iceberg", tableName); + + Table table = validationCatalog.loadTable(tableIdent); + + table.updateProperties().set(UPDATE_DISTRIBUTION_MODE, WRITE_DISTRIBUTION_MODE_HASH).commit(); + + table.replaceSortOrder().asc("id").asc("data").commit(); + + SortOrder[] expectedOrdering = + new SortOrder[] { + Expressions.sort(Expressions.column("id"), SortDirection.ASCENDING), + Expressions.sort(Expressions.column("data"), SortDirection.ASCENDING) + }; + + checkCopyOnWriteDistributionAndOrdering( + table, UPDATE, FILE_CLUSTERED_DISTRIBUTION, expectedOrdering); + } + + @Test + public void testRangeCopyOnWriteUpdateUnpartitionedSortedTable() { + sql("CREATE TABLE %s (id bigint, data string) USING iceberg", tableName); + + Table table = validationCatalog.loadTable(tableIdent); + + table.updateProperties().set(UPDATE_DISTRIBUTION_MODE, WRITE_DISTRIBUTION_MODE_RANGE).commit(); + + table.replaceSortOrder().asc("id").asc("data").commit(); + + SortOrder[] expectedOrdering = + new SortOrder[] { + Expressions.sort(Expressions.column("id"), SortDirection.ASCENDING), + Expressions.sort(Expressions.column("data"), SortDirection.ASCENDING) + }; + + Distribution expectedDistribution = Distributions.ordered(expectedOrdering); + + checkCopyOnWriteDistributionAndOrdering(table, UPDATE, expectedDistribution, expectedOrdering); + } + + @Test + public void testDefaultCopyOnWriteUpdatePartitionedUnsortedTable() { + sql( + "CREATE TABLE %s (id BIGINT, data STRING, date DATE, ts TIMESTAMP) " + + "USING iceberg " + + "PARTITIONED BY (date, days(ts))", + tableName); + + Table table = validationCatalog.loadTable(tableIdent); + + SortOrder[] expectedOrdering = + new SortOrder[] { + Expressions.sort(Expressions.column("date"), SortDirection.ASCENDING), + Expressions.sort(Expressions.days("ts"), SortDirection.ASCENDING), + Expressions.sort( + Expressions.column(MetadataColumns.FILE_PATH.name()), SortDirection.ASCENDING), + Expressions.sort( + Expressions.column(MetadataColumns.ROW_POSITION.name()), SortDirection.ASCENDING) + }; + + checkCopyOnWriteDistributionAndOrdering( + table, UPDATE, FILE_CLUSTERED_DISTRIBUTION, expectedOrdering); + } + + @Test + public void testNoneCopyOnWriteUpdatePartitionedUnsortedTable() { + sql( + "CREATE TABLE %s (id BIGINT, data STRING, date DATE, ts TIMESTAMP) " + + "USING iceberg " + + "PARTITIONED BY (date, days(ts))", + tableName); + + Table table = validationCatalog.loadTable(tableIdent); + + table.updateProperties().set(UPDATE_DISTRIBUTION_MODE, WRITE_DISTRIBUTION_MODE_NONE).commit(); + + SortOrder[] expectedOrdering = + new SortOrder[] { + Expressions.sort(Expressions.column("date"), SortDirection.ASCENDING), + Expressions.sort(Expressions.days("ts"), SortDirection.ASCENDING) + }; + + checkCopyOnWriteDistributionAndOrdering( + table, UPDATE, UNSPECIFIED_DISTRIBUTION, expectedOrdering); + } + + @Test + public void testHashCopyOnWriteUpdatePartitionedUnsortedTable() { + sql( + "CREATE TABLE %s (id BIGINT, data STRING, date DATE, ts TIMESTAMP) " + + "USING iceberg " + + "PARTITIONED BY (date, days(ts))", + tableName); + + Table table = validationCatalog.loadTable(tableIdent); + + table.updateProperties().set(UPDATE_DISTRIBUTION_MODE, WRITE_DISTRIBUTION_MODE_HASH).commit(); + + SortOrder[] expectedOrdering = + new SortOrder[] { + Expressions.sort(Expressions.column("date"), SortDirection.ASCENDING), + Expressions.sort(Expressions.days("ts"), SortDirection.ASCENDING), + Expressions.sort( + Expressions.column(MetadataColumns.FILE_PATH.name()), SortDirection.ASCENDING), + Expressions.sort( + Expressions.column(MetadataColumns.ROW_POSITION.name()), SortDirection.ASCENDING) + }; + + checkCopyOnWriteDistributionAndOrdering( + table, UPDATE, FILE_CLUSTERED_DISTRIBUTION, expectedOrdering); + } + + @Test + public void testRangeCopyOnWriteUpdatePartitionedUnsortedTable() { + sql( + "CREATE TABLE %s (id BIGINT, data STRING, date DATE, ts TIMESTAMP) " + + "USING iceberg " + + "PARTITIONED BY (date, days(ts))", + tableName); + + Table table = validationCatalog.loadTable(tableIdent); + + table.updateProperties().set(UPDATE_DISTRIBUTION_MODE, WRITE_DISTRIBUTION_MODE_RANGE).commit(); + + SortOrder[] expectedOrdering = + new SortOrder[] { + Expressions.sort(Expressions.column("date"), SortDirection.ASCENDING), + Expressions.sort(Expressions.days("ts"), SortDirection.ASCENDING), + Expressions.sort( + Expressions.column(MetadataColumns.FILE_PATH.name()), SortDirection.ASCENDING), + Expressions.sort( + Expressions.column(MetadataColumns.ROW_POSITION.name()), SortDirection.ASCENDING) + }; + + Distribution expectedDistribution = Distributions.ordered(expectedOrdering); + + checkCopyOnWriteDistributionAndOrdering(table, UPDATE, expectedDistribution, expectedOrdering); + } + + @Test + public void testDefaultCopyOnWriteUpdatePartitionedSortedTable() { + sql( + "CREATE TABLE %s (id BIGINT, data STRING, date DATE, ts TIMESTAMP) " + + "USING iceberg " + + "PARTITIONED BY (date)", + tableName); + + Table table = validationCatalog.loadTable(tableIdent); + + table.replaceSortOrder().desc("id").commit(); + + SortOrder[] expectedOrdering = + new SortOrder[] { + Expressions.sort(Expressions.column("date"), SortDirection.ASCENDING), + Expressions.sort(Expressions.column("id"), SortDirection.DESCENDING) + }; + + checkCopyOnWriteDistributionAndOrdering( + table, UPDATE, FILE_CLUSTERED_DISTRIBUTION, expectedOrdering); + } + + @Test + public void testNoneCopyOnWriteUpdatePartitionedSortedTable() { + sql( + "CREATE TABLE %s (id BIGINT, data STRING, date DATE, ts TIMESTAMP) " + + "USING iceberg " + + "PARTITIONED BY (date)", + tableName); + + Table table = validationCatalog.loadTable(tableIdent); + + table.updateProperties().set(UPDATE_DISTRIBUTION_MODE, WRITE_DISTRIBUTION_MODE_NONE).commit(); + + table.replaceSortOrder().desc("id").commit(); + + SortOrder[] expectedOrdering = + new SortOrder[] { + Expressions.sort(Expressions.column("date"), SortDirection.ASCENDING), + Expressions.sort(Expressions.column("id"), SortDirection.DESCENDING) + }; + + checkCopyOnWriteDistributionAndOrdering( + table, UPDATE, UNSPECIFIED_DISTRIBUTION, expectedOrdering); + } + + @Test + public void testHashCopyOnWriteUpdatePartitionedSortedTable() { + sql( + "CREATE TABLE %s (id BIGINT, data STRING, date DATE, ts TIMESTAMP) " + + "USING iceberg " + + "PARTITIONED BY (date, bucket(8, data))", + tableName); + + Table table = validationCatalog.loadTable(tableIdent); + + table.updateProperties().set(UPDATE_DISTRIBUTION_MODE, WRITE_DISTRIBUTION_MODE_HASH).commit(); + + table.replaceSortOrder().asc("id").commit(); + + SortOrder[] expectedOrdering = + new SortOrder[] { + Expressions.sort(Expressions.column("date"), SortDirection.ASCENDING), + Expressions.sort(Expressions.bucket(8, "data"), SortDirection.ASCENDING), + Expressions.sort(Expressions.column("id"), SortDirection.ASCENDING) + }; + + checkCopyOnWriteDistributionAndOrdering( + table, UPDATE, FILE_CLUSTERED_DISTRIBUTION, expectedOrdering); + } + + @Test + public void testRangeCopyOnWriteUpdatePartitionedSortedTable() { + sql( + "CREATE TABLE %s (id BIGINT, data STRING, date DATE, ts TIMESTAMP) " + + "USING iceberg " + + "PARTITIONED BY (date)", + tableName); + + Table table = validationCatalog.loadTable(tableIdent); + + table.updateProperties().set(UPDATE_DISTRIBUTION_MODE, WRITE_DISTRIBUTION_MODE_RANGE).commit(); + + table.replaceSortOrder().asc("id").commit(); + + SortOrder[] expectedOrdering = + new SortOrder[] { + Expressions.sort(Expressions.column("date"), SortDirection.ASCENDING), + Expressions.sort(Expressions.column("id"), SortDirection.ASCENDING) + }; + + Distribution expectedDistribution = Distributions.ordered(expectedOrdering); + + checkCopyOnWriteDistributionAndOrdering(table, UPDATE, expectedDistribution, expectedOrdering); + } + + // ============================================================= + // Distribution and ordering for copy-on-write MERGE operations + // ============================================================= + // + // UNPARTITIONED UNORDERED + // ------------------------------------------------------------------------- + // merge mode is NOT SET -> use write distribution and ordering + // merge mode is NONE -> unspecified distribution + empty ordering + // merge mode is HASH -> unspecified distribution + empty ordering + // merge mode is RANGE -> unspecified distribution + empty ordering + // + // UNPARTITIONED ORDERED BY id, data + // ------------------------------------------------------------------------- + // merge mode is NOT SET -> use write distribution and ordering + // merge mode is NONE -> unspecified distribution + LOCALLY ORDER BY id, data + // merge mode is HASH -> unspecified distribution + LOCALLY ORDER BY id, data + // merge mode is RANGE -> ORDER BY id, data + // + // PARTITIONED BY date, days(ts) UNORDERED + // ------------------------------------------------------------------------- + // merge mode is NOT SET -> CLUSTER BY date, days(ts) + LOCALLY ORDER BY date, days(ts) + // merge mode is NONE -> unspecified distribution + LOCALLY ORDERED BY date, days(ts) + // merge mode is HASH -> CLUSTER BY date, days(ts) + LOCALLY ORDER BY date, days(ts) + // merge mode is RANGE -> ORDER BY date, days(ts) + // + // PARTITIONED BY date ORDERED BY id + // ------------------------------------------------------------------------- + // merge mode is NOT SET -> CLUSTER BY date + LOCALLY ORDER BY date, id + // merge mode is NONE -> unspecified distribution + LOCALLY ORDERED BY date, id + // merge mode is HASH -> CLUSTER BY date + LOCALLY ORDER BY date, id + // merge mode is RANGE -> ORDERED BY date, id + + @Test + public void testDefaultCopyOnWriteMergeUnpartitionedUnsortedTable() { + sql("CREATE TABLE %s (id bigint, data string) USING iceberg", tableName); + + Table table = validationCatalog.loadTable(tableIdent); + + checkCopyOnWriteDistributionAndOrdering(table, MERGE, UNSPECIFIED_DISTRIBUTION, EMPTY_ORDERING); + } + + @Test + public void testNoneCopyOnWriteMergeUnpartitionedUnsortedTable() { + sql("CREATE TABLE %s (id bigint, data string) USING iceberg", tableName); + + Table table = validationCatalog.loadTable(tableIdent); + + table.updateProperties().set(MERGE_DISTRIBUTION_MODE, WRITE_DISTRIBUTION_MODE_NONE).commit(); + + checkCopyOnWriteDistributionAndOrdering(table, MERGE, UNSPECIFIED_DISTRIBUTION, EMPTY_ORDERING); + } + + @Test + public void testHashCopyOnWriteMergeUnpartitionedUnsortedTable() { + sql("CREATE TABLE %s (id bigint, data string) USING iceberg", tableName); + + Table table = validationCatalog.loadTable(tableIdent); + + table.updateProperties().set(MERGE_DISTRIBUTION_MODE, WRITE_DISTRIBUTION_MODE_HASH).commit(); + + checkCopyOnWriteDistributionAndOrdering(table, MERGE, UNSPECIFIED_DISTRIBUTION, EMPTY_ORDERING); + } + + @Test + public void testRangeCopyOnWriteMergeUnpartitionedUnsortedTable() { + sql("CREATE TABLE %s (id bigint, data string) USING iceberg", tableName); + + Table table = validationCatalog.loadTable(tableIdent); + + table.updateProperties().set(MERGE_DISTRIBUTION_MODE, WRITE_DISTRIBUTION_MODE_RANGE).commit(); + + checkCopyOnWriteDistributionAndOrdering(table, MERGE, UNSPECIFIED_DISTRIBUTION, EMPTY_ORDERING); + } + + @Test + public void testDefaultCopyOnWriteMergeUnpartitionedSortedTable() { + sql("CREATE TABLE %s (id bigint, data string) USING iceberg", tableName); + + Table table = validationCatalog.loadTable(tableIdent); + + table.replaceSortOrder().asc("id").asc("data").commit(); + + SortOrder[] expectedOrdering = + new SortOrder[] { + Expressions.sort(Expressions.column("id"), SortDirection.ASCENDING), + Expressions.sort(Expressions.column("data"), SortDirection.ASCENDING) + }; + + Distribution expectedDistribution = Distributions.ordered(expectedOrdering); + + checkCopyOnWriteDistributionAndOrdering(table, MERGE, expectedDistribution, expectedOrdering); + } + + @Test + public void testNoneCopyOnWriteMergeUnpartitionedSortedTable() { + sql("CREATE TABLE %s (id bigint, data string) USING iceberg", tableName); + + Table table = validationCatalog.loadTable(tableIdent); + + table.updateProperties().set(MERGE_DISTRIBUTION_MODE, WRITE_DISTRIBUTION_MODE_NONE).commit(); + + table.replaceSortOrder().asc("id").asc("data").commit(); + + SortOrder[] expectedOrdering = + new SortOrder[] { + Expressions.sort(Expressions.column("id"), SortDirection.ASCENDING), + Expressions.sort(Expressions.column("data"), SortDirection.ASCENDING) + }; + + checkCopyOnWriteDistributionAndOrdering( + table, MERGE, UNSPECIFIED_DISTRIBUTION, expectedOrdering); + } + + @Test + public void testHashCopyOnWriteMergeUnpartitionedSortedTable() { + sql("CREATE TABLE %s (id bigint, data string) USING iceberg", tableName); + + Table table = validationCatalog.loadTable(tableIdent); + + table.updateProperties().set(MERGE_DISTRIBUTION_MODE, WRITE_DISTRIBUTION_MODE_HASH).commit(); + + table.replaceSortOrder().asc("id").asc("data").commit(); + + SortOrder[] expectedOrdering = + new SortOrder[] { + Expressions.sort(Expressions.column("id"), SortDirection.ASCENDING), + Expressions.sort(Expressions.column("data"), SortDirection.ASCENDING) + }; + + checkCopyOnWriteDistributionAndOrdering( + table, MERGE, UNSPECIFIED_DISTRIBUTION, expectedOrdering); + } + + @Test + public void testRangeCopyOnWriteMergeUnpartitionedSortedTable() { + sql("CREATE TABLE %s (id bigint, data string) USING iceberg", tableName); + + Table table = validationCatalog.loadTable(tableIdent); + + table.updateProperties().set(MERGE_DISTRIBUTION_MODE, WRITE_DISTRIBUTION_MODE_RANGE).commit(); + + table.replaceSortOrder().asc("id").asc("data").commit(); + + SortOrder[] expectedOrdering = + new SortOrder[] { + Expressions.sort(Expressions.column("id"), SortDirection.ASCENDING), + Expressions.sort(Expressions.column("data"), SortDirection.ASCENDING) + }; + + Distribution expectedDistribution = Distributions.ordered(expectedOrdering); + + checkCopyOnWriteDistributionAndOrdering(table, MERGE, expectedDistribution, expectedOrdering); + } + + @Test + public void testDefaultCopyOnWriteMergePartitionedUnsortedTable() { + sql( + "CREATE TABLE %s (id BIGINT, data STRING, date DATE, ts TIMESTAMP) " + + "USING iceberg " + + "PARTITIONED BY (date, days(ts))", + tableName); + + Table table = validationCatalog.loadTable(tableIdent); + + Expression[] expectedClustering = + new Expression[] {Expressions.identity("date"), Expressions.days("ts")}; + Distribution expectedDistribution = Distributions.clustered(expectedClustering); + + SortOrder[] expectedOrdering = + new SortOrder[] { + Expressions.sort(Expressions.column("date"), SortDirection.ASCENDING), + Expressions.sort(Expressions.days("ts"), SortDirection.ASCENDING) + }; + + checkCopyOnWriteDistributionAndOrdering(table, MERGE, expectedDistribution, expectedOrdering); + } + + @Test + public void testNoneCopyOnWriteMergePartitionedUnsortedTable() { + sql( + "CREATE TABLE %s (id BIGINT, data STRING, date DATE, ts TIMESTAMP) " + + "USING iceberg " + + "PARTITIONED BY (date, days(ts))", + tableName); + + Table table = validationCatalog.loadTable(tableIdent); + + table.updateProperties().set(MERGE_DISTRIBUTION_MODE, WRITE_DISTRIBUTION_MODE_NONE).commit(); + + SortOrder[] expectedOrdering = + new SortOrder[] { + Expressions.sort(Expressions.column("date"), SortDirection.ASCENDING), + Expressions.sort(Expressions.days("ts"), SortDirection.ASCENDING) + }; + + checkCopyOnWriteDistributionAndOrdering( + table, MERGE, UNSPECIFIED_DISTRIBUTION, expectedOrdering); + } + + @Test + public void testHashCopyOnWriteMergePartitionedUnsortedTable() { + sql( + "CREATE TABLE %s (id BIGINT, data STRING, date DATE, ts TIMESTAMP) " + + "USING iceberg " + + "PARTITIONED BY (date, days(ts))", + tableName); + + Table table = validationCatalog.loadTable(tableIdent); + + table.updateProperties().set(MERGE_DISTRIBUTION_MODE, WRITE_DISTRIBUTION_MODE_HASH).commit(); + + Expression[] expectedClustering = + new Expression[] {Expressions.identity("date"), Expressions.days("ts")}; + Distribution expectedDistribution = Distributions.clustered(expectedClustering); + + SortOrder[] expectedOrdering = + new SortOrder[] { + Expressions.sort(Expressions.column("date"), SortDirection.ASCENDING), + Expressions.sort(Expressions.days("ts"), SortDirection.ASCENDING) + }; + + checkCopyOnWriteDistributionAndOrdering(table, MERGE, expectedDistribution, expectedOrdering); + } + + @Test + public void testRangeCopyOnWriteMergePartitionedUnsortedTable() { + sql( + "CREATE TABLE %s (id BIGINT, data STRING, date DATE, ts TIMESTAMP) " + + "USING iceberg " + + "PARTITIONED BY (date, days(ts))", + tableName); + + Table table = validationCatalog.loadTable(tableIdent); + + table.updateProperties().set(MERGE_DISTRIBUTION_MODE, WRITE_DISTRIBUTION_MODE_RANGE).commit(); + + SortOrder[] expectedOrdering = + new SortOrder[] { + Expressions.sort(Expressions.column("date"), SortDirection.ASCENDING), + Expressions.sort(Expressions.days("ts"), SortDirection.ASCENDING) + }; + + Distribution expectedDistribution = Distributions.ordered(expectedOrdering); + + checkCopyOnWriteDistributionAndOrdering(table, MERGE, expectedDistribution, expectedOrdering); + } + + @Test + public void testDefaultCopyOnWriteMergePartitionedSortedTable() { + sql( + "CREATE TABLE %s (id BIGINT, data STRING, date DATE, ts TIMESTAMP) " + + "USING iceberg " + + "PARTITIONED BY (date)", + tableName); + + Table table = validationCatalog.loadTable(tableIdent); + + table.replaceSortOrder().desc("id").commit(); + + Expression[] expectedClustering = new Expression[] {Expressions.identity("date")}; + Distribution expectedDistribution = Distributions.clustered(expectedClustering); + + SortOrder[] expectedOrdering = + new SortOrder[] { + Expressions.sort(Expressions.column("date"), SortDirection.ASCENDING), + Expressions.sort(Expressions.column("id"), SortDirection.DESCENDING) + }; + + checkCopyOnWriteDistributionAndOrdering(table, MERGE, expectedDistribution, expectedOrdering); + } + + @Test + public void testNoneCopyOnWriteMergePartitionedSortedTable() { + sql( + "CREATE TABLE %s (id BIGINT, data STRING, date DATE, ts TIMESTAMP) " + + "USING iceberg " + + "PARTITIONED BY (date)", + tableName); + + Table table = validationCatalog.loadTable(tableIdent); + + table.updateProperties().set(MERGE_DISTRIBUTION_MODE, WRITE_DISTRIBUTION_MODE_NONE).commit(); + + table.replaceSortOrder().desc("id").commit(); + + SortOrder[] expectedOrdering = + new SortOrder[] { + Expressions.sort(Expressions.column("date"), SortDirection.ASCENDING), + Expressions.sort(Expressions.column("id"), SortDirection.DESCENDING) + }; + + checkCopyOnWriteDistributionAndOrdering( + table, MERGE, UNSPECIFIED_DISTRIBUTION, expectedOrdering); + } + + @Test + public void testHashCopyOnWriteMergePartitionedSortedTable() { + sql( + "CREATE TABLE %s (id BIGINT, data STRING, date DATE, ts TIMESTAMP) " + + "USING iceberg " + + "PARTITIONED BY (date, bucket(8, data))", + tableName); + + Table table = validationCatalog.loadTable(tableIdent); + + table.updateProperties().set(MERGE_DISTRIBUTION_MODE, WRITE_DISTRIBUTION_MODE_HASH).commit(); + + table.replaceSortOrder().asc("id").commit(); + + Expression[] expectedClustering = + new Expression[] {Expressions.identity("date"), Expressions.bucket(8, "data")}; + Distribution expectedDistribution = Distributions.clustered(expectedClustering); + + SortOrder[] expectedOrdering = + new SortOrder[] { + Expressions.sort(Expressions.column("date"), SortDirection.ASCENDING), + Expressions.sort(Expressions.bucket(8, "data"), SortDirection.ASCENDING), + Expressions.sort(Expressions.column("id"), SortDirection.ASCENDING) + }; + + checkCopyOnWriteDistributionAndOrdering(table, MERGE, expectedDistribution, expectedOrdering); + } + + @Test + public void testRangeCopyOnWriteMergePartitionedSortedTable() { + sql( + "CREATE TABLE %s (id BIGINT, data STRING, date DATE, ts TIMESTAMP) " + + "USING iceberg " + + "PARTITIONED BY (date)", + tableName); + + Table table = validationCatalog.loadTable(tableIdent); + + table.updateProperties().set(MERGE_DISTRIBUTION_MODE, WRITE_DISTRIBUTION_MODE_RANGE).commit(); + + table.replaceSortOrder().asc("id").commit(); + + SortOrder[] expectedOrdering = + new SortOrder[] { + Expressions.sort(Expressions.column("date"), SortDirection.ASCENDING), + Expressions.sort(Expressions.column("id"), SortDirection.ASCENDING) + }; + + Distribution expectedDistribution = Distributions.ordered(expectedOrdering); + + checkCopyOnWriteDistributionAndOrdering(table, MERGE, expectedDistribution, expectedOrdering); + } + + // =================================================================================== + // Distribution and ordering for merge-on-read DELETE operations with position deletes + // =================================================================================== + // + // delete mode is NOT SET -> CLUSTER BY _spec_id, _partition + LOCALLY ORDER BY _spec_id, + // _partition, _file, _pos + // delete mode is NONE -> unspecified distribution + LOCALLY ORDER BY _spec_id, _partition, _file, + // _pos + // delete mode is HASH -> CLUSTER BY _spec_id, _partition + LOCALLY ORDER BY _spec_id, _partition, + // _file, _pos + // delete mode is RANGE -> RANGE DISTRIBUTE BY _spec_id, _partition, _file + + // LOCALLY ORDER BY _spec_id, _partition, _file, _pos + + @Test + public void testDefaultPositionDeltaDeleteUnpartitionedTable() { + sql("CREATE TABLE %s (id bigint, data string) USING iceberg", tableName); + + Table table = validationCatalog.loadTable(tableIdent); + + checkPositionDeltaDistributionAndOrdering( + table, + DELETE, + SPEC_ID_PARTITION_CLUSTERED_DISTRIBUTION, + SPEC_ID_PARTITION_FILE_POSITION_ORDERING); + } + + @Test + public void testNonePositionDeltaDeleteUnpartitionedTable() { + sql("CREATE TABLE %s (id bigint, data string) USING iceberg", tableName); + + Table table = validationCatalog.loadTable(tableIdent); + + table.updateProperties().set(DELETE_DISTRIBUTION_MODE, WRITE_DISTRIBUTION_MODE_NONE).commit(); + + checkPositionDeltaDistributionAndOrdering( + table, DELETE, UNSPECIFIED_DISTRIBUTION, SPEC_ID_PARTITION_FILE_POSITION_ORDERING); + } + + @Test + public void testHashPositionDeltaDeleteUnpartitionedTable() { + sql("CREATE TABLE %s (id bigint, data string) USING iceberg", tableName); + + Table table = validationCatalog.loadTable(tableIdent); + + table.updateProperties().set(DELETE_DISTRIBUTION_MODE, WRITE_DISTRIBUTION_MODE_HASH).commit(); + + checkPositionDeltaDistributionAndOrdering( + table, + DELETE, + SPEC_ID_PARTITION_CLUSTERED_DISTRIBUTION, + SPEC_ID_PARTITION_FILE_POSITION_ORDERING); + } + + @Test + public void testRangePositionDeltaDeleteUnpartitionedTable() { + sql("CREATE TABLE %s (id bigint, data string) USING iceberg", tableName); + + Table table = validationCatalog.loadTable(tableIdent); + + table.updateProperties().set(DELETE_DISTRIBUTION_MODE, WRITE_DISTRIBUTION_MODE_RANGE).commit(); + + Distribution expectedDistribution = Distributions.ordered(SPEC_ID_PARTITION_FILE_ORDERING); + + checkPositionDeltaDistributionAndOrdering( + table, DELETE, expectedDistribution, SPEC_ID_PARTITION_FILE_POSITION_ORDERING); + } + + @Test + public void testDefaultPositionDeltaDeletePartitionedTable() { + sql( + "CREATE TABLE %s (id BIGINT, data STRING, date DATE, ts TIMESTAMP) " + + "USING iceberg " + + "PARTITIONED BY (date, bucket(8, data))", + tableName); + + Table table = validationCatalog.loadTable(tableIdent); + + checkPositionDeltaDistributionAndOrdering( + table, + DELETE, + SPEC_ID_PARTITION_CLUSTERED_DISTRIBUTION, + SPEC_ID_PARTITION_FILE_POSITION_ORDERING); + } + + @Test + public void testNonePositionDeltaDeletePartitionedTable() { + sql( + "CREATE TABLE %s (id BIGINT, data STRING, date DATE, ts TIMESTAMP) " + + "USING iceberg " + + "PARTITIONED BY (date, bucket(8, data))", + tableName); + + Table table = validationCatalog.loadTable(tableIdent); + + table.updateProperties().set(DELETE_DISTRIBUTION_MODE, WRITE_DISTRIBUTION_MODE_NONE).commit(); + + checkPositionDeltaDistributionAndOrdering( + table, DELETE, UNSPECIFIED_DISTRIBUTION, SPEC_ID_PARTITION_FILE_POSITION_ORDERING); + } + + @Test + public void testHashPositionDeltaDeletePartitionedTable() { + sql( + "CREATE TABLE %s (id BIGINT, data STRING, date DATE, ts TIMESTAMP) " + + "USING iceberg " + + "PARTITIONED BY (date, bucket(8, data))", + tableName); + + Table table = validationCatalog.loadTable(tableIdent); + + table.updateProperties().set(DELETE_DISTRIBUTION_MODE, WRITE_DISTRIBUTION_MODE_HASH).commit(); + + checkPositionDeltaDistributionAndOrdering( + table, + DELETE, + SPEC_ID_PARTITION_CLUSTERED_DISTRIBUTION, + SPEC_ID_PARTITION_FILE_POSITION_ORDERING); + } + + @Test + public void testRangePositionDeltaDeletePartitionedTable() { + sql( + "CREATE TABLE %s (id BIGINT, data STRING, date DATE, ts TIMESTAMP) " + + "USING iceberg " + + "PARTITIONED BY (date, bucket(8, data))", + tableName); + + Table table = validationCatalog.loadTable(tableIdent); + + table.updateProperties().set(DELETE_DISTRIBUTION_MODE, WRITE_DISTRIBUTION_MODE_RANGE).commit(); + + Distribution expectedDistribution = Distributions.ordered(SPEC_ID_PARTITION_FILE_ORDERING); + + checkPositionDeltaDistributionAndOrdering( + table, DELETE, expectedDistribution, SPEC_ID_PARTITION_FILE_POSITION_ORDERING); + } + + // =================================================================================== + // Distribution and ordering for merge-on-read UPDATE operations with position deletes + // =================================================================================== + // + // update mode is NOT SET -> CLUSTER BY _spec_id, _partition + LOCALLY ORDER BY _spec_id, + // _partition, _file, _pos + // update mode is NONE -> unspecified distribution + LOCALLY ORDER BY _spec_id, _partition, _file, + // _pos + // update mode is HASH -> CLUSTER BY _spec_id, _partition + LOCALLY ORDER BY _spec_id, _partition, + // _file, _pos + // update mode is RANGE -> RANGE DISTRIBUTE BY _spec_id, _partition, _file + + // LOCALLY ORDER BY _spec_id, _partition, _file, _pos + + @Test + public void testDefaultPositionDeltaUpdateUnpartitionedTable() { + sql("CREATE TABLE %s (id bigint, data string) USING iceberg", tableName); + + Table table = validationCatalog.loadTable(tableIdent); + + checkPositionDeltaDistributionAndOrdering( + table, + UPDATE, + SPEC_ID_PARTITION_CLUSTERED_DISTRIBUTION, + SPEC_ID_PARTITION_FILE_POSITION_ORDERING); + } + + @Test + public void testNonePositionDeltaUpdateUnpartitionedTable() { + sql("CREATE TABLE %s (id bigint, data string) USING iceberg", tableName); + + Table table = validationCatalog.loadTable(tableIdent); + + table.updateProperties().set(UPDATE_DISTRIBUTION_MODE, WRITE_DISTRIBUTION_MODE_NONE).commit(); + + checkPositionDeltaDistributionAndOrdering( + table, UPDATE, UNSPECIFIED_DISTRIBUTION, SPEC_ID_PARTITION_FILE_POSITION_ORDERING); + } + + @Test + public void testHashPositionDeltaUpdateUnpartitionedTable() { + sql("CREATE TABLE %s (id bigint, data string) USING iceberg", tableName); + + Table table = validationCatalog.loadTable(tableIdent); + + table.updateProperties().set(UPDATE_DISTRIBUTION_MODE, WRITE_DISTRIBUTION_MODE_HASH).commit(); + + checkPositionDeltaDistributionAndOrdering( + table, + UPDATE, + SPEC_ID_PARTITION_CLUSTERED_DISTRIBUTION, + SPEC_ID_PARTITION_FILE_POSITION_ORDERING); + } + + @Test + public void testRangePositionDeltaUpdateUnpartitionedTable() { + sql("CREATE TABLE %s (id bigint, data string) USING iceberg", tableName); + + Table table = validationCatalog.loadTable(tableIdent); + + table.updateProperties().set(UPDATE_DISTRIBUTION_MODE, WRITE_DISTRIBUTION_MODE_RANGE).commit(); + + Distribution expectedDistribution = Distributions.ordered(SPEC_ID_PARTITION_FILE_ORDERING); + + checkPositionDeltaDistributionAndOrdering( + table, UPDATE, expectedDistribution, SPEC_ID_PARTITION_FILE_POSITION_ORDERING); + } + + @Test + public void testDefaultPositionDeltaUpdatePartitionedTable() { + sql( + "CREATE TABLE %s (id BIGINT, data STRING, date DATE, ts TIMESTAMP) " + + "USING iceberg " + + "PARTITIONED BY (date, bucket(8, data))", + tableName); + + Table table = validationCatalog.loadTable(tableIdent); + + checkPositionDeltaDistributionAndOrdering( + table, + UPDATE, + SPEC_ID_PARTITION_CLUSTERED_DISTRIBUTION, + SPEC_ID_PARTITION_FILE_POSITION_ORDERING); + } + + @Test + public void testNonePositionDeltaUpdatePartitionedTable() { + sql( + "CREATE TABLE %s (id BIGINT, data STRING, date DATE, ts TIMESTAMP) " + + "USING iceberg " + + "PARTITIONED BY (date, bucket(8, data))", + tableName); + + Table table = validationCatalog.loadTable(tableIdent); + + table.updateProperties().set(UPDATE_DISTRIBUTION_MODE, WRITE_DISTRIBUTION_MODE_NONE).commit(); + + checkPositionDeltaDistributionAndOrdering( + table, UPDATE, UNSPECIFIED_DISTRIBUTION, SPEC_ID_PARTITION_FILE_POSITION_ORDERING); + } + + @Test + public void testHashPositionDeltaUpdatePartitionedTable() { + sql( + "CREATE TABLE %s (id BIGINT, data STRING, date DATE, ts TIMESTAMP) " + + "USING iceberg " + + "PARTITIONED BY (date, bucket(8, data))", + tableName); + + Table table = validationCatalog.loadTable(tableIdent); + + table.updateProperties().set(UPDATE_DISTRIBUTION_MODE, WRITE_DISTRIBUTION_MODE_HASH).commit(); + + checkPositionDeltaDistributionAndOrdering( + table, + UPDATE, + SPEC_ID_PARTITION_CLUSTERED_DISTRIBUTION, + SPEC_ID_PARTITION_FILE_POSITION_ORDERING); + } + + @Test + public void testRangePositionDeltaUpdatePartitionedTable() { + sql( + "CREATE TABLE %s (id BIGINT, data STRING, date DATE, ts TIMESTAMP) " + + "USING iceberg " + + "PARTITIONED BY (date, bucket(8, data))", + tableName); + + Table table = validationCatalog.loadTable(tableIdent); + + table.updateProperties().set(UPDATE_DISTRIBUTION_MODE, WRITE_DISTRIBUTION_MODE_RANGE).commit(); + + Distribution expectedDistribution = Distributions.ordered(SPEC_ID_PARTITION_FILE_ORDERING); + + checkPositionDeltaDistributionAndOrdering( + table, UPDATE, expectedDistribution, SPEC_ID_PARTITION_FILE_POSITION_ORDERING); + } + + // ================================================================================== + // Distribution and ordering for merge-on-read MERGE operations with position deletes + // ================================================================================== + // + // IMPORTANT: metadata columns like _spec_id and _partition are NULL for new rows + // + // UNPARTITIONED UNORDERED + // ------------------------------------------------------------------------- + // merge mode is NOT SET -> CLUSTER BY _spec_id, _partition, _file + + // LOCALLY ORDER BY _spec_id, _partition, _file, _pos + // merge mode is NONE -> unspecified distribution + + // LOCALLY ORDER BY _spec_id, _partition, _file, _pos + // merge mode is HASH -> CLUSTER BY _spec_id, _partition, _file + + // LOCALLY ORDER BY _spec_id, _partition, _file, _pos + // merge mode is RANGE -> RANGE DISTRIBUTE BY _spec_id, _partition, _file + + // LOCALLY ORDER BY _spec_id, _partition, _file, _pos + // + // UNPARTITIONED ORDERED BY id, data + // ------------------------------------------------------------------------- + // merge mode is NOT SET -> CLUSTER BY _spec_id, _partition, _file + + // LOCALLY ORDER BY _spec_id, _partition, _file, _pos, id, data + // merge mode is NONE -> unspecified distribution + + // LOCALLY ORDER BY _spec_id, _partition, _file, _pos, id, data + // merge mode is HASH -> CLUSTER BY _spec_id, _partition, _file + + // LOCALLY ORDER BY _spec_id, _partition, _file, _pos, id, data + // merge mode is RANGE -> RANGE DISTRIBUTE BY _spec_id, _partition, _file, id, data + + // LOCALLY ORDER BY _spec_id, _partition, _file, _pos, id, data + // + // PARTITIONED BY date, days(ts) UNORDERED + // ------------------------------------------------------------------------- + // merge mode is NOT SET -> CLUSTER BY _spec_id, _partition, date, days(ts) + + // LOCALLY ORDER BY _spec_id, _partition, _file, _pos, date, days(ts) + // merge mode is NONE -> unspecified distribution + + // LOCALLY ORDERED BY _spec_id, _partition, _file, _pos, date, days(ts) + // merge mode is HASH -> CLUSTER BY _spec_id, _partition, date, days(ts) + + // LOCALLY ORDER BY _spec_id, _partition, _file, _pos, date, days(ts) + // merge mode is RANGE -> RANGE DISTRIBUTE BY _spec_id, _partition, _file, date, days(ts) + + // LOCALLY ORDER BY _spec_id, _partition, _file, _pos, date, days(ts) + // + // PARTITIONED BY date ORDERED BY id + // ------------------------------------------------------------------------- + // merge mode is NOT SET -> CLUSTER BY _spec_id, _partition, date + + // LOCALLY ORDER BY _spec_id, _partition, _file, _pos, date, id + // merge mode is NONE -> unspecified distribution + + // LOCALLY ORDERED BY _spec_id, _partition, _file, _pos, date, id + // merge mode is HASH -> CLUSTER BY _spec_id, _partition, date + + // LOCALLY ORDER BY _spec_id, _partition, _file, _pos, date, id + // merge mode is RANGE -> RANGE DISTRIBUTE BY _spec_id, _partition, _file, date, id + // LOCALLY ORDERED BY _spec_id, _partition, _file, _pos, date, id + + @Test + public void testDefaultPositionDeltaMergeUnpartitionedUnsortedTable() { + sql("CREATE TABLE %s (id bigint, data string) USING iceberg", tableName); + + Table table = validationCatalog.loadTable(tableIdent); + + Expression[] expectedClustering = + new Expression[] { + Expressions.column(MetadataColumns.SPEC_ID.name()), + Expressions.column(MetadataColumns.PARTITION_COLUMN_NAME), + Expressions.column(MetadataColumns.FILE_PATH.name()) + }; + Distribution expectedDistribution = Distributions.clustered(expectedClustering); + + checkPositionDeltaDistributionAndOrdering( + table, MERGE, expectedDistribution, SPEC_ID_PARTITION_FILE_POSITION_ORDERING); + } + + @Test + public void testNonePositionDeltaMergeUnpartitionedUnsortedTable() { + sql("CREATE TABLE %s (id bigint, data string) USING iceberg", tableName); + + Table table = validationCatalog.loadTable(tableIdent); + + table.updateProperties().set(MERGE_DISTRIBUTION_MODE, WRITE_DISTRIBUTION_MODE_NONE).commit(); + + checkPositionDeltaDistributionAndOrdering( + table, MERGE, UNSPECIFIED_DISTRIBUTION, SPEC_ID_PARTITION_FILE_POSITION_ORDERING); + } + + @Test + public void testHashPositionDeltaMergeUnpartitionedUnsortedTable() { + sql("CREATE TABLE %s (id bigint, data string) USING iceberg", tableName); + + Table table = validationCatalog.loadTable(tableIdent); + + table.updateProperties().set(MERGE_DISTRIBUTION_MODE, WRITE_DISTRIBUTION_MODE_HASH).commit(); + + Expression[] expectedClustering = + new Expression[] { + Expressions.column(MetadataColumns.SPEC_ID.name()), + Expressions.column(MetadataColumns.PARTITION_COLUMN_NAME), + Expressions.column(MetadataColumns.FILE_PATH.name()) + }; + Distribution expectedDistribution = Distributions.clustered(expectedClustering); + + checkPositionDeltaDistributionAndOrdering( + table, MERGE, expectedDistribution, SPEC_ID_PARTITION_FILE_POSITION_ORDERING); + } + + @Test + public void testRangePositionDeltaMergeUnpartitionedUnsortedTable() { + sql("CREATE TABLE %s (id bigint, data string) USING iceberg", tableName); + + Table table = validationCatalog.loadTable(tableIdent); + + table.updateProperties().set(MERGE_DISTRIBUTION_MODE, WRITE_DISTRIBUTION_MODE_RANGE).commit(); + + SortOrder[] expectedDistributionOrdering = + new SortOrder[] { + Expressions.sort( + Expressions.column(MetadataColumns.SPEC_ID.name()), SortDirection.ASCENDING), + Expressions.sort( + Expressions.column(MetadataColumns.PARTITION_COLUMN_NAME), SortDirection.ASCENDING), + Expressions.sort( + Expressions.column(MetadataColumns.FILE_PATH.name()), SortDirection.ASCENDING) + }; + Distribution expectedDistribution = Distributions.ordered(expectedDistributionOrdering); + + checkPositionDeltaDistributionAndOrdering( + table, MERGE, expectedDistribution, SPEC_ID_PARTITION_FILE_POSITION_ORDERING); + } + + @Test + public void testDefaultPositionDeltaMergeUnpartitionedSortedTable() { + sql("CREATE TABLE %s (id bigint, data string) USING iceberg", tableName); + + Table table = validationCatalog.loadTable(tableIdent); + + table.replaceSortOrder().asc("id").asc("data").commit(); + + Expression[] expectedClustering = + new Expression[] { + Expressions.column(MetadataColumns.SPEC_ID.name()), + Expressions.column(MetadataColumns.PARTITION_COLUMN_NAME), + Expressions.column(MetadataColumns.FILE_PATH.name()) + }; + Distribution expectedDistribution = Distributions.clustered(expectedClustering); + + SortOrder[] expectedOrdering = + new SortOrder[] { + Expressions.sort( + Expressions.column(MetadataColumns.SPEC_ID.name()), SortDirection.ASCENDING), + Expressions.sort( + Expressions.column(MetadataColumns.PARTITION_COLUMN_NAME), SortDirection.ASCENDING), + Expressions.sort( + Expressions.column(MetadataColumns.FILE_PATH.name()), SortDirection.ASCENDING), + Expressions.sort( + Expressions.column(MetadataColumns.ROW_POSITION.name()), SortDirection.ASCENDING), + Expressions.sort(Expressions.column("id"), SortDirection.ASCENDING), + Expressions.sort(Expressions.column("data"), SortDirection.ASCENDING) + }; + + checkPositionDeltaDistributionAndOrdering(table, MERGE, expectedDistribution, expectedOrdering); + } + + @Test + public void testNonePositionDeltaMergeUnpartitionedSortedTable() { + sql("CREATE TABLE %s (id bigint, data string) USING iceberg", tableName); + + Table table = validationCatalog.loadTable(tableIdent); + + table.updateProperties().set(MERGE_DISTRIBUTION_MODE, WRITE_DISTRIBUTION_MODE_NONE).commit(); + + table.replaceSortOrder().asc("id").asc("data").commit(); + + SortOrder[] expectedOrdering = + new SortOrder[] { + Expressions.sort( + Expressions.column(MetadataColumns.SPEC_ID.name()), SortDirection.ASCENDING), + Expressions.sort( + Expressions.column(MetadataColumns.PARTITION_COLUMN_NAME), SortDirection.ASCENDING), + Expressions.sort( + Expressions.column(MetadataColumns.FILE_PATH.name()), SortDirection.ASCENDING), + Expressions.sort( + Expressions.column(MetadataColumns.ROW_POSITION.name()), SortDirection.ASCENDING), + Expressions.sort(Expressions.column("id"), SortDirection.ASCENDING), + Expressions.sort(Expressions.column("data"), SortDirection.ASCENDING) + }; + + checkPositionDeltaDistributionAndOrdering( + table, MERGE, UNSPECIFIED_DISTRIBUTION, expectedOrdering); + } + + @Test + public void testHashPositionDeltaMergeUnpartitionedSortedTable() { + sql("CREATE TABLE %s (id bigint, data string) USING iceberg", tableName); + + Table table = validationCatalog.loadTable(tableIdent); + + table.updateProperties().set(MERGE_DISTRIBUTION_MODE, WRITE_DISTRIBUTION_MODE_HASH).commit(); + + table.replaceSortOrder().asc("id").asc("data").commit(); + + Expression[] expectedClustering = + new Expression[] { + Expressions.column(MetadataColumns.SPEC_ID.name()), + Expressions.column(MetadataColumns.PARTITION_COLUMN_NAME), + Expressions.column(MetadataColumns.FILE_PATH.name()) + }; + Distribution expectedDistribution = Distributions.clustered(expectedClustering); + + SortOrder[] expectedOrdering = + new SortOrder[] { + Expressions.sort( + Expressions.column(MetadataColumns.SPEC_ID.name()), SortDirection.ASCENDING), + Expressions.sort( + Expressions.column(MetadataColumns.PARTITION_COLUMN_NAME), SortDirection.ASCENDING), + Expressions.sort( + Expressions.column(MetadataColumns.FILE_PATH.name()), SortDirection.ASCENDING), + Expressions.sort( + Expressions.column(MetadataColumns.ROW_POSITION.name()), SortDirection.ASCENDING), + Expressions.sort(Expressions.column("id"), SortDirection.ASCENDING), + Expressions.sort(Expressions.column("data"), SortDirection.ASCENDING) + }; + + checkPositionDeltaDistributionAndOrdering(table, MERGE, expectedDistribution, expectedOrdering); + } + + @Test + public void testRangePositionDeltaMergeUnpartitionedSortedTable() { + sql("CREATE TABLE %s (id bigint, data string) USING iceberg", tableName); + + Table table = validationCatalog.loadTable(tableIdent); + + table.updateProperties().set(MERGE_DISTRIBUTION_MODE, WRITE_DISTRIBUTION_MODE_RANGE).commit(); + + table.replaceSortOrder().asc("id").asc("data").commit(); + + SortOrder[] expectedDistributionOrdering = + new SortOrder[] { + Expressions.sort( + Expressions.column(MetadataColumns.SPEC_ID.name()), SortDirection.ASCENDING), + Expressions.sort( + Expressions.column(MetadataColumns.PARTITION_COLUMN_NAME), SortDirection.ASCENDING), + Expressions.sort( + Expressions.column(MetadataColumns.FILE_PATH.name()), SortDirection.ASCENDING), + Expressions.sort(Expressions.column("id"), SortDirection.ASCENDING), + Expressions.sort(Expressions.column("data"), SortDirection.ASCENDING) + }; + Distribution expectedDistribution = Distributions.ordered(expectedDistributionOrdering); + + SortOrder[] expectedOrdering = + new SortOrder[] { + Expressions.sort( + Expressions.column(MetadataColumns.SPEC_ID.name()), SortDirection.ASCENDING), + Expressions.sort( + Expressions.column(MetadataColumns.PARTITION_COLUMN_NAME), SortDirection.ASCENDING), + Expressions.sort( + Expressions.column(MetadataColumns.FILE_PATH.name()), SortDirection.ASCENDING), + Expressions.sort( + Expressions.column(MetadataColumns.ROW_POSITION.name()), SortDirection.ASCENDING), + Expressions.sort(Expressions.column("id"), SortDirection.ASCENDING), + Expressions.sort(Expressions.column("data"), SortDirection.ASCENDING) + }; + + checkPositionDeltaDistributionAndOrdering(table, MERGE, expectedDistribution, expectedOrdering); + } + + @Test + public void testDefaultPositionDeltaMergePartitionedUnsortedTable() { + sql( + "CREATE TABLE %s (id BIGINT, data STRING, date DATE, ts TIMESTAMP) " + + "USING iceberg " + + "PARTITIONED BY (date, days(ts))", + tableName); + + Table table = validationCatalog.loadTable(tableIdent); + + Expression[] expectedClustering = + new Expression[] { + Expressions.column(MetadataColumns.SPEC_ID.name()), + Expressions.column(MetadataColumns.PARTITION_COLUMN_NAME), + Expressions.identity("date"), + Expressions.days("ts") + }; + Distribution expectedDistribution = Distributions.clustered(expectedClustering); + + SortOrder[] expectedOrdering = + new SortOrder[] { + Expressions.sort( + Expressions.column(MetadataColumns.SPEC_ID.name()), SortDirection.ASCENDING), + Expressions.sort( + Expressions.column(MetadataColumns.PARTITION_COLUMN_NAME), SortDirection.ASCENDING), + Expressions.sort( + Expressions.column(MetadataColumns.FILE_PATH.name()), SortDirection.ASCENDING), + Expressions.sort( + Expressions.column(MetadataColumns.ROW_POSITION.name()), SortDirection.ASCENDING), + Expressions.sort(Expressions.column("date"), SortDirection.ASCENDING), + Expressions.sort(Expressions.days("ts"), SortDirection.ASCENDING) + }; + + checkPositionDeltaDistributionAndOrdering(table, MERGE, expectedDistribution, expectedOrdering); + } + + @Test + public void testNonePositionDeltaMergePartitionedUnsortedTable() { + sql( + "CREATE TABLE %s (id BIGINT, data STRING, date DATE, ts TIMESTAMP) " + + "USING iceberg " + + "PARTITIONED BY (date, days(ts))", + tableName); + + Table table = validationCatalog.loadTable(tableIdent); + + table.updateProperties().set(MERGE_DISTRIBUTION_MODE, WRITE_DISTRIBUTION_MODE_NONE).commit(); + + SortOrder[] expectedOrdering = + new SortOrder[] { + Expressions.sort( + Expressions.column(MetadataColumns.SPEC_ID.name()), SortDirection.ASCENDING), + Expressions.sort( + Expressions.column(MetadataColumns.PARTITION_COLUMN_NAME), SortDirection.ASCENDING), + Expressions.sort( + Expressions.column(MetadataColumns.FILE_PATH.name()), SortDirection.ASCENDING), + Expressions.sort( + Expressions.column(MetadataColumns.ROW_POSITION.name()), SortDirection.ASCENDING), + Expressions.sort(Expressions.column("date"), SortDirection.ASCENDING), + Expressions.sort(Expressions.days("ts"), SortDirection.ASCENDING) + }; + + checkPositionDeltaDistributionAndOrdering( + table, MERGE, UNSPECIFIED_DISTRIBUTION, expectedOrdering); + } + + @Test + public void testHashPositionDeltaMergePartitionedUnsortedTable() { + sql( + "CREATE TABLE %s (id BIGINT, data STRING, date DATE, ts TIMESTAMP) " + + "USING iceberg " + + "PARTITIONED BY (date, days(ts))", + tableName); + + Table table = validationCatalog.loadTable(tableIdent); + + table.updateProperties().set(MERGE_DISTRIBUTION_MODE, WRITE_DISTRIBUTION_MODE_HASH).commit(); + + Expression[] expectedClustering = + new Expression[] { + Expressions.column(MetadataColumns.SPEC_ID.name()), + Expressions.column(MetadataColumns.PARTITION_COLUMN_NAME), + Expressions.identity("date"), + Expressions.days("ts") + }; + Distribution expectedDistribution = Distributions.clustered(expectedClustering); + + SortOrder[] expectedOrdering = + new SortOrder[] { + Expressions.sort( + Expressions.column(MetadataColumns.SPEC_ID.name()), SortDirection.ASCENDING), + Expressions.sort( + Expressions.column(MetadataColumns.PARTITION_COLUMN_NAME), SortDirection.ASCENDING), + Expressions.sort( + Expressions.column(MetadataColumns.FILE_PATH.name()), SortDirection.ASCENDING), + Expressions.sort( + Expressions.column(MetadataColumns.ROW_POSITION.name()), SortDirection.ASCENDING), + Expressions.sort(Expressions.column("date"), SortDirection.ASCENDING), + Expressions.sort(Expressions.days("ts"), SortDirection.ASCENDING) + }; + + checkPositionDeltaDistributionAndOrdering(table, MERGE, expectedDistribution, expectedOrdering); + } + + @Test + public void testRangePositionDeltaMergePartitionedUnsortedTable() { + sql( + "CREATE TABLE %s (id BIGINT, data STRING, date DATE, ts TIMESTAMP) " + + "USING iceberg " + + "PARTITIONED BY (date, days(ts))", + tableName); + + Table table = validationCatalog.loadTable(tableIdent); + + table.updateProperties().set(MERGE_DISTRIBUTION_MODE, WRITE_DISTRIBUTION_MODE_RANGE).commit(); + + SortOrder[] expectedDistributionOrdering = + new SortOrder[] { + Expressions.sort( + Expressions.column(MetadataColumns.SPEC_ID.name()), SortDirection.ASCENDING), + Expressions.sort( + Expressions.column(MetadataColumns.PARTITION_COLUMN_NAME), SortDirection.ASCENDING), + Expressions.sort( + Expressions.column(MetadataColumns.FILE_PATH.name()), SortDirection.ASCENDING), + Expressions.sort(Expressions.column("date"), SortDirection.ASCENDING), + Expressions.sort(Expressions.days("ts"), SortDirection.ASCENDING) + }; + Distribution expectedDistribution = Distributions.ordered(expectedDistributionOrdering); + + SortOrder[] expectedOrdering = + new SortOrder[] { + Expressions.sort( + Expressions.column(MetadataColumns.SPEC_ID.name()), SortDirection.ASCENDING), + Expressions.sort( + Expressions.column(MetadataColumns.PARTITION_COLUMN_NAME), SortDirection.ASCENDING), + Expressions.sort( + Expressions.column(MetadataColumns.FILE_PATH.name()), SortDirection.ASCENDING), + Expressions.sort( + Expressions.column(MetadataColumns.ROW_POSITION.name()), SortDirection.ASCENDING), + Expressions.sort(Expressions.column("date"), SortDirection.ASCENDING), + Expressions.sort(Expressions.days("ts"), SortDirection.ASCENDING) + }; + + checkPositionDeltaDistributionAndOrdering(table, MERGE, expectedDistribution, expectedOrdering); + } + + @Test + public void testNonePositionDeltaMergePartitionedSortedTable() { + sql( + "CREATE TABLE %s (id BIGINT, data STRING, date DATE, ts TIMESTAMP) " + + "USING iceberg " + + "PARTITIONED BY (date)", + tableName); + + Table table = validationCatalog.loadTable(tableIdent); + + table.updateProperties().set(MERGE_DISTRIBUTION_MODE, WRITE_DISTRIBUTION_MODE_NONE).commit(); + + table.replaceSortOrder().desc("id").commit(); + + SortOrder[] expectedOrdering = + new SortOrder[] { + Expressions.sort( + Expressions.column(MetadataColumns.SPEC_ID.name()), SortDirection.ASCENDING), + Expressions.sort( + Expressions.column(MetadataColumns.PARTITION_COLUMN_NAME), SortDirection.ASCENDING), + Expressions.sort( + Expressions.column(MetadataColumns.FILE_PATH.name()), SortDirection.ASCENDING), + Expressions.sort( + Expressions.column(MetadataColumns.ROW_POSITION.name()), SortDirection.ASCENDING), + Expressions.sort(Expressions.column("date"), SortDirection.ASCENDING), + Expressions.sort(Expressions.column("id"), SortDirection.DESCENDING) + }; + + checkPositionDeltaDistributionAndOrdering( + table, MERGE, UNSPECIFIED_DISTRIBUTION, expectedOrdering); + } + + @Test + public void testDefaultPositionDeltaMergePartitionedSortedTable() { + sql( + "CREATE TABLE %s (id BIGINT, data STRING, date DATE, ts TIMESTAMP) " + + "USING iceberg " + + "PARTITIONED BY (date, bucket(8, data))", + tableName); + + Table table = validationCatalog.loadTable(tableIdent); + + table.replaceSortOrder().asc("id").commit(); + + Expression[] expectedClustering = + new Expression[] { + Expressions.column(MetadataColumns.SPEC_ID.name()), + Expressions.column(MetadataColumns.PARTITION_COLUMN_NAME), + Expressions.identity("date"), + Expressions.bucket(8, "data") + }; + Distribution expectedDistribution = Distributions.clustered(expectedClustering); + + SortOrder[] expectedOrdering = + new SortOrder[] { + Expressions.sort( + Expressions.column(MetadataColumns.SPEC_ID.name()), SortDirection.ASCENDING), + Expressions.sort( + Expressions.column(MetadataColumns.PARTITION_COLUMN_NAME), SortDirection.ASCENDING), + Expressions.sort( + Expressions.column(MetadataColumns.FILE_PATH.name()), SortDirection.ASCENDING), + Expressions.sort( + Expressions.column(MetadataColumns.ROW_POSITION.name()), SortDirection.ASCENDING), + Expressions.sort(Expressions.column("date"), SortDirection.ASCENDING), + Expressions.sort(Expressions.bucket(8, "data"), SortDirection.ASCENDING), + Expressions.sort(Expressions.column("id"), SortDirection.ASCENDING) + }; + + checkPositionDeltaDistributionAndOrdering(table, MERGE, expectedDistribution, expectedOrdering); + } + + @Test + public void testHashPositionDeltaMergePartitionedSortedTable() { + sql( + "CREATE TABLE %s (id BIGINT, data STRING, date DATE, ts TIMESTAMP) " + + "USING iceberg " + + "PARTITIONED BY (date, bucket(8, data))", + tableName); + + Table table = validationCatalog.loadTable(tableIdent); + + table.updateProperties().set(MERGE_DISTRIBUTION_MODE, WRITE_DISTRIBUTION_MODE_HASH).commit(); + + table.replaceSortOrder().asc("id").commit(); + + Expression[] expectedClustering = + new Expression[] { + Expressions.column(MetadataColumns.SPEC_ID.name()), + Expressions.column(MetadataColumns.PARTITION_COLUMN_NAME), + Expressions.identity("date"), + Expressions.bucket(8, "data") + }; + Distribution expectedDistribution = Distributions.clustered(expectedClustering); + + SortOrder[] expectedOrdering = + new SortOrder[] { + Expressions.sort( + Expressions.column(MetadataColumns.SPEC_ID.name()), SortDirection.ASCENDING), + Expressions.sort( + Expressions.column(MetadataColumns.PARTITION_COLUMN_NAME), SortDirection.ASCENDING), + Expressions.sort( + Expressions.column(MetadataColumns.FILE_PATH.name()), SortDirection.ASCENDING), + Expressions.sort( + Expressions.column(MetadataColumns.ROW_POSITION.name()), SortDirection.ASCENDING), + Expressions.sort(Expressions.column("date"), SortDirection.ASCENDING), + Expressions.sort(Expressions.bucket(8, "data"), SortDirection.ASCENDING), + Expressions.sort(Expressions.column("id"), SortDirection.ASCENDING) + }; + + checkPositionDeltaDistributionAndOrdering(table, MERGE, expectedDistribution, expectedOrdering); + } + + @Test + public void testRangePositionDeltaMergePartitionedSortedTable() { + sql( + "CREATE TABLE %s (id BIGINT, data STRING, date DATE, ts TIMESTAMP) " + + "USING iceberg " + + "PARTITIONED BY (date)", + tableName); + + Table table = validationCatalog.loadTable(tableIdent); + + table.updateProperties().set(MERGE_DISTRIBUTION_MODE, WRITE_DISTRIBUTION_MODE_RANGE).commit(); + + table.replaceSortOrder().asc("id").commit(); + + SortOrder[] expectedDistributionOrdering = + new SortOrder[] { + Expressions.sort( + Expressions.column(MetadataColumns.SPEC_ID.name()), SortDirection.ASCENDING), + Expressions.sort( + Expressions.column(MetadataColumns.PARTITION_COLUMN_NAME), SortDirection.ASCENDING), + Expressions.sort( + Expressions.column(MetadataColumns.FILE_PATH.name()), SortDirection.ASCENDING), + Expressions.sort(Expressions.column("date"), SortDirection.ASCENDING), + Expressions.sort(Expressions.column("id"), SortDirection.ASCENDING) + }; + Distribution expectedDistribution = Distributions.ordered(expectedDistributionOrdering); + + SortOrder[] expectedOrdering = + new SortOrder[] { + Expressions.sort( + Expressions.column(MetadataColumns.SPEC_ID.name()), SortDirection.ASCENDING), + Expressions.sort( + Expressions.column(MetadataColumns.PARTITION_COLUMN_NAME), SortDirection.ASCENDING), + Expressions.sort( + Expressions.column(MetadataColumns.FILE_PATH.name()), SortDirection.ASCENDING), + Expressions.sort( + Expressions.column(MetadataColumns.ROW_POSITION.name()), SortDirection.ASCENDING), + Expressions.sort(Expressions.column("date"), SortDirection.ASCENDING), + Expressions.sort(Expressions.column("id"), SortDirection.ASCENDING) + }; + + checkPositionDeltaDistributionAndOrdering(table, MERGE, expectedDistribution, expectedOrdering); + } + + private void checkWriteDistributionAndOrdering( + Table table, Distribution expectedDistribution, SortOrder[] expectedOrdering) { + SparkWriteConf writeConf = new SparkWriteConf(spark, table, ImmutableMap.of()); + DistributionMode distributionMode = writeConf.distributionMode(); + Distribution distribution = + SparkDistributionAndOrderingUtil.buildRequiredDistribution(table, distributionMode); + Assert.assertEquals("Distribution must match", expectedDistribution, distribution); + + SortOrder[] ordering = + SparkDistributionAndOrderingUtil.buildRequiredOrdering(table, distribution); + Assert.assertArrayEquals("Ordering must match", expectedOrdering, ordering); + } + + private void checkCopyOnWriteDistributionAndOrdering( + Table table, + Command command, + Distribution expectedDistribution, + SortOrder[] expectedOrdering) { + SparkWriteConf writeConf = new SparkWriteConf(spark, table, ImmutableMap.of()); + + DistributionMode mode = copyOnWriteDistributionMode(command, writeConf); + + Distribution distribution = + SparkDistributionAndOrderingUtil.buildCopyOnWriteDistribution(table, command, mode); + Assert.assertEquals("Distribution must match", expectedDistribution, distribution); + + SortOrder[] ordering = + SparkDistributionAndOrderingUtil.buildCopyOnWriteOrdering(table, command, distribution); + Assert.assertArrayEquals("Ordering must match", expectedOrdering, ordering); + } + + private DistributionMode copyOnWriteDistributionMode(Command command, SparkWriteConf writeConf) { + switch (command) { + case DELETE: + return writeConf.deleteDistributionMode(); + case UPDATE: + return writeConf.updateDistributionMode(); + case MERGE: + return writeConf.copyOnWriteMergeDistributionMode(); + default: + throw new IllegalArgumentException("Unexpected command: " + command); + } + } + + private void checkPositionDeltaDistributionAndOrdering( + Table table, + Command command, + Distribution expectedDistribution, + SortOrder[] expectedOrdering) { + SparkWriteConf writeConf = new SparkWriteConf(spark, table, ImmutableMap.of()); + + DistributionMode mode = positionDeltaDistributionMode(command, writeConf); + + Distribution distribution = + SparkDistributionAndOrderingUtil.buildPositionDeltaDistribution(table, command, mode); + Assert.assertEquals("Distribution must match", expectedDistribution, distribution); + + SortOrder[] ordering = + SparkDistributionAndOrderingUtil.buildPositionDeltaOrdering(table, command); + Assert.assertArrayEquals("Ordering must match", expectedOrdering, ordering); + } + + private DistributionMode positionDeltaDistributionMode( + Command command, SparkWriteConf writeConf) { + switch (command) { + case DELETE: + return writeConf.deleteDistributionMode(); + case UPDATE: + return writeConf.updateDistributionMode(); + case MERGE: + return writeConf.positionDeltaMergeDistributionMode(); + default: + throw new IllegalArgumentException("Unexpected command: " + command); + } + } +} diff --git a/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/TestSparkFilters.java b/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/TestSparkFilters.java new file mode 100644 index 000000000000..2e56b6aa91b0 --- /dev/null +++ b/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/TestSparkFilters.java @@ -0,0 +1,168 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark; + +import java.sql.Date; +import java.sql.Timestamp; +import java.time.Instant; +import java.time.LocalDate; +import java.time.temporal.ChronoUnit; +import java.util.Map; +import org.apache.iceberg.expressions.Expression; +import org.apache.iceberg.expressions.Expressions; +import org.apache.iceberg.relocated.com.google.common.collect.Maps; +import org.apache.spark.sql.sources.And; +import org.apache.spark.sql.sources.EqualNullSafe; +import org.apache.spark.sql.sources.EqualTo; +import org.apache.spark.sql.sources.GreaterThan; +import org.apache.spark.sql.sources.GreaterThanOrEqual; +import org.apache.spark.sql.sources.In; +import org.apache.spark.sql.sources.IsNotNull; +import org.apache.spark.sql.sources.IsNull; +import org.apache.spark.sql.sources.LessThan; +import org.apache.spark.sql.sources.LessThanOrEqual; +import org.apache.spark.sql.sources.Not; +import org.junit.Assert; +import org.junit.Test; + +public class TestSparkFilters { + + @Test + public void testQuotedAttributes() { + Map attrMap = Maps.newHashMap(); + attrMap.put("id", "id"); + attrMap.put("`i.d`", "i.d"); + attrMap.put("`i``d`", "i`d"); + attrMap.put("`d`.b.`dd```", "d.b.dd`"); + attrMap.put("a.`aa```.c", "a.aa`.c"); + + attrMap.forEach( + (quoted, unquoted) -> { + IsNull isNull = IsNull.apply(quoted); + Expression expectedIsNull = Expressions.isNull(unquoted); + Expression actualIsNull = SparkFilters.convert(isNull); + Assert.assertEquals( + "IsNull must match", expectedIsNull.toString(), actualIsNull.toString()); + + IsNotNull isNotNull = IsNotNull.apply(quoted); + Expression expectedIsNotNull = Expressions.notNull(unquoted); + Expression actualIsNotNull = SparkFilters.convert(isNotNull); + Assert.assertEquals( + "IsNotNull must match", expectedIsNotNull.toString(), actualIsNotNull.toString()); + + LessThan lt = LessThan.apply(quoted, 1); + Expression expectedLt = Expressions.lessThan(unquoted, 1); + Expression actualLt = SparkFilters.convert(lt); + Assert.assertEquals("LessThan must match", expectedLt.toString(), actualLt.toString()); + + LessThanOrEqual ltEq = LessThanOrEqual.apply(quoted, 1); + Expression expectedLtEq = Expressions.lessThanOrEqual(unquoted, 1); + Expression actualLtEq = SparkFilters.convert(ltEq); + Assert.assertEquals( + "LessThanOrEqual must match", expectedLtEq.toString(), actualLtEq.toString()); + + GreaterThan gt = GreaterThan.apply(quoted, 1); + Expression expectedGt = Expressions.greaterThan(unquoted, 1); + Expression actualGt = SparkFilters.convert(gt); + Assert.assertEquals("GreaterThan must match", expectedGt.toString(), actualGt.toString()); + + GreaterThanOrEqual gtEq = GreaterThanOrEqual.apply(quoted, 1); + Expression expectedGtEq = Expressions.greaterThanOrEqual(unquoted, 1); + Expression actualGtEq = SparkFilters.convert(gtEq); + Assert.assertEquals( + "GreaterThanOrEqual must match", expectedGtEq.toString(), actualGtEq.toString()); + + EqualTo eq = EqualTo.apply(quoted, 1); + Expression expectedEq = Expressions.equal(unquoted, 1); + Expression actualEq = SparkFilters.convert(eq); + Assert.assertEquals("EqualTo must match", expectedEq.toString(), actualEq.toString()); + + EqualNullSafe eqNullSafe = EqualNullSafe.apply(quoted, 1); + Expression expectedEqNullSafe = Expressions.equal(unquoted, 1); + Expression actualEqNullSafe = SparkFilters.convert(eqNullSafe); + Assert.assertEquals( + "EqualNullSafe must match", + expectedEqNullSafe.toString(), + actualEqNullSafe.toString()); + + In in = In.apply(quoted, new Integer[] {1}); + Expression expectedIn = Expressions.in(unquoted, 1); + Expression actualIn = SparkFilters.convert(in); + Assert.assertEquals("In must match", expectedIn.toString(), actualIn.toString()); + }); + } + + @Test + public void testTimestampFilterConversion() { + Instant instant = Instant.parse("2018-10-18T00:00:57.907Z"); + Timestamp timestamp = Timestamp.from(instant); + long epochMicros = ChronoUnit.MICROS.between(Instant.EPOCH, instant); + + Expression instantExpression = SparkFilters.convert(GreaterThan.apply("x", instant)); + Expression timestampExpression = SparkFilters.convert(GreaterThan.apply("x", timestamp)); + Expression rawExpression = Expressions.greaterThan("x", epochMicros); + + Assert.assertEquals( + "Generated Timestamp expression should be correct", + rawExpression.toString(), + timestampExpression.toString()); + Assert.assertEquals( + "Generated Instant expression should be correct", + rawExpression.toString(), + instantExpression.toString()); + } + + @Test + public void testDateFilterConversion() { + LocalDate localDate = LocalDate.parse("2018-10-18"); + Date date = Date.valueOf(localDate); + long epochDay = localDate.toEpochDay(); + + Expression localDateExpression = SparkFilters.convert(GreaterThan.apply("x", localDate)); + Expression dateExpression = SparkFilters.convert(GreaterThan.apply("x", date)); + Expression rawExpression = Expressions.greaterThan("x", epochDay); + + Assert.assertEquals( + "Generated localdate expression should be correct", + rawExpression.toString(), + localDateExpression.toString()); + + Assert.assertEquals( + "Generated date expression should be correct", + rawExpression.toString(), + dateExpression.toString()); + } + + @Test + public void testNestedInInsideNot() { + Not filter = + Not.apply(And.apply(EqualTo.apply("col1", 1), In.apply("col2", new Integer[] {1, 2}))); + Expression converted = SparkFilters.convert(filter); + Assert.assertNull("Expression should not be converted", converted); + } + + @Test + public void testNotIn() { + Not filter = Not.apply(In.apply("col", new Integer[] {1, 2})); + Expression actual = SparkFilters.convert(filter); + Expression expected = + Expressions.and(Expressions.notNull("col"), Expressions.notIn("col", 1, 2)); + Assert.assertEquals("Expressions should match", expected.toString(), actual.toString()); + } +} diff --git a/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/TestSparkSchemaUtil.java b/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/TestSparkSchemaUtil.java new file mode 100644 index 000000000000..259f7c3dd789 --- /dev/null +++ b/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/TestSparkSchemaUtil.java @@ -0,0 +1,84 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark; + +import static org.apache.iceberg.types.Types.NestedField.optional; + +import java.io.IOException; +import java.util.List; +import org.apache.iceberg.MetadataColumns; +import org.apache.iceberg.Schema; +import org.apache.iceberg.types.Types; +import org.apache.spark.sql.catalyst.expressions.AttributeReference; +import org.apache.spark.sql.catalyst.expressions.MetadataAttribute; +import org.apache.spark.sql.types.StructType; +import org.junit.Assert; +import org.junit.Test; + +public class TestSparkSchemaUtil { + private static final Schema TEST_SCHEMA = + new Schema( + optional(1, "id", Types.IntegerType.get()), optional(2, "data", Types.StringType.get())); + + private static final Schema TEST_SCHEMA_WITH_METADATA_COLS = + new Schema( + optional(1, "id", Types.IntegerType.get()), + optional(2, "data", Types.StringType.get()), + MetadataColumns.FILE_PATH, + MetadataColumns.ROW_POSITION); + + @Test + public void testEstimateSizeMaxValue() throws IOException { + Assert.assertEquals( + "estimateSize returns Long max value", + Long.MAX_VALUE, + SparkSchemaUtil.estimateSize(null, Long.MAX_VALUE)); + } + + @Test + public void testEstimateSizeWithOverflow() throws IOException { + long tableSize = + SparkSchemaUtil.estimateSize(SparkSchemaUtil.convert(TEST_SCHEMA), Long.MAX_VALUE - 1); + Assert.assertEquals("estimateSize handles overflow", Long.MAX_VALUE, tableSize); + } + + @Test + public void testEstimateSize() throws IOException { + long tableSize = SparkSchemaUtil.estimateSize(SparkSchemaUtil.convert(TEST_SCHEMA), 1); + Assert.assertEquals("estimateSize matches with expected approximation", 24, tableSize); + } + + @Test + public void testSchemaConversionWithMetaDataColumnSchema() { + StructType structType = SparkSchemaUtil.convert(TEST_SCHEMA_WITH_METADATA_COLS); + List attrRefs = + scala.collection.JavaConverters.seqAsJavaList(structType.toAttributes()); + for (AttributeReference attrRef : attrRefs) { + if (MetadataColumns.isMetadataColumn(attrRef.name())) { + Assert.assertTrue( + "metadata columns should have __metadata_col in attribute metadata", + MetadataAttribute.unapply(attrRef).isDefined()); + } else { + Assert.assertFalse( + "non metadata columns should not have __metadata_col in attribute metadata", + MetadataAttribute.unapply(attrRef).isDefined()); + } + } + } +} diff --git a/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/TestSparkSessionCatalog.java b/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/TestSparkSessionCatalog.java new file mode 100644 index 000000000000..82a2fb473360 --- /dev/null +++ b/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/TestSparkSessionCatalog.java @@ -0,0 +1,113 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark; + +import static org.apache.hadoop.hive.conf.HiveConf.ConfVars.METASTOREURIS; + +import org.junit.Assert; +import org.junit.Before; +import org.junit.BeforeClass; +import org.junit.Test; + +public class TestSparkSessionCatalog extends SparkTestBase { + private final String envHmsUriKey = "spark.hadoop." + METASTOREURIS.varname; + private final String catalogHmsUriKey = "spark.sql.catalog.spark_catalog.uri"; + private final String hmsUri = hiveConf.get(METASTOREURIS.varname); + + @BeforeClass + public static void setUpCatalog() { + spark + .conf() + .set("spark.sql.catalog.spark_catalog", "org.apache.iceberg.spark.SparkSessionCatalog"); + spark.conf().set("spark.sql.catalog.spark_catalog.type", "hive"); + } + + @Before + public void setupHmsUri() { + spark.sessionState().catalogManager().reset(); + spark.conf().set(envHmsUriKey, hmsUri); + spark.conf().set(catalogHmsUriKey, hmsUri); + } + + @Test + public void testValidateHmsUri() { + // HMS uris match + Assert.assertTrue( + spark + .sessionState() + .catalogManager() + .v2SessionCatalog() + .defaultNamespace()[0] + .equals("default")); + + // HMS uris doesn't match + spark.sessionState().catalogManager().reset(); + String catalogHmsUri = "RandomString"; + spark.conf().set(envHmsUriKey, hmsUri); + spark.conf().set(catalogHmsUriKey, catalogHmsUri); + IllegalArgumentException exception = + Assert.assertThrows( + IllegalArgumentException.class, + () -> spark.sessionState().catalogManager().v2SessionCatalog()); + String errorMessage = + String.format( + "Inconsistent Hive metastore URIs: %s (Spark session) != %s (spark_catalog)", + hmsUri, catalogHmsUri); + Assert.assertEquals(errorMessage, exception.getMessage()); + + // no env HMS uri, only catalog HMS uri + spark.sessionState().catalogManager().reset(); + spark.conf().set(catalogHmsUriKey, hmsUri); + spark.conf().unset(envHmsUriKey); + Assert.assertTrue( + spark + .sessionState() + .catalogManager() + .v2SessionCatalog() + .defaultNamespace()[0] + .equals("default")); + + // no catalog HMS uri, only env HMS uri + spark.sessionState().catalogManager().reset(); + spark.conf().set(envHmsUriKey, hmsUri); + spark.conf().unset(catalogHmsUriKey); + Assert.assertTrue( + spark + .sessionState() + .catalogManager() + .v2SessionCatalog() + .defaultNamespace()[0] + .equals("default")); + } + + @Test + public void testLoadFunction() { + String functionClass = "org.apache.hadoop.hive.ql.udf.generic.GenericUDFUpper"; + + // load permanent UDF in Hive via FunctionCatalog + spark.sql(String.format("CREATE FUNCTION perm_upper AS '%s'", functionClass)); + Assert.assertEquals("Load permanent UDF in Hive", "XYZ", scalarSql("SELECT perm_upper('xyz')")); + + // load temporary UDF in Hive via FunctionCatalog + spark.sql(String.format("CREATE TEMPORARY FUNCTION temp_upper AS '%s'", functionClass)); + Assert.assertEquals("Load temporary UDF in Hive", "XYZ", scalarSql("SELECT temp_upper('xyz')")); + + // TODO: fix loading Iceberg built-in functions in SessionCatalog + } +} diff --git a/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/TestSparkTableUtil.java b/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/TestSparkTableUtil.java new file mode 100644 index 000000000000..1e51caadd0e9 --- /dev/null +++ b/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/TestSparkTableUtil.java @@ -0,0 +1,102 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark; + +import java.io.IOException; +import java.util.Map; +import org.apache.iceberg.KryoHelpers; +import org.apache.iceberg.MetricsConfig; +import org.apache.iceberg.MetricsModes; +import org.apache.iceberg.TableProperties; +import org.apache.iceberg.TestHelpers; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap; +import org.apache.iceberg.spark.SparkTableUtil.SparkPartition; +import org.assertj.core.api.Assertions; +import org.junit.Assert; +import org.junit.Test; + +public class TestSparkTableUtil { + @Test + public void testSparkPartitionOKryoSerialization() throws IOException { + Map values = ImmutableMap.of("id", "2"); + String uri = "s3://bucket/table/data/id=2"; + String format = "parquet"; + SparkPartition sparkPartition = new SparkPartition(values, uri, format); + + SparkPartition deserialized = KryoHelpers.roundTripSerialize(sparkPartition); + Assertions.assertThat(sparkPartition).isEqualTo(deserialized); + } + + @Test + public void testSparkPartitionJavaSerialization() throws IOException, ClassNotFoundException { + Map values = ImmutableMap.of("id", "2"); + String uri = "s3://bucket/table/data/id=2"; + String format = "parquet"; + SparkPartition sparkPartition = new SparkPartition(values, uri, format); + + SparkPartition deserialized = TestHelpers.roundTripSerialize(sparkPartition); + Assertions.assertThat(sparkPartition).isEqualTo(deserialized); + } + + @Test + public void testMetricsConfigKryoSerialization() throws Exception { + Map metricsConfig = + ImmutableMap.of( + TableProperties.DEFAULT_WRITE_METRICS_MODE, + "counts", + TableProperties.METRICS_MODE_COLUMN_CONF_PREFIX + "col1", + "full", + TableProperties.METRICS_MODE_COLUMN_CONF_PREFIX + "col2", + "truncate(16)"); + + MetricsConfig config = MetricsConfig.fromProperties(metricsConfig); + MetricsConfig deserialized = KryoHelpers.roundTripSerialize(config); + + Assert.assertEquals( + MetricsModes.Full.get().toString(), deserialized.columnMode("col1").toString()); + Assert.assertEquals( + MetricsModes.Truncate.withLength(16).toString(), + deserialized.columnMode("col2").toString()); + Assert.assertEquals( + MetricsModes.Counts.get().toString(), deserialized.columnMode("col3").toString()); + } + + @Test + public void testMetricsConfigJavaSerialization() throws Exception { + Map metricsConfig = + ImmutableMap.of( + TableProperties.DEFAULT_WRITE_METRICS_MODE, + "counts", + TableProperties.METRICS_MODE_COLUMN_CONF_PREFIX + "col1", + "full", + TableProperties.METRICS_MODE_COLUMN_CONF_PREFIX + "col2", + "truncate(16)"); + + MetricsConfig config = MetricsConfig.fromProperties(metricsConfig); + MetricsConfig deserialized = TestHelpers.roundTripSerialize(config); + + Assert.assertEquals( + MetricsModes.Full.get().toString(), deserialized.columnMode("col1").toString()); + Assert.assertEquals( + MetricsModes.Truncate.withLength(16).toString(), + deserialized.columnMode("col2").toString()); + Assert.assertEquals( + MetricsModes.Counts.get().toString(), deserialized.columnMode("col3").toString()); + } +} diff --git a/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/TestSparkV2Filters.java b/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/TestSparkV2Filters.java new file mode 100644 index 000000000000..4c8a32fa41a4 --- /dev/null +++ b/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/TestSparkV2Filters.java @@ -0,0 +1,267 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark; + +import java.time.Instant; +import java.time.LocalDate; +import java.time.temporal.ChronoUnit; +import java.util.Map; +import org.apache.iceberg.expressions.Expression; +import org.apache.iceberg.expressions.Expressions; +import org.apache.iceberg.relocated.com.google.common.collect.Maps; +import org.apache.spark.sql.connector.expressions.FieldReference; +import org.apache.spark.sql.connector.expressions.LiteralValue; +import org.apache.spark.sql.connector.expressions.NamedReference; +import org.apache.spark.sql.connector.expressions.filter.And; +import org.apache.spark.sql.connector.expressions.filter.Not; +import org.apache.spark.sql.connector.expressions.filter.Or; +import org.apache.spark.sql.connector.expressions.filter.Predicate; +import org.apache.spark.sql.types.DataTypes; +import org.apache.spark.unsafe.types.UTF8String; +import org.junit.Assert; +import org.junit.Test; + +public class TestSparkV2Filters { + + @Test + public void testV2Filters() { + Map attrMap = Maps.newHashMap(); + attrMap.put("id", "id"); + attrMap.put("`i.d`", "i.d"); + attrMap.put("`i``d`", "i`d"); + attrMap.put("`d`.b.`dd```", "d.b.dd`"); + attrMap.put("a.`aa```.c", "a.aa`.c"); + + attrMap.forEach( + (quoted, unquoted) -> { + NamedReference namedReference = FieldReference.apply(quoted); + org.apache.spark.sql.connector.expressions.Expression[] attrOnly = + new org.apache.spark.sql.connector.expressions.Expression[] {namedReference}; + + LiteralValue value = new LiteralValue(1, DataTypes.IntegerType); + org.apache.spark.sql.connector.expressions.Expression[] attrAndValue = + new org.apache.spark.sql.connector.expressions.Expression[] {namedReference, value}; + org.apache.spark.sql.connector.expressions.Expression[] valueAndAttr = + new org.apache.spark.sql.connector.expressions.Expression[] {value, namedReference}; + + Predicate isNull = new Predicate("IS_NULL", attrOnly); + Expression expectedIsNull = Expressions.isNull(unquoted); + Expression actualIsNull = SparkV2Filters.convert(isNull); + Assert.assertEquals( + "IsNull must match", expectedIsNull.toString(), actualIsNull.toString()); + + Predicate isNotNull = new Predicate("IS_NOT_NULL", attrOnly); + Expression expectedIsNotNull = Expressions.notNull(unquoted); + Expression actualIsNotNull = SparkV2Filters.convert(isNotNull); + Assert.assertEquals( + "IsNotNull must match", expectedIsNotNull.toString(), actualIsNotNull.toString()); + + Predicate lt1 = new Predicate("<", attrAndValue); + Expression expectedLt1 = Expressions.lessThan(unquoted, 1); + Expression actualLt1 = SparkV2Filters.convert(lt1); + Assert.assertEquals("LessThan must match", expectedLt1.toString(), actualLt1.toString()); + + Predicate lt2 = new Predicate("<", valueAndAttr); + Expression expectedLt2 = Expressions.greaterThan(unquoted, 1); + Expression actualLt2 = SparkV2Filters.convert(lt2); + Assert.assertEquals("LessThan must match", expectedLt2.toString(), actualLt2.toString()); + + Predicate ltEq1 = new Predicate("<=", attrAndValue); + Expression expectedLtEq1 = Expressions.lessThanOrEqual(unquoted, 1); + Expression actualLtEq1 = SparkV2Filters.convert(ltEq1); + Assert.assertEquals( + "LessThanOrEqual must match", expectedLtEq1.toString(), actualLtEq1.toString()); + + Predicate ltEq2 = new Predicate("<=", valueAndAttr); + Expression expectedLtEq2 = Expressions.greaterThanOrEqual(unquoted, 1); + Expression actualLtEq2 = SparkV2Filters.convert(ltEq2); + Assert.assertEquals( + "LessThanOrEqual must match", expectedLtEq2.toString(), actualLtEq2.toString()); + + Predicate gt1 = new Predicate(">", attrAndValue); + Expression expectedGt1 = Expressions.greaterThan(unquoted, 1); + Expression actualGt1 = SparkV2Filters.convert(gt1); + Assert.assertEquals( + "GreaterThan must match", expectedGt1.toString(), actualGt1.toString()); + + Predicate gt2 = new Predicate(">", valueAndAttr); + Expression expectedGt2 = Expressions.lessThan(unquoted, 1); + Expression actualGt2 = SparkV2Filters.convert(gt2); + Assert.assertEquals( + "GreaterThan must match", expectedGt2.toString(), actualGt2.toString()); + + Predicate gtEq1 = new Predicate(">=", attrAndValue); + Expression expectedGtEq1 = Expressions.greaterThanOrEqual(unquoted, 1); + Expression actualGtEq1 = SparkV2Filters.convert(gtEq1); + Assert.assertEquals( + "GreaterThanOrEqual must match", expectedGtEq1.toString(), actualGtEq1.toString()); + + Predicate gtEq2 = new Predicate(">=", valueAndAttr); + Expression expectedGtEq2 = Expressions.lessThanOrEqual(unquoted, 1); + Expression actualGtEq2 = SparkV2Filters.convert(gtEq2); + Assert.assertEquals( + "GreaterThanOrEqual must match", expectedGtEq2.toString(), actualGtEq2.toString()); + + Predicate eq1 = new Predicate("=", attrAndValue); + Expression expectedEq1 = Expressions.equal(unquoted, 1); + Expression actualEq1 = SparkV2Filters.convert(eq1); + Assert.assertEquals("EqualTo must match", expectedEq1.toString(), actualEq1.toString()); + + Predicate eq2 = new Predicate("=", valueAndAttr); + Expression expectedEq2 = Expressions.equal(unquoted, 1); + Expression actualEq2 = SparkV2Filters.convert(eq2); + Assert.assertEquals("EqualTo must match", expectedEq2.toString(), actualEq2.toString()); + + Predicate eqNullSafe1 = new Predicate("<=>", attrAndValue); + Expression expectedEqNullSafe1 = Expressions.equal(unquoted, 1); + Expression actualEqNullSafe1 = SparkV2Filters.convert(eqNullSafe1); + Assert.assertEquals( + "EqualNullSafe must match", + expectedEqNullSafe1.toString(), + actualEqNullSafe1.toString()); + + Predicate eqNullSafe2 = new Predicate("<=>", valueAndAttr); + Expression expectedEqNullSafe2 = Expressions.equal(unquoted, 1); + Expression actualEqNullSafe2 = SparkV2Filters.convert(eqNullSafe2); + Assert.assertEquals( + "EqualNullSafe must match", + expectedEqNullSafe2.toString(), + actualEqNullSafe2.toString()); + + LiteralValue str = + new LiteralValue(UTF8String.fromString("iceberg"), DataTypes.StringType); + org.apache.spark.sql.connector.expressions.Expression[] attrAndStr = + new org.apache.spark.sql.connector.expressions.Expression[] {namedReference, str}; + Predicate startsWith = new Predicate("STARTS_WITH", attrAndStr); + Expression expectedStartsWith = Expressions.startsWith(unquoted, "iceberg"); + Expression actualStartsWith = SparkV2Filters.convert(startsWith); + Assert.assertEquals( + "StartsWith must match", expectedStartsWith.toString(), actualStartsWith.toString()); + + Predicate in = new Predicate("IN", attrAndValue); + Expression expectedIn = Expressions.in(unquoted, 1); + Expression actualIn = SparkV2Filters.convert(in); + Assert.assertEquals("In must match", expectedIn.toString(), actualIn.toString()); + + Predicate and = new And(lt1, eq1); + Expression expectedAnd = Expressions.and(expectedLt1, expectedEq1); + Expression actualAnd = SparkV2Filters.convert(and); + Assert.assertEquals("And must match", expectedAnd.toString(), actualAnd.toString()); + + org.apache.spark.sql.connector.expressions.Expression[] attrAndAttr = + new org.apache.spark.sql.connector.expressions.Expression[] { + namedReference, namedReference + }; + Predicate invalid = new Predicate("<", attrAndAttr); + Predicate andWithInvalidLeft = new And(invalid, eq1); + Expression convertedAnd = SparkV2Filters.convert(andWithInvalidLeft); + Assert.assertEquals("And must match", convertedAnd, null); + + Predicate or = new Or(lt1, eq1); + Expression expectedOr = Expressions.or(expectedLt1, expectedEq1); + Expression actualOr = SparkV2Filters.convert(or); + Assert.assertEquals("Or must match", expectedOr.toString(), actualOr.toString()); + + Predicate orWithInvalidLeft = new Or(invalid, eq1); + Expression convertedOr = SparkV2Filters.convert(orWithInvalidLeft); + Assert.assertEquals("Or must match", convertedOr, null); + + Predicate not = new Not(lt1); + Expression expectedNot = Expressions.not(expectedLt1); + Expression actualNot = SparkV2Filters.convert(not); + Assert.assertEquals("Not must match", expectedNot.toString(), actualNot.toString()); + }); + } + + @Test + public void testTimestampFilterConversion() { + Instant instant = Instant.parse("2018-10-18T00:00:57.907Z"); + long epochMicros = ChronoUnit.MICROS.between(Instant.EPOCH, instant); + + NamedReference namedReference = FieldReference.apply("x"); + LiteralValue ts = new LiteralValue(epochMicros, DataTypes.TimestampType); + org.apache.spark.sql.connector.expressions.Expression[] attrAndValue = + new org.apache.spark.sql.connector.expressions.Expression[] {namedReference, ts}; + + Predicate predicate = new Predicate(">", attrAndValue); + Expression tsExpression = SparkV2Filters.convert(predicate); + Expression rawExpression = Expressions.greaterThan("x", epochMicros); + + Assert.assertEquals( + "Generated Timestamp expression should be correct", + rawExpression.toString(), + tsExpression.toString()); + } + + @Test + public void testDateFilterConversion() { + LocalDate localDate = LocalDate.parse("2018-10-18"); + long epochDay = localDate.toEpochDay(); + + NamedReference namedReference = FieldReference.apply("x"); + LiteralValue ts = new LiteralValue(epochDay, DataTypes.DateType); + org.apache.spark.sql.connector.expressions.Expression[] attrAndValue = + new org.apache.spark.sql.connector.expressions.Expression[] {namedReference, ts}; + + Predicate predicate = new Predicate(">", attrAndValue); + Expression dateExpression = SparkV2Filters.convert(predicate); + Expression rawExpression = Expressions.greaterThan("x", epochDay); + + Assert.assertEquals( + "Generated date expression should be correct", + rawExpression.toString(), + dateExpression.toString()); + } + + @Test + public void testNestedInInsideNot() { + NamedReference namedReference1 = FieldReference.apply("col1"); + LiteralValue v1 = new LiteralValue(1, DataTypes.IntegerType); + LiteralValue v2 = new LiteralValue(2, DataTypes.IntegerType); + org.apache.spark.sql.connector.expressions.Expression[] attrAndValue1 = + new org.apache.spark.sql.connector.expressions.Expression[] {namedReference1, v1}; + Predicate equal = new Predicate("=", attrAndValue1); + + NamedReference namedReference2 = FieldReference.apply("col2"); + org.apache.spark.sql.connector.expressions.Expression[] attrAndValue2 = + new org.apache.spark.sql.connector.expressions.Expression[] {namedReference2, v1, v2}; + Predicate in = new Predicate("IN", attrAndValue2); + + Not filter = new Not(new And(equal, in)); + Expression converted = SparkV2Filters.convert(filter); + Assert.assertNull("Expression should not be converted", converted); + } + + @Test + public void testNotIn() { + NamedReference namedReference = FieldReference.apply("col"); + LiteralValue v1 = new LiteralValue(1, DataTypes.IntegerType); + LiteralValue v2 = new LiteralValue(2, DataTypes.IntegerType); + org.apache.spark.sql.connector.expressions.Expression[] attrAndValue = + new org.apache.spark.sql.connector.expressions.Expression[] {namedReference, v1, v2}; + + Predicate in = new Predicate("IN", attrAndValue); + Not not = new Not(in); + + Expression actual = SparkV2Filters.convert(not); + Expression expected = + Expressions.and(Expressions.notNull("col"), Expressions.notIn("col", 1, 2)); + Assert.assertEquals("Expressions should match", expected.toString(), actual.toString()); + } +} diff --git a/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/TestSparkValueConverter.java b/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/TestSparkValueConverter.java new file mode 100644 index 000000000000..7f00c7edd8a9 --- /dev/null +++ b/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/TestSparkValueConverter.java @@ -0,0 +1,94 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark; + +import org.apache.iceberg.Schema; +import org.apache.iceberg.data.GenericRecord; +import org.apache.iceberg.data.Record; +import org.apache.iceberg.types.Types; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.RowFactory; +import org.junit.Assert; +import org.junit.Test; + +public class TestSparkValueConverter { + @Test + public void testSparkNullMapConvert() { + Schema schema = + new Schema( + Types.NestedField.required(0, "id", Types.LongType.get()), + Types.NestedField.optional( + 5, + "locations", + Types.MapType.ofOptional( + 6, + 7, + Types.StringType.get(), + Types.StructType.of( + Types.NestedField.required(1, "lat", Types.FloatType.get()), + Types.NestedField.required(2, "long", Types.FloatType.get()))))); + + assertCorrectNullConversion(schema); + } + + @Test + public void testSparkNullListConvert() { + Schema schema = + new Schema( + Types.NestedField.required(0, "id", Types.LongType.get()), + Types.NestedField.optional( + 5, "locations", Types.ListType.ofOptional(6, Types.StringType.get()))); + + assertCorrectNullConversion(schema); + } + + @Test + public void testSparkNullStructConvert() { + Schema schema = + new Schema( + Types.NestedField.required(0, "id", Types.LongType.get()), + Types.NestedField.optional( + 5, + "location", + Types.StructType.of( + Types.NestedField.required(1, "lat", Types.FloatType.get()), + Types.NestedField.required(2, "long", Types.FloatType.get())))); + + assertCorrectNullConversion(schema); + } + + @Test + public void testSparkNullPrimitiveConvert() { + Schema schema = + new Schema( + Types.NestedField.required(0, "id", Types.LongType.get()), + Types.NestedField.optional(5, "location", Types.StringType.get())); + assertCorrectNullConversion(schema); + } + + private void assertCorrectNullConversion(Schema schema) { + Row sparkRow = RowFactory.create(1, null); + Record record = GenericRecord.create(schema); + record.set(0, 1); + Assert.assertEquals( + "Round-trip conversion should produce original value", + record, + SparkValueConverter.convert(schema, sparkRow)); + } +} diff --git a/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/TestSparkWriteConf.java b/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/TestSparkWriteConf.java new file mode 100644 index 000000000000..d5ebc2aeab67 --- /dev/null +++ b/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/TestSparkWriteConf.java @@ -0,0 +1,197 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark; + +import static org.apache.iceberg.TableProperties.DELETE_DISTRIBUTION_MODE; +import static org.apache.iceberg.TableProperties.MERGE_DISTRIBUTION_MODE; +import static org.apache.iceberg.TableProperties.UPDATE_DISTRIBUTION_MODE; +import static org.apache.iceberg.TableProperties.WRITE_DISTRIBUTION_MODE; +import static org.apache.iceberg.TableProperties.WRITE_DISTRIBUTION_MODE_HASH; +import static org.apache.iceberg.TableProperties.WRITE_DISTRIBUTION_MODE_NONE; +import static org.apache.iceberg.TableProperties.WRITE_DISTRIBUTION_MODE_RANGE; + +import java.util.Map; +import org.apache.iceberg.DistributionMode; +import org.apache.iceberg.Table; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap; +import org.junit.After; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; + +public class TestSparkWriteConf extends SparkTestBaseWithCatalog { + + @Before + public void before() { + sql( + "CREATE TABLE %s (id BIGINT, data STRING, date DATE, ts TIMESTAMP) " + + "USING iceberg " + + "PARTITIONED BY (date, days(ts))", + tableName); + } + + @After + public void after() { + sql("DROP TABLE IF EXISTS %s", tableName); + } + + @Test + public void testSparkWriteConfDistributionDefault() { + Table table = validationCatalog.loadTable(tableIdent); + + SparkWriteConf writeConf = new SparkWriteConf(spark, table, ImmutableMap.of()); + + Assert.assertEquals(DistributionMode.HASH, writeConf.distributionMode()); + Assert.assertEquals(DistributionMode.HASH, writeConf.deleteDistributionMode()); + Assert.assertEquals(DistributionMode.HASH, writeConf.updateDistributionMode()); + Assert.assertEquals(DistributionMode.HASH, writeConf.copyOnWriteMergeDistributionMode()); + Assert.assertEquals(DistributionMode.HASH, writeConf.positionDeltaMergeDistributionMode()); + } + + @Test + public void testSparkWriteConfDistributionModeWithWriteOption() { + Table table = validationCatalog.loadTable(tableIdent); + + Map writeOptions = + ImmutableMap.of(SparkWriteOptions.DISTRIBUTION_MODE, DistributionMode.NONE.modeName()); + + SparkWriteConf writeConf = new SparkWriteConf(spark, table, writeOptions); + Assert.assertEquals(DistributionMode.NONE, writeConf.distributionMode()); + Assert.assertEquals(DistributionMode.NONE, writeConf.deleteDistributionMode()); + Assert.assertEquals(DistributionMode.NONE, writeConf.updateDistributionMode()); + Assert.assertEquals(DistributionMode.NONE, writeConf.copyOnWriteMergeDistributionMode()); + Assert.assertEquals(DistributionMode.NONE, writeConf.positionDeltaMergeDistributionMode()); + } + + @Test + public void testSparkWriteConfDistributionModeWithSessionConfig() { + withSQLConf( + ImmutableMap.of(SparkSQLProperties.DISTRIBUTION_MODE, DistributionMode.NONE.modeName()), + () -> { + Table table = validationCatalog.loadTable(tableIdent); + + SparkWriteConf writeConf = new SparkWriteConf(spark, table, ImmutableMap.of()); + + Assert.assertEquals(DistributionMode.NONE, writeConf.distributionMode()); + Assert.assertEquals(DistributionMode.NONE, writeConf.deleteDistributionMode()); + Assert.assertEquals(DistributionMode.NONE, writeConf.updateDistributionMode()); + Assert.assertEquals(DistributionMode.NONE, writeConf.copyOnWriteMergeDistributionMode()); + Assert.assertEquals( + DistributionMode.NONE, writeConf.positionDeltaMergeDistributionMode()); + }); + } + + @Test + public void testSparkWriteConfDistributionModeWithTableProperties() { + Table table = validationCatalog.loadTable(tableIdent); + + table + .updateProperties() + .set(WRITE_DISTRIBUTION_MODE, WRITE_DISTRIBUTION_MODE_NONE) + .set(DELETE_DISTRIBUTION_MODE, WRITE_DISTRIBUTION_MODE_NONE) + .set(UPDATE_DISTRIBUTION_MODE, WRITE_DISTRIBUTION_MODE_NONE) + .set(MERGE_DISTRIBUTION_MODE, WRITE_DISTRIBUTION_MODE_NONE) + .commit(); + + SparkWriteConf writeConf = new SparkWriteConf(spark, table, ImmutableMap.of()); + Assert.assertEquals(DistributionMode.NONE, writeConf.distributionMode()); + Assert.assertEquals(DistributionMode.NONE, writeConf.deleteDistributionMode()); + Assert.assertEquals(DistributionMode.NONE, writeConf.updateDistributionMode()); + Assert.assertEquals(DistributionMode.NONE, writeConf.copyOnWriteMergeDistributionMode()); + Assert.assertEquals(DistributionMode.NONE, writeConf.positionDeltaMergeDistributionMode()); + } + + @Test + public void testSparkWriteConfDistributionModeWithTblPropAndSessionConfig() { + withSQLConf( + ImmutableMap.of(SparkSQLProperties.DISTRIBUTION_MODE, DistributionMode.NONE.modeName()), + () -> { + Table table = validationCatalog.loadTable(tableIdent); + + table + .updateProperties() + .set(WRITE_DISTRIBUTION_MODE, WRITE_DISTRIBUTION_MODE_RANGE) + .set(DELETE_DISTRIBUTION_MODE, WRITE_DISTRIBUTION_MODE_RANGE) + .set(UPDATE_DISTRIBUTION_MODE, WRITE_DISTRIBUTION_MODE_RANGE) + .set(MERGE_DISTRIBUTION_MODE, WRITE_DISTRIBUTION_MODE_RANGE) + .commit(); + + SparkWriteConf writeConf = new SparkWriteConf(spark, table, ImmutableMap.of()); + // session config overwrite the table properties + Assert.assertEquals(DistributionMode.NONE, writeConf.distributionMode()); + Assert.assertEquals(DistributionMode.NONE, writeConf.deleteDistributionMode()); + Assert.assertEquals(DistributionMode.NONE, writeConf.updateDistributionMode()); + Assert.assertEquals(DistributionMode.NONE, writeConf.copyOnWriteMergeDistributionMode()); + Assert.assertEquals( + DistributionMode.NONE, writeConf.positionDeltaMergeDistributionMode()); + }); + } + + @Test + public void testSparkWriteConfDistributionModeWithWriteOptionAndSessionConfig() { + withSQLConf( + ImmutableMap.of(SparkSQLProperties.DISTRIBUTION_MODE, DistributionMode.RANGE.modeName()), + () -> { + Table table = validationCatalog.loadTable(tableIdent); + + Map writeOptions = + ImmutableMap.of( + SparkWriteOptions.DISTRIBUTION_MODE, DistributionMode.NONE.modeName()); + + SparkWriteConf writeConf = new SparkWriteConf(spark, table, writeOptions); + // write options overwrite the session config + Assert.assertEquals(DistributionMode.NONE, writeConf.distributionMode()); + Assert.assertEquals(DistributionMode.NONE, writeConf.deleteDistributionMode()); + Assert.assertEquals(DistributionMode.NONE, writeConf.updateDistributionMode()); + Assert.assertEquals(DistributionMode.NONE, writeConf.copyOnWriteMergeDistributionMode()); + Assert.assertEquals( + DistributionMode.NONE, writeConf.positionDeltaMergeDistributionMode()); + }); + } + + @Test + public void testSparkWriteConfDistributionModeWithEverything() { + withSQLConf( + ImmutableMap.of(SparkSQLProperties.DISTRIBUTION_MODE, DistributionMode.RANGE.modeName()), + () -> { + Table table = validationCatalog.loadTable(tableIdent); + + Map writeOptions = + ImmutableMap.of( + SparkWriteOptions.DISTRIBUTION_MODE, DistributionMode.NONE.modeName()); + + table + .updateProperties() + .set(WRITE_DISTRIBUTION_MODE, WRITE_DISTRIBUTION_MODE_HASH) + .set(DELETE_DISTRIBUTION_MODE, WRITE_DISTRIBUTION_MODE_HASH) + .set(UPDATE_DISTRIBUTION_MODE, WRITE_DISTRIBUTION_MODE_HASH) + .set(MERGE_DISTRIBUTION_MODE, WRITE_DISTRIBUTION_MODE_HASH) + .commit(); + + SparkWriteConf writeConf = new SparkWriteConf(spark, table, writeOptions); + // write options take the highest priority + Assert.assertEquals(DistributionMode.NONE, writeConf.distributionMode()); + Assert.assertEquals(DistributionMode.NONE, writeConf.deleteDistributionMode()); + Assert.assertEquals(DistributionMode.NONE, writeConf.updateDistributionMode()); + Assert.assertEquals(DistributionMode.NONE, writeConf.copyOnWriteMergeDistributionMode()); + Assert.assertEquals( + DistributionMode.NONE, writeConf.positionDeltaMergeDistributionMode()); + }); + } +} diff --git a/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/actions/TestCreateActions.java b/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/actions/TestCreateActions.java new file mode 100644 index 000000000000..96950e8227f3 --- /dev/null +++ b/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/actions/TestCreateActions.java @@ -0,0 +1,1038 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.actions; + +import java.io.File; +import java.io.FilenameFilter; +import java.io.IOException; +import java.net.URI; +import java.nio.file.Paths; +import java.util.Arrays; +import java.util.Collections; +import java.util.List; +import java.util.Map; +import java.util.stream.Collectors; +import org.apache.commons.io.FileUtils; +import org.apache.commons.io.filefilter.TrueFileFilter; +import org.apache.hadoop.fs.Path; +import org.apache.iceberg.Schema; +import org.apache.iceberg.Table; +import org.apache.iceberg.actions.MigrateTable; +import org.apache.iceberg.actions.SnapshotTable; +import org.apache.iceberg.relocated.com.google.common.base.Preconditions; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap; +import org.apache.iceberg.relocated.com.google.common.collect.Lists; +import org.apache.iceberg.relocated.com.google.common.collect.Maps; +import org.apache.iceberg.spark.Spark3Util; +import org.apache.iceberg.spark.SparkCatalog; +import org.apache.iceberg.spark.SparkCatalogTestBase; +import org.apache.iceberg.spark.SparkSessionCatalog; +import org.apache.iceberg.spark.source.SimpleRecord; +import org.apache.iceberg.spark.source.SparkTable; +import org.apache.iceberg.types.Types; +import org.apache.parquet.hadoop.ParquetFileReader; +import org.apache.parquet.hadoop.util.HadoopInputFile; +import org.apache.parquet.schema.MessageType; +import org.apache.parquet.schema.MessageTypeParser; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.SaveMode; +import org.apache.spark.sql.catalyst.TableIdentifier; +import org.apache.spark.sql.catalyst.analysis.NoSuchDatabaseException; +import org.apache.spark.sql.catalyst.analysis.NoSuchTableException; +import org.apache.spark.sql.catalyst.catalog.CatalogTable; +import org.apache.spark.sql.catalyst.catalog.CatalogTablePartition; +import org.apache.spark.sql.catalyst.parser.ParseException; +import org.apache.spark.sql.connector.catalog.Identifier; +import org.apache.spark.sql.connector.catalog.TableCatalog; +import org.apache.spark.sql.types.ArrayType; +import org.apache.spark.sql.types.DataTypes; +import org.apache.spark.sql.types.Metadata; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; +import org.junit.After; +import org.junit.Assert; +import org.junit.Assume; +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.TemporaryFolder; +import org.junit.runners.Parameterized; +import scala.Option; +import scala.Some; +import scala.collection.JavaConverters; +import scala.collection.Seq; + +public class TestCreateActions extends SparkCatalogTestBase { + private static final String CREATE_PARTITIONED_PARQUET = + "CREATE TABLE %s (id INT, data STRING) " + "using parquet PARTITIONED BY (id) LOCATION '%s'"; + private static final String CREATE_PARQUET = + "CREATE TABLE %s (id INT, data STRING) " + "using parquet LOCATION '%s'"; + private static final String CREATE_HIVE_EXTERNAL_PARQUET = + "CREATE EXTERNAL TABLE %s (data STRING) " + + "PARTITIONED BY (id INT) STORED AS parquet LOCATION '%s'"; + private static final String CREATE_HIVE_PARQUET = + "CREATE TABLE %s (data STRING) " + "PARTITIONED BY (id INT) STORED AS parquet"; + + private static final String NAMESPACE = "default"; + + @Parameterized.Parameters(name = "Catalog Name {0} - Options {2}") + public static Object[][] parameters() { + return new Object[][] { + new Object[] { + "spark_catalog", + SparkSessionCatalog.class.getName(), + ImmutableMap.of( + "type", "hive", + "default-namespace", "default", + "parquet-enabled", "true", + "cache-enabled", + "false" // Spark will delete tables using v1, leaving the cache out of sync + ) + }, + new Object[] { + "spark_catalog", + SparkSessionCatalog.class.getName(), + ImmutableMap.of( + "type", "hadoop", + "default-namespace", "default", + "parquet-enabled", "true", + "cache-enabled", + "false" // Spark will delete tables using v1, leaving the cache out of sync + ) + }, + new Object[] { + "testhive", + SparkCatalog.class.getName(), + ImmutableMap.of( + "type", "hive", + "default-namespace", "default") + }, + new Object[] { + "testhadoop", + SparkCatalog.class.getName(), + ImmutableMap.of( + "type", "hadoop", + "default-namespace", "default") + } + }; + } + + @Rule public TemporaryFolder temp = new TemporaryFolder(); + + private String baseTableName = "baseTable"; + private File tableDir; + private String tableLocation; + private final String type; + private final TableCatalog catalog; + + public TestCreateActions(String catalogName, String implementation, Map config) { + super(catalogName, implementation, config); + this.catalog = (TableCatalog) spark.sessionState().catalogManager().catalog(catalogName); + this.type = config.get("type"); + } + + @Before + public void before() { + try { + this.tableDir = temp.newFolder(); + } catch (IOException e) { + throw new RuntimeException(e); + } + this.tableLocation = tableDir.toURI().toString(); + + spark.conf().set("hive.exec.dynamic.partition", "true"); + spark.conf().set("hive.exec.dynamic.partition.mode", "nonstrict"); + spark.conf().set("spark.sql.parquet.writeLegacyFormat", false); + spark.conf().set("spark.sql.parquet.writeLegacyFormat", false); + spark.sql(String.format("DROP TABLE IF EXISTS %s", baseTableName)); + + List expected = + Lists.newArrayList( + new SimpleRecord(1, "a"), new SimpleRecord(2, "b"), new SimpleRecord(3, "c")); + + Dataset df = spark.createDataFrame(expected, SimpleRecord.class); + + df.select("id", "data") + .orderBy("data") + .write() + .mode("append") + .option("path", tableLocation) + .saveAsTable(baseTableName); + } + + @After + public void after() throws IOException { + // Drop the hive table. + spark.sql(String.format("DROP TABLE IF EXISTS %s", baseTableName)); + } + + @Test + public void testMigratePartitioned() throws Exception { + Assume.assumeTrue("Cannot migrate to a hadoop based catalog", !type.equals("hadoop")); + Assume.assumeTrue( + "Can only migrate from Spark Session Catalog", catalog.name().equals("spark_catalog")); + String source = sourceName("test_migrate_partitioned_table"); + String dest = source; + createSourceTable(CREATE_PARTITIONED_PARQUET, source); + assertMigratedFileCount(SparkActions.get().migrateTable(source), source, dest); + } + + @Test + public void testPartitionedTableWithUnRecoveredPartitions() throws Exception { + Assume.assumeTrue("Cannot migrate to a hadoop based catalog", !type.equals("hadoop")); + Assume.assumeTrue( + "Can only migrate from Spark Session Catalog", catalog.name().equals("spark_catalog")); + String source = sourceName("test_unrecovered_partitions"); + String dest = source; + File location = temp.newFolder(); + sql(CREATE_PARTITIONED_PARQUET, source, location); + + // Data generation and partition addition + spark + .range(5) + .selectExpr("id", "cast(id as STRING) as data") + .write() + .partitionBy("id") + .mode(SaveMode.Overwrite) + .parquet(location.toURI().toString()); + sql("ALTER TABLE %s ADD PARTITION(id=0)", source); + + assertMigratedFileCount(SparkActions.get().migrateTable(source), source, dest); + } + + @Test + public void testPartitionedTableWithCustomPartitions() throws Exception { + Assume.assumeTrue("Cannot migrate to a hadoop based catalog", !type.equals("hadoop")); + Assume.assumeTrue( + "Can only migrate from Spark Session Catalog", catalog.name().equals("spark_catalog")); + String source = sourceName("test_custom_parts"); + String dest = source; + File tblLocation = temp.newFolder(); + File partitionDataLoc = temp.newFolder(); + + // Data generation and partition addition + spark.sql(String.format(CREATE_PARTITIONED_PARQUET, source, tblLocation)); + spark + .range(10) + .selectExpr("cast(id as STRING) as data") + .write() + .mode(SaveMode.Overwrite) + .parquet(partitionDataLoc.toURI().toString()); + sql( + "ALTER TABLE %s ADD PARTITION(id=0) LOCATION '%s'", + source, partitionDataLoc.toURI().toString()); + assertMigratedFileCount(SparkActions.get().migrateTable(source), source, dest); + } + + @Test + public void testAddColumnOnMigratedTableAtEnd() throws Exception { + Assume.assumeTrue("Cannot migrate to a hadoop based catalog", !type.equals("hadoop")); + Assume.assumeTrue( + "Can only migrate from Spark Session Catalog", catalog.name().equals("spark_catalog")); + String source = sourceName("test_add_column_migrated_table"); + String dest = source; + createSourceTable(CREATE_PARQUET, source); + List expected1 = sql("select *, null from %s order by id", source); + List expected2 = sql("select *, null, null from %s order by id", source); + + // migrate table + SparkActions.get().migrateTable(source).execute(); + SparkTable sparkTable = loadTable(dest); + Table table = sparkTable.table(); + + // test column addition on migrated table + Schema beforeSchema = table.schema(); + String newCol1 = "newCol1"; + sparkTable.table().updateSchema().addColumn(newCol1, Types.IntegerType.get()).commit(); + Schema afterSchema = table.schema(); + Assert.assertNull(beforeSchema.findField(newCol1)); + Assert.assertNotNull(afterSchema.findField(newCol1)); + + // reads should succeed without any exceptions + List results1 = sql("select * from %s order by id", dest); + Assert.assertTrue(results1.size() > 0); + assertEquals("Output must match", results1, expected1); + + String newCol2 = "newCol2"; + sql("ALTER TABLE %s ADD COLUMN %s INT", dest, newCol2); + StructType schema = spark.table(dest).schema(); + Assert.assertTrue(Arrays.asList(schema.fieldNames()).contains(newCol2)); + + // reads should succeed without any exceptions + List results2 = sql("select * from %s order by id", dest); + Assert.assertTrue(results2.size() > 0); + assertEquals("Output must match", results2, expected2); + } + + @Test + public void testAddColumnOnMigratedTableAtMiddle() throws Exception { + Assume.assumeTrue("Cannot migrate to a hadoop based catalog", !type.equals("hadoop")); + Assume.assumeTrue( + "Can only migrate from Spark Session Catalog", catalog.name().equals("spark_catalog")); + String source = sourceName("test_add_column_migrated_table_middle"); + String dest = source; + createSourceTable(CREATE_PARQUET, source); + + // migrate table + SparkActions.get().migrateTable(source).execute(); + SparkTable sparkTable = loadTable(dest); + Table table = sparkTable.table(); + List expected = sql("select id, null, data from %s order by id", source); + + // test column addition on migrated table + Schema beforeSchema = table.schema(); + String newCol1 = "newCol"; + sparkTable + .table() + .updateSchema() + .addColumn("newCol", Types.IntegerType.get()) + .moveAfter(newCol1, "id") + .commit(); + Schema afterSchema = table.schema(); + Assert.assertNull(beforeSchema.findField(newCol1)); + Assert.assertNotNull(afterSchema.findField(newCol1)); + + // reads should succeed + List results = sql("select * from %s order by id", dest); + Assert.assertTrue(results.size() > 0); + assertEquals("Output must match", results, expected); + } + + @Test + public void removeColumnsAtEnd() throws Exception { + Assume.assumeTrue("Cannot migrate to a hadoop based catalog", !type.equals("hadoop")); + Assume.assumeTrue( + "Can only migrate from Spark Session Catalog", catalog.name().equals("spark_catalog")); + String source = sourceName("test_remove_column_migrated_table"); + String dest = source; + + String colName1 = "newCol1"; + String colName2 = "newCol2"; + File location = temp.newFolder(); + spark + .range(10) + .selectExpr("cast(id as INT)", "CAST(id as INT) " + colName1, "CAST(id as INT) " + colName2) + .write() + .mode(SaveMode.Overwrite) + .saveAsTable(dest); + List expected1 = sql("select id, %s from %s order by id", colName1, source); + List expected2 = sql("select id from %s order by id", source); + + // migrate table + SparkActions.get().migrateTable(source).execute(); + SparkTable sparkTable = loadTable(dest); + Table table = sparkTable.table(); + + // test column removal on migrated table + Schema beforeSchema = table.schema(); + sparkTable.table().updateSchema().deleteColumn(colName1).commit(); + Schema afterSchema = table.schema(); + Assert.assertNotNull(beforeSchema.findField(colName1)); + Assert.assertNull(afterSchema.findField(colName1)); + + // reads should succeed without any exceptions + List results1 = sql("select * from %s order by id", dest); + Assert.assertTrue(results1.size() > 0); + assertEquals("Output must match", expected1, results1); + + sql("ALTER TABLE %s DROP COLUMN %s", dest, colName2); + StructType schema = spark.table(dest).schema(); + Assert.assertFalse(Arrays.asList(schema.fieldNames()).contains(colName2)); + + // reads should succeed without any exceptions + List results2 = sql("select * from %s order by id", dest); + Assert.assertTrue(results2.size() > 0); + assertEquals("Output must match", expected2, results2); + } + + @Test + public void removeColumnFromMiddle() throws Exception { + Assume.assumeTrue("Cannot migrate to a hadoop based catalog", !type.equals("hadoop")); + Assume.assumeTrue( + "Can only migrate from Spark Session Catalog", catalog.name().equals("spark_catalog")); + String source = sourceName("test_remove_column_migrated_table_from_middle"); + String dest = source; + String dropColumnName = "col1"; + + spark + .range(10) + .selectExpr( + "cast(id as INT)", "CAST(id as INT) as " + dropColumnName, "CAST(id as INT) as col2") + .write() + .mode(SaveMode.Overwrite) + .saveAsTable(dest); + List expected = sql("select id, col2 from %s order by id", source); + + // migrate table + SparkActions.get().migrateTable(source).execute(); + + // drop column + sql("ALTER TABLE %s DROP COLUMN %s", dest, "col1"); + StructType schema = spark.table(dest).schema(); + Assert.assertFalse(Arrays.asList(schema.fieldNames()).contains(dropColumnName)); + + // reads should return same output as that of non-iceberg table + List results = sql("select * from %s order by id", dest); + Assert.assertTrue(results.size() > 0); + assertEquals("Output must match", expected, results); + } + + @Test + public void testMigrateUnpartitioned() throws Exception { + Assume.assumeTrue("Cannot migrate to a hadoop based catalog", !type.equals("hadoop")); + Assume.assumeTrue( + "Can only migrate from Spark Session Catalog", catalog.name().equals("spark_catalog")); + String source = sourceName("test_migrate_unpartitioned_table"); + String dest = source; + createSourceTable(CREATE_PARQUET, source); + assertMigratedFileCount(SparkActions.get().migrateTable(source), source, dest); + } + + @Test + public void testSnapshotPartitioned() throws Exception { + Assume.assumeTrue( + "Cannot snapshot with arbitrary location in a hadoop based catalog", + !type.equals("hadoop")); + File location = temp.newFolder(); + String source = sourceName("test_snapshot_partitioned_table"); + String dest = destName("iceberg_snapshot_partitioned"); + createSourceTable(CREATE_PARTITIONED_PARQUET, source); + assertSnapshotFileCount( + SparkActions.get().snapshotTable(source).as(dest).tableLocation(location.toString()), + source, + dest); + assertIsolatedSnapshot(source, dest); + } + + @Test + public void testSnapshotUnpartitioned() throws Exception { + Assume.assumeTrue( + "Cannot snapshot with arbitrary location in a hadoop based catalog", + !type.equals("hadoop")); + File location = temp.newFolder(); + String source = sourceName("test_snapshot_unpartitioned_table"); + String dest = destName("iceberg_snapshot_unpartitioned"); + createSourceTable(CREATE_PARQUET, source); + assertSnapshotFileCount( + SparkActions.get().snapshotTable(source).as(dest).tableLocation(location.toString()), + source, + dest); + assertIsolatedSnapshot(source, dest); + } + + @Test + public void testSnapshotHiveTable() throws Exception { + Assume.assumeTrue( + "Cannot snapshot with arbitrary location in a hadoop based catalog", + !type.equals("hadoop")); + File location = temp.newFolder(); + String source = sourceName("snapshot_hive_table"); + String dest = destName("iceberg_snapshot_hive_table"); + createSourceTable(CREATE_HIVE_EXTERNAL_PARQUET, source); + assertSnapshotFileCount( + SparkActions.get().snapshotTable(source).as(dest).tableLocation(location.toString()), + source, + dest); + assertIsolatedSnapshot(source, dest); + } + + @Test + public void testMigrateHiveTable() throws Exception { + Assume.assumeTrue("Cannot migrate to a hadoop based catalog", !type.equals("hadoop")); + String source = sourceName("migrate_hive_table"); + String dest = source; + createSourceTable(CREATE_HIVE_EXTERNAL_PARQUET, source); + assertMigratedFileCount(SparkActions.get().migrateTable(source), source, dest); + } + + @Test + public void testSnapshotManagedHiveTable() throws Exception { + Assume.assumeTrue("Cannot migrate to a hadoop based catalog", !type.equals("hadoop")); + File location = temp.newFolder(); + String source = sourceName("snapshot_managed_hive_table"); + String dest = destName("iceberg_snapshot_managed_hive_table"); + createSourceTable(CREATE_HIVE_PARQUET, source); + assertSnapshotFileCount( + SparkActions.get().snapshotTable(source).as(dest).tableLocation(location.toString()), + source, + dest); + assertIsolatedSnapshot(source, dest); + } + + @Test + public void testMigrateManagedHiveTable() throws Exception { + Assume.assumeTrue("Cannot migrate to a hadoop based catalog", !type.equals("hadoop")); + File location = temp.newFolder(); + String source = sourceName("migrate_managed_hive_table"); + String dest = destName("iceberg_migrate_managed_hive_table"); + createSourceTable(CREATE_HIVE_PARQUET, source); + assertSnapshotFileCount( + SparkActions.get().snapshotTable(source).as(dest).tableLocation(location.toString()), + source, + dest); + } + + @Test + public void testProperties() throws Exception { + String source = sourceName("test_properties_table"); + String dest = destName("iceberg_properties"); + Map props = Maps.newHashMap(); + props.put("city", "New Orleans"); + props.put("note", "Jazz"); + createSourceTable(CREATE_PARQUET, source); + for (Map.Entry keyValue : props.entrySet()) { + spark.sql( + String.format( + "ALTER TABLE %s SET TBLPROPERTIES (\"%s\" = \"%s\")", + source, keyValue.getKey(), keyValue.getValue())); + } + assertSnapshotFileCount( + SparkActions.get().snapshotTable(source).as(dest).tableProperty("dogs", "sundance"), + source, + dest); + SparkTable table = loadTable(dest); + + Map expectedProps = Maps.newHashMap(); + expectedProps.putAll(props); + expectedProps.put("dogs", "sundance"); + + for (Map.Entry entry : expectedProps.entrySet()) { + Assert.assertTrue( + "Created table missing property " + entry.getKey(), + table.properties().containsKey(entry.getKey())); + Assert.assertEquals( + "Property value is not the expected value", + entry.getValue(), + table.properties().get(entry.getKey())); + } + } + + @Test + public void testSparkTableReservedProperties() throws Exception { + String destTableName = "iceberg_reserved_properties"; + String source = sourceName("test_reserved_properties_table"); + String dest = destName(destTableName); + createSourceTable(CREATE_PARQUET, source); + assertSnapshotFileCount(SparkActions.get().snapshotTable(source).as(dest), source, dest); + SparkTable table = loadTable(dest); + // set sort orders + table.table().replaceSortOrder().asc("id").desc("data").commit(); + + String[] keys = {"provider", "format", "current-snapshot-id", "location", "sort-order"}; + + for (String entry : keys) { + Assert.assertTrue( + "Created table missing reserved property " + entry, + table.properties().containsKey(entry)); + } + + Assert.assertEquals("Unexpected provider", "iceberg", table.properties().get("provider")); + Assert.assertEquals("Unexpected format", "iceberg/parquet", table.properties().get("format")); + Assert.assertNotEquals( + "No current-snapshot-id found", "none", table.properties().get("current-snapshot-id")); + Assert.assertTrue( + "Location isn't correct", table.properties().get("location").endsWith(destTableName)); + + Assert.assertEquals("Unexpected format-version", "1", table.properties().get("format-version")); + table.table().updateProperties().set("format-version", "2").commit(); + Assert.assertEquals("Unexpected format-version", "2", table.properties().get("format-version")); + + Assert.assertEquals( + "Sort-order isn't correct", + "id ASC NULLS FIRST, data DESC NULLS LAST", + table.properties().get("sort-order")); + Assert.assertNull( + "Identifier fields should be null", table.properties().get("identifier-fields")); + + table + .table() + .updateSchema() + .allowIncompatibleChanges() + .requireColumn("id") + .setIdentifierFields("id") + .commit(); + Assert.assertEquals( + "Identifier fields aren't correct", "[id]", table.properties().get("identifier-fields")); + } + + @Test + public void testSnapshotDefaultLocation() throws Exception { + String source = sourceName("test_snapshot_default"); + String dest = destName("iceberg_snapshot_default"); + createSourceTable(CREATE_PARTITIONED_PARQUET, source); + assertSnapshotFileCount(SparkActions.get().snapshotTable(source).as(dest), source, dest); + assertIsolatedSnapshot(source, dest); + } + + @Test + public void schemaEvolutionTestWithSparkAPI() throws Exception { + Assume.assumeTrue("Cannot migrate to a hadoop based catalog", !type.equals("hadoop")); + Assume.assumeTrue( + "Can only migrate from Spark Session Catalog", catalog.name().equals("spark_catalog")); + + File location = temp.newFolder(); + String tblName = sourceName("schema_evolution_test"); + + // Data generation and partition addition + spark + .range(0, 5) + .selectExpr("CAST(id as INT) as col0", "CAST(id AS FLOAT) col2", "CAST(id AS LONG) col3") + .write() + .mode(SaveMode.Append) + .parquet(location.toURI().toString()); + Dataset rowDataset = + spark + .range(6, 10) + .selectExpr( + "CAST(id as INT) as col0", + "CAST(id AS STRING) col1", + "CAST(id AS FLOAT) col2", + "CAST(id AS LONG) col3"); + rowDataset.write().mode(SaveMode.Append).parquet(location.toURI().toString()); + spark + .read() + .schema(rowDataset.schema()) + .parquet(location.toURI().toString()) + .write() + .saveAsTable(tblName); + List expectedBeforeAddColumn = sql("SELECT * FROM %s ORDER BY col0", tblName); + List expectedAfterAddColumn = + sql("SELECT col0, null, col1, col2, col3 FROM %s ORDER BY col0", tblName); + + // Migrate table + SparkActions.get().migrateTable(tblName).execute(); + + // check if iceberg and non-iceberg output + List afterMigarteBeforeAddResults = sql("SELECT * FROM %s ORDER BY col0", tblName); + assertEquals("Output must match", expectedBeforeAddColumn, afterMigarteBeforeAddResults); + + // Update schema and check output correctness + SparkTable sparkTable = loadTable(tblName); + sparkTable + .table() + .updateSchema() + .addColumn("newCol", Types.IntegerType.get()) + .moveAfter("newCol", "col0") + .commit(); + List afterMigarteAfterAddResults = sql("SELECT * FROM %s ORDER BY col0", tblName); + assertEquals("Output must match", expectedAfterAddColumn, afterMigarteAfterAddResults); + } + + @Test + public void schemaEvolutionTestWithSparkSQL() throws Exception { + Assume.assumeTrue("Cannot migrate to a hadoop based catalog", !type.equals("hadoop")); + Assume.assumeTrue( + "Can only migrate from Spark Session Catalog", catalog.name().equals("spark_catalog")); + String tblName = sourceName("schema_evolution_test_sql"); + + // Data generation and partition addition + spark + .range(0, 5) + .selectExpr("CAST(id as INT) col0", "CAST(id AS FLOAT) col1", "CAST(id AS STRING) col2") + .write() + .mode(SaveMode.Append) + .saveAsTable(tblName); + sql("ALTER TABLE %s ADD COLUMN col3 INT", tblName); + spark + .range(6, 10) + .selectExpr( + "CAST(id AS INT) col0", + "CAST(id AS FLOAT) col1", + "CAST(id AS STRING) col2", + "CAST(id AS INT) col3") + .registerTempTable("tempdata"); + sql("INSERT INTO TABLE %s SELECT * FROM tempdata", tblName); + List expectedBeforeAddColumn = sql("SELECT * FROM %s ORDER BY col0", tblName); + List expectedAfterAddColumn = + sql("SELECT col0, null, col1, col2, col3 FROM %s ORDER BY col0", tblName); + + // Migrate table + SparkActions.get().migrateTable(tblName).execute(); + + // check if iceberg and non-iceberg output + List afterMigarteBeforeAddResults = sql("SELECT * FROM %s ORDER BY col0", tblName); + assertEquals("Output must match", expectedBeforeAddColumn, afterMigarteBeforeAddResults); + + // Update schema and check output correctness + SparkTable sparkTable = loadTable(tblName); + sparkTable + .table() + .updateSchema() + .addColumn("newCol", Types.IntegerType.get()) + .moveAfter("newCol", "col0") + .commit(); + List afterMigarteAfterAddResults = sql("SELECT * FROM %s ORDER BY col0", tblName); + assertEquals("Output must match", expectedAfterAddColumn, afterMigarteAfterAddResults); + } + + @Test + public void testHiveStyleThreeLevelList() throws Exception { + threeLevelList(true); + } + + @Test + public void testThreeLevelList() throws Exception { + threeLevelList(false); + } + + @Test + public void testHiveStyleThreeLevelListWithNestedStruct() throws Exception { + threeLevelListWithNestedStruct(true); + } + + @Test + public void testThreeLevelListWithNestedStruct() throws Exception { + threeLevelListWithNestedStruct(false); + } + + @Test + public void testHiveStyleThreeLevelLists() throws Exception { + threeLevelLists(true); + } + + @Test + public void testThreeLevelLists() throws Exception { + threeLevelLists(false); + } + + @Test + public void testHiveStyleStructOfThreeLevelLists() throws Exception { + structOfThreeLevelLists(true); + } + + @Test + public void testStructOfThreeLevelLists() throws Exception { + structOfThreeLevelLists(false); + } + + @Test + public void testTwoLevelList() throws IOException { + spark.conf().set("spark.sql.parquet.writeLegacyFormat", true); + + String tableName = sourceName("testTwoLevelList"); + File location = temp.newFolder(); + + StructType sparkSchema = + new StructType( + new StructField[] { + new StructField( + "col1", + new ArrayType( + new StructType( + new StructField[] { + new StructField("col2", DataTypes.IntegerType, false, Metadata.empty()) + }), + false), + true, + Metadata.empty()) + }); + + // even though this list looks like three level list, it is actually a 2-level list where the + // items are + // structs with 1 field. + String expectedParquetSchema = + "message spark_schema {\n" + + " optional group col1 (LIST) {\n" + + " repeated group array {\n" + + " required int32 col2;\n" + + " }\n" + + " }\n" + + "}\n"; + + // generate parquet file with required schema + List testData = Collections.singletonList("{\"col1\": [{\"col2\": 1}]}"); + spark + .read() + .schema(sparkSchema) + .json(JavaSparkContext.fromSparkContext(spark.sparkContext()).parallelize(testData)) + .coalesce(1) + .write() + .format("parquet") + .mode(SaveMode.Append) + .save(location.getPath()); + + File parquetFile = + Arrays.stream( + Preconditions.checkNotNull( + location.listFiles( + new FilenameFilter() { + @Override + public boolean accept(File dir, String name) { + return name.endsWith("parquet"); + } + }))) + .findAny() + .get(); + + // verify generated parquet file has expected schema + ParquetFileReader pqReader = + ParquetFileReader.open( + HadoopInputFile.fromPath( + new Path(parquetFile.getPath()), spark.sessionState().newHadoopConf())); + MessageType schema = pqReader.getFooter().getFileMetaData().getSchema(); + Assert.assertEquals(MessageTypeParser.parseMessageType(expectedParquetSchema), schema); + + // create sql table on top of it + sql( + "CREATE EXTERNAL TABLE %s (col1 ARRAY>)" + + " STORED AS parquet" + + " LOCATION '%s'", + tableName, location); + List expected = sql("select array(struct(1))"); + + // migrate table + SparkActions.get().migrateTable(tableName).execute(); + + // check migrated table is returning expected result + List results = sql("SELECT * FROM %s", tableName); + Assert.assertTrue(results.size() > 0); + assertEquals("Output must match", expected, results); + } + + private void threeLevelList(boolean useLegacyMode) throws Exception { + spark.conf().set("spark.sql.parquet.writeLegacyFormat", useLegacyMode); + + String tableName = sourceName(String.format("threeLevelList_%s", useLegacyMode)); + File location = temp.newFolder(); + sql( + "CREATE TABLE %s (col1 ARRAY>)" + " STORED AS parquet" + " LOCATION '%s'", + tableName, location); + + int testValue = 12345; + sql("INSERT INTO %s VALUES (ARRAY(STRUCT(%s)))", tableName, testValue); + List expected = sql(String.format("SELECT * FROM %s", tableName)); + + // migrate table + SparkActions.get().migrateTable(tableName).execute(); + + // check migrated table is returning expected result + List results = sql("SELECT * FROM %s", tableName); + Assert.assertTrue(results.size() > 0); + assertEquals("Output must match", expected, results); + } + + private void threeLevelListWithNestedStruct(boolean useLegacyMode) throws Exception { + spark.conf().set("spark.sql.parquet.writeLegacyFormat", useLegacyMode); + + String tableName = + sourceName(String.format("threeLevelListWithNestedStruct_%s", useLegacyMode)); + File location = temp.newFolder(); + sql( + "CREATE TABLE %s (col1 ARRAY>>)" + + " STORED AS parquet" + + " LOCATION '%s'", + tableName, location); + + int testValue = 12345; + sql("INSERT INTO %s VALUES (ARRAY(STRUCT(STRUCT(%s))))", tableName, testValue); + List expected = sql(String.format("SELECT * FROM %s", tableName)); + + // migrate table + SparkActions.get().migrateTable(tableName).execute(); + + // check migrated table is returning expected result + List results = sql("SELECT * FROM %s", tableName); + Assert.assertTrue(results.size() > 0); + assertEquals("Output must match", expected, results); + } + + private void threeLevelLists(boolean useLegacyMode) throws Exception { + spark.conf().set("spark.sql.parquet.writeLegacyFormat", useLegacyMode); + + String tableName = sourceName(String.format("threeLevelLists_%s", useLegacyMode)); + File location = temp.newFolder(); + sql( + "CREATE TABLE %s (col1 ARRAY>, col3 ARRAY>)" + + " STORED AS parquet" + + " LOCATION '%s'", + tableName, location); + + int testValue1 = 12345; + int testValue2 = 987654; + sql( + "INSERT INTO %s VALUES (ARRAY(STRUCT(%s)), ARRAY(STRUCT(%s)))", + tableName, testValue1, testValue2); + List expected = sql(String.format("SELECT * FROM %s", tableName)); + + // migrate table + SparkActions.get().migrateTable(tableName).execute(); + + // check migrated table is returning expected result + List results = sql("SELECT * FROM %s", tableName); + Assert.assertTrue(results.size() > 0); + assertEquals("Output must match", expected, results); + } + + private void structOfThreeLevelLists(boolean useLegacyMode) throws Exception { + spark.conf().set("spark.sql.parquet.writeLegacyFormat", useLegacyMode); + + String tableName = sourceName(String.format("structOfThreeLevelLists_%s", useLegacyMode)); + File location = temp.newFolder(); + sql( + "CREATE TABLE %s (col1 STRUCT>>)" + + " STORED AS parquet" + + " LOCATION '%s'", + tableName, location); + + int testValue1 = 12345; + sql("INSERT INTO %s VALUES (STRUCT(STRUCT(ARRAY(STRUCT(%s)))))", tableName, testValue1); + List expected = sql(String.format("SELECT * FROM %s", tableName)); + + // migrate table + SparkActions.get().migrateTable(tableName).execute(); + + // check migrated table is returning expected result + List results = sql("SELECT * FROM %s", tableName); + Assert.assertTrue(results.size() > 0); + assertEquals("Output must match", expected, results); + } + + private SparkTable loadTable(String name) throws NoSuchTableException, ParseException { + return (SparkTable) + catalog.loadTable(Spark3Util.catalogAndIdentifier(spark, name).identifier()); + } + + private CatalogTable loadSessionTable(String name) + throws NoSuchTableException, NoSuchDatabaseException, ParseException { + Identifier identifier = Spark3Util.catalogAndIdentifier(spark, name).identifier(); + Some namespace = Some.apply(identifier.namespace()[0]); + return spark + .sessionState() + .catalog() + .getTableMetadata(new TableIdentifier(identifier.name(), namespace)); + } + + private void createSourceTable(String createStatement, String tableName) + throws IOException, NoSuchTableException, NoSuchDatabaseException, ParseException { + File location = temp.newFolder(); + spark.sql(String.format(createStatement, tableName, location)); + CatalogTable table = loadSessionTable(tableName); + Seq partitionColumns = table.partitionColumnNames(); + String format = table.provider().get(); + spark + .table(baseTableName) + .write() + .mode(SaveMode.Append) + .format(format) + .partitionBy(partitionColumns.toSeq()) + .saveAsTable(tableName); + } + + // Counts the number of files in the source table, makes sure the same files exist in the + // destination table + private void assertMigratedFileCount(MigrateTable migrateAction, String source, String dest) + throws NoSuchTableException, NoSuchDatabaseException, ParseException { + long expectedFiles = expectedFilesCount(source); + MigrateTable.Result migratedFiles = migrateAction.execute(); + validateTables(source, dest); + Assert.assertEquals( + "Expected number of migrated files", expectedFiles, migratedFiles.migratedDataFilesCount()); + } + + // Counts the number of files in the source table, makes sure the same files exist in the + // destination table + private void assertSnapshotFileCount(SnapshotTable snapshotTable, String source, String dest) + throws NoSuchTableException, NoSuchDatabaseException, ParseException { + long expectedFiles = expectedFilesCount(source); + SnapshotTable.Result snapshotTableResult = snapshotTable.execute(); + validateTables(source, dest); + Assert.assertEquals( + "Expected number of imported snapshot files", + expectedFiles, + snapshotTableResult.importedDataFilesCount()); + } + + private void validateTables(String source, String dest) + throws NoSuchTableException, ParseException { + List expected = spark.table(source).collectAsList(); + SparkTable destTable = loadTable(dest); + Assert.assertEquals( + "Provider should be iceberg", + "iceberg", + destTable.properties().get(TableCatalog.PROP_PROVIDER)); + List actual = spark.table(dest).collectAsList(); + Assert.assertTrue( + String.format( + "Rows in migrated table did not match\nExpected :%s rows \nFound :%s", + expected, actual), + expected.containsAll(actual) && actual.containsAll(expected)); + } + + private long expectedFilesCount(String source) + throws NoSuchDatabaseException, NoSuchTableException, ParseException { + CatalogTable sourceTable = loadSessionTable(source); + List uris; + if (sourceTable.partitionColumnNames().size() == 0) { + uris = Lists.newArrayList(); + uris.add(sourceTable.location()); + } else { + Seq catalogTablePartitionSeq = + spark + .sessionState() + .catalog() + .listPartitions(sourceTable.identifier(), Option.apply(null)); + uris = + JavaConverters.seqAsJavaList(catalogTablePartitionSeq).stream() + .map(CatalogTablePartition::location) + .collect(Collectors.toList()); + } + return uris.stream() + .flatMap( + uri -> + FileUtils.listFiles( + Paths.get(uri).toFile(), TrueFileFilter.INSTANCE, TrueFileFilter.INSTANCE) + .stream()) + .filter(file -> !file.toString().endsWith("crc") && !file.toString().contains("_SUCCESS")) + .count(); + } + + // Insert records into the destination, makes sure those records exist and source table is + // unchanged + private void assertIsolatedSnapshot(String source, String dest) { + List expected = spark.sql(String.format("SELECT * FROM %s", source)).collectAsList(); + + List extraData = Lists.newArrayList(new SimpleRecord(4, "d")); + Dataset df = spark.createDataFrame(extraData, SimpleRecord.class); + df.write().format("iceberg").mode("append").saveAsTable(dest); + + List result = spark.sql(String.format("SELECT * FROM %s", source)).collectAsList(); + Assert.assertEquals( + "No additional rows should be added to the original table", expected.size(), result.size()); + + List snapshot = + spark + .sql(String.format("SELECT * FROM %s WHERE id = 4 AND data = 'd'", dest)) + .collectAsList(); + Assert.assertEquals("Added row not found in snapshot", 1, snapshot.size()); + } + + private String sourceName(String source) { + return NAMESPACE + "." + catalog.name() + "_" + type + "_" + source; + } + + private String destName(String dest) { + if (catalog.name().equals("spark_catalog")) { + return NAMESPACE + "." + catalog.name() + "_" + type + "_" + dest; + } else { + return catalog.name() + "." + NAMESPACE + "." + catalog.name() + "_" + type + "_" + dest; + } + } +} diff --git a/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/actions/TestDeleteReachableFilesAction.java b/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/actions/TestDeleteReachableFilesAction.java new file mode 100644 index 000000000000..154e940519d1 --- /dev/null +++ b/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/actions/TestDeleteReachableFilesAction.java @@ -0,0 +1,394 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.actions; + +import static org.apache.iceberg.types.Types.NestedField.optional; + +import java.io.File; +import java.util.Set; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.Executors; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.stream.StreamSupport; +import org.apache.hadoop.conf.Configuration; +import org.apache.iceberg.AssertHelpers; +import org.apache.iceberg.DataFile; +import org.apache.iceberg.DataFiles; +import org.apache.iceberg.DeleteFile; +import org.apache.iceberg.FileMetadata; +import org.apache.iceberg.HasTableOperations; +import org.apache.iceberg.PartitionSpec; +import org.apache.iceberg.Schema; +import org.apache.iceberg.Table; +import org.apache.iceberg.TableProperties; +import org.apache.iceberg.TestHelpers; +import org.apache.iceberg.actions.ActionsProvider; +import org.apache.iceberg.actions.DeleteOrphanFiles; +import org.apache.iceberg.actions.DeleteReachableFiles; +import org.apache.iceberg.exceptions.ValidationException; +import org.apache.iceberg.hadoop.HadoopTables; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableSet; +import org.apache.iceberg.relocated.com.google.common.collect.Iterables; +import org.apache.iceberg.relocated.com.google.common.collect.Lists; +import org.apache.iceberg.relocated.com.google.common.collect.Maps; +import org.apache.iceberg.relocated.com.google.common.collect.Sets; +import org.apache.iceberg.spark.SparkTestBase; +import org.apache.iceberg.types.Types; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.TemporaryFolder; + +public class TestDeleteReachableFilesAction extends SparkTestBase { + private static final HadoopTables TABLES = new HadoopTables(new Configuration()); + private static final Schema SCHEMA = + new Schema( + optional(1, "c1", Types.IntegerType.get()), + optional(2, "c2", Types.StringType.get()), + optional(3, "c3", Types.StringType.get())); + private static final int SHUFFLE_PARTITIONS = 2; + + private static final PartitionSpec SPEC = PartitionSpec.builderFor(SCHEMA).identity("c1").build(); + + static final DataFile FILE_A = + DataFiles.builder(SPEC) + .withPath("/path/to/data-a.parquet") + .withFileSizeInBytes(10) + .withPartition(TestHelpers.Row.of(0)) + .withRecordCount(1) + .build(); + static final DataFile FILE_B = + DataFiles.builder(SPEC) + .withPath("/path/to/data-b.parquet") + .withFileSizeInBytes(10) + .withPartition(TestHelpers.Row.of(1)) + .withRecordCount(1) + .build(); + static final DataFile FILE_C = + DataFiles.builder(SPEC) + .withPath("/path/to/data-c.parquet") + .withFileSizeInBytes(10) + .withPartition(TestHelpers.Row.of(2)) + .withRecordCount(1) + .build(); + static final DataFile FILE_D = + DataFiles.builder(SPEC) + .withPath("/path/to/data-d.parquet") + .withFileSizeInBytes(10) + .withPartition(TestHelpers.Row.of(3)) + .withRecordCount(1) + .build(); + static final DeleteFile FILE_A_POS_DELETES = + FileMetadata.deleteFileBuilder(SPEC) + .ofPositionDeletes() + .withPath("/path/to/data-a-pos-deletes.parquet") + .withFileSizeInBytes(10) + .withPartition(TestHelpers.Row.of(0)) + .withRecordCount(1) + .build(); + static final DeleteFile FILE_A_EQ_DELETES = + FileMetadata.deleteFileBuilder(SPEC) + .ofEqualityDeletes() + .withPath("/path/to/data-a-eq-deletes.parquet") + .withFileSizeInBytes(10) + .withPartition(TestHelpers.Row.of(0)) + .withRecordCount(1) + .build(); + + @Rule public TemporaryFolder temp = new TemporaryFolder(); + + private Table table; + + @Before + public void setupTableLocation() throws Exception { + File tableDir = temp.newFolder(); + String tableLocation = tableDir.toURI().toString(); + this.table = TABLES.create(SCHEMA, SPEC, Maps.newHashMap(), tableLocation); + spark.conf().set("spark.sql.shuffle.partitions", SHUFFLE_PARTITIONS); + } + + private void checkRemoveFilesResults( + long expectedDatafiles, + long expectedPosDeleteFiles, + long expectedEqDeleteFiles, + long expectedManifestsDeleted, + long expectedManifestListsDeleted, + long expectedOtherFilesDeleted, + DeleteReachableFiles.Result results) { + Assert.assertEquals( + "Incorrect number of manifest files deleted", + expectedManifestsDeleted, + results.deletedManifestsCount()); + Assert.assertEquals( + "Incorrect number of datafiles deleted", + expectedDatafiles, + results.deletedDataFilesCount()); + Assert.assertEquals( + "Incorrect number of position delete files deleted", + expectedPosDeleteFiles, + results.deletedPositionDeleteFilesCount()); + Assert.assertEquals( + "Incorrect number of equality delete files deleted", + expectedEqDeleteFiles, + results.deletedEqualityDeleteFilesCount()); + Assert.assertEquals( + "Incorrect number of manifest lists deleted", + expectedManifestListsDeleted, + results.deletedManifestListsCount()); + Assert.assertEquals( + "Incorrect number of other lists deleted", + expectedOtherFilesDeleted, + results.deletedOtherFilesCount()); + } + + @Test + public void dataFilesCleanupWithParallelTasks() { + table.newFastAppend().appendFile(FILE_A).commit(); + + table.newFastAppend().appendFile(FILE_B).commit(); + + table.newRewrite().rewriteFiles(ImmutableSet.of(FILE_B), ImmutableSet.of(FILE_D)).commit(); + + table.newRewrite().rewriteFiles(ImmutableSet.of(FILE_A), ImmutableSet.of(FILE_C)).commit(); + + Set deletedFiles = ConcurrentHashMap.newKeySet(); + Set deleteThreads = ConcurrentHashMap.newKeySet(); + AtomicInteger deleteThreadsIndex = new AtomicInteger(0); + + DeleteReachableFiles.Result result = + sparkActions() + .deleteReachableFiles(metadataLocation(table)) + .io(table.io()) + .executeDeleteWith( + Executors.newFixedThreadPool( + 4, + runnable -> { + Thread thread = new Thread(runnable); + thread.setName("remove-files-" + deleteThreadsIndex.getAndIncrement()); + thread.setDaemon( + true); // daemon threads will be terminated abruptly when the JVM exits + return thread; + })) + .deleteWith( + s -> { + deleteThreads.add(Thread.currentThread().getName()); + deletedFiles.add(s); + }) + .execute(); + + // Verifies that the delete methods ran in the threads created by the provided ExecutorService + // ThreadFactory + Assert.assertEquals( + deleteThreads, + Sets.newHashSet("remove-files-0", "remove-files-1", "remove-files-2", "remove-files-3")); + + Lists.newArrayList(FILE_A, FILE_B, FILE_C, FILE_D) + .forEach( + file -> + Assert.assertTrue( + "FILE_A should be deleted", deletedFiles.contains(FILE_A.path().toString()))); + checkRemoveFilesResults(4L, 0, 0, 6L, 4L, 6, result); + } + + @Test + public void testWithExpiringDanglingStageCommit() { + table.location(); + // `A` commit + table.newAppend().appendFile(FILE_A).commit(); + + // `B` staged commit + table.newAppend().appendFile(FILE_B).stageOnly().commit(); + + // `C` commit + table.newAppend().appendFile(FILE_C).commit(); + + DeleteReachableFiles.Result result = + sparkActions().deleteReachableFiles(metadataLocation(table)).io(table.io()).execute(); + + checkRemoveFilesResults(3L, 0, 0, 3L, 3L, 5, result); + } + + @Test + public void testRemoveFileActionOnEmptyTable() { + DeleteReachableFiles.Result result = + sparkActions().deleteReachableFiles(metadataLocation(table)).io(table.io()).execute(); + + checkRemoveFilesResults(0, 0, 0, 0, 0, 2, result); + } + + @Test + public void testRemoveFilesActionWithReducedVersionsTable() { + table.updateProperties().set(TableProperties.METADATA_PREVIOUS_VERSIONS_MAX, "2").commit(); + table.newAppend().appendFile(FILE_A).commit(); + + table.newAppend().appendFile(FILE_B).commit(); + + table.newAppend().appendFile(FILE_B).commit(); + + table.newAppend().appendFile(FILE_C).commit(); + + table.newAppend().appendFile(FILE_D).commit(); + + DeleteReachableFiles baseRemoveFilesSparkAction = + sparkActions().deleteReachableFiles(metadataLocation(table)).io(table.io()); + DeleteReachableFiles.Result result = baseRemoveFilesSparkAction.execute(); + + checkRemoveFilesResults(4, 0, 0, 5, 5, 8, result); + } + + @Test + public void testRemoveFilesAction() { + table.newAppend().appendFile(FILE_A).commit(); + + table.newAppend().appendFile(FILE_B).commit(); + + DeleteReachableFiles baseRemoveFilesSparkAction = + sparkActions().deleteReachableFiles(metadataLocation(table)).io(table.io()); + checkRemoveFilesResults(2, 0, 0, 2, 2, 4, baseRemoveFilesSparkAction.execute()); + } + + @Test + public void testPositionDeleteFiles() { + table.updateProperties().set(TableProperties.FORMAT_VERSION, "2").commit(); + + table.newAppend().appendFile(FILE_A).commit(); + + table.newAppend().appendFile(FILE_B).commit(); + + table.newRowDelta().addDeletes(FILE_A_POS_DELETES).commit(); + + DeleteReachableFiles baseRemoveFilesSparkAction = + sparkActions().deleteReachableFiles(metadataLocation(table)).io(table.io()); + checkRemoveFilesResults(2, 1, 0, 3, 3, 6, baseRemoveFilesSparkAction.execute()); + } + + @Test + public void testEqualityDeleteFiles() { + table.updateProperties().set(TableProperties.FORMAT_VERSION, "2").commit(); + + table.newAppend().appendFile(FILE_A).commit(); + + table.newAppend().appendFile(FILE_B).commit(); + + table.newRowDelta().addDeletes(FILE_A_EQ_DELETES).commit(); + + DeleteReachableFiles baseRemoveFilesSparkAction = + sparkActions().deleteReachableFiles(metadataLocation(table)).io(table.io()); + checkRemoveFilesResults(2, 0, 1, 3, 3, 6, baseRemoveFilesSparkAction.execute()); + } + + @Test + public void testRemoveFilesActionWithDefaultIO() { + table.newAppend().appendFile(FILE_A).commit(); + + table.newAppend().appendFile(FILE_B).commit(); + + // IO not set explicitly on removeReachableFiles action + // IO defaults to HadoopFileIO + DeleteReachableFiles baseRemoveFilesSparkAction = + sparkActions().deleteReachableFiles(metadataLocation(table)); + checkRemoveFilesResults(2, 0, 0, 2, 2, 4, baseRemoveFilesSparkAction.execute()); + } + + @Test + public void testUseLocalIterator() { + table.newFastAppend().appendFile(FILE_A).commit(); + + table.newOverwrite().deleteFile(FILE_A).addFile(FILE_B).commit(); + + table.newFastAppend().appendFile(FILE_C).commit(); + + int jobsBefore = spark.sparkContext().dagScheduler().nextJobId().get(); + + withSQLConf( + ImmutableMap.of("spark.sql.adaptive.enabled", "false"), + () -> { + DeleteReachableFiles.Result results = + sparkActions() + .deleteReachableFiles(metadataLocation(table)) + .io(table.io()) + .option("stream-results", "true") + .execute(); + + int jobsAfter = spark.sparkContext().dagScheduler().nextJobId().get(); + int totalJobsRun = jobsAfter - jobsBefore; + + checkRemoveFilesResults(3L, 0, 0, 4L, 3L, 5, results); + + Assert.assertEquals( + "Expected total jobs to be equal to total number of shuffle partitions", + totalJobsRun, + SHUFFLE_PARTITIONS); + }); + } + + @Test + public void testIgnoreMetadataFilesNotFound() { + table.updateProperties().set(TableProperties.METADATA_PREVIOUS_VERSIONS_MAX, "1").commit(); + + table.newAppend().appendFile(FILE_A).commit(); + // There are three metadata json files at this point + DeleteOrphanFiles.Result result = + sparkActions().deleteOrphanFiles(table).olderThan(System.currentTimeMillis()).execute(); + + Assert.assertEquals("Should delete 1 file", 1, Iterables.size(result.orphanFileLocations())); + Assert.assertTrue( + "Should remove v1 file", + StreamSupport.stream(result.orphanFileLocations().spliterator(), false) + .anyMatch(file -> file.contains("v1.metadata.json"))); + + DeleteReachableFiles baseRemoveFilesSparkAction = + sparkActions().deleteReachableFiles(metadataLocation(table)).io(table.io()); + DeleteReachableFiles.Result res = baseRemoveFilesSparkAction.execute(); + + checkRemoveFilesResults(1, 0, 0, 1, 1, 4, res); + } + + @Test + public void testEmptyIOThrowsException() { + DeleteReachableFiles baseRemoveFilesSparkAction = + sparkActions().deleteReachableFiles(metadataLocation(table)).io(null); + AssertHelpers.assertThrows( + "FileIO can't be null in DeleteReachableFiles action", + IllegalArgumentException.class, + "File IO cannot be null", + baseRemoveFilesSparkAction::execute); + } + + @Test + public void testRemoveFilesActionWhenGarbageCollectionDisabled() { + table.updateProperties().set(TableProperties.GC_ENABLED, "false").commit(); + + AssertHelpers.assertThrows( + "Should complain about removing files when GC is disabled", + ValidationException.class, + "Cannot delete files: GC is disabled (deleting files may corrupt other tables)", + () -> sparkActions().deleteReachableFiles(metadataLocation(table)).execute()); + } + + private String metadataLocation(Table tbl) { + return ((HasTableOperations) tbl).operations().current().metadataFileLocation(); + } + + private ActionsProvider sparkActions() { + return SparkActions.get(); + } +} diff --git a/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/actions/TestExpireSnapshotsAction.java b/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/actions/TestExpireSnapshotsAction.java new file mode 100644 index 000000000000..25e6f6486aae --- /dev/null +++ b/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/actions/TestExpireSnapshotsAction.java @@ -0,0 +1,1354 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.actions; + +import static org.apache.iceberg.types.Types.NestedField.optional; + +import java.io.File; +import java.io.IOException; +import java.util.List; +import java.util.Set; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.Executors; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.stream.Collectors; +import org.apache.hadoop.conf.Configuration; +import org.apache.iceberg.AssertHelpers; +import org.apache.iceberg.BaseTable; +import org.apache.iceberg.DataFile; +import org.apache.iceberg.DataFiles; +import org.apache.iceberg.DeleteFile; +import org.apache.iceberg.FileMetadata; +import org.apache.iceberg.ManifestFile; +import org.apache.iceberg.PartitionSpec; +import org.apache.iceberg.ReachableFileUtil; +import org.apache.iceberg.Schema; +import org.apache.iceberg.Snapshot; +import org.apache.iceberg.Table; +import org.apache.iceberg.TableMetadata; +import org.apache.iceberg.TableProperties; +import org.apache.iceberg.actions.ExpireSnapshots; +import org.apache.iceberg.exceptions.ValidationException; +import org.apache.iceberg.expressions.Expressions; +import org.apache.iceberg.hadoop.HadoopTables; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableSet; +import org.apache.iceberg.relocated.com.google.common.collect.Iterables; +import org.apache.iceberg.relocated.com.google.common.collect.Lists; +import org.apache.iceberg.relocated.com.google.common.collect.Maps; +import org.apache.iceberg.relocated.com.google.common.collect.Sets; +import org.apache.iceberg.spark.SparkTestBase; +import org.apache.iceberg.spark.data.TestHelpers; +import org.apache.iceberg.types.Types; +import org.apache.spark.sql.Dataset; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.TemporaryFolder; + +public class TestExpireSnapshotsAction extends SparkTestBase { + private static final HadoopTables TABLES = new HadoopTables(new Configuration()); + private static final Schema SCHEMA = + new Schema( + optional(1, "c1", Types.IntegerType.get()), + optional(2, "c2", Types.StringType.get()), + optional(3, "c3", Types.StringType.get())); + private static final int SHUFFLE_PARTITIONS = 2; + + private static final PartitionSpec SPEC = PartitionSpec.builderFor(SCHEMA).identity("c1").build(); + + static final DataFile FILE_A = + DataFiles.builder(SPEC) + .withPath("/path/to/data-a.parquet") + .withFileSizeInBytes(10) + .withPartitionPath("c1=0") // easy way to set partition data for now + .withRecordCount(1) + .build(); + static final DataFile FILE_B = + DataFiles.builder(SPEC) + .withPath("/path/to/data-b.parquet") + .withFileSizeInBytes(10) + .withPartitionPath("c1=1") // easy way to set partition data for now + .withRecordCount(1) + .build(); + static final DataFile FILE_C = + DataFiles.builder(SPEC) + .withPath("/path/to/data-c.parquet") + .withFileSizeInBytes(10) + .withPartitionPath("c1=2") // easy way to set partition data for now + .withRecordCount(1) + .build(); + static final DataFile FILE_D = + DataFiles.builder(SPEC) + .withPath("/path/to/data-d.parquet") + .withFileSizeInBytes(10) + .withPartitionPath("c1=3") // easy way to set partition data for now + .withRecordCount(1) + .build(); + static final DeleteFile FILE_A_POS_DELETES = + FileMetadata.deleteFileBuilder(SPEC) + .ofPositionDeletes() + .withPath("/path/to/data-a-pos-deletes.parquet") + .withFileSizeInBytes(10) + .withPartitionPath("c1=0") // easy way to set partition data for now + .withRecordCount(1) + .build(); + static final DeleteFile FILE_A_EQ_DELETES = + FileMetadata.deleteFileBuilder(SPEC) + .ofEqualityDeletes() + .withPath("/path/to/data-a-eq-deletes.parquet") + .withFileSizeInBytes(10) + .withPartitionPath("c1=0") // easy way to set partition data for now + .withRecordCount(1) + .build(); + + @Rule public TemporaryFolder temp = new TemporaryFolder(); + + private File tableDir; + private String tableLocation; + private Table table; + + @Before + public void setupTableLocation() throws Exception { + this.tableDir = temp.newFolder(); + this.tableLocation = tableDir.toURI().toString(); + this.table = TABLES.create(SCHEMA, SPEC, Maps.newHashMap(), tableLocation); + spark.conf().set("spark.sql.shuffle.partitions", SHUFFLE_PARTITIONS); + } + + private Long rightAfterSnapshot() { + return rightAfterSnapshot(table.currentSnapshot().snapshotId()); + } + + private Long rightAfterSnapshot(long snapshotId) { + Long end = System.currentTimeMillis(); + while (end <= table.snapshot(snapshotId).timestampMillis()) { + end = System.currentTimeMillis(); + } + return end; + } + + private void checkExpirationResults( + long expectedDatafiles, + long expectedPosDeleteFiles, + long expectedEqDeleteFiles, + long expectedManifestsDeleted, + long expectedManifestListsDeleted, + ExpireSnapshots.Result results) { + + Assert.assertEquals( + "Incorrect number of manifest files deleted", + expectedManifestsDeleted, + results.deletedManifestsCount()); + Assert.assertEquals( + "Incorrect number of datafiles deleted", + expectedDatafiles, + results.deletedDataFilesCount()); + Assert.assertEquals( + "Incorrect number of pos deletefiles deleted", + expectedPosDeleteFiles, + results.deletedPositionDeleteFilesCount()); + Assert.assertEquals( + "Incorrect number of eq deletefiles deleted", + expectedEqDeleteFiles, + results.deletedEqualityDeleteFilesCount()); + Assert.assertEquals( + "Incorrect number of manifest lists deleted", + expectedManifestListsDeleted, + results.deletedManifestListsCount()); + } + + @Test + public void testFilesCleaned() throws Exception { + table.newFastAppend().appendFile(FILE_A).commit(); + + table.newOverwrite().deleteFile(FILE_A).addFile(FILE_B).commit(); + + table.newFastAppend().appendFile(FILE_C).commit(); + + long end = rightAfterSnapshot(); + + ExpireSnapshots.Result results = + SparkActions.get().expireSnapshots(table).expireOlderThan(end).execute(); + + Assert.assertEquals( + "Table does not have 1 snapshot after expiration", 1, Iterables.size(table.snapshots())); + + checkExpirationResults(1L, 0L, 0L, 1L, 2L, results); + } + + @Test + public void dataFilesCleanupWithParallelTasks() throws IOException { + + table.newFastAppend().appendFile(FILE_A).commit(); + + table.newFastAppend().appendFile(FILE_B).commit(); + + table.newRewrite().rewriteFiles(ImmutableSet.of(FILE_B), ImmutableSet.of(FILE_D)).commit(); + + table.newRewrite().rewriteFiles(ImmutableSet.of(FILE_A), ImmutableSet.of(FILE_C)).commit(); + + long t4 = rightAfterSnapshot(); + + Set deletedFiles = Sets.newHashSet(); + Set deleteThreads = ConcurrentHashMap.newKeySet(); + AtomicInteger deleteThreadsIndex = new AtomicInteger(0); + + ExpireSnapshots.Result result = + SparkActions.get() + .expireSnapshots(table) + .executeDeleteWith( + Executors.newFixedThreadPool( + 4, + runnable -> { + Thread thread = new Thread(runnable); + thread.setName("remove-snapshot-" + deleteThreadsIndex.getAndIncrement()); + thread.setDaemon( + true); // daemon threads will be terminated abruptly when the JVM exits + return thread; + })) + .expireOlderThan(t4) + .deleteWith( + s -> { + deleteThreads.add(Thread.currentThread().getName()); + deletedFiles.add(s); + }) + .execute(); + + // Verifies that the delete methods ran in the threads created by the provided ExecutorService + // ThreadFactory + Assert.assertEquals( + deleteThreads, + Sets.newHashSet( + "remove-snapshot-0", "remove-snapshot-1", "remove-snapshot-2", "remove-snapshot-3")); + + Assert.assertTrue("FILE_A should be deleted", deletedFiles.contains(FILE_A.path().toString())); + Assert.assertTrue("FILE_B should be deleted", deletedFiles.contains(FILE_B.path().toString())); + + checkExpirationResults(2L, 0L, 0L, 3L, 3L, result); + } + + @Test + public void testNoFilesDeletedWhenNoSnapshotsExpired() throws Exception { + table.newFastAppend().appendFile(FILE_A).commit(); + + ExpireSnapshots.Result results = SparkActions.get().expireSnapshots(table).execute(); + checkExpirationResults(0L, 0L, 0L, 0L, 0L, results); + } + + @Test + public void testCleanupRepeatedOverwrites() throws Exception { + table.newFastAppend().appendFile(FILE_A).commit(); + + for (int i = 0; i < 10; i++) { + table.newOverwrite().deleteFile(FILE_A).addFile(FILE_B).commit(); + + table.newOverwrite().deleteFile(FILE_B).addFile(FILE_A).commit(); + } + + long end = rightAfterSnapshot(); + ExpireSnapshots.Result results = + SparkActions.get().expireSnapshots(table).expireOlderThan(end).execute(); + checkExpirationResults(1L, 0L, 0L, 39L, 20L, results); + } + + @Test + public void testRetainLastWithExpireOlderThan() { + table + .newAppend() + .appendFile(FILE_A) // data_bucket=0 + .commit(); + long firstSnapshotId = table.currentSnapshot().snapshotId(); + long t1 = System.currentTimeMillis(); + while (t1 <= table.currentSnapshot().timestampMillis()) { + t1 = System.currentTimeMillis(); + } + + table + .newAppend() + .appendFile(FILE_B) // data_bucket=1 + .commit(); + + table + .newAppend() + .appendFile(FILE_C) // data_bucket=2 + .commit(); + + long t3 = rightAfterSnapshot(); + + // Retain last 2 snapshots + SparkActions.get().expireSnapshots(table).expireOlderThan(t3).retainLast(2).execute(); + + Assert.assertEquals( + "Should have two snapshots.", 2, Lists.newArrayList(table.snapshots()).size()); + Assert.assertEquals( + "First snapshot should not present.", null, table.snapshot(firstSnapshotId)); + } + + @Test + public void testExpireTwoSnapshotsById() throws Exception { + table + .newAppend() + .appendFile(FILE_A) // data_bucket=0 + .commit(); + long firstSnapshotId = table.currentSnapshot().snapshotId(); + + table + .newAppend() + .appendFile(FILE_B) // data_bucket=1 + .commit(); + + long secondSnapshotID = table.currentSnapshot().snapshotId(); + + table + .newAppend() + .appendFile(FILE_C) // data_bucket=2 + .commit(); + + // Retain last 2 snapshots + ExpireSnapshots.Result result = + SparkActions.get() + .expireSnapshots(table) + .expireSnapshotId(firstSnapshotId) + .expireSnapshotId(secondSnapshotID) + .execute(); + + Assert.assertEquals( + "Should have one snapshots.", 1, Lists.newArrayList(table.snapshots()).size()); + Assert.assertEquals( + "First snapshot should not present.", null, table.snapshot(firstSnapshotId)); + Assert.assertEquals( + "Second snapshot should not be present.", null, table.snapshot(secondSnapshotID)); + + checkExpirationResults(0L, 0L, 0L, 0L, 2L, result); + } + + @Test + public void testRetainLastWithExpireById() { + table + .newAppend() + .appendFile(FILE_A) // data_bucket=0 + .commit(); + long firstSnapshotId = table.currentSnapshot().snapshotId(); + + table + .newAppend() + .appendFile(FILE_B) // data_bucket=1 + .commit(); + + table + .newAppend() + .appendFile(FILE_C) // data_bucket=2 + .commit(); + + // Retain last 3 snapshots, but explicitly remove the first snapshot + ExpireSnapshots.Result result = + SparkActions.get() + .expireSnapshots(table) + .expireSnapshotId(firstSnapshotId) + .retainLast(3) + .execute(); + + Assert.assertEquals( + "Should have two snapshots.", 2, Lists.newArrayList(table.snapshots()).size()); + Assert.assertEquals( + "First snapshot should not present.", null, table.snapshot(firstSnapshotId)); + checkExpirationResults(0L, 0L, 0L, 0L, 1L, result); + } + + @Test + public void testRetainLastWithTooFewSnapshots() { + table + .newAppend() + .appendFile(FILE_A) // data_bucket=0 + .appendFile(FILE_B) // data_bucket=1 + .commit(); + long firstSnapshotId = table.currentSnapshot().snapshotId(); + + table + .newAppend() + .appendFile(FILE_C) // data_bucket=2 + .commit(); + + long t2 = rightAfterSnapshot(); + + // Retain last 3 snapshots + ExpireSnapshots.Result result = + SparkActions.get().expireSnapshots(table).expireOlderThan(t2).retainLast(3).execute(); + + Assert.assertEquals( + "Should have two snapshots", 2, Lists.newArrayList(table.snapshots()).size()); + Assert.assertEquals( + "First snapshot should still present", + firstSnapshotId, + table.snapshot(firstSnapshotId).snapshotId()); + checkExpirationResults(0L, 0L, 0L, 0L, 0L, result); + } + + @Test + public void testRetainLastKeepsExpiringSnapshot() { + table + .newAppend() + .appendFile(FILE_A) // data_bucket=0 + .commit(); + + table + .newAppend() + .appendFile(FILE_B) // data_bucket=1 + .commit(); + + Snapshot secondSnapshot = table.currentSnapshot(); + + table + .newAppend() + .appendFile(FILE_C) // data_bucket=2 + .commit(); + + table + .newAppend() + .appendFile(FILE_D) // data_bucket=3 + .commit(); + + // Retain last 2 snapshots and expire older than t3 + ExpireSnapshots.Result result = + SparkActions.get() + .expireSnapshots(table) + .expireOlderThan(secondSnapshot.timestampMillis()) + .retainLast(2) + .execute(); + + Assert.assertEquals( + "Should have three snapshots.", 3, Lists.newArrayList(table.snapshots()).size()); + Assert.assertNotNull( + "Second snapshot should present.", table.snapshot(secondSnapshot.snapshotId())); + checkExpirationResults(0L, 0L, 0L, 0L, 1L, result); + } + + @Test + public void testExpireSnapshotsWithDisabledGarbageCollection() { + table.updateProperties().set(TableProperties.GC_ENABLED, "false").commit(); + + table.newAppend().appendFile(FILE_A).commit(); + + AssertHelpers.assertThrows( + "Should complain about expiring snapshots", + ValidationException.class, + "Cannot expire snapshots: GC is disabled", + () -> SparkActions.get().expireSnapshots(table)); + } + + @Test + public void testExpireOlderThanMultipleCalls() { + table + .newAppend() + .appendFile(FILE_A) // data_bucket=0 + .commit(); + + table + .newAppend() + .appendFile(FILE_B) // data_bucket=1 + .commit(); + + Snapshot secondSnapshot = table.currentSnapshot(); + + table + .newAppend() + .appendFile(FILE_C) // data_bucket=2 + .commit(); + + Snapshot thirdSnapshot = table.currentSnapshot(); + + // Retain last 2 snapshots and expire older than t3 + ExpireSnapshots.Result result = + SparkActions.get() + .expireSnapshots(table) + .expireOlderThan(secondSnapshot.timestampMillis()) + .expireOlderThan(thirdSnapshot.timestampMillis()) + .execute(); + + Assert.assertEquals( + "Should have one snapshots.", 1, Lists.newArrayList(table.snapshots()).size()); + Assert.assertNull( + "Second snapshot should not present.", table.snapshot(secondSnapshot.snapshotId())); + checkExpirationResults(0L, 0L, 0L, 0L, 2L, result); + } + + @Test + public void testRetainLastMultipleCalls() { + table + .newAppend() + .appendFile(FILE_A) // data_bucket=0 + .commit(); + + table + .newAppend() + .appendFile(FILE_B) // data_bucket=1 + .commit(); + + Snapshot secondSnapshot = table.currentSnapshot(); + + table + .newAppend() + .appendFile(FILE_C) // data_bucket=2 + .commit(); + + long t3 = rightAfterSnapshot(); + + // Retain last 2 snapshots and expire older than t3 + ExpireSnapshots.Result result = + SparkActions.get() + .expireSnapshots(table) + .expireOlderThan(t3) + .retainLast(2) + .retainLast(1) + .execute(); + + Assert.assertEquals( + "Should have one snapshots.", 1, Lists.newArrayList(table.snapshots()).size()); + Assert.assertNull( + "Second snapshot should not present.", table.snapshot(secondSnapshot.snapshotId())); + checkExpirationResults(0L, 0L, 0L, 0L, 2L, result); + } + + @Test + public void testRetainZeroSnapshots() { + AssertHelpers.assertThrows( + "Should fail retain 0 snapshots " + "because number of snapshots to retain cannot be zero", + IllegalArgumentException.class, + "Number of snapshots to retain must be at least 1, cannot be: 0", + () -> SparkActions.get().expireSnapshots(table).retainLast(0).execute()); + } + + @Test + public void testScanExpiredManifestInValidSnapshotAppend() { + table.newAppend().appendFile(FILE_A).appendFile(FILE_B).commit(); + + table.newOverwrite().addFile(FILE_C).deleteFile(FILE_A).commit(); + + table.newAppend().appendFile(FILE_D).commit(); + + long t3 = rightAfterSnapshot(); + + Set deletedFiles = Sets.newHashSet(); + + ExpireSnapshots.Result result = + SparkActions.get() + .expireSnapshots(table) + .expireOlderThan(t3) + .deleteWith(deletedFiles::add) + .execute(); + + Assert.assertTrue("FILE_A should be deleted", deletedFiles.contains(FILE_A.path().toString())); + checkExpirationResults(1L, 0L, 0L, 1L, 2L, result); + } + + @Test + public void testScanExpiredManifestInValidSnapshotFastAppend() { + table + .updateProperties() + .set(TableProperties.MANIFEST_MERGE_ENABLED, "true") + .set(TableProperties.MANIFEST_MIN_MERGE_COUNT, "1") + .commit(); + + table.newAppend().appendFile(FILE_A).appendFile(FILE_B).commit(); + + table.newOverwrite().addFile(FILE_C).deleteFile(FILE_A).commit(); + + table.newFastAppend().appendFile(FILE_D).commit(); + + long t3 = rightAfterSnapshot(); + + Set deletedFiles = Sets.newHashSet(); + + ExpireSnapshots.Result result = + SparkActions.get() + .expireSnapshots(table) + .expireOlderThan(t3) + .deleteWith(deletedFiles::add) + .execute(); + + Assert.assertTrue("FILE_A should be deleted", deletedFiles.contains(FILE_A.path().toString())); + checkExpirationResults(1L, 0L, 0L, 1L, 2L, result); + } + + /** + * Test on table below, and expiring the staged commit `B` using `expireOlderThan` API. Table: A - + * C ` B (staged) + */ + @Test + public void testWithExpiringDanglingStageCommit() { + // `A` commit + table.newAppend().appendFile(FILE_A).commit(); + + // `B` staged commit + table.newAppend().appendFile(FILE_B).stageOnly().commit(); + + TableMetadata base = ((BaseTable) table).operations().current(); + Snapshot snapshotA = base.snapshots().get(0); + Snapshot snapshotB = base.snapshots().get(1); + + // `C` commit + table.newAppend().appendFile(FILE_C).commit(); + + Set deletedFiles = Sets.newHashSet(); + + // Expire all commits including dangling staged snapshot. + ExpireSnapshots.Result result = + SparkActions.get() + .expireSnapshots(table) + .deleteWith(deletedFiles::add) + .expireOlderThan(snapshotB.timestampMillis() + 1) + .execute(); + + checkExpirationResults(1L, 0L, 0L, 1L, 2L, result); + + Set expectedDeletes = Sets.newHashSet(); + expectedDeletes.add(snapshotA.manifestListLocation()); + + // Files should be deleted of dangling staged snapshot + snapshotB + .addedDataFiles(table.io()) + .forEach( + i -> { + expectedDeletes.add(i.path().toString()); + }); + + // ManifestList should be deleted too + expectedDeletes.add(snapshotB.manifestListLocation()); + snapshotB + .dataManifests(table.io()) + .forEach( + file -> { + // Only the manifest of B should be deleted. + if (file.snapshotId() == snapshotB.snapshotId()) { + expectedDeletes.add(file.path()); + } + }); + Assert.assertSame( + "Files deleted count should be expected", expectedDeletes.size(), deletedFiles.size()); + // Take the diff + expectedDeletes.removeAll(deletedFiles); + Assert.assertTrue("Exactly same files should be deleted", expectedDeletes.isEmpty()); + } + + /** + * Expire cherry-pick the commit as shown below, when `B` is in table's current state Table: A - B + * - C <--current snapshot `- D (source=B) + */ + @Test + public void testWithCherryPickTableSnapshot() { + // `A` commit + table.newAppend().appendFile(FILE_A).commit(); + Snapshot snapshotA = table.currentSnapshot(); + + // `B` commit + Set deletedAFiles = Sets.newHashSet(); + table.newOverwrite().addFile(FILE_B).deleteFile(FILE_A).deleteWith(deletedAFiles::add).commit(); + Assert.assertTrue("No files should be physically deleted", deletedAFiles.isEmpty()); + + // pick the snapshot 'B` + Snapshot snapshotB = table.currentSnapshot(); + + // `C` commit to let cherry-pick take effect, and avoid fast-forward of `B` with cherry-pick + table.newAppend().appendFile(FILE_C).commit(); + Snapshot snapshotC = table.currentSnapshot(); + + // Move the table back to `A` + table.manageSnapshots().setCurrentSnapshot(snapshotA.snapshotId()).commit(); + + // Generate A -> `D (B)` + table.manageSnapshots().cherrypick(snapshotB.snapshotId()).commit(); + Snapshot snapshotD = table.currentSnapshot(); + + // Move the table back to `C` + table.manageSnapshots().setCurrentSnapshot(snapshotC.snapshotId()).commit(); + List deletedFiles = Lists.newArrayList(); + + // Expire `C` + ExpireSnapshots.Result result = + SparkActions.get() + .expireSnapshots(table) + .deleteWith(deletedFiles::add) + .expireOlderThan(snapshotC.timestampMillis() + 1) + .execute(); + + // Make sure no dataFiles are deleted for the B, C, D snapshot + Lists.newArrayList(snapshotB, snapshotC, snapshotD) + .forEach( + i -> { + i.addedDataFiles(table.io()) + .forEach( + item -> { + Assert.assertFalse(deletedFiles.contains(item.path().toString())); + }); + }); + + checkExpirationResults(1L, 0L, 0L, 2L, 2L, result); + } + + /** + * Test on table below, and expiring `B` which is not in current table state. 1) Expire `B` 2) All + * commit Table: A - C - D (B) ` B (staged) + */ + @Test + public void testWithExpiringStagedThenCherrypick() { + // `A` commit + table.newAppend().appendFile(FILE_A).commit(); + + // `B` commit + table.newAppend().appendFile(FILE_B).stageOnly().commit(); + + // pick the snapshot that's staged but not committed + TableMetadata base = ((BaseTable) table).operations().current(); + Snapshot snapshotB = base.snapshots().get(1); + + // `C` commit to let cherry-pick take effect, and avoid fast-forward of `B` with cherry-pick + table.newAppend().appendFile(FILE_C).commit(); + + // `D (B)` cherry-pick commit + table.manageSnapshots().cherrypick(snapshotB.snapshotId()).commit(); + + base = ((BaseTable) table).operations().current(); + Snapshot snapshotD = base.snapshots().get(3); + + List deletedFiles = Lists.newArrayList(); + + // Expire `B` commit. + ExpireSnapshots.Result firstResult = + SparkActions.get() + .expireSnapshots(table) + .deleteWith(deletedFiles::add) + .expireSnapshotId(snapshotB.snapshotId()) + .execute(); + + // Make sure no dataFiles are deleted for the staged snapshot + Lists.newArrayList(snapshotB) + .forEach( + i -> { + i.addedDataFiles(table.io()) + .forEach( + item -> { + Assert.assertFalse(deletedFiles.contains(item.path().toString())); + }); + }); + checkExpirationResults(0L, 0L, 0L, 1L, 1L, firstResult); + + // Expire all snapshots including cherry-pick + ExpireSnapshots.Result secondResult = + SparkActions.get() + .expireSnapshots(table) + .deleteWith(deletedFiles::add) + .expireOlderThan(table.currentSnapshot().timestampMillis() + 1) + .execute(); + + // Make sure no dataFiles are deleted for the staged and cherry-pick + Lists.newArrayList(snapshotB, snapshotD) + .forEach( + i -> { + i.addedDataFiles(table.io()) + .forEach( + item -> { + Assert.assertFalse(deletedFiles.contains(item.path().toString())); + }); + }); + checkExpirationResults(0L, 0L, 0L, 0L, 2L, secondResult); + } + + @Test + public void testExpireOlderThan() { + table.newAppend().appendFile(FILE_A).commit(); + + Snapshot firstSnapshot = table.currentSnapshot(); + + rightAfterSnapshot(); + + table.newAppend().appendFile(FILE_B).commit(); + + long snapshotId = table.currentSnapshot().snapshotId(); + + long tAfterCommits = rightAfterSnapshot(); + + Set deletedFiles = Sets.newHashSet(); + + ExpireSnapshots.Result result = + SparkActions.get() + .expireSnapshots(table) + .expireOlderThan(tAfterCommits) + .deleteWith(deletedFiles::add) + .execute(); + + Assert.assertEquals( + "Expire should not change current snapshot", + snapshotId, + table.currentSnapshot().snapshotId()); + Assert.assertNull( + "Expire should remove the oldest snapshot", table.snapshot(firstSnapshot.snapshotId())); + Assert.assertEquals( + "Should remove only the expired manifest list location", + Sets.newHashSet(firstSnapshot.manifestListLocation()), + deletedFiles); + + checkExpirationResults(0, 0, 0, 0, 1, result); + } + + @Test + public void testExpireOlderThanWithDelete() { + table.newAppend().appendFile(FILE_A).commit(); + + Snapshot firstSnapshot = table.currentSnapshot(); + Assert.assertEquals( + "Should create one manifest", 1, firstSnapshot.allManifests(table.io()).size()); + + rightAfterSnapshot(); + + table.newDelete().deleteFile(FILE_A).commit(); + + Snapshot secondSnapshot = table.currentSnapshot(); + Assert.assertEquals( + "Should create replace manifest with a rewritten manifest", + 1, + secondSnapshot.allManifests(table.io()).size()); + + table.newAppend().appendFile(FILE_B).commit(); + + rightAfterSnapshot(); + + long snapshotId = table.currentSnapshot().snapshotId(); + + long tAfterCommits = rightAfterSnapshot(); + + Set deletedFiles = Sets.newHashSet(); + + ExpireSnapshots.Result result = + SparkActions.get() + .expireSnapshots(table) + .expireOlderThan(tAfterCommits) + .deleteWith(deletedFiles::add) + .execute(); + + Assert.assertEquals( + "Expire should not change current snapshot", + snapshotId, + table.currentSnapshot().snapshotId()); + Assert.assertNull( + "Expire should remove the oldest snapshot", table.snapshot(firstSnapshot.snapshotId())); + Assert.assertNull( + "Expire should remove the second oldest snapshot", + table.snapshot(secondSnapshot.snapshotId())); + + Assert.assertEquals( + "Should remove expired manifest lists and deleted data file", + Sets.newHashSet( + firstSnapshot.manifestListLocation(), // snapshot expired + firstSnapshot + .allManifests(table.io()) + .get(0) + .path(), // manifest was rewritten for delete + secondSnapshot.manifestListLocation(), // snapshot expired + secondSnapshot + .allManifests(table.io()) + .get(0) + .path(), // manifest contained only deletes, was dropped + FILE_A.path()), // deleted + deletedFiles); + + checkExpirationResults(1, 0, 0, 2, 2, result); + } + + @Test + public void testExpireOlderThanWithDeleteInMergedManifests() { + // merge every commit + table.updateProperties().set(TableProperties.MANIFEST_MIN_MERGE_COUNT, "0").commit(); + + table.newAppend().appendFile(FILE_A).appendFile(FILE_B).commit(); + + Snapshot firstSnapshot = table.currentSnapshot(); + Assert.assertEquals( + "Should create one manifest", 1, firstSnapshot.allManifests(table.io()).size()); + + rightAfterSnapshot(); + + table + .newDelete() + .deleteFile(FILE_A) // FILE_B is still in the dataset + .commit(); + + Snapshot secondSnapshot = table.currentSnapshot(); + Assert.assertEquals( + "Should replace manifest with a rewritten manifest", + 1, + secondSnapshot.allManifests(table.io()).size()); + + table + .newFastAppend() // do not merge to keep the last snapshot's manifest valid + .appendFile(FILE_C) + .commit(); + + rightAfterSnapshot(); + + long snapshotId = table.currentSnapshot().snapshotId(); + + long tAfterCommits = rightAfterSnapshot(); + + Set deletedFiles = Sets.newHashSet(); + + ExpireSnapshots.Result result = + SparkActions.get() + .expireSnapshots(table) + .expireOlderThan(tAfterCommits) + .deleteWith(deletedFiles::add) + .execute(); + + Assert.assertEquals( + "Expire should not change current snapshot", + snapshotId, + table.currentSnapshot().snapshotId()); + Assert.assertNull( + "Expire should remove the oldest snapshot", table.snapshot(firstSnapshot.snapshotId())); + Assert.assertNull( + "Expire should remove the second oldest snapshot", + table.snapshot(secondSnapshot.snapshotId())); + + Assert.assertEquals( + "Should remove expired manifest lists and deleted data file", + Sets.newHashSet( + firstSnapshot.manifestListLocation(), // snapshot expired + firstSnapshot + .allManifests(table.io()) + .get(0) + .path(), // manifest was rewritten for delete + secondSnapshot.manifestListLocation(), // snapshot expired + FILE_A.path()), // deleted + deletedFiles); + + checkExpirationResults(1, 0, 0, 1, 2, result); + } + + @Test + public void testExpireOlderThanWithRollback() { + // merge every commit + table.updateProperties().set(TableProperties.MANIFEST_MIN_MERGE_COUNT, "0").commit(); + + table.newAppend().appendFile(FILE_A).appendFile(FILE_B).commit(); + + Snapshot firstSnapshot = table.currentSnapshot(); + Assert.assertEquals( + "Should create one manifest", 1, firstSnapshot.allManifests(table.io()).size()); + + rightAfterSnapshot(); + + table.newDelete().deleteFile(FILE_B).commit(); + + Snapshot secondSnapshot = table.currentSnapshot(); + Set secondSnapshotManifests = + Sets.newHashSet(secondSnapshot.allManifests(table.io())); + secondSnapshotManifests.removeAll(firstSnapshot.allManifests(table.io())); + Assert.assertEquals( + "Should add one new manifest for append", 1, secondSnapshotManifests.size()); + + table.manageSnapshots().rollbackTo(firstSnapshot.snapshotId()).commit(); + + long tAfterCommits = rightAfterSnapshot(secondSnapshot.snapshotId()); + + long snapshotId = table.currentSnapshot().snapshotId(); + + Set deletedFiles = Sets.newHashSet(); + + ExpireSnapshots.Result result = + SparkActions.get() + .expireSnapshots(table) + .expireOlderThan(tAfterCommits) + .deleteWith(deletedFiles::add) + .execute(); + + Assert.assertEquals( + "Expire should not change current snapshot", + snapshotId, + table.currentSnapshot().snapshotId()); + Assert.assertNotNull( + "Expire should keep the oldest snapshot, current", + table.snapshot(firstSnapshot.snapshotId())); + Assert.assertNull( + "Expire should remove the orphaned snapshot", table.snapshot(secondSnapshot.snapshotId())); + + Assert.assertEquals( + "Should remove expired manifest lists and reverted appended data file", + Sets.newHashSet( + secondSnapshot.manifestListLocation(), // snapshot expired + Iterables.getOnlyElement(secondSnapshotManifests) + .path()), // manifest is no longer referenced + deletedFiles); + + checkExpirationResults(0, 0, 0, 1, 1, result); + } + + @Test + public void testExpireOlderThanWithRollbackAndMergedManifests() { + table.newAppend().appendFile(FILE_A).commit(); + + Snapshot firstSnapshot = table.currentSnapshot(); + Assert.assertEquals( + "Should create one manifest", 1, firstSnapshot.allManifests(table.io()).size()); + + rightAfterSnapshot(); + + table.newAppend().appendFile(FILE_B).commit(); + + Snapshot secondSnapshot = table.currentSnapshot(); + Set secondSnapshotManifests = + Sets.newHashSet(secondSnapshot.allManifests(table.io())); + secondSnapshotManifests.removeAll(firstSnapshot.allManifests(table.io())); + Assert.assertEquals( + "Should add one new manifest for append", 1, secondSnapshotManifests.size()); + + table.manageSnapshots().rollbackTo(firstSnapshot.snapshotId()).commit(); + + long tAfterCommits = rightAfterSnapshot(secondSnapshot.snapshotId()); + + long snapshotId = table.currentSnapshot().snapshotId(); + + Set deletedFiles = Sets.newHashSet(); + + ExpireSnapshots.Result result = + SparkActions.get() + .expireSnapshots(table) + .expireOlderThan(tAfterCommits) + .deleteWith(deletedFiles::add) + .execute(); + + Assert.assertEquals( + "Expire should not change current snapshot", + snapshotId, + table.currentSnapshot().snapshotId()); + Assert.assertNotNull( + "Expire should keep the oldest snapshot, current", + table.snapshot(firstSnapshot.snapshotId())); + Assert.assertNull( + "Expire should remove the orphaned snapshot", table.snapshot(secondSnapshot.snapshotId())); + + Assert.assertEquals( + "Should remove expired manifest lists and reverted appended data file", + Sets.newHashSet( + secondSnapshot.manifestListLocation(), // snapshot expired + Iterables.getOnlyElement(secondSnapshotManifests) + .path(), // manifest is no longer referenced + FILE_B.path()), // added, but rolled back + deletedFiles); + + checkExpirationResults(1, 0, 0, 1, 1, result); + } + + @Test + public void testExpireOlderThanWithDeleteFile() { + table + .updateProperties() + .set(TableProperties.FORMAT_VERSION, "2") + .set(TableProperties.MANIFEST_MERGE_ENABLED, "false") + .commit(); + + // Add Data File + table.newAppend().appendFile(FILE_A).commit(); + Snapshot firstSnapshot = table.currentSnapshot(); + + // Add POS Delete + table.newRowDelta().addDeletes(FILE_A_POS_DELETES).commit(); + Snapshot secondSnapshot = table.currentSnapshot(); + + // Add EQ Delete + table.newRowDelta().addDeletes(FILE_A_EQ_DELETES).commit(); + Snapshot thirdSnapshot = table.currentSnapshot(); + + // Move files to DELETED + table.newDelete().deleteFromRowFilter(Expressions.alwaysTrue()).commit(); + Snapshot fourthSnapshot = table.currentSnapshot(); + + long afterAllDeleted = rightAfterSnapshot(); + + table.newAppend().appendFile(FILE_B).commit(); + + Set deletedFiles = Sets.newHashSet(); + + ExpireSnapshots.Result result = + SparkActions.get() + .expireSnapshots(table) + .expireOlderThan(afterAllDeleted) + .deleteWith(deletedFiles::add) + .execute(); + + Set expectedDeletes = + Sets.newHashSet( + firstSnapshot.manifestListLocation(), + secondSnapshot.manifestListLocation(), + thirdSnapshot.manifestListLocation(), + fourthSnapshot.manifestListLocation(), + FILE_A.path().toString(), + FILE_A_POS_DELETES.path().toString(), + FILE_A_EQ_DELETES.path().toString()); + + expectedDeletes.addAll( + thirdSnapshot.allManifests(table.io()).stream() + .map(ManifestFile::path) + .map(CharSequence::toString) + .collect(Collectors.toSet())); + // Delete operation (fourth snapshot) generates new manifest files + expectedDeletes.addAll( + fourthSnapshot.allManifests(table.io()).stream() + .map(ManifestFile::path) + .map(CharSequence::toString) + .collect(Collectors.toSet())); + + Assert.assertEquals( + "Should remove expired manifest lists and deleted data file", + expectedDeletes, + deletedFiles); + + checkExpirationResults(1, 1, 1, 6, 4, result); + } + + @Test + public void testExpireOnEmptyTable() { + Set deletedFiles = Sets.newHashSet(); + + // table has no data, testing ExpireSnapshots should not fail with no snapshot + ExpireSnapshots.Result result = + SparkActions.get() + .expireSnapshots(table) + .expireOlderThan(System.currentTimeMillis()) + .deleteWith(deletedFiles::add) + .execute(); + + checkExpirationResults(0, 0, 0, 0, 0, result); + } + + @Test + public void testExpireAction() { + table.newAppend().appendFile(FILE_A).commit(); + + Snapshot firstSnapshot = table.currentSnapshot(); + + rightAfterSnapshot(); + + table.newAppend().appendFile(FILE_B).commit(); + + long snapshotId = table.currentSnapshot().snapshotId(); + + long tAfterCommits = rightAfterSnapshot(); + + Set deletedFiles = Sets.newHashSet(); + + ExpireSnapshotsSparkAction action = + SparkActions.get() + .expireSnapshots(table) + .expireOlderThan(tAfterCommits) + .deleteWith(deletedFiles::add); + Dataset pendingDeletes = action.expireFiles(); + + List pending = pendingDeletes.collectAsList(); + + Assert.assertEquals( + "Should not change current snapshot", snapshotId, table.currentSnapshot().snapshotId()); + Assert.assertNull( + "Should remove the oldest snapshot", table.snapshot(firstSnapshot.snapshotId())); + + Assert.assertEquals("Pending deletes should contain one row", 1, pending.size()); + Assert.assertEquals( + "Pending delete should be the expired manifest list location", + firstSnapshot.manifestListLocation(), + pending.get(0).getPath()); + Assert.assertEquals( + "Pending delete should be a manifest list", "Manifest List", pending.get(0).getType()); + + Assert.assertEquals("Should not delete any files", 0, deletedFiles.size()); + + Assert.assertEquals( + "Multiple calls to expire should return the same count of deleted files", + pendingDeletes.count(), + action.expireFiles().count()); + } + + @Test + public void testUseLocalIterator() { + table.newFastAppend().appendFile(FILE_A).commit(); + + table.newOverwrite().deleteFile(FILE_A).addFile(FILE_B).commit(); + + table.newFastAppend().appendFile(FILE_C).commit(); + + long end = rightAfterSnapshot(); + + int jobsBeforeStreamResults = spark.sparkContext().dagScheduler().nextJobId().get(); + + withSQLConf( + ImmutableMap.of("spark.sql.adaptive.enabled", "false"), + () -> { + ExpireSnapshots.Result results = + SparkActions.get() + .expireSnapshots(table) + .expireOlderThan(end) + .option("stream-results", "true") + .execute(); + + int jobsAfterStreamResults = spark.sparkContext().dagScheduler().nextJobId().get(); + int jobsRunDuringStreamResults = jobsAfterStreamResults - jobsBeforeStreamResults; + + checkExpirationResults(1L, 0L, 0L, 1L, 2L, results); + + Assert.assertEquals( + "Expected total number of jobs with stream-results should match the expected number", + 4L, + jobsRunDuringStreamResults); + }); + } + + @Test + public void testExpireAfterExecute() { + table + .newAppend() + .appendFile(FILE_A) // data_bucket=0 + .commit(); + + rightAfterSnapshot(); + + table + .newAppend() + .appendFile(FILE_B) // data_bucket=1 + .commit(); + + table + .newAppend() + .appendFile(FILE_C) // data_bucket=2 + .commit(); + + long t3 = rightAfterSnapshot(); + + ExpireSnapshotsSparkAction action = SparkActions.get().expireSnapshots(table); + + action.expireOlderThan(t3).retainLast(2); + + ExpireSnapshots.Result result = action.execute(); + checkExpirationResults(0L, 0L, 0L, 0L, 1L, result); + + List typedExpiredFiles = action.expireFiles().collectAsList(); + Assert.assertEquals("Expired results must match", 1, typedExpiredFiles.size()); + + List untypedExpiredFiles = action.expireFiles().collectAsList(); + Assert.assertEquals("Expired results must match", 1, untypedExpiredFiles.size()); + } + + @Test + public void testExpireFileDeletionMostExpired() { + textExpireAllCheckFilesDeleted(5, 2); + } + + @Test + public void testExpireFileDeletionMostRetained() { + textExpireAllCheckFilesDeleted(2, 5); + } + + public void textExpireAllCheckFilesDeleted(int dataFilesExpired, int dataFilesRetained) { + // Add data files to be expired + Set dataFiles = Sets.newHashSet(); + for (int i = 0; i < dataFilesExpired; i++) { + DataFile df = + DataFiles.builder(SPEC) + .withPath(String.format("/path/to/data-expired-%d.parquet", i)) + .withFileSizeInBytes(10) + .withPartitionPath("c1=1") + .withRecordCount(1) + .build(); + dataFiles.add(df.path().toString()); + table.newFastAppend().appendFile(df).commit(); + } + + // Delete them all, these will be deleted on expire snapshot + table.newDelete().deleteFromRowFilter(Expressions.alwaysTrue()).commit(); + // Clears "DELETED" manifests + table.newDelete().deleteFromRowFilter(Expressions.alwaysTrue()).commit(); + + Set manifestsBefore = TestHelpers.reachableManifestPaths(table); + + // Add data files to be retained, which are not deleted. + for (int i = 0; i < dataFilesRetained; i++) { + DataFile df = + DataFiles.builder(SPEC) + .withPath(String.format("/path/to/data-retained-%d.parquet", i)) + .withFileSizeInBytes(10) + .withPartitionPath("c1=1") + .withRecordCount(1) + .build(); + table.newFastAppend().appendFile(df).commit(); + } + + long end = rightAfterSnapshot(); + + Set expectedDeletes = Sets.newHashSet(); + expectedDeletes.addAll(ReachableFileUtil.manifestListLocations(table)); + // all snapshot manifest lists except current will be deleted + expectedDeletes.remove(table.currentSnapshot().manifestListLocation()); + expectedDeletes.addAll( + manifestsBefore); // new manifests are reachable from current snapshot and not deleted + expectedDeletes.addAll( + dataFiles); // new data files are reachable from current snapshot and not deleted + + Set deletedFiles = Sets.newHashSet(); + SparkActions.get() + .expireSnapshots(table) + .expireOlderThan(end) + .deleteWith(deletedFiles::add) + .execute(); + + Assert.assertEquals( + "All reachable files before expiration should be deleted", expectedDeletes, deletedFiles); + } + + @Test + public void testExpireSomeCheckFilesDeleted() { + + table.newAppend().appendFile(FILE_A).commit(); + + table.newAppend().appendFile(FILE_B).commit(); + + table.newAppend().appendFile(FILE_C).commit(); + + table.newDelete().deleteFile(FILE_A).commit(); + + long after = rightAfterSnapshot(); + waitUntilAfter(after); + + table.newAppend().appendFile(FILE_D).commit(); + + table.newDelete().deleteFile(FILE_B).commit(); + + Set deletedFiles = Sets.newHashSet(); + SparkActions.get() + .expireSnapshots(table) + .expireOlderThan(after) + .deleteWith(deletedFiles::add) + .execute(); + + // C, D should be retained (live) + // B should be retained (previous snapshot points to it) + // A should be deleted + Assert.assertTrue(deletedFiles.contains(FILE_A.path().toString())); + Assert.assertFalse(deletedFiles.contains(FILE_B.path().toString())); + Assert.assertFalse(deletedFiles.contains(FILE_C.path().toString())); + Assert.assertFalse(deletedFiles.contains(FILE_D.path().toString())); + } +} diff --git a/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/actions/TestRemoveOrphanFilesAction.java b/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/actions/TestRemoveOrphanFilesAction.java new file mode 100644 index 000000000000..536dd5febbaa --- /dev/null +++ b/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/actions/TestRemoveOrphanFilesAction.java @@ -0,0 +1,1091 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.actions; + +import static org.apache.iceberg.types.Types.NestedField.optional; + +import java.io.File; +import java.io.IOException; +import java.net.URI; +import java.nio.ByteBuffer; +import java.nio.charset.StandardCharsets; +import java.sql.Timestamp; +import java.util.Arrays; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.stream.Collectors; +import java.util.stream.StreamSupport; +import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.fs.FileStatus; +import org.apache.hadoop.fs.FileSystem; +import org.apache.hadoop.fs.Path; +import org.apache.iceberg.AssertHelpers; +import org.apache.iceberg.Files; +import org.apache.iceberg.GenericBlobMetadata; +import org.apache.iceberg.GenericStatisticsFile; +import org.apache.iceberg.PartitionSpec; +import org.apache.iceberg.Schema; +import org.apache.iceberg.Snapshot; +import org.apache.iceberg.StatisticsFile; +import org.apache.iceberg.Table; +import org.apache.iceberg.TableProperties; +import org.apache.iceberg.Transaction; +import org.apache.iceberg.actions.DeleteOrphanFiles; +import org.apache.iceberg.catalog.Namespace; +import org.apache.iceberg.catalog.TableIdentifier; +import org.apache.iceberg.exceptions.ValidationException; +import org.apache.iceberg.hadoop.HadoopCatalog; +import org.apache.iceberg.hadoop.HadoopTables; +import org.apache.iceberg.hadoop.HiddenPathFilter; +import org.apache.iceberg.puffin.Blob; +import org.apache.iceberg.puffin.Puffin; +import org.apache.iceberg.puffin.PuffinWriter; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableList; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap; +import org.apache.iceberg.relocated.com.google.common.collect.Iterables; +import org.apache.iceberg.relocated.com.google.common.collect.Lists; +import org.apache.iceberg.relocated.com.google.common.collect.Maps; +import org.apache.iceberg.relocated.com.google.common.collect.Sets; +import org.apache.iceberg.spark.SparkSQLProperties; +import org.apache.iceberg.spark.SparkTestBase; +import org.apache.iceberg.spark.actions.DeleteOrphanFilesSparkAction.StringToFileURI; +import org.apache.iceberg.spark.source.FilePathLastModifiedRecord; +import org.apache.iceberg.spark.source.ThreeColumnRecord; +import org.apache.iceberg.types.Types; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Encoders; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.RowFactory; +import org.apache.spark.sql.types.DataTypes; +import org.apache.spark.sql.types.StructType; +import org.assertj.core.api.Assertions; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.TemporaryFolder; + +public abstract class TestRemoveOrphanFilesAction extends SparkTestBase { + + private static final HadoopTables TABLES = new HadoopTables(new Configuration()); + protected static final Schema SCHEMA = + new Schema( + optional(1, "c1", Types.IntegerType.get()), + optional(2, "c2", Types.StringType.get()), + optional(3, "c3", Types.StringType.get())); + protected static final PartitionSpec SPEC = + PartitionSpec.builderFor(SCHEMA).truncate("c2", 2).identity("c3").build(); + + @Rule public TemporaryFolder temp = new TemporaryFolder(); + private File tableDir = null; + protected String tableLocation = null; + + @Before + public void setupTableLocation() throws Exception { + this.tableDir = temp.newFolder(); + this.tableLocation = tableDir.toURI().toString(); + } + + @Test + public void testDryRun() throws IOException, InterruptedException { + Table table = + TABLES.create(SCHEMA, PartitionSpec.unpartitioned(), Maps.newHashMap(), tableLocation); + + List records = + Lists.newArrayList(new ThreeColumnRecord(1, "AAAAAAAAAA", "AAAA")); + + Dataset df = spark.createDataFrame(records, ThreeColumnRecord.class).coalesce(1); + + df.select("c1", "c2", "c3").write().format("iceberg").mode("append").save(tableLocation); + + df.select("c1", "c2", "c3").write().format("iceberg").mode("append").save(tableLocation); + + List validFiles = + spark + .read() + .format("iceberg") + .load(tableLocation + "#files") + .select("file_path") + .as(Encoders.STRING()) + .collectAsList(); + Assert.assertEquals("Should be 2 valid files", 2, validFiles.size()); + + df.write().mode("append").parquet(tableLocation + "/data"); + + Path dataPath = new Path(tableLocation + "/data"); + FileSystem fs = dataPath.getFileSystem(spark.sessionState().newHadoopConf()); + List allFiles = + Arrays.stream(fs.listStatus(dataPath, HiddenPathFilter.get())) + .filter(FileStatus::isFile) + .map(file -> file.getPath().toString()) + .collect(Collectors.toList()); + Assert.assertEquals("Should be 3 files", 3, allFiles.size()); + + List invalidFiles = Lists.newArrayList(allFiles); + invalidFiles.removeAll(validFiles); + Assert.assertEquals("Should be 1 invalid file", 1, invalidFiles.size()); + + waitUntilAfter(System.currentTimeMillis()); + + SparkActions actions = SparkActions.get(); + + DeleteOrphanFiles.Result result1 = + actions.deleteOrphanFiles(table).deleteWith(s -> {}).execute(); + Assert.assertTrue( + "Default olderThan interval should be safe", + Iterables.isEmpty(result1.orphanFileLocations())); + + DeleteOrphanFiles.Result result2 = + actions + .deleteOrphanFiles(table) + .olderThan(System.currentTimeMillis()) + .deleteWith(s -> {}) + .execute(); + Assert.assertEquals("Action should find 1 file", invalidFiles, result2.orphanFileLocations()); + Assert.assertTrue("Invalid file should be present", fs.exists(new Path(invalidFiles.get(0)))); + + DeleteOrphanFiles.Result result3 = + actions.deleteOrphanFiles(table).olderThan(System.currentTimeMillis()).execute(); + Assert.assertEquals("Action should delete 1 file", invalidFiles, result3.orphanFileLocations()); + Assert.assertFalse( + "Invalid file should not be present", fs.exists(new Path(invalidFiles.get(0)))); + + List expectedRecords = Lists.newArrayList(); + expectedRecords.addAll(records); + expectedRecords.addAll(records); + + Dataset resultDF = spark.read().format("iceberg").load(tableLocation); + List actualRecords = + resultDF.as(Encoders.bean(ThreeColumnRecord.class)).collectAsList(); + Assert.assertEquals("Rows must match", expectedRecords, actualRecords); + } + + @Test + public void testAllValidFilesAreKept() throws IOException, InterruptedException { + Table table = TABLES.create(SCHEMA, SPEC, Maps.newHashMap(), tableLocation); + + List records1 = + Lists.newArrayList(new ThreeColumnRecord(1, "AAAAAAAAAA", "AAAA")); + Dataset df1 = spark.createDataFrame(records1, ThreeColumnRecord.class).coalesce(1); + + // original append + df1.select("c1", "c2", "c3").write().format("iceberg").mode("append").save(tableLocation); + + List records2 = + Lists.newArrayList(new ThreeColumnRecord(2, "AAAAAAAAAA", "AAAA")); + Dataset df2 = spark.createDataFrame(records2, ThreeColumnRecord.class).coalesce(1); + + // dynamic partition overwrite + df2.select("c1", "c2", "c3").write().format("iceberg").mode("overwrite").save(tableLocation); + + // second append + df2.select("c1", "c2", "c3").write().format("iceberg").mode("append").save(tableLocation); + + List snapshots = Lists.newArrayList(table.snapshots()); + + List snapshotFiles1 = snapshotFiles(snapshots.get(0).snapshotId()); + Assert.assertEquals(1, snapshotFiles1.size()); + + List snapshotFiles2 = snapshotFiles(snapshots.get(1).snapshotId()); + Assert.assertEquals(1, snapshotFiles2.size()); + + List snapshotFiles3 = snapshotFiles(snapshots.get(2).snapshotId()); + Assert.assertEquals(2, snapshotFiles3.size()); + + df2.coalesce(1).write().mode("append").parquet(tableLocation + "/data"); + df2.coalesce(1).write().mode("append").parquet(tableLocation + "/data/c2_trunc=AA"); + df2.coalesce(1).write().mode("append").parquet(tableLocation + "/data/c2_trunc=AA/c3=AAAA"); + df2.coalesce(1).write().mode("append").parquet(tableLocation + "/data/invalid/invalid"); + + waitUntilAfter(System.currentTimeMillis()); + + SparkActions actions = SparkActions.get(); + + DeleteOrphanFiles.Result result = + actions.deleteOrphanFiles(table).olderThan(System.currentTimeMillis()).execute(); + + Assert.assertEquals("Should delete 4 files", 4, Iterables.size(result.orphanFileLocations())); + + Path dataPath = new Path(tableLocation + "/data"); + FileSystem fs = dataPath.getFileSystem(spark.sessionState().newHadoopConf()); + + for (String fileLocation : snapshotFiles1) { + Assert.assertTrue("All snapshot files must remain", fs.exists(new Path(fileLocation))); + } + + for (String fileLocation : snapshotFiles2) { + Assert.assertTrue("All snapshot files must remain", fs.exists(new Path(fileLocation))); + } + + for (String fileLocation : snapshotFiles3) { + Assert.assertTrue("All snapshot files must remain", fs.exists(new Path(fileLocation))); + } + } + + @Test + public void orphanedFileRemovedWithParallelTasks() throws InterruptedException, IOException { + Table table = TABLES.create(SCHEMA, SPEC, Maps.newHashMap(), tableLocation); + + List records1 = + Lists.newArrayList(new ThreeColumnRecord(1, "AAAAAAAAAA", "AAAA")); + Dataset df1 = spark.createDataFrame(records1, ThreeColumnRecord.class).coalesce(1); + + // original append + df1.select("c1", "c2", "c3").write().format("iceberg").mode("append").save(tableLocation); + + List records2 = + Lists.newArrayList(new ThreeColumnRecord(2, "AAAAAAAAAA", "AAAA")); + Dataset df2 = spark.createDataFrame(records2, ThreeColumnRecord.class).coalesce(1); + + // dynamic partition overwrite + df2.select("c1", "c2", "c3").write().format("iceberg").mode("overwrite").save(tableLocation); + + // second append + df2.select("c1", "c2", "c3").write().format("iceberg").mode("append").save(tableLocation); + + df2.coalesce(1).write().mode("append").parquet(tableLocation + "/data"); + df2.coalesce(1).write().mode("append").parquet(tableLocation + "/data/c2_trunc=AA"); + df2.coalesce(1).write().mode("append").parquet(tableLocation + "/data/c2_trunc=AA/c3=AAAA"); + df2.coalesce(1).write().mode("append").parquet(tableLocation + "/data/invalid/invalid"); + + waitUntilAfter(System.currentTimeMillis()); + + Set deletedFiles = Sets.newHashSet(); + Set deleteThreads = ConcurrentHashMap.newKeySet(); + AtomicInteger deleteThreadsIndex = new AtomicInteger(0); + + ExecutorService executorService = + Executors.newFixedThreadPool( + 4, + runnable -> { + Thread thread = new Thread(runnable); + thread.setName("remove-orphan-" + deleteThreadsIndex.getAndIncrement()); + thread.setDaemon(true); + return thread; + }); + + DeleteOrphanFiles.Result result = + SparkActions.get() + .deleteOrphanFiles(table) + .executeDeleteWith(executorService) + .olderThan(System.currentTimeMillis() + 5000) // Ensure all orphan files are selected + .deleteWith( + file -> { + deleteThreads.add(Thread.currentThread().getName()); + deletedFiles.add(file); + }) + .execute(); + + // Verifies that the delete methods ran in the threads created by the provided ExecutorService + // ThreadFactory + Assert.assertEquals( + deleteThreads, + Sets.newHashSet( + "remove-orphan-0", "remove-orphan-1", "remove-orphan-2", "remove-orphan-3")); + + Assert.assertEquals("Should delete 4 files", 4, deletedFiles.size()); + } + + @Test + public void testWapFilesAreKept() throws InterruptedException { + Map props = Maps.newHashMap(); + props.put(TableProperties.WRITE_AUDIT_PUBLISH_ENABLED, "true"); + Table table = TABLES.create(SCHEMA, SPEC, props, tableLocation); + + List records = + Lists.newArrayList(new ThreeColumnRecord(1, "AAAAAAAAAA", "AAAA")); + Dataset df = spark.createDataFrame(records, ThreeColumnRecord.class); + + // normal write + df.select("c1", "c2", "c3").write().format("iceberg").mode("append").save(tableLocation); + + spark.conf().set(SparkSQLProperties.WAP_ID, "1"); + + // wap write + df.select("c1", "c2", "c3").write().format("iceberg").mode("append").save(tableLocation); + + Dataset resultDF = spark.read().format("iceberg").load(tableLocation); + List actualRecords = + resultDF.as(Encoders.bean(ThreeColumnRecord.class)).collectAsList(); + Assert.assertEquals("Should not return data from the staged snapshot", records, actualRecords); + + waitUntilAfter(System.currentTimeMillis()); + + SparkActions actions = SparkActions.get(); + + DeleteOrphanFiles.Result result = + actions.deleteOrphanFiles(table).olderThan(System.currentTimeMillis()).execute(); + + Assert.assertTrue( + "Should not delete any files", Iterables.isEmpty(result.orphanFileLocations())); + } + + @Test + public void testMetadataFolderIsIntact() throws InterruptedException { + // write data directly to the table location + Map props = Maps.newHashMap(); + props.put(TableProperties.WRITE_DATA_LOCATION, tableLocation); + Table table = TABLES.create(SCHEMA, SPEC, props, tableLocation); + + List records = + Lists.newArrayList(new ThreeColumnRecord(1, "AAAAAAAAAA", "AAAA")); + Dataset df = spark.createDataFrame(records, ThreeColumnRecord.class).coalesce(1); + + df.select("c1", "c2", "c3").write().format("iceberg").mode("append").save(tableLocation); + + df.write().mode("append").parquet(tableLocation + "/c2_trunc=AA/c3=AAAA"); + + waitUntilAfter(System.currentTimeMillis()); + + SparkActions actions = SparkActions.get(); + + DeleteOrphanFiles.Result result = + actions.deleteOrphanFiles(table).olderThan(System.currentTimeMillis()).execute(); + + Assert.assertEquals("Should delete 1 file", 1, Iterables.size(result.orphanFileLocations())); + + Dataset resultDF = spark.read().format("iceberg").load(tableLocation); + List actualRecords = + resultDF.as(Encoders.bean(ThreeColumnRecord.class)).collectAsList(); + Assert.assertEquals("Rows must match", records, actualRecords); + } + + @Test + public void testOlderThanTimestamp() throws InterruptedException { + Table table = TABLES.create(SCHEMA, SPEC, Maps.newHashMap(), tableLocation); + + List records = + Lists.newArrayList(new ThreeColumnRecord(1, "AAAAAAAAAA", "AAAA")); + Dataset df = spark.createDataFrame(records, ThreeColumnRecord.class).coalesce(1); + + df.select("c1", "c2", "c3").write().format("iceberg").mode("append").save(tableLocation); + + df.write().mode("append").parquet(tableLocation + "/data/c2_trunc=AA/c3=AAAA"); + df.write().mode("append").parquet(tableLocation + "/data/c2_trunc=AA/c3=AAAA"); + + waitUntilAfter(System.currentTimeMillis()); + + long timestamp = System.currentTimeMillis(); + + waitUntilAfter(System.currentTimeMillis() + 1000L); + + df.write().mode("append").parquet(tableLocation + "/data/c2_trunc=AA/c3=AAAA"); + + SparkActions actions = SparkActions.get(); + + DeleteOrphanFiles.Result result = + actions.deleteOrphanFiles(table).olderThan(timestamp).execute(); + + Assert.assertEquals( + "Should delete only 2 files", 2, Iterables.size(result.orphanFileLocations())); + } + + @Test + public void testRemoveUnreachableMetadataVersionFiles() throws InterruptedException { + Map props = Maps.newHashMap(); + props.put(TableProperties.WRITE_DATA_LOCATION, tableLocation); + props.put(TableProperties.METADATA_PREVIOUS_VERSIONS_MAX, "1"); + Table table = TABLES.create(SCHEMA, SPEC, props, tableLocation); + + List records = + Lists.newArrayList(new ThreeColumnRecord(1, "AAAAAAAAAA", "AAAA")); + Dataset df = spark.createDataFrame(records, ThreeColumnRecord.class); + + df.select("c1", "c2", "c3").write().format("iceberg").mode("append").save(tableLocation); + + df.select("c1", "c2", "c3").write().format("iceberg").mode("append").save(tableLocation); + + waitUntilAfter(System.currentTimeMillis()); + + SparkActions actions = SparkActions.get(); + + DeleteOrphanFiles.Result result = + actions.deleteOrphanFiles(table).olderThan(System.currentTimeMillis()).execute(); + + Assert.assertEquals("Should delete 1 file", 1, Iterables.size(result.orphanFileLocations())); + Assert.assertTrue( + "Should remove v1 file", + StreamSupport.stream(result.orphanFileLocations().spliterator(), false) + .anyMatch(file -> file.contains("v1.metadata.json"))); + + List expectedRecords = Lists.newArrayList(); + expectedRecords.addAll(records); + expectedRecords.addAll(records); + + Dataset resultDF = spark.read().format("iceberg").load(tableLocation); + List actualRecords = + resultDF.as(Encoders.bean(ThreeColumnRecord.class)).collectAsList(); + Assert.assertEquals("Rows must match", expectedRecords, actualRecords); + } + + @Test + public void testManyTopLevelPartitions() throws InterruptedException { + Table table = TABLES.create(SCHEMA, SPEC, Maps.newHashMap(), tableLocation); + + List records = Lists.newArrayList(); + for (int i = 0; i < 100; i++) { + records.add(new ThreeColumnRecord(i, String.valueOf(i), String.valueOf(i))); + } + + Dataset df = spark.createDataFrame(records, ThreeColumnRecord.class); + + df.select("c1", "c2", "c3").write().format("iceberg").mode("append").save(tableLocation); + + waitUntilAfter(System.currentTimeMillis()); + + SparkActions actions = SparkActions.get(); + + DeleteOrphanFiles.Result result = + actions.deleteOrphanFiles(table).olderThan(System.currentTimeMillis()).execute(); + + Assert.assertTrue( + "Should not delete any files", Iterables.isEmpty(result.orphanFileLocations())); + + Dataset resultDF = spark.read().format("iceberg").load(tableLocation); + List actualRecords = + resultDF.as(Encoders.bean(ThreeColumnRecord.class)).collectAsList(); + Assert.assertEquals("Rows must match", records, actualRecords); + } + + @Test + public void testManyLeafPartitions() throws InterruptedException { + Table table = TABLES.create(SCHEMA, SPEC, Maps.newHashMap(), tableLocation); + + List records = Lists.newArrayList(); + for (int i = 0; i < 100; i++) { + records.add(new ThreeColumnRecord(i, String.valueOf(i % 3), String.valueOf(i))); + } + + Dataset df = spark.createDataFrame(records, ThreeColumnRecord.class); + + df.select("c1", "c2", "c3").write().format("iceberg").mode("append").save(tableLocation); + + waitUntilAfter(System.currentTimeMillis()); + + SparkActions actions = SparkActions.get(); + + DeleteOrphanFiles.Result result = + actions.deleteOrphanFiles(table).olderThan(System.currentTimeMillis()).execute(); + + Assert.assertTrue( + "Should not delete any files", Iterables.isEmpty(result.orphanFileLocations())); + + Dataset resultDF = spark.read().format("iceberg").load(tableLocation); + List actualRecords = + resultDF.as(Encoders.bean(ThreeColumnRecord.class)).collectAsList(); + Assert.assertEquals("Rows must match", records, actualRecords); + } + + @Test + public void testHiddenPartitionPaths() throws InterruptedException { + Schema schema = + new Schema( + optional(1, "c1", Types.IntegerType.get()), + optional(2, "_c2", Types.StringType.get()), + optional(3, "c3", Types.StringType.get())); + PartitionSpec spec = PartitionSpec.builderFor(schema).truncate("_c2", 2).identity("c3").build(); + Table table = TABLES.create(schema, spec, Maps.newHashMap(), tableLocation); + + StructType structType = + new StructType() + .add("c1", DataTypes.IntegerType) + .add("_c2", DataTypes.StringType) + .add("c3", DataTypes.StringType); + List records = Lists.newArrayList(RowFactory.create(1, "AAAAAAAAAA", "AAAA")); + Dataset df = spark.createDataFrame(records, structType).coalesce(1); + + df.select("c1", "_c2", "c3").write().format("iceberg").mode("append").save(tableLocation); + + df.write().mode("append").parquet(tableLocation + "/data/_c2_trunc=AA/c3=AAAA"); + df.write().mode("append").parquet(tableLocation + "/data/_c2_trunc=AA/c3=AAAA"); + + waitUntilAfter(System.currentTimeMillis()); + + SparkActions actions = SparkActions.get(); + + DeleteOrphanFiles.Result result = + actions.deleteOrphanFiles(table).olderThan(System.currentTimeMillis()).execute(); + + Assert.assertEquals("Should delete 2 files", 2, Iterables.size(result.orphanFileLocations())); + } + + @Test + public void testHiddenPartitionPathsWithPartitionEvolution() throws InterruptedException { + Schema schema = + new Schema( + optional(1, "_c1", Types.IntegerType.get()), + optional(2, "_c2", Types.StringType.get()), + optional(3, "c3", Types.StringType.get())); + PartitionSpec spec = PartitionSpec.builderFor(schema).truncate("_c2", 2).build(); + Table table = TABLES.create(schema, spec, Maps.newHashMap(), tableLocation); + + StructType structType = + new StructType() + .add("_c1", DataTypes.IntegerType) + .add("_c2", DataTypes.StringType) + .add("c3", DataTypes.StringType); + List records = Lists.newArrayList(RowFactory.create(1, "AAAAAAAAAA", "AAAA")); + Dataset df = spark.createDataFrame(records, structType).coalesce(1); + + df.select("_c1", "_c2", "c3").write().format("iceberg").mode("append").save(tableLocation); + + df.write().mode("append").parquet(tableLocation + "/data/_c2_trunc=AA"); + + table.updateSpec().addField("_c1").commit(); + + df.write().mode("append").parquet(tableLocation + "/data/_c2_trunc=AA/_c1=1"); + + waitUntilAfter(System.currentTimeMillis()); + + SparkActions actions = SparkActions.get(); + + DeleteOrphanFiles.Result result = + actions.deleteOrphanFiles(table).olderThan(System.currentTimeMillis()).execute(); + + Assert.assertEquals("Should delete 2 files", 2, Iterables.size(result.orphanFileLocations())); + } + + @Test + public void testHiddenPathsStartingWithPartitionNamesAreIgnored() + throws InterruptedException, IOException { + Schema schema = + new Schema( + optional(1, "c1", Types.IntegerType.get()), + optional(2, "_c2", Types.StringType.get()), + optional(3, "c3", Types.StringType.get())); + PartitionSpec spec = PartitionSpec.builderFor(schema).truncate("_c2", 2).identity("c3").build(); + Table table = TABLES.create(schema, spec, Maps.newHashMap(), tableLocation); + + StructType structType = + new StructType() + .add("c1", DataTypes.IntegerType) + .add("_c2", DataTypes.StringType) + .add("c3", DataTypes.StringType); + List records = Lists.newArrayList(RowFactory.create(1, "AAAAAAAAAA", "AAAA")); + Dataset df = spark.createDataFrame(records, structType).coalesce(1); + + df.select("c1", "_c2", "c3").write().format("iceberg").mode("append").save(tableLocation); + + Path dataPath = new Path(tableLocation + "/data"); + FileSystem fs = dataPath.getFileSystem(spark.sessionState().newHadoopConf()); + Path pathToFileInHiddenFolder = new Path(dataPath, "_c2_trunc/file.txt"); + fs.createNewFile(pathToFileInHiddenFolder); + + waitUntilAfter(System.currentTimeMillis()); + + SparkActions actions = SparkActions.get(); + + DeleteOrphanFiles.Result result = + actions.deleteOrphanFiles(table).olderThan(System.currentTimeMillis()).execute(); + + Assert.assertEquals("Should delete 0 files", 0, Iterables.size(result.orphanFileLocations())); + Assert.assertTrue(fs.exists(pathToFileInHiddenFolder)); + } + + private List snapshotFiles(long snapshotId) { + return spark + .read() + .format("iceberg") + .option("snapshot-id", snapshotId) + .load(tableLocation + "#files") + .select("file_path") + .as(Encoders.STRING()) + .collectAsList(); + } + + @Test + public void testRemoveOrphanFilesWithRelativeFilePath() throws IOException, InterruptedException { + Table table = + TABLES.create( + SCHEMA, PartitionSpec.unpartitioned(), Maps.newHashMap(), tableDir.getAbsolutePath()); + + List records = + Lists.newArrayList(new ThreeColumnRecord(1, "AAAAAAAAAA", "AAAA")); + + Dataset df = spark.createDataFrame(records, ThreeColumnRecord.class).coalesce(1); + + df.select("c1", "c2", "c3") + .write() + .format("iceberg") + .mode("append") + .save(tableDir.getAbsolutePath()); + + List validFiles = + spark + .read() + .format("iceberg") + .load(tableLocation + "#files") + .select("file_path") + .as(Encoders.STRING()) + .collectAsList(); + Assert.assertEquals("Should be 1 valid files", 1, validFiles.size()); + String validFile = validFiles.get(0); + + df.write().mode("append").parquet(tableLocation + "/data"); + + Path dataPath = new Path(tableLocation + "/data"); + FileSystem fs = dataPath.getFileSystem(spark.sessionState().newHadoopConf()); + List allFiles = + Arrays.stream(fs.listStatus(dataPath, HiddenPathFilter.get())) + .filter(FileStatus::isFile) + .map(file -> file.getPath().toString()) + .collect(Collectors.toList()); + Assert.assertEquals("Should be 2 files", 2, allFiles.size()); + + List invalidFiles = Lists.newArrayList(allFiles); + invalidFiles.removeIf(file -> file.contains(validFile)); + Assert.assertEquals("Should be 1 invalid file", 1, invalidFiles.size()); + + waitUntilAfter(System.currentTimeMillis()); + + SparkActions actions = SparkActions.get(); + DeleteOrphanFiles.Result result = + actions + .deleteOrphanFiles(table) + .olderThan(System.currentTimeMillis()) + .deleteWith(s -> {}) + .execute(); + Assert.assertEquals("Action should find 1 file", invalidFiles, result.orphanFileLocations()); + Assert.assertTrue("Invalid file should be present", fs.exists(new Path(invalidFiles.get(0)))); + } + + @Test + public void testRemoveOrphanFilesWithHadoopCatalog() throws InterruptedException { + HadoopCatalog catalog = new HadoopCatalog(new Configuration(), tableLocation); + String namespaceName = "testDb"; + String tableName = "testTb"; + + Namespace namespace = Namespace.of(namespaceName); + TableIdentifier tableIdentifier = TableIdentifier.of(namespace, tableName); + Table table = + catalog.createTable( + tableIdentifier, SCHEMA, PartitionSpec.unpartitioned(), Maps.newHashMap()); + + List records = + Lists.newArrayList(new ThreeColumnRecord(1, "AAAAAAAAAA", "AAAA")); + Dataset df = spark.createDataFrame(records, ThreeColumnRecord.class).coalesce(1); + + df.select("c1", "c2", "c3").write().format("iceberg").mode("append").save(table.location()); + + df.write().mode("append").parquet(table.location() + "/data"); + + waitUntilAfter(System.currentTimeMillis()); + + table.refresh(); + + DeleteOrphanFiles.Result result = + SparkActions.get().deleteOrphanFiles(table).olderThan(System.currentTimeMillis()).execute(); + + Assert.assertEquals( + "Should delete only 1 files", 1, Iterables.size(result.orphanFileLocations())); + + Dataset resultDF = spark.read().format("iceberg").load(table.location()); + List actualRecords = + resultDF.as(Encoders.bean(ThreeColumnRecord.class)).collectAsList(); + Assert.assertEquals("Rows must match", records, actualRecords); + } + + @Test + public void testHiveCatalogTable() throws IOException { + Table table = + catalog.createTable( + TableIdentifier.of("default", "hivetestorphan"), + SCHEMA, + SPEC, + tableLocation, + Maps.newHashMap()); + + List records = + Lists.newArrayList(new ThreeColumnRecord(1, "AAAAAAAAAA", "AAAA")); + + Dataset df = spark.createDataFrame(records, ThreeColumnRecord.class).coalesce(1); + + df.select("c1", "c2", "c3") + .write() + .format("iceberg") + .mode("append") + .save("default.hivetestorphan"); + + String location = table.location().replaceFirst("file:", ""); + new File(location + "/data/trashfile").createNewFile(); + + DeleteOrphanFiles.Result result = + SparkActions.get() + .deleteOrphanFiles(table) + .olderThan(System.currentTimeMillis() + 1000) + .execute(); + Assert.assertTrue( + "trash file should be removed", + StreamSupport.stream(result.orphanFileLocations().spliterator(), false) + .anyMatch(file -> file.contains("file:" + location + "/data/trashfile"))); + } + + @Test + public void testGarbageCollectionDisabled() { + Table table = + TABLES.create(SCHEMA, PartitionSpec.unpartitioned(), Maps.newHashMap(), tableLocation); + + List records = + Lists.newArrayList(new ThreeColumnRecord(1, "AAAAAAAAAA", "AAAA")); + + Dataset df = spark.createDataFrame(records, ThreeColumnRecord.class).coalesce(1); + + df.select("c1", "c2", "c3").write().format("iceberg").mode("append").save(tableLocation); + + table.updateProperties().set(TableProperties.GC_ENABLED, "false").commit(); + + AssertHelpers.assertThrows( + "Should complain about removing orphan files", + ValidationException.class, + "Cannot delete orphan files: GC is disabled", + () -> SparkActions.get().deleteOrphanFiles(table).execute()); + } + + @Test + public void testCompareToFileList() throws IOException, InterruptedException { + Table table = + TABLES.create(SCHEMA, PartitionSpec.unpartitioned(), Maps.newHashMap(), tableLocation); + + List records = + Lists.newArrayList(new ThreeColumnRecord(1, "AAAAAAAAAA", "AAAA")); + + Dataset df = spark.createDataFrame(records, ThreeColumnRecord.class).coalesce(1); + + df.select("c1", "c2", "c3").write().format("iceberg").mode("append").save(tableLocation); + + df.select("c1", "c2", "c3").write().format("iceberg").mode("append").save(tableLocation); + + Path dataPath = new Path(tableLocation + "/data"); + FileSystem fs = dataPath.getFileSystem(spark.sessionState().newHadoopConf()); + List validFiles = + Arrays.stream(fs.listStatus(dataPath, HiddenPathFilter.get())) + .filter(FileStatus::isFile) + .map( + file -> + new FilePathLastModifiedRecord( + file.getPath().toString(), new Timestamp(file.getModificationTime()))) + .collect(Collectors.toList()); + + Assert.assertEquals("Should be 2 valid files", 2, validFiles.size()); + + df.write().mode("append").parquet(tableLocation + "/data"); + + List allFiles = + Arrays.stream(fs.listStatus(dataPath, HiddenPathFilter.get())) + .filter(FileStatus::isFile) + .map( + file -> + new FilePathLastModifiedRecord( + file.getPath().toString(), new Timestamp(file.getModificationTime()))) + .collect(Collectors.toList()); + + Assert.assertEquals("Should be 3 files", 3, allFiles.size()); + + List invalidFiles = Lists.newArrayList(allFiles); + invalidFiles.removeAll(validFiles); + List invalidFilePaths = + invalidFiles.stream() + .map(FilePathLastModifiedRecord::getFilePath) + .collect(Collectors.toList()); + Assert.assertEquals("Should be 1 invalid file", 1, invalidFiles.size()); + + // sleep for 1 second to ensure files will be old enough + waitUntilAfter(System.currentTimeMillis()); + + SparkActions actions = SparkActions.get(); + + Dataset compareToFileList = + spark + .createDataFrame(allFiles, FilePathLastModifiedRecord.class) + .withColumnRenamed("filePath", "file_path") + .withColumnRenamed("lastModified", "last_modified"); + + DeleteOrphanFiles.Result result1 = + actions + .deleteOrphanFiles(table) + .compareToFileList(compareToFileList) + .deleteWith(s -> {}) + .execute(); + Assert.assertTrue( + "Default olderThan interval should be safe", + Iterables.isEmpty(result1.orphanFileLocations())); + + DeleteOrphanFiles.Result result2 = + actions + .deleteOrphanFiles(table) + .compareToFileList(compareToFileList) + .olderThan(System.currentTimeMillis()) + .deleteWith(s -> {}) + .execute(); + Assert.assertEquals( + "Action should find 1 file", invalidFilePaths, result2.orphanFileLocations()); + Assert.assertTrue( + "Invalid file should be present", fs.exists(new Path(invalidFilePaths.get(0)))); + + DeleteOrphanFiles.Result result3 = + actions + .deleteOrphanFiles(table) + .compareToFileList(compareToFileList) + .olderThan(System.currentTimeMillis()) + .execute(); + Assert.assertEquals( + "Action should delete 1 file", invalidFilePaths, result3.orphanFileLocations()); + Assert.assertFalse( + "Invalid file should not be present", fs.exists(new Path(invalidFilePaths.get(0)))); + + List expectedRecords = Lists.newArrayList(); + expectedRecords.addAll(records); + expectedRecords.addAll(records); + + Dataset resultDF = spark.read().format("iceberg").load(tableLocation); + List actualRecords = + resultDF.as(Encoders.bean(ThreeColumnRecord.class)).collectAsList(); + Assert.assertEquals("Rows must match", expectedRecords, actualRecords); + + List outsideLocationMockFiles = + Lists.newArrayList(new FilePathLastModifiedRecord("/tmp/mock1", new Timestamp(0L))); + + Dataset compareToFileListWithOutsideLocation = + spark + .createDataFrame(outsideLocationMockFiles, FilePathLastModifiedRecord.class) + .withColumnRenamed("filePath", "file_path") + .withColumnRenamed("lastModified", "last_modified"); + + DeleteOrphanFiles.Result result4 = + actions + .deleteOrphanFiles(table) + .compareToFileList(compareToFileListWithOutsideLocation) + .deleteWith(s -> {}) + .execute(); + Assert.assertEquals( + "Action should find nothing", Lists.newArrayList(), result4.orphanFileLocations()); + } + + protected long waitUntilAfter(long timestampMillis) { + long current = System.currentTimeMillis(); + while (current <= timestampMillis) { + current = System.currentTimeMillis(); + } + return current; + } + + @Test + public void testRemoveOrphanFilesWithStatisticFiles() throws Exception { + Table table = + TABLES.create( + SCHEMA, + PartitionSpec.unpartitioned(), + ImmutableMap.of(TableProperties.FORMAT_VERSION, "2"), + tableLocation); + + List records = + Lists.newArrayList(new ThreeColumnRecord(1, "AAAAAAAAAA", "AAAA")); + Dataset df = spark.createDataFrame(records, ThreeColumnRecord.class).coalesce(1); + df.select("c1", "c2", "c3").write().format("iceberg").mode("append").save(tableLocation); + + table.refresh(); + long snapshotId = table.currentSnapshot().snapshotId(); + long snapshotSequenceNumber = table.currentSnapshot().sequenceNumber(); + + File statsLocation = + new File(new URI(tableLocation)) + .toPath() + .resolve("data") + .resolve("some-stats-file") + .toFile(); + StatisticsFile statisticsFile; + try (PuffinWriter puffinWriter = Puffin.write(Files.localOutput(statsLocation)).build()) { + puffinWriter.add( + new Blob( + "some-blob-type", + ImmutableList.of(1), + snapshotId, + snapshotSequenceNumber, + ByteBuffer.wrap("blob content".getBytes(StandardCharsets.UTF_8)))); + puffinWriter.finish(); + statisticsFile = + new GenericStatisticsFile( + snapshotId, + statsLocation.toString(), + puffinWriter.fileSize(), + puffinWriter.footerSize(), + puffinWriter.writtenBlobsMetadata().stream() + .map(GenericBlobMetadata::from) + .collect(ImmutableList.toImmutableList())); + } + + Transaction transaction = table.newTransaction(); + transaction.updateStatistics().setStatistics(snapshotId, statisticsFile).commit(); + transaction.commitTransaction(); + + SparkActions.get() + .deleteOrphanFiles(table) + .olderThan(System.currentTimeMillis() + 1000) + .execute(); + + Assertions.assertThat(statsLocation.exists()).as("stats file should exist").isTrue(); + Assertions.assertThat(statsLocation.length()) + .as("stats file length") + .isEqualTo(statisticsFile.fileSizeInBytes()); + + transaction = table.newTransaction(); + transaction.updateStatistics().removeStatistics(statisticsFile.snapshotId()).commit(); + transaction.commitTransaction(); + + DeleteOrphanFiles.Result result = + SparkActions.get() + .deleteOrphanFiles(table) + .olderThan(System.currentTimeMillis() + 1000) + .execute(); + Iterable orphanFileLocations = result.orphanFileLocations(); + Assertions.assertThat(orphanFileLocations).as("Should be orphan files").hasSize(1); + Assertions.assertThat(Iterables.getOnlyElement(orphanFileLocations)) + .as("Deleted file") + .isEqualTo(statsLocation.toURI().toString()); + Assertions.assertThat(statsLocation.exists()).as("stats file should be deleted").isFalse(); + } + + @Test + public void testPathsWithExtraSlashes() { + List validFiles = Lists.newArrayList("file:///dir1/dir2/file1"); + List actualFiles = Lists.newArrayList("file:///dir1/////dir2///file1"); + executeTest(validFiles, actualFiles, Lists.newArrayList()); + } + + @Test + public void testPathsWithValidFileHavingNoAuthority() { + List validFiles = Lists.newArrayList("hdfs:///dir1/dir2/file1"); + List actualFiles = Lists.newArrayList("hdfs://servicename/dir1/dir2/file1"); + executeTest(validFiles, actualFiles, Lists.newArrayList()); + } + + @Test + public void testPathsWithActualFileHavingNoAuthority() { + List validFiles = Lists.newArrayList("hdfs://servicename/dir1/dir2/file1"); + List actualFiles = Lists.newArrayList("hdfs:///dir1/dir2/file1"); + executeTest(validFiles, actualFiles, Lists.newArrayList()); + } + + @Test + public void testPathsWithEqualSchemes() { + List validFiles = Lists.newArrayList("scheme1://bucket1/dir1/dir2/file1"); + List actualFiles = Lists.newArrayList("scheme2://bucket1/dir1/dir2/file1"); + AssertHelpers.assertThrows( + "Test remove orphan files with equal schemes", + ValidationException.class, + "Conflicting authorities/schemes: [(scheme1, scheme2)]", + () -> + executeTest( + validFiles, + actualFiles, + Lists.newArrayList(), + ImmutableMap.of(), + ImmutableMap.of(), + DeleteOrphanFiles.PrefixMismatchMode.ERROR)); + + Map equalSchemes = Maps.newHashMap(); + equalSchemes.put("scheme1", "scheme"); + equalSchemes.put("scheme2", "scheme"); + executeTest( + validFiles, + actualFiles, + Lists.newArrayList(), + equalSchemes, + ImmutableMap.of(), + DeleteOrphanFiles.PrefixMismatchMode.ERROR); + } + + @Test + public void testPathsWithEqualAuthorities() { + List validFiles = Lists.newArrayList("hdfs://servicename1/dir1/dir2/file1"); + List actualFiles = Lists.newArrayList("hdfs://servicename2/dir1/dir2/file1"); + AssertHelpers.assertThrows( + "Test remove orphan files with equal authorities", + ValidationException.class, + "Conflicting authorities/schemes: [(servicename1, servicename2)]", + () -> + executeTest( + validFiles, + actualFiles, + Lists.newArrayList(), + ImmutableMap.of(), + ImmutableMap.of(), + DeleteOrphanFiles.PrefixMismatchMode.ERROR)); + + Map equalAuthorities = Maps.newHashMap(); + equalAuthorities.put("servicename1", "servicename"); + equalAuthorities.put("servicename2", "servicename"); + executeTest( + validFiles, + actualFiles, + Lists.newArrayList(), + ImmutableMap.of(), + equalAuthorities, + DeleteOrphanFiles.PrefixMismatchMode.ERROR); + } + + @Test + public void testRemoveOrphanFileActionWithDeleteMode() { + List validFiles = Lists.newArrayList("hdfs://servicename1/dir1/dir2/file1"); + List actualFiles = Lists.newArrayList("hdfs://servicename2/dir1/dir2/file1"); + + executeTest( + validFiles, + actualFiles, + Lists.newArrayList("hdfs://servicename2/dir1/dir2/file1"), + ImmutableMap.of(), + ImmutableMap.of(), + DeleteOrphanFiles.PrefixMismatchMode.DELETE); + } + + private void executeTest( + List validFiles, List actualFiles, List expectedOrphanFiles) { + executeTest( + validFiles, + actualFiles, + expectedOrphanFiles, + ImmutableMap.of(), + ImmutableMap.of(), + DeleteOrphanFiles.PrefixMismatchMode.IGNORE); + } + + private void executeTest( + List validFiles, + List actualFiles, + List expectedOrphanFiles, + Map equalSchemes, + Map equalAuthorities, + DeleteOrphanFiles.PrefixMismatchMode mode) { + + StringToFileURI toFileUri = new StringToFileURI(equalSchemes, equalAuthorities); + + Dataset validFileDS = spark.createDataset(validFiles, Encoders.STRING()); + Dataset actualFileDS = spark.createDataset(actualFiles, Encoders.STRING()); + + List orphanFiles = + DeleteOrphanFilesSparkAction.findOrphanFiles( + spark, toFileUri.apply(actualFileDS), toFileUri.apply(validFileDS), mode); + Assert.assertEquals(expectedOrphanFiles, orphanFiles); + } +} diff --git a/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/actions/TestRemoveOrphanFilesAction3.java b/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/actions/TestRemoveOrphanFilesAction3.java new file mode 100644 index 000000000000..0abfd79d5ddb --- /dev/null +++ b/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/actions/TestRemoveOrphanFilesAction3.java @@ -0,0 +1,199 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.actions; + +import java.io.File; +import java.util.Map; +import java.util.stream.StreamSupport; +import org.apache.iceberg.actions.DeleteOrphanFiles; +import org.apache.iceberg.relocated.com.google.common.collect.Maps; +import org.apache.iceberg.spark.SparkCatalog; +import org.apache.iceberg.spark.SparkSchemaUtil; +import org.apache.iceberg.spark.SparkSessionCatalog; +import org.apache.iceberg.spark.source.SparkTable; +import org.apache.spark.sql.connector.catalog.Identifier; +import org.apache.spark.sql.connector.expressions.Transform; +import org.junit.After; +import org.junit.Assert; +import org.junit.Test; + +public class TestRemoveOrphanFilesAction3 extends TestRemoveOrphanFilesAction { + @Test + public void testSparkCatalogTable() throws Exception { + spark.conf().set("spark.sql.catalog.mycat", "org.apache.iceberg.spark.SparkCatalog"); + spark.conf().set("spark.sql.catalog.mycat.type", "hadoop"); + spark.conf().set("spark.sql.catalog.mycat.warehouse", tableLocation); + SparkCatalog cat = (SparkCatalog) spark.sessionState().catalogManager().catalog("mycat"); + + String[] database = {"default"}; + Identifier id = Identifier.of(database, "table"); + Map options = Maps.newHashMap(); + Transform[] transforms = {}; + cat.createTable(id, SparkSchemaUtil.convert(SCHEMA), transforms, options); + SparkTable table = (SparkTable) cat.loadTable(id); + + spark.sql("INSERT INTO mycat.default.table VALUES (1,1,1)"); + + String location = table.table().location().replaceFirst("file:", ""); + new File(location + "/data/trashfile").createNewFile(); + + DeleteOrphanFiles.Result results = + SparkActions.get() + .deleteOrphanFiles(table.table()) + .olderThan(System.currentTimeMillis() + 1000) + .execute(); + Assert.assertTrue( + "trash file should be removed", + StreamSupport.stream(results.orphanFileLocations().spliterator(), false) + .anyMatch(file -> file.contains("file:" + location + "/data/trashfile"))); + } + + @Test + public void testSparkCatalogNamedHadoopTable() throws Exception { + spark.conf().set("spark.sql.catalog.hadoop", "org.apache.iceberg.spark.SparkCatalog"); + spark.conf().set("spark.sql.catalog.hadoop.type", "hadoop"); + spark.conf().set("spark.sql.catalog.hadoop.warehouse", tableLocation); + SparkCatalog cat = (SparkCatalog) spark.sessionState().catalogManager().catalog("hadoop"); + + String[] database = {"default"}; + Identifier id = Identifier.of(database, "table"); + Map options = Maps.newHashMap(); + Transform[] transforms = {}; + cat.createTable(id, SparkSchemaUtil.convert(SCHEMA), transforms, options); + SparkTable table = (SparkTable) cat.loadTable(id); + + spark.sql("INSERT INTO hadoop.default.table VALUES (1,1,1)"); + + String location = table.table().location().replaceFirst("file:", ""); + new File(location + "/data/trashfile").createNewFile(); + + DeleteOrphanFiles.Result results = + SparkActions.get() + .deleteOrphanFiles(table.table()) + .olderThan(System.currentTimeMillis() + 1000) + .execute(); + Assert.assertTrue( + "trash file should be removed", + StreamSupport.stream(results.orphanFileLocations().spliterator(), false) + .anyMatch(file -> file.contains("file:" + location + "/data/trashfile"))); + } + + @Test + public void testSparkCatalogNamedHiveTable() throws Exception { + spark.conf().set("spark.sql.catalog.hive", "org.apache.iceberg.spark.SparkCatalog"); + spark.conf().set("spark.sql.catalog.hive.type", "hadoop"); + spark.conf().set("spark.sql.catalog.hive.warehouse", tableLocation); + SparkCatalog cat = (SparkCatalog) spark.sessionState().catalogManager().catalog("hive"); + + String[] database = {"default"}; + Identifier id = Identifier.of(database, "table"); + Map options = Maps.newHashMap(); + Transform[] transforms = {}; + cat.createTable(id, SparkSchemaUtil.convert(SCHEMA), transforms, options); + SparkTable table = (SparkTable) cat.loadTable(id); + + spark.sql("INSERT INTO hive.default.table VALUES (1,1,1)"); + + String location = table.table().location().replaceFirst("file:", ""); + new File(location + "/data/trashfile").createNewFile(); + + DeleteOrphanFiles.Result results = + SparkActions.get() + .deleteOrphanFiles(table.table()) + .olderThan(System.currentTimeMillis() + 1000) + .execute(); + Assert.assertTrue( + "trash file should be removed", + StreamSupport.stream(results.orphanFileLocations().spliterator(), false) + .anyMatch(file -> file.contains("file:" + location + "/data/trashfile"))); + } + + @Test + public void testSparkSessionCatalogHadoopTable() throws Exception { + spark + .conf() + .set("spark.sql.catalog.spark_catalog", "org.apache.iceberg.spark.SparkSessionCatalog"); + spark.conf().set("spark.sql.catalog.spark_catalog.type", "hadoop"); + spark.conf().set("spark.sql.catalog.spark_catalog.warehouse", tableLocation); + SparkSessionCatalog cat = + (SparkSessionCatalog) spark.sessionState().catalogManager().v2SessionCatalog(); + + String[] database = {"default"}; + Identifier id = Identifier.of(database, "table"); + Map options = Maps.newHashMap(); + Transform[] transforms = {}; + cat.createTable(id, SparkSchemaUtil.convert(SCHEMA), transforms, options); + SparkTable table = (SparkTable) cat.loadTable(id); + + spark.sql("INSERT INTO default.table VALUES (1,1,1)"); + + String location = table.table().location().replaceFirst("file:", ""); + new File(location + "/data/trashfile").createNewFile(); + + DeleteOrphanFiles.Result results = + SparkActions.get() + .deleteOrphanFiles(table.table()) + .olderThan(System.currentTimeMillis() + 1000) + .execute(); + Assert.assertTrue( + "trash file should be removed", + StreamSupport.stream(results.orphanFileLocations().spliterator(), false) + .anyMatch(file -> file.contains("file:" + location + "/data/trashfile"))); + } + + @Test + public void testSparkSessionCatalogHiveTable() throws Exception { + spark + .conf() + .set("spark.sql.catalog.spark_catalog", "org.apache.iceberg.spark.SparkSessionCatalog"); + spark.conf().set("spark.sql.catalog.spark_catalog.type", "hive"); + SparkSessionCatalog cat = + (SparkSessionCatalog) spark.sessionState().catalogManager().v2SessionCatalog(); + + String[] database = {"default"}; + Identifier id = Identifier.of(database, "sessioncattest"); + Map options = Maps.newHashMap(); + Transform[] transforms = {}; + cat.dropTable(id); + cat.createTable(id, SparkSchemaUtil.convert(SCHEMA), transforms, options); + SparkTable table = (SparkTable) cat.loadTable(id); + + spark.sql("INSERT INTO default.sessioncattest VALUES (1,1,1)"); + + String location = table.table().location().replaceFirst("file:", ""); + new File(location + "/data/trashfile").createNewFile(); + + DeleteOrphanFiles.Result results = + SparkActions.get() + .deleteOrphanFiles(table.table()) + .olderThan(System.currentTimeMillis() + 1000) + .execute(); + Assert.assertTrue( + "trash file should be removed", + StreamSupport.stream(results.orphanFileLocations().spliterator(), false) + .anyMatch(file -> file.contains("file:" + location + "/data/trashfile"))); + } + + @After + public void resetSparkSessionCatalog() throws Exception { + spark.conf().unset("spark.sql.catalog.spark_catalog"); + spark.conf().unset("spark.sql.catalog.spark_catalog.type"); + spark.conf().unset("spark.sql.catalog.spark_catalog.warehouse"); + } +} diff --git a/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/actions/TestRewriteDataFilesAction.java b/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/actions/TestRewriteDataFilesAction.java new file mode 100644 index 000000000000..761284bb56ea --- /dev/null +++ b/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/actions/TestRewriteDataFilesAction.java @@ -0,0 +1,1715 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.actions; + +import static org.apache.iceberg.TableProperties.COMMIT_NUM_RETRIES; +import static org.apache.iceberg.types.Types.NestedField.optional; +import static org.apache.iceberg.types.Types.NestedField.required; +import static org.apache.spark.sql.functions.current_date; +import static org.apache.spark.sql.functions.date_add; +import static org.apache.spark.sql.functions.expr; +import static org.assertj.core.api.Assertions.assertThat; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.argThat; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.doCallRealMethod; +import static org.mockito.Mockito.doReturn; +import static org.mockito.Mockito.doThrow; +import static org.mockito.Mockito.spy; + +import java.io.File; +import java.io.IOException; +import java.io.UncheckedIOException; +import java.util.Collections; +import java.util.Comparator; +import java.util.List; +import java.util.Map; +import java.util.Random; +import java.util.Set; +import java.util.UUID; +import java.util.stream.Collectors; +import java.util.stream.IntStream; +import java.util.stream.Stream; +import org.apache.hadoop.conf.Configuration; +import org.apache.iceberg.AssertHelpers; +import org.apache.iceberg.ContentFile; +import org.apache.iceberg.DataFile; +import org.apache.iceberg.DeleteFile; +import org.apache.iceberg.FileFormat; +import org.apache.iceberg.FileScanTask; +import org.apache.iceberg.MetadataTableType; +import org.apache.iceberg.PartitionSpec; +import org.apache.iceberg.RewriteJobOrder; +import org.apache.iceberg.RowDelta; +import org.apache.iceberg.Schema; +import org.apache.iceberg.Snapshot; +import org.apache.iceberg.SortOrder; +import org.apache.iceberg.StructLike; +import org.apache.iceberg.Table; +import org.apache.iceberg.TableProperties; +import org.apache.iceberg.actions.BinPackStrategy; +import org.apache.iceberg.actions.RewriteDataFiles; +import org.apache.iceberg.actions.RewriteDataFiles.Result; +import org.apache.iceberg.actions.RewriteDataFilesCommitManager; +import org.apache.iceberg.actions.RewriteFileGroup; +import org.apache.iceberg.actions.SortStrategy; +import org.apache.iceberg.data.GenericAppenderFactory; +import org.apache.iceberg.data.Record; +import org.apache.iceberg.deletes.PositionDelete; +import org.apache.iceberg.deletes.PositionDeleteWriter; +import org.apache.iceberg.encryption.EncryptedFiles; +import org.apache.iceberg.encryption.EncryptedOutputFile; +import org.apache.iceberg.encryption.EncryptionKeyMetadata; +import org.apache.iceberg.exceptions.CommitStateUnknownException; +import org.apache.iceberg.expressions.Expression; +import org.apache.iceberg.expressions.Expressions; +import org.apache.iceberg.hadoop.HadoopTables; +import org.apache.iceberg.io.CloseableIterable; +import org.apache.iceberg.io.OutputFile; +import org.apache.iceberg.relocated.com.google.common.base.Preconditions; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableList; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableSet; +import org.apache.iceberg.relocated.com.google.common.collect.Iterables; +import org.apache.iceberg.relocated.com.google.common.collect.Lists; +import org.apache.iceberg.relocated.com.google.common.collect.Maps; +import org.apache.iceberg.relocated.com.google.common.collect.Streams; +import org.apache.iceberg.spark.FileRewriteCoordinator; +import org.apache.iceberg.spark.ScanTaskSetManager; +import org.apache.iceberg.spark.SparkTableUtil; +import org.apache.iceberg.spark.SparkTestBase; +import org.apache.iceberg.spark.actions.RewriteDataFilesSparkAction.RewriteExecutionContext; +import org.apache.iceberg.spark.source.ThreeColumnRecord; +import org.apache.iceberg.types.Comparators; +import org.apache.iceberg.types.Conversions; +import org.apache.iceberg.types.Types; +import org.apache.iceberg.types.Types.NestedField; +import org.apache.iceberg.util.Pair; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Row; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.TemporaryFolder; +import org.mockito.ArgumentMatcher; +import org.mockito.Mockito; + +public class TestRewriteDataFilesAction extends SparkTestBase { + + private static final int SCALE = 400000; + + private static final HadoopTables TABLES = new HadoopTables(new Configuration()); + private static final Schema SCHEMA = + new Schema( + optional(1, "c1", Types.IntegerType.get()), + optional(2, "c2", Types.StringType.get()), + optional(3, "c3", Types.StringType.get())); + + @Rule public TemporaryFolder temp = new TemporaryFolder(); + + private final FileRewriteCoordinator coordinator = FileRewriteCoordinator.get(); + private final ScanTaskSetManager manager = ScanTaskSetManager.get(); + private String tableLocation = null; + + @Before + public void setupTableLocation() throws Exception { + File tableDir = temp.newFolder(); + this.tableLocation = tableDir.toURI().toString(); + } + + private RewriteDataFilesSparkAction basicRewrite(Table table) { + // Always compact regardless of input files + table.refresh(); + return actions().rewriteDataFiles(table).option(BinPackStrategy.MIN_INPUT_FILES, "1"); + } + + @Test + public void testEmptyTable() { + PartitionSpec spec = PartitionSpec.unpartitioned(); + Map options = Maps.newHashMap(); + Table table = TABLES.create(SCHEMA, spec, options, tableLocation); + + Assert.assertNull("Table must be empty", table.currentSnapshot()); + + basicRewrite(table).execute(); + + Assert.assertNull("Table must stay empty", table.currentSnapshot()); + } + + @Test + public void testBinPackUnpartitionedTable() { + Table table = createTable(4); + shouldHaveFiles(table, 4); + List expectedRecords = currentData(); + long dataSizeBefore = testDataSize(table); + + Result result = basicRewrite(table).execute(); + Assert.assertEquals("Action should rewrite 4 data files", 4, result.rewrittenDataFilesCount()); + Assert.assertEquals("Action should add 1 data file", 1, result.addedDataFilesCount()); + assertThat(result.rewrittenBytesCount()).isEqualTo(dataSizeBefore); + + shouldHaveFiles(table, 1); + List actual = currentData(); + + assertEquals("Rows must match", expectedRecords, actual); + } + + @Test + public void testBinPackPartitionedTable() { + Table table = createTablePartitioned(4, 2); + shouldHaveFiles(table, 8); + List expectedRecords = currentData(); + long dataSizeBefore = testDataSize(table); + + Result result = basicRewrite(table).execute(); + Assert.assertEquals("Action should rewrite 8 data files", 8, result.rewrittenDataFilesCount()); + Assert.assertEquals("Action should add 4 data file", 4, result.addedDataFilesCount()); + assertThat(result.rewrittenBytesCount()).isEqualTo(dataSizeBefore); + + shouldHaveFiles(table, 4); + List actualRecords = currentData(); + + assertEquals("Rows must match", expectedRecords, actualRecords); + } + + @Test + public void testBinPackWithFilter() { + Table table = createTablePartitioned(4, 2); + shouldHaveFiles(table, 8); + List expectedRecords = currentData(); + long dataSizeBefore = testDataSize(table); + + Result result = + basicRewrite(table) + .filter(Expressions.equal("c1", 1)) + .filter(Expressions.startsWith("c2", "foo")) + .execute(); + + Assert.assertEquals("Action should rewrite 2 data files", 2, result.rewrittenDataFilesCount()); + Assert.assertEquals("Action should add 1 data file", 1, result.addedDataFilesCount()); + assertThat(result.rewrittenBytesCount()).isGreaterThan(0L).isLessThan(dataSizeBefore); + + shouldHaveFiles(table, 7); + + List actualRecords = currentData(); + assertEquals("Rows must match", expectedRecords, actualRecords); + } + + @Test + public void testBinPackAfterPartitionChange() { + Table table = createTable(); + + writeRecords(20, SCALE, 20); + shouldHaveFiles(table, 20); + table.updateSpec().addField(Expressions.ref("c1")).commit(); + + List originalData = currentData(); + long dataSizeBefore = testDataSize(table); + + RewriteDataFiles.Result result = + basicRewrite(table) + .option(SortStrategy.MIN_INPUT_FILES, "1") + .option( + SortStrategy.MIN_FILE_SIZE_BYTES, Integer.toString(averageFileSize(table) + 1000)) + .option( + RewriteDataFiles.TARGET_FILE_SIZE_BYTES, + Integer.toString(averageFileSize(table) + 1001)) + .execute(); + + Assert.assertEquals( + "Should have 1 fileGroup because all files were not correctly partitioned", + 1, + result.rewriteResults().size()); + assertThat(result.rewrittenBytesCount()).isEqualTo(dataSizeBefore); + + List postRewriteData = currentData(); + assertEquals("We shouldn't have changed the data", originalData, postRewriteData); + + shouldHaveSnapshots(table, 2); + shouldHaveACleanCache(table); + shouldHaveFiles(table, 20); + } + + @Test + public void testBinPackWithDeletes() throws Exception { + Table table = createTablePartitioned(4, 2); + table.updateProperties().set(TableProperties.FORMAT_VERSION, "2").commit(); + shouldHaveFiles(table, 8); + table.refresh(); + + CloseableIterable tasks = table.newScan().planFiles(); + List dataFiles = + Lists.newArrayList(CloseableIterable.transform(tasks, FileScanTask::file)); + int total = (int) dataFiles.stream().mapToLong(ContentFile::recordCount).sum(); + + RowDelta rowDelta = table.newRowDelta(); + // add 1 delete file for data files 0, 1, 2 + for (int i = 0; i < 3; i++) { + writePosDeletesToFile(table, dataFiles.get(i), 1).forEach(rowDelta::addDeletes); + } + + // add 2 delete files for data files 3, 4 + for (int i = 3; i < 5; i++) { + writePosDeletesToFile(table, dataFiles.get(i), 2).forEach(rowDelta::addDeletes); + } + + rowDelta.commit(); + table.refresh(); + List expectedRecords = currentData(); + long dataSizeBefore = testDataSize(table); + + Result result = + actions() + .rewriteDataFiles(table) + // do not include any file based on bin pack file size configs + .option(BinPackStrategy.MIN_FILE_SIZE_BYTES, "0") + .option(RewriteDataFiles.TARGET_FILE_SIZE_BYTES, Long.toString(Long.MAX_VALUE - 1)) + .option(BinPackStrategy.MAX_FILE_SIZE_BYTES, Long.toString(Long.MAX_VALUE)) + .option(BinPackStrategy.DELETE_FILE_THRESHOLD, "2") + .execute(); + Assert.assertEquals("Action should rewrite 2 data files", 2, result.rewrittenDataFilesCount()); + assertThat(result.rewrittenBytesCount()).isGreaterThan(0L).isLessThan(dataSizeBefore); + + List actualRecords = currentData(); + assertEquals("Rows must match", expectedRecords, actualRecords); + Assert.assertEquals("7 rows are removed", total - 7, actualRecords.size()); + } + + @Test + public void testBinPackWithDeleteAllData() { + Map options = Maps.newHashMap(); + options.put(TableProperties.FORMAT_VERSION, "2"); + Table table = createTablePartitioned(1, 1, 1, options); + shouldHaveFiles(table, 1); + table.refresh(); + + CloseableIterable tasks = table.newScan().planFiles(); + List dataFiles = + Lists.newArrayList(CloseableIterable.transform(tasks, FileScanTask::file)); + int total = (int) dataFiles.stream().mapToLong(ContentFile::recordCount).sum(); + + RowDelta rowDelta = table.newRowDelta(); + // remove all data + writePosDeletesToFile(table, dataFiles.get(0), total).forEach(rowDelta::addDeletes); + + rowDelta.commit(); + table.refresh(); + List expectedRecords = currentData(); + long dataSizeBefore = testDataSize(table); + + Result result = + actions() + .rewriteDataFiles(table) + .option(BinPackStrategy.DELETE_FILE_THRESHOLD, "1") + .execute(); + Assert.assertEquals("Action should rewrite 1 data files", 1, result.rewrittenDataFilesCount()); + assertThat(result.rewrittenBytesCount()).isEqualTo(dataSizeBefore); + + List actualRecords = currentData(); + assertEquals("Rows must match", expectedRecords, actualRecords); + Assert.assertEquals( + "Data manifest should not have existing data file", + 0, + (long) table.currentSnapshot().dataManifests(table.io()).get(0).existingFilesCount()); + Assert.assertEquals( + "Data manifest should have 1 delete data file", + 1L, + (long) table.currentSnapshot().dataManifests(table.io()).get(0).deletedFilesCount()); + Assert.assertEquals( + "Delete manifest added row count should equal total count", + total, + (long) table.currentSnapshot().deleteManifests(table.io()).get(0).addedRowsCount()); + } + + @Test + public void testBinPackWithStartingSequenceNumber() { + Table table = createTablePartitioned(4, 2); + shouldHaveFiles(table, 8); + List expectedRecords = currentData(); + table.updateProperties().set(TableProperties.FORMAT_VERSION, "2").commit(); + table.refresh(); + long oldSequenceNumber = table.currentSnapshot().sequenceNumber(); + long dataSizeBefore = testDataSize(table); + + Result result = + basicRewrite(table).option(RewriteDataFiles.USE_STARTING_SEQUENCE_NUMBER, "true").execute(); + Assert.assertEquals("Action should rewrite 8 data files", 8, result.rewrittenDataFilesCount()); + Assert.assertEquals("Action should add 4 data file", 4, result.addedDataFilesCount()); + assertThat(result.rewrittenBytesCount()).isEqualTo(dataSizeBefore); + + shouldHaveFiles(table, 4); + List actualRecords = currentData(); + assertEquals("Rows must match", expectedRecords, actualRecords); + + table.refresh(); + Assert.assertTrue( + "Table sequence number should be incremented", + oldSequenceNumber < table.currentSnapshot().sequenceNumber()); + + Dataset rows = SparkTableUtil.loadMetadataTable(spark, table, MetadataTableType.ENTRIES); + for (Row row : rows.collectAsList()) { + if (row.getInt(0) == 1) { + Assert.assertEquals( + "Expect old sequence number for added entries", oldSequenceNumber, row.getLong(2)); + } + } + } + + @Test + public void testBinPackWithStartingSequenceNumberV1Compatibility() { + Table table = createTablePartitioned(4, 2); + shouldHaveFiles(table, 8); + List expectedRecords = currentData(); + table.refresh(); + long oldSequenceNumber = table.currentSnapshot().sequenceNumber(); + Assert.assertEquals("Table sequence number should be 0", 0, oldSequenceNumber); + long dataSizeBefore = testDataSize(table); + + Result result = + basicRewrite(table).option(RewriteDataFiles.USE_STARTING_SEQUENCE_NUMBER, "true").execute(); + Assert.assertEquals("Action should rewrite 8 data files", 8, result.rewrittenDataFilesCount()); + Assert.assertEquals("Action should add 4 data file", 4, result.addedDataFilesCount()); + assertThat(result.rewrittenBytesCount()).isEqualTo(dataSizeBefore); + + shouldHaveFiles(table, 4); + List actualRecords = currentData(); + assertEquals("Rows must match", expectedRecords, actualRecords); + + table.refresh(); + Assert.assertEquals( + "Table sequence number should still be 0", + oldSequenceNumber, + table.currentSnapshot().sequenceNumber()); + + Dataset rows = SparkTableUtil.loadMetadataTable(spark, table, MetadataTableType.ENTRIES); + for (Row row : rows.collectAsList()) { + Assert.assertEquals( + "Expect sequence number 0 for all entries", oldSequenceNumber, row.getLong(2)); + } + } + + @Test + public void testRewriteLargeTableHasResiduals() { + PartitionSpec spec = PartitionSpec.builderFor(SCHEMA).build(); + Map options = Maps.newHashMap(); + options.put(TableProperties.PARQUET_ROW_GROUP_SIZE_BYTES, "100"); + Table table = TABLES.create(SCHEMA, spec, options, tableLocation); + + // all records belong to the same partition + List records = Lists.newArrayList(); + for (int i = 0; i < 100; i++) { + records.add(new ThreeColumnRecord(i, String.valueOf(i), String.valueOf(i % 4))); + } + Dataset df = spark.createDataFrame(records, ThreeColumnRecord.class); + writeDF(df); + + List expectedRecords = currentData(); + + table.refresh(); + + CloseableIterable tasks = + table.newScan().ignoreResiduals().filter(Expressions.equal("c3", "0")).planFiles(); + for (FileScanTask task : tasks) { + Assert.assertEquals("Residuals must be ignored", Expressions.alwaysTrue(), task.residual()); + } + + shouldHaveFiles(table, 2); + + long dataSizeBefore = testDataSize(table); + Result result = basicRewrite(table).filter(Expressions.equal("c3", "0")).execute(); + Assert.assertEquals("Action should rewrite 2 data files", 2, result.rewrittenDataFilesCount()); + Assert.assertEquals("Action should add 1 data file", 1, result.addedDataFilesCount()); + assertThat(result.rewrittenBytesCount()).isEqualTo(dataSizeBefore); + + List actualRecords = currentData(); + + assertEquals("Rows must match", expectedRecords, actualRecords); + } + + @Test + public void testBinPackSplitLargeFile() { + Table table = createTable(1); + shouldHaveFiles(table, 1); + + List expectedRecords = currentData(); + long targetSize = testDataSize(table) / 2; + + long dataSizeBefore = testDataSize(table); + Result result = + basicRewrite(table) + .option(RewriteDataFiles.TARGET_FILE_SIZE_BYTES, Long.toString(targetSize)) + .option(BinPackStrategy.MAX_FILE_SIZE_BYTES, Long.toString(targetSize * 2 - 2000)) + .execute(); + + Assert.assertEquals("Action should delete 1 data files", 1, result.rewrittenDataFilesCount()); + Assert.assertEquals("Action should add 2 data files", 2, result.addedDataFilesCount()); + assertThat(result.rewrittenBytesCount()).isEqualTo(dataSizeBefore); + + shouldHaveFiles(table, 2); + + List actualRecords = currentData(); + assertEquals("Rows must match", expectedRecords, actualRecords); + } + + @Test + public void testBinPackCombineMixedFiles() { + Table table = createTable(1); // 400000 + shouldHaveFiles(table, 1); + + // Add one more small file, and one large file + writeRecords(1, SCALE); + writeRecords(1, SCALE * 3); + shouldHaveFiles(table, 3); + + List expectedRecords = currentData(); + + int targetSize = averageFileSize(table); + + long dataSizeBefore = testDataSize(table); + Result result = + basicRewrite(table) + .option(RewriteDataFiles.TARGET_FILE_SIZE_BYTES, Integer.toString(targetSize + 1000)) + .option(BinPackStrategy.MAX_FILE_SIZE_BYTES, Integer.toString(targetSize + 80000)) + .option(BinPackStrategy.MIN_FILE_SIZE_BYTES, Integer.toString(targetSize - 1000)) + .execute(); + + Assert.assertEquals("Action should delete 3 data files", 3, result.rewrittenDataFilesCount()); + // Should Split the big files into 3 pieces, one of which should be combined with the two + // smaller files + Assert.assertEquals("Action should add 3 data files", 3, result.addedDataFilesCount()); + assertThat(result.rewrittenBytesCount()).isEqualTo(dataSizeBefore); + + shouldHaveFiles(table, 3); + + List actualRecords = currentData(); + assertEquals("Rows must match", expectedRecords, actualRecords); + } + + @Test + public void testBinPackCombineMediumFiles() { + Table table = createTable(4); + shouldHaveFiles(table, 4); + + List expectedRecords = currentData(); + int targetSize = ((int) testDataSize(table) / 3); + // The test is to see if we can combine parts of files to make files of the correct size + + long dataSizeBefore = testDataSize(table); + Result result = + basicRewrite(table) + .option(RewriteDataFiles.TARGET_FILE_SIZE_BYTES, Integer.toString(targetSize)) + .option(BinPackStrategy.MAX_FILE_SIZE_BYTES, Integer.toString((int) (targetSize * 1.8))) + .option( + BinPackStrategy.MIN_FILE_SIZE_BYTES, + Integer.toString(targetSize - 100)) // All files too small + .execute(); + + Assert.assertEquals("Action should delete 4 data files", 4, result.rewrittenDataFilesCount()); + Assert.assertEquals("Action should add 3 data files", 3, result.addedDataFilesCount()); + assertThat(result.rewrittenBytesCount()).isEqualTo(dataSizeBefore); + + shouldHaveFiles(table, 3); + + List actualRecords = currentData(); + assertEquals("Rows must match", expectedRecords, actualRecords); + } + + @Test + public void testPartialProgressEnabled() { + Table table = createTable(20); + int fileSize = averageFileSize(table); + + table.updateProperties().set(COMMIT_NUM_RETRIES, "10").commit(); + + List originalData = currentData(); + long dataSizeBefore = testDataSize(table); + + // Perform a rewrite but only allow 2 files to be compacted at a time + RewriteDataFiles.Result result = + basicRewrite(table) + .option(RewriteDataFiles.PARTIAL_PROGRESS_ENABLED, "true") + .option( + RewriteDataFiles.MAX_FILE_GROUP_SIZE_BYTES, Integer.toString(fileSize * 2 + 1000)) + .option(RewriteDataFiles.PARTIAL_PROGRESS_MAX_COMMITS, "10") + .execute(); + + Assert.assertEquals("Should have 10 fileGroups", result.rewriteResults().size(), 10); + assertThat(result.rewrittenBytesCount()).isEqualTo(dataSizeBefore); + + table.refresh(); + + shouldHaveSnapshots(table, 11); + shouldHaveACleanCache(table); + + List postRewriteData = currentData(); + assertEquals("We shouldn't have changed the data", originalData, postRewriteData); + } + + @Test + public void testMultipleGroups() { + Table table = createTable(20); + int fileSize = averageFileSize(table); + + List originalData = currentData(); + long dataSizeBefore = testDataSize(table); + + // Perform a rewrite but only allow 2 files to be compacted at a time + RewriteDataFiles.Result result = + basicRewrite(table) + .option( + RewriteDataFiles.MAX_FILE_GROUP_SIZE_BYTES, Integer.toString(fileSize * 2 + 1000)) + .option(BinPackStrategy.MIN_INPUT_FILES, "1") + .execute(); + + Assert.assertEquals("Should have 10 fileGroups", result.rewriteResults().size(), 10); + assertThat(result.rewrittenBytesCount()).isEqualTo(dataSizeBefore); + + table.refresh(); + + List postRewriteData = currentData(); + assertEquals("We shouldn't have changed the data", originalData, postRewriteData); + + shouldHaveSnapshots(table, 2); + shouldHaveACleanCache(table); + } + + @Test + public void testPartialProgressMaxCommits() { + Table table = createTable(20); + int fileSize = averageFileSize(table); + + List originalData = currentData(); + long dataSizeBefore = testDataSize(table); + + // Perform a rewrite but only allow 2 files to be compacted at a time + RewriteDataFiles.Result result = + basicRewrite(table) + .option( + RewriteDataFiles.MAX_FILE_GROUP_SIZE_BYTES, Integer.toString(fileSize * 2 + 1000)) + .option(RewriteDataFiles.PARTIAL_PROGRESS_ENABLED, "true") + .option(RewriteDataFiles.PARTIAL_PROGRESS_MAX_COMMITS, "3") + .execute(); + + Assert.assertEquals("Should have 10 fileGroups", result.rewriteResults().size(), 10); + assertThat(result.rewrittenBytesCount()).isEqualTo(dataSizeBefore); + + table.refresh(); + + List postRewriteData = currentData(); + assertEquals("We shouldn't have changed the data", originalData, postRewriteData); + + shouldHaveSnapshots(table, 4); + shouldHaveACleanCache(table); + } + + @Test + public void testSingleCommitWithRewriteFailure() { + Table table = createTable(20); + int fileSize = averageFileSize(table); + + List originalData = currentData(); + + RewriteDataFilesSparkAction realRewrite = + basicRewrite(table) + .option( + RewriteDataFiles.MAX_FILE_GROUP_SIZE_BYTES, Integer.toString(fileSize * 2 + 1000)); + + RewriteDataFilesSparkAction spyRewrite = Mockito.spy(realRewrite); + + // Fail groups 1, 3, and 7 during rewrite + GroupInfoMatcher failGroup = new GroupInfoMatcher(1, 3, 7); + doThrow(new RuntimeException("Rewrite Failed")) + .when(spyRewrite) + .rewriteFiles(any(), argThat(failGroup)); + + AssertHelpers.assertThrows( + "Should fail entire rewrite if part fails", + RuntimeException.class, + () -> spyRewrite.execute()); + + table.refresh(); + + List postRewriteData = currentData(); + assertEquals("We shouldn't have changed the data", originalData, postRewriteData); + + shouldHaveSnapshots(table, 1); + shouldHaveNoOrphans(table); + shouldHaveACleanCache(table); + } + + @Test + public void testSingleCommitWithCommitFailure() { + Table table = createTable(20); + int fileSize = averageFileSize(table); + + List originalData = currentData(); + + RewriteDataFilesSparkAction realRewrite = + basicRewrite(table) + .option( + RewriteDataFiles.MAX_FILE_GROUP_SIZE_BYTES, Integer.toString(fileSize * 2 + 1000)); + + RewriteDataFilesSparkAction spyRewrite = spy(realRewrite); + RewriteDataFilesCommitManager util = spy(new RewriteDataFilesCommitManager(table)); + + // Fail to commit + doThrow(new RuntimeException("Commit Failure")).when(util).commitFileGroups(any()); + + doReturn(util).when(spyRewrite).commitManager(table.currentSnapshot().snapshotId()); + + AssertHelpers.assertThrows( + "Should fail entire rewrite if commit fails", + RuntimeException.class, + () -> spyRewrite.execute()); + + table.refresh(); + + List postRewriteData = currentData(); + assertEquals("We shouldn't have changed the data", originalData, postRewriteData); + + shouldHaveSnapshots(table, 1); + shouldHaveNoOrphans(table); + shouldHaveACleanCache(table); + } + + @Test + public void testParallelSingleCommitWithRewriteFailure() { + Table table = createTable(20); + int fileSize = averageFileSize(table); + + List originalData = currentData(); + + RewriteDataFilesSparkAction realRewrite = + basicRewrite(table) + .option( + RewriteDataFiles.MAX_FILE_GROUP_SIZE_BYTES, Integer.toString(fileSize * 2 + 1000)) + .option(RewriteDataFiles.MAX_CONCURRENT_FILE_GROUP_REWRITES, "3"); + + RewriteDataFilesSparkAction spyRewrite = Mockito.spy(realRewrite); + + // Fail groups 1, 3, and 7 during rewrite + GroupInfoMatcher failGroup = new GroupInfoMatcher(1, 3, 7); + doThrow(new RuntimeException("Rewrite Failed")) + .when(spyRewrite) + .rewriteFiles(any(), argThat(failGroup)); + + AssertHelpers.assertThrows( + "Should fail entire rewrite if part fails", + RuntimeException.class, + () -> spyRewrite.execute()); + + table.refresh(); + + List postRewriteData = currentData(); + assertEquals("We shouldn't have changed the data", originalData, postRewriteData); + + shouldHaveSnapshots(table, 1); + shouldHaveNoOrphans(table); + shouldHaveACleanCache(table); + } + + @Test + public void testPartialProgressWithRewriteFailure() { + Table table = createTable(20); + int fileSize = averageFileSize(table); + + List originalData = currentData(); + long dataSizeBefore = testDataSize(table); + + RewriteDataFilesSparkAction realRewrite = + basicRewrite(table) + .option( + RewriteDataFiles.MAX_FILE_GROUP_SIZE_BYTES, Integer.toString(fileSize * 2 + 1000)) + .option(RewriteDataFiles.PARTIAL_PROGRESS_ENABLED, "true") + .option(RewriteDataFiles.PARTIAL_PROGRESS_MAX_COMMITS, "3"); + + RewriteDataFilesSparkAction spyRewrite = Mockito.spy(realRewrite); + + // Fail groups 1, 3, and 7 during rewrite + GroupInfoMatcher failGroup = new GroupInfoMatcher(1, 3, 7); + doThrow(new RuntimeException("Rewrite Failed")) + .when(spyRewrite) + .rewriteFiles(any(), argThat(failGroup)); + + RewriteDataFiles.Result result = spyRewrite.execute(); + + Assert.assertEquals("Should have 7 fileGroups", result.rewriteResults().size(), 7); + assertThat(result.rewrittenBytesCount()).isGreaterThan(0L).isLessThan(dataSizeBefore); + + table.refresh(); + + List postRewriteData = currentData(); + assertEquals("We shouldn't have changed the data", originalData, postRewriteData); + + // With 10 original groups and Max Commits of 3, we should have commits with 4, 4, and 2. + // removing 3 groups leaves us with only 2 new commits, 4 and 3 + shouldHaveSnapshots(table, 3); + shouldHaveNoOrphans(table); + shouldHaveACleanCache(table); + } + + @Test + public void testParallelPartialProgressWithRewriteFailure() { + Table table = createTable(20); + int fileSize = averageFileSize(table); + + List originalData = currentData(); + long dataSizeBefore = testDataSize(table); + + RewriteDataFilesSparkAction realRewrite = + basicRewrite(table) + .option( + RewriteDataFiles.MAX_FILE_GROUP_SIZE_BYTES, Integer.toString(fileSize * 2 + 1000)) + .option(RewriteDataFiles.MAX_CONCURRENT_FILE_GROUP_REWRITES, "3") + .option(RewriteDataFiles.PARTIAL_PROGRESS_ENABLED, "true") + .option(RewriteDataFiles.PARTIAL_PROGRESS_MAX_COMMITS, "3"); + + RewriteDataFilesSparkAction spyRewrite = Mockito.spy(realRewrite); + + // Fail groups 1, 3, and 7 during rewrite + GroupInfoMatcher failGroup = new GroupInfoMatcher(1, 3, 7); + doThrow(new RuntimeException("Rewrite Failed")) + .when(spyRewrite) + .rewriteFiles(any(), argThat(failGroup)); + + RewriteDataFiles.Result result = spyRewrite.execute(); + + Assert.assertEquals("Should have 7 fileGroups", result.rewriteResults().size(), 7); + assertThat(result.rewrittenBytesCount()).isGreaterThan(0L).isLessThan(dataSizeBefore); + + table.refresh(); + + List postRewriteData = currentData(); + assertEquals("We shouldn't have changed the data", originalData, postRewriteData); + + // With 10 original groups and Max Commits of 3, we should have commits with 4, 4, and 2. + // removing 3 groups leaves us with only 2 new commits, 4 and 3 + shouldHaveSnapshots(table, 3); + shouldHaveNoOrphans(table); + shouldHaveACleanCache(table); + } + + @Test + public void testParallelPartialProgressWithCommitFailure() { + Table table = createTable(20); + int fileSize = averageFileSize(table); + + List originalData = currentData(); + long dataSizeBefore = testDataSize(table); + + RewriteDataFilesSparkAction realRewrite = + basicRewrite(table) + .option( + RewriteDataFiles.MAX_FILE_GROUP_SIZE_BYTES, Integer.toString(fileSize * 2 + 1000)) + .option(RewriteDataFiles.MAX_CONCURRENT_FILE_GROUP_REWRITES, "3") + .option(RewriteDataFiles.PARTIAL_PROGRESS_ENABLED, "true") + .option(RewriteDataFiles.PARTIAL_PROGRESS_MAX_COMMITS, "3"); + + RewriteDataFilesSparkAction spyRewrite = spy(realRewrite); + RewriteDataFilesCommitManager util = spy(new RewriteDataFilesCommitManager(table)); + + // First and Third commits work, second does not + doCallRealMethod() + .doThrow(new RuntimeException("Commit Failed")) + .doCallRealMethod() + .when(util) + .commitFileGroups(any()); + + doReturn(util).when(spyRewrite).commitManager(table.currentSnapshot().snapshotId()); + + RewriteDataFiles.Result result = spyRewrite.execute(); + + // Commit 1: 4/4 + Commit 2 failed 0/4 + Commit 3: 2/2 == 6 out of 10 total groups comitted + Assert.assertEquals("Should have 6 fileGroups", 6, result.rewriteResults().size()); + assertThat(result.rewrittenBytesCount()).isGreaterThan(0L).isLessThan(dataSizeBefore); + + table.refresh(); + + List postRewriteData = currentData(); + assertEquals("We shouldn't have changed the data", originalData, postRewriteData); + + // Only 2 new commits because we broke one + shouldHaveSnapshots(table, 3); + shouldHaveNoOrphans(table); + shouldHaveACleanCache(table); + } + + @Test + public void testInvalidOptions() { + Table table = createTable(20); + + AssertHelpers.assertThrows( + "No negative values for partial progress max commits", + IllegalArgumentException.class, + () -> + basicRewrite(table) + .option(RewriteDataFiles.PARTIAL_PROGRESS_ENABLED, "true") + .option(RewriteDataFiles.PARTIAL_PROGRESS_MAX_COMMITS, "-5") + .execute()); + + AssertHelpers.assertThrows( + "No negative values for max concurrent groups", + IllegalArgumentException.class, + () -> + basicRewrite(table) + .option(RewriteDataFiles.MAX_CONCURRENT_FILE_GROUP_REWRITES, "-5") + .execute()); + + AssertHelpers.assertThrows( + "No unknown options allowed", + IllegalArgumentException.class, + () -> basicRewrite(table).option("foobarity", "-5").execute()); + + AssertHelpers.assertThrows( + "Cannot set rewrite-job-order to foo", + IllegalArgumentException.class, + () -> basicRewrite(table).option(RewriteDataFiles.REWRITE_JOB_ORDER, "foo").execute()); + } + + @Test + public void testSortMultipleGroups() { + Table table = createTable(20); + shouldHaveFiles(table, 20); + table.replaceSortOrder().asc("c2").commit(); + shouldHaveLastCommitUnsorted(table, "c2"); + int fileSize = averageFileSize(table); + + List originalData = currentData(); + long dataSizeBefore = testDataSize(table); + + // Perform a rewrite but only allow 2 files to be compacted at a time + RewriteDataFiles.Result result = + basicRewrite(table) + .sort() + .option(SortStrategy.REWRITE_ALL, "true") + .option( + RewriteDataFiles.MAX_FILE_GROUP_SIZE_BYTES, Integer.toString(fileSize * 2 + 1000)) + .execute(); + + Assert.assertEquals("Should have 10 fileGroups", result.rewriteResults().size(), 10); + assertThat(result.rewrittenBytesCount()).isEqualTo(dataSizeBefore); + + table.refresh(); + + List postRewriteData = currentData(); + assertEquals("We shouldn't have changed the data", originalData, postRewriteData); + + shouldHaveSnapshots(table, 2); + shouldHaveACleanCache(table); + } + + @Test + public void testSimpleSort() { + Table table = createTable(20); + shouldHaveFiles(table, 20); + table.replaceSortOrder().asc("c2").commit(); + shouldHaveLastCommitUnsorted(table, "c2"); + + List originalData = currentData(); + long dataSizeBefore = testDataSize(table); + + RewriteDataFiles.Result result = + basicRewrite(table) + .sort() + .option(SortStrategy.MIN_INPUT_FILES, "1") + .option(SortStrategy.REWRITE_ALL, "true") + .option( + RewriteDataFiles.TARGET_FILE_SIZE_BYTES, Integer.toString(averageFileSize(table))) + .execute(); + + Assert.assertEquals("Should have 1 fileGroups", result.rewriteResults().size(), 1); + assertThat(result.rewrittenBytesCount()).isEqualTo(dataSizeBefore); + + table.refresh(); + + List postRewriteData = currentData(); + assertEquals("We shouldn't have changed the data", originalData, postRewriteData); + + shouldHaveSnapshots(table, 2); + shouldHaveACleanCache(table); + shouldHaveMultipleFiles(table); + shouldHaveLastCommitSorted(table, "c2"); + } + + @Test + public void testSortAfterPartitionChange() { + Table table = createTable(20); + shouldHaveFiles(table, 20); + table.updateSpec().addField(Expressions.bucket("c1", 4)).commit(); + table.replaceSortOrder().asc("c2").commit(); + shouldHaveLastCommitUnsorted(table, "c2"); + + List originalData = currentData(); + long dataSizeBefore = testDataSize(table); + + RewriteDataFiles.Result result = + basicRewrite(table) + .sort() + .option(SortStrategy.MIN_INPUT_FILES, "1") + .option(SortStrategy.REWRITE_ALL, "true") + .option( + RewriteDataFiles.TARGET_FILE_SIZE_BYTES, Integer.toString(averageFileSize(table))) + .execute(); + + Assert.assertEquals( + "Should have 1 fileGroup because all files were not correctly partitioned", + result.rewriteResults().size(), + 1); + assertThat(result.rewrittenBytesCount()).isEqualTo(dataSizeBefore); + + table.refresh(); + + List postRewriteData = currentData(); + assertEquals("We shouldn't have changed the data", originalData, postRewriteData); + + shouldHaveSnapshots(table, 2); + shouldHaveACleanCache(table); + shouldHaveMultipleFiles(table); + shouldHaveLastCommitSorted(table, "c2"); + } + + @Test + public void testSortCustomSortOrder() { + Table table = createTable(20); + shouldHaveLastCommitUnsorted(table, "c2"); + shouldHaveFiles(table, 20); + + List originalData = currentData(); + long dataSizeBefore = testDataSize(table); + + RewriteDataFiles.Result result = + basicRewrite(table) + .sort(SortOrder.builderFor(table.schema()).asc("c2").build()) + .option(SortStrategy.REWRITE_ALL, "true") + .option( + RewriteDataFiles.TARGET_FILE_SIZE_BYTES, Integer.toString(averageFileSize(table))) + .execute(); + + Assert.assertEquals("Should have 1 fileGroups", result.rewriteResults().size(), 1); + assertThat(result.rewrittenBytesCount()).isEqualTo(dataSizeBefore); + + table.refresh(); + + List postRewriteData = currentData(); + assertEquals("We shouldn't have changed the data", originalData, postRewriteData); + + shouldHaveSnapshots(table, 2); + shouldHaveACleanCache(table); + shouldHaveMultipleFiles(table); + shouldHaveLastCommitSorted(table, "c2"); + } + + @Test + public void testSortCustomSortOrderRequiresRepartition() { + int partitions = 4; + Table table = createTable(); + writeRecords(20, SCALE, partitions); + shouldHaveLastCommitUnsorted(table, "c3"); + + // Add a partition column so this requires repartitioning + table.updateSpec().addField("c1").commit(); + // Add a sort order which our repartitioning needs to ignore + table.replaceSortOrder().asc("c2").apply(); + shouldHaveFiles(table, 20); + + List originalData = currentData(); + long dataSizeBefore = testDataSize(table); + + RewriteDataFiles.Result result = + basicRewrite(table) + .sort(SortOrder.builderFor(table.schema()).asc("c3").build()) + .option(SortStrategy.REWRITE_ALL, "true") + .option( + RewriteDataFiles.TARGET_FILE_SIZE_BYTES, + Integer.toString(averageFileSize(table) / partitions)) + .execute(); + + Assert.assertEquals("Should have 1 fileGroups", result.rewriteResults().size(), 1); + assertThat(result.rewrittenBytesCount()).isEqualTo(dataSizeBefore); + + table.refresh(); + + List postRewriteData = currentData(); + assertEquals("We shouldn't have changed the data", originalData, postRewriteData); + + shouldHaveSnapshots(table, 2); + shouldHaveACleanCache(table); + shouldHaveMultipleFiles(table); + shouldHaveLastCommitUnsorted(table, "c2"); + shouldHaveLastCommitSorted(table, "c3"); + } + + @Test + public void testAutoSortShuffleOutput() { + Table table = createTable(20); + shouldHaveLastCommitUnsorted(table, "c2"); + shouldHaveFiles(table, 20); + + List originalData = currentData(); + long dataSizeBefore = testDataSize(table); + + RewriteDataFiles.Result result = + basicRewrite(table) + .sort(SortOrder.builderFor(table.schema()).asc("c2").build()) + .option( + SortStrategy.MAX_FILE_SIZE_BYTES, + Integer.toString((averageFileSize(table) / 2) + 2)) + // Divide files in 2 + .option( + RewriteDataFiles.TARGET_FILE_SIZE_BYTES, + Integer.toString(averageFileSize(table) / 2)) + .option(SortStrategy.MIN_INPUT_FILES, "1") + .execute(); + + Assert.assertEquals("Should have 1 fileGroups", result.rewriteResults().size(), 1); + assertThat(result.rewrittenBytesCount()).isEqualTo(dataSizeBefore); + Assert.assertTrue( + "Should have written 40+ files", + Iterables.size(table.currentSnapshot().addedDataFiles(table.io())) >= 40); + + table.refresh(); + + List postRewriteData = currentData(); + assertEquals("We shouldn't have changed the data", originalData, postRewriteData); + + shouldHaveSnapshots(table, 2); + shouldHaveACleanCache(table); + shouldHaveMultipleFiles(table); + shouldHaveLastCommitSorted(table, "c2"); + } + + @Test + public void testCommitStateUnknownException() { + Table table = createTable(20); + shouldHaveFiles(table, 20); + + List originalData = currentData(); + + RewriteDataFilesSparkAction action = basicRewrite(table); + RewriteDataFilesSparkAction spyAction = spy(action); + RewriteDataFilesCommitManager util = spy(new RewriteDataFilesCommitManager(table)); + + doAnswer( + invocationOnMock -> { + invocationOnMock.callRealMethod(); + throw new CommitStateUnknownException(new RuntimeException("Unknown State")); + }) + .when(util) + .commitFileGroups(any()); + + doReturn(util).when(spyAction).commitManager(table.currentSnapshot().snapshotId()); + + AssertHelpers.assertThrows( + "Should propagate CommitStateUnknown Exception", + CommitStateUnknownException.class, + () -> spyAction.execute()); + + List postRewriteData = currentData(); + assertEquals("We shouldn't have changed the data", originalData, postRewriteData); + + shouldHaveSnapshots(table, 2); // Commit actually Succeeded + } + + @Test + public void testZOrderSort() { + int originalFiles = 20; + Table table = createTable(originalFiles); + shouldHaveLastCommitUnsorted(table, "c2"); + shouldHaveFiles(table, originalFiles); + + List originalData = currentData(); + double originalFilesC2 = percentFilesRequired(table, "c2", "foo23"); + double originalFilesC3 = percentFilesRequired(table, "c3", "bar21"); + double originalFilesC2C3 = + percentFilesRequired(table, new String[] {"c2", "c3"}, new String[] {"foo23", "bar23"}); + + Assert.assertTrue("Should require all files to scan c2", originalFilesC2 > 0.99); + Assert.assertTrue("Should require all files to scan c3", originalFilesC3 > 0.99); + + long dataSizeBefore = testDataSize(table); + RewriteDataFiles.Result result = + basicRewrite(table) + .zOrder("c2", "c3") + .option( + SortStrategy.MAX_FILE_SIZE_BYTES, + Integer.toString((averageFileSize(table) / 2) + 2)) + // Divide files in 2 + .option( + RewriteDataFiles.TARGET_FILE_SIZE_BYTES, + Integer.toString(averageFileSize(table) / 2)) + .option(SortStrategy.MIN_INPUT_FILES, "1") + .execute(); + + Assert.assertEquals("Should have 1 fileGroups", 1, result.rewriteResults().size()); + assertThat(result.rewrittenBytesCount()).isEqualTo(dataSizeBefore); + int zOrderedFilesTotal = Iterables.size(table.currentSnapshot().addedDataFiles(table.io())); + Assert.assertTrue("Should have written 40+ files", zOrderedFilesTotal >= 40); + + table.refresh(); + + List postRewriteData = currentData(); + assertEquals("We shouldn't have changed the data", originalData, postRewriteData); + + shouldHaveSnapshots(table, 2); + shouldHaveACleanCache(table); + + double filesScannedC2 = percentFilesRequired(table, "c2", "foo23"); + double filesScannedC3 = percentFilesRequired(table, "c3", "bar21"); + double filesScannedC2C3 = + percentFilesRequired(table, new String[] {"c2", "c3"}, new String[] {"foo23", "bar23"}); + + Assert.assertTrue( + "Should have reduced the number of files required for c2", + filesScannedC2 < originalFilesC2); + Assert.assertTrue( + "Should have reduced the number of files required for c3", + filesScannedC3 < originalFilesC3); + Assert.assertTrue( + "Should have reduced the number of files required for a c2,c3 predicate", + filesScannedC2C3 < originalFilesC2C3); + } + + @Test + public void testZOrderAllTypesSort() { + Table table = createTypeTestTable(); + shouldHaveFiles(table, 10); + + List originalRaw = + spark.read().format("iceberg").load(tableLocation).sort("longCol").collectAsList(); + List originalData = rowsToJava(originalRaw); + long dataSizeBefore = testDataSize(table); + + // TODO add in UUID when it is supported in Spark + RewriteDataFiles.Result result = + basicRewrite(table) + .zOrder( + "longCol", + "intCol", + "floatCol", + "doubleCol", + "dateCol", + "timestampCol", + "stringCol", + "binaryCol", + "booleanCol") + .option(SortStrategy.MIN_INPUT_FILES, "1") + .option(SortStrategy.REWRITE_ALL, "true") + .execute(); + + Assert.assertEquals("Should have 1 fileGroups", 1, result.rewriteResults().size()); + assertThat(result.rewrittenBytesCount()).isEqualTo(dataSizeBefore); + int zOrderedFilesTotal = Iterables.size(table.currentSnapshot().addedDataFiles(table.io())); + Assert.assertEquals("Should have written 1 file", 1, zOrderedFilesTotal); + + table.refresh(); + + List postRaw = + spark.read().format("iceberg").load(tableLocation).sort("longCol").collectAsList(); + List postRewriteData = rowsToJava(postRaw); + assertEquals("We shouldn't have changed the data", originalData, postRewriteData); + + shouldHaveSnapshots(table, 2); + shouldHaveACleanCache(table); + } + + @Test + public void testInvalidAPIUsage() { + Table table = createTable(1); + + SortOrder sortOrder = SortOrder.builderFor(table.schema()).asc("c2").build(); + + AssertHelpers.assertThrows( + "Should be unable to set Strategy more than once", + IllegalArgumentException.class, + "Must use only one rewriter type", + () -> actions().rewriteDataFiles(table).binPack().sort()); + + AssertHelpers.assertThrows( + "Should be unable to set Strategy more than once", + IllegalArgumentException.class, + "Must use only one rewriter type", + () -> actions().rewriteDataFiles(table).sort(sortOrder).binPack()); + + AssertHelpers.assertThrows( + "Should be unable to set Strategy more than once", + IllegalArgumentException.class, + "Must use only one rewriter type", + () -> actions().rewriteDataFiles(table).sort(sortOrder).binPack()); + } + + @Test + public void testRewriteJobOrderBytesAsc() { + Table table = createTablePartitioned(4, 2); + writeRecords(1, SCALE, 1); + writeRecords(2, SCALE, 2); + writeRecords(3, SCALE, 3); + writeRecords(4, SCALE, 4); + table.updateProperties().set(TableProperties.FORMAT_VERSION, "2").commit(); + + RewriteDataFilesSparkAction basicRewrite = basicRewrite(table).binPack(); + List expected = + toGroupStream(table, basicRewrite) + .mapToLong(RewriteFileGroup::sizeInBytes) + .boxed() + .collect(Collectors.toList()); + + RewriteDataFilesSparkAction jobOrderRewrite = + basicRewrite(table) + .option(RewriteDataFiles.REWRITE_JOB_ORDER, RewriteJobOrder.BYTES_ASC.orderName()) + .binPack(); + List actual = + toGroupStream(table, jobOrderRewrite) + .mapToLong(RewriteFileGroup::sizeInBytes) + .boxed() + .collect(Collectors.toList()); + + expected.sort(Comparator.naturalOrder()); + Assert.assertEquals("Size in bytes order should be ascending", actual, expected); + Collections.reverse(expected); + Assert.assertNotEquals("Size in bytes order should not be descending", actual, expected); + } + + @Test + public void testRewriteJobOrderBytesDesc() { + Table table = createTablePartitioned(4, 2); + writeRecords(1, SCALE, 1); + writeRecords(2, SCALE, 2); + writeRecords(3, SCALE, 3); + writeRecords(4, SCALE, 4); + table.updateProperties().set(TableProperties.FORMAT_VERSION, "2").commit(); + + RewriteDataFilesSparkAction basicRewrite = basicRewrite(table).binPack(); + List expected = + toGroupStream(table, basicRewrite) + .mapToLong(RewriteFileGroup::sizeInBytes) + .boxed() + .collect(Collectors.toList()); + + RewriteDataFilesSparkAction jobOrderRewrite = + basicRewrite(table) + .option(RewriteDataFiles.REWRITE_JOB_ORDER, RewriteJobOrder.BYTES_DESC.orderName()) + .binPack(); + List actual = + toGroupStream(table, jobOrderRewrite) + .mapToLong(RewriteFileGroup::sizeInBytes) + .boxed() + .collect(Collectors.toList()); + + expected.sort(Comparator.reverseOrder()); + Assert.assertEquals("Size in bytes order should be descending", actual, expected); + Collections.reverse(expected); + Assert.assertNotEquals("Size in bytes order should not be ascending", actual, expected); + } + + @Test + public void testRewriteJobOrderFilesAsc() { + Table table = createTablePartitioned(4, 2); + writeRecords(1, SCALE, 1); + writeRecords(2, SCALE, 2); + writeRecords(3, SCALE, 3); + writeRecords(4, SCALE, 4); + table.updateProperties().set(TableProperties.FORMAT_VERSION, "2").commit(); + + RewriteDataFilesSparkAction basicRewrite = basicRewrite(table).binPack(); + List expected = + toGroupStream(table, basicRewrite) + .mapToLong(RewriteFileGroup::numFiles) + .boxed() + .collect(Collectors.toList()); + + RewriteDataFilesSparkAction jobOrderRewrite = + basicRewrite(table) + .option(RewriteDataFiles.REWRITE_JOB_ORDER, RewriteJobOrder.FILES_ASC.orderName()) + .binPack(); + List actual = + toGroupStream(table, jobOrderRewrite) + .mapToLong(RewriteFileGroup::numFiles) + .boxed() + .collect(Collectors.toList()); + + expected.sort(Comparator.naturalOrder()); + Assert.assertEquals("Number of files order should be ascending", actual, expected); + Collections.reverse(expected); + Assert.assertNotEquals("Number of files order should not be descending", actual, expected); + } + + @Test + public void testRewriteJobOrderFilesDesc() { + Table table = createTablePartitioned(4, 2); + writeRecords(1, SCALE, 1); + writeRecords(2, SCALE, 2); + writeRecords(3, SCALE, 3); + writeRecords(4, SCALE, 4); + table.updateProperties().set(TableProperties.FORMAT_VERSION, "2").commit(); + + RewriteDataFilesSparkAction basicRewrite = basicRewrite(table).binPack(); + List expected = + toGroupStream(table, basicRewrite) + .mapToLong(RewriteFileGroup::numFiles) + .boxed() + .collect(Collectors.toList()); + + RewriteDataFilesSparkAction jobOrderRewrite = + basicRewrite(table) + .option(RewriteDataFiles.REWRITE_JOB_ORDER, RewriteJobOrder.FILES_DESC.orderName()) + .binPack(); + List actual = + toGroupStream(table, jobOrderRewrite) + .mapToLong(RewriteFileGroup::numFiles) + .boxed() + .collect(Collectors.toList()); + + expected.sort(Comparator.reverseOrder()); + Assert.assertEquals("Number of files order should be descending", actual, expected); + Collections.reverse(expected); + Assert.assertNotEquals("Number of files order should not be ascending", actual, expected); + } + + private Stream toGroupStream(Table table, RewriteDataFilesSparkAction rewrite) { + rewrite.validateAndInitOptions(); + Map>> fileGroupsByPartition = + rewrite.planFileGroups(table.currentSnapshot().snapshotId()); + + return rewrite.toGroupStream( + new RewriteExecutionContext(fileGroupsByPartition), fileGroupsByPartition); + } + + protected List currentData() { + return rowsToJava( + spark.read().format("iceberg").load(tableLocation).sort("c1", "c2", "c3").collectAsList()); + } + + protected long testDataSize(Table table) { + return Streams.stream(table.newScan().planFiles()).mapToLong(FileScanTask::length).sum(); + } + + protected void shouldHaveMultipleFiles(Table table) { + table.refresh(); + int numFiles = Iterables.size(table.newScan().planFiles()); + Assert.assertTrue(String.format("Should have multiple files, had %d", numFiles), numFiles > 1); + } + + protected void shouldHaveFiles(Table table, int numExpected) { + table.refresh(); + int numFiles = Iterables.size(table.newScan().planFiles()); + Assert.assertEquals("Did not have the expected number of files", numExpected, numFiles); + } + + protected void shouldHaveSnapshots(Table table, int expectedSnapshots) { + table.refresh(); + int actualSnapshots = Iterables.size(table.snapshots()); + Assert.assertEquals( + "Table did not have the expected number of snapshots", expectedSnapshots, actualSnapshots); + } + + protected void shouldHaveNoOrphans(Table table) { + Assert.assertEquals( + "Should not have found any orphan files", + ImmutableList.of(), + actions() + .deleteOrphanFiles(table) + .olderThan(System.currentTimeMillis()) + .execute() + .orphanFileLocations()); + } + + protected void shouldHaveACleanCache(Table table) { + Assert.assertEquals( + "Should not have any entries in cache", ImmutableSet.of(), cacheContents(table)); + } + + protected void shouldHaveLastCommitSorted(Table table, String column) { + List, Pair>> overlappingFiles = checkForOverlappingFiles(table, column); + + Assert.assertEquals("Found overlapping files", Collections.emptyList(), overlappingFiles); + } + + protected void shouldHaveLastCommitUnsorted(Table table, String column) { + List, Pair>> overlappingFiles = checkForOverlappingFiles(table, column); + + Assert.assertNotEquals("Found no overlapping files", Collections.emptyList(), overlappingFiles); + } + + private Pair boundsOf(DataFile file, NestedField field, Class javaClass) { + int columnId = field.fieldId(); + return Pair.of( + javaClass.cast(Conversions.fromByteBuffer(field.type(), file.lowerBounds().get(columnId))), + javaClass.cast(Conversions.fromByteBuffer(field.type(), file.upperBounds().get(columnId)))); + } + + private List, Pair>> checkForOverlappingFiles( + Table table, String column) { + table.refresh(); + NestedField field = table.schema().caseInsensitiveFindField(column); + Class javaClass = (Class) field.type().typeId().javaClass(); + + Snapshot snapshot = table.currentSnapshot(); + Map> filesByPartition = + Streams.stream(snapshot.addedDataFiles(table.io())) + .collect(Collectors.groupingBy(DataFile::partition)); + + Stream, Pair>> overlaps = + filesByPartition.entrySet().stream() + .flatMap( + entry -> { + List datafiles = entry.getValue(); + Preconditions.checkArgument( + datafiles.size() > 1, + "This test is checking for overlaps in a situation where no overlaps can actually occur because the " + + "partition %s does not contain multiple datafiles", + entry.getKey()); + + List, Pair>> boundComparisons = + Lists.cartesianProduct(datafiles, datafiles).stream() + .filter(tuple -> tuple.get(0) != tuple.get(1)) + .map( + tuple -> + Pair.of( + boundsOf(tuple.get(0), field, javaClass), + boundsOf(tuple.get(1), field, javaClass))) + .collect(Collectors.toList()); + + Comparator comparator = Comparators.forType(field.type().asPrimitiveType()); + + List, Pair>> overlappingFiles = + boundComparisons.stream() + .filter( + filePair -> { + Pair left = filePair.first(); + T lMin = left.first(); + T lMax = left.second(); + Pair right = filePair.second(); + T rMin = right.first(); + T rMax = right.second(); + boolean boundsDoNotOverlap = + // Min and Max of a range are greater than or equal to the max + // value of the other range + (comparator.compare(rMax, lMax) >= 0 + && comparator.compare(rMin, lMax) >= 0) + || (comparator.compare(lMax, rMax) >= 0 + && comparator.compare(lMin, rMax) >= 0); + + return !boundsDoNotOverlap; + }) + .collect(Collectors.toList()); + return overlappingFiles.stream(); + }); + + return overlaps.collect(Collectors.toList()); + } + + protected Table createTable() { + PartitionSpec spec = PartitionSpec.unpartitioned(); + Map options = Maps.newHashMap(); + Table table = TABLES.create(SCHEMA, spec, options, tableLocation); + table + .updateProperties() + .set(TableProperties.PARQUET_ROW_GROUP_SIZE_BYTES, Integer.toString(20 * 1024)) + .commit(); + Assert.assertNull("Table must be empty", table.currentSnapshot()); + return table; + } + + /** + * Create a table with a certain number of files, returns the size of a file + * + * @param files number of files to create + * @return the created table + */ + protected Table createTable(int files) { + Table table = createTable(); + writeRecords(files, SCALE); + return table; + } + + protected Table createTablePartitioned( + int partitions, int files, int numRecords, Map options) { + PartitionSpec spec = PartitionSpec.builderFor(SCHEMA).identity("c1").truncate("c2", 2).build(); + Table table = TABLES.create(SCHEMA, spec, options, tableLocation); + Assert.assertNull("Table must be empty", table.currentSnapshot()); + + writeRecords(files, numRecords, partitions); + return table; + } + + protected Table createTablePartitioned(int partitions, int files) { + return createTablePartitioned(partitions, files, SCALE, Maps.newHashMap()); + } + + private Table createTypeTestTable() { + Schema schema = + new Schema( + required(1, "longCol", Types.LongType.get()), + required(2, "intCol", Types.IntegerType.get()), + required(3, "floatCol", Types.FloatType.get()), + optional(4, "doubleCol", Types.DoubleType.get()), + optional(5, "dateCol", Types.DateType.get()), + optional(6, "timestampCol", Types.TimestampType.withZone()), + optional(7, "stringCol", Types.StringType.get()), + optional(8, "booleanCol", Types.BooleanType.get()), + optional(9, "binaryCol", Types.BinaryType.get())); + + Map options = Maps.newHashMap(); + Table table = TABLES.create(schema, PartitionSpec.unpartitioned(), options, tableLocation); + + spark + .range(0, 10, 1, 10) + .withColumnRenamed("id", "longCol") + .withColumn("intCol", expr("CAST(longCol AS INT)")) + .withColumn("floatCol", expr("CAST(longCol AS FLOAT)")) + .withColumn("doubleCol", expr("CAST(longCol AS DOUBLE)")) + .withColumn("dateCol", date_add(current_date(), 1)) + .withColumn("timestampCol", expr("TO_TIMESTAMP(dateCol)")) + .withColumn("stringCol", expr("CAST(dateCol AS STRING)")) + .withColumn("booleanCol", expr("longCol > 5")) + .withColumn("binaryCol", expr("CAST(longCol AS BINARY)")) + .write() + .format("iceberg") + .mode("append") + .save(tableLocation); + + return table; + } + + protected int averageFileSize(Table table) { + table.refresh(); + return (int) + Streams.stream(table.newScan().planFiles()) + .mapToLong(FileScanTask::length) + .average() + .getAsDouble(); + } + + private void writeRecords(int files, int numRecords) { + writeRecords(files, numRecords, 0); + } + + private void writeRecords(int files, int numRecords, int partitions) { + List records = Lists.newArrayList(); + int rowDimension = (int) Math.ceil(Math.sqrt(numRecords)); + List> data = + IntStream.range(0, rowDimension) + .boxed() + .flatMap(x -> IntStream.range(0, rowDimension).boxed().map(y -> Pair.of(x, y))) + .collect(Collectors.toList()); + Collections.shuffle(data, new Random(42)); + if (partitions > 0) { + data.forEach( + i -> + records.add( + new ThreeColumnRecord( + i.first() % partitions, "foo" + i.first(), "bar" + i.second()))); + } else { + data.forEach( + i -> + records.add(new ThreeColumnRecord(i.first(), "foo" + i.first(), "bar" + i.second()))); + } + Dataset df = spark.createDataFrame(records, ThreeColumnRecord.class).repartition(files); + writeDF(df); + } + + private void writeDF(Dataset df) { + df.select("c1", "c2", "c3") + .sortWithinPartitions("c1", "c2") + .write() + .format("iceberg") + .mode("append") + .save(tableLocation); + } + + private List writePosDeletesToFile( + Table table, DataFile dataFile, int outputDeleteFiles) { + return writePosDeletes( + table, dataFile.partition(), dataFile.path().toString(), outputDeleteFiles); + } + + private List writePosDeletes( + Table table, StructLike partition, String path, int outputDeleteFiles) { + List results = Lists.newArrayList(); + int rowPosition = 0; + for (int file = 0; file < outputDeleteFiles; file++) { + OutputFile outputFile = + table + .io() + .newOutputFile( + table.locationProvider().newDataLocation(UUID.randomUUID().toString())); + EncryptedOutputFile encryptedOutputFile = + EncryptedFiles.encryptedOutput(outputFile, EncryptionKeyMetadata.EMPTY); + + GenericAppenderFactory appenderFactory = + new GenericAppenderFactory(table.schema(), table.spec(), null, null, null); + PositionDeleteWriter posDeleteWriter = + appenderFactory + .set(TableProperties.DEFAULT_WRITE_METRICS_MODE, "full") + .newPosDeleteWriter(encryptedOutputFile, FileFormat.PARQUET, partition); + + PositionDelete posDelete = PositionDelete.create(); + posDeleteWriter.write(posDelete.set(path, rowPosition, null)); + try { + posDeleteWriter.close(); + } catch (IOException e) { + throw new UncheckedIOException(e); + } + + results.add(posDeleteWriter.toDeleteFile()); + rowPosition++; + } + + return results; + } + + private SparkActions actions() { + return SparkActions.get(); + } + + private Set cacheContents(Table table) { + return ImmutableSet.builder() + .addAll(manager.fetchSetIds(table)) + .addAll(coordinator.fetchSetIds(table)) + .build(); + } + + private double percentFilesRequired(Table table, String col, String value) { + return percentFilesRequired(table, new String[] {col}, new String[] {value}); + } + + private double percentFilesRequired(Table table, String[] cols, String[] values) { + Preconditions.checkArgument(cols.length == values.length); + Expression restriction = Expressions.alwaysTrue(); + for (int i = 0; i < cols.length; i++) { + restriction = Expressions.and(restriction, Expressions.equal(cols[i], values[i])); + } + int totalFiles = Iterables.size(table.newScan().planFiles()); + int filteredFiles = Iterables.size(table.newScan().filter(restriction).planFiles()); + return (double) filteredFiles / (double) totalFiles; + } + + class GroupInfoMatcher implements ArgumentMatcher { + private final Set groupIDs; + + GroupInfoMatcher(Integer... globalIndex) { + this.groupIDs = ImmutableSet.copyOf(globalIndex); + } + + @Override + public boolean matches(RewriteFileGroup argument) { + return groupIDs.contains(argument.info().globalIndex()); + } + } +} diff --git a/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/actions/TestRewriteManifestsAction.java b/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/actions/TestRewriteManifestsAction.java new file mode 100644 index 000000000000..4aafb72acef9 --- /dev/null +++ b/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/actions/TestRewriteManifestsAction.java @@ -0,0 +1,583 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.actions; + +import static org.apache.iceberg.ValidationHelpers.dataSeqs; +import static org.apache.iceberg.ValidationHelpers.fileSeqs; +import static org.apache.iceberg.ValidationHelpers.files; +import static org.apache.iceberg.ValidationHelpers.snapshotIds; +import static org.apache.iceberg.ValidationHelpers.validateDataManifest; +import static org.apache.iceberg.types.Types.NestedField.optional; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.spy; +import static org.mockito.Mockito.when; + +import java.io.File; +import java.io.IOException; +import java.util.List; +import java.util.Map; +import org.apache.hadoop.conf.Configuration; +import org.apache.iceberg.AssertHelpers; +import org.apache.iceberg.DataFile; +import org.apache.iceberg.ManifestFile; +import org.apache.iceberg.PartitionSpec; +import org.apache.iceberg.Schema; +import org.apache.iceberg.Snapshot; +import org.apache.iceberg.Table; +import org.apache.iceberg.TableProperties; +import org.apache.iceberg.actions.RewriteManifests; +import org.apache.iceberg.exceptions.CommitStateUnknownException; +import org.apache.iceberg.hadoop.HadoopTables; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap; +import org.apache.iceberg.relocated.com.google.common.collect.Iterables; +import org.apache.iceberg.relocated.com.google.common.collect.Lists; +import org.apache.iceberg.relocated.com.google.common.collect.Maps; +import org.apache.iceberg.spark.SparkTableUtil; +import org.apache.iceberg.spark.SparkTestBase; +import org.apache.iceberg.spark.SparkWriteOptions; +import org.apache.iceberg.spark.source.ThreeColumnRecord; +import org.apache.iceberg.types.Types; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Encoders; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.catalyst.TableIdentifier; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.TemporaryFolder; +import org.junit.runner.RunWith; +import org.junit.runners.Parameterized; + +@RunWith(Parameterized.class) +public class TestRewriteManifestsAction extends SparkTestBase { + + private static final HadoopTables TABLES = new HadoopTables(new Configuration()); + private static final Schema SCHEMA = + new Schema( + optional(1, "c1", Types.IntegerType.get()), + optional(2, "c2", Types.StringType.get()), + optional(3, "c3", Types.StringType.get())); + + @Parameterized.Parameters(name = "snapshotIdInheritanceEnabled = {0}") + public static Object[] parameters() { + return new Object[] {"true", "false"}; + } + + @Rule public TemporaryFolder temp = new TemporaryFolder(); + + private final String snapshotIdInheritanceEnabled; + private String tableLocation = null; + + public TestRewriteManifestsAction(String snapshotIdInheritanceEnabled) { + this.snapshotIdInheritanceEnabled = snapshotIdInheritanceEnabled; + } + + @Before + public void setupTableLocation() throws Exception { + File tableDir = temp.newFolder(); + this.tableLocation = tableDir.toURI().toString(); + } + + @Test + public void testRewriteManifestsEmptyTable() throws IOException { + PartitionSpec spec = PartitionSpec.unpartitioned(); + Map options = Maps.newHashMap(); + options.put(TableProperties.SNAPSHOT_ID_INHERITANCE_ENABLED, snapshotIdInheritanceEnabled); + Table table = TABLES.create(SCHEMA, spec, options, tableLocation); + + Assert.assertNull("Table must be empty", table.currentSnapshot()); + + SparkActions actions = SparkActions.get(); + + actions + .rewriteManifests(table) + .rewriteIf(manifest -> true) + .stagingLocation(temp.newFolder().toString()) + .execute(); + + Assert.assertNull("Table must stay empty", table.currentSnapshot()); + } + + @Test + public void testRewriteSmallManifestsNonPartitionedTable() { + PartitionSpec spec = PartitionSpec.unpartitioned(); + Map options = Maps.newHashMap(); + options.put(TableProperties.SNAPSHOT_ID_INHERITANCE_ENABLED, snapshotIdInheritanceEnabled); + Table table = TABLES.create(SCHEMA, spec, options, tableLocation); + + List records1 = + Lists.newArrayList( + new ThreeColumnRecord(1, null, "AAAA"), new ThreeColumnRecord(1, "BBBBBBBBBB", "BBBB")); + writeRecords(records1); + + List records2 = + Lists.newArrayList( + new ThreeColumnRecord(2, "CCCCCCCCCC", "CCCC"), + new ThreeColumnRecord(2, "DDDDDDDDDD", "DDDD")); + writeRecords(records2); + + table.refresh(); + + List manifests = table.currentSnapshot().allManifests(table.io()); + Assert.assertEquals("Should have 2 manifests before rewrite", 2, manifests.size()); + + SparkActions actions = SparkActions.get(); + + RewriteManifests.Result result = + actions.rewriteManifests(table).rewriteIf(manifest -> true).execute(); + + Assert.assertEquals( + "Action should rewrite 2 manifests", 2, Iterables.size(result.rewrittenManifests())); + Assert.assertEquals( + "Action should add 1 manifests", 1, Iterables.size(result.addedManifests())); + + table.refresh(); + + List newManifests = table.currentSnapshot().allManifests(table.io()); + Assert.assertEquals("Should have 1 manifests after rewrite", 1, newManifests.size()); + + Assert.assertEquals(4, (long) newManifests.get(0).existingFilesCount()); + Assert.assertFalse(newManifests.get(0).hasAddedFiles()); + Assert.assertFalse(newManifests.get(0).hasDeletedFiles()); + + List expectedRecords = Lists.newArrayList(); + expectedRecords.addAll(records1); + expectedRecords.addAll(records2); + + Dataset resultDF = spark.read().format("iceberg").load(tableLocation); + List actualRecords = + resultDF.sort("c1", "c2").as(Encoders.bean(ThreeColumnRecord.class)).collectAsList(); + + Assert.assertEquals("Rows must match", expectedRecords, actualRecords); + } + + @Test + public void testRewriteManifestsWithCommitStateUnknownException() { + PartitionSpec spec = PartitionSpec.unpartitioned(); + Map options = Maps.newHashMap(); + options.put(TableProperties.SNAPSHOT_ID_INHERITANCE_ENABLED, snapshotIdInheritanceEnabled); + Table table = TABLES.create(SCHEMA, spec, options, tableLocation); + + List records1 = + Lists.newArrayList( + new ThreeColumnRecord(1, null, "AAAA"), new ThreeColumnRecord(1, "BBBBBBBBBB", "BBBB")); + writeRecords(records1); + + List records2 = + Lists.newArrayList( + new ThreeColumnRecord(2, "CCCCCCCCCC", "CCCC"), + new ThreeColumnRecord(2, "DDDDDDDDDD", "DDDD")); + writeRecords(records2); + + table.refresh(); + + List manifests = table.currentSnapshot().allManifests(table.io()); + Assert.assertEquals("Should have 2 manifests before rewrite", 2, manifests.size()); + + SparkActions actions = SparkActions.get(); + + // create a spy which would throw a CommitStateUnknownException after successful commit. + org.apache.iceberg.RewriteManifests newRewriteManifests = table.rewriteManifests(); + org.apache.iceberg.RewriteManifests spyNewRewriteManifests = spy(newRewriteManifests); + doAnswer( + invocation -> { + newRewriteManifests.commit(); + throw new CommitStateUnknownException(new RuntimeException("Datacenter on Fire")); + }) + .when(spyNewRewriteManifests) + .commit(); + + Table spyTable = spy(table); + when(spyTable.rewriteManifests()).thenReturn(spyNewRewriteManifests); + + AssertHelpers.assertThrowsCause( + "Should throw a Commit State Unknown Exception", + RuntimeException.class, + "Datacenter on Fire", + () -> actions.rewriteManifests(spyTable).rewriteIf(manifest -> true).execute()); + + table.refresh(); + + // table should reflect the changes, since the commit was successful + List newManifests = table.currentSnapshot().allManifests(table.io()); + Assert.assertEquals("Should have 1 manifests after rewrite", 1, newManifests.size()); + + Assert.assertEquals(4, (long) newManifests.get(0).existingFilesCount()); + Assert.assertFalse(newManifests.get(0).hasAddedFiles()); + Assert.assertFalse(newManifests.get(0).hasDeletedFiles()); + + List expectedRecords = Lists.newArrayList(); + expectedRecords.addAll(records1); + expectedRecords.addAll(records2); + + Dataset resultDF = spark.read().format("iceberg").load(tableLocation); + List actualRecords = + resultDF.sort("c1", "c2").as(Encoders.bean(ThreeColumnRecord.class)).collectAsList(); + + Assert.assertEquals("Rows must match", expectedRecords, actualRecords); + } + + @Test + public void testRewriteSmallManifestsPartitionedTable() { + PartitionSpec spec = PartitionSpec.builderFor(SCHEMA).identity("c1").truncate("c2", 2).build(); + Map options = Maps.newHashMap(); + options.put(TableProperties.SNAPSHOT_ID_INHERITANCE_ENABLED, snapshotIdInheritanceEnabled); + Table table = TABLES.create(SCHEMA, spec, options, tableLocation); + + List records1 = + Lists.newArrayList( + new ThreeColumnRecord(1, null, "AAAA"), new ThreeColumnRecord(1, "BBBBBBBBBB", "BBBB")); + writeRecords(records1); + + List records2 = + Lists.newArrayList( + new ThreeColumnRecord(2, "CCCCCCCCCC", "CCCC"), + new ThreeColumnRecord(2, "DDDDDDDDDD", "DDDD")); + writeRecords(records2); + + List records3 = + Lists.newArrayList( + new ThreeColumnRecord(3, "EEEEEEEEEE", "EEEE"), + new ThreeColumnRecord(3, "FFFFFFFFFF", "FFFF")); + writeRecords(records3); + + List records4 = + Lists.newArrayList( + new ThreeColumnRecord(4, "GGGGGGGGGG", "GGGG"), + new ThreeColumnRecord(4, "HHHHHHHHHG", "HHHH")); + writeRecords(records4); + + table.refresh(); + + List manifests = table.currentSnapshot().allManifests(table.io()); + Assert.assertEquals("Should have 4 manifests before rewrite", 4, manifests.size()); + + SparkActions actions = SparkActions.get(); + + // we will expect to have 2 manifests with 4 entries in each after rewrite + long manifestEntrySizeBytes = computeManifestEntrySizeBytes(manifests); + long targetManifestSizeBytes = (long) (1.05 * 4 * manifestEntrySizeBytes); + + table + .updateProperties() + .set(TableProperties.MANIFEST_TARGET_SIZE_BYTES, String.valueOf(targetManifestSizeBytes)) + .commit(); + + RewriteManifests.Result result = + actions.rewriteManifests(table).rewriteIf(manifest -> true).execute(); + + Assert.assertEquals( + "Action should rewrite 4 manifests", 4, Iterables.size(result.rewrittenManifests())); + Assert.assertEquals( + "Action should add 2 manifests", 2, Iterables.size(result.addedManifests())); + + table.refresh(); + + List newManifests = table.currentSnapshot().allManifests(table.io()); + Assert.assertEquals("Should have 2 manifests after rewrite", 2, newManifests.size()); + + Assert.assertEquals(4, (long) newManifests.get(0).existingFilesCount()); + Assert.assertFalse(newManifests.get(0).hasAddedFiles()); + Assert.assertFalse(newManifests.get(0).hasDeletedFiles()); + + Assert.assertEquals(4, (long) newManifests.get(1).existingFilesCount()); + Assert.assertFalse(newManifests.get(1).hasAddedFiles()); + Assert.assertFalse(newManifests.get(1).hasDeletedFiles()); + + List expectedRecords = Lists.newArrayList(); + expectedRecords.addAll(records1); + expectedRecords.addAll(records2); + expectedRecords.addAll(records3); + expectedRecords.addAll(records4); + + Dataset resultDF = spark.read().format("iceberg").load(tableLocation); + List actualRecords = + resultDF.sort("c1", "c2").as(Encoders.bean(ThreeColumnRecord.class)).collectAsList(); + + Assert.assertEquals("Rows must match", expectedRecords, actualRecords); + } + + @Test + public void testRewriteImportedManifests() throws IOException { + PartitionSpec spec = PartitionSpec.builderFor(SCHEMA).identity("c3").build(); + Map options = Maps.newHashMap(); + options.put(TableProperties.SNAPSHOT_ID_INHERITANCE_ENABLED, snapshotIdInheritanceEnabled); + Table table = TABLES.create(SCHEMA, spec, options, tableLocation); + + List records = + Lists.newArrayList( + new ThreeColumnRecord(1, null, "AAAA"), new ThreeColumnRecord(1, "BBBBBBBBBB", "BBBB")); + File parquetTableDir = temp.newFolder("parquet_table"); + String parquetTableLocation = parquetTableDir.toURI().toString(); + + try { + Dataset inputDF = spark.createDataFrame(records, ThreeColumnRecord.class); + inputDF + .select("c1", "c2", "c3") + .write() + .format("parquet") + .mode("overwrite") + .option("path", parquetTableLocation) + .partitionBy("c3") + .saveAsTable("parquet_table"); + + File stagingDir = temp.newFolder("staging-dir"); + SparkTableUtil.importSparkTable( + spark, new TableIdentifier("parquet_table"), table, stagingDir.toString()); + + // add some more data to create more than one manifest for the rewrite + inputDF.select("c1", "c2", "c3").write().format("iceberg").mode("append").save(tableLocation); + table.refresh(); + + Snapshot snapshot = table.currentSnapshot(); + + SparkActions actions = SparkActions.get(); + + RewriteManifests.Result result = + actions + .rewriteManifests(table) + .rewriteIf(manifest -> true) + .stagingLocation(temp.newFolder().toString()) + .execute(); + + Assert.assertEquals( + "Action should rewrite all manifests", + snapshot.allManifests(table.io()), + result.rewrittenManifests()); + Assert.assertEquals( + "Action should add 1 manifest", 1, Iterables.size(result.addedManifests())); + + } finally { + spark.sql("DROP TABLE parquet_table"); + } + } + + @Test + public void testRewriteLargeManifestsPartitionedTable() throws IOException { + PartitionSpec spec = PartitionSpec.builderFor(SCHEMA).identity("c3").build(); + Map options = Maps.newHashMap(); + options.put(TableProperties.SNAPSHOT_ID_INHERITANCE_ENABLED, snapshotIdInheritanceEnabled); + Table table = TABLES.create(SCHEMA, spec, options, tableLocation); + + // all records belong to the same partition + List records = Lists.newArrayList(); + for (int i = 0; i < 50; i++) { + records.add(new ThreeColumnRecord(i, String.valueOf(i), "0")); + } + Dataset df = spark.createDataFrame(records, ThreeColumnRecord.class); + // repartition to create separate files + writeDF(df.repartition(50, df.col("c1"))); + + table.refresh(); + + List manifests = table.currentSnapshot().allManifests(table.io()); + Assert.assertEquals("Should have 1 manifests before rewrite", 1, manifests.size()); + + // set the target manifest size to a small value to force splitting records into multiple files + table + .updateProperties() + .set( + TableProperties.MANIFEST_TARGET_SIZE_BYTES, + String.valueOf(manifests.get(0).length() / 2)) + .commit(); + + SparkActions actions = SparkActions.get(); + + RewriteManifests.Result result = + actions + .rewriteManifests(table) + .rewriteIf(manifest -> true) + .stagingLocation(temp.newFolder().toString()) + .execute(); + + Assert.assertEquals( + "Action should rewrite 1 manifest", 1, Iterables.size(result.rewrittenManifests())); + Assert.assertEquals( + "Action should add 2 manifests", 2, Iterables.size(result.addedManifests())); + + table.refresh(); + + List newManifests = table.currentSnapshot().allManifests(table.io()); + Assert.assertEquals("Should have 2 manifests after rewrite", 2, newManifests.size()); + + Dataset resultDF = spark.read().format("iceberg").load(tableLocation); + List actualRecords = + resultDF.sort("c1", "c2").as(Encoders.bean(ThreeColumnRecord.class)).collectAsList(); + + Assert.assertEquals("Rows must match", records, actualRecords); + } + + @Test + public void testRewriteManifestsWithPredicate() throws IOException { + PartitionSpec spec = PartitionSpec.builderFor(SCHEMA).identity("c1").truncate("c2", 2).build(); + Map options = Maps.newHashMap(); + options.put(TableProperties.SNAPSHOT_ID_INHERITANCE_ENABLED, snapshotIdInheritanceEnabled); + Table table = TABLES.create(SCHEMA, spec, options, tableLocation); + + List records1 = + Lists.newArrayList( + new ThreeColumnRecord(1, null, "AAAA"), new ThreeColumnRecord(1, "BBBBBBBBBB", "BBBB")); + writeRecords(records1); + + writeRecords(records1); + + List records2 = + Lists.newArrayList( + new ThreeColumnRecord(2, "CCCCCCCCCC", "CCCC"), + new ThreeColumnRecord(2, "DDDDDDDDDD", "DDDD")); + writeRecords(records2); + + table.refresh(); + + List manifests = table.currentSnapshot().allManifests(table.io()); + Assert.assertEquals("Should have 3 manifests before rewrite", 3, manifests.size()); + + SparkActions actions = SparkActions.get(); + + // rewrite only the first manifest without caching + RewriteManifests.Result result = + actions + .rewriteManifests(table) + .rewriteIf( + manifest -> + (manifest.path().equals(manifests.get(0).path()) + || (manifest.path().equals(manifests.get(1).path())))) + .stagingLocation(temp.newFolder().toString()) + .option("use-caching", "false") + .execute(); + + Assert.assertEquals( + "Action should rewrite 2 manifest", 2, Iterables.size(result.rewrittenManifests())); + Assert.assertEquals( + "Action should add 1 manifests", 1, Iterables.size(result.addedManifests())); + + table.refresh(); + + List newManifests = table.currentSnapshot().allManifests(table.io()); + Assert.assertEquals("Should have 2 manifests after rewrite", 2, newManifests.size()); + + Assert.assertFalse("First manifest must be rewritten", newManifests.contains(manifests.get(0))); + Assert.assertFalse( + "Second manifest must be rewritten", newManifests.contains(manifests.get(1))); + Assert.assertTrue( + "Third manifest must not be rewritten", newManifests.contains(manifests.get(2))); + + List expectedRecords = Lists.newArrayList(); + expectedRecords.add(records1.get(0)); + expectedRecords.add(records1.get(0)); + expectedRecords.add(records1.get(1)); + expectedRecords.add(records1.get(1)); + expectedRecords.addAll(records2); + + Dataset resultDF = spark.read().format("iceberg").load(tableLocation); + List actualRecords = + resultDF.sort("c1", "c2").as(Encoders.bean(ThreeColumnRecord.class)).collectAsList(); + + Assert.assertEquals("Rows must match", expectedRecords, actualRecords); + } + + @Test + public void testRewriteSmallManifestsNonPartitionedV2Table() { + PartitionSpec spec = PartitionSpec.unpartitioned(); + Map properties = ImmutableMap.of(TableProperties.FORMAT_VERSION, "2"); + Table table = TABLES.create(SCHEMA, spec, properties, tableLocation); + + List records1 = Lists.newArrayList(new ThreeColumnRecord(1, null, "AAAA")); + writeRecords(records1); + + table.refresh(); + + Snapshot snapshot1 = table.currentSnapshot(); + DataFile file1 = Iterables.getOnlyElement(snapshot1.addedDataFiles(table.io())); + + List records2 = Lists.newArrayList(new ThreeColumnRecord(2, "CCCC", "CCCC")); + writeRecords(records2); + + table.refresh(); + + Snapshot snapshot2 = table.currentSnapshot(); + DataFile file2 = Iterables.getOnlyElement(snapshot2.addedDataFiles(table.io())); + + List manifests = table.currentSnapshot().allManifests(table.io()); + Assert.assertEquals("Should have 2 manifests before rewrite", 2, manifests.size()); + + SparkActions actions = SparkActions.get(); + RewriteManifests.Result result = actions.rewriteManifests(table).execute(); + Assert.assertEquals( + "Action should rewrite 2 manifests", 2, Iterables.size(result.rewrittenManifests())); + Assert.assertEquals( + "Action should add 1 manifests", 1, Iterables.size(result.addedManifests())); + + table.refresh(); + + List newManifests = table.currentSnapshot().allManifests(table.io()); + Assert.assertEquals("Should have 1 manifests after rewrite", 1, newManifests.size()); + + ManifestFile newManifest = Iterables.getOnlyElement(newManifests); + Assert.assertEquals(2, (long) newManifest.existingFilesCount()); + Assert.assertFalse(newManifest.hasAddedFiles()); + Assert.assertFalse(newManifest.hasDeletedFiles()); + + validateDataManifest( + table, + newManifest, + dataSeqs(1L, 2L), + fileSeqs(1L, 2L), + snapshotIds(snapshot1.snapshotId(), snapshot2.snapshotId()), + files(file1, file2)); + + List expectedRecords = Lists.newArrayList(); + expectedRecords.addAll(records1); + expectedRecords.addAll(records2); + + Dataset resultDF = spark.read().format("iceberg").load(tableLocation); + List actualRecords = + resultDF.sort("c1", "c2").as(Encoders.bean(ThreeColumnRecord.class)).collectAsList(); + + Assert.assertEquals("Rows must match", expectedRecords, actualRecords); + } + + private void writeRecords(List records) { + Dataset df = spark.createDataFrame(records, ThreeColumnRecord.class); + writeDF(df); + } + + private void writeDF(Dataset df) { + df.select("c1", "c2", "c3") + .write() + .format("iceberg") + .option(SparkWriteOptions.DISTRIBUTION_MODE, TableProperties.WRITE_DISTRIBUTION_MODE_NONE) + .mode("append") + .save(tableLocation); + } + + private long computeManifestEntrySizeBytes(List manifests) { + long totalSize = 0L; + int numEntries = 0; + + for (ManifestFile manifest : manifests) { + totalSize += manifest.length(); + numEntries += + manifest.addedFilesCount() + manifest.existingFilesCount() + manifest.deletedFilesCount(); + } + + return totalSize / numEntries; + } +} diff --git a/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/actions/TestSparkFileRewriter.java b/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/actions/TestSparkFileRewriter.java new file mode 100644 index 000000000000..6800ffd404ea --- /dev/null +++ b/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/actions/TestSparkFileRewriter.java @@ -0,0 +1,396 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.actions; + +import java.util.List; +import java.util.Map; +import org.apache.iceberg.FileScanTask; +import org.apache.iceberg.MockFileScanTask; +import org.apache.iceberg.PartitionSpec; +import org.apache.iceberg.Schema; +import org.apache.iceberg.SortOrder; +import org.apache.iceberg.Table; +import org.apache.iceberg.actions.SizeBasedDataRewriter; +import org.apache.iceberg.actions.SizeBasedFileRewriter; +import org.apache.iceberg.catalog.TableIdentifier; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableList; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableSet; +import org.apache.iceberg.relocated.com.google.common.collect.Iterables; +import org.apache.iceberg.spark.SparkTestBase; +import org.apache.iceberg.types.Types.IntegerType; +import org.apache.iceberg.types.Types.NestedField; +import org.apache.iceberg.types.Types.StringType; +import org.assertj.core.api.Assertions; +import org.junit.After; +import org.junit.Assert; +import org.junit.Test; + +public class TestSparkFileRewriter extends SparkTestBase { + + private static final TableIdentifier TABLE_IDENT = TableIdentifier.of("default", "tbl"); + private static final Schema SCHEMA = + new Schema( + NestedField.required(1, "id", IntegerType.get()), + NestedField.required(2, "dep", StringType.get())); + private static final PartitionSpec SPEC = + PartitionSpec.builderFor(SCHEMA).identity("dep").build(); + private static final SortOrder SORT_ORDER = SortOrder.builderFor(SCHEMA).asc("id").build(); + + @After + public void removeTable() { + catalog.dropTable(TABLE_IDENT); + } + + @Test + public void testBinPackDataSelectFiles() { + Table table = catalog.createTable(TABLE_IDENT, SCHEMA); + SparkBinPackDataRewriter rewriter = new SparkBinPackDataRewriter(spark, table); + + checkDataFileSizeFiltering(rewriter); + checkDataFilesDeleteThreshold(rewriter); + checkDataFileGroupWithEnoughFiles(rewriter); + checkDataFileGroupWithEnoughData(rewriter); + checkDataFileGroupWithTooMuchData(rewriter); + } + + @Test + public void testSortDataSelectFiles() { + Table table = catalog.createTable(TABLE_IDENT, SCHEMA); + SparkSortDataRewriter rewriter = new SparkSortDataRewriter(spark, table, SORT_ORDER); + + checkDataFileSizeFiltering(rewriter); + checkDataFilesDeleteThreshold(rewriter); + checkDataFileGroupWithEnoughFiles(rewriter); + checkDataFileGroupWithEnoughData(rewriter); + checkDataFileGroupWithTooMuchData(rewriter); + } + + @Test + public void testZOrderDataSelectFiles() { + Table table = catalog.createTable(TABLE_IDENT, SCHEMA); + ImmutableList zOrderCols = ImmutableList.of("id"); + SparkZOrderDataRewriter rewriter = new SparkZOrderDataRewriter(spark, table, zOrderCols); + + checkDataFileSizeFiltering(rewriter); + checkDataFilesDeleteThreshold(rewriter); + checkDataFileGroupWithEnoughFiles(rewriter); + checkDataFileGroupWithEnoughData(rewriter); + checkDataFileGroupWithTooMuchData(rewriter); + } + + private void checkDataFileSizeFiltering(SizeBasedDataRewriter rewriter) { + FileScanTask tooSmallTask = new MockFileScanTask(100L); + FileScanTask optimal = new MockFileScanTask(450); + FileScanTask tooBigTask = new MockFileScanTask(1000L); + List tasks = ImmutableList.of(tooSmallTask, optimal, tooBigTask); + + Map options = + ImmutableMap.of( + SizeBasedDataRewriter.MIN_FILE_SIZE_BYTES, "250", + SizeBasedDataRewriter.TARGET_FILE_SIZE_BYTES, "500", + SizeBasedDataRewriter.MAX_FILE_SIZE_BYTES, "750", + SizeBasedDataRewriter.DELETE_FILE_THRESHOLD, String.valueOf(Integer.MAX_VALUE)); + rewriter.init(options); + + Iterable> groups = rewriter.planFileGroups(tasks); + Assert.assertEquals("Must have 1 group", 1, Iterables.size(groups)); + List group = Iterables.getOnlyElement(groups); + Assert.assertEquals("Must rewrite 2 files", 2, group.size()); + } + + private void checkDataFilesDeleteThreshold(SizeBasedDataRewriter rewriter) { + FileScanTask tooManyDeletesTask = MockFileScanTask.mockTaskWithDeletes(1000L, 3); + FileScanTask optimalTask = MockFileScanTask.mockTaskWithDeletes(1000L, 1); + List tasks = ImmutableList.of(tooManyDeletesTask, optimalTask); + + Map options = + ImmutableMap.of( + SizeBasedDataRewriter.MIN_FILE_SIZE_BYTES, "1", + SizeBasedDataRewriter.TARGET_FILE_SIZE_BYTES, "2000", + SizeBasedDataRewriter.MAX_FILE_SIZE_BYTES, "5000", + SizeBasedDataRewriter.DELETE_FILE_THRESHOLD, "2"); + rewriter.init(options); + + Iterable> groups = rewriter.planFileGroups(tasks); + Assert.assertEquals("Must have 1 group", 1, Iterables.size(groups)); + List group = Iterables.getOnlyElement(groups); + Assert.assertEquals("Must rewrite 1 file", 1, group.size()); + } + + private void checkDataFileGroupWithEnoughFiles(SizeBasedDataRewriter rewriter) { + List tasks = + ImmutableList.of( + new MockFileScanTask(100L), + new MockFileScanTask(100L), + new MockFileScanTask(100L), + new MockFileScanTask(100L)); + + Map options = + ImmutableMap.of( + SizeBasedDataRewriter.MIN_INPUT_FILES, "3", + SizeBasedDataRewriter.MIN_FILE_SIZE_BYTES, "150", + SizeBasedDataRewriter.TARGET_FILE_SIZE_BYTES, "1000", + SizeBasedDataRewriter.MAX_FILE_SIZE_BYTES, "5000", + SizeBasedDataRewriter.DELETE_FILE_THRESHOLD, String.valueOf(Integer.MAX_VALUE)); + rewriter.init(options); + + Iterable> groups = rewriter.planFileGroups(tasks); + Assert.assertEquals("Must have 1 group", 1, Iterables.size(groups)); + List group = Iterables.getOnlyElement(groups); + Assert.assertEquals("Must rewrite 4 files", 4, group.size()); + } + + private void checkDataFileGroupWithEnoughData(SizeBasedDataRewriter rewriter) { + List tasks = + ImmutableList.of( + new MockFileScanTask(100L), new MockFileScanTask(100L), new MockFileScanTask(100L)); + + Map options = + ImmutableMap.of( + SizeBasedDataRewriter.MIN_INPUT_FILES, "5", + SizeBasedDataRewriter.MIN_FILE_SIZE_BYTES, "200", + SizeBasedDataRewriter.TARGET_FILE_SIZE_BYTES, "250", + SizeBasedDataRewriter.MAX_FILE_SIZE_BYTES, "500", + SizeBasedDataRewriter.DELETE_FILE_THRESHOLD, String.valueOf(Integer.MAX_VALUE)); + rewriter.init(options); + + Iterable> groups = rewriter.planFileGroups(tasks); + Assert.assertEquals("Must have 1 group", 1, Iterables.size(groups)); + List group = Iterables.getOnlyElement(groups); + Assert.assertEquals("Must rewrite 3 files", 3, group.size()); + } + + private void checkDataFileGroupWithTooMuchData(SizeBasedDataRewriter rewriter) { + List tasks = ImmutableList.of(new MockFileScanTask(2000L)); + + Map options = + ImmutableMap.of( + SizeBasedDataRewriter.MIN_INPUT_FILES, "5", + SizeBasedDataRewriter.MIN_FILE_SIZE_BYTES, "200", + SizeBasedDataRewriter.TARGET_FILE_SIZE_BYTES, "250", + SizeBasedDataRewriter.MAX_FILE_SIZE_BYTES, "500", + SizeBasedDataRewriter.DELETE_FILE_THRESHOLD, String.valueOf(Integer.MAX_VALUE)); + rewriter.init(options); + + Iterable> groups = rewriter.planFileGroups(tasks); + Assert.assertEquals("Must have 1 group", 1, Iterables.size(groups)); + List group = Iterables.getOnlyElement(groups); + Assert.assertEquals("Must rewrite big file", 1, group.size()); + } + + @Test + public void testInvalidConstructorUsagesSortData() { + Table table = catalog.createTable(TABLE_IDENT, SCHEMA); + + Assertions.assertThatThrownBy(() -> new SparkSortDataRewriter(spark, table)) + .hasMessageContaining("Cannot sort data without a valid sort order") + .hasMessageContaining("is unsorted and no sort order is provided"); + + Assertions.assertThatThrownBy(() -> new SparkSortDataRewriter(spark, table, null)) + .hasMessageContaining("Cannot sort data without a valid sort order") + .hasMessageContaining("the provided sort order is null or empty"); + + Assertions.assertThatThrownBy( + () -> new SparkSortDataRewriter(spark, table, SortOrder.unsorted())) + .hasMessageContaining("Cannot sort data without a valid sort order") + .hasMessageContaining("the provided sort order is null or empty"); + } + + @Test + public void testInvalidConstructorUsagesZOrderData() { + Table table = catalog.createTable(TABLE_IDENT, SCHEMA, SPEC); + + Assertions.assertThatThrownBy(() -> new SparkZOrderDataRewriter(spark, table, null)) + .hasMessageContaining("Cannot ZOrder when no columns are specified"); + + Assertions.assertThatThrownBy( + () -> new SparkZOrderDataRewriter(spark, table, ImmutableList.of())) + .hasMessageContaining("Cannot ZOrder when no columns are specified"); + + Assertions.assertThatThrownBy( + () -> new SparkZOrderDataRewriter(spark, table, ImmutableList.of("dep"))) + .hasMessageContaining("Cannot ZOrder") + .hasMessageContaining("all columns provided were identity partition columns"); + + Assertions.assertThatThrownBy( + () -> new SparkZOrderDataRewriter(spark, table, ImmutableList.of("DeP"))) + .hasMessageContaining("Cannot ZOrder") + .hasMessageContaining("all columns provided were identity partition columns"); + } + + @Test + public void testBinPackDataValidOptions() { + Table table = catalog.createTable(TABLE_IDENT, SCHEMA); + SparkBinPackDataRewriter rewriter = new SparkBinPackDataRewriter(spark, table); + + Assert.assertEquals( + "Rewriter must report all supported options", + ImmutableSet.of( + SparkBinPackDataRewriter.TARGET_FILE_SIZE_BYTES, + SparkBinPackDataRewriter.MIN_FILE_SIZE_BYTES, + SparkBinPackDataRewriter.MAX_FILE_SIZE_BYTES, + SparkBinPackDataRewriter.MIN_INPUT_FILES, + SparkBinPackDataRewriter.REWRITE_ALL, + SparkBinPackDataRewriter.MAX_FILE_GROUP_SIZE_BYTES, + SparkBinPackDataRewriter.DELETE_FILE_THRESHOLD), + rewriter.validOptions()); + } + + @Test + public void testSortDataValidOptions() { + Table table = catalog.createTable(TABLE_IDENT, SCHEMA); + SparkSortDataRewriter rewriter = new SparkSortDataRewriter(spark, table, SORT_ORDER); + + Assert.assertEquals( + "Rewriter must report all supported options", + ImmutableSet.of( + SparkSortDataRewriter.TARGET_FILE_SIZE_BYTES, + SparkSortDataRewriter.MIN_FILE_SIZE_BYTES, + SparkSortDataRewriter.MAX_FILE_SIZE_BYTES, + SparkSortDataRewriter.MIN_INPUT_FILES, + SparkSortDataRewriter.REWRITE_ALL, + SparkSortDataRewriter.MAX_FILE_GROUP_SIZE_BYTES, + SparkSortDataRewriter.DELETE_FILE_THRESHOLD, + SparkSortDataRewriter.COMPRESSION_FACTOR), + rewriter.validOptions()); + } + + @Test + public void testZOrderDataValidOptions() { + Table table = catalog.createTable(TABLE_IDENT, SCHEMA); + ImmutableList zOrderCols = ImmutableList.of("id"); + SparkZOrderDataRewriter rewriter = new SparkZOrderDataRewriter(spark, table, zOrderCols); + + Assert.assertEquals( + "Rewriter must report all supported options", + ImmutableSet.of( + SparkZOrderDataRewriter.TARGET_FILE_SIZE_BYTES, + SparkZOrderDataRewriter.MIN_FILE_SIZE_BYTES, + SparkZOrderDataRewriter.MAX_FILE_SIZE_BYTES, + SparkZOrderDataRewriter.MIN_INPUT_FILES, + SparkZOrderDataRewriter.REWRITE_ALL, + SparkZOrderDataRewriter.MAX_FILE_GROUP_SIZE_BYTES, + SparkZOrderDataRewriter.DELETE_FILE_THRESHOLD, + SparkZOrderDataRewriter.COMPRESSION_FACTOR, + SparkZOrderDataRewriter.MAX_OUTPUT_SIZE, + SparkZOrderDataRewriter.VAR_LENGTH_CONTRIBUTION), + rewriter.validOptions()); + } + + @Test + public void testInvalidValuesForBinPackDataOptions() { + Table table = catalog.createTable(TABLE_IDENT, SCHEMA); + SparkBinPackDataRewriter rewriter = new SparkBinPackDataRewriter(spark, table); + + validateSizeBasedRewriterOptions(rewriter); + + Map invalidDeleteThresholdOptions = + ImmutableMap.of(SizeBasedDataRewriter.DELETE_FILE_THRESHOLD, "-1"); + Assertions.assertThatThrownBy(() -> rewriter.init(invalidDeleteThresholdOptions)) + .hasMessageContaining("'delete-file-threshold' is set to -1 but must be >= 0"); + } + + @Test + public void testInvalidValuesForSortDataOptions() { + Table table = catalog.createTable(TABLE_IDENT, SCHEMA); + SparkSortDataRewriter rewriter = new SparkSortDataRewriter(spark, table, SORT_ORDER); + + validateSizeBasedRewriterOptions(rewriter); + + Map invalidDeleteThresholdOptions = + ImmutableMap.of(SizeBasedDataRewriter.DELETE_FILE_THRESHOLD, "-1"); + Assertions.assertThatThrownBy(() -> rewriter.init(invalidDeleteThresholdOptions)) + .hasMessageContaining("'delete-file-threshold' is set to -1 but must be >= 0"); + + Map invalidCompressionFactorOptions = + ImmutableMap.of(SparkShufflingDataRewriter.COMPRESSION_FACTOR, "0"); + Assertions.assertThatThrownBy(() -> rewriter.init(invalidCompressionFactorOptions)) + .hasMessageContaining("'compression-factor' is set to 0.0 but must be > 0"); + } + + @Test + public void testInvalidValuesForZOrderDataOptions() { + Table table = catalog.createTable(TABLE_IDENT, SCHEMA); + ImmutableList zOrderCols = ImmutableList.of("id"); + SparkZOrderDataRewriter rewriter = new SparkZOrderDataRewriter(spark, table, zOrderCols); + + validateSizeBasedRewriterOptions(rewriter); + + Map invalidDeleteThresholdOptions = + ImmutableMap.of(SizeBasedDataRewriter.DELETE_FILE_THRESHOLD, "-1"); + Assertions.assertThatThrownBy(() -> rewriter.init(invalidDeleteThresholdOptions)) + .hasMessageContaining("'delete-file-threshold' is set to -1 but must be >= 0"); + + Map invalidCompressionFactorOptions = + ImmutableMap.of(SparkShufflingDataRewriter.COMPRESSION_FACTOR, "0"); + Assertions.assertThatThrownBy(() -> rewriter.init(invalidCompressionFactorOptions)) + .hasMessageContaining("'compression-factor' is set to 0.0 but must be > 0"); + + Map invalidMaxOutputOptions = + ImmutableMap.of(SparkZOrderDataRewriter.MAX_OUTPUT_SIZE, "0"); + Assertions.assertThatThrownBy(() -> rewriter.init(invalidMaxOutputOptions)) + .hasMessageContaining("Cannot have the interleaved ZOrder value use less than 1 byte") + .hasMessageContaining("'max-output-size' was set to 0"); + + Map invalidVarLengthContributionOptions = + ImmutableMap.of(SparkZOrderDataRewriter.VAR_LENGTH_CONTRIBUTION, "0"); + Assertions.assertThatThrownBy(() -> rewriter.init(invalidVarLengthContributionOptions)) + .hasMessageContaining("Cannot use less than 1 byte for variable length types with ZOrder") + .hasMessageContaining("'var-length-contribution' was set to 0"); + } + + private void validateSizeBasedRewriterOptions(SizeBasedFileRewriter rewriter) { + Map invalidTargetSizeOptions = + ImmutableMap.of(SizeBasedFileRewriter.TARGET_FILE_SIZE_BYTES, "0"); + Assertions.assertThatThrownBy(() -> rewriter.init(invalidTargetSizeOptions)) + .hasMessageContaining("'target-file-size-bytes' is set to 0 but must be > 0"); + + Map invalidMinSizeOptions = + ImmutableMap.of(SizeBasedFileRewriter.MIN_FILE_SIZE_BYTES, "-1"); + Assertions.assertThatThrownBy(() -> rewriter.init(invalidMinSizeOptions)) + .hasMessageContaining("'min-file-size-bytes' is set to -1 but must be >= 0"); + + Map invalidTargetMinSizeOptions = + ImmutableMap.of( + SizeBasedFileRewriter.TARGET_FILE_SIZE_BYTES, "3", + SizeBasedFileRewriter.MIN_FILE_SIZE_BYTES, "5"); + Assertions.assertThatThrownBy(() -> rewriter.init(invalidTargetMinSizeOptions)) + .hasMessageContaining("'target-file-size-bytes' (3) must be > 'min-file-size-bytes' (5)") + .hasMessageContaining("all new files will be smaller than the min threshold"); + + Map invalidTargetMaxSizeOptions = + ImmutableMap.of( + SizeBasedFileRewriter.TARGET_FILE_SIZE_BYTES, "5", + SizeBasedFileRewriter.MAX_FILE_SIZE_BYTES, "3"); + Assertions.assertThatThrownBy(() -> rewriter.init(invalidTargetMaxSizeOptions)) + .hasMessageContaining("'target-file-size-bytes' (5) must be < 'max-file-size-bytes' (3)") + .hasMessageContaining("all new files will be larger than the max threshold"); + + Map invalidMinInputFilesOptions = + ImmutableMap.of(SizeBasedFileRewriter.MIN_INPUT_FILES, "0"); + Assertions.assertThatThrownBy(() -> rewriter.init(invalidMinInputFilesOptions)) + .hasMessageContaining("'min-input-files' is set to 0 but must be > 0"); + + Map invalidMaxFileGroupSizeOptions = + ImmutableMap.of(SizeBasedFileRewriter.MAX_FILE_GROUP_SIZE_BYTES, "0"); + Assertions.assertThatThrownBy(() -> rewriter.init(invalidMaxFileGroupSizeOptions)) + .hasMessageContaining("'max-file-group-size-bytes' is set to 0 but must be > 0"); + } +} diff --git a/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/data/AvroDataTest.java b/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/data/AvroDataTest.java new file mode 100644 index 000000000000..5fd137c5361d --- /dev/null +++ b/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/data/AvroDataTest.java @@ -0,0 +1,286 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.data; + +import static org.apache.iceberg.types.Types.NestedField.optional; +import static org.apache.iceberg.types.Types.NestedField.required; + +import java.io.IOException; +import java.util.Map; +import java.util.concurrent.atomic.AtomicInteger; +import org.apache.iceberg.Schema; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap; +import org.apache.iceberg.relocated.com.google.common.collect.Lists; +import org.apache.iceberg.relocated.com.google.common.collect.Maps; +import org.apache.iceberg.spark.SparkSQLProperties; +import org.apache.iceberg.types.TypeUtil; +import org.apache.iceberg.types.Types; +import org.apache.iceberg.types.Types.ListType; +import org.apache.iceberg.types.Types.LongType; +import org.apache.iceberg.types.Types.MapType; +import org.apache.iceberg.types.Types.StructType; +import org.apache.spark.sql.internal.SQLConf; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.TemporaryFolder; + +public abstract class AvroDataTest { + + protected abstract void writeAndValidate(Schema schema) throws IOException; + + protected static final StructType SUPPORTED_PRIMITIVES = + StructType.of( + required(100, "id", LongType.get()), + optional(101, "data", Types.StringType.get()), + required(102, "b", Types.BooleanType.get()), + optional(103, "i", Types.IntegerType.get()), + required(104, "l", LongType.get()), + optional(105, "f", Types.FloatType.get()), + required(106, "d", Types.DoubleType.get()), + optional(107, "date", Types.DateType.get()), + required(108, "ts", Types.TimestampType.withZone()), + required(110, "s", Types.StringType.get()), + // required(111, "uuid", Types.UUIDType.get()), + required(112, "fixed", Types.FixedType.ofLength(7)), + optional(113, "bytes", Types.BinaryType.get()), + required(114, "dec_9_0", Types.DecimalType.of(9, 0)), // int encoded + required(115, "dec_11_2", Types.DecimalType.of(11, 2)), // long encoded + required(116, "dec_20_5", Types.DecimalType.of(20, 5)), // requires padding + required(117, "dec_38_10", Types.DecimalType.of(38, 10)) // Spark's maximum precision + ); + + @Rule public TemporaryFolder temp = new TemporaryFolder(); + + @Test + public void testSimpleStruct() throws IOException { + writeAndValidate(TypeUtil.assignIncreasingFreshIds(new Schema(SUPPORTED_PRIMITIVES.fields()))); + } + + @Test + public void testStructWithRequiredFields() throws IOException { + writeAndValidate( + TypeUtil.assignIncreasingFreshIds( + new Schema( + Lists.transform(SUPPORTED_PRIMITIVES.fields(), Types.NestedField::asRequired)))); + } + + @Test + public void testStructWithOptionalFields() throws IOException { + writeAndValidate( + TypeUtil.assignIncreasingFreshIds( + new Schema( + Lists.transform(SUPPORTED_PRIMITIVES.fields(), Types.NestedField::asOptional)))); + } + + @Test + public void testNestedStruct() throws IOException { + writeAndValidate( + TypeUtil.assignIncreasingFreshIds(new Schema(required(1, "struct", SUPPORTED_PRIMITIVES)))); + } + + @Test + public void testArray() throws IOException { + Schema schema = + new Schema( + required(0, "id", LongType.get()), + optional(1, "data", ListType.ofOptional(2, Types.StringType.get()))); + + writeAndValidate(schema); + } + + @Test + public void testArrayOfStructs() throws IOException { + Schema schema = + TypeUtil.assignIncreasingFreshIds( + new Schema( + required(0, "id", LongType.get()), + optional(1, "data", ListType.ofOptional(2, SUPPORTED_PRIMITIVES)))); + + writeAndValidate(schema); + } + + @Test + public void testMap() throws IOException { + Schema schema = + new Schema( + required(0, "id", LongType.get()), + optional( + 1, + "data", + MapType.ofOptional(2, 3, Types.StringType.get(), Types.StringType.get()))); + + writeAndValidate(schema); + } + + @Test + public void testNumericMapKey() throws IOException { + Schema schema = + new Schema( + required(0, "id", LongType.get()), + optional( + 1, "data", MapType.ofOptional(2, 3, Types.LongType.get(), Types.StringType.get()))); + + writeAndValidate(schema); + } + + @Test + public void testComplexMapKey() throws IOException { + Schema schema = + new Schema( + required(0, "id", LongType.get()), + optional( + 1, + "data", + MapType.ofOptional( + 2, + 3, + Types.StructType.of( + required(4, "i", Types.IntegerType.get()), + optional(5, "s", Types.StringType.get())), + Types.StringType.get()))); + + writeAndValidate(schema); + } + + @Test + public void testMapOfStructs() throws IOException { + Schema schema = + TypeUtil.assignIncreasingFreshIds( + new Schema( + required(0, "id", LongType.get()), + optional( + 1, + "data", + MapType.ofOptional(2, 3, Types.StringType.get(), SUPPORTED_PRIMITIVES)))); + + writeAndValidate(schema); + } + + @Test + public void testMixedTypes() throws IOException { + StructType structType = + StructType.of( + required(0, "id", LongType.get()), + optional( + 1, + "list_of_maps", + ListType.ofOptional( + 2, MapType.ofOptional(3, 4, Types.StringType.get(), SUPPORTED_PRIMITIVES))), + optional( + 5, + "map_of_lists", + MapType.ofOptional( + 6, 7, Types.StringType.get(), ListType.ofOptional(8, SUPPORTED_PRIMITIVES))), + required( + 9, + "list_of_lists", + ListType.ofOptional(10, ListType.ofOptional(11, SUPPORTED_PRIMITIVES))), + required( + 12, + "map_of_maps", + MapType.ofOptional( + 13, + 14, + Types.StringType.get(), + MapType.ofOptional(15, 16, Types.StringType.get(), SUPPORTED_PRIMITIVES))), + required( + 17, + "list_of_struct_of_nested_types", + ListType.ofOptional( + 19, + StructType.of( + Types.NestedField.required( + 20, + "m1", + MapType.ofOptional( + 21, 22, Types.StringType.get(), SUPPORTED_PRIMITIVES)), + Types.NestedField.optional( + 23, "l1", ListType.ofRequired(24, SUPPORTED_PRIMITIVES)), + Types.NestedField.required( + 25, "l2", ListType.ofRequired(26, SUPPORTED_PRIMITIVES)), + Types.NestedField.optional( + 27, + "m2", + MapType.ofOptional( + 28, 29, Types.StringType.get(), SUPPORTED_PRIMITIVES)))))); + + Schema schema = + new Schema( + TypeUtil.assignFreshIds(structType, new AtomicInteger(0)::incrementAndGet) + .asStructType() + .fields()); + + writeAndValidate(schema); + } + + @Test + public void testTimestampWithoutZone() throws IOException { + withSQLConf( + ImmutableMap.of(SparkSQLProperties.HANDLE_TIMESTAMP_WITHOUT_TIMEZONE, "true"), + () -> { + Schema schema = + TypeUtil.assignIncreasingFreshIds( + new Schema( + required(0, "id", LongType.get()), + optional(1, "ts_without_zone", Types.TimestampType.withoutZone()))); + + writeAndValidate(schema); + }); + } + + protected void withSQLConf(Map conf, Action action) throws IOException { + SQLConf sqlConf = SQLConf.get(); + + Map currentConfValues = Maps.newHashMap(); + conf.keySet() + .forEach( + confKey -> { + if (sqlConf.contains(confKey)) { + String currentConfValue = sqlConf.getConfString(confKey); + currentConfValues.put(confKey, currentConfValue); + } + }); + + conf.forEach( + (confKey, confValue) -> { + if (SQLConf.isStaticConfigKey(confKey)) { + throw new RuntimeException("Cannot modify the value of a static config: " + confKey); + } + sqlConf.setConfString(confKey, confValue); + }); + + try { + action.invoke(); + } finally { + conf.forEach( + (confKey, confValue) -> { + if (currentConfValues.containsKey(confKey)) { + sqlConf.setConfString(confKey, currentConfValues.get(confKey)); + } else { + sqlConf.unsetConf(confKey); + } + }); + } + } + + @FunctionalInterface + protected interface Action { + void invoke() throws IOException; + } +} diff --git a/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/data/GenericsHelpers.java b/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/data/GenericsHelpers.java new file mode 100644 index 000000000000..a96e3b1f57f5 --- /dev/null +++ b/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/data/GenericsHelpers.java @@ -0,0 +1,346 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.data; + +import static org.apache.iceberg.spark.SparkSchemaUtil.convert; +import static scala.collection.JavaConverters.mapAsJavaMapConverter; +import static scala.collection.JavaConverters.seqAsJavaListConverter; + +import java.math.BigDecimal; +import java.nio.ByteBuffer; +import java.sql.Timestamp; +import java.time.Instant; +import java.time.LocalDate; +import java.time.LocalDateTime; +import java.time.OffsetDateTime; +import java.time.ZoneId; +import java.time.ZoneOffset; +import java.time.temporal.ChronoUnit; +import java.util.Collection; +import java.util.Date; +import java.util.List; +import java.util.Map; +import java.util.UUID; +import org.apache.iceberg.data.Record; +import org.apache.iceberg.relocated.com.google.common.collect.Lists; +import org.apache.iceberg.types.Type; +import org.apache.iceberg.types.Types; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.catalyst.InternalRow; +import org.apache.spark.sql.catalyst.util.ArrayData; +import org.apache.spark.sql.catalyst.util.MapData; +import org.apache.spark.sql.types.Decimal; +import org.apache.spark.unsafe.types.UTF8String; +import org.assertj.core.api.Assertions; +import org.junit.Assert; +import scala.collection.Seq; + +public class GenericsHelpers { + private GenericsHelpers() {} + + private static final OffsetDateTime EPOCH = Instant.ofEpochMilli(0L).atOffset(ZoneOffset.UTC); + private static final LocalDate EPOCH_DAY = EPOCH.toLocalDate(); + + public static void assertEqualsSafe(Types.StructType struct, Record expected, Row actual) { + List fields = struct.fields(); + for (int i = 0; i < fields.size(); i += 1) { + Type fieldType = fields.get(i).type(); + + Object expectedValue = expected.get(i); + Object actualValue = actual.get(i); + + assertEqualsSafe(fieldType, expectedValue, actualValue); + } + } + + private static void assertEqualsSafe( + Types.ListType list, Collection expected, List actual) { + Type elementType = list.elementType(); + List expectedElements = Lists.newArrayList(expected); + for (int i = 0; i < expectedElements.size(); i += 1) { + Object expectedValue = expectedElements.get(i); + Object actualValue = actual.get(i); + + assertEqualsSafe(elementType, expectedValue, actualValue); + } + } + + private static void assertEqualsSafe(Types.MapType map, Map expected, Map actual) { + Type keyType = map.keyType(); + Type valueType = map.valueType(); + Assert.assertEquals( + "Should have the same number of keys", expected.keySet().size(), actual.keySet().size()); + + for (Object expectedKey : expected.keySet()) { + Object matchingKey = null; + for (Object actualKey : actual.keySet()) { + try { + assertEqualsSafe(keyType, expectedKey, actualKey); + matchingKey = actualKey; + break; + } catch (AssertionError e) { + // failed + } + } + + Assert.assertNotNull("Should have a matching key", matchingKey); + assertEqualsSafe(valueType, expected.get(expectedKey), actual.get(matchingKey)); + } + } + + @SuppressWarnings("unchecked") + private static void assertEqualsSafe(Type type, Object expected, Object actual) { + if (expected == null && actual == null) { + return; + } + + switch (type.typeId()) { + case BOOLEAN: + case INTEGER: + case LONG: + case FLOAT: + case DOUBLE: + Assert.assertEquals("Primitive value should be equal to expected", expected, actual); + break; + case DATE: + Assertions.assertThat(expected) + .as("Should expect a LocalDate") + .isInstanceOf(LocalDate.class); + Assertions.assertThat(actual).as("Should be a Date").isInstanceOf(Date.class); + Assert.assertEquals( + "ISO-8601 date should be equal", expected.toString(), actual.toString()); + break; + case TIMESTAMP: + Assertions.assertThat(actual).as("Should be a Timestamp").isInstanceOf(Timestamp.class); + Timestamp ts = (Timestamp) actual; + // milliseconds from nanos has already been added by getTime + OffsetDateTime actualTs = + EPOCH.plusNanos((ts.getTime() * 1_000_000) + (ts.getNanos() % 1_000_000)); + Types.TimestampType timestampType = (Types.TimestampType) type; + if (timestampType.shouldAdjustToUTC()) { + Assertions.assertThat(expected) + .as("Should expect an OffsetDateTime") + .isInstanceOf(OffsetDateTime.class); + Assert.assertEquals("Timestamp should be equal", expected, actualTs); + } else { + Assertions.assertThat(expected) + .as("Should expect an LocalDateTime") + .isInstanceOf(LocalDateTime.class); + Assert.assertEquals("Timestamp should be equal", expected, actualTs.toLocalDateTime()); + } + break; + case STRING: + Assertions.assertThat(actual).as("Should be a String").isInstanceOf(String.class); + Assert.assertEquals("Strings should be equal", String.valueOf(expected), actual); + break; + case UUID: + Assertions.assertThat(expected).as("Should expect a UUID").isInstanceOf(UUID.class); + Assertions.assertThat(actual).as("Should be a String").isInstanceOf(String.class); + Assert.assertEquals("UUID string representation should match", expected.toString(), actual); + break; + case FIXED: + Assertions.assertThat(expected).as("Should expect a byte[]").isInstanceOf(byte[].class); + Assertions.assertThat(actual).as("Should be a byte[]").isInstanceOf(byte[].class); + Assert.assertArrayEquals("Bytes should match", (byte[]) expected, (byte[]) actual); + break; + case BINARY: + Assertions.assertThat(expected) + .as("Should expect a ByteBuffer") + .isInstanceOf(ByteBuffer.class); + Assertions.assertThat(actual).as("Should be a byte[]").isInstanceOf(byte[].class); + Assert.assertArrayEquals( + "Bytes should match", ((ByteBuffer) expected).array(), (byte[]) actual); + break; + case DECIMAL: + Assertions.assertThat(expected) + .as("Should expect a BigDecimal") + .isInstanceOf(BigDecimal.class); + Assertions.assertThat(actual).as("Should be a BigDecimal").isInstanceOf(BigDecimal.class); + Assert.assertEquals("BigDecimals should be equal", expected, actual); + break; + case STRUCT: + Assertions.assertThat(expected).as("Should expect a Record").isInstanceOf(Record.class); + Assertions.assertThat(actual).as("Should be a Row").isInstanceOf(Row.class); + assertEqualsSafe(type.asNestedType().asStructType(), (Record) expected, (Row) actual); + break; + case LIST: + Assertions.assertThat(expected) + .as("Should expect a Collection") + .isInstanceOf(Collection.class); + Assertions.assertThat(actual).as("Should be a Seq").isInstanceOf(Seq.class); + List asList = seqAsJavaListConverter((Seq) actual).asJava(); + assertEqualsSafe(type.asNestedType().asListType(), (Collection) expected, asList); + break; + case MAP: + Assertions.assertThat(expected).as("Should expect a Collection").isInstanceOf(Map.class); + Assertions.assertThat(actual) + .as("Should be a Map") + .isInstanceOf(scala.collection.Map.class); + Map asMap = + mapAsJavaMapConverter((scala.collection.Map) actual).asJava(); + assertEqualsSafe(type.asNestedType().asMapType(), (Map) expected, asMap); + break; + case TIME: + default: + throw new IllegalArgumentException("Not a supported type: " + type); + } + } + + public static void assertEqualsUnsafe( + Types.StructType struct, Record expected, InternalRow actual) { + List fields = struct.fields(); + for (int i = 0; i < fields.size(); i += 1) { + Type fieldType = fields.get(i).type(); + + Object expectedValue = expected.get(i); + Object actualValue = actual.get(i, convert(fieldType)); + + assertEqualsUnsafe(fieldType, expectedValue, actualValue); + } + } + + private static void assertEqualsUnsafe( + Types.ListType list, Collection expected, ArrayData actual) { + Type elementType = list.elementType(); + List expectedElements = Lists.newArrayList(expected); + for (int i = 0; i < expectedElements.size(); i += 1) { + Object expectedValue = expectedElements.get(i); + Object actualValue = actual.get(i, convert(elementType)); + + assertEqualsUnsafe(elementType, expectedValue, actualValue); + } + } + + private static void assertEqualsUnsafe(Types.MapType map, Map expected, MapData actual) { + Type keyType = map.keyType(); + Type valueType = map.valueType(); + + List> expectedElements = Lists.newArrayList(expected.entrySet()); + ArrayData actualKeys = actual.keyArray(); + ArrayData actualValues = actual.valueArray(); + + for (int i = 0; i < expectedElements.size(); i += 1) { + Map.Entry expectedPair = expectedElements.get(i); + Object actualKey = actualKeys.get(i, convert(keyType)); + Object actualValue = actualValues.get(i, convert(keyType)); + + assertEqualsUnsafe(keyType, expectedPair.getKey(), actualKey); + assertEqualsUnsafe(valueType, expectedPair.getValue(), actualValue); + } + } + + private static void assertEqualsUnsafe(Type type, Object expected, Object actual) { + if (expected == null && actual == null) { + return; + } + + switch (type.typeId()) { + case BOOLEAN: + case INTEGER: + case LONG: + case FLOAT: + case DOUBLE: + Assert.assertEquals("Primitive value should be equal to expected", expected, actual); + break; + case DATE: + Assertions.assertThat(expected) + .as("Should expect a LocalDate") + .isInstanceOf(LocalDate.class); + int expectedDays = (int) ChronoUnit.DAYS.between(EPOCH_DAY, (LocalDate) expected); + Assert.assertEquals("Primitive value should be equal to expected", expectedDays, actual); + break; + case TIMESTAMP: + Types.TimestampType timestampType = (Types.TimestampType) type; + if (timestampType.shouldAdjustToUTC()) { + Assertions.assertThat(expected) + .as("Should expect an OffsetDateTime") + .isInstanceOf(OffsetDateTime.class); + long expectedMicros = ChronoUnit.MICROS.between(EPOCH, (OffsetDateTime) expected); + Assert.assertEquals( + "Primitive value should be equal to expected", expectedMicros, actual); + } else { + Assertions.assertThat(expected) + .as("Should expect an LocalDateTime") + .isInstanceOf(LocalDateTime.class); + long expectedMicros = + ChronoUnit.MICROS.between(EPOCH, ((LocalDateTime) expected).atZone(ZoneId.of("UTC"))); + Assert.assertEquals( + "Primitive value should be equal to expected", expectedMicros, actual); + } + break; + case STRING: + Assertions.assertThat(actual).as("Should be a UTF8String").isInstanceOf(UTF8String.class); + Assert.assertEquals("Strings should be equal", expected, actual.toString()); + break; + case UUID: + Assertions.assertThat(expected).as("Should expect a UUID").isInstanceOf(UUID.class); + Assertions.assertThat(actual).as("Should be a UTF8String").isInstanceOf(UTF8String.class); + Assert.assertEquals( + "UUID string representation should match", expected.toString(), actual.toString()); + break; + case FIXED: + Assertions.assertThat(expected).as("Should expect a byte[]").isInstanceOf(byte[].class); + Assertions.assertThat(actual).as("Should be a byte[]").isInstanceOf(byte[].class); + Assert.assertArrayEquals("Bytes should match", (byte[]) expected, (byte[]) actual); + break; + case BINARY: + Assertions.assertThat(expected) + .as("Should expect a ByteBuffer") + .isInstanceOf(ByteBuffer.class); + Assertions.assertThat(actual).as("Should be a byte[]").isInstanceOf(byte[].class); + Assert.assertArrayEquals( + "Bytes should match", ((ByteBuffer) expected).array(), (byte[]) actual); + break; + case DECIMAL: + Assertions.assertThat(expected) + .as("Should expect a BigDecimal") + .isInstanceOf(BigDecimal.class); + Assertions.assertThat(actual).as("Should be a Decimal").isInstanceOf(Decimal.class); + Assert.assertEquals( + "BigDecimals should be equal", expected, ((Decimal) actual).toJavaBigDecimal()); + break; + case STRUCT: + Assertions.assertThat(expected).as("Should expect a Record").isInstanceOf(Record.class); + Assertions.assertThat(actual) + .as("Should be an InternalRow") + .isInstanceOf(InternalRow.class); + assertEqualsUnsafe( + type.asNestedType().asStructType(), (Record) expected, (InternalRow) actual); + break; + case LIST: + Assertions.assertThat(expected) + .as("Should expect a Collection") + .isInstanceOf(Collection.class); + Assertions.assertThat(actual).as("Should be an ArrayData").isInstanceOf(ArrayData.class); + assertEqualsUnsafe( + type.asNestedType().asListType(), (Collection) expected, (ArrayData) actual); + break; + case MAP: + Assertions.assertThat(expected).as("Should expect a Map").isInstanceOf(Map.class); + Assertions.assertThat(actual) + .as("Should be an ArrayBasedMapData") + .isInstanceOf(MapData.class); + assertEqualsUnsafe(type.asNestedType().asMapType(), (Map) expected, (MapData) actual); + break; + case TIME: + default: + throw new IllegalArgumentException("Not a supported type: " + type); + } + } +} diff --git a/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/data/RandomData.java b/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/data/RandomData.java new file mode 100644 index 000000000000..1c95df8ced12 --- /dev/null +++ b/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/data/RandomData.java @@ -0,0 +1,368 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.data; + +import java.math.BigDecimal; +import java.nio.ByteBuffer; +import java.util.Iterator; +import java.util.List; +import java.util.Map; +import java.util.NoSuchElementException; +import java.util.Random; +import java.util.Set; +import java.util.UUID; +import java.util.function.Supplier; +import org.apache.avro.generic.GenericData; +import org.apache.avro.generic.GenericData.Record; +import org.apache.iceberg.Schema; +import org.apache.iceberg.avro.AvroSchemaUtil; +import org.apache.iceberg.relocated.com.google.common.base.Preconditions; +import org.apache.iceberg.relocated.com.google.common.collect.Lists; +import org.apache.iceberg.relocated.com.google.common.collect.Maps; +import org.apache.iceberg.relocated.com.google.common.collect.Sets; +import org.apache.iceberg.types.Type; +import org.apache.iceberg.types.TypeUtil; +import org.apache.iceberg.types.Types; +import org.apache.iceberg.util.RandomUtil; +import org.apache.spark.sql.catalyst.InternalRow; +import org.apache.spark.sql.catalyst.expressions.GenericInternalRow; +import org.apache.spark.sql.catalyst.util.ArrayBasedMapData; +import org.apache.spark.sql.catalyst.util.GenericArrayData; +import org.apache.spark.sql.types.Decimal; +import org.apache.spark.unsafe.types.UTF8String; + +public class RandomData { + + // Default percentage of number of values that are null for optional fields + public static final float DEFAULT_NULL_PERCENTAGE = 0.05f; + + private RandomData() {} + + public static List generateList(Schema schema, int numRecords, long seed) { + RandomDataGenerator generator = new RandomDataGenerator(schema, seed, DEFAULT_NULL_PERCENTAGE); + List records = Lists.newArrayListWithExpectedSize(numRecords); + for (int i = 0; i < numRecords; i += 1) { + records.add((Record) TypeUtil.visit(schema, generator)); + } + + return records; + } + + public static Iterable generateSpark(Schema schema, int numRecords, long seed) { + return () -> + new Iterator() { + private SparkRandomDataGenerator generator = new SparkRandomDataGenerator(seed); + private int count = 0; + + @Override + public boolean hasNext() { + return count < numRecords; + } + + @Override + public InternalRow next() { + if (count >= numRecords) { + throw new NoSuchElementException(); + } + count += 1; + return (InternalRow) TypeUtil.visit(schema, generator); + } + }; + } + + public static Iterable generate(Schema schema, int numRecords, long seed) { + return newIterable( + () -> new RandomDataGenerator(schema, seed, DEFAULT_NULL_PERCENTAGE), schema, numRecords); + } + + public static Iterable generate( + Schema schema, int numRecords, long seed, float nullPercentage) { + return newIterable( + () -> new RandomDataGenerator(schema, seed, nullPercentage), schema, numRecords); + } + + public static Iterable generateFallbackData( + Schema schema, int numRecords, long seed, long numDictRecords) { + return newIterable( + () -> new FallbackDataGenerator(schema, seed, numDictRecords), schema, numRecords); + } + + public static Iterable generateDictionaryEncodableData( + Schema schema, int numRecords, long seed, float nullPercentage) { + return newIterable( + () -> new DictionaryEncodedDataGenerator(schema, seed, nullPercentage), schema, numRecords); + } + + private static Iterable newIterable( + Supplier newGenerator, Schema schema, int numRecords) { + return () -> + new Iterator() { + private int count = 0; + private RandomDataGenerator generator = newGenerator.get(); + + @Override + public boolean hasNext() { + return count < numRecords; + } + + @Override + public Record next() { + if (count >= numRecords) { + throw new NoSuchElementException(); + } + count += 1; + return (Record) TypeUtil.visit(schema, generator); + } + }; + } + + private static class RandomDataGenerator extends TypeUtil.CustomOrderSchemaVisitor { + private final Map typeToSchema; + private final Random random; + // Percentage of number of values that are null for optional fields + private final float nullPercentage; + + private RandomDataGenerator(Schema schema, long seed, float nullPercentage) { + Preconditions.checkArgument( + 0.0f <= nullPercentage && nullPercentage <= 1.0f, + "Percentage needs to be in the range (0.0, 1.0)"); + this.nullPercentage = nullPercentage; + this.typeToSchema = AvroSchemaUtil.convertTypes(schema.asStruct(), "test"); + this.random = new Random(seed); + } + + @Override + public Record schema(Schema schema, Supplier structResult) { + return (Record) structResult.get(); + } + + @Override + public Record struct(Types.StructType struct, Iterable fieldResults) { + Record rec = new Record(typeToSchema.get(struct)); + + List values = Lists.newArrayList(fieldResults); + for (int i = 0; i < values.size(); i += 1) { + rec.put(i, values.get(i)); + } + + return rec; + } + + @Override + public Object field(Types.NestedField field, Supplier fieldResult) { + if (field.isOptional() && isNull()) { + return null; + } + return fieldResult.get(); + } + + private boolean isNull() { + return random.nextFloat() < nullPercentage; + } + + @Override + public Object list(Types.ListType list, Supplier elementResult) { + int numElements = random.nextInt(20); + + List result = Lists.newArrayListWithExpectedSize(numElements); + for (int i = 0; i < numElements; i += 1) { + if (list.isElementOptional() && isNull()) { + result.add(null); + } else { + result.add(elementResult.get()); + } + } + + return result; + } + + @Override + public Object map(Types.MapType map, Supplier keyResult, Supplier valueResult) { + int numEntries = random.nextInt(20); + + Map result = Maps.newLinkedHashMap(); + Set keySet = Sets.newHashSet(); + for (int i = 0; i < numEntries; i += 1) { + Object key = keyResult.get(); + // ensure no collisions + while (keySet.contains(key)) { + key = keyResult.get(); + } + + keySet.add(key); + + if (map.isValueOptional() && isNull()) { + result.put(key, null); + } else { + result.put(key, valueResult.get()); + } + } + + return result; + } + + @Override + public Object primitive(Type.PrimitiveType primitive) { + Object result = randomValue(primitive, random); + // For the primitives that Avro needs a different type than Spark, fix + // them here. + switch (primitive.typeId()) { + case FIXED: + return new GenericData.Fixed(typeToSchema.get(primitive), (byte[]) result); + case BINARY: + return ByteBuffer.wrap((byte[]) result); + case UUID: + return UUID.nameUUIDFromBytes((byte[]) result); + default: + return result; + } + } + + protected Object randomValue(Type.PrimitiveType primitive, Random rand) { + return RandomUtil.generatePrimitive(primitive, random); + } + } + + private static class SparkRandomDataGenerator extends TypeUtil.CustomOrderSchemaVisitor { + private final Random random; + + private SparkRandomDataGenerator(long seed) { + this.random = new Random(seed); + } + + @Override + public InternalRow schema(Schema schema, Supplier structResult) { + return (InternalRow) structResult.get(); + } + + @Override + public InternalRow struct(Types.StructType struct, Iterable fieldResults) { + List values = Lists.newArrayList(fieldResults); + GenericInternalRow row = new GenericInternalRow(values.size()); + for (int i = 0; i < values.size(); i += 1) { + row.update(i, values.get(i)); + } + + return row; + } + + @Override + public Object field(Types.NestedField field, Supplier fieldResult) { + // return null 5% of the time when the value is optional + if (field.isOptional() && random.nextInt(20) == 1) { + return null; + } + return fieldResult.get(); + } + + @Override + public GenericArrayData list(Types.ListType list, Supplier elementResult) { + int numElements = random.nextInt(20); + Object[] arr = new Object[numElements]; + GenericArrayData result = new GenericArrayData(arr); + + for (int i = 0; i < numElements; i += 1) { + // return null 5% of the time when the value is optional + if (list.isElementOptional() && random.nextInt(20) == 1) { + arr[i] = null; + } else { + arr[i] = elementResult.get(); + } + } + + return result; + } + + @Override + public Object map(Types.MapType map, Supplier keyResult, Supplier valueResult) { + int numEntries = random.nextInt(20); + + Object[] keysArr = new Object[numEntries]; + Object[] valuesArr = new Object[numEntries]; + GenericArrayData keys = new GenericArrayData(keysArr); + GenericArrayData values = new GenericArrayData(valuesArr); + ArrayBasedMapData result = new ArrayBasedMapData(keys, values); + + Set keySet = Sets.newHashSet(); + for (int i = 0; i < numEntries; i += 1) { + Object key = keyResult.get(); + // ensure no collisions + while (keySet.contains(key)) { + key = keyResult.get(); + } + + keySet.add(key); + + keysArr[i] = key; + // return null 5% of the time when the value is optional + if (map.isValueOptional() && random.nextInt(20) == 1) { + valuesArr[i] = null; + } else { + valuesArr[i] = valueResult.get(); + } + } + + return result; + } + + @Override + public Object primitive(Type.PrimitiveType primitive) { + Object obj = RandomUtil.generatePrimitive(primitive, random); + switch (primitive.typeId()) { + case STRING: + return UTF8String.fromString((String) obj); + case DECIMAL: + return Decimal.apply((BigDecimal) obj); + default: + return obj; + } + } + } + + private static class DictionaryEncodedDataGenerator extends RandomDataGenerator { + private DictionaryEncodedDataGenerator(Schema schema, long seed, float nullPercentage) { + super(schema, seed, nullPercentage); + } + + @Override + protected Object randomValue(Type.PrimitiveType primitive, Random random) { + return RandomUtil.generateDictionaryEncodablePrimitive(primitive, random); + } + } + + private static class FallbackDataGenerator extends RandomDataGenerator { + private final long dictionaryEncodedRows; + private long rowCount = 0; + + private FallbackDataGenerator(Schema schema, long seed, long numDictionaryEncoded) { + super(schema, seed, DEFAULT_NULL_PERCENTAGE); + this.dictionaryEncodedRows = numDictionaryEncoded; + } + + @Override + protected Object randomValue(Type.PrimitiveType primitive, Random rand) { + this.rowCount += 1; + if (rowCount > dictionaryEncodedRows) { + return RandomUtil.generatePrimitive(primitive, rand); + } else { + return RandomUtil.generateDictionaryEncodablePrimitive(primitive, rand); + } + } + } +} diff --git a/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/data/TestHelpers.java b/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/data/TestHelpers.java new file mode 100644 index 000000000000..35d16d6f8588 --- /dev/null +++ b/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/data/TestHelpers.java @@ -0,0 +1,857 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.data; + +import static org.apache.iceberg.spark.SparkSchemaUtil.convert; +import static scala.collection.JavaConverters.mapAsJavaMapConverter; +import static scala.collection.JavaConverters.seqAsJavaListConverter; + +import java.math.BigDecimal; +import java.nio.ByteBuffer; +import java.sql.Timestamp; +import java.time.Instant; +import java.time.LocalDate; +import java.time.OffsetDateTime; +import java.time.ZoneOffset; +import java.time.temporal.ChronoUnit; +import java.util.Collection; +import java.util.Date; +import java.util.Iterator; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.UUID; +import java.util.stream.Collectors; +import java.util.stream.Stream; +import java.util.stream.StreamSupport; +import org.apache.arrow.vector.ValueVector; +import org.apache.avro.generic.GenericData; +import org.apache.avro.generic.GenericData.Record; +import org.apache.iceberg.DataFile; +import org.apache.iceberg.DeleteFile; +import org.apache.iceberg.FileContent; +import org.apache.iceberg.FileScanTask; +import org.apache.iceberg.ManifestFile; +import org.apache.iceberg.Schema; +import org.apache.iceberg.Table; +import org.apache.iceberg.TableScan; +import org.apache.iceberg.relocated.com.google.common.collect.Lists; +import org.apache.iceberg.relocated.com.google.common.collect.Sets; +import org.apache.iceberg.relocated.com.google.common.collect.Streams; +import org.apache.iceberg.spark.SparkSchemaUtil; +import org.apache.iceberg.spark.data.vectorized.IcebergArrowColumnVector; +import org.apache.iceberg.types.Type; +import org.apache.iceberg.types.Types; +import org.apache.orc.storage.serde2.io.DateWritable; +import org.apache.spark.sql.Column; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.catalyst.InternalRow; +import org.apache.spark.sql.catalyst.expressions.GenericRow; +import org.apache.spark.sql.catalyst.expressions.SpecializedGetters; +import org.apache.spark.sql.catalyst.util.ArrayData; +import org.apache.spark.sql.catalyst.util.DateTimeUtils; +import org.apache.spark.sql.catalyst.util.MapData; +import org.apache.spark.sql.types.ArrayType; +import org.apache.spark.sql.types.BinaryType; +import org.apache.spark.sql.types.DataType; +import org.apache.spark.sql.types.Decimal; +import org.apache.spark.sql.types.MapType; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; +import org.apache.spark.sql.vectorized.ColumnVector; +import org.apache.spark.sql.vectorized.ColumnarBatch; +import org.apache.spark.unsafe.types.UTF8String; +import org.assertj.core.api.Assertions; +import org.junit.Assert; +import scala.collection.Seq; + +public class TestHelpers { + + private TestHelpers() {} + + public static void assertEqualsSafe(Types.StructType struct, List recs, List rows) { + Streams.forEachPair( + recs.stream(), rows.stream(), (rec, row) -> assertEqualsSafe(struct, rec, row)); + } + + public static void assertEqualsSafe(Types.StructType struct, Record rec, Row row) { + List fields = struct.fields(); + for (int i = 0; i < fields.size(); i += 1) { + Type fieldType = fields.get(i).type(); + + Object expectedValue = rec.get(i); + Object actualValue = row.get(i); + + assertEqualsSafe(fieldType, expectedValue, actualValue); + } + } + + public static void assertEqualsBatch( + Types.StructType struct, + Iterator expected, + ColumnarBatch batch, + boolean checkArrowValidityVector) { + for (int rowId = 0; rowId < batch.numRows(); rowId++) { + List fields = struct.fields(); + InternalRow row = batch.getRow(rowId); + Record rec = expected.next(); + for (int i = 0; i < fields.size(); i += 1) { + Type fieldType = fields.get(i).type(); + Object expectedValue = rec.get(i); + Object actualValue = row.isNullAt(i) ? null : row.get(i, convert(fieldType)); + assertEqualsUnsafe(fieldType, expectedValue, actualValue); + + if (checkArrowValidityVector) { + ColumnVector columnVector = batch.column(i); + ValueVector arrowVector = + ((IcebergArrowColumnVector) columnVector).vectorAccessor().getVector(); + Assert.assertFalse( + "Nullability doesn't match of " + columnVector.dataType(), + expectedValue == null ^ arrowVector.isNull(rowId)); + } + } + } + } + + private static void assertEqualsSafe(Types.ListType list, Collection expected, List actual) { + Type elementType = list.elementType(); + List expectedElements = Lists.newArrayList(expected); + for (int i = 0; i < expectedElements.size(); i += 1) { + Object expectedValue = expectedElements.get(i); + Object actualValue = actual.get(i); + + assertEqualsSafe(elementType, expectedValue, actualValue); + } + } + + private static void assertEqualsSafe(Types.MapType map, Map expected, Map actual) { + Type keyType = map.keyType(); + Type valueType = map.valueType(); + + for (Object expectedKey : expected.keySet()) { + Object matchingKey = null; + for (Object actualKey : actual.keySet()) { + try { + assertEqualsSafe(keyType, expectedKey, actualKey); + matchingKey = actualKey; + } catch (AssertionError e) { + // failed + } + } + + Assert.assertNotNull("Should have a matching key", matchingKey); + assertEqualsSafe(valueType, expected.get(expectedKey), actual.get(matchingKey)); + } + } + + private static final OffsetDateTime EPOCH = Instant.ofEpochMilli(0L).atOffset(ZoneOffset.UTC); + private static final LocalDate EPOCH_DAY = EPOCH.toLocalDate(); + + @SuppressWarnings("unchecked") + private static void assertEqualsSafe(Type type, Object expected, Object actual) { + if (expected == null && actual == null) { + return; + } + + switch (type.typeId()) { + case BOOLEAN: + case INTEGER: + case LONG: + case FLOAT: + case DOUBLE: + Assert.assertEquals("Primitive value should be equal to expected", expected, actual); + break; + case DATE: + Assertions.assertThat(expected).as("Should be an int").isInstanceOf(Integer.class); + Assertions.assertThat(actual).as("Should be a Date").isInstanceOf(Date.class); + int daysFromEpoch = (Integer) expected; + LocalDate date = ChronoUnit.DAYS.addTo(EPOCH_DAY, daysFromEpoch); + Assert.assertEquals("ISO-8601 date should be equal", date.toString(), actual.toString()); + break; + case TIMESTAMP: + Assertions.assertThat(expected).as("Should be a long").isInstanceOf(Long.class); + Assertions.assertThat(actual).as("Should be a Timestamp").isInstanceOf(Timestamp.class); + Timestamp ts = (Timestamp) actual; + // milliseconds from nanos has already been added by getTime + long tsMicros = (ts.getTime() * 1000) + ((ts.getNanos() / 1000) % 1000); + Assert.assertEquals("Timestamp micros should be equal", expected, tsMicros); + break; + case STRING: + Assertions.assertThat(actual).as("Should be a String").isInstanceOf(String.class); + Assert.assertEquals("Strings should be equal", String.valueOf(expected), actual); + break; + case UUID: + Assertions.assertThat(expected).as("Should expect a UUID").isInstanceOf(UUID.class); + Assertions.assertThat(actual).as("Should be a String").isInstanceOf(String.class); + Assert.assertEquals("UUID string representation should match", expected.toString(), actual); + break; + case FIXED: + Assertions.assertThat(expected) + .as("Should expect a Fixed") + .isInstanceOf(GenericData.Fixed.class); + Assertions.assertThat(actual).as("Should be a byte[]").isInstanceOf(byte[].class); + Assert.assertArrayEquals( + "Bytes should match", ((GenericData.Fixed) expected).bytes(), (byte[]) actual); + break; + case BINARY: + Assertions.assertThat(expected) + .as("Should expect a ByteBuffer") + .isInstanceOf(ByteBuffer.class); + Assertions.assertThat(actual).as("Should be a byte[]").isInstanceOf(byte[].class); + Assert.assertArrayEquals( + "Bytes should match", ((ByteBuffer) expected).array(), (byte[]) actual); + break; + case DECIMAL: + Assertions.assertThat(expected) + .as("Should expect a BigDecimal") + .isInstanceOf(BigDecimal.class); + Assertions.assertThat(actual).as("Should be a BigDecimal").isInstanceOf(BigDecimal.class); + Assert.assertEquals("BigDecimals should be equal", expected, actual); + break; + case STRUCT: + Assertions.assertThat(expected).as("Should expect a Record").isInstanceOf(Record.class); + Assertions.assertThat(actual).as("Should be a Row").isInstanceOf(Row.class); + assertEqualsSafe(type.asNestedType().asStructType(), (Record) expected, (Row) actual); + break; + case LIST: + Assertions.assertThat(expected) + .as("Should expect a Collection") + .isInstanceOf(Collection.class); + Assertions.assertThat(actual).as("Should be a Seq").isInstanceOf(Seq.class); + List asList = seqAsJavaListConverter((Seq) actual).asJava(); + assertEqualsSafe(type.asNestedType().asListType(), (Collection) expected, asList); + break; + case MAP: + Assertions.assertThat(expected).as("Should expect a Collection").isInstanceOf(Map.class); + Assertions.assertThat(actual) + .as("Should be a Map") + .isInstanceOf(scala.collection.Map.class); + Map asMap = + mapAsJavaMapConverter((scala.collection.Map) actual).asJava(); + assertEqualsSafe(type.asNestedType().asMapType(), (Map) expected, asMap); + break; + case TIME: + default: + throw new IllegalArgumentException("Not a supported type: " + type); + } + } + + public static void assertEqualsUnsafe(Types.StructType struct, Record rec, InternalRow row) { + List fields = struct.fields(); + for (int i = 0; i < fields.size(); i += 1) { + Type fieldType = fields.get(i).type(); + + Object expectedValue = rec.get(i); + Object actualValue = row.isNullAt(i) ? null : row.get(i, convert(fieldType)); + + assertEqualsUnsafe(fieldType, expectedValue, actualValue); + } + } + + private static void assertEqualsUnsafe( + Types.ListType list, Collection expected, ArrayData actual) { + Type elementType = list.elementType(); + List expectedElements = Lists.newArrayList(expected); + for (int i = 0; i < expectedElements.size(); i += 1) { + Object expectedValue = expectedElements.get(i); + Object actualValue = actual.get(i, convert(elementType)); + + assertEqualsUnsafe(elementType, expectedValue, actualValue); + } + } + + private static void assertEqualsUnsafe(Types.MapType map, Map expected, MapData actual) { + Type keyType = map.keyType(); + Type valueType = map.valueType(); + + List> expectedElements = Lists.newArrayList(expected.entrySet()); + ArrayData actualKeys = actual.keyArray(); + ArrayData actualValues = actual.valueArray(); + + for (int i = 0; i < expectedElements.size(); i += 1) { + Map.Entry expectedPair = expectedElements.get(i); + Object actualKey = actualKeys.get(i, convert(keyType)); + Object actualValue = actualValues.get(i, convert(keyType)); + + assertEqualsUnsafe(keyType, expectedPair.getKey(), actualKey); + assertEqualsUnsafe(valueType, expectedPair.getValue(), actualValue); + } + } + + private static void assertEqualsUnsafe(Type type, Object expected, Object actual) { + if (expected == null && actual == null) { + return; + } + + switch (type.typeId()) { + case LONG: + Assertions.assertThat(actual).as("Should be a long").isInstanceOf(Long.class); + if (expected instanceof Integer) { + Assert.assertEquals("Values didn't match", ((Number) expected).longValue(), actual); + } else { + Assert.assertEquals("Primitive value should be equal to expected", expected, actual); + } + break; + case DOUBLE: + Assertions.assertThat(actual).as("Should be a double").isInstanceOf(Double.class); + if (expected instanceof Float) { + Assert.assertEquals( + "Values didn't match", + Double.doubleToLongBits(((Number) expected).doubleValue()), + Double.doubleToLongBits((double) actual)); + } else { + Assert.assertEquals("Primitive value should be equal to expected", expected, actual); + } + break; + case INTEGER: + case FLOAT: + case BOOLEAN: + case DATE: + case TIMESTAMP: + Assert.assertEquals("Primitive value should be equal to expected", expected, actual); + break; + case STRING: + Assertions.assertThat(actual).as("Should be a UTF8String").isInstanceOf(UTF8String.class); + Assert.assertEquals("Strings should be equal", expected, actual.toString()); + break; + case UUID: + Assertions.assertThat(expected).as("Should expect a UUID").isInstanceOf(UUID.class); + Assertions.assertThat(actual).as("Should be a UTF8String").isInstanceOf(UTF8String.class); + Assert.assertEquals( + "UUID string representation should match", expected.toString(), actual.toString()); + break; + case FIXED: + Assertions.assertThat(expected) + .as("Should expect a Fixed") + .isInstanceOf(GenericData.Fixed.class); + Assertions.assertThat(actual).as("Should be a byte[]").isInstanceOf(byte[].class); + Assert.assertArrayEquals( + "Bytes should match", ((GenericData.Fixed) expected).bytes(), (byte[]) actual); + break; + case BINARY: + Assertions.assertThat(expected) + .as("Should expect a ByteBuffer") + .isInstanceOf(ByteBuffer.class); + Assertions.assertThat(actual).as("Should be a byte[]").isInstanceOf(byte[].class); + Assert.assertArrayEquals( + "Bytes should match", ((ByteBuffer) expected).array(), (byte[]) actual); + break; + case DECIMAL: + Assertions.assertThat(expected) + .as("Should expect a BigDecimal") + .isInstanceOf(BigDecimal.class); + Assertions.assertThat(actual).as("Should be a Decimal").isInstanceOf(Decimal.class); + Assert.assertEquals( + "BigDecimals should be equal", expected, ((Decimal) actual).toJavaBigDecimal()); + break; + case STRUCT: + Assertions.assertThat(expected).as("Should expect a Record").isInstanceOf(Record.class); + Assertions.assertThat(actual) + .as("Should be an InternalRow") + .isInstanceOf(InternalRow.class); + assertEqualsUnsafe( + type.asNestedType().asStructType(), (Record) expected, (InternalRow) actual); + break; + case LIST: + Assertions.assertThat(expected) + .as("Should expect a Collection") + .isInstanceOf(Collection.class); + Assertions.assertThat(actual).as("Should be an ArrayData").isInstanceOf(ArrayData.class); + assertEqualsUnsafe( + type.asNestedType().asListType(), (Collection) expected, (ArrayData) actual); + break; + case MAP: + Assertions.assertThat(expected).as("Should expect a Map").isInstanceOf(Map.class); + Assertions.assertThat(actual) + .as("Should be an ArrayBasedMapData") + .isInstanceOf(MapData.class); + assertEqualsUnsafe(type.asNestedType().asMapType(), (Map) expected, (MapData) actual); + break; + case TIME: + default: + throw new IllegalArgumentException("Not a supported type: " + type); + } + } + + /** + * Check that the given InternalRow is equivalent to the Row. + * + * @param prefix context for error messages + * @param type the type of the row + * @param expected the expected value of the row + * @param actual the actual value of the row + */ + public static void assertEquals( + String prefix, Types.StructType type, InternalRow expected, Row actual) { + if (expected == null || actual == null) { + Assert.assertEquals(prefix, expected, actual); + } else { + List fields = type.fields(); + for (int c = 0; c < fields.size(); ++c) { + String fieldName = fields.get(c).name(); + Type childType = fields.get(c).type(); + switch (childType.typeId()) { + case BOOLEAN: + case INTEGER: + case LONG: + case FLOAT: + case DOUBLE: + case STRING: + case DECIMAL: + case DATE: + case TIMESTAMP: + Assert.assertEquals( + prefix + "." + fieldName + " - " + childType, + getValue(expected, c, childType), + getPrimitiveValue(actual, c, childType)); + break; + case UUID: + case FIXED: + case BINARY: + assertEqualBytes( + prefix + "." + fieldName, + (byte[]) getValue(expected, c, childType), + (byte[]) actual.get(c)); + break; + case STRUCT: + { + Types.StructType st = (Types.StructType) childType; + assertEquals( + prefix + "." + fieldName, + st, + expected.getStruct(c, st.fields().size()), + actual.getStruct(c)); + break; + } + case LIST: + assertEqualsLists( + prefix + "." + fieldName, + childType.asListType(), + expected.getArray(c), + toList((Seq) actual.get(c))); + break; + case MAP: + assertEqualsMaps( + prefix + "." + fieldName, + childType.asMapType(), + expected.getMap(c), + toJavaMap((scala.collection.Map) actual.getMap(c))); + break; + default: + throw new IllegalArgumentException("Unhandled type " + childType); + } + } + } + } + + private static void assertEqualsLists( + String prefix, Types.ListType type, ArrayData expected, List actual) { + if (expected == null || actual == null) { + Assert.assertEquals(prefix, expected, actual); + } else { + Assert.assertEquals(prefix + " length", expected.numElements(), actual.size()); + Type childType = type.elementType(); + for (int e = 0; e < expected.numElements(); ++e) { + switch (childType.typeId()) { + case BOOLEAN: + case INTEGER: + case LONG: + case FLOAT: + case DOUBLE: + case STRING: + case DECIMAL: + case DATE: + case TIMESTAMP: + Assert.assertEquals( + prefix + ".elem " + e + " - " + childType, + getValue(expected, e, childType), + actual.get(e)); + break; + case UUID: + case FIXED: + case BINARY: + assertEqualBytes( + prefix + ".elem " + e, + (byte[]) getValue(expected, e, childType), + (byte[]) actual.get(e)); + break; + case STRUCT: + { + Types.StructType st = (Types.StructType) childType; + assertEquals( + prefix + ".elem " + e, + st, + expected.getStruct(e, st.fields().size()), + (Row) actual.get(e)); + break; + } + case LIST: + assertEqualsLists( + prefix + ".elem " + e, + childType.asListType(), + expected.getArray(e), + toList((Seq) actual.get(e))); + break; + case MAP: + assertEqualsMaps( + prefix + ".elem " + e, + childType.asMapType(), + expected.getMap(e), + toJavaMap((scala.collection.Map) actual.get(e))); + break; + default: + throw new IllegalArgumentException("Unhandled type " + childType); + } + } + } + } + + private static void assertEqualsMaps( + String prefix, Types.MapType type, MapData expected, Map actual) { + if (expected == null || actual == null) { + Assert.assertEquals(prefix, expected, actual); + } else { + Type keyType = type.keyType(); + Type valueType = type.valueType(); + ArrayData expectedKeyArray = expected.keyArray(); + ArrayData expectedValueArray = expected.valueArray(); + Assert.assertEquals(prefix + " length", expected.numElements(), actual.size()); + for (int e = 0; e < expected.numElements(); ++e) { + Object expectedKey = getValue(expectedKeyArray, e, keyType); + Object actualValue = actual.get(expectedKey); + if (actualValue == null) { + Assert.assertEquals( + prefix + ".key=" + expectedKey + " has null", + true, + expected.valueArray().isNullAt(e)); + } else { + switch (valueType.typeId()) { + case BOOLEAN: + case INTEGER: + case LONG: + case FLOAT: + case DOUBLE: + case STRING: + case DECIMAL: + case DATE: + case TIMESTAMP: + Assert.assertEquals( + prefix + ".key=" + expectedKey + " - " + valueType, + getValue(expectedValueArray, e, valueType), + actual.get(expectedKey)); + break; + case UUID: + case FIXED: + case BINARY: + assertEqualBytes( + prefix + ".key=" + expectedKey, + (byte[]) getValue(expectedValueArray, e, valueType), + (byte[]) actual.get(expectedKey)); + break; + case STRUCT: + { + Types.StructType st = (Types.StructType) valueType; + assertEquals( + prefix + ".key=" + expectedKey, + st, + expectedValueArray.getStruct(e, st.fields().size()), + (Row) actual.get(expectedKey)); + break; + } + case LIST: + assertEqualsLists( + prefix + ".key=" + expectedKey, + valueType.asListType(), + expectedValueArray.getArray(e), + toList((Seq) actual.get(expectedKey))); + break; + case MAP: + assertEqualsMaps( + prefix + ".key=" + expectedKey, + valueType.asMapType(), + expectedValueArray.getMap(e), + toJavaMap((scala.collection.Map) actual.get(expectedKey))); + break; + default: + throw new IllegalArgumentException("Unhandled type " + valueType); + } + } + } + } + } + + private static Object getValue(SpecializedGetters container, int ord, Type type) { + if (container.isNullAt(ord)) { + return null; + } + switch (type.typeId()) { + case BOOLEAN: + return container.getBoolean(ord); + case INTEGER: + return container.getInt(ord); + case LONG: + return container.getLong(ord); + case FLOAT: + return container.getFloat(ord); + case DOUBLE: + return container.getDouble(ord); + case STRING: + return container.getUTF8String(ord).toString(); + case BINARY: + case FIXED: + case UUID: + return container.getBinary(ord); + case DATE: + return new DateWritable(container.getInt(ord)).get(); + case TIMESTAMP: + return DateTimeUtils.toJavaTimestamp(container.getLong(ord)); + case DECIMAL: + { + Types.DecimalType dt = (Types.DecimalType) type; + return container.getDecimal(ord, dt.precision(), dt.scale()).toJavaBigDecimal(); + } + case STRUCT: + Types.StructType struct = type.asStructType(); + InternalRow internalRow = container.getStruct(ord, struct.fields().size()); + Object[] data = new Object[struct.fields().size()]; + for (int i = 0; i < data.length; i += 1) { + if (internalRow.isNullAt(i)) { + data[i] = null; + } else { + data[i] = getValue(internalRow, i, struct.fields().get(i).type()); + } + } + return new GenericRow(data); + default: + throw new IllegalArgumentException("Unhandled type " + type); + } + } + + private static Object getPrimitiveValue(Row row, int ord, Type type) { + if (row.isNullAt(ord)) { + return null; + } + switch (type.typeId()) { + case BOOLEAN: + return row.getBoolean(ord); + case INTEGER: + return row.getInt(ord); + case LONG: + return row.getLong(ord); + case FLOAT: + return row.getFloat(ord); + case DOUBLE: + return row.getDouble(ord); + case STRING: + return row.getString(ord); + case BINARY: + case FIXED: + case UUID: + return row.get(ord); + case DATE: + return row.getDate(ord); + case TIMESTAMP: + return row.getTimestamp(ord); + case DECIMAL: + return row.getDecimal(ord); + default: + throw new IllegalArgumentException("Unhandled type " + type); + } + } + + private static Map toJavaMap(scala.collection.Map map) { + return map == null ? null : mapAsJavaMapConverter(map).asJava(); + } + + private static List toList(Seq val) { + return val == null ? null : seqAsJavaListConverter(val).asJava(); + } + + private static void assertEqualBytes(String context, byte[] expected, byte[] actual) { + if (expected == null || actual == null) { + Assert.assertEquals(context, expected, actual); + } else { + Assert.assertArrayEquals(context, expected, actual); + } + } + + static void assertEquals(Schema schema, Object expected, Object actual) { + assertEquals("schema", convert(schema), expected, actual); + } + + private static void assertEquals(String context, DataType type, Object expected, Object actual) { + if (expected == null && actual == null) { + return; + } + + if (type instanceof StructType) { + Assertions.assertThat(expected) + .as("Expected should be an InternalRow: " + context) + .isInstanceOf(InternalRow.class); + Assertions.assertThat(actual) + .as("Actual should be an InternalRow: " + context) + .isInstanceOf(InternalRow.class); + assertEquals(context, (StructType) type, (InternalRow) expected, (InternalRow) actual); + + } else if (type instanceof ArrayType) { + Assertions.assertThat(expected) + .as("Expected should be an ArrayData: " + context) + .isInstanceOf(ArrayData.class); + Assertions.assertThat(actual) + .as("Actual should be an ArrayData: " + context) + .isInstanceOf(ArrayData.class); + assertEquals(context, (ArrayType) type, (ArrayData) expected, (ArrayData) actual); + + } else if (type instanceof MapType) { + Assertions.assertThat(expected) + .as("Expected should be a MapData: " + context) + .isInstanceOf(MapData.class); + Assertions.assertThat(actual) + .as("Actual should be a MapData: " + context) + .isInstanceOf(MapData.class); + assertEquals(context, (MapType) type, (MapData) expected, (MapData) actual); + + } else if (type instanceof BinaryType) { + assertEqualBytes(context, (byte[]) expected, (byte[]) actual); + } else { + Assert.assertEquals("Value should match expected: " + context, expected, actual); + } + } + + private static void assertEquals( + String context, StructType struct, InternalRow expected, InternalRow actual) { + Assert.assertEquals("Should have correct number of fields", struct.size(), actual.numFields()); + for (int i = 0; i < actual.numFields(); i += 1) { + StructField field = struct.fields()[i]; + DataType type = field.dataType(); + assertEquals( + context + "." + field.name(), + type, + expected.isNullAt(i) ? null : expected.get(i, type), + actual.isNullAt(i) ? null : actual.get(i, type)); + } + } + + private static void assertEquals( + String context, ArrayType array, ArrayData expected, ArrayData actual) { + Assert.assertEquals( + "Should have the same number of elements", expected.numElements(), actual.numElements()); + DataType type = array.elementType(); + for (int i = 0; i < actual.numElements(); i += 1) { + assertEquals( + context + ".element", + type, + expected.isNullAt(i) ? null : expected.get(i, type), + actual.isNullAt(i) ? null : actual.get(i, type)); + } + } + + private static void assertEquals(String context, MapType map, MapData expected, MapData actual) { + Assert.assertEquals( + "Should have the same number of elements", expected.numElements(), actual.numElements()); + + DataType keyType = map.keyType(); + ArrayData expectedKeys = expected.keyArray(); + ArrayData expectedValues = expected.valueArray(); + + DataType valueType = map.valueType(); + ArrayData actualKeys = actual.keyArray(); + ArrayData actualValues = actual.valueArray(); + + for (int i = 0; i < actual.numElements(); i += 1) { + assertEquals( + context + ".key", + keyType, + expectedKeys.isNullAt(i) ? null : expectedKeys.get(i, keyType), + actualKeys.isNullAt(i) ? null : actualKeys.get(i, keyType)); + assertEquals( + context + ".value", + valueType, + expectedValues.isNullAt(i) ? null : expectedValues.get(i, valueType), + actualValues.isNullAt(i) ? null : actualValues.get(i, valueType)); + } + } + + public static List dataManifests(Table table) { + return table.currentSnapshot().dataManifests(table.io()); + } + + public static List deleteManifests(Table table) { + return table.currentSnapshot().deleteManifests(table.io()); + } + + public static Set dataFiles(Table table) { + return dataFiles(table, null); + } + + public static Set dataFiles(Table table, String branch) { + Set dataFiles = Sets.newHashSet(); + TableScan scan = table.newScan(); + if (branch != null) { + scan.useRef(branch); + } + + for (FileScanTask task : scan.planFiles()) { + dataFiles.add(task.file()); + } + + return dataFiles; + } + + public static Set deleteFiles(Table table) { + Set deleteFiles = Sets.newHashSet(); + + for (FileScanTask task : table.newScan().planFiles()) { + deleteFiles.addAll(task.deletes()); + } + + return deleteFiles; + } + + public static Set reachableManifestPaths(Table table) { + return StreamSupport.stream(table.snapshots().spliterator(), false) + .flatMap(s -> s.allManifests(table.io()).stream()) + .map(ManifestFile::path) + .collect(Collectors.toSet()); + } + + public static void asMetadataRecord(GenericData.Record file, FileContent content) { + file.put(0, content.id()); + file.put(3, 0); // specId + } + + public static void asMetadataRecord(GenericData.Record file) { + file.put(0, FileContent.DATA.id()); + file.put(3, 0); // specId + } + + public static Dataset selectNonDerived(Dataset metadataTable) { + StructField[] fields = metadataTable.schema().fields(); + return metadataTable.select( + Stream.of(fields) + .filter(f -> !f.name().equals("readable_metrics")) // derived field + .map(f -> new Column(f.name())) + .toArray(Column[]::new)); + } + + public static Types.StructType nonDerivedSchema(Dataset metadataTable) { + return SparkSchemaUtil.convert(TestHelpers.selectNonDerived(metadataTable).schema()).asStruct(); + } +} diff --git a/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/data/TestOrcWrite.java b/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/data/TestOrcWrite.java new file mode 100644 index 000000000000..1e51a088390e --- /dev/null +++ b/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/data/TestOrcWrite.java @@ -0,0 +1,59 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.data; + +import static org.apache.iceberg.types.Types.NestedField.optional; + +import java.io.File; +import java.io.IOException; +import org.apache.iceberg.Files; +import org.apache.iceberg.Schema; +import org.apache.iceberg.io.FileAppender; +import org.apache.iceberg.orc.ORC; +import org.apache.iceberg.types.Types; +import org.apache.spark.sql.catalyst.InternalRow; +import org.junit.Assert; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.TemporaryFolder; + +public class TestOrcWrite { + @Rule public TemporaryFolder temp = new TemporaryFolder(); + + private static final Schema SCHEMA = + new Schema( + optional(1, "id", Types.IntegerType.get()), optional(2, "data", Types.StringType.get())); + + @Test + public void splitOffsets() throws IOException { + File testFile = temp.newFile(); + Assert.assertTrue("Delete should succeed", testFile.delete()); + + Iterable rows = RandomData.generateSpark(SCHEMA, 1, 0L); + FileAppender writer = + ORC.write(Files.localOutput(testFile)) + .createWriterFunc(SparkOrcWriter::new) + .schema(SCHEMA) + .build(); + + writer.addAll(rows); + writer.close(); + Assert.assertNotNull("Split offsets not present", writer.splitOffsets()); + } +} diff --git a/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/data/TestParquetAvroReader.java b/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/data/TestParquetAvroReader.java new file mode 100644 index 000000000000..a4ffc2fea437 --- /dev/null +++ b/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/data/TestParquetAvroReader.java @@ -0,0 +1,236 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.data; + +import static org.apache.iceberg.types.Types.NestedField.optional; +import static org.apache.iceberg.types.Types.NestedField.required; + +import java.io.File; +import java.io.IOException; +import java.util.Iterator; +import org.apache.avro.generic.GenericData.Record; +import org.apache.iceberg.Files; +import org.apache.iceberg.Schema; +import org.apache.iceberg.io.CloseableIterable; +import org.apache.iceberg.io.FileAppender; +import org.apache.iceberg.parquet.Parquet; +import org.apache.iceberg.parquet.ParquetAvroValueReaders; +import org.apache.iceberg.parquet.ParquetSchemaUtil; +import org.apache.iceberg.types.Types; +import org.apache.parquet.schema.MessageType; +import org.junit.Assert; +import org.junit.Ignore; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.TemporaryFolder; + +public class TestParquetAvroReader { + @Rule public TemporaryFolder temp = new TemporaryFolder(); + + private static final Schema COMPLEX_SCHEMA = + new Schema( + required(1, "roots", Types.LongType.get()), + optional(3, "lime", Types.ListType.ofRequired(4, Types.DoubleType.get())), + required( + 5, + "strict", + Types.StructType.of( + required(9, "tangerine", Types.StringType.get()), + optional( + 6, + "hopeful", + Types.StructType.of( + required(7, "steel", Types.FloatType.get()), + required(8, "lantern", Types.DateType.get()))), + optional(10, "vehement", Types.LongType.get()))), + optional( + 11, + "metamorphosis", + Types.MapType.ofRequired( + 12, 13, Types.StringType.get(), Types.TimestampType.withoutZone())), + required( + 14, + "winter", + Types.ListType.ofOptional( + 15, + Types.StructType.of( + optional(16, "beet", Types.DoubleType.get()), + required(17, "stamp", Types.TimeType.get()), + optional(18, "wheeze", Types.StringType.get())))), + optional( + 19, + "renovate", + Types.MapType.ofRequired( + 20, + 21, + Types.StringType.get(), + Types.StructType.of( + optional(22, "jumpy", Types.DoubleType.get()), + required(23, "koala", Types.TimeType.get()), + required(24, "couch rope", Types.IntegerType.get())))), + optional(2, "slide", Types.StringType.get())); + + @Ignore + public void testStructSchema() throws IOException { + Schema structSchema = + new Schema( + required(1, "circumvent", Types.LongType.get()), + optional(2, "antarctica", Types.StringType.get()), + optional(3, "fluent", Types.DoubleType.get()), + required( + 4, + "quell", + Types.StructType.of( + required(5, "operator", Types.BooleanType.get()), + optional(6, "fanta", Types.IntegerType.get()), + optional(7, "cable", Types.FloatType.get()))), + required(8, "chimney", Types.TimestampType.withZone()), + required(9, "wool", Types.DateType.get())); + + File testFile = writeTestData(structSchema, 5_000_000, 1059); + // RandomData uses the root record name "test", which must match for records to be equal + MessageType readSchema = ParquetSchemaUtil.convert(structSchema, "test"); + + long sum = 0; + long sumSq = 0; + int warmups = 2; + int trials = 10; + + for (int i = 0; i < warmups + trials; i += 1) { + // clean up as much memory as possible to avoid a large GC during the timed run + System.gc(); + + try (CloseableIterable reader = + Parquet.read(Files.localInput(testFile)) + .project(structSchema) + .createReaderFunc( + fileSchema -> ParquetAvroValueReaders.buildReader(structSchema, readSchema)) + .build()) { + long start = System.currentTimeMillis(); + long val = 0; + long count = 0; + for (Record record : reader) { + // access something to ensure the compiler doesn't optimize this away + val ^= (Long) record.get(0); + count += 1; + } + long end = System.currentTimeMillis(); + long duration = end - start; + + if (i >= warmups) { + sum += duration; + sumSq += duration * duration; + } + } + } + + double mean = ((double) sum) / trials; + double stddev = Math.sqrt((((double) sumSq) / trials) - (mean * mean)); + } + + @Ignore + public void testWithOldReadPath() throws IOException { + File testFile = writeTestData(COMPLEX_SCHEMA, 500_000, 1985); + // RandomData uses the root record name "test", which must match for records to be equal + MessageType readSchema = ParquetSchemaUtil.convert(COMPLEX_SCHEMA, "test"); + + for (int i = 0; i < 5; i += 1) { + // clean up as much memory as possible to avoid a large GC during the timed run + System.gc(); + + try (CloseableIterable reader = + Parquet.read(Files.localInput(testFile)).project(COMPLEX_SCHEMA).build()) { + long start = System.currentTimeMillis(); + long val = 0; + long count = 0; + for (Record record : reader) { + // access something to ensure the compiler doesn't optimize this away + val ^= (Long) record.get(0); + count += 1; + } + long end = System.currentTimeMillis(); + } + + // clean up as much memory as possible to avoid a large GC during the timed run + System.gc(); + + try (CloseableIterable reader = + Parquet.read(Files.localInput(testFile)) + .project(COMPLEX_SCHEMA) + .createReaderFunc( + fileSchema -> ParquetAvroValueReaders.buildReader(COMPLEX_SCHEMA, readSchema)) + .build()) { + long start = System.currentTimeMillis(); + long val = 0; + long count = 0; + for (Record record : reader) { + // access something to ensure the compiler doesn't optimize this away + val ^= (Long) record.get(0); + count += 1; + } + long end = System.currentTimeMillis(); + } + } + } + + @Test + public void testCorrectness() throws IOException { + Iterable records = RandomData.generate(COMPLEX_SCHEMA, 50_000, 34139); + + File testFile = temp.newFile(); + Assert.assertTrue("Delete should succeed", testFile.delete()); + + try (FileAppender writer = + Parquet.write(Files.localOutput(testFile)).schema(COMPLEX_SCHEMA).build()) { + writer.addAll(records); + } + + // RandomData uses the root record name "test", which must match for records to be equal + MessageType readSchema = ParquetSchemaUtil.convert(COMPLEX_SCHEMA, "test"); + + // verify that the new read path is correct + try (CloseableIterable reader = + Parquet.read(Files.localInput(testFile)) + .project(COMPLEX_SCHEMA) + .createReaderFunc( + fileSchema -> ParquetAvroValueReaders.buildReader(COMPLEX_SCHEMA, readSchema)) + .reuseContainers() + .build()) { + int recordNum = 0; + Iterator iter = records.iterator(); + for (Record actual : reader) { + Record expected = iter.next(); + Assert.assertEquals("Record " + recordNum + " should match expected", expected, actual); + recordNum += 1; + } + } + } + + private File writeTestData(Schema schema, int numRecords, int seed) throws IOException { + File testFile = temp.newFile(); + Assert.assertTrue("Delete should succeed", testFile.delete()); + + try (FileAppender writer = + Parquet.write(Files.localOutput(testFile)).schema(schema).build()) { + writer.addAll(RandomData.generate(schema, numRecords, seed)); + } + + return testFile; + } +} diff --git a/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/data/TestParquetAvroWriter.java b/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/data/TestParquetAvroWriter.java new file mode 100644 index 000000000000..15c6268da478 --- /dev/null +++ b/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/data/TestParquetAvroWriter.java @@ -0,0 +1,123 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.data; + +import static org.apache.iceberg.types.Types.NestedField.optional; +import static org.apache.iceberg.types.Types.NestedField.required; + +import java.io.File; +import java.io.IOException; +import java.util.Iterator; +import org.apache.avro.generic.GenericData.Record; +import org.apache.iceberg.Files; +import org.apache.iceberg.Schema; +import org.apache.iceberg.io.CloseableIterable; +import org.apache.iceberg.io.FileAppender; +import org.apache.iceberg.parquet.Parquet; +import org.apache.iceberg.parquet.ParquetAvroValueReaders; +import org.apache.iceberg.parquet.ParquetAvroWriter; +import org.apache.iceberg.parquet.ParquetSchemaUtil; +import org.apache.iceberg.types.Types; +import org.apache.parquet.schema.MessageType; +import org.junit.Assert; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.TemporaryFolder; + +public class TestParquetAvroWriter { + @Rule public TemporaryFolder temp = new TemporaryFolder(); + + private static final Schema COMPLEX_SCHEMA = + new Schema( + required(1, "roots", Types.LongType.get()), + optional(3, "lime", Types.ListType.ofRequired(4, Types.DoubleType.get())), + required( + 5, + "strict", + Types.StructType.of( + required(9, "tangerine", Types.StringType.get()), + optional( + 6, + "hopeful", + Types.StructType.of( + required(7, "steel", Types.FloatType.get()), + required(8, "lantern", Types.DateType.get()))), + optional(10, "vehement", Types.LongType.get()))), + optional( + 11, + "metamorphosis", + Types.MapType.ofRequired( + 12, 13, Types.StringType.get(), Types.TimestampType.withoutZone())), + required( + 14, + "winter", + Types.ListType.ofOptional( + 15, + Types.StructType.of( + optional(16, "beet", Types.DoubleType.get()), + required(17, "stamp", Types.TimeType.get()), + optional(18, "wheeze", Types.StringType.get())))), + optional( + 19, + "renovate", + Types.MapType.ofRequired( + 20, + 21, + Types.StringType.get(), + Types.StructType.of( + optional(22, "jumpy", Types.DoubleType.get()), + required(23, "koala", Types.TimeType.get()), + required(24, "couch rope", Types.IntegerType.get())))), + optional(2, "slide", Types.StringType.get())); + + @Test + public void testCorrectness() throws IOException { + Iterable records = RandomData.generate(COMPLEX_SCHEMA, 50_000, 34139); + + File testFile = temp.newFile(); + Assert.assertTrue("Delete should succeed", testFile.delete()); + + try (FileAppender writer = + Parquet.write(Files.localOutput(testFile)) + .schema(COMPLEX_SCHEMA) + .createWriterFunc(ParquetAvroWriter::buildWriter) + .build()) { + writer.addAll(records); + } + + // RandomData uses the root record name "test", which must match for records to be equal + MessageType readSchema = ParquetSchemaUtil.convert(COMPLEX_SCHEMA, "test"); + + // verify that the new read path is correct + try (CloseableIterable reader = + Parquet.read(Files.localInput(testFile)) + .project(COMPLEX_SCHEMA) + .createReaderFunc( + fileSchema -> ParquetAvroValueReaders.buildReader(COMPLEX_SCHEMA, readSchema)) + .build()) { + int recordNum = 0; + Iterator iter = records.iterator(); + for (Record actual : reader) { + Record expected = iter.next(); + Assert.assertEquals("Record " + recordNum + " should match expected", expected, actual); + recordNum += 1; + } + } + } +} diff --git a/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/data/TestSparkAvroEnums.java b/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/data/TestSparkAvroEnums.java new file mode 100644 index 000000000000..6f05a9ed7c1f --- /dev/null +++ b/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/data/TestSparkAvroEnums.java @@ -0,0 +1,96 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.data; + +import java.io.File; +import java.io.IOException; +import java.util.List; +import org.apache.avro.SchemaBuilder; +import org.apache.avro.file.DataFileWriter; +import org.apache.avro.generic.GenericData; +import org.apache.avro.generic.GenericData.Record; +import org.apache.avro.generic.GenericDatumWriter; +import org.apache.iceberg.Files; +import org.apache.iceberg.Schema; +import org.apache.iceberg.avro.Avro; +import org.apache.iceberg.avro.AvroIterable; +import org.apache.iceberg.avro.AvroSchemaUtil; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableList; +import org.apache.iceberg.relocated.com.google.common.collect.Lists; +import org.apache.spark.sql.catalyst.InternalRow; +import org.junit.Assert; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.TemporaryFolder; + +public class TestSparkAvroEnums { + + @Rule public TemporaryFolder temp = new TemporaryFolder(); + + @Test + public void writeAndValidateEnums() throws IOException { + org.apache.avro.Schema avroSchema = + SchemaBuilder.record("root") + .fields() + .name("enumCol") + .type() + .nullable() + .enumeration("testEnum") + .symbols("SYMB1", "SYMB2") + .enumDefault("SYMB2") + .endRecord(); + + org.apache.avro.Schema enumSchema = avroSchema.getField("enumCol").schema().getTypes().get(0); + Record enumRecord1 = new GenericData.Record(avroSchema); + enumRecord1.put("enumCol", new GenericData.EnumSymbol(enumSchema, "SYMB1")); + Record enumRecord2 = new GenericData.Record(avroSchema); + enumRecord2.put("enumCol", new GenericData.EnumSymbol(enumSchema, "SYMB2")); + Record enumRecord3 = new GenericData.Record(avroSchema); // null enum + List expected = ImmutableList.of(enumRecord1, enumRecord2, enumRecord3); + + File testFile = temp.newFile(); + Assert.assertTrue("Delete should succeed", testFile.delete()); + + try (DataFileWriter writer = new DataFileWriter<>(new GenericDatumWriter<>())) { + writer.create(avroSchema, testFile); + writer.append(enumRecord1); + writer.append(enumRecord2); + writer.append(enumRecord3); + } + + Schema schema = new Schema(AvroSchemaUtil.convert(avroSchema).asStructType().fields()); + List rows; + try (AvroIterable reader = + Avro.read(Files.localInput(testFile)) + .createReaderFunc(SparkAvroReader::new) + .project(schema) + .build()) { + rows = Lists.newArrayList(reader); + } + + // Iceberg will return enums as strings, so we compare string values for the enum field + for (int i = 0; i < expected.size(); i += 1) { + String expectedEnumString = + expected.get(i).get("enumCol") == null ? null : expected.get(i).get("enumCol").toString(); + String sparkString = + rows.get(i).getUTF8String(0) == null ? null : rows.get(i).getUTF8String(0).toString(); + Assert.assertEquals(expectedEnumString, sparkString); + } + } +} diff --git a/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/data/TestSparkAvroReader.java b/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/data/TestSparkAvroReader.java new file mode 100644 index 000000000000..6d1ef3db3657 --- /dev/null +++ b/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/data/TestSparkAvroReader.java @@ -0,0 +1,64 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.data; + +import static org.apache.iceberg.spark.data.TestHelpers.assertEqualsUnsafe; + +import java.io.File; +import java.io.IOException; +import java.util.List; +import org.apache.avro.generic.GenericData.Record; +import org.apache.iceberg.Files; +import org.apache.iceberg.Schema; +import org.apache.iceberg.avro.Avro; +import org.apache.iceberg.avro.AvroIterable; +import org.apache.iceberg.io.FileAppender; +import org.apache.iceberg.relocated.com.google.common.collect.Lists; +import org.apache.spark.sql.catalyst.InternalRow; +import org.junit.Assert; + +public class TestSparkAvroReader extends AvroDataTest { + @Override + protected void writeAndValidate(Schema schema) throws IOException { + List expected = RandomData.generateList(schema, 100, 0L); + + File testFile = temp.newFile(); + Assert.assertTrue("Delete should succeed", testFile.delete()); + + try (FileAppender writer = + Avro.write(Files.localOutput(testFile)).schema(schema).named("test").build()) { + for (Record rec : expected) { + writer.add(rec); + } + } + + List rows; + try (AvroIterable reader = + Avro.read(Files.localInput(testFile)) + .createReaderFunc(SparkAvroReader::new) + .project(schema) + .build()) { + rows = Lists.newArrayList(reader); + } + + for (int i = 0; i < expected.size(); i += 1) { + assertEqualsUnsafe(schema.asStruct(), expected.get(i), rows.get(i)); + } + } +} diff --git a/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/data/TestSparkDateTimes.java b/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/data/TestSparkDateTimes.java new file mode 100644 index 000000000000..b31ea8fd277d --- /dev/null +++ b/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/data/TestSparkDateTimes.java @@ -0,0 +1,74 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.data; + +import java.time.ZoneId; +import java.util.TimeZone; +import org.apache.iceberg.expressions.Literal; +import org.apache.iceberg.types.Types; +import org.apache.spark.sql.catalyst.util.DateTimeUtils; +import org.apache.spark.sql.catalyst.util.TimestampFormatter; +import org.junit.Assert; +import org.junit.Test; + +public class TestSparkDateTimes { + @Test + public void testSparkDate() { + // checkSparkDate("1582-10-14"); // -141428 + checkSparkDate("1582-10-15"); // first day of the gregorian calendar + checkSparkDate("1601-08-12"); + checkSparkDate("1801-07-04"); + checkSparkDate("1901-08-12"); + checkSparkDate("1969-12-31"); + checkSparkDate("1970-01-01"); + checkSparkDate("2017-12-25"); + checkSparkDate("2043-08-11"); + checkSparkDate("2111-05-03"); + checkSparkDate("2224-02-29"); + checkSparkDate("3224-10-05"); + } + + public void checkSparkDate(String dateString) { + Literal date = Literal.of(dateString).to(Types.DateType.get()); + String sparkDate = DateTimeUtils.toJavaDate(date.value()).toString(); + Assert.assertEquals("Should be the same date (" + date.value() + ")", dateString, sparkDate); + } + + @Test + public void testSparkTimestamp() { + TimeZone currentTz = TimeZone.getDefault(); + try { + TimeZone.setDefault(TimeZone.getTimeZone("UTC")); + checkSparkTimestamp("1582-10-15T15:51:08.440219+00:00", "1582-10-15 15:51:08.440219"); + checkSparkTimestamp("1970-01-01T00:00:00.000000+00:00", "1970-01-01 00:00:00"); + checkSparkTimestamp("2043-08-11T12:30:01.000001+00:00", "2043-08-11 12:30:01.000001"); + } finally { + TimeZone.setDefault(currentTz); + } + } + + public void checkSparkTimestamp(String timestampString, String sparkRepr) { + Literal ts = Literal.of(timestampString).to(Types.TimestampType.withZone()); + ZoneId zoneId = DateTimeUtils.getZoneId("UTC"); + TimestampFormatter formatter = TimestampFormatter.getFractionFormatter(zoneId); + String sparkTimestamp = formatter.format(ts.value()); + Assert.assertEquals( + "Should be the same timestamp (" + ts.value() + ")", sparkRepr, sparkTimestamp); + } +} diff --git a/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/data/TestSparkOrcReadMetadataColumns.java b/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/data/TestSparkOrcReadMetadataColumns.java new file mode 100644 index 000000000000..3c9037adc393 --- /dev/null +++ b/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/data/TestSparkOrcReadMetadataColumns.java @@ -0,0 +1,220 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.data; + +import static org.apache.iceberg.types.Types.NestedField.required; + +import java.io.File; +import java.io.IOException; +import java.util.Iterator; +import java.util.List; +import java.util.stream.Collectors; +import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.fs.Path; +import org.apache.iceberg.Files; +import org.apache.iceberg.MetadataColumns; +import org.apache.iceberg.Schema; +import org.apache.iceberg.exceptions.RuntimeIOException; +import org.apache.iceberg.expressions.Expression; +import org.apache.iceberg.expressions.Expressions; +import org.apache.iceberg.io.CloseableIterable; +import org.apache.iceberg.io.FileAppender; +import org.apache.iceberg.orc.ORC; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap; +import org.apache.iceberg.relocated.com.google.common.collect.Iterables; +import org.apache.iceberg.relocated.com.google.common.collect.Lists; +import org.apache.iceberg.spark.data.vectorized.VectorizedSparkOrcReaders; +import org.apache.iceberg.types.TypeUtil; +import org.apache.iceberg.types.Types; +import org.apache.orc.OrcConf; +import org.apache.orc.OrcFile; +import org.apache.orc.Reader; +import org.apache.orc.StripeInformation; +import org.apache.spark.sql.catalyst.InternalRow; +import org.apache.spark.sql.catalyst.expressions.GenericInternalRow; +import org.apache.spark.sql.vectorized.ColumnarBatch; +import org.apache.spark.unsafe.types.UTF8String; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.TemporaryFolder; +import org.junit.runner.RunWith; +import org.junit.runners.Parameterized; + +@RunWith(Parameterized.class) +public class TestSparkOrcReadMetadataColumns { + private static final Schema DATA_SCHEMA = + new Schema( + required(100, "id", Types.LongType.get()), required(101, "data", Types.StringType.get())); + + private static final Schema PROJECTION_SCHEMA = + new Schema( + required(100, "id", Types.LongType.get()), + required(101, "data", Types.StringType.get()), + MetadataColumns.ROW_POSITION, + MetadataColumns.IS_DELETED); + + private static final int NUM_ROWS = 1000; + private static final List DATA_ROWS; + private static final List EXPECTED_ROWS; + + static { + DATA_ROWS = Lists.newArrayListWithCapacity(NUM_ROWS); + for (long i = 0; i < NUM_ROWS; i++) { + InternalRow row = new GenericInternalRow(DATA_SCHEMA.columns().size()); + row.update(0, i); + row.update(1, UTF8String.fromString("str" + i)); + DATA_ROWS.add(row); + } + + EXPECTED_ROWS = Lists.newArrayListWithCapacity(NUM_ROWS); + for (long i = 0; i < NUM_ROWS; i++) { + InternalRow row = new GenericInternalRow(PROJECTION_SCHEMA.columns().size()); + row.update(0, i); + row.update(1, UTF8String.fromString("str" + i)); + row.update(2, i); + row.update(3, false); + EXPECTED_ROWS.add(row); + } + } + + @Parameterized.Parameters(name = "vectorized = {0}") + public static Object[] parameters() { + return new Object[] {false, true}; + } + + @Rule public TemporaryFolder temp = new TemporaryFolder(); + + private boolean vectorized; + private File testFile; + + public TestSparkOrcReadMetadataColumns(boolean vectorized) { + this.vectorized = vectorized; + } + + @Before + public void writeFile() throws IOException { + testFile = temp.newFile(); + Assert.assertTrue("Delete should succeed", testFile.delete()); + + try (FileAppender writer = + ORC.write(Files.localOutput(testFile)) + .createWriterFunc(SparkOrcWriter::new) + .schema(DATA_SCHEMA) + // write in such a way that the file contains 10 stripes each with 100 rows + .set("iceberg.orc.vectorbatch.size", "100") + .set(OrcConf.ROWS_BETWEEN_CHECKS.getAttribute(), "100") + .set(OrcConf.STRIPE_SIZE.getAttribute(), "1") + .build()) { + writer.addAll(DATA_ROWS); + } + } + + @Test + public void testReadRowNumbers() throws IOException { + readAndValidate(null, null, null, EXPECTED_ROWS); + } + + @Test + public void testReadRowNumbersWithFilter() throws IOException { + readAndValidate( + Expressions.greaterThanOrEqual("id", 500), null, null, EXPECTED_ROWS.subList(500, 1000)); + } + + @Test + public void testReadRowNumbersWithSplits() throws IOException { + Reader reader; + try { + OrcFile.ReaderOptions readerOptions = + OrcFile.readerOptions(new Configuration()).useUTCTimestamp(true); + reader = OrcFile.createReader(new Path(testFile.toString()), readerOptions); + } catch (IOException ioe) { + throw new RuntimeIOException(ioe, "Failed to open file: %s", testFile); + } + List splitOffsets = + reader.getStripes().stream().map(StripeInformation::getOffset).collect(Collectors.toList()); + List splitLengths = + reader.getStripes().stream().map(StripeInformation::getLength).collect(Collectors.toList()); + + for (int i = 0; i < 10; i++) { + readAndValidate( + null, + splitOffsets.get(i), + splitLengths.get(i), + EXPECTED_ROWS.subList(i * 100, (i + 1) * 100)); + } + } + + private void readAndValidate( + Expression filter, Long splitStart, Long splitLength, List expected) + throws IOException { + Schema projectionWithoutMetadataFields = + TypeUtil.selectNot(PROJECTION_SCHEMA, MetadataColumns.metadataFieldIds()); + CloseableIterable reader = null; + try { + ORC.ReadBuilder builder = + ORC.read(Files.localInput(testFile)).project(projectionWithoutMetadataFields); + + if (vectorized) { + builder = + builder.createBatchedReaderFunc( + readOrcSchema -> + VectorizedSparkOrcReaders.buildReader( + PROJECTION_SCHEMA, readOrcSchema, ImmutableMap.of())); + } else { + builder = + builder.createReaderFunc( + readOrcSchema -> new SparkOrcReader(PROJECTION_SCHEMA, readOrcSchema)); + } + + if (filter != null) { + builder = builder.filter(filter); + } + + if (splitStart != null && splitLength != null) { + builder = builder.split(splitStart, splitLength); + } + + if (vectorized) { + reader = batchesToRows(builder.build()); + } else { + reader = builder.build(); + } + + final Iterator actualRows = reader.iterator(); + final Iterator expectedRows = expected.iterator(); + while (expectedRows.hasNext()) { + Assert.assertTrue("Should have expected number of rows", actualRows.hasNext()); + TestHelpers.assertEquals(PROJECTION_SCHEMA, expectedRows.next(), actualRows.next()); + } + Assert.assertFalse("Should not have extra rows", actualRows.hasNext()); + } finally { + if (reader != null) { + reader.close(); + } + } + } + + private CloseableIterable batchesToRows(CloseableIterable batches) { + return CloseableIterable.combine( + Iterables.concat(Iterables.transform(batches, b -> (Iterable) b::rowIterator)), + batches); + } +} diff --git a/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/data/TestSparkOrcReader.java b/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/data/TestSparkOrcReader.java new file mode 100644 index 000000000000..b23fe729a187 --- /dev/null +++ b/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/data/TestSparkOrcReader.java @@ -0,0 +1,110 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.data; + +import static org.apache.iceberg.spark.data.TestHelpers.assertEquals; +import static org.apache.iceberg.types.Types.NestedField.required; + +import java.io.File; +import java.io.IOException; +import java.util.Collections; +import java.util.Iterator; +import java.util.List; +import org.apache.iceberg.Files; +import org.apache.iceberg.Schema; +import org.apache.iceberg.io.CloseableIterable; +import org.apache.iceberg.io.FileAppender; +import org.apache.iceberg.orc.ORC; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap; +import org.apache.iceberg.relocated.com.google.common.collect.Iterators; +import org.apache.iceberg.spark.data.vectorized.VectorizedSparkOrcReaders; +import org.apache.iceberg.types.Types; +import org.apache.spark.sql.catalyst.InternalRow; +import org.apache.spark.sql.vectorized.ColumnarBatch; +import org.junit.Assert; +import org.junit.Test; + +public class TestSparkOrcReader extends AvroDataTest { + @Override + protected void writeAndValidate(Schema schema) throws IOException { + final Iterable expected = RandomData.generateSpark(schema, 100, 0L); + + writeAndValidateRecords(schema, expected); + } + + @Test + public void writeAndValidateRepeatingRecords() throws IOException { + Schema structSchema = + new Schema( + required(100, "id", Types.LongType.get()), + required(101, "data", Types.StringType.get())); + List expectedRepeating = + Collections.nCopies(100, RandomData.generateSpark(structSchema, 1, 0L).iterator().next()); + + writeAndValidateRecords(structSchema, expectedRepeating); + } + + private void writeAndValidateRecords(Schema schema, Iterable expected) + throws IOException { + final File testFile = temp.newFile(); + Assert.assertTrue("Delete should succeed", testFile.delete()); + + try (FileAppender writer = + ORC.write(Files.localOutput(testFile)) + .createWriterFunc(SparkOrcWriter::new) + .schema(schema) + .build()) { + writer.addAll(expected); + } + + try (CloseableIterable reader = + ORC.read(Files.localInput(testFile)) + .project(schema) + .createReaderFunc(readOrcSchema -> new SparkOrcReader(schema, readOrcSchema)) + .build()) { + final Iterator actualRows = reader.iterator(); + final Iterator expectedRows = expected.iterator(); + while (expectedRows.hasNext()) { + Assert.assertTrue("Should have expected number of rows", actualRows.hasNext()); + assertEquals(schema, expectedRows.next(), actualRows.next()); + } + Assert.assertFalse("Should not have extra rows", actualRows.hasNext()); + } + + try (CloseableIterable reader = + ORC.read(Files.localInput(testFile)) + .project(schema) + .createBatchedReaderFunc( + readOrcSchema -> + VectorizedSparkOrcReaders.buildReader(schema, readOrcSchema, ImmutableMap.of())) + .build()) { + final Iterator actualRows = batchesToRows(reader.iterator()); + final Iterator expectedRows = expected.iterator(); + while (expectedRows.hasNext()) { + Assert.assertTrue("Should have expected number of rows", actualRows.hasNext()); + assertEquals(schema, expectedRows.next(), actualRows.next()); + } + Assert.assertFalse("Should not have extra rows", actualRows.hasNext()); + } + } + + private Iterator batchesToRows(Iterator batches) { + return Iterators.concat(Iterators.transform(batches, ColumnarBatch::rowIterator)); + } +} diff --git a/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/data/TestSparkParquetReadMetadataColumns.java b/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/data/TestSparkParquetReadMetadataColumns.java new file mode 100644 index 000000000000..23d69c467218 --- /dev/null +++ b/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/data/TestSparkParquetReadMetadataColumns.java @@ -0,0 +1,316 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.data; + +import static org.apache.iceberg.types.Types.NestedField.required; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +import java.io.File; +import java.io.IOException; +import java.util.Iterator; +import java.util.List; +import java.util.Set; +import org.apache.arrow.vector.NullCheckingForGet; +import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.fs.Path; +import org.apache.iceberg.Files; +import org.apache.iceberg.MetadataColumns; +import org.apache.iceberg.Schema; +import org.apache.iceberg.data.DeleteFilter; +import org.apache.iceberg.deletes.PositionDeleteIndex; +import org.apache.iceberg.expressions.Expression; +import org.apache.iceberg.expressions.Expressions; +import org.apache.iceberg.io.CloseableIterable; +import org.apache.iceberg.io.FileAppender; +import org.apache.iceberg.parquet.Parquet; +import org.apache.iceberg.parquet.ParquetSchemaUtil; +import org.apache.iceberg.relocated.com.google.common.collect.Iterables; +import org.apache.iceberg.relocated.com.google.common.collect.Lists; +import org.apache.iceberg.relocated.com.google.common.collect.Maps; +import org.apache.iceberg.relocated.com.google.common.collect.Sets; +import org.apache.iceberg.spark.SparkSchemaUtil; +import org.apache.iceberg.spark.data.vectorized.VectorizedSparkParquetReaders; +import org.apache.iceberg.types.Types; +import org.apache.parquet.ParquetReadOptions; +import org.apache.parquet.hadoop.ParquetFileReader; +import org.apache.parquet.hadoop.ParquetFileWriter; +import org.apache.parquet.hadoop.metadata.BlockMetaData; +import org.apache.parquet.hadoop.util.HadoopInputFile; +import org.apache.spark.sql.catalyst.InternalRow; +import org.apache.spark.sql.catalyst.expressions.GenericInternalRow; +import org.apache.spark.sql.types.StructType; +import org.apache.spark.sql.vectorized.ColumnarBatch; +import org.apache.spark.unsafe.types.UTF8String; +import org.junit.Assert; +import org.junit.Assume; +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.TemporaryFolder; +import org.junit.runner.RunWith; +import org.junit.runners.Parameterized; + +@RunWith(Parameterized.class) +public class TestSparkParquetReadMetadataColumns { + private static final Schema DATA_SCHEMA = + new Schema( + required(100, "id", Types.LongType.get()), required(101, "data", Types.StringType.get())); + + private static final Schema PROJECTION_SCHEMA = + new Schema( + required(100, "id", Types.LongType.get()), + required(101, "data", Types.StringType.get()), + MetadataColumns.ROW_POSITION, + MetadataColumns.IS_DELETED); + + private static final int NUM_ROWS = 1000; + private static final List DATA_ROWS; + private static final List EXPECTED_ROWS; + private static final int NUM_ROW_GROUPS = 10; + private static final int ROWS_PER_SPLIT = NUM_ROWS / NUM_ROW_GROUPS; + private static final int RECORDS_PER_BATCH = ROWS_PER_SPLIT / 10; + + static { + DATA_ROWS = Lists.newArrayListWithCapacity(NUM_ROWS); + for (long i = 0; i < NUM_ROWS; i += 1) { + InternalRow row = new GenericInternalRow(DATA_SCHEMA.columns().size()); + if (i >= NUM_ROWS / 2) { + row.update(0, 2 * i); + } else { + row.update(0, i); + } + row.update(1, UTF8String.fromString("str" + i)); + DATA_ROWS.add(row); + } + + EXPECTED_ROWS = Lists.newArrayListWithCapacity(NUM_ROWS); + for (long i = 0; i < NUM_ROWS; i += 1) { + InternalRow row = new GenericInternalRow(PROJECTION_SCHEMA.columns().size()); + if (i >= NUM_ROWS / 2) { + row.update(0, 2 * i); + } else { + row.update(0, i); + } + row.update(1, UTF8String.fromString("str" + i)); + row.update(2, i); + row.update(3, false); + EXPECTED_ROWS.add(row); + } + } + + @Parameterized.Parameters(name = "vectorized = {0}") + public static Object[][] parameters() { + return new Object[][] {new Object[] {false}, new Object[] {true}}; + } + + @Rule public TemporaryFolder temp = new TemporaryFolder(); + + private final boolean vectorized; + private File testFile; + + public TestSparkParquetReadMetadataColumns(boolean vectorized) { + this.vectorized = vectorized; + } + + @Before + public void writeFile() throws IOException { + List fileSplits = Lists.newArrayList(); + StructType struct = SparkSchemaUtil.convert(DATA_SCHEMA); + Configuration conf = new Configuration(); + + testFile = temp.newFile(); + Assert.assertTrue("Delete should succeed", testFile.delete()); + ParquetFileWriter parquetFileWriter = + new ParquetFileWriter( + conf, + ParquetSchemaUtil.convert(DATA_SCHEMA, "testSchema"), + new Path(testFile.getAbsolutePath())); + + parquetFileWriter.start(); + for (int i = 0; i < NUM_ROW_GROUPS; i += 1) { + File split = temp.newFile(); + Assert.assertTrue("Delete should succeed", split.delete()); + fileSplits.add(new Path(split.getAbsolutePath())); + try (FileAppender writer = + Parquet.write(Files.localOutput(split)) + .createWriterFunc(msgType -> SparkParquetWriters.buildWriter(struct, msgType)) + .schema(DATA_SCHEMA) + .overwrite() + .build()) { + writer.addAll(DATA_ROWS.subList(i * ROWS_PER_SPLIT, (i + 1) * ROWS_PER_SPLIT)); + } + parquetFileWriter.appendFile( + HadoopInputFile.fromPath(new Path(split.getAbsolutePath()), conf)); + } + parquetFileWriter.end( + ParquetFileWriter.mergeMetadataFiles(fileSplits, conf) + .getFileMetaData() + .getKeyValueMetaData()); + } + + @Test + public void testReadRowNumbers() throws IOException { + readAndValidate(null, null, null, EXPECTED_ROWS); + } + + @Test + public void testReadRowNumbersWithDelete() throws IOException { + Assume.assumeTrue(vectorized); + + List expectedRowsAfterDelete = Lists.newArrayList(); + EXPECTED_ROWS.forEach(row -> expectedRowsAfterDelete.add(row.copy())); + // remove row at position 98, 99, 100, 101, 102, this crosses two row groups [0, 100) and [100, + // 200) + for (int i = 98; i <= 102; i++) { + expectedRowsAfterDelete.get(i).update(3, true); + } + + Parquet.ReadBuilder builder = + Parquet.read(Files.localInput(testFile)).project(PROJECTION_SCHEMA); + + DeleteFilter deleteFilter = mock(DeleteFilter.class); + when(deleteFilter.hasPosDeletes()).thenReturn(true); + PositionDeleteIndex deletedRowPos = new CustomizedPositionDeleteIndex(); + deletedRowPos.delete(98, 103); + when(deleteFilter.deletedRowPositions()).thenReturn(deletedRowPos); + + builder.createBatchedReaderFunc( + fileSchema -> + VectorizedSparkParquetReaders.buildReader( + PROJECTION_SCHEMA, + fileSchema, + NullCheckingForGet.NULL_CHECKING_ENABLED, + Maps.newHashMap(), + deleteFilter)); + builder.recordsPerBatch(RECORDS_PER_BATCH); + + validate(expectedRowsAfterDelete, builder); + } + + private class CustomizedPositionDeleteIndex implements PositionDeleteIndex { + private final Set deleteIndex; + + private CustomizedPositionDeleteIndex() { + deleteIndex = Sets.newHashSet(); + } + + @Override + public void delete(long position) { + deleteIndex.add(position); + } + + @Override + public void delete(long posStart, long posEnd) { + for (long l = posStart; l < posEnd; l++) { + delete(l); + } + } + + @Override + public boolean isDeleted(long position) { + return deleteIndex.contains(position); + } + + @Override + public boolean isEmpty() { + return deleteIndex.isEmpty(); + } + } + + @Test + public void testReadRowNumbersWithFilter() throws IOException { + // current iceberg supports row group filter. + for (int i = 1; i < 5; i += 1) { + readAndValidate( + Expressions.and( + Expressions.lessThan("id", NUM_ROWS / 2), + Expressions.greaterThanOrEqual("id", i * ROWS_PER_SPLIT)), + null, + null, + EXPECTED_ROWS.subList(i * ROWS_PER_SPLIT, NUM_ROWS / 2)); + } + } + + @Test + public void testReadRowNumbersWithSplits() throws IOException { + ParquetFileReader fileReader = + new ParquetFileReader( + HadoopInputFile.fromPath(new Path(testFile.getAbsolutePath()), new Configuration()), + ParquetReadOptions.builder().build()); + List rowGroups = fileReader.getRowGroups(); + for (int i = 0; i < NUM_ROW_GROUPS; i += 1) { + readAndValidate( + null, + rowGroups.get(i).getColumns().get(0).getStartingPos(), + rowGroups.get(i).getCompressedSize(), + EXPECTED_ROWS.subList(i * ROWS_PER_SPLIT, (i + 1) * ROWS_PER_SPLIT)); + } + } + + private void readAndValidate( + Expression filter, Long splitStart, Long splitLength, List expected) + throws IOException { + Parquet.ReadBuilder builder = + Parquet.read(Files.localInput(testFile)).project(PROJECTION_SCHEMA); + + if (vectorized) { + builder.createBatchedReaderFunc( + fileSchema -> + VectorizedSparkParquetReaders.buildReader( + PROJECTION_SCHEMA, fileSchema, NullCheckingForGet.NULL_CHECKING_ENABLED)); + builder.recordsPerBatch(RECORDS_PER_BATCH); + } else { + builder = + builder.createReaderFunc( + msgType -> SparkParquetReaders.buildReader(PROJECTION_SCHEMA, msgType)); + } + + if (filter != null) { + builder = builder.filter(filter); + } + + if (splitStart != null && splitLength != null) { + builder = builder.split(splitStart, splitLength); + } + + validate(expected, builder); + } + + private void validate(List expected, Parquet.ReadBuilder builder) + throws IOException { + try (CloseableIterable reader = + vectorized ? batchesToRows(builder.build()) : builder.build()) { + final Iterator actualRows = reader.iterator(); + + for (InternalRow internalRow : expected) { + Assert.assertTrue("Should have expected number of rows", actualRows.hasNext()); + TestHelpers.assertEquals(PROJECTION_SCHEMA, internalRow, actualRows.next()); + } + + Assert.assertFalse("Should not have extra rows", actualRows.hasNext()); + } + } + + private CloseableIterable batchesToRows(CloseableIterable batches) { + return CloseableIterable.combine( + Iterables.concat(Iterables.transform(batches, b -> (Iterable) b::rowIterator)), + batches); + } +} diff --git a/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/data/TestSparkParquetReader.java b/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/data/TestSparkParquetReader.java new file mode 100644 index 000000000000..024ce3a60c2b --- /dev/null +++ b/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/data/TestSparkParquetReader.java @@ -0,0 +1,206 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.data; + +import static org.apache.iceberg.spark.data.TestHelpers.assertEqualsUnsafe; +import static org.apache.iceberg.types.Types.NestedField.required; + +import java.io.File; +import java.io.IOException; +import java.util.Iterator; +import java.util.List; +import java.util.Map; +import org.apache.avro.generic.GenericData; +import org.apache.hadoop.conf.Configuration; +import org.apache.iceberg.DataFiles; +import org.apache.iceberg.FileFormat; +import org.apache.iceberg.Files; +import org.apache.iceberg.MetricsConfig; +import org.apache.iceberg.PartitionSpec; +import org.apache.iceberg.Schema; +import org.apache.iceberg.Table; +import org.apache.iceberg.data.IcebergGenerics; +import org.apache.iceberg.data.Record; +import org.apache.iceberg.hadoop.HadoopTables; +import org.apache.iceberg.io.CloseableIterable; +import org.apache.iceberg.io.FileAppender; +import org.apache.iceberg.io.InputFile; +import org.apache.iceberg.parquet.Parquet; +import org.apache.iceberg.parquet.ParquetUtil; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap; +import org.apache.iceberg.relocated.com.google.common.collect.Lists; +import org.apache.iceberg.relocated.com.google.common.collect.Maps; +import org.apache.iceberg.types.TypeUtil; +import org.apache.iceberg.types.Types; +import org.apache.parquet.hadoop.ParquetWriter; +import org.apache.parquet.hadoop.api.WriteSupport; +import org.apache.parquet.hadoop.util.HadoopOutputFile; +import org.apache.spark.sql.catalyst.InternalRow; +import org.apache.spark.sql.types.DataTypes; +import org.apache.spark.sql.types.Metadata; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; +import org.assertj.core.api.Assertions; +import org.junit.Assert; +import org.junit.Assume; +import org.junit.Test; + +public class TestSparkParquetReader extends AvroDataTest { + @Override + protected void writeAndValidate(Schema schema) throws IOException { + Assume.assumeTrue( + "Parquet Avro cannot write non-string map keys", + null + == TypeUtil.find( + schema, + type -> type.isMapType() && type.asMapType().keyType() != Types.StringType.get())); + + List expected = RandomData.generateList(schema, 100, 0L); + + File testFile = temp.newFile(); + Assert.assertTrue("Delete should succeed", testFile.delete()); + + try (FileAppender writer = + Parquet.write(Files.localOutput(testFile)).schema(schema).named("test").build()) { + writer.addAll(expected); + } + + try (CloseableIterable reader = + Parquet.read(Files.localInput(testFile)) + .project(schema) + .createReaderFunc(type -> SparkParquetReaders.buildReader(schema, type)) + .build()) { + Iterator rows = reader.iterator(); + for (GenericData.Record record : expected) { + Assert.assertTrue("Should have expected number of rows", rows.hasNext()); + assertEqualsUnsafe(schema.asStruct(), record, rows.next()); + } + Assert.assertFalse("Should not have extra rows", rows.hasNext()); + } + } + + protected List rowsFromFile(InputFile inputFile, Schema schema) throws IOException { + try (CloseableIterable reader = + Parquet.read(inputFile) + .project(schema) + .createReaderFunc(type -> SparkParquetReaders.buildReader(schema, type)) + .build()) { + return Lists.newArrayList(reader); + } + } + + protected Table tableFromInputFile(InputFile inputFile, Schema schema) throws IOException { + HadoopTables tables = new HadoopTables(); + Table table = + tables.create( + schema, + PartitionSpec.unpartitioned(), + ImmutableMap.of(), + temp.newFolder().getCanonicalPath()); + + table + .newAppend() + .appendFile( + DataFiles.builder(PartitionSpec.unpartitioned()) + .withFormat(FileFormat.PARQUET) + .withInputFile(inputFile) + .withMetrics(ParquetUtil.fileMetrics(inputFile, MetricsConfig.getDefault())) + .withFileSizeInBytes(inputFile.getLength()) + .build()) + .commit(); + + return table; + } + + @Test + public void testInt96TimestampProducedBySparkIsReadCorrectly() throws IOException { + String outputFilePath = + String.format("%s/%s", temp.getRoot().getAbsolutePath(), "parquet_int96.parquet"); + HadoopOutputFile outputFile = + HadoopOutputFile.fromPath( + new org.apache.hadoop.fs.Path(outputFilePath), new Configuration()); + Schema schema = new Schema(required(1, "ts", Types.TimestampType.withZone())); + StructType sparkSchema = + new StructType( + new StructField[] { + new StructField("ts", DataTypes.TimestampType, true, Metadata.empty()) + }); + List rows = Lists.newArrayList(RandomData.generateSpark(schema, 10, 0L)); + + try (ParquetWriter writer = + new NativeSparkWriterBuilder(outputFile) + .set("org.apache.spark.sql.parquet.row.attributes", sparkSchema.json()) + .set("spark.sql.parquet.writeLegacyFormat", "false") + .set("spark.sql.parquet.outputTimestampType", "INT96") + .set("spark.sql.parquet.fieldId.write.enabled", "true") + .build()) { + for (InternalRow row : rows) { + writer.write(row); + } + } + + InputFile parquetInputFile = Files.localInput(outputFilePath); + List readRows = rowsFromFile(parquetInputFile, schema); + Assert.assertEquals(rows.size(), readRows.size()); + Assertions.assertThat(readRows).isEqualTo(rows); + + // Now we try to import that file as an Iceberg table to make sure Iceberg can read + // Int96 end to end. + Table int96Table = tableFromInputFile(parquetInputFile, schema); + List tableRecords = Lists.newArrayList(IcebergGenerics.read(int96Table).build()); + + Assert.assertEquals(rows.size(), tableRecords.size()); + + for (int i = 0; i < tableRecords.size(); i++) { + GenericsHelpers.assertEqualsUnsafe(schema.asStruct(), tableRecords.get(i), rows.get(i)); + } + } + + /** + * Native Spark ParquetWriter.Builder implementation so that we can write timestamps using Spark's + * native ParquetWriteSupport. + */ + private static class NativeSparkWriterBuilder + extends ParquetWriter.Builder { + private final Map config = Maps.newHashMap(); + + NativeSparkWriterBuilder(org.apache.parquet.io.OutputFile path) { + super(path); + } + + public NativeSparkWriterBuilder set(String property, String value) { + this.config.put(property, value); + return self(); + } + + @Override + protected NativeSparkWriterBuilder self() { + return this; + } + + @Override + protected WriteSupport getWriteSupport(Configuration configuration) { + for (Map.Entry entry : config.entrySet()) { + configuration.set(entry.getKey(), entry.getValue()); + } + + return new org.apache.spark.sql.execution.datasources.parquet.ParquetWriteSupport(); + } + } +} diff --git a/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/data/TestSparkParquetWriter.java b/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/data/TestSparkParquetWriter.java new file mode 100644 index 000000000000..261fb8838aa4 --- /dev/null +++ b/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/data/TestSparkParquetWriter.java @@ -0,0 +1,119 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.data; + +import static org.apache.iceberg.types.Types.NestedField.optional; +import static org.apache.iceberg.types.Types.NestedField.required; + +import java.io.File; +import java.io.IOException; +import java.util.Iterator; +import org.apache.iceberg.Files; +import org.apache.iceberg.Schema; +import org.apache.iceberg.io.CloseableIterable; +import org.apache.iceberg.io.FileAppender; +import org.apache.iceberg.parquet.Parquet; +import org.apache.iceberg.spark.SparkSchemaUtil; +import org.apache.iceberg.types.Types; +import org.apache.spark.sql.catalyst.InternalRow; +import org.junit.Assert; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.TemporaryFolder; + +public class TestSparkParquetWriter { + @Rule public TemporaryFolder temp = new TemporaryFolder(); + + private static final Schema COMPLEX_SCHEMA = + new Schema( + required(1, "roots", Types.LongType.get()), + optional(3, "lime", Types.ListType.ofRequired(4, Types.DoubleType.get())), + required( + 5, + "strict", + Types.StructType.of( + required(9, "tangerine", Types.StringType.get()), + optional( + 6, + "hopeful", + Types.StructType.of( + required(7, "steel", Types.FloatType.get()), + required(8, "lantern", Types.DateType.get()))), + optional(10, "vehement", Types.LongType.get()))), + optional( + 11, + "metamorphosis", + Types.MapType.ofRequired( + 12, 13, Types.StringType.get(), Types.TimestampType.withZone())), + required( + 14, + "winter", + Types.ListType.ofOptional( + 15, + Types.StructType.of( + optional(16, "beet", Types.DoubleType.get()), + required(17, "stamp", Types.FloatType.get()), + optional(18, "wheeze", Types.StringType.get())))), + optional( + 19, + "renovate", + Types.MapType.ofRequired( + 20, + 21, + Types.StringType.get(), + Types.StructType.of( + optional(22, "jumpy", Types.DoubleType.get()), + required(23, "koala", Types.IntegerType.get()), + required(24, "couch rope", Types.IntegerType.get())))), + optional(2, "slide", Types.StringType.get())); + + @Test + public void testCorrectness() throws IOException { + int numRows = 50_000; + Iterable records = RandomData.generateSpark(COMPLEX_SCHEMA, numRows, 19981); + + File testFile = temp.newFile(); + Assert.assertTrue("Delete should succeed", testFile.delete()); + + try (FileAppender writer = + Parquet.write(Files.localOutput(testFile)) + .schema(COMPLEX_SCHEMA) + .createWriterFunc( + msgType -> + SparkParquetWriters.buildWriter( + SparkSchemaUtil.convert(COMPLEX_SCHEMA), msgType)) + .build()) { + writer.addAll(records); + } + + try (CloseableIterable reader = + Parquet.read(Files.localInput(testFile)) + .project(COMPLEX_SCHEMA) + .createReaderFunc(type -> SparkParquetReaders.buildReader(COMPLEX_SCHEMA, type)) + .build()) { + Iterator expected = records.iterator(); + Iterator rows = reader.iterator(); + for (int i = 0; i < numRows; i += 1) { + Assert.assertTrue("Should have expected number of rows", rows.hasNext()); + TestHelpers.assertEquals(COMPLEX_SCHEMA, expected.next(), rows.next()); + } + Assert.assertFalse("Should not have extra rows", rows.hasNext()); + } + } +} diff --git a/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/data/TestSparkRecordOrcReaderWriter.java b/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/data/TestSparkRecordOrcReaderWriter.java new file mode 100644 index 000000000000..d10e7f5a19e3 --- /dev/null +++ b/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/data/TestSparkRecordOrcReaderWriter.java @@ -0,0 +1,153 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.data; + +import static org.apache.iceberg.types.Types.NestedField.required; + +import java.io.File; +import java.io.IOException; +import java.math.BigDecimal; +import java.util.Iterator; +import java.util.List; +import org.apache.iceberg.Files; +import org.apache.iceberg.Schema; +import org.apache.iceberg.data.GenericRecord; +import org.apache.iceberg.data.RandomGenericData; +import org.apache.iceberg.data.Record; +import org.apache.iceberg.data.orc.GenericOrcReader; +import org.apache.iceberg.data.orc.GenericOrcWriter; +import org.apache.iceberg.io.CloseableIterable; +import org.apache.iceberg.io.FileAppender; +import org.apache.iceberg.orc.ORC; +import org.apache.iceberg.relocated.com.google.common.collect.Lists; +import org.apache.iceberg.types.Types; +import org.apache.spark.sql.catalyst.InternalRow; +import org.junit.Assert; +import org.junit.Test; + +public class TestSparkRecordOrcReaderWriter extends AvroDataTest { + private static final int NUM_RECORDS = 200; + + private void writeAndValidate(Schema schema, List expectedRecords) throws IOException { + final File originalFile = temp.newFile(); + Assert.assertTrue("Delete should succeed", originalFile.delete()); + + // Write few generic records into the original test file. + try (FileAppender writer = + ORC.write(Files.localOutput(originalFile)) + .createWriterFunc(GenericOrcWriter::buildWriter) + .schema(schema) + .build()) { + writer.addAll(expectedRecords); + } + + // Read into spark InternalRow from the original test file. + List internalRows = Lists.newArrayList(); + try (CloseableIterable reader = + ORC.read(Files.localInput(originalFile)) + .project(schema) + .createReaderFunc(readOrcSchema -> new SparkOrcReader(schema, readOrcSchema)) + .build()) { + reader.forEach(internalRows::add); + assertEqualsUnsafe(schema.asStruct(), expectedRecords, reader, expectedRecords.size()); + } + + final File anotherFile = temp.newFile(); + Assert.assertTrue("Delete should succeed", anotherFile.delete()); + + // Write those spark InternalRows into a new file again. + try (FileAppender writer = + ORC.write(Files.localOutput(anotherFile)) + .createWriterFunc(SparkOrcWriter::new) + .schema(schema) + .build()) { + writer.addAll(internalRows); + } + + // Check whether the InternalRows are expected records. + try (CloseableIterable reader = + ORC.read(Files.localInput(anotherFile)) + .project(schema) + .createReaderFunc(readOrcSchema -> new SparkOrcReader(schema, readOrcSchema)) + .build()) { + assertEqualsUnsafe(schema.asStruct(), expectedRecords, reader, expectedRecords.size()); + } + + // Read into iceberg GenericRecord and check again. + try (CloseableIterable reader = + ORC.read(Files.localInput(anotherFile)) + .createReaderFunc(typeDesc -> GenericOrcReader.buildReader(schema, typeDesc)) + .project(schema) + .build()) { + assertRecordEquals(expectedRecords, reader, expectedRecords.size()); + } + } + + @Override + protected void writeAndValidate(Schema schema) throws IOException { + List expectedRecords = RandomGenericData.generate(schema, NUM_RECORDS, 1992L); + writeAndValidate(schema, expectedRecords); + } + + @Test + public void testDecimalWithTrailingZero() throws IOException { + Schema schema = + new Schema( + required(1, "d1", Types.DecimalType.of(10, 2)), + required(2, "d2", Types.DecimalType.of(20, 5)), + required(3, "d3", Types.DecimalType.of(38, 20))); + + List expected = Lists.newArrayList(); + + GenericRecord record = GenericRecord.create(schema); + record.set(0, new BigDecimal("101.00")); + record.set(1, new BigDecimal("10.00E-3")); + record.set(2, new BigDecimal("1001.0000E-16")); + + expected.add(record.copy()); + + writeAndValidate(schema, expected); + } + + private static void assertRecordEquals( + Iterable expected, Iterable actual, int size) { + Iterator expectedIter = expected.iterator(); + Iterator actualIter = actual.iterator(); + for (int i = 0; i < size; i += 1) { + Assert.assertTrue("Expected iterator should have more rows", expectedIter.hasNext()); + Assert.assertTrue("Actual iterator should have more rows", actualIter.hasNext()); + Assert.assertEquals("Should have same rows.", expectedIter.next(), actualIter.next()); + } + Assert.assertFalse("Expected iterator should not have any extra rows.", expectedIter.hasNext()); + Assert.assertFalse("Actual iterator should not have any extra rows.", actualIter.hasNext()); + } + + private static void assertEqualsUnsafe( + Types.StructType struct, Iterable expected, Iterable actual, int size) { + Iterator expectedIter = expected.iterator(); + Iterator actualIter = actual.iterator(); + for (int i = 0; i < size; i += 1) { + Assert.assertTrue("Expected iterator should have more rows", expectedIter.hasNext()); + Assert.assertTrue("Actual iterator should have more rows", actualIter.hasNext()); + GenericsHelpers.assertEqualsUnsafe(struct, expectedIter.next(), actualIter.next()); + } + Assert.assertFalse("Expected iterator should not have any extra rows.", expectedIter.hasNext()); + Assert.assertFalse("Actual iterator should not have any extra rows.", actualIter.hasNext()); + } +} diff --git a/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/data/TestVectorizedOrcDataReader.java b/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/data/TestVectorizedOrcDataReader.java new file mode 100644 index 000000000000..b247ef20d152 --- /dev/null +++ b/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/data/TestVectorizedOrcDataReader.java @@ -0,0 +1,166 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.data; + +import java.io.File; +import java.io.IOException; +import java.nio.file.Path; +import java.util.Arrays; +import java.util.Collections; +import java.util.Iterator; +import org.apache.iceberg.Files; +import org.apache.iceberg.PartitionSpec; +import org.apache.iceberg.Schema; +import org.apache.iceberg.data.GenericRecord; +import org.apache.iceberg.data.Record; +import org.apache.iceberg.data.orc.GenericOrcWriter; +import org.apache.iceberg.expressions.Expressions; +import org.apache.iceberg.io.CloseableIterable; +import org.apache.iceberg.io.DataWriter; +import org.apache.iceberg.io.OutputFile; +import org.apache.iceberg.orc.ORC; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableList; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap; +import org.apache.iceberg.relocated.com.google.common.collect.Iterators; +import org.apache.iceberg.spark.data.vectorized.VectorizedSparkOrcReaders; +import org.apache.iceberg.types.Types; +import org.apache.orc.OrcConf; +import org.apache.spark.sql.catalyst.InternalRow; +import org.apache.spark.sql.vectorized.ColumnarBatch; +import org.assertj.core.api.WithAssertions; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.io.TempDir; + +public class TestVectorizedOrcDataReader implements WithAssertions { + @TempDir public static Path temp; + + private static final Schema SCHEMA = + new Schema( + Types.NestedField.required(1, "id", Types.LongType.get()), + Types.NestedField.optional(2, "data", Types.StringType.get()), + Types.NestedField.optional(3, "binary", Types.BinaryType.get()), + Types.NestedField.required( + 4, "array", Types.ListType.ofOptional(5, Types.IntegerType.get()))); + private static OutputFile outputFile; + + @BeforeAll + public static void createDataFile() throws IOException { + GenericRecord bufferRecord = GenericRecord.create(SCHEMA); + + ImmutableList.Builder builder = ImmutableList.builder(); + builder.add( + bufferRecord.copy( + ImmutableMap.of("id", 1L, "data", "a", "array", Collections.singletonList(1)))); + builder.add( + bufferRecord.copy(ImmutableMap.of("id", 2L, "data", "b", "array", Arrays.asList(2, 3)))); + builder.add( + bufferRecord.copy(ImmutableMap.of("id", 3L, "data", "c", "array", Arrays.asList(3, 4, 5)))); + builder.add( + bufferRecord.copy( + ImmutableMap.of("id", 4L, "data", "d", "array", Arrays.asList(4, 5, 6, 7)))); + builder.add( + bufferRecord.copy( + ImmutableMap.of("id", 5L, "data", "e", "array", Arrays.asList(5, 6, 7, 8, 9)))); + + outputFile = Files.localOutput(File.createTempFile("test", ".orc", temp.toFile())); + + try (DataWriter dataWriter = + ORC.writeData(outputFile) + .schema(SCHEMA) + .createWriterFunc(GenericOrcWriter::buildWriter) + .overwrite() + .withSpec(PartitionSpec.unpartitioned()) + .build()) { + for (Record record : builder.build()) { + dataWriter.write(record); + } + } + } + + private Iterator batchesToRows(Iterator batches) { + return Iterators.concat(Iterators.transform(batches, ColumnarBatch::rowIterator)); + } + + private void validateAllRows(Iterator rows) { + long rowCount = 0; + long expId = 1; + char expChar = 'a'; + while (rows.hasNext()) { + InternalRow row = rows.next(); + assertThat(row.getLong(0)).isEqualTo(expId); + assertThat(row.getString(1)).isEqualTo(Character.toString(expChar)); + assertThat(row.isNullAt(2)).isTrue(); + expId += 1; + expChar += 1; + rowCount += 1; + } + assertThat(rowCount).isEqualTo(5); + } + + @Test + public void testReader() throws IOException { + try (CloseableIterable reader = + ORC.read(outputFile.toInputFile()) + .project(SCHEMA) + .createBatchedReaderFunc( + readOrcSchema -> + VectorizedSparkOrcReaders.buildReader(SCHEMA, readOrcSchema, ImmutableMap.of())) + .build()) { + validateAllRows(batchesToRows(reader.iterator())); + } + } + + @Test + public void testReaderWithFilter() throws IOException { + try (CloseableIterable reader = + ORC.read(outputFile.toInputFile()) + .project(SCHEMA) + .createBatchedReaderFunc( + readOrcSchema -> + VectorizedSparkOrcReaders.buildReader(SCHEMA, readOrcSchema, ImmutableMap.of())) + .filter(Expressions.equal("id", 3L)) + .config(OrcConf.ALLOW_SARG_TO_FILTER.getAttribute(), String.valueOf(true)) + .build()) { + validateAllRows(batchesToRows(reader.iterator())); + } + } + + @Test + public void testWithFilterWithSelected() throws IOException { + try (CloseableIterable reader = + ORC.read(outputFile.toInputFile()) + .project(SCHEMA) + .createBatchedReaderFunc( + readOrcSchema -> + VectorizedSparkOrcReaders.buildReader(SCHEMA, readOrcSchema, ImmutableMap.of())) + .filter(Expressions.equal("id", 3L)) + .config(OrcConf.ALLOW_SARG_TO_FILTER.getAttribute(), String.valueOf(true)) + .config(OrcConf.READER_USE_SELECTED.getAttribute(), String.valueOf(true)) + .build()) { + Iterator rows = batchesToRows(reader.iterator()); + assertThat(rows).hasNext(); + InternalRow row = rows.next(); + assertThat(row.getLong(0)).isEqualTo(3L); + assertThat(row.getString(1)).isEqualTo("c"); + assertThat(row.getArray(3).toIntArray()).isEqualTo(new int[] {3, 4, 5}); + assertThat(rows).isExhausted(); + } + } +} diff --git a/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/data/parquet/vectorized/TestParquetDictionaryEncodedVectorizedReads.java b/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/data/parquet/vectorized/TestParquetDictionaryEncodedVectorizedReads.java new file mode 100644 index 000000000000..756f49a2aad6 --- /dev/null +++ b/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/data/parquet/vectorized/TestParquetDictionaryEncodedVectorizedReads.java @@ -0,0 +1,97 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.data.parquet.vectorized; + +import static org.apache.iceberg.TableProperties.PARQUET_ROW_GROUP_SIZE_BYTES_DEFAULT; + +import java.io.File; +import java.io.IOException; +import org.apache.avro.generic.GenericData; +import org.apache.iceberg.Schema; +import org.apache.iceberg.io.FileAppender; +import org.apache.iceberg.parquet.Parquet; +import org.apache.iceberg.relocated.com.google.common.base.Function; +import org.apache.iceberg.relocated.com.google.common.collect.FluentIterable; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableList; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap; +import org.apache.iceberg.relocated.com.google.common.collect.Iterables; +import org.apache.iceberg.spark.data.RandomData; +import org.junit.Assert; +import org.junit.Ignore; +import org.junit.Test; + +public class TestParquetDictionaryEncodedVectorizedReads extends TestParquetVectorizedReads { + + @Override + Iterable generateData( + Schema schema, + int numRecords, + long seed, + float nullPercentage, + Function transform) { + Iterable data = + RandomData.generateDictionaryEncodableData(schema, numRecords, seed, nullPercentage); + return transform == IDENTITY ? data : Iterables.transform(data, transform); + } + + @Test + @Override + @Ignore // Ignored since this code path is already tested in TestParquetVectorizedReads + public void testVectorizedReadsWithNewContainers() throws IOException {} + + @Test + public void testMixedDictionaryNonDictionaryReads() throws IOException { + Schema schema = new Schema(SUPPORTED_PRIMITIVES.fields()); + File dictionaryEncodedFile = temp.newFile(); + Assert.assertTrue("Delete should succeed", dictionaryEncodedFile.delete()); + Iterable dictionaryEncodableData = + RandomData.generateDictionaryEncodableData( + schema, 10000, 0L, RandomData.DEFAULT_NULL_PERCENTAGE); + try (FileAppender writer = + getParquetWriter(schema, dictionaryEncodedFile)) { + writer.addAll(dictionaryEncodableData); + } + + File plainEncodingFile = temp.newFile(); + Assert.assertTrue("Delete should succeed", plainEncodingFile.delete()); + Iterable nonDictionaryData = + RandomData.generate(schema, 10000, 0L, RandomData.DEFAULT_NULL_PERCENTAGE); + try (FileAppender writer = getParquetWriter(schema, plainEncodingFile)) { + writer.addAll(nonDictionaryData); + } + + int rowGroupSize = PARQUET_ROW_GROUP_SIZE_BYTES_DEFAULT; + File mixedFile = temp.newFile(); + Assert.assertTrue("Delete should succeed", mixedFile.delete()); + Parquet.concat( + ImmutableList.of(dictionaryEncodedFile, plainEncodingFile, dictionaryEncodedFile), + mixedFile, + rowGroupSize, + schema, + ImmutableMap.of()); + assertRecordsMatch( + schema, + 30000, + FluentIterable.concat(dictionaryEncodableData, nonDictionaryData, dictionaryEncodableData), + mixedFile, + false, + true, + BATCH_SIZE); + } +} diff --git a/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/data/parquet/vectorized/TestParquetDictionaryFallbackToPlainEncodingVectorizedReads.java b/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/data/parquet/vectorized/TestParquetDictionaryFallbackToPlainEncodingVectorizedReads.java new file mode 100644 index 000000000000..42ea34936b5f --- /dev/null +++ b/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/data/parquet/vectorized/TestParquetDictionaryFallbackToPlainEncodingVectorizedReads.java @@ -0,0 +1,75 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.data.parquet.vectorized; + +import java.io.File; +import java.io.IOException; +import org.apache.avro.generic.GenericData; +import org.apache.iceberg.Files; +import org.apache.iceberg.Schema; +import org.apache.iceberg.TableProperties; +import org.apache.iceberg.io.FileAppender; +import org.apache.iceberg.parquet.Parquet; +import org.apache.iceberg.relocated.com.google.common.base.Function; +import org.apache.iceberg.relocated.com.google.common.collect.Iterables; +import org.apache.iceberg.spark.data.RandomData; +import org.junit.Ignore; +import org.junit.Test; + +public class TestParquetDictionaryFallbackToPlainEncodingVectorizedReads + extends TestParquetVectorizedReads { + private static final int NUM_ROWS = 1_000_000; + + @Override + protected int getNumRows() { + return NUM_ROWS; + } + + @Override + Iterable generateData( + Schema schema, + int numRecords, + long seed, + float nullPercentage, + Function transform) { + // TODO: take into account nullPercentage when generating fallback encoding data + Iterable data = RandomData.generateFallbackData(schema, numRecords, seed, numRecords / 20); + return transform == IDENTITY ? data : Iterables.transform(data, transform); + } + + @Override + FileAppender getParquetWriter(Schema schema, File testFile) + throws IOException { + return Parquet.write(Files.localOutput(testFile)) + .schema(schema) + .named("test") + .set(TableProperties.PARQUET_DICT_SIZE_BYTES, "512000") + .build(); + } + + @Test + @Override + @Ignore // Fallback encoding not triggered when data is mostly null + public void testMostlyNullsForOptionalFields() {} + + @Test + @Override + @Ignore // Ignored since this code path is already tested in TestParquetVectorizedReads + public void testVectorizedReadsWithNewContainers() throws IOException {} +} diff --git a/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/data/parquet/vectorized/TestParquetVectorizedReads.java b/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/data/parquet/vectorized/TestParquetVectorizedReads.java new file mode 100644 index 000000000000..56e9490b997b --- /dev/null +++ b/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/data/parquet/vectorized/TestParquetVectorizedReads.java @@ -0,0 +1,362 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.data.parquet.vectorized; + +import static org.apache.iceberg.types.Types.NestedField.optional; +import static org.apache.iceberg.types.Types.NestedField.required; + +import java.io.File; +import java.io.IOException; +import java.util.Iterator; +import org.apache.avro.generic.GenericData; +import org.apache.iceberg.AssertHelpers; +import org.apache.iceberg.Files; +import org.apache.iceberg.Schema; +import org.apache.iceberg.io.CloseableIterable; +import org.apache.iceberg.io.FileAppender; +import org.apache.iceberg.parquet.Parquet; +import org.apache.iceberg.relocated.com.google.common.base.Function; +import org.apache.iceberg.relocated.com.google.common.base.Strings; +import org.apache.iceberg.relocated.com.google.common.collect.Iterables; +import org.apache.iceberg.relocated.com.google.common.collect.Lists; +import org.apache.iceberg.spark.data.AvroDataTest; +import org.apache.iceberg.spark.data.RandomData; +import org.apache.iceberg.spark.data.TestHelpers; +import org.apache.iceberg.spark.data.vectorized.VectorizedSparkParquetReaders; +import org.apache.iceberg.types.TypeUtil; +import org.apache.iceberg.types.Types; +import org.apache.parquet.column.ParquetProperties; +import org.apache.parquet.schema.GroupType; +import org.apache.parquet.schema.MessageType; +import org.apache.parquet.schema.Type; +import org.apache.spark.sql.vectorized.ColumnarBatch; +import org.junit.Assert; +import org.junit.Assume; +import org.junit.Ignore; +import org.junit.Test; + +public class TestParquetVectorizedReads extends AvroDataTest { + private static final int NUM_ROWS = 200_000; + static final int BATCH_SIZE = 10_000; + + static final Function IDENTITY = record -> record; + + @Override + protected void writeAndValidate(Schema schema) throws IOException { + writeAndValidate(schema, getNumRows(), 0L, RandomData.DEFAULT_NULL_PERCENTAGE, false, true); + } + + private void writeAndValidate( + Schema schema, + int numRecords, + long seed, + float nullPercentage, + boolean setAndCheckArrowValidityVector, + boolean reuseContainers) + throws IOException { + writeAndValidate( + schema, + numRecords, + seed, + nullPercentage, + setAndCheckArrowValidityVector, + reuseContainers, + BATCH_SIZE, + IDENTITY); + } + + private void writeAndValidate( + Schema schema, + int numRecords, + long seed, + float nullPercentage, + boolean setAndCheckArrowValidityVector, + boolean reuseContainers, + int batchSize, + Function transform) + throws IOException { + // Write test data + Assume.assumeTrue( + "Parquet Avro cannot write non-string map keys", + null + == TypeUtil.find( + schema, + type -> type.isMapType() && type.asMapType().keyType() != Types.StringType.get())); + + Iterable expected = + generateData(schema, numRecords, seed, nullPercentage, transform); + + // write a test parquet file using iceberg writer + File testFile = temp.newFile(); + Assert.assertTrue("Delete should succeed", testFile.delete()); + + try (FileAppender writer = getParquetWriter(schema, testFile)) { + writer.addAll(expected); + } + assertRecordsMatch( + schema, + numRecords, + expected, + testFile, + setAndCheckArrowValidityVector, + reuseContainers, + batchSize); + } + + protected int getNumRows() { + return NUM_ROWS; + } + + Iterable generateData( + Schema schema, + int numRecords, + long seed, + float nullPercentage, + Function transform) { + Iterable data = + RandomData.generate(schema, numRecords, seed, nullPercentage); + return transform == IDENTITY ? data : Iterables.transform(data, transform); + } + + FileAppender getParquetWriter(Schema schema, File testFile) + throws IOException { + return Parquet.write(Files.localOutput(testFile)).schema(schema).named("test").build(); + } + + FileAppender getParquetV2Writer(Schema schema, File testFile) + throws IOException { + return Parquet.write(Files.localOutput(testFile)) + .schema(schema) + .named("test") + .writerVersion(ParquetProperties.WriterVersion.PARQUET_2_0) + .build(); + } + + void assertRecordsMatch( + Schema schema, + int expectedSize, + Iterable expected, + File testFile, + boolean setAndCheckArrowValidityBuffer, + boolean reuseContainers, + int batchSize) + throws IOException { + Parquet.ReadBuilder readBuilder = + Parquet.read(Files.localInput(testFile)) + .project(schema) + .recordsPerBatch(batchSize) + .createBatchedReaderFunc( + type -> + VectorizedSparkParquetReaders.buildReader( + schema, type, setAndCheckArrowValidityBuffer)); + if (reuseContainers) { + readBuilder.reuseContainers(); + } + try (CloseableIterable batchReader = readBuilder.build()) { + Iterator expectedIter = expected.iterator(); + Iterator batches = batchReader.iterator(); + int numRowsRead = 0; + while (batches.hasNext()) { + ColumnarBatch batch = batches.next(); + numRowsRead += batch.numRows(); + TestHelpers.assertEqualsBatch( + schema.asStruct(), expectedIter, batch, setAndCheckArrowValidityBuffer); + } + Assert.assertEquals(expectedSize, numRowsRead); + } + } + + @Override + @Test + @Ignore + public void testArray() {} + + @Override + @Test + @Ignore + public void testArrayOfStructs() {} + + @Override + @Test + @Ignore + public void testMap() {} + + @Override + @Test + @Ignore + public void testNumericMapKey() {} + + @Override + @Test + @Ignore + public void testComplexMapKey() {} + + @Override + @Test + @Ignore + public void testMapOfStructs() {} + + @Override + @Test + @Ignore + public void testMixedTypes() {} + + @Test + @Override + public void testNestedStruct() { + AssertHelpers.assertThrows( + "Vectorized reads are not supported yet for struct fields", + UnsupportedOperationException.class, + "Vectorized reads are not supported yet for struct fields", + () -> + VectorizedSparkParquetReaders.buildReader( + TypeUtil.assignIncreasingFreshIds( + new Schema(required(1, "struct", SUPPORTED_PRIMITIVES))), + new MessageType( + "struct", new GroupType(Type.Repetition.OPTIONAL, "struct").withId(1)), + false)); + } + + @Test + public void testMostlyNullsForOptionalFields() throws IOException { + writeAndValidate( + TypeUtil.assignIncreasingFreshIds(new Schema(SUPPORTED_PRIMITIVES.fields())), + getNumRows(), + 0L, + 0.99f, + false, + true); + } + + @Test + public void testSettingArrowValidityVector() throws IOException { + writeAndValidate( + new Schema(Lists.transform(SUPPORTED_PRIMITIVES.fields(), Types.NestedField::asOptional)), + getNumRows(), + 0L, + RandomData.DEFAULT_NULL_PERCENTAGE, + true, + true); + } + + @Test + public void testVectorizedReadsWithNewContainers() throws IOException { + writeAndValidate( + TypeUtil.assignIncreasingFreshIds(new Schema(SUPPORTED_PRIMITIVES.fields())), + getNumRows(), + 0L, + RandomData.DEFAULT_NULL_PERCENTAGE, + true, + false); + } + + @Test + public void testVectorizedReadsWithReallocatedArrowBuffers() throws IOException { + // With a batch size of 2, 256 bytes are allocated in the VarCharVector. By adding strings of + // length 512, the vector will need to be reallocated for storing the batch. + writeAndValidate( + new Schema( + Lists.newArrayList( + SUPPORTED_PRIMITIVES.field("id"), SUPPORTED_PRIMITIVES.field("data"))), + 10, + 0L, + RandomData.DEFAULT_NULL_PERCENTAGE, + true, + true, + 2, + record -> { + if (record.get("data") != null) { + record.put("data", Strings.padEnd((String) record.get("data"), 512, 'a')); + } else { + record.put("data", Strings.padEnd("", 512, 'a')); + } + return record; + }); + } + + @Test + public void testReadsForTypePromotedColumns() throws Exception { + Schema writeSchema = + new Schema( + required(100, "id", Types.LongType.get()), + optional(101, "int_data", Types.IntegerType.get()), + optional(102, "float_data", Types.FloatType.get()), + optional(103, "decimal_data", Types.DecimalType.of(10, 5))); + + File dataFile = temp.newFile(); + Assert.assertTrue("Delete should succeed", dataFile.delete()); + Iterable data = + generateData(writeSchema, 30000, 0L, RandomData.DEFAULT_NULL_PERCENTAGE, IDENTITY); + try (FileAppender writer = getParquetWriter(writeSchema, dataFile)) { + writer.addAll(data); + } + + Schema readSchema = + new Schema( + required(100, "id", Types.LongType.get()), + optional(101, "int_data", Types.LongType.get()), + optional(102, "float_data", Types.DoubleType.get()), + optional(103, "decimal_data", Types.DecimalType.of(25, 5))); + + assertRecordsMatch(readSchema, 30000, data, dataFile, false, true, BATCH_SIZE); + } + + @Test + public void testSupportedReadsForParquetV2() throws Exception { + // Float and double column types are written using plain encoding with Parquet V2, + // also Parquet V2 will dictionary encode decimals that use fixed length binary + // (i.e. decimals > 8 bytes) + Schema schema = + new Schema( + optional(102, "float_data", Types.FloatType.get()), + optional(103, "double_data", Types.DoubleType.get()), + optional(104, "decimal_data", Types.DecimalType.of(25, 5))); + + File dataFile = temp.newFile(); + Assert.assertTrue("Delete should succeed", dataFile.delete()); + Iterable data = + generateData(schema, 30000, 0L, RandomData.DEFAULT_NULL_PERCENTAGE, IDENTITY); + try (FileAppender writer = getParquetV2Writer(schema, dataFile)) { + writer.addAll(data); + } + assertRecordsMatch(schema, 30000, data, dataFile, false, true, BATCH_SIZE); + } + + @Test + public void testUnsupportedReadsForParquetV2() throws Exception { + // Longs, ints, string types etc use delta encoding and which are not supported for vectorized + // reads + Schema schema = new Schema(SUPPORTED_PRIMITIVES.fields()); + File dataFile = temp.newFile(); + Assert.assertTrue("Delete should succeed", dataFile.delete()); + Iterable data = + generateData(schema, 30000, 0L, RandomData.DEFAULT_NULL_PERCENTAGE, IDENTITY); + try (FileAppender writer = getParquetV2Writer(schema, dataFile)) { + writer.addAll(data); + } + AssertHelpers.assertThrows( + "Vectorized reads not supported", + UnsupportedOperationException.class, + "Cannot support vectorized reads for column", + () -> { + assertRecordsMatch(schema, 30000, data, dataFile, false, true, BATCH_SIZE); + return null; + }); + } +} diff --git a/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/source/ComplexRecord.java b/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/source/ComplexRecord.java new file mode 100644 index 000000000000..42e8552578cd --- /dev/null +++ b/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/source/ComplexRecord.java @@ -0,0 +1,74 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.source; + +import org.apache.iceberg.relocated.com.google.common.base.MoreObjects; +import org.apache.iceberg.relocated.com.google.common.base.Objects; + +public class ComplexRecord { + private long id; + private NestedRecord struct; + + public ComplexRecord() {} + + public ComplexRecord(long id, NestedRecord struct) { + this.id = id; + this.struct = struct; + } + + public long getId() { + return id; + } + + public void setId(long id) { + this.id = id; + } + + public NestedRecord getStruct() { + return struct; + } + + public void setStruct(NestedRecord struct) { + this.struct = struct; + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + + if (o == null || getClass() != o.getClass()) { + return false; + } + + ComplexRecord record = (ComplexRecord) o; + return id == record.id && Objects.equal(struct, record.struct); + } + + @Override + public int hashCode() { + return Objects.hashCode(id, struct); + } + + @Override + public String toString() { + return MoreObjects.toStringHelper(this).add("id", id).add("struct", struct).toString(); + } +} diff --git a/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/source/FilePathLastModifiedRecord.java b/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/source/FilePathLastModifiedRecord.java new file mode 100644 index 000000000000..c62c1de6ba33 --- /dev/null +++ b/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/source/FilePathLastModifiedRecord.java @@ -0,0 +1,80 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.source; + +import java.sql.Timestamp; +import java.util.Objects; + +public class FilePathLastModifiedRecord { + private String filePath; + private Timestamp lastModified; + + public FilePathLastModifiedRecord() {} + + public FilePathLastModifiedRecord(String filePath, Timestamp lastModified) { + this.filePath = filePath; + this.lastModified = lastModified; + } + + public String getFilePath() { + return filePath; + } + + public void setFilePath(String filePath) { + this.filePath = filePath; + } + + public Timestamp getLastModified() { + return lastModified; + } + + public void setLastModified(Timestamp lastModified) { + this.lastModified = lastModified; + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + FilePathLastModifiedRecord that = (FilePathLastModifiedRecord) o; + return Objects.equals(filePath, that.filePath) + && Objects.equals(lastModified, that.lastModified); + } + + @Override + public int hashCode() { + return Objects.hash(filePath, lastModified); + } + + @Override + public String toString() { + return "FilePathLastModifiedRecord{" + + "filePath='" + + filePath + + '\'' + + ", lastModified='" + + lastModified + + '\'' + + '}'; + } +} diff --git a/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/source/LogMessage.java b/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/source/LogMessage.java new file mode 100644 index 000000000000..53a35eec61ce --- /dev/null +++ b/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/source/LogMessage.java @@ -0,0 +1,119 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.source; + +import java.time.Instant; +import java.util.concurrent.atomic.AtomicInteger; + +public class LogMessage { + private static AtomicInteger idCounter = new AtomicInteger(0); + + static LogMessage debug(String date, String message) { + return new LogMessage(idCounter.getAndIncrement(), date, "DEBUG", message); + } + + static LogMessage debug(String date, String message, Instant timestamp) { + return new LogMessage(idCounter.getAndIncrement(), date, "DEBUG", message, timestamp); + } + + static LogMessage info(String date, String message) { + return new LogMessage(idCounter.getAndIncrement(), date, "INFO", message); + } + + static LogMessage info(String date, String message, Instant timestamp) { + return new LogMessage(idCounter.getAndIncrement(), date, "INFO", message, timestamp); + } + + static LogMessage error(String date, String message) { + return new LogMessage(idCounter.getAndIncrement(), date, "ERROR", message); + } + + static LogMessage error(String date, String message, Instant timestamp) { + return new LogMessage(idCounter.getAndIncrement(), date, "ERROR", message, timestamp); + } + + static LogMessage warn(String date, String message) { + return new LogMessage(idCounter.getAndIncrement(), date, "WARN", message); + } + + static LogMessage warn(String date, String message, Instant timestamp) { + return new LogMessage(idCounter.getAndIncrement(), date, "WARN", message, timestamp); + } + + private int id; + private String date; + private String level; + private String message; + private Instant timestamp; + + private LogMessage(int id, String date, String level, String message) { + this.id = id; + this.date = date; + this.level = level; + this.message = message; + } + + private LogMessage(int id, String date, String level, String message, Instant timestamp) { + this.id = id; + this.date = date; + this.level = level; + this.message = message; + this.timestamp = timestamp; + } + + public int getId() { + return id; + } + + public void setId(int id) { + this.id = id; + } + + public String getDate() { + return date; + } + + public void setDate(String date) { + this.date = date; + } + + public String getLevel() { + return level; + } + + public void setLevel(String level) { + this.level = level; + } + + public String getMessage() { + return message; + } + + public void setMessage(String message) { + this.message = message; + } + + public Instant getTimestamp() { + return timestamp; + } + + public void setTimestamp(Instant timestamp) { + this.timestamp = timestamp; + } +} diff --git a/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/source/ManualSource.java b/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/source/ManualSource.java new file mode 100644 index 000000000000..c9c1c29ea8fc --- /dev/null +++ b/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/source/ManualSource.java @@ -0,0 +1,75 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.source; + +import java.util.Map; +import org.apache.iceberg.relocated.com.google.common.base.Preconditions; +import org.apache.iceberg.relocated.com.google.common.collect.Maps; +import org.apache.spark.sql.connector.catalog.Table; +import org.apache.spark.sql.connector.catalog.TableProvider; +import org.apache.spark.sql.connector.expressions.Transform; +import org.apache.spark.sql.sources.DataSourceRegister; +import org.apache.spark.sql.types.StructType; +import org.apache.spark.sql.util.CaseInsensitiveStringMap; + +public class ManualSource implements TableProvider, DataSourceRegister { + public static final String SHORT_NAME = "manual_source"; + public static final String TABLE_NAME = "TABLE_NAME"; + private static final Map tableMap = Maps.newHashMap(); + + public static void setTable(String name, Table table) { + Preconditions.checkArgument( + !tableMap.containsKey(name), "Cannot set " + name + ". It is already set"); + tableMap.put(name, table); + } + + public static void clearTables() { + tableMap.clear(); + } + + @Override + public String shortName() { + return SHORT_NAME; + } + + @Override + public StructType inferSchema(CaseInsensitiveStringMap options) { + return getTable(null, null, options).schema(); + } + + @Override + public Transform[] inferPartitioning(CaseInsensitiveStringMap options) { + return getTable(null, null, options).partitioning(); + } + + @Override + public org.apache.spark.sql.connector.catalog.Table getTable( + StructType schema, Transform[] partitioning, Map properties) { + Preconditions.checkArgument( + properties.containsKey(TABLE_NAME), "Missing property " + TABLE_NAME); + String tableName = properties.get(TABLE_NAME); + Preconditions.checkArgument(tableMap.containsKey(tableName), "Table missing " + tableName); + return tableMap.get(tableName); + } + + @Override + public boolean supportsExternalMetadata() { + return false; + } +} diff --git a/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/source/NestedRecord.java b/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/source/NestedRecord.java new file mode 100644 index 000000000000..ca36bfd4938b --- /dev/null +++ b/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/source/NestedRecord.java @@ -0,0 +1,77 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.source; + +import org.apache.iceberg.relocated.com.google.common.base.MoreObjects; +import org.apache.iceberg.relocated.com.google.common.base.Objects; + +public class NestedRecord { + private long innerId; + private String innerName; + + public NestedRecord() {} + + public NestedRecord(long innerId, String innerName) { + this.innerId = innerId; + this.innerName = innerName; + } + + public long getInnerId() { + return innerId; + } + + public String getInnerName() { + return innerName; + } + + public void setInnerId(long iId) { + innerId = iId; + } + + public void setInnerName(String name) { + innerName = name; + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + + if (o == null || getClass() != o.getClass()) { + return false; + } + + NestedRecord that = (NestedRecord) o; + return innerId == that.innerId && Objects.equal(innerName, that.innerName); + } + + @Override + public int hashCode() { + return Objects.hashCode(innerId, innerName); + } + + @Override + public String toString() { + return MoreObjects.toStringHelper(this) + .add("innerId", innerId) + .add("innerName", innerName) + .toString(); + } +} diff --git a/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/source/SimpleRecord.java b/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/source/SimpleRecord.java new file mode 100644 index 000000000000..550e20b9338e --- /dev/null +++ b/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/source/SimpleRecord.java @@ -0,0 +1,78 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.source; + +import org.apache.iceberg.relocated.com.google.common.base.Objects; + +public class SimpleRecord { + private Integer id; + private String data; + + public SimpleRecord() {} + + public SimpleRecord(Integer id, String data) { + this.id = id; + this.data = data; + } + + public Integer getId() { + return id; + } + + public void setId(Integer id) { + this.id = id; + } + + public String getData() { + return data; + } + + public void setData(String data) { + this.data = data; + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + + SimpleRecord record = (SimpleRecord) o; + return Objects.equal(id, record.id) && Objects.equal(data, record.data); + } + + @Override + public int hashCode() { + return Objects.hashCode(id, data); + } + + @Override + public String toString() { + StringBuilder buffer = new StringBuilder(); + buffer.append("{\"id\"="); + buffer.append(id); + buffer.append(",\"data\"=\""); + buffer.append(data); + buffer.append("\"}"); + return buffer.toString(); + } +} diff --git a/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/source/SparkSQLExecutionHelper.java b/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/source/SparkSQLExecutionHelper.java new file mode 100644 index 000000000000..3b350bc91e72 --- /dev/null +++ b/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/source/SparkSQLExecutionHelper.java @@ -0,0 +1,69 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.source; + +import org.apache.spark.sql.SparkSession; +import org.apache.spark.sql.execution.ui.SQLAppStatusStore; +import org.apache.spark.sql.execution.ui.SQLExecutionUIData; +import org.apache.spark.sql.execution.ui.SQLPlanMetric; +import org.junit.Assert; +import scala.Option; + +public class SparkSQLExecutionHelper { + + private SparkSQLExecutionHelper() {} + + /** + * Finds the value of a specified metric for the last SQL query that was executed. Metric values + * are stored in the `SQLAppStatusStore` as strings. + * + * @param spark SparkSession used to run the SQL query + * @param metricName name of the metric + * @return value of the metric + */ + public static String lastExecutedMetricValue(SparkSession spark, String metricName) { + SQLAppStatusStore statusStore = spark.sharedState().statusStore(); + SQLExecutionUIData lastExecution = statusStore.executionsList().last(); + Option sqlPlanMetric = + lastExecution.metrics().find(metric -> metric.name().equals(metricName)); + Assert.assertTrue( + String.format("Metric '%s' not found in last execution", metricName), + sqlPlanMetric.isDefined()); + long metricId = sqlPlanMetric.get().accumulatorId(); + + // Refresh metricValues, they will remain null until the execution is complete and metrics are + // aggregated + int attempts = 3; + while (lastExecution.metricValues() == null && attempts > 0) { + try { + Thread.sleep(100); + attempts--; + } catch (InterruptedException e) { + throw new RuntimeException(e); + } + + lastExecution = statusStore.execution(lastExecution.executionId()).get(); + } + + Assert.assertNotNull("Metric values were not finalized", lastExecution.metricValues()); + String metricValue = lastExecution.metricValues().get(metricId).getOrElse(null); + Assert.assertNotNull(String.format("Metric '%s' was not finalized", metricName), metricValue); + return metricValue; + } +} diff --git a/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/source/TestAvroScan.java b/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/source/TestAvroScan.java new file mode 100644 index 000000000000..9491adde4605 --- /dev/null +++ b/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/source/TestAvroScan.java @@ -0,0 +1,111 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.source; + +import static org.apache.iceberg.Files.localOutput; + +import java.io.File; +import java.io.IOException; +import java.util.List; +import java.util.UUID; +import org.apache.avro.generic.GenericData.Record; +import org.apache.hadoop.conf.Configuration; +import org.apache.iceberg.DataFile; +import org.apache.iceberg.DataFiles; +import org.apache.iceberg.FileFormat; +import org.apache.iceberg.PartitionSpec; +import org.apache.iceberg.Schema; +import org.apache.iceberg.Table; +import org.apache.iceberg.avro.Avro; +import org.apache.iceberg.hadoop.HadoopTables; +import org.apache.iceberg.io.FileAppender; +import org.apache.iceberg.spark.data.AvroDataTest; +import org.apache.iceberg.spark.data.RandomData; +import org.apache.iceberg.spark.data.TestHelpers; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.SparkSession; +import org.junit.AfterClass; +import org.junit.Assert; +import org.junit.BeforeClass; +import org.junit.Rule; +import org.junit.rules.TemporaryFolder; + +public class TestAvroScan extends AvroDataTest { + private static final Configuration CONF = new Configuration(); + + @Rule public TemporaryFolder temp = new TemporaryFolder(); + + private static SparkSession spark = null; + + @BeforeClass + public static void startSpark() { + TestAvroScan.spark = SparkSession.builder().master("local[2]").getOrCreate(); + } + + @AfterClass + public static void stopSpark() { + SparkSession currentSpark = TestAvroScan.spark; + TestAvroScan.spark = null; + currentSpark.stop(); + } + + @Override + protected void writeAndValidate(Schema schema) throws IOException { + File parent = temp.newFolder("avro"); + File location = new File(parent, "test"); + File dataFolder = new File(location, "data"); + dataFolder.mkdirs(); + + File avroFile = + new File(dataFolder, FileFormat.AVRO.addExtension(UUID.randomUUID().toString())); + + HadoopTables tables = new HadoopTables(CONF); + Table table = tables.create(schema, PartitionSpec.unpartitioned(), location.toString()); + + // Important: use the table's schema for the rest of the test + // When tables are created, the column ids are reassigned. + Schema tableSchema = table.schema(); + + List expected = RandomData.generateList(tableSchema, 100, 1L); + + try (FileAppender writer = + Avro.write(localOutput(avroFile)).schema(tableSchema).build()) { + writer.addAll(expected); + } + + DataFile file = + DataFiles.builder(PartitionSpec.unpartitioned()) + .withRecordCount(100) + .withFileSizeInBytes(avroFile.length()) + .withPath(avroFile.toString()) + .build(); + + table.newAppend().appendFile(file).commit(); + + Dataset df = spark.read().format("iceberg").load(location.toString()); + + List rows = df.collectAsList(); + Assert.assertEquals("Should contain 100 rows", 100, rows.size()); + + for (int i = 0; i < expected.size(); i += 1) { + TestHelpers.assertEqualsSafe(tableSchema.asStruct(), expected.get(i), rows.get(i)); + } + } +} diff --git a/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/source/TestBaseReader.java b/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/source/TestBaseReader.java new file mode 100644 index 000000000000..3d94966eb76c --- /dev/null +++ b/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/source/TestBaseReader.java @@ -0,0 +1,283 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.source; + +import static org.apache.iceberg.FileFormat.PARQUET; +import static org.apache.iceberg.Files.localOutput; + +import java.io.File; +import java.io.IOException; +import java.util.Iterator; +import java.util.List; +import java.util.Map; +import java.util.UUID; +import java.util.stream.Collectors; +import java.util.stream.IntStream; +import java.util.stream.Stream; +import java.util.stream.StreamSupport; +import org.apache.avro.generic.GenericData; +import org.apache.iceberg.AppendFiles; +import org.apache.iceberg.BaseCombinedScanTask; +import org.apache.iceberg.ContentFile; +import org.apache.iceberg.DataFile; +import org.apache.iceberg.DataFiles; +import org.apache.iceberg.FileScanTask; +import org.apache.iceberg.PartitionSpec; +import org.apache.iceberg.Schema; +import org.apache.iceberg.Table; +import org.apache.iceberg.io.CloseableIterator; +import org.apache.iceberg.io.FileAppender; +import org.apache.iceberg.parquet.Parquet; +import org.apache.iceberg.relocated.com.google.common.collect.Maps; +import org.apache.iceberg.spark.data.RandomData; +import org.apache.iceberg.types.Types; +import org.junit.Assert; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.TemporaryFolder; + +public class TestBaseReader { + + @Rule public TemporaryFolder temp = new TemporaryFolder(); + + private Table table; + + // Simulates the closeable iterator of data to be read + private static class CloseableIntegerRange implements CloseableIterator { + boolean closed; + Iterator iter; + + CloseableIntegerRange(long range) { + this.closed = false; + this.iter = IntStream.range(0, (int) range).iterator(); + } + + @Override + public void close() { + this.closed = true; + } + + @Override + public boolean hasNext() { + return iter.hasNext(); + } + + @Override + public Integer next() { + return iter.next(); + } + } + + // Main reader class to test base class iteration logic. + // Keeps track of iterator closure. + private static class ClosureTrackingReader extends BaseReader { + private Map tracker = Maps.newHashMap(); + + ClosureTrackingReader(Table table, List tasks) { + super(table, new BaseCombinedScanTask(tasks), null, null, false); + } + + @Override + protected Stream> referencedFiles(FileScanTask task) { + return Stream.of(); + } + + @Override + protected CloseableIterator open(FileScanTask task) { + CloseableIntegerRange intRange = new CloseableIntegerRange(task.file().recordCount()); + tracker.put(getKey(task), intRange); + return intRange; + } + + public Boolean isIteratorClosed(FileScanTask task) { + return tracker.get(getKey(task)).closed; + } + + public Boolean hasIterator(FileScanTask task) { + return tracker.containsKey(getKey(task)); + } + + private String getKey(FileScanTask task) { + return task.file().path().toString(); + } + } + + @Test + public void testClosureOnDataExhaustion() throws IOException { + Integer totalTasks = 10; + Integer recordPerTask = 10; + List tasks = createFileScanTasks(totalTasks, recordPerTask); + + ClosureTrackingReader reader = new ClosureTrackingReader(table, tasks); + + int countRecords = 0; + while (reader.next()) { + countRecords += 1; + Assert.assertNotNull("Reader should return non-null value", reader.get()); + } + + Assert.assertEquals( + "Reader returned incorrect number of records", totalTasks * recordPerTask, countRecords); + tasks.forEach( + t -> + Assert.assertTrue( + "All iterators should be closed after read exhausion", reader.isIteratorClosed(t))); + } + + @Test + public void testClosureDuringIteration() throws IOException { + Integer totalTasks = 2; + Integer recordPerTask = 1; + List tasks = createFileScanTasks(totalTasks, recordPerTask); + Assert.assertEquals(2, tasks.size()); + FileScanTask firstTask = tasks.get(0); + FileScanTask secondTask = tasks.get(1); + + ClosureTrackingReader reader = new ClosureTrackingReader(table, tasks); + + // Total of 2 elements + Assert.assertTrue(reader.next()); + Assert.assertFalse( + "First iter should not be closed on its last element", reader.isIteratorClosed(firstTask)); + + Assert.assertTrue(reader.next()); + Assert.assertTrue( + "First iter should be closed after moving to second iter", + reader.isIteratorClosed(firstTask)); + Assert.assertFalse( + "Second iter should not be closed on its last element", + reader.isIteratorClosed(secondTask)); + + Assert.assertFalse(reader.next()); + Assert.assertTrue(reader.isIteratorClosed(firstTask)); + Assert.assertTrue(reader.isIteratorClosed(secondTask)); + } + + @Test + public void testClosureWithoutAnyRead() throws IOException { + Integer totalTasks = 10; + Integer recordPerTask = 10; + List tasks = createFileScanTasks(totalTasks, recordPerTask); + + ClosureTrackingReader reader = new ClosureTrackingReader(table, tasks); + + reader.close(); + + tasks.forEach( + t -> + Assert.assertFalse( + "Iterator should not be created eagerly for tasks", reader.hasIterator(t))); + } + + @Test + public void testExplicitClosure() throws IOException { + Integer totalTasks = 10; + Integer recordPerTask = 10; + List tasks = createFileScanTasks(totalTasks, recordPerTask); + + ClosureTrackingReader reader = new ClosureTrackingReader(table, tasks); + + Integer halfDataSize = (totalTasks * recordPerTask) / 2; + for (int i = 0; i < halfDataSize; i++) { + Assert.assertTrue("Reader should have some element", reader.next()); + Assert.assertNotNull("Reader should return non-null value", reader.get()); + } + + reader.close(); + + // Some tasks might have not been opened yet, so we don't have corresponding tracker for it. + // But all that have been created must be closed. + tasks.forEach( + t -> { + if (reader.hasIterator(t)) { + Assert.assertTrue( + "Iterator should be closed after read exhausion", reader.isIteratorClosed(t)); + } + }); + } + + @Test + public void testIdempotentExplicitClosure() throws IOException { + Integer totalTasks = 10; + Integer recordPerTask = 10; + List tasks = createFileScanTasks(totalTasks, recordPerTask); + + ClosureTrackingReader reader = new ClosureTrackingReader(table, tasks); + + // Total 100 elements, only 5 iterators have been created + for (int i = 0; i < 45; i++) { + Assert.assertTrue("eader should have some element", reader.next()); + Assert.assertNotNull("Reader should return non-null value", reader.get()); + } + + for (int closeAttempt = 0; closeAttempt < 5; closeAttempt++) { + reader.close(); + for (int i = 0; i < 5; i++) { + Assert.assertTrue( + "Iterator should be closed after read exhausion", + reader.isIteratorClosed(tasks.get(i))); + } + for (int i = 5; i < 10; i++) { + Assert.assertFalse( + "Iterator should not be created eagerly for tasks", reader.hasIterator(tasks.get(i))); + } + } + } + + private List createFileScanTasks(Integer totalTasks, Integer recordPerTask) + throws IOException { + String desc = "make_scan_tasks"; + File parent = temp.newFolder(desc); + File location = new File(parent, "test"); + File dataFolder = new File(location, "data"); + Assert.assertTrue("mkdirs should succeed", dataFolder.mkdirs()); + + Schema schema = new Schema(Types.NestedField.required(0, "id", Types.LongType.get())); + + try { + this.table = TestTables.create(location, desc, schema, PartitionSpec.unpartitioned()); + // Important: use the table's schema for the rest of the test + // When tables are created, the column ids are reassigned. + Schema tableSchema = table.schema(); + List expected = RandomData.generateList(tableSchema, recordPerTask, 1L); + + AppendFiles appendFiles = table.newAppend(); + for (int i = 0; i < totalTasks; i++) { + File parquetFile = new File(dataFolder, PARQUET.addExtension(UUID.randomUUID().toString())); + try (FileAppender writer = + Parquet.write(localOutput(parquetFile)).schema(tableSchema).build()) { + writer.addAll(expected); + } + DataFile file = + DataFiles.builder(PartitionSpec.unpartitioned()) + .withFileSizeInBytes(parquetFile.length()) + .withPath(parquetFile.toString()) + .withRecordCount(recordPerTask) + .build(); + appendFiles.appendFile(file); + } + appendFiles.commit(); + + return StreamSupport.stream(table.newScan().planFiles().spliterator(), false) + .collect(Collectors.toList()); + } finally { + TestTables.clearTables(); + } + } +} diff --git a/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/source/TestChangelogReader.java b/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/source/TestChangelogReader.java new file mode 100644 index 000000000000..fc17547fad41 --- /dev/null +++ b/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/source/TestChangelogReader.java @@ -0,0 +1,259 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.source; + +import static org.apache.iceberg.types.Types.NestedField.optional; +import static org.apache.iceberg.types.Types.NestedField.required; + +import java.io.IOException; +import java.util.List; +import java.util.stream.Collectors; +import org.apache.iceberg.ChangelogOperation; +import org.apache.iceberg.ChangelogScanTask; +import org.apache.iceberg.DataFile; +import org.apache.iceberg.Files; +import org.apache.iceberg.IncrementalChangelogScan; +import org.apache.iceberg.PartitionSpec; +import org.apache.iceberg.ScanTaskGroup; +import org.apache.iceberg.Schema; +import org.apache.iceberg.Table; +import org.apache.iceberg.TestHelpers; +import org.apache.iceberg.catalog.TableIdentifier; +import org.apache.iceberg.data.FileHelpers; +import org.apache.iceberg.data.GenericRecord; +import org.apache.iceberg.data.Record; +import org.apache.iceberg.io.CloseableIterable; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableSet; +import org.apache.iceberg.relocated.com.google.common.collect.Lists; +import org.apache.iceberg.spark.SparkTestBase; +import org.apache.iceberg.types.Types; +import org.apache.spark.sql.catalyst.InternalRow; +import org.junit.After; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.TemporaryFolder; + +public class TestChangelogReader extends SparkTestBase { + private static final Schema SCHEMA = + new Schema( + required(1, "id", Types.IntegerType.get()), optional(2, "data", Types.StringType.get())); + private static final PartitionSpec SPEC = + PartitionSpec.builderFor(SCHEMA).bucket("data", 16).build(); + private final List records1 = Lists.newArrayList(); + private final List records2 = Lists.newArrayList(); + + private Table table; + private DataFile dataFile1; + private DataFile dataFile2; + + @Rule public TemporaryFolder temp = new TemporaryFolder(); + + @Before + public void before() throws IOException { + table = catalog.createTable(TableIdentifier.of("default", "test"), SCHEMA, SPEC); + // create some data + GenericRecord record = GenericRecord.create(table.schema()); + records1.add(record.copy("id", 29, "data", "a")); + records1.add(record.copy("id", 43, "data", "b")); + records1.add(record.copy("id", 61, "data", "c")); + records1.add(record.copy("id", 89, "data", "d")); + + records2.add(record.copy("id", 100, "data", "e")); + records2.add(record.copy("id", 121, "data", "f")); + records2.add(record.copy("id", 122, "data", "g")); + + // write data to files + dataFile1 = writeDataFile(records1); + dataFile2 = writeDataFile(records2); + } + + @After + public void after() { + catalog.dropTable(TableIdentifier.of("default", "test")); + } + + @Test + public void testInsert() throws IOException { + table.newAppend().appendFile(dataFile1).commit(); + long snapshotId1 = table.currentSnapshot().snapshotId(); + + table.newAppend().appendFile(dataFile2).commit(); + long snapshotId2 = table.currentSnapshot().snapshotId(); + + CloseableIterable> taskGroups = newScan().planTasks(); + + List rows = Lists.newArrayList(); + + for (ScanTaskGroup taskGroup : taskGroups) { + ChangelogRowReader reader = + new ChangelogRowReader(table, taskGroup, table.schema(), table.schema(), false); + while (reader.next()) { + rows.add(reader.get().copy()); + } + reader.close(); + } + + rows.sort((r1, r2) -> r1.getInt(0) - r2.getInt(0)); + + List expectedRows = Lists.newArrayList(); + addExpectedRows(expectedRows, ChangelogOperation.INSERT, snapshotId1, 0, records1); + addExpectedRows(expectedRows, ChangelogOperation.INSERT, snapshotId2, 1, records2); + + assertEquals("Should have expected rows", expectedRows, internalRowsToJava(rows)); + } + + @Test + public void testDelete() throws IOException { + table.newAppend().appendFile(dataFile1).commit(); + long snapshotId1 = table.currentSnapshot().snapshotId(); + + table.newDelete().deleteFile(dataFile1).commit(); + long snapshotId2 = table.currentSnapshot().snapshotId(); + + CloseableIterable> taskGroups = + newScan().fromSnapshotExclusive(snapshotId1).planTasks(); + + List rows = Lists.newArrayList(); + + for (ScanTaskGroup taskGroup : taskGroups) { + ChangelogRowReader reader = + new ChangelogRowReader(table, taskGroup, table.schema(), table.schema(), false); + while (reader.next()) { + rows.add(reader.get().copy()); + } + reader.close(); + } + + rows.sort((r1, r2) -> r1.getInt(0) - r2.getInt(0)); + + List expectedRows = Lists.newArrayList(); + addExpectedRows(expectedRows, ChangelogOperation.DELETE, snapshotId2, 0, records1); + + assertEquals("Should have expected rows", expectedRows, internalRowsToJava(rows)); + } + + @Test + public void testDataFileRewrite() throws IOException { + table.newAppend().appendFile(dataFile1).commit(); + table.newAppend().appendFile(dataFile2).commit(); + long snapshotId2 = table.currentSnapshot().snapshotId(); + + table + .newRewrite() + .rewriteFiles(ImmutableSet.of(dataFile1), ImmutableSet.of(dataFile2)) + .commit(); + + // the rewrite operation should generate no Changelog rows + CloseableIterable> taskGroups = + newScan().fromSnapshotExclusive(snapshotId2).planTasks(); + + List rows = Lists.newArrayList(); + + for (ScanTaskGroup taskGroup : taskGroups) { + ChangelogRowReader reader = + new ChangelogRowReader(table, taskGroup, table.schema(), table.schema(), false); + while (reader.next()) { + rows.add(reader.get().copy()); + } + reader.close(); + } + + Assert.assertEquals("Should have no rows", 0, rows.size()); + } + + @Test + public void testMixDeleteAndInsert() throws IOException { + table.newAppend().appendFile(dataFile1).commit(); + long snapshotId1 = table.currentSnapshot().snapshotId(); + + table.newDelete().deleteFile(dataFile1).commit(); + long snapshotId2 = table.currentSnapshot().snapshotId(); + + table.newAppend().appendFile(dataFile2).commit(); + long snapshotId3 = table.currentSnapshot().snapshotId(); + + CloseableIterable> taskGroups = newScan().planTasks(); + + List rows = Lists.newArrayList(); + + for (ScanTaskGroup taskGroup : taskGroups) { + ChangelogRowReader reader = + new ChangelogRowReader(table, taskGroup, table.schema(), table.schema(), false); + while (reader.next()) { + rows.add(reader.get().copy()); + } + reader.close(); + } + + // order by the change ordinal + rows.sort( + (r1, r2) -> { + if (r1.getInt(3) != r2.getInt(3)) { + return r1.getInt(3) - r2.getInt(3); + } else { + return r1.getInt(0) - r2.getInt(0); + } + }); + + List expectedRows = Lists.newArrayList(); + addExpectedRows(expectedRows, ChangelogOperation.INSERT, snapshotId1, 0, records1); + addExpectedRows(expectedRows, ChangelogOperation.DELETE, snapshotId2, 1, records1); + addExpectedRows(expectedRows, ChangelogOperation.INSERT, snapshotId3, 2, records2); + + assertEquals("Should have expected rows", expectedRows, internalRowsToJava(rows)); + } + + private IncrementalChangelogScan newScan() { + return table.newIncrementalChangelogScan(); + } + + private List addExpectedRows( + List expectedRows, + ChangelogOperation operation, + long snapshotId, + int changeOrdinal, + List records) { + records.forEach( + r -> + expectedRows.add(row(r.get(0), r.get(1), operation.name(), changeOrdinal, snapshotId))); + return expectedRows; + } + + protected List internalRowsToJava(List rows) { + return rows.stream().map(this::toJava).collect(Collectors.toList()); + } + + private Object[] toJava(InternalRow row) { + Object[] values = new Object[row.numFields()]; + values[0] = row.getInt(0); + values[1] = row.getString(1); + values[2] = row.getString(2); + values[3] = row.getInt(3); + values[4] = row.getLong(4); + return values; + } + + private DataFile writeDataFile(List records) throws IOException { + // records all use IDs that are in bucket id_bucket=0 + return FileHelpers.writeDataFile( + table, Files.localOutput(temp.newFile()), TestHelpers.Row.of(0), records); + } +} diff --git a/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/source/TestDataFrameWriterV2.java b/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/source/TestDataFrameWriterV2.java new file mode 100644 index 000000000000..a9b4f0d3ad2f --- /dev/null +++ b/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/source/TestDataFrameWriterV2.java @@ -0,0 +1,228 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.source; + +import java.util.List; +import org.apache.iceberg.AssertHelpers; +import org.apache.iceberg.TableProperties; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableList; +import org.apache.iceberg.spark.Spark3Util; +import org.apache.iceberg.spark.SparkTestBaseWithCatalog; +import org.apache.iceberg.types.Types; +import org.apache.spark.sql.AnalysisException; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Encoders; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.SparkSession; +import org.apache.spark.sql.catalyst.analysis.NoSuchTableException; +import org.apache.spark.sql.catalyst.parser.ParseException; +import org.apache.spark.sql.internal.SQLConf; +import org.junit.After; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; + +public class TestDataFrameWriterV2 extends SparkTestBaseWithCatalog { + @Before + public void createTable() { + sql("CREATE TABLE %s (id bigint, data string) USING iceberg", tableName); + } + + @After + public void removeTables() { + sql("DROP TABLE IF EXISTS %s", tableName); + } + + @Test + public void testMergeSchemaFailsWithoutWriterOption() throws Exception { + sql( + "ALTER TABLE %s SET TBLPROPERTIES ('%s'='true')", + tableName, TableProperties.SPARK_WRITE_ACCEPT_ANY_SCHEMA); + + Dataset twoColDF = + jsonToDF( + "id bigint, data string", + "{ \"id\": 1, \"data\": \"a\" }", + "{ \"id\": 2, \"data\": \"b\" }"); + + twoColDF.writeTo(tableName).append(); + + assertEquals( + "Should have initial 2-column rows", + ImmutableList.of(row(1L, "a"), row(2L, "b")), + sql("select * from %s order by id", tableName)); + + Dataset threeColDF = + jsonToDF( + "id bigint, data string, new_col float", + "{ \"id\": 3, \"data\": \"c\", \"new_col\": 12.06 }", + "{ \"id\": 4, \"data\": \"d\", \"new_col\": 14.41 }"); + + // this has a different error message than the case without accept-any-schema because it uses + // Iceberg checks + AssertHelpers.assertThrows( + "Should fail when merge-schema is not enabled on the writer", + IllegalArgumentException.class, + "Field new_col not found in source schema", + () -> { + try { + threeColDF.writeTo(tableName).append(); + } catch (NoSuchTableException e) { + // needed because append has checked exceptions + throw new RuntimeException(e); + } + }); + } + + @Test + public void testMergeSchemaWithoutAcceptAnySchema() throws Exception { + Dataset twoColDF = + jsonToDF( + "id bigint, data string", + "{ \"id\": 1, \"data\": \"a\" }", + "{ \"id\": 2, \"data\": \"b\" }"); + + twoColDF.writeTo(tableName).append(); + + assertEquals( + "Should have initial 2-column rows", + ImmutableList.of(row(1L, "a"), row(2L, "b")), + sql("select * from %s order by id", tableName)); + + Dataset threeColDF = + jsonToDF( + "id bigint, data string, new_col float", + "{ \"id\": 3, \"data\": \"c\", \"new_col\": 12.06 }", + "{ \"id\": 4, \"data\": \"d\", \"new_col\": 14.41 }"); + + AssertHelpers.assertThrows( + "Should fail when accept-any-schema is not enabled on the table", + AnalysisException.class, + "too many data columns", + () -> { + try { + threeColDF.writeTo(tableName).option("merge-schema", "true").append(); + } catch (NoSuchTableException e) { + // needed because append has checked exceptions + throw new RuntimeException(e); + } + }); + } + + @Test + public void testMergeSchemaSparkProperty() throws Exception { + sql( + "ALTER TABLE %s SET TBLPROPERTIES ('%s'='true')", + tableName, TableProperties.SPARK_WRITE_ACCEPT_ANY_SCHEMA); + + Dataset twoColDF = + jsonToDF( + "id bigint, data string", + "{ \"id\": 1, \"data\": \"a\" }", + "{ \"id\": 2, \"data\": \"b\" }"); + + twoColDF.writeTo(tableName).append(); + + assertEquals( + "Should have initial 2-column rows", + ImmutableList.of(row(1L, "a"), row(2L, "b")), + sql("select * from %s order by id", tableName)); + + Dataset threeColDF = + jsonToDF( + "id bigint, data string, new_col float", + "{ \"id\": 3, \"data\": \"c\", \"new_col\": 12.06 }", + "{ \"id\": 4, \"data\": \"d\", \"new_col\": 14.41 }"); + + threeColDF.writeTo(tableName).option("mergeSchema", "true").append(); + + assertEquals( + "Should have 3-column rows", + ImmutableList.of( + row(1L, "a", null), row(2L, "b", null), row(3L, "c", 12.06F), row(4L, "d", 14.41F)), + sql("select * from %s order by id", tableName)); + } + + @Test + public void testMergeSchemaIcebergProperty() throws Exception { + sql( + "ALTER TABLE %s SET TBLPROPERTIES ('%s'='true')", + tableName, TableProperties.SPARK_WRITE_ACCEPT_ANY_SCHEMA); + + Dataset twoColDF = + jsonToDF( + "id bigint, data string", + "{ \"id\": 1, \"data\": \"a\" }", + "{ \"id\": 2, \"data\": \"b\" }"); + + twoColDF.writeTo(tableName).append(); + + assertEquals( + "Should have initial 2-column rows", + ImmutableList.of(row(1L, "a"), row(2L, "b")), + sql("select * from %s order by id", tableName)); + + Dataset threeColDF = + jsonToDF( + "id bigint, data string, new_col float", + "{ \"id\": 3, \"data\": \"c\", \"new_col\": 12.06 }", + "{ \"id\": 4, \"data\": \"d\", \"new_col\": 14.41 }"); + + threeColDF.writeTo(tableName).option("merge-schema", "true").append(); + + assertEquals( + "Should have 3-column rows", + ImmutableList.of( + row(1L, "a", null), row(2L, "b", null), row(3L, "c", 12.06F), row(4L, "d", 14.41F)), + sql("select * from %s order by id", tableName)); + } + + @Test + public void testWriteWithCaseSensitiveOption() throws NoSuchTableException, ParseException { + SparkSession sparkSession = spark.cloneSession(); + sparkSession + .sql( + String.format( + "ALTER TABLE %s SET TBLPROPERTIES ('%s'='true')", + tableName, TableProperties.SPARK_WRITE_ACCEPT_ANY_SCHEMA)) + .collect(); + + String schema = "ID bigint, DaTa string"; + ImmutableList records = + ImmutableList.of("{ \"id\": 1, \"data\": \"a\" }", "{ \"id\": 2, \"data\": \"b\" }"); + + // disable spark.sql.caseSensitive + sparkSession.sql(String.format("SET %s=false", SQLConf.CASE_SENSITIVE().key())); + Dataset jsonDF = + sparkSession.createDataset(ImmutableList.copyOf(records), Encoders.STRING()); + Dataset ds = sparkSession.read().schema(schema).json(jsonDF); + // write should succeed + ds.writeTo(tableName).option("merge-schema", "true").option("check-ordering", "false").append(); + List fields = + Spark3Util.loadIcebergTable(sparkSession, tableName).schema().asStruct().fields(); + // Additional columns should not be created + Assert.assertEquals(2, fields.size()); + + // enable spark.sql.caseSensitive + sparkSession.sql(String.format("SET %s=true", SQLConf.CASE_SENSITIVE().key())); + ds.writeTo(tableName).option("merge-schema", "true").option("check-ordering", "false").append(); + fields = Spark3Util.loadIcebergTable(sparkSession, tableName).schema().asStruct().fields(); + Assert.assertEquals(4, fields.size()); + } +} diff --git a/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/source/TestDataFrameWrites.java b/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/source/TestDataFrameWrites.java new file mode 100644 index 000000000000..310e69b827a9 --- /dev/null +++ b/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/source/TestDataFrameWrites.java @@ -0,0 +1,423 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.source; + +import static org.apache.iceberg.spark.SparkSchemaUtil.convert; +import static org.apache.iceberg.spark.data.TestHelpers.assertEqualsSafe; +import static org.apache.iceberg.spark.data.TestHelpers.assertEqualsUnsafe; + +import java.io.File; +import java.io.IOException; +import java.net.URI; +import java.util.Arrays; +import java.util.Iterator; +import java.util.List; +import java.util.Map; +import java.util.Random; +import org.apache.avro.generic.GenericData.Record; +import org.apache.hadoop.conf.Configuration; +import org.apache.iceberg.Files; +import org.apache.iceberg.PartitionSpec; +import org.apache.iceberg.Schema; +import org.apache.iceberg.Snapshot; +import org.apache.iceberg.Table; +import org.apache.iceberg.TableProperties; +import org.apache.iceberg.avro.Avro; +import org.apache.iceberg.avro.AvroIterable; +import org.apache.iceberg.hadoop.HadoopTables; +import org.apache.iceberg.io.FileAppender; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap; +import org.apache.iceberg.relocated.com.google.common.collect.Lists; +import org.apache.iceberg.spark.SparkSQLProperties; +import org.apache.iceberg.spark.SparkSchemaUtil; +import org.apache.iceberg.spark.SparkWriteOptions; +import org.apache.iceberg.spark.data.AvroDataTest; +import org.apache.iceberg.spark.data.RandomData; +import org.apache.iceberg.spark.data.SparkAvroReader; +import org.apache.iceberg.types.Types; +import org.apache.spark.SparkException; +import org.apache.spark.TaskContext; +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.api.java.function.MapPartitionsFunction; +import org.apache.spark.sql.DataFrameWriter; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.SaveMode; +import org.apache.spark.sql.SparkSession; +import org.apache.spark.sql.catalyst.InternalRow; +import org.apache.spark.sql.catalyst.encoders.RowEncoder; +import org.assertj.core.api.Assertions; +import org.junit.AfterClass; +import org.junit.Assert; +import org.junit.Assume; +import org.junit.BeforeClass; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.Parameterized; + +@RunWith(Parameterized.class) +public class TestDataFrameWrites extends AvroDataTest { + private static final Configuration CONF = new Configuration(); + + private final String format; + + @Parameterized.Parameters(name = "format = {0}") + public static Object[] parameters() { + return new Object[] {"parquet", "avro", "orc"}; + } + + public TestDataFrameWrites(String format) { + this.format = format; + } + + private static SparkSession spark = null; + private static JavaSparkContext sc = null; + + private Map tableProperties; + + private org.apache.spark.sql.types.StructType sparkSchema = + new org.apache.spark.sql.types.StructType( + new org.apache.spark.sql.types.StructField[] { + new org.apache.spark.sql.types.StructField( + "optionalField", + org.apache.spark.sql.types.DataTypes.StringType, + true, + org.apache.spark.sql.types.Metadata.empty()), + new org.apache.spark.sql.types.StructField( + "requiredField", + org.apache.spark.sql.types.DataTypes.StringType, + false, + org.apache.spark.sql.types.Metadata.empty()) + }); + + private Schema icebergSchema = + new Schema( + Types.NestedField.optional(1, "optionalField", Types.StringType.get()), + Types.NestedField.required(2, "requiredField", Types.StringType.get())); + + private List data0 = + Arrays.asList( + "{\"optionalField\": \"a1\", \"requiredField\": \"bid_001\"}", + "{\"optionalField\": \"a2\", \"requiredField\": \"bid_002\"}"); + private List data1 = + Arrays.asList( + "{\"optionalField\": \"d1\", \"requiredField\": \"bid_101\"}", + "{\"optionalField\": \"d2\", \"requiredField\": \"bid_102\"}", + "{\"optionalField\": \"d3\", \"requiredField\": \"bid_103\"}", + "{\"optionalField\": \"d4\", \"requiredField\": \"bid_104\"}"); + + @BeforeClass + public static void startSpark() { + TestDataFrameWrites.spark = SparkSession.builder().master("local[2]").getOrCreate(); + TestDataFrameWrites.sc = JavaSparkContext.fromSparkContext(spark.sparkContext()); + } + + @AfterClass + public static void stopSpark() { + SparkSession currentSpark = TestDataFrameWrites.spark; + TestDataFrameWrites.spark = null; + TestDataFrameWrites.sc = null; + currentSpark.stop(); + } + + @Override + protected void writeAndValidate(Schema schema) throws IOException { + File location = createTableFolder(); + Table table = createTable(schema, location); + writeAndValidateWithLocations(table, location, new File(location, "data")); + } + + @Test + public void testWriteWithCustomDataLocation() throws IOException { + File location = createTableFolder(); + File tablePropertyDataLocation = temp.newFolder("test-table-property-data-dir"); + Table table = createTable(new Schema(SUPPORTED_PRIMITIVES.fields()), location); + table + .updateProperties() + .set(TableProperties.WRITE_DATA_LOCATION, tablePropertyDataLocation.getAbsolutePath()) + .commit(); + writeAndValidateWithLocations(table, location, tablePropertyDataLocation); + } + + private File createTableFolder() throws IOException { + File parent = temp.newFolder("parquet"); + File location = new File(parent, "test"); + Assert.assertTrue("Mkdir should succeed", location.mkdirs()); + return location; + } + + private Table createTable(Schema schema, File location) { + HadoopTables tables = new HadoopTables(CONF); + return tables.create(schema, PartitionSpec.unpartitioned(), location.toString()); + } + + private void writeAndValidateWithLocations(Table table, File location, File expectedDataDir) + throws IOException { + Schema tableSchema = table.schema(); // use the table schema because ids are reassigned + + table.updateProperties().set(TableProperties.DEFAULT_FILE_FORMAT, format).commit(); + + Iterable expected = RandomData.generate(tableSchema, 100, 0L); + writeData(expected, tableSchema, location.toString()); + + table.refresh(); + + List actual = readTable(location.toString()); + + Iterator expectedIter = expected.iterator(); + Iterator actualIter = actual.iterator(); + while (expectedIter.hasNext() && actualIter.hasNext()) { + assertEqualsSafe(tableSchema.asStruct(), expectedIter.next(), actualIter.next()); + } + Assert.assertEquals( + "Both iterators should be exhausted", expectedIter.hasNext(), actualIter.hasNext()); + + table + .currentSnapshot() + .addedDataFiles(table.io()) + .forEach( + dataFile -> + Assert.assertTrue( + String.format( + "File should have the parent directory %s, but has: %s.", + expectedDataDir.getAbsolutePath(), dataFile.path()), + URI.create(dataFile.path().toString()) + .getPath() + .startsWith(expectedDataDir.getAbsolutePath()))); + } + + private List readTable(String location) { + Dataset result = spark.read().format("iceberg").load(location); + + return result.collectAsList(); + } + + private void writeData(Iterable records, Schema schema, String location) + throws IOException { + Dataset df = createDataset(records, schema); + DataFrameWriter writer = df.write().format("iceberg").mode("append"); + writer.save(location); + } + + private void writeDataWithFailOnPartition( + Iterable records, Schema schema, String location) throws IOException, SparkException { + final int numPartitions = 10; + final int partitionToFail = new Random().nextInt(numPartitions); + MapPartitionsFunction failOnFirstPartitionFunc = + (MapPartitionsFunction) + input -> { + int partitionId = TaskContext.getPartitionId(); + + if (partitionId == partitionToFail) { + throw new SparkException( + String.format("Intended exception in partition %d !", partitionId)); + } + return input; + }; + + Dataset df = + createDataset(records, schema) + .repartition(numPartitions) + .mapPartitions(failOnFirstPartitionFunc, RowEncoder.apply(convert(schema))); + // This trick is needed because Spark 3 handles decimal overflow in RowEncoder which "changes" + // nullability of the column to "true" regardless of original nullability. + // Setting "check-nullability" option to "false" doesn't help as it fails at Spark analyzer. + Dataset convertedDf = df.sqlContext().createDataFrame(df.rdd(), convert(schema)); + DataFrameWriter writer = convertedDf.write().format("iceberg").mode("append"); + writer.save(location); + } + + private Dataset createDataset(Iterable records, Schema schema) throws IOException { + // this uses the SparkAvroReader to create a DataFrame from the list of records + // it assumes that SparkAvroReader is correct + File testFile = temp.newFile(); + Assert.assertTrue("Delete should succeed", testFile.delete()); + + try (FileAppender writer = + Avro.write(Files.localOutput(testFile)).schema(schema).named("test").build()) { + for (Record rec : records) { + writer.add(rec); + } + } + + // make sure the dataframe matches the records before moving on + List rows = Lists.newArrayList(); + try (AvroIterable reader = + Avro.read(Files.localInput(testFile)) + .createReaderFunc(SparkAvroReader::new) + .project(schema) + .build()) { + + Iterator recordIter = records.iterator(); + Iterator readIter = reader.iterator(); + while (recordIter.hasNext() && readIter.hasNext()) { + InternalRow row = readIter.next(); + assertEqualsUnsafe(schema.asStruct(), recordIter.next(), row); + rows.add(row); + } + Assert.assertEquals( + "Both iterators should be exhausted", recordIter.hasNext(), readIter.hasNext()); + } + + JavaRDD rdd = sc.parallelize(rows); + return spark.internalCreateDataFrame(JavaRDD.toRDD(rdd), convert(schema), false); + } + + @Test + public void testNullableWithWriteOption() throws IOException { + Assume.assumeTrue( + "Spark 3 rejects writing nulls to a required column", spark.version().startsWith("2")); + + File location = new File(temp.newFolder("parquet"), "test"); + String sourcePath = String.format("%s/nullable_poc/sourceFolder/", location.toString()); + String targetPath = String.format("%s/nullable_poc/targetFolder/", location.toString()); + + tableProperties = ImmutableMap.of(TableProperties.WRITE_DATA_LOCATION, targetPath); + + // read this and append to iceberg dataset + spark + .read() + .schema(sparkSchema) + .json(JavaSparkContext.fromSparkContext(spark.sparkContext()).parallelize(data1)) + .write() + .parquet(sourcePath); + + // this is our iceberg dataset to which we will append data + new HadoopTables(spark.sessionState().newHadoopConf()) + .create( + icebergSchema, + PartitionSpec.builderFor(icebergSchema).identity("requiredField").build(), + tableProperties, + targetPath); + + // this is the initial data inside the iceberg dataset + spark + .read() + .schema(sparkSchema) + .json(JavaSparkContext.fromSparkContext(spark.sparkContext()).parallelize(data0)) + .write() + .format("iceberg") + .mode(SaveMode.Append) + .save(targetPath); + + // read from parquet and append to iceberg w/ nullability check disabled + spark + .read() + .schema(SparkSchemaUtil.convert(icebergSchema)) + .parquet(sourcePath) + .write() + .format("iceberg") + .option(SparkWriteOptions.CHECK_NULLABILITY, false) + .mode(SaveMode.Append) + .save(targetPath); + + // read all data + List rows = spark.read().format("iceberg").load(targetPath).collectAsList(); + Assert.assertEquals("Should contain 6 rows", 6, rows.size()); + } + + @Test + public void testNullableWithSparkSqlOption() throws IOException { + Assume.assumeTrue( + "Spark 3 rejects writing nulls to a required column", spark.version().startsWith("2")); + + File location = new File(temp.newFolder("parquet"), "test"); + String sourcePath = String.format("%s/nullable_poc/sourceFolder/", location.toString()); + String targetPath = String.format("%s/nullable_poc/targetFolder/", location.toString()); + + tableProperties = ImmutableMap.of(TableProperties.WRITE_DATA_LOCATION, targetPath); + + // read this and append to iceberg dataset + spark + .read() + .schema(sparkSchema) + .json(JavaSparkContext.fromSparkContext(spark.sparkContext()).parallelize(data1)) + .write() + .parquet(sourcePath); + + SparkSession newSparkSession = + SparkSession.builder() + .master("local[2]") + .appName("NullableTest") + .config(SparkSQLProperties.CHECK_NULLABILITY, false) + .getOrCreate(); + + // this is our iceberg dataset to which we will append data + new HadoopTables(newSparkSession.sessionState().newHadoopConf()) + .create( + icebergSchema, + PartitionSpec.builderFor(icebergSchema).identity("requiredField").build(), + tableProperties, + targetPath); + + // this is the initial data inside the iceberg dataset + newSparkSession + .read() + .schema(sparkSchema) + .json(JavaSparkContext.fromSparkContext(spark.sparkContext()).parallelize(data0)) + .write() + .format("iceberg") + .mode(SaveMode.Append) + .save(targetPath); + + // read from parquet and append to iceberg + newSparkSession + .read() + .schema(SparkSchemaUtil.convert(icebergSchema)) + .parquet(sourcePath) + .write() + .format("iceberg") + .mode(SaveMode.Append) + .save(targetPath); + + // read all data + List rows = newSparkSession.read().format("iceberg").load(targetPath).collectAsList(); + Assert.assertEquals("Should contain 6 rows", 6, rows.size()); + } + + @Test + public void testFaultToleranceOnWrite() throws IOException { + File location = createTableFolder(); + Schema schema = new Schema(SUPPORTED_PRIMITIVES.fields()); + Table table = createTable(schema, location); + + Iterable records = RandomData.generate(schema, 100, 0L); + writeData(records, schema, location.toString()); + + table.refresh(); + + Snapshot snapshotBeforeFailingWrite = table.currentSnapshot(); + List resultBeforeFailingWrite = readTable(location.toString()); + + Iterable records2 = RandomData.generate(schema, 100, 0L); + + Assertions.assertThatThrownBy( + () -> writeDataWithFailOnPartition(records2, schema, location.toString())) + .isInstanceOf(SparkException.class); + + table.refresh(); + + Snapshot snapshotAfterFailingWrite = table.currentSnapshot(); + List resultAfterFailingWrite = readTable(location.toString()); + + Assert.assertEquals(snapshotAfterFailingWrite, snapshotBeforeFailingWrite); + Assert.assertEquals(resultAfterFailingWrite, resultBeforeFailingWrite); + } +} diff --git a/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/source/TestDataSourceOptions.java b/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/source/TestDataSourceOptions.java new file mode 100644 index 000000000000..60dd716c631e --- /dev/null +++ b/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/source/TestDataSourceOptions.java @@ -0,0 +1,451 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.source; + +import static org.apache.iceberg.types.Types.NestedField.optional; + +import java.io.IOException; +import java.math.RoundingMode; +import java.util.List; +import java.util.Map; +import java.util.Set; +import org.apache.hadoop.conf.Configuration; +import org.apache.iceberg.AssertHelpers; +import org.apache.iceberg.DataFile; +import org.apache.iceberg.FileFormat; +import org.apache.iceberg.FileScanTask; +import org.apache.iceberg.ManifestFile; +import org.apache.iceberg.PartitionSpec; +import org.apache.iceberg.Schema; +import org.apache.iceberg.Snapshot; +import org.apache.iceberg.Table; +import org.apache.iceberg.TableProperties; +import org.apache.iceberg.hadoop.HadoopTables; +import org.apache.iceberg.io.CloseableIterable; +import org.apache.iceberg.relocated.com.google.common.collect.Lists; +import org.apache.iceberg.relocated.com.google.common.collect.Maps; +import org.apache.iceberg.relocated.com.google.common.collect.Sets; +import org.apache.iceberg.relocated.com.google.common.math.LongMath; +import org.apache.iceberg.spark.CommitMetadata; +import org.apache.iceberg.spark.SparkReadOptions; +import org.apache.iceberg.spark.SparkWriteOptions; +import org.apache.iceberg.types.Types; +import org.apache.iceberg.util.SnapshotUtil; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Encoders; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.SaveMode; +import org.apache.spark.sql.SparkSession; +import org.junit.AfterClass; +import org.junit.Assert; +import org.junit.BeforeClass; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.TemporaryFolder; + +public class TestDataSourceOptions { + + private static final Configuration CONF = new Configuration(); + private static final Schema SCHEMA = + new Schema( + optional(1, "id", Types.IntegerType.get()), optional(2, "data", Types.StringType.get())); + private static SparkSession spark = null; + + @Rule public TemporaryFolder temp = new TemporaryFolder(); + + @BeforeClass + public static void startSpark() { + TestDataSourceOptions.spark = SparkSession.builder().master("local[2]").getOrCreate(); + } + + @AfterClass + public static void stopSpark() { + SparkSession currentSpark = TestDataSourceOptions.spark; + TestDataSourceOptions.spark = null; + currentSpark.stop(); + } + + @Test + public void testWriteFormatOptionOverridesTableProperties() throws IOException { + String tableLocation = temp.newFolder("iceberg-table").toString(); + + HadoopTables tables = new HadoopTables(CONF); + PartitionSpec spec = PartitionSpec.unpartitioned(); + Map options = Maps.newHashMap(); + options.put(TableProperties.DEFAULT_FILE_FORMAT, "avro"); + Table table = tables.create(SCHEMA, spec, options, tableLocation); + + List expectedRecords = + Lists.newArrayList( + new SimpleRecord(1, "a"), new SimpleRecord(2, "b"), new SimpleRecord(3, "c")); + Dataset df = spark.createDataFrame(expectedRecords, SimpleRecord.class); + df.select("id", "data") + .write() + .format("iceberg") + .option(SparkWriteOptions.WRITE_FORMAT, "parquet") + .mode(SaveMode.Append) + .save(tableLocation); + + try (CloseableIterable tasks = table.newScan().planFiles()) { + tasks.forEach( + task -> { + FileFormat fileFormat = FileFormat.fromFileName(task.file().path()); + Assert.assertEquals(FileFormat.PARQUET, fileFormat); + }); + } + } + + @Test + public void testNoWriteFormatOption() throws IOException { + String tableLocation = temp.newFolder("iceberg-table").toString(); + + HadoopTables tables = new HadoopTables(CONF); + PartitionSpec spec = PartitionSpec.unpartitioned(); + Map options = Maps.newHashMap(); + options.put(TableProperties.DEFAULT_FILE_FORMAT, "avro"); + Table table = tables.create(SCHEMA, spec, options, tableLocation); + + List expectedRecords = + Lists.newArrayList( + new SimpleRecord(1, "a"), new SimpleRecord(2, "b"), new SimpleRecord(3, "c")); + Dataset df = spark.createDataFrame(expectedRecords, SimpleRecord.class); + df.select("id", "data").write().format("iceberg").mode("append").save(tableLocation); + + try (CloseableIterable tasks = table.newScan().planFiles()) { + tasks.forEach( + task -> { + FileFormat fileFormat = FileFormat.fromFileName(task.file().path()); + Assert.assertEquals(FileFormat.AVRO, fileFormat); + }); + } + } + + @Test + public void testHadoopOptions() throws IOException { + String tableLocation = temp.newFolder("iceberg-table").toString(); + Configuration sparkHadoopConf = spark.sessionState().newHadoopConf(); + String originalDefaultFS = sparkHadoopConf.get("fs.default.name"); + + try { + HadoopTables tables = new HadoopTables(CONF); + PartitionSpec spec = PartitionSpec.unpartitioned(); + Map options = Maps.newHashMap(); + tables.create(SCHEMA, spec, options, tableLocation); + + // set an invalid value for 'fs.default.name' in Spark Hadoop config + // to verify that 'hadoop.' data source options are propagated correctly + sparkHadoopConf.set("fs.default.name", "hdfs://localhost:9000"); + + List expectedRecords = + Lists.newArrayList(new SimpleRecord(1, "a"), new SimpleRecord(2, "b")); + Dataset originalDf = spark.createDataFrame(expectedRecords, SimpleRecord.class); + originalDf + .select("id", "data") + .write() + .format("iceberg") + .mode("append") + .option("hadoop.fs.default.name", "file:///") + .save(tableLocation); + + Dataset resultDf = + spark + .read() + .format("iceberg") + .option("hadoop.fs.default.name", "file:///") + .load(tableLocation); + List resultRecords = + resultDf.orderBy("id").as(Encoders.bean(SimpleRecord.class)).collectAsList(); + + Assert.assertEquals("Records should match", expectedRecords, resultRecords); + } finally { + sparkHadoopConf.set("fs.default.name", originalDefaultFS); + } + } + + @Test + public void testSplitOptionsOverridesTableProperties() throws IOException { + String tableLocation = temp.newFolder("iceberg-table").toString(); + + HadoopTables tables = new HadoopTables(CONF); + PartitionSpec spec = PartitionSpec.unpartitioned(); + Map options = Maps.newHashMap(); + options.put(TableProperties.SPLIT_SIZE, String.valueOf(128L * 1024 * 1024)); // 128Mb + options.put( + TableProperties.DEFAULT_FILE_FORMAT, + String.valueOf(FileFormat.AVRO)); // Arbitrarily splittable + Table icebergTable = tables.create(SCHEMA, spec, options, tableLocation); + + List expectedRecords = + Lists.newArrayList(new SimpleRecord(1, "a"), new SimpleRecord(2, "b")); + Dataset originalDf = spark.createDataFrame(expectedRecords, SimpleRecord.class); + originalDf + .select("id", "data") + .repartition(1) + .write() + .format("iceberg") + .mode("append") + .save(tableLocation); + + List files = + Lists.newArrayList(icebergTable.currentSnapshot().addedDataFiles(icebergTable.io())); + Assert.assertEquals("Should have written 1 file", 1, files.size()); + + long fileSize = files.get(0).fileSizeInBytes(); + long splitSize = LongMath.divide(fileSize, 2, RoundingMode.CEILING); + + Dataset resultDf = + spark + .read() + .format("iceberg") + .option(SparkReadOptions.SPLIT_SIZE, String.valueOf(splitSize)) + .load(tableLocation); + + Assert.assertEquals("Spark partitions should match", 2, resultDf.javaRDD().getNumPartitions()); + } + + @Test + public void testIncrementalScanOptions() throws IOException { + String tableLocation = temp.newFolder("iceberg-table").toString(); + + HadoopTables tables = new HadoopTables(CONF); + PartitionSpec spec = PartitionSpec.unpartitioned(); + Map options = Maps.newHashMap(); + Table table = tables.create(SCHEMA, spec, options, tableLocation); + + List expectedRecords = + Lists.newArrayList( + new SimpleRecord(1, "a"), + new SimpleRecord(2, "b"), + new SimpleRecord(3, "c"), + new SimpleRecord(4, "d")); + for (SimpleRecord record : expectedRecords) { + Dataset originalDf = + spark.createDataFrame(Lists.newArrayList(record), SimpleRecord.class); + originalDf.select("id", "data").write().format("iceberg").mode("append").save(tableLocation); + } + List snapshotIds = SnapshotUtil.currentAncestorIds(table); + + // start-snapshot-id and snapshot-id are both configured. + AssertHelpers.assertThrows( + "Check both start-snapshot-id and snapshot-id are configured", + IllegalArgumentException.class, + "Cannot set start-snapshot-id and end-snapshot-id for incremental scans", + () -> { + spark + .read() + .format("iceberg") + .option("snapshot-id", snapshotIds.get(3).toString()) + .option("start-snapshot-id", snapshotIds.get(3).toString()) + .load(tableLocation) + .explain(); + }); + + // end-snapshot-id and as-of-timestamp are both configured. + AssertHelpers.assertThrows( + "Check both start-snapshot-id and snapshot-id are configured", + IllegalArgumentException.class, + "Cannot set start-snapshot-id and end-snapshot-id for incremental scans", + () -> { + spark + .read() + .format("iceberg") + .option( + SparkReadOptions.AS_OF_TIMESTAMP, + Long.toString(table.snapshot(snapshotIds.get(3)).timestampMillis())) + .option("end-snapshot-id", snapshotIds.get(2).toString()) + .load(tableLocation) + .explain(); + }); + + // only end-snapshot-id is configured. + AssertHelpers.assertThrows( + "Check both start-snapshot-id and snapshot-id are configured", + IllegalArgumentException.class, + "Cannot set only end-snapshot-id for incremental scans", + () -> { + spark + .read() + .format("iceberg") + .option("end-snapshot-id", snapshotIds.get(2).toString()) + .load(tableLocation) + .explain(); + }); + + // test (1st snapshot, current snapshot] incremental scan. + List result = + spark + .read() + .format("iceberg") + .option("start-snapshot-id", snapshotIds.get(3).toString()) + .load(tableLocation) + .orderBy("id") + .as(Encoders.bean(SimpleRecord.class)) + .collectAsList(); + Assert.assertEquals("Records should match", expectedRecords.subList(1, 4), result); + + // test (2nd snapshot, 3rd snapshot] incremental scan. + List result1 = + spark + .read() + .format("iceberg") + .option("start-snapshot-id", snapshotIds.get(2).toString()) + .option("end-snapshot-id", snapshotIds.get(1).toString()) + .load(tableLocation) + .orderBy("id") + .as(Encoders.bean(SimpleRecord.class)) + .collectAsList(); + Assert.assertEquals("Records should match", expectedRecords.subList(2, 3), result1); + } + + @Test + public void testMetadataSplitSizeOptionOverrideTableProperties() throws IOException { + String tableLocation = temp.newFolder("iceberg-table").toString(); + + HadoopTables tables = new HadoopTables(CONF); + PartitionSpec spec = PartitionSpec.unpartitioned(); + Map options = Maps.newHashMap(); + Table table = tables.create(SCHEMA, spec, options, tableLocation); + + List expectedRecords = + Lists.newArrayList(new SimpleRecord(1, "a"), new SimpleRecord(2, "b")); + Dataset originalDf = spark.createDataFrame(expectedRecords, SimpleRecord.class); + // produce 1st manifest + originalDf.select("id", "data").write().format("iceberg").mode("append").save(tableLocation); + // produce 2nd manifest + originalDf.select("id", "data").write().format("iceberg").mode("append").save(tableLocation); + + List manifests = table.currentSnapshot().allManifests(table.io()); + + Assert.assertEquals("Must be 2 manifests", 2, manifests.size()); + + // set the target metadata split size so each manifest ends up in a separate split + table + .updateProperties() + .set(TableProperties.METADATA_SPLIT_SIZE, String.valueOf(manifests.get(0).length())) + .commit(); + + Dataset entriesDf = spark.read().format("iceberg").load(tableLocation + "#entries"); + Assert.assertEquals("Num partitions must match", 2, entriesDf.javaRDD().getNumPartitions()); + + // override the table property using options + entriesDf = + spark + .read() + .format("iceberg") + .option(SparkReadOptions.SPLIT_SIZE, String.valueOf(128 * 1024 * 1024)) + .load(tableLocation + "#entries"); + Assert.assertEquals("Num partitions must match", 1, entriesDf.javaRDD().getNumPartitions()); + } + + @Test + public void testDefaultMetadataSplitSize() throws IOException { + String tableLocation = temp.newFolder("iceberg-table").toString(); + + HadoopTables tables = new HadoopTables(CONF); + PartitionSpec spec = PartitionSpec.unpartitioned(); + Map options = Maps.newHashMap(); + Table icebergTable = tables.create(SCHEMA, spec, options, tableLocation); + + List expectedRecords = + Lists.newArrayList(new SimpleRecord(1, "a"), new SimpleRecord(2, "b")); + Dataset originalDf = spark.createDataFrame(expectedRecords, SimpleRecord.class); + originalDf.select("id", "data").write().format("iceberg").mode("append").save(tableLocation); + + int splitSize = (int) TableProperties.METADATA_SPLIT_SIZE_DEFAULT; // 32MB split size + + int expectedSplits = + ((int) + tables + .load(tableLocation + "#entries") + .currentSnapshot() + .allManifests(icebergTable.io()) + .get(0) + .length() + + splitSize + - 1) + / splitSize; + + Dataset metadataDf = spark.read().format("iceberg").load(tableLocation + "#entries"); + + int partitionNum = metadataDf.javaRDD().getNumPartitions(); + Assert.assertEquals("Spark partitions should match", expectedSplits, partitionNum); + } + + @Test + public void testExtraSnapshotMetadata() throws IOException { + String tableLocation = temp.newFolder("iceberg-table").toString(); + HadoopTables tables = new HadoopTables(CONF); + tables.create(SCHEMA, PartitionSpec.unpartitioned(), Maps.newHashMap(), tableLocation); + + List expectedRecords = + Lists.newArrayList(new SimpleRecord(1, "a"), new SimpleRecord(2, "b")); + Dataset originalDf = spark.createDataFrame(expectedRecords, SimpleRecord.class); + originalDf + .select("id", "data") + .write() + .format("iceberg") + .mode("append") + .option(SparkWriteOptions.SNAPSHOT_PROPERTY_PREFIX + ".extra-key", "someValue") + .option(SparkWriteOptions.SNAPSHOT_PROPERTY_PREFIX + ".another-key", "anotherValue") + .save(tableLocation); + + Table table = tables.load(tableLocation); + + Assert.assertTrue(table.currentSnapshot().summary().get("extra-key").equals("someValue")); + Assert.assertTrue(table.currentSnapshot().summary().get("another-key").equals("anotherValue")); + } + + @Test + public void testExtraSnapshotMetadataWithSQL() throws InterruptedException, IOException { + String tableLocation = temp.newFolder("iceberg-table").toString(); + HadoopTables tables = new HadoopTables(CONF); + + Table table = + tables.create(SCHEMA, PartitionSpec.unpartitioned(), Maps.newHashMap(), tableLocation); + + List expectedRecords = + Lists.newArrayList(new SimpleRecord(1, "a"), new SimpleRecord(2, "b")); + Dataset originalDf = spark.createDataFrame(expectedRecords, SimpleRecord.class); + originalDf.select("id", "data").write().format("iceberg").mode("append").save(tableLocation); + spark.read().format("iceberg").load(tableLocation).createOrReplaceTempView("target"); + Thread writerThread = + new Thread( + () -> { + Map properties = Maps.newHashMap(); + properties.put("writer-thread", String.valueOf(Thread.currentThread().getName())); + CommitMetadata.withCommitProperties( + properties, + () -> { + spark.sql("INSERT INTO target VALUES (3, 'c'), (4, 'd')"); + return 0; + }, + RuntimeException.class); + }); + writerThread.setName("test-extra-commit-message-writer-thread"); + writerThread.start(); + writerThread.join(); + Set threadNames = Sets.newHashSet(); + for (Snapshot snapshot : table.snapshots()) { + threadNames.add(snapshot.summary().get("writer-thread")); + } + Assert.assertEquals(2, threadNames.size()); + Assert.assertTrue(threadNames.contains(null)); + Assert.assertTrue(threadNames.contains("test-extra-commit-message-writer-thread")); + } +} diff --git a/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/source/TestFilteredScan.java b/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/source/TestFilteredScan.java new file mode 100644 index 000000000000..a616b764a6b1 --- /dev/null +++ b/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/source/TestFilteredScan.java @@ -0,0 +1,683 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.source; + +import static org.apache.iceberg.Files.localOutput; +import static org.apache.spark.sql.catalyst.util.DateTimeUtils.fromJavaTimestamp; +import static org.apache.spark.sql.functions.callUDF; +import static org.apache.spark.sql.functions.column; + +import java.io.File; +import java.io.IOException; +import java.sql.Timestamp; +import java.time.OffsetDateTime; +import java.util.List; +import java.util.UUID; +import java.util.function.Function; +import java.util.stream.Collectors; +import org.apache.hadoop.conf.Configuration; +import org.apache.iceberg.DataFile; +import org.apache.iceberg.DataFiles; +import org.apache.iceberg.FileFormat; +import org.apache.iceberg.PartitionSpec; +import org.apache.iceberg.Schema; +import org.apache.iceberg.Table; +import org.apache.iceberg.data.GenericAppenderFactory; +import org.apache.iceberg.data.GenericRecord; +import org.apache.iceberg.data.Record; +import org.apache.iceberg.hadoop.HadoopTables; +import org.apache.iceberg.io.FileAppender; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap; +import org.apache.iceberg.relocated.com.google.common.collect.Lists; +import org.apache.iceberg.relocated.com.google.common.collect.Sets; +import org.apache.iceberg.spark.SparkReadOptions; +import org.apache.iceberg.spark.data.GenericsHelpers; +import org.apache.iceberg.transforms.Transforms; +import org.apache.iceberg.types.Types; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Encoders; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.SparkSession; +import org.apache.spark.sql.api.java.UDF1; +import org.apache.spark.sql.catalyst.expressions.UnsafeRow; +import org.apache.spark.sql.connector.read.Batch; +import org.apache.spark.sql.connector.read.InputPartition; +import org.apache.spark.sql.connector.read.ScanBuilder; +import org.apache.spark.sql.connector.read.SupportsPushDownFilters; +import org.apache.spark.sql.sources.And; +import org.apache.spark.sql.sources.EqualTo; +import org.apache.spark.sql.sources.Filter; +import org.apache.spark.sql.sources.GreaterThan; +import org.apache.spark.sql.sources.LessThan; +import org.apache.spark.sql.sources.Not; +import org.apache.spark.sql.sources.StringStartsWith; +import org.apache.spark.sql.types.IntegerType$; +import org.apache.spark.sql.types.LongType$; +import org.apache.spark.sql.types.StringType$; +import org.apache.spark.sql.util.CaseInsensitiveStringMap; +import org.assertj.core.api.Assertions; +import org.junit.AfterClass; +import org.junit.Assert; +import org.junit.Before; +import org.junit.BeforeClass; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.TemporaryFolder; +import org.junit.runner.RunWith; +import org.junit.runners.Parameterized; + +@RunWith(Parameterized.class) +public class TestFilteredScan { + private static final Configuration CONF = new Configuration(); + private static final HadoopTables TABLES = new HadoopTables(CONF); + + private static final Schema SCHEMA = + new Schema( + Types.NestedField.required(1, "id", Types.LongType.get()), + Types.NestedField.optional(2, "ts", Types.TimestampType.withZone()), + Types.NestedField.optional(3, "data", Types.StringType.get())); + + private static final PartitionSpec BUCKET_BY_ID = + PartitionSpec.builderFor(SCHEMA).bucket("id", 4).build(); + + private static final PartitionSpec PARTITION_BY_DAY = + PartitionSpec.builderFor(SCHEMA).day("ts").build(); + + private static final PartitionSpec PARTITION_BY_HOUR = + PartitionSpec.builderFor(SCHEMA).hour("ts").build(); + + private static final PartitionSpec PARTITION_BY_DATA = + PartitionSpec.builderFor(SCHEMA).identity("data").build(); + + private static final PartitionSpec PARTITION_BY_ID = + PartitionSpec.builderFor(SCHEMA).identity("id").build(); + + private static SparkSession spark = null; + + @BeforeClass + public static void startSpark() { + TestFilteredScan.spark = SparkSession.builder().master("local[2]").getOrCreate(); + + // define UDFs used by partition tests + Function bucket4 = Transforms.bucket(4).bind(Types.LongType.get()); + spark.udf().register("bucket4", (UDF1) bucket4::apply, IntegerType$.MODULE$); + + Function day = Transforms.day().bind(Types.TimestampType.withZone()); + spark + .udf() + .register( + "ts_day", + (UDF1) timestamp -> day.apply(fromJavaTimestamp(timestamp)), + IntegerType$.MODULE$); + + Function hour = Transforms.hour().bind(Types.TimestampType.withZone()); + spark + .udf() + .register( + "ts_hour", + (UDF1) timestamp -> hour.apply(fromJavaTimestamp(timestamp)), + IntegerType$.MODULE$); + + spark.udf().register("data_ident", (UDF1) data -> data, StringType$.MODULE$); + spark.udf().register("id_ident", (UDF1) id -> id, LongType$.MODULE$); + } + + @AfterClass + public static void stopSpark() { + SparkSession currentSpark = TestFilteredScan.spark; + TestFilteredScan.spark = null; + currentSpark.stop(); + } + + @Rule public TemporaryFolder temp = new TemporaryFolder(); + + private final String format; + private final boolean vectorized; + + @Parameterized.Parameters(name = "format = {0}, vectorized = {1}") + public static Object[][] parameters() { + return new Object[][] { + {"parquet", false}, + {"parquet", true}, + {"avro", false}, + {"orc", false}, + {"orc", true} + }; + } + + public TestFilteredScan(String format, boolean vectorized) { + this.format = format; + this.vectorized = vectorized; + } + + private File parent = null; + private File unpartitioned = null; + private List records = null; + + @Before + public void writeUnpartitionedTable() throws IOException { + this.parent = temp.newFolder("TestFilteredScan"); + this.unpartitioned = new File(parent, "unpartitioned"); + File dataFolder = new File(unpartitioned, "data"); + Assert.assertTrue("Mkdir should succeed", dataFolder.mkdirs()); + + Table table = TABLES.create(SCHEMA, PartitionSpec.unpartitioned(), unpartitioned.toString()); + Schema tableSchema = table.schema(); // use the table schema because ids are reassigned + + FileFormat fileFormat = FileFormat.fromString(format); + + File testFile = new File(dataFolder, fileFormat.addExtension(UUID.randomUUID().toString())); + + this.records = testRecords(tableSchema); + + try (FileAppender writer = + new GenericAppenderFactory(tableSchema).newAppender(localOutput(testFile), fileFormat)) { + writer.addAll(records); + } + + DataFile file = + DataFiles.builder(PartitionSpec.unpartitioned()) + .withRecordCount(records.size()) + .withFileSizeInBytes(testFile.length()) + .withPath(testFile.toString()) + .build(); + + table.newAppend().appendFile(file).commit(); + } + + @Test + public void testUnpartitionedIDFilters() { + CaseInsensitiveStringMap options = + new CaseInsensitiveStringMap(ImmutableMap.of("path", unpartitioned.toString())); + SparkScanBuilder builder = + new SparkScanBuilder(spark, TABLES.load(options.get("path")), options); + + for (int i = 0; i < 10; i += 1) { + pushFilters(builder, EqualTo.apply("id", i)); + Batch scan = builder.build().toBatch(); + + InputPartition[] partitions = scan.planInputPartitions(); + Assert.assertEquals("Should only create one task for a small file", 1, partitions.length); + + // validate row filtering + assertEqualsSafe( + SCHEMA.asStruct(), expected(i), read(unpartitioned.toString(), vectorized, "id = " + i)); + } + } + + @Test + public void testUnpartitionedCaseInsensitiveIDFilters() { + CaseInsensitiveStringMap options = + new CaseInsensitiveStringMap(ImmutableMap.of("path", unpartitioned.toString())); + + // set spark.sql.caseSensitive to false + String caseSensitivityBeforeTest = TestFilteredScan.spark.conf().get("spark.sql.caseSensitive"); + TestFilteredScan.spark.conf().set("spark.sql.caseSensitive", "false"); + + try { + + for (int i = 0; i < 10; i += 1) { + SparkScanBuilder builder = + new SparkScanBuilder(spark, TABLES.load(options.get("path")), options) + .caseSensitive(false); + + pushFilters( + builder, + EqualTo.apply("ID", i)); // note lower(ID) == lower(id), so there must be a match + Batch scan = builder.build().toBatch(); + + InputPartition[] tasks = scan.planInputPartitions(); + Assert.assertEquals("Should only create one task for a small file", 1, tasks.length); + + // validate row filtering + assertEqualsSafe( + SCHEMA.asStruct(), + expected(i), + read(unpartitioned.toString(), vectorized, "id = " + i)); + } + } finally { + // return global conf to previous state + TestFilteredScan.spark.conf().set("spark.sql.caseSensitive", caseSensitivityBeforeTest); + } + } + + @Test + public void testUnpartitionedTimestampFilter() { + CaseInsensitiveStringMap options = + new CaseInsensitiveStringMap(ImmutableMap.of("path", unpartitioned.toString())); + + SparkScanBuilder builder = + new SparkScanBuilder(spark, TABLES.load(options.get("path")), options); + + pushFilters(builder, LessThan.apply("ts", "2017-12-22T00:00:00+00:00")); + Batch scan = builder.build().toBatch(); + + InputPartition[] tasks = scan.planInputPartitions(); + Assert.assertEquals("Should only create one task for a small file", 1, tasks.length); + + assertEqualsSafe( + SCHEMA.asStruct(), + expected(5, 6, 7, 8, 9), + read( + unpartitioned.toString(), + vectorized, + "ts < cast('2017-12-22 00:00:00+00:00' as timestamp)")); + } + + @Test + public void testBucketPartitionedIDFilters() { + Table table = buildPartitionedTable("bucketed_by_id", BUCKET_BY_ID, "bucket4", "id"); + CaseInsensitiveStringMap options = + new CaseInsensitiveStringMap(ImmutableMap.of("path", table.location())); + + Batch unfiltered = + new SparkScanBuilder(spark, TABLES.load(options.get("path")), options).build().toBatch(); + Assert.assertEquals( + "Unfiltered table should created 4 read tasks", 4, unfiltered.planInputPartitions().length); + + for (int i = 0; i < 10; i += 1) { + SparkScanBuilder builder = + new SparkScanBuilder(spark, TABLES.load(options.get("path")), options); + + pushFilters(builder, EqualTo.apply("id", i)); + Batch scan = builder.build().toBatch(); + + InputPartition[] tasks = scan.planInputPartitions(); + + // validate predicate push-down + Assert.assertEquals("Should create one task for a single bucket", 1, tasks.length); + + // validate row filtering + assertEqualsSafe( + SCHEMA.asStruct(), expected(i), read(table.location(), vectorized, "id = " + i)); + } + } + + @SuppressWarnings("checkstyle:AvoidNestedBlocks") + @Test + public void testDayPartitionedTimestampFilters() { + Table table = buildPartitionedTable("partitioned_by_day", PARTITION_BY_DAY, "ts_day", "ts"); + CaseInsensitiveStringMap options = + new CaseInsensitiveStringMap(ImmutableMap.of("path", table.location())); + Batch unfiltered = + new SparkScanBuilder(spark, TABLES.load(options.get("path")), options).build().toBatch(); + + Assert.assertEquals( + "Unfiltered table should created 2 read tasks", 2, unfiltered.planInputPartitions().length); + + { + SparkScanBuilder builder = + new SparkScanBuilder(spark, TABLES.load(options.get("path")), options); + + pushFilters(builder, LessThan.apply("ts", "2017-12-22T00:00:00+00:00")); + Batch scan = builder.build().toBatch(); + + InputPartition[] tasks = scan.planInputPartitions(); + Assert.assertEquals("Should create one task for 2017-12-21", 1, tasks.length); + + assertEqualsSafe( + SCHEMA.asStruct(), + expected(5, 6, 7, 8, 9), + read( + table.location(), vectorized, "ts < cast('2017-12-22 00:00:00+00:00' as timestamp)")); + } + + { + SparkScanBuilder builder = + new SparkScanBuilder(spark, TABLES.load(options.get("path")), options); + + pushFilters( + builder, + And.apply( + GreaterThan.apply("ts", "2017-12-22T06:00:00+00:00"), + LessThan.apply("ts", "2017-12-22T08:00:00+00:00"))); + Batch scan = builder.build().toBatch(); + + InputPartition[] tasks = scan.planInputPartitions(); + Assert.assertEquals("Should create one task for 2017-12-22", 1, tasks.length); + + assertEqualsSafe( + SCHEMA.asStruct(), + expected(1, 2), + read( + table.location(), + vectorized, + "ts > cast('2017-12-22 06:00:00+00:00' as timestamp) and " + + "ts < cast('2017-12-22 08:00:00+00:00' as timestamp)")); + } + } + + @SuppressWarnings("checkstyle:AvoidNestedBlocks") + @Test + public void testHourPartitionedTimestampFilters() { + Table table = buildPartitionedTable("partitioned_by_hour", PARTITION_BY_HOUR, "ts_hour", "ts"); + + CaseInsensitiveStringMap options = + new CaseInsensitiveStringMap(ImmutableMap.of("path", table.location())); + Batch unfiltered = + new SparkScanBuilder(spark, TABLES.load(options.get("path")), options).build().toBatch(); + + Assert.assertEquals( + "Unfiltered table should created 9 read tasks", 9, unfiltered.planInputPartitions().length); + + { + SparkScanBuilder builder = + new SparkScanBuilder(spark, TABLES.load(options.get("path")), options); + + pushFilters(builder, LessThan.apply("ts", "2017-12-22T00:00:00+00:00")); + Batch scan = builder.build().toBatch(); + + InputPartition[] tasks = scan.planInputPartitions(); + Assert.assertEquals("Should create 4 tasks for 2017-12-21: 15, 17, 21, 22", 4, tasks.length); + + assertEqualsSafe( + SCHEMA.asStruct(), + expected(8, 9, 7, 6, 5), + read( + table.location(), vectorized, "ts < cast('2017-12-22 00:00:00+00:00' as timestamp)")); + } + + { + SparkScanBuilder builder = + new SparkScanBuilder(spark, TABLES.load(options.get("path")), options); + + pushFilters( + builder, + And.apply( + GreaterThan.apply("ts", "2017-12-22T06:00:00+00:00"), + LessThan.apply("ts", "2017-12-22T08:00:00+00:00"))); + Batch scan = builder.build().toBatch(); + + InputPartition[] tasks = scan.planInputPartitions(); + Assert.assertEquals("Should create 2 tasks for 2017-12-22: 6, 7", 2, tasks.length); + + assertEqualsSafe( + SCHEMA.asStruct(), + expected(2, 1), + read( + table.location(), + vectorized, + "ts > cast('2017-12-22 06:00:00+00:00' as timestamp) and " + + "ts < cast('2017-12-22 08:00:00+00:00' as timestamp)")); + } + } + + @SuppressWarnings("checkstyle:AvoidNestedBlocks") + @Test + public void testFilterByNonProjectedColumn() { + { + Schema actualProjection = SCHEMA.select("id", "data"); + List expected = Lists.newArrayList(); + for (Record rec : expected(5, 6, 7, 8, 9)) { + expected.add(projectFlat(actualProjection, rec)); + } + + assertEqualsSafe( + actualProjection.asStruct(), + expected, + read( + unpartitioned.toString(), + vectorized, + "ts < cast('2017-12-22 00:00:00+00:00' as timestamp)", + "id", + "data")); + } + + { + // only project id: ts will be projected because of the filter, but data will not be included + + Schema actualProjection = SCHEMA.select("id"); + List expected = Lists.newArrayList(); + for (Record rec : expected(1, 2)) { + expected.add(projectFlat(actualProjection, rec)); + } + + assertEqualsSafe( + actualProjection.asStruct(), + expected, + read( + unpartitioned.toString(), + vectorized, + "ts > cast('2017-12-22 06:00:00+00:00' as timestamp) and " + + "ts < cast('2017-12-22 08:00:00+00:00' as timestamp)", + "id")); + } + } + + @Test + public void testPartitionedByDataStartsWithFilter() { + Table table = + buildPartitionedTable("partitioned_by_data", PARTITION_BY_DATA, "data_ident", "data"); + CaseInsensitiveStringMap options = + new CaseInsensitiveStringMap(ImmutableMap.of("path", table.location())); + + SparkScanBuilder builder = + new SparkScanBuilder(spark, TABLES.load(options.get("path")), options); + + pushFilters(builder, new StringStartsWith("data", "junc")); + Batch scan = builder.build().toBatch(); + + Assert.assertEquals(1, scan.planInputPartitions().length); + } + + @Test + public void testPartitionedByDataNotStartsWithFilter() { + Table table = + buildPartitionedTable("partitioned_by_data", PARTITION_BY_DATA, "data_ident", "data"); + CaseInsensitiveStringMap options = + new CaseInsensitiveStringMap(ImmutableMap.of("path", table.location())); + + SparkScanBuilder builder = + new SparkScanBuilder(spark, TABLES.load(options.get("path")), options); + + pushFilters(builder, new Not(new StringStartsWith("data", "junc"))); + Batch scan = builder.build().toBatch(); + + Assert.assertEquals(9, scan.planInputPartitions().length); + } + + @Test + public void testPartitionedByIdStartsWith() { + Table table = buildPartitionedTable("partitioned_by_id", PARTITION_BY_ID, "id_ident", "id"); + + CaseInsensitiveStringMap options = + new CaseInsensitiveStringMap(ImmutableMap.of("path", table.location())); + + SparkScanBuilder builder = + new SparkScanBuilder(spark, TABLES.load(options.get("path")), options); + + pushFilters(builder, new StringStartsWith("data", "junc")); + Batch scan = builder.build().toBatch(); + + Assert.assertEquals(1, scan.planInputPartitions().length); + } + + @Test + public void testPartitionedByIdNotStartsWith() { + Table table = buildPartitionedTable("partitioned_by_id", PARTITION_BY_ID, "id_ident", "id"); + + CaseInsensitiveStringMap options = + new CaseInsensitiveStringMap(ImmutableMap.of("path", table.location())); + + SparkScanBuilder builder = + new SparkScanBuilder(spark, TABLES.load(options.get("path")), options); + + pushFilters(builder, new Not(new StringStartsWith("data", "junc"))); + Batch scan = builder.build().toBatch(); + + Assert.assertEquals(9, scan.planInputPartitions().length); + } + + @Test + public void testUnpartitionedStartsWith() { + Dataset df = + spark + .read() + .format("iceberg") + .option(SparkReadOptions.VECTORIZATION_ENABLED, String.valueOf(vectorized)) + .load(unpartitioned.toString()); + + List matchedData = + df.select("data").where("data LIKE 'jun%'").as(Encoders.STRING()).collectAsList(); + + Assert.assertEquals(1, matchedData.size()); + Assert.assertEquals("junction", matchedData.get(0)); + } + + @Test + public void testUnpartitionedNotStartsWith() { + Dataset df = + spark + .read() + .format("iceberg") + .option(SparkReadOptions.VECTORIZATION_ENABLED, String.valueOf(vectorized)) + .load(unpartitioned.toString()); + + List matchedData = + df.select("data").where("data NOT LIKE 'jun%'").as(Encoders.STRING()).collectAsList(); + + List expected = + testRecords(SCHEMA).stream() + .map(r -> r.getField("data").toString()) + .filter(d -> !d.startsWith("jun")) + .collect(Collectors.toList()); + + Assert.assertEquals(9, matchedData.size()); + Assert.assertEquals(Sets.newHashSet(expected), Sets.newHashSet(matchedData)); + } + + private static Record projectFlat(Schema projection, Record record) { + Record result = GenericRecord.create(projection); + List fields = projection.asStruct().fields(); + for (int i = 0; i < fields.size(); i += 1) { + Types.NestedField field = fields.get(i); + result.set(i, record.getField(field.name())); + } + return result; + } + + public static void assertEqualsUnsafe( + Types.StructType struct, List expected, List actual) { + // TODO: match records by ID + int numRecords = Math.min(expected.size(), actual.size()); + for (int i = 0; i < numRecords; i += 1) { + GenericsHelpers.assertEqualsUnsafe(struct, expected.get(i), actual.get(i)); + } + Assert.assertEquals("Number of results should match expected", expected.size(), actual.size()); + } + + public static void assertEqualsSafe( + Types.StructType struct, List expected, List actual) { + // TODO: match records by ID + int numRecords = Math.min(expected.size(), actual.size()); + for (int i = 0; i < numRecords; i += 1) { + GenericsHelpers.assertEqualsSafe(struct, expected.get(i), actual.get(i)); + } + Assert.assertEquals("Number of results should match expected", expected.size(), actual.size()); + } + + private List expected(int... ordinals) { + List expected = Lists.newArrayListWithExpectedSize(ordinals.length); + for (int ord : ordinals) { + expected.add(records.get(ord)); + } + return expected; + } + + private void pushFilters(ScanBuilder scan, Filter... filters) { + Assertions.assertThat(scan).isInstanceOf(SupportsPushDownFilters.class); + SupportsPushDownFilters filterable = (SupportsPushDownFilters) scan; + filterable.pushFilters(filters); + } + + private Table buildPartitionedTable( + String desc, PartitionSpec spec, String udf, String partitionColumn) { + File location = new File(parent, desc); + Table table = TABLES.create(SCHEMA, spec, location.toString()); + + // Do not combine or split files because the tests expect a split per partition. + // A target split size of 2048 helps us achieve that. + table.updateProperties().set("read.split.target-size", "2048").commit(); + + // copy the unpartitioned table into the partitioned table to produce the partitioned data + Dataset allRows = + spark + .read() + .format("iceberg") + .option(SparkReadOptions.VECTORIZATION_ENABLED, String.valueOf(vectorized)) + .load(unpartitioned.toString()); + + allRows + .coalesce(1) // ensure only 1 file per partition is written + .withColumn("part", callUDF(udf, column(partitionColumn))) + .sortWithinPartitions("part") + .drop("part") + .write() + .format("iceberg") + .mode("append") + .save(table.location()); + + table.refresh(); + + return table; + } + + private List testRecords(Schema schema) { + return Lists.newArrayList( + record(schema, 0L, parse("2017-12-22T09:20:44.294658+00:00"), "junction"), + record(schema, 1L, parse("2017-12-22T07:15:34.582910+00:00"), "alligator"), + record(schema, 2L, parse("2017-12-22T06:02:09.243857+00:00"), ""), + record(schema, 3L, parse("2017-12-22T03:10:11.134509+00:00"), "clapping"), + record(schema, 4L, parse("2017-12-22T00:34:00.184671+00:00"), "brush"), + record(schema, 5L, parse("2017-12-21T22:20:08.935889+00:00"), "trap"), + record(schema, 6L, parse("2017-12-21T21:55:30.589712+00:00"), "element"), + record(schema, 7L, parse("2017-12-21T17:31:14.532797+00:00"), "limited"), + record(schema, 8L, parse("2017-12-21T15:21:51.237521+00:00"), "global"), + record(schema, 9L, parse("2017-12-21T15:02:15.230570+00:00"), "goldfish")); + } + + private static List read(String table, boolean vectorized, String expr) { + return read(table, vectorized, expr, "*"); + } + + private static List read( + String table, boolean vectorized, String expr, String select0, String... selectN) { + Dataset dataset = + spark + .read() + .format("iceberg") + .option(SparkReadOptions.VECTORIZATION_ENABLED, String.valueOf(vectorized)) + .load(table) + .filter(expr) + .select(select0, selectN); + return dataset.collectAsList(); + } + + private static OffsetDateTime parse(String timestamp) { + return OffsetDateTime.parse(timestamp); + } + + private static Record record(Schema schema, Object... values) { + Record rec = GenericRecord.create(schema); + for (int i = 0; i < values.length; i += 1) { + rec.set(i, values[i]); + } + return rec; + } +} diff --git a/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/source/TestForwardCompatibility.java b/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/source/TestForwardCompatibility.java new file mode 100644 index 000000000000..fe440235901c --- /dev/null +++ b/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/source/TestForwardCompatibility.java @@ -0,0 +1,224 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.source; + +import static org.apache.iceberg.Files.localInput; +import static org.apache.iceberg.Files.localOutput; + +import java.io.File; +import java.io.IOException; +import java.util.List; +import java.util.UUID; +import java.util.concurrent.TimeoutException; +import org.apache.avro.generic.GenericData; +import org.apache.hadoop.conf.Configuration; +import org.apache.iceberg.AssertHelpers; +import org.apache.iceberg.DataFile; +import org.apache.iceberg.DataFiles; +import org.apache.iceberg.FileFormat; +import org.apache.iceberg.ManifestFiles; +import org.apache.iceberg.ManifestWriter; +import org.apache.iceberg.PartitionSpec; +import org.apache.iceberg.PartitionSpecParser; +import org.apache.iceberg.Schema; +import org.apache.iceberg.Table; +import org.apache.iceberg.TableProperties; +import org.apache.iceberg.hadoop.HadoopTables; +import org.apache.iceberg.io.FileAppender; +import org.apache.iceberg.io.OutputFile; +import org.apache.iceberg.parquet.Parquet; +import org.apache.iceberg.relocated.com.google.common.collect.Lists; +import org.apache.iceberg.spark.data.RandomData; +import org.apache.iceberg.spark.data.TestHelpers; +import org.apache.iceberg.types.Types; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Encoder; +import org.apache.spark.sql.Encoders; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.SQLContext; +import org.apache.spark.sql.SparkSession; +import org.apache.spark.sql.execution.streaming.MemoryStream; +import org.apache.spark.sql.streaming.StreamingQuery; +import org.apache.spark.sql.streaming.StreamingQueryException; +import org.junit.AfterClass; +import org.junit.Assert; +import org.junit.BeforeClass; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.TemporaryFolder; +import scala.Option; +import scala.collection.JavaConverters; + +public class TestForwardCompatibility { + private static final Configuration CONF = new Configuration(); + + private static final Schema SCHEMA = + new Schema( + Types.NestedField.optional(1, "id", Types.LongType.get()), + Types.NestedField.optional(2, "data", Types.StringType.get())); + + // create a spec for the schema that uses a "zero" transform that produces all 0s + private static final PartitionSpec UNKNOWN_SPEC = + PartitionSpecParser.fromJson( + SCHEMA, + "{ \"spec-id\": 0, \"fields\": [ { \"name\": \"id_zero\", \"transform\": \"zero\", \"source-id\": 1 } ] }"); + // create a fake spec to use to write table metadata + private static final PartitionSpec FAKE_SPEC = + PartitionSpecParser.fromJson( + SCHEMA, + "{ \"spec-id\": 0, \"fields\": [ { \"name\": \"id_zero\", \"transform\": \"identity\", \"source-id\": 1 } ] }"); + + @Rule public TemporaryFolder temp = new TemporaryFolder(); + + private static SparkSession spark = null; + + @BeforeClass + public static void startSpark() { + TestForwardCompatibility.spark = SparkSession.builder().master("local[2]").getOrCreate(); + } + + @AfterClass + public static void stopSpark() { + SparkSession currentSpark = TestForwardCompatibility.spark; + TestForwardCompatibility.spark = null; + currentSpark.stop(); + } + + @Test + public void testSparkWriteFailsUnknownTransform() throws IOException { + File parent = temp.newFolder("avro"); + File location = new File(parent, "test"); + File dataFolder = new File(location, "data"); + dataFolder.mkdirs(); + + HadoopTables tables = new HadoopTables(CONF); + tables.create(SCHEMA, UNKNOWN_SPEC, location.toString()); + + List expected = + Lists.newArrayList( + new SimpleRecord(1, "a"), new SimpleRecord(2, "b"), new SimpleRecord(3, "c")); + + Dataset df = spark.createDataFrame(expected, SimpleRecord.class); + + AssertHelpers.assertThrows( + "Should reject write with unsupported transform", + UnsupportedOperationException.class, + "Cannot write using unsupported transforms: zero", + () -> + df.select("id", "data") + .write() + .format("iceberg") + .mode("append") + .save(location.toString())); + } + + @Test + public void testSparkStreamingWriteFailsUnknownTransform() throws IOException, TimeoutException { + File parent = temp.newFolder("avro"); + File location = new File(parent, "test"); + File dataFolder = new File(location, "data"); + dataFolder.mkdirs(); + File checkpoint = new File(parent, "checkpoint"); + checkpoint.mkdirs(); + + HadoopTables tables = new HadoopTables(CONF); + tables.create(SCHEMA, UNKNOWN_SPEC, location.toString()); + + MemoryStream inputStream = newMemoryStream(1, spark.sqlContext(), Encoders.INT()); + StreamingQuery query = + inputStream + .toDF() + .selectExpr("value AS id", "CAST (value AS STRING) AS data") + .writeStream() + .outputMode("append") + .format("iceberg") + .option("checkpointLocation", checkpoint.toString()) + .option("path", location.toString()) + .start(); + + List batch1 = Lists.newArrayList(1, 2); + send(batch1, inputStream); + + AssertHelpers.assertThrows( + "Should reject streaming write with unsupported transform", + StreamingQueryException.class, + "Cannot write using unsupported transforms: zero", + query::processAllAvailable); + } + + @Test + public void testSparkCanReadUnknownTransform() throws IOException { + File parent = temp.newFolder("avro"); + File location = new File(parent, "test"); + File dataFolder = new File(location, "data"); + dataFolder.mkdirs(); + + HadoopTables tables = new HadoopTables(CONF); + Table table = tables.create(SCHEMA, UNKNOWN_SPEC, location.toString()); + + // enable snapshot inheritance to avoid rewriting the manifest with an unknown transform + table.updateProperties().set(TableProperties.SNAPSHOT_ID_INHERITANCE_ENABLED, "true").commit(); + + List expected = RandomData.generateList(table.schema(), 100, 1L); + + File parquetFile = + new File(dataFolder, FileFormat.PARQUET.addExtension(UUID.randomUUID().toString())); + FileAppender writer = + Parquet.write(localOutput(parquetFile)).schema(table.schema()).build(); + try { + writer.addAll(expected); + } finally { + writer.close(); + } + + DataFile file = + DataFiles.builder(FAKE_SPEC) + .withInputFile(localInput(parquetFile)) + .withMetrics(writer.metrics()) + .withPartitionPath("id_zero=0") + .build(); + + OutputFile manifestFile = localOutput(FileFormat.AVRO.addExtension(temp.newFile().toString())); + ManifestWriter manifestWriter = ManifestFiles.write(FAKE_SPEC, manifestFile); + try { + manifestWriter.add(file); + } finally { + manifestWriter.close(); + } + + table.newFastAppend().appendManifest(manifestWriter.toManifestFile()).commit(); + + Dataset df = spark.read().format("iceberg").load(location.toString()); + + List rows = df.collectAsList(); + Assert.assertEquals("Should contain 100 rows", 100, rows.size()); + + for (int i = 0; i < expected.size(); i += 1) { + TestHelpers.assertEqualsSafe(table.schema().asStruct(), expected.get(i), rows.get(i)); + } + } + + private MemoryStream newMemoryStream(int id, SQLContext sqlContext, Encoder encoder) { + return new MemoryStream<>(id, sqlContext, Option.empty(), encoder); + } + + private void send(List records, MemoryStream stream) { + stream.addData(JavaConverters.asScalaBuffer(records)); + } +} diff --git a/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/source/TestIcebergSource.java b/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/source/TestIcebergSource.java new file mode 100644 index 000000000000..a850275118db --- /dev/null +++ b/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/source/TestIcebergSource.java @@ -0,0 +1,42 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.source; + +import org.apache.iceberg.catalog.TableIdentifier; +import org.apache.spark.sql.SparkSession; +import org.apache.spark.sql.connector.catalog.Identifier; +import org.apache.spark.sql.util.CaseInsensitiveStringMap; + +public class TestIcebergSource extends IcebergSource { + @Override + public String shortName() { + return "iceberg-test"; + } + + @Override + public Identifier extractIdentifier(CaseInsensitiveStringMap options) { + TableIdentifier ti = TableIdentifier.parse(options.get("iceberg.table.name")); + return Identifier.of(ti.namespace().levels(), ti.name()); + } + + @Override + public String extractCatalog(CaseInsensitiveStringMap options) { + return SparkSession.active().sessionState().catalogManager().currentCatalog().name(); + } +} diff --git a/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/source/TestIcebergSourceHadoopTables.java b/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/source/TestIcebergSourceHadoopTables.java new file mode 100644 index 000000000000..9bd7220b905a --- /dev/null +++ b/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/source/TestIcebergSourceHadoopTables.java @@ -0,0 +1,72 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.source; + +import java.io.File; +import org.apache.hadoop.conf.Configuration; +import org.apache.iceberg.PartitionSpec; +import org.apache.iceberg.Schema; +import org.apache.iceberg.Table; +import org.apache.iceberg.catalog.TableIdentifier; +import org.apache.iceberg.hadoop.HadoopTables; +import org.junit.Before; + +public class TestIcebergSourceHadoopTables extends TestIcebergSourceTablesBase { + + private static final HadoopTables TABLES = new HadoopTables(new Configuration()); + + File tableDir = null; + String tableLocation = null; + + @Before + public void setupTable() throws Exception { + this.tableDir = temp.newFolder(); + tableDir.delete(); // created by table create + + this.tableLocation = tableDir.toURI().toString(); + } + + @Override + public Table createTable(TableIdentifier ident, Schema schema, PartitionSpec spec) { + if (spec.equals(PartitionSpec.unpartitioned())) { + return TABLES.create(schema, tableLocation); + } + return TABLES.create(schema, spec, tableLocation); + } + + @Override + public void dropTable(TableIdentifier ident) { + TABLES.dropTable(tableLocation); + } + + @Override + public Table loadTable(TableIdentifier ident, String entriesSuffix) { + return TABLES.load(loadLocation(ident, entriesSuffix)); + } + + @Override + public String loadLocation(TableIdentifier ident, String entriesSuffix) { + return String.format("%s#%s", loadLocation(ident), entriesSuffix); + } + + @Override + public String loadLocation(TableIdentifier ident) { + return tableLocation; + } +} diff --git a/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/source/TestIcebergSourceHiveTables.java b/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/source/TestIcebergSourceHiveTables.java new file mode 100644 index 000000000000..6292a2c1a834 --- /dev/null +++ b/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/source/TestIcebergSourceHiveTables.java @@ -0,0 +1,84 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.source; + +import java.io.IOException; +import org.apache.hadoop.fs.FileSystem; +import org.apache.hadoop.fs.Path; +import org.apache.iceberg.PartitionSpec; +import org.apache.iceberg.Schema; +import org.apache.iceberg.Table; +import org.apache.iceberg.catalog.Namespace; +import org.apache.iceberg.catalog.TableIdentifier; +import org.junit.After; +import org.junit.BeforeClass; + +public class TestIcebergSourceHiveTables extends TestIcebergSourceTablesBase { + + private static TableIdentifier currentIdentifier; + + @BeforeClass + public static void start() { + Namespace db = Namespace.of("db"); + if (!catalog.namespaceExists(db)) { + catalog.createNamespace(db); + } + } + + @After + public void dropTable() throws IOException { + if (!catalog.tableExists(currentIdentifier)) { + return; + } + + dropTable(currentIdentifier); + } + + @Override + public Table createTable(TableIdentifier ident, Schema schema, PartitionSpec spec) { + TestIcebergSourceHiveTables.currentIdentifier = ident; + return TestIcebergSourceHiveTables.catalog.createTable(ident, schema, spec); + } + + @Override + public void dropTable(TableIdentifier ident) throws IOException { + Table table = catalog.loadTable(ident); + Path tablePath = new Path(table.location()); + FileSystem fs = tablePath.getFileSystem(spark.sessionState().newHadoopConf()); + fs.delete(tablePath, true); + catalog.dropTable(ident, false); + } + + @Override + public Table loadTable(TableIdentifier ident, String entriesSuffix) { + TableIdentifier identifier = + TableIdentifier.of(ident.namespace().level(0), ident.name(), entriesSuffix); + return TestIcebergSourceHiveTables.catalog.loadTable(identifier); + } + + @Override + public String loadLocation(TableIdentifier ident, String entriesSuffix) { + return String.format("%s.%s", loadLocation(ident), entriesSuffix); + } + + @Override + public String loadLocation(TableIdentifier ident) { + return ident.toString(); + } +} diff --git a/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/source/TestIcebergSourceTablesBase.java b/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/source/TestIcebergSourceTablesBase.java new file mode 100644 index 000000000000..fc023cdfdb51 --- /dev/null +++ b/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/source/TestIcebergSourceTablesBase.java @@ -0,0 +1,1908 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.source; + +import static org.apache.iceberg.ManifestContent.DATA; +import static org.apache.iceberg.ManifestContent.DELETES; +import static org.apache.iceberg.types.Types.NestedField.optional; +import static org.apache.iceberg.types.Types.NestedField.required; + +import java.io.File; +import java.io.IOException; +import java.io.UncheckedIOException; +import java.time.LocalDateTime; +import java.time.ZoneOffset; +import java.util.Comparator; +import java.util.List; +import java.util.StringJoiner; +import java.util.stream.Collectors; +import java.util.stream.Stream; +import java.util.stream.StreamSupport; +import org.apache.avro.generic.GenericData; +import org.apache.avro.generic.GenericRecordBuilder; +import org.apache.iceberg.AssertHelpers; +import org.apache.iceberg.DataFile; +import org.apache.iceberg.DeleteFile; +import org.apache.iceberg.ManifestFile; +import org.apache.iceberg.PartitionSpec; +import org.apache.iceberg.Schema; +import org.apache.iceberg.Snapshot; +import org.apache.iceberg.StructLike; +import org.apache.iceberg.Table; +import org.apache.iceberg.TableProperties; +import org.apache.iceberg.actions.DeleteOrphanFiles; +import org.apache.iceberg.avro.Avro; +import org.apache.iceberg.avro.AvroSchemaUtil; +import org.apache.iceberg.catalog.TableIdentifier; +import org.apache.iceberg.deletes.PositionDelete; +import org.apache.iceberg.deletes.PositionDeleteWriter; +import org.apache.iceberg.encryption.EncryptedOutputFile; +import org.apache.iceberg.expressions.Expressions; +import org.apache.iceberg.io.CloseableIterable; +import org.apache.iceberg.io.InputFile; +import org.apache.iceberg.io.OutputFileFactory; +import org.apache.iceberg.mapping.MappingUtil; +import org.apache.iceberg.mapping.NameMapping; +import org.apache.iceberg.mapping.NameMappingParser; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableList; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap; +import org.apache.iceberg.relocated.com.google.common.collect.Iterables; +import org.apache.iceberg.relocated.com.google.common.collect.Lists; +import org.apache.iceberg.spark.SparkReadOptions; +import org.apache.iceberg.spark.SparkSQLProperties; +import org.apache.iceberg.spark.SparkSchemaUtil; +import org.apache.iceberg.spark.SparkTableUtil; +import org.apache.iceberg.spark.SparkTestBase; +import org.apache.iceberg.spark.SparkWriteOptions; +import org.apache.iceberg.spark.actions.SparkActions; +import org.apache.iceberg.spark.data.TestHelpers; +import org.apache.iceberg.types.Types; +import org.apache.iceberg.util.Pair; +import org.apache.spark.SparkException; +import org.apache.spark.sql.Column; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Encoders; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.RowFactory; +import org.apache.spark.sql.SaveMode; +import org.apache.spark.sql.catalyst.InternalRow; +import org.apache.spark.sql.functions; +import org.apache.spark.sql.internal.SQLConf; +import org.apache.spark.sql.types.StructType; +import org.assertj.core.api.Assertions; +import org.junit.After; +import org.junit.Assert; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.TemporaryFolder; + +public abstract class TestIcebergSourceTablesBase extends SparkTestBase { + + private static final Schema SCHEMA = + new Schema( + optional(1, "id", Types.IntegerType.get()), optional(2, "data", Types.StringType.get())); + + private static final Schema SCHEMA2 = + new Schema( + optional(1, "id", Types.IntegerType.get()), + optional(2, "data", Types.StringType.get()), + optional(3, "category", Types.StringType.get())); + + private static final Schema SCHEMA3 = + new Schema( + optional(1, "id", Types.IntegerType.get()), + optional(3, "category", Types.StringType.get())); + + private static final PartitionSpec SPEC = PartitionSpec.builderFor(SCHEMA).identity("id").build(); + + @Rule public TemporaryFolder temp = new TemporaryFolder(); + + public abstract Table createTable(TableIdentifier ident, Schema schema, PartitionSpec spec); + + public abstract Table loadTable(TableIdentifier ident, String entriesSuffix); + + public abstract String loadLocation(TableIdentifier ident, String entriesSuffix); + + public abstract String loadLocation(TableIdentifier ident); + + public abstract void dropTable(TableIdentifier ident) throws IOException; + + @After + public void removeTable() { + spark.sql("DROP TABLE IF EXISTS parquet_table"); + } + + @Test + public synchronized void testTablesSupport() { + TableIdentifier tableIdentifier = TableIdentifier.of("db", "table"); + createTable(tableIdentifier, SCHEMA, PartitionSpec.unpartitioned()); + + List expectedRecords = + Lists.newArrayList( + new SimpleRecord(1, "1"), new SimpleRecord(2, "2"), new SimpleRecord(3, "3")); + + Dataset inputDf = spark.createDataFrame(expectedRecords, SimpleRecord.class); + inputDf + .select("id", "data") + .write() + .format("iceberg") + .mode(SaveMode.Append) + .save(loadLocation(tableIdentifier)); + + Dataset resultDf = spark.read().format("iceberg").load(loadLocation(tableIdentifier)); + List actualRecords = + resultDf.orderBy("id").as(Encoders.bean(SimpleRecord.class)).collectAsList(); + + Assert.assertEquals("Records should match", expectedRecords, actualRecords); + } + + @Test + public void testEntriesTable() throws Exception { + TableIdentifier tableIdentifier = TableIdentifier.of("db", "entries_test"); + Table table = createTable(tableIdentifier, SCHEMA, PartitionSpec.unpartitioned()); + Table entriesTable = loadTable(tableIdentifier, "entries"); + + List records = Lists.newArrayList(new SimpleRecord(1, "1")); + + Dataset inputDf = spark.createDataFrame(records, SimpleRecord.class); + inputDf + .select("id", "data") + .write() + .format("iceberg") + .mode("append") + .save(loadLocation(tableIdentifier)); + + table.refresh(); + + List actual = + spark + .read() + .format("iceberg") + .load(loadLocation(tableIdentifier, "entries")) + .collectAsList(); + + Snapshot snapshot = table.currentSnapshot(); + + Assert.assertEquals( + "Should only contain one manifest", 1, snapshot.allManifests(table.io()).size()); + + InputFile manifest = table.io().newInputFile(snapshot.allManifests(table.io()).get(0).path()); + List expected = Lists.newArrayList(); + try (CloseableIterable rows = + Avro.read(manifest).project(entriesTable.schema()).build()) { + // each row must inherit snapshot_id and sequence_number + rows.forEach( + row -> { + row.put(2, 0L); // data sequence number + row.put(3, 0L); // file sequence number + GenericData.Record file = (GenericData.Record) row.get("data_file"); + TestHelpers.asMetadataRecord(file); + expected.add(row); + }); + } + + Assert.assertEquals("Entries table should have one row", 1, expected.size()); + Assert.assertEquals("Actual results should have one row", 1, actual.size()); + TestHelpers.assertEqualsSafe(entriesTable.schema().asStruct(), expected.get(0), actual.get(0)); + } + + @Test + public void testEntriesTablePartitionedPrune() throws Exception { + TableIdentifier tableIdentifier = TableIdentifier.of("db", "entries_test"); + Table table = createTable(tableIdentifier, SCHEMA, SPEC); + + List records = Lists.newArrayList(new SimpleRecord(1, "1")); + + Dataset inputDf = spark.createDataFrame(records, SimpleRecord.class); + inputDf + .select("id", "data") + .write() + .format("iceberg") + .mode("append") + .save(loadLocation(tableIdentifier)); + + table.refresh(); + + List actual = + spark + .read() + .format("iceberg") + .load(loadLocation(tableIdentifier, "entries")) + .select("status") + .collectAsList(); + + Assert.assertEquals("Results should contain only one status", 1, actual.size()); + Assert.assertEquals("That status should be Added (1)", 1, actual.get(0).getInt(0)); + } + + @Test + public void testEntriesTableDataFilePrune() throws Exception { + TableIdentifier tableIdentifier = TableIdentifier.of("db", "entries_test"); + Table table = createTable(tableIdentifier, SCHEMA, PartitionSpec.unpartitioned()); + + List records = Lists.newArrayList(new SimpleRecord(1, "1")); + + Dataset inputDf = spark.createDataFrame(records, SimpleRecord.class); + inputDf + .select("id", "data") + .write() + .format("iceberg") + .mode("append") + .save(loadLocation(tableIdentifier)); + + table.refresh(); + DataFile file = table.currentSnapshot().addedDataFiles(table.io()).iterator().next(); + + List singleActual = + rowsToJava( + spark + .read() + .format("iceberg") + .load(loadLocation(tableIdentifier, "entries")) + .select("data_file.file_path") + .collectAsList()); + + List singleExpected = ImmutableList.of(row(file.path())); + + assertEquals( + "Should prune a single element from a nested struct", singleExpected, singleActual); + } + + @Test + public void testEntriesTableDataFilePruneMulti() throws Exception { + TableIdentifier tableIdentifier = TableIdentifier.of("db", "entries_test"); + Table table = createTable(tableIdentifier, SCHEMA, PartitionSpec.unpartitioned()); + + List records = Lists.newArrayList(new SimpleRecord(1, "1")); + + Dataset inputDf = spark.createDataFrame(records, SimpleRecord.class); + inputDf + .select("id", "data") + .write() + .format("iceberg") + .mode("append") + .save(loadLocation(tableIdentifier)); + + table.refresh(); + DataFile file = table.currentSnapshot().addedDataFiles(table.io()).iterator().next(); + + List multiActual = + rowsToJava( + spark + .read() + .format("iceberg") + .load(loadLocation(tableIdentifier, "entries")) + .select( + "data_file.file_path", + "data_file.value_counts", + "data_file.record_count", + "data_file.column_sizes") + .collectAsList()); + + List multiExpected = + ImmutableList.of( + row(file.path(), file.valueCounts(), file.recordCount(), file.columnSizes())); + + assertEquals("Should prune a single element from a nested struct", multiExpected, multiActual); + } + + @Test + public void testFilesSelectMap() throws Exception { + TableIdentifier tableIdentifier = TableIdentifier.of("db", "entries_test"); + Table table = createTable(tableIdentifier, SCHEMA, PartitionSpec.unpartitioned()); + + List records = Lists.newArrayList(new SimpleRecord(1, "1")); + + Dataset inputDf = spark.createDataFrame(records, SimpleRecord.class); + inputDf + .select("id", "data") + .write() + .format("iceberg") + .mode("append") + .save(loadLocation(tableIdentifier)); + + table.refresh(); + DataFile file = table.currentSnapshot().addedDataFiles(table.io()).iterator().next(); + + List multiActual = + rowsToJava( + spark + .read() + .format("iceberg") + .load(loadLocation(tableIdentifier, "files")) + .select("file_path", "value_counts", "record_count", "column_sizes") + .collectAsList()); + + List multiExpected = + ImmutableList.of( + row(file.path(), file.valueCounts(), file.recordCount(), file.columnSizes())); + + assertEquals("Should prune a single element from a row", multiExpected, multiActual); + } + + @Test + public void testAllEntriesTable() throws Exception { + TableIdentifier tableIdentifier = TableIdentifier.of("db", "entries_test"); + Table table = createTable(tableIdentifier, SCHEMA, PartitionSpec.unpartitioned()); + Table entriesTable = loadTable(tableIdentifier, "all_entries"); + + Dataset df1 = + spark.createDataFrame(Lists.newArrayList(new SimpleRecord(1, "a")), SimpleRecord.class); + Dataset df2 = + spark.createDataFrame(Lists.newArrayList(new SimpleRecord(1, "b")), SimpleRecord.class); + + df1.select("id", "data") + .write() + .format("iceberg") + .mode("append") + .save(loadLocation(tableIdentifier)); + + // delete the first file to test that not only live files are listed + table.newDelete().deleteFromRowFilter(Expressions.equal("id", 1)).commit(); + + // add a second file + df2.select("id", "data") + .write() + .format("iceberg") + .mode("append") + .save(loadLocation(tableIdentifier)); + + // ensure table data isn't stale + table.refresh(); + + List actual = + spark + .read() + .format("iceberg") + .load(loadLocation(tableIdentifier, "all_entries")) + .orderBy("snapshot_id") + .collectAsList(); + + List expected = Lists.newArrayList(); + for (ManifestFile manifest : + Iterables.concat(Iterables.transform(table.snapshots(), s -> s.allManifests(table.io())))) { + InputFile in = table.io().newInputFile(manifest.path()); + try (CloseableIterable rows = + Avro.read(in).project(entriesTable.schema()).build()) { + // each row must inherit snapshot_id and sequence_number + rows.forEach( + row -> { + row.put(2, 0L); // data sequence number + row.put(3, 0L); // file sequence number + GenericData.Record file = (GenericData.Record) row.get("data_file"); + TestHelpers.asMetadataRecord(file); + expected.add(row); + }); + } + } + + expected.sort(Comparator.comparing(o -> (Long) o.get("snapshot_id"))); + + Assert.assertEquals("Entries table should have 3 rows", 3, expected.size()); + Assert.assertEquals("Actual results should have 3 rows", 3, actual.size()); + for (int i = 0; i < expected.size(); i += 1) { + TestHelpers.assertEqualsSafe( + entriesTable.schema().asStruct(), expected.get(i), actual.get(i)); + } + } + + @Test + public void testCountEntriesTable() { + TableIdentifier tableIdentifier = TableIdentifier.of("db", "count_entries_test"); + createTable(tableIdentifier, SCHEMA, PartitionSpec.unpartitioned()); + + // init load + List records = Lists.newArrayList(new SimpleRecord(1, "1")); + Dataset inputDf = spark.createDataFrame(records, SimpleRecord.class); + inputDf + .select("id", "data") + .write() + .format("iceberg") + .mode("append") + .save(loadLocation(tableIdentifier)); + + final int expectedEntryCount = 1; + + // count entries + Assert.assertEquals( + "Count should return " + expectedEntryCount, + expectedEntryCount, + spark.read().format("iceberg").load(loadLocation(tableIdentifier, "entries")).count()); + + // count all_entries + Assert.assertEquals( + "Count should return " + expectedEntryCount, + expectedEntryCount, + spark.read().format("iceberg").load(loadLocation(tableIdentifier, "all_entries")).count()); + } + + @Test + public void testFilesTable() throws Exception { + TableIdentifier tableIdentifier = TableIdentifier.of("db", "files_test"); + Table table = createTable(tableIdentifier, SCHEMA, SPEC); + Table entriesTable = loadTable(tableIdentifier, "entries"); + + Dataset df1 = + spark.createDataFrame(Lists.newArrayList(new SimpleRecord(1, "a")), SimpleRecord.class); + Dataset df2 = + spark.createDataFrame(Lists.newArrayList(new SimpleRecord(2, "b")), SimpleRecord.class); + + df1.select("id", "data") + .write() + .format("iceberg") + .mode("append") + .save(loadLocation(tableIdentifier)); + + // add a second file + df2.select("id", "data") + .write() + .format("iceberg") + .mode("append") + .save(loadLocation(tableIdentifier)); + + // delete the first file to test that only live files are listed + table.newDelete().deleteFromRowFilter(Expressions.equal("id", 1)).commit(); + + Dataset filesTableDs = + spark.read().format("iceberg").load(loadLocation(tableIdentifier, "files")); + List actual = TestHelpers.selectNonDerived(filesTableDs).collectAsList(); + + List expected = Lists.newArrayList(); + for (ManifestFile manifest : table.currentSnapshot().dataManifests(table.io())) { + InputFile in = table.io().newInputFile(manifest.path()); + try (CloseableIterable rows = + Avro.read(in).project(entriesTable.schema()).build()) { + for (GenericData.Record record : rows) { + if ((Integer) record.get("status") < 2 /* added or existing */) { + GenericData.Record file = (GenericData.Record) record.get("data_file"); + TestHelpers.asMetadataRecord(file); + expected.add(file); + } + } + } + } + + Assert.assertEquals("Files table should have one row", 1, expected.size()); + Assert.assertEquals("Actual results should have one row", 1, actual.size()); + + TestHelpers.assertEqualsSafe( + TestHelpers.nonDerivedSchema(filesTableDs), expected.get(0), actual.get(0)); + } + + @Test + public void testFilesTableWithSnapshotIdInheritance() throws Exception { + TableIdentifier tableIdentifier = TableIdentifier.of("db", "files_inheritance_test"); + Table table = createTable(tableIdentifier, SCHEMA, SPEC); + table.updateProperties().set(TableProperties.SNAPSHOT_ID_INHERITANCE_ENABLED, "true").commit(); + Table entriesTable = loadTable(tableIdentifier, "entries"); + + spark.sql( + String.format( + "CREATE TABLE parquet_table (data string, id int) " + + "USING parquet PARTITIONED BY (id) LOCATION '%s'", + temp.newFolder())); + + List records = + Lists.newArrayList(new SimpleRecord(1, "a"), new SimpleRecord(2, "b")); + + Dataset inputDF = spark.createDataFrame(records, SimpleRecord.class); + inputDF.select("data", "id").write().mode("overwrite").insertInto("parquet_table"); + + NameMapping mapping = MappingUtil.create(table.schema()); + String mappingJson = NameMappingParser.toJson(mapping); + + table.updateProperties().set(TableProperties.DEFAULT_NAME_MAPPING, mappingJson).commit(); + + String stagingLocation = table.location() + "/metadata"; + SparkTableUtil.importSparkTable( + spark, + new org.apache.spark.sql.catalyst.TableIdentifier("parquet_table"), + table, + stagingLocation); + + Dataset filesTableDs = + spark.read().format("iceberg").load(loadLocation(tableIdentifier, "files")); + List actual = TestHelpers.selectNonDerived(filesTableDs).collectAsList(); + + List expected = Lists.newArrayList(); + for (ManifestFile manifest : table.currentSnapshot().dataManifests(table.io())) { + InputFile in = table.io().newInputFile(manifest.path()); + try (CloseableIterable rows = + Avro.read(in).project(entriesTable.schema()).build()) { + for (GenericData.Record record : rows) { + GenericData.Record file = (GenericData.Record) record.get("data_file"); + TestHelpers.asMetadataRecord(file); + expected.add(file); + } + } + } + + Types.StructType struct = TestHelpers.nonDerivedSchema(filesTableDs); + Assert.assertEquals("Files table should have one row", 2, expected.size()); + Assert.assertEquals("Actual results should have one row", 2, actual.size()); + TestHelpers.assertEqualsSafe(struct, expected.get(0), actual.get(0)); + TestHelpers.assertEqualsSafe(struct, expected.get(1), actual.get(1)); + } + + @Test + public void testEntriesTableWithSnapshotIdInheritance() throws Exception { + TableIdentifier tableIdentifier = TableIdentifier.of("db", "entries_inheritance_test"); + PartitionSpec spec = SPEC; + Table table = createTable(tableIdentifier, SCHEMA, spec); + + table.updateProperties().set(TableProperties.SNAPSHOT_ID_INHERITANCE_ENABLED, "true").commit(); + + spark.sql( + String.format( + "CREATE TABLE parquet_table (data string, id int) " + + "USING parquet PARTITIONED BY (id) LOCATION '%s'", + temp.newFolder())); + + List records = + Lists.newArrayList(new SimpleRecord(1, "a"), new SimpleRecord(2, "b")); + + Dataset inputDF = spark.createDataFrame(records, SimpleRecord.class); + inputDF.select("data", "id").write().mode("overwrite").insertInto("parquet_table"); + + String stagingLocation = table.location() + "/metadata"; + SparkTableUtil.importSparkTable( + spark, + new org.apache.spark.sql.catalyst.TableIdentifier("parquet_table"), + table, + stagingLocation); + + List actual = + spark + .read() + .format("iceberg") + .load(loadLocation(tableIdentifier, "entries")) + .select("sequence_number", "snapshot_id", "data_file") + .collectAsList(); + + table.refresh(); + + long snapshotId = table.currentSnapshot().snapshotId(); + + Assert.assertEquals("Entries table should have 2 rows", 2, actual.size()); + Assert.assertEquals("Sequence number must match", 0, actual.get(0).getLong(0)); + Assert.assertEquals("Snapshot id must match", snapshotId, actual.get(0).getLong(1)); + Assert.assertEquals("Sequence number must match", 0, actual.get(1).getLong(0)); + Assert.assertEquals("Snapshot id must match", snapshotId, actual.get(1).getLong(1)); + } + + @Test + public void testFilesUnpartitionedTable() throws Exception { + TableIdentifier tableIdentifier = TableIdentifier.of("db", "unpartitioned_files_test"); + Table table = createTable(tableIdentifier, SCHEMA, PartitionSpec.unpartitioned()); + Table entriesTable = loadTable(tableIdentifier, "entries"); + + Dataset df1 = + spark.createDataFrame(Lists.newArrayList(new SimpleRecord(1, "a")), SimpleRecord.class); + Dataset df2 = + spark.createDataFrame(Lists.newArrayList(new SimpleRecord(2, "b")), SimpleRecord.class); + + df1.select("id", "data") + .write() + .format("iceberg") + .mode("append") + .save(loadLocation(tableIdentifier)); + + table.refresh(); + DataFile toDelete = + Iterables.getOnlyElement(table.currentSnapshot().addedDataFiles(table.io())); + + // add a second file + df2.select("id", "data") + .write() + .format("iceberg") + .mode("append") + .save(loadLocation(tableIdentifier)); + + // delete the first file to test that only live files are listed + table.newDelete().deleteFile(toDelete).commit(); + + Dataset filesTableDs = + spark.read().format("iceberg").load(loadLocation(tableIdentifier, "files")); + List actual = TestHelpers.selectNonDerived(filesTableDs).collectAsList(); + + List expected = Lists.newArrayList(); + for (ManifestFile manifest : table.currentSnapshot().dataManifests(table.io())) { + InputFile in = table.io().newInputFile(manifest.path()); + try (CloseableIterable rows = + Avro.read(in).project(entriesTable.schema()).build()) { + for (GenericData.Record record : rows) { + if ((Integer) record.get("status") < 2 /* added or existing */) { + GenericData.Record file = (GenericData.Record) record.get("data_file"); + TestHelpers.asMetadataRecord(file); + expected.add(file); + } + } + } + } + + Assert.assertEquals("Files table should have one row", 1, expected.size()); + Assert.assertEquals("Actual results should have one row", 1, actual.size()); + TestHelpers.assertEqualsSafe( + TestHelpers.nonDerivedSchema(filesTableDs), expected.get(0), actual.get(0)); + } + + @Test + public void testAllMetadataTablesWithStagedCommits() throws Exception { + TableIdentifier tableIdentifier = TableIdentifier.of("db", "stage_aggregate_table_test"); + Table table = createTable(tableIdentifier, SCHEMA, SPEC); + + table.updateProperties().set(TableProperties.WRITE_AUDIT_PUBLISH_ENABLED, "true").commit(); + spark.conf().set(SparkSQLProperties.WAP_ID, "1234567"); + Dataset df1 = + spark.createDataFrame(Lists.newArrayList(new SimpleRecord(1, "a")), SimpleRecord.class); + Dataset df2 = + spark.createDataFrame(Lists.newArrayList(new SimpleRecord(2, "b")), SimpleRecord.class); + + df1.select("id", "data") + .write() + .format("iceberg") + .mode("append") + .save(loadLocation(tableIdentifier)); + + // add a second file + df2.select("id", "data") + .write() + .format("iceberg") + .mode("append") + .save(loadLocation(tableIdentifier)); + + List actualAllData = + spark + .read() + .format("iceberg") + .load(loadLocation(tableIdentifier, "all_data_files")) + .collectAsList(); + + List actualAllManifests = + spark + .read() + .format("iceberg") + .load(loadLocation(tableIdentifier, "all_manifests")) + .collectAsList(); + + List actualAllEntries = + spark + .read() + .format("iceberg") + .load(loadLocation(tableIdentifier, "all_entries")) + .collectAsList(); + + Assert.assertTrue( + "Stage table should have some snapshots", table.snapshots().iterator().hasNext()); + Assert.assertEquals( + "Stage table should have null currentSnapshot", null, table.currentSnapshot()); + Assert.assertEquals("Actual results should have two rows", 2, actualAllData.size()); + Assert.assertEquals("Actual results should have two rows", 2, actualAllManifests.size()); + Assert.assertEquals("Actual results should have two rows", 2, actualAllEntries.size()); + } + + @Test + public void testAllDataFilesTable() throws Exception { + TableIdentifier tableIdentifier = TableIdentifier.of("db", "files_test"); + Table table = createTable(tableIdentifier, SCHEMA, SPEC); + Table entriesTable = loadTable(tableIdentifier, "entries"); + + Dataset df1 = + spark.createDataFrame(Lists.newArrayList(new SimpleRecord(1, "a")), SimpleRecord.class); + Dataset df2 = + spark.createDataFrame(Lists.newArrayList(new SimpleRecord(2, "b")), SimpleRecord.class); + + df1.select("id", "data") + .write() + .format("iceberg") + .mode("append") + .save(loadLocation(tableIdentifier)); + + // delete the first file to test that not only live files are listed + table.newDelete().deleteFromRowFilter(Expressions.equal("id", 1)).commit(); + + // add a second file + df2.select("id", "data") + .write() + .format("iceberg") + .mode("append") + .save(loadLocation(tableIdentifier)); + + // ensure table data isn't stale + table.refresh(); + + Dataset filesTableDs = + spark.read().format("iceberg").load(loadLocation(tableIdentifier, "all_data_files")); + List actual = TestHelpers.selectNonDerived(filesTableDs).collectAsList(); + actual.sort(Comparator.comparing(o -> o.getString(1))); + + List expected = Lists.newArrayList(); + Iterable dataManifests = + Iterables.concat( + Iterables.transform(table.snapshots(), snapshot -> snapshot.dataManifests(table.io()))); + for (ManifestFile manifest : dataManifests) { + InputFile in = table.io().newInputFile(manifest.path()); + try (CloseableIterable rows = + Avro.read(in).project(entriesTable.schema()).build()) { + for (GenericData.Record record : rows) { + if ((Integer) record.get("status") < 2 /* added or existing */) { + GenericData.Record file = (GenericData.Record) record.get("data_file"); + TestHelpers.asMetadataRecord(file); + expected.add(file); + } + } + } + } + + expected.sort(Comparator.comparing(o -> o.get("file_path").toString())); + + Assert.assertEquals("Files table should have two rows", 2, expected.size()); + Assert.assertEquals("Actual results should have two rows", 2, actual.size()); + for (int i = 0; i < expected.size(); i += 1) { + TestHelpers.assertEqualsSafe( + TestHelpers.nonDerivedSchema(filesTableDs), expected.get(i), actual.get(i)); + } + } + + @Test + public void testHistoryTable() { + TableIdentifier tableIdentifier = TableIdentifier.of("db", "history_test"); + Table table = createTable(tableIdentifier, SCHEMA, PartitionSpec.unpartitioned()); + Table historyTable = loadTable(tableIdentifier, "history"); + + List records = Lists.newArrayList(new SimpleRecord(1, "1")); + Dataset inputDf = spark.createDataFrame(records, SimpleRecord.class); + + inputDf + .select("id", "data") + .write() + .format("iceberg") + .mode("append") + .save(loadLocation(tableIdentifier)); + + table.refresh(); + long firstSnapshotTimestamp = table.currentSnapshot().timestampMillis(); + long firstSnapshotId = table.currentSnapshot().snapshotId(); + + inputDf + .select("id", "data") + .write() + .format("iceberg") + .mode("append") + .save(loadLocation(tableIdentifier)); + + table.refresh(); + long secondSnapshotTimestamp = table.currentSnapshot().timestampMillis(); + long secondSnapshotId = table.currentSnapshot().snapshotId(); + + // rollback the table state to the first snapshot + table.manageSnapshots().rollbackTo(firstSnapshotId).commit(); + long rollbackTimestamp = Iterables.getLast(table.history()).timestampMillis(); + + inputDf + .select("id", "data") + .write() + .format("iceberg") + .mode("append") + .save(loadLocation(tableIdentifier)); + + table.refresh(); + long thirdSnapshotTimestamp = table.currentSnapshot().timestampMillis(); + long thirdSnapshotId = table.currentSnapshot().snapshotId(); + + List actual = + spark + .read() + .format("iceberg") + .load(loadLocation(tableIdentifier, "history")) + .collectAsList(); + + GenericRecordBuilder builder = + new GenericRecordBuilder(AvroSchemaUtil.convert(historyTable.schema(), "history")); + List expected = + Lists.newArrayList( + builder + .set("made_current_at", firstSnapshotTimestamp * 1000) + .set("snapshot_id", firstSnapshotId) + .set("parent_id", null) + .set("is_current_ancestor", true) + .build(), + builder + .set("made_current_at", secondSnapshotTimestamp * 1000) + .set("snapshot_id", secondSnapshotId) + .set("parent_id", firstSnapshotId) + .set( + "is_current_ancestor", + false) // commit rolled back, not an ancestor of the current table state + .build(), + builder + .set("made_current_at", rollbackTimestamp * 1000) + .set("snapshot_id", firstSnapshotId) + .set("parent_id", null) + .set("is_current_ancestor", true) + .build(), + builder + .set("made_current_at", thirdSnapshotTimestamp * 1000) + .set("snapshot_id", thirdSnapshotId) + .set("parent_id", firstSnapshotId) + .set("is_current_ancestor", true) + .build()); + + Assert.assertEquals("History table should have a row for each commit", 4, actual.size()); + TestHelpers.assertEqualsSafe(historyTable.schema().asStruct(), expected.get(0), actual.get(0)); + TestHelpers.assertEqualsSafe(historyTable.schema().asStruct(), expected.get(1), actual.get(1)); + TestHelpers.assertEqualsSafe(historyTable.schema().asStruct(), expected.get(2), actual.get(2)); + TestHelpers.assertEqualsSafe(historyTable.schema().asStruct(), expected.get(3), actual.get(3)); + } + + @Test + public void testSnapshotsTable() { + TableIdentifier tableIdentifier = TableIdentifier.of("db", "snapshots_test"); + Table table = createTable(tableIdentifier, SCHEMA, PartitionSpec.unpartitioned()); + Table snapTable = loadTable(tableIdentifier, "snapshots"); + + List records = Lists.newArrayList(new SimpleRecord(1, "1")); + Dataset inputDf = spark.createDataFrame(records, SimpleRecord.class); + + inputDf + .select("id", "data") + .write() + .format("iceberg") + .mode("append") + .save(loadLocation(tableIdentifier)); + + table.refresh(); + long firstSnapshotTimestamp = table.currentSnapshot().timestampMillis(); + long firstSnapshotId = table.currentSnapshot().snapshotId(); + String firstManifestList = table.currentSnapshot().manifestListLocation(); + + table.newDelete().deleteFromRowFilter(Expressions.alwaysTrue()).commit(); + + long secondSnapshotTimestamp = table.currentSnapshot().timestampMillis(); + long secondSnapshotId = table.currentSnapshot().snapshotId(); + String secondManifestList = table.currentSnapshot().manifestListLocation(); + + // rollback the table state to the first snapshot + table.manageSnapshots().rollbackTo(firstSnapshotId).commit(); + + List actual = + spark + .read() + .format("iceberg") + .load(loadLocation(tableIdentifier, "snapshots")) + .collectAsList(); + + GenericRecordBuilder builder = + new GenericRecordBuilder(AvroSchemaUtil.convert(snapTable.schema(), "snapshots")); + List expected = + Lists.newArrayList( + builder + .set("committed_at", firstSnapshotTimestamp * 1000) + .set("snapshot_id", firstSnapshotId) + .set("parent_id", null) + .set("operation", "append") + .set("manifest_list", firstManifestList) + .set( + "summary", + ImmutableMap.of( + "added-records", "1", + "added-data-files", "1", + "changed-partition-count", "1", + "total-data-files", "1", + "total-records", "1")) + .build(), + builder + .set("committed_at", secondSnapshotTimestamp * 1000) + .set("snapshot_id", secondSnapshotId) + .set("parent_id", firstSnapshotId) + .set("operation", "delete") + .set("manifest_list", secondManifestList) + .set( + "summary", + ImmutableMap.of( + "deleted-records", "1", + "deleted-data-files", "1", + "changed-partition-count", "1", + "total-records", "0", + "total-data-files", "0")) + .build()); + + Assert.assertEquals("Snapshots table should have a row for each snapshot", 2, actual.size()); + TestHelpers.assertEqualsSafe(snapTable.schema().asStruct(), expected.get(0), actual.get(0)); + TestHelpers.assertEqualsSafe(snapTable.schema().asStruct(), expected.get(1), actual.get(1)); + } + + @Test + public void testPrunedSnapshotsTable() { + TableIdentifier tableIdentifier = TableIdentifier.of("db", "snapshots_test"); + Table table = createTable(tableIdentifier, SCHEMA, PartitionSpec.unpartitioned()); + + List records = Lists.newArrayList(new SimpleRecord(1, "1")); + Dataset inputDf = spark.createDataFrame(records, SimpleRecord.class); + + inputDf + .select("id", "data") + .write() + .format("iceberg") + .mode("append") + .save(loadLocation(tableIdentifier)); + + table.refresh(); + long firstSnapshotTimestamp = table.currentSnapshot().timestampMillis(); + long firstSnapshotId = table.currentSnapshot().snapshotId(); + + table.newDelete().deleteFromRowFilter(Expressions.alwaysTrue()).commit(); + + long secondSnapshotTimestamp = table.currentSnapshot().timestampMillis(); + + // rollback the table state to the first snapshot + table.manageSnapshots().rollbackTo(firstSnapshotId).commit(); + + Dataset actualDf = + spark + .read() + .format("iceberg") + .load(loadLocation(tableIdentifier, "snapshots")) + .select("operation", "committed_at", "summary", "parent_id"); + + Schema projectedSchema = SparkSchemaUtil.convert(actualDf.schema()); + + List actual = actualDf.collectAsList(); + + GenericRecordBuilder builder = + new GenericRecordBuilder(AvroSchemaUtil.convert(projectedSchema, "snapshots")); + List expected = + Lists.newArrayList( + builder + .set("committed_at", firstSnapshotTimestamp * 1000) + .set("parent_id", null) + .set("operation", "append") + .set( + "summary", + ImmutableMap.of( + "added-records", "1", + "added-data-files", "1", + "changed-partition-count", "1", + "total-data-files", "1", + "total-records", "1")) + .build(), + builder + .set("committed_at", secondSnapshotTimestamp * 1000) + .set("parent_id", firstSnapshotId) + .set("operation", "delete") + .set( + "summary", + ImmutableMap.of( + "deleted-records", "1", + "deleted-data-files", "1", + "changed-partition-count", "1", + "total-records", "0", + "total-data-files", "0")) + .build()); + + Assert.assertEquals("Snapshots table should have a row for each snapshot", 2, actual.size()); + TestHelpers.assertEqualsSafe(projectedSchema.asStruct(), expected.get(0), actual.get(0)); + TestHelpers.assertEqualsSafe(projectedSchema.asStruct(), expected.get(1), actual.get(1)); + } + + @Test + public void testManifestsTable() { + TableIdentifier tableIdentifier = TableIdentifier.of("db", "manifests_test"); + Table table = createTable(tableIdentifier, SCHEMA, SPEC); + Table manifestTable = loadTable(tableIdentifier, "manifests"); + Dataset df1 = + spark.createDataFrame( + Lists.newArrayList(new SimpleRecord(1, "a"), new SimpleRecord(null, "b")), + SimpleRecord.class); + + df1.select("id", "data") + .write() + .format("iceberg") + .mode("append") + .option(SparkWriteOptions.DISTRIBUTION_MODE, TableProperties.WRITE_DISTRIBUTION_MODE_NONE) + .save(loadLocation(tableIdentifier)); + + table.updateProperties().set(TableProperties.FORMAT_VERSION, "2").commit(); + + DataFile dataFile = + Iterables.getFirst(table.currentSnapshot().addedDataFiles(table.io()), null); + PartitionSpec dataFileSpec = table.specs().get(dataFile.specId()); + StructLike dataFilePartition = dataFile.partition(); + + PositionDelete delete = PositionDelete.create(); + delete.set(dataFile.path(), 0L, null); + + DeleteFile deleteFile = + writePositionDeletes(table, dataFileSpec, dataFilePartition, ImmutableList.of(delete)); + + table.newRowDelta().addDeletes(deleteFile).commit(); + + List actual = + spark + .read() + .format("iceberg") + .load(loadLocation(tableIdentifier, "manifests")) + .collectAsList(); + + table.refresh(); + + GenericRecordBuilder builder = + new GenericRecordBuilder(AvroSchemaUtil.convert(manifestTable.schema(), "manifests")); + GenericRecordBuilder summaryBuilder = + new GenericRecordBuilder( + AvroSchemaUtil.convert( + manifestTable.schema().findType("partition_summaries.element").asStructType(), + "partition_summary")); + List expected = + Lists.transform( + table.currentSnapshot().allManifests(table.io()), + manifest -> + builder + .set("content", manifest.content().id()) + .set("path", manifest.path()) + .set("length", manifest.length()) + .set("partition_spec_id", manifest.partitionSpecId()) + .set("added_snapshot_id", manifest.snapshotId()) + .set( + "added_data_files_count", + manifest.content() == DATA ? manifest.addedFilesCount() : 0) + .set( + "existing_data_files_count", + manifest.content() == DATA ? manifest.existingFilesCount() : 0) + .set( + "deleted_data_files_count", + manifest.content() == DATA ? manifest.deletedFilesCount() : 0) + .set( + "added_delete_files_count", + manifest.content() == DELETES ? manifest.addedFilesCount() : 0) + .set( + "existing_delete_files_count", + manifest.content() == DELETES ? manifest.existingFilesCount() : 0) + .set( + "deleted_delete_files_count", + manifest.content() == DELETES ? manifest.deletedFilesCount() : 0) + .set( + "partition_summaries", + Lists.transform( + manifest.partitions(), + partition -> + summaryBuilder + .set("contains_null", manifest.content() == DATA) + .set("contains_nan", false) + .set("lower_bound", "1") + .set("upper_bound", "1") + .build())) + .build()); + + Assert.assertEquals("Manifests table should have two manifest rows", 2, actual.size()); + TestHelpers.assertEqualsSafe(manifestTable.schema().asStruct(), expected.get(0), actual.get(0)); + TestHelpers.assertEqualsSafe(manifestTable.schema().asStruct(), expected.get(1), actual.get(1)); + } + + @Test + public void testPruneManifestsTable() { + TableIdentifier tableIdentifier = TableIdentifier.of("db", "manifests_test"); + Table table = createTable(tableIdentifier, SCHEMA, SPEC); + Table manifestTable = loadTable(tableIdentifier, "manifests"); + Dataset df1 = + spark.createDataFrame( + Lists.newArrayList(new SimpleRecord(1, "a"), new SimpleRecord(null, "b")), + SimpleRecord.class); + + df1.select("id", "data") + .write() + .format("iceberg") + .mode("append") + .save(loadLocation(tableIdentifier)); + + if (!spark.version().startsWith("2")) { + // Spark 2 isn't able to actually push down nested struct projections so this will not break + AssertHelpers.assertThrows( + "Can't prune struct inside list", + SparkException.class, + "Cannot project a partial list element struct", + () -> + spark + .read() + .format("iceberg") + .load(loadLocation(tableIdentifier, "manifests")) + .select("partition_spec_id", "path", "partition_summaries.contains_null") + .collectAsList()); + } + + Dataset actualDf = + spark + .read() + .format("iceberg") + .load(loadLocation(tableIdentifier, "manifests")) + .select("partition_spec_id", "path", "partition_summaries"); + + Schema projectedSchema = SparkSchemaUtil.convert(actualDf.schema()); + + List actual = + spark + .read() + .format("iceberg") + .load(loadLocation(tableIdentifier, "manifests")) + .select("partition_spec_id", "path", "partition_summaries") + .collectAsList(); + + table.refresh(); + + GenericRecordBuilder builder = + new GenericRecordBuilder(AvroSchemaUtil.convert(projectedSchema.asStruct())); + GenericRecordBuilder summaryBuilder = + new GenericRecordBuilder( + AvroSchemaUtil.convert( + projectedSchema.findType("partition_summaries.element").asStructType(), + "partition_summary")); + List expected = + Lists.transform( + table.currentSnapshot().allManifests(table.io()), + manifest -> + builder + .set("partition_spec_id", manifest.partitionSpecId()) + .set("path", manifest.path()) + .set( + "partition_summaries", + Lists.transform( + manifest.partitions(), + partition -> + summaryBuilder + .set("contains_null", true) + .set("contains_nan", false) + .set("lower_bound", "1") + .set("upper_bound", "1") + .build())) + .build()); + + Assert.assertEquals("Manifests table should have one manifest row", 1, actual.size()); + TestHelpers.assertEqualsSafe(projectedSchema.asStruct(), expected.get(0), actual.get(0)); + } + + @Test + public void testAllManifestsTable() { + TableIdentifier tableIdentifier = TableIdentifier.of("db", "manifests_test"); + Table table = createTable(tableIdentifier, SCHEMA, SPEC); + Table manifestTable = loadTable(tableIdentifier, "all_manifests"); + Dataset df1 = + spark.createDataFrame(Lists.newArrayList(new SimpleRecord(1, "a")), SimpleRecord.class); + + df1.select("id", "data") + .write() + .format("iceberg") + .mode("append") + .save(loadLocation(tableIdentifier)); + + table.updateProperties().set(TableProperties.FORMAT_VERSION, "2").commit(); + + DataFile dataFile = + Iterables.getFirst(table.currentSnapshot().addedDataFiles(table.io()), null); + PartitionSpec dataFileSpec = table.specs().get(dataFile.specId()); + StructLike dataFilePartition = dataFile.partition(); + + PositionDelete delete = PositionDelete.create(); + delete.set(dataFile.path(), 0L, null); + + DeleteFile deleteFile = + writePositionDeletes(table, dataFileSpec, dataFilePartition, ImmutableList.of(delete)); + + table.newRowDelta().addDeletes(deleteFile).commit(); + + table.newDelete().deleteFromRowFilter(Expressions.alwaysTrue()).commit(); + + Stream> snapshotIdToManifests = + StreamSupport.stream(table.snapshots().spliterator(), false) + .flatMap( + snapshot -> + snapshot.allManifests(table.io()).stream() + .map(manifest -> Pair.of(snapshot.snapshotId(), manifest))); + + List actual = + spark + .read() + .format("iceberg") + .load(loadLocation(tableIdentifier, "all_manifests")) + .orderBy("path") + .collectAsList(); + + table.refresh(); + + List expected = + snapshotIdToManifests + .map( + snapshotManifest -> + manifestRecord( + manifestTable, snapshotManifest.first(), snapshotManifest.second())) + .collect(Collectors.toList()); + expected.sort(Comparator.comparing(o -> o.get("path").toString())); + + Assert.assertEquals("Manifests table should have 5 manifest rows", 5, actual.size()); + for (int i = 0; i < expected.size(); i += 1) { + TestHelpers.assertEqualsSafe( + manifestTable.schema().asStruct(), expected.get(i), actual.get(i)); + } + } + + @Test + public void testUnpartitionedPartitionsTable() { + TableIdentifier tableIdentifier = TableIdentifier.of("db", "unpartitioned_partitions_test"); + createTable(tableIdentifier, SCHEMA, PartitionSpec.unpartitioned()); + + Dataset df = + spark.createDataFrame(Lists.newArrayList(new SimpleRecord(1, "a")), SimpleRecord.class); + + df.select("id", "data") + .write() + .format("iceberg") + .mode("append") + .save(loadLocation(tableIdentifier)); + + Types.StructType expectedSchema = + Types.StructType.of( + required(2, "record_count", Types.LongType.get(), "Count of records in data files"), + required(3, "file_count", Types.IntegerType.get(), "Count of data files")); + + Table partitionsTable = loadTable(tableIdentifier, "partitions"); + + Assert.assertEquals( + "Schema should not have partition field", + expectedSchema, + partitionsTable.schema().asStruct()); + + GenericRecordBuilder builder = + new GenericRecordBuilder(AvroSchemaUtil.convert(partitionsTable.schema(), "partitions")); + GenericData.Record expectedRow = builder.set("record_count", 1L).set("file_count", 1).build(); + + List actual = + spark + .read() + .format("iceberg") + .load(loadLocation(tableIdentifier, "partitions")) + .collectAsList(); + + Assert.assertEquals("Unpartitioned partitions table should have one row", 1, actual.size()); + TestHelpers.assertEqualsSafe(expectedSchema, expectedRow, actual.get(0)); + } + + @Test + public void testPartitionsTable() { + TableIdentifier tableIdentifier = TableIdentifier.of("db", "partitions_test"); + Table table = createTable(tableIdentifier, SCHEMA, SPEC); + Table partitionsTable = loadTable(tableIdentifier, "partitions"); + Dataset df1 = + spark.createDataFrame(Lists.newArrayList(new SimpleRecord(1, "a")), SimpleRecord.class); + Dataset df2 = + spark.createDataFrame(Lists.newArrayList(new SimpleRecord(2, "b")), SimpleRecord.class); + + df1.select("id", "data") + .write() + .format("iceberg") + .mode("append") + .save(loadLocation(tableIdentifier)); + + table.refresh(); + long firstCommitId = table.currentSnapshot().snapshotId(); + + // add a second file + df2.select("id", "data") + .write() + .format("iceberg") + .mode("append") + .save(loadLocation(tableIdentifier)); + + List actual = + spark + .read() + .format("iceberg") + .load(loadLocation(tableIdentifier, "partitions")) + .orderBy("partition.id") + .collectAsList(); + + GenericRecordBuilder builder = + new GenericRecordBuilder(AvroSchemaUtil.convert(partitionsTable.schema(), "partitions")); + GenericRecordBuilder partitionBuilder = + new GenericRecordBuilder( + AvroSchemaUtil.convert( + partitionsTable.schema().findType("partition").asStructType(), "partition")); + List expected = Lists.newArrayList(); + expected.add( + builder + .set("partition", partitionBuilder.set("id", 1).build()) + .set("record_count", 1L) + .set("file_count", 1) + .set("spec_id", 0) + .build()); + expected.add( + builder + .set("partition", partitionBuilder.set("id", 2).build()) + .set("record_count", 1L) + .set("file_count", 1) + .set("spec_id", 0) + .build()); + + Assert.assertEquals("Partitions table should have two rows", 2, expected.size()); + Assert.assertEquals("Actual results should have two rows", 2, actual.size()); + for (int i = 0; i < 2; i += 1) { + TestHelpers.assertEqualsSafe( + partitionsTable.schema().asStruct(), expected.get(i), actual.get(i)); + } + + // check time travel + List actualAfterFirstCommit = + spark + .read() + .format("iceberg") + .option(SparkReadOptions.SNAPSHOT_ID, String.valueOf(firstCommitId)) + .load(loadLocation(tableIdentifier, "partitions")) + .orderBy("partition.id") + .collectAsList(); + + Assert.assertEquals("Actual results should have one row", 1, actualAfterFirstCommit.size()); + TestHelpers.assertEqualsSafe( + partitionsTable.schema().asStruct(), expected.get(0), actualAfterFirstCommit.get(0)); + + // check predicate push down + List filtered = + spark + .read() + .format("iceberg") + .load(loadLocation(tableIdentifier, "partitions")) + .filter("partition.id < 2") + .collectAsList(); + Assert.assertEquals("Actual results should have one row", 1, filtered.size()); + TestHelpers.assertEqualsSafe( + partitionsTable.schema().asStruct(), expected.get(0), filtered.get(0)); + + List nonFiltered = + spark + .read() + .format("iceberg") + .load(loadLocation(tableIdentifier, "partitions")) + .filter("partition.id < 2 or record_count=1") + .collectAsList(); + Assert.assertEquals("Actual results should have one row", 2, nonFiltered.size()); + for (int i = 0; i < 2; i += 1) { + TestHelpers.assertEqualsSafe( + partitionsTable.schema().asStruct(), expected.get(i), actual.get(i)); + } + } + + @Test + public synchronized void testSnapshotReadAfterAddColumn() { + TableIdentifier tableIdentifier = TableIdentifier.of("db", "table"); + Table table = createTable(tableIdentifier, SCHEMA, PartitionSpec.unpartitioned()); + + List originalRecords = + Lists.newArrayList( + RowFactory.create(1, "x"), RowFactory.create(2, "y"), RowFactory.create(3, "z")); + + StructType originalSparkSchema = SparkSchemaUtil.convert(SCHEMA); + Dataset inputDf = spark.createDataFrame(originalRecords, originalSparkSchema); + inputDf + .select("id", "data") + .write() + .format("iceberg") + .mode(SaveMode.Append) + .save(loadLocation(tableIdentifier)); + + table.refresh(); + + Dataset resultDf = spark.read().format("iceberg").load(loadLocation(tableIdentifier)); + Assert.assertEquals( + "Records should match", originalRecords, resultDf.orderBy("id").collectAsList()); + + Snapshot snapshotBeforeAddColumn = table.currentSnapshot(); + + table.updateSchema().addColumn("category", Types.StringType.get()).commit(); + + List newRecords = + Lists.newArrayList(RowFactory.create(4, "xy", "B"), RowFactory.create(5, "xyz", "C")); + + StructType newSparkSchema = SparkSchemaUtil.convert(SCHEMA2); + Dataset inputDf2 = spark.createDataFrame(newRecords, newSparkSchema); + inputDf2 + .select("id", "data", "category") + .write() + .format("iceberg") + .mode(SaveMode.Append) + .save(loadLocation(tableIdentifier)); + + table.refresh(); + + List updatedRecords = + Lists.newArrayList( + RowFactory.create(1, "x", null), + RowFactory.create(2, "y", null), + RowFactory.create(3, "z", null), + RowFactory.create(4, "xy", "B"), + RowFactory.create(5, "xyz", "C")); + + Dataset resultDf2 = spark.read().format("iceberg").load(loadLocation(tableIdentifier)); + Assert.assertEquals( + "Records should match", updatedRecords, resultDf2.orderBy("id").collectAsList()); + + Dataset resultDf3 = + spark + .read() + .format("iceberg") + .option(SparkReadOptions.SNAPSHOT_ID, snapshotBeforeAddColumn.snapshotId()) + .load(loadLocation(tableIdentifier)); + Assert.assertEquals( + "Records should match", originalRecords, resultDf3.orderBy("id").collectAsList()); + Assert.assertEquals("Schemas should match", originalSparkSchema, resultDf3.schema()); + } + + @Test + public synchronized void testSnapshotReadAfterDropColumn() { + TableIdentifier tableIdentifier = TableIdentifier.of("db", "table"); + Table table = createTable(tableIdentifier, SCHEMA2, PartitionSpec.unpartitioned()); + + List originalRecords = + Lists.newArrayList( + RowFactory.create(1, "x", "A"), + RowFactory.create(2, "y", "A"), + RowFactory.create(3, "z", "B")); + + StructType originalSparkSchema = SparkSchemaUtil.convert(SCHEMA2); + Dataset inputDf = spark.createDataFrame(originalRecords, originalSparkSchema); + inputDf + .select("id", "data", "category") + .write() + .format("iceberg") + .mode(SaveMode.Append) + .save(loadLocation(tableIdentifier)); + + table.refresh(); + + Dataset resultDf = spark.read().format("iceberg").load(loadLocation(tableIdentifier)); + Assert.assertEquals( + "Records should match", originalRecords, resultDf.orderBy("id").collectAsList()); + + long tsBeforeDropColumn = waitUntilAfter(System.currentTimeMillis()); + table.updateSchema().deleteColumn("data").commit(); + long tsAfterDropColumn = waitUntilAfter(System.currentTimeMillis()); + + List newRecords = Lists.newArrayList(RowFactory.create(4, "B"), RowFactory.create(5, "C")); + + StructType newSparkSchema = SparkSchemaUtil.convert(SCHEMA3); + Dataset inputDf2 = spark.createDataFrame(newRecords, newSparkSchema); + inputDf2 + .select("id", "category") + .write() + .format("iceberg") + .mode(SaveMode.Append) + .save(loadLocation(tableIdentifier)); + + table.refresh(); + + List updatedRecords = + Lists.newArrayList( + RowFactory.create(1, "A"), + RowFactory.create(2, "A"), + RowFactory.create(3, "B"), + RowFactory.create(4, "B"), + RowFactory.create(5, "C")); + + Dataset resultDf2 = spark.read().format("iceberg").load(loadLocation(tableIdentifier)); + Assert.assertEquals( + "Records should match", updatedRecords, resultDf2.orderBy("id").collectAsList()); + + Dataset resultDf3 = + spark + .read() + .format("iceberg") + .option(SparkReadOptions.AS_OF_TIMESTAMP, tsBeforeDropColumn) + .load(loadLocation(tableIdentifier)); + Assert.assertEquals( + "Records should match", originalRecords, resultDf3.orderBy("id").collectAsList()); + Assert.assertEquals("Schemas should match", originalSparkSchema, resultDf3.schema()); + + // At tsAfterDropColumn, there has been a schema change, but no new snapshot, + // so the snapshot as of tsAfterDropColumn is the same as that as of tsBeforeDropColumn. + Dataset resultDf4 = + spark + .read() + .format("iceberg") + .option(SparkReadOptions.AS_OF_TIMESTAMP, tsAfterDropColumn) + .load(loadLocation(tableIdentifier)); + Assert.assertEquals( + "Records should match", originalRecords, resultDf4.orderBy("id").collectAsList()); + Assert.assertEquals("Schemas should match", originalSparkSchema, resultDf4.schema()); + } + + @Test + public synchronized void testSnapshotReadAfterAddAndDropColumn() { + TableIdentifier tableIdentifier = TableIdentifier.of("db", "table"); + Table table = createTable(tableIdentifier, SCHEMA, PartitionSpec.unpartitioned()); + + List originalRecords = + Lists.newArrayList( + RowFactory.create(1, "x"), RowFactory.create(2, "y"), RowFactory.create(3, "z")); + + StructType originalSparkSchema = SparkSchemaUtil.convert(SCHEMA); + Dataset inputDf = spark.createDataFrame(originalRecords, originalSparkSchema); + inputDf + .select("id", "data") + .write() + .format("iceberg") + .mode(SaveMode.Append) + .save(loadLocation(tableIdentifier)); + + table.refresh(); + + Dataset resultDf = spark.read().format("iceberg").load(loadLocation(tableIdentifier)); + Assert.assertEquals( + "Records should match", originalRecords, resultDf.orderBy("id").collectAsList()); + + Snapshot snapshotBeforeAddColumn = table.currentSnapshot(); + + table.updateSchema().addColumn("category", Types.StringType.get()).commit(); + + List newRecords = + Lists.newArrayList(RowFactory.create(4, "xy", "B"), RowFactory.create(5, "xyz", "C")); + + StructType sparkSchemaAfterAddColumn = SparkSchemaUtil.convert(SCHEMA2); + Dataset inputDf2 = spark.createDataFrame(newRecords, sparkSchemaAfterAddColumn); + inputDf2 + .select("id", "data", "category") + .write() + .format("iceberg") + .mode(SaveMode.Append) + .save(loadLocation(tableIdentifier)); + + table.refresh(); + + List updatedRecords = + Lists.newArrayList( + RowFactory.create(1, "x", null), + RowFactory.create(2, "y", null), + RowFactory.create(3, "z", null), + RowFactory.create(4, "xy", "B"), + RowFactory.create(5, "xyz", "C")); + + Dataset resultDf2 = spark.read().format("iceberg").load(loadLocation(tableIdentifier)); + Assert.assertEquals( + "Records should match", updatedRecords, resultDf2.orderBy("id").collectAsList()); + + table.updateSchema().deleteColumn("data").commit(); + + List recordsAfterDropColumn = + Lists.newArrayList( + RowFactory.create(1, null), + RowFactory.create(2, null), + RowFactory.create(3, null), + RowFactory.create(4, "B"), + RowFactory.create(5, "C")); + + Dataset resultDf3 = spark.read().format("iceberg").load(loadLocation(tableIdentifier)); + Assert.assertEquals( + "Records should match", recordsAfterDropColumn, resultDf3.orderBy("id").collectAsList()); + + Dataset resultDf4 = + spark + .read() + .format("iceberg") + .option(SparkReadOptions.SNAPSHOT_ID, snapshotBeforeAddColumn.snapshotId()) + .load(loadLocation(tableIdentifier)); + Assert.assertEquals( + "Records should match", originalRecords, resultDf4.orderBy("id").collectAsList()); + Assert.assertEquals("Schemas should match", originalSparkSchema, resultDf4.schema()); + } + + @Test + public void testRemoveOrphanFilesActionSupport() throws InterruptedException { + TableIdentifier tableIdentifier = TableIdentifier.of("db", "table"); + Table table = createTable(tableIdentifier, SCHEMA, PartitionSpec.unpartitioned()); + + List records = Lists.newArrayList(new SimpleRecord(1, "1")); + + Dataset df = spark.createDataFrame(records, SimpleRecord.class); + + df.select("id", "data") + .write() + .format("iceberg") + .mode("append") + .save(loadLocation(tableIdentifier)); + + df.write().mode("append").parquet(table.location() + "/data"); + + // sleep for 1 second to ensure files will be old enough + Thread.sleep(1000); + + SparkActions actions = SparkActions.get(); + + DeleteOrphanFiles.Result result1 = + actions + .deleteOrphanFiles(table) + .location(table.location() + "/metadata") + .olderThan(System.currentTimeMillis()) + .execute(); + Assert.assertTrue( + "Should not delete any metadata files", Iterables.isEmpty(result1.orphanFileLocations())); + + DeleteOrphanFiles.Result result2 = + actions.deleteOrphanFiles(table).olderThan(System.currentTimeMillis()).execute(); + Assert.assertEquals( + "Should delete 1 data file", 1, Iterables.size(result2.orphanFileLocations())); + + Dataset resultDF = spark.read().format("iceberg").load(loadLocation(tableIdentifier)); + List actualRecords = + resultDF.as(Encoders.bean(SimpleRecord.class)).collectAsList(); + + Assert.assertEquals("Rows must match", records, actualRecords); + } + + @Test + public void testFilesTablePartitionId() throws Exception { + TableIdentifier tableIdentifier = TableIdentifier.of("db", "files_test"); + Table table = + createTable( + tableIdentifier, SCHEMA, PartitionSpec.builderFor(SCHEMA).identity("id").build()); + int spec0 = table.spec().specId(); + + Dataset df1 = + spark.createDataFrame(Lists.newArrayList(new SimpleRecord(1, "a")), SimpleRecord.class); + Dataset df2 = + spark.createDataFrame(Lists.newArrayList(new SimpleRecord(2, "b")), SimpleRecord.class); + + df1.select("id", "data") + .write() + .format("iceberg") + .mode("append") + .save(loadLocation(tableIdentifier)); + + // change partition spec + table.refresh(); + table.updateSpec().removeField("id").commit(); + int spec1 = table.spec().specId(); + + // add a second file + df2.select("id", "data") + .write() + .format("iceberg") + .mode("append") + .save(loadLocation(tableIdentifier)); + + List actual = + spark.read().format("iceberg").load(loadLocation(tableIdentifier, "files")) + .sort(DataFile.SPEC_ID.name()).collectAsList().stream() + .map(r -> (Integer) r.getAs(DataFile.SPEC_ID.name())) + .collect(Collectors.toList()); + + Assert.assertEquals("Should have two partition specs", ImmutableList.of(spec0, spec1), actual); + } + + @Test + public void testAllManifestTableSnapshotFiltering() throws Exception { + TableIdentifier tableIdentifier = TableIdentifier.of("db", "all_manifest_snapshot_filtering"); + Table table = createTable(tableIdentifier, SCHEMA, SPEC); + Table manifestTable = loadTable(tableIdentifier, "all_manifests"); + Dataset df = + spark.createDataFrame(Lists.newArrayList(new SimpleRecord(1, "a")), SimpleRecord.class); + + List> snapshotIdToManifests = Lists.newArrayList(); + + df.select("id", "data") + .write() + .format("iceberg") + .mode("append") + .save(loadLocation(tableIdentifier)); + + table.refresh(); + Snapshot snapshot1 = table.currentSnapshot(); + snapshotIdToManifests.addAll( + snapshot1.allManifests(table.io()).stream() + .map(manifest -> Pair.of(snapshot1.snapshotId(), manifest)) + .collect(Collectors.toList())); + + df.select("id", "data") + .write() + .format("iceberg") + .mode("append") + .save(loadLocation(tableIdentifier)); + + table.refresh(); + Snapshot snapshot2 = table.currentSnapshot(); + Assert.assertEquals("Should have two manifests", 2, snapshot2.allManifests(table.io()).size()); + snapshotIdToManifests.addAll( + snapshot2.allManifests(table.io()).stream() + .map(manifest -> Pair.of(snapshot2.snapshotId(), manifest)) + .collect(Collectors.toList())); + + // Add manifests that will not be selected + df.select("id", "data") + .write() + .format("iceberg") + .mode("append") + .save(loadLocation(tableIdentifier)); + df.select("id", "data") + .write() + .format("iceberg") + .mode("append") + .save(loadLocation(tableIdentifier)); + + StringJoiner snapshotIds = new StringJoiner(",", "(", ")"); + snapshotIds.add(String.valueOf(snapshot1.snapshotId())); + snapshotIds.add(String.valueOf(snapshot2.snapshotId())); + snapshotIds.toString(); + + List actual = + spark + .read() + .format("iceberg") + .load(loadLocation(tableIdentifier, "all_manifests")) + .filter("reference_snapshot_id in " + snapshotIds) + .orderBy("path") + .collectAsList(); + table.refresh(); + + List expected = + snapshotIdToManifests.stream() + .map( + snapshotManifest -> + manifestRecord( + manifestTable, snapshotManifest.first(), snapshotManifest.second())) + .collect(Collectors.toList()); + expected.sort(Comparator.comparing(o -> o.get("path").toString())); + + Assert.assertEquals("Manifests table should have 3 manifest rows", 3, actual.size()); + for (int i = 0; i < expected.size(); i += 1) { + TestHelpers.assertEqualsSafe( + manifestTable.schema().asStruct(), expected.get(i), actual.get(i)); + } + } + + @Test + public void testTableWithInt96Timestamp() throws IOException { + File parquetTableDir = temp.newFolder("table_timestamp_int96"); + String parquetTableLocation = parquetTableDir.toURI().toString(); + Schema schema = + new Schema( + optional(1, "id", Types.LongType.get()), + optional(2, "tmp_col", Types.TimestampType.withZone())); + spark.conf().set(SQLConf.PARQUET_OUTPUT_TIMESTAMP_TYPE().key(), "INT96"); + + LocalDateTime start = LocalDateTime.of(2000, 1, 31, 0, 0, 0); + LocalDateTime end = LocalDateTime.of(2100, 1, 1, 0, 0, 0); + long startSec = start.toEpochSecond(ZoneOffset.UTC); + long endSec = end.toEpochSecond(ZoneOffset.UTC); + Column idColumn = functions.expr("id"); + Column secondsColumn = + functions.expr("(id % " + (endSec - startSec) + " + " + startSec + ")").as("seconds"); + Column timestampColumn = functions.expr("cast( seconds as timestamp) as tmp_col"); + + for (Boolean useDict : new Boolean[] {true, false}) { + for (Boolean useVectorization : new Boolean[] {true, false}) { + spark.sql("DROP TABLE IF EXISTS parquet_table"); + spark + .range(0, 5000, 100, 1) + .select(idColumn, secondsColumn) + .select(idColumn, timestampColumn) + .write() + .format("parquet") + .option("parquet.enable.dictionary", useDict) + .mode("overwrite") + .option("path", parquetTableLocation) + .saveAsTable("parquet_table"); + TableIdentifier tableIdentifier = TableIdentifier.of("db", "table_with_timestamp_int96"); + Table table = createTable(tableIdentifier, schema, PartitionSpec.unpartitioned()); + table + .updateProperties() + .set(TableProperties.PARQUET_VECTORIZATION_ENABLED, useVectorization.toString()) + .commit(); + + String stagingLocation = table.location() + "/metadata"; + SparkTableUtil.importSparkTable( + spark, + new org.apache.spark.sql.catalyst.TableIdentifier("parquet_table"), + table, + stagingLocation); + + // validate we get the expected results back + List expected = spark.table("parquet_table").select("tmp_col").collectAsList(); + List actual = + spark + .read() + .format("iceberg") + .load(loadLocation(tableIdentifier)) + .select("tmp_col") + .collectAsList(); + Assertions.assertThat(actual) + .as("Rows must match") + .containsExactlyInAnyOrderElementsOf(expected); + dropTable(tableIdentifier); + } + } + } + + private GenericData.Record manifestRecord( + Table manifestTable, Long referenceSnapshotId, ManifestFile manifest) { + GenericRecordBuilder builder = + new GenericRecordBuilder(AvroSchemaUtil.convert(manifestTable.schema(), "manifests")); + GenericRecordBuilder summaryBuilder = + new GenericRecordBuilder( + AvroSchemaUtil.convert( + manifestTable.schema().findType("partition_summaries.element").asStructType(), + "partition_summary")); + return builder + .set("content", manifest.content().id()) + .set("path", manifest.path()) + .set("length", manifest.length()) + .set("partition_spec_id", manifest.partitionSpecId()) + .set("added_snapshot_id", manifest.snapshotId()) + .set("added_data_files_count", manifest.content() == DATA ? manifest.addedFilesCount() : 0) + .set( + "existing_data_files_count", + manifest.content() == DATA ? manifest.existingFilesCount() : 0) + .set( + "deleted_data_files_count", + manifest.content() == DATA ? manifest.deletedFilesCount() : 0) + .set( + "added_delete_files_count", + manifest.content() == DELETES ? manifest.addedFilesCount() : 0) + .set( + "existing_delete_files_count", + manifest.content() == DELETES ? manifest.existingFilesCount() : 0) + .set( + "deleted_delete_files_count", + manifest.content() == DELETES ? manifest.deletedFilesCount() : 0) + .set( + "partition_summaries", + Lists.transform( + manifest.partitions(), + partition -> + summaryBuilder + .set("contains_null", false) + .set("contains_nan", false) + .set("lower_bound", "1") + .set("upper_bound", "1") + .build())) + .set("reference_snapshot_id", referenceSnapshotId) + .build(); + } + + private PositionDeleteWriter newPositionDeleteWriter( + Table table, PartitionSpec spec, StructLike partition) { + OutputFileFactory fileFactory = OutputFileFactory.builderFor(table, 0, 0).build(); + EncryptedOutputFile outputFile = fileFactory.newOutputFile(spec, partition); + + SparkFileWriterFactory fileWriterFactory = SparkFileWriterFactory.builderFor(table).build(); + return fileWriterFactory.newPositionDeleteWriter(outputFile, spec, partition); + } + + private DeleteFile writePositionDeletes( + Table table, + PartitionSpec spec, + StructLike partition, + Iterable> deletes) { + PositionDeleteWriter positionDeleteWriter = + newPositionDeleteWriter(table, spec, partition); + + try (PositionDeleteWriter writer = positionDeleteWriter) { + for (PositionDelete delete : deletes) { + writer.write(delete); + } + } catch (IOException e) { + throw new UncheckedIOException(e); + } + + return positionDeleteWriter.toDeleteFile(); + } +} diff --git a/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/source/TestIcebergSpark.java b/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/source/TestIcebergSpark.java new file mode 100644 index 000000000000..37e329a8b97b --- /dev/null +++ b/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/source/TestIcebergSpark.java @@ -0,0 +1,249 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.source; + +import java.math.BigDecimal; +import java.nio.ByteBuffer; +import java.sql.Date; +import java.sql.Timestamp; +import java.util.List; +import org.apache.iceberg.spark.IcebergSpark; +import org.apache.iceberg.transforms.Transforms; +import org.apache.iceberg.types.Types; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.SparkSession; +import org.apache.spark.sql.catalyst.util.DateTimeUtils; +import org.apache.spark.sql.types.CharType; +import org.apache.spark.sql.types.DataTypes; +import org.apache.spark.sql.types.DecimalType; +import org.apache.spark.sql.types.VarcharType; +import org.assertj.core.api.Assertions; +import org.junit.AfterClass; +import org.junit.Assert; +import org.junit.BeforeClass; +import org.junit.Test; + +public class TestIcebergSpark { + + private static SparkSession spark = null; + + @BeforeClass + public static void startSpark() { + TestIcebergSpark.spark = SparkSession.builder().master("local[2]").getOrCreate(); + } + + @AfterClass + public static void stopSpark() { + SparkSession currentSpark = TestIcebergSpark.spark; + TestIcebergSpark.spark = null; + currentSpark.stop(); + } + + @Test + public void testRegisterIntegerBucketUDF() { + IcebergSpark.registerBucketUDF(spark, "iceberg_bucket_int_16", DataTypes.IntegerType, 16); + List results = spark.sql("SELECT iceberg_bucket_int_16(1)").collectAsList(); + Assert.assertEquals(1, results.size()); + Assert.assertEquals( + (int) Transforms.bucket(16).bind(Types.IntegerType.get()).apply(1), + results.get(0).getInt(0)); + } + + @Test + public void testRegisterShortBucketUDF() { + IcebergSpark.registerBucketUDF(spark, "iceberg_bucket_short_16", DataTypes.ShortType, 16); + List results = spark.sql("SELECT iceberg_bucket_short_16(1S)").collectAsList(); + Assert.assertEquals(1, results.size()); + Assert.assertEquals( + (int) Transforms.bucket(16).bind(Types.IntegerType.get()).apply(1), + results.get(0).getInt(0)); + } + + @Test + public void testRegisterByteBucketUDF() { + IcebergSpark.registerBucketUDF(spark, "iceberg_bucket_byte_16", DataTypes.ByteType, 16); + List results = spark.sql("SELECT iceberg_bucket_byte_16(1Y)").collectAsList(); + Assert.assertEquals(1, results.size()); + Assert.assertEquals( + (int) Transforms.bucket(16).bind(Types.IntegerType.get()).apply(1), + results.get(0).getInt(0)); + } + + @Test + public void testRegisterLongBucketUDF() { + IcebergSpark.registerBucketUDF(spark, "iceberg_bucket_long_16", DataTypes.LongType, 16); + List results = spark.sql("SELECT iceberg_bucket_long_16(1L)").collectAsList(); + Assert.assertEquals(1, results.size()); + Assert.assertEquals( + (int) Transforms.bucket(16).bind(Types.LongType.get()).apply(1L), results.get(0).getInt(0)); + } + + @Test + public void testRegisterStringBucketUDF() { + IcebergSpark.registerBucketUDF(spark, "iceberg_bucket_string_16", DataTypes.StringType, 16); + List results = spark.sql("SELECT iceberg_bucket_string_16('hello')").collectAsList(); + Assert.assertEquals(1, results.size()); + Assert.assertEquals( + (int) Transforms.bucket(16).bind(Types.StringType.get()).apply("hello"), + results.get(0).getInt(0)); + } + + @Test + public void testRegisterCharBucketUDF() { + IcebergSpark.registerBucketUDF(spark, "iceberg_bucket_char_16", new CharType(5), 16); + List results = spark.sql("SELECT iceberg_bucket_char_16('hello')").collectAsList(); + Assert.assertEquals(1, results.size()); + Assert.assertEquals( + (int) Transforms.bucket(16).bind(Types.StringType.get()).apply("hello"), + results.get(0).getInt(0)); + } + + @Test + public void testRegisterVarCharBucketUDF() { + IcebergSpark.registerBucketUDF(spark, "iceberg_bucket_varchar_16", new VarcharType(5), 16); + List results = spark.sql("SELECT iceberg_bucket_varchar_16('hello')").collectAsList(); + Assert.assertEquals(1, results.size()); + Assert.assertEquals( + (int) Transforms.bucket(16).bind(Types.StringType.get()).apply("hello"), + results.get(0).getInt(0)); + } + + @Test + public void testRegisterDateBucketUDF() { + IcebergSpark.registerBucketUDF(spark, "iceberg_bucket_date_16", DataTypes.DateType, 16); + List results = + spark.sql("SELECT iceberg_bucket_date_16(DATE '2021-06-30')").collectAsList(); + Assert.assertEquals(1, results.size()); + Assert.assertEquals( + (int) + Transforms.bucket(16) + .bind(Types.DateType.get()) + .apply(DateTimeUtils.fromJavaDate(Date.valueOf("2021-06-30"))), + results.get(0).getInt(0)); + } + + @Test + public void testRegisterTimestampBucketUDF() { + IcebergSpark.registerBucketUDF( + spark, "iceberg_bucket_timestamp_16", DataTypes.TimestampType, 16); + List results = + spark + .sql("SELECT iceberg_bucket_timestamp_16(TIMESTAMP '2021-06-30 00:00:00.000')") + .collectAsList(); + Assert.assertEquals(1, results.size()); + Assert.assertEquals( + (int) + Transforms.bucket(16) + .bind(Types.TimestampType.withZone()) + .apply( + DateTimeUtils.fromJavaTimestamp(Timestamp.valueOf("2021-06-30 00:00:00.000"))), + results.get(0).getInt(0)); + } + + @Test + public void testRegisterBinaryBucketUDF() { + IcebergSpark.registerBucketUDF(spark, "iceberg_bucket_binary_16", DataTypes.BinaryType, 16); + List results = spark.sql("SELECT iceberg_bucket_binary_16(X'0020001F')").collectAsList(); + Assert.assertEquals(1, results.size()); + Assert.assertEquals( + (int) + Transforms.bucket(16) + .bind(Types.BinaryType.get()) + .apply(ByteBuffer.wrap(new byte[] {0x00, 0x20, 0x00, 0x1F})), + results.get(0).getInt(0)); + } + + @Test + public void testRegisterDecimalBucketUDF() { + IcebergSpark.registerBucketUDF(spark, "iceberg_bucket_decimal_16", new DecimalType(4, 2), 16); + List results = spark.sql("SELECT iceberg_bucket_decimal_16(11.11)").collectAsList(); + Assert.assertEquals(1, results.size()); + Assert.assertEquals( + (int) Transforms.bucket(16).bind(Types.DecimalType.of(4, 2)).apply(new BigDecimal("11.11")), + results.get(0).getInt(0)); + } + + @Test + public void testRegisterBooleanBucketUDF() { + Assertions.assertThatThrownBy( + () -> + IcebergSpark.registerBucketUDF( + spark, "iceberg_bucket_boolean_16", DataTypes.BooleanType, 16)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("Cannot bucket by type: boolean"); + } + + @Test + public void testRegisterDoubleBucketUDF() { + Assertions.assertThatThrownBy( + () -> + IcebergSpark.registerBucketUDF( + spark, "iceberg_bucket_double_16", DataTypes.DoubleType, 16)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("Cannot bucket by type: double"); + } + + @Test + public void testRegisterFloatBucketUDF() { + Assertions.assertThatThrownBy( + () -> + IcebergSpark.registerBucketUDF( + spark, "iceberg_bucket_float_16", DataTypes.FloatType, 16)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("Cannot bucket by type: float"); + } + + @Test + public void testRegisterIntegerTruncateUDF() { + IcebergSpark.registerTruncateUDF(spark, "iceberg_truncate_int_4", DataTypes.IntegerType, 4); + List results = spark.sql("SELECT iceberg_truncate_int_4(1)").collectAsList(); + Assert.assertEquals(1, results.size()); + Assert.assertEquals( + Transforms.truncate(4).bind(Types.IntegerType.get()).apply(1), results.get(0).getInt(0)); + } + + @Test + public void testRegisterLongTruncateUDF() { + IcebergSpark.registerTruncateUDF(spark, "iceberg_truncate_long_4", DataTypes.LongType, 4); + List results = spark.sql("SELECT iceberg_truncate_long_4(1L)").collectAsList(); + Assert.assertEquals(1, results.size()); + Assert.assertEquals( + Transforms.truncate(4).bind(Types.LongType.get()).apply(1L), results.get(0).getLong(0)); + } + + @Test + public void testRegisterDecimalTruncateUDF() { + IcebergSpark.registerTruncateUDF(spark, "iceberg_truncate_decimal_4", new DecimalType(4, 2), 4); + List results = spark.sql("SELECT iceberg_truncate_decimal_4(11.11)").collectAsList(); + Assert.assertEquals(1, results.size()); + Assert.assertEquals( + Transforms.truncate(4).bind(Types.DecimalType.of(4, 2)).apply(new BigDecimal("11.11")), + results.get(0).getDecimal(0)); + } + + @Test + public void testRegisterStringTruncateUDF() { + IcebergSpark.registerTruncateUDF(spark, "iceberg_truncate_string_4", DataTypes.StringType, 4); + List results = spark.sql("SELECT iceberg_truncate_string_4('hello')").collectAsList(); + Assert.assertEquals(1, results.size()); + Assert.assertEquals( + Transforms.truncate(4).bind(Types.StringType.get()).apply("hello"), + results.get(0).getString(0)); + } +} diff --git a/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/source/TestIdentityPartitionData.java b/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/source/TestIdentityPartitionData.java new file mode 100644 index 000000000000..7313c18cc09d --- /dev/null +++ b/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/source/TestIdentityPartitionData.java @@ -0,0 +1,209 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.source; + +import java.io.File; +import java.util.Arrays; +import java.util.List; +import java.util.Map; +import org.apache.hadoop.conf.Configuration; +import org.apache.iceberg.PartitionSpec; +import org.apache.iceberg.Schema; +import org.apache.iceberg.Table; +import org.apache.iceberg.TableProperties; +import org.apache.iceberg.hadoop.HadoopTables; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableList; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap; +import org.apache.iceberg.spark.SparkReadOptions; +import org.apache.iceberg.spark.SparkSchemaUtil; +import org.apache.iceberg.spark.SparkTableUtil; +import org.apache.iceberg.spark.SparkTestBase; +import org.apache.iceberg.types.Types; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.catalyst.TableIdentifier; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.TemporaryFolder; +import org.junit.runner.RunWith; +import org.junit.runners.Parameterized; + +@RunWith(Parameterized.class) +public class TestIdentityPartitionData extends SparkTestBase { + private static final Configuration CONF = new Configuration(); + private static final HadoopTables TABLES = new HadoopTables(CONF); + + @Parameterized.Parameters(name = "format = {0}, vectorized = {1}") + public static Object[][] parameters() { + return new Object[][] { + {"parquet", false}, + {"parquet", true}, + {"avro", false}, + {"orc", false}, + {"orc", true}, + }; + } + + private final String format; + private final boolean vectorized; + + public TestIdentityPartitionData(String format, boolean vectorized) { + this.format = format; + this.vectorized = vectorized; + } + + private static final Schema LOG_SCHEMA = + new Schema( + Types.NestedField.optional(1, "id", Types.IntegerType.get()), + Types.NestedField.optional(2, "date", Types.StringType.get()), + Types.NestedField.optional(3, "level", Types.StringType.get()), + Types.NestedField.optional(4, "message", Types.StringType.get())); + + private static final List LOGS = + ImmutableList.of( + LogMessage.debug("2020-02-02", "debug event 1"), + LogMessage.info("2020-02-02", "info event 1"), + LogMessage.debug("2020-02-02", "debug event 2"), + LogMessage.info("2020-02-03", "info event 2"), + LogMessage.debug("2020-02-03", "debug event 3"), + LogMessage.info("2020-02-03", "info event 3"), + LogMessage.error("2020-02-03", "error event 1"), + LogMessage.debug("2020-02-04", "debug event 4"), + LogMessage.warn("2020-02-04", "warn event 1"), + LogMessage.debug("2020-02-04", "debug event 5")); + + @Rule public TemporaryFolder temp = new TemporaryFolder(); + + private PartitionSpec spec = + PartitionSpec.builderFor(LOG_SCHEMA).identity("date").identity("level").build(); + private Table table = null; + private Dataset logs = null; + + /** + * Use the Hive Based table to make Identity Partition Columns with no duplication of the data in + * the underlying parquet files. This makes sure that if the identity mapping fails, the test will + * also fail. + */ + private void setupParquet() throws Exception { + File location = temp.newFolder("logs"); + File hiveLocation = temp.newFolder("hive"); + String hiveTable = "hivetable"; + Assert.assertTrue("Temp folder should exist", location.exists()); + + Map properties = ImmutableMap.of(TableProperties.DEFAULT_FILE_FORMAT, format); + this.logs = + spark.createDataFrame(LOGS, LogMessage.class).select("id", "date", "level", "message"); + spark.sql(String.format("DROP TABLE IF EXISTS %s", hiveTable)); + logs.orderBy("date", "level", "id") + .write() + .partitionBy("date", "level") + .format("parquet") + .option("path", hiveLocation.toString()) + .saveAsTable(hiveTable); + + this.table = + TABLES.create( + SparkSchemaUtil.schemaForTable(spark, hiveTable), + SparkSchemaUtil.specForTable(spark, hiveTable), + properties, + location.toString()); + + SparkTableUtil.importSparkTable( + spark, new TableIdentifier(hiveTable), table, location.toString()); + } + + @Before + public void setupTable() throws Exception { + if (format.equals("parquet")) { + setupParquet(); + } else { + File location = temp.newFolder("logs"); + Assert.assertTrue("Temp folder should exist", location.exists()); + + Map properties = ImmutableMap.of(TableProperties.DEFAULT_FILE_FORMAT, format); + this.table = TABLES.create(LOG_SCHEMA, spec, properties, location.toString()); + this.logs = + spark.createDataFrame(LOGS, LogMessage.class).select("id", "date", "level", "message"); + + logs.orderBy("date", "level", "id") + .write() + .format("iceberg") + .mode("append") + .save(location.toString()); + } + } + + @Test + public void testFullProjection() { + List expected = logs.orderBy("id").collectAsList(); + List actual = + spark + .read() + .format("iceberg") + .option(SparkReadOptions.VECTORIZATION_ENABLED, String.valueOf(vectorized)) + .load(table.location()) + .orderBy("id") + .select("id", "date", "level", "message") + .collectAsList(); + Assert.assertEquals("Rows should match", expected, actual); + } + + @Test + public void testProjections() { + String[][] cases = + new String[][] { + // individual fields + new String[] {"date"}, + new String[] {"level"}, + new String[] {"message"}, + // field pairs + new String[] {"date", "message"}, + new String[] {"level", "message"}, + new String[] {"date", "level"}, + // out-of-order pairs + new String[] {"message", "date"}, + new String[] {"message", "level"}, + new String[] {"level", "date"}, + // full projection, different orderings + new String[] {"date", "level", "message"}, + new String[] {"level", "date", "message"}, + new String[] {"date", "message", "level"}, + new String[] {"level", "message", "date"}, + new String[] {"message", "date", "level"}, + new String[] {"message", "level", "date"} + }; + + for (String[] ordering : cases) { + List expected = logs.select("id", ordering).orderBy("id").collectAsList(); + List actual = + spark + .read() + .format("iceberg") + .option(SparkReadOptions.VECTORIZATION_ENABLED, String.valueOf(vectorized)) + .load(table.location()) + .select("id", ordering) + .orderBy("id") + .collectAsList(); + Assert.assertEquals( + "Rows should match for ordering: " + Arrays.toString(ordering), expected, actual); + } + } +} diff --git a/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/source/TestInternalRowWrapper.java b/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/source/TestInternalRowWrapper.java new file mode 100644 index 000000000000..9e75145faff9 --- /dev/null +++ b/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/source/TestInternalRowWrapper.java @@ -0,0 +1,79 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.source; + +import java.util.Iterator; +import org.apache.iceberg.RecordWrapperTest; +import org.apache.iceberg.Schema; +import org.apache.iceberg.StructLike; +import org.apache.iceberg.data.InternalRecordWrapper; +import org.apache.iceberg.data.RandomGenericData; +import org.apache.iceberg.data.Record; +import org.apache.iceberg.spark.SparkSchemaUtil; +import org.apache.iceberg.spark.data.RandomData; +import org.apache.iceberg.util.StructLikeWrapper; +import org.apache.spark.sql.catalyst.InternalRow; +import org.junit.Assert; +import org.junit.Ignore; + +public class TestInternalRowWrapper extends RecordWrapperTest { + + @Ignore + @Override + public void testTimestampWithoutZone() { + // Spark does not support timestamp without zone. + } + + @Ignore + @Override + public void testTime() { + // Spark does not support time fields. + } + + @Override + protected void generateAndValidate(Schema schema, AssertMethod assertMethod) { + int numRecords = 100; + Iterable recordList = RandomGenericData.generate(schema, numRecords, 101L); + Iterable rowList = RandomData.generateSpark(schema, numRecords, 101L); + + InternalRecordWrapper recordWrapper = new InternalRecordWrapper(schema.asStruct()); + InternalRowWrapper rowWrapper = new InternalRowWrapper(SparkSchemaUtil.convert(schema)); + + Iterator actual = recordList.iterator(); + Iterator expected = rowList.iterator(); + + StructLikeWrapper actualWrapper = StructLikeWrapper.forType(schema.asStruct()); + StructLikeWrapper expectedWrapper = StructLikeWrapper.forType(schema.asStruct()); + for (int i = 0; i < numRecords; i++) { + Assert.assertTrue("Should have more records", actual.hasNext()); + Assert.assertTrue("Should have more InternalRow", expected.hasNext()); + + StructLike recordStructLike = recordWrapper.wrap(actual.next()); + StructLike rowStructLike = rowWrapper.wrap(expected.next()); + + assertMethod.assertEquals( + "Should have expected StructLike values", + actualWrapper.set(recordStructLike), + expectedWrapper.set(rowStructLike)); + } + + Assert.assertFalse("Shouldn't have more record", actual.hasNext()); + Assert.assertFalse("Shouldn't have more InternalRow", expected.hasNext()); + } +} diff --git a/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/source/TestMetadataTableReadableMetrics.java b/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/source/TestMetadataTableReadableMetrics.java new file mode 100644 index 000000000000..343943b0f891 --- /dev/null +++ b/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/source/TestMetadataTableReadableMetrics.java @@ -0,0 +1,286 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.source; + +import static org.apache.iceberg.types.Types.NestedField.optional; +import static org.apache.iceberg.types.Types.NestedField.required; + +import java.io.IOException; +import java.math.BigDecimal; +import java.nio.ByteBuffer; +import java.util.Base64; +import java.util.List; +import org.apache.iceberg.DataFile; +import org.apache.iceberg.Files; +import org.apache.iceberg.PartitionSpec; +import org.apache.iceberg.Schema; +import org.apache.iceberg.Table; +import org.apache.iceberg.catalog.Namespace; +import org.apache.iceberg.catalog.TableIdentifier; +import org.apache.iceberg.data.FileHelpers; +import org.apache.iceberg.data.GenericRecord; +import org.apache.iceberg.data.Record; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableList; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap; +import org.apache.iceberg.relocated.com.google.common.collect.Lists; +import org.apache.iceberg.spark.SparkCatalogConfig; +import org.apache.iceberg.spark.SparkTestBaseWithCatalog; +import org.apache.iceberg.types.Types; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Row; +import org.junit.After; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.TemporaryFolder; + +public class TestMetadataTableReadableMetrics extends SparkTestBaseWithCatalog { + + @Rule public TemporaryFolder temp = new TemporaryFolder(); + + private static final Types.StructType LEAF_STRUCT_TYPE = + Types.StructType.of( + optional(1, "leafLongCol", Types.LongType.get()), + optional(2, "leafDoubleCol", Types.DoubleType.get())); + + private static final Types.StructType NESTED_STRUCT_TYPE = + Types.StructType.of(required(3, "leafStructCol", LEAF_STRUCT_TYPE)); + + private static final Schema NESTED_SCHEMA = + new Schema(required(4, "nestedStructCol", NESTED_STRUCT_TYPE)); + + private static final Schema PRIMITIVE_SCHEMA = + new Schema( + required(1, "booleanCol", Types.BooleanType.get()), + required(2, "intCol", Types.IntegerType.get()), + required(3, "longCol", Types.LongType.get()), + required(4, "floatCol", Types.FloatType.get()), + required(5, "doubleCol", Types.DoubleType.get()), + optional(6, "decimalCol", Types.DecimalType.of(10, 2)), + optional(7, "stringCol", Types.StringType.get()), + optional(8, "fixedCol", Types.FixedType.ofLength(3)), + optional(9, "binaryCol", Types.BinaryType.get())); + + public TestMetadataTableReadableMetrics() { + // only SparkCatalog supports metadata table sql queries + super(SparkCatalogConfig.HIVE); + } + + protected String tableName() { + return tableName.split("\\.")[2]; + } + + protected String database() { + return tableName.split("\\.")[1]; + } + + private Table createPrimitiveTable() throws IOException { + Table table = + catalog.createTable( + TableIdentifier.of(Namespace.of(database()), tableName()), + PRIMITIVE_SCHEMA, + PartitionSpec.unpartitioned(), + ImmutableMap.of()); + List records = + Lists.newArrayList( + createPrimitiveRecord( + false, + 1, + 1L, + 0, + 1.0D, + new BigDecimal("1.00"), + "1", + Base64.getDecoder().decode("1111"), + ByteBuffer.wrap(Base64.getDecoder().decode("1111"))), + createPrimitiveRecord( + true, + 2, + 2L, + 0, + 2.0D, + new BigDecimal("2.00"), + "2", + Base64.getDecoder().decode("2222"), + ByteBuffer.wrap(Base64.getDecoder().decode("2222"))), + createPrimitiveRecord(false, 1, 1, Float.NaN, Double.NaN, null, "1", null, null), + createPrimitiveRecord( + false, 2, 2L, Float.NaN, 2.0D, new BigDecimal("2.00"), "2", null, null)); + + DataFile dataFile = + FileHelpers.writeDataFile(table, Files.localOutput(temp.newFile()), records); + table.newAppend().appendFile(dataFile).commit(); + return table; + } + + private void createNestedTable() throws IOException { + Table table = + catalog.createTable( + TableIdentifier.of(Namespace.of(database()), tableName()), + NESTED_SCHEMA, + PartitionSpec.unpartitioned(), + ImmutableMap.of()); + + List records = + Lists.newArrayList( + createNestedRecord(0L, 0.0), + createNestedRecord(1L, Double.NaN), + createNestedRecord(null, null)); + DataFile dataFile = + FileHelpers.writeDataFile(table, Files.localOutput(temp.newFile()), records); + table.newAppend().appendFile(dataFile).commit(); + } + + @After + public void dropTable() { + sql("DROP TABLE %s", tableName); + } + + private Dataset filesDf() { + return spark.read().format("iceberg").load(database() + "." + tableName() + ".files"); + } + + protected GenericRecord createPrimitiveRecord( + boolean booleanCol, + int intCol, + long longCol, + float floatCol, + double doubleCol, + BigDecimal decimalCol, + String stringCol, + byte[] fixedCol, + ByteBuffer binaryCol) { + GenericRecord record = GenericRecord.create(PRIMITIVE_SCHEMA); + record.set(0, booleanCol); + record.set(1, intCol); + record.set(2, longCol); + record.set(3, floatCol); + record.set(4, doubleCol); + record.set(5, decimalCol); + record.set(6, stringCol); + record.set(7, fixedCol); + record.set(8, binaryCol); + return record; + } + + private GenericRecord createNestedRecord(Long longCol, Double doubleCol) { + GenericRecord record = GenericRecord.create(NESTED_SCHEMA); + GenericRecord nested = GenericRecord.create(NESTED_STRUCT_TYPE); + GenericRecord leaf = GenericRecord.create(LEAF_STRUCT_TYPE); + leaf.set(0, longCol); + leaf.set(1, doubleCol); + nested.set(0, leaf); + record.set(0, nested); + return record; + } + + @Test + public void testPrimitiveColumns() throws Exception { + createPrimitiveTable(); + + Object[] binaryCol = + row( + 59L, + 4L, + 2L, + null, + Base64.getDecoder().decode("1111"), + Base64.getDecoder().decode("2222")); + Object[] booleanCol = row(44L, 4L, 0L, null, false, true); + Object[] decimalCol = row(97L, 4L, 1L, null, new BigDecimal("1.00"), new BigDecimal("2.00")); + Object[] doubleCol = row(99L, 4L, 0L, 1L, 1.0D, 2.0D); + Object[] fixedCol = + row( + 55L, + 4L, + 2L, + null, + Base64.getDecoder().decode("1111"), + Base64.getDecoder().decode("2222")); + Object[] floatCol = row(90L, 4L, 0L, 2L, 0f, 0f); + Object[] intCol = row(91L, 4L, 0L, null, 1, 2); + Object[] longCol = row(91L, 4L, 0L, null, 1L, 2L); + Object[] stringCol = row(99L, 4L, 0L, null, "1", "2"); + + Object[] metrics = + row( + binaryCol, + booleanCol, + decimalCol, + doubleCol, + fixedCol, + floatCol, + intCol, + longCol, + stringCol); + + assertEquals( + "Row should match", + ImmutableList.of(new Object[] {metrics}), + sql("SELECT readable_metrics FROM %s.files", tableName)); + } + + @Test + public void testSelectPrimitiveValues() throws Exception { + createPrimitiveTable(); + + assertEquals( + "select of primitive readable_metrics fields should work", + ImmutableList.of(row(1, true)), + sql( + "SELECT readable_metrics.intCol.lower_bound, readable_metrics.booleanCol.upper_bound FROM %s.files", + tableName)); + + assertEquals( + "mixed select of readable_metrics and other field should work", + ImmutableList.of(row(0, 4L)), + sql("SELECT content, readable_metrics.longCol.value_count FROM %s.files", tableName)); + + assertEquals( + "mixed select of readable_metrics and other field should work, in the other order", + ImmutableList.of(row(4L, 0)), + sql("SELECT readable_metrics.longCol.value_count, content FROM %s.files", tableName)); + } + + @Test + public void testSelectNestedValues() throws Exception { + createNestedTable(); + + assertEquals( + "select of nested readable_metrics fields should work", + ImmutableList.of(row(0L, 3L)), + sql( + "SELECT readable_metrics.`nestedStructCol.leafStructCol.leafLongCol`.lower_bound, " + + "readable_metrics.`nestedStructCol.leafStructCol.leafDoubleCol`.value_count FROM %s.files", + tableName)); + } + + @Test + public void testNestedValues() throws Exception { + createNestedTable(); + + Object[] leafDoubleCol = row(53L, 3L, 1L, 1L, 0.0D, 0.0D); + Object[] leafLongCol = row(54L, 3L, 1L, null, 0L, 1L); + Object[] metrics = row(leafDoubleCol, leafLongCol); + + assertEquals( + "Row should match", + ImmutableList.of(new Object[] {metrics}), + sql("SELECT readable_metrics FROM %s.files", tableName)); + } +} diff --git a/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/source/TestMetadataTablesWithPartitionEvolution.java b/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/source/TestMetadataTablesWithPartitionEvolution.java new file mode 100644 index 000000000000..82c9a58e33ea --- /dev/null +++ b/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/source/TestMetadataTablesWithPartitionEvolution.java @@ -0,0 +1,759 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.source; + +import static org.apache.iceberg.FileFormat.AVRO; +import static org.apache.iceberg.FileFormat.ORC; +import static org.apache.iceberg.FileFormat.PARQUET; +import static org.apache.iceberg.MetadataTableType.ALL_DATA_FILES; +import static org.apache.iceberg.MetadataTableType.ALL_ENTRIES; +import static org.apache.iceberg.MetadataTableType.ENTRIES; +import static org.apache.iceberg.MetadataTableType.FILES; +import static org.apache.iceberg.MetadataTableType.PARTITIONS; +import static org.apache.iceberg.TableProperties.DEFAULT_FILE_FORMAT; +import static org.apache.iceberg.TableProperties.FORMAT_VERSION; + +import java.util.Arrays; +import java.util.List; +import java.util.Map; +import org.apache.iceberg.AssertHelpers; +import org.apache.iceberg.FileFormat; +import org.apache.iceberg.HasTableOperations; +import org.apache.iceberg.MetadataTableType; +import org.apache.iceberg.PartitionSpec; +import org.apache.iceberg.PartitionSpecParser; +import org.apache.iceberg.Table; +import org.apache.iceberg.TableMetadata; +import org.apache.iceberg.TableOperations; +import org.apache.iceberg.exceptions.ValidationException; +import org.apache.iceberg.expressions.Expressions; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableList; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap; +import org.apache.iceberg.spark.SparkCatalog; +import org.apache.iceberg.spark.SparkCatalogTestBase; +import org.apache.iceberg.spark.SparkSessionCatalog; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.catalyst.parser.ParseException; +import org.apache.spark.sql.types.DataType; +import org.apache.spark.sql.types.StructType; +import org.junit.After; +import org.junit.Assert; +import org.junit.Assume; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.Parameterized; +import org.junit.runners.Parameterized.Parameters; + +@RunWith(Parameterized.class) +public class TestMetadataTablesWithPartitionEvolution extends SparkCatalogTestBase { + + @Parameters(name = "catalog = {0}, impl = {1}, conf = {2}, fileFormat = {3}, formatVersion = {4}") + public static Object[][] parameters() { + return new Object[][] { + { + "testhive", + SparkCatalog.class.getName(), + ImmutableMap.of( + "type", "hive", + "default-namespace", "default"), + ORC, + 1 + }, + { + "testhive", + SparkCatalog.class.getName(), + ImmutableMap.of( + "type", "hive", + "default-namespace", "default"), + ORC, + 2 + }, + {"testhadoop", SparkCatalog.class.getName(), ImmutableMap.of("type", "hadoop"), PARQUET, 1}, + {"testhadoop", SparkCatalog.class.getName(), ImmutableMap.of("type", "hadoop"), PARQUET, 2}, + { + "spark_catalog", + SparkSessionCatalog.class.getName(), + ImmutableMap.of( + "type", "hive", + "default-namespace", "default", + "clients", "1", + "parquet-enabled", "false", + "cache-enabled", + "false" // Spark will delete tables using v1, leaving the cache out of sync + ), + AVRO, + 1 + }, + { + "spark_catalog", + SparkSessionCatalog.class.getName(), + ImmutableMap.of( + "type", "hive", + "default-namespace", "default", + "clients", "1", + "parquet-enabled", "false", + "cache-enabled", + "false" // Spark will delete tables using v1, leaving the cache out of sync + ), + AVRO, + 2 + } + }; + } + + private final FileFormat fileFormat; + private final int formatVersion; + + public TestMetadataTablesWithPartitionEvolution( + String catalogName, + String implementation, + Map config, + FileFormat fileFormat, + int formatVersion) { + super(catalogName, implementation, config); + this.fileFormat = fileFormat; + this.formatVersion = formatVersion; + } + + @After + public void removeTable() { + sql("DROP TABLE IF EXISTS %s", tableName); + } + + @Test + public void testFilesMetadataTable() throws ParseException { + sql( + "CREATE TABLE %s (id bigint NOT NULL, category string, data string) USING iceberg", + tableName); + initTable(); + + sql("INSERT INTO TABLE %s VALUES (1, 'a1', 'b1')", tableName); + + // verify the metadata tables while the current spec is still unpartitioned + for (MetadataTableType tableType : Arrays.asList(FILES, ALL_DATA_FILES)) { + Dataset df = loadMetadataTable(tableType); + Assert.assertTrue( + "Partition must be skipped", df.schema().getFieldIndex("partition").isEmpty()); + } + + Table table = validationCatalog.loadTable(tableIdent); + + table.updateSpec().addField("data").commit(); + sql("REFRESH TABLE %s", tableName); + sql("INSERT INTO TABLE %s VALUES (1, 'a1', 'b1')", tableName); + + // verify the metadata tables after adding the first partition column + for (MetadataTableType tableType : Arrays.asList(FILES, ALL_DATA_FILES)) { + assertPartitions( + ImmutableList.of(row(new Object[] {null}), row("b1")), "STRUCT", tableType); + } + + table.updateSpec().addField(Expressions.bucket("category", 8)).commit(); + sql("REFRESH TABLE %s", tableName); + sql("INSERT INTO TABLE %s VALUES (1, 'a1', 'b1')", tableName); + + // verify the metadata tables after adding the second partition column + for (MetadataTableType tableType : Arrays.asList(FILES, ALL_DATA_FILES)) { + assertPartitions( + ImmutableList.of(row(null, null), row("b1", null), row("b1", 2)), + "STRUCT", + tableType); + } + + table.updateSpec().removeField("data").commit(); + sql("REFRESH TABLE %s", tableName); + sql("INSERT INTO TABLE %s VALUES (1, 'a1', 'b1')", tableName); + + // verify the metadata tables after dropping the first partition column + for (MetadataTableType tableType : Arrays.asList(FILES, ALL_DATA_FILES)) { + assertPartitions( + ImmutableList.of(row(null, null), row(null, 2), row("b1", null), row("b1", 2)), + "STRUCT", + tableType); + } + + table.updateSpec().renameField("category_bucket_8", "category_bucket_8_another_name").commit(); + sql("REFRESH TABLE %s", tableName); + + // verify the metadata tables after renaming the second partition column + for (MetadataTableType tableType : Arrays.asList(FILES, ALL_DATA_FILES)) { + assertPartitions( + ImmutableList.of(row(null, null), row(null, 2), row("b1", null), row("b1", 2)), + "STRUCT", + tableType); + } + } + + @Test + public void testFilesMetadataTableFilter() throws ParseException { + sql( + "CREATE TABLE %s (id bigint NOT NULL, category string, data string) USING iceberg " + + "TBLPROPERTIES ('commit.manifest-merge.enabled' 'false')", + tableName); + initTable(); + + sql("INSERT INTO TABLE %s VALUES (1, 'c1', 'd1')", tableName); + sql("INSERT INTO TABLE %s VALUES (2, 'c2', 'd2')", tableName); + + // verify the metadata tables while the current spec is still unpartitioned + for (MetadataTableType tableType : Arrays.asList(FILES, ALL_DATA_FILES)) { + Dataset df = loadMetadataTable(tableType); + Assert.assertTrue( + "Partition must be skipped", df.schema().getFieldIndex("partition").isEmpty()); + } + + Table table = validationCatalog.loadTable(tableIdent); + + table.updateSpec().addField("data").commit(); + sql("REFRESH TABLE %s", tableName); + sql("INSERT INTO TABLE %s VALUES (1, 'c1', 'd1')", tableName); + sql("INSERT INTO TABLE %s VALUES (2, 'c2', 'd2')", tableName); + + // verify the metadata tables after adding the first partition column + for (MetadataTableType tableType : Arrays.asList(FILES, ALL_DATA_FILES)) { + assertPartitions( + ImmutableList.of(row("d2")), "STRUCT", tableType, "partition.data = 'd2'"); + } + + table.updateSpec().addField("category").commit(); + sql("REFRESH TABLE %s", tableName); + sql("INSERT INTO TABLE %s VALUES (1, 'c1', 'd1')", tableName); + sql("INSERT INTO TABLE %s VALUES (2, 'c2', 'd2')", tableName); + + // verify the metadata tables after adding the second partition column + for (MetadataTableType tableType : Arrays.asList(FILES, ALL_DATA_FILES)) { + assertPartitions( + ImmutableList.of(row("d2", null), row("d2", "c2")), + "STRUCT", + tableType, + "partition.data = 'd2'"); + } + for (MetadataTableType tableType : Arrays.asList(FILES, ALL_DATA_FILES)) { + assertPartitions( + ImmutableList.of(row("d2", "c2")), + "STRUCT", + tableType, + "partition.category = 'c2'"); + } + + table.updateSpec().removeField("data").commit(); + sql("REFRESH TABLE %s", tableName); + + // Verify new partitions do not show up for removed 'partition.data=d2' query + sql("INSERT INTO TABLE %s VALUES (3, 'c3', 'd2')", tableName); + sql("INSERT INTO TABLE %s VALUES (4, 'c4', 'd2')", tableName); + + // Verify new partitions do show up for 'partition.category=c2' query + sql("INSERT INTO TABLE %s VALUES (5, 'c2', 'd5')", tableName); + + // no new partition should show up for 'data' partition query as partition field has been + // removed + for (MetadataTableType tableType : Arrays.asList(FILES, ALL_DATA_FILES)) { + assertPartitions( + ImmutableList.of(row("d2", null), row("d2", "c2")), + "STRUCT", + tableType, + "partition.data = 'd2'"); + } + // new partition shows up from 'category' partition field query + for (MetadataTableType tableType : Arrays.asList(FILES, ALL_DATA_FILES)) { + assertPartitions( + ImmutableList.of(row(null, "c2"), row("d2", "c2")), + "STRUCT", + tableType, + "partition.category = 'c2'"); + } + + table.updateSpec().renameField("category", "category_another_name").commit(); + sql("REFRESH TABLE %s", tableName); + + // Verify new partitions do show up for 'category=c2' query + sql("INSERT INTO TABLE %s VALUES (6, 'c2', 'd6')", tableName); + for (MetadataTableType tableType : Arrays.asList(FILES, ALL_DATA_FILES)) { + assertPartitions( + ImmutableList.of(row(null, "c2"), row(null, "c2"), row("d2", "c2")), + "STRUCT", + tableType, + "partition.category_another_name = 'c2'"); + } + } + + @Test + public void testEntriesMetadataTable() throws ParseException { + sql( + "CREATE TABLE %s (id bigint NOT NULL, category string, data string) USING iceberg", + tableName); + initTable(); + + sql("INSERT INTO TABLE %s VALUES (1, 'a1', 'b1')", tableName); + + // verify the metadata tables while the current spec is still unpartitioned + for (MetadataTableType tableType : Arrays.asList(ENTRIES, ALL_ENTRIES)) { + Dataset df = loadMetadataTable(tableType); + StructType dataFileType = (StructType) df.schema().apply("data_file").dataType(); + Assert.assertTrue("Partition must be skipped", dataFileType.getFieldIndex("").isEmpty()); + } + + Table table = validationCatalog.loadTable(tableIdent); + + table.updateSpec().addField("data").commit(); + sql("REFRESH TABLE %s", tableName); + sql("INSERT INTO TABLE %s VALUES (1, 'a1', 'b1')", tableName); + + // verify the metadata tables after adding the first partition column + for (MetadataTableType tableType : Arrays.asList(ENTRIES, ALL_ENTRIES)) { + assertPartitions( + ImmutableList.of(row(new Object[] {null}), row("b1")), "STRUCT", tableType); + } + + table.updateSpec().addField(Expressions.bucket("category", 8)).commit(); + sql("REFRESH TABLE %s", tableName); + sql("INSERT INTO TABLE %s VALUES (1, 'a1', 'b1')", tableName); + + // verify the metadata tables after adding the second partition column + for (MetadataTableType tableType : Arrays.asList(ENTRIES, ALL_ENTRIES)) { + assertPartitions( + ImmutableList.of(row(null, null), row("b1", null), row("b1", 2)), + "STRUCT", + tableType); + } + + table.updateSpec().removeField("data").commit(); + sql("REFRESH TABLE %s", tableName); + sql("INSERT INTO TABLE %s VALUES (1, 'a1', 'b1')", tableName); + + // verify the metadata tables after dropping the first partition column + for (MetadataTableType tableType : Arrays.asList(ENTRIES, ALL_ENTRIES)) { + assertPartitions( + ImmutableList.of(row(null, null), row(null, 2), row("b1", null), row("b1", 2)), + "STRUCT", + tableType); + } + + table.updateSpec().renameField("category_bucket_8", "category_bucket_8_another_name").commit(); + sql("REFRESH TABLE %s", tableName); + + // verify the metadata tables after renaming the second partition column + for (MetadataTableType tableType : Arrays.asList(ENTRIES, ALL_ENTRIES)) { + assertPartitions( + ImmutableList.of(row(null, null), row(null, 2), row("b1", null), row("b1", 2)), + "STRUCT", + tableType); + } + } + + @Test + public void testPartitionsTableAddRemoveFields() throws ParseException { + sql( + "CREATE TABLE %s (id bigint NOT NULL, category string, data string) USING iceberg ", + tableName); + initTable(); + sql("INSERT INTO TABLE %s VALUES (1, 'c1', 'd1')", tableName); + sql("INSERT INTO TABLE %s VALUES (2, 'c2', 'd2')", tableName); + + // verify the metadata tables while the current spec is still unpartitioned + Dataset df = loadMetadataTable(PARTITIONS); + Assert.assertTrue( + "Partition must be skipped", df.schema().getFieldIndex("partition").isEmpty()); + + Table table = validationCatalog.loadTable(tableIdent); + + table.updateSpec().addField("data").commit(); + sql("REFRESH TABLE %s", tableName); + sql("INSERT INTO TABLE %s VALUES (1, 'c1', 'd1')", tableName); + sql("INSERT INTO TABLE %s VALUES (2, 'c2', 'd2')", tableName); + + // verify the metadata tables after adding the first partition column + assertPartitions( + ImmutableList.of(row(new Object[] {null}), row("d1"), row("d2")), + "STRUCT", + PARTITIONS); + + table.updateSpec().addField("category").commit(); + sql("REFRESH TABLE %s", tableName); + sql("INSERT INTO TABLE %s VALUES (1, 'c1', 'd1')", tableName); + sql("INSERT INTO TABLE %s VALUES (2, 'c2', 'd2')", tableName); + + // verify the metadata tables after adding the second partition column + assertPartitions( + ImmutableList.of( + row(null, null), row("d1", null), row("d1", "c1"), row("d2", null), row("d2", "c2")), + "STRUCT", + PARTITIONS); + + // verify the metadata tables after removing the first partition column + table.updateSpec().removeField("data").commit(); + sql("REFRESH TABLE %s", tableName); + sql("INSERT INTO TABLE %s VALUES (1, 'c1', 'd1')", tableName); + sql("INSERT INTO TABLE %s VALUES (2, 'c2', 'd2')", tableName); + + assertPartitions( + ImmutableList.of( + row(null, null), + row(null, "c1"), + row(null, "c2"), + row("d1", null), + row("d1", "c1"), + row("d2", null), + row("d2", "c2")), + "STRUCT", + PARTITIONS); + } + + @Test + public void testPartitionsTableRenameFields() throws ParseException { + sql( + "CREATE TABLE %s (id bigint NOT NULL, category string, data string) USING iceberg", + tableName); + initTable(); + + Table table = validationCatalog.loadTable(tableIdent); + + table.updateSpec().addField("data").addField("category").commit(); + sql("REFRESH TABLE %s", tableName); + sql("INSERT INTO TABLE %s VALUES (1, 'c1', 'd1')", tableName); + sql("INSERT INTO TABLE %s VALUES (2, 'c2', 'd2')", tableName); + + assertPartitions( + ImmutableList.of(row("d1", "c1"), row("d2", "c2")), + "STRUCT", + PARTITIONS); + + table.updateSpec().renameField("category", "category_another_name").commit(); + sql("REFRESH TABLE %s", tableName); + sql("INSERT INTO TABLE %s VALUES (1, 'c1', 'd1')", tableName); + sql("INSERT INTO TABLE %s VALUES (2, 'c2', 'd2')", tableName); + + assertPartitions( + ImmutableList.of(row("d1", "c1"), row("d2", "c2")), + "STRUCT", + PARTITIONS); + } + + @Test + public void testPartitionsTableSwitchFields() throws Exception { + sql( + "CREATE TABLE %s (id bigint NOT NULL, category string, data string) USING iceberg", + tableName); + initTable(); + Table table = validationCatalog.loadTable(tableIdent); + + // verify the metadata tables after re-adding the first dropped column in the second location + table.updateSpec().addField("data").addField("category").commit(); + sql("REFRESH TABLE %s", tableName); + + sql("INSERT INTO TABLE %s VALUES (1, 'c1', 'd1')", tableName); + sql("INSERT INTO TABLE %s VALUES (2, 'c2', 'd2')", tableName); + + assertPartitions( + ImmutableList.of(row("d1", "c1"), row("d2", "c2")), + "STRUCT", + PARTITIONS); + + table.updateSpec().removeField("data").commit(); + sql("REFRESH TABLE %s", tableName); + + sql("INSERT INTO TABLE %s VALUES (1, 'c1', 'd1')", tableName); + sql("INSERT INTO TABLE %s VALUES (2, 'c2', 'd2')", tableName); + + assertPartitions( + ImmutableList.of(row(null, "c1"), row(null, "c2"), row("d1", "c1"), row("d2", "c2")), + "STRUCT", + PARTITIONS); + + table.updateSpec().addField("data").commit(); + sql("REFRESH TABLE %s", tableName); + + sql("INSERT INTO TABLE %s VALUES (1, 'c1', 'd1')", tableName); + sql("INSERT INTO TABLE %s VALUES (2, 'c2', 'd2')", tableName); + sql("INSERT INTO TABLE %s VALUES (3, 'c3', 'd3')", tableName); + + if (formatVersion == 1) { + assertPartitions( + ImmutableList.of( + row(null, "c1", null), + row(null, "c1", "d1"), + row(null, "c2", null), + row(null, "c2", "d2"), + row(null, "c3", "d3"), + row("d1", "c1", null), + row("d2", "c2", null)), + "STRUCT", + PARTITIONS); + } else { + // In V2 re-adding a former partition field that was part of an older spec will not change its + // name or its + // field ID either, thus values will be collapsed into a single common column (as opposed to + // V1 where any new + // partition field addition will result in a new column in this metadata table) + assertPartitions( + ImmutableList.of( + row(null, "c1"), row(null, "c2"), row("d1", "c1"), row("d2", "c2"), row("d3", "c3")), + "STRUCT", + PARTITIONS); + } + } + + @Test + public void testPartitionTableFilterAddRemoveFields() throws ParseException { + // Create un-partitioned table + sql( + "CREATE TABLE %s (id bigint NOT NULL, category string, data string) USING iceberg", + tableName); + initTable(); + + sql("INSERT INTO TABLE %s VALUES (1, 'c1', 'd1')", tableName); + sql("INSERT INTO TABLE %s VALUES (2, 'c2', 'd2')", tableName); + + // Partition Table with one partition column + Table table = validationCatalog.loadTable(tableIdent); + table.updateSpec().addField("data").commit(); + sql("REFRESH TABLE %s", tableName); + sql("INSERT INTO TABLE %s VALUES (1, 'c1', 'd1')", tableName); + sql("INSERT INTO TABLE %s VALUES (2, 'c2', 'd2')", tableName); + + assertPartitions( + ImmutableList.of(row("d2")), "STRUCT", PARTITIONS, "partition.data = 'd2'"); + + // Partition Table with two partition column + table.updateSpec().addField("category").commit(); + sql("REFRESH TABLE %s", tableName); + sql("INSERT INTO TABLE %s VALUES (1, 'c1', 'd1')", tableName); + sql("INSERT INTO TABLE %s VALUES (2, 'c2', 'd2')", tableName); + + assertPartitions( + ImmutableList.of(row("d2", null), row("d2", "c2")), + "STRUCT", + PARTITIONS, + "partition.data = 'd2'"); + assertPartitions( + ImmutableList.of(row("d2", "c2")), + "STRUCT", + PARTITIONS, + "partition.category = 'c2'"); + + // Partition Table with first partition column removed + table.updateSpec().removeField("data").commit(); + sql("REFRESH TABLE %s", tableName); + + sql("INSERT INTO TABLE %s VALUES (3, 'c3', 'd2')", tableName); + sql("INSERT INTO TABLE %s VALUES (4, 'c4', 'd2')", tableName); + sql("INSERT INTO TABLE %s VALUES (5, 'c2', 'd5')", tableName); + assertPartitions( + ImmutableList.of(row("d2", null), row("d2", "c2")), + "STRUCT", + PARTITIONS, + "partition.data = 'd2'"); + assertPartitions( + ImmutableList.of(row(null, "c2"), row("d2", "c2")), + "STRUCT", + PARTITIONS, + "partition.category = 'c2'"); + } + + @Test + public void testPartitionTableFilterSwitchFields() throws Exception { + // Re-added partition fields currently not re-associated: + // https://github.com/apache/iceberg/issues/4292 + // In V1, dropped partition fields show separately when field is re-added + // In V2, re-added field currently conflicts with its deleted form + Assume.assumeTrue(formatVersion == 1); + + sql( + "CREATE TABLE %s (id bigint NOT NULL, category string, data string) USING iceberg", + tableName); + initTable(); + sql("INSERT INTO TABLE %s VALUES (1, 'c1', 'd1')", tableName); + sql("INSERT INTO TABLE %s VALUES (2, 'c2', 'd2')", tableName); + Table table = validationCatalog.loadTable(tableIdent); + + // Two partition columns + table.updateSpec().addField("data").addField("category").commit(); + sql("REFRESH TABLE %s", tableName); + + sql("INSERT INTO TABLE %s VALUES (1, 'c1', 'd1')", tableName); + sql("INSERT INTO TABLE %s VALUES (2, 'c2', 'd2')", tableName); + + // Drop first partition column + table.updateSpec().removeField("data").commit(); + sql("REFRESH TABLE %s", tableName); + sql("INSERT INTO TABLE %s VALUES (1, 'c1', 'd1')", tableName); + sql("INSERT INTO TABLE %s VALUES (2, 'c2', 'd2')", tableName); + + // Re-add first partition column at the end + table.updateSpec().addField("data").commit(); + sql("REFRESH TABLE %s", tableName); + + sql("INSERT INTO TABLE %s VALUES (1, 'c1', 'd1')", tableName); + sql("INSERT INTO TABLE %s VALUES (2, 'c2', 'd2')", tableName); + + assertPartitions( + ImmutableList.of(row(null, "c2", null), row(null, "c2", "d2"), row("d2", "c2", null)), + "STRUCT", + PARTITIONS, + "partition.category = 'c2'"); + + assertPartitions( + ImmutableList.of(row(null, "c1", "d1")), + "STRUCT", + PARTITIONS, + "partition.data = 'd1'"); + } + + @Test + public void testPartitionsTableFilterRenameFields() throws ParseException { + sql( + "CREATE TABLE %s (id bigint NOT NULL, category string, data string) USING iceberg", + tableName); + initTable(); + + Table table = validationCatalog.loadTable(tableIdent); + + table.updateSpec().addField("data").addField("category").commit(); + sql("REFRESH TABLE %s", tableName); + sql("INSERT INTO TABLE %s VALUES (1, 'c1', 'd1')", tableName); + sql("INSERT INTO TABLE %s VALUES (2, 'c2', 'd2')", tableName); + + table.updateSpec().renameField("category", "category_another_name").commit(); + sql("REFRESH TABLE %s", tableName); + sql("INSERT INTO TABLE %s VALUES (1, 'c1', 'd1')", tableName); + sql("INSERT INTO TABLE %s VALUES (2, 'c2', 'd2')", tableName); + + assertPartitions( + ImmutableList.of(row("d1", "c1")), + "STRUCT", + PARTITIONS, + "partition.category_another_name = 'c1'"); + } + + @Test + public void testMetadataTablesWithUnknownTransforms() { + sql( + "CREATE TABLE %s (id bigint NOT NULL, category string, data string) USING iceberg", + tableName); + initTable(); + + sql("INSERT INTO TABLE %s VALUES (1, 'a1', 'b1')", tableName); + + Table table = validationCatalog.loadTable(tableIdent); + + PartitionSpec unknownSpec = + PartitionSpecParser.fromJson( + table.schema(), + "{ \"spec-id\": 1, \"fields\": [ { \"name\": \"id_zero\", \"transform\": \"zero\", \"source-id\": 1 } ] }"); + + // replace the table spec to include an unknown transform + TableOperations ops = ((HasTableOperations) table).operations(); + TableMetadata base = ops.current(); + ops.commit(base, base.updatePartitionSpec(unknownSpec)); + + sql("REFRESH TABLE %s", tableName); + + for (MetadataTableType tableType : Arrays.asList(FILES, ALL_DATA_FILES, ENTRIES, ALL_ENTRIES)) { + AssertHelpers.assertThrows( + "Should complain about the partition type", + ValidationException.class, + "Cannot build table partition type, unknown transforms", + () -> loadMetadataTable(tableType)); + } + } + + @Test + public void testPartitionColumnNamedPartition() { + sql( + "CREATE TABLE %s (id int, partition int) USING iceberg PARTITIONED BY (partition)", + tableName); + sql("INSERT INTO %s VALUES (1, 1), (2, 1), (3, 2), (2, 2)", tableName); + List expected = ImmutableList.of(row(1, 1), row(2, 1), row(3, 2), row(2, 2)); + assertEquals("Should return all expected rows", expected, sql("SELECT * FROM %s", tableName)); + Assert.assertEquals(2, sql("SELECT * FROM %s.files", tableName).size()); + } + + private void assertPartitions( + List expectedPartitions, String expectedTypeAsString, MetadataTableType tableType) + throws ParseException { + assertPartitions(expectedPartitions, expectedTypeAsString, tableType, null); + } + + private void assertPartitions( + List expectedPartitions, + String expectedTypeAsString, + MetadataTableType tableType, + String filter) + throws ParseException { + Dataset df = loadMetadataTable(tableType); + if (filter != null) { + df = df.filter(filter); + } + + DataType expectedType = spark.sessionState().sqlParser().parseDataType(expectedTypeAsString); + switch (tableType) { + case PARTITIONS: + case FILES: + case ALL_DATA_FILES: + DataType actualFilesType = df.schema().apply("partition").dataType(); + Assert.assertEquals("Partition type must match", expectedType, actualFilesType); + break; + + case ENTRIES: + case ALL_ENTRIES: + StructType dataFileType = (StructType) df.schema().apply("data_file").dataType(); + DataType actualEntriesType = dataFileType.apply("partition").dataType(); + Assert.assertEquals("Partition type must match", expectedType, actualEntriesType); + break; + + default: + throw new UnsupportedOperationException("Unsupported metadata table type: " + tableType); + } + + switch (tableType) { + case PARTITIONS: + case FILES: + case ALL_DATA_FILES: + List actualFilesPartitions = + df.orderBy("partition").select("partition.*").collectAsList(); + assertEquals( + "Partitions must match", expectedPartitions, rowsToJava(actualFilesPartitions)); + break; + + case ENTRIES: + case ALL_ENTRIES: + List actualEntriesPartitions = + df.orderBy("data_file.partition").select("data_file.partition.*").collectAsList(); + assertEquals( + "Partitions must match", expectedPartitions, rowsToJava(actualEntriesPartitions)); + break; + + default: + throw new UnsupportedOperationException("Unsupported metadata table type: " + tableType); + } + } + + private Dataset loadMetadataTable(MetadataTableType tableType) { + return spark.read().format("iceberg").load(tableName + "." + tableType.name()); + } + + private void initTable() { + sql( + "ALTER TABLE %s SET TBLPROPERTIES('%s' '%s')", + tableName, DEFAULT_FILE_FORMAT, fileFormat.name()); + sql("ALTER TABLE %s SET TBLPROPERTIES('%s' '%d')", tableName, FORMAT_VERSION, formatVersion); + } +} diff --git a/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/source/TestParquetScan.java b/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/source/TestParquetScan.java new file mode 100644 index 000000000000..f585ed360f95 --- /dev/null +++ b/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/source/TestParquetScan.java @@ -0,0 +1,140 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.source; + +import static org.apache.iceberg.Files.localOutput; + +import java.io.File; +import java.io.IOException; +import java.util.List; +import java.util.UUID; +import org.apache.avro.generic.GenericData; +import org.apache.hadoop.conf.Configuration; +import org.apache.iceberg.DataFile; +import org.apache.iceberg.DataFiles; +import org.apache.iceberg.FileFormat; +import org.apache.iceberg.PartitionSpec; +import org.apache.iceberg.Schema; +import org.apache.iceberg.Table; +import org.apache.iceberg.TableProperties; +import org.apache.iceberg.hadoop.HadoopTables; +import org.apache.iceberg.io.FileAppender; +import org.apache.iceberg.parquet.Parquet; +import org.apache.iceberg.spark.data.AvroDataTest; +import org.apache.iceberg.spark.data.RandomData; +import org.apache.iceberg.spark.data.TestHelpers; +import org.apache.iceberg.types.TypeUtil; +import org.apache.iceberg.types.Types; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.SparkSession; +import org.junit.AfterClass; +import org.junit.Assert; +import org.junit.Assume; +import org.junit.BeforeClass; +import org.junit.Rule; +import org.junit.rules.TemporaryFolder; +import org.junit.runner.RunWith; +import org.junit.runners.Parameterized; + +@RunWith(Parameterized.class) +public class TestParquetScan extends AvroDataTest { + private static final Configuration CONF = new Configuration(); + + private static SparkSession spark = null; + + @BeforeClass + public static void startSpark() { + TestParquetScan.spark = SparkSession.builder().master("local[2]").getOrCreate(); + } + + @AfterClass + public static void stopSpark() { + SparkSession currentSpark = TestParquetScan.spark; + TestParquetScan.spark = null; + currentSpark.stop(); + } + + @Rule public TemporaryFolder temp = new TemporaryFolder(); + + @Parameterized.Parameters(name = "vectorized = {0}") + public static Object[] parameters() { + return new Object[] {false, true}; + } + + private final boolean vectorized; + + public TestParquetScan(boolean vectorized) { + this.vectorized = vectorized; + } + + @Override + protected void writeAndValidate(Schema schema) throws IOException { + Assume.assumeTrue( + "Cannot handle non-string map keys in parquet-avro", + null + == TypeUtil.find( + schema, + type -> type.isMapType() && type.asMapType().keyType() != Types.StringType.get())); + + File parent = temp.newFolder("parquet"); + File location = new File(parent, "test"); + File dataFolder = new File(location, "data"); + dataFolder.mkdirs(); + + File parquetFile = + new File(dataFolder, FileFormat.PARQUET.addExtension(UUID.randomUUID().toString())); + + HadoopTables tables = new HadoopTables(CONF); + Table table = tables.create(schema, PartitionSpec.unpartitioned(), location.toString()); + + // Important: use the table's schema for the rest of the test + // When tables are created, the column ids are reassigned. + Schema tableSchema = table.schema(); + + List expected = RandomData.generateList(tableSchema, 100, 1L); + + try (FileAppender writer = + Parquet.write(localOutput(parquetFile)).schema(tableSchema).build()) { + writer.addAll(expected); + } + + DataFile file = + DataFiles.builder(PartitionSpec.unpartitioned()) + .withFileSizeInBytes(parquetFile.length()) + .withPath(parquetFile.toString()) + .withRecordCount(100) + .build(); + + table.newAppend().appendFile(file).commit(); + table + .updateProperties() + .set(TableProperties.PARQUET_VECTORIZATION_ENABLED, String.valueOf(vectorized)) + .commit(); + + Dataset df = spark.read().format("iceberg").load(location.toString()); + + List rows = df.collectAsList(); + Assert.assertEquals("Should contain 100 rows", 100, rows.size()); + + for (int i = 0; i < expected.size(); i += 1) { + TestHelpers.assertEqualsSafe(tableSchema.asStruct(), expected.get(i), rows.get(i)); + } + } +} diff --git a/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/source/TestPartitionPruning.java b/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/source/TestPartitionPruning.java new file mode 100644 index 000000000000..4ef022c50c59 --- /dev/null +++ b/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/source/TestPartitionPruning.java @@ -0,0 +1,465 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.source; + +import java.io.File; +import java.io.IOException; +import java.net.URI; +import java.sql.Timestamp; +import java.time.Instant; +import java.util.Arrays; +import java.util.List; +import java.util.Map; +import java.util.Random; +import java.util.Set; +import java.util.concurrent.TimeUnit; +import java.util.function.Function; +import java.util.function.Predicate; +import java.util.stream.Collectors; +import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.fs.FSDataInputStream; +import org.apache.hadoop.fs.Path; +import org.apache.hadoop.fs.RawLocalFileSystem; +import org.apache.iceberg.PartitionSpec; +import org.apache.iceberg.Schema; +import org.apache.iceberg.Table; +import org.apache.iceberg.TableProperties; +import org.apache.iceberg.expressions.Literal; +import org.apache.iceberg.hadoop.HadoopTables; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableList; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap; +import org.apache.iceberg.relocated.com.google.common.collect.Maps; +import org.apache.iceberg.relocated.com.google.common.collect.Sets; +import org.apache.iceberg.spark.SparkReadOptions; +import org.apache.iceberg.spark.SparkSchemaUtil; +import org.apache.iceberg.transforms.Transforms; +import org.apache.iceberg.types.Types; +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.SparkSession; +import org.apache.spark.sql.catalyst.InternalRow; +import org.apache.spark.sql.catalyst.expressions.GenericInternalRow; +import org.apache.spark.sql.types.DataTypes; +import org.apache.spark.unsafe.types.UTF8String; +import org.junit.AfterClass; +import org.junit.Assert; +import org.junit.BeforeClass; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.TemporaryFolder; +import org.junit.runner.RunWith; +import org.junit.runners.Parameterized; + +@RunWith(Parameterized.class) +public class TestPartitionPruning { + + private static final Configuration CONF = new Configuration(); + private static final HadoopTables TABLES = new HadoopTables(CONF); + + @Parameterized.Parameters(name = "format = {0}, vectorized = {1}") + public static Object[][] parameters() { + return new Object[][] { + {"parquet", false}, + {"parquet", true}, + {"avro", false}, + {"orc", false}, + {"orc", true} + }; + } + + private final String format; + private final boolean vectorized; + + public TestPartitionPruning(String format, boolean vectorized) { + this.format = format; + this.vectorized = vectorized; + } + + private static SparkSession spark = null; + private static JavaSparkContext sparkContext = null; + + private static final Function BUCKET_FUNC = + Transforms.bucket(3).bind(Types.IntegerType.get()); + private static final Function TRUNCATE_FUNC = + Transforms.truncate(5).bind(Types.StringType.get()); + private static final Function HOUR_FUNC = + Transforms.hour().bind(Types.TimestampType.withoutZone()); + + @BeforeClass + public static void startSpark() { + TestPartitionPruning.spark = SparkSession.builder().master("local[2]").getOrCreate(); + TestPartitionPruning.sparkContext = JavaSparkContext.fromSparkContext(spark.sparkContext()); + + String optionKey = String.format("fs.%s.impl", CountOpenLocalFileSystem.scheme); + CONF.set(optionKey, CountOpenLocalFileSystem.class.getName()); + spark.conf().set(optionKey, CountOpenLocalFileSystem.class.getName()); + spark.conf().set("spark.sql.session.timeZone", "UTC"); + spark.udf().register("bucket3", (Integer num) -> BUCKET_FUNC.apply(num), DataTypes.IntegerType); + spark + .udf() + .register("truncate5", (String str) -> TRUNCATE_FUNC.apply(str), DataTypes.StringType); + // NOTE: date transforms take the type long, not Timestamp + spark + .udf() + .register( + "hour", + (Timestamp ts) -> + HOUR_FUNC.apply( + org.apache.spark.sql.catalyst.util.DateTimeUtils.fromJavaTimestamp(ts)), + DataTypes.IntegerType); + } + + @AfterClass + public static void stopSpark() { + SparkSession currentSpark = TestPartitionPruning.spark; + TestPartitionPruning.spark = null; + currentSpark.stop(); + } + + private static final Schema LOG_SCHEMA = + new Schema( + Types.NestedField.optional(1, "id", Types.IntegerType.get()), + Types.NestedField.optional(2, "date", Types.StringType.get()), + Types.NestedField.optional(3, "level", Types.StringType.get()), + Types.NestedField.optional(4, "message", Types.StringType.get()), + Types.NestedField.optional(5, "timestamp", Types.TimestampType.withZone())); + + private static final List LOGS = + ImmutableList.of( + LogMessage.debug("2020-02-02", "debug event 1", getInstant("2020-02-02T00:00:00")), + LogMessage.info("2020-02-02", "info event 1", getInstant("2020-02-02T01:00:00")), + LogMessage.debug("2020-02-02", "debug event 2", getInstant("2020-02-02T02:00:00")), + LogMessage.info("2020-02-03", "info event 2", getInstant("2020-02-03T00:00:00")), + LogMessage.debug("2020-02-03", "debug event 3", getInstant("2020-02-03T01:00:00")), + LogMessage.info("2020-02-03", "info event 3", getInstant("2020-02-03T02:00:00")), + LogMessage.error("2020-02-03", "error event 1", getInstant("2020-02-03T03:00:00")), + LogMessage.debug("2020-02-04", "debug event 4", getInstant("2020-02-04T01:00:00")), + LogMessage.warn("2020-02-04", "warn event 1", getInstant("2020-02-04T02:00:00")), + LogMessage.debug("2020-02-04", "debug event 5", getInstant("2020-02-04T03:00:00"))); + + private static Instant getInstant(String timestampWithoutZone) { + Long epochMicros = + (Long) Literal.of(timestampWithoutZone).to(Types.TimestampType.withoutZone()).value(); + return Instant.ofEpochMilli(TimeUnit.MICROSECONDS.toMillis(epochMicros)); + } + + @Rule public TemporaryFolder temp = new TemporaryFolder(); + + private PartitionSpec spec = + PartitionSpec.builderFor(LOG_SCHEMA) + .identity("date") + .identity("level") + .bucket("id", 3) + .truncate("message", 5) + .hour("timestamp") + .build(); + + @Test + public void testPartitionPruningIdentityString() { + String filterCond = "date >= '2020-02-03' AND level = 'DEBUG'"; + Predicate partCondition = + (Row r) -> { + String date = r.getString(0); + String level = r.getString(1); + return date.compareTo("2020-02-03") >= 0 && level.equals("DEBUG"); + }; + + runTest(filterCond, partCondition); + } + + @Test + public void testPartitionPruningBucketingInteger() { + final int[] ids = new int[] {LOGS.get(3).getId(), LOGS.get(7).getId()}; + String condForIds = + Arrays.stream(ids).mapToObj(String::valueOf).collect(Collectors.joining(",", "(", ")")); + String filterCond = "id in " + condForIds; + Predicate partCondition = + (Row r) -> { + int bucketId = r.getInt(2); + Set buckets = + Arrays.stream(ids).map(BUCKET_FUNC::apply).boxed().collect(Collectors.toSet()); + return buckets.contains(bucketId); + }; + + runTest(filterCond, partCondition); + } + + @Test + public void testPartitionPruningTruncatedString() { + String filterCond = "message like 'info event%'"; + Predicate partCondition = + (Row r) -> { + String truncatedMessage = r.getString(3); + return truncatedMessage.equals("info "); + }; + + runTest(filterCond, partCondition); + } + + @Test + public void testPartitionPruningTruncatedStringComparingValueShorterThanPartitionValue() { + String filterCond = "message like 'inf%'"; + Predicate partCondition = + (Row r) -> { + String truncatedMessage = r.getString(3); + return truncatedMessage.startsWith("inf"); + }; + + runTest(filterCond, partCondition); + } + + @Test + public void testPartitionPruningHourlyPartition() { + String filterCond; + if (spark.version().startsWith("2")) { + // Looks like from Spark 2 we need to compare timestamp with timestamp to push down the + // filter. + filterCond = "timestamp >= to_timestamp('2020-02-03T01:00:00')"; + } else { + filterCond = "timestamp >= '2020-02-03T01:00:00'"; + } + Predicate partCondition = + (Row r) -> { + int hourValue = r.getInt(4); + Instant instant = getInstant("2020-02-03T01:00:00"); + Integer hourValueToFilter = + HOUR_FUNC.apply(TimeUnit.MILLISECONDS.toMicros(instant.toEpochMilli())); + return hourValue >= hourValueToFilter; + }; + + runTest(filterCond, partCondition); + } + + private void runTest(String filterCond, Predicate partCondition) { + File originTableLocation = createTempDir(); + Assert.assertTrue("Temp folder should exist", originTableLocation.exists()); + + Table table = createTable(originTableLocation); + Dataset logs = createTestDataset(); + saveTestDatasetToTable(logs, table); + + List expected = + logs.select("id", "date", "level", "message", "timestamp") + .filter(filterCond) + .orderBy("id") + .collectAsList(); + Assert.assertFalse("Expected rows should be not empty", expected.isEmpty()); + + // remove records which may be recorded during storing to table + CountOpenLocalFileSystem.resetRecordsInPathPrefix(originTableLocation.getAbsolutePath()); + + List actual = + spark + .read() + .format("iceberg") + .option(SparkReadOptions.VECTORIZATION_ENABLED, String.valueOf(vectorized)) + .load(table.location()) + .select("id", "date", "level", "message", "timestamp") + .filter(filterCond) + .orderBy("id") + .collectAsList(); + Assert.assertFalse("Actual rows should not be empty", actual.isEmpty()); + + Assert.assertEquals("Rows should match", expected, actual); + + assertAccessOnDataFiles(originTableLocation, table, partCondition); + } + + private File createTempDir() { + try { + return temp.newFolder(); + } catch (Exception e) { + throw new RuntimeException(e); + } + } + + private Table createTable(File originTableLocation) { + String trackedTableLocation = CountOpenLocalFileSystem.convertPath(originTableLocation); + Map properties = ImmutableMap.of(TableProperties.DEFAULT_FILE_FORMAT, format); + return TABLES.create(LOG_SCHEMA, spec, properties, trackedTableLocation); + } + + private Dataset createTestDataset() { + List rows = + LOGS.stream() + .map( + logMessage -> { + Object[] underlying = + new Object[] { + logMessage.getId(), + UTF8String.fromString(logMessage.getDate()), + UTF8String.fromString(logMessage.getLevel()), + UTF8String.fromString(logMessage.getMessage()), + // discard the nanoseconds part to simplify + TimeUnit.MILLISECONDS.toMicros(logMessage.getTimestamp().toEpochMilli()) + }; + return new GenericInternalRow(underlying); + }) + .collect(Collectors.toList()); + + JavaRDD rdd = sparkContext.parallelize(rows); + Dataset df = + spark.internalCreateDataFrame( + JavaRDD.toRDD(rdd), SparkSchemaUtil.convert(LOG_SCHEMA), false); + + return df.selectExpr("id", "date", "level", "message", "timestamp") + .selectExpr( + "id", + "date", + "level", + "message", + "timestamp", + "bucket3(id) AS bucket_id", + "truncate5(message) AS truncated_message", + "hour(timestamp) AS ts_hour"); + } + + private void saveTestDatasetToTable(Dataset logs, Table table) { + logs.orderBy("date", "level", "bucket_id", "truncated_message", "ts_hour") + .select("id", "date", "level", "message", "timestamp") + .write() + .format("iceberg") + .mode("append") + .save(table.location()); + } + + private void assertAccessOnDataFiles( + File originTableLocation, Table table, Predicate partCondition) { + // only use files in current table location to avoid side-effects on concurrent test runs + Set readFilesInQuery = + CountOpenLocalFileSystem.pathToNumOpenCalled.keySet().stream() + .filter(path -> path.startsWith(originTableLocation.getAbsolutePath())) + .collect(Collectors.toSet()); + + List files = + spark.read().format("iceberg").load(table.location() + "#files").collectAsList(); + + Set filesToRead = extractFilePathsMatchingConditionOnPartition(files, partCondition); + Set filesToNotRead = extractFilePathsNotIn(files, filesToRead); + + // Just to be sure, they should be mutually exclusive. + Assert.assertTrue(Sets.intersection(filesToRead, filesToNotRead).isEmpty()); + + Assert.assertFalse("The query should prune some data files.", filesToNotRead.isEmpty()); + + // We don't check "all" data files bound to the condition are being read, as data files can be + // pruned on + // other conditions like lower/upper bound of columns. + Assert.assertFalse( + "Some of data files in partition range should be read. " + + "Read files in query: " + + readFilesInQuery + + " / data files in partition range: " + + filesToRead, + Sets.intersection(filesToRead, readFilesInQuery).isEmpty()); + + // Data files which aren't bound to the condition shouldn't be read. + Assert.assertTrue( + "Data files outside of partition range should not be read. " + + "Read files in query: " + + readFilesInQuery + + " / data files outside of partition range: " + + filesToNotRead, + Sets.intersection(filesToNotRead, readFilesInQuery).isEmpty()); + } + + private Set extractFilePathsMatchingConditionOnPartition( + List files, Predicate condition) { + // idx 1: file_path, idx 3: partition + return files.stream() + .filter( + r -> { + Row partition = r.getStruct(4); + return condition.test(partition); + }) + .map(r -> CountOpenLocalFileSystem.stripScheme(r.getString(1))) + .collect(Collectors.toSet()); + } + + private Set extractFilePathsNotIn(List files, Set filePaths) { + Set allFilePaths = + files.stream() + .map(r -> CountOpenLocalFileSystem.stripScheme(r.getString(1))) + .collect(Collectors.toSet()); + return Sets.newHashSet(Sets.symmetricDifference(allFilePaths, filePaths)); + } + + public static class CountOpenLocalFileSystem extends RawLocalFileSystem { + public static String scheme = + String.format("TestIdentityPartitionData%dfs", new Random().nextInt()); + public static Map pathToNumOpenCalled = Maps.newConcurrentMap(); + + public static String convertPath(String absPath) { + return scheme + "://" + absPath; + } + + public static String convertPath(File file) { + return convertPath(file.getAbsolutePath()); + } + + public static String stripScheme(String pathWithScheme) { + if (!pathWithScheme.startsWith(scheme + ":")) { + throw new IllegalArgumentException("Received unexpected path: " + pathWithScheme); + } + + int idxToCut = scheme.length() + 1; + while (pathWithScheme.charAt(idxToCut) == '/') { + idxToCut++; + } + + // leave the last '/' + idxToCut--; + + return pathWithScheme.substring(idxToCut); + } + + public static void resetRecordsInPathPrefix(String pathPrefix) { + pathToNumOpenCalled.keySet().stream() + .filter(p -> p.startsWith(pathPrefix)) + .forEach(key -> pathToNumOpenCalled.remove(key)); + } + + @Override + public URI getUri() { + return URI.create(scheme + ":///"); + } + + @Override + public String getScheme() { + return scheme; + } + + @Override + public FSDataInputStream open(Path f, int bufferSize) throws IOException { + String path = f.toUri().getPath(); + pathToNumOpenCalled.compute( + path, + (ignored, v) -> { + if (v == null) { + return 1L; + } else { + return v + 1; + } + }); + return super.open(f, bufferSize); + } + } +} diff --git a/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/source/TestPartitionValues.java b/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/source/TestPartitionValues.java new file mode 100644 index 000000000000..ad0984ef4220 --- /dev/null +++ b/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/source/TestPartitionValues.java @@ -0,0 +1,504 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.source; + +import static org.apache.iceberg.types.Types.NestedField.optional; +import static org.apache.iceberg.types.Types.NestedField.required; + +import java.io.File; +import java.util.List; +import java.util.stream.Collectors; +import java.util.stream.IntStream; +import org.apache.avro.generic.GenericData; +import org.apache.iceberg.DataFiles; +import org.apache.iceberg.Files; +import org.apache.iceberg.PartitionSpec; +import org.apache.iceberg.Schema; +import org.apache.iceberg.Table; +import org.apache.iceberg.TableProperties; +import org.apache.iceberg.avro.Avro; +import org.apache.iceberg.hadoop.HadoopTables; +import org.apache.iceberg.io.FileAppender; +import org.apache.iceberg.relocated.com.google.common.collect.Lists; +import org.apache.iceberg.spark.SparkReadOptions; +import org.apache.iceberg.spark.SparkWriteOptions; +import org.apache.iceberg.spark.data.RandomData; +import org.apache.iceberg.spark.data.TestHelpers; +import org.apache.iceberg.types.Types; +import org.apache.spark.api.java.function.MapFunction; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Encoders; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.RowFactory; +import org.apache.spark.sql.SaveMode; +import org.apache.spark.sql.SparkSession; +import org.apache.spark.sql.types.DataTypes; +import org.apache.spark.sql.types.Metadata; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; +import org.junit.AfterClass; +import org.junit.Assert; +import org.junit.Assume; +import org.junit.BeforeClass; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.TemporaryFolder; +import org.junit.runner.RunWith; +import org.junit.runners.Parameterized; + +@RunWith(Parameterized.class) +public class TestPartitionValues { + @Parameterized.Parameters(name = "format = {0}, vectorized = {1}") + public static Object[][] parameters() { + return new Object[][] { + {"parquet", false}, + {"parquet", true}, + {"avro", false}, + {"orc", false}, + {"orc", true} + }; + } + + private static final Schema SUPPORTED_PRIMITIVES = + new Schema( + required(100, "id", Types.LongType.get()), + required(101, "data", Types.StringType.get()), + required(102, "b", Types.BooleanType.get()), + required(103, "i", Types.IntegerType.get()), + required(104, "l", Types.LongType.get()), + required(105, "f", Types.FloatType.get()), + required(106, "d", Types.DoubleType.get()), + required(107, "date", Types.DateType.get()), + required(108, "ts", Types.TimestampType.withZone()), + required(110, "s", Types.StringType.get()), + required(113, "bytes", Types.BinaryType.get()), + required(114, "dec_9_0", Types.DecimalType.of(9, 0)), + required(115, "dec_11_2", Types.DecimalType.of(11, 2)), + required(116, "dec_38_10", Types.DecimalType.of(38, 10)) // spark's maximum precision + ); + + private static final Schema SIMPLE_SCHEMA = + new Schema( + optional(1, "id", Types.IntegerType.get()), optional(2, "data", Types.StringType.get())); + + private static final PartitionSpec SPEC = + PartitionSpec.builderFor(SIMPLE_SCHEMA).identity("data").build(); + + private static SparkSession spark = null; + + @BeforeClass + public static void startSpark() { + TestPartitionValues.spark = SparkSession.builder().master("local[2]").getOrCreate(); + } + + @AfterClass + public static void stopSpark() { + SparkSession currentSpark = TestPartitionValues.spark; + TestPartitionValues.spark = null; + currentSpark.stop(); + } + + @Rule public TemporaryFolder temp = new TemporaryFolder(); + + private final String format; + private final boolean vectorized; + + public TestPartitionValues(String format, boolean vectorized) { + this.format = format; + this.vectorized = vectorized; + } + + @Test + public void testNullPartitionValue() throws Exception { + String desc = "null_part"; + File parent = temp.newFolder(desc); + File location = new File(parent, "test"); + File dataFolder = new File(location, "data"); + Assert.assertTrue("mkdirs should succeed", dataFolder.mkdirs()); + + HadoopTables tables = new HadoopTables(spark.sessionState().newHadoopConf()); + Table table = tables.create(SIMPLE_SCHEMA, SPEC, location.toString()); + table.updateProperties().set(TableProperties.DEFAULT_FILE_FORMAT, format).commit(); + + List expected = + Lists.newArrayList( + new SimpleRecord(1, "a"), + new SimpleRecord(2, "b"), + new SimpleRecord(3, "c"), + new SimpleRecord(4, null)); + + Dataset df = spark.createDataFrame(expected, SimpleRecord.class); + + df.select("id", "data") + .write() + .format("iceberg") + .mode(SaveMode.Append) + .save(location.toString()); + + Dataset result = + spark + .read() + .format("iceberg") + .option(SparkReadOptions.VECTORIZATION_ENABLED, String.valueOf(vectorized)) + .load(location.toString()); + + List actual = + result.orderBy("id").as(Encoders.bean(SimpleRecord.class)).collectAsList(); + + Assert.assertEquals("Number of rows should match", expected.size(), actual.size()); + Assert.assertEquals("Result rows should match", expected, actual); + } + + @Test + public void testReorderedColumns() throws Exception { + String desc = "reorder_columns"; + File parent = temp.newFolder(desc); + File location = new File(parent, "test"); + File dataFolder = new File(location, "data"); + Assert.assertTrue("mkdirs should succeed", dataFolder.mkdirs()); + + HadoopTables tables = new HadoopTables(spark.sessionState().newHadoopConf()); + Table table = tables.create(SIMPLE_SCHEMA, SPEC, location.toString()); + table.updateProperties().set(TableProperties.DEFAULT_FILE_FORMAT, format).commit(); + + List expected = + Lists.newArrayList( + new SimpleRecord(1, "a"), new SimpleRecord(2, "b"), new SimpleRecord(3, "c")); + + Dataset df = spark.createDataFrame(expected, SimpleRecord.class); + + df.select("data", "id") + .write() + .format("iceberg") + .mode(SaveMode.Append) + .option(SparkWriteOptions.CHECK_ORDERING, "false") + .save(location.toString()); + + Dataset result = + spark + .read() + .format("iceberg") + .option(SparkReadOptions.VECTORIZATION_ENABLED, String.valueOf(vectorized)) + .load(location.toString()); + + List actual = + result.orderBy("id").as(Encoders.bean(SimpleRecord.class)).collectAsList(); + + Assert.assertEquals("Number of rows should match", expected.size(), actual.size()); + Assert.assertEquals("Result rows should match", expected, actual); + } + + @Test + public void testReorderedColumnsNoNullability() throws Exception { + String desc = "reorder_columns_no_nullability"; + File parent = temp.newFolder(desc); + File location = new File(parent, "test"); + File dataFolder = new File(location, "data"); + Assert.assertTrue("mkdirs should succeed", dataFolder.mkdirs()); + + HadoopTables tables = new HadoopTables(spark.sessionState().newHadoopConf()); + Table table = tables.create(SIMPLE_SCHEMA, SPEC, location.toString()); + table.updateProperties().set(TableProperties.DEFAULT_FILE_FORMAT, format).commit(); + + List expected = + Lists.newArrayList( + new SimpleRecord(1, "a"), new SimpleRecord(2, "b"), new SimpleRecord(3, "c")); + + Dataset df = spark.createDataFrame(expected, SimpleRecord.class); + + df.select("data", "id") + .write() + .format("iceberg") + .mode(SaveMode.Append) + .option(SparkWriteOptions.CHECK_ORDERING, "false") + .option(SparkWriteOptions.CHECK_NULLABILITY, "false") + .save(location.toString()); + + Dataset result = + spark + .read() + .format("iceberg") + .option(SparkReadOptions.VECTORIZATION_ENABLED, String.valueOf(vectorized)) + .load(location.toString()); + + List actual = + result.orderBy("id").as(Encoders.bean(SimpleRecord.class)).collectAsList(); + + Assert.assertEquals("Number of rows should match", expected.size(), actual.size()); + Assert.assertEquals("Result rows should match", expected, actual); + } + + @Test + public void testPartitionValueTypes() throws Exception { + String[] columnNames = + new String[] { + "b", "i", "l", "f", "d", "date", "ts", "s", "bytes", "dec_9_0", "dec_11_2", "dec_38_10" + }; + + HadoopTables tables = new HadoopTables(spark.sessionState().newHadoopConf()); + + // create a table around the source data + String sourceLocation = temp.newFolder("source_table").toString(); + Table source = tables.create(SUPPORTED_PRIMITIVES, sourceLocation); + + // write out an Avro data file with all of the data types for source data + List expected = RandomData.generateList(source.schema(), 2, 128735L); + File avroData = temp.newFile("data.avro"); + Assert.assertTrue(avroData.delete()); + try (FileAppender appender = + Avro.write(Files.localOutput(avroData)).schema(source.schema()).build()) { + appender.addAll(expected); + } + + // add the Avro data file to the source table + source + .newAppend() + .appendFile( + DataFiles.builder(PartitionSpec.unpartitioned()) + .withRecordCount(10) + .withInputFile(Files.localInput(avroData)) + .build()) + .commit(); + + Dataset sourceDF = + spark + .read() + .format("iceberg") + .option(SparkReadOptions.VECTORIZATION_ENABLED, String.valueOf(vectorized)) + .load(sourceLocation); + + for (String column : columnNames) { + String desc = "partition_by_" + SUPPORTED_PRIMITIVES.findType(column).toString(); + + File parent = temp.newFolder(desc); + File location = new File(parent, "test"); + File dataFolder = new File(location, "data"); + Assert.assertTrue("mkdirs should succeed", dataFolder.mkdirs()); + + PartitionSpec spec = PartitionSpec.builderFor(SUPPORTED_PRIMITIVES).identity(column).build(); + + Table table = tables.create(SUPPORTED_PRIMITIVES, spec, location.toString()); + table.updateProperties().set(TableProperties.DEFAULT_FILE_FORMAT, format).commit(); + + sourceDF + .write() + .format("iceberg") + .mode(SaveMode.Append) + .option(SparkWriteOptions.USE_TABLE_DISTRIBUTION_AND_ORDERING, "false") + .save(location.toString()); + + List actual = + spark + .read() + .format("iceberg") + .option(SparkReadOptions.VECTORIZATION_ENABLED, String.valueOf(vectorized)) + .load(location.toString()) + .collectAsList(); + + Assert.assertEquals("Number of rows should match", expected.size(), actual.size()); + + for (int i = 0; i < expected.size(); i += 1) { + TestHelpers.assertEqualsSafe( + SUPPORTED_PRIMITIVES.asStruct(), expected.get(i), actual.get(i)); + } + } + } + + @Test + public void testNestedPartitionValues() throws Exception { + String[] columnNames = + new String[] { + "b", "i", "l", "f", "d", "date", "ts", "s", "bytes", "dec_9_0", "dec_11_2", "dec_38_10" + }; + + HadoopTables tables = new HadoopTables(spark.sessionState().newHadoopConf()); + Schema nestedSchema = new Schema(optional(1, "nested", SUPPORTED_PRIMITIVES.asStruct())); + + // create a table around the source data + String sourceLocation = temp.newFolder("source_table").toString(); + Table source = tables.create(nestedSchema, sourceLocation); + + // write out an Avro data file with all of the data types for source data + List expected = RandomData.generateList(source.schema(), 2, 128735L); + File avroData = temp.newFile("data.avro"); + Assert.assertTrue(avroData.delete()); + try (FileAppender appender = + Avro.write(Files.localOutput(avroData)).schema(source.schema()).build()) { + appender.addAll(expected); + } + + // add the Avro data file to the source table + source + .newAppend() + .appendFile( + DataFiles.builder(PartitionSpec.unpartitioned()) + .withRecordCount(10) + .withInputFile(Files.localInput(avroData)) + .build()) + .commit(); + + Dataset sourceDF = + spark + .read() + .format("iceberg") + .option(SparkReadOptions.VECTORIZATION_ENABLED, String.valueOf(vectorized)) + .load(sourceLocation); + + for (String column : columnNames) { + String desc = "partition_by_" + SUPPORTED_PRIMITIVES.findType(column).toString(); + + File parent = temp.newFolder(desc); + File location = new File(parent, "test"); + File dataFolder = new File(location, "data"); + Assert.assertTrue("mkdirs should succeed", dataFolder.mkdirs()); + + PartitionSpec spec = + PartitionSpec.builderFor(nestedSchema).identity("nested." + column).build(); + + Table table = tables.create(nestedSchema, spec, location.toString()); + table.updateProperties().set(TableProperties.DEFAULT_FILE_FORMAT, format).commit(); + + sourceDF + .write() + .format("iceberg") + .mode(SaveMode.Append) + .option(SparkWriteOptions.USE_TABLE_DISTRIBUTION_AND_ORDERING, "false") + .save(location.toString()); + + List actual = + spark + .read() + .format("iceberg") + .option(SparkReadOptions.VECTORIZATION_ENABLED, String.valueOf(vectorized)) + .load(location.toString()) + .collectAsList(); + + Assert.assertEquals("Number of rows should match", expected.size(), actual.size()); + + for (int i = 0; i < expected.size(); i += 1) { + TestHelpers.assertEqualsSafe(nestedSchema.asStruct(), expected.get(i), actual.get(i)); + } + } + } + + /** + * To verify if WrappedPositionAccessor is generated against a string field within a nested field, + * rather than a Position2Accessor. Or when building the partition path, a ClassCastException is + * thrown with the message like: Cannot cast org.apache.spark.unsafe.types.UTF8String to + * java.lang.CharSequence + */ + @Test + public void testPartitionedByNestedString() throws Exception { + // schema and partition spec + Schema nestedSchema = + new Schema( + Types.NestedField.required( + 1, + "struct", + Types.StructType.of( + Types.NestedField.required(2, "string", Types.StringType.get())))); + PartitionSpec spec = PartitionSpec.builderFor(nestedSchema).identity("struct.string").build(); + + // create table + HadoopTables tables = new HadoopTables(spark.sessionState().newHadoopConf()); + String baseLocation = temp.newFolder("partition_by_nested_string").toString(); + tables.create(nestedSchema, spec, baseLocation); + + // input data frame + StructField[] structFields = { + new StructField( + "struct", + DataTypes.createStructType( + new StructField[] { + new StructField("string", DataTypes.StringType, false, Metadata.empty()) + }), + false, + Metadata.empty()) + }; + + List rows = Lists.newArrayList(); + rows.add(RowFactory.create(RowFactory.create("nested_string_value"))); + Dataset sourceDF = spark.createDataFrame(rows, new StructType(structFields)); + + // write into iceberg + sourceDF.write().format("iceberg").mode(SaveMode.Append).save(baseLocation); + + // verify + List actual = + spark + .read() + .format("iceberg") + .option(SparkReadOptions.VECTORIZATION_ENABLED, String.valueOf(vectorized)) + .load(baseLocation) + .collectAsList(); + + Assert.assertEquals("Number of rows should match", rows.size(), actual.size()); + } + + @Test + public void testReadPartitionColumn() throws Exception { + Assume.assumeTrue("Temporary skip ORC", !"orc".equals(format)); + + Schema nestedSchema = + new Schema( + Types.NestedField.optional(1, "id", Types.LongType.get()), + Types.NestedField.optional( + 2, + "struct", + Types.StructType.of( + Types.NestedField.optional(3, "innerId", Types.LongType.get()), + Types.NestedField.optional(4, "innerName", Types.StringType.get())))); + PartitionSpec spec = + PartitionSpec.builderFor(nestedSchema).identity("struct.innerName").build(); + + // create table + HadoopTables tables = new HadoopTables(spark.sessionState().newHadoopConf()); + String baseLocation = temp.newFolder("partition_by_nested_string").toString(); + Table table = tables.create(nestedSchema, spec, baseLocation); + table.updateProperties().set(TableProperties.DEFAULT_FILE_FORMAT, format).commit(); + + // write into iceberg + MapFunction func = + value -> new ComplexRecord(value, new NestedRecord(value, "name_" + value)); + spark + .range(0, 10, 1, 1) + .map(func, Encoders.bean(ComplexRecord.class)) + .write() + .format("iceberg") + .mode(SaveMode.Append) + .save(baseLocation); + + List actual = + spark + .read() + .format("iceberg") + .option(SparkReadOptions.VECTORIZATION_ENABLED, String.valueOf(vectorized)) + .load(baseLocation) + .select("struct.innerName") + .orderBy("struct.innerName") + .as(Encoders.STRING()) + .collectAsList(); + + Assert.assertEquals("Number of rows should match", 10, actual.size()); + + List inputRecords = + IntStream.range(0, 10).mapToObj(i -> "name_" + i).collect(Collectors.toList()); + Assert.assertEquals("Read object should be matched", inputRecords, actual); + } +} diff --git a/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/source/TestPathIdentifier.java b/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/source/TestPathIdentifier.java new file mode 100644 index 000000000000..5baf6071233d --- /dev/null +++ b/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/source/TestPathIdentifier.java @@ -0,0 +1,86 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.source; + +import static org.apache.iceberg.types.Types.NestedField.required; + +import java.io.File; +import java.io.IOException; +import org.apache.iceberg.BaseTable; +import org.apache.iceberg.Schema; +import org.apache.iceberg.hadoop.HadoopTableOperations; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap; +import org.apache.iceberg.spark.PathIdentifier; +import org.apache.iceberg.spark.SparkCatalog; +import org.apache.iceberg.spark.SparkSchemaUtil; +import org.apache.iceberg.spark.SparkTestBase; +import org.apache.iceberg.types.Types; +import org.apache.spark.sql.catalyst.analysis.NoSuchTableException; +import org.apache.spark.sql.catalyst.analysis.TableAlreadyExistsException; +import org.apache.spark.sql.connector.expressions.Transform; +import org.apache.spark.sql.util.CaseInsensitiveStringMap; +import org.assertj.core.api.Assertions; +import org.junit.After; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.TemporaryFolder; + +public class TestPathIdentifier extends SparkTestBase { + + private static final Schema SCHEMA = + new Schema( + required(1, "id", Types.LongType.get()), required(2, "data", Types.StringType.get())); + + @Rule public TemporaryFolder temp = new TemporaryFolder(); + private File tableLocation; + private PathIdentifier identifier; + private SparkCatalog sparkCatalog; + + @Before + public void before() throws IOException { + tableLocation = temp.newFolder(); + identifier = new PathIdentifier(tableLocation.getAbsolutePath()); + sparkCatalog = new SparkCatalog(); + sparkCatalog.initialize("test", new CaseInsensitiveStringMap(ImmutableMap.of())); + } + + @After + public void after() { + tableLocation.delete(); + sparkCatalog = null; + } + + @Test + public void testPathIdentifier() throws TableAlreadyExistsException, NoSuchTableException { + SparkTable table = + (SparkTable) + sparkCatalog.createTable( + identifier, SparkSchemaUtil.convert(SCHEMA), new Transform[0], ImmutableMap.of()); + + Assert.assertEquals(table.table().location(), tableLocation.getAbsolutePath()); + Assertions.assertThat(table.table()).isInstanceOf(BaseTable.class); + Assertions.assertThat(((BaseTable) table.table()).operations()) + .isInstanceOf(HadoopTableOperations.class); + + Assert.assertEquals(sparkCatalog.loadTable(identifier), table); + Assert.assertTrue(sparkCatalog.dropTable(identifier)); + } +} diff --git a/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/source/TestPositionDeletesTable.java b/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/source/TestPositionDeletesTable.java new file mode 100644 index 000000000000..2ec4f2f4f907 --- /dev/null +++ b/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/source/TestPositionDeletesTable.java @@ -0,0 +1,1566 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.source; + +import java.io.IOException; +import java.util.Comparator; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.UUID; +import java.util.function.BiFunction; +import java.util.stream.Collectors; +import org.apache.iceberg.DataFile; +import org.apache.iceberg.DeleteFile; +import org.apache.iceberg.FileFormat; +import org.apache.iceberg.Files; +import org.apache.iceberg.MetadataColumns; +import org.apache.iceberg.MetadataTableType; +import org.apache.iceberg.MetadataTableUtils; +import org.apache.iceberg.PartitionField; +import org.apache.iceberg.PartitionSpec; +import org.apache.iceberg.Partitioning; +import org.apache.iceberg.PositionDeletesScanTask; +import org.apache.iceberg.ScanTask; +import org.apache.iceberg.Schema; +import org.apache.iceberg.StructLike; +import org.apache.iceberg.Table; +import org.apache.iceberg.TableProperties; +import org.apache.iceberg.TestHelpers; +import org.apache.iceberg.catalog.TableIdentifier; +import org.apache.iceberg.data.FileHelpers; +import org.apache.iceberg.data.GenericRecord; +import org.apache.iceberg.data.Record; +import org.apache.iceberg.deletes.PositionDelete; +import org.apache.iceberg.expressions.Expression; +import org.apache.iceberg.expressions.Expressions; +import org.apache.iceberg.io.CloseableIterable; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableList; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableSet; +import org.apache.iceberg.relocated.com.google.common.collect.Iterables; +import org.apache.iceberg.relocated.com.google.common.collect.Lists; +import org.apache.iceberg.spark.PositionDeletesRewriteCoordinator; +import org.apache.iceberg.spark.ScanTaskSetManager; +import org.apache.iceberg.spark.SparkCatalogConfig; +import org.apache.iceberg.spark.SparkCatalogTestBase; +import org.apache.iceberg.spark.SparkReadOptions; +import org.apache.iceberg.spark.SparkStructLike; +import org.apache.iceberg.spark.SparkWriteOptions; +import org.apache.iceberg.types.TypeUtil; +import org.apache.iceberg.types.Types; +import org.apache.iceberg.util.CharSequenceSet; +import org.apache.iceberg.util.Pair; +import org.apache.iceberg.util.StructLikeSet; +import org.apache.spark.sql.AnalysisException; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.catalyst.analysis.NoSuchTableException; +import org.apache.spark.sql.functions; +import org.junit.Assert; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.TemporaryFolder; +import org.junit.runner.RunWith; +import org.junit.runners.Parameterized; + +@RunWith(Parameterized.class) +public class TestPositionDeletesTable extends SparkCatalogTestBase { + + public static final Schema SCHEMA = + new Schema( + Types.NestedField.required(1, "id", Types.IntegerType.get()), + Types.NestedField.required(2, "data", Types.StringType.get())); + private static final Map CATALOG_PROPS = + ImmutableMap.of( + "type", "hive", + "default-namespace", "default", + "cache-enabled", "false"); + private static final List NON_PATH_COLS = + ImmutableList.of("file_path", "pos", "row", "partition", "spec_id"); + + private final FileFormat format; + + @Parameterized.Parameters( + name = + "formatVersion = {0}, catalogName = {1}, implementation = {2}, config = {3}, fileFormat = {4}") + public static Object[][] parameters() { + return new Object[][] { + { + SparkCatalogConfig.HIVE.catalogName(), + SparkCatalogConfig.HIVE.implementation(), + CATALOG_PROPS, + FileFormat.PARQUET + }, + { + SparkCatalogConfig.HIVE.catalogName(), + SparkCatalogConfig.HIVE.implementation(), + CATALOG_PROPS, + FileFormat.AVRO + }, + { + SparkCatalogConfig.HIVE.catalogName(), + SparkCatalogConfig.HIVE.implementation(), + CATALOG_PROPS, + FileFormat.ORC + }, + }; + } + + public TestPositionDeletesTable( + String catalogName, String implementation, Map config, FileFormat format) { + super(catalogName, implementation, config); + this.format = format; + } + + @Rule public TemporaryFolder temp = new TemporaryFolder(); + + @Test + public void testNullRows() throws IOException { + String tableName = "null_rows"; + Table tab = createTable(tableName, SCHEMA, PartitionSpec.unpartitioned()); + + DataFile dFile = dataFile(tab); + tab.newAppend().appendFile(dFile).commit(); + + List> deletes = Lists.newArrayList(); + deletes.add(Pair.of(dFile.path(), 0L)); + deletes.add(Pair.of(dFile.path(), 1L)); + Pair posDeletes = + FileHelpers.writeDeleteFile( + tab, Files.localOutput(temp.newFile()), TestHelpers.Row.of(0), deletes); + tab.newRowDelta().addDeletes(posDeletes.first()).commit(); + + StructLikeSet actual = actual(tableName, tab); + + List> expectedDeletes = + Lists.newArrayList(positionDelete(dFile.path(), 0L), positionDelete(dFile.path(), 1L)); + StructLikeSet expected = + expected(tab, expectedDeletes, null, posDeletes.first().path().toString()); + + Assert.assertEquals("Position Delete table should contain expected rows", expected, actual); + dropTable(tableName); + } + + @Test + public void testPartitionedTable() throws IOException { + // Create table with two partitions + String tableName = "partitioned_table"; + PartitionSpec spec = PartitionSpec.builderFor(SCHEMA).identity("data").build(); + Table tab = createTable(tableName, SCHEMA, spec); + + DataFile dataFileA = dataFile(tab, "a"); + DataFile dataFileB = dataFile(tab, "b"); + + tab.newAppend().appendFile(dataFileA).appendFile(dataFileB).commit(); + + // Add position deletes for both partitions + Pair>, DeleteFile> deletesA = deleteFile(tab, dataFileA, "a"); + Pair>, DeleteFile> deletesB = deleteFile(tab, dataFileB, "b"); + + tab.newRowDelta().addDeletes(deletesA.second()).addDeletes(deletesB.second()).commit(); + + // Select deletes from one partition + StructLikeSet actual = actual(tableName, tab, "row.data='b'"); + GenericRecord partitionB = GenericRecord.create(tab.spec().partitionType()); + partitionB.setField("data", "b"); + StructLikeSet expected = + expected(tab, deletesB.first(), partitionB, deletesB.second().path().toString()); + + Assert.assertEquals("Position Delete table should contain expected rows", expected, actual); + dropTable(tableName); + } + + @Test + public void testSelect() throws IOException { + // Create table with two partitions + String tableName = "select"; + PartitionSpec spec = PartitionSpec.builderFor(SCHEMA).identity("data").build(); + Table tab = createTable(tableName, SCHEMA, spec); + + DataFile dataFileA = dataFile(tab, "a"); + DataFile dataFileB = dataFile(tab, "b"); + + tab.newAppend().appendFile(dataFileA).appendFile(dataFileB).commit(); + + // Add position deletes for both partitions + Pair>, DeleteFile> deletesA = deleteFile(tab, dataFileA, "a"); + Pair>, DeleteFile> deletesB = deleteFile(tab, dataFileB, "b"); + + tab.newRowDelta().addDeletes(deletesA.second()).addDeletes(deletesB.second()).commit(); + + // Select certain columns + Dataset df = + spark + .read() + .format("iceberg") + .load("default." + tableName + ".position_deletes") + .withColumn("input_file", functions.input_file_name()) + .select("row.id", "pos", "delete_file_path", "input_file"); + List actual = rowsToJava(df.collectAsList()); + + // Select cols from expected delete values + List expected = Lists.newArrayList(); + BiFunction, DeleteFile, Object[]> toRow = + (delete, file) -> { + int rowData = delete.get(2, GenericRecord.class).get(0, Integer.class); + long pos = delete.get(1, Long.class); + return row(rowData, pos, file.path().toString(), file.path().toString()); + }; + expected.addAll( + deletesA.first().stream() + .map(d -> toRow.apply(d, deletesA.second())) + .collect(Collectors.toList())); + expected.addAll( + deletesB.first().stream() + .map(d -> toRow.apply(d, deletesB.second())) + .collect(Collectors.toList())); + + // Sort and compare + Comparator comp = + (o1, o2) -> { + int result = Integer.compare((int) o1[0], (int) o2[0]); + if (result != 0) { + return result; + } else { + return ((String) o1[2]).compareTo((String) o2[2]); + } + }; + actual.sort(comp); + expected.sort(comp); + assertEquals("Position Delete table should contain expected rows", expected, actual); + dropTable(tableName); + } + + @Test + public void testSplitTasks() throws IOException { + String tableName = "big_table"; + Table tab = createTable(tableName, SCHEMA, PartitionSpec.unpartitioned()); + tab.updateProperties().set("read.split.target-size", "100").commit(); + int records = 500; + + GenericRecord record = GenericRecord.create(tab.schema()); + List dataRecords = Lists.newArrayList(); + for (int i = 0; i < records; i++) { + dataRecords.add(record.copy("id", i, "data", String.valueOf(i))); + } + DataFile dFile = + FileHelpers.writeDataFile( + tab, + Files.localOutput(temp.newFile()), + org.apache.iceberg.TestHelpers.Row.of(), + dataRecords); + tab.newAppend().appendFile(dFile).commit(); + + List> deletes = Lists.newArrayList(); + for (long i = 0; i < records; i++) { + deletes.add(positionDelete(tab.schema(), dFile.path(), i, (int) i, String.valueOf(i))); + } + DeleteFile posDeletes = + FileHelpers.writePosDeleteFile( + tab, Files.localOutput(temp.newFile()), TestHelpers.Row.of(0), deletes); + tab.newRowDelta().addDeletes(posDeletes).commit(); + + Table deleteTable = + MetadataTableUtils.createMetadataTableInstance(tab, MetadataTableType.POSITION_DELETES); + + if (format.equals(FileFormat.AVRO)) { + Assert.assertTrue( + "Position delete scan should produce more than one split", + Iterables.size(deleteTable.newBatchScan().planTasks()) > 1); + } else { + Assert.assertEquals( + "Position delete scan should produce one split", + 1, + Iterables.size(deleteTable.newBatchScan().planTasks())); + } + + StructLikeSet actual = actual(tableName, tab); + StructLikeSet expected = expected(tab, deletes, null, posDeletes.path().toString()); + + Assert.assertEquals("Position Delete table should contain expected rows", expected, actual); + dropTable(tableName); + } + + @Test + public void testPartitionFilter() throws IOException { + // Create table with two partitions + String tableName = "partition_filter"; + PartitionSpec spec = PartitionSpec.builderFor(SCHEMA).identity("data").build(); + Table tab = createTable(tableName, SCHEMA, spec); + Table deletesTab = + MetadataTableUtils.createMetadataTableInstance(tab, MetadataTableType.POSITION_DELETES); + + DataFile dataFileA = dataFile(tab, "a"); + DataFile dataFileB = dataFile(tab, "b"); + tab.newAppend().appendFile(dataFileA).appendFile(dataFileB).commit(); + + // Add position deletes for both partitions + Pair>, DeleteFile> deletesA = deleteFile(tab, dataFileA, "a"); + Pair>, DeleteFile> deletesB = deleteFile(tab, dataFileA, "b"); + + tab.newRowDelta().addDeletes(deletesA.second()).addDeletes(deletesB.second()).commit(); + + // Prepare expected values + GenericRecord partitionRecordTemplate = GenericRecord.create(tab.spec().partitionType()); + Record partitionA = partitionRecordTemplate.copy("data", "a"); + Record partitionB = partitionRecordTemplate.copy("data", "b"); + StructLikeSet expectedA = + expected(tab, deletesA.first(), partitionA, deletesA.second().path().toString()); + StructLikeSet expectedB = + expected(tab, deletesB.first(), partitionB, deletesB.second().path().toString()); + StructLikeSet allExpected = StructLikeSet.create(deletesTab.schema().asStruct()); + allExpected.addAll(expectedA); + allExpected.addAll(expectedB); + + // Select deletes from all partitions + StructLikeSet actual = actual(tableName, tab); + Assert.assertEquals("Position Delete table should contain expected rows", allExpected, actual); + + // Select deletes from one partition + StructLikeSet actual2 = actual(tableName, tab, "partition.data = 'a' AND pos >= 0"); + + Assert.assertEquals("Position Delete table should contain expected rows", expectedA, actual2); + dropTable(tableName); + } + + @Test + public void testPartitionTransformFilter() throws IOException { + // Create table with two partitions + String tableName = "partition_filter"; + PartitionSpec spec = PartitionSpec.builderFor(SCHEMA).truncate("data", 1).build(); + Table tab = createTable(tableName, SCHEMA, spec); + Table deletesTable = + MetadataTableUtils.createMetadataTableInstance(tab, MetadataTableType.POSITION_DELETES); + + DataFile dataFileA = dataFile(tab, new Object[] {"aa"}, new Object[] {"a"}); + DataFile dataFileB = dataFile(tab, new Object[] {"bb"}, new Object[] {"b"}); + tab.newAppend().appendFile(dataFileA).appendFile(dataFileB).commit(); + + // Add position deletes for both partitions + Pair>, DeleteFile> deletesA = + deleteFile(tab, dataFileA, new Object[] {"aa"}, new Object[] {"a"}); + Pair>, DeleteFile> deletesB = + deleteFile(tab, dataFileA, new Object[] {"bb"}, new Object[] {"b"}); + tab.newRowDelta().addDeletes(deletesA.second()).addDeletes(deletesB.second()).commit(); + + // Prepare expected values + GenericRecord partitionRecordTemplate = GenericRecord.create(tab.spec().partitionType()); + Record partitionA = partitionRecordTemplate.copy("data_trunc", "a"); + Record partitionB = partitionRecordTemplate.copy("data_trunc", "b"); + StructLikeSet expectedA = + expected(tab, deletesA.first(), partitionA, deletesA.second().path().toString()); + StructLikeSet expectedB = + expected(tab, deletesB.first(), partitionB, deletesB.second().path().toString()); + StructLikeSet allExpected = StructLikeSet.create(deletesTable.schema().asStruct()); + allExpected.addAll(expectedA); + allExpected.addAll(expectedB); + + // Select deletes from all partitions + StructLikeSet actual = actual(tableName, tab); + Assert.assertEquals("Position Delete table should contain expected rows", allExpected, actual); + + // Select deletes from one partition + StructLikeSet actual2 = actual(tableName, tab, "partition.data_trunc = 'a' AND pos >= 0"); + + Assert.assertEquals("Position Delete table should contain expected rows", expectedA, actual2); + dropTable(tableName); + } + + @Test + public void testPartitionEvolutionReplace() throws Exception { + // Create table with spec (data) + String tableName = "partition_evolution"; + PartitionSpec originalSpec = PartitionSpec.builderFor(SCHEMA).identity("data").build(); + Table tab = createTable(tableName, SCHEMA, originalSpec); + int dataSpec = tab.spec().specId(); + + // Add files with old spec (data) + DataFile dataFileA = dataFile(tab, "a"); + DataFile dataFileB = dataFile(tab, "b"); + tab.newAppend().appendFile(dataFileA).appendFile(dataFileB).commit(); + Pair>, DeleteFile> deletesA = deleteFile(tab, dataFileA, "a"); + Pair>, DeleteFile> deletesB = deleteFile(tab, dataFileA, "b"); + tab.newRowDelta().addDeletes(deletesA.second()).addDeletes(deletesB.second()).commit(); + + // Switch partition spec from (data) to (id) + tab.updateSpec().removeField("data").addField("id").commit(); + + // Add data and delete files with new spec (id) + DataFile dataFile10 = dataFile(tab, 10); + DataFile dataFile99 = dataFile(tab, 99); + tab.newAppend().appendFile(dataFile10).appendFile(dataFile99).commit(); + + Pair>, DeleteFile> deletes10 = deleteFile(tab, dataFile10, 10); + Pair>, DeleteFile> deletes99 = deleteFile(tab, dataFile10, 99); + tab.newRowDelta().addDeletes(deletes10.second()).addDeletes(deletes99.second()).commit(); + + // Query partition of old spec + GenericRecord partitionRecordTemplate = GenericRecord.create(Partitioning.partitionType(tab)); + Record partitionA = partitionRecordTemplate.copy("data", "a"); + StructLikeSet expectedA = + expected(tab, deletesA.first(), partitionA, dataSpec, deletesA.second().path().toString()); + StructLikeSet actualA = actual(tableName, tab, "partition.data = 'a' AND pos >= 0"); + Assert.assertEquals("Position Delete table should contain expected rows", expectedA, actualA); + + // Query partition of new spec + Record partition10 = partitionRecordTemplate.copy("id", 10); + StructLikeSet expected10 = + expected( + tab, + deletes10.first(), + partition10, + tab.spec().specId(), + deletes10.second().path().toString()); + StructLikeSet actual10 = actual(tableName, tab, "partition.id = 10 AND pos >= 0"); + + Assert.assertEquals("Position Delete table should contain expected rows", expected10, actual10); + dropTable(tableName); + } + + @Test + public void testPartitionEvolutionAdd() throws Exception { + // Create unpartitioned table + String tableName = "partition_evolution_add"; + Table tab = createTable(tableName, SCHEMA, PartitionSpec.unpartitioned()); + int specId0 = tab.spec().specId(); + + // Add files with unpartitioned spec + DataFile dataFileUnpartitioned = dataFile(tab); + tab.newAppend().appendFile(dataFileUnpartitioned).commit(); + Pair>, DeleteFile> deletesUnpartitioned = + deleteFile(tab, dataFileUnpartitioned); + tab.newRowDelta().addDeletes(deletesUnpartitioned.second()).commit(); + + // Switch partition spec to (data) + tab.updateSpec().addField("data").commit(); + int specId1 = tab.spec().specId(); + + // Add files with new spec (data) + DataFile dataFileA = dataFile(tab, "a"); + DataFile dataFileB = dataFile(tab, "b"); + tab.newAppend().appendFile(dataFileA).appendFile(dataFileB).commit(); + + Pair>, DeleteFile> deletesA = deleteFile(tab, dataFileA, "a"); + Pair>, DeleteFile> deletesB = deleteFile(tab, dataFileB, "b"); + tab.newRowDelta().addDeletes(deletesA.second()).addDeletes(deletesB.second()).commit(); + + // Select deletes from new spec (data) + GenericRecord partitionRecordTemplate = GenericRecord.create(Partitioning.partitionType(tab)); + Record partitionA = partitionRecordTemplate.copy("data", "a"); + StructLikeSet expectedA = + expected(tab, deletesA.first(), partitionA, specId1, deletesA.second().path().toString()); + StructLikeSet actualA = actual(tableName, tab, "partition.data = 'a' AND pos >= 0"); + Assert.assertEquals("Position Delete table should contain expected rows", expectedA, actualA); + + // Select deletes from 'unpartitioned' data + Record unpartitionedRecord = partitionRecordTemplate.copy("data", null); + StructLikeSet expectedUnpartitioned = + expected( + tab, + deletesUnpartitioned.first(), + unpartitionedRecord, + specId0, + deletesUnpartitioned.second().path().toString()); + StructLikeSet actualUnpartitioned = + actual(tableName, tab, "partition.data IS NULL and pos >= 0"); + + Assert.assertEquals( + "Position Delete table should contain expected rows", + expectedUnpartitioned, + actualUnpartitioned); + dropTable(tableName); + } + + @Test + public void testPartitionEvolutionRemove() throws Exception { + // Create table with spec (data) + String tableName = "partition_evolution_remove"; + PartitionSpec originalSpec = PartitionSpec.builderFor(SCHEMA).identity("data").build(); + Table tab = createTable(tableName, SCHEMA, originalSpec); + int specId0 = tab.spec().specId(); + + // Add files with spec (data) + DataFile dataFileA = dataFile(tab, "a"); + DataFile dataFileB = dataFile(tab, "b"); + tab.newAppend().appendFile(dataFileA).appendFile(dataFileB).commit(); + + Pair>, DeleteFile> deletesA = deleteFile(tab, dataFileA, "a"); + Pair>, DeleteFile> deletesB = deleteFile(tab, dataFileB, "b"); + tab.newRowDelta().addDeletes(deletesA.second()).addDeletes(deletesB.second()).commit(); + + // Remove partition field + tab.updateSpec().removeField("data").commit(); + int specId1 = tab.spec().specId(); + + // Add unpartitioned files + DataFile dataFileUnpartitioned = dataFile(tab); + tab.newAppend().appendFile(dataFileUnpartitioned).commit(); + Pair>, DeleteFile> deletesUnpartitioned = + deleteFile(tab, dataFileUnpartitioned); + tab.newRowDelta().addDeletes(deletesUnpartitioned.second()).commit(); + + // Select deletes from (data) spec + GenericRecord partitionRecordTemplate = GenericRecord.create(Partitioning.partitionType(tab)); + Record partitionA = partitionRecordTemplate.copy("data", "a"); + StructLikeSet expectedA = + expected(tab, deletesA.first(), partitionA, specId0, deletesA.second().path().toString()); + StructLikeSet actualA = actual(tableName, tab, "partition.data = 'a' AND pos >= 0"); + Assert.assertEquals("Position Delete table should contain expected rows", expectedA, actualA); + + // Select deletes from 'unpartitioned' spec + Record unpartitionedRecord = partitionRecordTemplate.copy("data", null); + StructLikeSet expectedUnpartitioned = + expected( + tab, + deletesUnpartitioned.first(), + unpartitionedRecord, + specId1, + deletesUnpartitioned.second().path().toString()); + StructLikeSet actualUnpartitioned = + actual(tableName, tab, "partition.data IS NULL and pos >= 0"); + + Assert.assertEquals( + "Position Delete table should contain expected rows", + expectedUnpartitioned, + actualUnpartitioned); + dropTable(tableName); + } + + @Test + public void testSpecIdFilter() throws Exception { + // Create table with spec (data) + String tableName = "spec_id_filter"; + Table tab = createTable(tableName, SCHEMA, PartitionSpec.unpartitioned()); + int unpartitionedSpec = tab.spec().specId(); + + // Add data file and delete + DataFile dataFileUnpartitioned = dataFile(tab); + tab.newAppend().appendFile(dataFileUnpartitioned).commit(); + Pair>, DeleteFile> deletesUnpartitioned = + deleteFile(tab, dataFileUnpartitioned); + tab.newRowDelta().addDeletes(deletesUnpartitioned.second()).commit(); + + // Switch partition spec to (data) and add files + tab.updateSpec().addField("data").commit(); + int dataSpec = tab.spec().specId(); + + DataFile dataFileA = dataFile(tab, "a"); + DataFile dataFileB = dataFile(tab, "b"); + tab.newAppend().appendFile(dataFileA).appendFile(dataFileB).commit(); + + Pair>, DeleteFile> deletesA = deleteFile(tab, dataFileA, "a"); + Pair>, DeleteFile> deletesB = deleteFile(tab, dataFileB, "b"); + tab.newRowDelta().addDeletes(deletesA.second()).addDeletes(deletesB.second()).commit(); + + // Select deletes from 'unpartitioned' + GenericRecord partitionRecordTemplate = GenericRecord.create(Partitioning.partitionType(tab)); + StructLikeSet expectedUnpartitioned = + expected( + tab, + deletesUnpartitioned.first(), + partitionRecordTemplate, + unpartitionedSpec, + deletesUnpartitioned.second().path().toString()); + StructLikeSet actualUnpartitioned = + actual(tableName, tab, String.format("spec_id = %d", unpartitionedSpec)); + Assert.assertEquals( + "Position Delete table should contain expected rows", + expectedUnpartitioned, + actualUnpartitioned); + + // Select deletes from 'data' partition spec + StructLike partitionA = partitionRecordTemplate.copy("data", "a"); + StructLike partitionB = partitionRecordTemplate.copy("data", "b"); + StructLikeSet expected = + expected(tab, deletesA.first(), partitionA, dataSpec, deletesA.second().path().toString()); + expected.addAll( + expected(tab, deletesB.first(), partitionB, dataSpec, deletesB.second().path().toString())); + + StructLikeSet actual = actual(tableName, tab, String.format("spec_id = %d", dataSpec)); + Assert.assertEquals("Position Delete table should contain expected rows", expected, actual); + dropTable(tableName); + } + + @Test + public void testSchemaEvolutionAdd() throws Exception { + // Create table with original schema + String tableName = "schema_evolution_add"; + PartitionSpec spec = PartitionSpec.builderFor(SCHEMA).identity("data").build(); + Table tab = createTable(tableName, SCHEMA, spec); + + // Add files with original schema + DataFile dataFileA = dataFile(tab, "a"); + DataFile dataFileB = dataFile(tab, "b"); + tab.newAppend().appendFile(dataFileA).appendFile(dataFileB).commit(); + + Pair>, DeleteFile> deletesA = deleteFile(tab, dataFileA, "a"); + Pair>, DeleteFile> deletesB = deleteFile(tab, dataFileB, "b"); + tab.newRowDelta().addDeletes(deletesA.second()).addDeletes(deletesB.second()).commit(); + + // Add files with new schema + tab.updateSchema() + .addColumn("new_col_1", Types.IntegerType.get()) + .addColumn("new_col_2", Types.IntegerType.get()) + .commit(); + + // Add files with new schema + DataFile dataFileC = dataFile(tab, "c"); + DataFile dataFileD = dataFile(tab, "d"); + tab.newAppend().appendFile(dataFileA).appendFile(dataFileB).commit(); + + Pair>, DeleteFile> deletesC = deleteFile(tab, dataFileC, "c"); + Pair>, DeleteFile> deletesD = deleteFile(tab, dataFileD, "d"); + tab.newRowDelta().addDeletes(deletesC.second()).addDeletes(deletesD.second()).commit(); + + // Select deletes from old schema + GenericRecord partitionRecordTemplate = GenericRecord.create(Partitioning.partitionType(tab)); + Record partitionA = partitionRecordTemplate.copy("data", "a"); + // pad expected delete rows with null values for new columns + List> expectedDeletesA = deletesA.first(); + expectedDeletesA.forEach( + d -> { + GenericRecord nested = d.get(2, GenericRecord.class); + GenericRecord padded = GenericRecord.create(tab.schema().asStruct()); + padded.set(0, nested.get(0)); + padded.set(1, nested.get(1)); + padded.set(2, null); + padded.set(3, null); + d.set(2, padded); + }); + StructLikeSet expectedA = + expected(tab, expectedDeletesA, partitionA, deletesA.second().path().toString()); + StructLikeSet actualA = actual(tableName, tab, "partition.data = 'a' AND pos >= 0"); + Assert.assertEquals("Position Delete table should contain expected rows", expectedA, actualA); + + // Select deletes from new schema + Record partitionC = partitionRecordTemplate.copy("data", "c"); + StructLikeSet expectedC = + expected(tab, deletesC.first(), partitionC, deletesC.second().path().toString()); + StructLikeSet actualC = actual(tableName, tab, "partition.data = 'c' and pos >= 0"); + + Assert.assertEquals("Position Delete table should contain expected rows", expectedC, actualC); + dropTable(tableName); + } + + @Test + public void testSchemaEvolutionRemove() throws Exception { + // Create table with original schema + String tableName = "schema_evolution_remove"; + Schema oldSchema = + new Schema( + Types.NestedField.required(1, "id", Types.IntegerType.get()), + Types.NestedField.required(2, "data", Types.StringType.get()), + Types.NestedField.optional(3, "new_col_1", Types.IntegerType.get()), + Types.NestedField.optional(4, "new_col_2", Types.IntegerType.get())); + PartitionSpec spec = PartitionSpec.builderFor(oldSchema).identity("data").build(); + Table tab = createTable(tableName, oldSchema, spec); + + // Add files with original schema + DataFile dataFileA = dataFile(tab, "a"); + DataFile dataFileB = dataFile(tab, "b"); + tab.newAppend().appendFile(dataFileA).appendFile(dataFileB).commit(); + + Pair>, DeleteFile> deletesA = deleteFile(tab, dataFileA, "a"); + Pair>, DeleteFile> deletesB = deleteFile(tab, dataFileB, "b"); + tab.newRowDelta().addDeletes(deletesA.second()).addDeletes(deletesB.second()).commit(); + + // Add files with new schema + tab.updateSchema().deleteColumn("new_col_1").deleteColumn("new_col_2").commit(); + + // Add files with new schema + DataFile dataFileC = dataFile(tab, "c"); + DataFile dataFileD = dataFile(tab, "d"); + tab.newAppend().appendFile(dataFileA).appendFile(dataFileB).commit(); + + Pair>, DeleteFile> deletesC = deleteFile(tab, dataFileC, "c"); + Pair>, DeleteFile> deletesD = deleteFile(tab, dataFileD, "d"); + tab.newRowDelta().addDeletes(deletesC.second()).addDeletes(deletesD.second()).commit(); + + // Select deletes from old schema + GenericRecord partitionRecordTemplate = GenericRecord.create(Partitioning.partitionType(tab)); + Record partitionA = partitionRecordTemplate.copy("data", "a"); + // remove deleted columns from expected result + List> expectedDeletesA = deletesA.first(); + expectedDeletesA.forEach( + d -> { + GenericRecord nested = d.get(2, GenericRecord.class); + GenericRecord padded = GenericRecord.create(tab.schema().asStruct()); + padded.set(0, nested.get(0)); + padded.set(1, nested.get(1)); + d.set(2, padded); + }); + StructLikeSet expectedA = + expected(tab, expectedDeletesA, partitionA, deletesA.second().path().toString()); + StructLikeSet actualA = actual(tableName, tab, "partition.data = 'a' AND pos >= 0"); + Assert.assertEquals("Position Delete table should contain expected rows", expectedA, actualA); + + // Select deletes from new schema + Record partitionC = partitionRecordTemplate.copy("data", "c"); + StructLikeSet expectedC = + expected(tab, deletesC.first(), partitionC, deletesC.second().path().toString()); + StructLikeSet actualC = actual(tableName, tab, "partition.data = 'c' and pos >= 0"); + + Assert.assertEquals("Position Delete table should contain expected rows", expectedC, actualC); + dropTable(tableName); + } + + @Test + public void testWrite() throws IOException, NoSuchTableException { + String tableName = "test_write"; + PartitionSpec spec = PartitionSpec.builderFor(SCHEMA).identity("data").build(); + Table tab = createTable(tableName, SCHEMA, spec); + + DataFile dataFileA = dataFile(tab, "a"); + DataFile dataFileB = dataFile(tab, "b"); + tab.newAppend().appendFile(dataFileA).appendFile(dataFileB).commit(); + + // Add position deletes for both partitions + Pair>, DeleteFile> deletesA = deleteFile(tab, dataFileA, "a"); + Pair>, DeleteFile> deletesB = deleteFile(tab, dataFileB, "b"); + + tab.newRowDelta().addDeletes(deletesA.second()).addDeletes(deletesB.second()).commit(); + + Table posDeletesTable = + MetadataTableUtils.createMetadataTableInstance(tab, MetadataTableType.POSITION_DELETES); + String posDeletesTableName = catalogName + ".default." + tableName + ".position_deletes"; + for (String partValue : ImmutableList.of("a", "b")) { + try (CloseableIterable tasks = tasks(posDeletesTable, "data", partValue)) { + String fileSetID = UUID.randomUUID().toString(); + stageTask(tab, fileSetID, tasks); + + Dataset scanDF = + spark + .read() + .format("iceberg") + .option(SparkReadOptions.SCAN_TASK_SET_ID, fileSetID) + .option(SparkReadOptions.FILE_OPEN_COST, Integer.MAX_VALUE) + .load(posDeletesTableName); + + Assert.assertEquals(1, scanDF.javaRDD().getNumPartitions()); + scanDF + .writeTo(posDeletesTableName) + .option(SparkWriteOptions.REWRITTEN_FILE_SCAN_TASK_SET_ID, fileSetID) + .append(); + + commit(tab, posDeletesTable, fileSetID, 1); + } + } + + // Prepare expected values (without 'delete_file_path' as these have been rewritten) + GenericRecord partitionRecordTemplate = GenericRecord.create(tab.spec().partitionType()); + Record partitionA = partitionRecordTemplate.copy("data", "a"); + Record partitionB = partitionRecordTemplate.copy("data", "b"); + StructLikeSet expectedA = expected(tab, deletesA.first(), partitionA, null); + StructLikeSet expectedB = expected(tab, deletesB.first(), partitionB, null); + StructLikeSet allExpected = + StructLikeSet.create( + TypeUtil.selectNot( + posDeletesTable.schema(), ImmutableSet.of(MetadataColumns.FILE_PATH_COLUMN_ID)) + .asStruct()); + allExpected.addAll(expectedA); + allExpected.addAll(expectedB); + + // Compare values without 'delete_file_path' as these have been rewritten + StructLikeSet actual = actual(tableName, tab, null, NON_PATH_COLS); + Assert.assertEquals("Position Delete table should contain expected rows", allExpected, actual); + dropTable(tableName); + } + + @Test + public void testWriteUnpartitionedNullRows() throws Exception { + String tableName = "write_null_rows"; + Table tab = createTable(tableName, SCHEMA, PartitionSpec.unpartitioned()); + + DataFile dFile = dataFile(tab); + tab.newAppend().appendFile(dFile).commit(); + + List> deletes = Lists.newArrayList(); + deletes.add(Pair.of(dFile.path(), 0L)); + deletes.add(Pair.of(dFile.path(), 1L)); + Pair posDeletes = + FileHelpers.writeDeleteFile( + tab, Files.localOutput(temp.newFile()), TestHelpers.Row.of(0), deletes); + tab.newRowDelta().addDeletes(posDeletes.first()).commit(); + + Table posDeletesTable = + MetadataTableUtils.createMetadataTableInstance(tab, MetadataTableType.POSITION_DELETES); + String posDeletesTableName = catalogName + ".default." + tableName + ".position_deletes"; + try (CloseableIterable tasks = posDeletesTable.newBatchScan().planFiles()) { + String fileSetID = UUID.randomUUID().toString(); + stageTask(tab, fileSetID, tasks); + + Dataset scanDF = + spark + .read() + .format("iceberg") + .option(SparkReadOptions.SCAN_TASK_SET_ID, fileSetID) + .option(SparkReadOptions.FILE_OPEN_COST, Integer.MAX_VALUE) + .load(posDeletesTableName); + Assert.assertEquals(1, scanDF.javaRDD().getNumPartitions()); + scanDF + .writeTo(posDeletesTableName) + .option(SparkWriteOptions.REWRITTEN_FILE_SCAN_TASK_SET_ID, fileSetID) + .append(); + + commit(tab, posDeletesTable, fileSetID, 1); + } + + // Compare values without 'delete_file_path' as these have been rewritten + StructLikeSet actual = + actual(tableName, tab, null, ImmutableList.of("file_path", "pos", "row", "spec_id")); + + List> expectedDeletes = + Lists.newArrayList(positionDelete(dFile.path(), 0L), positionDelete(dFile.path(), 1L)); + StructLikeSet expected = expected(tab, expectedDeletes, null, null); + + Assert.assertEquals("Position Delete table should contain expected rows", expected, actual); + dropTable(tableName); + } + + @Test + public void testWriteMixedRows() throws Exception { + String tableName = "write_mixed_rows"; + PartitionSpec spec = PartitionSpec.builderFor(SCHEMA).identity("data").build(); + Table tab = createTable(tableName, SCHEMA, spec); + + DataFile dataFileA = dataFile(tab, "a"); + DataFile dataFileB = dataFile(tab, "b"); + tab.newAppend().appendFile(dataFileA).appendFile(dataFileB).commit(); + + // Add a delete file with row and without row + List> deletes = Lists.newArrayList(); + deletes.add(Pair.of(dataFileA.path(), 0L)); + deletes.add(Pair.of(dataFileA.path(), 1L)); + Pair deletesWithoutRow = + FileHelpers.writeDeleteFile( + tab, Files.localOutput(temp.newFile()), TestHelpers.Row.of("a"), deletes); + + Pair>, DeleteFile> deletesWithRow = deleteFile(tab, dataFileB, "b"); + + tab.newRowDelta() + .addDeletes(deletesWithoutRow.first()) + .addDeletes(deletesWithRow.second()) + .commit(); + + // rewrite delete files + Table posDeletesTable = + MetadataTableUtils.createMetadataTableInstance(tab, MetadataTableType.POSITION_DELETES); + String posDeletesTableName = catalogName + ".default." + tableName + ".position_deletes"; + for (String partValue : ImmutableList.of("a", "b")) { + try (CloseableIterable tasks = tasks(posDeletesTable, "data", partValue)) { + String fileSetID = UUID.randomUUID().toString(); + stageTask(tab, fileSetID, tasks); + + Dataset scanDF = + spark + .read() + .format("iceberg") + .option(SparkReadOptions.SCAN_TASK_SET_ID, fileSetID) + .load(posDeletesTableName); + Assert.assertEquals(1, scanDF.javaRDD().getNumPartitions()); + scanDF + .writeTo(posDeletesTableName) + .option(SparkWriteOptions.REWRITTEN_FILE_SCAN_TASK_SET_ID, fileSetID) + .append(); + + commit(tab, posDeletesTable, fileSetID, 1); + } + } + + // Compare values without 'delete_file_path' as these have been rewritten + StructLikeSet actual = + actual( + tableName, + tab, + null, + ImmutableList.of("file_path", "pos", "row", "partition", "spec_id")); + + // Prepare expected values + GenericRecord partitionRecordTemplate = GenericRecord.create(tab.spec().partitionType()); + Record partitionA = partitionRecordTemplate.copy("data", "a"); + Record partitionB = partitionRecordTemplate.copy("data", "b"); + StructLikeSet allExpected = + StructLikeSet.create( + TypeUtil.selectNot( + posDeletesTable.schema(), ImmutableSet.of(MetadataColumns.FILE_PATH_COLUMN_ID)) + .asStruct()); + allExpected.addAll( + expected( + tab, + Lists.newArrayList( + positionDelete(dataFileA.path(), 0L), positionDelete(dataFileA.path(), 1L)), + partitionA, + null)); + allExpected.addAll(expected(tab, deletesWithRow.first(), partitionB, null)); + + Assert.assertEquals("Position Delete table should contain expected rows", allExpected, actual); + dropTable(tableName); + } + + @Test + public void testWritePartitionEvolutionAdd() throws Exception { + // Create unpartitioned table + String tableName = "write_partition_evolution_add"; + Table tab = createTable(tableName, SCHEMA, PartitionSpec.unpartitioned()); + int specId0 = tab.spec().specId(); + + // Add files with unpartitioned spec + DataFile dataFileUnpartitioned = dataFile(tab); + tab.newAppend().appendFile(dataFileUnpartitioned).commit(); + Pair>, DeleteFile> deletesUnpartitioned = + deleteFile(tab, dataFileUnpartitioned); + tab.newRowDelta().addDeletes(deletesUnpartitioned.second()).commit(); + + // Switch partition spec to (data) + tab.updateSpec().addField("data").commit(); + int specId1 = tab.spec().specId(); + + // Add files with new spec (data) + DataFile dataFileA = dataFile(tab, "a"); + DataFile dataFileB = dataFile(tab, "b"); + tab.newAppend().appendFile(dataFileA).appendFile(dataFileB).commit(); + + Pair>, DeleteFile> deletesA = deleteFile(tab, dataFileA, "a"); + Pair>, DeleteFile> deletesB = deleteFile(tab, dataFileB, "b"); + tab.newRowDelta().addDeletes(deletesA.second()).addDeletes(deletesB.second()).commit(); + + Table posDeletesTable = + MetadataTableUtils.createMetadataTableInstance(tab, MetadataTableType.POSITION_DELETES); + String posDeletesTableName = catalogName + ".default." + tableName + ".position_deletes"; + + // Read/write back unpartitioned data + try (CloseableIterable tasks = + posDeletesTable.newBatchScan().filter(Expressions.isNull("partition.data")).planFiles()) { + String fileSetID = UUID.randomUUID().toString(); + stageTask(tab, fileSetID, tasks); + + Dataset scanDF = + spark + .read() + .format("iceberg") + .option(SparkReadOptions.SCAN_TASK_SET_ID, fileSetID) + .option(SparkReadOptions.FILE_OPEN_COST, Integer.MAX_VALUE) + .load(posDeletesTableName); + Assert.assertEquals(1, scanDF.javaRDD().getNumPartitions()); + scanDF + .writeTo(posDeletesTableName) + .option(SparkWriteOptions.REWRITTEN_FILE_SCAN_TASK_SET_ID, fileSetID) + .append(); + + commit(tab, posDeletesTable, fileSetID, 1); + } + + // Select deletes from unpartitioned data + // Compare values without 'delete_file_path' as these have been rewritten + GenericRecord partitionRecordTemplate = GenericRecord.create(Partitioning.partitionType(tab)); + Record unpartitionedRecord = partitionRecordTemplate.copy("data", null); + StructLikeSet expectedUnpartitioned = + expected(tab, deletesUnpartitioned.first(), unpartitionedRecord, specId0, null); + StructLikeSet actualUnpartitioned = + actual(tableName, tab, "partition.data IS NULL", NON_PATH_COLS); + Assert.assertEquals( + "Position Delete table should contain expected rows", + expectedUnpartitioned, + actualUnpartitioned); + + // Read/write back new partition spec (data) + for (String partValue : ImmutableList.of("a", "b")) { + try (CloseableIterable tasks = tasks(posDeletesTable, "data", partValue)) { + String fileSetID = UUID.randomUUID().toString(); + stageTask(tab, fileSetID, tasks); + + Dataset scanDF = + spark + .read() + .format("iceberg") + .option(SparkReadOptions.SCAN_TASK_SET_ID, fileSetID) + .option(SparkReadOptions.FILE_OPEN_COST, Integer.MAX_VALUE) + .load(posDeletesTableName); + Assert.assertEquals(1, scanDF.javaRDD().getNumPartitions()); + scanDF + .writeTo(posDeletesTableName) + .option(SparkWriteOptions.REWRITTEN_FILE_SCAN_TASK_SET_ID, fileSetID) + .append(); + + // commit the rewrite + commit(tab, posDeletesTable, fileSetID, 1); + } + } + + // Select deletes from new spec (data) + Record partitionA = partitionRecordTemplate.copy("data", "a"); + Record partitionB = partitionRecordTemplate.copy("data", "b"); + StructLikeSet expectedAll = + StructLikeSet.create( + TypeUtil.selectNot( + posDeletesTable.schema(), ImmutableSet.of(MetadataColumns.FILE_PATH_COLUMN_ID)) + .asStruct()); + expectedAll.addAll(expected(tab, deletesA.first(), partitionA, specId1, null)); + expectedAll.addAll(expected(tab, deletesB.first(), partitionB, specId1, null)); + StructLikeSet actualAll = + actual(tableName, tab, "partition.data = 'a' OR partition.data = 'b'", NON_PATH_COLS); + Assert.assertEquals( + "Position Delete table should contain expected rows", expectedAll, actualAll); + + dropTable(tableName); + } + + @Test + public void testWritePartitionEvolutionDisallowed() throws Exception { + // Create unpartitioned table + String tableName = "write_partition_evolution_write"; + Table tab = createTable(tableName, SCHEMA, PartitionSpec.unpartitioned()); + + // Add files with unpartitioned spec + DataFile dataFileUnpartitioned = dataFile(tab); + tab.newAppend().appendFile(dataFileUnpartitioned).commit(); + Pair>, DeleteFile> deletesUnpartitioned = + deleteFile(tab, dataFileUnpartitioned); + tab.newRowDelta().addDeletes(deletesUnpartitioned.second()).commit(); + + Table posDeletesTable = + MetadataTableUtils.createMetadataTableInstance(tab, MetadataTableType.POSITION_DELETES); + String posDeletesTableName = catalogName + ".default." + tableName + ".position_deletes"; + + Dataset scanDF; + String fileSetID = UUID.randomUUID().toString(); + try (CloseableIterable tasks = posDeletesTable.newBatchScan().planFiles()) { + stageTask(tab, fileSetID, tasks); + + scanDF = + spark + .read() + .format("iceberg") + .option(SparkReadOptions.SCAN_TASK_SET_ID, fileSetID) + .option(SparkReadOptions.FILE_OPEN_COST, Integer.MAX_VALUE) + .load(posDeletesTableName); + Assert.assertEquals(1, scanDF.javaRDD().getNumPartitions()); + + // Add partition field to render the original un-partitioned dataset un-commitable + tab.updateSpec().addField("data").commit(); + } + + Assert.assertThrows( + AnalysisException.class, + () -> + scanDF + .writeTo(posDeletesTableName) + .option(SparkWriteOptions.REWRITTEN_FILE_SCAN_TASK_SET_ID, fileSetID) + .append()); + + dropTable(tableName); + } + + @Test + public void testWriteSchemaEvolutionAdd() throws Exception { + // Create table with original schema + String tableName = "write_schema_evolution_add"; + PartitionSpec spec = PartitionSpec.builderFor(SCHEMA).identity("data").build(); + Table tab = createTable(tableName, SCHEMA, spec); + + // Add files with original schema + DataFile dataFileA = dataFile(tab, "a"); + DataFile dataFileB = dataFile(tab, "b"); + tab.newAppend().appendFile(dataFileA).appendFile(dataFileB).commit(); + + Pair>, DeleteFile> deletesA = deleteFile(tab, dataFileA, "a"); + Pair>, DeleteFile> deletesB = deleteFile(tab, dataFileB, "b"); + tab.newRowDelta().addDeletes(deletesA.second()).addDeletes(deletesB.second()).commit(); + + // Add files with new schema + tab.updateSchema() + .addColumn("new_col_1", Types.IntegerType.get()) + .addColumn("new_col_2", Types.IntegerType.get()) + .commit(); + + // Add files with new schema + DataFile dataFileC = dataFile(tab, "c"); + DataFile dataFileD = dataFile(tab, "d"); + tab.newAppend().appendFile(dataFileA).appendFile(dataFileB).commit(); + + Pair>, DeleteFile> deletesC = deleteFile(tab, dataFileC, "c"); + Pair>, DeleteFile> deletesD = deleteFile(tab, dataFileD, "d"); + tab.newRowDelta().addDeletes(deletesC.second()).addDeletes(deletesD.second()).commit(); + + Table posDeletesTable = + MetadataTableUtils.createMetadataTableInstance(tab, MetadataTableType.POSITION_DELETES); + String posDeletesTableName = catalogName + ".default." + tableName + ".position_deletes"; + + // rewrite files of old schema + try (CloseableIterable tasks = tasks(posDeletesTable, "data", "a")) { + String fileSetID = UUID.randomUUID().toString(); + stageTask(tab, fileSetID, tasks); + + Dataset scanDF = + spark + .read() + .format("iceberg") + .option(SparkReadOptions.SCAN_TASK_SET_ID, fileSetID) + .option(SparkReadOptions.FILE_OPEN_COST, Integer.MAX_VALUE) + .load(posDeletesTableName); + + Assert.assertEquals(1, scanDF.javaRDD().getNumPartitions()); + scanDF + .writeTo(posDeletesTableName) + .option(SparkWriteOptions.REWRITTEN_FILE_SCAN_TASK_SET_ID, fileSetID) + .append(); + + commit(tab, posDeletesTable, fileSetID, 1); + } + + // Select deletes from old schema + GenericRecord partitionRecordTemplate = GenericRecord.create(Partitioning.partitionType(tab)); + Record partitionA = partitionRecordTemplate.copy("data", "a"); + // pad expected delete rows with null values for new columns + List> expectedDeletesA = deletesA.first(); + expectedDeletesA.forEach( + d -> { + GenericRecord nested = d.get(2, GenericRecord.class); + GenericRecord padded = GenericRecord.create(tab.schema().asStruct()); + padded.set(0, nested.get(0)); + padded.set(1, nested.get(1)); + padded.set(2, null); + padded.set(3, null); + d.set(2, padded); + }); + StructLikeSet expectedA = expected(tab, expectedDeletesA, partitionA, null); + StructLikeSet actualA = actual(tableName, tab, "partition.data = 'a'", NON_PATH_COLS); + Assert.assertEquals("Position Delete table should contain expected rows", expectedA, actualA); + + // rewrite files of new schema + try (CloseableIterable tasks = tasks(posDeletesTable, "data", "c")) { + String fileSetID = UUID.randomUUID().toString(); + stageTask(tab, fileSetID, tasks); + + Dataset scanDF = + spark + .read() + .format("iceberg") + .option(SparkReadOptions.SCAN_TASK_SET_ID, fileSetID) + .option(SparkReadOptions.FILE_OPEN_COST, Integer.MAX_VALUE) + .load(posDeletesTableName); + + Assert.assertEquals(1, scanDF.javaRDD().getNumPartitions()); + scanDF + .writeTo(posDeletesTableName) + .option(SparkWriteOptions.REWRITTEN_FILE_SCAN_TASK_SET_ID, fileSetID) + .append(); + + commit(tab, posDeletesTable, fileSetID, 1); + } + + // Select deletes from new schema + Record partitionC = partitionRecordTemplate.copy("data", "c"); + StructLikeSet expectedC = expected(tab, deletesC.first(), partitionC, null); + StructLikeSet actualC = actual(tableName, tab, "partition.data = 'c'", NON_PATH_COLS); + + Assert.assertEquals("Position Delete table should contain expected rows", expectedC, actualC); + dropTable(tableName); + } + + @Test + public void testWriteSchemaEvolutionRemove() throws Exception { + // Create table with original schema + String tableName = "write_schema_evolution_remove"; + Schema oldSchema = + new Schema( + Types.NestedField.required(1, "id", Types.IntegerType.get()), + Types.NestedField.required(2, "data", Types.StringType.get()), + Types.NestedField.optional(3, "new_col_1", Types.IntegerType.get()), + Types.NestedField.optional(4, "new_col_2", Types.IntegerType.get())); + PartitionSpec spec = PartitionSpec.builderFor(oldSchema).identity("data").build(); + Table tab = createTable(tableName, oldSchema, spec); + + // Add files with original schema + DataFile dataFileA = dataFile(tab, "a"); + DataFile dataFileB = dataFile(tab, "b"); + tab.newAppend().appendFile(dataFileA).appendFile(dataFileB).commit(); + + Pair>, DeleteFile> deletesA = deleteFile(tab, dataFileA, "a"); + Pair>, DeleteFile> deletesB = deleteFile(tab, dataFileB, "b"); + tab.newRowDelta().addDeletes(deletesA.second()).addDeletes(deletesB.second()).commit(); + + // Add files with new schema + tab.updateSchema().deleteColumn("new_col_1").deleteColumn("new_col_2").commit(); + + // Add files with new schema + DataFile dataFileC = dataFile(tab, "c"); + DataFile dataFileD = dataFile(tab, "d"); + tab.newAppend().appendFile(dataFileA).appendFile(dataFileB).commit(); + + Pair>, DeleteFile> deletesC = deleteFile(tab, dataFileC, "c"); + Pair>, DeleteFile> deletesD = deleteFile(tab, dataFileD, "d"); + tab.newRowDelta().addDeletes(deletesC.second()).addDeletes(deletesD.second()).commit(); + + Table posDeletesTable = + MetadataTableUtils.createMetadataTableInstance(tab, MetadataTableType.POSITION_DELETES); + String posDeletesTableName = catalogName + ".default." + tableName + ".position_deletes"; + + // rewrite files + for (String partValue : ImmutableList.of("a", "b", "c", "d")) { + try (CloseableIterable tasks = tasks(posDeletesTable, "data", partValue)) { + String fileSetID = UUID.randomUUID().toString(); + stageTask(tab, fileSetID, tasks); + + Dataset scanDF = + spark + .read() + .format("iceberg") + .option(SparkReadOptions.SCAN_TASK_SET_ID, fileSetID) + .option(SparkReadOptions.FILE_OPEN_COST, Integer.MAX_VALUE) + .load(posDeletesTableName); + Assert.assertEquals(1, scanDF.javaRDD().getNumPartitions()); + scanDF + .writeTo(posDeletesTableName) + .option(SparkWriteOptions.REWRITTEN_FILE_SCAN_TASK_SET_ID, fileSetID) + .append(); + + commit(tab, posDeletesTable, fileSetID, 1); + } + } + + // Select deletes from old schema + GenericRecord partitionRecordTemplate = GenericRecord.create(Partitioning.partitionType(tab)); + Record partitionA = partitionRecordTemplate.copy("data", "a"); + // remove deleted columns from expected result + List> expectedDeletesA = deletesA.first(); + expectedDeletesA.forEach( + d -> { + GenericRecord nested = d.get(2, GenericRecord.class); + GenericRecord padded = GenericRecord.create(tab.schema().asStruct()); + padded.set(0, nested.get(0)); + padded.set(1, nested.get(1)); + d.set(2, padded); + }); + StructLikeSet expectedA = expected(tab, expectedDeletesA, partitionA, null); + StructLikeSet actualA = actual(tableName, tab, "partition.data = 'a'", NON_PATH_COLS); + Assert.assertEquals("Position Delete table should contain expected rows", expectedA, actualA); + + // Select deletes from new schema + Record partitionC = partitionRecordTemplate.copy("data", "c"); + StructLikeSet expectedC = expected(tab, deletesC.first(), partitionC, null); + StructLikeSet actualC = actual(tableName, tab, "partition.data = 'c'", NON_PATH_COLS); + + Assert.assertEquals("Position Delete table should contain expected rows", expectedC, actualC); + dropTable(tableName); + } + + @Test + public void testNormalWritesNotAllowed() throws IOException { + String tableName = "test_normal_write_not_allowed"; + PartitionSpec spec = PartitionSpec.builderFor(SCHEMA).identity("data").build(); + Table tab = createTable(tableName, SCHEMA, spec); + + DataFile dataFileA = dataFile(tab, "a"); + tab.newAppend().appendFile(dataFileA).commit(); + + Pair>, DeleteFile> deletesA = deleteFile(tab, dataFileA, "a"); + tab.newRowDelta().addDeletes(deletesA.second()).commit(); + + String posDeletesTableName = catalogName + ".default." + tableName + ".position_deletes"; + + Dataset scanDF = spark.read().format("iceberg").load(posDeletesTableName); + + Assert.assertThrows( + "position_deletes table can only be written by RewriteDeleteFiles", + IllegalArgumentException.class, + () -> scanDF.writeTo(posDeletesTableName).append()); + + dropTable(tableName); + } + + private StructLikeSet actual(String tableName, Table table) { + return actual(tableName, table, null, null); + } + + private StructLikeSet actual(String tableName, Table table, String filter) { + return actual(tableName, table, filter, null); + } + + private StructLikeSet actual(String tableName, Table table, String filter, List cols) { + Dataset df = + spark + .read() + .format("iceberg") + .load(catalogName + ".default." + tableName + ".position_deletes"); + if (filter != null) { + df = df.filter(filter); + } + if (cols != null) { + df = df.select(cols.get(0), cols.subList(1, cols.size()).toArray(new String[0])); + } + Table deletesTable = + MetadataTableUtils.createMetadataTableInstance(table, MetadataTableType.POSITION_DELETES); + Types.StructType projection = deletesTable.schema().asStruct(); + if (cols != null) { + projection = + Types.StructType.of( + projection.fields().stream() + .filter(f -> cols.contains(f.name())) + .collect(Collectors.toList())); + } + Types.StructType finalProjection = projection; + StructLikeSet set = StructLikeSet.create(projection); + df.collectAsList() + .forEach( + row -> { + SparkStructLike rowWrapper = new SparkStructLike(finalProjection); + set.add(rowWrapper.wrap(row)); + }); + + return set; + } + + protected Table createTable(String name, Schema schema, PartitionSpec spec) { + Map properties = + ImmutableMap.of( + TableProperties.FORMAT_VERSION, + "2", + TableProperties.DEFAULT_FILE_FORMAT, + format.toString()); + return validationCatalog.createTable( + TableIdentifier.of("default", name), schema, spec, properties); + } + + protected void dropTable(String name) { + validationCatalog.dropTable(TableIdentifier.of("default", name), false); + } + + private PositionDelete positionDelete(CharSequence path, Long position) { + PositionDelete posDelete = PositionDelete.create(); + posDelete.set(path, position, null); + return posDelete; + } + + private PositionDelete positionDelete( + Schema tableSchema, CharSequence path, Long position, Object... values) { + PositionDelete posDelete = PositionDelete.create(); + GenericRecord nested = GenericRecord.create(tableSchema); + for (int i = 0; i < values.length; i++) { + nested.set(i, values[i]); + } + posDelete.set(path, position, nested); + return posDelete; + } + + private StructLikeSet expected( + Table testTable, + List> deletes, + StructLike partitionStruct, + int specId, + String deleteFilePath) { + Table deletesTable = + MetadataTableUtils.createMetadataTableInstance( + testTable, MetadataTableType.POSITION_DELETES); + Types.StructType posDeleteSchema = deletesTable.schema().asStruct(); + // Do not compare file paths + if (deleteFilePath == null) { + posDeleteSchema = + TypeUtil.selectNot( + deletesTable.schema(), ImmutableSet.of(MetadataColumns.FILE_PATH_COLUMN_ID)) + .asStruct(); + } + final Types.StructType finalSchema = posDeleteSchema; + StructLikeSet set = StructLikeSet.create(posDeleteSchema); + deletes.stream() + .map( + p -> { + GenericRecord record = GenericRecord.create(finalSchema); + record.setField("file_path", p.path()); + record.setField("pos", p.pos()); + record.setField("row", p.row()); + if (partitionStruct != null) { + record.setField("partition", partitionStruct); + } + record.setField("spec_id", specId); + if (deleteFilePath != null) { + record.setField("delete_file_path", deleteFilePath); + } + return record; + }) + .forEach(set::add); + return set; + } + + private StructLikeSet expected( + Table testTable, + List> deletes, + StructLike partitionStruct, + String deleteFilePath) { + return expected(testTable, deletes, partitionStruct, testTable.spec().specId(), deleteFilePath); + } + + private DataFile dataFile(Table tab, Object... partValues) throws IOException { + return dataFile(tab, partValues, partValues); + } + + @SuppressWarnings("checkstyle:CyclomaticComplexity") + private DataFile dataFile(Table tab, Object[] partDataValues, Object[] partFieldValues) + throws IOException { + GenericRecord record = GenericRecord.create(tab.schema()); + List partitionFieldNames = + tab.spec().fields().stream().map(PartitionField::name).collect(Collectors.toList()); + int idIndex = partitionFieldNames.indexOf("id"); + int dataIndex = partitionFieldNames.indexOf("data"); + Integer idPartition = idIndex != -1 ? (Integer) partDataValues[idIndex] : null; + String dataPartition = dataIndex != -1 ? (String) partDataValues[dataIndex] : null; + + // fill columns with partition source fields, or preset values + List records = + Lists.newArrayList( + record.copy( + "id", + idPartition != null ? idPartition : 29, + "data", + dataPartition != null ? dataPartition : "c"), + record.copy( + "id", + idPartition != null ? idPartition : 43, + "data", + dataPartition != null ? dataPartition : "k"), + record.copy( + "id", + idPartition != null ? idPartition : 61, + "data", + dataPartition != null ? dataPartition : "r"), + record.copy( + "id", + idPartition != null ? idPartition : 89, + "data", + dataPartition != null ? dataPartition : "t")); + + // fill remaining columns with incremental values + List cols = tab.schema().columns(); + if (cols.size() > 2) { + for (int i = 2; i < cols.size(); i++) { + final int pos = i; + records.forEach(r -> r.set(pos, pos)); + } + } + + TestHelpers.Row partitionInfo = TestHelpers.Row.of(partFieldValues); + return FileHelpers.writeDataFile( + tab, Files.localOutput(temp.newFile()), partitionInfo, records); + } + + private Pair>, DeleteFile> deleteFile( + Table tab, DataFile dataFile, Object... partValues) throws IOException { + return deleteFile(tab, dataFile, partValues, partValues); + } + + private Pair>, DeleteFile> deleteFile( + Table tab, DataFile dataFile, Object[] partDataValues, Object[] partFieldValues) + throws IOException { + List partFields = tab.spec().fields(); + List partitionFieldNames = + partFields.stream().map(PartitionField::name).collect(Collectors.toList()); + int idIndex = partitionFieldNames.indexOf("id"); + int dataIndex = partitionFieldNames.indexOf("data"); + Integer idPartition = idIndex != -1 ? (Integer) partDataValues[idIndex] : null; + String dataPartition = dataIndex != -1 ? (String) partDataValues[dataIndex] : null; + + // fill columns with partition source fields, or preset values + List> deletes = + Lists.newArrayList( + positionDelete( + tab.schema(), + dataFile.path(), + 0L, + idPartition != null ? idPartition : 29, + dataPartition != null ? dataPartition : "c"), + positionDelete( + tab.schema(), + dataFile.path(), + 1L, + idPartition != null ? idPartition : 61, + dataPartition != null ? dataPartition : "r")); + + // fill remaining columns with incremental values + List cols = tab.schema().columns(); + if (cols.size() > 2) { + for (int i = 2; i < cols.size(); i++) { + final int pos = i; + deletes.forEach(d -> d.get(2, GenericRecord.class).set(pos, pos)); + } + } + + TestHelpers.Row partitionInfo = TestHelpers.Row.of(partFieldValues); + + DeleteFile deleteFile = + FileHelpers.writePosDeleteFile( + tab, Files.localOutput(temp.newFile()), partitionInfo, deletes); + return Pair.of(deletes, deleteFile); + } + + private void stageTask( + Table tab, String fileSetID, CloseableIterable tasks) { + ScanTaskSetManager taskSetManager = ScanTaskSetManager.get(); + taskSetManager.stageTasks(tab, fileSetID, Lists.newArrayList(tasks)); + } + + private void commit( + Table baseTab, + Table posDeletesTable, + String fileSetID, + int expectedSourceFiles, + int expectedTargetFiles) { + PositionDeletesRewriteCoordinator rewriteCoordinator = PositionDeletesRewriteCoordinator.get(); + Set rewrittenFiles = + ScanTaskSetManager.get().fetchTasks(posDeletesTable, fileSetID).stream() + .map(t -> ((PositionDeletesScanTask) t).file()) + .collect(Collectors.toSet()); + Set addedFiles = rewriteCoordinator.fetchNewFiles(posDeletesTable, fileSetID); + + // Assert new files and old files are equal in number but different in paths + Assert.assertEquals(expectedSourceFiles, rewrittenFiles.size()); + Assert.assertEquals(expectedTargetFiles, addedFiles.size()); + + List sortedAddedFiles = + addedFiles.stream().map(f -> f.path().toString()).sorted().collect(Collectors.toList()); + List sortedRewrittenFiles = + rewrittenFiles.stream().map(f -> f.path().toString()).sorted().collect(Collectors.toList()); + Assert.assertNotEquals("Lists should not be the same", sortedAddedFiles, sortedRewrittenFiles); + + baseTab + .newRewrite() + .rewriteFiles(ImmutableSet.of(), rewrittenFiles, ImmutableSet.of(), addedFiles) + .commit(); + } + + private void commit(Table baseTab, Table posDeletesTable, String fileSetID, int expectedFiles) { + commit(baseTab, posDeletesTable, fileSetID, expectedFiles, expectedFiles); + } + + private CloseableIterable tasks( + Table posDeletesTable, String partitionColumn, String partitionValue) { + + Expression filter = Expressions.equal("partition." + partitionColumn, partitionValue); + CloseableIterable files = posDeletesTable.newBatchScan().filter(filter).planFiles(); + + // take care of fail to filter in some partition evolution cases + return CloseableIterable.filter( + files, + t -> { + StructLike filePartition = ((PositionDeletesScanTask) t).partition(); + String filePartitionValue = filePartition.get(0, String.class); + return filePartitionValue != null && filePartitionValue.equals(partitionValue); + }); + } +} diff --git a/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/source/TestReadProjection.java b/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/source/TestReadProjection.java new file mode 100644 index 000000000000..eecc405b1a09 --- /dev/null +++ b/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/source/TestReadProjection.java @@ -0,0 +1,609 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.source; + +import static org.apache.avro.Schema.Type.UNION; + +import java.io.IOException; +import java.util.List; +import java.util.Map; +import org.apache.iceberg.Schema; +import org.apache.iceberg.data.GenericRecord; +import org.apache.iceberg.data.Record; +import org.apache.iceberg.relocated.com.google.common.base.Preconditions; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableList; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap; +import org.apache.iceberg.relocated.com.google.common.collect.Maps; +import org.apache.iceberg.relocated.com.google.common.collect.Sets; +import org.apache.iceberg.types.Comparators; +import org.apache.iceberg.types.Types; +import org.assertj.core.api.Assertions; +import org.junit.Assert; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.TemporaryFolder; + +public abstract class TestReadProjection { + final String format; + + TestReadProjection(String format) { + this.format = format; + } + + protected abstract Record writeAndRead( + String desc, Schema writeSchema, Schema readSchema, Record record) throws IOException; + + @Rule public TemporaryFolder temp = new TemporaryFolder(); + + @Test + public void testFullProjection() throws Exception { + Schema schema = + new Schema( + Types.NestedField.required(0, "id", Types.LongType.get()), + Types.NestedField.optional(1, "data", Types.StringType.get())); + + Record record = GenericRecord.create(schema); + record.setField("id", 34L); + record.setField("data", "test"); + + Record projected = writeAndRead("full_projection", schema, schema, record); + + Assert.assertEquals( + "Should contain the correct id value", 34L, (long) projected.getField("id")); + + int cmp = + Comparators.charSequences().compare("test", (CharSequence) projected.getField("data")); + Assert.assertEquals("Should contain the correct data value", 0, cmp); + } + + @Test + public void testReorderedFullProjection() throws Exception { + // Assume.assumeTrue( + // "Spark's Parquet read support does not support reordered columns", + // !format.equalsIgnoreCase("parquet")); + + Schema schema = + new Schema( + Types.NestedField.required(0, "id", Types.LongType.get()), + Types.NestedField.optional(1, "data", Types.StringType.get())); + + Record record = GenericRecord.create(schema); + record.setField("id", 34L); + record.setField("data", "test"); + + Schema reordered = + new Schema( + Types.NestedField.optional(1, "data", Types.StringType.get()), + Types.NestedField.required(0, "id", Types.LongType.get())); + + Record projected = writeAndRead("reordered_full_projection", schema, reordered, record); + + Assert.assertEquals("Should contain the correct 0 value", "test", projected.get(0).toString()); + Assert.assertEquals("Should contain the correct 1 value", 34L, projected.get(1)); + } + + @Test + public void testReorderedProjection() throws Exception { + // Assume.assumeTrue( + // "Spark's Parquet read support does not support reordered columns", + // !format.equalsIgnoreCase("parquet")); + + Schema schema = + new Schema( + Types.NestedField.required(0, "id", Types.LongType.get()), + Types.NestedField.optional(1, "data", Types.StringType.get())); + + Record record = GenericRecord.create(schema); + record.setField("id", 34L); + record.setField("data", "test"); + + Schema reordered = + new Schema( + Types.NestedField.optional(2, "missing_1", Types.StringType.get()), + Types.NestedField.optional(1, "data", Types.StringType.get()), + Types.NestedField.optional(3, "missing_2", Types.LongType.get())); + + Record projected = writeAndRead("reordered_projection", schema, reordered, record); + + Assert.assertNull("Should contain the correct 0 value", projected.get(0)); + Assert.assertEquals("Should contain the correct 1 value", "test", projected.get(1).toString()); + Assert.assertNull("Should contain the correct 2 value", projected.get(2)); + } + + @Test + public void testEmptyProjection() throws Exception { + Schema schema = + new Schema( + Types.NestedField.required(0, "id", Types.LongType.get()), + Types.NestedField.optional(1, "data", Types.StringType.get())); + + Record record = GenericRecord.create(schema); + record.setField("id", 34L); + record.setField("data", "test"); + + Record projected = writeAndRead("empty_projection", schema, schema.select(), record); + + Assert.assertNotNull("Should read a non-null record", projected); + // this is expected because there are no values + Assertions.assertThatThrownBy(() -> projected.get(0)) + .isInstanceOf(ArrayIndexOutOfBoundsException.class); + } + + @Test + public void testBasicProjection() throws Exception { + Schema writeSchema = + new Schema( + Types.NestedField.required(0, "id", Types.LongType.get()), + Types.NestedField.optional(1, "data", Types.StringType.get())); + + Record record = GenericRecord.create(writeSchema); + record.setField("id", 34L); + record.setField("data", "test"); + + Schema idOnly = new Schema(Types.NestedField.required(0, "id", Types.LongType.get())); + + Record projected = writeAndRead("basic_projection_id", writeSchema, idOnly, record); + Assert.assertNull("Should not project data", projected.getField("data")); + Assert.assertEquals( + "Should contain the correct id value", 34L, (long) projected.getField("id")); + + Schema dataOnly = new Schema(Types.NestedField.optional(1, "data", Types.StringType.get())); + + projected = writeAndRead("basic_projection_data", writeSchema, dataOnly, record); + + Assert.assertNull("Should not project id", projected.getField("id")); + int cmp = + Comparators.charSequences().compare("test", (CharSequence) projected.getField("data")); + Assert.assertEquals("Should contain the correct data value", 0, cmp); + } + + @Test + public void testRename() throws Exception { + Schema writeSchema = + new Schema( + Types.NestedField.required(0, "id", Types.LongType.get()), + Types.NestedField.optional(1, "data", Types.StringType.get())); + + Record record = GenericRecord.create(writeSchema); + record.setField("id", 34L); + record.setField("data", "test"); + + Schema readSchema = + new Schema( + Types.NestedField.required(0, "id", Types.LongType.get()), + Types.NestedField.optional(1, "renamed", Types.StringType.get())); + + Record projected = writeAndRead("project_and_rename", writeSchema, readSchema, record); + + Assert.assertEquals( + "Should contain the correct id value", 34L, (long) projected.getField("id")); + int cmp = + Comparators.charSequences().compare("test", (CharSequence) projected.getField("renamed")); + Assert.assertEquals("Should contain the correct data/renamed value", 0, cmp); + } + + @Test + public void testNestedStructProjection() throws Exception { + Schema writeSchema = + new Schema( + Types.NestedField.required(0, "id", Types.LongType.get()), + Types.NestedField.optional( + 3, + "location", + Types.StructType.of( + Types.NestedField.required(1, "lat", Types.FloatType.get()), + Types.NestedField.required(2, "long", Types.FloatType.get())))); + + Record record = GenericRecord.create(writeSchema); + record.setField("id", 34L); + Record location = GenericRecord.create(writeSchema.findType("location").asStructType()); + location.setField("lat", 52.995143f); + location.setField("long", -1.539054f); + record.setField("location", location); + + Schema idOnly = new Schema(Types.NestedField.required(0, "id", Types.LongType.get())); + + Record projected = writeAndRead("id_only", writeSchema, idOnly, record); + Record projectedLocation = (Record) projected.getField("location"); + Assert.assertEquals( + "Should contain the correct id value", 34L, (long) projected.getField("id")); + Assert.assertNull("Should not project location", projectedLocation); + + Schema latOnly = + new Schema( + Types.NestedField.optional( + 3, + "location", + Types.StructType.of(Types.NestedField.required(1, "lat", Types.FloatType.get())))); + + projected = writeAndRead("latitude_only", writeSchema, latOnly, record); + projectedLocation = (Record) projected.getField("location"); + Assert.assertNull("Should not project id", projected.getField("id")); + Assert.assertNotNull("Should project location", projected.getField("location")); + Assert.assertNull("Should not project longitude", projectedLocation.getField("long")); + Assert.assertEquals( + "Should project latitude", + 52.995143f, + (float) projectedLocation.getField("lat"), + 0.000001f); + + Schema longOnly = + new Schema( + Types.NestedField.optional( + 3, + "location", + Types.StructType.of(Types.NestedField.required(2, "long", Types.FloatType.get())))); + + projected = writeAndRead("longitude_only", writeSchema, longOnly, record); + projectedLocation = (Record) projected.getField("location"); + Assert.assertNull("Should not project id", projected.getField("id")); + Assert.assertNotNull("Should project location", projected.getField("location")); + Assert.assertNull("Should not project latitutde", projectedLocation.getField("lat")); + Assert.assertEquals( + "Should project longitude", + -1.539054f, + (float) projectedLocation.getField("long"), + 0.000001f); + + Schema locationOnly = writeSchema.select("location"); + projected = writeAndRead("location_only", writeSchema, locationOnly, record); + projectedLocation = (Record) projected.getField("location"); + Assert.assertNull("Should not project id", projected.getField("id")); + Assert.assertNotNull("Should project location", projected.getField("location")); + Assert.assertEquals( + "Should project latitude", + 52.995143f, + (float) projectedLocation.getField("lat"), + 0.000001f); + Assert.assertEquals( + "Should project longitude", + -1.539054f, + (float) projectedLocation.getField("long"), + 0.000001f); + } + + @Test + public void testMapProjection() throws IOException { + Schema writeSchema = + new Schema( + Types.NestedField.required(0, "id", Types.LongType.get()), + Types.NestedField.optional( + 5, + "properties", + Types.MapType.ofOptional(6, 7, Types.StringType.get(), Types.StringType.get()))); + + Map properties = ImmutableMap.of("a", "A", "b", "B"); + + Record record = GenericRecord.create(writeSchema); + record.setField("id", 34L); + record.setField("properties", properties); + + Schema idOnly = new Schema(Types.NestedField.required(0, "id", Types.LongType.get())); + + Record projected = writeAndRead("id_only", writeSchema, idOnly, record); + Assert.assertEquals( + "Should contain the correct id value", 34L, (long) projected.getField("id")); + Assert.assertNull("Should not project properties map", projected.getField("properties")); + + Schema keyOnly = writeSchema.select("properties.key"); + projected = writeAndRead("key_only", writeSchema, keyOnly, record); + Assert.assertNull("Should not project id", projected.getField("id")); + Assert.assertEquals( + "Should project entire map", + properties, + toStringMap((Map) projected.getField("properties"))); + + Schema valueOnly = writeSchema.select("properties.value"); + projected = writeAndRead("value_only", writeSchema, valueOnly, record); + Assert.assertNull("Should not project id", projected.getField("id")); + Assert.assertEquals( + "Should project entire map", + properties, + toStringMap((Map) projected.getField("properties"))); + + Schema mapOnly = writeSchema.select("properties"); + projected = writeAndRead("map_only", writeSchema, mapOnly, record); + Assert.assertNull("Should not project id", projected.getField("id")); + Assert.assertEquals( + "Should project entire map", + properties, + toStringMap((Map) projected.getField("properties"))); + } + + private Map toStringMap(Map map) { + Map stringMap = Maps.newHashMap(); + for (Map.Entry entry : map.entrySet()) { + if (entry.getValue() instanceof CharSequence) { + stringMap.put(entry.getKey().toString(), entry.getValue().toString()); + } else { + stringMap.put(entry.getKey().toString(), entry.getValue()); + } + } + return stringMap; + } + + @Test + public void testMapOfStructsProjection() throws IOException { + Schema writeSchema = + new Schema( + Types.NestedField.required(0, "id", Types.LongType.get()), + Types.NestedField.optional( + 5, + "locations", + Types.MapType.ofOptional( + 6, + 7, + Types.StringType.get(), + Types.StructType.of( + Types.NestedField.required(1, "lat", Types.FloatType.get()), + Types.NestedField.required(2, "long", Types.FloatType.get()))))); + + Record record = GenericRecord.create(writeSchema); + record.setField("id", 34L); + Record l1 = GenericRecord.create(writeSchema.findType("locations.value").asStructType()); + l1.setField("lat", 53.992811f); + l1.setField("long", -1.542616f); + Record l2 = GenericRecord.create(l1.struct()); + l2.setField("lat", 52.995143f); + l2.setField("long", -1.539054f); + record.setField("locations", ImmutableMap.of("L1", l1, "L2", l2)); + + Schema idOnly = new Schema(Types.NestedField.required(0, "id", Types.LongType.get())); + + Record projected = writeAndRead("id_only", writeSchema, idOnly, record); + Assert.assertEquals( + "Should contain the correct id value", 34L, (long) projected.getField("id")); + Assert.assertNull("Should not project locations map", projected.getField("locations")); + + projected = writeAndRead("all_locations", writeSchema, writeSchema.select("locations"), record); + Assert.assertNull("Should not project id", projected.getField("id")); + Assert.assertEquals( + "Should project locations map", + record.getField("locations"), + toStringMap((Map) projected.getField("locations"))); + + projected = writeAndRead("lat_only", writeSchema, writeSchema.select("locations.lat"), record); + Assert.assertNull("Should not project id", projected.getField("id")); + Map locations = toStringMap((Map) projected.getField("locations")); + Assert.assertNotNull("Should project locations map", locations); + Assert.assertEquals( + "Should contain L1 and L2", Sets.newHashSet("L1", "L2"), locations.keySet()); + Record projectedL1 = (Record) locations.get("L1"); + Assert.assertNotNull("L1 should not be null", projectedL1); + Assert.assertEquals( + "L1 should contain lat", 53.992811f, (float) projectedL1.getField("lat"), 0.000001); + Assert.assertNull("L1 should not contain long", projectedL1.getField("long")); + Record projectedL2 = (Record) locations.get("L2"); + Assert.assertNotNull("L2 should not be null", projectedL2); + Assert.assertEquals( + "L2 should contain lat", 52.995143f, (float) projectedL2.getField("lat"), 0.000001); + Assert.assertNull("L2 should not contain long", projectedL2.getField("long")); + + projected = + writeAndRead("long_only", writeSchema, writeSchema.select("locations.long"), record); + Assert.assertNull("Should not project id", projected.getField("id")); + locations = toStringMap((Map) projected.getField("locations")); + Assert.assertNotNull("Should project locations map", locations); + Assert.assertEquals( + "Should contain L1 and L2", Sets.newHashSet("L1", "L2"), locations.keySet()); + projectedL1 = (Record) locations.get("L1"); + Assert.assertNotNull("L1 should not be null", projectedL1); + Assert.assertNull("L1 should not contain lat", projectedL1.getField("lat")); + Assert.assertEquals( + "L1 should contain long", -1.542616f, (float) projectedL1.getField("long"), 0.000001); + projectedL2 = (Record) locations.get("L2"); + Assert.assertNotNull("L2 should not be null", projectedL2); + Assert.assertNull("L2 should not contain lat", projectedL2.getField("lat")); + Assert.assertEquals( + "L2 should contain long", -1.539054f, (float) projectedL2.getField("long"), 0.000001); + + Schema latitiudeRenamed = + new Schema( + Types.NestedField.optional( + 5, + "locations", + Types.MapType.ofOptional( + 6, + 7, + Types.StringType.get(), + Types.StructType.of( + Types.NestedField.required(1, "latitude", Types.FloatType.get()))))); + + projected = writeAndRead("latitude_renamed", writeSchema, latitiudeRenamed, record); + Assert.assertNull("Should not project id", projected.getField("id")); + locations = toStringMap((Map) projected.getField("locations")); + Assert.assertNotNull("Should project locations map", locations); + Assert.assertEquals( + "Should contain L1 and L2", Sets.newHashSet("L1", "L2"), locations.keySet()); + projectedL1 = (Record) locations.get("L1"); + Assert.assertNotNull("L1 should not be null", projectedL1); + Assert.assertEquals( + "L1 should contain latitude", + 53.992811f, + (float) projectedL1.getField("latitude"), + 0.000001); + Assert.assertNull("L1 should not contain lat", projectedL1.getField("lat")); + Assert.assertNull("L1 should not contain long", projectedL1.getField("long")); + projectedL2 = (Record) locations.get("L2"); + Assert.assertNotNull("L2 should not be null", projectedL2); + Assert.assertEquals( + "L2 should contain latitude", + 52.995143f, + (float) projectedL2.getField("latitude"), + 0.000001); + Assert.assertNull("L2 should not contain lat", projectedL2.getField("lat")); + Assert.assertNull("L2 should not contain long", projectedL2.getField("long")); + } + + @Test + public void testListProjection() throws IOException { + Schema writeSchema = + new Schema( + Types.NestedField.required(0, "id", Types.LongType.get()), + Types.NestedField.optional( + 10, "values", Types.ListType.ofOptional(11, Types.LongType.get()))); + + List values = ImmutableList.of(56L, 57L, 58L); + + Record record = GenericRecord.create(writeSchema); + record.setField("id", 34L); + record.setField("values", values); + + Schema idOnly = new Schema(Types.NestedField.required(0, "id", Types.LongType.get())); + + Record projected = writeAndRead("id_only", writeSchema, idOnly, record); + Assert.assertEquals( + "Should contain the correct id value", 34L, (long) projected.getField("id")); + Assert.assertNull("Should not project values list", projected.getField("values")); + + Schema elementOnly = writeSchema.select("values.element"); + projected = writeAndRead("element_only", writeSchema, elementOnly, record); + Assert.assertNull("Should not project id", projected.getField("id")); + Assert.assertEquals("Should project entire list", values, projected.getField("values")); + + Schema listOnly = writeSchema.select("values"); + projected = writeAndRead("list_only", writeSchema, listOnly, record); + Assert.assertNull("Should not project id", projected.getField("id")); + Assert.assertEquals("Should project entire list", values, projected.getField("values")); + } + + @Test + @SuppressWarnings("unchecked") + public void testListOfStructsProjection() throws IOException { + Schema writeSchema = + new Schema( + Types.NestedField.required(0, "id", Types.LongType.get()), + Types.NestedField.optional( + 22, + "points", + Types.ListType.ofOptional( + 21, + Types.StructType.of( + Types.NestedField.required(19, "x", Types.IntegerType.get()), + Types.NestedField.optional(18, "y", Types.IntegerType.get()))))); + + Record record = GenericRecord.create(writeSchema); + record.setField("id", 34L); + Record p1 = GenericRecord.create(writeSchema.findType("points.element").asStructType()); + p1.setField("x", 1); + p1.setField("y", 2); + Record p2 = GenericRecord.create(p1.struct()); + p2.setField("x", 3); + p2.setField("y", null); + record.setField("points", ImmutableList.of(p1, p2)); + + Schema idOnly = new Schema(Types.NestedField.required(0, "id", Types.LongType.get())); + + Record projected = writeAndRead("id_only", writeSchema, idOnly, record); + Assert.assertEquals( + "Should contain the correct id value", 34L, (long) projected.getField("id")); + Assert.assertNull("Should not project points list", projected.getField("points")); + + projected = writeAndRead("all_points", writeSchema, writeSchema.select("points"), record); + Assert.assertNull("Should not project id", projected.getField("id")); + Assert.assertEquals( + "Should project points list", record.getField("points"), projected.getField("points")); + + projected = writeAndRead("x_only", writeSchema, writeSchema.select("points.x"), record); + Assert.assertNull("Should not project id", projected.getField("id")); + Assert.assertNotNull("Should project points list", projected.getField("points")); + List points = (List) projected.getField("points"); + Assert.assertEquals("Should read 2 points", 2, points.size()); + Record projectedP1 = points.get(0); + Assert.assertEquals("Should project x", 1, (int) projectedP1.getField("x")); + Assert.assertNull("Should not project y", projectedP1.getField("y")); + Record projectedP2 = points.get(1); + Assert.assertEquals("Should project x", 3, (int) projectedP2.getField("x")); + Assert.assertNull("Should not project y", projectedP2.getField("y")); + + projected = writeAndRead("y_only", writeSchema, writeSchema.select("points.y"), record); + Assert.assertNull("Should not project id", projected.getField("id")); + Assert.assertNotNull("Should project points list", projected.getField("points")); + points = (List) projected.getField("points"); + Assert.assertEquals("Should read 2 points", 2, points.size()); + projectedP1 = points.get(0); + Assert.assertNull("Should not project x", projectedP1.getField("x")); + Assert.assertEquals("Should project y", 2, (int) projectedP1.getField("y")); + projectedP2 = points.get(1); + Assert.assertNull("Should not project x", projectedP2.getField("x")); + Assert.assertNull("Should project null y", projectedP2.getField("y")); + + Schema yRenamed = + new Schema( + Types.NestedField.optional( + 22, + "points", + Types.ListType.ofOptional( + 21, + Types.StructType.of( + Types.NestedField.optional(18, "z", Types.IntegerType.get()))))); + + projected = writeAndRead("y_renamed", writeSchema, yRenamed, record); + Assert.assertNull("Should not project id", projected.getField("id")); + Assert.assertNotNull("Should project points list", projected.getField("points")); + points = (List) projected.getField("points"); + Assert.assertEquals("Should read 2 points", 2, points.size()); + projectedP1 = points.get(0); + Assert.assertNull("Should not project x", projectedP1.getField("x")); + Assert.assertNull("Should not project y", projectedP1.getField("y")); + Assert.assertEquals("Should project z", 2, (int) projectedP1.getField("z")); + projectedP2 = points.get(1); + Assert.assertNull("Should not project x", projectedP2.getField("x")); + Assert.assertNull("Should not project y", projectedP2.getField("y")); + Assert.assertNull("Should project null z", projectedP2.getField("z")); + + Schema zAdded = + new Schema( + Types.NestedField.optional( + 22, + "points", + Types.ListType.ofOptional( + 21, + Types.StructType.of( + Types.NestedField.required(19, "x", Types.IntegerType.get()), + Types.NestedField.optional(18, "y", Types.IntegerType.get()), + Types.NestedField.optional(20, "z", Types.IntegerType.get()))))); + + projected = writeAndRead("z_added", writeSchema, zAdded, record); + Assert.assertNull("Should not project id", projected.getField("id")); + Assert.assertNotNull("Should project points list", projected.getField("points")); + points = (List) projected.getField("points"); + Assert.assertEquals("Should read 2 points", 2, points.size()); + projectedP1 = points.get(0); + Assert.assertEquals("Should project x", 1, (int) projectedP1.getField("x")); + Assert.assertEquals("Should project y", 2, (int) projectedP1.getField("y")); + Assert.assertNull("Should contain null z", projectedP1.getField("z")); + projectedP2 = points.get(1); + Assert.assertEquals("Should project x", 3, (int) projectedP2.getField("x")); + Assert.assertNull("Should project null y", projectedP2.getField("y")); + Assert.assertNull("Should contain null z", projectedP2.getField("z")); + } + + private static org.apache.avro.Schema fromOption(org.apache.avro.Schema schema) { + Preconditions.checkArgument( + schema.getType() == UNION, "Expected union schema but was passed: %s", schema); + Preconditions.checkArgument( + schema.getTypes().size() == 2, "Expected optional schema, but was passed: %s", schema); + if (schema.getTypes().get(0).getType() == org.apache.avro.Schema.Type.NULL) { + return schema.getTypes().get(1); + } else { + return schema.getTypes().get(0); + } + } +} diff --git a/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/source/TestRequiredDistributionAndOrdering.java b/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/source/TestRequiredDistributionAndOrdering.java new file mode 100644 index 000000000000..521d90299d2b --- /dev/null +++ b/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/source/TestRequiredDistributionAndOrdering.java @@ -0,0 +1,327 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.source; + +import java.util.List; +import java.util.Map; +import org.apache.iceberg.AssertHelpers; +import org.apache.iceberg.Table; +import org.apache.iceberg.TableProperties; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableList; +import org.apache.iceberg.spark.SparkCatalogTestBase; +import org.apache.iceberg.spark.SparkWriteOptions; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.catalyst.analysis.NoSuchTableException; +import org.junit.After; +import org.junit.Test; + +public class TestRequiredDistributionAndOrdering extends SparkCatalogTestBase { + + public TestRequiredDistributionAndOrdering( + String catalogName, String implementation, Map config) { + super(catalogName, implementation, config); + } + + @After + public void dropTestTable() { + sql("DROP TABLE IF EXISTS %s", tableName); + } + + @Test + public void testDefaultLocalSort() throws NoSuchTableException { + sql( + "CREATE TABLE %s (c1 INT, c2 STRING, c3 STRING) " + + "USING iceberg " + + "PARTITIONED BY (c3)", + tableName); + + List data = + ImmutableList.of( + new ThreeColumnRecord(1, null, "A"), + new ThreeColumnRecord(2, "BBBBBBBBBB", "B"), + new ThreeColumnRecord(3, "BBBBBBBBBB", "A"), + new ThreeColumnRecord(4, "BBBBBBBBBB", "B"), + new ThreeColumnRecord(5, "BBBBBBBBBB", "A"), + new ThreeColumnRecord(6, "BBBBBBBBBB", "B"), + new ThreeColumnRecord(7, "BBBBBBBBBB", "A")); + Dataset ds = spark.createDataFrame(data, ThreeColumnRecord.class); + Dataset inputDF = ds.coalesce(1).sortWithinPartitions("c1"); + + // should insert a local sort by partition columns by default + inputDF.writeTo(tableName).append(); + + assertEquals( + "Row count must match", + ImmutableList.of(row(7L)), + sql("SELECT count(*) FROM %s", tableName)); + } + + @Test + public void testPartitionColumnsArePrependedForRangeDistribution() throws NoSuchTableException { + sql( + "CREATE TABLE %s (c1 INT, c2 STRING, c3 STRING) " + + "USING iceberg " + + "PARTITIONED BY (c3)", + tableName); + + List data = + ImmutableList.of( + new ThreeColumnRecord(1, null, "A"), + new ThreeColumnRecord(2, "BBBBBBBBBB", "B"), + new ThreeColumnRecord(3, "BBBBBBBBBB", "A"), + new ThreeColumnRecord(4, "BBBBBBBBBB", "B"), + new ThreeColumnRecord(5, "BBBBBBBBBB", "A"), + new ThreeColumnRecord(6, "BBBBBBBBBB", "B"), + new ThreeColumnRecord(7, "BBBBBBBBBB", "A")); + Dataset ds = spark.createDataFrame(data, ThreeColumnRecord.class); + Dataset inputDF = ds.coalesce(1).sortWithinPartitions("c1"); + + Table table = validationCatalog.loadTable(tableIdent); + + // should automatically prepend partition columns to the ordering + table + .updateProperties() + .set(TableProperties.WRITE_DISTRIBUTION_MODE, TableProperties.WRITE_DISTRIBUTION_MODE_RANGE) + .commit(); + table.replaceSortOrder().asc("c1").asc("c2").commit(); + inputDF.writeTo(tableName).append(); + + assertEquals( + "Row count must match", + ImmutableList.of(row(7L)), + sql("SELECT count(*) FROM %s", tableName)); + } + + @Test + public void testSortOrderIncludesPartitionColumns() throws NoSuchTableException { + sql( + "CREATE TABLE %s (c1 INT, c2 STRING, c3 STRING) " + + "USING iceberg " + + "PARTITIONED BY (c3)", + tableName); + + List data = + ImmutableList.of( + new ThreeColumnRecord(1, null, "A"), + new ThreeColumnRecord(2, "BBBBBBBBBB", "B"), + new ThreeColumnRecord(3, "BBBBBBBBBB", "A"), + new ThreeColumnRecord(4, "BBBBBBBBBB", "B"), + new ThreeColumnRecord(5, "BBBBBBBBBB", "A"), + new ThreeColumnRecord(6, "BBBBBBBBBB", "B"), + new ThreeColumnRecord(7, "BBBBBBBBBB", "A")); + Dataset ds = spark.createDataFrame(data, ThreeColumnRecord.class); + Dataset inputDF = ds.coalesce(1).sortWithinPartitions("c1"); + + Table table = validationCatalog.loadTable(tableIdent); + + // should succeed with a correct sort order + table.replaceSortOrder().asc("c3").asc("c1").asc("c2").commit(); + inputDF.writeTo(tableName).append(); + + assertEquals( + "Row count must match", + ImmutableList.of(row(7L)), + sql("SELECT count(*) FROM %s", tableName)); + } + + @Test + public void testDisabledDistributionAndOrdering() { + sql( + "CREATE TABLE %s (c1 INT, c2 STRING, c3 STRING) " + + "USING iceberg " + + "PARTITIONED BY (c3)", + tableName); + + List data = + ImmutableList.of( + new ThreeColumnRecord(1, null, "A"), + new ThreeColumnRecord(2, "BBBBBBBBBB", "B"), + new ThreeColumnRecord(3, "BBBBBBBBBB", "A"), + new ThreeColumnRecord(4, "BBBBBBBBBB", "B"), + new ThreeColumnRecord(5, "BBBBBBBBBB", "A"), + new ThreeColumnRecord(6, "BBBBBBBBBB", "B"), + new ThreeColumnRecord(7, "BBBBBBBBBB", "A")); + Dataset ds = spark.createDataFrame(data, ThreeColumnRecord.class); + Dataset inputDF = ds.coalesce(1).sortWithinPartitions("c1"); + + // should fail if ordering is disabled + AssertHelpers.assertThrowsCause( + "Should reject writes without ordering", + IllegalStateException.class, + "Incoming records violate the writer assumption", + () -> { + try { + inputDF + .writeTo(tableName) + .option(SparkWriteOptions.USE_TABLE_DISTRIBUTION_AND_ORDERING, "false") + .append(); + } catch (NoSuchTableException e) { + throw new RuntimeException(e); + } + }); + } + + @Test + public void testHashDistribution() throws NoSuchTableException { + sql( + "CREATE TABLE %s (c1 INT, c2 STRING, c3 STRING) " + + "USING iceberg " + + "PARTITIONED BY (c3)", + tableName); + + List data = + ImmutableList.of( + new ThreeColumnRecord(1, null, "A"), + new ThreeColumnRecord(2, "BBBBBBBBBB", "B"), + new ThreeColumnRecord(3, "BBBBBBBBBB", "A"), + new ThreeColumnRecord(4, "BBBBBBBBBB", "B"), + new ThreeColumnRecord(5, "BBBBBBBBBB", "A"), + new ThreeColumnRecord(6, "BBBBBBBBBB", "B"), + new ThreeColumnRecord(7, "BBBBBBBBBB", "A")); + Dataset ds = spark.createDataFrame(data, ThreeColumnRecord.class); + Dataset inputDF = ds.coalesce(1).sortWithinPartitions("c1"); + + Table table = validationCatalog.loadTable(tableIdent); + + // should automatically prepend partition columns to the local ordering after hash distribution + table + .updateProperties() + .set(TableProperties.WRITE_DISTRIBUTION_MODE, TableProperties.WRITE_DISTRIBUTION_MODE_HASH) + .commit(); + table.replaceSortOrder().asc("c1").asc("c2").commit(); + inputDF.writeTo(tableName).append(); + + assertEquals( + "Row count must match", + ImmutableList.of(row(7L)), + sql("SELECT count(*) FROM %s", tableName)); + } + + @Test + public void testNoSortBucketTransformsWithoutExtensions() throws NoSuchTableException { + sql( + "CREATE TABLE %s (c1 INT, c2 STRING, c3 STRING) " + + "USING iceberg " + + "PARTITIONED BY (bucket(2, c1))", + tableName); + + List data = + ImmutableList.of( + new ThreeColumnRecord(1, null, "A"), + new ThreeColumnRecord(2, "BBBB", "B"), + new ThreeColumnRecord(3, "BBBB", "B"), + new ThreeColumnRecord(4, "BBBB", "B")); + Dataset ds = spark.createDataFrame(data, ThreeColumnRecord.class); + Dataset inputDF = ds.coalesce(1).sortWithinPartitions("c1"); + + // should fail by default as extensions are disabled + AssertHelpers.assertThrowsCause( + "Should reject writes without ordering", + IllegalStateException.class, + "Incoming records violate the writer assumption", + () -> { + try { + inputDF.writeTo(tableName).append(); + } catch (NoSuchTableException e) { + throw new RuntimeException(e); + } + }); + + inputDF.writeTo(tableName).option(SparkWriteOptions.FANOUT_ENABLED, "true").append(); + + List expected = + ImmutableList.of( + row(1, null, "A"), row(2, "BBBB", "B"), row(3, "BBBB", "B"), row(4, "BBBB", "B")); + + assertEquals("Rows must match", expected, sql("SELECT * FROM %s ORDER BY c1", tableName)); + } + + @Test + public void testRangeDistributionWithQuotedColumnsNames() throws NoSuchTableException { + sql( + "CREATE TABLE %s (c1 INT, c2 STRING, `c.3` STRING) " + + "USING iceberg " + + "PARTITIONED BY (`c.3`)", + tableName); + + List data = + ImmutableList.of( + new ThreeColumnRecord(1, null, "A"), + new ThreeColumnRecord(2, "BBBBBBBBBB", "B"), + new ThreeColumnRecord(3, "BBBBBBBBBB", "A"), + new ThreeColumnRecord(4, "BBBBBBBBBB", "B"), + new ThreeColumnRecord(5, "BBBBBBBBBB", "A"), + new ThreeColumnRecord(6, "BBBBBBBBBB", "B"), + new ThreeColumnRecord(7, "BBBBBBBBBB", "A")); + Dataset ds = spark.createDataFrame(data, ThreeColumnRecord.class); + Dataset inputDF = + ds.selectExpr("c1", "c2", "c3 as `c.3`").coalesce(1).sortWithinPartitions("c1"); + + Table table = validationCatalog.loadTable(tableIdent); + + table + .updateProperties() + .set(TableProperties.WRITE_DISTRIBUTION_MODE, TableProperties.WRITE_DISTRIBUTION_MODE_RANGE) + .commit(); + table.replaceSortOrder().asc("c1").asc("c2").commit(); + inputDF.writeTo(tableName).append(); + + assertEquals( + "Row count must match", + ImmutableList.of(row(7L)), + sql("SELECT count(*) FROM %s", tableName)); + } + + @Test + public void testHashDistributionWithQuotedColumnsNames() throws NoSuchTableException { + sql( + "CREATE TABLE %s (c1 INT, c2 STRING, `c``3` STRING) " + + "USING iceberg " + + "PARTITIONED BY (`c``3`)", + tableName); + + List data = + ImmutableList.of( + new ThreeColumnRecord(1, null, "A"), + new ThreeColumnRecord(2, "BBBBBBBBBB", "B"), + new ThreeColumnRecord(3, "BBBBBBBBBB", "A"), + new ThreeColumnRecord(4, "BBBBBBBBBB", "B"), + new ThreeColumnRecord(5, "BBBBBBBBBB", "A"), + new ThreeColumnRecord(6, "BBBBBBBBBB", "B"), + new ThreeColumnRecord(7, "BBBBBBBBBB", "A")); + Dataset ds = spark.createDataFrame(data, ThreeColumnRecord.class); + Dataset inputDF = + ds.selectExpr("c1", "c2", "c3 as `c``3`").coalesce(1).sortWithinPartitions("c1"); + + Table table = validationCatalog.loadTable(tableIdent); + + table + .updateProperties() + .set(TableProperties.WRITE_DISTRIBUTION_MODE, TableProperties.WRITE_DISTRIBUTION_MODE_HASH) + .commit(); + table.replaceSortOrder().asc("c1").asc("c2").commit(); + inputDF.writeTo(tableName).append(); + + assertEquals( + "Row count must match", + ImmutableList.of(row(7L)), + sql("SELECT count(*) FROM %s", tableName)); + } +} diff --git a/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/source/TestRuntimeFiltering.java b/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/source/TestRuntimeFiltering.java new file mode 100644 index 000000000000..beaf7b75c6c0 --- /dev/null +++ b/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/source/TestRuntimeFiltering.java @@ -0,0 +1,472 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.source; + +import static org.apache.spark.sql.functions.date_add; +import static org.apache.spark.sql.functions.expr; + +import java.io.IOException; +import java.io.UncheckedIOException; +import java.util.List; +import java.util.Set; +import org.apache.commons.lang3.StringUtils; +import org.apache.iceberg.FileScanTask; +import org.apache.iceberg.Table; +import org.apache.iceberg.expressions.Expression; +import org.apache.iceberg.expressions.Expressions; +import org.apache.iceberg.io.CloseableIterable; +import org.apache.iceberg.io.FileIO; +import org.apache.iceberg.relocated.com.google.common.collect.Sets; +import org.apache.iceberg.spark.SparkTestBaseWithCatalog; +import org.apache.iceberg.spark.SparkWriteOptions; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.catalyst.analysis.NoSuchTableException; +import org.junit.After; +import org.junit.Assert; +import org.junit.Test; + +public class TestRuntimeFiltering extends SparkTestBaseWithCatalog { + + @After + public void removeTables() { + sql("DROP TABLE IF EXISTS %s", tableName); + sql("DROP TABLE IF EXISTS dim"); + } + + @Test + public void testIdentityPartitionedTable() throws NoSuchTableException { + sql( + "CREATE TABLE %s (id BIGINT, data STRING, date DATE, ts TIMESTAMP) " + + "USING iceberg " + + "PARTITIONED BY (date)", + tableName); + + Dataset df = + spark + .range(1, 100) + .withColumn("date", date_add(expr("DATE '1970-01-01'"), expr("CAST(id % 4 AS INT)"))) + .withColumn("ts", expr("TO_TIMESTAMP(date)")) + .withColumn("data", expr("CAST(date AS STRING)")) + .select("id", "data", "date", "ts"); + + df.coalesce(1).writeTo(tableName).option(SparkWriteOptions.FANOUT_ENABLED, "true").append(); + + sql("CREATE TABLE dim (id BIGINT, date DATE) USING parquet"); + Dataset dimDF = + spark.range(1, 10).withColumn("date", expr("DATE '1970-01-02'")).select("id", "date"); + dimDF.coalesce(1).write().mode("append").insertInto("dim"); + + String query = + String.format( + "SELECT f.* FROM %s f JOIN dim d ON f.date = d.date AND d.id = 1 ORDER BY id", + tableName); + + assertQueryContainsRuntimeFilter(query); + + deleteNotMatchingFiles(Expressions.equal("date", 1), 3); + + assertEquals( + "Should have expected rows", + sql("SELECT * FROM %s WHERE date = DATE '1970-01-02' ORDER BY id", tableName), + sql(query)); + } + + @Test + public void testBucketedTable() throws NoSuchTableException { + sql( + "CREATE TABLE %s (id BIGINT, data STRING, date DATE, ts TIMESTAMP) " + + "USING iceberg " + + "PARTITIONED BY (bucket(8, id))", + tableName); + + Dataset df = + spark + .range(1, 100) + .withColumn("date", date_add(expr("DATE '1970-01-01'"), expr("CAST(id % 4 AS INT)"))) + .withColumn("ts", expr("TO_TIMESTAMP(date)")) + .withColumn("data", expr("CAST(date AS STRING)")) + .select("id", "data", "date", "ts"); + + df.coalesce(1).writeTo(tableName).option(SparkWriteOptions.FANOUT_ENABLED, "true").append(); + + sql("CREATE TABLE dim (id BIGINT, date DATE) USING parquet"); + Dataset dimDF = + spark.range(1, 2).withColumn("date", expr("DATE '1970-01-02'")).select("id", "date"); + dimDF.coalesce(1).write().mode("append").insertInto("dim"); + + String query = + String.format( + "SELECT f.* FROM %s f JOIN dim d ON f.id = d.id AND d.date = DATE '1970-01-02' ORDER BY date", + tableName); + + assertQueryContainsRuntimeFilter(query); + + deleteNotMatchingFiles(Expressions.equal("id", 1), 7); + + assertEquals( + "Should have expected rows", + sql("SELECT * FROM %s WHERE id = 1 ORDER BY date", tableName), + sql(query)); + } + + @Test + public void testRenamedSourceColumnTable() throws NoSuchTableException { + sql( + "CREATE TABLE %s (id BIGINT, data STRING, date DATE, ts TIMESTAMP) " + + "USING iceberg " + + "PARTITIONED BY (bucket(8, id))", + tableName); + + Dataset df = + spark + .range(1, 100) + .withColumn("date", date_add(expr("DATE '1970-01-01'"), expr("CAST(id % 4 AS INT)"))) + .withColumn("ts", expr("TO_TIMESTAMP(date)")) + .withColumn("data", expr("CAST(date AS STRING)")) + .select("id", "data", "date", "ts"); + + df.coalesce(1).writeTo(tableName).option(SparkWriteOptions.FANOUT_ENABLED, "true").append(); + + sql("CREATE TABLE dim (id BIGINT, date DATE) USING parquet"); + Dataset dimDF = + spark.range(1, 2).withColumn("date", expr("DATE '1970-01-02'")).select("id", "date"); + dimDF.coalesce(1).write().mode("append").insertInto("dim"); + + sql("ALTER TABLE %s RENAME COLUMN id TO row_id", tableName); + + String query = + String.format( + "SELECT f.* FROM %s f JOIN dim d ON f.row_id = d.id AND d.date = DATE '1970-01-02' ORDER BY date", + tableName); + + assertQueryContainsRuntimeFilter(query); + + deleteNotMatchingFiles(Expressions.equal("row_id", 1), 7); + + assertEquals( + "Should have expected rows", + sql("SELECT * FROM %s WHERE row_id = 1 ORDER BY date", tableName), + sql(query)); + } + + @Test + public void testMultipleRuntimeFilters() throws NoSuchTableException { + sql( + "CREATE TABLE %s (id BIGINT, data STRING, date DATE, ts TIMESTAMP) " + + "USING iceberg " + + "PARTITIONED BY (data, bucket(8, id))", + tableName); + + Dataset df = + spark + .range(1, 100) + .withColumn("date", date_add(expr("DATE '1970-01-01'"), expr("CAST(id % 4 AS INT)"))) + .withColumn("ts", expr("TO_TIMESTAMP(date)")) + .withColumn("data", expr("CAST(date AS STRING)")) + .select("id", "data", "date", "ts"); + + df.coalesce(1).writeTo(tableName).option(SparkWriteOptions.FANOUT_ENABLED, "true").append(); + + sql("CREATE TABLE dim (id BIGINT, date DATE, data STRING) USING parquet"); + Dataset dimDF = + spark + .range(1, 2) + .withColumn("date", expr("DATE '1970-01-02'")) + .withColumn("data", expr("'1970-01-02'")) + .select("id", "date", "data"); + dimDF.coalesce(1).write().mode("append").insertInto("dim"); + + String query = + String.format( + "SELECT f.* FROM %s f JOIN dim d ON f.id = d.id AND f.data = d.data AND d.date = DATE '1970-01-02'", + tableName); + + assertQueryContainsRuntimeFilters(query, 2, "Query should have 2 runtime filters"); + + deleteNotMatchingFiles(Expressions.equal("id", 1), 31); + + assertEquals( + "Should have expected rows", + sql("SELECT * FROM %s WHERE id = 1 AND data = '1970-01-02'", tableName), + sql(query)); + } + + @Test + public void testCaseSensitivityOfRuntimeFilters() throws NoSuchTableException { + sql( + "CREATE TABLE %s (id BIGINT, data STRING, date DATE, ts TIMESTAMP) " + + "USING iceberg " + + "PARTITIONED BY (data, bucket(8, id))", + tableName); + + Dataset df = + spark + .range(1, 100) + .withColumn("date", date_add(expr("DATE '1970-01-01'"), expr("CAST(id % 4 AS INT)"))) + .withColumn("ts", expr("TO_TIMESTAMP(date)")) + .withColumn("data", expr("CAST(date AS STRING)")) + .select("id", "data", "date", "ts"); + + df.coalesce(1).writeTo(tableName).option(SparkWriteOptions.FANOUT_ENABLED, "true").append(); + + sql("CREATE TABLE dim (id BIGINT, date DATE, data STRING) USING parquet"); + Dataset dimDF = + spark + .range(1, 2) + .withColumn("date", expr("DATE '1970-01-02'")) + .withColumn("data", expr("'1970-01-02'")) + .select("id", "date", "data"); + dimDF.coalesce(1).write().mode("append").insertInto("dim"); + + String caseInsensitiveQuery = + String.format( + "select f.* from %s F join dim d ON f.Id = d.iD and f.DaTa = d.dAtA and d.dAtE = date '1970-01-02'", + tableName); + + assertQueryContainsRuntimeFilters( + caseInsensitiveQuery, 2, "Query should have 2 runtime filters"); + + deleteNotMatchingFiles(Expressions.equal("id", 1), 31); + + assertEquals( + "Should have expected rows", + sql("SELECT * FROM %s WHERE id = 1 AND data = '1970-01-02'", tableName), + sql(caseInsensitiveQuery)); + } + + @Test + public void testBucketedTableWithMultipleSpecs() throws NoSuchTableException { + sql( + "CREATE TABLE %s (id BIGINT, data STRING, date DATE, ts TIMESTAMP) USING iceberg", + tableName); + + Dataset df1 = + spark + .range(1, 100) + .withColumn("date", date_add(expr("DATE '1970-01-01'"), expr("CAST(id % 2 AS INT)"))) + .withColumn("ts", expr("TO_TIMESTAMP(date)")) + .withColumn("data", expr("CAST(date AS STRING)")) + .select("id", "data", "date", "ts"); + + df1.coalesce(1).writeTo(tableName).append(); + + Table table = validationCatalog.loadTable(tableIdent); + table.updateSpec().addField(Expressions.bucket("id", 8)).commit(); + + sql("REFRESH TABLE %s", tableName); + + Dataset df2 = + spark + .range(1, 100) + .withColumn("date", date_add(expr("DATE '1970-01-01'"), expr("CAST(id % 4 AS INT)"))) + .withColumn("ts", expr("TO_TIMESTAMP(date)")) + .withColumn("data", expr("CAST(date AS STRING)")) + .select("id", "data", "date", "ts"); + + df2.coalesce(1).writeTo(tableName).option(SparkWriteOptions.FANOUT_ENABLED, "true").append(); + + sql("CREATE TABLE dim (id BIGINT, date DATE) USING parquet"); + Dataset dimDF = + spark.range(1, 2).withColumn("date", expr("DATE '1970-01-02'")).select("id", "date"); + dimDF.coalesce(1).write().mode("append").insertInto("dim"); + + String query = + String.format( + "SELECT f.* FROM %s f JOIN dim d ON f.id = d.id AND d.date = DATE '1970-01-02' ORDER BY date", + tableName); + + assertQueryContainsRuntimeFilter(query); + + deleteNotMatchingFiles(Expressions.equal("id", 1), 7); + + assertEquals( + "Should have expected rows", + sql("SELECT * FROM %s WHERE id = 1 ORDER BY date", tableName), + sql(query)); + } + + @Test + public void testSourceColumnWithDots() throws NoSuchTableException { + sql( + "CREATE TABLE %s (`i.d` BIGINT, data STRING, date DATE, ts TIMESTAMP) " + + "USING iceberg " + + "PARTITIONED BY (bucket(8, `i.d`))", + tableName); + + Dataset df = + spark + .range(1, 100) + .withColumnRenamed("id", "i.d") + .withColumn("date", date_add(expr("DATE '1970-01-01'"), expr("CAST(`i.d` % 4 AS INT)"))) + .withColumn("ts", expr("TO_TIMESTAMP(date)")) + .withColumn("data", expr("CAST(date AS STRING)")) + .select("`i.d`", "data", "date", "ts"); + + df.coalesce(1).writeTo(tableName).option(SparkWriteOptions.FANOUT_ENABLED, "true").append(); + + sql("SELECT * FROM %s WHERE `i.d` = 1", tableName); + + sql("CREATE TABLE dim (id BIGINT, date DATE) USING parquet"); + Dataset dimDF = + spark.range(1, 2).withColumn("date", expr("DATE '1970-01-02'")).select("id", "date"); + dimDF.coalesce(1).write().mode("append").insertInto("dim"); + + String query = + String.format( + "SELECT f.* FROM %s f JOIN dim d ON f.`i.d` = d.id AND d.date = DATE '1970-01-02' ORDER BY date", + tableName); + + assertQueryContainsRuntimeFilter(query); + + deleteNotMatchingFiles(Expressions.equal("i.d", 1), 7); + + sql(query); + + assertEquals( + "Should have expected rows", + sql("SELECT * FROM %s WHERE `i.d` = 1 ORDER BY date", tableName), + sql(query)); + } + + @Test + public void testSourceColumnWithBackticks() throws NoSuchTableException { + sql( + "CREATE TABLE %s (`i``d` BIGINT, data STRING, date DATE, ts TIMESTAMP) " + + "USING iceberg " + + "PARTITIONED BY (bucket(8, `i``d`))", + tableName); + + Dataset df = + spark + .range(1, 100) + .withColumnRenamed("id", "i`d") + .withColumn( + "date", date_add(expr("DATE '1970-01-01'"), expr("CAST(`i``d` % 4 AS INT)"))) + .withColumn("ts", expr("TO_TIMESTAMP(date)")) + .withColumn("data", expr("CAST(date AS STRING)")) + .select("`i``d`", "data", "date", "ts"); + + df.coalesce(1).writeTo(tableName).option(SparkWriteOptions.FANOUT_ENABLED, "true").append(); + + sql("CREATE TABLE dim (id BIGINT, date DATE) USING parquet"); + Dataset dimDF = + spark.range(1, 2).withColumn("date", expr("DATE '1970-01-02'")).select("id", "date"); + dimDF.coalesce(1).write().mode("append").insertInto("dim"); + + String query = + String.format( + "SELECT f.* FROM %s f JOIN dim d ON f.`i``d` = d.id AND d.date = DATE '1970-01-02' ORDER BY date", + tableName); + + assertQueryContainsRuntimeFilter(query); + + deleteNotMatchingFiles(Expressions.equal("i`d", 1), 7); + + assertEquals( + "Should have expected rows", + sql("SELECT * FROM %s WHERE `i``d` = 1 ORDER BY date", tableName), + sql(query)); + } + + @Test + public void testUnpartitionedTable() throws NoSuchTableException { + sql( + "CREATE TABLE %s (id BIGINT, data STRING, date DATE, ts TIMESTAMP) USING iceberg", + tableName); + + Dataset df = + spark + .range(1, 100) + .withColumn("date", date_add(expr("DATE '1970-01-01'"), expr("CAST(id % 4 AS INT)"))) + .withColumn("ts", expr("TO_TIMESTAMP(date)")) + .withColumn("data", expr("CAST(date AS STRING)")) + .select("id", "data", "date", "ts"); + + df.coalesce(1).writeTo(tableName).append(); + + sql("CREATE TABLE dim (id BIGINT, date DATE) USING parquet"); + Dataset dimDF = + spark.range(1, 2).withColumn("date", expr("DATE '1970-01-02'")).select("id", "date"); + dimDF.coalesce(1).write().mode("append").insertInto("dim"); + + String query = + String.format( + "SELECT f.* FROM %s f JOIN dim d ON f.id = d.id AND d.date = DATE '1970-01-02' ORDER BY date", + tableName); + + assertQueryContainsNoRuntimeFilter(query); + + assertEquals( + "Should have expected rows", + sql("SELECT * FROM %s WHERE id = 1 ORDER BY date", tableName), + sql(query)); + } + + private void assertQueryContainsRuntimeFilter(String query) { + assertQueryContainsRuntimeFilters(query, 1, "Query should have 1 runtime filter"); + } + + private void assertQueryContainsNoRuntimeFilter(String query) { + assertQueryContainsRuntimeFilters(query, 0, "Query should have no runtime filters"); + } + + private void assertQueryContainsRuntimeFilters( + String query, int expectedFilterCount, String errorMessage) { + List output = spark.sql("EXPLAIN EXTENDED " + query).collectAsList(); + String plan = output.get(0).getString(0); + int actualFilterCount = StringUtils.countMatches(plan, "dynamicpruningexpression"); + Assert.assertEquals(errorMessage, expectedFilterCount, actualFilterCount); + } + + // delete files that don't match the filter to ensure dynamic filtering works and only required + // files are read + private void deleteNotMatchingFiles(Expression filter, int expectedDeletedFileCount) { + Table table = validationCatalog.loadTable(tableIdent); + FileIO io = table.io(); + + Set matchingFileLocations = Sets.newHashSet(); + try (CloseableIterable files = table.newScan().filter(filter).planFiles()) { + for (FileScanTask file : files) { + String path = file.file().path().toString(); + matchingFileLocations.add(path); + } + } catch (IOException e) { + throw new UncheckedIOException(e); + } + + Set deletedFileLocations = Sets.newHashSet(); + try (CloseableIterable files = table.newScan().planFiles()) { + for (FileScanTask file : files) { + String path = file.file().path().toString(); + if (!matchingFileLocations.contains(path)) { + io.deleteFile(path); + deletedFileLocations.add(path); + } + } + } catch (IOException e) { + throw new UncheckedIOException(e); + } + + Assert.assertEquals( + "Deleted unexpected number of files", + expectedDeletedFileCount, + deletedFileLocations.size()); + } +} diff --git a/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/source/TestSnapshotSelection.java b/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/source/TestSnapshotSelection.java new file mode 100644 index 000000000000..276fbcd592ae --- /dev/null +++ b/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/source/TestSnapshotSelection.java @@ -0,0 +1,457 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.source; + +import static org.apache.iceberg.types.Types.NestedField.optional; + +import java.io.IOException; +import java.util.List; +import org.apache.hadoop.conf.Configuration; +import org.apache.iceberg.PartitionSpec; +import org.apache.iceberg.Schema; +import org.apache.iceberg.Snapshot; +import org.apache.iceberg.Table; +import org.apache.iceberg.hadoop.HadoopTables; +import org.apache.iceberg.relocated.com.google.common.collect.Iterables; +import org.apache.iceberg.relocated.com.google.common.collect.Lists; +import org.apache.iceberg.spark.SparkReadOptions; +import org.apache.iceberg.types.Types; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Encoders; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.SparkSession; +import org.assertj.core.api.Assertions; +import org.junit.AfterClass; +import org.junit.Assert; +import org.junit.BeforeClass; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.TemporaryFolder; + +public class TestSnapshotSelection { + + private static final Configuration CONF = new Configuration(); + private static final Schema SCHEMA = + new Schema( + optional(1, "id", Types.IntegerType.get()), optional(2, "data", Types.StringType.get())); + + @Rule public TemporaryFolder temp = new TemporaryFolder(); + + private static SparkSession spark = null; + + @BeforeClass + public static void startSpark() { + TestSnapshotSelection.spark = SparkSession.builder().master("local[2]").getOrCreate(); + } + + @AfterClass + public static void stopSpark() { + SparkSession currentSpark = TestSnapshotSelection.spark; + TestSnapshotSelection.spark = null; + currentSpark.stop(); + } + + @Test + public void testSnapshotSelectionById() throws IOException { + String tableLocation = temp.newFolder("iceberg-table").toString(); + + HadoopTables tables = new HadoopTables(CONF); + PartitionSpec spec = PartitionSpec.unpartitioned(); + Table table = tables.create(SCHEMA, spec, tableLocation); + + // produce the first snapshot + List firstBatchRecords = + Lists.newArrayList( + new SimpleRecord(1, "a"), new SimpleRecord(2, "b"), new SimpleRecord(3, "c")); + Dataset firstDf = spark.createDataFrame(firstBatchRecords, SimpleRecord.class); + firstDf.select("id", "data").write().format("iceberg").mode("append").save(tableLocation); + + // produce the second snapshot + List secondBatchRecords = + Lists.newArrayList( + new SimpleRecord(4, "d"), new SimpleRecord(5, "e"), new SimpleRecord(6, "f")); + Dataset secondDf = spark.createDataFrame(secondBatchRecords, SimpleRecord.class); + secondDf.select("id", "data").write().format("iceberg").mode("append").save(tableLocation); + + Assert.assertEquals("Expected 2 snapshots", 2, Iterables.size(table.snapshots())); + + // verify records in the current snapshot + Dataset currentSnapshotResult = spark.read().format("iceberg").load(tableLocation); + List currentSnapshotRecords = + currentSnapshotResult.orderBy("id").as(Encoders.bean(SimpleRecord.class)).collectAsList(); + List expectedRecords = Lists.newArrayList(); + expectedRecords.addAll(firstBatchRecords); + expectedRecords.addAll(secondBatchRecords); + Assert.assertEquals( + "Current snapshot rows should match", expectedRecords, currentSnapshotRecords); + + // verify records in the previous snapshot + Snapshot currentSnapshot = table.currentSnapshot(); + Long parentSnapshotId = currentSnapshot.parentId(); + Dataset previousSnapshotResult = + spark.read().format("iceberg").option("snapshot-id", parentSnapshotId).load(tableLocation); + List previousSnapshotRecords = + previousSnapshotResult.orderBy("id").as(Encoders.bean(SimpleRecord.class)).collectAsList(); + Assert.assertEquals( + "Previous snapshot rows should match", firstBatchRecords, previousSnapshotRecords); + } + + @Test + public void testSnapshotSelectionByTimestamp() throws IOException { + String tableLocation = temp.newFolder("iceberg-table").toString(); + + HadoopTables tables = new HadoopTables(CONF); + PartitionSpec spec = PartitionSpec.unpartitioned(); + Table table = tables.create(SCHEMA, spec, tableLocation); + + // produce the first snapshot + List firstBatchRecords = + Lists.newArrayList( + new SimpleRecord(1, "a"), new SimpleRecord(2, "b"), new SimpleRecord(3, "c")); + Dataset firstDf = spark.createDataFrame(firstBatchRecords, SimpleRecord.class); + firstDf.select("id", "data").write().format("iceberg").mode("append").save(tableLocation); + + // remember the time when the first snapshot was valid + long firstSnapshotTimestamp = System.currentTimeMillis(); + + // produce the second snapshot + List secondBatchRecords = + Lists.newArrayList( + new SimpleRecord(4, "d"), new SimpleRecord(5, "e"), new SimpleRecord(6, "f")); + Dataset secondDf = spark.createDataFrame(secondBatchRecords, SimpleRecord.class); + secondDf.select("id", "data").write().format("iceberg").mode("append").save(tableLocation); + + Assert.assertEquals("Expected 2 snapshots", 2, Iterables.size(table.snapshots())); + + // verify records in the current snapshot + Dataset currentSnapshotResult = spark.read().format("iceberg").load(tableLocation); + List currentSnapshotRecords = + currentSnapshotResult.orderBy("id").as(Encoders.bean(SimpleRecord.class)).collectAsList(); + List expectedRecords = Lists.newArrayList(); + expectedRecords.addAll(firstBatchRecords); + expectedRecords.addAll(secondBatchRecords); + Assert.assertEquals( + "Current snapshot rows should match", expectedRecords, currentSnapshotRecords); + + // verify records in the previous snapshot + Dataset previousSnapshotResult = + spark + .read() + .format("iceberg") + .option(SparkReadOptions.AS_OF_TIMESTAMP, firstSnapshotTimestamp) + .load(tableLocation); + List previousSnapshotRecords = + previousSnapshotResult.orderBy("id").as(Encoders.bean(SimpleRecord.class)).collectAsList(); + Assert.assertEquals( + "Previous snapshot rows should match", firstBatchRecords, previousSnapshotRecords); + } + + @Test + public void testSnapshotSelectionByInvalidSnapshotId() throws IOException { + String tableLocation = temp.newFolder("iceberg-table").toString(); + + HadoopTables tables = new HadoopTables(CONF); + PartitionSpec spec = PartitionSpec.unpartitioned(); + tables.create(SCHEMA, spec, tableLocation); + + Dataset df = spark.read().format("iceberg").option("snapshot-id", -10).load(tableLocation); + + Assertions.assertThatThrownBy(df::collectAsList) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("Cannot find snapshot with ID -10"); + } + + @Test + public void testSnapshotSelectionByInvalidTimestamp() throws IOException { + long timestamp = System.currentTimeMillis(); + + String tableLocation = temp.newFolder("iceberg-table").toString(); + HadoopTables tables = new HadoopTables(CONF); + PartitionSpec spec = PartitionSpec.unpartitioned(); + tables.create(SCHEMA, spec, tableLocation); + + Assertions.assertThatThrownBy( + () -> + spark + .read() + .format("iceberg") + .option(SparkReadOptions.AS_OF_TIMESTAMP, timestamp) + .load(tableLocation)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("Cannot find a snapshot older than"); + } + + @Test + public void testSnapshotSelectionBySnapshotIdAndTimestamp() throws IOException { + String tableLocation = temp.newFolder("iceberg-table").toString(); + + HadoopTables tables = new HadoopTables(CONF); + PartitionSpec spec = PartitionSpec.unpartitioned(); + Table table = tables.create(SCHEMA, spec, tableLocation); + + List firstBatchRecords = + Lists.newArrayList( + new SimpleRecord(1, "a"), new SimpleRecord(2, "b"), new SimpleRecord(3, "c")); + Dataset firstDf = spark.createDataFrame(firstBatchRecords, SimpleRecord.class); + firstDf.select("id", "data").write().format("iceberg").mode("append").save(tableLocation); + + long timestamp = System.currentTimeMillis(); + long snapshotId = table.currentSnapshot().snapshotId(); + + Assertions.assertThatThrownBy( + () -> + spark + .read() + .format("iceberg") + .option(SparkReadOptions.SNAPSHOT_ID, snapshotId) + .option(SparkReadOptions.AS_OF_TIMESTAMP, timestamp) + .load(tableLocation)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("Can specify only one of snapshot-id") + .hasMessageContaining("as-of-timestamp") + .hasMessageContaining("branch") + .hasMessageContaining("tag"); + } + + @Test + public void testSnapshotSelectionByTag() throws IOException { + String tableLocation = temp.newFolder("iceberg-table").toString(); + + HadoopTables tables = new HadoopTables(CONF); + PartitionSpec spec = PartitionSpec.unpartitioned(); + Table table = tables.create(SCHEMA, spec, tableLocation); + + // produce the first snapshot + List firstBatchRecords = + Lists.newArrayList( + new SimpleRecord(1, "a"), new SimpleRecord(2, "b"), new SimpleRecord(3, "c")); + Dataset firstDf = spark.createDataFrame(firstBatchRecords, SimpleRecord.class); + firstDf.select("id", "data").write().format("iceberg").mode("append").save(tableLocation); + + table.manageSnapshots().createTag("tag", table.currentSnapshot().snapshotId()).commit(); + + // produce the second snapshot + List secondBatchRecords = + Lists.newArrayList( + new SimpleRecord(4, "d"), new SimpleRecord(5, "e"), new SimpleRecord(6, "f")); + Dataset secondDf = spark.createDataFrame(secondBatchRecords, SimpleRecord.class); + secondDf.select("id", "data").write().format("iceberg").mode("append").save(tableLocation); + + // verify records in the current snapshot by tag + Dataset currentSnapshotResult = + spark.read().format("iceberg").option("tag", "tag").load(tableLocation); + List currentSnapshotRecords = + currentSnapshotResult.orderBy("id").as(Encoders.bean(SimpleRecord.class)).collectAsList(); + List expectedRecords = Lists.newArrayList(); + expectedRecords.addAll(firstBatchRecords); + Assert.assertEquals( + "Current snapshot rows should match", expectedRecords, currentSnapshotRecords); + } + + @Test + public void testSnapshotSelectionByBranch() throws IOException { + String tableLocation = temp.newFolder("iceberg-table").toString(); + + HadoopTables tables = new HadoopTables(CONF); + PartitionSpec spec = PartitionSpec.unpartitioned(); + Table table = tables.create(SCHEMA, spec, tableLocation); + + // produce the first snapshot + List firstBatchRecords = + Lists.newArrayList( + new SimpleRecord(1, "a"), new SimpleRecord(2, "b"), new SimpleRecord(3, "c")); + Dataset firstDf = spark.createDataFrame(firstBatchRecords, SimpleRecord.class); + firstDf.select("id", "data").write().format("iceberg").mode("append").save(tableLocation); + + table.manageSnapshots().createBranch("branch", table.currentSnapshot().snapshotId()).commit(); + + // produce the second snapshot + List secondBatchRecords = + Lists.newArrayList( + new SimpleRecord(4, "d"), new SimpleRecord(5, "e"), new SimpleRecord(6, "f")); + Dataset secondDf = spark.createDataFrame(secondBatchRecords, SimpleRecord.class); + secondDf.select("id", "data").write().format("iceberg").mode("append").save(tableLocation); + + // verify records in the current snapshot by branch + Dataset currentSnapshotResult = + spark.read().format("iceberg").option("branch", "branch").load(tableLocation); + List currentSnapshotRecords = + currentSnapshotResult.orderBy("id").as(Encoders.bean(SimpleRecord.class)).collectAsList(); + List expectedRecords = Lists.newArrayList(); + expectedRecords.addAll(firstBatchRecords); + Assert.assertEquals( + "Current snapshot rows should match", expectedRecords, currentSnapshotRecords); + } + + @Test + public void testSnapshotSelectionByBranchAndTagFails() throws IOException { + String tableLocation = temp.newFolder("iceberg-table").toString(); + + HadoopTables tables = new HadoopTables(CONF); + PartitionSpec spec = PartitionSpec.unpartitioned(); + Table table = tables.create(SCHEMA, spec, tableLocation); + + // produce the first snapshot + List firstBatchRecords = + Lists.newArrayList( + new SimpleRecord(1, "a"), new SimpleRecord(2, "b"), new SimpleRecord(3, "c")); + Dataset firstDf = spark.createDataFrame(firstBatchRecords, SimpleRecord.class); + firstDf.select("id", "data").write().format("iceberg").mode("append").save(tableLocation); + + table.manageSnapshots().createBranch("branch", table.currentSnapshot().snapshotId()).commit(); + table.manageSnapshots().createTag("tag", table.currentSnapshot().snapshotId()).commit(); + + Assertions.assertThatThrownBy( + () -> + spark + .read() + .format("iceberg") + .option(SparkReadOptions.TAG, "tag") + .option(SparkReadOptions.BRANCH, "branch") + .load(tableLocation) + .show()) + .isInstanceOf(IllegalArgumentException.class) + .hasMessageStartingWith("Can specify only one of snapshot-id"); + } + + @Test + public void testSnapshotSelectionByTimestampAndBranchOrTagFails() throws IOException { + String tableLocation = temp.newFolder("iceberg-table").toString(); + + HadoopTables tables = new HadoopTables(CONF); + PartitionSpec spec = PartitionSpec.unpartitioned(); + Table table = tables.create(SCHEMA, spec, tableLocation); + + List firstBatchRecords = + Lists.newArrayList( + new SimpleRecord(1, "a"), new SimpleRecord(2, "b"), new SimpleRecord(3, "c")); + Dataset firstDf = spark.createDataFrame(firstBatchRecords, SimpleRecord.class); + firstDf.select("id", "data").write().format("iceberg").mode("append").save(tableLocation); + + long timestamp = System.currentTimeMillis(); + table.manageSnapshots().createBranch("branch", table.currentSnapshot().snapshotId()).commit(); + table.manageSnapshots().createTag("tag", table.currentSnapshot().snapshotId()).commit(); + + Assertions.assertThatThrownBy( + () -> + spark + .read() + .format("iceberg") + .option(SparkReadOptions.AS_OF_TIMESTAMP, timestamp) + .option(SparkReadOptions.BRANCH, "branch") + .load(tableLocation) + .show()) + .isInstanceOf(IllegalArgumentException.class) + .hasMessageStartingWith("Can specify only one of snapshot-id"); + + Assertions.assertThatThrownBy( + () -> + spark + .read() + .format("iceberg") + .option(SparkReadOptions.AS_OF_TIMESTAMP, timestamp) + .option(SparkReadOptions.TAG, "tag") + .load(tableLocation) + .show()) + .isInstanceOf(IllegalArgumentException.class) + .hasMessageStartingWith("Can specify only one of snapshot-id"); + } + + @Test + public void testSnapshotSelectionByBranchWithSchemaChange() throws IOException { + String tableLocation = temp.newFolder("iceberg-table").toString(); + + HadoopTables tables = new HadoopTables(CONF); + PartitionSpec spec = PartitionSpec.unpartitioned(); + Table table = tables.create(SCHEMA, spec, tableLocation); + + // produce the first snapshot + List firstBatchRecords = + Lists.newArrayList( + new SimpleRecord(1, "a"), new SimpleRecord(2, "b"), new SimpleRecord(3, "c")); + Dataset firstDf = spark.createDataFrame(firstBatchRecords, SimpleRecord.class); + firstDf.select("id", "data").write().format("iceberg").mode("append").save(tableLocation); + + table.manageSnapshots().createBranch("branch", table.currentSnapshot().snapshotId()).commit(); + + Dataset branchSnapshotResult = + spark.read().format("iceberg").option("branch", "branch").load(tableLocation); + List branchSnapshotRecords = + branchSnapshotResult.orderBy("id").as(Encoders.bean(SimpleRecord.class)).collectAsList(); + List expectedRecords = Lists.newArrayList(); + expectedRecords.addAll(firstBatchRecords); + Assert.assertEquals( + "Current snapshot rows should match", expectedRecords, branchSnapshotRecords); + + // Deleting a column to indicate schema change + table.updateSchema().deleteColumn("data").commit(); + + // The data should have the deleted column as it was captured in an earlier snapshot. + Dataset deletedColumnBranchSnapshotResult = + spark.read().format("iceberg").option("branch", "branch").load(tableLocation); + List deletedColumnBranchSnapshotRecords = + deletedColumnBranchSnapshotResult + .orderBy("id") + .as(Encoders.bean(SimpleRecord.class)) + .collectAsList(); + Assert.assertEquals( + "Current snapshot rows should match", expectedRecords, deletedColumnBranchSnapshotRecords); + } + + @Test + public void testSnapshotSelectionByTagWithSchemaChange() throws IOException { + String tableLocation = temp.newFolder("iceberg-table").toString(); + + HadoopTables tables = new HadoopTables(CONF); + PartitionSpec spec = PartitionSpec.unpartitioned(); + Table table = tables.create(SCHEMA, spec, tableLocation); + + // produce the first snapshot + List firstBatchRecords = + Lists.newArrayList( + new SimpleRecord(1, "a"), new SimpleRecord(2, "b"), new SimpleRecord(3, "c")); + Dataset firstDf = spark.createDataFrame(firstBatchRecords, SimpleRecord.class); + firstDf.select("id", "data").write().format("iceberg").mode("append").save(tableLocation); + + table.manageSnapshots().createTag("tag", table.currentSnapshot().snapshotId()).commit(); + + List expectedRecords = Lists.newArrayList(); + expectedRecords.addAll(firstBatchRecords); + + Dataset tagSnapshotResult = + spark.read().format("iceberg").option("tag", "tag").load(tableLocation); + List tagSnapshotRecords = + tagSnapshotResult.orderBy("id").as(Encoders.bean(SimpleRecord.class)).collectAsList(); + Assert.assertEquals("Current snapshot rows should match", expectedRecords, tagSnapshotRecords); + + // Deleting a column to indicate schema change + table.updateSchema().deleteColumn("data").commit(); + + // The data should have the deleted column as it was captured in an earlier snapshot. + Dataset deletedColumnTagSnapshotResult = + spark.read().format("iceberg").option("tag", "tag").load(tableLocation); + List deletedColumnTagSnapshotRecords = + deletedColumnTagSnapshotResult + .orderBy("id") + .as(Encoders.bean(SimpleRecord.class)) + .collectAsList(); + Assert.assertEquals( + "Current snapshot rows should match", expectedRecords, deletedColumnTagSnapshotRecords); + } +} diff --git a/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/source/TestSparkAggregates.java b/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/source/TestSparkAggregates.java new file mode 100644 index 000000000000..e2d6f744f5a5 --- /dev/null +++ b/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/source/TestSparkAggregates.java @@ -0,0 +1,76 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.source; + +import java.util.Map; +import org.apache.iceberg.expressions.Expression; +import org.apache.iceberg.expressions.Expressions; +import org.apache.iceberg.relocated.com.google.common.collect.Maps; +import org.apache.iceberg.spark.SparkAggregates; +import org.apache.spark.sql.connector.expressions.FieldReference; +import org.apache.spark.sql.connector.expressions.NamedReference; +import org.apache.spark.sql.connector.expressions.aggregate.Count; +import org.apache.spark.sql.connector.expressions.aggregate.CountStar; +import org.apache.spark.sql.connector.expressions.aggregate.Max; +import org.apache.spark.sql.connector.expressions.aggregate.Min; +import org.junit.Assert; +import org.junit.Test; + +public class TestSparkAggregates { + + @Test + public void testAggregates() { + Map attrMap = Maps.newHashMap(); + attrMap.put("id", "id"); + attrMap.put("`i.d`", "i.d"); + attrMap.put("`i``d`", "i`d"); + attrMap.put("`d`.b.`dd```", "d.b.dd`"); + attrMap.put("a.`aa```.c", "a.aa`.c"); + + attrMap.forEach( + (quoted, unquoted) -> { + NamedReference namedReference = FieldReference.apply(quoted); + + Max max = new Max(namedReference); + Expression expectedMax = Expressions.max(unquoted); + Expression actualMax = SparkAggregates.convert(max); + Assert.assertEquals("Max must match", expectedMax.toString(), actualMax.toString()); + + Min min = new Min(namedReference); + Expression expectedMin = Expressions.min(unquoted); + Expression actualMin = SparkAggregates.convert(min); + Assert.assertEquals("Min must match", expectedMin.toString(), actualMin.toString()); + + Count count = new Count(namedReference, false); + Expression expectedCount = Expressions.count(unquoted); + Expression actualCount = SparkAggregates.convert(count); + Assert.assertEquals("Count must match", expectedCount.toString(), actualCount.toString()); + + Count countDistinct = new Count(namedReference, true); + Expression convertedCountDistinct = SparkAggregates.convert(countDistinct); + Assert.assertNull("Count Distinct is converted to null", convertedCountDistinct); + + CountStar countStar = new CountStar(); + Expression expectedCountStar = Expressions.countStar(); + Expression actualCountStar = SparkAggregates.convert(countStar); + Assert.assertEquals( + "CountStar must match", expectedCountStar.toString(), actualCountStar.toString()); + }); + } +} diff --git a/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/source/TestSparkAppenderFactory.java b/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/source/TestSparkAppenderFactory.java new file mode 100644 index 000000000000..3fb2a630fe81 --- /dev/null +++ b/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/source/TestSparkAppenderFactory.java @@ -0,0 +1,69 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.source; + +import java.util.List; +import org.apache.iceberg.Schema; +import org.apache.iceberg.io.FileAppenderFactory; +import org.apache.iceberg.io.TestAppenderFactory; +import org.apache.iceberg.spark.SparkSchemaUtil; +import org.apache.iceberg.util.ArrayUtil; +import org.apache.iceberg.util.StructLikeSet; +import org.apache.spark.sql.catalyst.InternalRow; +import org.apache.spark.sql.catalyst.expressions.GenericInternalRow; +import org.apache.spark.sql.types.StructType; +import org.apache.spark.unsafe.types.UTF8String; + +public class TestSparkAppenderFactory extends TestAppenderFactory { + + private final StructType sparkType; + + public TestSparkAppenderFactory(String fileFormat, boolean partitioned) { + super(fileFormat, partitioned); + this.sparkType = SparkSchemaUtil.convert(SCHEMA); + } + + @Override + protected FileAppenderFactory createAppenderFactory( + List equalityFieldIds, Schema eqDeleteSchema, Schema posDeleteRowSchema) { + return SparkAppenderFactory.builderFor(table, table.schema(), sparkType) + .equalityFieldIds(ArrayUtil.toIntArray(equalityFieldIds)) + .eqDeleteRowSchema(eqDeleteSchema) + .posDelRowSchema(posDeleteRowSchema) + .build(); + } + + @Override + protected InternalRow createRow(Integer id, String data) { + InternalRow row = new GenericInternalRow(2); + row.update(0, id); + row.update(1, UTF8String.fromString(data)); + return row; + } + + @Override + protected StructLikeSet expectedRowSet(Iterable rows) { + StructLikeSet set = StructLikeSet.create(table.schema().asStruct()); + for (InternalRow row : rows) { + InternalRowWrapper wrapper = new InternalRowWrapper(sparkType); + set.add(wrapper.wrap(row)); + } + return set; + } +} diff --git a/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/source/TestSparkCatalog.java b/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/source/TestSparkCatalog.java new file mode 100644 index 000000000000..0c6cad7f369c --- /dev/null +++ b/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/source/TestSparkCatalog.java @@ -0,0 +1,66 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.source; + +import java.util.Map; +import org.apache.iceberg.catalog.Namespace; +import org.apache.iceberg.catalog.TableIdentifier; +import org.apache.iceberg.relocated.com.google.common.base.Preconditions; +import org.apache.iceberg.relocated.com.google.common.collect.Maps; +import org.apache.iceberg.spark.Spark3Util; +import org.apache.iceberg.spark.SparkSessionCatalog; +import org.apache.spark.sql.catalyst.analysis.NoSuchTableException; +import org.apache.spark.sql.connector.catalog.FunctionCatalog; +import org.apache.spark.sql.connector.catalog.Identifier; +import org.apache.spark.sql.connector.catalog.SupportsNamespaces; +import org.apache.spark.sql.connector.catalog.Table; +import org.apache.spark.sql.connector.catalog.TableCatalog; + +public class TestSparkCatalog + extends SparkSessionCatalog { + + private static final Map tableMap = Maps.newHashMap(); + + public static void setTable(Identifier ident, Table table) { + Preconditions.checkArgument( + !tableMap.containsKey(ident), "Cannot set " + ident + ". It is already set"); + tableMap.put(ident, table); + } + + @Override + public Table loadTable(Identifier ident) throws NoSuchTableException { + if (tableMap.containsKey(ident)) { + return tableMap.get(ident); + } + + TableIdentifier tableIdentifier = Spark3Util.identifierToTableIdentifier(ident); + Namespace namespace = tableIdentifier.namespace(); + + TestTables.TestTable table = TestTables.load(tableIdentifier.toString()); + if (table == null && namespace.equals(Namespace.of("default"))) { + table = TestTables.load(tableIdentifier.name()); + } + + return new SparkTable(table, false); + } + + public static void clearTables() { + tableMap.clear(); + } +} diff --git a/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/source/TestSparkCatalogCacheExpiration.java b/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/source/TestSparkCatalogCacheExpiration.java new file mode 100644 index 000000000000..3d668197fd51 --- /dev/null +++ b/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/source/TestSparkCatalogCacheExpiration.java @@ -0,0 +1,152 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.source; + +import java.util.Map; +import org.apache.iceberg.CachingCatalog; +import org.apache.iceberg.CatalogProperties; +import org.apache.iceberg.catalog.Catalog; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap; +import org.apache.iceberg.spark.SparkCatalog; +import org.apache.iceberg.spark.SparkSessionCatalog; +import org.apache.iceberg.spark.SparkTestBaseWithCatalog; +import org.apache.spark.sql.connector.catalog.TableCatalog; +import org.assertj.core.api.Assertions; +import org.junit.BeforeClass; +import org.junit.Test; + +public class TestSparkCatalogCacheExpiration extends SparkTestBaseWithCatalog { + + private static final String sessionCatalogName = "spark_catalog"; + private static final String sessionCatalogImpl = SparkSessionCatalog.class.getName(); + private static final Map sessionCatalogConfig = + ImmutableMap.of( + "type", + "hadoop", + "default-namespace", + "default", + CatalogProperties.CACHE_ENABLED, + "true", + CatalogProperties.CACHE_EXPIRATION_INTERVAL_MS, + "3000"); + + private static String asSqlConfCatalogKeyFor(String catalog, String configKey) { + // configKey is empty when the catalog's class is being defined + if (configKey.isEmpty()) { + return String.format("spark.sql.catalog.%s", catalog); + } else { + return String.format("spark.sql.catalog.%s.%s", catalog, configKey); + } + } + + // Add more catalogs to the spark session, so we only need to start spark one time for multiple + // different catalog configuration tests. + @BeforeClass + public static void beforeClass() { + // Catalog - expiration_disabled: Catalog with caching on and expiration disabled. + ImmutableMap.of( + "", + "org.apache.iceberg.spark.SparkCatalog", + "type", + "hive", + CatalogProperties.CACHE_ENABLED, + "true", + CatalogProperties.CACHE_EXPIRATION_INTERVAL_MS, + "-1") + .forEach((k, v) -> spark.conf().set(asSqlConfCatalogKeyFor("expiration_disabled", k), v)); + + // Catalog - cache_disabled_implicitly: Catalog that does not cache, as the cache expiration + // interval is 0. + ImmutableMap.of( + "", + "org.apache.iceberg.spark.SparkCatalog", + "type", + "hive", + CatalogProperties.CACHE_ENABLED, + "true", + CatalogProperties.CACHE_EXPIRATION_INTERVAL_MS, + "0") + .forEach( + (k, v) -> spark.conf().set(asSqlConfCatalogKeyFor("cache_disabled_implicitly", k), v)); + } + + public TestSparkCatalogCacheExpiration() { + super(sessionCatalogName, sessionCatalogImpl, sessionCatalogConfig); + } + + @Test + public void testSparkSessionCatalogWithExpirationEnabled() { + SparkSessionCatalog sparkCatalog = sparkSessionCatalog(); + Assertions.assertThat(sparkCatalog) + .extracting("icebergCatalog") + .extracting("cacheEnabled") + .isEqualTo(true); + + Assertions.assertThat(sparkCatalog) + .extracting("icebergCatalog") + .extracting("icebergCatalog") + .isInstanceOfSatisfying( + Catalog.class, + icebergCatalog -> { + Assertions.assertThat(icebergCatalog) + .isExactlyInstanceOf(CachingCatalog.class) + .extracting("expirationIntervalMillis") + .isEqualTo(3000L); + }); + } + + @Test + public void testCacheEnabledAndExpirationDisabled() { + SparkCatalog sparkCatalog = getSparkCatalog("expiration_disabled"); + Assertions.assertThat(sparkCatalog).extracting("cacheEnabled").isEqualTo(true); + + Assertions.assertThat(sparkCatalog) + .extracting("icebergCatalog") + .isInstanceOfSatisfying( + CachingCatalog.class, + icebergCatalog -> { + Assertions.assertThat(icebergCatalog) + .extracting("expirationIntervalMillis") + .isEqualTo(-1L); + }); + } + + @Test + public void testCacheDisabledImplicitly() { + SparkCatalog sparkCatalog = getSparkCatalog("cache_disabled_implicitly"); + Assertions.assertThat(sparkCatalog).extracting("cacheEnabled").isEqualTo(false); + + Assertions.assertThat(sparkCatalog) + .extracting("icebergCatalog") + .isInstanceOfSatisfying( + Catalog.class, + icebergCatalog -> + Assertions.assertThat(icebergCatalog).isNotInstanceOf(CachingCatalog.class)); + } + + private SparkSessionCatalog sparkSessionCatalog() { + TableCatalog catalog = + (TableCatalog) spark.sessionState().catalogManager().catalog("spark_catalog"); + return (SparkSessionCatalog) catalog; + } + + private SparkCatalog getSparkCatalog(String catalog) { + return (SparkCatalog) spark.sessionState().catalogManager().catalog(catalog); + } +} diff --git a/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/source/TestSparkCatalogHadoopOverrides.java b/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/source/TestSparkCatalogHadoopOverrides.java new file mode 100644 index 000000000000..607f1d45ba3a --- /dev/null +++ b/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/source/TestSparkCatalogHadoopOverrides.java @@ -0,0 +1,145 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.source; + +import java.util.Map; +import org.apache.hadoop.conf.Configurable; +import org.apache.hadoop.conf.Configuration; +import org.apache.iceberg.KryoHelpers; +import org.apache.iceberg.Table; +import org.apache.iceberg.TestHelpers; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap; +import org.apache.iceberg.spark.SparkCatalog; +import org.apache.iceberg.spark.SparkCatalogTestBase; +import org.apache.iceberg.spark.SparkSessionCatalog; +import org.apache.spark.sql.connector.catalog.Identifier; +import org.apache.spark.sql.connector.catalog.TableCatalog; +import org.junit.After; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; +import org.junit.runners.Parameterized; + +public class TestSparkCatalogHadoopOverrides extends SparkCatalogTestBase { + + private static final String configToOverride = "fs.s3a.buffer.dir"; + // prepend "hadoop." so that the test base formats SQLConf correctly + // as `spark.sql.catalogs..hadoop. + private static final String hadoopPrefixedConfigToOverride = "hadoop." + configToOverride; + private static final String configOverrideValue = "/tmp-overridden"; + + @Parameterized.Parameters(name = "catalogName = {0}, implementation = {1}, config = {2}") + public static Object[][] parameters() { + return new Object[][] { + { + "testhive", + SparkCatalog.class.getName(), + ImmutableMap.of( + "type", + "hive", + "default-namespace", + "default", + hadoopPrefixedConfigToOverride, + configOverrideValue) + }, + { + "testhadoop", + SparkCatalog.class.getName(), + ImmutableMap.of("type", "hadoop", hadoopPrefixedConfigToOverride, configOverrideValue) + }, + { + "spark_catalog", + SparkSessionCatalog.class.getName(), + ImmutableMap.of( + "type", + "hive", + "default-namespace", + "default", + hadoopPrefixedConfigToOverride, + configOverrideValue) + } + }; + } + + public TestSparkCatalogHadoopOverrides( + String catalogName, String implementation, Map config) { + super(catalogName, implementation, config); + } + + @Before + public void createTable() { + sql("CREATE TABLE IF NOT EXISTS %s (id bigint) USING iceberg", tableName(tableIdent.name())); + } + + @After + public void dropTable() { + sql("DROP TABLE IF EXISTS %s", tableName(tableIdent.name())); + } + + @Test + public void testTableFromCatalogHasOverrides() throws Exception { + Table table = getIcebergTableFromSparkCatalog(); + Configuration conf = ((Configurable) table.io()).getConf(); + String actualCatalogOverride = conf.get(configToOverride, "/whammies"); + Assert.assertEquals( + "Iceberg tables from spark should have the overridden hadoop configurations from the spark config", + configOverrideValue, + actualCatalogOverride); + } + + @Test + public void ensureRoundTripSerializedTableRetainsHadoopConfig() throws Exception { + Table table = getIcebergTableFromSparkCatalog(); + Configuration originalConf = ((Configurable) table.io()).getConf(); + String actualCatalogOverride = originalConf.get(configToOverride, "/whammies"); + Assert.assertEquals( + "Iceberg tables from spark should have the overridden hadoop configurations from the spark config", + configOverrideValue, + actualCatalogOverride); + + // Now convert to SerializableTable and ensure overridden property is still present. + Table serializableTable = SerializableTableWithSize.copyOf(table); + Table kryoSerializedTable = + KryoHelpers.roundTripSerialize(SerializableTableWithSize.copyOf(table)); + Configuration configFromKryoSerde = ((Configurable) kryoSerializedTable.io()).getConf(); + String kryoSerializedCatalogOverride = configFromKryoSerde.get(configToOverride, "/whammies"); + Assert.assertEquals( + "Tables serialized with Kryo serialization should retain overridden hadoop configuration properties", + configOverrideValue, + kryoSerializedCatalogOverride); + + // Do the same for Java based serde + Table javaSerializedTable = TestHelpers.roundTripSerialize(serializableTable); + Configuration configFromJavaSerde = ((Configurable) javaSerializedTable.io()).getConf(); + String javaSerializedCatalogOverride = configFromJavaSerde.get(configToOverride, "/whammies"); + Assert.assertEquals( + "Tables serialized with Java serialization should retain overridden hadoop configuration properties", + configOverrideValue, + javaSerializedCatalogOverride); + } + + @SuppressWarnings("ThrowSpecificity") + private Table getIcebergTableFromSparkCatalog() throws Exception { + Identifier identifier = Identifier.of(tableIdent.namespace().levels(), tableIdent.name()); + TableCatalog catalog = + (TableCatalog) spark.sessionState().catalogManager().catalog(catalogName); + SparkTable sparkTable = (SparkTable) catalog.loadTable(identifier); + return sparkTable.table(); + } +} diff --git a/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/source/TestSparkDataFile.java b/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/source/TestSparkDataFile.java new file mode 100644 index 000000000000..b1f2082b5d9b --- /dev/null +++ b/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/source/TestSparkDataFile.java @@ -0,0 +1,224 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.source; + +import static org.apache.iceberg.types.Types.NestedField.optional; +import static org.apache.iceberg.types.Types.NestedField.required; + +import java.io.File; +import java.io.IOException; +import java.util.Arrays; +import java.util.Collections; +import java.util.List; +import java.util.Map; +import java.util.stream.Collectors; +import org.apache.hadoop.conf.Configuration; +import org.apache.iceberg.DataFile; +import org.apache.iceberg.DataFiles; +import org.apache.iceberg.ManifestFile; +import org.apache.iceberg.ManifestFiles; +import org.apache.iceberg.ManifestReader; +import org.apache.iceberg.PartitionSpec; +import org.apache.iceberg.Schema; +import org.apache.iceberg.StructLike; +import org.apache.iceberg.Table; +import org.apache.iceberg.TableProperties; +import org.apache.iceberg.hadoop.HadoopTables; +import org.apache.iceberg.relocated.com.google.common.collect.Iterables; +import org.apache.iceberg.relocated.com.google.common.collect.Lists; +import org.apache.iceberg.relocated.com.google.common.collect.Maps; +import org.apache.iceberg.spark.SparkDataFile; +import org.apache.iceberg.spark.SparkSchemaUtil; +import org.apache.iceberg.spark.data.RandomData; +import org.apache.iceberg.types.Types; +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.sql.Column; +import org.apache.spark.sql.ColumnName; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.SparkSession; +import org.apache.spark.sql.catalyst.InternalRow; +import org.apache.spark.sql.types.StructType; +import org.junit.AfterClass; +import org.junit.Assert; +import org.junit.Before; +import org.junit.BeforeClass; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.TemporaryFolder; + +public class TestSparkDataFile { + + private static final HadoopTables TABLES = new HadoopTables(new Configuration()); + private static final Schema SCHEMA = + new Schema( + required(100, "id", Types.LongType.get()), + optional(101, "data", Types.StringType.get()), + required(102, "b", Types.BooleanType.get()), + optional(103, "i", Types.IntegerType.get()), + required(104, "l", Types.LongType.get()), + optional(105, "f", Types.FloatType.get()), + required(106, "d", Types.DoubleType.get()), + optional(107, "date", Types.DateType.get()), + required(108, "ts", Types.TimestampType.withZone()), + required(110, "s", Types.StringType.get()), + optional(113, "bytes", Types.BinaryType.get()), + required(114, "dec_9_0", Types.DecimalType.of(9, 0)), + required(115, "dec_11_2", Types.DecimalType.of(11, 2)), + required(116, "dec_38_10", Types.DecimalType.of(38, 10)) // maximum precision + ); + private static final PartitionSpec SPEC = + PartitionSpec.builderFor(SCHEMA) + .identity("b") + .bucket("i", 2) + .identity("l") + .identity("f") + .identity("d") + .identity("date") + .hour("ts") + .identity("ts") + .truncate("s", 2) + .identity("bytes") + .bucket("dec_9_0", 2) + .bucket("dec_11_2", 2) + .bucket("dec_38_10", 2) + .build(); + + private static SparkSession spark; + private static JavaSparkContext sparkContext = null; + + @BeforeClass + public static void startSpark() { + TestSparkDataFile.spark = SparkSession.builder().master("local[2]").getOrCreate(); + TestSparkDataFile.sparkContext = JavaSparkContext.fromSparkContext(spark.sparkContext()); + } + + @AfterClass + public static void stopSpark() { + SparkSession currentSpark = TestSparkDataFile.spark; + TestSparkDataFile.spark = null; + TestSparkDataFile.sparkContext = null; + currentSpark.stop(); + } + + @Rule public TemporaryFolder temp = new TemporaryFolder(); + private String tableLocation = null; + + @Before + public void setupTableLocation() throws Exception { + File tableDir = temp.newFolder(); + this.tableLocation = tableDir.toURI().toString(); + } + + @Test + public void testValueConversion() throws IOException { + Table table = + TABLES.create(SCHEMA, PartitionSpec.unpartitioned(), Maps.newHashMap(), tableLocation); + checkSparkDataFile(table); + } + + @Test + public void testValueConversionPartitionedTable() throws IOException { + Table table = TABLES.create(SCHEMA, SPEC, Maps.newHashMap(), tableLocation); + checkSparkDataFile(table); + } + + @Test + public void testValueConversionWithEmptyStats() throws IOException { + Map props = Maps.newHashMap(); + props.put(TableProperties.DEFAULT_WRITE_METRICS_MODE, "none"); + Table table = TABLES.create(SCHEMA, SPEC, props, tableLocation); + checkSparkDataFile(table); + } + + private void checkSparkDataFile(Table table) throws IOException { + Iterable rows = RandomData.generateSpark(table.schema(), 200, 0); + JavaRDD rdd = sparkContext.parallelize(Lists.newArrayList(rows)); + Dataset df = + spark.internalCreateDataFrame( + JavaRDD.toRDD(rdd), SparkSchemaUtil.convert(table.schema()), false); + + df.write().format("iceberg").mode("append").save(tableLocation); + + table.refresh(); + + List manifests = table.currentSnapshot().allManifests(table.io()); + Assert.assertEquals("Should have 1 manifest", 1, manifests.size()); + + List dataFiles = Lists.newArrayList(); + try (ManifestReader reader = ManifestFiles.read(manifests.get(0), table.io())) { + for (DataFile dataFile : reader) { + checkDataFile(dataFile.copy(), DataFiles.builder(table.spec()).copy(dataFile).build()); + dataFiles.add(dataFile.copy()); + } + } + + Dataset dataFileDF = spark.read().format("iceberg").load(tableLocation + "#files"); + + // reorder columns to test arbitrary projections + List columns = + Arrays.stream(dataFileDF.columns()).map(ColumnName::new).collect(Collectors.toList()); + Collections.shuffle(columns); + + List sparkDataFiles = + dataFileDF.select(Iterables.toArray(columns, Column.class)).collectAsList(); + + Assert.assertEquals( + "The number of files should match", dataFiles.size(), sparkDataFiles.size()); + + Types.StructType dataFileType = DataFile.getType(table.spec().partitionType()); + StructType sparkDataFileType = sparkDataFiles.get(0).schema(); + SparkDataFile wrapper = new SparkDataFile(dataFileType, sparkDataFileType); + + for (int i = 0; i < dataFiles.size(); i++) { + checkDataFile(dataFiles.get(i), wrapper.wrap(sparkDataFiles.get(i))); + } + } + + private void checkDataFile(DataFile expected, DataFile actual) { + Assert.assertEquals("Path must match", expected.path(), actual.path()); + Assert.assertEquals("Format must match", expected.format(), actual.format()); + Assert.assertEquals("Record count must match", expected.recordCount(), actual.recordCount()); + Assert.assertEquals("Size must match", expected.fileSizeInBytes(), actual.fileSizeInBytes()); + Assert.assertEquals( + "Record value counts must match", expected.valueCounts(), actual.valueCounts()); + Assert.assertEquals( + "Record null value counts must match", + expected.nullValueCounts(), + actual.nullValueCounts()); + Assert.assertEquals( + "Record nan value counts must match", expected.nanValueCounts(), actual.nanValueCounts()); + Assert.assertEquals("Lower bounds must match", expected.lowerBounds(), actual.lowerBounds()); + Assert.assertEquals("Upper bounds must match", expected.upperBounds(), actual.upperBounds()); + Assert.assertEquals("Key metadata must match", expected.keyMetadata(), actual.keyMetadata()); + Assert.assertEquals("Split offsets must match", expected.splitOffsets(), actual.splitOffsets()); + Assert.assertEquals("Sort order id must match", expected.sortOrderId(), actual.sortOrderId()); + + checkStructLike(expected.partition(), actual.partition()); + } + + private void checkStructLike(StructLike expected, StructLike actual) { + Assert.assertEquals("Struct size should match", expected.size(), actual.size()); + for (int i = 0; i < expected.size(); i++) { + Assert.assertEquals( + "Struct values must match", expected.get(i, Object.class), actual.get(i, Object.class)); + } + } +} diff --git a/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/source/TestSparkDataWrite.java b/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/source/TestSparkDataWrite.java new file mode 100644 index 000000000000..991719d61615 --- /dev/null +++ b/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/source/TestSparkDataWrite.java @@ -0,0 +1,741 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.source; + +import static org.apache.iceberg.TableProperties.SPARK_WRITE_PARTITIONED_FANOUT_ENABLED; +import static org.apache.iceberg.types.Types.NestedField.optional; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.spy; +import static org.mockito.Mockito.when; + +import java.io.File; +import java.io.IOException; +import java.util.List; +import java.util.Map; +import org.apache.hadoop.conf.Configuration; +import org.apache.iceberg.AppendFiles; +import org.apache.iceberg.AssertHelpers; +import org.apache.iceberg.DataFile; +import org.apache.iceberg.FileFormat; +import org.apache.iceberg.ManifestFile; +import org.apache.iceberg.ManifestFiles; +import org.apache.iceberg.PartitionSpec; +import org.apache.iceberg.Schema; +import org.apache.iceberg.SnapshotRef; +import org.apache.iceberg.Table; +import org.apache.iceberg.TableProperties; +import org.apache.iceberg.exceptions.CommitStateUnknownException; +import org.apache.iceberg.hadoop.HadoopTables; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableList; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap; +import org.apache.iceberg.relocated.com.google.common.collect.Lists; +import org.apache.iceberg.spark.SparkWriteOptions; +import org.apache.iceberg.types.Types; +import org.apache.iceberg.util.SnapshotUtil; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Encoders; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.SaveMode; +import org.apache.spark.sql.SparkSession; +import org.assertj.core.api.Assertions; +import org.junit.AfterClass; +import org.junit.Assert; +import org.junit.Assume; +import org.junit.BeforeClass; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.TemporaryFolder; +import org.junit.runner.RunWith; +import org.junit.runners.Parameterized; + +@RunWith(Parameterized.class) +public class TestSparkDataWrite { + private static final Configuration CONF = new Configuration(); + private final FileFormat format; + private final String branch; + private static SparkSession spark = null; + private static final Schema SCHEMA = + new Schema( + optional(1, "id", Types.IntegerType.get()), optional(2, "data", Types.StringType.get())); + + @Rule public TemporaryFolder temp = new TemporaryFolder(); + + @Parameterized.Parameters(name = "format = {0}, branch = {1}") + public static Object[] parameters() { + return new Object[] { + new Object[] {"parquet", null}, + new Object[] {"parquet", "main"}, + new Object[] {"parquet", "testBranch"}, + new Object[] {"avro", null}, + new Object[] {"orc", "testBranch"} + }; + } + + @BeforeClass + public static void startSpark() { + TestSparkDataWrite.spark = SparkSession.builder().master("local[2]").getOrCreate(); + } + + @Parameterized.AfterParam + public static void clearSourceCache() { + ManualSource.clearTables(); + } + + @AfterClass + public static void stopSpark() { + SparkSession currentSpark = TestSparkDataWrite.spark; + TestSparkDataWrite.spark = null; + currentSpark.stop(); + } + + public TestSparkDataWrite(String format, String branch) { + this.format = FileFormat.fromString(format); + this.branch = branch; + } + + @Test + public void testBasicWrite() throws IOException { + File parent = temp.newFolder(format.toString()); + File location = new File(parent, "test"); + String targetLocation = locationWithBranch(location); + + HadoopTables tables = new HadoopTables(CONF); + PartitionSpec spec = PartitionSpec.builderFor(SCHEMA).identity("data").build(); + Table table = tables.create(SCHEMA, spec, location.toString()); + + List expected = + Lists.newArrayList( + new SimpleRecord(1, "a"), new SimpleRecord(2, "b"), new SimpleRecord(3, "c")); + + Dataset df = spark.createDataFrame(expected, SimpleRecord.class); + // TODO: incoming columns must be ordered according to the table's schema + df.select("id", "data") + .write() + .format("iceberg") + .option(SparkWriteOptions.WRITE_FORMAT, format.toString()) + .mode(SaveMode.Append) + .save(location.toString()); + + createBranch(table); + table.refresh(); + + Dataset result = spark.read().format("iceberg").load(targetLocation); + + List actual = + result.orderBy("id").as(Encoders.bean(SimpleRecord.class)).collectAsList(); + Assert.assertEquals("Number of rows should match", expected.size(), actual.size()); + Assert.assertEquals("Result rows should match", expected, actual); + for (ManifestFile manifest : + SnapshotUtil.latestSnapshot(table, branch).allManifests(table.io())) { + for (DataFile file : ManifestFiles.read(manifest, table.io())) { + // TODO: avro not support split + if (!format.equals(FileFormat.AVRO)) { + Assert.assertNotNull("Split offsets not present", file.splitOffsets()); + } + Assert.assertEquals("Should have reported record count as 1", 1, file.recordCount()); + // TODO: append more metric info + if (format.equals(FileFormat.PARQUET)) { + Assert.assertNotNull("Column sizes metric not present", file.columnSizes()); + Assert.assertNotNull("Counts metric not present", file.valueCounts()); + Assert.assertNotNull("Null value counts metric not present", file.nullValueCounts()); + Assert.assertNotNull("Lower bounds metric not present", file.lowerBounds()); + Assert.assertNotNull("Upper bounds metric not present", file.upperBounds()); + } + } + } + } + + @Test + public void testAppend() throws IOException { + File parent = temp.newFolder(format.toString()); + File location = new File(parent, "test"); + String targetLocation = locationWithBranch(location); + + HadoopTables tables = new HadoopTables(CONF); + PartitionSpec spec = PartitionSpec.builderFor(SCHEMA).identity("data").build(); + Table table = tables.create(SCHEMA, spec, location.toString()); + + List records = + Lists.newArrayList( + new SimpleRecord(1, "a"), new SimpleRecord(2, "b"), new SimpleRecord(3, "c")); + + List expected = + Lists.newArrayList( + new SimpleRecord(1, "a"), + new SimpleRecord(2, "b"), + new SimpleRecord(3, "c"), + new SimpleRecord(4, "a"), + new SimpleRecord(5, "b"), + new SimpleRecord(6, "c")); + + Dataset df = spark.createDataFrame(records, SimpleRecord.class); + + df.select("id", "data") + .write() + .format("iceberg") + .option(SparkWriteOptions.WRITE_FORMAT, format.toString()) + .mode(SaveMode.Append) + .save(location.toString()); + + createBranch(table); + + df.withColumn("id", df.col("id").plus(3)) + .select("id", "data") + .write() + .format("iceberg") + .option(SparkWriteOptions.WRITE_FORMAT, format.toString()) + .mode(SaveMode.Append) + .save(targetLocation); + + table.refresh(); + + Dataset result = spark.read().format("iceberg").load(targetLocation); + + List actual = + result.orderBy("id").as(Encoders.bean(SimpleRecord.class)).collectAsList(); + Assert.assertEquals("Number of rows should match", expected.size(), actual.size()); + Assert.assertEquals("Result rows should match", expected, actual); + } + + @Test + public void testEmptyOverwrite() throws IOException { + File parent = temp.newFolder(format.toString()); + File location = new File(parent, "test"); + String targetLocation = locationWithBranch(location); + + HadoopTables tables = new HadoopTables(CONF); + PartitionSpec spec = PartitionSpec.builderFor(SCHEMA).identity("id").build(); + Table table = tables.create(SCHEMA, spec, location.toString()); + + List records = + Lists.newArrayList( + new SimpleRecord(1, "a"), new SimpleRecord(2, "b"), new SimpleRecord(3, "c")); + + List expected = records; + Dataset df = spark.createDataFrame(records, SimpleRecord.class); + + df.select("id", "data") + .write() + .format("iceberg") + .option(SparkWriteOptions.WRITE_FORMAT, format.toString()) + .mode(SaveMode.Append) + .save(location.toString()); + + createBranch(table); + + Dataset empty = spark.createDataFrame(ImmutableList.of(), SimpleRecord.class); + empty + .select("id", "data") + .write() + .format("iceberg") + .option(SparkWriteOptions.WRITE_FORMAT, format.toString()) + .mode(SaveMode.Overwrite) + .option("overwrite-mode", "dynamic") + .save(targetLocation); + + table.refresh(); + + Dataset result = spark.read().format("iceberg").load(targetLocation); + + List actual = + result.orderBy("id").as(Encoders.bean(SimpleRecord.class)).collectAsList(); + Assert.assertEquals("Number of rows should match", expected.size(), actual.size()); + Assert.assertEquals("Result rows should match", expected, actual); + } + + @Test + public void testOverwrite() throws IOException { + File parent = temp.newFolder(format.toString()); + File location = new File(parent, "test"); + String targetLocation = locationWithBranch(location); + + HadoopTables tables = new HadoopTables(CONF); + PartitionSpec spec = PartitionSpec.builderFor(SCHEMA).identity("id").build(); + Table table = tables.create(SCHEMA, spec, location.toString()); + + List records = + Lists.newArrayList( + new SimpleRecord(1, "a"), new SimpleRecord(2, "b"), new SimpleRecord(3, "c")); + + List expected = + Lists.newArrayList( + new SimpleRecord(1, "a"), + new SimpleRecord(2, "a"), + new SimpleRecord(3, "c"), + new SimpleRecord(4, "b"), + new SimpleRecord(6, "c")); + + Dataset df = spark.createDataFrame(records, SimpleRecord.class); + + df.select("id", "data") + .write() + .format("iceberg") + .option(SparkWriteOptions.WRITE_FORMAT, format.toString()) + .mode(SaveMode.Append) + .save(location.toString()); + + createBranch(table); + + // overwrite with 2*id to replace record 2, append 4 and 6 + df.withColumn("id", df.col("id").multiply(2)) + .select("id", "data") + .write() + .format("iceberg") + .option(SparkWriteOptions.WRITE_FORMAT, format.toString()) + .mode(SaveMode.Overwrite) + .option("overwrite-mode", "dynamic") + .save(targetLocation); + + table.refresh(); + + Dataset result = spark.read().format("iceberg").load(targetLocation); + + List actual = + result.orderBy("id").as(Encoders.bean(SimpleRecord.class)).collectAsList(); + Assert.assertEquals("Number of rows should match", expected.size(), actual.size()); + Assert.assertEquals("Result rows should match", expected, actual); + } + + @Test + public void testUnpartitionedOverwrite() throws IOException { + File parent = temp.newFolder(format.toString()); + File location = new File(parent, "test"); + String targetLocation = locationWithBranch(location); + + HadoopTables tables = new HadoopTables(CONF); + PartitionSpec spec = PartitionSpec.unpartitioned(); + Table table = tables.create(SCHEMA, spec, location.toString()); + + List expected = + Lists.newArrayList( + new SimpleRecord(1, "a"), new SimpleRecord(2, "b"), new SimpleRecord(3, "c")); + + Dataset df = spark.createDataFrame(expected, SimpleRecord.class); + + df.select("id", "data") + .write() + .format("iceberg") + .option(SparkWriteOptions.WRITE_FORMAT, format.toString()) + .mode(SaveMode.Append) + .save(location.toString()); + + createBranch(table); + + // overwrite with the same data; should not produce two copies + df.select("id", "data") + .write() + .format("iceberg") + .option(SparkWriteOptions.WRITE_FORMAT, format.toString()) + .mode(SaveMode.Overwrite) + .save(targetLocation); + + table.refresh(); + + Dataset result = spark.read().format("iceberg").load(targetLocation); + + List actual = + result.orderBy("id").as(Encoders.bean(SimpleRecord.class)).collectAsList(); + Assert.assertEquals("Number of rows should match", expected.size(), actual.size()); + Assert.assertEquals("Result rows should match", expected, actual); + } + + @Test + public void testUnpartitionedCreateWithTargetFileSizeViaTableProperties() throws IOException { + File parent = temp.newFolder(format.toString()); + File location = new File(parent, "test"); + String targetLocation = locationWithBranch(location); + + HadoopTables tables = new HadoopTables(CONF); + PartitionSpec spec = PartitionSpec.unpartitioned(); + Table table = tables.create(SCHEMA, spec, location.toString()); + + table + .updateProperties() + .set(TableProperties.WRITE_TARGET_FILE_SIZE_BYTES, "4") // ~4 bytes; low enough to trigger + .commit(); + + List expected = Lists.newArrayListWithCapacity(4000); + for (int i = 0; i < 4000; i++) { + expected.add(new SimpleRecord(i, "a")); + } + + Dataset df = spark.createDataFrame(expected, SimpleRecord.class); + + df.select("id", "data") + .write() + .format("iceberg") + .option(SparkWriteOptions.WRITE_FORMAT, format.toString()) + .mode(SaveMode.Append) + .save(location.toString()); + + createBranch(table); + table.refresh(); + + Dataset result = spark.read().format("iceberg").load(targetLocation); + + List actual = + result.orderBy("id").as(Encoders.bean(SimpleRecord.class)).collectAsList(); + Assert.assertEquals("Number of rows should match", expected.size(), actual.size()); + Assert.assertEquals("Result rows should match", expected, actual); + + List files = Lists.newArrayList(); + for (ManifestFile manifest : + SnapshotUtil.latestSnapshot(table, branch).allManifests(table.io())) { + for (DataFile file : ManifestFiles.read(manifest, table.io())) { + files.add(file); + } + } + + Assert.assertEquals("Should have 4 DataFiles", 4, files.size()); + Assert.assertTrue( + "All DataFiles contain 1000 rows", files.stream().allMatch(d -> d.recordCount() == 1000)); + } + + @Test + public void testPartitionedCreateWithTargetFileSizeViaOption() throws IOException { + partitionedCreateWithTargetFileSizeViaOption(IcebergOptionsType.NONE); + } + + @Test + public void testPartitionedFanoutCreateWithTargetFileSizeViaOption() throws IOException { + partitionedCreateWithTargetFileSizeViaOption(IcebergOptionsType.TABLE); + } + + @Test + public void testPartitionedFanoutCreateWithTargetFileSizeViaOption2() throws IOException { + partitionedCreateWithTargetFileSizeViaOption(IcebergOptionsType.JOB); + } + + @Test + public void testWriteProjection() throws IOException { + Assume.assumeTrue( + "Not supported in Spark 3; analysis requires all columns are present", + spark.version().startsWith("2")); + + File parent = temp.newFolder(format.toString()); + File location = new File(parent, "test"); + String targetLocation = locationWithBranch(location); + + HadoopTables tables = new HadoopTables(CONF); + PartitionSpec spec = PartitionSpec.unpartitioned(); + Table table = tables.create(SCHEMA, spec, location.toString()); + + List expected = + Lists.newArrayList( + new SimpleRecord(1, null), new SimpleRecord(2, null), new SimpleRecord(3, null)); + + Dataset df = spark.createDataFrame(expected, SimpleRecord.class); + + df.select("id") + .write() // select only id column + .format("iceberg") + .option(SparkWriteOptions.WRITE_FORMAT, format.toString()) + .mode(SaveMode.Append) + .save(location.toString()); + + createBranch(table); + table.refresh(); + + Dataset result = spark.read().format("iceberg").load(targetLocation); + + List actual = + result.orderBy("id").as(Encoders.bean(SimpleRecord.class)).collectAsList(); + Assert.assertEquals("Number of rows should match", expected.size(), actual.size()); + Assert.assertEquals("Result rows should match", expected, actual); + } + + @Test + public void testWriteProjectionWithMiddle() throws IOException { + Assume.assumeTrue( + "Not supported in Spark 3; analysis requires all columns are present", + spark.version().startsWith("2")); + + File parent = temp.newFolder(format.toString()); + File location = new File(parent, "test"); + String targetLocation = locationWithBranch(location); + + HadoopTables tables = new HadoopTables(CONF); + PartitionSpec spec = PartitionSpec.unpartitioned(); + Schema schema = + new Schema( + optional(1, "c1", Types.IntegerType.get()), + optional(2, "c2", Types.StringType.get()), + optional(3, "c3", Types.StringType.get())); + Table table = tables.create(schema, spec, location.toString()); + + List expected = + Lists.newArrayList( + new ThreeColumnRecord(1, null, "hello"), + new ThreeColumnRecord(2, null, "world"), + new ThreeColumnRecord(3, null, null)); + + Dataset df = spark.createDataFrame(expected, ThreeColumnRecord.class); + + df.select("c1", "c3") + .write() + .format("iceberg") + .option(SparkWriteOptions.WRITE_FORMAT, format.toString()) + .mode(SaveMode.Append) + .save(location.toString()); + + createBranch(table); + table.refresh(); + + Dataset result = spark.read().format("iceberg").load(targetLocation); + + List actual = + result.orderBy("c1").as(Encoders.bean(ThreeColumnRecord.class)).collectAsList(); + Assert.assertEquals("Number of rows should match", expected.size(), actual.size()); + Assert.assertEquals("Result rows should match", expected, actual); + } + + @Test + public void testViewsReturnRecentResults() throws IOException { + File parent = temp.newFolder(format.toString()); + File location = new File(parent, "test"); + String targetLocation = locationWithBranch(location); + + HadoopTables tables = new HadoopTables(CONF); + PartitionSpec spec = PartitionSpec.builderFor(SCHEMA).identity("data").build(); + tables.create(SCHEMA, spec, location.toString()); + + List records = + Lists.newArrayList( + new SimpleRecord(1, "a"), new SimpleRecord(2, "b"), new SimpleRecord(3, "c")); + + Dataset df = spark.createDataFrame(records, SimpleRecord.class); + + df.select("id", "data") + .write() + .format("iceberg") + .option(SparkWriteOptions.WRITE_FORMAT, format.toString()) + .mode(SaveMode.Append) + .save(location.toString()); + + Table table = tables.load(location.toString()); + createBranch(table); + + Dataset query = spark.read().format("iceberg").load(targetLocation).where("id = 1"); + query.createOrReplaceTempView("tmp"); + + List actual1 = + spark.table("tmp").as(Encoders.bean(SimpleRecord.class)).collectAsList(); + List expected1 = Lists.newArrayList(new SimpleRecord(1, "a")); + Assert.assertEquals("Number of rows should match", expected1.size(), actual1.size()); + Assert.assertEquals("Result rows should match", expected1, actual1); + + df.select("id", "data") + .write() + .format("iceberg") + .option(SparkWriteOptions.WRITE_FORMAT, format.toString()) + .mode(SaveMode.Append) + .save(targetLocation); + + List actual2 = + spark.table("tmp").as(Encoders.bean(SimpleRecord.class)).collectAsList(); + List expected2 = + Lists.newArrayList(new SimpleRecord(1, "a"), new SimpleRecord(1, "a")); + Assert.assertEquals("Number of rows should match", expected2.size(), actual2.size()); + Assert.assertEquals("Result rows should match", expected2, actual2); + } + + public void partitionedCreateWithTargetFileSizeViaOption(IcebergOptionsType option) + throws IOException { + File parent = temp.newFolder(format.toString()); + File location = new File(parent, "test"); + String targetLocation = locationWithBranch(location); + + HadoopTables tables = new HadoopTables(CONF); + PartitionSpec spec = PartitionSpec.builderFor(SCHEMA).identity("data").build(); + Map properties = + ImmutableMap.of( + TableProperties.WRITE_DISTRIBUTION_MODE, TableProperties.WRITE_DISTRIBUTION_MODE_NONE); + Table table = tables.create(SCHEMA, spec, properties, location.toString()); + + List expected = Lists.newArrayListWithCapacity(8000); + for (int i = 0; i < 2000; i++) { + expected.add(new SimpleRecord(i, "a")); + expected.add(new SimpleRecord(i, "b")); + expected.add(new SimpleRecord(i, "c")); + expected.add(new SimpleRecord(i, "d")); + } + + Dataset df = spark.createDataFrame(expected, SimpleRecord.class); + + switch (option) { + case NONE: + df.select("id", "data") + .sort("data") + .write() + .format("iceberg") + .option(SparkWriteOptions.WRITE_FORMAT, format.toString()) + .mode(SaveMode.Append) + .option(SparkWriteOptions.TARGET_FILE_SIZE_BYTES, 4) // ~4 bytes; low enough to trigger + .save(location.toString()); + break; + case TABLE: + table.updateProperties().set(SPARK_WRITE_PARTITIONED_FANOUT_ENABLED, "true").commit(); + df.select("id", "data") + .write() + .format("iceberg") + .option(SparkWriteOptions.WRITE_FORMAT, format.toString()) + .mode(SaveMode.Append) + .option(SparkWriteOptions.TARGET_FILE_SIZE_BYTES, 4) // ~4 bytes; low enough to trigger + .save(location.toString()); + break; + case JOB: + df.select("id", "data") + .write() + .format("iceberg") + .option(SparkWriteOptions.WRITE_FORMAT, format.toString()) + .mode(SaveMode.Append) + .option(SparkWriteOptions.TARGET_FILE_SIZE_BYTES, 4) // ~4 bytes; low enough to trigger + .option(SparkWriteOptions.FANOUT_ENABLED, true) + .save(location.toString()); + break; + default: + break; + } + + createBranch(table); + table.refresh(); + + Dataset result = spark.read().format("iceberg").load(targetLocation); + + List actual = + result.orderBy("id").as(Encoders.bean(SimpleRecord.class)).collectAsList(); + Assert.assertEquals("Number of rows should match", expected.size(), actual.size()); + Assert.assertEquals("Result rows should match", expected, actual); + + List files = Lists.newArrayList(); + for (ManifestFile manifest : + SnapshotUtil.latestSnapshot(table, branch).allManifests(table.io())) { + for (DataFile file : ManifestFiles.read(manifest, table.io())) { + files.add(file); + } + } + + Assert.assertEquals("Should have 8 DataFiles", 8, files.size()); + Assert.assertTrue( + "All DataFiles contain 1000 rows", files.stream().allMatch(d -> d.recordCount() == 1000)); + } + + @Test + public void testCommitUnknownException() throws IOException { + File parent = temp.newFolder(format.toString()); + File location = new File(parent, "commitunknown"); + String targetLocation = locationWithBranch(location); + + HadoopTables tables = new HadoopTables(CONF); + PartitionSpec spec = PartitionSpec.builderFor(SCHEMA).identity("data").build(); + Table table = tables.create(SCHEMA, spec, location.toString()); + + List records = + Lists.newArrayList( + new SimpleRecord(1, "a"), new SimpleRecord(2, "b"), new SimpleRecord(3, "c")); + + Dataset df = spark.createDataFrame(records, SimpleRecord.class); + + df.select("id", "data") + .write() + .format("iceberg") + .option(SparkWriteOptions.WRITE_FORMAT, format.toString()) + .mode(SaveMode.Append) + .save(location.toString()); + + createBranch(table); + table.refresh(); + + List records2 = + Lists.newArrayList( + new SimpleRecord(4, "d"), new SimpleRecord(5, "e"), new SimpleRecord(6, "f")); + + Dataset df2 = spark.createDataFrame(records2, SimpleRecord.class); + + AppendFiles append = table.newFastAppend(); + if (branch != null) { + append.toBranch(branch); + } + + AppendFiles spyAppend = spy(append); + doAnswer( + invocation -> { + append.commit(); + throw new CommitStateUnknownException(new RuntimeException("Datacenter on Fire")); + }) + .when(spyAppend) + .commit(); + + Table spyTable = spy(table); + when(spyTable.newAppend()).thenReturn(spyAppend); + SparkTable sparkTable = new SparkTable(spyTable, false); + + String manualTableName = "unknown_exception"; + ManualSource.setTable(manualTableName, sparkTable); + + // Although an exception is thrown here, write and commit have succeeded + AssertHelpers.assertThrows( + "Should throw a Commit State Unknown Exception", + CommitStateUnknownException.class, + "Datacenter on Fire", + () -> + df2.select("id", "data") + .sort("data") + .write() + .format("org.apache.iceberg.spark.source.ManualSource") + .option(ManualSource.TABLE_NAME, manualTableName) + .mode(SaveMode.Append) + .save(targetLocation)); + + // Since write and commit succeeded, the rows should be readable + Dataset result = spark.read().format("iceberg").load(targetLocation); + List actual = + result.orderBy("id").as(Encoders.bean(SimpleRecord.class)).collectAsList(); + Assert.assertEquals( + "Number of rows should match", records.size() + records2.size(), actual.size()); + Assertions.assertThat(actual) + .describedAs("Result rows should match") + .containsExactlyInAnyOrder( + ImmutableList.builder() + .addAll(records) + .addAll(records2) + .build() + .toArray(new SimpleRecord[0])); + } + + public enum IcebergOptionsType { + NONE, + TABLE, + JOB + } + + private String locationWithBranch(File location) { + if (branch == null) { + return location.toString(); + } + + return location + "#branch_" + branch; + } + + private void createBranch(Table table) { + if (branch != null && !branch.equals(SnapshotRef.MAIN_BRANCH)) { + table.manageSnapshots().createBranch(branch, table.currentSnapshot().snapshotId()).commit(); + } + } +} diff --git a/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/source/TestSparkFileWriterFactory.java b/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/source/TestSparkFileWriterFactory.java new file mode 100644 index 000000000000..4a3263e368c0 --- /dev/null +++ b/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/source/TestSparkFileWriterFactory.java @@ -0,0 +1,74 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.source; + +import java.util.List; +import org.apache.iceberg.FileFormat; +import org.apache.iceberg.Schema; +import org.apache.iceberg.io.FileWriterFactory; +import org.apache.iceberg.io.TestFileWriterFactory; +import org.apache.iceberg.spark.SparkSchemaUtil; +import org.apache.iceberg.util.ArrayUtil; +import org.apache.iceberg.util.StructLikeSet; +import org.apache.spark.sql.catalyst.InternalRow; +import org.apache.spark.sql.catalyst.expressions.GenericInternalRow; +import org.apache.spark.sql.types.StructType; +import org.apache.spark.unsafe.types.UTF8String; + +public class TestSparkFileWriterFactory extends TestFileWriterFactory { + + public TestSparkFileWriterFactory(FileFormat fileFormat, boolean partitioned) { + super(fileFormat, partitioned); + } + + @Override + protected FileWriterFactory newWriterFactory( + Schema dataSchema, + List equalityFieldIds, + Schema equalityDeleteRowSchema, + Schema positionDeleteRowSchema) { + return SparkFileWriterFactory.builderFor(table) + .dataSchema(table.schema()) + .dataFileFormat(format()) + .deleteFileFormat(format()) + .equalityFieldIds(ArrayUtil.toIntArray(equalityFieldIds)) + .equalityDeleteRowSchema(equalityDeleteRowSchema) + .positionDeleteRowSchema(positionDeleteRowSchema) + .build(); + } + + @Override + protected InternalRow toRow(Integer id, String data) { + InternalRow row = new GenericInternalRow(2); + row.update(0, id); + row.update(1, UTF8String.fromString(data)); + return row; + } + + @Override + protected StructLikeSet toSet(Iterable rows) { + StructLikeSet set = StructLikeSet.create(table.schema().asStruct()); + StructType sparkType = SparkSchemaUtil.convert(table.schema()); + for (InternalRow row : rows) { + InternalRowWrapper wrapper = new InternalRowWrapper(sparkType); + set.add(wrapper.wrap(row)); + } + return set; + } +} diff --git a/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/source/TestSparkMergingMetrics.java b/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/source/TestSparkMergingMetrics.java new file mode 100644 index 000000000000..c3bb35ca7df8 --- /dev/null +++ b/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/source/TestSparkMergingMetrics.java @@ -0,0 +1,73 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.source; + +import java.io.IOException; +import java.util.Collections; +import java.util.List; +import java.util.Map; +import org.apache.iceberg.BaseTable; +import org.apache.iceberg.FileFormat; +import org.apache.iceberg.PartitionSpec; +import org.apache.iceberg.SortOrder; +import org.apache.iceberg.Table; +import org.apache.iceberg.TestMergingMetrics; +import org.apache.iceberg.data.Record; +import org.apache.iceberg.io.FileAppender; +import org.apache.iceberg.spark.SparkSchemaUtil; +import org.apache.spark.sql.catalyst.InternalRow; + +public class TestSparkMergingMetrics extends TestMergingMetrics { + + public TestSparkMergingMetrics(FileFormat fileFormat) { + super(fileFormat); + } + + @Override + protected FileAppender writeAndGetAppender(List records) throws IOException { + Table testTable = + new BaseTable(null, "dummy") { + @Override + public Map properties() { + return Collections.emptyMap(); + } + + @Override + public SortOrder sortOrder() { + return SortOrder.unsorted(); + } + + @Override + public PartitionSpec spec() { + return PartitionSpec.unpartitioned(); + } + }; + + FileAppender appender = + SparkAppenderFactory.builderFor(testTable, SCHEMA, SparkSchemaUtil.convert(SCHEMA)) + .build() + .newAppender(org.apache.iceberg.Files.localOutput(temp.newFile()), fileFormat); + try (FileAppender fileAppender = appender) { + records.stream() + .map(r -> new StructInternalRow(SCHEMA.asStruct()).setStruct(r)) + .forEach(fileAppender::add); + } + return appender; + } +} diff --git a/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/source/TestSparkMetadataColumns.java b/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/source/TestSparkMetadataColumns.java new file mode 100644 index 000000000000..e39985228570 --- /dev/null +++ b/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/source/TestSparkMetadataColumns.java @@ -0,0 +1,288 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.source; + +import static org.apache.iceberg.TableProperties.DEFAULT_FILE_FORMAT; +import static org.apache.iceberg.TableProperties.FORMAT_VERSION; +import static org.apache.iceberg.TableProperties.ORC_VECTORIZATION_ENABLED; +import static org.apache.iceberg.TableProperties.PARQUET_BATCH_SIZE; +import static org.apache.iceberg.TableProperties.PARQUET_ROW_GROUP_SIZE_BYTES; +import static org.apache.iceberg.TableProperties.PARQUET_VECTORIZATION_ENABLED; +import static org.apache.spark.sql.functions.lit; + +import java.io.IOException; +import java.util.List; +import java.util.stream.Collectors; +import org.apache.iceberg.AssertHelpers; +import org.apache.iceberg.FileFormat; +import org.apache.iceberg.HasTableOperations; +import org.apache.iceberg.MetadataColumns; +import org.apache.iceberg.PartitionSpec; +import org.apache.iceberg.PartitionSpecParser; +import org.apache.iceberg.Schema; +import org.apache.iceberg.Table; +import org.apache.iceberg.TableMetadata; +import org.apache.iceberg.TableOperations; +import org.apache.iceberg.UpdateProperties; +import org.apache.iceberg.exceptions.ValidationException; +import org.apache.iceberg.expressions.Expressions; +import org.apache.iceberg.relocated.com.google.common.base.Preconditions; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableList; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap; +import org.apache.iceberg.relocated.com.google.common.collect.Lists; +import org.apache.iceberg.spark.SparkTestBase; +import org.apache.iceberg.types.Types; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Encoders; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.catalyst.analysis.NoSuchTableException; +import org.junit.After; +import org.junit.Assert; +import org.junit.Assume; +import org.junit.Before; +import org.junit.BeforeClass; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.TemporaryFolder; +import org.junit.runner.RunWith; +import org.junit.runners.Parameterized; + +@RunWith(Parameterized.class) +public class TestSparkMetadataColumns extends SparkTestBase { + + private static final String TABLE_NAME = "test_table"; + private static final Schema SCHEMA = + new Schema( + Types.NestedField.optional(1, "id", Types.LongType.get()), + Types.NestedField.optional(2, "category", Types.StringType.get()), + Types.NestedField.optional(3, "data", Types.StringType.get())); + private static final PartitionSpec UNKNOWN_SPEC = + PartitionSpecParser.fromJson( + SCHEMA, + "{ \"spec-id\": 1, \"fields\": [ { \"name\": \"id_zero\", \"transform\": \"zero\", \"source-id\": 1 } ] }"); + + @Parameterized.Parameters(name = "fileFormat = {0}, vectorized = {1}, formatVersion = {2}") + public static Object[][] parameters() { + return new Object[][] { + {FileFormat.PARQUET, false, 1}, + {FileFormat.PARQUET, true, 1}, + {FileFormat.PARQUET, false, 2}, + {FileFormat.PARQUET, true, 2}, + {FileFormat.AVRO, false, 1}, + {FileFormat.AVRO, false, 2}, + {FileFormat.ORC, false, 1}, + {FileFormat.ORC, true, 1}, + {FileFormat.ORC, false, 2}, + {FileFormat.ORC, true, 2}, + }; + } + + @Rule public TemporaryFolder temp = new TemporaryFolder(); + + private final FileFormat fileFormat; + private final boolean vectorized; + private final int formatVersion; + + private Table table = null; + + public TestSparkMetadataColumns(FileFormat fileFormat, boolean vectorized, int formatVersion) { + this.fileFormat = fileFormat; + this.vectorized = vectorized; + this.formatVersion = formatVersion; + } + + @BeforeClass + public static void setupSpark() { + ImmutableMap config = + ImmutableMap.of( + "type", "hive", + "default-namespace", "default", + "cache-enabled", "true"); + spark + .conf() + .set("spark.sql.catalog.spark_catalog", "org.apache.iceberg.spark.source.TestSparkCatalog"); + config.forEach( + (key, value) -> spark.conf().set("spark.sql.catalog.spark_catalog." + key, value)); + } + + @Before + public void setupTable() throws IOException { + createAndInitTable(); + } + + @After + public void dropTable() { + TestTables.clearTables(); + } + + @Test + public void testSpecAndPartitionMetadataColumns() { + // TODO: support metadata structs in vectorized ORC reads + Assume.assumeFalse(fileFormat == FileFormat.ORC && vectorized); + + sql("INSERT INTO TABLE %s VALUES (1, 'a1', 'b1')", TABLE_NAME); + + table.refresh(); + table.updateSpec().addField("data").commit(); + sql("INSERT INTO TABLE %s VALUES (1, 'a1', 'b1')", TABLE_NAME); + + table.refresh(); + table.updateSpec().addField(Expressions.bucket("category", 8)).commit(); + sql("INSERT INTO TABLE %s VALUES (1, 'a1', 'b1')", TABLE_NAME); + + table.refresh(); + table.updateSpec().removeField("data").commit(); + sql("INSERT INTO TABLE %s VALUES (1, 'a1', 'b1')", TABLE_NAME); + + table.refresh(); + table.updateSpec().renameField("category_bucket_8", "category_bucket_8_another_name").commit(); + + List expected = + ImmutableList.of( + row(0, row(null, null)), + row(1, row("b1", null)), + row(2, row("b1", 2)), + row(3, row(null, 2))); + assertEquals( + "Rows must match", + expected, + sql("SELECT _spec_id, _partition FROM %s ORDER BY _spec_id", TABLE_NAME)); + } + + @Test + public void testPositionMetadataColumnWithMultipleRowGroups() throws NoSuchTableException { + Assume.assumeTrue(fileFormat == FileFormat.PARQUET); + + table.updateProperties().set(PARQUET_ROW_GROUP_SIZE_BYTES, "100").commit(); + + List ids = Lists.newArrayList(); + for (long id = 0L; id < 200L; id++) { + ids.add(id); + } + Dataset df = + spark + .createDataset(ids, Encoders.LONG()) + .withColumnRenamed("value", "id") + .withColumn("category", lit("hr")) + .withColumn("data", lit("ABCDEF")); + df.coalesce(1).writeTo(TABLE_NAME).append(); + + Assert.assertEquals(200, spark.table(TABLE_NAME).count()); + + List expectedRows = ids.stream().map(this::row).collect(Collectors.toList()); + assertEquals("Rows must match", expectedRows, sql("SELECT _pos FROM %s", TABLE_NAME)); + } + + @Test + public void testPositionMetadataColumnWithMultipleBatches() throws NoSuchTableException { + Assume.assumeTrue(fileFormat == FileFormat.PARQUET); + + table.updateProperties().set(PARQUET_BATCH_SIZE, "1000").commit(); + + List ids = Lists.newArrayList(); + for (long id = 0L; id < 7500L; id++) { + ids.add(id); + } + Dataset df = + spark + .createDataset(ids, Encoders.LONG()) + .withColumnRenamed("value", "id") + .withColumn("category", lit("hr")) + .withColumn("data", lit("ABCDEF")); + df.coalesce(1).writeTo(TABLE_NAME).append(); + + Assert.assertEquals(7500, spark.table(TABLE_NAME).count()); + + List expectedRows = ids.stream().map(this::row).collect(Collectors.toList()); + assertEquals("Rows must match", expectedRows, sql("SELECT _pos FROM %s", TABLE_NAME)); + } + + @Test + public void testPartitionMetadataColumnWithUnknownTransforms() { + // replace the table spec to include an unknown transform + TableOperations ops = ((HasTableOperations) table).operations(); + TableMetadata base = ops.current(); + ops.commit(base, base.updatePartitionSpec(UNKNOWN_SPEC)); + + AssertHelpers.assertThrows( + "Should fail to query the partition metadata column", + ValidationException.class, + "Cannot build table partition type, unknown transforms", + () -> sql("SELECT _partition FROM %s", TABLE_NAME)); + } + + @Test + public void testConflictingColumns() { + table + .updateSchema() + .addColumn(MetadataColumns.SPEC_ID.name(), Types.IntegerType.get()) + .addColumn(MetadataColumns.FILE_PATH.name(), Types.StringType.get()) + .commit(); + + sql("INSERT INTO TABLE %s VALUES (1, 'a1', 'b1', -1, 'path/to/file')", TABLE_NAME); + + assertEquals( + "Rows must match", + ImmutableList.of(row(1L, "a1")), + sql("SELECT id, category FROM %s", TABLE_NAME)); + + AssertHelpers.assertThrows( + "Should fail to query conflicting columns", + ValidationException.class, + "column names conflict", + () -> sql("SELECT * FROM %s", TABLE_NAME)); + + table.refresh(); + + table + .updateSchema() + .renameColumn(MetadataColumns.SPEC_ID.name(), "_renamed" + MetadataColumns.SPEC_ID.name()) + .renameColumn( + MetadataColumns.FILE_PATH.name(), "_renamed" + MetadataColumns.FILE_PATH.name()) + .commit(); + + assertEquals( + "Rows must match", + ImmutableList.of(row(0, null, -1)), + sql("SELECT _spec_id, _partition, _renamed_spec_id FROM %s", TABLE_NAME)); + } + + private void createAndInitTable() throws IOException { + this.table = + TestTables.create(temp.newFolder(), TABLE_NAME, SCHEMA, PartitionSpec.unpartitioned()); + + UpdateProperties updateProperties = table.updateProperties(); + updateProperties.set(FORMAT_VERSION, String.valueOf(formatVersion)); + updateProperties.set(DEFAULT_FILE_FORMAT, fileFormat.name()); + + switch (fileFormat) { + case PARQUET: + updateProperties.set(PARQUET_VECTORIZATION_ENABLED, String.valueOf(vectorized)); + break; + case ORC: + updateProperties.set(ORC_VECTORIZATION_ENABLED, String.valueOf(vectorized)); + break; + default: + Preconditions.checkState( + !vectorized, "File format %s does not support vectorized reads", fileFormat); + } + + updateProperties.commit(); + } +} diff --git a/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/source/TestSparkPartitioningWriters.java b/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/source/TestSparkPartitioningWriters.java new file mode 100644 index 000000000000..276d8c632fc0 --- /dev/null +++ b/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/source/TestSparkPartitioningWriters.java @@ -0,0 +1,74 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.source; + +import java.util.List; +import org.apache.iceberg.FileFormat; +import org.apache.iceberg.Schema; +import org.apache.iceberg.io.FileWriterFactory; +import org.apache.iceberg.io.TestPartitioningWriters; +import org.apache.iceberg.spark.SparkSchemaUtil; +import org.apache.iceberg.util.ArrayUtil; +import org.apache.iceberg.util.StructLikeSet; +import org.apache.spark.sql.catalyst.InternalRow; +import org.apache.spark.sql.catalyst.expressions.GenericInternalRow; +import org.apache.spark.sql.types.StructType; +import org.apache.spark.unsafe.types.UTF8String; + +public class TestSparkPartitioningWriters extends TestPartitioningWriters { + + public TestSparkPartitioningWriters(FileFormat fileFormat) { + super(fileFormat); + } + + @Override + protected FileWriterFactory newWriterFactory( + Schema dataSchema, + List equalityFieldIds, + Schema equalityDeleteRowSchema, + Schema positionDeleteRowSchema) { + return SparkFileWriterFactory.builderFor(table) + .dataSchema(table.schema()) + .dataFileFormat(format()) + .deleteFileFormat(format()) + .equalityFieldIds(ArrayUtil.toIntArray(equalityFieldIds)) + .equalityDeleteRowSchema(equalityDeleteRowSchema) + .positionDeleteRowSchema(positionDeleteRowSchema) + .build(); + } + + @Override + protected InternalRow toRow(Integer id, String data) { + InternalRow row = new GenericInternalRow(2); + row.update(0, id); + row.update(1, UTF8String.fromString(data)); + return row; + } + + @Override + protected StructLikeSet toSet(Iterable rows) { + StructLikeSet set = StructLikeSet.create(table.schema().asStruct()); + StructType sparkType = SparkSchemaUtil.convert(table.schema()); + for (InternalRow row : rows) { + InternalRowWrapper wrapper = new InternalRowWrapper(sparkType); + set.add(wrapper.wrap(row)); + } + return set; + } +} diff --git a/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/source/TestSparkPositionDeltaWriters.java b/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/source/TestSparkPositionDeltaWriters.java new file mode 100644 index 000000000000..245c392774f5 --- /dev/null +++ b/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/source/TestSparkPositionDeltaWriters.java @@ -0,0 +1,74 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.source; + +import java.util.List; +import org.apache.iceberg.FileFormat; +import org.apache.iceberg.Schema; +import org.apache.iceberg.io.FileWriterFactory; +import org.apache.iceberg.io.TestPositionDeltaWriters; +import org.apache.iceberg.spark.SparkSchemaUtil; +import org.apache.iceberg.util.ArrayUtil; +import org.apache.iceberg.util.StructLikeSet; +import org.apache.spark.sql.catalyst.InternalRow; +import org.apache.spark.sql.catalyst.expressions.GenericInternalRow; +import org.apache.spark.sql.types.StructType; +import org.apache.spark.unsafe.types.UTF8String; + +public class TestSparkPositionDeltaWriters extends TestPositionDeltaWriters { + + public TestSparkPositionDeltaWriters(FileFormat fileFormat) { + super(fileFormat); + } + + @Override + protected FileWriterFactory newWriterFactory( + Schema dataSchema, + List equalityFieldIds, + Schema equalityDeleteRowSchema, + Schema positionDeleteRowSchema) { + return SparkFileWriterFactory.builderFor(table) + .dataSchema(table.schema()) + .dataFileFormat(format()) + .deleteFileFormat(format()) + .equalityFieldIds(ArrayUtil.toIntArray(equalityFieldIds)) + .equalityDeleteRowSchema(equalityDeleteRowSchema) + .positionDeleteRowSchema(positionDeleteRowSchema) + .build(); + } + + @Override + protected InternalRow toRow(Integer id, String data) { + InternalRow row = new GenericInternalRow(2); + row.update(0, id); + row.update(1, UTF8String.fromString(data)); + return row; + } + + @Override + protected StructLikeSet toSet(Iterable rows) { + StructLikeSet set = StructLikeSet.create(table.schema().asStruct()); + StructType sparkType = SparkSchemaUtil.convert(table.schema()); + for (InternalRow row : rows) { + InternalRowWrapper wrapper = new InternalRowWrapper(sparkType); + set.add(wrapper.wrap(row)); + } + return set; + } +} diff --git a/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/source/TestSparkReadProjection.java b/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/source/TestSparkReadProjection.java new file mode 100644 index 000000000000..dde1eb7b36ec --- /dev/null +++ b/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/source/TestSparkReadProjection.java @@ -0,0 +1,258 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.source; + +import static org.apache.iceberg.Files.localOutput; +import static org.apache.iceberg.types.Types.NestedField.optional; +import static org.apache.iceberg.types.Types.NestedField.required; + +import java.io.File; +import java.io.IOException; +import java.util.List; +import java.util.Map; +import java.util.UUID; +import org.apache.iceberg.DataFile; +import org.apache.iceberg.DataFiles; +import org.apache.iceberg.FileFormat; +import org.apache.iceberg.PartitionSpec; +import org.apache.iceberg.Schema; +import org.apache.iceberg.Table; +import org.apache.iceberg.data.GenericAppenderFactory; +import org.apache.iceberg.data.Record; +import org.apache.iceberg.io.FileAppender; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap; +import org.apache.iceberg.relocated.com.google.common.collect.Lists; +import org.apache.iceberg.relocated.com.google.common.collect.Maps; +import org.apache.iceberg.spark.SparkReadOptions; +import org.apache.iceberg.spark.SparkValueConverter; +import org.apache.iceberg.types.Type; +import org.apache.iceberg.types.TypeUtil; +import org.apache.iceberg.types.Types; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.SparkSession; +import org.junit.AfterClass; +import org.junit.Assert; +import org.junit.BeforeClass; +import org.junit.runner.RunWith; +import org.junit.runners.Parameterized; + +@RunWith(Parameterized.class) +public class TestSparkReadProjection extends TestReadProjection { + + private static SparkSession spark = null; + + @Parameterized.Parameters(name = "format = {0}, vectorized = {1}") + public static Object[][] parameters() { + return new Object[][] { + {"parquet", false}, + {"parquet", true}, + {"avro", false}, + {"orc", false}, + {"orc", true} + }; + } + + private final FileFormat format; + private final boolean vectorized; + + public TestSparkReadProjection(String format, boolean vectorized) { + super(format); + this.format = FileFormat.fromString(format); + this.vectorized = vectorized; + } + + @BeforeClass + public static void startSpark() { + TestSparkReadProjection.spark = SparkSession.builder().master("local[2]").getOrCreate(); + ImmutableMap config = + ImmutableMap.of( + "type", "hive", + "default-namespace", "default", + "parquet-enabled", "true", + "cache-enabled", "false"); + spark + .conf() + .set("spark.sql.catalog.spark_catalog", "org.apache.iceberg.spark.source.TestSparkCatalog"); + config.forEach( + (key, value) -> spark.conf().set("spark.sql.catalog.spark_catalog." + key, value)); + } + + @AfterClass + public static void stopSpark() { + SparkSession currentSpark = TestSparkReadProjection.spark; + TestSparkReadProjection.spark = null; + currentSpark.stop(); + } + + @Override + protected Record writeAndRead(String desc, Schema writeSchema, Schema readSchema, Record record) + throws IOException { + File parent = temp.newFolder(desc); + File location = new File(parent, "test"); + File dataFolder = new File(location, "data"); + Assert.assertTrue("mkdirs should succeed", dataFolder.mkdirs()); + + File testFile = new File(dataFolder, format.addExtension(UUID.randomUUID().toString())); + + Table table = TestTables.create(location, desc, writeSchema, PartitionSpec.unpartitioned()); + try { + // Important: use the table's schema for the rest of the test + // When tables are created, the column ids are reassigned. + Schema tableSchema = table.schema(); + + try (FileAppender writer = + new GenericAppenderFactory(tableSchema).newAppender(localOutput(testFile), format)) { + writer.add(record); + } + + DataFile file = + DataFiles.builder(PartitionSpec.unpartitioned()) + .withRecordCount(100) + .withFileSizeInBytes(testFile.length()) + .withPath(testFile.toString()) + .build(); + + table.newAppend().appendFile(file).commit(); + + // rewrite the read schema for the table's reassigned ids + Map idMapping = Maps.newHashMap(); + for (int id : allIds(writeSchema)) { + // translate each id to the original schema's column name, then to the new schema's id + String originalName = writeSchema.findColumnName(id); + idMapping.put(id, tableSchema.findField(originalName).fieldId()); + } + Schema expectedSchema = reassignIds(readSchema, idMapping); + + // Set the schema to the expected schema directly to simulate the table schema evolving + TestTables.replaceMetadata( + desc, TestTables.readMetadata(desc).updateSchema(expectedSchema, 100)); + + Dataset df = + spark + .read() + .format("org.apache.iceberg.spark.source.TestIcebergSource") + .option("iceberg.table.name", desc) + .option(SparkReadOptions.VECTORIZATION_ENABLED, String.valueOf(vectorized)) + .load(); + + return SparkValueConverter.convert(readSchema, df.collectAsList().get(0)); + + } finally { + TestTables.clearTables(); + } + } + + private List allIds(Schema schema) { + List ids = Lists.newArrayList(); + TypeUtil.visit( + schema, + new TypeUtil.SchemaVisitor() { + @Override + public Void field(Types.NestedField field, Void fieldResult) { + ids.add(field.fieldId()); + return null; + } + + @Override + public Void list(Types.ListType list, Void elementResult) { + ids.add(list.elementId()); + return null; + } + + @Override + public Void map(Types.MapType map, Void keyResult, Void valueResult) { + ids.add(map.keyId()); + ids.add(map.valueId()); + return null; + } + }); + return ids; + } + + private Schema reassignIds(Schema schema, Map idMapping) { + return new Schema( + TypeUtil.visit( + schema, + new TypeUtil.SchemaVisitor() { + private int mapId(int id) { + if (idMapping.containsKey(id)) { + return idMapping.get(id); + } + return 1000 + id; // make sure the new IDs don't conflict with reassignment + } + + @Override + public Type schema(Schema schema, Type structResult) { + return structResult; + } + + @Override + public Type struct(Types.StructType struct, List fieldResults) { + List newFields = + Lists.newArrayListWithExpectedSize(fieldResults.size()); + List fields = struct.fields(); + for (int i = 0; i < fields.size(); i += 1) { + Types.NestedField field = fields.get(i); + if (field.isOptional()) { + newFields.add( + optional(mapId(field.fieldId()), field.name(), fieldResults.get(i))); + } else { + newFields.add( + required(mapId(field.fieldId()), field.name(), fieldResults.get(i))); + } + } + return Types.StructType.of(newFields); + } + + @Override + public Type field(Types.NestedField field, Type fieldResult) { + return fieldResult; + } + + @Override + public Type list(Types.ListType list, Type elementResult) { + if (list.isElementOptional()) { + return Types.ListType.ofOptional(mapId(list.elementId()), elementResult); + } else { + return Types.ListType.ofRequired(mapId(list.elementId()), elementResult); + } + } + + @Override + public Type map(Types.MapType map, Type keyResult, Type valueResult) { + if (map.isValueOptional()) { + return Types.MapType.ofOptional( + mapId(map.keyId()), mapId(map.valueId()), keyResult, valueResult); + } else { + return Types.MapType.ofRequired( + mapId(map.keyId()), mapId(map.valueId()), keyResult, valueResult); + } + } + + @Override + public Type primitive(Type.PrimitiveType primitive) { + return primitive; + } + }) + .asNestedType() + .asStructType() + .fields()); + } +} diff --git a/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/source/TestSparkReaderDeletes.java b/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/source/TestSparkReaderDeletes.java new file mode 100644 index 000000000000..cadcbad6aa76 --- /dev/null +++ b/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/source/TestSparkReaderDeletes.java @@ -0,0 +1,657 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.source; + +import static org.apache.hadoop.hive.conf.HiveConf.ConfVars.METASTOREURIS; +import static org.apache.iceberg.spark.source.SparkSQLExecutionHelper.lastExecutedMetricValue; +import static org.apache.iceberg.types.Types.NestedField.required; + +import java.io.File; +import java.io.IOException; +import java.util.List; +import java.util.Set; +import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.fs.Path; +import org.apache.hadoop.hive.conf.HiveConf; +import org.apache.iceberg.BaseTable; +import org.apache.iceberg.CatalogUtil; +import org.apache.iceberg.CombinedScanTask; +import org.apache.iceberg.DataFile; +import org.apache.iceberg.DataFiles; +import org.apache.iceberg.DeleteFile; +import org.apache.iceberg.Files; +import org.apache.iceberg.MetadataColumns; +import org.apache.iceberg.PartitionSpec; +import org.apache.iceberg.Schema; +import org.apache.iceberg.Table; +import org.apache.iceberg.TableMetadata; +import org.apache.iceberg.TableOperations; +import org.apache.iceberg.TableProperties; +import org.apache.iceberg.TestHelpers; +import org.apache.iceberg.catalog.Namespace; +import org.apache.iceberg.catalog.TableIdentifier; +import org.apache.iceberg.data.DeleteReadTests; +import org.apache.iceberg.data.FileHelpers; +import org.apache.iceberg.data.GenericRecord; +import org.apache.iceberg.data.InternalRecordWrapper; +import org.apache.iceberg.data.Record; +import org.apache.iceberg.exceptions.AlreadyExistsException; +import org.apache.iceberg.hive.HiveCatalog; +import org.apache.iceberg.hive.TestHiveMetastore; +import org.apache.iceberg.io.CloseableIterable; +import org.apache.iceberg.io.FileAppender; +import org.apache.iceberg.parquet.Parquet; +import org.apache.iceberg.parquet.ParquetSchemaUtil; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap; +import org.apache.iceberg.relocated.com.google.common.collect.Lists; +import org.apache.iceberg.relocated.com.google.common.collect.Sets; +import org.apache.iceberg.spark.SparkSchemaUtil; +import org.apache.iceberg.spark.SparkStructLike; +import org.apache.iceberg.spark.data.RandomData; +import org.apache.iceberg.spark.data.SparkParquetWriters; +import org.apache.iceberg.spark.source.metrics.NumDeletes; +import org.apache.iceberg.types.Types; +import org.apache.iceberg.util.ArrayUtil; +import org.apache.iceberg.util.CharSequenceSet; +import org.apache.iceberg.util.Pair; +import org.apache.iceberg.util.StructLikeSet; +import org.apache.iceberg.util.TableScanUtil; +import org.apache.parquet.hadoop.ParquetFileWriter; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.SparkSession; +import org.apache.spark.sql.catalyst.InternalRow; +import org.apache.spark.sql.internal.SQLConf; +import org.apache.spark.sql.types.StructType; +import org.jetbrains.annotations.NotNull; +import org.junit.After; +import org.junit.AfterClass; +import org.junit.Assert; +import org.junit.Assume; +import org.junit.BeforeClass; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.Parameterized; + +@RunWith(Parameterized.class) +public class TestSparkReaderDeletes extends DeleteReadTests { + + private static TestHiveMetastore metastore = null; + protected static SparkSession spark = null; + protected static HiveCatalog catalog = null; + private final String format; + private final boolean vectorized; + + public TestSparkReaderDeletes(String format, boolean vectorized) { + this.format = format; + this.vectorized = vectorized; + } + + @Parameterized.Parameters(name = "format = {0}, vectorized = {1}") + public static Object[][] parameters() { + return new Object[][] { + new Object[] {"parquet", false}, + new Object[] {"parquet", true}, + new Object[] {"orc", false}, + new Object[] {"avro", false} + }; + } + + @BeforeClass + public static void startMetastoreAndSpark() { + metastore = new TestHiveMetastore(); + metastore.start(); + HiveConf hiveConf = metastore.hiveConf(); + + spark = + SparkSession.builder() + .master("local[2]") + .config("spark.appStateStore.asyncTracking.enable", false) + .config(SQLConf.PARTITION_OVERWRITE_MODE().key(), "dynamic") + .config("spark.hadoop." + METASTOREURIS.varname, hiveConf.get(METASTOREURIS.varname)) + .enableHiveSupport() + .getOrCreate(); + + catalog = + (HiveCatalog) + CatalogUtil.loadCatalog( + HiveCatalog.class.getName(), "hive", ImmutableMap.of(), hiveConf); + + try { + catalog.createNamespace(Namespace.of("default")); + } catch (AlreadyExistsException ignored) { + // the default namespace already exists. ignore the create error + } + } + + @AfterClass + public static void stopMetastoreAndSpark() throws Exception { + catalog = null; + metastore.stop(); + metastore = null; + spark.stop(); + spark = null; + } + + @After + @Override + public void cleanup() throws IOException { + super.cleanup(); + dropTable("test3"); + } + + @Override + protected Table createTable(String name, Schema schema, PartitionSpec spec) { + Table table = catalog.createTable(TableIdentifier.of("default", name), schema); + TableOperations ops = ((BaseTable) table).operations(); + TableMetadata meta = ops.current(); + ops.commit(meta, meta.upgradeToFormatVersion(2)); + table.updateProperties().set(TableProperties.DEFAULT_FILE_FORMAT, format).commit(); + if (format.equals("parquet") || format.equals("orc")) { + String vectorizationEnabled = + format.equals("parquet") + ? TableProperties.PARQUET_VECTORIZATION_ENABLED + : TableProperties.ORC_VECTORIZATION_ENABLED; + String batchSize = + format.equals("parquet") + ? TableProperties.PARQUET_BATCH_SIZE + : TableProperties.ORC_BATCH_SIZE; + table.updateProperties().set(vectorizationEnabled, String.valueOf(vectorized)).commit(); + if (vectorized) { + // split 7 records to two batches to cover more code paths + table.updateProperties().set(batchSize, "4").commit(); + } + } + return table; + } + + @Override + protected void dropTable(String name) { + catalog.dropTable(TableIdentifier.of("default", name)); + } + + protected boolean countDeletes() { + return true; + } + + @Override + protected long deleteCount() { + return Long.parseLong(lastExecutedMetricValue(spark, NumDeletes.DISPLAY_STRING)); + } + + @Override + public StructLikeSet rowSet(String name, Table table, String... columns) { + return rowSet(name, table.schema().select(columns).asStruct(), columns); + } + + public StructLikeSet rowSet(String name, Types.StructType projection, String... columns) { + Dataset df = + spark + .read() + .format("iceberg") + .load(TableIdentifier.of("default", name).toString()) + .selectExpr(columns); + + StructLikeSet set = StructLikeSet.create(projection); + df.collectAsList() + .forEach( + row -> { + SparkStructLike rowWrapper = new SparkStructLike(projection); + set.add(rowWrapper.wrap(row)); + }); + + return set; + } + + @Test + public void testEqualityDeleteWithFilter() throws IOException { + String tableName = table.name().substring(table.name().lastIndexOf(".") + 1); + Schema deleteRowSchema = table.schema().select("data"); + Record dataDelete = GenericRecord.create(deleteRowSchema); + List dataDeletes = + Lists.newArrayList( + dataDelete.copy("data", "a"), // id = 29 + dataDelete.copy("data", "d"), // id = 89 + dataDelete.copy("data", "g") // id = 122 + ); + + DeleteFile eqDeletes = + FileHelpers.writeDeleteFile( + table, + Files.localOutput(temp.newFile()), + TestHelpers.Row.of(0), + dataDeletes, + deleteRowSchema); + + table.newRowDelta().addDeletes(eqDeletes).commit(); + + Types.StructType projection = table.schema().select("*").asStruct(); + Dataset df = + spark + .read() + .format("iceberg") + .load(TableIdentifier.of("default", tableName).toString()) + .filter("data = 'a'") // select a deleted row + .selectExpr("*"); + + StructLikeSet actual = StructLikeSet.create(projection); + df.collectAsList() + .forEach( + row -> { + SparkStructLike rowWrapper = new SparkStructLike(projection); + actual.add(rowWrapper.wrap(row)); + }); + + Assert.assertEquals("Table should contain no rows", 0, actual.size()); + } + + @Test + public void testReadEqualityDeleteRows() throws IOException { + Schema deleteSchema1 = table.schema().select("data"); + Record dataDelete = GenericRecord.create(deleteSchema1); + List dataDeletes = + Lists.newArrayList( + dataDelete.copy("data", "a"), // id = 29 + dataDelete.copy("data", "d") // id = 89 + ); + + Schema deleteSchema2 = table.schema().select("id"); + Record idDelete = GenericRecord.create(deleteSchema2); + List idDeletes = + Lists.newArrayList( + idDelete.copy("id", 121), // id = 121 + idDelete.copy("id", 122) // id = 122 + ); + + DeleteFile eqDelete1 = + FileHelpers.writeDeleteFile( + table, + Files.localOutput(temp.newFile()), + TestHelpers.Row.of(0), + dataDeletes, + deleteSchema1); + + DeleteFile eqDelete2 = + FileHelpers.writeDeleteFile( + table, + Files.localOutput(temp.newFile()), + TestHelpers.Row.of(0), + idDeletes, + deleteSchema2); + + table.newRowDelta().addDeletes(eqDelete1).addDeletes(eqDelete2).commit(); + + StructLikeSet expectedRowSet = rowSetWithIds(29, 89, 121, 122); + + Types.StructType type = table.schema().asStruct(); + StructLikeSet actualRowSet = StructLikeSet.create(type); + + CloseableIterable tasks = + TableScanUtil.planTasks( + table.newScan().planFiles(), + TableProperties.METADATA_SPLIT_SIZE_DEFAULT, + TableProperties.SPLIT_LOOKBACK_DEFAULT, + TableProperties.SPLIT_OPEN_FILE_COST_DEFAULT); + + for (CombinedScanTask task : tasks) { + try (EqualityDeleteRowReader reader = + new EqualityDeleteRowReader(task, table, null, table.schema(), false)) { + while (reader.next()) { + actualRowSet.add( + new InternalRowWrapper(SparkSchemaUtil.convert(table.schema())) + .wrap(reader.get().copy())); + } + } + } + + Assert.assertEquals("should include 4 deleted row", 4, actualRowSet.size()); + Assert.assertEquals("deleted row should be matched", expectedRowSet, actualRowSet); + } + + @Test + public void testPosDeletesAllRowsInBatch() throws IOException { + // read.parquet.vectorization.batch-size is set to 4, so the 4 rows in the first batch are all + // deleted. + List> deletes = + Lists.newArrayList( + Pair.of(dataFile.path(), 0L), // id = 29 + Pair.of(dataFile.path(), 1L), // id = 43 + Pair.of(dataFile.path(), 2L), // id = 61 + Pair.of(dataFile.path(), 3L) // id = 89 + ); + + Pair posDeletes = + FileHelpers.writeDeleteFile( + table, Files.localOutput(temp.newFile()), TestHelpers.Row.of(0), deletes); + + table + .newRowDelta() + .addDeletes(posDeletes.first()) + .validateDataFilesExist(posDeletes.second()) + .commit(); + + StructLikeSet expected = rowSetWithoutIds(table, records, 29, 43, 61, 89); + StructLikeSet actual = rowSet(tableName, table, "*"); + + Assert.assertEquals("Table should contain expected rows", expected, actual); + checkDeleteCount(4L); + } + + @Test + public void testPosDeletesWithDeletedColumn() throws IOException { + // read.parquet.vectorization.batch-size is set to 4, so the 4 rows in the first batch are all + // deleted. + List> deletes = + Lists.newArrayList( + Pair.of(dataFile.path(), 0L), // id = 29 + Pair.of(dataFile.path(), 1L), // id = 43 + Pair.of(dataFile.path(), 2L), // id = 61 + Pair.of(dataFile.path(), 3L) // id = 89 + ); + + Pair posDeletes = + FileHelpers.writeDeleteFile( + table, Files.localOutput(temp.newFile()), TestHelpers.Row.of(0), deletes); + + table + .newRowDelta() + .addDeletes(posDeletes.first()) + .validateDataFilesExist(posDeletes.second()) + .commit(); + + StructLikeSet expected = expectedRowSet(29, 43, 61, 89); + StructLikeSet actual = + rowSet(tableName, PROJECTION_SCHEMA.asStruct(), "id", "data", "_deleted"); + + Assert.assertEquals("Table should contain expected row", expected, actual); + checkDeleteCount(4L); + } + + @Test + public void testEqualityDeleteWithDeletedColumn() throws IOException { + String tableName = table.name().substring(table.name().lastIndexOf(".") + 1); + Schema deleteRowSchema = table.schema().select("data"); + Record dataDelete = GenericRecord.create(deleteRowSchema); + List dataDeletes = + Lists.newArrayList( + dataDelete.copy("data", "a"), // id = 29 + dataDelete.copy("data", "d"), // id = 89 + dataDelete.copy("data", "g") // id = 122 + ); + + DeleteFile eqDeletes = + FileHelpers.writeDeleteFile( + table, + Files.localOutput(temp.newFile()), + TestHelpers.Row.of(0), + dataDeletes, + deleteRowSchema); + + table.newRowDelta().addDeletes(eqDeletes).commit(); + + StructLikeSet expected = expectedRowSet(29, 89, 122); + StructLikeSet actual = + rowSet(tableName, PROJECTION_SCHEMA.asStruct(), "id", "data", "_deleted"); + + Assert.assertEquals("Table should contain expected row", expected, actual); + checkDeleteCount(3L); + } + + @Test + public void testMixedPosAndEqDeletesWithDeletedColumn() throws IOException { + Schema dataSchema = table.schema().select("data"); + Record dataDelete = GenericRecord.create(dataSchema); + List dataDeletes = + Lists.newArrayList( + dataDelete.copy("data", "a"), // id = 29 + dataDelete.copy("data", "d"), // id = 89 + dataDelete.copy("data", "g") // id = 122 + ); + + DeleteFile eqDeletes = + FileHelpers.writeDeleteFile( + table, + Files.localOutput(temp.newFile()), + TestHelpers.Row.of(0), + dataDeletes, + dataSchema); + + List> deletes = + Lists.newArrayList( + Pair.of(dataFile.path(), 3L), // id = 89 + Pair.of(dataFile.path(), 5L) // id = 121 + ); + + Pair posDeletes = + FileHelpers.writeDeleteFile( + table, Files.localOutput(temp.newFile()), TestHelpers.Row.of(0), deletes); + + table + .newRowDelta() + .addDeletes(eqDeletes) + .addDeletes(posDeletes.first()) + .validateDataFilesExist(posDeletes.second()) + .commit(); + + StructLikeSet expected = expectedRowSet(29, 89, 121, 122); + StructLikeSet actual = + rowSet(tableName, PROJECTION_SCHEMA.asStruct(), "id", "data", "_deleted"); + + Assert.assertEquals("Table should contain expected row", expected, actual); + checkDeleteCount(4L); + } + + @Test + public void testFilterOnDeletedMetadataColumn() throws IOException { + List> deletes = + Lists.newArrayList( + Pair.of(dataFile.path(), 0L), // id = 29 + Pair.of(dataFile.path(), 1L), // id = 43 + Pair.of(dataFile.path(), 2L), // id = 61 + Pair.of(dataFile.path(), 3L) // id = 89 + ); + + Pair posDeletes = + FileHelpers.writeDeleteFile( + table, Files.localOutput(temp.newFile()), TestHelpers.Row.of(0), deletes); + + table + .newRowDelta() + .addDeletes(posDeletes.first()) + .validateDataFilesExist(posDeletes.second()) + .commit(); + + StructLikeSet expected = expectedRowSetWithNonDeletesOnly(29, 43, 61, 89); + + // get non-deleted rows + Dataset df = + spark + .read() + .format("iceberg") + .load(TableIdentifier.of("default", tableName).toString()) + .select("id", "data", "_deleted") + .filter("_deleted = false"); + + Types.StructType projection = PROJECTION_SCHEMA.asStruct(); + StructLikeSet actual = StructLikeSet.create(projection); + df.collectAsList() + .forEach( + row -> { + SparkStructLike rowWrapper = new SparkStructLike(projection); + actual.add(rowWrapper.wrap(row)); + }); + + Assert.assertEquals("Table should contain expected row", expected, actual); + + StructLikeSet expectedDeleted = expectedRowSetWithDeletesOnly(29, 43, 61, 89); + + // get deleted rows + df = + spark + .read() + .format("iceberg") + .load(TableIdentifier.of("default", tableName).toString()) + .select("id", "data", "_deleted") + .filter("_deleted = true"); + + StructLikeSet actualDeleted = StructLikeSet.create(projection); + df.collectAsList() + .forEach( + row -> { + SparkStructLike rowWrapper = new SparkStructLike(projection); + actualDeleted.add(rowWrapper.wrap(row)); + }); + + Assert.assertEquals("Table should contain expected row", expectedDeleted, actualDeleted); + } + + @Test + public void testIsDeletedColumnWithoutDeleteFile() { + StructLikeSet expected = expectedRowSet(); + StructLikeSet actual = + rowSet(tableName, PROJECTION_SCHEMA.asStruct(), "id", "data", "_deleted"); + Assert.assertEquals("Table should contain expected row", expected, actual); + checkDeleteCount(0L); + } + + @Test + public void testPosDeletesOnParquetFileWithMultipleRowGroups() throws IOException { + Assume.assumeTrue(format.equals("parquet")); + + String tblName = "test3"; + Table tbl = createTable(tblName, SCHEMA, PartitionSpec.unpartitioned()); + + List fileSplits = Lists.newArrayList(); + StructType sparkSchema = SparkSchemaUtil.convert(SCHEMA); + Configuration conf = new Configuration(); + File testFile = temp.newFile(); + Assert.assertTrue("Delete should succeed", testFile.delete()); + Path testFilePath = new Path(testFile.getAbsolutePath()); + + // Write a Parquet file with more than one row group + ParquetFileWriter parquetFileWriter = + new ParquetFileWriter(conf, ParquetSchemaUtil.convert(SCHEMA, "test3Schema"), testFilePath); + parquetFileWriter.start(); + for (int i = 0; i < 2; i += 1) { + File split = temp.newFile(); + Assert.assertTrue("Delete should succeed", split.delete()); + Path splitPath = new Path(split.getAbsolutePath()); + fileSplits.add(splitPath); + try (FileAppender writer = + Parquet.write(Files.localOutput(split)) + .createWriterFunc(msgType -> SparkParquetWriters.buildWriter(sparkSchema, msgType)) + .schema(SCHEMA) + .overwrite() + .build()) { + Iterable records = RandomData.generateSpark(SCHEMA, 100, 34 * i + 37); + writer.addAll(records); + } + parquetFileWriter.appendFile( + org.apache.parquet.hadoop.util.HadoopInputFile.fromPath(splitPath, conf)); + } + parquetFileWriter.end( + ParquetFileWriter.mergeMetadataFiles(fileSplits, conf) + .getFileMetaData() + .getKeyValueMetaData()); + + // Add the file to the table + DataFile dataFile = + DataFiles.builder(PartitionSpec.unpartitioned()) + .withInputFile(org.apache.iceberg.hadoop.HadoopInputFile.fromPath(testFilePath, conf)) + .withFormat("parquet") + .withRecordCount(200) + .build(); + tbl.newAppend().appendFile(dataFile).commit(); + + // Add positional deletes to the table + List> deletes = + Lists.newArrayList( + Pair.of(dataFile.path(), 97L), + Pair.of(dataFile.path(), 98L), + Pair.of(dataFile.path(), 99L), + Pair.of(dataFile.path(), 101L), + Pair.of(dataFile.path(), 103L), + Pair.of(dataFile.path(), 107L), + Pair.of(dataFile.path(), 109L)); + Pair posDeletes = + FileHelpers.writeDeleteFile(table, Files.localOutput(temp.newFile()), deletes); + tbl.newRowDelta() + .addDeletes(posDeletes.first()) + .validateDataFilesExist(posDeletes.second()) + .commit(); + + Assert.assertEquals(193, rowSet(tblName, tbl, "*").size()); + } + + private static final Schema PROJECTION_SCHEMA = + new Schema( + required(1, "id", Types.IntegerType.get()), + required(2, "data", Types.StringType.get()), + MetadataColumns.IS_DELETED); + + private static StructLikeSet expectedRowSet(int... idsToRemove) { + return expectedRowSet(false, false, idsToRemove); + } + + private static StructLikeSet expectedRowSetWithDeletesOnly(int... idsToRemove) { + return expectedRowSet(false, true, idsToRemove); + } + + private static StructLikeSet expectedRowSetWithNonDeletesOnly(int... idsToRemove) { + return expectedRowSet(true, false, idsToRemove); + } + + private static StructLikeSet expectedRowSet( + boolean removeDeleted, boolean removeNonDeleted, int... idsToRemove) { + Set deletedIds = Sets.newHashSet(ArrayUtil.toIntList(idsToRemove)); + List records = recordsWithDeletedColumn(); + // mark rows deleted + records.forEach( + record -> { + if (deletedIds.contains(record.getField("id"))) { + record.setField(MetadataColumns.IS_DELETED.name(), true); + } + }); + + records.removeIf(record -> deletedIds.contains(record.getField("id")) && removeDeleted); + records.removeIf(record -> !deletedIds.contains(record.getField("id")) && removeNonDeleted); + + StructLikeSet set = StructLikeSet.create(PROJECTION_SCHEMA.asStruct()); + records.forEach( + record -> set.add(new InternalRecordWrapper(PROJECTION_SCHEMA.asStruct()).wrap(record))); + + return set; + } + + @NotNull + private static List recordsWithDeletedColumn() { + List records = Lists.newArrayList(); + + // records all use IDs that are in bucket id_bucket=0 + GenericRecord record = GenericRecord.create(PROJECTION_SCHEMA); + records.add(record.copy("id", 29, "data", "a", "_deleted", false)); + records.add(record.copy("id", 43, "data", "b", "_deleted", false)); + records.add(record.copy("id", 61, "data", "c", "_deleted", false)); + records.add(record.copy("id", 89, "data", "d", "_deleted", false)); + records.add(record.copy("id", 100, "data", "e", "_deleted", false)); + records.add(record.copy("id", 121, "data", "f", "_deleted", false)); + records.add(record.copy("id", 122, "data", "g", "_deleted", false)); + return records; + } +} diff --git a/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/source/TestSparkReaderWithBloomFilter.java b/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/source/TestSparkReaderWithBloomFilter.java new file mode 100644 index 000000000000..e5831b76e424 --- /dev/null +++ b/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/source/TestSparkReaderWithBloomFilter.java @@ -0,0 +1,373 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.source; + +import static org.apache.hadoop.hive.conf.HiveConf.ConfVars.METASTOREURIS; +import static org.apache.iceberg.TableProperties.DEFAULT_FILE_FORMAT; +import static org.apache.iceberg.TableProperties.DEFAULT_FILE_FORMAT_DEFAULT; +import static org.apache.iceberg.TableProperties.PARQUET_BLOOM_FILTER_COLUMN_ENABLED_PREFIX; +import static org.apache.iceberg.TableProperties.PARQUET_ROW_GROUP_SIZE_BYTES; +import static org.apache.iceberg.TableProperties.PARQUET_ROW_GROUP_SIZE_BYTES_DEFAULT; + +import java.io.Closeable; +import java.io.IOException; +import java.math.BigDecimal; +import java.time.LocalDate; +import java.util.List; +import java.util.Map; +import org.apache.hadoop.hive.conf.HiveConf; +import org.apache.iceberg.BaseTable; +import org.apache.iceberg.CatalogUtil; +import org.apache.iceberg.DataFile; +import org.apache.iceberg.DataFiles; +import org.apache.iceberg.FileFormat; +import org.apache.iceberg.Files; +import org.apache.iceberg.Schema; +import org.apache.iceberg.StructLike; +import org.apache.iceberg.Table; +import org.apache.iceberg.TableMetadata; +import org.apache.iceberg.TableOperations; +import org.apache.iceberg.TableProperties; +import org.apache.iceberg.TestHelpers.Row; +import org.apache.iceberg.catalog.Namespace; +import org.apache.iceberg.catalog.TableIdentifier; +import org.apache.iceberg.data.GenericAppenderFactory; +import org.apache.iceberg.data.GenericRecord; +import org.apache.iceberg.data.Record; +import org.apache.iceberg.exceptions.AlreadyExistsException; +import org.apache.iceberg.hive.HiveCatalog; +import org.apache.iceberg.hive.TestHiveMetastore; +import org.apache.iceberg.io.FileAppender; +import org.apache.iceberg.io.OutputFile; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap; +import org.apache.iceberg.relocated.com.google.common.collect.Lists; +import org.apache.iceberg.spark.SparkValueConverter; +import org.apache.iceberg.types.Types; +import org.apache.iceberg.util.PropertyUtil; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.SparkSession; +import org.junit.After; +import org.junit.AfterClass; +import org.junit.Assert; +import org.junit.Before; +import org.junit.BeforeClass; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.TemporaryFolder; +import org.junit.runner.RunWith; +import org.junit.runners.Parameterized; + +@RunWith(Parameterized.class) +public class TestSparkReaderWithBloomFilter { + + protected String tableName = null; + protected Table table = null; + protected List records = null; + protected DataFile dataFile = null; + + private static TestHiveMetastore metastore = null; + protected static SparkSession spark = null; + protected static HiveCatalog catalog = null; + protected final boolean vectorized; + protected final boolean useBloomFilter; + + public TestSparkReaderWithBloomFilter(boolean vectorized, boolean useBloomFilter) { + this.vectorized = vectorized; + this.useBloomFilter = useBloomFilter; + } + + // Schema passed to create tables + public static final Schema SCHEMA = + new Schema( + Types.NestedField.required(1, "id", Types.IntegerType.get()), + Types.NestedField.required(2, "id_long", Types.LongType.get()), + Types.NestedField.required(3, "id_double", Types.DoubleType.get()), + Types.NestedField.required(4, "id_float", Types.FloatType.get()), + Types.NestedField.required(5, "id_string", Types.StringType.get()), + Types.NestedField.optional(6, "id_boolean", Types.BooleanType.get()), + Types.NestedField.optional(7, "id_date", Types.DateType.get()), + Types.NestedField.optional(8, "id_int_decimal", Types.DecimalType.of(8, 2)), + Types.NestedField.optional(9, "id_long_decimal", Types.DecimalType.of(14, 2)), + Types.NestedField.optional(10, "id_fixed_decimal", Types.DecimalType.of(31, 2))); + + private static final int INT_MIN_VALUE = 30; + private static final int INT_MAX_VALUE = 329; + private static final int INT_VALUE_COUNT = INT_MAX_VALUE - INT_MIN_VALUE + 1; + private static final long LONG_BASE = 1000L; + private static final double DOUBLE_BASE = 10000D; + private static final float FLOAT_BASE = 100000F; + private static final String BINARY_PREFIX = "BINARY测试_"; + + @Rule public TemporaryFolder temp = new TemporaryFolder(); + + @Before + public void writeTestDataFile() throws IOException { + this.tableName = "test"; + createTable(tableName, SCHEMA); + this.records = Lists.newArrayList(); + + // records all use IDs that are in bucket id_bucket=0 + GenericRecord record = GenericRecord.create(table.schema()); + + for (int i = 0; i < INT_VALUE_COUNT; i += 1) { + records.add( + record.copy( + ImmutableMap.of( + "id", + INT_MIN_VALUE + i, + "id_long", + LONG_BASE + INT_MIN_VALUE + i, + "id_double", + DOUBLE_BASE + INT_MIN_VALUE + i, + "id_float", + FLOAT_BASE + INT_MIN_VALUE + i, + "id_string", + BINARY_PREFIX + (INT_MIN_VALUE + i), + "id_boolean", + i % 2 == 0, + "id_date", + LocalDate.parse("2021-09-05"), + "id_int_decimal", + new BigDecimal(String.valueOf(77.77)), + "id_long_decimal", + new BigDecimal(String.valueOf(88.88)), + "id_fixed_decimal", + new BigDecimal(String.valueOf(99.99))))); + } + + this.dataFile = writeDataFile(Files.localOutput(temp.newFile()), Row.of(0), records); + + table.newAppend().appendFile(dataFile).commit(); + } + + @After + public void cleanup() throws IOException { + dropTable("test"); + } + + @Parameterized.Parameters(name = "vectorized = {0}, useBloomFilter = {1}") + public static Object[][] parameters() { + return new Object[][] {{false, false}, {true, false}, {false, true}, {true, true}}; + } + + @BeforeClass + public static void startMetastoreAndSpark() { + metastore = new TestHiveMetastore(); + metastore.start(); + HiveConf hiveConf = metastore.hiveConf(); + + spark = + SparkSession.builder() + .master("local[2]") + .config("spark.hadoop." + METASTOREURIS.varname, hiveConf.get(METASTOREURIS.varname)) + .enableHiveSupport() + .getOrCreate(); + + catalog = + (HiveCatalog) + CatalogUtil.loadCatalog( + HiveCatalog.class.getName(), "hive", ImmutableMap.of(), hiveConf); + + try { + catalog.createNamespace(Namespace.of("default")); + } catch (AlreadyExistsException ignored) { + // the default namespace already exists. ignore the create error + } + } + + @AfterClass + public static void stopMetastoreAndSpark() throws Exception { + catalog = null; + metastore.stop(); + metastore = null; + spark.stop(); + spark = null; + } + + protected void createTable(String name, Schema schema) { + table = catalog.createTable(TableIdentifier.of("default", name), schema); + TableOperations ops = ((BaseTable) table).operations(); + TableMetadata meta = ops.current(); + ops.commit(meta, meta.upgradeToFormatVersion(2)); + + if (useBloomFilter) { + table + .updateProperties() + .set(PARQUET_BLOOM_FILTER_COLUMN_ENABLED_PREFIX + "id", "true") + .set(PARQUET_BLOOM_FILTER_COLUMN_ENABLED_PREFIX + "id_long", "true") + .set(PARQUET_BLOOM_FILTER_COLUMN_ENABLED_PREFIX + "id_double", "true") + .set(PARQUET_BLOOM_FILTER_COLUMN_ENABLED_PREFIX + "id_float", "true") + .set(PARQUET_BLOOM_FILTER_COLUMN_ENABLED_PREFIX + "id_string", "true") + .set(PARQUET_BLOOM_FILTER_COLUMN_ENABLED_PREFIX + "id_boolean", "true") + .set(PARQUET_BLOOM_FILTER_COLUMN_ENABLED_PREFIX + "id_date", "true") + .set(PARQUET_BLOOM_FILTER_COLUMN_ENABLED_PREFIX + "id_int_decimal", "true") + .set(PARQUET_BLOOM_FILTER_COLUMN_ENABLED_PREFIX + "id_long_decimal", "true") + .set(PARQUET_BLOOM_FILTER_COLUMN_ENABLED_PREFIX + "id_fixed_decimal", "true") + .commit(); + } + + table + .updateProperties() + .set(TableProperties.PARQUET_ROW_GROUP_SIZE_BYTES, "100") // to have multiple row groups + .commit(); + if (vectorized) { + table + .updateProperties() + .set(TableProperties.PARQUET_VECTORIZATION_ENABLED, "true") + .set(TableProperties.PARQUET_BATCH_SIZE, "4") + .commit(); + } + } + + protected void dropTable(String name) { + catalog.dropTable(TableIdentifier.of("default", name)); + } + + private DataFile writeDataFile(OutputFile out, StructLike partition, List rows) + throws IOException { + FileFormat format = defaultFormat(table.properties()); + GenericAppenderFactory factory = new GenericAppenderFactory(table.schema(), table.spec()); + + boolean useBloomFilterCol1 = + PropertyUtil.propertyAsBoolean( + table.properties(), PARQUET_BLOOM_FILTER_COLUMN_ENABLED_PREFIX + "id", false); + factory.set( + PARQUET_BLOOM_FILTER_COLUMN_ENABLED_PREFIX + "id", Boolean.toString(useBloomFilterCol1)); + boolean useBloomFilterCol2 = + PropertyUtil.propertyAsBoolean( + table.properties(), PARQUET_BLOOM_FILTER_COLUMN_ENABLED_PREFIX + "id_long", false); + factory.set( + PARQUET_BLOOM_FILTER_COLUMN_ENABLED_PREFIX + "id_long", + Boolean.toString(useBloomFilterCol2)); + boolean useBloomFilterCol3 = + PropertyUtil.propertyAsBoolean( + table.properties(), PARQUET_BLOOM_FILTER_COLUMN_ENABLED_PREFIX + "id_double", false); + factory.set( + PARQUET_BLOOM_FILTER_COLUMN_ENABLED_PREFIX + "id_double", + Boolean.toString(useBloomFilterCol3)); + boolean useBloomFilterCol4 = + PropertyUtil.propertyAsBoolean( + table.properties(), PARQUET_BLOOM_FILTER_COLUMN_ENABLED_PREFIX + "id_float", false); + factory.set( + PARQUET_BLOOM_FILTER_COLUMN_ENABLED_PREFIX + "id_float", + Boolean.toString(useBloomFilterCol4)); + boolean useBloomFilterCol5 = + PropertyUtil.propertyAsBoolean( + table.properties(), PARQUET_BLOOM_FILTER_COLUMN_ENABLED_PREFIX + "id_string", false); + factory.set( + PARQUET_BLOOM_FILTER_COLUMN_ENABLED_PREFIX + "id_string", + Boolean.toString(useBloomFilterCol5)); + boolean useBloomFilterCol6 = + PropertyUtil.propertyAsBoolean( + table.properties(), PARQUET_BLOOM_FILTER_COLUMN_ENABLED_PREFIX + "id_boolean", false); + factory.set( + PARQUET_BLOOM_FILTER_COLUMN_ENABLED_PREFIX + "id_boolean", + Boolean.toString(useBloomFilterCol6)); + boolean useBloomFilterCol7 = + PropertyUtil.propertyAsBoolean( + table.properties(), PARQUET_BLOOM_FILTER_COLUMN_ENABLED_PREFIX + "id_date", false); + factory.set( + PARQUET_BLOOM_FILTER_COLUMN_ENABLED_PREFIX + "id_date", + Boolean.toString(useBloomFilterCol7)); + boolean useBloomFilterCol8 = + PropertyUtil.propertyAsBoolean( + table.properties(), + PARQUET_BLOOM_FILTER_COLUMN_ENABLED_PREFIX + "id_int_decimal", + false); + factory.set( + PARQUET_BLOOM_FILTER_COLUMN_ENABLED_PREFIX + "id_int_decimal", + Boolean.toString(useBloomFilterCol8)); + boolean useBloomFilterCol9 = + PropertyUtil.propertyAsBoolean( + table.properties(), + PARQUET_BLOOM_FILTER_COLUMN_ENABLED_PREFIX + "id_long_decimal", + false); + factory.set( + PARQUET_BLOOM_FILTER_COLUMN_ENABLED_PREFIX + "id_long_decimal", + Boolean.toString(useBloomFilterCol9)); + boolean useBloomFilterCol10 = + PropertyUtil.propertyAsBoolean( + table.properties(), + PARQUET_BLOOM_FILTER_COLUMN_ENABLED_PREFIX + "id_fixed_decimal", + false); + factory.set( + PARQUET_BLOOM_FILTER_COLUMN_ENABLED_PREFIX + "id_fixed_decimal", + Boolean.toString(useBloomFilterCol10)); + int blockSize = + PropertyUtil.propertyAsInt( + table.properties(), PARQUET_ROW_GROUP_SIZE_BYTES, PARQUET_ROW_GROUP_SIZE_BYTES_DEFAULT); + factory.set(PARQUET_ROW_GROUP_SIZE_BYTES, Integer.toString(blockSize)); + + FileAppender writer = factory.newAppender(out, format); + try (Closeable toClose = writer) { + writer.addAll(rows); + } + + return DataFiles.builder(table.spec()) + .withFormat(format) + .withPath(out.location()) + .withPartition(partition) + .withFileSizeInBytes(writer.length()) + .withSplitOffsets(writer.splitOffsets()) + .withMetrics(writer.metrics()) + .build(); + } + + private FileFormat defaultFormat(Map properties) { + String formatString = properties.getOrDefault(DEFAULT_FILE_FORMAT, DEFAULT_FILE_FORMAT_DEFAULT); + return FileFormat.fromString(formatString); + } + + @Test + public void testReadWithFilter() { + Dataset df = + spark + .read() + .format("iceberg") + .load(TableIdentifier.of("default", tableName).toString()) + // this is from the first row group + .filter( + "id = 30 AND id_long = 1030 AND id_double = 10030.0 AND id_float = 100030.0" + + " AND id_string = 'BINARY测试_30' AND id_boolean = true AND id_date = '2021-09-05'" + + " AND id_int_decimal = 77.77 AND id_long_decimal = 88.88 AND id_fixed_decimal = 99.99"); + + Record record = SparkValueConverter.convert(table.schema(), df.collectAsList().get(0)); + + Assert.assertEquals("Table should contain 1 row", 1, df.collectAsList().size()); + + Assert.assertEquals("Table should contain expected rows", record.get(0), 30); + + df = + spark + .read() + .format("iceberg") + .load(TableIdentifier.of("default", tableName).toString()) + // this is from the third row group + .filter( + "id = 250 AND id_long = 1250 AND id_double = 10250.0 AND id_float = 100250.0" + + " AND id_string = 'BINARY测试_250' AND id_boolean = true AND id_date = '2021-09-05'" + + " AND id_int_decimal = 77.77 AND id_long_decimal = 88.88 AND id_fixed_decimal = 99.99"); + + record = SparkValueConverter.convert(table.schema(), df.collectAsList().get(0)); + + Assert.assertEquals("Table should contain 1 row", 1, df.collectAsList().size()); + + Assert.assertEquals("Table should contain expected rows", record.get(0), 250); + } +} diff --git a/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/source/TestSparkRollingFileWriters.java b/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/source/TestSparkRollingFileWriters.java new file mode 100644 index 000000000000..dcf9140a8885 --- /dev/null +++ b/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/source/TestSparkRollingFileWriters.java @@ -0,0 +1,60 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.source; + +import java.util.List; +import org.apache.iceberg.FileFormat; +import org.apache.iceberg.Schema; +import org.apache.iceberg.io.FileWriterFactory; +import org.apache.iceberg.io.TestRollingFileWriters; +import org.apache.iceberg.util.ArrayUtil; +import org.apache.spark.sql.catalyst.InternalRow; +import org.apache.spark.sql.catalyst.expressions.GenericInternalRow; +import org.apache.spark.unsafe.types.UTF8String; + +public class TestSparkRollingFileWriters extends TestRollingFileWriters { + + public TestSparkRollingFileWriters(FileFormat fileFormat, boolean partitioned) { + super(fileFormat, partitioned); + } + + @Override + protected FileWriterFactory newWriterFactory( + Schema dataSchema, + List equalityFieldIds, + Schema equalityDeleteRowSchema, + Schema positionDeleteRowSchema) { + return SparkFileWriterFactory.builderFor(table) + .dataSchema(table.schema()) + .dataFileFormat(format()) + .deleteFileFormat(format()) + .equalityFieldIds(ArrayUtil.toIntArray(equalityFieldIds)) + .equalityDeleteRowSchema(equalityDeleteRowSchema) + .positionDeleteRowSchema(positionDeleteRowSchema) + .build(); + } + + @Override + protected InternalRow toRow(Integer id, String data) { + InternalRow row = new GenericInternalRow(2); + row.update(0, id); + row.update(1, UTF8String.fromString(data)); + return row; + } +} diff --git a/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/source/TestSparkScan.java b/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/source/TestSparkScan.java new file mode 100644 index 000000000000..905a4e7dfef6 --- /dev/null +++ b/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/source/TestSparkScan.java @@ -0,0 +1,79 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.source; + +import static org.apache.spark.sql.functions.date_add; +import static org.apache.spark.sql.functions.expr; + +import org.apache.iceberg.Table; +import org.apache.iceberg.TableProperties; +import org.apache.iceberg.spark.SparkTestBaseWithCatalog; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.catalyst.analysis.NoSuchTableException; +import org.apache.spark.sql.connector.read.Statistics; +import org.apache.spark.sql.util.CaseInsensitiveStringMap; +import org.junit.After; +import org.junit.Assert; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.Parameterized; + +@RunWith(Parameterized.class) +public class TestSparkScan extends SparkTestBaseWithCatalog { + + private final String format; + + @Parameterized.Parameters(name = "format = {0}") + public static Object[] parameters() { + return new Object[] {"parquet", "avro", "orc"}; + } + + public TestSparkScan(String format) { + this.format = format; + } + + @After + public void removeTables() { + sql("DROP TABLE IF EXISTS %s", tableName); + } + + @Test + public void testEstimatedRowCount() throws NoSuchTableException { + sql( + "CREATE TABLE %s (id BIGINT, date DATE) USING iceberg TBLPROPERTIES('%s' = '%s')", + tableName, TableProperties.DEFAULT_FILE_FORMAT, format); + + Dataset df = + spark + .range(10000) + .withColumn("date", date_add(expr("DATE '1970-01-01'"), expr("CAST(id AS INT)"))) + .select("id", "date"); + + df.coalesce(1).writeTo(tableName).append(); + + Table table = validationCatalog.loadTable(tableIdent); + SparkScanBuilder scanBuilder = + new SparkScanBuilder(spark, table, CaseInsensitiveStringMap.empty()); + SparkScan scan = (SparkScan) scanBuilder.build(); + Statistics stats = scan.estimateStatistics(); + + Assert.assertEquals(10000L, stats.numRows().getAsLong()); + } +} diff --git a/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/source/TestSparkStagedScan.java b/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/source/TestSparkStagedScan.java new file mode 100644 index 000000000000..241293f367aa --- /dev/null +++ b/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/source/TestSparkStagedScan.java @@ -0,0 +1,127 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.source; + +import java.io.IOException; +import java.util.List; +import java.util.Map; +import java.util.UUID; +import org.apache.iceberg.FileScanTask; +import org.apache.iceberg.Table; +import org.apache.iceberg.io.CloseableIterable; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableList; +import org.apache.iceberg.relocated.com.google.common.collect.Iterables; +import org.apache.iceberg.spark.ScanTaskSetManager; +import org.apache.iceberg.spark.SparkCatalogTestBase; +import org.apache.iceberg.spark.SparkReadOptions; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.catalyst.analysis.NoSuchTableException; +import org.junit.After; +import org.junit.Assert; +import org.junit.Test; + +public class TestSparkStagedScan extends SparkCatalogTestBase { + + public TestSparkStagedScan( + String catalogName, String implementation, Map config) { + super(catalogName, implementation, config); + } + + @After + public void removeTables() { + sql("DROP TABLE IF EXISTS %s", tableName); + } + + @Test + public void testTaskSetLoading() throws NoSuchTableException, IOException { + sql("CREATE TABLE %s (id INT, data STRING) USING iceberg", tableName); + + List records = + ImmutableList.of(new SimpleRecord(1, "a"), new SimpleRecord(2, "b")); + Dataset df = spark.createDataFrame(records, SimpleRecord.class); + df.writeTo(tableName).append(); + + Table table = validationCatalog.loadTable(tableIdent); + Assert.assertEquals("Should produce 1 snapshot", 1, Iterables.size(table.snapshots())); + + try (CloseableIterable fileScanTasks = table.newScan().planFiles()) { + ScanTaskSetManager taskSetManager = ScanTaskSetManager.get(); + String setID = UUID.randomUUID().toString(); + taskSetManager.stageTasks(table, setID, ImmutableList.copyOf(fileScanTasks)); + + // load the staged file set + Dataset scanDF = + spark + .read() + .format("iceberg") + .option(SparkReadOptions.SCAN_TASK_SET_ID, setID) + .load(tableName); + + // write the records back essentially duplicating data + scanDF.writeTo(tableName).append(); + } + + assertEquals( + "Should have expected rows", + ImmutableList.of(row(1, "a"), row(1, "a"), row(2, "b"), row(2, "b")), + sql("SELECT * FROM %s ORDER BY id", tableName)); + } + + @Test + public void testTaskSetPlanning() throws NoSuchTableException, IOException { + sql("CREATE TABLE %s (id INT, data STRING) USING iceberg", tableName); + + List records = + ImmutableList.of(new SimpleRecord(1, "a"), new SimpleRecord(2, "b")); + Dataset df = spark.createDataFrame(records, SimpleRecord.class); + df.coalesce(1).writeTo(tableName).append(); + df.coalesce(1).writeTo(tableName).append(); + + Table table = validationCatalog.loadTable(tableIdent); + Assert.assertEquals("Should produce 2 snapshots", 2, Iterables.size(table.snapshots())); + + try (CloseableIterable fileScanTasks = table.newScan().planFiles()) { + ScanTaskSetManager taskSetManager = ScanTaskSetManager.get(); + String setID = UUID.randomUUID().toString(); + List tasks = ImmutableList.copyOf(fileScanTasks); + taskSetManager.stageTasks(table, setID, tasks); + + // load the staged file set and make sure each file is in a separate split + Dataset scanDF = + spark + .read() + .format("iceberg") + .option(SparkReadOptions.SCAN_TASK_SET_ID, setID) + .option(SparkReadOptions.SPLIT_SIZE, tasks.get(0).file().fileSizeInBytes()) + .load(tableName); + Assert.assertEquals("Num partitions should match", 2, scanDF.javaRDD().getNumPartitions()); + + // load the staged file set and make sure we combine both files into a single split + scanDF = + spark + .read() + .format("iceberg") + .option(SparkReadOptions.SCAN_TASK_SET_ID, setID) + .option(SparkReadOptions.SPLIT_SIZE, Long.MAX_VALUE) + .load(tableName); + Assert.assertEquals("Num partitions should match", 1, scanDF.javaRDD().getNumPartitions()); + } + } +} diff --git a/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/source/TestSparkTable.java b/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/source/TestSparkTable.java new file mode 100644 index 000000000000..616a196872de --- /dev/null +++ b/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/source/TestSparkTable.java @@ -0,0 +1,60 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.source; + +import java.util.Map; +import org.apache.iceberg.spark.SparkCatalogTestBase; +import org.apache.spark.sql.catalyst.analysis.NoSuchTableException; +import org.apache.spark.sql.connector.catalog.CatalogManager; +import org.apache.spark.sql.connector.catalog.Identifier; +import org.apache.spark.sql.connector.catalog.TableCatalog; +import org.junit.After; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; + +public class TestSparkTable extends SparkCatalogTestBase { + + public TestSparkTable(String catalogName, String implementation, Map config) { + super(catalogName, implementation, config); + } + + @Before + public void createTable() { + sql("CREATE TABLE %s (id bigint NOT NULL, data string) USING iceberg", tableName); + } + + @After + public void removeTable() { + sql("DROP TABLE IF EXISTS %s", tableName); + } + + @Test + public void testTableEquality() throws NoSuchTableException { + CatalogManager catalogManager = spark.sessionState().catalogManager(); + TableCatalog catalog = (TableCatalog) catalogManager.catalog(catalogName); + Identifier identifier = Identifier.of(tableIdent.namespace().levels(), tableIdent.name()); + SparkTable table1 = (SparkTable) catalog.loadTable(identifier); + SparkTable table2 = (SparkTable) catalog.loadTable(identifier); + + // different instances pointing to the same table must be equivalent + Assert.assertNotSame("References must be different", table1, table2); + Assert.assertEquals("Tables must be equivalent", table1, table2); + } +} diff --git a/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/source/TestSparkWriterMetrics.java b/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/source/TestSparkWriterMetrics.java new file mode 100644 index 000000000000..06ecc20c2fc3 --- /dev/null +++ b/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/source/TestSparkWriterMetrics.java @@ -0,0 +1,67 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.source; + +import org.apache.iceberg.FileFormat; +import org.apache.iceberg.Table; +import org.apache.iceberg.io.FileWriterFactory; +import org.apache.iceberg.io.TestWriterMetrics; +import org.apache.spark.sql.catalyst.InternalRow; +import org.apache.spark.sql.catalyst.expressions.GenericInternalRow; +import org.apache.spark.unsafe.types.UTF8String; + +public class TestSparkWriterMetrics extends TestWriterMetrics { + + public TestSparkWriterMetrics(FileFormat fileFormat) { + super(fileFormat); + } + + @Override + protected FileWriterFactory newWriterFactory(Table sourceTable) { + return SparkFileWriterFactory.builderFor(sourceTable) + .dataSchema(sourceTable.schema()) + .dataFileFormat(fileFormat) + .deleteFileFormat(fileFormat) + .positionDeleteRowSchema(sourceTable.schema()) + .build(); + } + + @Override + protected InternalRow toRow(Integer id, String data, boolean boolValue, Long longValue) { + InternalRow row = new GenericInternalRow(3); + row.update(0, id); + row.update(1, UTF8String.fromString(data)); + + InternalRow nested = new GenericInternalRow(2); + nested.update(0, boolValue); + nested.update(1, longValue); + + row.update(2, nested); + return row; + } + + @Override + protected InternalRow toGenericRow(int value, int repeated) { + InternalRow row = new GenericInternalRow(repeated); + for (int i = 0; i < repeated; i++) { + row.update(i, value); + } + return row; + } +} diff --git a/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/source/TestStreamingOffset.java b/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/source/TestStreamingOffset.java new file mode 100644 index 000000000000..17370aaa22f2 --- /dev/null +++ b/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/source/TestStreamingOffset.java @@ -0,0 +1,56 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.source; + +import com.fasterxml.jackson.databind.node.ObjectNode; +import java.util.Arrays; +import org.apache.iceberg.util.JsonUtil; +import org.junit.Assert; +import org.junit.Test; + +public class TestStreamingOffset { + + @Test + public void testJsonConversion() { + StreamingOffset[] expected = + new StreamingOffset[] { + new StreamingOffset(System.currentTimeMillis(), 1L, false), + new StreamingOffset(System.currentTimeMillis(), 2L, false), + new StreamingOffset(System.currentTimeMillis(), 3L, false), + new StreamingOffset(System.currentTimeMillis(), 4L, true) + }; + Assert.assertArrayEquals( + "StreamingOffsets should match", + expected, + Arrays.stream(expected).map(elem -> StreamingOffset.fromJson(elem.json())).toArray()); + } + + @Test + public void testToJson() throws Exception { + StreamingOffset expected = new StreamingOffset(System.currentTimeMillis(), 1L, false); + ObjectNode actual = JsonUtil.mapper().createObjectNode(); + actual.put("version", 1); + actual.put("snapshot_id", expected.snapshotId()); + actual.put("position", 1L); + actual.put("scan_all_files", false); + String expectedJson = expected.json(); + String actualJson = JsonUtil.mapper().writeValueAsString(actual); + Assert.assertEquals("Json should match", expectedJson, actualJson); + } +} diff --git a/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/source/TestStructuredStreaming.java b/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/source/TestStructuredStreaming.java new file mode 100644 index 000000000000..464f1f5922b3 --- /dev/null +++ b/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/source/TestStructuredStreaming.java @@ -0,0 +1,303 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.source; + +import static org.apache.iceberg.types.Types.NestedField.optional; + +import java.io.File; +import java.util.List; +import org.apache.hadoop.conf.Configuration; +import org.apache.iceberg.PartitionSpec; +import org.apache.iceberg.Schema; +import org.apache.iceberg.Table; +import org.apache.iceberg.hadoop.HadoopTables; +import org.apache.iceberg.relocated.com.google.common.collect.Iterables; +import org.apache.iceberg.relocated.com.google.common.collect.Lists; +import org.apache.iceberg.types.Types; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Encoder; +import org.apache.spark.sql.Encoders; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.SQLContext; +import org.apache.spark.sql.SparkSession; +import org.apache.spark.sql.execution.streaming.MemoryStream; +import org.apache.spark.sql.streaming.DataStreamWriter; +import org.apache.spark.sql.streaming.StreamingQuery; +import org.apache.spark.sql.streaming.StreamingQueryException; +import org.assertj.core.api.Assertions; +import org.junit.AfterClass; +import org.junit.Assert; +import org.junit.BeforeClass; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.TemporaryFolder; +import scala.Option; +import scala.collection.JavaConverters; + +public class TestStructuredStreaming { + + private static final Configuration CONF = new Configuration(); + private static final Schema SCHEMA = + new Schema( + optional(1, "id", Types.IntegerType.get()), optional(2, "data", Types.StringType.get())); + private static SparkSession spark = null; + + @Rule public TemporaryFolder temp = new TemporaryFolder(); + + @BeforeClass + public static void startSpark() { + TestStructuredStreaming.spark = + SparkSession.builder() + .master("local[2]") + .config("spark.sql.shuffle.partitions", 4) + .getOrCreate(); + } + + @AfterClass + public static void stopSpark() { + SparkSession currentSpark = TestStructuredStreaming.spark; + TestStructuredStreaming.spark = null; + currentSpark.stop(); + } + + @Test + public void testStreamingWriteAppendMode() throws Exception { + File parent = temp.newFolder("parquet"); + File location = new File(parent, "test-table"); + File checkpoint = new File(parent, "checkpoint"); + + HadoopTables tables = new HadoopTables(CONF); + PartitionSpec spec = PartitionSpec.builderFor(SCHEMA).identity("data").build(); + Table table = tables.create(SCHEMA, spec, location.toString()); + + List expected = + Lists.newArrayList( + new SimpleRecord(1, "1"), + new SimpleRecord(2, "2"), + new SimpleRecord(3, "3"), + new SimpleRecord(4, "4")); + + MemoryStream inputStream = newMemoryStream(1, spark.sqlContext(), Encoders.INT()); + DataStreamWriter streamWriter = + inputStream + .toDF() + .selectExpr("value AS id", "CAST (value AS STRING) AS data") + .writeStream() + .outputMode("append") + .format("iceberg") + .option("checkpointLocation", checkpoint.toString()) + .option("path", location.toString()); + + try { + // start the original query with checkpointing + StreamingQuery query = streamWriter.start(); + List batch1 = Lists.newArrayList(1, 2); + send(batch1, inputStream); + query.processAllAvailable(); + List batch2 = Lists.newArrayList(3, 4); + send(batch2, inputStream); + query.processAllAvailable(); + query.stop(); + + // remove the last commit to force Spark to reprocess batch #1 + File lastCommitFile = new File(checkpoint.toString() + "/commits/1"); + Assert.assertTrue("The commit file must be deleted", lastCommitFile.delete()); + + // restart the query from the checkpoint + StreamingQuery restartedQuery = streamWriter.start(); + restartedQuery.processAllAvailable(); + + // ensure the write was idempotent + Dataset result = spark.read().format("iceberg").load(location.toString()); + List actual = + result.orderBy("id").as(Encoders.bean(SimpleRecord.class)).collectAsList(); + Assert.assertEquals("Number of rows should match", expected.size(), actual.size()); + Assert.assertEquals("Result rows should match", expected, actual); + Assert.assertEquals("Number of snapshots should match", 2, Iterables.size(table.snapshots())); + } finally { + for (StreamingQuery query : spark.streams().active()) { + query.stop(); + } + } + } + + @Test + public void testStreamingWriteCompleteMode() throws Exception { + File parent = temp.newFolder("parquet"); + File location = new File(parent, "test-table"); + File checkpoint = new File(parent, "checkpoint"); + + HadoopTables tables = new HadoopTables(CONF); + PartitionSpec spec = PartitionSpec.builderFor(SCHEMA).identity("data").build(); + Table table = tables.create(SCHEMA, spec, location.toString()); + + List expected = + Lists.newArrayList( + new SimpleRecord(2, "1"), new SimpleRecord(3, "2"), new SimpleRecord(1, "3")); + + MemoryStream inputStream = newMemoryStream(1, spark.sqlContext(), Encoders.INT()); + DataStreamWriter streamWriter = + inputStream + .toDF() + .groupBy("value") + .count() + .selectExpr("CAST(count AS INT) AS id", "CAST (value AS STRING) AS data") + .writeStream() + .outputMode("complete") + .format("iceberg") + .option("checkpointLocation", checkpoint.toString()) + .option("path", location.toString()); + + try { + // start the original query with checkpointing + StreamingQuery query = streamWriter.start(); + List batch1 = Lists.newArrayList(1, 2); + send(batch1, inputStream); + query.processAllAvailable(); + List batch2 = Lists.newArrayList(1, 2, 2, 3); + send(batch2, inputStream); + query.processAllAvailable(); + query.stop(); + + // remove the last commit to force Spark to reprocess batch #1 + File lastCommitFile = new File(checkpoint.toString() + "/commits/1"); + Assert.assertTrue("The commit file must be deleted", lastCommitFile.delete()); + + // restart the query from the checkpoint + StreamingQuery restartedQuery = streamWriter.start(); + restartedQuery.processAllAvailable(); + + // ensure the write was idempotent + Dataset result = spark.read().format("iceberg").load(location.toString()); + List actual = + result.orderBy("data").as(Encoders.bean(SimpleRecord.class)).collectAsList(); + Assert.assertEquals("Number of rows should match", expected.size(), actual.size()); + Assert.assertEquals("Result rows should match", expected, actual); + Assert.assertEquals("Number of snapshots should match", 2, Iterables.size(table.snapshots())); + } finally { + for (StreamingQuery query : spark.streams().active()) { + query.stop(); + } + } + } + + @Test + public void testStreamingWriteCompleteModeWithProjection() throws Exception { + File parent = temp.newFolder("parquet"); + File location = new File(parent, "test-table"); + File checkpoint = new File(parent, "checkpoint"); + + HadoopTables tables = new HadoopTables(CONF); + PartitionSpec spec = PartitionSpec.unpartitioned(); + Table table = tables.create(SCHEMA, spec, location.toString()); + + List expected = + Lists.newArrayList( + new SimpleRecord(1, null), new SimpleRecord(2, null), new SimpleRecord(3, null)); + + MemoryStream inputStream = newMemoryStream(1, spark.sqlContext(), Encoders.INT()); + DataStreamWriter streamWriter = + inputStream + .toDF() + .groupBy("value") + .count() + .selectExpr("CAST(count AS INT) AS id") // select only id column + .writeStream() + .outputMode("complete") + .format("iceberg") + .option("checkpointLocation", checkpoint.toString()) + .option("path", location.toString()); + + try { + // start the original query with checkpointing + StreamingQuery query = streamWriter.start(); + List batch1 = Lists.newArrayList(1, 2); + send(batch1, inputStream); + query.processAllAvailable(); + List batch2 = Lists.newArrayList(1, 2, 2, 3); + send(batch2, inputStream); + query.processAllAvailable(); + query.stop(); + + // remove the last commit to force Spark to reprocess batch #1 + File lastCommitFile = new File(checkpoint.toString() + "/commits/1"); + Assert.assertTrue("The commit file must be deleted", lastCommitFile.delete()); + + // restart the query from the checkpoint + StreamingQuery restartedQuery = streamWriter.start(); + restartedQuery.processAllAvailable(); + + // ensure the write was idempotent + Dataset result = spark.read().format("iceberg").load(location.toString()); + List actual = + result.orderBy("id").as(Encoders.bean(SimpleRecord.class)).collectAsList(); + Assert.assertEquals("Number of rows should match", expected.size(), actual.size()); + Assert.assertEquals("Result rows should match", expected, actual); + Assert.assertEquals("Number of snapshots should match", 2, Iterables.size(table.snapshots())); + } finally { + for (StreamingQuery query : spark.streams().active()) { + query.stop(); + } + } + } + + @Test + public void testStreamingWriteUpdateMode() throws Exception { + File parent = temp.newFolder("parquet"); + File location = new File(parent, "test-table"); + File checkpoint = new File(parent, "checkpoint"); + + HadoopTables tables = new HadoopTables(CONF); + PartitionSpec spec = PartitionSpec.builderFor(SCHEMA).identity("data").build(); + tables.create(SCHEMA, spec, location.toString()); + + MemoryStream inputStream = newMemoryStream(1, spark.sqlContext(), Encoders.INT()); + DataStreamWriter streamWriter = + inputStream + .toDF() + .selectExpr("value AS id", "CAST (value AS STRING) AS data") + .writeStream() + .outputMode("update") + .format("iceberg") + .option("checkpointLocation", checkpoint.toString()) + .option("path", location.toString()); + + try { + StreamingQuery query = streamWriter.start(); + List batch1 = Lists.newArrayList(1, 2); + send(batch1, inputStream); + + Assertions.assertThatThrownBy(query::processAllAvailable) + .isInstanceOf(StreamingQueryException.class) + .hasMessageContaining("does not support Update mode"); + } finally { + for (StreamingQuery query : spark.streams().active()) { + query.stop(); + } + } + } + + private MemoryStream newMemoryStream(int id, SQLContext sqlContext, Encoder encoder) { + return new MemoryStream<>(id, sqlContext, Option.empty(), encoder); + } + + private void send(List records, MemoryStream stream) { + stream.addData(JavaConverters.asScalaBuffer(records)); + } +} diff --git a/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/source/TestStructuredStreamingRead3.java b/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/source/TestStructuredStreamingRead3.java new file mode 100644 index 000000000000..dd456f22371e --- /dev/null +++ b/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/source/TestStructuredStreamingRead3.java @@ -0,0 +1,574 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.source; + +import static org.apache.iceberg.expressions.Expressions.ref; +import static org.assertj.core.api.AssertionsForClassTypes.assertThatThrownBy; + +import java.io.File; +import java.io.IOException; +import java.util.Collections; +import java.util.List; +import java.util.Map; +import java.util.concurrent.TimeoutException; +import java.util.stream.IntStream; +import org.apache.iceberg.AssertHelpers; +import org.apache.iceberg.BaseTable; +import org.apache.iceberg.DataOperations; +import org.apache.iceberg.DeleteFile; +import org.apache.iceberg.Files; +import org.apache.iceberg.Schema; +import org.apache.iceberg.Table; +import org.apache.iceberg.TableMetadata; +import org.apache.iceberg.TableOperations; +import org.apache.iceberg.TestHelpers; +import org.apache.iceberg.data.FileHelpers; +import org.apache.iceberg.data.GenericRecord; +import org.apache.iceberg.data.Record; +import org.apache.iceberg.expressions.Expressions; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap; +import org.apache.iceberg.relocated.com.google.common.collect.Iterables; +import org.apache.iceberg.relocated.com.google.common.collect.Lists; +import org.apache.iceberg.spark.SparkCatalogTestBase; +import org.apache.iceberg.spark.SparkReadOptions; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Encoders; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.streaming.DataStreamWriter; +import org.apache.spark.sql.streaming.OutputMode; +import org.apache.spark.sql.streaming.StreamingQuery; +import org.assertj.core.api.Assertions; +import org.junit.After; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.Parameterized; + +@RunWith(Parameterized.class) +public final class TestStructuredStreamingRead3 extends SparkCatalogTestBase { + public TestStructuredStreamingRead3( + String catalogName, String implementation, Map config) { + super(catalogName, implementation, config); + } + + private Table table; + + /** + * test data to be used by multiple writes each write creates a snapshot and writes a list of + * records + */ + private static final List> TEST_DATA_MULTIPLE_SNAPSHOTS = + Lists.newArrayList( + Lists.newArrayList( + new SimpleRecord(1, "one"), new SimpleRecord(2, "two"), new SimpleRecord(3, "three")), + Lists.newArrayList(new SimpleRecord(4, "four"), new SimpleRecord(5, "five")), + Lists.newArrayList(new SimpleRecord(6, "six"), new SimpleRecord(7, "seven"))); + + /** + * test data - to be used for multiple write batches each batch inturn will have multiple + * snapshots + */ + private static final List>> TEST_DATA_MULTIPLE_WRITES_MULTIPLE_SNAPSHOTS = + Lists.newArrayList( + Lists.newArrayList( + Lists.newArrayList( + new SimpleRecord(1, "one"), + new SimpleRecord(2, "two"), + new SimpleRecord(3, "three")), + Lists.newArrayList(new SimpleRecord(4, "four"), new SimpleRecord(5, "five"))), + Lists.newArrayList( + Lists.newArrayList(new SimpleRecord(6, "six"), new SimpleRecord(7, "seven")), + Lists.newArrayList(new SimpleRecord(8, "eight"), new SimpleRecord(9, "nine"))), + Lists.newArrayList( + Lists.newArrayList( + new SimpleRecord(10, "ten"), + new SimpleRecord(11, "eleven"), + new SimpleRecord(12, "twelve")), + Lists.newArrayList( + new SimpleRecord(13, "thirteen"), new SimpleRecord(14, "fourteen")), + Lists.newArrayList( + new SimpleRecord(15, "fifteen"), new SimpleRecord(16, "sixteen")))); + + @Before + public void setupTable() { + sql( + "CREATE TABLE %s " + + "(id INT, data STRING) " + + "USING iceberg " + + "PARTITIONED BY (bucket(3, id))", + tableName); + this.table = validationCatalog.loadTable(tableIdent); + } + + @After + public void stopStreams() throws TimeoutException { + for (StreamingQuery query : spark.streams().active()) { + query.stop(); + } + } + + @After + public void removeTables() { + sql("DROP TABLE IF EXISTS %s", tableName); + } + + @Test + public void testReadStreamOnIcebergTableWithMultipleSnapshots() throws Exception { + List> expected = TEST_DATA_MULTIPLE_SNAPSHOTS; + appendDataAsMultipleSnapshots(expected); + + StreamingQuery query = startStream(); + + List actual = rowsAvailable(query); + Assertions.assertThat(actual).containsExactlyInAnyOrderElementsOf(Iterables.concat(expected)); + } + + @Test + public void testReadStreamOnIcebergThenAddData() throws Exception { + List> expected = TEST_DATA_MULTIPLE_SNAPSHOTS; + + StreamingQuery query = startStream(); + + appendDataAsMultipleSnapshots(expected); + + List actual = rowsAvailable(query); + Assertions.assertThat(actual).containsExactlyInAnyOrderElementsOf(Iterables.concat(expected)); + } + + @Test + public void testReadingStreamFromTimestamp() throws Exception { + List dataBeforeTimestamp = + Lists.newArrayList( + new SimpleRecord(-2, "minustwo"), + new SimpleRecord(-1, "minusone"), + new SimpleRecord(0, "zero")); + + appendData(dataBeforeTimestamp); + + table.refresh(); + long streamStartTimestamp = table.currentSnapshot().timestampMillis() + 1; + + StreamingQuery query = + startStream(SparkReadOptions.STREAM_FROM_TIMESTAMP, Long.toString(streamStartTimestamp)); + + List empty = rowsAvailable(query); + Assertions.assertThat(empty.isEmpty()).isTrue(); + + List> expected = TEST_DATA_MULTIPLE_SNAPSHOTS; + appendDataAsMultipleSnapshots(expected); + + List actual = rowsAvailable(query); + + Assertions.assertThat(actual).containsExactlyInAnyOrderElementsOf(Iterables.concat(expected)); + } + + @Test + public void testReadingStreamFromFutureTimetsamp() throws Exception { + long futureTimestamp = System.currentTimeMillis() + 10000; + + StreamingQuery query = + startStream(SparkReadOptions.STREAM_FROM_TIMESTAMP, Long.toString(futureTimestamp)); + + List actual = rowsAvailable(query); + Assertions.assertThat(actual.isEmpty()).isTrue(); + + List data = + Lists.newArrayList( + new SimpleRecord(-2, "minustwo"), + new SimpleRecord(-1, "minusone"), + new SimpleRecord(0, "zero")); + + // Perform several inserts that should not show up because the fromTimestamp has not elapsed + IntStream.range(0, 3) + .forEach( + x -> { + appendData(data); + Assertions.assertThat(rowsAvailable(query).isEmpty()).isTrue(); + }); + + waitUntilAfter(futureTimestamp); + + // Data appended after the timestamp should appear + appendData(data); + actual = rowsAvailable(query); + Assertions.assertThat(actual).containsExactlyInAnyOrderElementsOf(data); + } + + @Test + public void testReadingStreamFromTimestampFutureWithExistingSnapshots() throws Exception { + List dataBeforeTimestamp = + Lists.newArrayList( + new SimpleRecord(1, "one"), new SimpleRecord(2, "two"), new SimpleRecord(3, "three")); + appendData(dataBeforeTimestamp); + + long streamStartTimestamp = System.currentTimeMillis() + 2000; + + // Start the stream with a future timestamp after the current snapshot + StreamingQuery query = + startStream(SparkReadOptions.STREAM_FROM_TIMESTAMP, Long.toString(streamStartTimestamp)); + List actual = rowsAvailable(query); + Assert.assertEquals(Collections.emptyList(), actual); + + // Stream should contain data added after the timestamp elapses + waitUntilAfter(streamStartTimestamp); + List> expected = TEST_DATA_MULTIPLE_SNAPSHOTS; + appendDataAsMultipleSnapshots(expected); + Assertions.assertThat(rowsAvailable(query)) + .containsExactlyInAnyOrderElementsOf(Iterables.concat(expected)); + } + + @Test + public void testReadingStreamFromTimestampOfExistingSnapshot() throws Exception { + List> expected = TEST_DATA_MULTIPLE_SNAPSHOTS; + + // Create an existing snapshot with some data + appendData(expected.get(0)); + table.refresh(); + long firstSnapshotTime = table.currentSnapshot().timestampMillis(); + + // Start stream giving the first Snapshot's time as the start point + StreamingQuery stream = + startStream(SparkReadOptions.STREAM_FROM_TIMESTAMP, Long.toString(firstSnapshotTime)); + + // Append rest of expected data + for (int i = 1; i < expected.size(); i++) { + appendData(expected.get(i)); + } + + List actual = rowsAvailable(stream); + Assertions.assertThat(actual).containsExactlyInAnyOrderElementsOf(Iterables.concat(expected)); + } + + @Test + public void testReadingStreamWithExpiredSnapshotFromTimestamp() throws TimeoutException { + List firstSnapshotRecordList = Lists.newArrayList(new SimpleRecord(1, "one")); + + List secondSnapshotRecordList = Lists.newArrayList(new SimpleRecord(2, "two")); + + List thirdSnapshotRecordList = Lists.newArrayList(new SimpleRecord(3, "three")); + + List expectedRecordList = Lists.newArrayList(); + expectedRecordList.addAll(secondSnapshotRecordList); + expectedRecordList.addAll(thirdSnapshotRecordList); + + appendData(firstSnapshotRecordList); + table.refresh(); + long firstSnapshotid = table.currentSnapshot().snapshotId(); + long firstSnapshotCommitTime = table.currentSnapshot().timestampMillis(); + + appendData(secondSnapshotRecordList); + appendData(thirdSnapshotRecordList); + + table.expireSnapshots().expireSnapshotId(firstSnapshotid).commit(); + + StreamingQuery query = + startStream( + SparkReadOptions.STREAM_FROM_TIMESTAMP, String.valueOf(firstSnapshotCommitTime)); + List actual = rowsAvailable(query); + Assertions.assertThat(actual).containsExactlyInAnyOrderElementsOf(expectedRecordList); + } + + @Test + public void testResumingStreamReadFromCheckpoint() throws Exception { + File writerCheckpointFolder = temp.newFolder("writer-checkpoint-folder"); + File writerCheckpoint = new File(writerCheckpointFolder, "writer-checkpoint"); + File output = temp.newFolder(); + + DataStreamWriter querySource = + spark + .readStream() + .format("iceberg") + .load(tableName) + .writeStream() + .option("checkpointLocation", writerCheckpoint.toString()) + .format("parquet") + .queryName("checkpoint_test") + .option("path", output.getPath()); + + StreamingQuery startQuery = querySource.start(); + startQuery.processAllAvailable(); + startQuery.stop(); + + List expected = Lists.newArrayList(); + for (List> expectedCheckpoint : + TEST_DATA_MULTIPLE_WRITES_MULTIPLE_SNAPSHOTS) { + // New data was added while the stream was down + appendDataAsMultipleSnapshots(expectedCheckpoint); + expected.addAll(Lists.newArrayList(Iterables.concat(Iterables.concat(expectedCheckpoint)))); + + // Stream starts up again from checkpoint read the newly added data and shut down + StreamingQuery restartedQuery = querySource.start(); + restartedQuery.processAllAvailable(); + restartedQuery.stop(); + + // Read data added by the stream + List actual = + spark.read().load(output.getPath()).as(Encoders.bean(SimpleRecord.class)).collectAsList(); + Assertions.assertThat(actual).containsExactlyInAnyOrderElementsOf(Iterables.concat(expected)); + } + } + + @Test + public void testFailReadingCheckpointInvalidSnapshot() throws IOException, TimeoutException { + File writerCheckpointFolder = temp.newFolder("writer-checkpoint-folder"); + File writerCheckpoint = new File(writerCheckpointFolder, "writer-checkpoint"); + File output = temp.newFolder(); + + DataStreamWriter querySource = + spark + .readStream() + .format("iceberg") + .load(tableName) + .writeStream() + .option("checkpointLocation", writerCheckpoint.toString()) + .format("parquet") + .queryName("checkpoint_test") + .option("path", output.getPath()); + + List firstSnapshotRecordList = Lists.newArrayList(new SimpleRecord(1, "one")); + List secondSnapshotRecordList = Lists.newArrayList(new SimpleRecord(2, "two")); + StreamingQuery startQuery = querySource.start(); + + appendData(firstSnapshotRecordList); + table.refresh(); + long firstSnapshotid = table.currentSnapshot().snapshotId(); + startQuery.processAllAvailable(); + startQuery.stop(); + + appendData(secondSnapshotRecordList); + + table.expireSnapshots().expireSnapshotId(firstSnapshotid).commit(); + + StreamingQuery restartedQuery = querySource.start(); + assertThatThrownBy(restartedQuery::processAllAvailable) + .hasCauseInstanceOf(IllegalStateException.class) + .hasMessageContaining( + String.format( + "Cannot load current offset at snapshot %d, the snapshot was expired or removed", + firstSnapshotid)); + } + + @Test + public void testParquetOrcAvroDataInOneTable() throws Exception { + List parquetFileRecords = + Lists.newArrayList( + new SimpleRecord(1, "one"), new SimpleRecord(2, "two"), new SimpleRecord(3, "three")); + + List orcFileRecords = + Lists.newArrayList(new SimpleRecord(4, "four"), new SimpleRecord(5, "five")); + + List avroFileRecords = + Lists.newArrayList(new SimpleRecord(6, "six"), new SimpleRecord(7, "seven")); + + appendData(parquetFileRecords); + appendData(orcFileRecords, "orc"); + appendData(avroFileRecords, "avro"); + + StreamingQuery query = startStream(); + Assertions.assertThat(rowsAvailable(query)) + .containsExactlyInAnyOrderElementsOf( + Iterables.concat(parquetFileRecords, orcFileRecords, avroFileRecords)); + } + + @Test + public void testReadStreamFromEmptyTable() throws Exception { + StreamingQuery stream = startStream(); + List actual = rowsAvailable(stream); + Assert.assertEquals(Collections.emptyList(), actual); + } + + @Test + public void testReadStreamWithSnapshotTypeOverwriteErrorsOut() throws Exception { + // upgrade table to version 2 - to facilitate creation of Snapshot of type OVERWRITE. + TableOperations ops = ((BaseTable) table).operations(); + TableMetadata meta = ops.current(); + ops.commit(meta, meta.upgradeToFormatVersion(2)); + + // fill table with some initial data + List> dataAcrossSnapshots = TEST_DATA_MULTIPLE_SNAPSHOTS; + appendDataAsMultipleSnapshots(dataAcrossSnapshots); + + Schema deleteRowSchema = table.schema().select("data"); + Record dataDelete = GenericRecord.create(deleteRowSchema); + List dataDeletes = + Lists.newArrayList( + dataDelete.copy("data", "one") // id = 1 + ); + + DeleteFile eqDeletes = + FileHelpers.writeDeleteFile( + table, + Files.localOutput(temp.newFile()), + TestHelpers.Row.of(0), + dataDeletes, + deleteRowSchema); + + table.newRowDelta().addDeletes(eqDeletes).commit(); + + // check pre-condition - that the above Delete file write - actually resulted in snapshot of + // type OVERWRITE + Assert.assertEquals(DataOperations.OVERWRITE, table.currentSnapshot().operation()); + + StreamingQuery query = startStream(); + + AssertHelpers.assertThrowsCause( + "Streaming should fail with IllegalStateException, as the snapshot is not of type APPEND", + IllegalStateException.class, + "Cannot process overwrite snapshot", + () -> query.processAllAvailable()); + } + + @Test + public void testReadStreamWithSnapshotTypeReplaceIgnoresReplace() throws Exception { + // fill table with some data + List> expected = TEST_DATA_MULTIPLE_SNAPSHOTS; + appendDataAsMultipleSnapshots(expected); + + // this should create a snapshot with type Replace. + table.rewriteManifests().clusterBy(f -> 1).commit(); + + // check pre-condition + Assert.assertEquals(DataOperations.REPLACE, table.currentSnapshot().operation()); + + StreamingQuery query = startStream(); + List actual = rowsAvailable(query); + Assertions.assertThat(actual).containsExactlyInAnyOrderElementsOf(Iterables.concat(expected)); + } + + @Test + public void testReadStreamWithSnapshotTypeDeleteErrorsOut() throws Exception { + table.updateSpec().removeField("id_bucket").addField(ref("id")).commit(); + + // fill table with some data + List> dataAcrossSnapshots = TEST_DATA_MULTIPLE_SNAPSHOTS; + appendDataAsMultipleSnapshots(dataAcrossSnapshots); + + // this should create a snapshot with type delete. + table.newDelete().deleteFromRowFilter(Expressions.equal("id", 4)).commit(); + + // check pre-condition - that the above delete operation on table resulted in Snapshot of Type + // DELETE. + Assert.assertEquals(DataOperations.DELETE, table.currentSnapshot().operation()); + + StreamingQuery query = startStream(); + + AssertHelpers.assertThrowsCause( + "Streaming should fail with IllegalStateException, as the snapshot is not of type APPEND", + IllegalStateException.class, + "Cannot process delete snapshot", + () -> query.processAllAvailable()); + } + + @Test + public void testReadStreamWithSnapshotTypeDeleteAndSkipDeleteOption() throws Exception { + table.updateSpec().removeField("id_bucket").addField(ref("id")).commit(); + + // fill table with some data + List> dataAcrossSnapshots = TEST_DATA_MULTIPLE_SNAPSHOTS; + appendDataAsMultipleSnapshots(dataAcrossSnapshots); + + // this should create a snapshot with type delete. + table.newDelete().deleteFromRowFilter(Expressions.equal("id", 4)).commit(); + + // check pre-condition - that the above delete operation on table resulted in Snapshot of Type + // DELETE. + Assert.assertEquals(DataOperations.DELETE, table.currentSnapshot().operation()); + + StreamingQuery query = startStream(SparkReadOptions.STREAMING_SKIP_DELETE_SNAPSHOTS, "true"); + Assertions.assertThat(rowsAvailable(query)) + .containsExactlyInAnyOrderElementsOf(Iterables.concat(dataAcrossSnapshots)); + } + + @Test + public void testReadStreamWithSnapshotTypeDeleteAndSkipOverwriteOption() throws Exception { + table.updateSpec().removeField("id_bucket").addField(ref("id")).commit(); + + // fill table with some data + List> dataAcrossSnapshots = TEST_DATA_MULTIPLE_SNAPSHOTS; + appendDataAsMultipleSnapshots(dataAcrossSnapshots); + + // this should create a snapshot with type overwrite. + table.newOverwrite().overwriteByRowFilter(Expressions.greaterThan("id", 4)).commit(); + + // check pre-condition - that the above delete operation on table resulted in Snapshot of Type + // OVERWRITE. + Assert.assertEquals(DataOperations.OVERWRITE, table.currentSnapshot().operation()); + + StreamingQuery query = startStream(SparkReadOptions.STREAMING_SKIP_OVERWRITE_SNAPSHOTS, "true"); + Assertions.assertThat(rowsAvailable(query)) + .containsExactlyInAnyOrderElementsOf(Iterables.concat(dataAcrossSnapshots)); + } + + /** + * appends each list as a Snapshot on the iceberg table at the given location. accepts a list of + * lists - each list representing data per snapshot. + */ + private void appendDataAsMultipleSnapshots(List> data) { + for (List l : data) { + appendData(l); + } + } + + private void appendData(List data) { + appendData(data, "parquet"); + } + + private void appendData(List data, String format) { + Dataset df = spark.createDataFrame(data, SimpleRecord.class); + df.select("id", "data") + .write() + .format("iceberg") + .option("write-format", format) + .mode("append") + .save(tableName); + } + + private static final String MEMORY_TABLE = "_stream_view_mem"; + + private StreamingQuery startStream(Map options) throws TimeoutException { + return spark + .readStream() + .options(options) + .format("iceberg") + .load(tableName) + .writeStream() + .options(options) + .format("memory") + .queryName(MEMORY_TABLE) + .outputMode(OutputMode.Append()) + .start(); + } + + private StreamingQuery startStream() throws TimeoutException { + return startStream(Collections.emptyMap()); + } + + private StreamingQuery startStream(String key, String value) throws TimeoutException { + return startStream(ImmutableMap.of(key, value)); + } + + private List rowsAvailable(StreamingQuery query) { + query.processAllAvailable(); + return spark + .sql("select * from " + MEMORY_TABLE) + .as(Encoders.bean(SimpleRecord.class)) + .collectAsList(); + } +} diff --git a/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/source/TestTables.java b/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/source/TestTables.java new file mode 100644 index 000000000000..0650cb9738a6 --- /dev/null +++ b/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/source/TestTables.java @@ -0,0 +1,205 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.source; + +import java.io.File; +import java.util.Map; +import org.apache.iceberg.BaseTable; +import org.apache.iceberg.Files; +import org.apache.iceberg.LocationProviders; +import org.apache.iceberg.PartitionSpec; +import org.apache.iceberg.Schema; +import org.apache.iceberg.Snapshot; +import org.apache.iceberg.TableMetadata; +import org.apache.iceberg.TableOperations; +import org.apache.iceberg.exceptions.AlreadyExistsException; +import org.apache.iceberg.exceptions.CommitFailedException; +import org.apache.iceberg.exceptions.RuntimeIOException; +import org.apache.iceberg.io.FileIO; +import org.apache.iceberg.io.InputFile; +import org.apache.iceberg.io.LocationProvider; +import org.apache.iceberg.io.OutputFile; +import org.apache.iceberg.relocated.com.google.common.base.Preconditions; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap; +import org.apache.iceberg.relocated.com.google.common.collect.Maps; + +// TODO: Use the copy of this from core. +class TestTables { + private TestTables() {} + + static TestTable create(File temp, String name, Schema schema, PartitionSpec spec) { + TestTableOperations ops = new TestTableOperations(name); + if (ops.current() != null) { + throw new AlreadyExistsException("Table %s already exists at location: %s", name, temp); + } + ops.commit( + null, TableMetadata.newTableMetadata(schema, spec, temp.toString(), ImmutableMap.of())); + return new TestTable(ops, name); + } + + static TestTable load(String name) { + TestTableOperations ops = new TestTableOperations(name); + if (ops.current() == null) { + return null; + } + return new TestTable(ops, name); + } + + static boolean drop(String name) { + synchronized (METADATA) { + return METADATA.remove(name) != null; + } + } + + static class TestTable extends BaseTable { + private final TestTableOperations ops; + + private TestTable(TestTableOperations ops, String name) { + super(ops, name); + this.ops = ops; + } + + @Override + public TestTableOperations operations() { + return ops; + } + } + + private static final Map METADATA = Maps.newHashMap(); + + static void clearTables() { + synchronized (METADATA) { + METADATA.clear(); + } + } + + static TableMetadata readMetadata(String tableName) { + synchronized (METADATA) { + return METADATA.get(tableName); + } + } + + static void replaceMetadata(String tableName, TableMetadata metadata) { + synchronized (METADATA) { + METADATA.put(tableName, metadata); + } + } + + static class TestTableOperations implements TableOperations { + + private final String tableName; + private TableMetadata current = null; + private long lastSnapshotId = 0; + private int failCommits = 0; + + TestTableOperations(String tableName) { + this.tableName = tableName; + refresh(); + if (current != null) { + for (Snapshot snap : current.snapshots()) { + this.lastSnapshotId = Math.max(lastSnapshotId, snap.snapshotId()); + } + } else { + this.lastSnapshotId = 0; + } + } + + void failCommits(int numFailures) { + this.failCommits = numFailures; + } + + @Override + public TableMetadata current() { + return current; + } + + @Override + public TableMetadata refresh() { + synchronized (METADATA) { + this.current = METADATA.get(tableName); + } + return current; + } + + @Override + public void commit(TableMetadata base, TableMetadata metadata) { + if (base != current) { + throw new CommitFailedException("Cannot commit changes based on stale metadata"); + } + synchronized (METADATA) { + refresh(); + if (base == current) { + if (failCommits > 0) { + this.failCommits -= 1; + throw new CommitFailedException("Injected failure"); + } + METADATA.put(tableName, metadata); + this.current = metadata; + } else { + throw new CommitFailedException( + "Commit failed: table was updated at %d", base.lastUpdatedMillis()); + } + } + } + + @Override + public FileIO io() { + return new LocalFileIO(); + } + + @Override + public LocationProvider locationProvider() { + Preconditions.checkNotNull( + current, "Current metadata should not be null when locationProvider is called"); + return LocationProviders.locationsFor(current.location(), current.properties()); + } + + @Override + public String metadataFileLocation(String fileName) { + return new File(new File(current.location(), "metadata"), fileName).getAbsolutePath(); + } + + @Override + public long newSnapshotId() { + long nextSnapshotId = lastSnapshotId + 1; + this.lastSnapshotId = nextSnapshotId; + return nextSnapshotId; + } + } + + static class LocalFileIO implements FileIO { + + @Override + public InputFile newInputFile(String path) { + return Files.localInput(path); + } + + @Override + public OutputFile newOutputFile(String path) { + return Files.localOutput(new File(path)); + } + + @Override + public void deleteFile(String path) { + if (!new File(path).delete()) { + throw new RuntimeIOException("Failed to delete file: " + path); + } + } + } +} diff --git a/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/source/TestTimestampWithoutZone.java b/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/source/TestTimestampWithoutZone.java new file mode 100644 index 000000000000..053f6dbaea46 --- /dev/null +++ b/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/source/TestTimestampWithoutZone.java @@ -0,0 +1,309 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.source; + +import static org.apache.iceberg.Files.localOutput; + +import java.io.File; +import java.io.IOException; +import java.time.LocalDateTime; +import java.util.List; +import java.util.UUID; +import java.util.stream.Collectors; +import java.util.stream.Stream; +import org.apache.hadoop.conf.Configuration; +import org.apache.iceberg.AssertHelpers; +import org.apache.iceberg.DataFile; +import org.apache.iceberg.DataFiles; +import org.apache.iceberg.FileFormat; +import org.apache.iceberg.PartitionSpec; +import org.apache.iceberg.Schema; +import org.apache.iceberg.Table; +import org.apache.iceberg.data.GenericAppenderFactory; +import org.apache.iceberg.data.GenericRecord; +import org.apache.iceberg.data.Record; +import org.apache.iceberg.hadoop.HadoopTables; +import org.apache.iceberg.io.FileAppender; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap; +import org.apache.iceberg.relocated.com.google.common.collect.Lists; +import org.apache.iceberg.spark.SparkReadOptions; +import org.apache.iceberg.spark.SparkSQLProperties; +import org.apache.iceberg.spark.SparkTestBase; +import org.apache.iceberg.spark.SparkUtil; +import org.apache.iceberg.spark.SparkWriteOptions; +import org.apache.iceberg.spark.data.GenericsHelpers; +import org.apache.iceberg.types.Types; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.SaveMode; +import org.apache.spark.sql.SparkSession; +import org.junit.AfterClass; +import org.junit.Assert; +import org.junit.Before; +import org.junit.BeforeClass; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.TemporaryFolder; +import org.junit.runner.RunWith; +import org.junit.runners.Parameterized; + +@RunWith(Parameterized.class) +public class TestTimestampWithoutZone extends SparkTestBase { + private static final Configuration CONF = new Configuration(); + private static final HadoopTables TABLES = new HadoopTables(CONF); + + private static final Schema SCHEMA = + new Schema( + Types.NestedField.required(1, "id", Types.LongType.get()), + Types.NestedField.optional(2, "ts", Types.TimestampType.withoutZone()), + Types.NestedField.optional(3, "data", Types.StringType.get())); + + private static SparkSession spark = null; + + @BeforeClass + public static void startSpark() { + TestTimestampWithoutZone.spark = SparkSession.builder().master("local[2]").getOrCreate(); + } + + @AfterClass + public static void stopSpark() { + SparkSession currentSpark = TestTimestampWithoutZone.spark; + TestTimestampWithoutZone.spark = null; + currentSpark.stop(); + } + + @Rule public TemporaryFolder temp = new TemporaryFolder(); + + private final String format; + private final boolean vectorized; + + @Parameterized.Parameters(name = "format = {0}, vectorized = {1}") + public static Object[][] parameters() { + return new Object[][] { + {"parquet", false}, + {"parquet", true}, + {"avro", false} + }; + } + + public TestTimestampWithoutZone(String format, boolean vectorized) { + this.format = format; + this.vectorized = vectorized; + } + + private File parent = null; + private File unpartitioned = null; + private List records = null; + + @Before + public void writeUnpartitionedTable() throws IOException { + this.parent = temp.newFolder("TestTimestampWithoutZone"); + this.unpartitioned = new File(parent, "unpartitioned"); + File dataFolder = new File(unpartitioned, "data"); + Assert.assertTrue("Mkdir should succeed", dataFolder.mkdirs()); + + Table table = TABLES.create(SCHEMA, PartitionSpec.unpartitioned(), unpartitioned.toString()); + Schema tableSchema = table.schema(); // use the table schema because ids are reassigned + + FileFormat fileFormat = FileFormat.fromString(format); + + File testFile = new File(dataFolder, fileFormat.addExtension(UUID.randomUUID().toString())); + + // create records using the table's schema + this.records = testRecords(tableSchema); + + try (FileAppender writer = + new GenericAppenderFactory(tableSchema).newAppender(localOutput(testFile), fileFormat)) { + writer.addAll(records); + } + + DataFile file = + DataFiles.builder(PartitionSpec.unpartitioned()) + .withRecordCount(records.size()) + .withFileSizeInBytes(testFile.length()) + .withPath(testFile.toString()) + .build(); + + table.newAppend().appendFile(file).commit(); + } + + @Test + public void testUnpartitionedTimestampWithoutZone() { + assertEqualsSafe(SCHEMA.asStruct(), records, read(unpartitioned.toString(), vectorized)); + } + + @Test + public void testUnpartitionedTimestampWithoutZoneProjection() { + Schema projection = SCHEMA.select("id", "ts"); + assertEqualsSafe( + projection.asStruct(), + records.stream().map(r -> projectFlat(projection, r)).collect(Collectors.toList()), + read(unpartitioned.toString(), vectorized, "id", "ts")); + } + + @Test + public void testUnpartitionedTimestampWithoutZoneError() { + AssertHelpers.assertThrows( + String.format( + "Read operation performed on a timestamp without timezone field while " + + "'%s' set to false should throw exception", + SparkReadOptions.HANDLE_TIMESTAMP_WITHOUT_TIMEZONE), + IllegalArgumentException.class, + SparkUtil.TIMESTAMP_WITHOUT_TIMEZONE_ERROR, + () -> + spark + .read() + .format("iceberg") + .option(SparkReadOptions.VECTORIZATION_ENABLED, String.valueOf(vectorized)) + .option(SparkReadOptions.HANDLE_TIMESTAMP_WITHOUT_TIMEZONE, "false") + .load(unpartitioned.toString()) + .collectAsList()); + } + + @Test + public void testUnpartitionedTimestampWithoutZoneAppend() { + spark + .read() + .format("iceberg") + .option(SparkReadOptions.HANDLE_TIMESTAMP_WITHOUT_TIMEZONE, "true") + .option(SparkReadOptions.VECTORIZATION_ENABLED, String.valueOf(vectorized)) + .load(unpartitioned.toString()) + .write() + .format("iceberg") + .option(SparkWriteOptions.HANDLE_TIMESTAMP_WITHOUT_TIMEZONE, "true") + .mode(SaveMode.Append) + .save(unpartitioned.toString()); + + assertEqualsSafe( + SCHEMA.asStruct(), + Stream.concat(records.stream(), records.stream()).collect(Collectors.toList()), + read(unpartitioned.toString(), vectorized)); + } + + @Test + public void testUnpartitionedTimestampWithoutZoneWriteError() { + String errorMessage = + String.format( + "Write operation performed on a timestamp without timezone field while " + + "'%s' set to false should throw exception", + SparkSQLProperties.HANDLE_TIMESTAMP_WITHOUT_TIMEZONE); + Runnable writeOperation = + () -> + spark + .read() + .format("iceberg") + .option(SparkReadOptions.HANDLE_TIMESTAMP_WITHOUT_TIMEZONE, "true") + .option(SparkReadOptions.VECTORIZATION_ENABLED, String.valueOf(vectorized)) + .load(unpartitioned.toString()) + .write() + .format("iceberg") + .option(SparkWriteOptions.HANDLE_TIMESTAMP_WITHOUT_TIMEZONE, "false") + .mode(SaveMode.Append) + .save(unpartitioned.toString()); + + AssertHelpers.assertThrows( + errorMessage, + IllegalArgumentException.class, + SparkUtil.TIMESTAMP_WITHOUT_TIMEZONE_ERROR, + writeOperation); + } + + @Test + public void testUnpartitionedTimestampWithoutZoneSessionProperties() { + withSQLConf( + ImmutableMap.of(SparkSQLProperties.HANDLE_TIMESTAMP_WITHOUT_TIMEZONE, "true"), + () -> { + spark + .read() + .format("iceberg") + .option(SparkReadOptions.VECTORIZATION_ENABLED, String.valueOf(vectorized)) + .load(unpartitioned.toString()) + .write() + .format("iceberg") + .mode(SaveMode.Append) + .save(unpartitioned.toString()); + + assertEqualsSafe( + SCHEMA.asStruct(), + Stream.concat(records.stream(), records.stream()).collect(Collectors.toList()), + read(unpartitioned.toString(), vectorized)); + }); + } + + private static Record projectFlat(Schema projection, Record record) { + Record result = GenericRecord.create(projection); + List fields = projection.asStruct().fields(); + for (int i = 0; i < fields.size(); i += 1) { + Types.NestedField field = fields.get(i); + result.set(i, record.getField(field.name())); + } + return result; + } + + public static void assertEqualsSafe( + Types.StructType struct, List expected, List actual) { + Assert.assertEquals("Number of results should match expected", expected.size(), actual.size()); + for (int i = 0; i < expected.size(); i += 1) { + GenericsHelpers.assertEqualsSafe(struct, expected.get(i), actual.get(i)); + } + } + + private List testRecords(Schema schema) { + return Lists.newArrayList( + record(schema, 0L, parseToLocal("2017-12-22T09:20:44.294658"), "junction"), + record(schema, 1L, parseToLocal("2017-12-22T07:15:34.582910"), "alligator"), + record(schema, 2L, parseToLocal("2017-12-22T06:02:09.243857"), "forrest"), + record(schema, 3L, parseToLocal("2017-12-22T03:10:11.134509"), "clapping"), + record(schema, 4L, parseToLocal("2017-12-22T00:34:00.184671"), "brush"), + record(schema, 5L, parseToLocal("2017-12-21T22:20:08.935889"), "trap"), + record(schema, 6L, parseToLocal("2017-12-21T21:55:30.589712"), "element"), + record(schema, 7L, parseToLocal("2017-12-21T17:31:14.532797"), "limited"), + record(schema, 8L, parseToLocal("2017-12-21T15:21:51.237521"), "global"), + record(schema, 9L, parseToLocal("2017-12-21T15:02:15.230570"), "goldfish")); + } + + private static List read(String table, boolean vectorized) { + return read(table, vectorized, "*"); + } + + private static List read( + String table, boolean vectorized, String select0, String... selectN) { + Dataset dataset = + spark + .read() + .format("iceberg") + .option(SparkReadOptions.VECTORIZATION_ENABLED, String.valueOf(vectorized)) + .option(SparkReadOptions.HANDLE_TIMESTAMP_WITHOUT_TIMEZONE, "true") + .load(table) + .select(select0, selectN); + return dataset.collectAsList(); + } + + private static LocalDateTime parseToLocal(String timestamp) { + return LocalDateTime.parse(timestamp); + } + + private static Record record(Schema schema, Object... values) { + Record rec = GenericRecord.create(schema); + for (int i = 0; i < values.length; i += 1) { + rec.set(i, values[i]); + } + return rec; + } +} diff --git a/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/source/TestWriteMetricsConfig.java b/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/source/TestWriteMetricsConfig.java new file mode 100644 index 000000000000..9bf00f1b1365 --- /dev/null +++ b/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/source/TestWriteMetricsConfig.java @@ -0,0 +1,298 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.source; + +import static org.apache.iceberg.spark.SparkSchemaUtil.convert; +import static org.apache.iceberg.types.Types.NestedField.optional; +import static org.apache.iceberg.types.Types.NestedField.required; + +import java.io.IOException; +import java.nio.ByteBuffer; +import java.util.List; +import java.util.Map; +import org.apache.hadoop.conf.Configuration; +import org.apache.iceberg.AssertHelpers; +import org.apache.iceberg.DataFile; +import org.apache.iceberg.FileScanTask; +import org.apache.iceberg.PartitionSpec; +import org.apache.iceberg.Schema; +import org.apache.iceberg.Table; +import org.apache.iceberg.TableProperties; +import org.apache.iceberg.exceptions.ValidationException; +import org.apache.iceberg.hadoop.HadoopTables; +import org.apache.iceberg.relocated.com.google.common.collect.Lists; +import org.apache.iceberg.relocated.com.google.common.collect.Maps; +import org.apache.iceberg.spark.SparkWriteOptions; +import org.apache.iceberg.spark.data.RandomData; +import org.apache.iceberg.types.Types; +import org.apache.iceberg.util.ByteBuffers; +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.SaveMode; +import org.apache.spark.sql.SparkSession; +import org.apache.spark.sql.catalyst.InternalRow; +import org.junit.AfterClass; +import org.junit.Assert; +import org.junit.BeforeClass; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.TemporaryFolder; + +public class TestWriteMetricsConfig { + + private static final Configuration CONF = new Configuration(); + private static final Schema SIMPLE_SCHEMA = + new Schema( + optional(1, "id", Types.IntegerType.get()), optional(2, "data", Types.StringType.get())); + private static final Schema COMPLEX_SCHEMA = + new Schema( + required(1, "longCol", Types.IntegerType.get()), + optional(2, "strCol", Types.StringType.get()), + required( + 3, + "record", + Types.StructType.of( + required(4, "id", Types.IntegerType.get()), + required(5, "data", Types.StringType.get())))); + + @Rule public TemporaryFolder temp = new TemporaryFolder(); + + private static SparkSession spark = null; + private static JavaSparkContext sc = null; + + @BeforeClass + public static void startSpark() { + TestWriteMetricsConfig.spark = SparkSession.builder().master("local[2]").getOrCreate(); + TestWriteMetricsConfig.sc = JavaSparkContext.fromSparkContext(spark.sparkContext()); + } + + @AfterClass + public static void stopSpark() { + SparkSession currentSpark = TestWriteMetricsConfig.spark; + TestWriteMetricsConfig.spark = null; + TestWriteMetricsConfig.sc = null; + currentSpark.stop(); + } + + @Test + public void testFullMetricsCollectionForParquet() throws IOException { + String tableLocation = temp.newFolder("iceberg-table").toString(); + + HadoopTables tables = new HadoopTables(CONF); + PartitionSpec spec = PartitionSpec.unpartitioned(); + Map properties = Maps.newHashMap(); + properties.put(TableProperties.DEFAULT_WRITE_METRICS_MODE, "full"); + Table table = tables.create(SIMPLE_SCHEMA, spec, properties, tableLocation); + + List expectedRecords = + Lists.newArrayList( + new SimpleRecord(1, "a"), new SimpleRecord(2, "b"), new SimpleRecord(3, "c")); + Dataset df = spark.createDataFrame(expectedRecords, SimpleRecord.class); + df.select("id", "data") + .coalesce(1) + .write() + .format("iceberg") + .option(SparkWriteOptions.WRITE_FORMAT, "parquet") + .mode(SaveMode.Append) + .save(tableLocation); + + for (FileScanTask task : table.newScan().includeColumnStats().planFiles()) { + DataFile file = task.file(); + Assert.assertEquals(2, file.nullValueCounts().size()); + Assert.assertEquals(2, file.valueCounts().size()); + Assert.assertEquals(2, file.lowerBounds().size()); + Assert.assertEquals(2, file.upperBounds().size()); + } + } + + @Test + public void testCountMetricsCollectionForParquet() throws IOException { + String tableLocation = temp.newFolder("iceberg-table").toString(); + + HadoopTables tables = new HadoopTables(CONF); + PartitionSpec spec = PartitionSpec.unpartitioned(); + Map properties = Maps.newHashMap(); + properties.put(TableProperties.DEFAULT_WRITE_METRICS_MODE, "counts"); + Table table = tables.create(SIMPLE_SCHEMA, spec, properties, tableLocation); + + List expectedRecords = + Lists.newArrayList( + new SimpleRecord(1, "a"), new SimpleRecord(2, "b"), new SimpleRecord(3, "c")); + Dataset df = spark.createDataFrame(expectedRecords, SimpleRecord.class); + df.select("id", "data") + .coalesce(1) + .write() + .format("iceberg") + .option(SparkWriteOptions.WRITE_FORMAT, "parquet") + .mode(SaveMode.Append) + .save(tableLocation); + + for (FileScanTask task : table.newScan().includeColumnStats().planFiles()) { + DataFile file = task.file(); + Assert.assertEquals(2, file.nullValueCounts().size()); + Assert.assertEquals(2, file.valueCounts().size()); + Assert.assertTrue(file.lowerBounds().isEmpty()); + Assert.assertTrue(file.upperBounds().isEmpty()); + } + } + + @Test + public void testNoMetricsCollectionForParquet() throws IOException { + String tableLocation = temp.newFolder("iceberg-table").toString(); + + HadoopTables tables = new HadoopTables(CONF); + PartitionSpec spec = PartitionSpec.unpartitioned(); + Map properties = Maps.newHashMap(); + properties.put(TableProperties.DEFAULT_WRITE_METRICS_MODE, "none"); + Table table = tables.create(SIMPLE_SCHEMA, spec, properties, tableLocation); + + List expectedRecords = + Lists.newArrayList( + new SimpleRecord(1, "a"), new SimpleRecord(2, "b"), new SimpleRecord(3, "c")); + Dataset df = spark.createDataFrame(expectedRecords, SimpleRecord.class); + df.select("id", "data") + .coalesce(1) + .write() + .format("iceberg") + .option(SparkWriteOptions.WRITE_FORMAT, "parquet") + .mode(SaveMode.Append) + .save(tableLocation); + + for (FileScanTask task : table.newScan().includeColumnStats().planFiles()) { + DataFile file = task.file(); + Assert.assertTrue(file.nullValueCounts().isEmpty()); + Assert.assertTrue(file.valueCounts().isEmpty()); + Assert.assertTrue(file.lowerBounds().isEmpty()); + Assert.assertTrue(file.upperBounds().isEmpty()); + } + } + + @Test + public void testCustomMetricCollectionForParquet() throws IOException { + String tableLocation = temp.newFolder("iceberg-table").toString(); + + HadoopTables tables = new HadoopTables(CONF); + PartitionSpec spec = PartitionSpec.unpartitioned(); + Map properties = Maps.newHashMap(); + properties.put(TableProperties.DEFAULT_WRITE_METRICS_MODE, "counts"); + properties.put("write.metadata.metrics.column.id", "full"); + Table table = tables.create(SIMPLE_SCHEMA, spec, properties, tableLocation); + + List expectedRecords = + Lists.newArrayList( + new SimpleRecord(1, "a"), new SimpleRecord(2, "b"), new SimpleRecord(3, "c")); + Dataset df = spark.createDataFrame(expectedRecords, SimpleRecord.class); + df.select("id", "data") + .coalesce(1) + .write() + .format("iceberg") + .option(SparkWriteOptions.WRITE_FORMAT, "parquet") + .mode(SaveMode.Append) + .save(tableLocation); + + Schema schema = table.schema(); + Types.NestedField id = schema.findField("id"); + for (FileScanTask task : table.newScan().includeColumnStats().planFiles()) { + DataFile file = task.file(); + Assert.assertEquals(2, file.nullValueCounts().size()); + Assert.assertEquals(2, file.valueCounts().size()); + Assert.assertEquals(1, file.lowerBounds().size()); + Assert.assertTrue(file.lowerBounds().containsKey(id.fieldId())); + Assert.assertEquals(1, file.upperBounds().size()); + Assert.assertTrue(file.upperBounds().containsKey(id.fieldId())); + } + } + + @Test + public void testBadCustomMetricCollectionForParquet() throws IOException { + String tableLocation = temp.newFolder("iceberg-table").toString(); + + HadoopTables tables = new HadoopTables(CONF); + PartitionSpec spec = PartitionSpec.unpartitioned(); + Map properties = Maps.newHashMap(); + properties.put(TableProperties.DEFAULT_WRITE_METRICS_MODE, "counts"); + properties.put("write.metadata.metrics.column.ids", "full"); + + AssertHelpers.assertThrows( + "Creating a table with invalid metrics should fail", + ValidationException.class, + null, + () -> tables.create(SIMPLE_SCHEMA, spec, properties, tableLocation)); + } + + @Test + public void testCustomMetricCollectionForNestedParquet() throws IOException { + String tableLocation = temp.newFolder("iceberg-table").toString(); + + HadoopTables tables = new HadoopTables(CONF); + PartitionSpec spec = PartitionSpec.builderFor(COMPLEX_SCHEMA).identity("strCol").build(); + Map properties = Maps.newHashMap(); + properties.put(TableProperties.DEFAULT_WRITE_METRICS_MODE, "none"); + properties.put("write.metadata.metrics.column.longCol", "counts"); + properties.put("write.metadata.metrics.column.record.id", "full"); + properties.put("write.metadata.metrics.column.record.data", "truncate(2)"); + Table table = tables.create(COMPLEX_SCHEMA, spec, properties, tableLocation); + + Iterable rows = RandomData.generateSpark(COMPLEX_SCHEMA, 10, 0); + JavaRDD rdd = sc.parallelize(Lists.newArrayList(rows)); + Dataset df = + spark.internalCreateDataFrame(JavaRDD.toRDD(rdd), convert(COMPLEX_SCHEMA), false); + + df.coalesce(1) + .write() + .format("iceberg") + .option(SparkWriteOptions.WRITE_FORMAT, "parquet") + .mode(SaveMode.Append) + .save(tableLocation); + + Schema schema = table.schema(); + Types.NestedField longCol = schema.findField("longCol"); + Types.NestedField recordId = schema.findField("record.id"); + Types.NestedField recordData = schema.findField("record.data"); + for (FileScanTask task : table.newScan().includeColumnStats().planFiles()) { + DataFile file = task.file(); + + Map nullValueCounts = file.nullValueCounts(); + Assert.assertEquals(3, nullValueCounts.size()); + Assert.assertTrue(nullValueCounts.containsKey(longCol.fieldId())); + Assert.assertTrue(nullValueCounts.containsKey(recordId.fieldId())); + Assert.assertTrue(nullValueCounts.containsKey(recordData.fieldId())); + + Map valueCounts = file.valueCounts(); + Assert.assertEquals(3, valueCounts.size()); + Assert.assertTrue(valueCounts.containsKey(longCol.fieldId())); + Assert.assertTrue(valueCounts.containsKey(recordId.fieldId())); + Assert.assertTrue(valueCounts.containsKey(recordData.fieldId())); + + Map lowerBounds = file.lowerBounds(); + Assert.assertEquals(2, lowerBounds.size()); + Assert.assertTrue(lowerBounds.containsKey(recordId.fieldId())); + ByteBuffer recordDataLowerBound = lowerBounds.get(recordData.fieldId()); + Assert.assertEquals(2, ByteBuffers.toByteArray(recordDataLowerBound).length); + + Map upperBounds = file.upperBounds(); + Assert.assertEquals(2, upperBounds.size()); + Assert.assertTrue(upperBounds.containsKey(recordId.fieldId())); + ByteBuffer recordDataUpperBound = upperBounds.get(recordData.fieldId()); + Assert.assertEquals(2, ByteBuffers.toByteArray(recordDataUpperBound).length); + } + } +} diff --git a/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/source/ThreeColumnRecord.java b/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/source/ThreeColumnRecord.java new file mode 100644 index 000000000000..554557df416c --- /dev/null +++ b/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/source/ThreeColumnRecord.java @@ -0,0 +1,83 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.source; + +import java.util.Objects; + +public class ThreeColumnRecord { + private Integer c1; + private String c2; + private String c3; + + public ThreeColumnRecord() {} + + public ThreeColumnRecord(Integer c1, String c2, String c3) { + this.c1 = c1; + this.c2 = c2; + this.c3 = c3; + } + + public Integer getC1() { + return c1; + } + + public void setC1(Integer c1) { + this.c1 = c1; + } + + public String getC2() { + return c2; + } + + public void setC2(String c2) { + this.c2 = c2; + } + + public String getC3() { + return c3; + } + + public void setC3(String c3) { + this.c3 = c3; + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + ThreeColumnRecord that = (ThreeColumnRecord) o; + return Objects.equals(c1, that.c1) + && Objects.equals(c2, that.c2) + && Objects.equals(c3, that.c3); + } + + @Override + public int hashCode() { + return Objects.hash(c1, c2, c3); + } + + @Override + public String toString() { + return "ThreeColumnRecord{" + "c1=" + c1 + ", c2='" + c2 + '\'' + ", c3='" + c3 + '\'' + '}'; + } +} diff --git a/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/sql/PartitionedWritesTestBase.java b/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/sql/PartitionedWritesTestBase.java new file mode 100644 index 000000000000..77dccbf1e064 --- /dev/null +++ b/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/sql/PartitionedWritesTestBase.java @@ -0,0 +1,282 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.sql; + +import java.util.Arrays; +import java.util.List; +import java.util.Map; +import org.apache.iceberg.Table; +import org.apache.iceberg.expressions.Expressions; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableList; +import org.apache.iceberg.spark.SparkCatalogTestBase; +import org.apache.iceberg.spark.source.SimpleRecord; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.catalyst.analysis.NoSuchTableException; +import org.apache.spark.sql.functions; +import org.junit.After; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; + +public abstract class PartitionedWritesTestBase extends SparkCatalogTestBase { + public PartitionedWritesTestBase( + String catalogName, String implementation, Map config) { + super(catalogName, implementation, config); + } + + @Before + public void createTables() { + sql( + "CREATE TABLE %s (id bigint, data string) USING iceberg PARTITIONED BY (truncate(id, 3))", + tableName); + sql("INSERT INTO %s VALUES (1, 'a'), (2, 'b'), (3, 'c')", tableName); + } + + @After + public void removeTables() { + sql("DROP TABLE IF EXISTS %s", tableName); + } + + @Test + public void testInsertAppend() { + Assert.assertEquals( + "Should have 3 rows", 3L, scalarSql("SELECT count(*) FROM %s", selectTarget())); + + sql("INSERT INTO %s VALUES (4, 'd'), (5, 'e')", commitTarget()); + + Assert.assertEquals( + "Should have 5 rows after insert", + 5L, + scalarSql("SELECT count(*) FROM %s", selectTarget())); + + List expected = + ImmutableList.of(row(1L, "a"), row(2L, "b"), row(3L, "c"), row(4L, "d"), row(5L, "e")); + + assertEquals( + "Row data should match expected", + expected, + sql("SELECT * FROM %s ORDER BY id", selectTarget())); + } + + @Test + public void testInsertOverwrite() { + Assert.assertEquals( + "Should have 3 rows", 3L, scalarSql("SELECT count(*) FROM %s", selectTarget())); + + // 4 and 5 replace 3 in the partition (id - (id % 3)) = 3 + sql("INSERT OVERWRITE %s VALUES (4, 'd'), (5, 'e')", commitTarget()); + + Assert.assertEquals( + "Should have 4 rows after overwrite", + 4L, + scalarSql("SELECT count(*) FROM %s", selectTarget())); + + List expected = + ImmutableList.of(row(1L, "a"), row(2L, "b"), row(4L, "d"), row(5L, "e")); + + assertEquals( + "Row data should match expected", + expected, + sql("SELECT * FROM %s ORDER BY id", selectTarget())); + } + + @Test + public void testDataFrameV2Append() throws NoSuchTableException { + Assert.assertEquals( + "Should have 3 rows", 3L, scalarSql("SELECT count(*) FROM %s", selectTarget())); + + List data = ImmutableList.of(new SimpleRecord(4, "d"), new SimpleRecord(5, "e")); + Dataset ds = spark.createDataFrame(data, SimpleRecord.class); + + ds.writeTo(commitTarget()).append(); + + Assert.assertEquals( + "Should have 5 rows after insert", + 5L, + scalarSql("SELECT count(*) FROM %s", selectTarget())); + + List expected = + ImmutableList.of(row(1L, "a"), row(2L, "b"), row(3L, "c"), row(4L, "d"), row(5L, "e")); + + assertEquals( + "Row data should match expected", + expected, + sql("SELECT * FROM %s ORDER BY id", selectTarget())); + } + + @Test + public void testDataFrameV2DynamicOverwrite() throws NoSuchTableException { + Assert.assertEquals( + "Should have 3 rows", 3L, scalarSql("SELECT count(*) FROM %s", selectTarget())); + + List data = ImmutableList.of(new SimpleRecord(4, "d"), new SimpleRecord(5, "e")); + Dataset ds = spark.createDataFrame(data, SimpleRecord.class); + + ds.writeTo(commitTarget()).overwritePartitions(); + + Assert.assertEquals( + "Should have 4 rows after overwrite", + 4L, + scalarSql("SELECT count(*) FROM %s", selectTarget())); + + List expected = + ImmutableList.of(row(1L, "a"), row(2L, "b"), row(4L, "d"), row(5L, "e")); + + assertEquals( + "Row data should match expected", + expected, + sql("SELECT * FROM %s ORDER BY id", selectTarget())); + } + + @Test + public void testDataFrameV2Overwrite() throws NoSuchTableException { + Assert.assertEquals( + "Should have 3 rows", 3L, scalarSql("SELECT count(*) FROM %s", selectTarget())); + + List data = ImmutableList.of(new SimpleRecord(4, "d"), new SimpleRecord(5, "e")); + Dataset ds = spark.createDataFrame(data, SimpleRecord.class); + + ds.writeTo(commitTarget()).overwrite(functions.col("id").$less(3)); + + Assert.assertEquals( + "Should have 3 rows after overwrite", + 3L, + scalarSql("SELECT count(*) FROM %s", selectTarget())); + + List expected = ImmutableList.of(row(3L, "c"), row(4L, "d"), row(5L, "e")); + + assertEquals( + "Row data should match expected", + expected, + sql("SELECT * FROM %s ORDER BY id", selectTarget())); + } + + @Test + public void testViewsReturnRecentResults() { + Assert.assertEquals( + "Should have 3 rows", 3L, scalarSql("SELECT count(*) FROM %s", selectTarget())); + + Dataset query = spark.sql("SELECT * FROM " + commitTarget() + " WHERE id = 1"); + query.createOrReplaceTempView("tmp"); + + assertEquals( + "View should have expected rows", ImmutableList.of(row(1L, "a")), sql("SELECT * FROM tmp")); + + sql("INSERT INTO TABLE %s VALUES (1, 'a')", commitTarget()); + + assertEquals( + "View should have expected rows", + ImmutableList.of(row(1L, "a"), row(1L, "a")), + sql("SELECT * FROM tmp")); + } + + // Asserts whether the given table .partitions table has the expected rows. Note that the output + // row should have spec_id and it is sorted by spec_id and selectPartitionColumns. + protected void assertPartitionMetadata( + String tableName, List expected, String... selectPartitionColumns) { + String[] fullyQualifiedCols = + Arrays.stream(selectPartitionColumns).map(s -> "partition." + s).toArray(String[]::new); + Dataset actualPartitionRows = + spark + .read() + .format("iceberg") + .load(tableName + ".partitions") + .select("spec_id", fullyQualifiedCols) + .orderBy("spec_id", fullyQualifiedCols); + + assertEquals( + "There are 3 partitions, one with the original spec ID and two with the new one", + expected, + rowsToJava(actualPartitionRows.collectAsList())); + } + + @Test + public void testWriteWithOutputSpec() throws NoSuchTableException { + Table table = validationCatalog.loadTable(tableIdent); + + // Drop all records in table to have a fresh start. + table.newDelete().deleteFromRowFilter(Expressions.alwaysTrue()).commit(); + + final int originalSpecId = table.spec().specId(); + table.updateSpec().addField("data").commit(); + + // Refresh this when using SparkCatalog since otherwise the new spec would not be caught. + sql("REFRESH TABLE %s", tableName); + + // By default, we write to the current spec. + List data = ImmutableList.of(new SimpleRecord(10, "a")); + spark.createDataFrame(data, SimpleRecord.class).toDF().writeTo(tableName).append(); + + List expected = ImmutableList.of(row(10L, "a", table.spec().specId())); + assertEquals( + "Rows must match", + expected, + sql("SELECT id, data, _spec_id FROM %s WHERE id >= 10 ORDER BY id", tableName)); + + // Output spec ID should be respected when present. + data = ImmutableList.of(new SimpleRecord(11, "b"), new SimpleRecord(12, "c")); + spark + .createDataFrame(data, SimpleRecord.class) + .toDF() + .writeTo(tableName) + .option("output-spec-id", Integer.toString(originalSpecId)) + .append(); + + expected = + ImmutableList.of( + row(10L, "a", table.spec().specId()), + row(11L, "b", originalSpecId), + row(12L, "c", originalSpecId)); + assertEquals( + "Rows must match", + expected, + sql("SELECT id, data, _spec_id FROM %s WHERE id >= 10 ORDER BY id", tableName)); + + // Verify that the actual partitions are written with the correct spec ID. + // Two of the partitions should have the original spec ID and one should have the new one. + // TODO: WAP branch does not support reading partitions table, skip this check for now. + expected = + ImmutableList.of( + row(originalSpecId, 9L, null), + row(originalSpecId, 12L, null), + row(table.spec().specId(), 9L, "a")); + assertPartitionMetadata(tableName, expected, "id_trunc", "data"); + + // Even the default spec ID should be followed when present. + data = ImmutableList.of(new SimpleRecord(13, "d")); + spark + .createDataFrame(data, SimpleRecord.class) + .toDF() + .writeTo(tableName) + .option("output-spec-id", Integer.toString(table.spec().specId())) + .append(); + + expected = + ImmutableList.of( + row(10L, "a", table.spec().specId()), + row(11L, "b", originalSpecId), + row(12L, "c", originalSpecId), + row(13L, "d", table.spec().specId())); + assertEquals( + "Rows must match", + expected, + sql("SELECT id, data, _spec_id FROM %s WHERE id >= 10 ORDER BY id", tableName)); + } +} diff --git a/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/sql/TestAggregatePushDown.java b/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/sql/TestAggregatePushDown.java new file mode 100644 index 000000000000..37ae96a248ef --- /dev/null +++ b/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/sql/TestAggregatePushDown.java @@ -0,0 +1,682 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.sql; + +import java.math.BigDecimal; +import java.sql.Date; +import java.sql.Timestamp; +import java.util.List; +import java.util.Locale; +import java.util.Map; +import org.apache.iceberg.CatalogUtil; +import org.apache.iceberg.TableProperties; +import org.apache.iceberg.catalog.Namespace; +import org.apache.iceberg.exceptions.AlreadyExistsException; +import org.apache.iceberg.hive.HiveCatalog; +import org.apache.iceberg.hive.TestHiveMetastore; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap; +import org.apache.iceberg.relocated.com.google.common.collect.Lists; +import org.apache.iceberg.spark.SparkCatalogTestBase; +import org.apache.iceberg.spark.SparkTestBase; +import org.apache.spark.sql.SparkSession; +import org.junit.After; +import org.junit.Assert; +import org.junit.BeforeClass; +import org.junit.Test; + +public class TestAggregatePushDown extends SparkCatalogTestBase { + + public TestAggregatePushDown( + String catalogName, String implementation, Map config) { + super(catalogName, implementation, config); + } + + @BeforeClass + public static void startMetastoreAndSpark() { + SparkTestBase.metastore = new TestHiveMetastore(); + metastore.start(); + SparkTestBase.hiveConf = metastore.hiveConf(); + + SparkTestBase.spark = + SparkSession.builder() + .master("local[2]") + .config("spark.sql.iceberg.aggregate_pushdown", "true") + .enableHiveSupport() + .getOrCreate(); + + SparkTestBase.catalog = + (HiveCatalog) + CatalogUtil.loadCatalog( + HiveCatalog.class.getName(), "hive", ImmutableMap.of(), hiveConf); + + try { + catalog.createNamespace(Namespace.of("default")); + } catch (AlreadyExistsException ignored) { + // the default namespace already exists. ignore the create error + } + } + + @After + public void removeTables() { + sql("DROP TABLE IF EXISTS %s", tableName); + } + + @Test + public void testDifferentDataTypesAggregatePushDownInPartitionedTable() { + testDifferentDataTypesAggregatePushDown(true); + } + + @Test + public void testDifferentDataTypesAggregatePushDownInNonPartitionedTable() { + testDifferentDataTypesAggregatePushDown(false); + } + + @SuppressWarnings("checkstyle:CyclomaticComplexity") + private void testDifferentDataTypesAggregatePushDown(boolean hasPartitionCol) { + String createTable; + if (hasPartitionCol) { + createTable = + "CREATE TABLE %s (id LONG, int_data INT, boolean_data BOOLEAN, float_data FLOAT, double_data DOUBLE, " + + "decimal_data DECIMAL(14, 2), binary_data binary) USING iceberg PARTITIONED BY (id)"; + } else { + createTable = + "CREATE TABLE %s (id LONG, int_data INT, boolean_data BOOLEAN, float_data FLOAT, double_data DOUBLE, " + + "decimal_data DECIMAL(14, 2), binary_data binary) USING iceberg"; + } + + sql(createTable, tableName); + sql( + "INSERT INTO TABLE %s VALUES " + + "(1, null, false, null, null, 11.11, X'1111')," + + " (1, null, true, 2.222, 2.222222, 22.22, X'2222')," + + " (2, 33, false, 3.333, 3.333333, 33.33, X'3333')," + + " (2, 44, true, null, 4.444444, 44.44, X'4444')," + + " (3, 55, false, 5.555, 5.555555, 55.55, X'5555')," + + " (3, null, true, null, 6.666666, 66.66, null) ", + tableName); + + String select = + "SELECT count(*), max(id), min(id), count(id), " + + "max(int_data), min(int_data), count(int_data), " + + "max(boolean_data), min(boolean_data), count(boolean_data), " + + "max(float_data), min(float_data), count(float_data), " + + "max(double_data), min(double_data), count(double_data), " + + "max(decimal_data), min(decimal_data), count(decimal_data), " + + "max(binary_data), min(binary_data), count(binary_data) FROM %s"; + + List explain = sql("EXPLAIN " + select, tableName); + String explainString = explain.get(0)[0].toString().toLowerCase(Locale.ROOT); + boolean explainContainsPushDownAggregates = false; + if (explainString.contains("count(*)") + && explainString.contains("max(id)") + && explainString.contains("min(id)") + && explainString.contains("count(id)") + && explainString.contains("max(int_data)") + && explainString.contains("min(int_data)") + && explainString.contains("count(int_data)") + && explainString.contains("max(boolean_data)") + && explainString.contains("min(boolean_data)") + && explainString.contains("count(boolean_data)") + && explainString.contains("max(float_data)") + && explainString.contains("min(float_data)") + && explainString.contains("count(float_data)") + && explainString.contains("max(double_data)") + && explainString.contains("min(double_data)") + && explainString.contains("count(double_data)") + && explainString.contains("max(decimal_data)") + && explainString.contains("min(decimal_data)") + && explainString.contains("count(decimal_data)") + && explainString.contains("max(binary_data)") + && explainString.contains("min(binary_data)") + && explainString.contains("count(binary_data)")) { + explainContainsPushDownAggregates = true; + } + + Assert.assertTrue( + "explain should contain the pushed down aggregates", explainContainsPushDownAggregates); + + List actual = sql(select, tableName); + List expected = Lists.newArrayList(); + expected.add( + new Object[] { + 6L, + 3L, + 1L, + 6L, + 55, + 33, + 3L, + true, + false, + 6L, + 5.555f, + 2.222f, + 3L, + 6.666666, + 2.222222, + 5L, + new BigDecimal("66.66"), + new BigDecimal("11.11"), + 6L, + new byte[] {85, 85}, + new byte[] {17, 17}, + 5L + }); + assertEquals("min/max/count push down", expected, actual); + } + + @Test + public void testDateAndTimestampWithPartition() { + sql( + "CREATE TABLE %s (id bigint, data string, d date, ts timestamp) USING iceberg PARTITIONED BY (id)", + tableName); + sql( + "INSERT INTO %s VALUES (1, '1', date('2021-11-10'), null)," + + "(1, '2', date('2021-11-11'), timestamp('2021-11-11 22:22:22')), " + + "(2, '3', date('2021-11-12'), timestamp('2021-11-12 22:22:22')), " + + "(2, '4', date('2021-11-13'), timestamp('2021-11-13 22:22:22')), " + + "(3, '5', null, timestamp('2021-11-14 22:22:22')), " + + "(3, '6', date('2021-11-14'), null)", + tableName); + String select = "SELECT max(d), min(d), count(d), max(ts), min(ts), count(ts) FROM %s"; + + List explain = sql("EXPLAIN " + select, tableName); + String explainString = explain.get(0)[0].toString().toLowerCase(Locale.ROOT); + boolean explainContainsPushDownAggregates = false; + if (explainString.contains("max(d)") + && explainString.contains("min(d)") + && explainString.contains("count(d)") + && explainString.contains("max(ts)") + && explainString.contains("min(ts)") + && explainString.contains("count(ts)")) { + explainContainsPushDownAggregates = true; + } + + Assert.assertTrue( + "explain should contain the pushed down aggregates", explainContainsPushDownAggregates); + + List actual = sql(select, tableName); + List expected = Lists.newArrayList(); + expected.add( + new Object[] { + Date.valueOf("2021-11-14"), + Date.valueOf("2021-11-10"), + 5L, + Timestamp.valueOf("2021-11-14 22:22:22.0"), + Timestamp.valueOf("2021-11-11 22:22:22.0"), + 4L + }); + assertEquals("min/max/count push down", expected, actual); + } + + @Test + public void testAggregateNotPushDownIfOneCantPushDown() { + sql("CREATE TABLE %s (id LONG, data DOUBLE) USING iceberg", tableName); + sql( + "INSERT INTO TABLE %s VALUES (1, 1111), (1, 2222), (2, 3333), (2, 4444), (3, 5555), (3, 6666) ", + tableName); + String select = "SELECT COUNT(data), SUM(data) FROM %s"; + + List explain = sql("EXPLAIN " + select, tableName); + String explainString = explain.get(0)[0].toString().toLowerCase(Locale.ROOT); + boolean explainContainsPushDownAggregates = false; + if (explainString.contains("count(data)")) { + explainContainsPushDownAggregates = true; + } + + Assert.assertFalse( + "explain should not contain the pushed down aggregates", explainContainsPushDownAggregates); + + List actual = sql(select, tableName); + List expected = Lists.newArrayList(); + expected.add(new Object[] {6L, 23331.0}); + assertEquals("expected and actual should equal", expected, actual); + } + + @Test + public void testAggregatePushDownWithMetricsMode() { + sql("CREATE TABLE %s (id LONG, data DOUBLE) USING iceberg", tableName); + sql( + "ALTER TABLE %s SET TBLPROPERTIES('%s' '%s')", + tableName, TableProperties.DEFAULT_WRITE_METRICS_MODE, "none"); + sql( + "ALTER TABLE %s SET TBLPROPERTIES('%s' '%s')", + tableName, TableProperties.METRICS_MODE_COLUMN_CONF_PREFIX + "id", "counts"); + sql( + "ALTER TABLE %s SET TBLPROPERTIES('%s' '%s')", + tableName, TableProperties.METRICS_MODE_COLUMN_CONF_PREFIX + "data", "none"); + sql( + "INSERT INTO TABLE %s VALUES (1, 1111), (1, 2222), (2, 3333), (2, 4444), (3, 5555), (3, 6666)", + tableName); + + String select1 = "SELECT COUNT(data) FROM %s"; + + List explain1 = sql("EXPLAIN " + select1, tableName); + String explainString1 = explain1.get(0)[0].toString().toLowerCase(Locale.ROOT); + boolean explainContainsPushDownAggregates = false; + if (explainString1.contains("count(data)")) { + explainContainsPushDownAggregates = true; + } + + // count(data) is not pushed down because the metrics mode is `none` + Assert.assertFalse( + "explain should not contain the pushed down aggregates", explainContainsPushDownAggregates); + + List actual1 = sql(select1, tableName); + List expected1 = Lists.newArrayList(); + expected1.add(new Object[] {6L}); + assertEquals("expected and actual should equal", expected1, actual1); + + String select2 = "SELECT COUNT(id) FROM %s"; + List explain2 = sql("EXPLAIN " + select2, tableName); + String explainString2 = explain2.get(0)[0].toString().toLowerCase(Locale.ROOT); + if (explainString2.contains("count(id)")) { + explainContainsPushDownAggregates = true; + } + + // count(id) is pushed down because the metrics mode is `counts` + Assert.assertTrue( + "explain should contain the pushed down aggregates", explainContainsPushDownAggregates); + + List actual2 = sql(select2, tableName); + List expected2 = Lists.newArrayList(); + expected2.add(new Object[] {6L}); + assertEquals("expected and actual should equal", expected2, actual2); + + String select3 = "SELECT COUNT(id), MAX(id) FROM %s"; + explainContainsPushDownAggregates = false; + List explain3 = sql("EXPLAIN " + select3, tableName); + String explainString3 = explain3.get(0)[0].toString().toLowerCase(Locale.ROOT); + if (explainString3.contains("count(id)")) { + explainContainsPushDownAggregates = true; + } + + // COUNT(id), MAX(id) are not pushed down because MAX(id) is not pushed down (metrics mode is + // `counts`) + Assert.assertFalse( + "explain should not contain the pushed down aggregates", explainContainsPushDownAggregates); + + List actual3 = sql(select3, tableName); + List expected3 = Lists.newArrayList(); + expected3.add(new Object[] {6L, 3L}); + assertEquals("expected and actual should equal", expected3, actual3); + } + + @Test + public void testAggregateNotPushDownForStringType() { + sql("CREATE TABLE %s (id LONG, data STRING) USING iceberg", tableName); + sql( + "INSERT INTO TABLE %s VALUES (1, '1111'), (1, '2222'), (2, '3333'), (2, '4444'), (3, '5555'), (3, '6666') ", + tableName); + sql( + "ALTER TABLE %s SET TBLPROPERTIES('%s' '%s')", + tableName, TableProperties.DEFAULT_WRITE_METRICS_MODE, "truncate(16)"); + + String select1 = "SELECT MAX(id), MAX(data) FROM %s"; + + List explain1 = sql("EXPLAIN " + select1, tableName); + String explainString1 = explain1.get(0)[0].toString().toLowerCase(Locale.ROOT); + boolean explainContainsPushDownAggregates = false; + if (explainString1.contains("max(id)")) { + explainContainsPushDownAggregates = true; + } + + Assert.assertFalse( + "explain should not contain the pushed down aggregates", explainContainsPushDownAggregates); + + List actual1 = sql(select1, tableName); + List expected1 = Lists.newArrayList(); + expected1.add(new Object[] {3L, "6666"}); + assertEquals("expected and actual should equal", expected1, actual1); + + String select2 = "SELECT COUNT(data) FROM %s"; + List explain2 = sql("EXPLAIN " + select2, tableName); + String explainString2 = explain2.get(0)[0].toString().toLowerCase(Locale.ROOT); + if (explainString2.contains("count(data)")) { + explainContainsPushDownAggregates = true; + } + + Assert.assertTrue( + "explain should contain the pushed down aggregates", explainContainsPushDownAggregates); + + List actual2 = sql(select2, tableName); + List expected2 = Lists.newArrayList(); + expected2.add(new Object[] {6L}); + assertEquals("expected and actual should equal", expected2, actual2); + + explainContainsPushDownAggregates = false; + sql( + "ALTER TABLE %s SET TBLPROPERTIES('%s' '%s')", + tableName, TableProperties.DEFAULT_WRITE_METRICS_MODE, "full"); + String select3 = "SELECT count(data), max(data) FROM %s"; + List explain3 = sql("EXPLAIN " + select3, tableName); + String explainString3 = explain3.get(0)[0].toString().toLowerCase(Locale.ROOT); + if (explainString3.contains("count(data)") && explainString3.contains("max(data)")) { + explainContainsPushDownAggregates = true; + } + + Assert.assertTrue( + "explain should contain the pushed down aggregates", explainContainsPushDownAggregates); + + List actual3 = sql(select3, tableName); + List expected3 = Lists.newArrayList(); + expected3.add(new Object[] {6L, "6666"}); + assertEquals("expected and actual should equal", expected3, actual3); + } + + @Test + public void testAggregatePushDownWithDataFilter() { + testAggregatePushDownWithFilter(false); + } + + @Test + public void testAggregatePushDownWithPartitionFilter() { + testAggregatePushDownWithFilter(true); + } + + private void testAggregatePushDownWithFilter(boolean partitionFilerOnly) { + String createTable; + if (!partitionFilerOnly) { + createTable = "CREATE TABLE %s (id LONG, data INT) USING iceberg"; + } else { + createTable = "CREATE TABLE %s (id LONG, data INT) USING iceberg PARTITIONED BY (id)"; + } + + sql(createTable, tableName); + + sql( + "INSERT INTO TABLE %s VALUES" + + " (1, 11)," + + " (1, 22)," + + " (2, 33)," + + " (2, 44)," + + " (3, 55)," + + " (3, 66) ", + tableName); + + String select = "SELECT MIN(data) FROM %s WHERE id > 1"; + + List explain = sql("EXPLAIN " + select, tableName); + String explainString = explain.get(0)[0].toString().toLowerCase(Locale.ROOT); + boolean explainContainsPushDownAggregates = false; + if (explainString.contains("min(data)")) { + explainContainsPushDownAggregates = true; + } + + if (!partitionFilerOnly) { + // Filters are not completely pushed down, we can't push down aggregates + Assert.assertFalse( + "explain should not contain the pushed down aggregates", + explainContainsPushDownAggregates); + } else { + // Filters are not completely pushed down, we can push down aggregates + Assert.assertTrue( + "explain should contain the pushed down aggregates", explainContainsPushDownAggregates); + } + + List actual = sql(select, tableName); + List expected = Lists.newArrayList(); + expected.add(new Object[] {33}); + assertEquals("expected and actual should equal", expected, actual); + } + + @Test + public void testAggregateWithComplexType() { + sql("CREATE TABLE %s (id INT, complex STRUCT) USING iceberg", tableName); + sql( + "INSERT INTO TABLE %s VALUES (1, named_struct(\"c1\", 3, \"c2\", \"v1\"))," + + "(2, named_struct(\"c1\", 2, \"c2\", \"v2\")), (3, null)", + tableName); + String select1 = "SELECT count(complex), count(id) FROM %s"; + List explain = sql("EXPLAIN " + select1, tableName); + String explainString = explain.get(0)[0].toString().toLowerCase(Locale.ROOT); + boolean explainContainsPushDownAggregates = false; + if (explainString.contains("count(complex)")) { + explainContainsPushDownAggregates = true; + } + + Assert.assertFalse( + "count not pushed down for complex types", explainContainsPushDownAggregates); + + List actual = sql(select1, tableName); + List expected = Lists.newArrayList(); + expected.add(new Object[] {2L, 3L}); + assertEquals("count not push down", actual, expected); + + String select2 = "SELECT max(complex) FROM %s"; + explain = sql("EXPLAIN " + select2, tableName); + explainString = explain.get(0)[0].toString().toLowerCase(Locale.ROOT); + explainContainsPushDownAggregates = false; + if (explainString.contains("max(complex)")) { + explainContainsPushDownAggregates = true; + } + + Assert.assertFalse("max not pushed down for complex types", explainContainsPushDownAggregates); + } + + @Test + public void testAggregatePushDownInDeleteCopyOnWrite() { + sql("CREATE TABLE %s (id LONG, data INT) USING iceberg", tableName); + sql( + "INSERT INTO TABLE %s VALUES (1, 1111), (1, 2222), (2, 3333), (2, 4444), (3, 5555), (3, 6666) ", + tableName); + sql("DELETE FROM %s WHERE data = 1111", tableName); + String select = "SELECT max(data), min(data), count(data) FROM %s"; + + List explain = sql("EXPLAIN " + select, tableName); + String explainString = explain.get(0)[0].toString().toLowerCase(Locale.ROOT); + boolean explainContainsPushDownAggregates = false; + if (explainString.contains("max(data)") + && explainString.contains("min(data)") + && explainString.contains("count(data)")) { + explainContainsPushDownAggregates = true; + } + + Assert.assertTrue("min/max/count pushed down for deleted", explainContainsPushDownAggregates); + + List actual = sql(select, tableName); + List expected = Lists.newArrayList(); + expected.add(new Object[] {6666, 2222, 5L}); + assertEquals("min/max/count push down", expected, actual); + } + + @Test + public void testAggregatePushDownForTimeTravel() { + sql("CREATE TABLE %s (id LONG, data INT) USING iceberg", tableName); + sql( + "INSERT INTO TABLE %s VALUES (1, 1111), (1, 2222), (2, 3333), (2, 4444), (3, 5555), (3, 6666) ", + tableName); + + long snapshotId = validationCatalog.loadTable(tableIdent).currentSnapshot().snapshotId(); + List expected1 = sql("SELECT count(id) FROM %s", tableName); + + sql("INSERT INTO %s VALUES (4, 7777), (5, 8888)", tableName); + List expected2 = sql("SELECT count(id) FROM %s", tableName); + + List explain1 = + sql("EXPLAIN SELECT count(id) FROM %s VERSION AS OF %s", tableName, snapshotId); + String explainString1 = explain1.get(0)[0].toString().toLowerCase(Locale.ROOT); + boolean explainContainsPushDownAggregates1 = false; + if (explainString1.contains("count(id)")) { + explainContainsPushDownAggregates1 = true; + } + Assert.assertTrue("count pushed down", explainContainsPushDownAggregates1); + + List actual1 = + sql("SELECT count(id) FROM %s VERSION AS OF %s", tableName, snapshotId); + assertEquals("count push down", expected1, actual1); + + List explain2 = sql("EXPLAIN SELECT count(id) FROM %s", tableName); + String explainString2 = explain2.get(0)[0].toString().toLowerCase(Locale.ROOT); + boolean explainContainsPushDownAggregates2 = false; + if (explainString2.contains("count(id)")) { + explainContainsPushDownAggregates2 = true; + } + + Assert.assertTrue("count pushed down", explainContainsPushDownAggregates2); + + List actual2 = sql("SELECT count(id) FROM %s", tableName); + assertEquals("count push down", expected2, actual2); + } + + @Test + public void testAllNull() { + sql("CREATE TABLE %s (id int, data int) USING iceberg PARTITIONED BY (id)", tableName); + sql( + "INSERT INTO %s VALUES (1, null)," + + "(1, null), " + + "(2, null), " + + "(2, null), " + + "(3, null), " + + "(3, null)", + tableName); + String select = "SELECT count(*), max(data), min(data), count(data) FROM %s"; + + List explain = sql("EXPLAIN " + select, tableName); + String explainString = explain.get(0)[0].toString().toLowerCase(Locale.ROOT); + boolean explainContainsPushDownAggregates = false; + if (explainString.contains("max(data)") + && explainString.contains("min(data)") + && explainString.contains("count(data)")) { + explainContainsPushDownAggregates = true; + } + + Assert.assertTrue( + "explain should contain the pushed down aggregates", explainContainsPushDownAggregates); + + List actual = sql(select, tableName); + List expected = Lists.newArrayList(); + expected.add(new Object[] {6L, null, null, 0L}); + assertEquals("min/max/count push down", expected, actual); + } + + @Test + public void testAllNaN() { + sql("CREATE TABLE %s (id int, data float) USING iceberg PARTITIONED BY (id)", tableName); + sql( + "INSERT INTO %s VALUES (1, float('nan'))," + + "(1, float('nan')), " + + "(2, float('nan')), " + + "(2, float('nan')), " + + "(3, float('nan')), " + + "(3, float('nan'))", + tableName); + String select = "SELECT count(*), max(data), min(data), count(data) FROM %s"; + + List explain = sql("EXPLAIN " + select, tableName); + String explainString = explain.get(0)[0].toString().toLowerCase(Locale.ROOT); + boolean explainContainsPushDownAggregates = false; + if (explainString.contains("max(data)") + || explainString.contains("min(data)") + || explainString.contains("count(data)")) { + explainContainsPushDownAggregates = true; + } + + Assert.assertFalse( + "explain should not contain the pushed down aggregates", explainContainsPushDownAggregates); + + List actual = sql(select, tableName); + List expected = Lists.newArrayList(); + expected.add(new Object[] {6L, Float.NaN, Float.NaN, 6L}); + assertEquals("expected and actual should equal", expected, actual); + } + + @Test + public void testNaN() { + sql("CREATE TABLE %s (id int, data float) USING iceberg PARTITIONED BY (id)", tableName); + sql( + "INSERT INTO %s VALUES (1, float('nan'))," + + "(1, float('nan')), " + + "(2, 2), " + + "(2, float('nan')), " + + "(3, float('nan')), " + + "(3, 1)", + tableName); + String select = "SELECT count(*), max(data), min(data), count(data) FROM %s"; + + List explain = sql("EXPLAIN " + select, tableName); + String explainString = explain.get(0)[0].toString().toLowerCase(Locale.ROOT); + boolean explainContainsPushDownAggregates = false; + if (explainString.contains("max(data)") + || explainString.contains("min(data)") + || explainString.contains("count(data)")) { + explainContainsPushDownAggregates = true; + } + + Assert.assertFalse( + "explain should not contain the pushed down aggregates", explainContainsPushDownAggregates); + + List actual = sql(select, tableName); + List expected = Lists.newArrayList(); + expected.add(new Object[] {6L, Float.NaN, 1.0F, 6L}); + assertEquals("expected and actual should equal", expected, actual); + } + + @Test + public void testInfinity() { + sql( + "CREATE TABLE %s (id int, data1 float, data2 double, data3 double) USING iceberg PARTITIONED BY (id)", + tableName); + sql( + "INSERT INTO %s VALUES (1, float('-infinity'), double('infinity'), 1.23), " + + "(1, float('-infinity'), double('infinity'), -1.23), " + + "(1, float('-infinity'), double('infinity'), double('infinity')), " + + "(1, float('-infinity'), double('infinity'), 2.23), " + + "(1, float('-infinity'), double('infinity'), double('-infinity')), " + + "(1, float('-infinity'), double('infinity'), -2.23)", + tableName); + String select = + "SELECT count(*), max(data1), min(data1), count(data1), max(data2), min(data2), count(data2), max(data3), min(data3), count(data3) FROM %s"; + + List explain = sql("EXPLAIN " + select, tableName); + String explainString = explain.get(0)[0].toString().toLowerCase(Locale.ROOT); + boolean explainContainsPushDownAggregates = false; + if (explainString.contains("max(data1)") + && explainString.contains("min(data1)") + && explainString.contains("count(data1)") + && explainString.contains("max(data2)") + && explainString.contains("min(data2)") + && explainString.contains("count(data2)") + && explainString.contains("max(data3)") + && explainString.contains("min(data3)") + && explainString.contains("count(data3)")) { + explainContainsPushDownAggregates = true; + } + + Assert.assertTrue( + "explain should contain the pushed down aggregates", explainContainsPushDownAggregates); + + List actual = sql(select, tableName); + List expected = Lists.newArrayList(); + expected.add( + new Object[] { + 6L, + Float.NEGATIVE_INFINITY, + Float.NEGATIVE_INFINITY, + 6L, + Double.POSITIVE_INFINITY, + Double.POSITIVE_INFINITY, + 6L, + Double.POSITIVE_INFINITY, + Double.NEGATIVE_INFINITY, + 6L + }); + assertEquals("min/max/count push down", expected, actual); + } +} diff --git a/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/sql/TestAlterTable.java b/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/sql/TestAlterTable.java new file mode 100644 index 000000000000..e347cde7ba32 --- /dev/null +++ b/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/sql/TestAlterTable.java @@ -0,0 +1,331 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.sql; + +import java.util.Map; +import org.apache.iceberg.AssertHelpers; +import org.apache.iceberg.catalog.Namespace; +import org.apache.iceberg.catalog.TableIdentifier; +import org.apache.iceberg.hadoop.HadoopCatalog; +import org.apache.iceberg.spark.SparkCatalogTestBase; +import org.apache.iceberg.types.Types; +import org.apache.iceberg.types.Types.NestedField; +import org.apache.spark.SparkException; +import org.apache.spark.sql.AnalysisException; +import org.junit.After; +import org.junit.Assert; +import org.junit.Assume; +import org.junit.Before; +import org.junit.Test; + +public class TestAlterTable extends SparkCatalogTestBase { + private final TableIdentifier renamedIdent = + TableIdentifier.of(Namespace.of("default"), "table2"); + + public TestAlterTable(String catalogName, String implementation, Map config) { + super(catalogName, implementation, config); + } + + @Before + public void createTable() { + sql("CREATE TABLE %s (id bigint NOT NULL, data string) USING iceberg", tableName); + } + + @After + public void removeTable() { + sql("DROP TABLE IF EXISTS %s", tableName); + sql("DROP TABLE IF EXISTS %s2", tableName); + } + + @Test + public void testAddColumnNotNull() { + AssertHelpers.assertThrows( + "Should reject adding NOT NULL column", + SparkException.class, + "Incompatible change: cannot add required column", + () -> sql("ALTER TABLE %s ADD COLUMN c3 INT NOT NULL", tableName)); + } + + @Test + public void testAddColumn() { + sql( + "ALTER TABLE %s ADD COLUMN point struct AFTER id", + tableName); + + Types.StructType expectedSchema = + Types.StructType.of( + NestedField.required(1, "id", Types.LongType.get()), + NestedField.optional( + 3, + "point", + Types.StructType.of( + NestedField.required(4, "x", Types.DoubleType.get()), + NestedField.required(5, "y", Types.DoubleType.get()))), + NestedField.optional(2, "data", Types.StringType.get())); + + Assert.assertEquals( + "Schema should match expected", + expectedSchema, + validationCatalog.loadTable(tableIdent).schema().asStruct()); + + sql("ALTER TABLE %s ADD COLUMN point.z double COMMENT 'May be null' FIRST", tableName); + + Types.StructType expectedSchema2 = + Types.StructType.of( + NestedField.required(1, "id", Types.LongType.get()), + NestedField.optional( + 3, + "point", + Types.StructType.of( + NestedField.optional(6, "z", Types.DoubleType.get(), "May be null"), + NestedField.required(4, "x", Types.DoubleType.get()), + NestedField.required(5, "y", Types.DoubleType.get()))), + NestedField.optional(2, "data", Types.StringType.get())); + + Assert.assertEquals( + "Schema should match expected", + expectedSchema2, + validationCatalog.loadTable(tableIdent).schema().asStruct()); + } + + @Test + public void testAddColumnWithArray() { + sql("ALTER TABLE %s ADD COLUMN data2 array>", tableName); + // use the implicit column name 'element' to access member of array and add column d to struct. + sql("ALTER TABLE %s ADD COLUMN data2.element.d int", tableName); + Types.StructType expectedSchema = + Types.StructType.of( + NestedField.required(1, "id", Types.LongType.get()), + NestedField.optional(2, "data", Types.StringType.get()), + NestedField.optional( + 3, + "data2", + Types.ListType.ofOptional( + 4, + Types.StructType.of( + NestedField.optional(5, "a", Types.IntegerType.get()), + NestedField.optional(6, "b", Types.IntegerType.get()), + NestedField.optional(7, "c", Types.IntegerType.get()), + NestedField.optional(8, "d", Types.IntegerType.get()))))); + Assert.assertEquals( + "Schema should match expected", + expectedSchema, + validationCatalog.loadTable(tableIdent).schema().asStruct()); + } + + @Test + public void testAddColumnWithMap() { + sql("ALTER TABLE %s ADD COLUMN data2 map, struct>", tableName); + // use the implicit column name 'key' and 'value' to access member of map. + // add column to value struct column + sql("ALTER TABLE %s ADD COLUMN data2.value.c int", tableName); + Types.StructType expectedSchema = + Types.StructType.of( + NestedField.required(1, "id", Types.LongType.get()), + NestedField.optional(2, "data", Types.StringType.get()), + NestedField.optional( + 3, + "data2", + Types.MapType.ofOptional( + 4, + 5, + Types.StructType.of(NestedField.optional(6, "x", Types.IntegerType.get())), + Types.StructType.of( + NestedField.optional(7, "a", Types.IntegerType.get()), + NestedField.optional(8, "b", Types.IntegerType.get()), + NestedField.optional(9, "c", Types.IntegerType.get()))))); + Assert.assertEquals( + "Schema should match expected", + expectedSchema, + validationCatalog.loadTable(tableIdent).schema().asStruct()); + + // should not allow changing map key column + AssertHelpers.assertThrows( + "Should reject changing key of the map column", + SparkException.class, + "Unsupported table change: Cannot add fields to map keys:", + () -> sql("ALTER TABLE %s ADD COLUMN data2.key.y int", tableName)); + } + + @Test + public void testDropColumn() { + sql("ALTER TABLE %s DROP COLUMN data", tableName); + + Types.StructType expectedSchema = + Types.StructType.of(NestedField.required(1, "id", Types.LongType.get())); + + Assert.assertEquals( + "Schema should match expected", + expectedSchema, + validationCatalog.loadTable(tableIdent).schema().asStruct()); + } + + @Test + public void testRenameColumn() { + sql("ALTER TABLE %s RENAME COLUMN id TO row_id", tableName); + + Types.StructType expectedSchema = + Types.StructType.of( + NestedField.required(1, "row_id", Types.LongType.get()), + NestedField.optional(2, "data", Types.StringType.get())); + + Assert.assertEquals( + "Schema should match expected", + expectedSchema, + validationCatalog.loadTable(tableIdent).schema().asStruct()); + } + + @Test + public void testAlterColumnComment() { + sql("ALTER TABLE %s ALTER COLUMN id COMMENT 'Record id'", tableName); + + Types.StructType expectedSchema = + Types.StructType.of( + NestedField.required(1, "id", Types.LongType.get(), "Record id"), + NestedField.optional(2, "data", Types.StringType.get())); + + Assert.assertEquals( + "Schema should match expected", + expectedSchema, + validationCatalog.loadTable(tableIdent).schema().asStruct()); + } + + @Test + public void testAlterColumnType() { + sql("ALTER TABLE %s ADD COLUMN count int", tableName); + sql("ALTER TABLE %s ALTER COLUMN count TYPE bigint", tableName); + + Types.StructType expectedSchema = + Types.StructType.of( + NestedField.required(1, "id", Types.LongType.get()), + NestedField.optional(2, "data", Types.StringType.get()), + NestedField.optional(3, "count", Types.LongType.get())); + + Assert.assertEquals( + "Schema should match expected", + expectedSchema, + validationCatalog.loadTable(tableIdent).schema().asStruct()); + } + + @Test + public void testAlterColumnDropNotNull() { + sql("ALTER TABLE %s ALTER COLUMN id DROP NOT NULL", tableName); + + Types.StructType expectedSchema = + Types.StructType.of( + NestedField.optional(1, "id", Types.LongType.get()), + NestedField.optional(2, "data", Types.StringType.get())); + + Assert.assertEquals( + "Schema should match expected", + expectedSchema, + validationCatalog.loadTable(tableIdent).schema().asStruct()); + } + + @Test + public void testAlterColumnSetNotNull() { + // no-op changes are allowed + sql("ALTER TABLE %s ALTER COLUMN id SET NOT NULL", tableName); + + Types.StructType expectedSchema = + Types.StructType.of( + NestedField.required(1, "id", Types.LongType.get()), + NestedField.optional(2, "data", Types.StringType.get())); + + Assert.assertEquals( + "Schema should match expected", + expectedSchema, + validationCatalog.loadTable(tableIdent).schema().asStruct()); + + AssertHelpers.assertThrows( + "Should reject adding NOT NULL constraint to an optional column", + AnalysisException.class, + "Cannot change nullable column to non-nullable: data", + () -> sql("ALTER TABLE %s ALTER COLUMN data SET NOT NULL", tableName)); + } + + @Test + public void testAlterColumnPositionAfter() { + sql("ALTER TABLE %s ADD COLUMN count int", tableName); + sql("ALTER TABLE %s ALTER COLUMN count AFTER id", tableName); + + Types.StructType expectedSchema = + Types.StructType.of( + NestedField.required(1, "id", Types.LongType.get()), + NestedField.optional(3, "count", Types.IntegerType.get()), + NestedField.optional(2, "data", Types.StringType.get())); + + Assert.assertEquals( + "Schema should match expected", + expectedSchema, + validationCatalog.loadTable(tableIdent).schema().asStruct()); + } + + @Test + public void testAlterColumnPositionFirst() { + sql("ALTER TABLE %s ADD COLUMN count int", tableName); + sql("ALTER TABLE %s ALTER COLUMN count FIRST", tableName); + + Types.StructType expectedSchema = + Types.StructType.of( + NestedField.optional(3, "count", Types.IntegerType.get()), + NestedField.required(1, "id", Types.LongType.get()), + NestedField.optional(2, "data", Types.StringType.get())); + + Assert.assertEquals( + "Schema should match expected", + expectedSchema, + validationCatalog.loadTable(tableIdent).schema().asStruct()); + } + + @Test + public void testTableRename() { + Assume.assumeFalse( + "Hadoop catalog does not support rename", validationCatalog instanceof HadoopCatalog); + + Assert.assertTrue("Initial name should exist", validationCatalog.tableExists(tableIdent)); + Assert.assertFalse("New name should not exist", validationCatalog.tableExists(renamedIdent)); + + sql("ALTER TABLE %s RENAME TO %s2", tableName, tableName); + + Assert.assertFalse("Initial name should not exist", validationCatalog.tableExists(tableIdent)); + Assert.assertTrue("New name should exist", validationCatalog.tableExists(renamedIdent)); + } + + @Test + public void testSetTableProperties() { + sql("ALTER TABLE %s SET TBLPROPERTIES ('prop'='value')", tableName); + + Assert.assertEquals( + "Should have the new table property", + "value", + validationCatalog.loadTable(tableIdent).properties().get("prop")); + + sql("ALTER TABLE %s UNSET TBLPROPERTIES ('prop')", tableName); + + Assert.assertNull( + "Should not have the removed table property", + validationCatalog.loadTable(tableIdent).properties().get("prop")); + + AssertHelpers.assertThrows( + "Cannot specify the 'sort-order' because it's a reserved table property", + UnsupportedOperationException.class, + () -> sql("ALTER TABLE %s SET TBLPROPERTIES ('sort-order'='value')", tableName)); + } +} diff --git a/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/sql/TestCreateTable.java b/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/sql/TestCreateTable.java new file mode 100644 index 000000000000..1411c83ddc65 --- /dev/null +++ b/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/sql/TestCreateTable.java @@ -0,0 +1,344 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.sql; + +import java.io.File; +import java.util.Map; +import org.apache.iceberg.AssertHelpers; +import org.apache.iceberg.BaseTable; +import org.apache.iceberg.PartitionSpec; +import org.apache.iceberg.Schema; +import org.apache.iceberg.Table; +import org.apache.iceberg.TableOperations; +import org.apache.iceberg.TableProperties; +import org.apache.iceberg.hadoop.HadoopCatalog; +import org.apache.iceberg.spark.SparkCatalogTestBase; +import org.apache.iceberg.types.Types; +import org.apache.iceberg.types.Types.NestedField; +import org.apache.iceberg.types.Types.StructType; +import org.apache.spark.sql.connector.catalog.TableCatalog; +import org.junit.After; +import org.junit.Assert; +import org.junit.Assume; +import org.junit.Test; + +public class TestCreateTable extends SparkCatalogTestBase { + public TestCreateTable(String catalogName, String implementation, Map config) { + super(catalogName, implementation, config); + } + + @After + public void dropTestTable() { + sql("DROP TABLE IF EXISTS %s", tableName); + } + + @Test + public void testTransformIgnoreCase() { + Assert.assertFalse("Table should not already exist", validationCatalog.tableExists(tableIdent)); + sql( + "CREATE TABLE IF NOT EXISTS %s (id BIGINT NOT NULL, ts timestamp) " + + "USING iceberg partitioned by (HOURS(ts))", + tableName); + Assert.assertTrue("Table should already exist", validationCatalog.tableExists(tableIdent)); + sql( + "CREATE TABLE IF NOT EXISTS %s (id BIGINT NOT NULL, ts timestamp) " + + "USING iceberg partitioned by (hours(ts))", + tableName); + Assert.assertTrue("Table should already exist", validationCatalog.tableExists(tableIdent)); + } + + @Test + public void testCreateTable() { + Assert.assertFalse("Table should not already exist", validationCatalog.tableExists(tableIdent)); + + sql("CREATE TABLE %s (id BIGINT NOT NULL, data STRING) USING iceberg", tableName); + + Table table = validationCatalog.loadTable(tableIdent); + Assert.assertNotNull("Should load the new table", table); + + StructType expectedSchema = + StructType.of( + NestedField.required(1, "id", Types.LongType.get()), + NestedField.optional(2, "data", Types.StringType.get())); + Assert.assertEquals( + "Should have the expected schema", expectedSchema, table.schema().asStruct()); + Assert.assertEquals("Should not be partitioned", 0, table.spec().fields().size()); + Assert.assertNull( + "Should not have the default format set", + table.properties().get(TableProperties.DEFAULT_FILE_FORMAT)); + } + + @Test + public void testCreateTableInRootNamespace() { + Assume.assumeTrue( + "Hadoop has no default namespace configured", "testhadoop".equals(catalogName)); + + try { + sql("CREATE TABLE %s.table (id bigint) USING iceberg", catalogName); + } finally { + sql("DROP TABLE IF EXISTS %s.table", catalogName); + } + } + + @Test + public void testCreateTableUsingParquet() { + Assume.assumeTrue( + "Not working with session catalog because Spark will not use v2 for a Parquet table", + !"spark_catalog".equals(catalogName)); + + Assert.assertFalse("Table should not already exist", validationCatalog.tableExists(tableIdent)); + + sql("CREATE TABLE %s (id BIGINT NOT NULL, data STRING) USING parquet", tableName); + + Table table = validationCatalog.loadTable(tableIdent); + Assert.assertNotNull("Should load the new table", table); + + StructType expectedSchema = + StructType.of( + NestedField.required(1, "id", Types.LongType.get()), + NestedField.optional(2, "data", Types.StringType.get())); + Assert.assertEquals( + "Should have the expected schema", expectedSchema, table.schema().asStruct()); + Assert.assertEquals("Should not be partitioned", 0, table.spec().fields().size()); + Assert.assertEquals( + "Should not have default format parquet", + "parquet", + table.properties().get(TableProperties.DEFAULT_FILE_FORMAT)); + + AssertHelpers.assertThrows( + "Should reject unsupported format names", + IllegalArgumentException.class, + "Unsupported format in USING: crocodile", + () -> + sql( + "CREATE TABLE %s.default.fail (id BIGINT NOT NULL, data STRING) USING crocodile", + catalogName)); + } + + @Test + public void testCreateTablePartitionedBy() { + Assert.assertFalse("Table should not already exist", validationCatalog.tableExists(tableIdent)); + + sql( + "CREATE TABLE %s " + + "(id BIGINT NOT NULL, created_at TIMESTAMP, category STRING, data STRING) " + + "USING iceberg " + + "PARTITIONED BY (category, bucket(8, id), days(created_at))", + tableName); + + Table table = validationCatalog.loadTable(tableIdent); + Assert.assertNotNull("Should load the new table", table); + + StructType expectedSchema = + StructType.of( + NestedField.required(1, "id", Types.LongType.get()), + NestedField.optional(2, "created_at", Types.TimestampType.withZone()), + NestedField.optional(3, "category", Types.StringType.get()), + NestedField.optional(4, "data", Types.StringType.get())); + Assert.assertEquals( + "Should have the expected schema", expectedSchema, table.schema().asStruct()); + + PartitionSpec expectedSpec = + PartitionSpec.builderFor(new Schema(expectedSchema.fields())) + .identity("category") + .bucket("id", 8) + .day("created_at") + .build(); + Assert.assertEquals("Should be partitioned correctly", expectedSpec, table.spec()); + + Assert.assertNull( + "Should not have the default format set", + table.properties().get(TableProperties.DEFAULT_FILE_FORMAT)); + } + + @Test + public void testCreateTableColumnComments() { + Assert.assertFalse("Table should not already exist", validationCatalog.tableExists(tableIdent)); + + sql( + "CREATE TABLE %s " + + "(id BIGINT NOT NULL COMMENT 'Unique identifier', data STRING COMMENT 'Data value') " + + "USING iceberg", + tableName); + + Table table = validationCatalog.loadTable(tableIdent); + Assert.assertNotNull("Should load the new table", table); + + StructType expectedSchema = + StructType.of( + NestedField.required(1, "id", Types.LongType.get(), "Unique identifier"), + NestedField.optional(2, "data", Types.StringType.get(), "Data value")); + Assert.assertEquals( + "Should have the expected schema", expectedSchema, table.schema().asStruct()); + Assert.assertEquals("Should not be partitioned", 0, table.spec().fields().size()); + Assert.assertNull( + "Should not have the default format set", + table.properties().get(TableProperties.DEFAULT_FILE_FORMAT)); + } + + @Test + public void testCreateTableComment() { + Assert.assertFalse("Table should not already exist", validationCatalog.tableExists(tableIdent)); + + sql( + "CREATE TABLE %s " + + "(id BIGINT NOT NULL, data STRING) " + + "USING iceberg " + + "COMMENT 'Table doc'", + tableName); + + Table table = validationCatalog.loadTable(tableIdent); + Assert.assertNotNull("Should load the new table", table); + + StructType expectedSchema = + StructType.of( + NestedField.required(1, "id", Types.LongType.get()), + NestedField.optional(2, "data", Types.StringType.get())); + Assert.assertEquals( + "Should have the expected schema", expectedSchema, table.schema().asStruct()); + Assert.assertEquals("Should not be partitioned", 0, table.spec().fields().size()); + Assert.assertNull( + "Should not have the default format set", + table.properties().get(TableProperties.DEFAULT_FILE_FORMAT)); + Assert.assertEquals( + "Should have the table comment set in properties", + "Table doc", + table.properties().get(TableCatalog.PROP_COMMENT)); + } + + @Test + public void testCreateTableLocation() throws Exception { + Assume.assumeTrue( + "Cannot set custom locations for Hadoop catalog tables", + !(validationCatalog instanceof HadoopCatalog)); + + Assert.assertFalse("Table should not already exist", validationCatalog.tableExists(tableIdent)); + + File tableLocation = temp.newFolder(); + Assert.assertTrue(tableLocation.delete()); + + String location = "file:" + tableLocation.toString(); + + sql( + "CREATE TABLE %s " + + "(id BIGINT NOT NULL, data STRING) " + + "USING iceberg " + + "LOCATION '%s'", + tableName, location); + + Table table = validationCatalog.loadTable(tableIdent); + Assert.assertNotNull("Should load the new table", table); + + StructType expectedSchema = + StructType.of( + NestedField.required(1, "id", Types.LongType.get()), + NestedField.optional(2, "data", Types.StringType.get())); + Assert.assertEquals( + "Should have the expected schema", expectedSchema, table.schema().asStruct()); + Assert.assertEquals("Should not be partitioned", 0, table.spec().fields().size()); + Assert.assertNull( + "Should not have the default format set", + table.properties().get(TableProperties.DEFAULT_FILE_FORMAT)); + Assert.assertEquals("Should have a custom table location", location, table.location()); + } + + @Test + public void testCreateTableProperties() { + Assert.assertFalse("Table should not already exist", validationCatalog.tableExists(tableIdent)); + + sql( + "CREATE TABLE %s " + + "(id BIGINT NOT NULL, data STRING) " + + "USING iceberg " + + "TBLPROPERTIES (p1=2, p2='x')", + tableName); + + Table table = validationCatalog.loadTable(tableIdent); + Assert.assertNotNull("Should load the new table", table); + + StructType expectedSchema = + StructType.of( + NestedField.required(1, "id", Types.LongType.get()), + NestedField.optional(2, "data", Types.StringType.get())); + Assert.assertEquals( + "Should have the expected schema", expectedSchema, table.schema().asStruct()); + Assert.assertEquals("Should not be partitioned", 0, table.spec().fields().size()); + Assert.assertEquals("Should have property p1", "2", table.properties().get("p1")); + Assert.assertEquals("Should have property p2", "x", table.properties().get("p2")); + } + + @Test + public void testCreateTableWithFormatV2ThroughTableProperty() { + Assert.assertFalse("Table should not already exist", validationCatalog.tableExists(tableIdent)); + + sql( + "CREATE TABLE %s " + + "(id BIGINT NOT NULL, data STRING) " + + "USING iceberg " + + "TBLPROPERTIES ('format-version'='2')", + tableName); + + Table table = validationCatalog.loadTable(tableIdent); + Assert.assertEquals( + "should create table using format v2", + 2, + ((BaseTable) table).operations().current().formatVersion()); + } + + @Test + public void testUpgradeTableWithFormatV2ThroughTableProperty() { + Assert.assertFalse("Table should not already exist", validationCatalog.tableExists(tableIdent)); + + sql( + "CREATE TABLE %s " + + "(id BIGINT NOT NULL, data STRING) " + + "USING iceberg " + + "TBLPROPERTIES ('format-version'='1')", + tableName); + + Table table = validationCatalog.loadTable(tableIdent); + TableOperations ops = ((BaseTable) table).operations(); + Assert.assertEquals("should create table using format v1", 1, ops.refresh().formatVersion()); + + sql("ALTER TABLE %s SET TBLPROPERTIES ('format-version'='2')", tableName); + Assert.assertEquals("should update table to use format v2", 2, ops.refresh().formatVersion()); + } + + @Test + public void testDowngradeTableToFormatV1ThroughTablePropertyFails() { + Assert.assertFalse("Table should not already exist", validationCatalog.tableExists(tableIdent)); + + sql( + "CREATE TABLE %s " + + "(id BIGINT NOT NULL, data STRING) " + + "USING iceberg " + + "TBLPROPERTIES ('format-version'='2')", + tableName); + + Table table = validationCatalog.loadTable(tableIdent); + TableOperations ops = ((BaseTable) table).operations(); + Assert.assertEquals("should create table using format v2", 2, ops.refresh().formatVersion()); + + AssertHelpers.assertThrowsCause( + "should fail to downgrade to v1", + IllegalArgumentException.class, + "Cannot downgrade v2 table to v1", + () -> sql("ALTER TABLE %s SET TBLPROPERTIES ('format-version'='1')", tableName)); + } +} diff --git a/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/sql/TestCreateTableAsSelect.java b/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/sql/TestCreateTableAsSelect.java new file mode 100644 index 000000000000..2581c0fd3c56 --- /dev/null +++ b/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/sql/TestCreateTableAsSelect.java @@ -0,0 +1,416 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.sql; + +import static org.apache.spark.sql.functions.col; +import static org.apache.spark.sql.functions.lit; +import static org.apache.spark.sql.functions.when; + +import java.util.Map; +import org.apache.iceberg.PartitionSpec; +import org.apache.iceberg.Schema; +import org.apache.iceberg.Table; +import org.apache.iceberg.relocated.com.google.common.collect.Iterables; +import org.apache.iceberg.spark.SparkCatalogTestBase; +import org.apache.iceberg.types.Types; +import org.junit.After; +import org.junit.Assert; +import org.junit.Test; + +public class TestCreateTableAsSelect extends SparkCatalogTestBase { + + private final String sourceName; + + public TestCreateTableAsSelect( + String catalogName, String implementation, Map config) { + super(catalogName, implementation, config); + this.sourceName = tableName("source"); + + sql( + "CREATE TABLE IF NOT EXISTS %s (id bigint NOT NULL, data string) " + + "USING iceberg PARTITIONED BY (truncate(id, 3))", + sourceName); + sql("INSERT INTO %s VALUES (1, 'a'), (2, 'b'), (3, 'c')", sourceName); + } + + @After + public void removeTables() { + sql("DROP TABLE IF EXISTS %s", tableName); + } + + @Test + public void testUnpartitionedCTAS() { + sql("CREATE TABLE %s USING iceberg AS SELECT * FROM %s", tableName, sourceName); + + Schema expectedSchema = + new Schema( + Types.NestedField.optional(1, "id", Types.LongType.get()), + Types.NestedField.optional(2, "data", Types.StringType.get())); + + Table ctasTable = validationCatalog.loadTable(tableIdent); + + Assert.assertEquals( + "Should have expected nullable schema", + expectedSchema.asStruct(), + ctasTable.schema().asStruct()); + Assert.assertEquals("Should be an unpartitioned table", 0, ctasTable.spec().fields().size()); + assertEquals( + "Should have rows matching the source table", + sql("SELECT * FROM %s ORDER BY id", sourceName), + sql("SELECT * FROM %s ORDER BY id", tableName)); + } + + @Test + public void testPartitionedCTAS() { + sql( + "CREATE TABLE %s USING iceberg PARTITIONED BY (id) AS SELECT * FROM %s ORDER BY id", + tableName, sourceName); + + Schema expectedSchema = + new Schema( + Types.NestedField.optional(1, "id", Types.LongType.get()), + Types.NestedField.optional(2, "data", Types.StringType.get())); + + PartitionSpec expectedSpec = PartitionSpec.builderFor(expectedSchema).identity("id").build(); + + Table ctasTable = validationCatalog.loadTable(tableIdent); + + Assert.assertEquals( + "Should have expected nullable schema", + expectedSchema.asStruct(), + ctasTable.schema().asStruct()); + Assert.assertEquals("Should be partitioned by id", expectedSpec, ctasTable.spec()); + assertEquals( + "Should have rows matching the source table", + sql("SELECT * FROM %s ORDER BY id", sourceName), + sql("SELECT * FROM %s ORDER BY id", tableName)); + } + + @Test + public void testRTAS() { + sql( + "CREATE TABLE %s USING iceberg TBLPROPERTIES ('prop1'='val1', 'prop2'='val2')" + + "AS SELECT * FROM %s", + tableName, sourceName); + + assertEquals( + "Should have rows matching the source table", + sql("SELECT * FROM %s ORDER BY id", sourceName), + sql("SELECT * FROM %s ORDER BY id", tableName)); + + sql( + "REPLACE TABLE %s USING iceberg PARTITIONED BY (part) TBLPROPERTIES ('prop1'='newval1', 'prop3'='val3') AS " + + "SELECT id, data, CASE WHEN (id %% 2) = 0 THEN 'even' ELSE 'odd' END AS part " + + "FROM %s ORDER BY 3, 1", + tableName, sourceName); + + Schema expectedSchema = + new Schema( + Types.NestedField.optional(1, "id", Types.LongType.get()), + Types.NestedField.optional(2, "data", Types.StringType.get()), + Types.NestedField.optional(3, "part", Types.StringType.get())); + + PartitionSpec expectedSpec = + PartitionSpec.builderFor(expectedSchema).identity("part").withSpecId(1).build(); + + Table rtasTable = validationCatalog.loadTable(tableIdent); + + // the replacement table has a different schema and partition spec than the original + Assert.assertEquals( + "Should have expected nullable schema", + expectedSchema.asStruct(), + rtasTable.schema().asStruct()); + Assert.assertEquals("Should be partitioned by part", expectedSpec, rtasTable.spec()); + + assertEquals( + "Should have rows matching the source table", + sql( + "SELECT id, data, CASE WHEN (id %% 2) = 0 THEN 'even' ELSE 'odd' END AS part " + + "FROM %s ORDER BY id", + sourceName), + sql("SELECT * FROM %s ORDER BY id", tableName)); + + Assert.assertEquals( + "Table should have expected snapshots", 2, Iterables.size(rtasTable.snapshots())); + + Assert.assertEquals( + "Should have updated table property", "newval1", rtasTable.properties().get("prop1")); + Assert.assertEquals( + "Should have preserved table property", "val2", rtasTable.properties().get("prop2")); + Assert.assertEquals( + "Should have new table property", "val3", rtasTable.properties().get("prop3")); + } + + @Test + public void testCreateRTAS() { + sql( + "CREATE OR REPLACE TABLE %s USING iceberg PARTITIONED BY (part) AS " + + "SELECT id, data, CASE WHEN (id %% 2) = 0 THEN 'even' ELSE 'odd' END AS part " + + "FROM %s ORDER BY 3, 1", + tableName, sourceName); + + assertEquals( + "Should have rows matching the source table", + sql( + "SELECT id, data, CASE WHEN (id %% 2) = 0 THEN 'even' ELSE 'odd' END AS part " + + "FROM %s ORDER BY id", + sourceName), + sql("SELECT * FROM %s ORDER BY id", tableName)); + + sql( + "CREATE OR REPLACE TABLE %s USING iceberg PARTITIONED BY (part) AS " + + "SELECT 2 * id as id, data, CASE WHEN ((2 * id) %% 2) = 0 THEN 'even' ELSE 'odd' END AS part " + + "FROM %s ORDER BY 3, 1", + tableName, sourceName); + + Schema expectedSchema = + new Schema( + Types.NestedField.optional(1, "id", Types.LongType.get()), + Types.NestedField.optional(2, "data", Types.StringType.get()), + Types.NestedField.optional(3, "part", Types.StringType.get())); + + PartitionSpec expectedSpec = + PartitionSpec.builderFor(expectedSchema) + .identity("part") + .withSpecId(0) // the spec is identical and should be reused + .build(); + + Table rtasTable = validationCatalog.loadTable(tableIdent); + + // the replacement table has a different schema and partition spec than the original + Assert.assertEquals( + "Should have expected nullable schema", + expectedSchema.asStruct(), + rtasTable.schema().asStruct()); + Assert.assertEquals("Should be partitioned by part", expectedSpec, rtasTable.spec()); + + assertEquals( + "Should have rows matching the source table", + sql( + "SELECT 2 * id, data, CASE WHEN ((2 * id) %% 2) = 0 THEN 'even' ELSE 'odd' END AS part " + + "FROM %s ORDER BY id", + sourceName), + sql("SELECT * FROM %s ORDER BY id", tableName)); + + Assert.assertEquals( + "Table should have expected snapshots", 2, Iterables.size(rtasTable.snapshots())); + } + + @Test + public void testDataFrameV2Create() throws Exception { + spark.table(sourceName).writeTo(tableName).using("iceberg").create(); + + Schema expectedSchema = + new Schema( + Types.NestedField.optional(1, "id", Types.LongType.get()), + Types.NestedField.optional(2, "data", Types.StringType.get())); + + Table ctasTable = validationCatalog.loadTable(tableIdent); + + Assert.assertEquals( + "Should have expected nullable schema", + expectedSchema.asStruct(), + ctasTable.schema().asStruct()); + Assert.assertEquals("Should be an unpartitioned table", 0, ctasTable.spec().fields().size()); + assertEquals( + "Should have rows matching the source table", + sql("SELECT * FROM %s ORDER BY id", sourceName), + sql("SELECT * FROM %s ORDER BY id", tableName)); + } + + @Test + public void testDataFrameV2Replace() throws Exception { + spark.table(sourceName).writeTo(tableName).using("iceberg").create(); + + assertEquals( + "Should have rows matching the source table", + sql("SELECT * FROM %s ORDER BY id", sourceName), + sql("SELECT * FROM %s ORDER BY id", tableName)); + + spark + .table(sourceName) + .select( + col("id"), + col("data"), + when(col("id").mod(lit(2)).equalTo(lit(0)), lit("even")).otherwise("odd").as("part")) + .orderBy("part", "id") + .writeTo(tableName) + .partitionedBy(col("part")) + .using("iceberg") + .replace(); + + Schema expectedSchema = + new Schema( + Types.NestedField.optional(1, "id", Types.LongType.get()), + Types.NestedField.optional(2, "data", Types.StringType.get()), + Types.NestedField.optional(3, "part", Types.StringType.get())); + + PartitionSpec expectedSpec = + PartitionSpec.builderFor(expectedSchema).identity("part").withSpecId(1).build(); + + Table rtasTable = validationCatalog.loadTable(tableIdent); + + // the replacement table has a different schema and partition spec than the original + Assert.assertEquals( + "Should have expected nullable schema", + expectedSchema.asStruct(), + rtasTable.schema().asStruct()); + Assert.assertEquals("Should be partitioned by part", expectedSpec, rtasTable.spec()); + + assertEquals( + "Should have rows matching the source table", + sql( + "SELECT id, data, CASE WHEN (id %% 2) = 0 THEN 'even' ELSE 'odd' END AS part " + + "FROM %s ORDER BY id", + sourceName), + sql("SELECT * FROM %s ORDER BY id", tableName)); + + Assert.assertEquals( + "Table should have expected snapshots", 2, Iterables.size(rtasTable.snapshots())); + } + + @Test + public void testDataFrameV2CreateOrReplace() { + spark + .table(sourceName) + .select( + col("id"), + col("data"), + when(col("id").mod(lit(2)).equalTo(lit(0)), lit("even")).otherwise("odd").as("part")) + .orderBy("part", "id") + .writeTo(tableName) + .partitionedBy(col("part")) + .using("iceberg") + .createOrReplace(); + + assertEquals( + "Should have rows matching the source table", + sql( + "SELECT id, data, CASE WHEN (id %% 2) = 0 THEN 'even' ELSE 'odd' END AS part " + + "FROM %s ORDER BY id", + sourceName), + sql("SELECT * FROM %s ORDER BY id", tableName)); + + spark + .table(sourceName) + .select(col("id").multiply(lit(2)).as("id"), col("data")) + .select( + col("id"), + col("data"), + when(col("id").mod(lit(2)).equalTo(lit(0)), lit("even")).otherwise("odd").as("part")) + .orderBy("part", "id") + .writeTo(tableName) + .partitionedBy(col("part")) + .using("iceberg") + .createOrReplace(); + + Schema expectedSchema = + new Schema( + Types.NestedField.optional(1, "id", Types.LongType.get()), + Types.NestedField.optional(2, "data", Types.StringType.get()), + Types.NestedField.optional(3, "part", Types.StringType.get())); + + PartitionSpec expectedSpec = + PartitionSpec.builderFor(expectedSchema) + .identity("part") + .withSpecId(0) // the spec is identical and should be reused + .build(); + + Table rtasTable = validationCatalog.loadTable(tableIdent); + + // the replacement table has a different schema and partition spec than the original + Assert.assertEquals( + "Should have expected nullable schema", + expectedSchema.asStruct(), + rtasTable.schema().asStruct()); + Assert.assertEquals("Should be partitioned by part", expectedSpec, rtasTable.spec()); + + assertEquals( + "Should have rows matching the source table", + sql( + "SELECT 2 * id, data, CASE WHEN ((2 * id) %% 2) = 0 THEN 'even' ELSE 'odd' END AS part " + + "FROM %s ORDER BY id", + sourceName), + sql("SELECT * FROM %s ORDER BY id", tableName)); + + Assert.assertEquals( + "Table should have expected snapshots", 2, Iterables.size(rtasTable.snapshots())); + } + + @Test + public void testCreateRTASWithPartitionSpecChanging() { + sql( + "CREATE OR REPLACE TABLE %s USING iceberg PARTITIONED BY (part) AS " + + "SELECT id, data, CASE WHEN (id %% 2) = 0 THEN 'even' ELSE 'odd' END AS part " + + "FROM %s ORDER BY 3, 1", + tableName, sourceName); + + Table rtasTable = validationCatalog.loadTable(tableIdent); + + assertEquals( + "Should have rows matching the source table", + sql( + "SELECT id, data, CASE WHEN (id %% 2) = 0 THEN 'even' ELSE 'odd' END AS part " + + "FROM %s ORDER BY id", + sourceName), + sql("SELECT * FROM %s ORDER BY id", tableName)); + + // Change the partitioning of the table + rtasTable.updateSpec().removeField("part").commit(); // Spec 1 + + sql( + "CREATE OR REPLACE TABLE %s USING iceberg PARTITIONED BY (part, id) AS " + + "SELECT 2 * id as id, data, CASE WHEN ((2 * id) %% 2) = 0 THEN 'even' ELSE 'odd' END AS part " + + "FROM %s ORDER BY 3, 1", + tableName, sourceName); + + Schema expectedSchema = + new Schema( + Types.NestedField.optional(1, "id", Types.LongType.get()), + Types.NestedField.optional(2, "data", Types.StringType.get()), + Types.NestedField.optional(3, "part", Types.StringType.get())); + + PartitionSpec expectedSpec = + PartitionSpec.builderFor(expectedSchema) + .alwaysNull("part", "part_1000") + .identity("part") + .identity("id") + .withSpecId(2) // The Spec is new + .build(); + + Assert.assertEquals("Should be partitioned by part and id", expectedSpec, rtasTable.spec()); + + // the replacement table has a different schema and partition spec than the original + Assert.assertEquals( + "Should have expected nullable schema", + expectedSchema.asStruct(), + rtasTable.schema().asStruct()); + + assertEquals( + "Should have rows matching the source table", + sql( + "SELECT 2 * id, data, CASE WHEN ((2 * id) %% 2) = 0 THEN 'even' ELSE 'odd' END AS part " + + "FROM %s ORDER BY id", + sourceName), + sql("SELECT * FROM %s ORDER BY id", tableName)); + + Assert.assertEquals( + "Table should have expected snapshots", 2, Iterables.size(rtasTable.snapshots())); + } +} diff --git a/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/sql/TestDeleteFrom.java b/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/sql/TestDeleteFrom.java new file mode 100644 index 000000000000..cae1901aa713 --- /dev/null +++ b/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/sql/TestDeleteFrom.java @@ -0,0 +1,170 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.sql; + +import java.util.List; +import java.util.Map; +import org.apache.iceberg.AssertHelpers; +import org.apache.iceberg.Table; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableList; +import org.apache.iceberg.relocated.com.google.common.collect.Iterables; +import org.apache.iceberg.relocated.com.google.common.collect.Lists; +import org.apache.iceberg.spark.SparkCatalogTestBase; +import org.apache.iceberg.spark.source.SimpleRecord; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.catalyst.analysis.NoSuchTableException; +import org.junit.After; +import org.junit.Assert; +import org.junit.Test; + +public class TestDeleteFrom extends SparkCatalogTestBase { + public TestDeleteFrom(String catalogName, String implementation, Map config) { + super(catalogName, implementation, config); + } + + @After + public void removeTables() { + sql("DROP TABLE IF EXISTS %s", tableName); + } + + @Test + public void testDeleteFromUnpartitionedTable() throws NoSuchTableException { + sql("CREATE TABLE %s (id bigint, data string) USING iceberg", tableName); + + List records = + Lists.newArrayList( + new SimpleRecord(1, "a"), new SimpleRecord(2, "b"), new SimpleRecord(3, "c")); + Dataset df = spark.createDataFrame(records, SimpleRecord.class); + df.coalesce(1).writeTo(tableName).append(); + + assertEquals( + "Should have expected rows", + ImmutableList.of(row(1L, "a"), row(2L, "b"), row(3L, "c")), + sql("SELECT * FROM %s ORDER BY id", tableName)); + + sql("DELETE FROM %s WHERE id < 2", tableName); + + assertEquals( + "Should have no rows after successful delete", + ImmutableList.of(row(2L, "b"), row(3L, "c")), + sql("SELECT * FROM %s ORDER BY id", tableName)); + + sql("DELETE FROM %s WHERE id < 4", tableName); + + assertEquals( + "Should have no rows after successful delete", + ImmutableList.of(), + sql("SELECT * FROM %s ORDER BY id", tableName)); + } + + @Test + public void testDeleteFromTableAtSnapshot() throws NoSuchTableException { + sql("CREATE TABLE %s (id bigint, data string) USING iceberg", tableName); + + List records = + Lists.newArrayList( + new SimpleRecord(1, "a"), new SimpleRecord(2, "b"), new SimpleRecord(3, "c")); + Dataset df = spark.createDataFrame(records, SimpleRecord.class); + df.coalesce(1).writeTo(tableName).append(); + + long snapshotId = validationCatalog.loadTable(tableIdent).currentSnapshot().snapshotId(); + String prefix = "snapshot_id_"; + AssertHelpers.assertThrows( + "Should not be able to delete from a table at a specific snapshot", + IllegalArgumentException.class, + "Cannot delete from table at a specific snapshot", + () -> sql("DELETE FROM %s.%s WHERE id < 4", tableName, prefix + snapshotId)); + } + + @Test + public void testDeleteFromPartitionedTable() throws NoSuchTableException { + sql( + "CREATE TABLE %s (id bigint, data string) " + + "USING iceberg " + + "PARTITIONED BY (truncate(id, 2))", + tableName); + + List records = + Lists.newArrayList( + new SimpleRecord(1, "a"), new SimpleRecord(2, "b"), new SimpleRecord(3, "c")); + Dataset df = spark.createDataFrame(records, SimpleRecord.class); + df.coalesce(1).writeTo(tableName).append(); + + assertEquals( + "Should have 3 rows in 2 partitions", + ImmutableList.of(row(1L, "a"), row(2L, "b"), row(3L, "c")), + sql("SELECT * FROM %s ORDER BY id", tableName)); + + sql("DELETE FROM %s WHERE id > 2", tableName); + assertEquals( + "Should have two rows in the second partition", + ImmutableList.of(row(1L, "a"), row(2L, "b")), + sql("SELECT * FROM %s ORDER BY id", tableName)); + + sql("DELETE FROM %s WHERE id < 2", tableName); + + assertEquals( + "Should have two rows in the second partition", + ImmutableList.of(row(2L, "b")), + sql("SELECT * FROM %s ORDER BY id", tableName)); + } + + @Test + public void testDeleteFromWhereFalse() { + sql("CREATE TABLE %s (id bigint NOT NULL, data string) USING iceberg", tableName); + sql("INSERT INTO TABLE %s VALUES (1, 'a'), (2, 'b'), (3, 'c')", tableName); + + assertEquals( + "Should have expected rows", + ImmutableList.of(row(1L, "a"), row(2L, "b"), row(3L, "c")), + sql("SELECT * FROM %s ORDER BY id", tableName)); + + Table table = validationCatalog.loadTable(tableIdent); + Assert.assertEquals("Should have 1 snapshot", 1, Iterables.size(table.snapshots())); + + sql("DELETE FROM %s WHERE false", tableName); + + table.refresh(); + + Assert.assertEquals( + "Delete should not produce a new snapshot", 1, Iterables.size(table.snapshots())); + } + + @Test + public void testTruncate() { + sql("CREATE TABLE %s (id bigint NOT NULL, data string) USING iceberg", tableName); + sql("INSERT INTO TABLE %s VALUES (1, 'a'), (2, 'b'), (3, 'c')", tableName); + + assertEquals( + "Should have expected rows", + ImmutableList.of(row(1L, "a"), row(2L, "b"), row(3L, "c")), + sql("SELECT * FROM %s ORDER BY id", tableName)); + + Table table = validationCatalog.loadTable(tableIdent); + Assert.assertEquals("Should have 1 snapshot", 1, Iterables.size(table.snapshots())); + + sql("TRUNCATE TABLE %s", tableName); + + assertEquals( + "Should have expected rows", + ImmutableList.of(), + sql("SELECT * FROM %s ORDER BY id", tableName)); + } +} diff --git a/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/sql/TestDropTable.java b/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/sql/TestDropTable.java new file mode 100644 index 000000000000..34b6899a1c08 --- /dev/null +++ b/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/sql/TestDropTable.java @@ -0,0 +1,158 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.sql; + +import java.io.IOException; +import java.util.List; +import java.util.Map; +import java.util.stream.Collectors; +import org.apache.hadoop.fs.FileSystem; +import org.apache.hadoop.fs.Path; +import org.apache.iceberg.AssertHelpers; +import org.apache.iceberg.MetadataTableType; +import org.apache.iceberg.exceptions.ValidationException; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableList; +import org.apache.iceberg.relocated.com.google.common.collect.Streams; +import org.apache.iceberg.spark.SparkCatalogTestBase; +import org.junit.After; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Ignore; +import org.junit.Test; + +public class TestDropTable extends SparkCatalogTestBase { + + public TestDropTable(String catalogName, String implementation, Map config) { + super(catalogName, implementation, config); + } + + @Before + public void createTable() { + sql("CREATE TABLE %s (id INT, name STRING) USING iceberg", tableName); + sql("INSERT INTO %s VALUES (1, 'test')", tableName); + } + + @After + public void removeTable() throws IOException { + sql("DROP TABLE IF EXISTS %s", tableName); + } + + @Test + public void testDropTable() throws IOException { + dropTableInternal(); + } + + @Test + public void testDropTableGCDisabled() throws IOException { + sql("ALTER TABLE %s SET TBLPROPERTIES (gc.enabled = false)", tableName); + dropTableInternal(); + } + + private void dropTableInternal() throws IOException { + assertEquals( + "Should have expected rows", + ImmutableList.of(row(1, "test")), + sql("SELECT * FROM %s", tableName)); + + List manifestAndFiles = manifestsAndFiles(); + Assert.assertEquals( + "There should be 2 files for manifests and files", 2, manifestAndFiles.size()); + Assert.assertTrue("All files should be existed", checkFilesExist(manifestAndFiles, true)); + + sql("DROP TABLE %s", tableName); + Assert.assertFalse("Table should not exist", validationCatalog.tableExists(tableIdent)); + + if (catalogName.equals("testhadoop")) { + // HadoopCatalog drop table without purge will delete the base table location. + Assert.assertTrue("All files should be deleted", checkFilesExist(manifestAndFiles, false)); + } else { + Assert.assertTrue("All files should not be deleted", checkFilesExist(manifestAndFiles, true)); + } + } + + // TODO: enable once SPARK-43203 is fixed + @Ignore + public void testPurgeTable() throws IOException { + assertEquals( + "Should have expected rows", + ImmutableList.of(row(1, "test")), + sql("SELECT * FROM %s", tableName)); + + List manifestAndFiles = manifestsAndFiles(); + Assert.assertEquals( + "There should be 2 files for manifests and files", 2, manifestAndFiles.size()); + Assert.assertTrue("All files should exist", checkFilesExist(manifestAndFiles, true)); + + sql("DROP TABLE %s PURGE", tableName); + Assert.assertFalse("Table should not exist", validationCatalog.tableExists(tableIdent)); + Assert.assertTrue("All files should be deleted", checkFilesExist(manifestAndFiles, false)); + } + + // TODO: enable once SPARK-43203 is fixed + @Ignore + public void testPurgeTableGCDisabled() throws IOException { + sql("ALTER TABLE %s SET TBLPROPERTIES (gc.enabled = false)", tableName); + + assertEquals( + "Should have expected rows", + ImmutableList.of(row(1, "test")), + sql("SELECT * FROM %s", tableName)); + + List manifestAndFiles = manifestsAndFiles(); + Assert.assertEquals( + "There totally should have 2 files for manifests and files", 2, manifestAndFiles.size()); + Assert.assertTrue("All files should be existed", checkFilesExist(manifestAndFiles, true)); + + AssertHelpers.assertThrows( + "Purge table is not allowed when GC is disabled", + ValidationException.class, + "Cannot purge table: GC is disabled (deleting files may corrupt other tables", + () -> sql("DROP TABLE %s PURGE", tableName)); + + Assert.assertTrue("Table should not been dropped", validationCatalog.tableExists(tableIdent)); + Assert.assertTrue("All files should not be deleted", checkFilesExist(manifestAndFiles, true)); + } + + private List manifestsAndFiles() { + List files = sql("SELECT file_path FROM %s.%s", tableName, MetadataTableType.FILES); + List manifests = + sql("SELECT path FROM %s.%s", tableName, MetadataTableType.MANIFESTS); + return Streams.concat(files.stream(), manifests.stream()) + .map(row -> (String) row[0]) + .collect(Collectors.toList()); + } + + private boolean checkFilesExist(List files, boolean shouldExist) throws IOException { + boolean mask = !shouldExist; + if (files.isEmpty()) { + return mask; + } + + FileSystem fs = new Path(files.get(0)).getFileSystem(hiveConf); + return files.stream() + .allMatch( + file -> { + try { + return fs.exists(new Path(file)) ^ mask; + } catch (IOException e) { + throw new RuntimeException(e); + } + }); + } +} diff --git a/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/sql/TestFilterPushDown.java b/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/sql/TestFilterPushDown.java new file mode 100644 index 000000000000..0ea34e187f1d --- /dev/null +++ b/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/sql/TestFilterPushDown.java @@ -0,0 +1,556 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.sql; + +import java.sql.Timestamp; +import java.time.Instant; +import java.util.List; +import org.apache.iceberg.Table; +import org.apache.iceberg.expressions.Expressions; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableList; +import org.apache.iceberg.spark.SparkTestBaseWithCatalog; +import org.apache.spark.sql.execution.SparkPlan; +import org.assertj.core.api.Assertions; +import org.junit.After; +import org.junit.Test; + +public class TestFilterPushDown extends SparkTestBaseWithCatalog { + + @After + public void removeTables() { + sql("DROP TABLE IF EXISTS %s", tableName); + sql("DROP TABLE IF EXISTS tmp_view"); + } + + @Test + public void testFilterPushdownWithIdentityTransform() { + sql( + "CREATE TABLE %s (id INT, salary INT, dep STRING)" + + "USING iceberg " + + "PARTITIONED BY (dep)", + tableName); + + sql("INSERT INTO %s VALUES (1, 100, 'd1')", tableName); + sql("INSERT INTO %s VALUES (2, 200, 'd2')", tableName); + sql("INSERT INTO %s VALUES (3, 300, 'd3')", tableName); + sql("INSERT INTO %s VALUES (4, 400, 'd4')", tableName); + sql("INSERT INTO %s VALUES (5, 500, 'd5')", tableName); + sql("INSERT INTO %s VALUES (6, 600, null)", tableName); + + checkOnlyIcebergFilters( + "dep IS NULL" /* query predicate */, + "dep IS NULL" /* Iceberg scan filters */, + ImmutableList.of(row(6, 600, null))); + + checkOnlyIcebergFilters( + "dep IS NOT NULL" /* query predicate */, + "dep IS NOT NULL" /* Iceberg scan filters */, + ImmutableList.of( + row(1, 100, "d1"), + row(2, 200, "d2"), + row(3, 300, "d3"), + row(4, 400, "d4"), + row(5, 500, "d5"))); + + checkOnlyIcebergFilters( + "dep = 'd3'" /* query predicate */, + "dep IS NOT NULL, dep = 'd3'" /* Iceberg scan filters */, + ImmutableList.of(row(3, 300, "d3"))); + + checkOnlyIcebergFilters( + "dep > 'd3'" /* query predicate */, + "dep IS NOT NULL, dep > 'd3'" /* Iceberg scan filters */, + ImmutableList.of(row(4, 400, "d4"), row(5, 500, "d5"))); + + checkOnlyIcebergFilters( + "dep >= 'd5'" /* query predicate */, + "dep IS NOT NULL, dep >= 'd5'" /* Iceberg scan filters */, + ImmutableList.of(row(5, 500, "d5"))); + + checkOnlyIcebergFilters( + "dep < 'd2'" /* query predicate */, + "dep IS NOT NULL, dep < 'd2'" /* Iceberg scan filters */, + ImmutableList.of(row(1, 100, "d1"))); + + checkOnlyIcebergFilters( + "dep <= 'd2'" /* query predicate */, + "dep IS NOT NULL, dep <= 'd2'" /* Iceberg scan filters */, + ImmutableList.of(row(1, 100, "d1"), row(2, 200, "d2"))); + + checkOnlyIcebergFilters( + "dep <=> 'd3'" /* query predicate */, + "dep = 'd3'" /* Iceberg scan filters */, + ImmutableList.of(row(3, 300, "d3"))); + + checkOnlyIcebergFilters( + "dep IN (null, 'd1')" /* query predicate */, + "dep IN ('d1')" /* Iceberg scan filters */, + ImmutableList.of(row(1, 100, "d1"))); + + checkOnlyIcebergFilters( + "dep NOT IN ('d2', 'd4')" /* query predicate */, + "(dep IS NOT NULL AND dep NOT IN ('d2', 'd4'))" /* Iceberg scan filters */, + ImmutableList.of(row(1, 100, "d1"), row(3, 300, "d3"), row(5, 500, "d5"))); + + checkOnlyIcebergFilters( + "dep = 'd1' AND dep IS NOT NULL" /* query predicate */, + "dep = 'd1', dep IS NOT NULL" /* Iceberg scan filters */, + ImmutableList.of(row(1, 100, "d1"))); + + checkOnlyIcebergFilters( + "dep = 'd1' OR dep = 'd2' OR dep = 'd3'" /* query predicate */, + "((dep = 'd1' OR dep = 'd2') OR dep = 'd3')" /* Iceberg scan filters */, + ImmutableList.of(row(1, 100, "d1"), row(2, 200, "d2"), row(3, 300, "d3"))); + + checkFilters( + "dep = 'd1' AND id = 1" /* query predicate */, + "isnotnull(id) AND (id = 1)" /* Spark post scan filter */, + "dep IS NOT NULL, id IS NOT NULL, dep = 'd1', id = 1" /* Iceberg scan filters */, + ImmutableList.of(row(1, 100, "d1"))); + + checkFilters( + "dep = 'd2' OR id = 1" /* query predicate */, + "(dep = d2) OR (id = 1)" /* Spark post scan filter */, + "(dep = 'd2' OR id = 1)" /* Iceberg scan filters */, + ImmutableList.of(row(1, 100, "d1"), row(2, 200, "d2"))); + + checkFilters( + "dep LIKE 'd1%' AND id = 1" /* query predicate */, + "isnotnull(id) AND (id = 1)" /* Spark post scan filter */, + "dep IS NOT NULL, id IS NOT NULL, dep LIKE 'd1%', id = 1" /* Iceberg scan filters */, + ImmutableList.of(row(1, 100, "d1"))); + + checkFilters( + "dep NOT LIKE 'd5%' AND (id = 1 OR id = 5)" /* query predicate */, + "(id = 1) OR (id = 5)" /* Spark post scan filter */, + "dep IS NOT NULL, NOT (dep LIKE 'd5%'), (id = 1 OR id = 5)" /* Iceberg scan filters */, + ImmutableList.of(row(1, 100, "d1"))); + + checkFilters( + "dep LIKE '%d5' AND id IN (1, 5)" /* query predicate */, + "EndsWith(dep, d5) AND id IN (1,5)" /* Spark post scan filter */, + "dep IS NOT NULL, id IN (1, 5)" /* Iceberg scan filters */, + ImmutableList.of(row(5, 500, "d5"))); + } + + @Test + public void testFilterPushdownWithHoursTransform() { + sql( + "CREATE TABLE %s (id INT, price INT, t TIMESTAMP)" + + "USING iceberg " + + "PARTITIONED BY (hours(t))", + tableName); + + sql("INSERT INTO %s VALUES (1, 100, TIMESTAMP '2021-06-30T01:00:00.000Z')", tableName); + sql("INSERT INTO %s VALUES (2, 200, TIMESTAMP '2021-06-30T02:00:00.000Z')", tableName); + sql("INSERT INTO %s VALUES (3, 300, null)", tableName); + + withDefaultTimeZone( + "UTC", + () -> { + checkOnlyIcebergFilters( + "t IS NULL" /* query predicate */, + "t IS NULL" /* Iceberg scan filters */, + ImmutableList.of(row(3, 300, null))); + + // strict/inclusive projections for t < TIMESTAMP '2021-06-30T02:00:00.000Z' are equal, + // so this filter selects entire partitions and can be pushed down completely + checkOnlyIcebergFilters( + "t < TIMESTAMP '2021-06-30T02:00:00.000Z'" /* query predicate */, + "t IS NOT NULL, t < 1625018400000000" /* Iceberg scan filters */, + ImmutableList.of(row(1, 100, timestamp("2021-06-30T01:00:00.0Z")))); + + // strict/inclusive projections for t < TIMESTAMP '2021-06-30T01:00:00.001Z' differ, + // so this filter does NOT select entire partitions and can't be pushed down completely + checkFilters( + "t < TIMESTAMP '2021-06-30T01:00:00.001Z'" /* query predicate */, + "t < 2021-06-30 01:00:00.001" /* Spark post scan filter */, + "t IS NOT NULL, t < 1625014800001000" /* Iceberg scan filters */, + ImmutableList.of(row(1, 100, timestamp("2021-06-30T01:00:00.0Z")))); + + // strict/inclusive projections for t <= TIMESTAMP '2021-06-30T01:00:00.000Z' differ, + // so this filter does NOT select entire partitions and can't be pushed down completely + checkFilters( + "t <= TIMESTAMP '2021-06-30T01:00:00.000Z'" /* query predicate */, + "t <= 2021-06-30 01:00:00" /* Spark post scan filter */, + "t IS NOT NULL, t <= 1625014800000000" /* Iceberg scan filters */, + ImmutableList.of(row(1, 100, timestamp("2021-06-30T01:00:00.0Z")))); + }); + } + + @Test + public void testFilterPushdownWithDaysTransform() { + sql( + "CREATE TABLE %s (id INT, price INT, t TIMESTAMP)" + + "USING iceberg " + + "PARTITIONED BY (days(t))", + tableName); + + sql("INSERT INTO %s VALUES (1, 100, TIMESTAMP '2021-06-15T01:00:00.000Z')", tableName); + sql("INSERT INTO %s VALUES (2, 200, TIMESTAMP '2021-06-30T02:00:00.000Z')", tableName); + sql("INSERT INTO %s VALUES (3, 300, TIMESTAMP '2021-07-15T10:00:00.000Z')", tableName); + sql("INSERT INTO %s VALUES (4, 400, null)", tableName); + + withDefaultTimeZone( + "UTC", + () -> { + checkOnlyIcebergFilters( + "t IS NULL" /* query predicate */, + "t IS NULL" /* Iceberg scan filters */, + ImmutableList.of(row(4, 400, null))); + + // strict/inclusive projections for t < TIMESTAMP '2021-07-05T00:00:00.000Z' are equal, + // so this filter selects entire partitions and can be pushed down completely + checkOnlyIcebergFilters( + "t < TIMESTAMP '2021-07-05T00:00:00.000Z'" /* query predicate */, + "t IS NOT NULL, t < 1625443200000000" /* Iceberg scan filters */, + ImmutableList.of( + row(1, 100, timestamp("2021-06-15T01:00:00.000Z")), + row(2, 200, timestamp("2021-06-30T02:00:00.000Z")))); + + // strict/inclusive projections for t < TIMESTAMP '2021-06-30T03:00:00.000Z' differ, + // so this filter does NOT select entire partitions and can't be pushed down completely + checkFilters( + "t < TIMESTAMP '2021-06-30T03:00:00.000Z'" /* query predicate */, + "t < 2021-06-30 03:00:00" /* Spark post scan filter */, + "t IS NOT NULL, t < 1625022000000000" /* Iceberg scan filters */, + ImmutableList.of( + row(1, 100, timestamp("2021-06-15T01:00:00.000Z")), + row(2, 200, timestamp("2021-06-30T02:00:00.000Z")))); + }); + } + + @Test + public void testFilterPushdownWithMonthsTransform() { + sql( + "CREATE TABLE %s (id INT, price INT, t TIMESTAMP)" + + "USING iceberg " + + "PARTITIONED BY (months(t))", + tableName); + + sql("INSERT INTO %s VALUES (1, 100, TIMESTAMP '2021-06-30T01:00:00.000Z')", tableName); + sql("INSERT INTO %s VALUES (2, 200, TIMESTAMP '2021-06-30T02:00:00.000Z')", tableName); + sql("INSERT INTO %s VALUES (3, 300, TIMESTAMP '2021-07-15T10:00:00.000Z')", tableName); + sql("INSERT INTO %s VALUES (4, 400, null)", tableName); + + withDefaultTimeZone( + "UTC", + () -> { + checkOnlyIcebergFilters( + "t IS NULL" /* query predicate */, + "t IS NULL" /* Iceberg scan filters */, + ImmutableList.of(row(4, 400, null))); + + // strict/inclusive projections for t < TIMESTAMP '2021-07-01T00:00:00.000Z' are equal, + // so this filter selects entire partitions and can be pushed down completely + checkOnlyIcebergFilters( + "t < TIMESTAMP '2021-07-01T00:00:00.000Z'" /* query predicate */, + "t IS NOT NULL, t < 1625097600000000" /* Iceberg scan filters */, + ImmutableList.of( + row(1, 100, timestamp("2021-06-30T01:00:00.000Z")), + row(2, 200, timestamp("2021-06-30T02:00:00.000Z")))); + + // strict/inclusive projections for t < TIMESTAMP '2021-06-30T03:00:00.000Z' differ, + // so this filter does NOT select entire partitions and can't be pushed down completely + checkFilters( + "t < TIMESTAMP '2021-06-30T03:00:00.000Z'" /* query predicate */, + "t < 2021-06-30 03:00:00" /* Spark post scan filter */, + "t IS NOT NULL, t < 1625022000000000" /* Iceberg scan filters */, + ImmutableList.of( + row(1, 100, timestamp("2021-06-30T01:00:00.000Z")), + row(2, 200, timestamp("2021-06-30T02:00:00.000Z")))); + }); + } + + @Test + public void testFilterPushdownWithYearsTransform() { + sql( + "CREATE TABLE %s (id INT, price INT, t TIMESTAMP)" + + "USING iceberg " + + "PARTITIONED BY (years(t))", + tableName); + + sql("INSERT INTO %s VALUES (1, 100, TIMESTAMP '2021-06-30T01:00:00.000Z')", tableName); + sql("INSERT INTO %s VALUES (2, 200, TIMESTAMP '2021-06-30T02:00:00.000Z')", tableName); + sql("INSERT INTO %s VALUES (2, 200, TIMESTAMP '2022-09-25T02:00:00.000Z')", tableName); + sql("INSERT INTO %s VALUES (3, 300, null)", tableName); + + withDefaultTimeZone( + "UTC", + () -> { + checkOnlyIcebergFilters( + "t IS NULL" /* query predicate */, + "t IS NULL" /* Iceberg scan filters */, + ImmutableList.of(row(3, 300, null))); + + // strict/inclusive projections for t < TIMESTAMP '2022-01-01T00:00:00.000Z' are equal, + // so this filter selects entire partitions and can be pushed down completely + checkOnlyIcebergFilters( + "t < TIMESTAMP '2022-01-01T00:00:00.000Z'" /* query predicate */, + "t IS NOT NULL, t < 1640995200000000" /* Iceberg scan filters */, + ImmutableList.of( + row(1, 100, timestamp("2021-06-30T01:00:00.000Z")), + row(2, 200, timestamp("2021-06-30T02:00:00.000Z")))); + + // strict/inclusive projections for t < TIMESTAMP '2021-06-30T03:00:00.000Z' differ, + // so this filter does NOT select entire partitions and can't be pushed down completely + checkFilters( + "t < TIMESTAMP '2021-06-30T03:00:00.000Z'" /* query predicate */, + "t < 2021-06-30 03:00:00" /* Spark post scan filter */, + "t IS NOT NULL, t < 1625022000000000" /* Iceberg scan filters */, + ImmutableList.of( + row(1, 100, timestamp("2021-06-30T01:00:00.000Z")), + row(2, 200, timestamp("2021-06-30T02:00:00.000Z")))); + }); + } + + @Test + public void testFilterPushdownWithBucketTransform() { + sql( + "CREATE TABLE %s (id INT, salary INT, dep STRING)" + + "USING iceberg " + + "PARTITIONED BY (dep, bucket(8, id))", + tableName); + + sql("INSERT INTO %s VALUES (1, 100, 'd1')", tableName); + sql("INSERT INTO %s VALUES (2, 200, 'd2')", tableName); + + checkFilters( + "dep = 'd1' AND id = 1" /* query predicate */, + "id = 1" /* Spark post scan filter */, + "dep IS NOT NULL, id IS NOT NULL, dep = 'd1'" /* Iceberg scan filters */, + ImmutableList.of(row(1, 100, "d1"))); + } + + @Test + public void testFilterPushdownWithTruncateTransform() { + sql( + "CREATE TABLE %s (id INT, salary INT, dep STRING)" + + "USING iceberg " + + "PARTITIONED BY (truncate(1, dep))", + tableName); + + sql("INSERT INTO %s VALUES (1, 100, 'd1')", tableName); + sql("INSERT INTO %s VALUES (2, 200, 'd2')", tableName); + sql("INSERT INTO %s VALUES (3, 300, 'a3')", tableName); + + checkOnlyIcebergFilters( + "dep LIKE 'd%'" /* query predicate */, + "dep IS NOT NULL, dep LIKE 'd%'" /* Iceberg scan filters */, + ImmutableList.of(row(1, 100, "d1"), row(2, 200, "d2"))); + + checkFilters( + "dep = 'd1'" /* query predicate */, + "dep = d1" /* Spark post scan filter */, + "dep IS NOT NULL" /* Iceberg scan filters */, + ImmutableList.of(row(1, 100, "d1"))); + } + + @Test + public void testFilterPushdownWithSpecEvolutionAndIdentityTransforms() { + sql( + "CREATE TABLE %s (id INT, salary INT, dep STRING, sub_dep STRING)" + + "USING iceberg " + + "PARTITIONED BY (dep)", + tableName); + + sql("INSERT INTO %s VALUES (1, 100, 'd1', 'sd1')", tableName); + + // the filter can be pushed completely because all specs include identity(dep) + checkOnlyIcebergFilters( + "dep = 'd1'" /* query predicate */, + "dep IS NOT NULL, dep = 'd1'" /* Iceberg scan filters */, + ImmutableList.of(row(1, 100, "d1", "sd1"))); + + Table table = validationCatalog.loadTable(tableIdent); + + table.updateSpec().addField("sub_dep").commit(); + sql("REFRESH TABLE %s", tableName); + sql("INSERT INTO %s VALUES (2, 200, 'd2', 'sd2')", tableName); + + // the filter can be pushed completely because all specs include identity(dep) + checkOnlyIcebergFilters( + "dep = 'd1'" /* query predicate */, + "dep IS NOT NULL, dep = 'd1'" /* Iceberg scan filters */, + ImmutableList.of(row(1, 100, "d1", "sd1"))); + + table.updateSpec().removeField("sub_dep").removeField("dep").commit(); + sql("REFRESH TABLE %s", tableName); + sql("INSERT INTO %s VALUES (3, 300, 'd3', 'sd3')", tableName); + + // the filter can't be pushed completely because not all specs include identity(dep) + checkFilters( + "dep = 'd1'" /* query predicate */, + "isnotnull(dep) AND (dep = d1)" /* Spark post scan filter */, + "dep IS NOT NULL, dep = 'd1'" /* Iceberg scan filters */, + ImmutableList.of(row(1, 100, "d1", "sd1"))); + } + + @Test + public void testFilterPushdownWithSpecEvolutionAndTruncateTransform() { + sql( + "CREATE TABLE %s (id INT, salary INT, dep STRING)" + + "USING iceberg " + + "PARTITIONED BY (truncate(2, dep))", + tableName); + + sql("INSERT INTO %s VALUES (1, 100, 'd1')", tableName); + + // the filter can be pushed completely because the current spec supports it + checkOnlyIcebergFilters( + "dep LIKE 'd1%'" /* query predicate */, + "dep IS NOT NULL, dep LIKE 'd1%'" /* Iceberg scan filters */, + ImmutableList.of(row(1, 100, "d1"))); + + Table table = validationCatalog.loadTable(tableIdent); + table + .updateSpec() + .removeField(Expressions.truncate("dep", 2)) + .addField(Expressions.truncate("dep", 1)) + .commit(); + sql("REFRESH TABLE %s", tableName); + sql("INSERT INTO %s VALUES (2, 200, 'd2')", tableName); + + // the filter can be pushed completely because both specs support it + checkOnlyIcebergFilters( + "dep LIKE 'd%'" /* query predicate */, + "dep IS NOT NULL, dep LIKE 'd%'" /* Iceberg scan filters */, + ImmutableList.of(row(1, 100, "d1"), row(2, 200, "d2"))); + + // the filter can't be pushed completely because the second spec is truncate(dep, 1) and + // the predicate literal is d1, which is two chars + checkFilters( + "dep LIKE 'd1%' AND id = 1" /* query predicate */, + "(isnotnull(id) AND StartsWith(dep, d1)) AND (id = 1)" /* Spark post scan filter */, + "dep IS NOT NULL, id IS NOT NULL, dep LIKE 'd1%', id = 1" /* Iceberg scan filters */, + ImmutableList.of(row(1, 100, "d1"))); + } + + @Test + public void testFilterPushdownWithSpecEvolutionAndTimeTransforms() { + sql( + "CREATE TABLE %s (id INT, price INT, t TIMESTAMP)" + + "USING iceberg " + + "PARTITIONED BY (hours(t))", + tableName); + + withDefaultTimeZone( + "UTC", + () -> { + sql("INSERT INTO %s VALUES (1, 100, TIMESTAMP '2021-06-30T01:00:00.000Z')", tableName); + + // the filter can be pushed completely because the current spec supports it + checkOnlyIcebergFilters( + "t < TIMESTAMP '2021-07-01T00:00:00.000Z'" /* query predicate */, + "t IS NOT NULL, t < 1625097600000000" /* Iceberg scan filters */, + ImmutableList.of(row(1, 100, timestamp("2021-06-30T01:00:00.000Z")))); + + Table table = validationCatalog.loadTable(tableIdent); + table + .updateSpec() + .removeField(Expressions.hour("t")) + .addField(Expressions.month("t")) + .commit(); + sql("REFRESH TABLE %s", tableName); + sql("INSERT INTO %s VALUES (2, 200, TIMESTAMP '2021-05-30T01:00:00.000Z')", tableName); + + // the filter can be pushed completely because both specs support it + checkOnlyIcebergFilters( + "t < TIMESTAMP '2021-06-01T00:00:00.000Z'" /* query predicate */, + "t IS NOT NULL, t < 1622505600000000" /* Iceberg scan filters */, + ImmutableList.of(row(2, 200, timestamp("2021-05-30T01:00:00.000Z")))); + }); + } + + @Test + public void testFilterPushdownWithSpecialFloatingPointPartitionValues() { + sql( + "CREATE TABLE %s (id INT, salary DOUBLE)" + "USING iceberg " + "PARTITIONED BY (salary)", + tableName); + + sql("INSERT INTO %s VALUES (1, 100.5)", tableName); + sql("INSERT INTO %s VALUES (2, double('NaN'))", tableName); + sql("INSERT INTO %s VALUES (3, double('infinity'))", tableName); + sql("INSERT INTO %s VALUES (4, double('-infinity'))", tableName); + + checkOnlyIcebergFilters( + "salary = 100.5" /* query predicate */, + "salary IS NOT NULL, salary = 100.5" /* Iceberg scan filters */, + ImmutableList.of(row(1, 100.5))); + + checkOnlyIcebergFilters( + "salary = double('NaN')" /* query predicate */, + "salary IS NOT NULL, is_nan(salary)" /* Iceberg scan filters */, + ImmutableList.of(row(2, Double.NaN))); + + checkOnlyIcebergFilters( + "salary != double('NaN')" /* query predicate */, + "salary IS NOT NULL, NOT (is_nan(salary))" /* Iceberg scan filters */, + ImmutableList.of( + row(1, 100.5), row(3, Double.POSITIVE_INFINITY), row(4, Double.NEGATIVE_INFINITY))); + + checkOnlyIcebergFilters( + "salary = double('infinity')" /* query predicate */, + "salary IS NOT NULL, salary = Infinity" /* Iceberg scan filters */, + ImmutableList.of(row(3, Double.POSITIVE_INFINITY))); + + checkOnlyIcebergFilters( + "salary = double('-infinity')" /* query predicate */, + "salary IS NOT NULL, salary = -Infinity" /* Iceberg scan filters */, + ImmutableList.of(row(4, Double.NEGATIVE_INFINITY))); + } + + private void checkOnlyIcebergFilters( + String predicate, String icebergFilters, List expectedRows) { + + checkFilters(predicate, null, icebergFilters, expectedRows); + } + + private void checkFilters( + String predicate, String sparkFilter, String icebergFilters, List expectedRows) { + + Action check = + () -> { + assertEquals( + "Rows must match", + expectedRows, + sql("SELECT * FROM %s WHERE %s ORDER BY id", tableName, predicate)); + }; + SparkPlan sparkPlan = executeAndKeepPlan(check); + String planAsString = sparkPlan.toString().replaceAll("#(\\d+L?)", ""); + + if (sparkFilter != null) { + Assertions.assertThat(planAsString) + .as("Post scan filter should match") + .contains("Filter (" + sparkFilter + ")"); + } else { + Assertions.assertThat(planAsString) + .as("Should be no post scan filter") + .doesNotContain("Filter ("); + } + + Assertions.assertThat(planAsString) + .as("Pushed filters must match") + .contains("[filters=" + icebergFilters + ","); + } + + private Timestamp timestamp(String timestampAsString) { + return Timestamp.from(Instant.parse(timestampAsString)); + } +} diff --git a/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/sql/TestNamespaceSQL.java b/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/sql/TestNamespaceSQL.java new file mode 100644 index 000000000000..6c29ea4442ef --- /dev/null +++ b/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/sql/TestNamespaceSQL.java @@ -0,0 +1,245 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.sql; + +import java.io.File; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.stream.Collectors; +import org.apache.iceberg.AssertHelpers; +import org.apache.iceberg.catalog.Namespace; +import org.apache.iceberg.catalog.TableIdentifier; +import org.apache.iceberg.exceptions.NamespaceNotEmptyException; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableSet; +import org.apache.iceberg.relocated.com.google.common.collect.Iterables; +import org.apache.iceberg.spark.SparkCatalogTestBase; +import org.junit.After; +import org.junit.Assert; +import org.junit.Assume; +import org.junit.Test; + +public class TestNamespaceSQL extends SparkCatalogTestBase { + private static final Namespace NS = Namespace.of("db"); + + private final String fullNamespace; + private final boolean isHadoopCatalog; + + public TestNamespaceSQL(String catalogName, String implementation, Map config) { + super(catalogName, implementation, config); + this.fullNamespace = ("spark_catalog".equals(catalogName) ? "" : catalogName + ".") + NS; + this.isHadoopCatalog = "testhadoop".equals(catalogName); + } + + @After + public void cleanNamespaces() { + sql("DROP TABLE IF EXISTS %s.table", fullNamespace); + sql("DROP NAMESPACE IF EXISTS %s", fullNamespace); + } + + @Test + public void testCreateNamespace() { + Assert.assertFalse( + "Namespace should not already exist", validationNamespaceCatalog.namespaceExists(NS)); + + sql("CREATE NAMESPACE %s", fullNamespace); + + Assert.assertTrue("Namespace should exist", validationNamespaceCatalog.namespaceExists(NS)); + } + + @Test + public void testDefaultNamespace() { + Assume.assumeFalse("Hadoop has no default namespace configured", isHadoopCatalog); + + sql("USE %s", catalogName); + + Object[] current = Iterables.getOnlyElement(sql("SHOW CURRENT NAMESPACE")); + Assert.assertEquals("Should use the current catalog", current[0], catalogName); + Assert.assertEquals("Should use the configured default namespace", current[1], "default"); + } + + @Test + public void testDropEmptyNamespace() { + Assert.assertFalse( + "Namespace should not already exist", validationNamespaceCatalog.namespaceExists(NS)); + + sql("CREATE NAMESPACE %s", fullNamespace); + + Assert.assertTrue("Namespace should exist", validationNamespaceCatalog.namespaceExists(NS)); + + sql("DROP NAMESPACE %s", fullNamespace); + + Assert.assertFalse( + "Namespace should have been dropped", validationNamespaceCatalog.namespaceExists(NS)); + } + + @Test + public void testDropNonEmptyNamespace() { + Assume.assumeFalse("Session catalog has flaky behavior", "spark_catalog".equals(catalogName)); + + Assert.assertFalse( + "Namespace should not already exist", validationNamespaceCatalog.namespaceExists(NS)); + + sql("CREATE NAMESPACE %s", fullNamespace); + sql("CREATE TABLE %s.table (id bigint) USING iceberg", fullNamespace); + + Assert.assertTrue("Namespace should exist", validationNamespaceCatalog.namespaceExists(NS)); + Assert.assertTrue( + "Table should exist", validationCatalog.tableExists(TableIdentifier.of(NS, "table"))); + + AssertHelpers.assertThrows( + "Should fail if trying to delete a non-empty namespace", + NamespaceNotEmptyException.class, + "Namespace db is not empty.", + () -> sql("DROP NAMESPACE %s", fullNamespace)); + + sql("DROP TABLE %s.table", fullNamespace); + } + + @Test + public void testListTables() { + Assert.assertFalse( + "Namespace should not already exist", validationNamespaceCatalog.namespaceExists(NS)); + + sql("CREATE NAMESPACE %s", fullNamespace); + + Assert.assertTrue("Namespace should exist", validationNamespaceCatalog.namespaceExists(NS)); + + List rows = sql("SHOW TABLES IN %s", fullNamespace); + Assert.assertEquals("Should not list any tables", 0, rows.size()); + + sql("CREATE TABLE %s.table (id bigint) USING iceberg", fullNamespace); + + Object[] row = Iterables.getOnlyElement(sql("SHOW TABLES IN %s", fullNamespace)); + Assert.assertEquals("Namespace should match", "db", row[0]); + Assert.assertEquals("Table name should match", "table", row[1]); + } + + @Test + public void testListNamespace() { + Assert.assertFalse( + "Namespace should not already exist", validationNamespaceCatalog.namespaceExists(NS)); + + sql("CREATE NAMESPACE %s", fullNamespace); + + Assert.assertTrue("Namespace should exist", validationNamespaceCatalog.namespaceExists(NS)); + + List namespaces = sql("SHOW NAMESPACES IN %s", catalogName); + + if (isHadoopCatalog) { + Assert.assertEquals("Should have 1 namespace", 1, namespaces.size()); + Set namespaceNames = + namespaces.stream().map(arr -> arr[0].toString()).collect(Collectors.toSet()); + Assert.assertEquals("Should have only db namespace", ImmutableSet.of("db"), namespaceNames); + } else { + Assert.assertEquals("Should have 2 namespaces", 2, namespaces.size()); + Set namespaceNames = + namespaces.stream().map(arr -> arr[0].toString()).collect(Collectors.toSet()); + Assert.assertEquals( + "Should have default and db namespaces", + ImmutableSet.of("default", "db"), + namespaceNames); + } + + List nestedNamespaces = sql("SHOW NAMESPACES IN %s", fullNamespace); + + Set nestedNames = + nestedNamespaces.stream().map(arr -> arr[0].toString()).collect(Collectors.toSet()); + Assert.assertEquals("Should not have nested namespaces", ImmutableSet.of(), nestedNames); + } + + @Test + public void testCreateNamespaceWithMetadata() { + Assume.assumeFalse("HadoopCatalog does not support namespace metadata", isHadoopCatalog); + + Assert.assertFalse( + "Namespace should not already exist", validationNamespaceCatalog.namespaceExists(NS)); + + sql("CREATE NAMESPACE %s WITH PROPERTIES ('prop'='value')", fullNamespace); + + Assert.assertTrue("Namespace should exist", validationNamespaceCatalog.namespaceExists(NS)); + + Map nsMetadata = validationNamespaceCatalog.loadNamespaceMetadata(NS); + + Assert.assertEquals( + "Namespace should have expected prop value", "value", nsMetadata.get("prop")); + } + + @Test + public void testCreateNamespaceWithComment() { + Assume.assumeFalse("HadoopCatalog does not support namespace metadata", isHadoopCatalog); + + Assert.assertFalse( + "Namespace should not already exist", validationNamespaceCatalog.namespaceExists(NS)); + + sql("CREATE NAMESPACE %s COMMENT 'namespace doc'", fullNamespace); + + Assert.assertTrue("Namespace should exist", validationNamespaceCatalog.namespaceExists(NS)); + + Map nsMetadata = validationNamespaceCatalog.loadNamespaceMetadata(NS); + + Assert.assertEquals( + "Namespace should have expected comment", "namespace doc", nsMetadata.get("comment")); + } + + @Test + public void testCreateNamespaceWithLocation() throws Exception { + Assume.assumeFalse("HadoopCatalog does not support namespace locations", isHadoopCatalog); + + Assert.assertFalse( + "Namespace should not already exist", validationNamespaceCatalog.namespaceExists(NS)); + + File location = temp.newFile(); + Assert.assertTrue(location.delete()); + + sql("CREATE NAMESPACE %s LOCATION '%s'", fullNamespace, location); + + Assert.assertTrue("Namespace should exist", validationNamespaceCatalog.namespaceExists(NS)); + + Map nsMetadata = validationNamespaceCatalog.loadNamespaceMetadata(NS); + + Assert.assertEquals( + "Namespace should have expected location", + "file:" + location.getPath(), + nsMetadata.get("location")); + } + + @Test + public void testSetProperties() { + Assume.assumeFalse("HadoopCatalog does not support namespace metadata", isHadoopCatalog); + + Assert.assertFalse( + "Namespace should not already exist", validationNamespaceCatalog.namespaceExists(NS)); + + sql("CREATE NAMESPACE %s", fullNamespace); + + Assert.assertTrue("Namespace should exist", validationNamespaceCatalog.namespaceExists(NS)); + + Map defaultMetadata = validationNamespaceCatalog.loadNamespaceMetadata(NS); + Assert.assertFalse( + "Default metadata should not have custom property", defaultMetadata.containsKey("prop")); + + sql("ALTER NAMESPACE %s SET PROPERTIES ('prop'='value')", fullNamespace); + + Map nsMetadata = validationNamespaceCatalog.loadNamespaceMetadata(NS); + + Assert.assertEquals( + "Namespace should have expected prop value", "value", nsMetadata.get("prop")); + } +} diff --git a/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/sql/TestPartitionedWrites.java b/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/sql/TestPartitionedWrites.java new file mode 100644 index 000000000000..a18bd997250b --- /dev/null +++ b/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/sql/TestPartitionedWrites.java @@ -0,0 +1,29 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.sql; + +import java.util.Map; + +public class TestPartitionedWrites extends PartitionedWritesTestBase { + + public TestPartitionedWrites( + String catalogName, String implementation, Map config) { + super(catalogName, implementation, config); + } +} diff --git a/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/sql/TestPartitionedWritesAsSelect.java b/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/sql/TestPartitionedWritesAsSelect.java new file mode 100644 index 000000000000..3ffd38b83c3b --- /dev/null +++ b/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/sql/TestPartitionedWritesAsSelect.java @@ -0,0 +1,141 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.sql; + +import java.util.List; +import java.util.stream.IntStream; +import org.apache.iceberg.spark.IcebergSpark; +import org.apache.iceberg.spark.SparkTestBaseWithCatalog; +import org.apache.spark.sql.types.DataTypes; +import org.junit.After; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; + +public class TestPartitionedWritesAsSelect extends SparkTestBaseWithCatalog { + + private final String targetTable = tableName("target_table"); + + @Before + public void createTables() { + sql( + "CREATE TABLE %s (id bigint, data string, category string, ts timestamp) USING iceberg", + tableName); + } + + @After + public void removeTables() { + sql("DROP TABLE IF EXISTS %s", tableName); + sql("DROP TABLE IF EXISTS %s", targetTable); + } + + @Test + public void testInsertAsSelectAppend() { + insertData(3); + List expected = currentData(); + + sql( + "CREATE TABLE %s (id bigint, data string, category string, ts timestamp)" + + "USING iceberg PARTITIONED BY (days(ts), category)", + targetTable); + + sql( + "INSERT INTO %s SELECT id, data, category, ts FROM %s ORDER BY ts,category", + targetTable, tableName); + Assert.assertEquals( + "Should have 15 rows after insert", + 3 * 5L, + scalarSql("SELECT count(*) FROM %s", targetTable)); + + assertEquals( + "Row data should match expected", + expected, + sql("SELECT * FROM %s ORDER BY id", targetTable)); + } + + @Test + public void testInsertAsSelectWithBucket() { + insertData(3); + List expected = currentData(); + + sql( + "CREATE TABLE %s (id bigint, data string, category string, ts timestamp)" + + "USING iceberg PARTITIONED BY (bucket(8, data))", + targetTable); + + IcebergSpark.registerBucketUDF(spark, "iceberg_bucket8", DataTypes.StringType, 8); + sql( + "INSERT INTO %s SELECT id, data, category, ts FROM %s ORDER BY iceberg_bucket8(data)", + targetTable, tableName); + Assert.assertEquals( + "Should have 15 rows after insert", + 3 * 5L, + scalarSql("SELECT count(*) FROM %s", targetTable)); + + assertEquals( + "Row data should match expected", + expected, + sql("SELECT * FROM %s ORDER BY id", targetTable)); + } + + @Test + public void testInsertAsSelectWithTruncate() { + insertData(3); + List expected = currentData(); + + sql( + "CREATE TABLE %s (id bigint, data string, category string, ts timestamp)" + + "USING iceberg PARTITIONED BY (truncate(data, 4), truncate(id, 4))", + targetTable); + + IcebergSpark.registerTruncateUDF(spark, "iceberg_truncate_string4", DataTypes.StringType, 4); + IcebergSpark.registerTruncateUDF(spark, "iceberg_truncate_long4", DataTypes.LongType, 4); + sql( + "INSERT INTO %s SELECT id, data, category, ts FROM %s " + + "ORDER BY iceberg_truncate_string4(data),iceberg_truncate_long4(id)", + targetTable, tableName); + Assert.assertEquals( + "Should have 15 rows after insert", + 3 * 5L, + scalarSql("SELECT count(*) FROM %s", targetTable)); + + assertEquals( + "Row data should match expected", + expected, + sql("SELECT * FROM %s ORDER BY id", targetTable)); + } + + private void insertData(int repeatCounter) { + IntStream.range(0, repeatCounter) + .forEach( + i -> { + sql( + "INSERT INTO %s VALUES (13, '1', 'bgd16', timestamp('2021-11-10 11:20:10'))," + + "(21, '2', 'bgd13', timestamp('2021-11-10 11:20:10')), " + + "(12, '3', 'bgd14', timestamp('2021-11-10 11:20:10'))," + + "(222, '3', 'bgd15', timestamp('2021-11-10 11:20:10'))," + + "(45, '4', 'bgd16', timestamp('2021-11-10 11:20:10'))", + tableName); + }); + } + + private List currentData() { + return rowsToJava(spark.sql("SELECT * FROM " + tableName + " order by id").collectAsList()); + } +} diff --git a/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/sql/TestPartitionedWritesToBranch.java b/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/sql/TestPartitionedWritesToBranch.java new file mode 100644 index 000000000000..c6cde7a5524e --- /dev/null +++ b/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/sql/TestPartitionedWritesToBranch.java @@ -0,0 +1,52 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.sql; + +import java.util.Map; +import org.apache.iceberg.Table; +import org.junit.Before; + +public class TestPartitionedWritesToBranch extends PartitionedWritesTestBase { + + private static final String BRANCH = "test"; + + public TestPartitionedWritesToBranch( + String catalogName, String implementation, Map config) { + super(catalogName, implementation, config); + } + + @Before + @Override + public void createTables() { + super.createTables(); + Table table = validationCatalog.loadTable(tableIdent); + table.manageSnapshots().createBranch(BRANCH, table.currentSnapshot().snapshotId()).commit(); + sql("REFRESH TABLE " + tableName); + } + + @Override + protected String commitTarget() { + return String.format("%s.branch_%s", tableName, BRANCH); + } + + @Override + protected String selectTarget() { + return String.format("%s VERSION AS OF '%s'", tableName, BRANCH); + } +} diff --git a/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/sql/TestPartitionedWritesToWapBranch.java b/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/sql/TestPartitionedWritesToWapBranch.java new file mode 100644 index 000000000000..5dde5f33d965 --- /dev/null +++ b/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/sql/TestPartitionedWritesToWapBranch.java @@ -0,0 +1,99 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.sql; + +import java.util.List; +import java.util.Map; +import java.util.UUID; +import org.apache.iceberg.Table; +import org.apache.iceberg.TableProperties; +import org.apache.iceberg.exceptions.ValidationException; +import org.apache.iceberg.spark.SparkSQLProperties; +import org.assertj.core.api.Assertions; +import org.junit.After; +import org.junit.Before; +import org.junit.Test; + +public class TestPartitionedWritesToWapBranch extends PartitionedWritesTestBase { + + private static final String BRANCH = "test"; + + public TestPartitionedWritesToWapBranch( + String catalogName, String implementation, Map config) { + super(catalogName, implementation, config); + } + + @Before + @Override + public void createTables() { + spark.conf().set(SparkSQLProperties.WAP_BRANCH, BRANCH); + sql( + "CREATE TABLE %s (id bigint, data string) USING iceberg PARTITIONED BY (truncate(id, 3)) OPTIONS (%s = 'true')", + tableName, TableProperties.WRITE_AUDIT_PUBLISH_ENABLED); + sql("INSERT INTO %s VALUES (1, 'a'), (2, 'b'), (3, 'c')", tableName); + } + + @After + @Override + public void removeTables() { + super.removeTables(); + spark.conf().unset(SparkSQLProperties.WAP_BRANCH); + spark.conf().unset(SparkSQLProperties.WAP_ID); + } + + @Override + protected String commitTarget() { + return tableName; + } + + @Override + protected String selectTarget() { + return String.format("%s VERSION AS OF '%s'", tableName, BRANCH); + } + + @Test + public void testBranchAndWapBranchCannotBothBeSetForWrite() { + Table table = validationCatalog.loadTable(tableIdent); + table.manageSnapshots().createBranch("test2", table.refs().get(BRANCH).snapshotId()).commit(); + sql("REFRESH TABLE " + tableName); + Assertions.assertThatThrownBy( + () -> sql("INSERT INTO %s.branch_test2 VALUES (4, 'd')", tableName)) + .isInstanceOf(ValidationException.class) + .hasMessage( + "Cannot write to both branch and WAP branch, but got branch [test2] and WAP branch [%s]", + BRANCH); + } + + @Test + public void testWapIdAndWapBranchCannotBothBeSetForWrite() { + String wapId = UUID.randomUUID().toString(); + spark.conf().set(SparkSQLProperties.WAP_ID, wapId); + Assertions.assertThatThrownBy(() -> sql("INSERT INTO %s VALUES (4, 'd')", tableName)) + .isInstanceOf(ValidationException.class) + .hasMessage( + "Cannot set both WAP ID and branch, but got ID [%s] and branch [%s]", wapId, BRANCH); + } + + @Override + protected void assertPartitionMetadata( + String tableName, List expected, String... selectPartitionColumns) { + // Cannot read from the .partitions table newly written data into the WAP branch. See + // https://github.com/apache/iceberg/issues/7297 for more details. + } +} diff --git a/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/sql/TestRefreshTable.java b/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/sql/TestRefreshTable.java new file mode 100644 index 000000000000..7da2dc0882db --- /dev/null +++ b/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/sql/TestRefreshTable.java @@ -0,0 +1,78 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.sql; + +import java.util.List; +import java.util.Map; +import org.apache.iceberg.DataFile; +import org.apache.iceberg.Table; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableList; +import org.apache.iceberg.spark.SparkCatalogConfig; +import org.apache.iceberg.spark.SparkCatalogTestBase; +import org.junit.After; +import org.junit.Before; +import org.junit.Test; + +public class TestRefreshTable extends SparkCatalogTestBase { + + public TestRefreshTable(String catalogName, String implementation, Map config) { + super(catalogName, implementation, config); + } + + @Before + public void createTables() { + sql("CREATE TABLE %s (key int, value int) USING iceberg", tableName); + sql("INSERT INTO %s VALUES (1,1)", tableName); + } + + @After + public void removeTables() { + sql("DROP TABLE IF EXISTS %s", tableName); + } + + @Test + public void testRefreshCommand() { + // We are not allowed to change the session catalog after it has been initialized, so build a + // new one + if (catalogName.equals(SparkCatalogConfig.SPARK.catalogName()) + || catalogName.equals(SparkCatalogConfig.HADOOP.catalogName())) { + spark.conf().set("spark.sql.catalog." + catalogName + ".cache-enabled", true); + spark = spark.cloneSession(); + } + + List originalExpected = ImmutableList.of(row(1, 1)); + List originalActual = sql("SELECT * FROM %s", tableName); + assertEquals("Table should start as expected", originalExpected, originalActual); + + // Modify table outside of spark, it should be cached so Spark should see the same value after + // mutation + Table table = validationCatalog.loadTable(tableIdent); + DataFile file = table.currentSnapshot().addedDataFiles(table.io()).iterator().next(); + table.newDelete().deleteFile(file).commit(); + + List cachedActual = sql("SELECT * FROM %s", tableName); + assertEquals("Cached table should be unchanged", originalExpected, cachedActual); + + // Refresh the Spark catalog, should be empty + sql("REFRESH TABLE %s", tableName); + List refreshedExpected = ImmutableList.of(); + List refreshedActual = sql("SELECT * FROM %s", tableName); + assertEquals("Refreshed table should be empty", refreshedExpected, refreshedActual); + } +} diff --git a/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/sql/TestSelect.java b/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/sql/TestSelect.java new file mode 100644 index 000000000000..e08bc4574dbf --- /dev/null +++ b/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/sql/TestSelect.java @@ -0,0 +1,490 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.sql; + +import java.text.SimpleDateFormat; +import java.util.Date; +import java.util.List; +import java.util.Map; +import java.util.concurrent.TimeUnit; +import org.apache.iceberg.AssertHelpers; +import org.apache.iceberg.Table; +import org.apache.iceberg.events.Listeners; +import org.apache.iceberg.events.ScanEvent; +import org.apache.iceberg.exceptions.ValidationException; +import org.apache.iceberg.expressions.Expressions; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableList; +import org.apache.iceberg.spark.Spark3Util; +import org.apache.iceberg.spark.SparkCatalogTestBase; +import org.apache.iceberg.spark.SparkReadOptions; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Row; +import org.assertj.core.api.Assertions; +import org.junit.After; +import org.junit.Assert; +import org.junit.Assume; +import org.junit.Before; +import org.junit.Test; + +public class TestSelect extends SparkCatalogTestBase { + private int scanEventCount = 0; + private ScanEvent lastScanEvent = null; + private String binaryTableName = tableName("binary_table"); + + public TestSelect(String catalogName, String implementation, Map config) { + super(catalogName, implementation, config); + + // register a scan event listener to validate pushdown + Listeners.register( + event -> { + scanEventCount += 1; + lastScanEvent = event; + }, + ScanEvent.class); + } + + @Before + public void createTables() { + sql("CREATE TABLE %s (id bigint, data string, float float) USING iceberg", tableName); + sql("INSERT INTO %s VALUES (1, 'a', 1.0), (2, 'b', 2.0), (3, 'c', float('NaN'))", tableName); + + this.scanEventCount = 0; + this.lastScanEvent = null; + } + + @After + public void removeTables() { + sql("DROP TABLE IF EXISTS %s", tableName); + sql("DROP TABLE IF EXISTS %s", binaryTableName); + } + + @Test + public void testSelect() { + List expected = + ImmutableList.of(row(1L, "a", 1.0F), row(2L, "b", 2.0F), row(3L, "c", Float.NaN)); + + assertEquals("Should return all expected rows", expected, sql("SELECT * FROM %s", tableName)); + } + + @Test + public void testSelectRewrite() { + List expected = ImmutableList.of(row(3L, "c", Float.NaN)); + + assertEquals( + "Should return all expected rows", + expected, + sql("SELECT * FROM %s where float = float('NaN')", tableName)); + + Assert.assertEquals("Should create only one scan", 1, scanEventCount); + Assert.assertEquals( + "Should push down expected filter", + "(float IS NOT NULL AND is_nan(float))", + Spark3Util.describe(lastScanEvent.filter())); + } + + @Test + public void testProjection() { + List expected = ImmutableList.of(row(1L), row(2L), row(3L)); + + assertEquals("Should return all expected rows", expected, sql("SELECT id FROM %s", tableName)); + + Assert.assertEquals("Should create only one scan", 1, scanEventCount); + Assert.assertEquals( + "Should not push down a filter", Expressions.alwaysTrue(), lastScanEvent.filter()); + Assert.assertEquals( + "Should project only the id column", + validationCatalog.loadTable(tableIdent).schema().select("id").asStruct(), + lastScanEvent.projection().asStruct()); + } + + @Test + public void testExpressionPushdown() { + List expected = ImmutableList.of(row("b")); + + assertEquals( + "Should return all expected rows", + expected, + sql("SELECT data FROM %s WHERE id = 2", tableName)); + + Assert.assertEquals("Should create only one scan", 1, scanEventCount); + Assert.assertEquals( + "Should push down expected filter", + "(id IS NOT NULL AND id = 2)", + Spark3Util.describe(lastScanEvent.filter())); + Assert.assertEquals( + "Should project only id and data columns", + validationCatalog.loadTable(tableIdent).schema().select("id", "data").asStruct(), + lastScanEvent.projection().asStruct()); + } + + @Test + public void testMetadataTables() { + Assume.assumeFalse( + "Spark session catalog does not support metadata tables", + "spark_catalog".equals(catalogName)); + + assertEquals( + "Snapshot metadata table", + ImmutableList.of(row(ANY, ANY, null, "append", ANY, ANY)), + sql("SELECT * FROM %s.snapshots", tableName)); + } + + @Test + public void testSnapshotInTableName() { + Assume.assumeFalse( + "Spark session catalog does not support extended table names", + "spark_catalog".equals(catalogName)); + + // get the snapshot ID of the last write and get the current row set as expected + long snapshotId = validationCatalog.loadTable(tableIdent).currentSnapshot().snapshotId(); + List expected = sql("SELECT * FROM %s", tableName); + + // create a second snapshot + sql("INSERT INTO %s VALUES (4, 'd', 4.0), (5, 'e', 5.0)", tableName); + + String prefix = "snapshot_id_"; + // read the table at the snapshot + List actual = sql("SELECT * FROM %s.%s", tableName, prefix + snapshotId); + assertEquals("Snapshot at specific ID, prefix " + prefix, expected, actual); + + // read the table using DataFrameReader option + Dataset df = + spark + .read() + .format("iceberg") + .option(SparkReadOptions.SNAPSHOT_ID, snapshotId) + .load(tableName); + List fromDF = rowsToJava(df.collectAsList()); + assertEquals("Snapshot at specific ID " + snapshotId, expected, fromDF); + } + + @Test + public void testTimestampInTableName() { + Assume.assumeFalse( + "Spark session catalog does not support extended table names", + "spark_catalog".equals(catalogName)); + + // get a timestamp just after the last write and get the current row set as expected + long snapshotTs = validationCatalog.loadTable(tableIdent).currentSnapshot().timestampMillis(); + long timestamp = waitUntilAfter(snapshotTs + 2); + List expected = sql("SELECT * FROM %s", tableName); + + // create a second snapshot + sql("INSERT INTO %s VALUES (4, 'd', 4.0), (5, 'e', 5.0)", tableName); + + String prefix = "at_timestamp_"; + // read the table at the snapshot + List actual = sql("SELECT * FROM %s.%s", tableName, prefix + timestamp); + assertEquals("Snapshot at timestamp, prefix " + prefix, expected, actual); + + // read the table using DataFrameReader option + Dataset df = + spark + .read() + .format("iceberg") + .option(SparkReadOptions.AS_OF_TIMESTAMP, timestamp) + .load(tableName); + List fromDF = rowsToJava(df.collectAsList()); + assertEquals("Snapshot at timestamp " + timestamp, expected, fromDF); + } + + @Test + public void testVersionAsOf() { + // get the snapshot ID of the last write and get the current row set as expected + long snapshotId = validationCatalog.loadTable(tableIdent).currentSnapshot().snapshotId(); + List expected = sql("SELECT * FROM %s", tableName); + + // create a second snapshot + sql("INSERT INTO %s VALUES (4, 'd', 4.0), (5, 'e', 5.0)", tableName); + + // read the table at the snapshot + List actual1 = sql("SELECT * FROM %s VERSION AS OF %s", tableName, snapshotId); + assertEquals("Snapshot at specific ID", expected, actual1); + + // read the table at the snapshot + // HIVE time travel syntax + List actual2 = + sql("SELECT * FROM %s FOR SYSTEM_VERSION AS OF %s", tableName, snapshotId); + assertEquals("Snapshot at specific ID", expected, actual2); + + // read the table using DataFrameReader option: versionAsOf + Dataset df = + spark + .read() + .format("iceberg") + .option(SparkReadOptions.VERSION_AS_OF, snapshotId) + .load(tableName); + List fromDF = rowsToJava(df.collectAsList()); + assertEquals("Snapshot at specific ID " + snapshotId, expected, fromDF); + } + + @Test + public void testTagReferenceAsOf() { + Table table = validationCatalog.loadTable(tableIdent); + long snapshotId = table.currentSnapshot().snapshotId(); + table.manageSnapshots().createTag("test_tag", snapshotId).commit(); + + // create a second snapshot, read the table at the snapshot + List expected = sql("SELECT * FROM %s", tableName); + sql("INSERT INTO %s VALUES (4, 'd', 4.0), (5, 'e', 5.0)", tableName); + List actual1 = sql("SELECT * FROM %s VERSION AS OF 'test_tag'", tableName); + assertEquals("Snapshot at specific tag reference name", expected, actual1); + + // read the table at the snapshot + // HIVE time travel syntax + List actual2 = sql("SELECT * FROM %s FOR SYSTEM_VERSION AS OF 'test_tag'", tableName); + assertEquals("Snapshot at specific tag reference name", expected, actual2); + + // read the table using DataFrameReader option: branch + Dataset df = + spark.read().format("iceberg").option(SparkReadOptions.TAG, "test_tag").load(tableName); + List fromDF = rowsToJava(df.collectAsList()); + assertEquals("Snapshot at specific tag reference name", expected, fromDF); + } + + @Test + public void testUseSnapshotIdForTagReferenceAsOf() { + Table table = validationCatalog.loadTable(tableIdent); + long snapshotId1 = table.currentSnapshot().snapshotId(); + + // create a second snapshot, read the table at the snapshot + List actual = sql("SELECT * FROM %s", tableName); + sql("INSERT INTO %s VALUES (4, 'd', 4.0), (5, 'e', 5.0)", tableName); + + table.refresh(); + long snapshotId2 = table.currentSnapshot().snapshotId(); + table.manageSnapshots().createTag(Long.toString(snapshotId1), snapshotId2).commit(); + + // currently Spark version travel ignores the type of the AS OF + // this means if a tag name matches a snapshot ID, it will always choose snapshotID to travel + // to. + List travelWithStringResult = + sql("SELECT * FROM %s VERSION AS OF '%s'", tableName, snapshotId1); + assertEquals("Snapshot at specific tag reference name", actual, travelWithStringResult); + + List travelWithLongResult = + sql("SELECT * FROM %s VERSION AS OF %s", tableName, snapshotId1); + assertEquals("Snapshot at specific tag reference name", actual, travelWithLongResult); + } + + @Test + public void testBranchReferenceAsOf() { + Table table = validationCatalog.loadTable(tableIdent); + long snapshotId = table.currentSnapshot().snapshotId(); + table.manageSnapshots().createBranch("test_branch", snapshotId).commit(); + + // create a second snapshot, read the table at the snapshot + List expected = sql("SELECT * FROM %s", tableName); + sql("INSERT INTO %s VALUES (4, 'd', 4.0), (5, 'e', 5.0)", tableName); + List actual1 = sql("SELECT * FROM %s VERSION AS OF 'test_branch'", tableName); + assertEquals("Snapshot at specific branch reference name", expected, actual1); + + // read the table at the snapshot + // HIVE time travel syntax + List actual2 = + sql("SELECT * FROM %s FOR SYSTEM_VERSION AS OF 'test_branch'", tableName); + assertEquals("Snapshot at specific branch reference name", expected, actual2); + + // read the table using DataFrameReader option: branch + Dataset df = + spark + .read() + .format("iceberg") + .option(SparkReadOptions.BRANCH, "test_branch") + .load(tableName); + List fromDF = rowsToJava(df.collectAsList()); + assertEquals("Snapshot at specific branch reference name", expected, fromDF); + } + + @Test + public void testUnknownReferenceAsOf() { + Assertions.assertThatThrownBy( + () -> sql("SELECT * FROM %s VERSION AS OF 'test_unknown'", tableName)) + .hasMessageContaining("Cannot find matching snapshot ID or reference name for version") + .isInstanceOf(ValidationException.class); + } + + @Test + public void testTimestampAsOf() { + long snapshotTs = validationCatalog.loadTable(tableIdent).currentSnapshot().timestampMillis(); + long timestamp = waitUntilAfter(snapshotTs + 1000); + waitUntilAfter(timestamp + 1000); + // AS OF expects the timestamp if given in long format will be of seconds precision + long timestampInSeconds = TimeUnit.MILLISECONDS.toSeconds(timestamp); + SimpleDateFormat sdf = new SimpleDateFormat("yyyy-MM-dd HH:mm:ss"); + String formattedDate = sdf.format(new Date(timestamp)); + + List expected = sql("SELECT * FROM %s", tableName); + + // create a second snapshot + sql("INSERT INTO %s VALUES (4, 'd', 4.0), (5, 'e', 5.0)", tableName); + + // read the table at the timestamp in long format i.e 1656507980463. + List actualWithLongFormat = + sql("SELECT * FROM %s TIMESTAMP AS OF %s", tableName, timestampInSeconds); + assertEquals("Snapshot at timestamp", expected, actualWithLongFormat); + + // read the table at the timestamp in date format i.e 2022-06-29 18:40:37 + List actualWithDateFormat = + sql("SELECT * FROM %s TIMESTAMP AS OF '%s'", tableName, formattedDate); + assertEquals("Snapshot at timestamp", expected, actualWithDateFormat); + + // HIVE time travel syntax + // read the table at the timestamp in long format i.e 1656507980463. + List actualWithLongFormatInHiveSyntax = + sql("SELECT * FROM %s FOR SYSTEM_TIME AS OF %s", tableName, timestampInSeconds); + assertEquals("Snapshot at specific ID", expected, actualWithLongFormatInHiveSyntax); + + // read the table at the timestamp in date format i.e 2022-06-29 18:40:37 + List actualWithDateFormatInHiveSyntax = + sql("SELECT * FROM %s FOR SYSTEM_TIME AS OF '%s'", tableName, formattedDate); + assertEquals("Snapshot at specific ID", expected, actualWithDateFormatInHiveSyntax); + + // read the table using DataFrameReader option + Dataset df = + spark + .read() + .format("iceberg") + .option(SparkReadOptions.TIMESTAMP_AS_OF, formattedDate) + .load(tableName); + List fromDF = rowsToJava(df.collectAsList()); + assertEquals("Snapshot at timestamp " + timestamp, expected, fromDF); + } + + @Test + public void testInvalidTimeTravelBasedOnBothAsOfAndTableIdentifier() { + // get the snapshot ID of the last write + long snapshotId = validationCatalog.loadTable(tableIdent).currentSnapshot().snapshotId(); + // get a timestamp just after the last write + long timestamp = + validationCatalog.loadTable(tableIdent).currentSnapshot().timestampMillis() + 2; + + String timestampPrefix = "at_timestamp_"; + String snapshotPrefix = "snapshot_id_"; + + // create a second snapshot + sql("INSERT INTO %s VALUES (4, 'd', 4.0), (5, 'e', 5.0)", tableName); + + // using snapshot in table identifier and VERSION AS OF + AssertHelpers.assertThrows( + "Cannot do time-travel based on both table identifier and AS OF", + IllegalArgumentException.class, + "Cannot do time-travel based on both table identifier and AS OF", + () -> { + sql( + "SELECT * FROM %s.%s VERSION AS OF %s", + tableName, snapshotPrefix + snapshotId, snapshotId); + }); + + // using snapshot in table identifier and TIMESTAMP AS OF + AssertHelpers.assertThrows( + "Cannot do time-travel based on both table identifier and AS OF", + IllegalArgumentException.class, + "Cannot do time-travel based on both table identifier and AS OF", + () -> { + sql( + "SELECT * FROM %s.%s VERSION AS OF %s", + tableName, timestampPrefix + timestamp, snapshotId); + }); + + // using timestamp in table identifier and VERSION AS OF + AssertHelpers.assertThrows( + "Cannot do time-travel based on both table identifier and AS OF", + IllegalArgumentException.class, + "Cannot do time-travel based on both table identifier and AS OF", + () -> { + sql( + "SELECT * FROM %s.%s TIMESTAMP AS OF %s", + tableName, snapshotPrefix + snapshotId, timestamp); + }); + + // using timestamp in table identifier and TIMESTAMP AS OF + AssertHelpers.assertThrows( + "Cannot do time-travel based on both table identifier and AS OF", + IllegalArgumentException.class, + "Cannot do time-travel based on both table identifier and AS OF", + () -> { + sql( + "SELECT * FROM %s.%s TIMESTAMP AS OF %s", + tableName, timestampPrefix + timestamp, timestamp); + }); + } + + @Test + public void testSpecifySnapshotAndTimestamp() { + // get the snapshot ID of the last write + long snapshotId = validationCatalog.loadTable(tableIdent).currentSnapshot().snapshotId(); + // get a timestamp just after the last write + long timestamp = + validationCatalog.loadTable(tableIdent).currentSnapshot().timestampMillis() + 2; + + // create a second snapshot + sql("INSERT INTO %s VALUES (4, 'd', 4.0), (5, 'e', 5.0)", tableName); + + AssertHelpers.assertThrows( + "Should not be able to specify both snapshot id and timestamp", + IllegalArgumentException.class, + String.format( + "Can specify only one of snapshot-id (%s), as-of-timestamp (%s)", + snapshotId, timestamp), + () -> { + spark + .read() + .format("iceberg") + .option(SparkReadOptions.SNAPSHOT_ID, snapshotId) + .option(SparkReadOptions.AS_OF_TIMESTAMP, timestamp) + .load(tableName) + .collectAsList(); + }); + } + + @Test + public void testBinaryInFilter() { + sql("CREATE TABLE %s (id bigint, binary binary) USING iceberg", binaryTableName); + sql("INSERT INTO %s VALUES (1, X''), (2, X'1111'), (3, X'11')", binaryTableName); + List expected = ImmutableList.of(row(2L, new byte[] {0x11, 0x11})); + + assertEquals( + "Should return all expected rows", + expected, + sql("SELECT id, binary FROM %s where binary > X'11'", binaryTableName)); + } + + @Test + public void testComplexTypeFilter() { + String complexTypeTableName = tableName("complex_table"); + sql( + "CREATE TABLE %s (id INT, complex STRUCT) USING iceberg", + complexTypeTableName); + sql( + "INSERT INTO TABLE %s VALUES (1, named_struct(\"c1\", 3, \"c2\", \"v1\"))", + complexTypeTableName); + sql( + "INSERT INTO TABLE %s VALUES (2, named_struct(\"c1\", 2, \"c2\", \"v2\"))", + complexTypeTableName); + + List result = + sql( + "SELECT id FROM %s WHERE complex = named_struct(\"c1\", 3, \"c2\", \"v1\")", + complexTypeTableName); + + assertEquals("Should return all expected rows", ImmutableList.of(row(1)), result); + sql("DROP TABLE IF EXISTS %s", complexTypeTableName); + } +} diff --git a/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/sql/TestSparkBucketFunction.java b/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/sql/TestSparkBucketFunction.java new file mode 100644 index 000000000000..c9c8c02b417c --- /dev/null +++ b/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/sql/TestSparkBucketFunction.java @@ -0,0 +1,361 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.sql; + +import java.math.BigDecimal; +import java.nio.ByteBuffer; +import java.nio.charset.StandardCharsets; +import org.apache.iceberg.AssertHelpers; +import org.apache.iceberg.expressions.Literal; +import org.apache.iceberg.relocated.com.google.common.io.BaseEncoding; +import org.apache.iceberg.spark.SparkTestBaseWithCatalog; +import org.apache.iceberg.spark.functions.BucketFunction; +import org.apache.iceberg.types.Types; +import org.apache.spark.sql.AnalysisException; +import org.apache.spark.sql.types.DataTypes; +import org.assertj.core.api.Assertions; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; + +public class TestSparkBucketFunction extends SparkTestBaseWithCatalog { + @Before + public void useCatalog() { + sql("USE %s", catalogName); + } + + @Test + public void testSpecValues() { + Assert.assertEquals( + "Spec example: hash(34) = 2017239379", + 2017239379, + new BucketFunction.BucketInt(DataTypes.IntegerType).hash(34)); + + Assert.assertEquals( + "Spec example: hash(34L) = 2017239379", + 2017239379, + new BucketFunction.BucketLong(DataTypes.LongType).hash(34L)); + + Assert.assertEquals( + "Spec example: hash(decimal2(14.20)) = -500754589", + -500754589, + new BucketFunction.BucketDecimal(DataTypes.createDecimalType(9, 2)) + .hash(new BigDecimal("14.20"))); + + Literal date = Literal.of("2017-11-16").to(Types.DateType.get()); + Assert.assertEquals( + "Spec example: hash(2017-11-16) = -653330422", + -653330422, + new BucketFunction.BucketInt(DataTypes.DateType).hash(date.value())); + + Literal timestampVal = + Literal.of("2017-11-16T22:31:08").to(Types.TimestampType.withoutZone()); + Assert.assertEquals( + "Spec example: hash(2017-11-16T22:31:08) = -2047944441", + -2047944441, + new BucketFunction.BucketLong(DataTypes.TimestampType).hash(timestampVal.value())); + + Assert.assertEquals( + "Spec example: hash(\"iceberg\") = 1210000089", + 1210000089, + new BucketFunction.BucketString().hash("iceberg")); + + ByteBuffer bytes = ByteBuffer.wrap(new byte[] {0, 1, 2, 3}); + Assert.assertEquals( + "Spec example: hash([00 01 02 03]) = -188683207", + -188683207, + new BucketFunction.BucketBinary().hash(bytes)); + } + + @Test + public void testBucketIntegers() { + Assert.assertEquals( + "Byte type should bucket similarly to integer", + 3, + scalarSql("SELECT system.bucket(10, 8Y)")); + Assert.assertEquals( + "Short type should bucket similarly to integer", + 3, + scalarSql("SELECT system.bucket(10, 8S)")); + // Integers + Assert.assertEquals(3, scalarSql("SELECT system.bucket(10, 8)")); + Assert.assertEquals(79, scalarSql("SELECT system.bucket(100, 34)")); + Assert.assertNull(scalarSql("SELECT system.bucket(1, CAST(null AS INT))")); + } + + @Test + public void testBucketDates() { + Assert.assertEquals(3, scalarSql("SELECT system.bucket(10, date('1970-01-09'))")); + Assert.assertEquals(79, scalarSql("SELECT system.bucket(100, date('1970-02-04'))")); + Assert.assertNull(scalarSql("SELECT system.bucket(1, CAST(null AS DATE))")); + } + + @Test + public void testBucketLong() { + Assert.assertEquals(79, scalarSql("SELECT system.bucket(100, 34L)")); + Assert.assertEquals(76, scalarSql("SELECT system.bucket(100, 0L)")); + Assert.assertEquals(97, scalarSql("SELECT system.bucket(100, -34L)")); + Assert.assertEquals(0, scalarSql("SELECT system.bucket(2, -1L)")); + Assert.assertNull(scalarSql("SELECT system.bucket(2, CAST(null AS LONG))")); + } + + @Test + public void testBucketDecimal() { + Assert.assertEquals(56, scalarSql("SELECT system.bucket(64, CAST('12.34' as DECIMAL(9, 2)))")); + Assert.assertEquals(13, scalarSql("SELECT system.bucket(18, CAST('12.30' as DECIMAL(9, 2)))")); + Assert.assertEquals(2, scalarSql("SELECT system.bucket(16, CAST('12.999' as DECIMAL(9, 3)))")); + Assert.assertEquals(21, scalarSql("SELECT system.bucket(32, CAST('0.05' as DECIMAL(5, 2)))")); + Assert.assertEquals(85, scalarSql("SELECT system.bucket(128, CAST('0.05' as DECIMAL(9, 2)))")); + Assert.assertEquals(3, scalarSql("SELECT system.bucket(18, CAST('0.05' as DECIMAL(9, 2)))")); + + Assert.assertNull( + "Null input should return null", + scalarSql("SELECT system.bucket(2, CAST(null AS decimal))")); + } + + @Test + public void testBucketTimestamp() { + Assert.assertEquals( + 99, scalarSql("SELECT system.bucket(100, TIMESTAMP '1997-01-01 00:00:00 UTC+00:00')")); + Assert.assertEquals( + 85, scalarSql("SELECT system.bucket(100, TIMESTAMP '1997-01-31 09:26:56 UTC+00:00')")); + Assert.assertEquals( + 62, scalarSql("SELECT system.bucket(100, TIMESTAMP '2022-08-08 00:00:00 UTC+00:00')")); + Assert.assertNull(scalarSql("SELECT system.bucket(2, CAST(null AS timestamp))")); + } + + @Test + public void testBucketString() { + Assert.assertEquals(4, scalarSql("SELECT system.bucket(5, 'abcdefg')")); + Assert.assertEquals(122, scalarSql("SELECT system.bucket(128, 'abc')")); + Assert.assertEquals(54, scalarSql("SELECT system.bucket(64, 'abcde')")); + Assert.assertEquals(8, scalarSql("SELECT system.bucket(12, '测试')")); + Assert.assertEquals(1, scalarSql("SELECT system.bucket(16, '测试raul试测')")); + Assert.assertEquals( + "Varchar should work like string", + 1, + scalarSql("SELECT system.bucket(16, CAST('测试raul试测' AS varchar(8)))")); + Assert.assertEquals( + "Char should work like string", + 1, + scalarSql("SELECT system.bucket(16, CAST('测试raul试测' AS char(8)))")); + Assert.assertEquals( + "Should not fail on the empty string", 0, scalarSql("SELECT system.bucket(16, '')")); + Assert.assertNull( + "Null input should return null as output", + scalarSql("SELECT system.bucket(16, CAST(null AS string))")); + } + + @Test + public void testBucketBinary() { + Assert.assertEquals( + 1, scalarSql("SELECT system.bucket(10, X'0102030405060708090a0b0c0d0e0f')")); + Assert.assertEquals(10, scalarSql("SELECT system.bucket(12, %s)", asBytesLiteral("abcdefg"))); + Assert.assertEquals(13, scalarSql("SELECT system.bucket(18, %s)", asBytesLiteral("abc\0\0"))); + Assert.assertEquals(42, scalarSql("SELECT system.bucket(48, %s)", asBytesLiteral("abc"))); + Assert.assertEquals(3, scalarSql("SELECT system.bucket(16, %s)", asBytesLiteral("测试_"))); + + Assert.assertNull( + "Null input should return null as output", + scalarSql("SELECT system.bucket(100, CAST(null AS binary))")); + } + + @Test + public void testNumBucketsAcceptsShortAndByte() { + Assert.assertEquals( + "Short types should be usable for the number of buckets field", + 1, + scalarSql("SELECT system.bucket(5S, 1L)")); + + Assert.assertEquals( + "Byte types should be allowed for the number of buckets field", + 1, + scalarSql("SELECT system.bucket(5Y, 1)")); + } + + @Test + public void testWrongNumberOfArguments() { + AssertHelpers.assertThrows( + "Function resolution should not work with zero arguments", + AnalysisException.class, + "Function 'bucket' cannot process input: (): Wrong number of inputs (expected numBuckets and value)", + () -> scalarSql("SELECT system.bucket()")); + + AssertHelpers.assertThrows( + "Function resolution should not work with only one argument", + AnalysisException.class, + "Function 'bucket' cannot process input: (int): Wrong number of inputs (expected numBuckets and value)", + () -> scalarSql("SELECT system.bucket(1)")); + + AssertHelpers.assertThrows( + "Function resolution should not work with more than two arguments", + AnalysisException.class, + "Function 'bucket' cannot process input: (int, bigint, int): Wrong number of inputs (expected numBuckets and value)", + () -> scalarSql("SELECT system.bucket(1, 1L, 1)")); + } + + @Test + public void testInvalidTypesCannotBeUsedForNumberOfBuckets() { + AssertHelpers.assertThrows( + "Decimal type should not be coercible to the number of buckets", + AnalysisException.class, + "Function 'bucket' cannot process input: (decimal(9,2), int): Expected number of buckets to be tinyint, shortint or int", + () -> scalarSql("SELECT system.bucket(CAST('12.34' as DECIMAL(9, 2)), 10)")); + + AssertHelpers.assertThrows( + "Long type should not be coercible to the number of buckets", + AnalysisException.class, + "Function 'bucket' cannot process input: (bigint, int): Expected number of buckets to be tinyint, shortint or int", + () -> scalarSql("SELECT system.bucket(12L, 10)")); + + AssertHelpers.assertThrows( + "String type should not be coercible to the number of buckets", + AnalysisException.class, + "Function 'bucket' cannot process input: (string, int): Expected number of buckets to be tinyint, shortint or int", + () -> scalarSql("SELECT system.bucket('5', 10)")); + + AssertHelpers.assertThrows( + "Interval year to month type should not be coercible to the number of buckets", + AnalysisException.class, + "Function 'bucket' cannot process input: (interval year to month, int): Expected number of buckets to be tinyint, shortint or int", + () -> scalarSql("SELECT system.bucket(INTERVAL '100-00' YEAR TO MONTH, 10)")); + + AssertHelpers.assertThrows( + "Interval day-time type should not be coercible to the number of buckets", + AnalysisException.class, + "Function 'bucket' cannot process input: (interval day to second, int): Expected number of buckets to be tinyint, shortint or int", + () -> scalarSql("SELECT system.bucket(CAST('11 23:4:0' AS INTERVAL DAY TO SECOND), 10)")); + } + + @Test + public void testInvalidTypesForBucketColumn() { + AssertHelpers.assertThrows( + "Double type should not be bucketable", + AnalysisException.class, + "Function 'bucket' cannot process input: (int, float): Expected column to be date, tinyint, smallint, int, bigint, decimal, timestamp, string, or binary", + () -> scalarSql("SELECT system.bucket(10, cast(12.3456 as float))")); + + AssertHelpers.assertThrows( + "Double type should not be bucketable", + AnalysisException.class, + "Function 'bucket' cannot process input: (int, double): Expected column to be date, tinyint, smallint, int, bigint, decimal, timestamp, string, or binary", + () -> scalarSql("SELECT system.bucket(10, cast(12.3456 as double))")); + + AssertHelpers.assertThrows( + "Boolean type should not be bucketable", + AnalysisException.class, + "Function 'bucket' cannot process input: (int, boolean)", + () -> scalarSql("SELECT system.bucket(10, true)")); + + AssertHelpers.assertThrows( + "Map types should not be bucketable", + AnalysisException.class, + "Function 'bucket' cannot process input: (int, map)", + () -> scalarSql("SELECT system.bucket(10, map(1, 1))")); + + AssertHelpers.assertThrows( + "Array types should not be bucketable", + AnalysisException.class, + "Function 'bucket' cannot process input: (int, array)", + () -> scalarSql("SELECT system.bucket(10, array(1L))")); + + AssertHelpers.assertThrows( + "Interval year-to-month type should not be bucketable", + AnalysisException.class, + "Function 'bucket' cannot process input: (int, interval year to month)", + () -> scalarSql("SELECT system.bucket(10, INTERVAL '100-00' YEAR TO MONTH)")); + + AssertHelpers.assertThrows( + "Interval day-time type should not be bucketable", + AnalysisException.class, + "Function 'bucket' cannot process input: (int, interval day to second)", + () -> scalarSql("SELECT system.bucket(10, CAST('11 23:4:0' AS INTERVAL DAY TO SECOND))")); + } + + @Test + public void testThatMagicFunctionsAreInvoked() { + // TinyInt + Assertions.assertThat(scalarSql("EXPLAIN EXTENDED SELECT system.bucket(5, 6Y)")) + .asString() + .isNotNull() + .contains("staticinvoke(class org.apache.iceberg.spark.functions.BucketFunction$BucketInt"); + + // SmallInt + Assertions.assertThat(scalarSql("EXPLAIN EXTENDED SELECT system.bucket(5, 6S)")) + .asString() + .isNotNull() + .contains("staticinvoke(class org.apache.iceberg.spark.functions.BucketFunction$BucketInt"); + + // Int + Assertions.assertThat(scalarSql("EXPLAIN EXTENDED SELECT system.bucket(5, 6)")) + .asString() + .isNotNull() + .contains("staticinvoke(class org.apache.iceberg.spark.functions.BucketFunction$BucketInt"); + + // Date + Assertions.assertThat( + scalarSql("EXPLAIN EXTENDED SELECT system.bucket(100, DATE '2022-08-08')")) + .asString() + .isNotNull() + .contains("staticinvoke(class org.apache.iceberg.spark.functions.BucketFunction$BucketInt"); + + // Long + Assertions.assertThat(scalarSql("EXPLAIN EXTENDED SELECT system.bucket(5, 6L)")) + .asString() + .isNotNull() + .contains( + "staticinvoke(class org.apache.iceberg.spark.functions.BucketFunction$BucketLong"); + + // Timestamp + Assertions.assertThat( + scalarSql("EXPLAIN EXTENDED SELECT system.bucket(100, TIMESTAMP '2022-08-08')")) + .asString() + .isNotNull() + .contains( + "staticinvoke(class org.apache.iceberg.spark.functions.BucketFunction$BucketLong"); + + // String + Assertions.assertThat(scalarSql("EXPLAIN EXTENDED SELECT system.bucket(5, 'abcdefg')")) + .asString() + .isNotNull() + .contains( + "staticinvoke(class org.apache.iceberg.spark.functions.BucketFunction$BucketString"); + + // Decimal + Assertions.assertThat( + scalarSql("EXPLAIN EXTENDED SELECT system.bucket(5, CAST('12.34' AS DECIMAL))")) + .asString() + .isNotNull() + .contains( + "staticinvoke(class org.apache.iceberg.spark.functions.BucketFunction$BucketDecimal"); + + // Binary + Assertions.assertThat( + scalarSql("EXPLAIN EXTENDED SELECT system.bucket(4, X'0102030405060708')")) + .asString() + .isNotNull() + .contains( + "staticinvoke(class org.apache.iceberg.spark.functions.BucketFunction$BucketBinary"); + } + + private String asBytesLiteral(String value) { + byte[] bytes = value.getBytes(StandardCharsets.UTF_8); + return "X'" + BaseEncoding.base16().encode(bytes) + "'"; + } +} diff --git a/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/sql/TestSparkDaysFunction.java b/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/sql/TestSparkDaysFunction.java new file mode 100644 index 000000000000..ccba28735e33 --- /dev/null +++ b/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/sql/TestSparkDaysFunction.java @@ -0,0 +1,99 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.sql; + +import java.sql.Date; +import org.apache.iceberg.AssertHelpers; +import org.apache.iceberg.spark.SparkTestBaseWithCatalog; +import org.apache.spark.sql.AnalysisException; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; + +public class TestSparkDaysFunction extends SparkTestBaseWithCatalog { + + @Before + public void useCatalog() { + sql("USE %s", catalogName); + } + + @Test + public void testDates() { + Assert.assertEquals( + "Expected to produce 2017-12-01", + Date.valueOf("2017-12-01"), + scalarSql("SELECT system.days(date('2017-12-01'))")); + Assert.assertEquals( + "Expected to produce 1970-01-01", + Date.valueOf("1970-01-01"), + scalarSql("SELECT system.days(date('1970-01-01'))")); + Assert.assertEquals( + "Expected to produce 1969-12-31", + Date.valueOf("1969-12-31"), + scalarSql("SELECT system.days(date('1969-12-31'))")); + Assert.assertNull(scalarSql("SELECT system.days(CAST(null AS DATE))")); + } + + @Test + public void testTimestamps() { + Assert.assertEquals( + "Expected to produce 2017-12-01", + Date.valueOf("2017-12-01"), + scalarSql("SELECT system.days(TIMESTAMP '2017-12-01 10:12:55.038194 UTC+00:00')")); + Assert.assertEquals( + "Expected to produce 1970-01-01", + Date.valueOf("1970-01-01"), + scalarSql("SELECT system.days(TIMESTAMP '1970-01-01 00:00:01.000001 UTC+00:00')")); + Assert.assertEquals( + "Expected to produce 1969-12-31", + Date.valueOf("1969-12-31"), + scalarSql("SELECT system.days(TIMESTAMP '1969-12-31 23:59:58.999999 UTC+00:00')")); + Assert.assertNull(scalarSql("SELECT system.days(CAST(null AS DATE))")); + } + + @Test + public void testWrongNumberOfArguments() { + AssertHelpers.assertThrows( + "Function resolution should not work with zero arguments", + AnalysisException.class, + "Function 'days' cannot process input: (): Wrong number of inputs", + () -> scalarSql("SELECT system.days()")); + + AssertHelpers.assertThrows( + "Function resolution should not work with more than one argument", + AnalysisException.class, + "Function 'days' cannot process input: (date, date): Wrong number of inputs", + () -> scalarSql("SELECT system.days(date('1969-12-31'), date('1969-12-31'))")); + } + + @Test + public void testInvalidInputTypes() { + AssertHelpers.assertThrows( + "Int type should not be coercible to date/timestamp", + AnalysisException.class, + "Function 'days' cannot process input: (int): Expected value to be date or timestamp", + () -> scalarSql("SELECT system.days(1)")); + + AssertHelpers.assertThrows( + "Long type should not be coercible to date/timestamp", + AnalysisException.class, + "Function 'days' cannot process input: (bigint): Expected value to be date or timestamp", + () -> scalarSql("SELECT system.days(1L)")); + } +} diff --git a/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/sql/TestSparkHoursFunction.java b/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/sql/TestSparkHoursFunction.java new file mode 100644 index 000000000000..fc0d781318c9 --- /dev/null +++ b/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/sql/TestSparkHoursFunction.java @@ -0,0 +1,81 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.sql; + +import org.apache.iceberg.AssertHelpers; +import org.apache.iceberg.spark.SparkTestBaseWithCatalog; +import org.apache.spark.sql.AnalysisException; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; + +public class TestSparkHoursFunction extends SparkTestBaseWithCatalog { + + @Before + public void useCatalog() { + sql("USE %s", catalogName); + } + + @Test + public void testTimestamps() { + Assert.assertEquals( + "Expected to produce 17501 * 24 + 10", + 420034, + scalarSql("SELECT system.hours(TIMESTAMP '2017-12-01 10:12:55.038194 UTC+00:00')")); + Assert.assertEquals( + "Expected to produce 0 * 24 + 0 = 0", + 0, + scalarSql("SELECT system.hours(TIMESTAMP '1970-01-01 00:00:01.000001 UTC+00:00')")); + Assert.assertEquals( + "Expected to produce -1", + -1, + scalarSql("SELECT system.hours(TIMESTAMP '1969-12-31 23:59:58.999999 UTC+00:00')")); + Assert.assertNull(scalarSql("SELECT system.hours(CAST(null AS TIMESTAMP))")); + } + + @Test + public void testWrongNumberOfArguments() { + AssertHelpers.assertThrows( + "Function resolution should not work with zero arguments", + AnalysisException.class, + "Function 'hours' cannot process input: (): Wrong number of inputs", + () -> scalarSql("SELECT system.hours()")); + + AssertHelpers.assertThrows( + "Function resolution should not work with more than one argument", + AnalysisException.class, + "Function 'hours' cannot process input: (date, date): Wrong number of inputs", + () -> scalarSql("SELECT system.hours(date('1969-12-31'), date('1969-12-31'))")); + } + + @Test + public void testInvalidInputTypes() { + AssertHelpers.assertThrows( + "Int type should not be coercible to timestamp", + AnalysisException.class, + "Function 'hours' cannot process input: (int): Expected value to be timestamp", + () -> scalarSql("SELECT system.hours(1)")); + + AssertHelpers.assertThrows( + "Long type should not be coercible to timestamp", + AnalysisException.class, + "Function 'hours' cannot process input: (bigint): Expected value to be timestamp", + () -> scalarSql("SELECT system.hours(1L)")); + } +} diff --git a/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/sql/TestSparkMonthsFunction.java b/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/sql/TestSparkMonthsFunction.java new file mode 100644 index 000000000000..b88bf00256b0 --- /dev/null +++ b/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/sql/TestSparkMonthsFunction.java @@ -0,0 +1,115 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.sql; + +import org.apache.iceberg.AssertHelpers; +import org.apache.iceberg.spark.SparkTestBaseWithCatalog; +import org.apache.iceberg.spark.functions.MonthsFunction; +import org.apache.spark.sql.AnalysisException; +import org.assertj.core.api.Assertions; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; + +public class TestSparkMonthsFunction extends SparkTestBaseWithCatalog { + + @Before + public void useCatalog() { + sql("USE %s", catalogName); + } + + @Test + public void testDates() { + Assert.assertEquals( + "Expected to produce 47 * 12 + 11 = 575", + 575, + scalarSql("SELECT system.months(date('2017-12-01'))")); + Assert.assertEquals( + "Expected to produce 0 * 12 + 0 = 0", + 0, + scalarSql("SELECT system.months(date('1970-01-01'))")); + Assert.assertEquals( + "Expected to produce -1", -1, scalarSql("SELECT system.months(date('1969-12-31'))")); + Assert.assertNull(scalarSql("SELECT system.months(CAST(null AS DATE))")); + } + + @Test + public void testTimestamps() { + Assert.assertEquals( + "Expected to produce 47 * 12 + 11 = 575", + 575, + scalarSql("SELECT system.months(TIMESTAMP '2017-12-01 10:12:55.038194 UTC+00:00')")); + Assert.assertEquals( + "Expected to produce 0 * 12 + 0 = 0", + 0, + scalarSql("SELECT system.months(TIMESTAMP '1970-01-01 00:00:01.000001 UTC+00:00')")); + Assert.assertEquals( + "Expected to produce -1", + -1, + scalarSql("SELECT system.months(TIMESTAMP '1969-12-31 23:59:58.999999 UTC+00:00')")); + Assert.assertNull(scalarSql("SELECT system.months(CAST(null AS TIMESTAMP))")); + } + + @Test + public void testWrongNumberOfArguments() { + AssertHelpers.assertThrows( + "Function resolution should not work with zero arguments", + AnalysisException.class, + "Function 'months' cannot process input: (): Wrong number of inputs", + () -> scalarSql("SELECT system.months()")); + + AssertHelpers.assertThrows( + "Function resolution should not work with more than one argument", + AnalysisException.class, + "Function 'months' cannot process input: (date, date): Wrong number of inputs", + () -> scalarSql("SELECT system.months(date('1969-12-31'), date('1969-12-31'))")); + } + + @Test + public void testInvalidInputTypes() { + AssertHelpers.assertThrows( + "Int type should not be coercible to date/timestamp", + AnalysisException.class, + "Function 'months' cannot process input: (int): Expected value to be date or timestamp", + () -> scalarSql("SELECT system.months(1)")); + + AssertHelpers.assertThrows( + "Long type should not be coercible to date/timestamp", + AnalysisException.class, + "Function 'months' cannot process input: (bigint): Expected value to be date or timestamp", + () -> scalarSql("SELECT system.months(1L)")); + } + + @Test + public void testThatMagicFunctionsAreInvoked() { + String dateValue = "date('2017-12-01')"; + String dateTransformClass = MonthsFunction.DateToMonthsFunction.class.getName(); + Assertions.assertThat(scalarSql("EXPLAIN EXTENDED SELECT system.months(%s)", dateValue)) + .asString() + .isNotNull() + .contains("staticinvoke(class " + dateTransformClass); + + String timestampValue = "TIMESTAMP '2017-12-01 10:12:55.038194 UTC+00:00'"; + String timestampTransformClass = MonthsFunction.TimestampToMonthsFunction.class.getName(); + Assertions.assertThat(scalarSql("EXPLAIN EXTENDED SELECT system.months(%s)", timestampValue)) + .asString() + .isNotNull() + .contains("staticinvoke(class " + timestampTransformClass); + } +} diff --git a/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/sql/TestSparkTruncateFunction.java b/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/sql/TestSparkTruncateFunction.java new file mode 100644 index 000000000000..f21544fcdf7a --- /dev/null +++ b/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/sql/TestSparkTruncateFunction.java @@ -0,0 +1,470 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.sql; + +import java.math.BigDecimal; +import java.nio.charset.StandardCharsets; +import org.apache.iceberg.AssertHelpers; +import org.apache.iceberg.relocated.com.google.common.io.BaseEncoding; +import org.apache.iceberg.spark.SparkTestBaseWithCatalog; +import org.apache.spark.sql.AnalysisException; +import org.assertj.core.api.Assertions; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; + +public class TestSparkTruncateFunction extends SparkTestBaseWithCatalog { + public TestSparkTruncateFunction() {} + + @Before + public void useCatalog() { + sql("USE %s", catalogName); + } + + @Test + public void testTruncateTinyInt() { + Assert.assertEquals((byte) 0, scalarSql("SELECT system.truncate(10, 0Y)")); + Assert.assertEquals((byte) 0, scalarSql("SELECT system.truncate(10, 1Y)")); + Assert.assertEquals((byte) 0, scalarSql("SELECT system.truncate(10, 5Y)")); + Assert.assertEquals((byte) 0, scalarSql("SELECT system.truncate(10, 9Y)")); + Assert.assertEquals((byte) 10, scalarSql("SELECT system.truncate(10, 10Y)")); + Assert.assertEquals((byte) 10, scalarSql("SELECT system.truncate(10, 11Y)")); + Assert.assertEquals((byte) -10, scalarSql("SELECT system.truncate(10, -1Y)")); + Assert.assertEquals((byte) -10, scalarSql("SELECT system.truncate(10, -5Y)")); + Assert.assertEquals((byte) -10, scalarSql("SELECT system.truncate(10, -10Y)")); + Assert.assertEquals((byte) -20, scalarSql("SELECT system.truncate(10, -11Y)")); + + // Check that different widths can be used + Assert.assertEquals((byte) -2, scalarSql("SELECT system.truncate(2, -1Y)")); + + Assert.assertNull( + "Null input should return null", + scalarSql("SELECT system.truncate(2, CAST(null AS tinyint))")); + } + + @Test + public void testTruncateSmallInt() { + Assert.assertEquals((short) 0, scalarSql("SELECT system.truncate(10, 0S)")); + Assert.assertEquals((short) 0, scalarSql("SELECT system.truncate(10, 1S)")); + Assert.assertEquals((short) 0, scalarSql("SELECT system.truncate(10, 5S)")); + Assert.assertEquals((short) 0, scalarSql("SELECT system.truncate(10, 9S)")); + Assert.assertEquals((short) 10, scalarSql("SELECT system.truncate(10, 10S)")); + Assert.assertEquals((short) 10, scalarSql("SELECT system.truncate(10, 11S)")); + Assert.assertEquals((short) -10, scalarSql("SELECT system.truncate(10, -1S)")); + Assert.assertEquals((short) -10, scalarSql("SELECT system.truncate(10, -5S)")); + Assert.assertEquals((short) -10, scalarSql("SELECT system.truncate(10, -10S)")); + Assert.assertEquals((short) -20, scalarSql("SELECT system.truncate(10, -11S)")); + + // Check that different widths can be used + Assert.assertEquals((short) -2, scalarSql("SELECT system.truncate(2, -1S)")); + + Assert.assertNull( + "Null input should return null", + scalarSql("SELECT system.truncate(2, CAST(null AS smallint))")); + } + + @Test + public void testTruncateInt() { + Assert.assertEquals(0, scalarSql("SELECT system.truncate(10, 0)")); + Assert.assertEquals(0, scalarSql("SELECT system.truncate(10, 1)")); + Assert.assertEquals(0, scalarSql("SELECT system.truncate(10, 5)")); + Assert.assertEquals(0, scalarSql("SELECT system.truncate(10, 9)")); + Assert.assertEquals(10, scalarSql("SELECT system.truncate(10, 10)")); + Assert.assertEquals(10, scalarSql("SELECT system.truncate(10, 11)")); + Assert.assertEquals(-10, scalarSql("SELECT system.truncate(10, -1)")); + Assert.assertEquals(-10, scalarSql("SELECT system.truncate(10, -5)")); + Assert.assertEquals(-10, scalarSql("SELECT system.truncate(10, -10)")); + Assert.assertEquals(-20, scalarSql("SELECT system.truncate(10, -11)")); + + // Check that different widths can be used + Assert.assertEquals(-2, scalarSql("SELECT system.truncate(2, -1)")); + Assert.assertEquals(0, scalarSql("SELECT system.truncate(300, 1)")); + + Assert.assertNull( + "Null input should return null", scalarSql("SELECT system.truncate(2, CAST(null AS int))")); + } + + @Test + public void testTruncateBigInt() { + Assert.assertEquals(0L, scalarSql("SELECT system.truncate(10, 0L)")); + Assert.assertEquals(0L, scalarSql("SELECT system.truncate(10, 1L)")); + Assert.assertEquals(0L, scalarSql("SELECT system.truncate(10, 5L)")); + Assert.assertEquals(0L, scalarSql("SELECT system.truncate(10, 9L)")); + Assert.assertEquals(10L, scalarSql("SELECT system.truncate(10, 10L)")); + Assert.assertEquals(10L, scalarSql("SELECT system.truncate(10, 11L)")); + Assert.assertEquals(-10L, scalarSql("SELECT system.truncate(10, -1L)")); + Assert.assertEquals(-10L, scalarSql("SELECT system.truncate(10, -5L)")); + Assert.assertEquals(-10L, scalarSql("SELECT system.truncate(10, -10L)")); + Assert.assertEquals(-20L, scalarSql("SELECT system.truncate(10, -11L)")); + + // Check that different widths can be used + Assert.assertEquals(-2L, scalarSql("SELECT system.truncate(2, -1L)")); + + Assert.assertNull( + "Null input should return null", + scalarSql("SELECT system.truncate(2, CAST(null AS bigint))")); + } + + @Test + public void testTruncateDecimal() { + // decimal truncation works by applying the decimal scale to the width: ie 10 scale 2 = 0.10 + Assert.assertEquals( + new BigDecimal("12.30"), + scalarSql("SELECT system.truncate(10, CAST(%s as DECIMAL(9, 2)))", "12.34")); + + Assert.assertEquals( + new BigDecimal("12.30"), + scalarSql("SELECT system.truncate(10, CAST(%s as DECIMAL(9, 2)))", "12.30")); + + Assert.assertEquals( + new BigDecimal("12.290"), + scalarSql("SELECT system.truncate(10, CAST(%s as DECIMAL(9, 3)))", "12.299")); + + Assert.assertEquals( + new BigDecimal("0.03"), + scalarSql("SELECT system.truncate(3, CAST(%s as DECIMAL(5, 2)))", "0.05")); + + Assert.assertEquals( + new BigDecimal("0.00"), + scalarSql("SELECT system.truncate(10, CAST(%s as DECIMAL(9, 2)))", "0.05")); + + Assert.assertEquals( + new BigDecimal("-0.10"), + scalarSql("SELECT system.truncate(10, CAST(%s as DECIMAL(9, 2)))", "-0.05")); + + Assert.assertEquals( + "Implicit decimal scale and precision should be allowed", + new BigDecimal("12345.3480"), + scalarSql("SELECT system.truncate(10, 12345.3482)")); + + BigDecimal truncatedDecimal = + (BigDecimal) scalarSql("SELECT system.truncate(10, CAST(%s as DECIMAL(6, 4)))", "-0.05"); + Assert.assertEquals( + "Truncating a decimal should return a decimal with the same scale", + 4, + truncatedDecimal.scale()); + + Assert.assertEquals( + "Truncating a decimal should return a decimal with the correct scale", + BigDecimal.valueOf(-500, 4), + truncatedDecimal); + + Assert.assertNull( + "Null input should return null", + scalarSql("SELECT system.truncate(2, CAST(null AS decimal))")); + } + + @SuppressWarnings("checkstyle:AvoidEscapedUnicodeCharacters") + @Test + public void testTruncateString() { + Assert.assertEquals( + "Should system.truncate strings longer than length", + "abcde", + scalarSql("SELECT system.truncate(5, 'abcdefg')")); + + Assert.assertEquals( + "Should not pad strings shorter than length", + "abc", + scalarSql("SELECT system.truncate(5, 'abc')")); + + Assert.assertEquals( + "Should not alter strings equal to length", + "abcde", + scalarSql("SELECT system.truncate(5, 'abcde')")); + + Assert.assertEquals( + "Strings with multibyte unicode characters should should truncate along codepoint boundaries", + "イロ", + scalarSql("SELECT system.truncate(2, 'イロハニホヘト')")); + + Assert.assertEquals( + "Strings with multibyte unicode characters should truncate along codepoint boundaries", + "イロハ", + scalarSql("SELECT system.truncate(3, 'イロハニホヘト')")); + + Assert.assertEquals( + "Strings with multibyte unicode characters should not alter input with fewer codepoints than width", + "イロハニホヘト", + scalarSql("SELECT system.truncate(7, 'イロハニホヘト')")); + + String stringWithTwoCodePointsEachFourBytes = "\uD800\uDC00\uD800\uDC00"; + Assert.assertEquals( + "String truncation on four byte codepoints should work as expected", + "\uD800\uDC00", + scalarSql("SELECT system.truncate(1, '%s')", stringWithTwoCodePointsEachFourBytes)); + + Assert.assertEquals( + "Should handle three-byte UTF-8 characters appropriately", + "测", + scalarSql("SELECT system.truncate(1, '测试')")); + + Assert.assertEquals( + "Should handle three-byte UTF-8 characters mixed with two byte utf-8 characters", + "测试ra", + scalarSql("SELECT system.truncate(4, '测试raul试测')")); + + Assert.assertEquals( + "Should not fail on the empty string", "", scalarSql("SELECT system.truncate(10, '')")); + + Assert.assertNull( + "Null input should return null as output", + scalarSql("SELECT system.truncate(3, CAST(null AS string))")); + + Assert.assertEquals( + "Varchar should work like string", + "测试ra", + scalarSql("SELECT system.truncate(4, CAST('测试raul试测' AS varchar(8)))")); + + Assert.assertEquals( + "Char should work like string", + "测试ra", + scalarSql("SELECT system.truncate(4, CAST('测试raul试测' AS char(8)))")); + } + + @Test + public void testTruncateBinary() { + Assert.assertArrayEquals( + new byte[] {1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, + (byte[]) scalarSql("SELECT system.truncate(10, X'0102030405060708090a0b0c0d0e0f')")); + Assert.assertArrayEquals( + "Should return the same input when value is equal to truncation width", + "abc".getBytes(StandardCharsets.UTF_8), + (byte[]) scalarSql("SELECT system.truncate(3, %s)", asBytesLiteral("abcdefg"))); + Assert.assertArrayEquals( + "Should not truncate, pad, or trim the input when its length is less than the width", + "abc\0\0".getBytes(StandardCharsets.UTF_8), + (byte[]) scalarSql("SELECT system.truncate(10, %s)", asBytesLiteral("abc\0\0"))); + Assert.assertArrayEquals( + "Should not pad the input when its length is equal to the width", + "abc".getBytes(StandardCharsets.UTF_8), + (byte[]) scalarSql("SELECT system.truncate(3, %s)", asBytesLiteral("abc"))); + Assert.assertArrayEquals( + "Should handle three-byte UTF-8 characters appropriately", + "测试".getBytes(StandardCharsets.UTF_8), + (byte[]) scalarSql("SELECT system.truncate(6, %s)", asBytesLiteral("测试_"))); + + Assert.assertNull( + "Null input should return null as output", + scalarSql("SELECT system.truncate(3, CAST(null AS binary))")); + } + + @Test + public void testTruncateUsingDataframeForWidthWithVaryingWidth() { + // This situation is atypical but allowed. Typically, width is static as data is partitioned on + // one width. + long rumRows = 10L; + long numNonZero = + spark + .range(rumRows) + .toDF("value") + .selectExpr("CAST(value + 1 AS INT) AS width", "value") + .selectExpr("system.truncate(width, value) as truncated_value") + .filter("truncated_value == 0") + .count(); + Assert.assertEquals( + "A truncate function with variable widths should be usable on dataframe columns", + rumRows, + numNonZero); + } + + @Test + public void testWidthAcceptsShortAndByte() { + Assert.assertEquals( + "Short types should be usable for the width field", + 0L, + scalarSql("SELECT system.truncate(5S, 1L)")); + + Assert.assertEquals( + "Byte types should be allowed for the width field", + 0, + scalarSql("SELECT system.truncate(5Y, 1)")); + } + + @Test + public void testWrongNumberOfArguments() { + AssertHelpers.assertThrows( + "Function resolution should not work with zero arguments", + AnalysisException.class, + "Function 'truncate' cannot process input: (): Wrong number of inputs (expected width and value)", + () -> scalarSql("SELECT system.truncate()")); + + AssertHelpers.assertThrows( + "Function resolution should not work with only one argument", + AnalysisException.class, + "Function 'truncate' cannot process input: (int): Wrong number of inputs (expected width and value)", + () -> scalarSql("SELECT system.truncate(1)")); + + AssertHelpers.assertThrows( + "Function resolution should not work with more than two arguments", + AnalysisException.class, + "Function 'truncate' cannot process input: (int, bigint, int): Wrong number of inputs (expected width and value)", + () -> scalarSql("SELECT system.truncate(1, 1L, 1)")); + } + + @Test + public void testInvalidTypesCannotBeUsedForWidth() { + AssertHelpers.assertThrows( + "Decimal type should not be coercible to the width field", + AnalysisException.class, + "Function 'truncate' cannot process input: (decimal(9,2), int): Expected truncation width to be tinyint, shortint or int", + () -> scalarSql("SELECT system.truncate(CAST('12.34' as DECIMAL(9, 2)), 10)")); + + AssertHelpers.assertThrows( + "String type should not be coercible to the width field", + AnalysisException.class, + "Function 'truncate' cannot process input: (string, int): Expected truncation width to be tinyint, shortint or int", + () -> scalarSql("SELECT system.truncate('5', 10)")); + + AssertHelpers.assertThrows( + "Interval year to month type should not be coercible to the width field", + AnalysisException.class, + "Function 'truncate' cannot process input: (interval year to month, int): Expected truncation width to be tinyint, shortint or int", + () -> scalarSql("SELECT system.truncate(INTERVAL '100-00' YEAR TO MONTH, 10)")); + + AssertHelpers.assertThrows( + "Interval day-time type should not be coercible to the width field", + AnalysisException.class, + "Function 'truncate' cannot process input: (interval day to second, int): Expected truncation width to be tinyint, shortint or int", + () -> scalarSql("SELECT system.truncate(CAST('11 23:4:0' AS INTERVAL DAY TO SECOND), 10)")); + } + + @Test + public void testInvalidTypesForTruncationColumn() { + AssertHelpers.assertThrows( + "FLoat type should not be truncatable", + AnalysisException.class, + "Function 'truncate' cannot process input: (int, float): Expected truncation col to be tinyint, shortint, int, bigint, decimal, string, or binary", + () -> scalarSql("SELECT system.truncate(10, cast(12.3456 as float))")); + + AssertHelpers.assertThrows( + "Double type should not be truncatable", + AnalysisException.class, + "Function 'truncate' cannot process input: (int, double): Expected truncation col to be tinyint, shortint, int, bigint, decimal, string, or binary", + () -> scalarSql("SELECT system.truncate(10, cast(12.3456 as double))")); + + AssertHelpers.assertThrows( + "Boolean type should not be truncatable", + AnalysisException.class, + "Function 'truncate' cannot process input: (int, boolean): Expected truncation col to be tinyint, shortint, int, bigint, decimal, string, or binary", + () -> scalarSql("SELECT system.truncate(10, true)")); + + AssertHelpers.assertThrows( + "Map types should not be truncatable", + AnalysisException.class, + "Function 'truncate' cannot process input: (int, map): Expected truncation col to be tinyint, shortint, int, bigint, decimal, string, or binary", + () -> scalarSql("SELECT system.truncate(10, map(1, 1))")); + + AssertHelpers.assertThrows( + "Array types should not be truncatable", + AnalysisException.class, + "Function 'truncate' cannot process input: (int, array): Expected truncation col to be tinyint, shortint, int, bigint, decimal, string, or binary", + () -> scalarSql("SELECT system.truncate(10, array(1L))")); + + AssertHelpers.assertThrows( + "Interval year-to-month type should not be truncatable", + AnalysisException.class, + "Function 'truncate' cannot process input: (int, interval year to month): Expected truncation col to be tinyint, shortint, int, bigint, decimal, string, or binary", + () -> scalarSql("SELECT system.truncate(10, INTERVAL '100-00' YEAR TO MONTH)")); + + AssertHelpers.assertThrows( + "Interval day-time type should not be truncatable", + AnalysisException.class, + "Function 'truncate' cannot process input: (int, interval day to second): Expected truncation col to be tinyint, shortint, int, bigint, decimal, string, or binary", + () -> scalarSql("SELECT system.truncate(10, CAST('11 23:4:0' AS INTERVAL DAY TO SECOND))")); + } + + @Test + public void testMagicFunctionsResolveForTinyIntAndSmallIntWidths() { + // Magic functions have staticinvoke in the explain output. Nonmagic calls use + // applyfunctionexpression instead. + String tinyIntWidthExplain = + (String) scalarSql("EXPLAIN EXTENDED SELECT system.truncate(1Y, 6)"); + Assertions.assertThat(tinyIntWidthExplain) + .contains("cast(1 as int)") + .contains( + "staticinvoke(class org.apache.iceberg.spark.functions.TruncateFunction$TruncateInt"); + + String smallIntWidth = (String) scalarSql("EXPLAIN EXTENDED SELECT system.truncate(5S, 6L)"); + Assertions.assertThat(smallIntWidth) + .contains("cast(5 as int)") + .contains( + "staticinvoke(class org.apache.iceberg.spark.functions.TruncateFunction$TruncateBigInt"); + } + + @Test + public void testThatMagicFunctionsAreInvoked() { + // Magic functions have `staticinvoke` in the explain output. + // Non-magic calls have `applyfunctionexpression` instead. + + // TinyInt + Assertions.assertThat(scalarSql("EXPLAIN EXTENDED select system.truncate(5, 6Y)")) + .asString() + .isNotNull() + .contains( + "staticinvoke(class org.apache.iceberg.spark.functions.TruncateFunction$TruncateTinyInt"); + + // SmallInt + Assertions.assertThat(scalarSql("EXPLAIN EXTENDED select system.truncate(5, 6S)")) + .asString() + .isNotNull() + .contains( + "staticinvoke(class org.apache.iceberg.spark.functions.TruncateFunction$TruncateSmallInt"); + + // Int + Assertions.assertThat(scalarSql("EXPLAIN EXTENDED select system.truncate(5, 6)")) + .asString() + .isNotNull() + .contains( + "staticinvoke(class org.apache.iceberg.spark.functions.TruncateFunction$TruncateInt"); + + // Long + Assertions.assertThat(scalarSql("EXPLAIN EXTENDED SELECT system.truncate(5, 6L)")) + .asString() + .isNotNull() + .contains( + "staticinvoke(class org.apache.iceberg.spark.functions.TruncateFunction$TruncateBigInt"); + + // String + Assertions.assertThat(scalarSql("EXPLAIN EXTENDED SELECT system.truncate(5, 'abcdefg')")) + .asString() + .isNotNull() + .contains( + "staticinvoke(class org.apache.iceberg.spark.functions.TruncateFunction$TruncateString"); + + // Decimal + Assertions.assertThat(scalarSql("EXPLAIN EXTENDED SELECT system.truncate(5, 12.34)")) + .asString() + .isNotNull() + .contains( + "staticinvoke(class org.apache.iceberg.spark.functions.TruncateFunction$TruncateDecimal"); + + // Binary + Assertions.assertThat( + scalarSql("EXPLAIN EXTENDED SELECT system.truncate(4, X'0102030405060708')")) + .asString() + .isNotNull() + .contains( + "staticinvoke(class org.apache.iceberg.spark.functions.TruncateFunction$TruncateBinary"); + } + + private String asBytesLiteral(String value) { + byte[] bytes = value.getBytes(StandardCharsets.UTF_8); + return "X'" + BaseEncoding.base16().encode(bytes) + "'"; + } +} diff --git a/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/sql/TestSparkYearsFunction.java b/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/sql/TestSparkYearsFunction.java new file mode 100644 index 000000000000..d4676716a612 --- /dev/null +++ b/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/sql/TestSparkYearsFunction.java @@ -0,0 +1,117 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.sql; + +import org.apache.iceberg.AssertHelpers; +import org.apache.iceberg.spark.SparkTestBaseWithCatalog; +import org.apache.iceberg.spark.functions.YearsFunction; +import org.apache.spark.sql.AnalysisException; +import org.assertj.core.api.Assertions; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; + +public class TestSparkYearsFunction extends SparkTestBaseWithCatalog { + + @Before + public void useCatalog() { + sql("USE %s", catalogName); + } + + @Test + public void testDates() { + Assert.assertEquals( + "Expected to produce 2017 - 1970 = 47", + 47, + scalarSql("SELECT system.years(date('2017-12-01'))")); + Assert.assertEquals( + "Expected to produce 1970 - 1970 = 0", + 0, + scalarSql("SELECT system.years(date('1970-01-01'))")); + Assert.assertEquals( + "Expected to produce 1969 - 1970 = -1", + -1, + scalarSql("SELECT system.years(date('1969-12-31'))")); + Assert.assertNull(scalarSql("SELECT system.years(CAST(null AS DATE))")); + } + + @Test + public void testTimestamps() { + Assert.assertEquals( + "Expected to produce 2017 - 1970 = 47", + 47, + scalarSql("SELECT system.years(TIMESTAMP '2017-12-01 10:12:55.038194 UTC+00:00')")); + Assert.assertEquals( + "Expected to produce 1970 - 1970 = 0", + 0, + scalarSql("SELECT system.years(TIMESTAMP '1970-01-01 00:00:01.000001 UTC+00:00')")); + Assert.assertEquals( + "Expected to produce 1969 - 1970 = -1", + -1, + scalarSql("SELECT system.years(TIMESTAMP '1969-12-31 23:59:58.999999 UTC+00:00')")); + Assert.assertNull(scalarSql("SELECT system.years(CAST(null AS TIMESTAMP))")); + } + + @Test + public void testWrongNumberOfArguments() { + AssertHelpers.assertThrows( + "Function resolution should not work with zero arguments", + AnalysisException.class, + "Function 'years' cannot process input: (): Wrong number of inputs", + () -> scalarSql("SELECT system.years()")); + + AssertHelpers.assertThrows( + "Function resolution should not work with more than one argument", + AnalysisException.class, + "Function 'years' cannot process input: (date, date): Wrong number of inputs", + () -> scalarSql("SELECT system.years(date('1969-12-31'), date('1969-12-31'))")); + } + + @Test + public void testInvalidInputTypes() { + AssertHelpers.assertThrows( + "Int type should not be coercible to date/timestamp", + AnalysisException.class, + "Function 'years' cannot process input: (int): Expected value to be date or timestamp", + () -> scalarSql("SELECT system.years(1)")); + + AssertHelpers.assertThrows( + "Long type should not be coercible to date/timestamp", + AnalysisException.class, + "Function 'years' cannot process input: (bigint): Expected value to be date or timestamp", + () -> scalarSql("SELECT system.years(1L)")); + } + + @Test + public void testThatMagicFunctionsAreInvoked() { + String dateValue = "date('2017-12-01')"; + String dateTransformClass = YearsFunction.DateToYearsFunction.class.getName(); + Assertions.assertThat(scalarSql("EXPLAIN EXTENDED SELECT system.years(%s)", dateValue)) + .asString() + .isNotNull() + .contains("staticinvoke(class " + dateTransformClass); + + String timestampValue = "TIMESTAMP '2017-12-01 10:12:55.038194 UTC+00:00'"; + String timestampTransformClass = YearsFunction.TimestampToYearsFunction.class.getName(); + Assertions.assertThat(scalarSql("EXPLAIN EXTENDED SELECT system.years(%s)", timestampValue)) + .asString() + .isNotNull() + .contains("staticinvoke(class " + timestampTransformClass); + } +} diff --git a/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/sql/TestStoragePartitionedJoins.java b/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/sql/TestStoragePartitionedJoins.java new file mode 100644 index 000000000000..78e18516f70d --- /dev/null +++ b/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/sql/TestStoragePartitionedJoins.java @@ -0,0 +1,586 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.sql; + +import java.util.List; +import java.util.Map; +import java.util.concurrent.atomic.AtomicReference; +import org.apache.commons.lang3.StringUtils; +import org.apache.iceberg.Schema; +import org.apache.iceberg.Table; +import org.apache.iceberg.TableProperties; +import org.apache.iceberg.expressions.Expressions; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap; +import org.apache.iceberg.relocated.com.google.common.collect.Lists; +import org.apache.iceberg.spark.SparkSQLProperties; +import org.apache.iceberg.spark.SparkSchemaUtil; +import org.apache.iceberg.spark.SparkTestBaseWithCatalog; +import org.apache.iceberg.spark.SparkWriteOptions; +import org.apache.iceberg.spark.data.RandomData; +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.catalyst.InternalRow; +import org.apache.spark.sql.catalyst.analysis.NoSuchTableException; +import org.apache.spark.sql.internal.SQLConf; +import org.apache.spark.sql.types.StructType; +import org.junit.After; +import org.junit.Assert; +import org.junit.BeforeClass; +import org.junit.Test; + +public class TestStoragePartitionedJoins extends SparkTestBaseWithCatalog { + + private static final String OTHER_TABLE_NAME = "other_table"; + + // open file cost and split size are set as 16 MB to produce a split per file + private static final Map TABLE_PROPERTIES = + ImmutableMap.of( + TableProperties.SPLIT_SIZE, "16777216", TableProperties.SPLIT_OPEN_FILE_COST, "16777216"); + + // only v2 bucketing and preserve data grouping properties have to be enabled to trigger SPJ + // other properties are only to simplify testing and validation + private static final Map ENABLED_SPJ_SQL_CONF = + ImmutableMap.of( + SQLConf.V2_BUCKETING_ENABLED().key(), + "true", + SQLConf.REQUIRE_ALL_CLUSTER_KEYS_FOR_CO_PARTITION().key(), + "false", + SQLConf.ADAPTIVE_EXECUTION_ENABLED().key(), + "false", + SQLConf.AUTO_BROADCASTJOIN_THRESHOLD().key(), + "-1", + SparkSQLProperties.PRESERVE_DATA_GROUPING, + "true"); + + private static final Map DISABLED_SPJ_SQL_CONF = + ImmutableMap.of( + SQLConf.V2_BUCKETING_ENABLED().key(), + "false", + SQLConf.REQUIRE_ALL_CLUSTER_KEYS_FOR_CO_PARTITION().key(), + "false", + SQLConf.ADAPTIVE_EXECUTION_ENABLED().key(), + "false", + SQLConf.AUTO_BROADCASTJOIN_THRESHOLD().key(), + "-1", + SparkSQLProperties.PRESERVE_DATA_GROUPING, + "true"); + + @BeforeClass + public static void setupSparkConf() { + spark.conf().set("spark.sql.shuffle.partitions", "4"); + } + + @After + public void removeTables() { + sql("DROP TABLE IF EXISTS %s", tableName); + sql("DROP TABLE IF EXISTS %s", tableName(OTHER_TABLE_NAME)); + } + + // TODO: add tests for truncate transforms once SPARK-40295 is released + // TODO: add tests for cases when one side contains a subset of keys once SPARK-41398 is released + + @Test + public void testJoinsWithBucketingOnByteColumn() throws NoSuchTableException { + checkJoin("byte_col", "TINYINT", "bucket(4, byte_col)"); + } + + @Test + public void testJoinsWithBucketingOnShortColumn() throws NoSuchTableException { + checkJoin("short_col", "SMALLINT", "bucket(4, short_col)"); + } + + @Test + public void testJoinsWithBucketingOnIntColumn() throws NoSuchTableException { + checkJoin("int_col", "INT", "bucket(16, int_col)"); + } + + @Test + public void testJoinsWithBucketingOnLongColumn() throws NoSuchTableException { + checkJoin("long_col", "BIGINT", "bucket(16, long_col)"); + } + + @Test + public void testJoinsWithBucketingOnTimestampColumn() throws NoSuchTableException { + checkJoin("timestamp_col", "TIMESTAMP", "bucket(16, timestamp_col)"); + } + + @Test + public void testJoinsWithBucketingOnDateColumn() throws NoSuchTableException { + checkJoin("date_col", "DATE", "bucket(8, date_col)"); + } + + @Test + public void testJoinsWithBucketingOnDecimalColumn() throws NoSuchTableException { + checkJoin("decimal_col", "DECIMAL(20, 2)", "bucket(8, decimal_col)"); + } + + @Test + public void testJoinsWithBucketingOnBinaryColumn() throws NoSuchTableException { + checkJoin("binary_col", "BINARY", "bucket(8, binary_col)"); + } + + @Test + public void testJoinsWithYearsOnTimestampColumn() throws NoSuchTableException { + checkJoin("timestamp_col", "TIMESTAMP", "years(timestamp_col)"); + } + + @Test + public void testJoinsWithYearsOnDateColumn() throws NoSuchTableException { + checkJoin("date_col", "DATE", "years(date_col)"); + } + + @Test + public void testJoinsWithMonthsOnTimestampColumn() throws NoSuchTableException { + checkJoin("timestamp_col", "TIMESTAMP", "months(timestamp_col)"); + } + + @Test + public void testJoinsWithMonthsOnDateColumn() throws NoSuchTableException { + checkJoin("date_col", "DATE", "months(date_col)"); + } + + @Test + public void testJoinsWithDaysOnTimestampColumn() throws NoSuchTableException { + checkJoin("timestamp_col", "TIMESTAMP", "days(timestamp_col)"); + } + + @Test + public void testJoinsWithDaysOnDateColumn() throws NoSuchTableException { + checkJoin("date_col", "DATE", "days(date_col)"); + } + + @Test + public void testJoinsWithHoursOnTimestampColumn() throws NoSuchTableException { + checkJoin("timestamp_col", "TIMESTAMP", "hours(timestamp_col)"); + } + + @Test + public void testJoinsWithMultipleTransformTypes() throws NoSuchTableException { + String createTableStmt = + "CREATE TABLE %s (" + + " id BIGINT, int_col INT, date_col1 DATE, date_col2 DATE, date_col3 DATE," + + " timestamp_col TIMESTAMP, string_col STRING, dep STRING)" + + "USING iceberg " + + "PARTITIONED BY (" + + " years(date_col1), months(date_col2), days(date_col3), hours(timestamp_col), " + + " bucket(8, int_col), dep)" + + "TBLPROPERTIES (%s)"; + + sql(createTableStmt, tableName, tablePropsAsString(TABLE_PROPERTIES)); + sql(createTableStmt, tableName(OTHER_TABLE_NAME), tablePropsAsString(TABLE_PROPERTIES)); + + Table table = validationCatalog.loadTable(tableIdent); + + Dataset dataDF = randomDataDF(table.schema(), 16); + + // write to the first table 1 time to generate 1 file per partition + append(tableName, dataDF); + + // write to the second table 2 times to generate 2 files per partition + append(tableName(OTHER_TABLE_NAME), dataDF); + append(tableName(OTHER_TABLE_NAME), dataDF); + + // Spark SPJ support is limited at the moment and requires all source partitioning columns, + // which were projected in the query, to be part of the join condition + // suppose a table is partitioned by `p1`, `bucket(8, pk)` + // queries covering `p1` and `pk` columns must include equality predicates + // on both `p1` and `pk` to benefit from SPJ + // this is a temporary Spark limitation that will be removed in a future release + + assertPartitioningAwarePlan( + 1, /* expected num of shuffles with SPJ */ + 3, /* expected num of shuffles without SPJ */ + "SELECT t1.id " + + "FROM %s t1 " + + "INNER JOIN %s t2 " + + "ON t1.id = t2.id AND t1.dep = t2.dep " + + "ORDER BY t1.id", + tableName, + tableName(OTHER_TABLE_NAME)); + + assertPartitioningAwarePlan( + 1, /* expected num of shuffles with SPJ */ + 3, /* expected num of shuffles without SPJ */ + "SELECT t1.id, t1.int_col, t1.date_col1 " + + "FROM %s t1 " + + "INNER JOIN %s t2 " + + "ON t1.id = t2.id AND t1.int_col = t2.int_col AND t1.date_col1 = t2.date_col1 " + + "ORDER BY t1.id, t1.int_col, t1.date_col1", + tableName, + tableName(OTHER_TABLE_NAME)); + + assertPartitioningAwarePlan( + 1, /* expected num of shuffles with SPJ */ + 3, /* expected num of shuffles without SPJ */ + "SELECT t1.id, t1.timestamp_col, t1.string_col " + + "FROM %s t1 " + + "INNER JOIN %s t2 " + + "ON t1.id = t2.id AND t1.timestamp_col = t2.timestamp_col AND t1.string_col = t2.string_col " + + "ORDER BY t1.id, t1.timestamp_col, t1.string_col", + tableName, + tableName(OTHER_TABLE_NAME)); + + assertPartitioningAwarePlan( + 1, /* expected num of shuffles with SPJ */ + 3, /* expected num of shuffles without SPJ */ + "SELECT t1.id, t1.date_col1, t1.date_col2, t1.date_col3 " + + "FROM %s t1 " + + "INNER JOIN %s t2 " + + "ON t1.id = t2.id AND t1.date_col1 = t2.date_col1 AND t1.date_col2 = t2.date_col2 AND t1.date_col3 = t2.date_col3 " + + "ORDER BY t1.id, t1.date_col1, t1.date_col2, t1.date_col3", + tableName, + tableName(OTHER_TABLE_NAME)); + + assertPartitioningAwarePlan( + 1, /* expected num of shuffles with SPJ */ + 3, /* expected num of shuffles without SPJ */ + "SELECT t1.id, t1.int_col, t1.timestamp_col, t1.dep " + + "FROM %s t1 " + + "INNER JOIN %s t2 " + + "ON t1.id = t2.id AND t1.int_col = t2.int_col AND t1.timestamp_col = t2.timestamp_col AND t1.dep = t2.dep " + + "ORDER BY t1.id, t1.int_col, t1.timestamp_col, t1.dep", + tableName, + tableName(OTHER_TABLE_NAME)); + } + + @Test + public void testJoinsWithCompatibleSpecEvolution() { + // create a table with an empty spec + sql( + "CREATE TABLE %s (id BIGINT, int_col INT, dep STRING)" + + "USING iceberg " + + "TBLPROPERTIES (%s)", + tableName, tablePropsAsString(TABLE_PROPERTIES)); + + Table table = validationCatalog.loadTable(tableIdent); + + // evolve the spec in the first table by adding `dep` + table.updateSpec().addField("dep").commit(); + + // insert data into the first table partitioned by `dep` + sql("REFRESH TABLE %s", tableName); + sql("INSERT INTO %s VALUES (1L, 100, 'software')", tableName); + + // evolve the spec in the first table by adding `bucket(int_col, 8)` + table.updateSpec().addField(Expressions.bucket("int_col", 8)).commit(); + + // insert data into the first table partitioned by `dep`, `bucket(8, int_col)` + sql("REFRESH TABLE %s", tableName); + sql("INSERT INTO %s VALUES (2L, 200, 'hr')", tableName); + + // create another table partitioned by `other_dep` + sql( + "CREATE TABLE %s (other_id BIGINT, other_int_col INT, other_dep STRING)" + + "USING iceberg " + + "PARTITIONED BY (other_dep)" + + "TBLPROPERTIES (%s)", + tableName(OTHER_TABLE_NAME), tablePropsAsString(TABLE_PROPERTIES)); + + // insert data into the second table partitioned by 'other_dep' + sql("INSERT INTO %s VALUES (1L, 100, 'software')", tableName(OTHER_TABLE_NAME)); + sql("INSERT INTO %s VALUES (2L, 200, 'hr')", tableName(OTHER_TABLE_NAME)); + + // SPJ would apply as the grouping keys are compatible + // the first table: `dep` (an intersection of all active partition fields across scanned specs) + // the second table: `other_dep` (the only partition field). + + assertPartitioningAwarePlan( + 1, /* expected num of shuffles with SPJ */ + 3, /* expected num of shuffles without SPJ */ + "SELECT * " + + "FROM %s " + + "INNER JOIN %s " + + "ON id = other_id AND int_col = other_int_col AND dep = other_dep " + + "ORDER BY id, int_col, dep", + tableName, + tableName(OTHER_TABLE_NAME)); + } + + @Test + public void testJoinsWithIncompatibleSpecs() { + sql( + "CREATE TABLE %s (id BIGINT, int_col INT, dep STRING)" + + "USING iceberg " + + "PARTITIONED BY (dep)" + + "TBLPROPERTIES (%s)", + tableName, tablePropsAsString(TABLE_PROPERTIES)); + + sql("INSERT INTO %s VALUES (1L, 100, 'software')", tableName); + sql("INSERT INTO %s VALUES (2L, 200, 'software')", tableName); + sql("INSERT INTO %s VALUES (3L, 300, 'software')", tableName); + + sql( + "CREATE TABLE %s (id BIGINT, int_col INT, dep STRING)" + + "USING iceberg " + + "PARTITIONED BY (bucket(8, int_col))" + + "TBLPROPERTIES (%s)", + tableName(OTHER_TABLE_NAME), tablePropsAsString(TABLE_PROPERTIES)); + + sql("INSERT INTO %s VALUES (1L, 100, 'software')", tableName(OTHER_TABLE_NAME)); + sql("INSERT INTO %s VALUES (2L, 200, 'software')", tableName(OTHER_TABLE_NAME)); + sql("INSERT INTO %s VALUES (3L, 300, 'software')", tableName(OTHER_TABLE_NAME)); + + // queries can't benefit from SPJ as specs are not compatible + // the first table: `dep` + // the second table: `bucket(8, int_col)` + + assertPartitioningAwarePlan( + 3, /* expected num of shuffles with SPJ */ + 3, /* expected num of shuffles with SPJ */ + "SELECT * " + + "FROM %s t1 " + + "INNER JOIN %s t2 " + + "ON t1.id = t2.id AND t1.int_col = t2.int_col AND t1.dep = t2.dep " + + "ORDER BY t1.id, t1.int_col, t1.dep, t2.id, t2.int_col, t2.dep", + tableName, + tableName(OTHER_TABLE_NAME)); + } + + @Test + public void testJoinsWithUnpartitionedTables() { + sql( + "CREATE TABLE %s (id BIGINT, int_col INT, dep STRING)" + + "USING iceberg " + + "TBLPROPERTIES (" + + " 'read.split.target-size' = 16777216," + + " 'read.split.open-file-cost' = 16777216)", + tableName); + + sql("INSERT INTO %s VALUES (1L, 100, 'software')", tableName); + sql("INSERT INTO %s VALUES (2L, 200, 'software')", tableName); + sql("INSERT INTO %s VALUES (3L, 300, 'software')", tableName); + + sql( + "CREATE TABLE %s (id BIGINT, int_col INT, dep STRING)" + + "USING iceberg " + + "TBLPROPERTIES (" + + " 'read.split.target-size' = 16777216," + + " 'read.split.open-file-cost' = 16777216)", + tableName(OTHER_TABLE_NAME)); + + sql("INSERT INTO %s VALUES (1L, 100, 'software')", tableName(OTHER_TABLE_NAME)); + sql("INSERT INTO %s VALUES (2L, 200, 'software')", tableName(OTHER_TABLE_NAME)); + sql("INSERT INTO %s VALUES (3L, 300, 'software')", tableName(OTHER_TABLE_NAME)); + + // queries covering unpartitioned tables can't benefit from SPJ but shouldn't fail + + assertPartitioningAwarePlan( + 3, /* expected num of shuffles with SPJ */ + 3, /* expected num of shuffles without SPJ */ + "SELECT * " + + "FROM %s t1 " + + "INNER JOIN %s t2 " + + "ON t1.id = t2.id AND t1.int_col = t2.int_col AND t1.dep = t2.dep " + + "ORDER BY t1.id, t1.int_col, t1.dep, t2.id, t2.int_col, t2.dep", + tableName, + tableName(OTHER_TABLE_NAME)); + } + + @Test + public void testJoinsWithEmptyTable() { + sql( + "CREATE TABLE %s (id BIGINT, int_col INT, dep STRING)" + + "USING iceberg " + + "PARTITIONED BY (dep)" + + "TBLPROPERTIES (%s)", + tableName, tablePropsAsString(TABLE_PROPERTIES)); + + sql( + "CREATE TABLE %s (id BIGINT, int_col INT, dep STRING)" + + "USING iceberg " + + "PARTITIONED BY (dep)" + + "TBLPROPERTIES (%s)", + tableName(OTHER_TABLE_NAME), tablePropsAsString(TABLE_PROPERTIES)); + + sql("INSERT INTO %s VALUES (1L, 100, 'software')", tableName(OTHER_TABLE_NAME)); + sql("INSERT INTO %s VALUES (2L, 200, 'software')", tableName(OTHER_TABLE_NAME)); + sql("INSERT INTO %s VALUES (3L, 300, 'software')", tableName(OTHER_TABLE_NAME)); + + assertPartitioningAwarePlan( + 3, /* expected num of shuffles with SPJ */ + 3, /* expected num of shuffles without SPJ */ + "SELECT * " + + "FROM %s t1 " + + "INNER JOIN %s t2 " + + "ON t1.id = t2.id AND t1.int_col = t2.int_col AND t1.dep = t2.dep " + + "ORDER BY t1.id, t1.int_col, t1.dep, t2.id, t2.int_col, t2.dep", + tableName, + tableName(OTHER_TABLE_NAME)); + } + + @Test + public void testJoinsWithOneSplitTables() { + sql( + "CREATE TABLE %s (id BIGINT, int_col INT, dep STRING)" + + "USING iceberg " + + "PARTITIONED BY (dep)" + + "TBLPROPERTIES (%s)", + tableName, tablePropsAsString(TABLE_PROPERTIES)); + + sql("INSERT INTO %s VALUES (1L, 100, 'software')", tableName); + + sql( + "CREATE TABLE %s (id BIGINT, int_col INT, dep STRING)" + + "USING iceberg " + + "PARTITIONED BY (dep)" + + "TBLPROPERTIES (%s)", + tableName(OTHER_TABLE_NAME), tablePropsAsString(TABLE_PROPERTIES)); + + sql("INSERT INTO %s VALUES (1L, 100, 'software')", tableName(OTHER_TABLE_NAME)); + + // Spark should be able to avoid shuffles even without SPJ if each side has only one split + + assertPartitioningAwarePlan( + 0, /* expected num of shuffles with SPJ */ + 0, /* expected num of shuffles without SPJ */ + "SELECT * " + + "FROM %s t1 " + + "INNER JOIN %s t2 " + + "ON t1.id = t2.id AND t1.int_col = t2.int_col AND t1.dep = t2.dep " + + "ORDER BY t1.id, t1.int_col, t1.dep, t2.id, t2.int_col, t2.dep", + tableName, + tableName(OTHER_TABLE_NAME)); + } + + @Test + public void testAggregates() throws NoSuchTableException { + sql( + "CREATE TABLE %s (id BIGINT, int_col INT, dep STRING)" + + "USING iceberg " + + "PARTITIONED BY (dep, bucket(8, int_col))" + + "TBLPROPERTIES (%s)", + tableName, tablePropsAsString(TABLE_PROPERTIES)); + + // write to the table 3 times to generate 3 files per partition + Table table = validationCatalog.loadTable(tableIdent); + Dataset dataDF = randomDataDF(table.schema(), 100); + append(tableName, dataDF); + + assertPartitioningAwarePlan( + 1, /* expected num of shuffles with SPJ */ + 3, /* expected num of shuffles without SPJ */ + "SELECT COUNT (DISTINCT id) AS count FROM %s GROUP BY dep, int_col ORDER BY count", + tableName, + tableName(OTHER_TABLE_NAME)); + + assertPartitioningAwarePlan( + 1, /* expected num of shuffles with SPJ */ + 3, /* expected num of shuffles without SPJ */ + "SELECT COUNT (DISTINCT id) AS count FROM %s GROUP BY dep ORDER BY count", + tableName, + tableName(OTHER_TABLE_NAME)); + } + + private void checkJoin(String sourceColumnName, String sourceColumnType, String transform) + throws NoSuchTableException { + + String createTableStmt = + "CREATE TABLE %s (id BIGINT, salary INT, %s %s)" + + "USING iceberg " + + "PARTITIONED BY (%s)" + + "TBLPROPERTIES (%s)"; + + sql( + createTableStmt, + tableName, + sourceColumnName, + sourceColumnType, + transform, + tablePropsAsString(TABLE_PROPERTIES)); + + sql( + createTableStmt, + tableName(OTHER_TABLE_NAME), + sourceColumnName, + sourceColumnType, + transform, + tablePropsAsString(TABLE_PROPERTIES)); + + Table table = validationCatalog.loadTable(tableIdent); + Dataset dataDF = randomDataDF(table.schema(), 200); + append(tableName, dataDF); + append(tableName(OTHER_TABLE_NAME), dataDF); + + assertPartitioningAwarePlan( + 1, /* expected num of shuffles with SPJ */ + 3, /* expected num of shuffles without SPJ */ + "SELECT t1.id, t1.salary, t1.%s " + + "FROM %s t1 " + + "INNER JOIN %s t2 " + + "ON t1.id = t2.id AND t1.%s = t2.%s " + + "ORDER BY t1.id, t1.%s", + sourceColumnName, + tableName, + tableName(OTHER_TABLE_NAME), + sourceColumnName, + sourceColumnName, + sourceColumnName); + } + + private void assertPartitioningAwarePlan( + int expectedNumShufflesWithSPJ, + int expectedNumShufflesWithoutSPJ, + String query, + Object... args) { + + AtomicReference> rowsWithSPJ = new AtomicReference<>(); + AtomicReference> rowsWithoutSPJ = new AtomicReference<>(); + + withSQLConf( + ENABLED_SPJ_SQL_CONF, + () -> { + String plan = executeAndKeepPlan(query, args).toString(); + int actualNumShuffles = StringUtils.countMatches(plan, "Exchange"); + Assert.assertEquals( + "Number of shuffles with enabled SPJ must match", + expectedNumShufflesWithSPJ, + actualNumShuffles); + + rowsWithSPJ.set(sql(query, args)); + }); + + withSQLConf( + DISABLED_SPJ_SQL_CONF, + () -> { + String plan = executeAndKeepPlan(query, args).toString(); + int actualNumShuffles = StringUtils.countMatches(plan, "Exchange"); + Assert.assertEquals( + "Number of shuffles with disabled SPJ must match", + expectedNumShufflesWithoutSPJ, + actualNumShuffles); + + rowsWithoutSPJ.set(sql(query, args)); + }); + + assertEquals("SPJ should not change query output", rowsWithoutSPJ.get(), rowsWithSPJ.get()); + } + + private Dataset randomDataDF(Schema schema, int numRows) { + Iterable rows = RandomData.generateSpark(schema, numRows, 0); + JavaRDD rowRDD = sparkContext.parallelize(Lists.newArrayList(rows)); + StructType rowSparkType = SparkSchemaUtil.convert(schema); + return spark.internalCreateDataFrame(JavaRDD.toRDD(rowRDD), rowSparkType, false); + } + + private void append(String table, Dataset df) throws NoSuchTableException { + // fanout writes are enabled as write-time clustering is not supported without Spark extensions + df.coalesce(1).writeTo(table).option(SparkWriteOptions.FANOUT_ENABLED, "true").append(); + } +} diff --git a/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/sql/TestTimestampWithoutZone.java b/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/sql/TestTimestampWithoutZone.java new file mode 100644 index 000000000000..51b8d255a99b --- /dev/null +++ b/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/sql/TestTimestampWithoutZone.java @@ -0,0 +1,238 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.sql; + +import java.sql.Timestamp; +import java.util.Arrays; +import java.util.List; +import java.util.Map; +import java.util.stream.Collectors; +import org.apache.iceberg.AssertHelpers; +import org.apache.iceberg.Schema; +import org.apache.iceberg.Table; +import org.apache.iceberg.catalog.TableIdentifier; +import org.apache.iceberg.relocated.com.google.common.base.Joiner; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableList; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap; +import org.apache.iceberg.spark.SparkCatalogTestBase; +import org.apache.iceberg.spark.SparkSQLProperties; +import org.apache.iceberg.spark.SparkSessionCatalog; +import org.apache.iceberg.spark.SparkUtil; +import org.apache.iceberg.types.Type; +import org.apache.iceberg.types.Types; +import org.apache.spark.sql.util.CaseInsensitiveStringMap; +import org.joda.time.DateTime; +import org.junit.After; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; +import org.junit.runners.Parameterized; + +public class TestTimestampWithoutZone extends SparkCatalogTestBase { + + private static final String newTableName = "created_table"; + private final Map config; + + private static final Schema schema = + new Schema( + Types.NestedField.required(1, "id", Types.LongType.get()), + Types.NestedField.required(2, "ts", Types.TimestampType.withoutZone()), + Types.NestedField.required(3, "tsz", Types.TimestampType.withZone())); + + private final List values = + ImmutableList.of( + row(1L, toTimestamp("2021-01-01T00:00:00.0"), toTimestamp("2021-02-01T00:00:00.0")), + row(2L, toTimestamp("2021-01-01T00:00:00.0"), toTimestamp("2021-02-01T00:00:00.0")), + row(3L, toTimestamp("2021-01-01T00:00:00.0"), toTimestamp("2021-02-01T00:00:00.0"))); + + @Parameterized.Parameters(name = "catalogName = {0}, implementation = {1}, config = {2}") + public static Object[][] parameters() { + return new Object[][] { + { + "spark_catalog", + SparkSessionCatalog.class.getName(), + ImmutableMap.of( + "type", "hive", + "default-namespace", "default", + "parquet-enabled", "true", + "cache-enabled", "false") + } + }; + } + + public TestTimestampWithoutZone( + String catalogName, String implementation, Map config) { + super(catalogName, implementation, config); + this.config = config; + } + + @Before + public void createTables() { + validationCatalog.createTable(tableIdent, schema); + } + + @After + public void removeTables() { + validationCatalog.dropTable(tableIdent, true); + sql("DROP TABLE IF EXISTS %s", newTableName); + } + + @Test + public void testWriteTimestampWithoutZoneError() { + AssertHelpers.assertThrows( + String.format( + "Write operation performed on a timestamp without timezone field while " + + "'%s' set to false should throw exception", + SparkSQLProperties.HANDLE_TIMESTAMP_WITHOUT_TIMEZONE), + IllegalArgumentException.class, + SparkUtil.TIMESTAMP_WITHOUT_TIMEZONE_ERROR, + () -> sql("INSERT INTO %s VALUES %s", tableName, rowToSqlValues(values))); + } + + @Test + public void testAppendTimestampWithoutZone() { + withSQLConf( + ImmutableMap.of(SparkSQLProperties.HANDLE_TIMESTAMP_WITHOUT_TIMEZONE, "true"), + () -> { + sql("INSERT INTO %s VALUES %s", tableName, rowToSqlValues(values)); + + Assert.assertEquals( + "Should have " + values.size() + " row", + (long) values.size(), + scalarSql("SELECT count(*) FROM %s", tableName)); + + assertEquals( + "Row data should match expected", + values, + sql("SELECT * FROM %s ORDER BY id", tableName)); + }); + } + + @Test + public void testCreateAsSelectWithTimestampWithoutZone() { + withSQLConf( + ImmutableMap.of(SparkSQLProperties.HANDLE_TIMESTAMP_WITHOUT_TIMEZONE, "true"), + () -> { + sql("INSERT INTO %s VALUES %s", tableName, rowToSqlValues(values)); + + sql("CREATE TABLE %s USING iceberg AS SELECT * FROM %s", newTableName, tableName); + + Assert.assertEquals( + "Should have " + values.size() + " row", + (long) values.size(), + scalarSql("SELECT count(*) FROM %s", newTableName)); + + assertEquals( + "Row data should match expected", + sql("SELECT * FROM %s ORDER BY id", tableName), + sql("SELECT * FROM %s ORDER BY id", newTableName)); + }); + } + + @Test + public void testCreateNewTableShouldHaveTimestampWithZoneIcebergType() { + withSQLConf( + ImmutableMap.of(SparkSQLProperties.HANDLE_TIMESTAMP_WITHOUT_TIMEZONE, "true"), + () -> { + sql("INSERT INTO %s VALUES %s", tableName, rowToSqlValues(values)); + + sql("CREATE TABLE %s USING iceberg AS SELECT * FROM %s", newTableName, tableName); + + Assert.assertEquals( + "Should have " + values.size() + " row", + (long) values.size(), + scalarSql("SELECT count(*) FROM %s", newTableName)); + + assertEquals( + "Data from created table should match data from base table", + sql("SELECT * FROM %s ORDER BY id", tableName), + sql("SELECT * FROM %s ORDER BY id", newTableName)); + + Table createdTable = + validationCatalog.loadTable(TableIdentifier.of("default", newTableName)); + assertFieldsType(createdTable.schema(), Types.TimestampType.withZone(), "ts", "tsz"); + }); + } + + @Test + public void testCreateNewTableShouldHaveTimestampWithoutZoneIcebergType() { + withSQLConf( + ImmutableMap.of( + SparkSQLProperties.HANDLE_TIMESTAMP_WITHOUT_TIMEZONE, "true", + SparkSQLProperties.USE_TIMESTAMP_WITHOUT_TIME_ZONE_IN_NEW_TABLES, "true"), + () -> { + spark + .sessionState() + .catalogManager() + .currentCatalog() + .initialize(catalog.name(), new CaseInsensitiveStringMap(config)); + sql("INSERT INTO %s VALUES %s", tableName, rowToSqlValues(values)); + + sql("CREATE TABLE %s USING iceberg AS SELECT * FROM %s", newTableName, tableName); + + Assert.assertEquals( + "Should have " + values.size() + " row", + (long) values.size(), + scalarSql("SELECT count(*) FROM %s", newTableName)); + + assertEquals( + "Row data should match expected", + sql("SELECT * FROM %s ORDER BY id", tableName), + sql("SELECT * FROM %s ORDER BY id", newTableName)); + Table createdTable = + validationCatalog.loadTable(TableIdentifier.of("default", newTableName)); + assertFieldsType(createdTable.schema(), Types.TimestampType.withoutZone(), "ts", "tsz"); + }); + } + + private Timestamp toTimestamp(String value) { + return new Timestamp(DateTime.parse(value).getMillis()); + } + + private String rowToSqlValues(List rows) { + List rowValues = + rows.stream() + .map( + row -> { + List columns = + Arrays.stream(row) + .map( + value -> { + if (value instanceof Long) { + return value.toString(); + } else if (value instanceof Timestamp) { + return String.format("timestamp '%s'", value); + } + throw new RuntimeException("Type is not supported"); + }) + .collect(Collectors.toList()); + return "(" + Joiner.on(",").join(columns) + ")"; + }) + .collect(Collectors.toList()); + return Joiner.on(",").join(rowValues); + } + + private void assertFieldsType(Schema actual, Type.PrimitiveType expected, String... fields) { + actual + .select(fields) + .asStruct() + .fields() + .forEach(field -> Assert.assertEquals(expected, field.type())); + } +} diff --git a/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/sql/TestUnpartitionedWrites.java b/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/sql/TestUnpartitionedWrites.java new file mode 100644 index 000000000000..d01ccab00f55 --- /dev/null +++ b/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/sql/TestUnpartitionedWrites.java @@ -0,0 +1,29 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.sql; + +import java.util.Map; + +public class TestUnpartitionedWrites extends UnpartitionedWritesTestBase { + + public TestUnpartitionedWrites( + String catalogName, String implementation, Map config) { + super(catalogName, implementation, config); + } +} diff --git a/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/sql/TestUnpartitionedWritesToBranch.java b/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/sql/TestUnpartitionedWritesToBranch.java new file mode 100644 index 000000000000..1f5bee42af05 --- /dev/null +++ b/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/sql/TestUnpartitionedWritesToBranch.java @@ -0,0 +1,61 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.sql; + +import java.util.Map; +import org.apache.iceberg.Table; +import org.apache.iceberg.exceptions.ValidationException; +import org.assertj.core.api.Assertions; +import org.junit.Test; + +public class TestUnpartitionedWritesToBranch extends UnpartitionedWritesTestBase { + + private static final String BRANCH = "test"; + + public TestUnpartitionedWritesToBranch( + String catalogName, String implementation, Map config) { + super(catalogName, implementation, config); + } + + @Override + public void createTables() { + super.createTables(); + Table table = validationCatalog.loadTable(tableIdent); + table.manageSnapshots().createBranch(BRANCH, table.currentSnapshot().snapshotId()).commit(); + sql("REFRESH TABLE " + tableName); + } + + @Override + protected String commitTarget() { + return String.format("%s.branch_%s", tableName, BRANCH); + } + + @Override + protected String selectTarget() { + return String.format("%s VERSION AS OF '%s'", tableName, BRANCH); + } + + @Test + public void testInsertIntoNonExistingBranchFails() { + Assertions.assertThatThrownBy( + () -> sql("INSERT INTO %s.branch_not_exist VALUES (4, 'd'), (5, 'e')", tableName)) + .isInstanceOf(ValidationException.class) + .hasMessage("Cannot use branch (does not exist): not_exist"); + } +} diff --git a/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/sql/UnpartitionedWritesTestBase.java b/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/sql/UnpartitionedWritesTestBase.java new file mode 100644 index 000000000000..71089ebfd79e --- /dev/null +++ b/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/sql/UnpartitionedWritesTestBase.java @@ -0,0 +1,191 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.sql; + +import java.util.List; +import java.util.Map; +import org.apache.iceberg.AssertHelpers; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableList; +import org.apache.iceberg.spark.SparkCatalogTestBase; +import org.apache.iceberg.spark.source.SimpleRecord; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.catalyst.analysis.NoSuchTableException; +import org.apache.spark.sql.functions; +import org.junit.After; +import org.junit.Assert; +import org.junit.Assume; +import org.junit.Before; +import org.junit.Test; + +public abstract class UnpartitionedWritesTestBase extends SparkCatalogTestBase { + public UnpartitionedWritesTestBase( + String catalogName, String implementation, Map config) { + super(catalogName, implementation, config); + } + + @Before + public void createTables() { + sql("CREATE TABLE %s (id bigint, data string) USING iceberg", tableName); + sql("INSERT INTO %s VALUES (1, 'a'), (2, 'b'), (3, 'c')", tableName); + } + + @After + public void removeTables() { + sql("DROP TABLE IF EXISTS %s", tableName); + } + + @Test + public void testInsertAppend() { + Assert.assertEquals( + "Should have 3 rows", 3L, scalarSql("SELECT count(*) FROM %s", selectTarget())); + + sql("INSERT INTO %s VALUES (4, 'd'), (5, 'e')", commitTarget()); + + Assert.assertEquals( + "Should have 5 rows after insert", + 5L, + scalarSql("SELECT count(*) FROM %s", selectTarget())); + + List expected = + ImmutableList.of(row(1L, "a"), row(2L, "b"), row(3L, "c"), row(4L, "d"), row(5L, "e")); + + assertEquals( + "Row data should match expected", + expected, + sql("SELECT * FROM %s ORDER BY id", selectTarget())); + } + + @Test + public void testInsertOverwrite() { + Assert.assertEquals( + "Should have 3 rows", 3L, scalarSql("SELECT count(*) FROM %s", selectTarget())); + + sql("INSERT OVERWRITE %s VALUES (4, 'd'), (5, 'e')", commitTarget()); + + Assert.assertEquals( + "Should have 2 rows after overwrite", + 2L, + scalarSql("SELECT count(*) FROM %s", selectTarget())); + + List expected = ImmutableList.of(row(4L, "d"), row(5L, "e")); + + assertEquals( + "Row data should match expected", + expected, + sql("SELECT * FROM %s ORDER BY id", selectTarget())); + } + + @Test + public void testInsertAppendAtSnapshot() { + Assume.assumeTrue(tableName.equals(commitTarget())); + long snapshotId = validationCatalog.loadTable(tableIdent).currentSnapshot().snapshotId(); + String prefix = "snapshot_id_"; + AssertHelpers.assertThrows( + "Should not be able to insert into a table at a specific snapshot", + IllegalArgumentException.class, + "Cannot write to table at a specific snapshot", + () -> sql("INSERT INTO %s.%s VALUES (4, 'd'), (5, 'e')", tableName, prefix + snapshotId)); + } + + @Test + public void testInsertOverwriteAtSnapshot() { + Assume.assumeTrue(tableName.equals(commitTarget())); + long snapshotId = validationCatalog.loadTable(tableIdent).currentSnapshot().snapshotId(); + String prefix = "snapshot_id_"; + AssertHelpers.assertThrows( + "Should not be able to insert into a table at a specific snapshot", + IllegalArgumentException.class, + "Cannot write to table at a specific snapshot", + () -> + sql( + "INSERT OVERWRITE %s.%s VALUES (4, 'd'), (5, 'e')", + tableName, prefix + snapshotId)); + } + + @Test + public void testDataFrameV2Append() throws NoSuchTableException { + Assert.assertEquals( + "Should have 3 rows", 3L, scalarSql("SELECT count(*) FROM %s", selectTarget())); + + List data = ImmutableList.of(new SimpleRecord(4, "d"), new SimpleRecord(5, "e")); + Dataset ds = spark.createDataFrame(data, SimpleRecord.class); + + ds.writeTo(commitTarget()).append(); + + Assert.assertEquals( + "Should have 5 rows after insert", + 5L, + scalarSql("SELECT count(*) FROM %s", selectTarget())); + + List expected = + ImmutableList.of(row(1L, "a"), row(2L, "b"), row(3L, "c"), row(4L, "d"), row(5L, "e")); + + assertEquals( + "Row data should match expected", + expected, + sql("SELECT * FROM %s ORDER BY id", selectTarget())); + } + + @Test + public void testDataFrameV2DynamicOverwrite() throws NoSuchTableException { + Assert.assertEquals( + "Should have 3 rows", 3L, scalarSql("SELECT count(*) FROM %s", selectTarget())); + + List data = ImmutableList.of(new SimpleRecord(4, "d"), new SimpleRecord(5, "e")); + Dataset ds = spark.createDataFrame(data, SimpleRecord.class); + + ds.writeTo(commitTarget()).overwritePartitions(); + + Assert.assertEquals( + "Should have 2 rows after overwrite", + 2L, + scalarSql("SELECT count(*) FROM %s", selectTarget())); + + List expected = ImmutableList.of(row(4L, "d"), row(5L, "e")); + + assertEquals( + "Row data should match expected", + expected, + sql("SELECT * FROM %s ORDER BY id", selectTarget())); + } + + @Test + public void testDataFrameV2Overwrite() throws NoSuchTableException { + Assert.assertEquals( + "Should have 3 rows", 3L, scalarSql("SELECT count(*) FROM %s", selectTarget())); + + List data = ImmutableList.of(new SimpleRecord(4, "d"), new SimpleRecord(5, "e")); + Dataset ds = spark.createDataFrame(data, SimpleRecord.class); + + ds.writeTo(commitTarget()).overwrite(functions.col("id").$less$eq(3)); + + Assert.assertEquals( + "Should have 2 rows after overwrite", + 2L, + scalarSql("SELECT count(*) FROM %s", selectTarget())); + + List expected = ImmutableList.of(row(4L, "d"), row(5L, "e")); + + assertEquals( + "Row data should match expected", + expected, + sql("SELECT * FROM %s ORDER BY id", selectTarget())); + } +}