Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
3140ebf
Fix escaping issue for mysql
mihailomilosevic2001 May 7, 2024
1fda363
Remove unrelated changes
mihailomilosevic2001 May 7, 2024
81b6459
Remove unrelated changes
mihailomilosevic2001 May 7, 2024
339e8a9
Fix tests
mihailomilosevic2001 May 8, 2024
b6ebebe
Remove unused imports
mihailomilosevic2001 May 8, 2024
4b7e3f5
Fix ' escaping
mihailomilosevic2001 May 8, 2024
e117da9
Add MySQL tests
mihailomilosevic2001 May 8, 2024
dfbc210
Fix line length
mihailomilosevic2001 May 8, 2024
9062037
Add tests for expression pushdown to different JDBCs
mihailomilosevic2001 May 9, 2024
d88fcbb
Move ' escaping to LiteralValue
mihailomilosevic2001 May 10, 2024
dc3c91b
Remove unused imports
mihailomilosevic2001 May 10, 2024
4be5786
Move tests to use v2 push down
mihailomilosevic2001 May 13, 2024
7f70b67
Revert unnecessary changes
mihailomilosevic2001 May 13, 2024
13f6a81
Fix test
mihailomilosevic2001 May 13, 2024
9d07698
Fix closing bracket
mihailomilosevic2001 May 13, 2024
b3e32f5
conn -> connection
mihailomilosevic2001 May 13, 2024
45fcb60
Rename variable
mihailomilosevic2001 May 14, 2024
3310d09
Merge remote-tracking branch 'upstream/master' into SPARK-48172
mihailomilosevic2001 May 15, 2024
4d485ee
Fix tests
mihailomilosevic2001 May 15, 2024
36c6f02
Fix spark escape of '
mihailomilosevic2001 May 15, 2024
afbeaae
Fix test
mihailomilosevic2001 May 15, 2024
950f9f0
Fix tests to include namespace
mihailomilosevic2001 May 15, 2024
1c20c91
Fix caseConvert in tests
mihailomilosevic2001 May 15, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,12 @@ class DB2IntegrationSuite extends DockerJDBCIntegrationV2Suite with V2JDBCTest {
connection.prepareStatement(
"CREATE TABLE employee (dept INTEGER, name VARCHAR(10), salary DECIMAL(20, 2), bonus DOUBLE)")
.executeUpdate()
connection.prepareStatement(
s"""CREATE TABLE pattern_testing_table (
|pattern_testing_col VARCHAR(50)
|)
""".stripMargin
).executeUpdate()
}

override def testUpdateColumnType(tbl: String): Unit = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,17 @@ abstract class DockerJDBCIntegrationV2Suite extends DockerJDBCIntegrationSuite {
.executeUpdate()
connection.prepareStatement("INSERT INTO employee VALUES (6, 'jen', 12000, 1200)")
.executeUpdate()

connection.prepareStatement(
s"""
|INSERT INTO pattern_testing_table VALUES
|('special_character_quote''_present'),
|('special_character_quote_not_present'),
|('special_character_percent%_present'),
|('special_character_percent_not_present'),
|('special_character_underscore_present'),
|('special_character_underscorenot_present')
""".stripMargin).executeUpdate()
}

def tablePreparation(connection: Connection): Unit
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,12 @@ class MsSqlServerIntegrationSuite extends DockerJDBCIntegrationV2Suite with V2JD
connection.prepareStatement(
"CREATE TABLE employee (dept INT, name VARCHAR(32), salary NUMERIC(20, 2), bonus FLOAT)")
.executeUpdate()
connection.prepareStatement(
s"""CREATE TABLE pattern_testing_table (
|pattern_testing_col VARCHAR(50)
|)
""".stripMargin
).executeUpdate()
}

override def notSupportsTableComment: Boolean = true
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,12 @@ class MySQLIntegrationSuite extends DockerJDBCIntegrationV2Suite with V2JDBCTest
connection.prepareStatement(
"CREATE TABLE employee (dept INT, name VARCHAR(32), salary DECIMAL(20, 2)," +
" bonus DOUBLE)").executeUpdate()
connection.prepareStatement(
s"""CREATE TABLE pattern_testing_table (
|pattern_testing_col LONGTEXT
|)
""".stripMargin
).executeUpdate()
}

override def testUpdateColumnType(tbl: String): Unit = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,12 @@ class OracleIntegrationSuite extends DockerJDBCIntegrationV2Suite with V2JDBCTes
connection.prepareStatement(
"CREATE TABLE employee (dept NUMBER(32), name VARCHAR2(32), salary NUMBER(20, 2)," +
" bonus BINARY_DOUBLE)").executeUpdate()
connection.prepareStatement(
s"""CREATE TABLE pattern_testing_table (
|pattern_testing_col VARCHAR(50)
|)
""".stripMargin
).executeUpdate()
}

override def testUpdateColumnType(tbl: String): Unit = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,12 @@ class PostgresIntegrationSuite extends DockerJDBCIntegrationV2Suite with V2JDBCT
connection.prepareStatement(
"CREATE TABLE employee (dept INTEGER, name VARCHAR(32), salary NUMERIC(20, 2)," +
" bonus double precision)").executeUpdate()
connection.prepareStatement(
s"""CREATE TABLE pattern_testing_table (
|pattern_testing_col VARCHAR(50)
|)
""".stripMargin
).executeUpdate()
}

override def testUpdateColumnType(tbl: String): Unit = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -359,6 +359,235 @@ private[v2] trait V2JDBCTest extends SharedSparkSession with DockerIntegrationFu
assert(scan.schema.names.sameElements(Seq(col)))
}

test("SPARK-48172: Test CONTAINS") {
val df1 = spark.sql(
s"""
|SELECT * FROM $catalogAndNamespace.${caseConvert("pattern_testing_table")}
|WHERE contains(pattern_testing_col, 'quote\\'')""".stripMargin)
df1.explain("formatted")
val rows1 = df1.collect()
assert(rows1.length === 1)
assert(rows1(0).getString(0) === "special_character_quote'_present")

val df2 = spark.sql(
s"""SELECT * FROM $catalogAndNamespace.${caseConvert("pattern_testing_table")}
|WHERE contains(pattern_testing_col, 'percent%')""".stripMargin)
val rows2 = df2.collect()
assert(rows2.length === 1)
assert(rows2(0).getString(0) === "special_character_percent%_present")

val df3 = spark.
sql(
s"""SELECT * FROM $catalogAndNamespace.${caseConvert("pattern_testing_table")}
|WHERE contains(pattern_testing_col, 'underscore_')""".stripMargin)
val rows3 = df3.collect()
assert(rows3.length === 1)
assert(rows3(0).getString(0) === "special_character_underscore_present")

val df4 = spark.
sql(
s"""SELECT * FROM $catalogAndNamespace.${caseConvert("pattern_testing_table")}
|WHERE contains(pattern_testing_col, 'character')
|ORDER BY pattern_testing_col""".stripMargin)
val rows4 = df4.collect()
assert(rows4.length === 6)
assert(rows4(0).getString(0) === "special_character_percent%_present")
assert(rows4(1).getString(0) === "special_character_percent_not_present")
assert(rows4(2).getString(0) === "special_character_quote'_present")
assert(rows4(3).getString(0) === "special_character_quote_not_present")
assert(rows4(4).getString(0) === "special_character_underscore_present")
assert(rows4(5).getString(0) === "special_character_underscorenot_present")
}

test("SPARK-48172: Test ENDSWITH") {
val df1 = spark.sql(
s"""SELECT * FROM $catalogAndNamespace.${caseConvert("pattern_testing_table")}
|WHERE endswith(pattern_testing_col, 'quote\\'_present')""".stripMargin)
val rows1 = df1.collect()
assert(rows1.length === 1)
assert(rows1(0).getString(0) === "special_character_quote'_present")

val df2 = spark.sql(
s"""SELECT * FROM $catalogAndNamespace.${caseConvert("pattern_testing_table")}
|WHERE endswith(pattern_testing_col, 'percent%_present')""".stripMargin)
val rows2 = df2.collect()
assert(rows2.length === 1)
assert(rows2(0).getString(0) === "special_character_percent%_present")

val df3 = spark.
sql(
s"""SELECT * FROM $catalogAndNamespace.${caseConvert("pattern_testing_table")}
|WHERE endswith(pattern_testing_col, 'underscore_present')""".stripMargin)
val rows3 = df3.collect()
assert(rows3.length === 1)
assert(rows3(0).getString(0) === "special_character_underscore_present")

val df4 = spark.
sql(
s"""SELECT * FROM $catalogAndNamespace.${caseConvert("pattern_testing_table")}
|WHERE endswith(pattern_testing_col, 'present')
|ORDER BY pattern_testing_col""".stripMargin)
val rows4 = df4.collect()
assert(rows4.length === 6)
assert(rows4(0).getString(0) === "special_character_percent%_present")
assert(rows4(1).getString(0) === "special_character_percent_not_present")
assert(rows4(2).getString(0) === "special_character_quote'_present")
assert(rows4(3).getString(0) === "special_character_quote_not_present")
assert(rows4(4).getString(0) === "special_character_underscore_present")
assert(rows4(5).getString(0) === "special_character_underscorenot_present")
}

test("SPARK-48172: Test STARTSWITH") {
val df1 = spark.sql(
s"""SELECT * FROM $catalogAndNamespace.${caseConvert("pattern_testing_table")}
|WHERE startswith(pattern_testing_col, 'special_character_quote\\'')""".stripMargin)
val rows1 = df1.collect()
assert(rows1.length === 1)
assert(rows1(0).getString(0) === "special_character_quote'_present")

val df2 = spark.sql(
s"""SELECT * FROM $catalogAndNamespace.${caseConvert("pattern_testing_table")}
|WHERE startswith(pattern_testing_col, 'special_character_percent%')""".stripMargin)
val rows2 = df2.collect()
assert(rows2.length === 1)
assert(rows2(0).getString(0) === "special_character_percent%_present")

val df3 = spark.
sql(
s"""SELECT * FROM $catalogAndNamespace.${caseConvert("pattern_testing_table")}
|WHERE startswith(pattern_testing_col, 'special_character_underscore_')""".stripMargin)
val rows3 = df3.collect()
assert(rows3.length === 1)
assert(rows3(0).getString(0) === "special_character_underscore_present")

val df4 = spark.
sql(
s"""SELECT * FROM $catalogAndNamespace.${caseConvert("pattern_testing_table")}
|WHERE startswith(pattern_testing_col, 'special_character')
|ORDER BY pattern_testing_col""".stripMargin)
val rows4 = df4.collect()
assert(rows4.length === 6)
assert(rows4(0).getString(0) === "special_character_percent%_present")
assert(rows4(1).getString(0) === "special_character_percent_not_present")
assert(rows4(2).getString(0) === "special_character_quote'_present")
assert(rows4(3).getString(0) === "special_character_quote_not_present")
assert(rows4(4).getString(0) === "special_character_underscore_present")
assert(rows4(5).getString(0) === "special_character_underscorenot_present")
}

test("SPARK-48172: Test LIKE") {
// this one should map to contains
val df1 = spark.sql(
s"""SELECT * FROM $catalogAndNamespace.${caseConvert("pattern_testing_table")}
|WHERE pattern_testing_col LIKE '%quote\\'%'""".stripMargin)
val rows1 = df1.collect()
assert(rows1.length === 1)
assert(rows1(0).getString(0) === "special_character_quote'_present")

val df2 = spark.sql(
s"""SELECT * FROM $catalogAndNamespace.${caseConvert("pattern_testing_table")}
|WHERE pattern_testing_col LIKE '%percent\\%%'""".stripMargin)
val rows2 = df2.collect()
assert(rows2.length === 1)
assert(rows2(0).getString(0) === "special_character_percent%_present")

val df3 = spark.
sql(
s"""SELECT * FROM $catalogAndNamespace.${caseConvert("pattern_testing_table")}
|WHERE pattern_testing_col LIKE '%underscore\\_%'""".stripMargin)
val rows3 = df3.collect()
assert(rows3.length === 1)
assert(rows3(0).getString(0) === "special_character_underscore_present")

val df4 = spark.
sql(
s"""SELECT * FROM $catalogAndNamespace.${caseConvert("pattern_testing_table")}
|WHERE pattern_testing_col LIKE '%character%'
|ORDER BY pattern_testing_col""".stripMargin)
val rows4 = df4.collect()
assert(rows4.length === 6)
assert(rows4(0).getString(0) === "special_character_percent%_present")
assert(rows4(1).getString(0) === "special_character_percent_not_present")
assert(rows4(2).getString(0) === "special_character_quote'_present")
assert(rows4(3).getString(0) === "special_character_quote_not_present")
assert(rows4(4).getString(0) === "special_character_underscore_present")
assert(rows4(5).getString(0) === "special_character_underscorenot_present")

// map to startsWith
// this one should map to contains
val df5 = spark.sql(
s"""SELECT * FROM $catalogAndNamespace.${caseConvert("pattern_testing_table")}
|WHERE pattern_testing_col LIKE 'special_character_quote\\'%'""".stripMargin)
val rows5 = df5.collect()
assert(rows5.length === 1)
assert(rows5(0).getString(0) === "special_character_quote'_present")

val df6 = spark.sql(
s"""SELECT * FROM $catalogAndNamespace.${caseConvert("pattern_testing_table")}
|WHERE pattern_testing_col LIKE 'special_character_percent\\%%'""".stripMargin)
val rows6 = df6.collect()
assert(rows6.length === 1)
assert(rows6(0).getString(0) === "special_character_percent%_present")

val df7 = spark.
sql(
s"""SELECT * FROM $catalogAndNamespace.${caseConvert("pattern_testing_table")}
|WHERE pattern_testing_col LIKE 'special_character_underscore\\_%'""".stripMargin)
val rows7 = df7.collect()
assert(rows7.length === 1)
assert(rows7(0).getString(0) === "special_character_underscore_present")

val df8 = spark.
sql(
s"""SELECT * FROM $catalogAndNamespace.${caseConvert("pattern_testing_table")}
|WHERE pattern_testing_col LIKE 'special_character%'
|ORDER BY pattern_testing_col""".stripMargin)
val rows8 = df8.collect()
assert(rows8.length === 6)
assert(rows8(0).getString(0) === "special_character_percent%_present")
assert(rows8(1).getString(0) === "special_character_percent_not_present")
assert(rows8(2).getString(0) === "special_character_quote'_present")
assert(rows8(3).getString(0) === "special_character_quote_not_present")
assert(rows8(4).getString(0) === "special_character_underscore_present")
assert(rows8(5).getString(0) === "special_character_underscorenot_present")
// map to endsWith
// this one should map to contains
val df9 = spark.sql(
s"""SELECT * FROM $catalogAndNamespace.${caseConvert("pattern_testing_table")}
|WHERE pattern_testing_col LIKE '%quote\\'_present'""".stripMargin)
val rows9 = df9.collect()
assert(rows9.length === 1)
assert(rows9(0).getString(0) === "special_character_quote'_present")

val df10 = spark.sql(
s"""SELECT * FROM $catalogAndNamespace.${caseConvert("pattern_testing_table")}
|WHERE pattern_testing_col LIKE '%percent\\%_present'""".stripMargin)
val rows10 = df10.collect()
assert(rows10.length === 1)
assert(rows10(0).getString(0) === "special_character_percent%_present")

val df11 = spark.
sql(
s"""SELECT * FROM $catalogAndNamespace.${caseConvert("pattern_testing_table")}
|WHERE pattern_testing_col LIKE '%underscore\\_present'""".stripMargin)
val rows11 = df11.collect()
assert(rows11.length === 1)
assert(rows11(0).getString(0) === "special_character_underscore_present")

val df12 = spark.
sql(
s"""SELECT * FROM $catalogAndNamespace.${caseConvert("pattern_testing_table")}
|WHERE pattern_testing_col LIKE '%present' ORDER BY pattern_testing_col""".stripMargin)
val rows12 = df12.collect()
assert(rows12.length === 6)
assert(rows12(0).getString(0) === "special_character_percent%_present")
assert(rows12(1).getString(0) === "special_character_percent_not_present")
assert(rows12(2).getString(0) === "special_character_quote'_present")
assert(rows12(3).getString(0) === "special_character_quote_not_present")
assert(rows12(4).getString(0) === "special_character_underscore_present")
assert(rows12(5).getString(0) === "special_character_underscorenot_present")
}

test("SPARK-37038: Test TABLESAMPLE") {
if (supportsTableSample) {
withTable(s"$catalogName.new_table") {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,6 @@ protected String escapeSpecialCharsForLikePattern(String str) {
switch (c) {
case '_' -> builder.append("\\_");
case '%' -> builder.append("\\%");
case '\'' -> builder.append("\\\'");
default -> builder.append(c);
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@

package org.apache.spark.sql.connector.expressions

import org.apache.commons.lang3.StringUtils

import org.apache.spark.SparkException
import org.apache.spark.sql.catalyst
import org.apache.spark.sql.catalyst.parser.CatalystSqlParser
Expand Down Expand Up @@ -388,7 +390,7 @@ private[sql] object HoursTransform {
private[sql] final case class LiteralValue[T](value: T, dataType: DataType) extends Literal[T] {
override def toString: String = {
if (dataType.isInstanceOf[StringType]) {
s"'$value'"
s"'${StringUtils.replace(s"$value", "'", "''")}'"
} else {
s"$value"
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -259,13 +259,6 @@ private[sql] case class H2Dialect() extends JdbcDialect {
}

class H2SQLBuilder extends JDBCSQLBuilder {
override def escapeSpecialCharsForLikePattern(str: String): String = {
str.map {
case '_' => "\\_"
case '%' => "\\%"
case c => c.toString
}.mkString
}

override def visitAggregateFunction(
funcName: String, isDistinct: Boolean, inputs: Array[String]): String =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,21 @@ private case class MySQLDialect() extends JdbcDialect with SQLConfHelper {
}
}

override def visitStartsWith(l: String, r: String): String = {
val value = r.substring(1, r.length() - 1)
s"$l LIKE '${escapeSpecialCharsForLikePattern(value)}%' ESCAPE '\\\\'"
}

override def visitEndsWith(l: String, r: String): String = {
val value = r.substring(1, r.length() - 1)
s"$l LIKE '%${escapeSpecialCharsForLikePattern(value)}' ESCAPE '\\\\'"
}

override def visitContains(l: String, r: String): String = {
val value = r.substring(1, r.length() - 1)
s"$l LIKE '%${escapeSpecialCharsForLikePattern(value)}%' ESCAPE '\\\\'"
}

override def visitAggregateFunction(
funcName: String, isDistinct: Boolean, inputs: Array[String]): String =
if (isDistinct && distinctUnsupportedAggregateFunctions.contains(funcName)) {
Expand Down
Loading