Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[#2543] feat(spark-connector): support row-level operations to iceberg Table #3243

Merged
merged 13 commits into from
May 13, 2024
Merged
4 changes: 4 additions & 0 deletions integration-test/build.gradle.kts
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ plugins {

val scalaVersion: String = project.properties["scalaVersion"] as? String ?: extra["defaultScalaVersion"].toString()
val sparkVersion: String = libs.versions.spark.get()
val sparkMajorVersion: String = sparkVersion.substringBeforeLast(".")
val kyuubiVersion: String = libs.versions.kyuubi.get()
val icebergVersion: String = libs.versions.iceberg.get()
val scalaCollectionCompatVersion: String = libs.versions.scala.collection.compat.get()

Expand Down Expand Up @@ -114,6 +116,8 @@ dependencies {
exclude("io.dropwizard.metrics")
exclude("org.rocksdb")
}
testImplementation("org.apache.iceberg:iceberg-spark-runtime-${sparkMajorVersion}_$scalaVersion:$icebergVersion")
testImplementation("org.apache.kyuubi:kyuubi-spark-connector-hive_$scalaVersion:$kyuubiVersion")

testImplementation(libs.okhttp3.loginterceptor)
testImplementation(libs.postgresql.driver)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,11 +68,39 @@ protected static String getDeleteSql(String tableName, String condition) {
return String.format("DELETE FROM %s where %s", tableName, condition);
}

private static String getUpdateTableSql(String tableName, String setClause, String whereClause) {
return String.format("UPDATE %s SET %s WHERE %s", tableName, setClause, whereClause);
}

private static String getRowLevelUpdateTableSql(
String targetTableName, String selectClause, String sourceTableName, String onClause) {
return String.format(
"MERGE INTO %s "
+ "USING (%s) %s "
+ "ON %s "
+ "WHEN MATCHED THEN UPDATE SET * "
+ "WHEN NOT MATCHED THEN INSERT *",
targetTableName, selectClause, sourceTableName, onClause);
}

private static String getRowLevelDeleteTableSql(
String targetTableName, String selectClause, String sourceTableName, String onClause) {
return String.format(
"MERGE INTO %s "
+ "USING (%s) %s "
+ "ON %s "
+ "WHEN MATCHED THEN DELETE "
+ "WHEN NOT MATCHED THEN INSERT *",
targetTableName, selectClause, sourceTableName, onClause);
}

// Whether supports [CLUSTERED BY col_name3 SORTED BY col_name INTO num_buckets BUCKETS]
protected abstract boolean supportsSparkSQLClusteredBy();

protected abstract boolean supportsPartition();

protected abstract boolean supportsDelete();

// Use a custom database not the original default database because SparkCommonIT couldn't
// read&write data to tables in default database. The main reason is default database location is
// determined by `hive.metastore.warehouse.dir` in hive-site.xml which is local HDFS address
Expand Down Expand Up @@ -702,6 +730,28 @@ void testTableOptions() {
checkTableReadWrite(tableInfo);
}

@Test
@EnabledIf("supportsDelete")
void testDeleteOperation() {
String tableName = "test_row_level_delete_table";
dropTableIfExists(tableName);
createSimpleTable(tableName);

SparkTableInfo table = getTableInfo(tableName);
checkTableColumns(tableName, getSimpleTableColumn(), table);
sql(
String.format(
"INSERT INTO %s VALUES (1, '1', 1),(2, '2', 2),(3, '3', 3),(4, '4', 4),(5, '5', 5)",
tableName));
List<String> queryResult1 = getTableData(tableName);
Assertions.assertEquals(5, queryResult1.size());
Assertions.assertEquals("1,1,1;2,2,2;3,3,3;4,4,4;5,5,5", String.join(";", queryResult1));
sql(getDeleteSql(tableName, "id <= 4"));
List<String> queryResult2 = getTableData(tableName);
Assertions.assertEquals(1, queryResult2.size());
Assertions.assertEquals("5,5,5", queryResult2.get(0));
}

protected void checkTableReadWrite(SparkTableInfo table) {
String name = table.getTableIdentifier();
boolean isPartitionTable = table.isPartitionTable();
Expand Down Expand Up @@ -760,6 +810,49 @@ protected String getExpectedTableData(SparkTableInfo table) {
.collect(Collectors.joining(","));
}

protected void checkRowLevelUpdate(String tableName) {
writeToEmptyTableAndCheckData(tableName);
String updatedValues = "id = 6, name = '6', age = 6";
sql(getUpdateTableSql(tableName, updatedValues, "id = 5"));
List<String> queryResult = getQueryData(getSelectAllSqlWithOrder(tableName, "id"));
Assertions.assertEquals(5, queryResult.size());
Assertions.assertEquals("1,1,1;2,2,2;3,3,3;4,4,4;6,6,6", String.join(";", queryResult));
}

protected void checkRowLevelDelete(String tableName) {
writeToEmptyTableAndCheckData(tableName);
sql(getDeleteSql(tableName, "id <= 2"));
List<String> queryResult = getQueryData(getSelectAllSqlWithOrder(tableName, "id"));
Assertions.assertEquals(3, queryResult.size());
Assertions.assertEquals("3,3,3;4,4,4;5,5,5", String.join(";", queryResult));
}

protected void checkDeleteByMergeInto(String tableName) {
writeToEmptyTableAndCheckData(tableName);

String sourceTableName = "source_table";
String selectClause =
"SELECT 1 AS id, '1' AS name, 1 AS age UNION ALL SELECT 6 AS id, '6' AS name, 6 AS age";
String onClause = String.format("%s.id = %s.id", tableName, sourceTableName);
sql(getRowLevelDeleteTableSql(tableName, selectClause, sourceTableName, onClause));
List<String> queryResult = getQueryData(getSelectAllSqlWithOrder(tableName, "id"));
Assertions.assertEquals(5, queryResult.size());
Assertions.assertEquals("2,2,2;3,3,3;4,4,4;5,5,5;6,6,6", String.join(";", queryResult));
}

protected void checkTableUpdateByMergeInto(String tableName) {
writeToEmptyTableAndCheckData(tableName);

String sourceTableName = "source_table";
String selectClause =
"SELECT 1 AS id, '2' AS name, 2 AS age UNION ALL SELECT 6 AS id, '6' AS name, 6 AS age";
String onClause = String.format("%s.id = %s.id", tableName, sourceTableName);
sql(getRowLevelUpdateTableSql(tableName, selectClause, sourceTableName, onClause));
List<String> queryResult = getQueryData(getSelectAllSqlWithOrder(tableName, "id"));
Assertions.assertEquals(6, queryResult.size());
Assertions.assertEquals("1,2,2;2,2,2;3,3,3;4,4,4;5,5,5;6,6,6", String.join(";", queryResult));
}

protected String getCreateSimpleTableString(String tableName) {
return getCreateSimpleTableString(tableName, false);
}
Expand Down Expand Up @@ -801,6 +894,16 @@ protected void checkTableColumns(
.check(tableInfo);
}

private void writeToEmptyTableAndCheckData(String tableName) {
sql(
String.format(
"INSERT INTO %s VALUES (1, '1', 1),(2, '2', 2),(3, '3', 3),(4, '4', 4),(5, '5', 5)",
tableName));
List<String> queryResult = getTableData(tableName);
Assertions.assertEquals(5, queryResult.size());
Assertions.assertEquals("1,1,1;2,2,2;3,3,3;4,4,4;5,5,5", String.join(";", queryResult));
}

// partition expression may contain "'", like a='s'/b=1
private String getPartitionExpression(SparkTableInfo table, String delimiter) {
return table.getPartitionedColumns().stream()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,11 @@ protected boolean supportsPartition() {
return true;
}

@Override
protected boolean supportsDelete() {
return false;
}

@Test
public void testCreateHiveFormatPartitionTable() {
String tableName = "hive_partition_table";
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import com.datastrato.gravitino.integration.test.util.spark.SparkTableInfo;
import com.datastrato.gravitino.integration.test.util.spark.SparkTableInfoChecker;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import java.io.File;
import java.util.ArrayList;
import java.util.Arrays;
Expand All @@ -17,6 +18,7 @@
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
import lombok.Data;
import org.apache.hadoop.fs.Path;
import org.apache.spark.sql.Column;
import org.apache.spark.sql.Dataset;
Expand All @@ -34,9 +36,16 @@
import org.apache.spark.sql.types.StructField;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.MethodSource;

public abstract class SparkIcebergCatalogIT extends SparkCommonIT {

private static final String ICEBERG_FORMAT_VERSION = "format-version";
private static final String ICEBERG_DELETE_MODE = "write.delete.mode";
private static final String ICEBERG_UPDATE_MODE = "write.update.mode";
private static final String ICEBERG_MERGE_MODE = "write.merge.mode";

@Override
protected String getCatalogName() {
return "iceberg";
Expand All @@ -57,6 +66,11 @@ protected boolean supportsPartition() {
return true;
}

@Override
protected boolean supportsDelete() {
return true;
}

@Override
protected String getTableLocation(SparkTableInfo table) {
return String.join(File.separator, table.getTableLocation(), "data");
Expand Down Expand Up @@ -216,6 +230,15 @@ void testIcebergMetadataColumns() throws NoSuchTableException {
testDeleteMetadataColumn();
}

@ParameterizedTest
@MethodSource("getIcebergTablePropertyValues")
void testIcebergTableRowLevelOperations(IcebergTableWriteProperties icebergTableWriteProperties) {
testIcebergDeleteOperation(icebergTableWriteProperties);
testIcebergUpdateOperation(icebergTableWriteProperties);
testIcebergMergeIntoDeleteOperation(icebergTableWriteProperties);
testIcebergMergeIntoUpdateOperation(icebergTableWriteProperties);
}

private void testMetadataColumns() {
String tableName = "test_metadata_columns";
dropTableIfExists(tableName);
Expand Down Expand Up @@ -386,6 +409,84 @@ private void testDeleteMetadataColumn() {
Assertions.assertEquals(0, queryResult1.size());
}

private void testIcebergDeleteOperation(IcebergTableWriteProperties icebergTableWriteProperties) {
String tableName =
String.format(
"test_iceberg_%s_%s_delete_operation",
icebergTableWriteProperties.isPartitionedTable,
icebergTableWriteProperties.formatVersion);
dropTableIfExists(tableName);
createIcebergTableWithTableProperties(
tableName,
icebergTableWriteProperties.isPartitionedTable,
ImmutableMap.of(
ICEBERG_FORMAT_VERSION,
String.valueOf(icebergTableWriteProperties.formatVersion),
ICEBERG_DELETE_MODE,
icebergTableWriteProperties.writeMode));
checkTableColumns(tableName, getSimpleTableColumn(), getTableInfo(tableName));
checkRowLevelDelete(tableName);
}

private void testIcebergUpdateOperation(IcebergTableWriteProperties icebergTableWriteProperties) {
String tableName =
String.format(
"test_iceberg_%s_%s_update_operation",
icebergTableWriteProperties.isPartitionedTable,
icebergTableWriteProperties.formatVersion);
dropTableIfExists(tableName);
createIcebergTableWithTableProperties(
tableName,
icebergTableWriteProperties.isPartitionedTable,
ImmutableMap.of(
ICEBERG_FORMAT_VERSION,
String.valueOf(icebergTableWriteProperties.formatVersion),
ICEBERG_UPDATE_MODE,
icebergTableWriteProperties.writeMode));
checkTableColumns(tableName, getSimpleTableColumn(), getTableInfo(tableName));
checkRowLevelUpdate(tableName);
}

private void testIcebergMergeIntoDeleteOperation(
IcebergTableWriteProperties icebergTableWriteProperties) {
String tableName =
String.format(
"test_iceberg_%s_%s_mergeinto_delete_operation",
icebergTableWriteProperties.isPartitionedTable,
icebergTableWriteProperties.formatVersion);
dropTableIfExists(tableName);
createIcebergTableWithTableProperties(
tableName,
icebergTableWriteProperties.isPartitionedTable,
ImmutableMap.of(
ICEBERG_FORMAT_VERSION,
String.valueOf(icebergTableWriteProperties.formatVersion),
ICEBERG_MERGE_MODE,
icebergTableWriteProperties.writeMode));
checkTableColumns(tableName, getSimpleTableColumn(), getTableInfo(tableName));
checkDeleteByMergeInto(tableName);
}

private void testIcebergMergeIntoUpdateOperation(
IcebergTableWriteProperties icebergTableWriteProperties) {
String tableName =
String.format(
"test_iceberg_%s_%s_mergeinto_update_operation",
icebergTableWriteProperties.isPartitionedTable,
icebergTableWriteProperties.formatVersion);
dropTableIfExists(tableName);
createIcebergTableWithTableProperties(
tableName,
icebergTableWriteProperties.isPartitionedTable,
ImmutableMap.of(
ICEBERG_FORMAT_VERSION,
String.valueOf(icebergTableWriteProperties.formatVersion),
ICEBERG_MERGE_MODE,
icebergTableWriteProperties.writeMode));
checkTableColumns(tableName, getSimpleTableColumn(), getTableInfo(tableName));
checkTableUpdateByMergeInto(tableName);
}

private List<SparkTableInfo.SparkColumnInfo> getIcebergSimpleTableColumn() {
return Arrays.asList(
SparkTableInfo.SparkColumnInfo.of("id", DataTypes.IntegerType, "id comment"),
Expand Down Expand Up @@ -416,4 +517,46 @@ private SparkMetadataColumnInfo[] getIcebergMetadataColumns() {
new SparkMetadataColumnInfo("_deleted", DataTypes.BooleanType, false)
};
}

private List<IcebergTableWriteProperties> getIcebergTablePropertyValues() {
return Arrays.asList(
IcebergTableWriteProperties.of(false, 1, "copy-on-write"),
IcebergTableWriteProperties.of(false, 2, "merge-on-read"),
IcebergTableWriteProperties.of(true, 1, "copy-on-write"),
IcebergTableWriteProperties.of(true, 2, "merge-on-read"));
}

private void createIcebergTableWithTableProperties(
String tableName, boolean isPartitioned, ImmutableMap<String, String> tblProperties) {
String partitionedClause = isPartitioned ? " PARTITIONED BY (name) " : "";
String tblPropertiesStr =
tblProperties.entrySet().stream()
.map(e -> String.format("'%s'='%s'", e.getKey(), e.getValue()))
.collect(Collectors.joining(","));
String createSql =
String.format(
"CREATE TABLE %s (id INT COMMENT 'id comment', name STRING COMMENT '', age INT) %s TBLPROPERTIES(%s)",
tableName, partitionedClause, tblPropertiesStr);
sql(createSql);
}

@Data
private static class IcebergTableWriteProperties {

private boolean isPartitionedTable;
private int formatVersion;
private String writeMode;

private IcebergTableWriteProperties(
boolean isPartitionedTable, int formatVersion, String writeMode) {
this.isPartitionedTable = isPartitionedTable;
this.formatVersion = formatVersion;
this.writeMode = writeMode;
}

static IcebergTableWriteProperties of(
boolean isPartitionedTable, int formatVersion, String writeMode) {
return new IcebergTableWriteProperties(isPartitionedTable, formatVersion, writeMode);
}
}
}
Loading
Loading