Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 16 additions & 10 deletions spark/src/main/java/org/apache/iceberg/spark/SparkTableUtil.java
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan;
import scala.Function2;
import scala.Option;
import scala.Predef;
import scala.Some;
import scala.Tuple2;
import scala.collection.JavaConverters;
Expand Down Expand Up @@ -140,7 +141,7 @@ public static Dataset<Row> partitionDFByFilter(SparkSession spark, String table,
public static List<SparkPartition> getPartitions(SparkSession spark, String table) {
try {
TableIdentifier tableIdent = spark.sessionState().sqlParser().parseTableIdentifier(table);
return getPartitions(spark, tableIdent);
return getPartitions(spark, tableIdent, null);
} catch (ParseException e) {
throw SparkExceptionUtil.toUncheckedException(e, "Unable to parse table identifier: %s", table);
}
Expand All @@ -151,15 +152,23 @@ public static List<SparkPartition> getPartitions(SparkSession spark, String tabl
*
* @param spark a Spark session
* @param tableIdent a table identifier
* @param partitionFilter partition filter, or null if no filter
* @return all table's partitions
*/
public static List<SparkPartition> getPartitions(SparkSession spark, TableIdentifier tableIdent) {
public static List<SparkPartition> getPartitions(SparkSession spark, TableIdentifier tableIdent,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we benefit from having this as a Java Optional? Since we have to immediately convert it to Scala maybe we should just pass a normal Map?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the "filterPartitions" function would just use an empty map as "no filter"?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

Map<String, String> partitionFilter) {
try {
SessionCatalog catalog = spark.sessionState().catalog();
CatalogTable catalogTable = catalog.getTableMetadata(tableIdent);

Seq<CatalogTablePartition> partitions = catalog.listPartitions(tableIdent, Option.empty());

Option<scala.collection.immutable.Map<String, String>> scalaPartitionFilter;
if (partitionFilter != null && !partitionFilter.isEmpty()) {
scalaPartitionFilter = Option.apply(JavaConverters.mapAsScalaMapConverter(partitionFilter).asScala()
.toMap(Predef.conforms()));
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The scala api requires immuable map, hence this extra step

} else {
scalaPartitionFilter = Option.empty();
}
Seq<CatalogTablePartition> partitions = catalog.listPartitions(tableIdent, scalaPartitionFilter);
return JavaConverters
.seqAsJavaListConverter(partitions)
.asJava()
Expand Down Expand Up @@ -375,14 +384,11 @@ public static void importSparkTable(SparkSession spark, TableIdentifier sourceTa
if (Objects.equal(spec, PartitionSpec.unpartitioned())) {
importUnpartitionedSparkTable(spark, sourceTableIdentWithDB, targetTable);
} else {
List<SparkPartition> sourceTablePartitions = getPartitions(spark, sourceTableIdent);
List<SparkPartition> sourceTablePartitions = getPartitions(spark, sourceTableIdent,
partitionFilter);
Preconditions.checkArgument(!sourceTablePartitions.isEmpty(),
"Cannot find any partitions in table %s", sourceTableIdent);
List<SparkPartition> filteredPartitions = filterPartitions(sourceTablePartitions, partitionFilter);
Preconditions.checkArgument(!filteredPartitions.isEmpty(),
"Cannot find any partitions which match the given filter. Partition filter is %s",
MAP_JOINER.join(partitionFilter));
importSparkPartitions(spark, filteredPartitions, targetTable, spec, stagingDir);
importSparkPartitions(spark, sourceTablePartitions, targetTable, spec, stagingDir);
}
} catch (AnalysisException e) {
throw SparkExceptionUtil.toUncheckedException(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,9 @@
import java.io.File;
import java.io.IOException;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.stream.Collectors;
import org.apache.hadoop.fs.Path;
import org.apache.hadoop.hive.conf.HiveConf;
import org.apache.iceberg.FileFormat;
Expand All @@ -34,6 +37,7 @@
import org.apache.iceberg.mapping.NameMappingParser;
import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap;
import org.apache.iceberg.relocated.com.google.common.collect.Lists;
import org.apache.iceberg.relocated.com.google.common.collect.Sets;
import org.apache.iceberg.spark.SparkSchemaUtil;
import org.apache.iceberg.spark.SparkTableUtil;
import org.apache.iceberg.spark.SparkTableUtil.SparkPartition;
Expand Down Expand Up @@ -369,6 +373,78 @@ public void testImportUnpartitionedWithWhitespace() throws Exception {
}
}

public static class GetPartitions {

@Rule
public TemporaryFolder temp = new TemporaryFolder();

// This logic does not really depend on format
private final FileFormat format = FileFormat.PARQUET;

@Test
public void testPartitionScan() throws Exception {

List<ThreeColumnRecord> records = Lists.newArrayList(
new ThreeColumnRecord(1, "ab", "data"),
new ThreeColumnRecord(2, "b c", "data"),
new ThreeColumnRecord(1, "b c", "data"),
new ThreeColumnRecord(2, "ab", "data"));

String tableName = "external_table";

spark.createDataFrame(records, ThreeColumnRecord.class)
.write().mode("overwrite").format(format.toString())
.partitionBy("c1", "c2").saveAsTable(tableName);

TableIdentifier source = spark.sessionState().sqlParser()
.parseTableIdentifier(tableName);

Map<String, String> partition1 = ImmutableMap.of(
"c1", "1",
"c2", "ab");
Map<String, String> partition2 = ImmutableMap.of(
"c1", "2",
"c2", "b c");
Map<String, String> partition3 = ImmutableMap.of(
"c1", "1",
"c2", "b c");
Map<String, String> partition4 = ImmutableMap.of(
"c1", "2",
"c2", "ab");

List<SparkPartition> partitionsC11 =
SparkTableUtil.getPartitions(spark, source, ImmutableMap.of("c1", "1"));
Set<Map<String, String>> expectedC11 =
Sets.newHashSet(partition1, partition3);
Set<Map<String, String>> actualC11 = partitionsC11.stream().map(
p -> p.getValues()).collect(Collectors.toSet());
Assert.assertEquals("Wrong partitions fetched for c1=1", expectedC11, actualC11);

List<SparkPartition> partitionsC12 =
SparkTableUtil.getPartitions(spark, source, ImmutableMap.of("c1", "2"));
Set<Map<String, String>> expectedC12 = Sets.newHashSet(partition2, partition4);
Set<Map<String, String>> actualC12 = partitionsC12.stream().map(
p -> p.getValues()).collect(Collectors.toSet());
Assert.assertEquals("Wrong partitions fetched for c1=2", expectedC12, actualC12);

List<SparkPartition> partitionsC21 =
SparkTableUtil.getPartitions(spark, source, ImmutableMap.of("c2", "ab"));
Set<Map<String, String>> expectedC21 =
Sets.newHashSet(partition1, partition4);
Set<Map<String, String>> actualC21 = partitionsC21.stream().map(
p -> p.getValues()).collect(Collectors.toSet());
Assert.assertEquals("Wrong partitions fetched for c2=ab", expectedC21, actualC21);

List<SparkPartition> partitionsC22 =
SparkTableUtil.getPartitions(spark, source, ImmutableMap.of("c2", "b c"));
Set<Map<String, String>> expectedC22 =
Sets.newHashSet(partition2, partition3);
Set<Map<String, String>> actualC22 = partitionsC22.stream().map(
p -> p.getValues()).collect(Collectors.toSet());
Assert.assertEquals("Wrong partitions fetched for c2=b c", expectedC22, actualC22);
}
}

public static class PartitionScan {

@Before
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -318,6 +318,26 @@ public void addFilteredPartitionsToPartitioned() {
sql("SELECT id, name, dept, subdept FROM %s ORDER BY id", tableName));
}

@Test
public void addFilteredPartitionsToPartitioned2() {
createCompositePartitionedTable("parquet");
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we need a SparkTableUtil test for the new getPartitions code as well, unless that's a pain


String createIceberg =
"CREATE TABLE %s (id Integer, name String, dept String, subdept String) USING iceberg " +
"PARTITIONED BY (id, dept)";

sql(createIceberg, tableName);

Object result = scalarSql("CALL %s.system.add_files('%s', '`parquet`.`%s`', map('dept', 'hr'))",
catalogName, tableName, fileTableDir.getAbsolutePath());

Assert.assertEquals(6L, result);

assertEquals("Iceberg table contains correct data",
sql("SELECT id, name, dept, subdept FROM %s WHERE dept = 'hr' ORDER BY id", sourceTableName),
sql("SELECT id, name, dept, subdept FROM %s ORDER BY id", tableName));
}

@Test
public void addWeirdCaseHiveTable() {
createWeirdCaseTable();
Expand Down