From 2c1d6dff8bae10df2e11b581d9254a6d90b6c413 Mon Sep 17 00:00:00 2001 From: cai can <94670132+caican00@users.noreply.github.com> Date: Mon, 13 May 2024 16:50:53 +0800 Subject: [PATCH] [#2543] feat(spark-connector): support row-level operations to iceberg Table (#3243) ### What changes were proposed in this pull request? - refactor table implementation, make `SparkIcebergTable` extend Iceberg `SparkTable`, and `SparkHiveTable` extend Kyuubi `HiveTable`. - support row-level operations to iceberg Table ``` 1. update tableName set c1=v1, c2=v2, ... 2. merge into targetTable t using sourceTable s on s.key=t.key when matched then ... when not matched then ... 3. delete from table where xxx ``` ### Why are the changes needed? 1. For spark-connector in Iceberg, it explicitly uses `SparkTable` to identify whether it is an Iceberg table, so the `SparkIcebergTable` must extend `SparkTable`. 2. support row-level operations to iceberg Table. Fix: https://github.com/datastrato/gravitino/issues/2543 ### Does this PR introduce any user-facing change? Yes, support update ... , merge into ..., delete from ... ### How was this patch tested? New ITs. --- docs/spark-connector/spark-catalog-iceberg.md | 25 ++- integration-test/build.gradle.kts | 4 + .../integration/test/spark/SparkCommonIT.java | 103 +++++++++++++ .../test/spark/hive/SparkHiveCatalogIT.java | 5 + .../spark/iceberg/SparkIcebergCatalogIT.java | 143 ++++++++++++++++++ .../test/util/spark/SparkTableInfo.java | 34 +++-- .../test/util/spark/SparkUtilIT.java | 8 +- .../spark/connector/ConnectorConstants.java | 1 + .../connector/SparkTransformConverter.java | 6 - .../spark/connector/catalog/BaseCatalog.java | 17 +-- .../connector/hive/GravitinoHiveCatalog.java | 24 ++- .../spark/connector/hive/SparkHiveTable.java | 44 +++++- .../iceberg/GravitinoIcebergCatalog.java | 24 ++- .../connector/iceberg/SparkIcebergTable.java | 63 +++++--- .../plugin/GravitinoDriverPlugin.java | 27 +++- .../spark/connector/utils/ConnectorUtil.java | 27 ++++ .../GravitinoTableInfoHelper.java} | 69 ++------- .../connector/utils/TestConnectorUtil.java | 39 +++++ 18 files changed, 531 insertions(+), 132 deletions(-) create mode 100644 spark-connector/spark-connector/src/main/java/com/datastrato/gravitino/spark/connector/utils/ConnectorUtil.java rename spark-connector/spark-connector/src/main/java/com/datastrato/gravitino/spark/connector/{table/SparkBaseTable.java => utils/GravitinoTableInfoHelper.java} (65%) create mode 100644 spark-connector/spark-connector/src/test/java/com/datastrato/gravitino/spark/connector/utils/TestConnectorUtil.java diff --git a/docs/spark-connector/spark-catalog-iceberg.md b/docs/spark-connector/spark-catalog-iceberg.md index b18c14f937f..f5defef510a 100644 --- a/docs/spark-connector/spark-catalog-iceberg.md +++ b/docs/spark-connector/spark-catalog-iceberg.md @@ -8,7 +8,7 @@ This software is licensed under the Apache License version 2." ## Capabilities -#### Support basic DML and DDL operations: +#### Support DML and DDL operations: - `CREATE TABLE` @@ -18,13 +18,12 @@ Supports basic create table clause including table schema, properties, partition - `ALTER TABLE` - `INSERT INTO&OVERWRITE` - `SELECT` -- `DELETE` - -Supports file delete only. +- `MERGE INOT` +- `DELETE FROM` +- `UPDATE` #### Not supported operations: -- Row level operations. like `MERGE INOT`, `DELETE FROM`, `UPDATE` - View operations. - Branching and tagging operations. - Spark procedures. @@ -57,6 +56,22 @@ VALUES (3, 'Charlie', 'Sales', TIMESTAMP '2021-03-01 08:45:00'); SELECT * FROM employee WHERE date(hire_date) = '2021-01-01' + +UPDATE employee SET department = 'Jenny' WHERE id = 1; + +DELETE FROM employee WHERE id < 2; + +MERGE INTO employee +USING (SELECT 4 as id, 'David' as name, 'Engineering' as department, TIMESTAMP '2021-04-01 09:00:00' as hire_date) as new_employee +ON employee.id = new_employee.id +WHEN MATCHED THEN UPDATE SET * +WHEN NOT MATCHED THEN INSERT *; + +MERGE INTO employee +USING (SELECT 4 as id, 'David' as name, 'Engineering' as department, TIMESTAMP '2021-04-01 09:00:00' as hire_date) as new_employee +ON employee.id = new_employee.id +WHEN MATCHED THEN DELETE +WHEN NOT MATCHED THEN INSERT *; ``` ## Catalog properties diff --git a/integration-test/build.gradle.kts b/integration-test/build.gradle.kts index 384f8417b18..95ce862da68 100644 --- a/integration-test/build.gradle.kts +++ b/integration-test/build.gradle.kts @@ -13,6 +13,8 @@ plugins { val scalaVersion: String = project.properties["scalaVersion"] as? String ?: extra["defaultScalaVersion"].toString() val sparkVersion: String = libs.versions.spark.get() +val sparkMajorVersion: String = sparkVersion.substringBeforeLast(".") +val kyuubiVersion: String = libs.versions.kyuubi.get() val icebergVersion: String = libs.versions.iceberg.get() val scalaCollectionCompatVersion: String = libs.versions.scala.collection.compat.get() @@ -114,6 +116,8 @@ dependencies { exclude("io.dropwizard.metrics") exclude("org.rocksdb") } + testImplementation("org.apache.iceberg:iceberg-spark-runtime-${sparkMajorVersion}_$scalaVersion:$icebergVersion") + testImplementation("org.apache.kyuubi:kyuubi-spark-connector-hive_$scalaVersion:$kyuubiVersion") testImplementation(libs.okhttp3.loginterceptor) testImplementation(libs.postgresql.driver) diff --git a/integration-test/src/test/java/com/datastrato/gravitino/integration/test/spark/SparkCommonIT.java b/integration-test/src/test/java/com/datastrato/gravitino/integration/test/spark/SparkCommonIT.java index 9dab1b46839..bd44fd33374 100644 --- a/integration-test/src/test/java/com/datastrato/gravitino/integration/test/spark/SparkCommonIT.java +++ b/integration-test/src/test/java/com/datastrato/gravitino/integration/test/spark/SparkCommonIT.java @@ -68,11 +68,39 @@ protected static String getDeleteSql(String tableName, String condition) { return String.format("DELETE FROM %s where %s", tableName, condition); } + private static String getUpdateTableSql(String tableName, String setClause, String whereClause) { + return String.format("UPDATE %s SET %s WHERE %s", tableName, setClause, whereClause); + } + + private static String getRowLevelUpdateTableSql( + String targetTableName, String selectClause, String sourceTableName, String onClause) { + return String.format( + "MERGE INTO %s " + + "USING (%s) %s " + + "ON %s " + + "WHEN MATCHED THEN UPDATE SET * " + + "WHEN NOT MATCHED THEN INSERT *", + targetTableName, selectClause, sourceTableName, onClause); + } + + private static String getRowLevelDeleteTableSql( + String targetTableName, String selectClause, String sourceTableName, String onClause) { + return String.format( + "MERGE INTO %s " + + "USING (%s) %s " + + "ON %s " + + "WHEN MATCHED THEN DELETE " + + "WHEN NOT MATCHED THEN INSERT *", + targetTableName, selectClause, sourceTableName, onClause); + } + // Whether supports [CLUSTERED BY col_name3 SORTED BY col_name INTO num_buckets BUCKETS] protected abstract boolean supportsSparkSQLClusteredBy(); protected abstract boolean supportsPartition(); + protected abstract boolean supportsDelete(); + // Use a custom database not the original default database because SparkCommonIT couldn't // read&write data to tables in default database. The main reason is default database location is // determined by `hive.metastore.warehouse.dir` in hive-site.xml which is local HDFS address @@ -702,6 +730,28 @@ void testTableOptions() { checkTableReadWrite(tableInfo); } + @Test + @EnabledIf("supportsDelete") + void testDeleteOperation() { + String tableName = "test_row_level_delete_table"; + dropTableIfExists(tableName); + createSimpleTable(tableName); + + SparkTableInfo table = getTableInfo(tableName); + checkTableColumns(tableName, getSimpleTableColumn(), table); + sql( + String.format( + "INSERT INTO %s VALUES (1, '1', 1),(2, '2', 2),(3, '3', 3),(4, '4', 4),(5, '5', 5)", + tableName)); + List queryResult1 = getTableData(tableName); + Assertions.assertEquals(5, queryResult1.size()); + Assertions.assertEquals("1,1,1;2,2,2;3,3,3;4,4,4;5,5,5", String.join(";", queryResult1)); + sql(getDeleteSql(tableName, "id <= 4")); + List queryResult2 = getTableData(tableName); + Assertions.assertEquals(1, queryResult2.size()); + Assertions.assertEquals("5,5,5", queryResult2.get(0)); + } + protected void checkTableReadWrite(SparkTableInfo table) { String name = table.getTableIdentifier(); boolean isPartitionTable = table.isPartitionTable(); @@ -760,6 +810,49 @@ protected String getExpectedTableData(SparkTableInfo table) { .collect(Collectors.joining(",")); } + protected void checkRowLevelUpdate(String tableName) { + writeToEmptyTableAndCheckData(tableName); + String updatedValues = "id = 6, name = '6', age = 6"; + sql(getUpdateTableSql(tableName, updatedValues, "id = 5")); + List queryResult = getQueryData(getSelectAllSqlWithOrder(tableName, "id")); + Assertions.assertEquals(5, queryResult.size()); + Assertions.assertEquals("1,1,1;2,2,2;3,3,3;4,4,4;6,6,6", String.join(";", queryResult)); + } + + protected void checkRowLevelDelete(String tableName) { + writeToEmptyTableAndCheckData(tableName); + sql(getDeleteSql(tableName, "id <= 2")); + List queryResult = getQueryData(getSelectAllSqlWithOrder(tableName, "id")); + Assertions.assertEquals(3, queryResult.size()); + Assertions.assertEquals("3,3,3;4,4,4;5,5,5", String.join(";", queryResult)); + } + + protected void checkDeleteByMergeInto(String tableName) { + writeToEmptyTableAndCheckData(tableName); + + String sourceTableName = "source_table"; + String selectClause = + "SELECT 1 AS id, '1' AS name, 1 AS age UNION ALL SELECT 6 AS id, '6' AS name, 6 AS age"; + String onClause = String.format("%s.id = %s.id", tableName, sourceTableName); + sql(getRowLevelDeleteTableSql(tableName, selectClause, sourceTableName, onClause)); + List queryResult = getQueryData(getSelectAllSqlWithOrder(tableName, "id")); + Assertions.assertEquals(5, queryResult.size()); + Assertions.assertEquals("2,2,2;3,3,3;4,4,4;5,5,5;6,6,6", String.join(";", queryResult)); + } + + protected void checkTableUpdateByMergeInto(String tableName) { + writeToEmptyTableAndCheckData(tableName); + + String sourceTableName = "source_table"; + String selectClause = + "SELECT 1 AS id, '2' AS name, 2 AS age UNION ALL SELECT 6 AS id, '6' AS name, 6 AS age"; + String onClause = String.format("%s.id = %s.id", tableName, sourceTableName); + sql(getRowLevelUpdateTableSql(tableName, selectClause, sourceTableName, onClause)); + List queryResult = getQueryData(getSelectAllSqlWithOrder(tableName, "id")); + Assertions.assertEquals(6, queryResult.size()); + Assertions.assertEquals("1,2,2;2,2,2;3,3,3;4,4,4;5,5,5;6,6,6", String.join(";", queryResult)); + } + protected String getCreateSimpleTableString(String tableName) { return getCreateSimpleTableString(tableName, false); } @@ -801,6 +894,16 @@ protected void checkTableColumns( .check(tableInfo); } + private void writeToEmptyTableAndCheckData(String tableName) { + sql( + String.format( + "INSERT INTO %s VALUES (1, '1', 1),(2, '2', 2),(3, '3', 3),(4, '4', 4),(5, '5', 5)", + tableName)); + List queryResult = getTableData(tableName); + Assertions.assertEquals(5, queryResult.size()); + Assertions.assertEquals("1,1,1;2,2,2;3,3,3;4,4,4;5,5,5", String.join(";", queryResult)); + } + // partition expression may contain "'", like a='s'/b=1 private String getPartitionExpression(SparkTableInfo table, String delimiter) { return table.getPartitionedColumns().stream() diff --git a/integration-test/src/test/java/com/datastrato/gravitino/integration/test/spark/hive/SparkHiveCatalogIT.java b/integration-test/src/test/java/com/datastrato/gravitino/integration/test/spark/hive/SparkHiveCatalogIT.java index 9e0b8291df4..c07565b9a2d 100644 --- a/integration-test/src/test/java/com/datastrato/gravitino/integration/test/spark/hive/SparkHiveCatalogIT.java +++ b/integration-test/src/test/java/com/datastrato/gravitino/integration/test/spark/hive/SparkHiveCatalogIT.java @@ -57,6 +57,11 @@ protected boolean supportsPartition() { return true; } + @Override + protected boolean supportsDelete() { + return false; + } + @Test void testCreateHiveFormatPartitionTable() { String tableName = "hive_partition_table"; diff --git a/integration-test/src/test/java/com/datastrato/gravitino/integration/test/spark/iceberg/SparkIcebergCatalogIT.java b/integration-test/src/test/java/com/datastrato/gravitino/integration/test/spark/iceberg/SparkIcebergCatalogIT.java index b94d6eb5e17..a87246bfce1 100644 --- a/integration-test/src/test/java/com/datastrato/gravitino/integration/test/spark/iceberg/SparkIcebergCatalogIT.java +++ b/integration-test/src/test/java/com/datastrato/gravitino/integration/test/spark/iceberg/SparkIcebergCatalogIT.java @@ -9,6 +9,7 @@ import com.datastrato.gravitino.integration.test.util.spark.SparkTableInfo; import com.datastrato.gravitino.integration.test.util.spark.SparkTableInfoChecker; import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; import java.io.File; import java.util.ArrayList; import java.util.Arrays; @@ -17,6 +18,7 @@ import java.util.List; import java.util.Map; import java.util.stream.Collectors; +import lombok.Data; import org.apache.hadoop.fs.Path; import org.apache.spark.sql.Column; import org.apache.spark.sql.Dataset; @@ -34,9 +36,16 @@ import org.apache.spark.sql.types.StructField; import org.junit.jupiter.api.Assertions; import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.MethodSource; public abstract class SparkIcebergCatalogIT extends SparkCommonIT { + private static final String ICEBERG_FORMAT_VERSION = "format-version"; + private static final String ICEBERG_DELETE_MODE = "write.delete.mode"; + private static final String ICEBERG_UPDATE_MODE = "write.update.mode"; + private static final String ICEBERG_MERGE_MODE = "write.merge.mode"; + @Override protected String getCatalogName() { return "iceberg"; @@ -57,6 +66,11 @@ protected boolean supportsPartition() { return true; } + @Override + protected boolean supportsDelete() { + return true; + } + @Override protected String getTableLocation(SparkTableInfo table) { return String.join(File.separator, table.getTableLocation(), "data"); @@ -216,6 +230,15 @@ void testIcebergMetadataColumns() throws NoSuchTableException { testDeleteMetadataColumn(); } + @ParameterizedTest + @MethodSource("getIcebergTablePropertyValues") + void testIcebergTableRowLevelOperations(IcebergTableWriteProperties icebergTableWriteProperties) { + testIcebergDeleteOperation(icebergTableWriteProperties); + testIcebergUpdateOperation(icebergTableWriteProperties); + testIcebergMergeIntoDeleteOperation(icebergTableWriteProperties); + testIcebergMergeIntoUpdateOperation(icebergTableWriteProperties); + } + private void testMetadataColumns() { String tableName = "test_metadata_columns"; dropTableIfExists(tableName); @@ -386,6 +409,84 @@ private void testDeleteMetadataColumn() { Assertions.assertEquals(0, queryResult1.size()); } + private void testIcebergDeleteOperation(IcebergTableWriteProperties icebergTableWriteProperties) { + String tableName = + String.format( + "test_iceberg_%s_%s_delete_operation", + icebergTableWriteProperties.isPartitionedTable, + icebergTableWriteProperties.formatVersion); + dropTableIfExists(tableName); + createIcebergTableWithTableProperties( + tableName, + icebergTableWriteProperties.isPartitionedTable, + ImmutableMap.of( + ICEBERG_FORMAT_VERSION, + String.valueOf(icebergTableWriteProperties.formatVersion), + ICEBERG_DELETE_MODE, + icebergTableWriteProperties.writeMode)); + checkTableColumns(tableName, getSimpleTableColumn(), getTableInfo(tableName)); + checkRowLevelDelete(tableName); + } + + private void testIcebergUpdateOperation(IcebergTableWriteProperties icebergTableWriteProperties) { + String tableName = + String.format( + "test_iceberg_%s_%s_update_operation", + icebergTableWriteProperties.isPartitionedTable, + icebergTableWriteProperties.formatVersion); + dropTableIfExists(tableName); + createIcebergTableWithTableProperties( + tableName, + icebergTableWriteProperties.isPartitionedTable, + ImmutableMap.of( + ICEBERG_FORMAT_VERSION, + String.valueOf(icebergTableWriteProperties.formatVersion), + ICEBERG_UPDATE_MODE, + icebergTableWriteProperties.writeMode)); + checkTableColumns(tableName, getSimpleTableColumn(), getTableInfo(tableName)); + checkRowLevelUpdate(tableName); + } + + private void testIcebergMergeIntoDeleteOperation( + IcebergTableWriteProperties icebergTableWriteProperties) { + String tableName = + String.format( + "test_iceberg_%s_%s_mergeinto_delete_operation", + icebergTableWriteProperties.isPartitionedTable, + icebergTableWriteProperties.formatVersion); + dropTableIfExists(tableName); + createIcebergTableWithTableProperties( + tableName, + icebergTableWriteProperties.isPartitionedTable, + ImmutableMap.of( + ICEBERG_FORMAT_VERSION, + String.valueOf(icebergTableWriteProperties.formatVersion), + ICEBERG_MERGE_MODE, + icebergTableWriteProperties.writeMode)); + checkTableColumns(tableName, getSimpleTableColumn(), getTableInfo(tableName)); + checkDeleteByMergeInto(tableName); + } + + private void testIcebergMergeIntoUpdateOperation( + IcebergTableWriteProperties icebergTableWriteProperties) { + String tableName = + String.format( + "test_iceberg_%s_%s_mergeinto_update_operation", + icebergTableWriteProperties.isPartitionedTable, + icebergTableWriteProperties.formatVersion); + dropTableIfExists(tableName); + createIcebergTableWithTableProperties( + tableName, + icebergTableWriteProperties.isPartitionedTable, + ImmutableMap.of( + ICEBERG_FORMAT_VERSION, + String.valueOf(icebergTableWriteProperties.formatVersion), + ICEBERG_MERGE_MODE, + icebergTableWriteProperties.writeMode)); + checkTableColumns(tableName, getSimpleTableColumn(), getTableInfo(tableName)); + checkTableUpdateByMergeInto(tableName); + } + private List getIcebergSimpleTableColumn() { return Arrays.asList( SparkTableInfo.SparkColumnInfo.of("id", DataTypes.IntegerType, "id comment"), @@ -416,4 +517,46 @@ private SparkMetadataColumnInfo[] getIcebergMetadataColumns() { new SparkMetadataColumnInfo("_deleted", DataTypes.BooleanType, false) }; } + + private List getIcebergTablePropertyValues() { + return Arrays.asList( + IcebergTableWriteProperties.of(false, 1, "copy-on-write"), + IcebergTableWriteProperties.of(false, 2, "merge-on-read"), + IcebergTableWriteProperties.of(true, 1, "copy-on-write"), + IcebergTableWriteProperties.of(true, 2, "merge-on-read")); + } + + private void createIcebergTableWithTableProperties( + String tableName, boolean isPartitioned, ImmutableMap tblProperties) { + String partitionedClause = isPartitioned ? " PARTITIONED BY (name) " : ""; + String tblPropertiesStr = + tblProperties.entrySet().stream() + .map(e -> String.format("'%s'='%s'", e.getKey(), e.getValue())) + .collect(Collectors.joining(",")); + String createSql = + String.format( + "CREATE TABLE %s (id INT COMMENT 'id comment', name STRING COMMENT '', age INT) %s TBLPROPERTIES(%s)", + tableName, partitionedClause, tblPropertiesStr); + sql(createSql); + } + + @Data + private static class IcebergTableWriteProperties { + + private boolean isPartitionedTable; + private int formatVersion; + private String writeMode; + + private IcebergTableWriteProperties( + boolean isPartitionedTable, int formatVersion, String writeMode) { + this.isPartitionedTable = isPartitionedTable; + this.formatVersion = formatVersion; + this.writeMode = writeMode; + } + + static IcebergTableWriteProperties of( + boolean isPartitionedTable, int formatVersion, String writeMode) { + return new IcebergTableWriteProperties(isPartitionedTable, formatVersion, writeMode); + } + } } diff --git a/integration-test/src/test/java/com/datastrato/gravitino/integration/test/util/spark/SparkTableInfo.java b/integration-test/src/test/java/com/datastrato/gravitino/integration/test/util/spark/SparkTableInfo.java index ee08de46ee9..0ef93040d4e 100644 --- a/integration-test/src/test/java/com/datastrato/gravitino/integration/test/util/spark/SparkTableInfo.java +++ b/integration-test/src/test/java/com/datastrato/gravitino/integration/test/util/spark/SparkTableInfo.java @@ -6,7 +6,8 @@ package com.datastrato.gravitino.integration.test.util.spark; import com.datastrato.gravitino.spark.connector.ConnectorConstants; -import com.datastrato.gravitino.spark.connector.table.SparkBaseTable; +import com.datastrato.gravitino.spark.connector.hive.SparkHiveTable; +import com.datastrato.gravitino.spark.connector.iceberg.SparkIcebergTable; import java.util.ArrayList; import java.util.Arrays; import java.util.HashSet; @@ -18,6 +19,7 @@ import lombok.Data; import org.apache.commons.lang3.StringUtils; import org.apache.spark.sql.connector.catalog.SupportsMetadataColumns; +import org.apache.spark.sql.connector.catalog.Table; import org.apache.spark.sql.connector.catalog.TableCatalog; import org.apache.spark.sql.connector.expressions.ApplyTransform; import org.apache.spark.sql.connector.expressions.BucketTransform; @@ -29,6 +31,7 @@ import org.apache.spark.sql.connector.expressions.Transform; import org.apache.spark.sql.connector.expressions.YearsTransform; import org.apache.spark.sql.types.DataType; +import org.apache.spark.sql.types.StructType; import org.junit.jupiter.api.Assertions; /** SparkTableInfo is used to check the result in test. */ @@ -89,7 +92,7 @@ void addPartition(Transform partition) { } } - static SparkTableInfo create(SparkBaseTable baseTable) { + static SparkTableInfo create(Table baseTable) { SparkTableInfo sparkTableInfo = new SparkTableInfo(); String identifier = baseTable.name(); String[] items = identifier.split("\\."); @@ -98,7 +101,9 @@ static SparkTableInfo create(SparkBaseTable baseTable) { sparkTableInfo.tableName = items[1]; sparkTableInfo.database = items[0]; sparkTableInfo.columns = - Arrays.stream(baseTable.schema().fields()) + // using `baseTable.schema()` directly will failed because the method named `schema` is + // Deprecated in Spark Table interface + Arrays.stream(getSchema(baseTable).fields()) .map( sparkField -> new SparkColumnInfo( @@ -109,14 +114,12 @@ static SparkTableInfo create(SparkBaseTable baseTable) { .collect(Collectors.toList()); sparkTableInfo.comment = baseTable.properties().remove(ConnectorConstants.COMMENT); sparkTableInfo.tableProperties = baseTable.properties(); - boolean supportsBucketPartition = - baseTable.getSparkTransformConverter().isSupportsBucketPartition(); Arrays.stream(baseTable.partitioning()) .forEach( transform -> { if (transform instanceof BucketTransform || transform instanceof SortedBucketTransform) { - if (isBucketPartition(supportsBucketPartition, transform)) { + if (isBucketPartition(baseTable, transform)) { sparkTableInfo.addPartition(transform); } else { sparkTableInfo.setBucket(transform); @@ -149,10 +152,6 @@ static SparkTableInfo create(SparkBaseTable baseTable) { return sparkTableInfo; } - private static boolean isBucketPartition(boolean supportsBucketPartition, Transform transform) { - return supportsBucketPartition && !(transform instanceof SortedBucketTransform); - } - public List getUnPartitionedColumns() { return columns.stream() .filter(column -> !partitionColumnNames.contains(column.name)) @@ -165,6 +164,21 @@ public List getPartitionedColumns() { .collect(Collectors.toList()); } + private static boolean isBucketPartition(Table baseTable, Transform transform) { + return baseTable instanceof SparkIcebergTable && !(transform instanceof SortedBucketTransform); + } + + private static StructType getSchema(Table baseTable) { + if (baseTable instanceof SparkHiveTable) { + return ((SparkHiveTable) baseTable).schema(); + } else if (baseTable instanceof SparkIcebergTable) { + return ((SparkIcebergTable) baseTable).schema(); + } else { + throw new IllegalArgumentException( + "Doesn't support Spark table: " + baseTable.getClass().getName()); + } + } + @Data public static class SparkColumnInfo { private String name; diff --git a/integration-test/src/test/java/com/datastrato/gravitino/integration/test/util/spark/SparkUtilIT.java b/integration-test/src/test/java/com/datastrato/gravitino/integration/test/util/spark/SparkUtilIT.java index 2603fbe8f73..cd55e1205ba 100644 --- a/integration-test/src/test/java/com/datastrato/gravitino/integration/test/util/spark/SparkUtilIT.java +++ b/integration-test/src/test/java/com/datastrato/gravitino/integration/test/util/spark/SparkUtilIT.java @@ -20,7 +20,6 @@ package com.datastrato.gravitino.integration.test.util.spark; import com.datastrato.gravitino.integration.test.util.AbstractIT; -import com.datastrato.gravitino.spark.connector.table.SparkBaseTable; import java.sql.Timestamp; import java.text.SimpleDateFormat; import java.util.Arrays; @@ -130,8 +129,7 @@ protected SparkTableInfo getTableInfo(String tableName) { CommandResult result = (CommandResult) ds.logicalPlan(); DescribeRelation relation = (DescribeRelation) result.commandLogicalPlan(); ResolvedTable table = (ResolvedTable) relation.child(); - SparkBaseTable baseTable = (SparkBaseTable) table.table(); - return SparkTableInfo.create(baseTable); + return SparkTableInfo.create(table.table()); } protected void dropTableIfExists(String tableName) { @@ -159,6 +157,10 @@ protected void insertTableAsSelect(String tableName, String newName) { sql(String.format("INSERT INTO TABLE %s SELECT * FROM %s", newName, tableName)); } + protected static String getSelectAllSqlWithOrder(String tableName, String orderByColumn) { + return String.format("SELECT * FROM %s ORDER BY %s", tableName, orderByColumn); + } + private static String getSelectAllSql(String tableName) { return String.format("SELECT * FROM %s", tableName); } diff --git a/spark-connector/spark-connector/src/main/java/com/datastrato/gravitino/spark/connector/ConnectorConstants.java b/spark-connector/spark-connector/src/main/java/com/datastrato/gravitino/spark/connector/ConnectorConstants.java index 3a49a21470f..9758ff42196 100644 --- a/spark-connector/spark-connector/src/main/java/com/datastrato/gravitino/spark/connector/ConnectorConstants.java +++ b/spark-connector/spark-connector/src/main/java/com/datastrato/gravitino/spark/connector/ConnectorConstants.java @@ -14,6 +14,7 @@ public class ConnectorConstants { public static final String LOCATION = "location"; public static final String DOT = "."; + public static final String COMMA = ","; private ConnectorConstants() {} } diff --git a/spark-connector/spark-connector/src/main/java/com/datastrato/gravitino/spark/connector/SparkTransformConverter.java b/spark-connector/spark-connector/src/main/java/com/datastrato/gravitino/spark/connector/SparkTransformConverter.java index a636699024d..d830af0719d 100644 --- a/spark-connector/spark-connector/src/main/java/com/datastrato/gravitino/spark/connector/SparkTransformConverter.java +++ b/spark-connector/spark-connector/src/main/java/com/datastrato/gravitino/spark/connector/SparkTransformConverter.java @@ -13,7 +13,6 @@ import com.datastrato.gravitino.rel.expressions.sorts.SortOrders; import com.datastrato.gravitino.rel.expressions.transforms.Transform; import com.datastrato.gravitino.rel.expressions.transforms.Transforms; -import com.google.common.annotations.VisibleForTesting; import com.google.common.base.Preconditions; import java.util.ArrayList; import java.util.Arrays; @@ -59,11 +58,6 @@ public SparkTransformConverter(boolean supportsBucketPartition) { this.supportsBucketPartition = supportsBucketPartition; } - @VisibleForTesting - public boolean isSupportsBucketPartition() { - return supportsBucketPartition; - } - @Getter public static class DistributionAndSortOrdersInfo { private Distribution distribution; diff --git a/spark-connector/spark-connector/src/main/java/com/datastrato/gravitino/spark/connector/catalog/BaseCatalog.java b/spark-connector/spark-connector/src/main/java/com/datastrato/gravitino/spark/connector/catalog/BaseCatalog.java index f5994b4ce86..7fc4b20fa25 100644 --- a/spark-connector/spark-connector/src/main/java/com/datastrato/gravitino/spark/connector/catalog/BaseCatalog.java +++ b/spark-connector/spark-connector/src/main/java/com/datastrato/gravitino/spark/connector/catalog/BaseCatalog.java @@ -19,7 +19,6 @@ import com.datastrato.gravitino.spark.connector.SparkTransformConverter; import com.datastrato.gravitino.spark.connector.SparkTransformConverter.DistributionAndSortOrdersInfo; import com.datastrato.gravitino.spark.connector.SparkTypeConverter; -import com.datastrato.gravitino.spark.connector.table.SparkBaseTable; import com.google.common.annotations.VisibleForTesting; import com.google.common.base.Preconditions; import java.util.Arrays; @@ -99,7 +98,7 @@ protected abstract TableCatalog createAndInitSparkCatalog( * Spark * @return a specific Spark table */ - protected abstract SparkBaseTable createSparkTable( + protected abstract Table createSparkTable( Identifier identifier, com.datastrato.gravitino.rel.Table gravitinoTable, TableCatalog sparkCatalog, @@ -184,7 +183,7 @@ public Table createTable( sparkTransformConverter.toGravitinoPartitionings(transforms); try { - com.datastrato.gravitino.rel.Table table = + com.datastrato.gravitino.rel.Table gravitinoTable = gravitinoCatalogClient .asTableCatalog() .createTable( @@ -196,7 +195,7 @@ public Table createTable( distributionAndSortOrdersInfo.getDistribution(), distributionAndSortOrdersInfo.getSortOrders()); return createSparkTable( - ident, table, sparkCatalog, propertiesConverter, sparkTransformConverter); + ident, gravitinoTable, sparkCatalog, propertiesConverter, sparkTransformConverter); } catch (NoSuchSchemaException e) { throw new NoSuchNamespaceException(ident.namespace()); } catch (com.datastrato.gravitino.exceptions.TableAlreadyExistsException e) { @@ -208,13 +207,13 @@ public Table createTable( public Table loadTable(Identifier ident) throws NoSuchTableException { try { String database = getDatabase(ident); - com.datastrato.gravitino.rel.Table table = + com.datastrato.gravitino.rel.Table gravitinoTable = gravitinoCatalogClient .asTableCatalog() .loadTable(NameIdentifier.of(metalakeName, catalogName, database, ident.name())); // Will create a catalog specific table return createSparkTable( - ident, table, sparkCatalog, propertiesConverter, sparkTransformConverter); + ident, gravitinoTable, sparkCatalog, propertiesConverter, sparkTransformConverter); } catch (com.datastrato.gravitino.exceptions.NoSuchTableException e) { throw new NoSuchTableException(ident); } @@ -235,14 +234,14 @@ public Table alterTable(Identifier ident, TableChange... changes) throws NoSuchT .map(BaseCatalog::transformTableChange) .toArray(com.datastrato.gravitino.rel.TableChange[]::new); try { - com.datastrato.gravitino.rel.Table table = + com.datastrato.gravitino.rel.Table gravitinoTable = gravitinoCatalogClient .asTableCatalog() .alterTable( NameIdentifier.of(metalakeName, catalogName, getDatabase(ident), ident.name()), gravitinoTableChanges); return createSparkTable( - ident, table, sparkCatalog, propertiesConverter, sparkTransformConverter); + ident, gravitinoTable, sparkCatalog, propertiesConverter, sparkTransformConverter); } catch (com.datastrato.gravitino.exceptions.NoSuchTableException e) { throw new NoSuchTableException(ident); } @@ -404,7 +403,7 @@ private com.datastrato.gravitino.rel.Column createGravitinoColumn(Column sparkCo com.datastrato.gravitino.rel.Column.DEFAULT_VALUE_NOT_SET); } - private String getDatabase(Identifier sparkIdentifier) { + protected String getDatabase(Identifier sparkIdentifier) { if (sparkIdentifier.namespace().length > 0) { return sparkIdentifier.namespace()[0]; } diff --git a/spark-connector/spark-connector/src/main/java/com/datastrato/gravitino/spark/connector/hive/GravitinoHiveCatalog.java b/spark-connector/spark-connector/src/main/java/com/datastrato/gravitino/spark/connector/hive/GravitinoHiveCatalog.java index 92400437983..cbfd09a4d15 100644 --- a/spark-connector/spark-connector/src/main/java/com/datastrato/gravitino/spark/connector/hive/GravitinoHiveCatalog.java +++ b/spark-connector/spark-connector/src/main/java/com/datastrato/gravitino/spark/connector/hive/GravitinoHiveCatalog.java @@ -9,9 +9,10 @@ import com.datastrato.gravitino.spark.connector.PropertiesConverter; import com.datastrato.gravitino.spark.connector.SparkTransformConverter; import com.datastrato.gravitino.spark.connector.catalog.BaseCatalog; -import com.datastrato.gravitino.spark.connector.table.SparkBaseTable; import java.util.Map; +import org.apache.kyuubi.spark.connector.hive.HiveTable; import org.apache.kyuubi.spark.connector.hive.HiveTableCatalog; +import org.apache.spark.sql.catalyst.analysis.NoSuchTableException; import org.apache.spark.sql.connector.catalog.Identifier; import org.apache.spark.sql.connector.catalog.TableCatalog; import org.apache.spark.sql.util.CaseInsensitiveStringMap; @@ -30,14 +31,29 @@ protected TableCatalog createAndInitSparkCatalog( } @Override - protected SparkBaseTable createSparkTable( + protected org.apache.spark.sql.connector.catalog.Table createSparkTable( Identifier identifier, Table gravitinoTable, - TableCatalog sparkCatalog, + TableCatalog sparkHiveCatalog, PropertiesConverter propertiesConverter, SparkTransformConverter sparkTransformConverter) { + org.apache.spark.sql.connector.catalog.Table sparkTable; + try { + sparkTable = sparkHiveCatalog.loadTable(identifier); + } catch (NoSuchTableException e) { + throw new RuntimeException( + String.format( + "Failed to load the real sparkTable: %s", + String.join(".", getDatabase(identifier), identifier.name())), + e); + } return new SparkHiveTable( - identifier, gravitinoTable, sparkCatalog, propertiesConverter, sparkTransformConverter); + identifier, + gravitinoTable, + (HiveTable) sparkTable, + (HiveTableCatalog) sparkHiveCatalog, + propertiesConverter, + sparkTransformConverter); } @Override diff --git a/spark-connector/spark-connector/src/main/java/com/datastrato/gravitino/spark/connector/hive/SparkHiveTable.java b/spark-connector/spark-connector/src/main/java/com/datastrato/gravitino/spark/connector/hive/SparkHiveTable.java index 91f9468178b..e27916af283 100644 --- a/spark-connector/spark-connector/src/main/java/com/datastrato/gravitino/spark/connector/hive/SparkHiveTable.java +++ b/spark-connector/spark-connector/src/main/java/com/datastrato/gravitino/spark/connector/hive/SparkHiveTable.java @@ -8,23 +8,51 @@ import com.datastrato.gravitino.rel.Table; import com.datastrato.gravitino.spark.connector.PropertiesConverter; import com.datastrato.gravitino.spark.connector.SparkTransformConverter; -import com.datastrato.gravitino.spark.connector.table.SparkBaseTable; +import com.datastrato.gravitino.spark.connector.utils.GravitinoTableInfoHelper; +import java.util.Map; +import org.apache.kyuubi.spark.connector.hive.HiveTable; +import org.apache.kyuubi.spark.connector.hive.HiveTableCatalog; +import org.apache.spark.sql.SparkSession; import org.apache.spark.sql.connector.catalog.Identifier; -import org.apache.spark.sql.connector.catalog.TableCatalog; +import org.apache.spark.sql.connector.expressions.Transform; +import org.apache.spark.sql.types.StructType; + +/** Keep consistent behavior with the SparkIcebergTable */ +public class SparkHiveTable extends HiveTable { + + private GravitinoTableInfoHelper gravitinoTableInfoHelper; -/** May support more capabilities like partition management. */ -public class SparkHiveTable extends SparkBaseTable { public SparkHiveTable( Identifier identifier, Table gravitinoTable, - TableCatalog sparkCatalog, + HiveTable hiveTable, + HiveTableCatalog hiveTableCatalog, PropertiesConverter propertiesConverter, SparkTransformConverter sparkTransformConverter) { - super(identifier, gravitinoTable, sparkCatalog, propertiesConverter, sparkTransformConverter); + super(SparkSession.active(), hiveTable.catalogTable(), hiveTableCatalog); + this.gravitinoTableInfoHelper = + new GravitinoTableInfoHelper( + false, identifier, gravitinoTable, propertiesConverter, sparkTransformConverter); + } + + @Override + public String name() { + return gravitinoTableInfoHelper.name(); + } + + @Override + @SuppressWarnings("deprecation") + public StructType schema() { + return gravitinoTableInfoHelper.schema(); + } + + @Override + public Map properties() { + return gravitinoTableInfoHelper.properties(); } @Override - protected boolean isCaseSensitive() { - return false; + public Transform[] partitioning() { + return gravitinoTableInfoHelper.partitioning(); } } diff --git a/spark-connector/spark-connector/src/main/java/com/datastrato/gravitino/spark/connector/iceberg/GravitinoIcebergCatalog.java b/spark-connector/spark-connector/src/main/java/com/datastrato/gravitino/spark/connector/iceberg/GravitinoIcebergCatalog.java index 4ed21faee5b..d44dd1edb5e 100644 --- a/spark-connector/spark-connector/src/main/java/com/datastrato/gravitino/spark/connector/iceberg/GravitinoIcebergCatalog.java +++ b/spark-connector/spark-connector/src/main/java/com/datastrato/gravitino/spark/connector/iceberg/GravitinoIcebergCatalog.java @@ -9,11 +9,12 @@ import com.datastrato.gravitino.spark.connector.PropertiesConverter; import com.datastrato.gravitino.spark.connector.SparkTransformConverter; import com.datastrato.gravitino.spark.connector.catalog.BaseCatalog; -import com.datastrato.gravitino.spark.connector.table.SparkBaseTable; import java.util.Map; import org.apache.iceberg.spark.SparkCatalog; +import org.apache.iceberg.spark.source.SparkTable; import org.apache.spark.sql.catalyst.analysis.NoSuchFunctionException; import org.apache.spark.sql.catalyst.analysis.NoSuchNamespaceException; +import org.apache.spark.sql.catalyst.analysis.NoSuchTableException; import org.apache.spark.sql.connector.catalog.FunctionCatalog; import org.apache.spark.sql.connector.catalog.Identifier; import org.apache.spark.sql.connector.catalog.TableCatalog; @@ -40,14 +41,29 @@ protected TableCatalog createAndInitSparkCatalog( } @Override - protected SparkBaseTable createSparkTable( + protected org.apache.spark.sql.connector.catalog.Table createSparkTable( Identifier identifier, Table gravitinoTable, - TableCatalog sparkCatalog, + TableCatalog sparkIcebergCatalog, PropertiesConverter propertiesConverter, SparkTransformConverter sparkTransformConverter) { + org.apache.spark.sql.connector.catalog.Table sparkTable; + try { + sparkTable = sparkIcebergCatalog.loadTable(identifier); + } catch (NoSuchTableException e) { + throw new RuntimeException( + String.format( + "Failed to load the real sparkTable: %s", + String.join(".", getDatabase(identifier), identifier.name())), + e); + } return new SparkIcebergTable( - identifier, gravitinoTable, sparkCatalog, propertiesConverter, sparkTransformConverter); + identifier, + gravitinoTable, + (SparkTable) sparkTable, + (SparkCatalog) sparkIcebergCatalog, + propertiesConverter, + sparkTransformConverter); } @Override diff --git a/spark-connector/spark-connector/src/main/java/com/datastrato/gravitino/spark/connector/iceberg/SparkIcebergTable.java b/spark-connector/spark-connector/src/main/java/com/datastrato/gravitino/spark/connector/iceberg/SparkIcebergTable.java index 22dd0bb73a8..870ff535f88 100644 --- a/spark-connector/spark-connector/src/main/java/com/datastrato/gravitino/spark/connector/iceberg/SparkIcebergTable.java +++ b/spark-connector/spark-connector/src/main/java/com/datastrato/gravitino/spark/connector/iceberg/SparkIcebergTable.java @@ -8,43 +8,64 @@ import com.datastrato.gravitino.rel.Table; import com.datastrato.gravitino.spark.connector.PropertiesConverter; import com.datastrato.gravitino.spark.connector.SparkTransformConverter; -import com.datastrato.gravitino.spark.connector.table.SparkBaseTable; +import com.datastrato.gravitino.spark.connector.utils.GravitinoTableInfoHelper; +import java.lang.reflect.Field; +import java.util.Map; +import org.apache.iceberg.spark.SparkCatalog; +import org.apache.iceberg.spark.source.SparkTable; import org.apache.spark.sql.connector.catalog.Identifier; -import org.apache.spark.sql.connector.catalog.MetadataColumn; -import org.apache.spark.sql.connector.catalog.SupportsDelete; -import org.apache.spark.sql.connector.catalog.SupportsMetadataColumns; -import org.apache.spark.sql.connector.catalog.TableCatalog; -import org.apache.spark.sql.sources.Filter; +import org.apache.spark.sql.connector.expressions.Transform; +import org.apache.spark.sql.types.StructType; -public class SparkIcebergTable extends SparkBaseTable - implements SupportsDelete, SupportsMetadataColumns { +/** + * For spark-connector in Iceberg, it explicitly uses SparkTable to identify whether it is an + * Iceberg table, so the SparkIcebergTable must extend SparkTable. + */ +public class SparkIcebergTable extends SparkTable { + + private GravitinoTableInfoHelper gravitinoTableInfoHelper; public SparkIcebergTable( Identifier identifier, Table gravitinoTable, - TableCatalog sparkIcebergCatalog, + SparkTable sparkTable, + SparkCatalog sparkCatalog, PropertiesConverter propertiesConverter, SparkTransformConverter sparkTransformConverter) { - super( - identifier, - gravitinoTable, - sparkIcebergCatalog, - propertiesConverter, - sparkTransformConverter); + super(sparkTable.table(), !isCacheEnabled(sparkCatalog)); + this.gravitinoTableInfoHelper = + new GravitinoTableInfoHelper( + true, identifier, gravitinoTable, propertiesConverter, sparkTransformConverter); } @Override - public boolean canDeleteWhere(Filter[] filters) { - return ((SupportsDelete) getSparkTable()).canDeleteWhere(filters); + public String name() { + return gravitinoTableInfoHelper.name(); } @Override - public void deleteWhere(Filter[] filters) { - ((SupportsDelete) getSparkTable()).deleteWhere(filters); + @SuppressWarnings("deprecation") + public StructType schema() { + return gravitinoTableInfoHelper.schema(); } @Override - public MetadataColumn[] metadataColumns() { - return ((SupportsMetadataColumns) getSparkTable()).metadataColumns(); + public Map properties() { + return gravitinoTableInfoHelper.properties(); + } + + @Override + public Transform[] partitioning() { + return gravitinoTableInfoHelper.partitioning(); + } + + private static boolean isCacheEnabled(SparkCatalog sparkCatalog) { + try { + Field cacheEnabled = sparkCatalog.getClass().getDeclaredField("cacheEnabled"); + cacheEnabled.setAccessible(true); + return cacheEnabled.getBoolean(sparkCatalog); + } catch (NoSuchFieldException | IllegalAccessException e) { + throw new RuntimeException("Failed to get cacheEnabled field from SparkCatalog", e); + } } } diff --git a/spark-connector/spark-connector/src/main/java/com/datastrato/gravitino/spark/connector/plugin/GravitinoDriverPlugin.java b/spark-connector/spark-connector/src/main/java/com/datastrato/gravitino/spark/connector/plugin/GravitinoDriverPlugin.java index 3f830de2cdc..3a80d7a6148 100644 --- a/spark-connector/spark-connector/src/main/java/com/datastrato/gravitino/spark/connector/plugin/GravitinoDriverPlugin.java +++ b/spark-connector/spark-connector/src/main/java/com/datastrato/gravitino/spark/connector/plugin/GravitinoDriverPlugin.java @@ -5,6 +5,9 @@ package com.datastrato.gravitino.spark.connector.plugin; +import static com.datastrato.gravitino.spark.connector.ConnectorConstants.COMMA; +import static com.datastrato.gravitino.spark.connector.utils.ConnectorUtil.removeDuplicateSparkExtensions; + import com.datastrato.gravitino.Catalog; import com.datastrato.gravitino.spark.connector.GravitinoSparkConfig; import com.datastrato.gravitino.spark.connector.catalog.GravitinoCatalogManager; @@ -15,10 +18,12 @@ import java.util.Locale; import java.util.Map; import org.apache.commons.lang3.StringUtils; +import org.apache.iceberg.spark.extensions.IcebergSparkSessionExtensions; import org.apache.spark.SparkConf; import org.apache.spark.SparkContext; import org.apache.spark.api.plugin.DriverPlugin; import org.apache.spark.api.plugin.PluginContext; +import org.apache.spark.sql.internal.StaticSQLConf; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -30,6 +35,8 @@ public class GravitinoDriverPlugin implements DriverPlugin { private static final Logger LOG = LoggerFactory.getLogger(GravitinoDriverPlugin.class); private GravitinoCatalogManager catalogManager; + private static final String[] GRAVITINO_DRIVER_EXTENSIONS = + new String[] {IcebergSparkSessionExtensions.class.getName()}; @Override public Map init(SparkContext sc, PluginContext pluginContext) { @@ -48,7 +55,7 @@ public Map init(SparkContext sc, PluginContext pluginContext) { catalogManager = GravitinoCatalogManager.create(gravitinoUri, metalake); catalogManager.loadRelationalCatalogs(); registerGravitinoCatalogs(conf, catalogManager.getCatalogs()); - registerSqlExtensions(); + registerSqlExtensions(conf); return Collections.emptyMap(); } @@ -103,6 +110,20 @@ private void registerCatalog(SparkConf sparkConf, String catalogName, String pro LOG.info("Register {} catalog to Spark catalog manager.", catalogName); } - // Todo inject Iceberg extensions - private void registerSqlExtensions() {} + private void registerSqlExtensions(SparkConf conf) { + String gravitinoDriverExtensions = String.join(COMMA, GRAVITINO_DRIVER_EXTENSIONS); + if (conf.contains(StaticSQLConf.SPARK_SESSION_EXTENSIONS().key())) { + String sparkSessionExtensions = conf.get(StaticSQLConf.SPARK_SESSION_EXTENSIONS().key()); + if (StringUtils.isNotBlank(sparkSessionExtensions)) { + conf.set( + StaticSQLConf.SPARK_SESSION_EXTENSIONS().key(), + removeDuplicateSparkExtensions( + GRAVITINO_DRIVER_EXTENSIONS, sparkSessionExtensions.split(COMMA))); + } else { + conf.set(StaticSQLConf.SPARK_SESSION_EXTENSIONS().key(), gravitinoDriverExtensions); + } + } else { + conf.set(StaticSQLConf.SPARK_SESSION_EXTENSIONS().key(), gravitinoDriverExtensions); + } + } } diff --git a/spark-connector/spark-connector/src/main/java/com/datastrato/gravitino/spark/connector/utils/ConnectorUtil.java b/spark-connector/spark-connector/src/main/java/com/datastrato/gravitino/spark/connector/utils/ConnectorUtil.java new file mode 100644 index 00000000000..673d6cf0380 --- /dev/null +++ b/spark-connector/spark-connector/src/main/java/com/datastrato/gravitino/spark/connector/utils/ConnectorUtil.java @@ -0,0 +1,27 @@ +/* + * Copyright 2024 Datastrato Pvt Ltd. + * This software is licensed under the Apache License version 2. + */ + +package com.datastrato.gravitino.spark.connector.utils; + +import static com.datastrato.gravitino.spark.connector.ConnectorConstants.COMMA; + +import java.util.Arrays; +import java.util.LinkedHashSet; +import java.util.Set; +import org.apache.commons.lang3.StringUtils; + +public class ConnectorUtil { + + public static String removeDuplicateSparkExtensions( + String[] extensions, String[] addedExtensions) { + Set uniqueElements = new LinkedHashSet<>(Arrays.asList(extensions)); + if (addedExtensions != null && StringUtils.isNoneBlank(addedExtensions)) { + uniqueElements.addAll(Arrays.asList(addedExtensions)); + } + return uniqueElements.stream() + .reduce((element1, element2) -> element1 + COMMA + element2) + .orElse(""); + } +} diff --git a/spark-connector/spark-connector/src/main/java/com/datastrato/gravitino/spark/connector/table/SparkBaseTable.java b/spark-connector/spark-connector/src/main/java/com/datastrato/gravitino/spark/connector/utils/GravitinoTableInfoHelper.java similarity index 65% rename from spark-connector/spark-connector/src/main/java/com/datastrato/gravitino/spark/connector/table/SparkBaseTable.java rename to spark-connector/spark-connector/src/main/java/com/datastrato/gravitino/spark/connector/utils/GravitinoTableInfoHelper.java index d1333135f19..a1ab61021c4 100644 --- a/spark-connector/spark-connector/src/main/java/com/datastrato/gravitino/spark/connector/table/SparkBaseTable.java +++ b/spark-connector/spark-connector/src/main/java/com/datastrato/gravitino/spark/connector/utils/GravitinoTableInfoHelper.java @@ -3,7 +3,7 @@ * This software is licensed under the Apache License version 2. */ -package com.datastrato.gravitino.spark.connector.table; +package com.datastrato.gravitino.spark.connector.utils; import com.datastrato.gravitino.rel.expressions.distributions.Distribution; import com.datastrato.gravitino.rel.expressions.sorts.SortOrder; @@ -11,65 +11,49 @@ import com.datastrato.gravitino.spark.connector.PropertiesConverter; import com.datastrato.gravitino.spark.connector.SparkTransformConverter; import com.datastrato.gravitino.spark.connector.SparkTypeConverter; -import com.google.common.annotations.VisibleForTesting; import java.util.Arrays; import java.util.HashMap; import java.util.List; import java.util.Locale; import java.util.Map; -import java.util.Set; import java.util.stream.Collectors; -import org.apache.spark.sql.catalyst.analysis.NoSuchTableException; import org.apache.spark.sql.connector.catalog.Identifier; -import org.apache.spark.sql.connector.catalog.SupportsRead; -import org.apache.spark.sql.connector.catalog.SupportsWrite; -import org.apache.spark.sql.connector.catalog.Table; -import org.apache.spark.sql.connector.catalog.TableCapability; -import org.apache.spark.sql.connector.catalog.TableCatalog; import org.apache.spark.sql.connector.expressions.Transform; -import org.apache.spark.sql.connector.read.ScanBuilder; -import org.apache.spark.sql.connector.write.LogicalWriteInfo; -import org.apache.spark.sql.connector.write.WriteBuilder; import org.apache.spark.sql.types.DataTypes; import org.apache.spark.sql.types.Metadata; import org.apache.spark.sql.types.MetadataBuilder; import org.apache.spark.sql.types.StructField; import org.apache.spark.sql.types.StructType; -import org.apache.spark.sql.util.CaseInsensitiveStringMap; /** - * Provides schema info from Gravitino, IO from the internal spark table. The specific catalog table - * could implement more capabilities like SupportsPartitionManagement for Hive table, SupportsIndex - * for JDBC table, SupportsRowLevelOperations for Iceberg table. + * GravitinoTableInfoHelper is a common helper class that is used to retrieve table info from the + * Gravitino Server */ -public abstract class SparkBaseTable implements Table, SupportsRead, SupportsWrite { +public class GravitinoTableInfoHelper { + + private boolean isCaseSensitive; private Identifier identifier; private com.datastrato.gravitino.rel.Table gravitinoTable; - private TableCatalog sparkCatalog; - private Table lazySparkTable; private PropertiesConverter propertiesConverter; private SparkTransformConverter sparkTransformConverter; - public SparkBaseTable( + public GravitinoTableInfoHelper( + boolean isCaseSensitive, Identifier identifier, com.datastrato.gravitino.rel.Table gravitinoTable, - TableCatalog sparkCatalog, PropertiesConverter propertiesConverter, SparkTransformConverter sparkTransformConverter) { + this.isCaseSensitive = isCaseSensitive; this.identifier = identifier; this.gravitinoTable = gravitinoTable; - this.sparkCatalog = sparkCatalog; this.propertiesConverter = propertiesConverter; this.sparkTransformConverter = sparkTransformConverter; } - @Override public String name() { return getNormalizedIdentifier(identifier, gravitinoTable.name()); } - @Override - @SuppressWarnings("deprecation") public StructType schema() { List structs = Arrays.stream(gravitinoTable.columns()) @@ -93,7 +77,6 @@ public StructType schema() { return DataTypes.createStructType(structs); } - @Override public Map properties() { Map properties = new HashMap(); if (gravitinoTable.properties() != null) { @@ -110,22 +93,6 @@ public Map properties() { return properties; } - @Override - public Set capabilities() { - return getSparkTable().capabilities(); - } - - @Override - public ScanBuilder newScanBuilder(CaseInsensitiveStringMap options) { - return ((SupportsRead) getSparkTable()).newScanBuilder(options); - } - - @Override - public WriteBuilder newWriteBuilder(LogicalWriteInfo info) { - return ((SupportsWrite) getSparkTable()).newWriteBuilder(info); - } - - @Override public Transform[] partitioning() { com.datastrato.gravitino.rel.expressions.transforms.Transform[] partitions = gravitinoTable.partitioning(); @@ -134,26 +101,10 @@ public Transform[] partitioning() { return sparkTransformConverter.toSparkTransform(partitions, distribution, sortOrders); } - protected Table getSparkTable() { - if (lazySparkTable == null) { - try { - this.lazySparkTable = sparkCatalog.loadTable(identifier); - } catch (NoSuchTableException e) { - throw new RuntimeException(e); - } - } - return lazySparkTable; - } - - @VisibleForTesting public SparkTransformConverter getSparkTransformConverter() { return sparkTransformConverter; } - protected boolean isCaseSensitive() { - return true; - } - // The underlying catalogs may not case-sensitive, to keep consistent with the action of SparkSQL, // we should return normalized identifiers. private String getNormalizedIdentifier(Identifier tableIdentifier, String gravitinoTableName) { @@ -162,7 +113,7 @@ private String getNormalizedIdentifier(Identifier tableIdentifier, String gravit } String databaseName = tableIdentifier.namespace()[0]; - if (isCaseSensitive() == false) { + if (!isCaseSensitive) { databaseName = databaseName.toLowerCase(Locale.ROOT); } diff --git a/spark-connector/spark-connector/src/test/java/com/datastrato/gravitino/spark/connector/utils/TestConnectorUtil.java b/spark-connector/spark-connector/src/test/java/com/datastrato/gravitino/spark/connector/utils/TestConnectorUtil.java new file mode 100644 index 00000000000..4f1d73dd024 --- /dev/null +++ b/spark-connector/spark-connector/src/test/java/com/datastrato/gravitino/spark/connector/utils/TestConnectorUtil.java @@ -0,0 +1,39 @@ +/* + * Copyright 2024 Datastrato Pvt Ltd. + * This software is licensed under the Apache License version 2. + */ + +package com.datastrato.gravitino.spark.connector.utils; + +import static com.datastrato.gravitino.spark.connector.ConnectorConstants.COMMA; + +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.TestInstance; + +@TestInstance(TestInstance.Lifecycle.PER_CLASS) +public class TestConnectorUtil { + + @Test + void testRemoveDuplicateSparkExtensions() { + String[] extensions = {"a", "b", "c"}; + String addedExtensions = "a,d,e"; + String result = + ConnectorUtil.removeDuplicateSparkExtensions(extensions, addedExtensions.split(COMMA)); + Assertions.assertEquals(result, "a,b,c,d,e"); + + extensions = new String[] {"a", "a", "b", "c"}; + addedExtensions = ""; + result = ConnectorUtil.removeDuplicateSparkExtensions(extensions, addedExtensions.split(COMMA)); + Assertions.assertEquals(result, "a,b,c"); + + extensions = new String[] {"a", "a", "b", "c"}; + addedExtensions = "b"; + result = ConnectorUtil.removeDuplicateSparkExtensions(extensions, addedExtensions.split(COMMA)); + Assertions.assertEquals(result, "a,b,c"); + + extensions = new String[] {"a", "a", "b", "c"}; + result = ConnectorUtil.removeDuplicateSparkExtensions(extensions, null); + Assertions.assertEquals(result, "a,b,c"); + } +}