Skip to content

Commit f5ebbe8

Browse files
committed
eliminate duplicated code
1 parent 4a217bc commit f5ebbe8

File tree

1 file changed

+26
-49
lines changed

1 file changed

+26
-49
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala

Lines changed: 26 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -3322,7 +3322,9 @@ case class ArrayUnion(left: Expression, right: Expression) extends ArraySetLike
33223322
def assignInt(array: ArrayData, idx: Int, resultArray: ArrayData, pos: Int): Boolean = {
33233323
val elem = array.getInt(idx)
33243324
if (!hsInt.contains(elem)) {
3325-
resultArray.setInt(pos, elem)
3325+
if (resultArray != null) {
3326+
resultArray.setInt(pos, elem)
3327+
}
33263328
hsInt.add(elem)
33273329
true
33283330
} else {
@@ -3333,7 +3335,9 @@ case class ArrayUnion(left: Expression, right: Expression) extends ArraySetLike
33333335
def assignLong(array: ArrayData, idx: Int, resultArray: ArrayData, pos: Int): Boolean = {
33343336
val elem = array.getLong(idx)
33353337
if (!hsLong.contains(elem)) {
3336-
resultArray.setLong(pos, elem)
3338+
if (resultArray != null) {
3339+
resultArray.setLong(pos, elem)
3340+
}
33373341
hsLong.add(elem)
33383342
true
33393343
} else {
@@ -3344,20 +3348,25 @@ case class ArrayUnion(left: Expression, right: Expression) extends ArraySetLike
33443348
def evalIntLongPrimitiveType(
33453349
array1: ArrayData,
33463350
array2: ArrayData,
3347-
size: Int,
33483351
resultArray: ArrayData,
3349-
isLongType: Boolean): ArrayData = {
3352+
isLongType: Boolean): Int = {
33503353
// store elements into resultArray
3351-
var foundNullElement = false
3354+
var nullElementSize = 0
33523355
var pos = 0
33533356
Seq(array1, array2).foreach(array => {
33543357
var i = 0
33553358
while (i < array.numElements()) {
3359+
val size = if (!isLongType) hsInt.size else hsLong.size
3360+
if (size + nullElementSize > ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH) {
3361+
ArraySetLike.throwUnionLengthOverflowException(size)
3362+
}
33563363
if (array.isNullAt(i)) {
3357-
if (!foundNullElement) {
3358-
resultArray.setNullAt(pos)
3364+
if (nullElementSize == 0) {
3365+
if (resultArray != null) {
3366+
resultArray.setNullAt(pos)
3367+
}
33593368
pos += 1
3360-
foundNullElement = true
3369+
nullElementSize = 1
33613370
}
33623371
} else {
33633372
val assigned = if (!isLongType) {
@@ -3372,7 +3381,7 @@ case class ArrayUnion(left: Expression, right: Expression) extends ArraySetLike
33723381
i += 1
33733382
}
33743383
})
3375-
resultArray
3384+
pos
33763385
}
33773386

33783387
override def nullSafeEval(input1: Any, input2: Any): Any = {
@@ -3384,25 +3393,8 @@ case class ArrayUnion(left: Expression, right: Expression) extends ArraySetLike
33843393
case IntegerType =>
33853394
// avoid boxing of primitive int array elements
33863395
// calculate result array size
3387-
val hsSize = new OpenHashSet[Int]
3388-
var nullElementSize = 0
3389-
Seq(array1, array2).foreach { array =>
3390-
var i = 0
3391-
while (i < array.numElements()) {
3392-
if (hsSize.size + nullElementSize > ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH) {
3393-
ArraySetLike.throwUnionLengthOverflowException(hsSize.size)
3394-
}
3395-
if (array.isNullAt(i)) {
3396-
if (nullElementSize == 0) {
3397-
nullElementSize = 1
3398-
}
3399-
} else {
3400-
hsSize.add(array.getInt(i))
3401-
}
3402-
i += 1
3403-
}
3404-
}
3405-
val elements = hsSize.size + nullElementSize
3396+
hsInt = new OpenHashSet[Int]
3397+
val elements = evalIntLongPrimitiveType(array1, array2, null, false)
34063398
hsInt = new OpenHashSet[Int]
34073399
val resultArray = if (UnsafeArrayData.useGenericArrayData(
34083400
IntegerType.defaultSize, elements)) {
@@ -3411,29 +3403,13 @@ case class ArrayUnion(left: Expression, right: Expression) extends ArraySetLike
34113403
UnsafeArrayData.forPrimitiveArray(
34123404
Platform.INT_ARRAY_OFFSET, elements, IntegerType.defaultSize)
34133405
}
3414-
evalIntLongPrimitiveType(array1, array2, elements, resultArray, false)
3406+
evalIntLongPrimitiveType(array1, array2, resultArray, false)
3407+
resultArray
34153408
case LongType =>
34163409
// avoid boxing of primitive long array elements
34173410
// calculate result array size
3418-
val hsSize = new OpenHashSet[Long]
3419-
var nullElementSize = 0
3420-
Seq(array1, array2).foreach { array =>
3421-
var i = 0
3422-
while (i < array.numElements()) {
3423-
if (hsSize.size + nullElementSize > ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH) {
3424-
ArraySetLike.throwUnionLengthOverflowException(hsSize.size)
3425-
}
3426-
if (array.isNullAt(i)) {
3427-
if (nullElementSize == 0) {
3428-
nullElementSize = 1
3429-
}
3430-
} else {
3431-
hsSize.add(array.getLong(i))
3432-
}
3433-
i += 1
3434-
}
3435-
}
3436-
val elements = hsSize.size + nullElementSize
3411+
hsLong = new OpenHashSet[Long]
3412+
val elements = evalIntLongPrimitiveType(array1, array2, null, true)
34373413
hsLong = new OpenHashSet[Long]
34383414
val resultArray = if (UnsafeArrayData.useGenericArrayData(
34393415
LongType.defaultSize, elements)) {
@@ -3442,7 +3418,8 @@ case class ArrayUnion(left: Expression, right: Expression) extends ArraySetLike
34423418
UnsafeArrayData.forPrimitiveArray(
34433419
Platform.LONG_ARRAY_OFFSET, elements, LongType.defaultSize)
34443420
}
3445-
evalIntLongPrimitiveType(array1, array2, elements, resultArray, true)
3421+
evalIntLongPrimitiveType(array1, array2, resultArray, true)
3422+
resultArray
34463423
case _ =>
34473424
val arrayBuffer = new scala.collection.mutable.ArrayBuffer[Any]
34483425
val hs = new OpenHashSet[Any]

0 commit comments

Comments
 (0)