Skip to content

Commit

Permalink
Core: Schema for a branch should return table schema (#9131)
Browse files Browse the repository at this point in the history
When retrieving the schema for branch we should always return the table
schema instead of the snapshot schema. This is because the table schema
is the schema that will be used when the branch will be created.
We should only return the schema of the snapshot when we have a tag.
  • Loading branch information
nastra authored Dec 5, 2023
1 parent 99843f0 commit a4d4756
Show file tree
Hide file tree
Showing 5 changed files with 380 additions and 43 deletions.
40 changes: 21 additions & 19 deletions core/src/main/java/org/apache/iceberg/util/SnapshotUtil.java
Original file line number Diff line number Diff line change
Expand Up @@ -409,49 +409,51 @@ public static Schema schemaFor(Table table, Long snapshotId, Long timestampMilli
}

/**
* Return the schema of the snapshot at a given branch.
* Return the schema of the snapshot at a given ref.
*
* <p>If branch does not exist, the table schema is returned because it will be the schema when
* the new branch is created.
* <p>If the ref does not exist or the ref is a branch, the table schema is returned because it
* will be the schema when the new branch is created. If the ref is a tag, then the snapshot
* schema is returned.
*
* @param table a {@link Table}
* @param branch branch name of the table (nullable)
* @return schema of the specific snapshot at the given branch
* @param ref ref name of the table (nullable)
* @return schema of the specific snapshot at the given ref
*/
public static Schema schemaFor(Table table, String branch) {
if (branch == null || branch.equals(SnapshotRef.MAIN_BRANCH)) {
public static Schema schemaFor(Table table, String ref) {
if (ref == null || ref.equals(SnapshotRef.MAIN_BRANCH)) {
return table.schema();
}

Snapshot ref = table.snapshot(branch);
if (ref == null) {
SnapshotRef snapshotRef = table.refs().get(ref);
if (null == snapshotRef || snapshotRef.isBranch()) {
return table.schema();
}

return schemaFor(table, ref.snapshotId());
return schemaFor(table, snapshotRef.snapshotId());
}

/**
* Return the schema of the snapshot at a given branch.
* Return the schema of the snapshot at a given ref.
*
* <p>If branch does not exist, the table schema is returned because it will be the schema when
* the new branch is created.
* <p>If the ref does not exist or the ref is a branch, the table schema is returned because it
* will be the schema when the new branch is created. If the ref is a tag, then the snapshot
* schema is returned.
*
* @param metadata a {@link TableMetadata}
* @param branch branch name of the table (nullable)
* @param ref ref name of the table (nullable)
* @return schema of the specific snapshot at the given branch
*/
public static Schema schemaFor(TableMetadata metadata, String branch) {
if (branch == null || branch.equals(SnapshotRef.MAIN_BRANCH)) {
public static Schema schemaFor(TableMetadata metadata, String ref) {
if (ref == null || ref.equals(SnapshotRef.MAIN_BRANCH)) {
return metadata.schema();
}

SnapshotRef ref = metadata.ref(branch);
if (ref == null) {
SnapshotRef snapshotRef = metadata.ref(ref);
if (snapshotRef == null || snapshotRef.isBranch()) {
return metadata.schema();
}

Snapshot snapshot = metadata.snapshot(ref.snapshotId());
Snapshot snapshot = metadata.snapshot(snapshotRef.snapshotId());
return metadata.schemas().get(snapshot.schemaId());
}

Expand Down
65 changes: 65 additions & 0 deletions core/src/test/java/org/apache/iceberg/util/TestSnapshotUtil.java
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
*/
package org.apache.iceberg.util;

import static org.apache.iceberg.types.Types.NestedField.optional;
import static org.apache.iceberg.types.Types.NestedField.required;
import static org.assertj.core.api.Assertions.assertThat;

Expand Down Expand Up @@ -187,4 +188,68 @@ private void expectedSnapshots(long[] snapshotIdExpected, Iterable<Snapshot> sna
.toArray();
assertThat(actualSnapshots).isEqualTo(snapshotIdExpected);
}

@Test
public void schemaForRef() {
Schema initialSchema =
new Schema(
required(1, "id", Types.IntegerType.get()),
required(2, "data", Types.StringType.get()));
assertThat(table.schema().asStruct()).isEqualTo(initialSchema.asStruct());

assertThat(SnapshotUtil.schemaFor(table, null).asStruct()).isEqualTo(initialSchema.asStruct());
assertThat(SnapshotUtil.schemaFor(table, "non-existing-ref").asStruct())
.isEqualTo(initialSchema.asStruct());
assertThat(SnapshotUtil.schemaFor(table, SnapshotRef.MAIN_BRANCH).asStruct())
.isEqualTo(initialSchema.asStruct());
}

@Test
public void schemaForBranch() {
Schema initialSchema =
new Schema(
required(1, "id", Types.IntegerType.get()),
required(2, "data", Types.StringType.get()));
assertThat(table.schema().asStruct()).isEqualTo(initialSchema.asStruct());

String branch = "branch";
table.manageSnapshots().createBranch(branch).commit();

assertThat(SnapshotUtil.schemaFor(table, branch).asStruct())
.isEqualTo(initialSchema.asStruct());

table.updateSchema().addColumn("zip", Types.IntegerType.get()).commit();
Schema expected =
new Schema(
required(1, "id", Types.IntegerType.get()),
required(2, "data", Types.StringType.get()),
optional(3, "zip", Types.IntegerType.get()));

assertThat(table.schema().asStruct()).isEqualTo(expected.asStruct());
assertThat(SnapshotUtil.schemaFor(table, branch).asStruct()).isEqualTo(expected.asStruct());
}

@Test
public void schemaForTag() {
Schema initialSchema =
new Schema(
required(1, "id", Types.IntegerType.get()),
required(2, "data", Types.StringType.get()));
assertThat(table.schema().asStruct()).isEqualTo(initialSchema.asStruct());

String tag = "tag";
table.manageSnapshots().createTag(tag, table.currentSnapshot().snapshotId()).commit();

assertThat(SnapshotUtil.schemaFor(table, tag).asStruct()).isEqualTo(initialSchema.asStruct());

table.updateSchema().addColumn("zip", Types.IntegerType.get()).commit();
Schema expected =
new Schema(
required(1, "id", Types.IntegerType.get()),
required(2, "data", Types.StringType.get()),
optional(3, "zip", Types.IntegerType.get()));

assertThat(table.schema().asStruct()).isEqualTo(expected.asStruct());
assertThat(SnapshotUtil.schemaFor(table, tag).asStruct()).isEqualTo(initialSchema.asStruct());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,12 @@
import org.apache.iceberg.relocated.com.google.common.collect.Iterables;
import org.apache.iceberg.relocated.com.google.common.collect.Lists;
import org.apache.iceberg.spark.SparkReadOptions;
import org.apache.iceberg.spark.SparkSchemaUtil;
import org.apache.iceberg.types.Types;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Encoders;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.RowFactory;
import org.apache.spark.sql.SparkSession;
import org.assertj.core.api.Assertions;
import org.junit.AfterClass;
Expand Down Expand Up @@ -402,16 +404,104 @@ public void testSnapshotSelectionByBranchWithSchemaChange() throws IOException {
// Deleting a column to indicate schema change
table.updateSchema().deleteColumn("data").commit();

// The data should have the deleted column as it was captured in an earlier snapshot.
Dataset<Row> deletedColumnBranchSnapshotResult =
// The data should not have the deleted column
Assertions.assertThat(
spark
.read()
.format("iceberg")
.option("branch", "branch")
.load(tableLocation)
.orderBy("id")
.collectAsList())
.containsExactly(RowFactory.create(1), RowFactory.create(2), RowFactory.create(3));

// re-introducing the column should not let the data re-appear
table.updateSchema().addColumn("data", Types.StringType.get()).commit();

Assertions.assertThat(
spark
.read()
.format("iceberg")
.option("branch", "branch")
.load(tableLocation)
.orderBy("id")
.as(Encoders.bean(SimpleRecord.class))
.collectAsList())
.containsExactly(
new SimpleRecord(1, null), new SimpleRecord(2, null), new SimpleRecord(3, null));
}

@Test
public void testWritingToBranchAfterSchemaChange() throws IOException {
String tableLocation = temp.newFolder("iceberg-table").toString();

HadoopTables tables = new HadoopTables(CONF);
PartitionSpec spec = PartitionSpec.unpartitioned();
Table table = tables.create(SCHEMA, spec, tableLocation);

// produce the first snapshot
List<SimpleRecord> firstBatchRecords =
Lists.newArrayList(
new SimpleRecord(1, "a"), new SimpleRecord(2, "b"), new SimpleRecord(3, "c"));
Dataset<Row> firstDf = spark.createDataFrame(firstBatchRecords, SimpleRecord.class);
firstDf.select("id", "data").write().format("iceberg").mode("append").save(tableLocation);

table.manageSnapshots().createBranch("branch", table.currentSnapshot().snapshotId()).commit();

Dataset<Row> branchSnapshotResult =
spark.read().format("iceberg").option("branch", "branch").load(tableLocation);
List<SimpleRecord> deletedColumnBranchSnapshotRecords =
deletedColumnBranchSnapshotResult
.orderBy("id")
.as(Encoders.bean(SimpleRecord.class))
.collectAsList();
List<SimpleRecord> branchSnapshotRecords =
branchSnapshotResult.orderBy("id").as(Encoders.bean(SimpleRecord.class)).collectAsList();
List<SimpleRecord> expectedRecords = Lists.newArrayList();
expectedRecords.addAll(firstBatchRecords);
Assert.assertEquals(
"Current snapshot rows should match", expectedRecords, deletedColumnBranchSnapshotRecords);
"Current snapshot rows should match", expectedRecords, branchSnapshotRecords);

// Deleting and add a new column of the same type to indicate schema change
table.updateSchema().deleteColumn("data").addColumn("zip", Types.IntegerType.get()).commit();

Assertions.assertThat(
spark
.read()
.format("iceberg")
.option("branch", "branch")
.load(tableLocation)
.orderBy("id")
.collectAsList())
.containsExactly(
RowFactory.create(1, null), RowFactory.create(2, null), RowFactory.create(3, null));

// writing new records into the branch should work with the new column
List<Row> records =
Lists.newArrayList(
RowFactory.create(4, 12345), RowFactory.create(5, 54321), RowFactory.create(6, 67890));

Dataset<Row> dataFrame =
spark.createDataFrame(
records,
SparkSchemaUtil.convert(
new Schema(
optional(1, "id", Types.IntegerType.get()),
optional(2, "zip", Types.IntegerType.get()))));
dataFrame
.select("id", "zip")
.write()
.format("iceberg")
.option("branch", "branch")
.mode("append")
.save(tableLocation);

Assertions.assertThat(
spark
.read()
.format("iceberg")
.option("branch", "branch")
.load(tableLocation)
.collectAsList())
.hasSize(6)
.contains(
RowFactory.create(1, null), RowFactory.create(2, null), RowFactory.create(3, null))
.containsAll(records);
}

@Test
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,10 +37,12 @@
import org.apache.iceberg.relocated.com.google.common.collect.Iterables;
import org.apache.iceberg.relocated.com.google.common.collect.Lists;
import org.apache.iceberg.spark.SparkReadOptions;
import org.apache.iceberg.spark.SparkSchemaUtil;
import org.apache.iceberg.types.Types;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Encoders;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.RowFactory;
import org.apache.spark.sql.SparkSession;
import org.assertj.core.api.Assertions;
import org.junit.AfterClass;
Expand Down Expand Up @@ -425,16 +427,104 @@ public void testSnapshotSelectionByBranchWithSchemaChange() throws IOException {
// Deleting a column to indicate schema change
table.updateSchema().deleteColumn("data").commit();

// The data should have the deleted column as it was captured in an earlier snapshot.
Dataset<Row> deletedColumnBranchSnapshotResult =
// The data should not have the deleted column
Assertions.assertThat(
spark
.read()
.format("iceberg")
.option("branch", "branch")
.load(tableLocation)
.orderBy("id")
.collectAsList())
.containsExactly(RowFactory.create(1), RowFactory.create(2), RowFactory.create(3));

// re-introducing the column should not let the data re-appear
table.updateSchema().addColumn("data", Types.StringType.get()).commit();

Assertions.assertThat(
spark
.read()
.format("iceberg")
.option("branch", "branch")
.load(tableLocation)
.orderBy("id")
.as(Encoders.bean(SimpleRecord.class))
.collectAsList())
.containsExactly(
new SimpleRecord(1, null), new SimpleRecord(2, null), new SimpleRecord(3, null));
}

@Test
public void testWritingToBranchAfterSchemaChange() throws IOException {
String tableLocation = temp.newFolder("iceberg-table").toString();

HadoopTables tables = new HadoopTables(CONF);
PartitionSpec spec = PartitionSpec.unpartitioned();
Table table = tables.create(SCHEMA, spec, properties, tableLocation);

// produce the first snapshot
List<SimpleRecord> firstBatchRecords =
Lists.newArrayList(
new SimpleRecord(1, "a"), new SimpleRecord(2, "b"), new SimpleRecord(3, "c"));
Dataset<Row> firstDf = spark.createDataFrame(firstBatchRecords, SimpleRecord.class);
firstDf.select("id", "data").write().format("iceberg").mode("append").save(tableLocation);

table.manageSnapshots().createBranch("branch", table.currentSnapshot().snapshotId()).commit();

Dataset<Row> branchSnapshotResult =
spark.read().format("iceberg").option("branch", "branch").load(tableLocation);
List<SimpleRecord> deletedColumnBranchSnapshotRecords =
deletedColumnBranchSnapshotResult
.orderBy("id")
.as(Encoders.bean(SimpleRecord.class))
.collectAsList();
List<SimpleRecord> branchSnapshotRecords =
branchSnapshotResult.orderBy("id").as(Encoders.bean(SimpleRecord.class)).collectAsList();
List<SimpleRecord> expectedRecords = Lists.newArrayList();
expectedRecords.addAll(firstBatchRecords);
Assert.assertEquals(
"Current snapshot rows should match", expectedRecords, deletedColumnBranchSnapshotRecords);
"Current snapshot rows should match", expectedRecords, branchSnapshotRecords);

// Deleting and add a new column of the same type to indicate schema change
table.updateSchema().deleteColumn("data").addColumn("zip", Types.IntegerType.get()).commit();

Assertions.assertThat(
spark
.read()
.format("iceberg")
.option("branch", "branch")
.load(tableLocation)
.orderBy("id")
.collectAsList())
.containsExactly(
RowFactory.create(1, null), RowFactory.create(2, null), RowFactory.create(3, null));

// writing new records into the branch should work with the new column
List<Row> records =
Lists.newArrayList(
RowFactory.create(4, 12345), RowFactory.create(5, 54321), RowFactory.create(6, 67890));

Dataset<Row> dataFrame =
spark.createDataFrame(
records,
SparkSchemaUtil.convert(
new Schema(
optional(1, "id", Types.IntegerType.get()),
optional(2, "zip", Types.IntegerType.get()))));
dataFrame
.select("id", "zip")
.write()
.format("iceberg")
.option("branch", "branch")
.mode("append")
.save(tableLocation);

Assertions.assertThat(
spark
.read()
.format("iceberg")
.option("branch", "branch")
.load(tableLocation)
.collectAsList())
.hasSize(6)
.contains(
RowFactory.create(1, null), RowFactory.create(2, null), RowFactory.create(3, null))
.containsAll(records);
}

@Test
Expand Down
Loading

0 comments on commit a4d4756

Please sign in to comment.