Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
f68c929
Collation support for: contains, startswith, endswith
uros-db Feb 22, 2024
d6e4c7e
Remove unused import
uros-db Feb 22, 2024
e566fdc
Remove unused test cases
uros-db Feb 22, 2024
593abe4
Remove unused test cases
uros-db Feb 22, 2024
eeacec1
Move collation support to UTF8String class
uros-db Feb 22, 2024
a9f8932
Fix import formatting
uros-db Feb 22, 2024
865d0cc
Update error handling
uros-db Feb 26, 2024
f37c73b
New error classes
uros-db Feb 27, 2024
5c462a7
Refactoring function names
uros-db Feb 27, 2024
eb2bfe8
Fix scalastyle
uros-db Feb 28, 2024
bdaf100
Fix scalastyle
uros-db Feb 28, 2024
1bb302f
Refactoring function names
uros-db Feb 28, 2024
4f39b15
Regenerate golden files
uros-db Feb 28, 2024
0b735a2
Merge branch 'master' into string-functions
uros-db Mar 1, 2024
bd761bd
Avoid unneeded changes
uros-db Mar 1, 2024
5196a2e
Fix function name in error
uros-db Mar 1, 2024
4bf7c9a
Merge branch 'apache:master' into string-functions
uros-db Mar 1, 2024
4a36b17
Check collationIds only once in analysis
uros-db Mar 1, 2024
7ccf429
Move tests in suite
uros-db Mar 1, 2024
6f2b100
Update comments
uros-db Mar 1, 2024
68e5b0f
Make collationId a field of StringPredicate
uros-db Mar 1, 2024
2f1a175
Make COLLATION_MISMATCH a subclass of DATATYPE_MISMATCH
uros-db Mar 1, 2024
0a8d0ac
Fix tests and error handling in analysis
uros-db Mar 4, 2024
fb33d44
Remove deprecated error class
uros-db Mar 4, 2024
c993fd8
Merge branch 'master' into string-functions
uros-db Mar 4, 2024
8cd6287
Lazy collationId eval in StringPredicate
uros-db Mar 4, 2024
deed8d9
Fix merge conflict error
uros-db Mar 4, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,9 @@ public Collation(
private static final Collation[] collationTable = new Collation[4];
private static final HashMap<String, Integer> collationNameToIdMap = new HashMap<>();

public static final int DEFAULT_COLLATION_ID = 0;
public static final int LOWERCASE_COLLATION_ID = 1;

static {
// Binary comparison. This is the default collation.
// No custom comparators will be used for this collation.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import java.nio.ByteBuffer;
import java.nio.charset.StandardCharsets;
import java.util.Arrays;
import java.util.HashMap;
import java.util.Map;
import java.util.regex.Pattern;

Expand All @@ -30,6 +31,7 @@
import com.esotericsoftware.kryo.io.Input;
import com.esotericsoftware.kryo.io.Output;

import org.apache.spark.SparkException;
import org.apache.spark.sql.catalyst.util.CollationFactory;
import org.apache.spark.unsafe.Platform;
import org.apache.spark.unsafe.UTF8StringBuilder;
Expand Down Expand Up @@ -341,6 +343,21 @@ public boolean contains(final UTF8String substring) {
return false;
}

public boolean contains(final UTF8String substring, int collationId) throws SparkException {
if (CollationFactory.fetchCollation(collationId).isBinaryCollation) {
return this.contains(substring);
}
if (collationId == CollationFactory.LOWERCASE_COLLATION_ID) {
return this.toLowerCase().contains(substring.toLowerCase());
}
// TODO: enable ICU collation support for "contains" (SPARK-47248)
Map<String, String> params = new HashMap<>();
params.put("functionName", "contains");
params.put("collationName", CollationFactory.fetchCollation(collationId).collationName);
throw new SparkException("UNSUPPORTED_COLLATION.FOR_FUNCTION",
SparkException.constructMessageParams(params), null);
}

/**
* Returns the byte at position `i`.
*/
Expand All @@ -355,14 +372,41 @@ public boolean matchAt(final UTF8String s, int pos) {
return ByteArrayMethods.arrayEquals(base, offset + pos, s.base, s.offset, s.numBytes);
}

private boolean matchAt(final UTF8String s, int pos, int collationId) {
if (s.numBytes + pos > numBytes || pos < 0) {
return false;
}
return this.substring(pos, pos + s.numBytes).semanticCompare(s, collationId) == 0;
}

public boolean startsWith(final UTF8String prefix) {
return matchAt(prefix, 0);
}

public boolean startsWith(final UTF8String prefix, int collationId) {
if (CollationFactory.fetchCollation(collationId).isBinaryCollation) {
return this.startsWith(prefix);
}
if (collationId == CollationFactory.LOWERCASE_COLLATION_ID) {
return this.toLowerCase().startsWith(prefix.toLowerCase());
}
return matchAt(prefix, 0, collationId);
}

public boolean endsWith(final UTF8String suffix) {
return matchAt(suffix, numBytes - suffix.numBytes);
}

public boolean endsWith(final UTF8String suffix, int collationId) {
if (CollationFactory.fetchCollation(collationId).isBinaryCollation) {
return this.endsWith(suffix);
}
if (collationId == CollationFactory.LOWERCASE_COLLATION_ID) {
return this.toLowerCase().endsWith(suffix.toLowerCase());
}
return matchAt(suffix, numBytes - suffix.numBytes, collationId);
}

/**
* Returns the upper case of this string
*/
Expand Down
18 changes: 18 additions & 0 deletions common/utils/src/main/resources/error/error-classes.json
Original file line number Diff line number Diff line change
Expand Up @@ -690,6 +690,11 @@
"To convert values from <srcType> to <targetType>, you can use the functions <functionNames> instead."
]
},
"COLLATION_MISMATCH" : {
"message" : [
"Collations <collationNameLeft> and <collationNameRight> are not compatible. Please use the same collation for both strings."
]
},
"CREATE_MAP_KEY_DIFF_TYPES" : {
"message" : [
"The given keys of function <functionName> should all be the same type, but they are <dataType>."
Expand Down Expand Up @@ -3756,6 +3761,19 @@
],
"sqlState" : "0A000"
},
"UNSUPPORTED_COLLATION" : {
"message" : [
"Collation <collationName> is not supported for:"
],
"subClass" : {
"FOR_FUNCTION" : {
"message" : [
"function <functionName>. Please try to use a different collation."
]
}
},
"sqlState" : "0A000"
},
"UNSUPPORTED_DATASOURCE_FOR_DIRECT_QUERY" : {
"message" : [
"Unsupported data source type for direct query on files: <dataSourceType>"
Expand Down
4 changes: 4 additions & 0 deletions docs/sql-error-conditions-datatype-mismatch-error-class.md
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,10 @@ If you have to cast `<srcType>` to `<targetType>`, you can set `<config>` as `<c
cannot cast `<srcType>` to `<targetType>`.
To convert values from `<srcType>` to `<targetType>`, you can use the functions `<functionNames>` instead.

## COLLATION_MISMATCH

Collations `<collationNameLeft>` and `<collationNameRight>` are not compatible. Please use the same collation for both strings.

## CREATE_MAP_KEY_DIFF_TYPES

The given keys of function `<functionName>` should all be the same type, but they are `<dataType>`.
Expand Down
37 changes: 37 additions & 0 deletions docs/sql-error-conditions-unsupported-collation-error-class.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
---
layout: global
title: UNSUPPORTED_COLLATION error class
displayTitle: UNSUPPORTED_COLLATION error class
license: |
Licensed to the Apache Software Foundation (ASF) under one or more
contributor license agreements. See the NOTICE file distributed with
this work for additional information regarding copyright ownership.
The ASF licenses this file to You under the Apache License, Version 2.0
(the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
---

<!--
DO NOT EDIT THIS FILE.
It was generated automatically by `org.apache.spark.SparkThrowableSuite`.
-->

[SQLSTATE: 0A000](sql-error-conditions-sqlstates.html#class-0A-feature-not-supported)

Collation `<collationName>` is not supported for:

This error class has the following derived error classes:

## FOR_FUNCTION

function `<functionName>`. Please try to use a different collation.


8 changes: 8 additions & 0 deletions docs/sql-error-conditions.md
Original file line number Diff line number Diff line change
Expand Up @@ -2450,6 +2450,14 @@ For more details see [UNSUPPORTED_CALL](sql-error-conditions-unsupported-call-er
The char/varchar type can't be used in the table schema.
If you want Spark treat them as string type as same as Spark 3.0 and earlier, please set "spark.sql.legacy.charVarcharAsString" to "true".

### [UNSUPPORTED_COLLATION](sql-error-conditions-unsupported-collation-error-class.html)

[SQLSTATE: 0A000](sql-error-conditions-sqlstates.html#class-0A-feature-not-supported)

Collation `<collationName>` is not supported for:

For more details see [UNSUPPORTED_COLLATION](sql-error-conditions-unsupported-collation-error-class.html)

### UNSUPPORTED_DATASOURCE_FOR_DIRECT_QUERY

[SQLSTATE: 0A000](sql-error-conditions-sqlstates.html#class-0A-feature-not-supported)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ class StringType private(val collationId: Int) extends AtomicType with Serializa
/**
* Returns whether assigned collation is the default spark collation (UCS_BASIC).
*/
def isDefaultCollation: Boolean = collationId == StringType.DEFAULT_COLLATION_ID
def isDefaultCollation: Boolean = collationId == CollationFactory.DEFAULT_COLLATION_ID

/**
* Binary collation implies that strings are considered equal only if they are
Expand Down Expand Up @@ -69,6 +69,5 @@ class StringType private(val collationId: Int) extends AtomicType with Serializa
*/
@Stable
case object StringType extends StringType(0) {
val DEFAULT_COLLATION_ID = 0
def apply(collationId: Int): StringType = new StringType(collationId)
}
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ import java.util.{HashMap, Locale, Map => JMap}

import scala.collection.mutable.ArrayBuffer

import org.apache.spark.QueryContext
import org.apache.spark.{QueryContext, SparkException}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.{ExpressionBuilder, FunctionRegistry, TypeCheckResult}
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.DataTypeMismatch
Expand All @@ -34,7 +34,7 @@ import org.apache.spark.sql.catalyst.expressions.codegen.Block._
import org.apache.spark.sql.catalyst.expressions.objects.StaticInvoke
import org.apache.spark.sql.catalyst.trees.BinaryLike
import org.apache.spark.sql.catalyst.trees.TreePattern.{TreePattern, UPPER_OR_LOWER}
import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData, TypeUtils}
import org.apache.spark.sql.catalyst.util.{ArrayData, CollationFactory, GenericArrayData, TypeUtils}
import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._
Expand Down Expand Up @@ -497,10 +497,32 @@ case class Lower(child: Expression)
abstract class StringPredicate extends BinaryExpression
with Predicate with ImplicitCastInputTypes with NullIntolerant {

final lazy val collationId: Int = left.dataType.asInstanceOf[StringType].collationId

def compare(l: UTF8String, r: UTF8String): Boolean

override def inputTypes: Seq[DataType] = Seq(StringType, StringType)

override def checkInputDataTypes(): TypeCheckResult = {
val checkResult = super.checkInputDataTypes()
if (checkResult.isFailure) {
return checkResult
}
// Additional check needed for collation compatibility
val rightCollationId: Int = right.dataType.asInstanceOf[StringType].collationId
if (collationId != rightCollationId) {
DataTypeMismatch(
errorSubClass = "COLLATION_MISMATCH",
messageParameters = Map(
"collationNameLeft" -> CollationFactory.fetchCollation(collationId).collationName,
"collationNameRight" -> CollationFactory.fetchCollation(rightCollationId).collationName
)
)
} else {
TypeCheckResult.TypeCheckSuccess
}
}

protected override def nullSafeEval(input1: Any, input2: Any): Any =
compare(input1.asInstanceOf[UTF8String], input2.asInstanceOf[UTF8String])

Expand Down Expand Up @@ -586,9 +608,38 @@ object ContainsExpressionBuilder extends StringBinaryPredicateExpressionBuilderB
}

case class Contains(left: Expression, right: Expression) extends StringPredicate {
override def compare(l: UTF8String, r: UTF8String): Boolean = l.contains(r)
override def checkInputDataTypes(): TypeCheckResult = {
val checkResult = super.checkInputDataTypes()
if (checkResult.isFailure) {
return checkResult
}
// Additional check needed for collation support
if (!CollationFactory.fetchCollation(collationId).isBinaryCollation
&& collationId != CollationFactory.LOWERCASE_COLLATION_ID) {
throw new SparkException(
errorClass = "UNSUPPORTED_COLLATION.FOR_FUNCTION",
messageParameters = Map(
"functionName" -> "contains",
"collationName" -> CollationFactory.fetchCollation(collationId).collationName),
cause = null
)
} else {
TypeCheckResult.TypeCheckSuccess
}
}
override def compare(l: UTF8String, r: UTF8String): Boolean = {
if (CollationFactory.fetchCollation(collationId).isBinaryCollation) {
l.contains(r)
} else {
l.contains(r, collationId)
}
}
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
defineCodeGen(ctx, ev, (c1, c2) => s"($c1).contains($c2)")
if (CollationFactory.fetchCollation(collationId).isBinaryCollation) {
defineCodeGen(ctx, ev, (c1, c2) => s"$c1.contains($c2)")
} else {
defineCodeGen(ctx, ev, (c1, c2) => s"$c1.contains($c2, $collationId)")
}
}
override protected def withNewChildrenInternal(
newLeft: Expression, newRight: Expression): Contains = copy(left = newLeft, right = newRight)
Expand Down Expand Up @@ -623,9 +674,20 @@ object StartsWithExpressionBuilder extends StringBinaryPredicateExpressionBuilde
}

case class StartsWith(left: Expression, right: Expression) extends StringPredicate {
override def compare(l: UTF8String, r: UTF8String): Boolean = l.startsWith(r)
override def compare(l: UTF8String, r: UTF8String): Boolean = {
if (CollationFactory.fetchCollation(collationId).isBinaryCollation) {
l.startsWith(r)
} else {
l.startsWith(r, collationId)
}
}

override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
defineCodeGen(ctx, ev, (c1, c2) => s"($c1).startsWith($c2)")
if (CollationFactory.fetchCollation(collationId).isBinaryCollation) {
defineCodeGen(ctx, ev, (c1, c2) => s"$c1.startsWith($c2)")
} else {
defineCodeGen(ctx, ev, (c1, c2) => s"$c1.startsWith($c2, $collationId)")
}
}
override protected def withNewChildrenInternal(
newLeft: Expression, newRight: Expression): StartsWith = copy(left = newLeft, right = newRight)
Expand Down Expand Up @@ -660,9 +722,20 @@ object EndsWithExpressionBuilder extends StringBinaryPredicateExpressionBuilderB
}

case class EndsWith(left: Expression, right: Expression) extends StringPredicate {
override def compare(l: UTF8String, r: UTF8String): Boolean = l.endsWith(r)
override def compare(l: UTF8String, r: UTF8String): Boolean = {
if (CollationFactory.fetchCollation(collationId).isBinaryCollation) {
l.endsWith(r)
} else {
l.endsWith(r, collationId)
}
}

override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
defineCodeGen(ctx, ev, (c1, c2) => s"($c1).endsWith($c2)")
if (CollationFactory.fetchCollation(collationId).isBinaryCollation) {
defineCodeGen(ctx, ev, (c1, c2) => s"$c1.endsWith($c2)")
} else {
defineCodeGen(ctx, ev, (c1, c2) => s"$c1.endsWith($c2, $collationId)")
}
}
override protected def withNewChildrenInternal(
newLeft: Expression, newRight: Expression): EndsWith = copy(left = newLeft, right = newRight)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,8 @@ object PhysicalDataType {
case ShortType => PhysicalShortType
case IntegerType => PhysicalIntegerType
case LongType => PhysicalLongType
case VarcharType(_) => PhysicalStringType(StringType.DEFAULT_COLLATION_ID)
case CharType(_) => PhysicalStringType(StringType.DEFAULT_COLLATION_ID)
case VarcharType(_) => PhysicalStringType(CollationFactory.DEFAULT_COLLATION_ID)
case CharType(_) => PhysicalStringType(CollationFactory.DEFAULT_COLLATION_ID)
case s: StringType => PhysicalStringType(s.collationId)
case FloatType => PhysicalFloatType
case DoubleType => PhysicalDoubleType
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import scala.annotation.tailrec
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.types.{PhysicalArrayType, PhysicalBinaryType, PhysicalBooleanType, PhysicalByteType, PhysicalCalendarIntervalType, PhysicalDataType, PhysicalDecimalType, PhysicalDoubleType, PhysicalFloatType, PhysicalIntegerType, PhysicalLongType, PhysicalMapType, PhysicalNullType, PhysicalShortType, PhysicalStringType, PhysicalStructType}
import org.apache.spark.sql.catalyst.util.CollationFactory
import org.apache.spark.sql.errors.ExecutionErrors
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.Platform
Expand Down Expand Up @@ -492,7 +493,7 @@ private[columnar] trait DirectCopyColumnType[JvmType] extends ColumnType[JvmType
}

private[columnar] object STRING
extends NativeColumnType(PhysicalStringType(StringType.DEFAULT_COLLATION_ID), 8)
extends NativeColumnType(PhysicalStringType(CollationFactory.DEFAULT_COLLATION_ID), 8)
with DirectCopyColumnType[UTF8String] {

override def actualSize(row: InternalRow, ordinal: Int): Int = {
Expand Down
Loading