Skip to content

Commit

Permalink
Quote table name with backticks to escape special characters in spark…
Browse files Browse the repository at this point in the history
….read.table() (#224)

* quote table name for special character in df read

Signed-off-by: Sean Kao <seankao@amazon.com>

* add UT

Signed-off-by: Sean Kao <seankao@amazon.com>

* update comments; scalafmtAll

Signed-off-by: Sean Kao <seankao@amazon.com>

* add a missed table name to quote

Signed-off-by: Sean Kao <seankao@amazon.com>

---------

Signed-off-by: Sean Kao <seankao@amazon.com>
  • Loading branch information
seankao-az authored Jan 12, 2024
1 parent 9e2475f commit 6e163e7
Show file tree
Hide file tree
Showing 6 changed files with 52 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ import org.opensearch.flint.core.{FlintClient, FlintClientBuilder}
import org.opensearch.flint.core.metadata.log.FlintMetadataLogEntry.IndexState._
import org.opensearch.flint.core.metadata.log.OptimisticTransaction.NO_LOG_ENTRY
import org.opensearch.flint.spark.FlintSpark.RefreshMode.{FULL, INCREMENTAL, RefreshMode}
import org.opensearch.flint.spark.FlintSparkIndex.{ID_COLUMN, StreamingRefresh}
import org.opensearch.flint.spark.FlintSparkIndex.{quotedTableName, ID_COLUMN, StreamingRefresh}
import org.opensearch.flint.spark.covering.FlintSparkCoveringIndex
import org.opensearch.flint.spark.mv.FlintSparkMaterializedView
import org.opensearch.flint.spark.skipping.FlintSparkSkippingIndex
Expand Down Expand Up @@ -372,7 +372,7 @@ class FlintSpark(val spark: SparkSession) extends Logging {
logInfo("Start refreshing index in foreach streaming style")
val job = spark.readStream
.options(options.extraSourceOptions(tableName))
.table(tableName)
.table(quotedTableName(tableName))
.writeStream
.queryName(indexName)
.addSinkOptions(options)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,21 @@ object FlintSparkIndex {
s"flint_${parts(0)}_${parts(1)}_${parts.drop(2).mkString(".")}"
}

/**
* Add backticks to table name to escape special character
*
* @param fullTableName
* source full table name
* @return
* quoted table name
*/
def quotedTableName(fullTableName: String): String = {
require(fullTableName.split('.').length >= 3, s"Table name $fullTableName is not qualified")

val parts = fullTableName.split('.')
s"${parts(0)}.${parts(1)}.`${parts.drop(2).mkString(".")}`"
}

/**
* Populate environment variables to persist in Flint metadata.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ import scala.collection.JavaConverters.mapAsJavaMapConverter

import org.opensearch.flint.core.metadata.FlintMetadata
import org.opensearch.flint.spark._
import org.opensearch.flint.spark.FlintSparkIndex.{flintIndexNamePrefix, generateSchemaJSON, metadataBuilder}
import org.opensearch.flint.spark.FlintSparkIndex.{flintIndexNamePrefix, generateSchemaJSON, metadataBuilder, quotedTableName}
import org.opensearch.flint.spark.FlintSparkIndexOptions.empty
import org.opensearch.flint.spark.covering.FlintSparkCoveringIndex.{getFlintIndexName, COVERING_INDEX_TYPE}

Expand Down Expand Up @@ -60,7 +60,7 @@ case class FlintSparkCoveringIndex(

override def build(spark: SparkSession, df: Option[DataFrame]): DataFrame = {
val colNames = indexedColumns.keys.toSeq
val job = df.getOrElse(spark.read.table(tableName))
val job = df.getOrElse(spark.read.table(quotedTableName(tableName)))

// Add optional filtering condition
filterCondition
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ case class FlintSparkSkippingIndex(
new Column(aggFunc.as(name))
}

df.getOrElse(spark.read.table(tableName))
df.getOrElse(spark.read.table(quotedTableName(tableName)))
.groupBy(input_file_name().as(FILE_PATH_COLUMN))
.agg(namedAggFuncs.head, namedAggFuncs.tail: _*)
.withColumn(ID_COLUMN, sha1(col(FILE_PATH_COLUMN)))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

package org.opensearch.flint.spark.covering

import org.scalatest.matchers.must.Matchers.contain
import org.scalatest.matchers.should.Matchers.convertToAnyShouldWrapper

import org.apache.spark.FlintSuite
Expand All @@ -30,6 +31,24 @@ class FlintSparkCoveringIndexSuite extends FlintSuite {
}
}

test("can build index building job with unique ID column") {
val index =
new FlintSparkCoveringIndex("ci", "spark_catalog.default.test", Map("name" -> "string"))

val df = spark.createDataFrame(Seq(("hello", 20))).toDF("name", "age")
val indexDf = index.build(spark, Some(df))
indexDf.schema.fieldNames should contain only ("name")
}

test("can build index on table name with special characters") {
val testTableSpecial = "spark_catalog.default.test/2023/10"
val index = new FlintSparkCoveringIndex("ci", testTableSpecial, Map("name" -> "string"))

val df = spark.createDataFrame(Seq(("hello", 20))).toDF("name", "age")
val indexDf = index.build(spark, Some(df))
indexDf.schema.fieldNames should contain only ("name")
}

test("should fail if no indexed column given") {
assertThrows[IllegalArgumentException] {
new FlintSparkCoveringIndex("ci", "default.test", Map.empty)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,19 @@ class FlintSparkSkippingIndexSuite extends FlintSuite {
indexDf.schema.fieldNames should contain only ("name", FILE_PATH_COLUMN, ID_COLUMN)
}

test("can build index on table name with special characters") {
val testTableSpecial = "spark_catalog.default.test/2023/10"
val indexCol = mock[FlintSparkSkippingStrategy]
when(indexCol.outputSchema()).thenReturn(Map("name" -> "string"))
when(indexCol.getAggregators).thenReturn(
Seq(CollectSet(col("name").expr).toAggregateExpression()))
val index = new FlintSparkSkippingIndex(testTableSpecial, Seq(indexCol))

val df = spark.createDataFrame(Seq(("hello", 20))).toDF("name", "age")
val indexDf = index.build(spark, Some(df))
indexDf.schema.fieldNames should contain only ("name", FILE_PATH_COLUMN, ID_COLUMN)
}

// Test index build for different column type
Seq(
(
Expand Down

0 comments on commit 6e163e7

Please sign in to comment.