@@ -3322,7 +3322,9 @@ case class ArrayUnion(left: Expression, right: Expression) extends ArraySetLike
33223322 def assignInt (array : ArrayData , idx : Int , resultArray : ArrayData , pos : Int ): Boolean = {
33233323 val elem = array.getInt(idx)
33243324 if (! hsInt.contains(elem)) {
3325- resultArray.setInt(pos, elem)
3325+ if (resultArray != null ) {
3326+ resultArray.setInt(pos, elem)
3327+ }
33263328 hsInt.add(elem)
33273329 true
33283330 } else {
@@ -3333,7 +3335,9 @@ case class ArrayUnion(left: Expression, right: Expression) extends ArraySetLike
33333335 def assignLong (array : ArrayData , idx : Int , resultArray : ArrayData , pos : Int ): Boolean = {
33343336 val elem = array.getLong(idx)
33353337 if (! hsLong.contains(elem)) {
3336- resultArray.setLong(pos, elem)
3338+ if (resultArray != null ) {
3339+ resultArray.setLong(pos, elem)
3340+ }
33373341 hsLong.add(elem)
33383342 true
33393343 } else {
@@ -3344,20 +3348,25 @@ case class ArrayUnion(left: Expression, right: Expression) extends ArraySetLike
33443348 def evalIntLongPrimitiveType (
33453349 array1 : ArrayData ,
33463350 array2 : ArrayData ,
3347- size : Int ,
33483351 resultArray : ArrayData ,
3349- isLongType : Boolean ): ArrayData = {
3352+ isLongType : Boolean ): Int = {
33503353 // store elements into resultArray
3351- var foundNullElement = false
3354+ var nullElementSize = 0
33523355 var pos = 0
33533356 Seq (array1, array2).foreach(array => {
33543357 var i = 0
33553358 while (i < array.numElements()) {
3359+ val size = if (! isLongType) hsInt.size else hsLong.size
3360+ if (size + nullElementSize > ByteArrayMethods .MAX_ROUNDED_ARRAY_LENGTH ) {
3361+ ArraySetLike .throwUnionLengthOverflowException(size)
3362+ }
33563363 if (array.isNullAt(i)) {
3357- if (! foundNullElement) {
3358- resultArray.setNullAt(pos)
3364+ if (nullElementSize == 0 ) {
3365+ if (resultArray != null ) {
3366+ resultArray.setNullAt(pos)
3367+ }
33593368 pos += 1
3360- foundNullElement = true
3369+ nullElementSize = 1
33613370 }
33623371 } else {
33633372 val assigned = if (! isLongType) {
@@ -3372,7 +3381,7 @@ case class ArrayUnion(left: Expression, right: Expression) extends ArraySetLike
33723381 i += 1
33733382 }
33743383 })
3375- resultArray
3384+ pos
33763385 }
33773386
33783387 override def nullSafeEval (input1 : Any , input2 : Any ): Any = {
@@ -3384,25 +3393,8 @@ case class ArrayUnion(left: Expression, right: Expression) extends ArraySetLike
33843393 case IntegerType =>
33853394 // avoid boxing of primitive int array elements
33863395 // calculate result array size
3387- val hsSize = new OpenHashSet [Int ]
3388- var nullElementSize = 0
3389- Seq (array1, array2).foreach { array =>
3390- var i = 0
3391- while (i < array.numElements()) {
3392- if (hsSize.size + nullElementSize > ByteArrayMethods .MAX_ROUNDED_ARRAY_LENGTH ) {
3393- ArraySetLike .throwUnionLengthOverflowException(hsSize.size)
3394- }
3395- if (array.isNullAt(i)) {
3396- if (nullElementSize == 0 ) {
3397- nullElementSize = 1
3398- }
3399- } else {
3400- hsSize.add(array.getInt(i))
3401- }
3402- i += 1
3403- }
3404- }
3405- val elements = hsSize.size + nullElementSize
3396+ hsInt = new OpenHashSet [Int ]
3397+ val elements = evalIntLongPrimitiveType(array1, array2, null , false )
34063398 hsInt = new OpenHashSet [Int ]
34073399 val resultArray = if (UnsafeArrayData .useGenericArrayData(
34083400 IntegerType .defaultSize, elements)) {
@@ -3411,29 +3403,13 @@ case class ArrayUnion(left: Expression, right: Expression) extends ArraySetLike
34113403 UnsafeArrayData .forPrimitiveArray(
34123404 Platform .INT_ARRAY_OFFSET , elements, IntegerType .defaultSize)
34133405 }
3414- evalIntLongPrimitiveType(array1, array2, elements, resultArray, false )
3406+ evalIntLongPrimitiveType(array1, array2, resultArray, false )
3407+ resultArray
34153408 case LongType =>
34163409 // avoid boxing of primitive long array elements
34173410 // calculate result array size
3418- val hsSize = new OpenHashSet [Long ]
3419- var nullElementSize = 0
3420- Seq (array1, array2).foreach { array =>
3421- var i = 0
3422- while (i < array.numElements()) {
3423- if (hsSize.size + nullElementSize > ByteArrayMethods .MAX_ROUNDED_ARRAY_LENGTH ) {
3424- ArraySetLike .throwUnionLengthOverflowException(hsSize.size)
3425- }
3426- if (array.isNullAt(i)) {
3427- if (nullElementSize == 0 ) {
3428- nullElementSize = 1
3429- }
3430- } else {
3431- hsSize.add(array.getLong(i))
3432- }
3433- i += 1
3434- }
3435- }
3436- val elements = hsSize.size + nullElementSize
3411+ hsLong = new OpenHashSet [Long ]
3412+ val elements = evalIntLongPrimitiveType(array1, array2, null , true )
34373413 hsLong = new OpenHashSet [Long ]
34383414 val resultArray = if (UnsafeArrayData .useGenericArrayData(
34393415 LongType .defaultSize, elements)) {
@@ -3442,7 +3418,8 @@ case class ArrayUnion(left: Expression, right: Expression) extends ArraySetLike
34423418 UnsafeArrayData .forPrimitiveArray(
34433419 Platform .LONG_ARRAY_OFFSET , elements, LongType .defaultSize)
34443420 }
3445- evalIntLongPrimitiveType(array1, array2, elements, resultArray, true )
3421+ evalIntLongPrimitiveType(array1, array2, resultArray, true )
3422+ resultArray
34463423 case _ =>
34473424 val arrayBuffer = new scala.collection.mutable.ArrayBuffer [Any ]
34483425 val hs = new OpenHashSet [Any ]
0 commit comments