Skip to content

Commit f2b3525

Browse files
maropugatorsmile
authored andcommitted
[SPARK-22771][SQL] Concatenate binary inputs into a binary output
## What changes were proposed in this pull request? This pr modified `concat` to concat binary inputs into a single binary output. `concat` in the current master always output data as a string. But, in some databases (e.g., PostgreSQL), if all inputs are binary, `concat` also outputs binary. ## How was this patch tested? Added tests in `SQLQueryTestSuite` and `TypeCoercionSuite`. Author: Takeshi Yamamuro <yamamuro@apache.org> Closes #19977 from maropu/SPARK-22771.
1 parent 2ea17af commit f2b3525

File tree

16 files changed

+587
-20
lines changed

16 files changed

+587
-20
lines changed

R/pkg/R/functions.R

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2133,7 +2133,8 @@ setMethod("countDistinct",
21332133
})
21342134

21352135
#' @details
2136-
#' \code{concat}: Concatenates multiple input string columns together into a single string column.
2136+
#' \code{concat}: Concatenates multiple input columns together into a single column.
2137+
#' If all inputs are binary, concat returns an output as binary. Otherwise, it returns as string.
21372138
#'
21382139
#' @rdname column_string_functions
21392140
#' @aliases concat concat,Column-method

common/unsafe/src/main/java/org/apache/spark/unsafe/types/ByteArray.java

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,4 +74,29 @@ public static byte[] subStringSQL(byte[] bytes, int pos, int len) {
7474
}
7575
return Arrays.copyOfRange(bytes, start, end);
7676
}
77+
78+
public static byte[] concat(byte[]... inputs) {
79+
// Compute the total length of the result
80+
int totalLength = 0;
81+
for (int i = 0; i < inputs.length; i++) {
82+
if (inputs[i] != null) {
83+
totalLength += inputs[i].length;
84+
} else {
85+
return null;
86+
}
87+
}
88+
89+
// Allocate a new byte array, and copy the inputs one by one into it
90+
final byte[] result = new byte[totalLength];
91+
int offset = 0;
92+
for (int i = 0; i < inputs.length; i++) {
93+
int len = inputs[i].length;
94+
Platform.copyMemory(
95+
inputs[i], Platform.BYTE_ARRAY_OFFSET,
96+
result, Platform.BYTE_ARRAY_OFFSET + offset,
97+
len);
98+
offset += len;
99+
}
100+
return result;
101+
}
77102
}

docs/sql-programming-guide.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1780,6 +1780,8 @@ options.
17801780

17811781
- Since Spark 2.3, when either broadcast hash join or broadcast nested loop join is applicable, we prefer to broadcasting the table that is explicitly specified in a broadcast hint. For details, see the section [Broadcast Hint](#broadcast-hint-for-sql-queries) and [SPARK-22489](https://issues.apache.org/jira/browse/SPARK-22489).
17821782

1783+
- Since Spark 2.3, when all inputs are binary, `functions.concat()` returns an output as binary. Otherwise, it returns as a string. Until Spark 2.3, it always returns as a string despite of input types. To keep the old behavior, set `spark.sql.function.concatBinaryAsString` to `true`.
1784+
17831785
## Upgrading From Spark SQL 2.1 to 2.2
17841786

17851787
- Spark 2.1.1 introduced a new configuration key: `spark.sql.hive.caseSensitiveInferenceMode`. It had a default setting of `NEVER_INFER`, which kept behavior identical to 2.1.0. However, Spark 2.2.0 changes this setting's default value to `INFER_AND_SAVE` to restore compatibility with reading Hive metastore tables whose underlying file schema have mixed-case column names. With the `INFER_AND_SAVE` configuration value, on first access Spark will perform schema inference on any Hive metastore table for which it has not already saved an inferred schema. Note that schema inference can be a very time consuming operation for tables with thousands of partitions. If compatibility with mixed-case column names is not a concern, you can safely set `spark.sql.hive.caseSensitiveInferenceMode` to `NEVER_INFER` to avoid the initial overhead of schema inference. Note that with the new default `INFER_AND_SAVE` setting, the results of the schema inference are saved as a metastore key for future use. Therefore, the initial schema inference occurs only at a table's first access.

python/pyspark/sql/functions.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1374,7 +1374,8 @@ def hash(*cols):
13741374
@ignore_unicode_prefix
13751375
def concat(*cols):
13761376
"""
1377-
Concatenates multiple input string columns together into a single string column.
1377+
Concatenates multiple input columns together into a single column.
1378+
If all inputs are binary, concat returns an output as binary. Otherwise, it returns as string.
13781379
13791380
>>> df = spark.createDataFrame([('abcd','123')], ['s', 'd'])
13801381
>>> df.select(concat(df.s, df.d).alias('s')).collect()

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -150,7 +150,7 @@ class Analyzer(
150150
TimeWindowing ::
151151
ResolveInlineTables(conf) ::
152152
ResolveTimeZone(conf) ::
153-
TypeCoercion.typeCoercionRules ++
153+
TypeCoercion.typeCoercionRules(conf) ++
154154
extendedResolutionRules : _*),
155155
Batch("Post-Hoc Resolution", Once, postHocResolutionRules: _*),
156156
Batch("View", Once,

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ import org.apache.spark.sql.catalyst.expressions._
2727
import org.apache.spark.sql.catalyst.expressions.aggregate._
2828
import org.apache.spark.sql.catalyst.plans.logical._
2929
import org.apache.spark.sql.catalyst.rules.Rule
30+
import org.apache.spark.sql.internal.SQLConf
3031
import org.apache.spark.sql.types._
3132

3233

@@ -45,13 +46,14 @@ import org.apache.spark.sql.types._
4546
*/
4647
object TypeCoercion {
4748

48-
val typeCoercionRules =
49+
def typeCoercionRules(conf: SQLConf): List[Rule[LogicalPlan]] =
4950
InConversion ::
5051
WidenSetOperationTypes ::
5152
PromoteStrings ::
5253
DecimalPrecision ::
5354
BooleanEquality ::
5455
FunctionArgumentConversion ::
56+
ConcatCoercion(conf) ::
5557
CaseWhenCoercion ::
5658
IfCoercion ::
5759
StackCoercion ::
@@ -660,6 +662,28 @@ object TypeCoercion {
660662
}
661663
}
662664

665+
/**
666+
* Coerces the types of [[Concat]] children to expected ones.
667+
*
668+
* If `spark.sql.function.concatBinaryAsString` is false and all children types are binary,
669+
* the expected types are binary. Otherwise, the expected ones are strings.
670+
*/
671+
case class ConcatCoercion(conf: SQLConf) extends TypeCoercionRule {
672+
673+
override protected def coerceTypes(plan: LogicalPlan): LogicalPlan = plan transform { case p =>
674+
p transformExpressionsUp {
675+
// Skip nodes if unresolved or empty children
676+
case c @ Concat(children) if !c.childrenResolved || children.isEmpty => c
677+
case c @ Concat(children) if conf.concatBinaryAsString ||
678+
!children.map(_.dataType).forall(_ == BinaryType) =>
679+
val newChildren = c.children.map { e =>
680+
ImplicitTypeCasts.implicitCast(e, StringType).getOrElse(e)
681+
}
682+
c.copy(children = newChildren)
683+
}
684+
}
685+
}
686+
663687
/**
664688
* Turns Add/Subtract of DateType/TimestampType/StringType and CalendarIntervalType
665689
* to TimeAdd/TimeSub

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala

Lines changed: 41 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -24,11 +24,10 @@ import java.util.regex.Pattern
2424

2525
import scala.collection.mutable.ArrayBuffer
2626

27-
import org.apache.spark.sql.AnalysisException
2827
import org.apache.spark.sql.catalyst.InternalRow
2928
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
3029
import org.apache.spark.sql.catalyst.expressions.codegen._
31-
import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData}
30+
import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData, TypeUtils}
3231
import org.apache.spark.sql.types._
3332
import org.apache.spark.unsafe.types.{ByteArray, UTF8String}
3433

@@ -38,7 +37,8 @@ import org.apache.spark.unsafe.types.{ByteArray, UTF8String}
3837

3938

4039
/**
41-
* An expression that concatenates multiple input strings into a single string.
40+
* An expression that concatenates multiple inputs into a single output.
41+
* If all inputs are binary, concat returns an output as binary. Otherwise, it returns as string.
4242
* If any input is null, concat returns null.
4343
*/
4444
@ExpressionDescription(
@@ -48,17 +48,37 @@ import org.apache.spark.unsafe.types.{ByteArray, UTF8String}
4848
> SELECT _FUNC_('Spark', 'SQL');
4949
SparkSQL
5050
""")
51-
case class Concat(children: Seq[Expression]) extends Expression with ImplicitCastInputTypes {
51+
case class Concat(children: Seq[Expression]) extends Expression {
5252

53-
override def inputTypes: Seq[AbstractDataType] = Seq.fill(children.size)(StringType)
54-
override def dataType: DataType = StringType
53+
private lazy val isBinaryMode: Boolean = dataType == BinaryType
54+
55+
override def checkInputDataTypes(): TypeCheckResult = {
56+
if (children.isEmpty) {
57+
TypeCheckResult.TypeCheckSuccess
58+
} else {
59+
val childTypes = children.map(_.dataType)
60+
if (childTypes.exists(tpe => !Seq(StringType, BinaryType).contains(tpe))) {
61+
TypeCheckResult.TypeCheckFailure(
62+
s"input to function $prettyName should have StringType or BinaryType, but it's " +
63+
childTypes.map(_.simpleString).mkString("[", ", ", "]"))
64+
}
65+
TypeUtils.checkForSameTypeInputExpr(childTypes, s"function $prettyName")
66+
}
67+
}
68+
69+
override def dataType: DataType = children.map(_.dataType).headOption.getOrElse(StringType)
5570

5671
override def nullable: Boolean = children.exists(_.nullable)
5772
override def foldable: Boolean = children.forall(_.foldable)
5873

5974
override def eval(input: InternalRow): Any = {
60-
val inputs = children.map(_.eval(input).asInstanceOf[UTF8String])
61-
UTF8String.concat(inputs : _*)
75+
if (isBinaryMode) {
76+
val inputs = children.map(_.eval(input).asInstanceOf[Array[Byte]])
77+
ByteArray.concat(inputs: _*)
78+
} else {
79+
val inputs = children.map(_.eval(input).asInstanceOf[UTF8String])
80+
UTF8String.concat(inputs : _*)
81+
}
6282
}
6383

6484
override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
@@ -73,17 +93,27 @@ case class Concat(children: Seq[Expression]) extends Expression with ImplicitCas
7393
}
7494
"""
7595
}
96+
97+
val (concatenator, initCode) = if (isBinaryMode) {
98+
(classOf[ByteArray].getName, s"byte[][] $args = new byte[${evals.length}][];")
99+
} else {
100+
("UTF8String", s"UTF8String[] $args = new UTF8String[${evals.length}];")
101+
}
76102
val codes = ctx.splitExpressionsWithCurrentInputs(
77103
expressions = inputs,
78104
funcName = "valueConcat",
79-
extraArguments = ("UTF8String[]", args) :: Nil)
105+
extraArguments = (s"${ctx.javaType(dataType)}[]", args) :: Nil)
80106
ev.copy(s"""
81-
UTF8String[] $args = new UTF8String[${evals.length}];
107+
$initCode
82108
$codes
83-
UTF8String ${ev.value} = UTF8String.concat($args);
109+
${ctx.javaType(dataType)} ${ev.value} = $concatenator.concat($args);
84110
boolean ${ev.isNull} = ${ev.value} == null;
85111
""")
86112
}
113+
114+
override def toString: String = s"concat(${children.mkString(", ")})"
115+
116+
override def sql: String = s"concat(${children.map(_.sql).mkString(", ")})"
87117
}
88118

89119

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ import scala.collection.immutable.HashSet
2121
import scala.collection.mutable.{ArrayBuffer, Stack}
2222

2323
import org.apache.spark.sql.catalyst.analysis._
24+
import org.apache.spark.sql.catalyst.analysis.TypeCoercion.ImplicitTypeCasts
2425
import org.apache.spark.sql.catalyst.expressions._
2526
import org.apache.spark.sql.catalyst.expressions.Literal.{FalseLiteral, TrueLiteral}
2627
import org.apache.spark.sql.catalyst.expressions.aggregate._
@@ -645,15 +646,27 @@ object CombineConcats extends Rule[LogicalPlan] {
645646
stack.pop() match {
646647
case Concat(children) =>
647648
stack.pushAll(children.reverse)
649+
// If `spark.sql.function.concatBinaryAsString` is false, nested `Concat` exprs possibly
650+
// have `Concat`s with binary output. Since `TypeCoercion` casts them into strings,
651+
// we need to handle the case to combine all nested `Concat`s.
652+
case c @ Cast(Concat(children), StringType, _) =>
653+
val newChildren = children.map { e => c.copy(child = e) }
654+
stack.pushAll(newChildren.reverse)
648655
case child =>
649656
flattened += child
650657
}
651658
}
652659
Concat(flattened)
653660
}
654661

662+
private def hasNestedConcats(concat: Concat): Boolean = concat.children.exists {
663+
case c: Concat => true
664+
case c @ Cast(Concat(children), StringType, _) => true
665+
case _ => false
666+
}
667+
655668
def apply(plan: LogicalPlan): LogicalPlan = plan.transformExpressionsDown {
656-
case concat: Concat if concat.children.exists(_.isInstanceOf[Concat]) =>
669+
case concat: Concat if hasNestedConcats(concat) =>
657670
flattenConcats(concat)
658671
}
659672
}

sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1044,6 +1044,12 @@ object SQLConf {
10441044
"When this conf is not set, the value from `spark.redaction.string.regex` is used.")
10451045
.fallbackConf(org.apache.spark.internal.config.STRING_REDACTION_PATTERN)
10461046

1047+
val CONCAT_BINARY_AS_STRING = buildConf("spark.sql.function.concatBinaryAsString")
1048+
.doc("When this option is set to false and all inputs are binary, `functions.concat` returns " +
1049+
"an output as binary. Otherwise, it returns as a string. ")
1050+
.booleanConf
1051+
.createWithDefault(false)
1052+
10471053
val CONTINUOUS_STREAMING_EXECUTOR_QUEUE_SIZE =
10481054
buildConf("spark.sql.streaming.continuous.executorQueueSize")
10491055
.internal()
@@ -1378,6 +1384,8 @@ class SQLConf extends Serializable with Logging {
13781384
def continuousStreamingExecutorPollIntervalMs: Long =
13791385
getConf(CONTINUOUS_STREAMING_EXECUTOR_POLL_INTERVAL_MS)
13801386

1387+
def concatBinaryAsString: Boolean = getConf(CONCAT_BINARY_AS_STRING)
1388+
13811389
/** ********************** SQLConf functionality methods ************ */
13821390

13831391
/** Set Spark SQL configuration properties. */

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -869,6 +869,60 @@ class TypeCoercionSuite extends AnalysisTest {
869869
Literal.create(null, IntegerType), Literal.create(null, StringType))))
870870
}
871871

872+
test("type coercion for Concat") {
873+
val rule = TypeCoercion.ConcatCoercion(conf)
874+
875+
ruleTest(rule,
876+
Concat(Seq(Literal("ab"), Literal("cde"))),
877+
Concat(Seq(Literal("ab"), Literal("cde"))))
878+
ruleTest(rule,
879+
Concat(Seq(Literal(null), Literal("abc"))),
880+
Concat(Seq(Cast(Literal(null), StringType), Literal("abc"))))
881+
ruleTest(rule,
882+
Concat(Seq(Literal(1), Literal("234"))),
883+
Concat(Seq(Cast(Literal(1), StringType), Literal("234"))))
884+
ruleTest(rule,
885+
Concat(Seq(Literal("1"), Literal("234".getBytes()))),
886+
Concat(Seq(Literal("1"), Cast(Literal("234".getBytes()), StringType))))
887+
ruleTest(rule,
888+
Concat(Seq(Literal(1L), Literal(2.toByte), Literal(0.1))),
889+
Concat(Seq(Cast(Literal(1L), StringType), Cast(Literal(2.toByte), StringType),
890+
Cast(Literal(0.1), StringType))))
891+
ruleTest(rule,
892+
Concat(Seq(Literal(true), Literal(0.1f), Literal(3.toShort))),
893+
Concat(Seq(Cast(Literal(true), StringType), Cast(Literal(0.1f), StringType),
894+
Cast(Literal(3.toShort), StringType))))
895+
ruleTest(rule,
896+
Concat(Seq(Literal(1L), Literal(0.1))),
897+
Concat(Seq(Cast(Literal(1L), StringType), Cast(Literal(0.1), StringType))))
898+
ruleTest(rule,
899+
Concat(Seq(Literal(Decimal(10)))),
900+
Concat(Seq(Cast(Literal(Decimal(10)), StringType))))
901+
ruleTest(rule,
902+
Concat(Seq(Literal(BigDecimal.valueOf(10)))),
903+
Concat(Seq(Cast(Literal(BigDecimal.valueOf(10)), StringType))))
904+
ruleTest(rule,
905+
Concat(Seq(Literal(java.math.BigDecimal.valueOf(10)))),
906+
Concat(Seq(Cast(Literal(java.math.BigDecimal.valueOf(10)), StringType))))
907+
ruleTest(rule,
908+
Concat(Seq(Literal(new java.sql.Date(0)), Literal(new Timestamp(0)))),
909+
Concat(Seq(Cast(Literal(new java.sql.Date(0)), StringType),
910+
Cast(Literal(new Timestamp(0)), StringType))))
911+
912+
withSQLConf("spark.sql.function.concatBinaryAsString" -> "true") {
913+
ruleTest(rule,
914+
Concat(Seq(Literal("123".getBytes), Literal("456".getBytes))),
915+
Concat(Seq(Cast(Literal("123".getBytes), StringType),
916+
Cast(Literal("456".getBytes), StringType))))
917+
}
918+
919+
withSQLConf("spark.sql.function.concatBinaryAsString" -> "false") {
920+
ruleTest(rule,
921+
Concat(Seq(Literal("123".getBytes), Literal("456".getBytes))),
922+
Concat(Seq(Literal("123".getBytes), Literal("456".getBytes))))
923+
}
924+
}
925+
872926
test("BooleanEquality type cast") {
873927
val be = TypeCoercion.BooleanEquality
874928
// Use something more than a literal to avoid triggering the simplification rules.

0 commit comments

Comments
 (0)