Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
9872b1b
Add string locate functionality
miland-db Apr 1, 2024
092b44a
Add tests for StringLocate
miland-db Apr 1, 2024
37e0796
Remove repeated code for getting collationId
miland-db Apr 1, 2024
4bd974c
Improve performance by directly calling methods for UTF8_BINARY_COLLA…
miland-db Apr 1, 2024
84da295
Improve collatedIndexOf
miland-db Apr 1, 2024
88c7338
Improve naming for collation aware methods
miland-db Apr 2, 2024
851e480
Improve java style
miland-db Apr 3, 2024
a35d95e
Merge branch 'master' into string-locate
miland-db Apr 3, 2024
c5ecebf
Improve collationAwareIndexOf, and tests style
miland-db Apr 3, 2024
4e053a5
Improve naming and add doc comment
miland-db Apr 3, 2024
875d0ec
Improve doc comments
miland-db Apr 3, 2024
9771f02
Merge branch 'master' into string-locate
miland-db Apr 4, 2024
a581557
Merge latest master and add StringLocate to CollationTypeCasts transf…
miland-db Apr 4, 2024
c956a92
Add empty lines between imports
miland-db Apr 4, 2024
3ae8c31
Handle all collationIds in getStringSearch
miland-db Apr 4, 2024
b93bcc5
Merge branch 'master' into string-locate
miland-db Apr 12, 2024
5e76c85
Add StringLocate and fix test errors
miland-db Apr 12, 2024
9883254
Break line at 100 chars
miland-db Apr 13, 2024
277e9c6
Add StringLocate to CollationTypeCasts
miland-db Apr 15, 2024
39e3be0
Improve ordering of arguments
miland-db Apr 15, 2024
3ea9b5b
Refactor tests
miland-db Apr 15, 2024
7336897
Merge branch 'master' into string-locate
miland-db Apr 15, 2024
a44c3d2
Break line at 100 chars
miland-db Apr 16, 2024
cab932d
Merge branch 'master' into string-locate
miland-db Apr 17, 2024
fd06d8b
Merge branch 'master' into string-locate
miland-db Apr 17, 2024
782b539
Add new test cases with variable length characters
miland-db Apr 23, 2024
827ff99
Merge branch 'master' into string-locate
miland-db Apr 23, 2024
600da6a
Update expected result for Case-variable character length test
miland-db Apr 24, 2024
2e7e9c3
Merge branch 'master' into string-locate
miland-db Apr 24, 2024
b4292ef
Merge branch 'master' into string-locate
miland-db Apr 25, 2024
9eb7fd9
Merge branch 'master' into string-locate
miland-db Apr 26, 2024
dcfd645
Merge branch 'master' into string-locate
miland-db Apr 26, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -403,6 +403,44 @@ public static UTF8String execICU(final UTF8String src, final UTF8String search,
}
}

public static class StringLocate {
public static int exec(final UTF8String string, final UTF8String substring, final int start,
final int collationId) {
CollationFactory.Collation collation = CollationFactory.fetchCollation(collationId);
if (collation.supportsBinaryEquality) {
return execBinary(string, substring, start);
} else if (collation.supportsLowercaseEquality) {
return execLowercase(string, substring, start);
} else {
return execICU(string, substring, start, collationId);
}
}
public static String genCode(final String string, final String substring, final int start,
final int collationId) {
CollationFactory.Collation collation = CollationFactory.fetchCollation(collationId);
String expr = "CollationSupport.StringLocate.exec";
if (collation.supportsBinaryEquality) {
return String.format(expr + "Binary(%s, %s, %d)", string, substring, start);
} else if (collation.supportsLowercaseEquality) {
return String.format(expr + "Lowercase(%s, %s, %d)", string, substring, start);
} else {
return String.format(expr + "ICU(%s, %s, %d, %d)", string, substring, start, collationId);
}
}
public static int execBinary(final UTF8String string, final UTF8String substring,
final int start) {
return string.indexOf(substring, start);
}
public static int execLowercase(final UTF8String string, final UTF8String substring,
final int start) {
return string.toLowerCase().indexOf(substring.toLowerCase(), start);
}
public static int execICU(final UTF8String string, final UTF8String substring, final int start,
final int collationId) {
return CollationAwareUTF8String.indexOf(string, substring, start, collationId);
}
}

// TODO: Add more collation-aware string expressions.

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -652,6 +652,71 @@ public void testReplace() throws SparkException {
assertReplace("abi̇o12i̇o", "İo", "yy", "UNICODE_CI", "abyy12yy");
}

private void assertLocate(String substring, String string, Integer start, String collationName,
Integer expected) throws SparkException {
UTF8String substr = UTF8String.fromString(substring);
UTF8String str = UTF8String.fromString(string);
int collationId = CollationFactory.collationNameToId(collationName);
assertEquals(expected, CollationSupport.StringLocate.exec(str, substr,
start - 1, collationId) + 1);
}

@Test
public void testLocate() throws SparkException {
// If you add tests with start < 1 be careful to understand the behavior of the indexOf method
// and usage of indexOf in the StringLocate class.
assertLocate("aa", "aaads", 1, "UTF8_BINARY", 1);
assertLocate("aa", "aaads", 2, "UTF8_BINARY", 2);
assertLocate("aa", "aaads", 3, "UTF8_BINARY", 0);
assertLocate("Aa", "aaads", 1, "UTF8_BINARY", 0);
assertLocate("Aa", "aAads", 1, "UTF8_BINARY", 2);
assertLocate("界x", "test大千世界X大千世界", 1, "UTF8_BINARY", 0);
assertLocate("界X", "test大千世界X大千世界", 1, "UTF8_BINARY", 8);
assertLocate("界", "test大千世界X大千世界", 13, "UTF8_BINARY", 13);
assertLocate("AA", "aaads", 1, "UTF8_BINARY_LCASE", 1);
assertLocate("aa", "aAads", 2, "UTF8_BINARY_LCASE", 2);
assertLocate("aa", "aaAds", 3, "UTF8_BINARY_LCASE", 0);
assertLocate("abC", "abcabc", 1, "UTF8_BINARY_LCASE", 1);
assertLocate("abC", "abCabc", 2, "UTF8_BINARY_LCASE", 4);
assertLocate("abc", "abcabc", 4, "UTF8_BINARY_LCASE", 4);
assertLocate("界x", "test大千世界X大千世界", 1, "UTF8_BINARY_LCASE", 8);
assertLocate("界X", "test大千世界Xtest大千世界", 1, "UTF8_BINARY_LCASE", 8);
assertLocate("界", "test大千世界X大千世界", 13, "UTF8_BINARY_LCASE", 13);
assertLocate("大千", "test大千世界大千世界", 1, "UTF8_BINARY_LCASE", 5);
assertLocate("大千", "test大千世界大千世界", 9, "UTF8_BINARY_LCASE", 9);
assertLocate("大千", "大千世界大千世界", 1, "UTF8_BINARY_LCASE", 1);
assertLocate("aa", "Aaads", 1, "UNICODE", 2);
assertLocate("AA", "aaads", 1, "UNICODE", 0);
assertLocate("aa", "aAads", 2, "UNICODE", 0);
assertLocate("aa", "aaAds", 3, "UNICODE", 0);
assertLocate("abC", "abcabc", 1, "UNICODE", 0);
assertLocate("abC", "abCabc", 2, "UNICODE", 0);
assertLocate("abC", "abCabC", 2, "UNICODE", 4);
assertLocate("abc", "abcabc", 1, "UNICODE", 1);
assertLocate("abc", "abcabc", 3, "UNICODE", 4);
assertLocate("界x", "test大千世界X大千世界", 1, "UNICODE", 0);
assertLocate("界X", "test大千世界X大千世界", 1, "UNICODE", 8);
assertLocate("界", "test大千世界X大千世界", 13, "UNICODE", 13);
assertLocate("AA", "aaads", 1, "UNICODE_CI", 1);
assertLocate("aa", "aAads", 2, "UNICODE_CI", 2);
assertLocate("aa", "aaAds", 3, "UNICODE_CI", 0);
assertLocate("abC", "abcabc", 1, "UNICODE_CI", 1);
assertLocate("abC", "abCabc", 2, "UNICODE_CI", 4);
assertLocate("abc", "abcabc", 4, "UNICODE_CI", 4);
assertLocate("界x", "test大千世界X大千世界", 1, "UNICODE_CI", 8);
assertLocate("界", "test大千世界X大千世界", 13, "UNICODE_CI", 13);
assertLocate("大千", "test大千世界大千世界", 1, "UNICODE_CI", 5);
assertLocate("大千", "test大千世界大千世界", 9, "UNICODE_CI", 9);
assertLocate("大千", "大千世界大千世界", 1, "UNICODE_CI", 1);
// Case-variable character length
assertLocate("i̇o", "İo世界大千世界", 1, "UNICODE_CI", 1);
assertLocate("i̇o", "大千İo世界大千世界", 1, "UNICODE_CI", 3);
assertLocate("i̇o", "世界İo大千世界大千İo", 4, "UNICODE_CI", 11);
assertLocate("İo", "i̇o世界大千世界", 1, "UNICODE_CI", 1);
assertLocate("İo", "大千i̇o世界大千世界", 1, "UNICODE_CI", 3);
assertLocate("İo", "世界i̇o大千世界大千i̇o", 4, "UNICODE_CI", 12); // 12 instead of 11
}

// TODO: Test more collation-aware string expressions.

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,10 @@ object CollationTypeCasts extends TypeCoercionRule {
caseWhenExpr.elseValue.map(e => castStringType(e, outputStringType).getOrElse(e))
CaseWhen(newBranches, newElseValue)

case stringLocate: StringLocate =>
stringLocate.withNewChildren(collateToSingleType(
Seq(stringLocate.first, stringLocate.second)) :+ stringLocate.third)

case eltExpr: Elt =>
eltExpr.withNewChildren(eltExpr.children.head +: collateToSingleType(eltExpr.children.tail))

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1457,12 +1457,15 @@ case class StringLocate(substr: Expression, str: Expression, start: Expression)
this(substr, str, Literal(1))
}

final lazy val collationId: Int = first.dataType.asInstanceOf[StringType].collationId

override def first: Expression = substr
override def second: Expression = str
override def third: Expression = start
override def nullable: Boolean = substr.nullable || str.nullable
override def dataType: DataType = IntegerType
override def inputTypes: Seq[DataType] = Seq(StringType, StringType, IntegerType)
override def inputTypes: Seq[AbstractDataType] =
Seq(StringTypeAnyCollation, StringTypeAnyCollation, IntegerType)

override def eval(input: InternalRow): Any = {
val s = start.eval(input)
Expand All @@ -1482,9 +1485,8 @@ case class StringLocate(substr: Expression, str: Expression, start: Expression)
if (sVal < 1) {
0
} else {
l.asInstanceOf[UTF8String].indexOf(
r.asInstanceOf[UTF8String],
s.asInstanceOf[Int] - 1) + 1
CollationSupport.StringLocate.exec(l.asInstanceOf[UTF8String],
r.asInstanceOf[UTF8String], s.asInstanceOf[Int] - 1, collationId) + 1;
}
}
}
Expand All @@ -1505,8 +1507,8 @@ case class StringLocate(substr: Expression, str: Expression, start: Expression)
${strGen.code}
if (!${strGen.isNull}) {
if (${startGen.value} > 0) {
${ev.value} = ${strGen.value}.indexOf(${substrGen.value},
${startGen.value} - 1) + 1;
${ev.value} = CollationSupport.StringLocate.exec(${strGen.value},
${substrGen.value}, ${startGen.value} - 1, $collationId) + 1;
}
} else {
${ev.isNull} = true;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -661,6 +661,40 @@ class CollationStringExpressionsSuite
assert(sql(query).schema.fields.head.dataType.sameType(StringType(0)))
}

test("Support Locate string expression with collation") {
case class StringLocateTestCase[R](substring: String, string: String, start: Integer,
c: String, result: R)
val testCases = Seq(
// scalastyle:off
StringLocateTestCase("aa", "aaads", 0, "UTF8_BINARY", 0),
StringLocateTestCase("aa", "Aaads", 0, "UTF8_BINARY_LCASE", 0),
StringLocateTestCase("界x", "test大千世界X大千世界", 1, "UTF8_BINARY_LCASE", 8),
StringLocateTestCase("aBc", "abcabc", 4, "UTF8_BINARY_LCASE", 4),
StringLocateTestCase("aa", "Aaads", 0, "UNICODE", 0),
StringLocateTestCase("abC", "abCabC", 2, "UNICODE", 4),
StringLocateTestCase("aa", "Aaads", 0, "UNICODE_CI", 0),
StringLocateTestCase("界x", "test大千世界X大千世界", 1, "UNICODE_CI", 8)
// scalastyle:on
)
testCases.foreach(t => {
val query = s"SELECT locate(collate('${t.substring}','${t.c}')," +
s"collate('${t.string}','${t.c}'),${t.start})"
// Result & data type
checkAnswer(sql(query), Row(t.result))
assert(sql(query).schema.fields.head.dataType.sameType(IntegerType))
// Implicit casting
checkAnswer(sql(s"SELECT locate(collate('${t.substring}','${t.c}')," +
s"'${t.string}',${t.start})"), Row(t.result))
checkAnswer(sql(s"SELECT locate('${t.substring}',collate('${t.string}'," +
s"'${t.c}'),${t.start})"), Row(t.result))
})
// Collation mismatch
val collationMismatch = intercept[AnalysisException] {
sql("SELECT locate(collate('aBc', 'UTF8_BINARY'),collate('abcabc', 'UTF8_BINARY_LCASE'),4)")
}
assert(collationMismatch.getErrorClass === "COLLATION_MISMATCH.EXPLICIT")
}

// TODO: Add more tests for other string expressions

}
Expand Down