Skip to content
Closed
20 changes: 20 additions & 0 deletions python/pyspark/sql/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -2233,6 +2233,26 @@ def flatten(col):
return Column(sc._jvm.functions.flatten(_to_java_column(col)))


@since(2.4)
def zip_with_index(col, indexFirst=False, startFromZero=False):
"""
Collection function: transforms the input array by encapsulating elements into pairs
with indexes indicating the order.

:param col: name of column or expression

>>> df = spark.createDataFrame([([2, 5, 3],), ([],)], ['d'])
>>> df.select(zip_with_index(df.d).alias('r')).collect()
[Row(r=[Row(value=2, index=1), Row(value=5, index=2), Row(value=3, index=3)]), Row(r=[])]
>>> df.select(zip_with_index(df.d, indexFirst=True, startFromZero=False).alias('r')).collect()
[Row(r=[Row(index=1, value=2), Row(index=2, value=5), Row(index=3, value=3)]), Row(r=[])]
>>> df.select(zip_with_index(df.d, indexFirst=True, startFromZero=True).alias('r')).collect()
[Row(r=[Row(index=0, value=2), Row(index=1, value=5), Row(index=2, value=3)]), Row(r=[])]
"""
sc = SparkContext._active_spark_context
return Column(sc._jvm.functions.zip_with_index(_to_java_column(col), indexFirst, startFromZero))


@since(2.3)
def map_keys(col):
"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -415,6 +415,7 @@ object FunctionRegistry {
expression[Reverse]("reverse"),
expression[Concat]("concat"),
expression[Flatten]("flatten"),
expression[ZipWithIndex]("zip_with_index"),
CreateStruct.registryEntry,

// misc functions
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ package org.apache.spark.sql.catalyst.expressions

import java.util.Comparator

import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.expressions.codegen._
Expand Down Expand Up @@ -1228,3 +1229,158 @@ case class Flatten(child: Expression) extends UnaryExpression {

override def prettyName: String = "flatten"
}

/**
* Transforms an array by assigning an order number to each element.
*/
// scalastyle:off line.size.limit
@ExpressionDescription(
usage = "_FUNC_(array[, indexFirst, startFromZero]) - Transforms the input array by encapsulating elements into pairs with indexes indicating the order.",
examples = """
Examples:
> SELECT _FUNC_(array("d", "a", null, "b"));
[("d",1),("a",2),(null,3),("b",4)]
> SELECT _FUNC_(array("d", "a", null, "b"), true, false);
[(1,"d"),(2,"a"),(3,null),(4,"b")]
> SELECT _FUNC_(array("d", "a", null, "b"), true, true);
[(0,"d"),(1,"a"),(2,null),(3,"b")]
""",
since = "2.4.0")
// scalastyle:on line.size.limit
case class ZipWithIndex(child: Expression, indexFirst: Expression, startFromZero: Expression)
extends UnaryExpression with ExpectsInputTypes {

def this(e: Expression) = this(e, Literal.FalseLiteral, Literal.FalseLiteral)

def exprToFlag(e: Expression, order: String): Boolean = e match {
case Literal(v: Boolean, BooleanType) => v
case _ => throw new AnalysisException(s"The $order argument has to be a boolean constant.")
}

private val idxFirst: Boolean = exprToFlag(indexFirst, "second")

private val (idxShift, idxGen): (Int, String) = if (exprToFlag(startFromZero, "third")) {
(0, "z")
} else {
(1, "z + 1")
}

private val MAX_ARRAY_LENGTH: Int = ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH

override def inputTypes: Seq[AbstractDataType] = Seq(ArrayType)

lazy val childArrayType: ArrayType = child.dataType.asInstanceOf[ArrayType]

override def dataType: DataType = {
val elementField = StructField("value", childArrayType.elementType, childArrayType.containsNull)
val indexField = StructField("index", IntegerType, false)

val fields = if (idxFirst) Seq(indexField, elementField) else Seq(elementField, indexField)

ArrayType(StructType(fields), false)
}

override protected def nullSafeEval(input: Any): Any = {
val array = input.asInstanceOf[ArrayData].toObjectArray(childArrayType.elementType)

val makeStruct = (v: Any, i: Int) => if (idxFirst) InternalRow(i, v) else InternalRow(v, i)
val resultData = array.zipWithIndex.map{case (v, i) => makeStruct(v, i + idxShift)}

new GenericArrayData(resultData)
}

override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
nullSafeCodeGen(ctx, ev, c => {
val numElements = ctx.freshName("numElements")
val code = if (CodeGenerator.isPrimitiveType(childArrayType.elementType)) {
genCodeForPrimitiveElements(ctx, c, ev.value, numElements)
} else {
genCodeForAnyElements(ctx, c, ev.value, numElements)
}
s"""
|final int $numElements = $c.numElements();
|$code
""".stripMargin
})
}

private def genCodeForPrimitiveElements(
ctx: CodegenContext,
childVariableName: String,
arrayData: String,
numElements: String): String = {
val byteArraySize = ctx.freshName("byteArraySize")
val data = ctx.freshName("byteArray")
val unsafeRow = ctx.freshName("unsafeRow")
val structSize = ctx.freshName("structSize")
val unsafeArrayData = ctx.freshName("unsafeArrayData")
val structsOffset = ctx.freshName("structsOffset")
val calculateArraySize = "UnsafeArrayData.calculateSizeOfUnderlyingByteArray"
val calculateHeader = "UnsafeArrayData.calculateHeaderPortionInBytes"
Copy link
Member

@viirya viirya Apr 22, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are we sure the input is always unsafe-backed array? If it is GenericArrayData?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, I see. You just use unsafe-backed array as output.


val baseOffset = Platform.BYTE_ARRAY_OFFSET
val longSize = LongType.defaultSize
val primitiveValueTypeName = CodeGenerator.primitiveTypeName(childArrayType.elementType)
val (valuePosition, indexPosition) = if (idxFirst) ("1", "0") else ("0", "1")

s"""
|final int $structSize = ${UnsafeRow.calculateBitSetWidthInBytes(2) + longSize * 2};
|final long $byteArraySize = $calculateArraySize($numElements, $longSize + $structSize);
|final int $structsOffset = $calculateHeader($numElements) + $numElements * $longSize;
|if ($byteArraySize > $MAX_ARRAY_LENGTH) {
| ${genCodeForAnyElements(ctx, childVariableName, arrayData, numElements)}
|} else {
| final byte[] $data = new byte[(int)$byteArraySize];
| UnsafeArrayData $unsafeArrayData = new UnsafeArrayData();
| Platform.putLong($data, $baseOffset, $numElements);
| $unsafeArrayData.pointTo($data, $baseOffset, (int)$byteArraySize);
| UnsafeRow $unsafeRow = new UnsafeRow(2);
| for (int z = 0; z < $numElements; z++) {
| long offset = $structsOffset + z * $structSize;
| $unsafeArrayData.setLong(z, (offset << 32) + $structSize);
| $unsafeRow.pointTo($data, $baseOffset + offset, $structSize);
| if (${childArrayType.containsNull} && $childVariableName.isNullAt(z)) {
| $unsafeRow.setNullAt($valuePosition);
| } else {
| $unsafeRow.set$primitiveValueTypeName(
| $valuePosition,
| ${CodeGenerator.getValue(childVariableName, childArrayType.elementType, "z")}
| );
| }
| $unsafeRow.setInt($indexPosition, $idxGen);
| }
| $arrayData = $unsafeArrayData;
|}
""".stripMargin
}

private def genCodeForAnyElements(
ctx: CodegenContext,
childVariableName: String,
arrayData: String,
numElements: String): String = {
val genericArrayClass = classOf[GenericArrayData].getName
val rowClass = classOf[GenericInternalRow].getName
val data = ctx.freshName("internalRowArray")

val getElement = CodeGenerator.getValue(childVariableName, childArrayType.elementType, "z")
val isPrimitiveType = CodeGenerator.isPrimitiveType(childArrayType.elementType)
val elementValue = if (childArrayType.containsNull && isPrimitiveType) {
s"$childVariableName.isNullAt(z) ? null : (Object)$getElement"
} else {
getElement
}
val arguments = if (idxFirst) s"$idxGen, $elementValue" else s"$elementValue, $idxGen"

s"""
|final Object[] $data = new Object[$numElements];
|for (int z = 0; z < $numElements; z++) {
| $data[z] = new $rowClass(new Object[]{$arguments});
|}
|$arrayData = new $genericArrayClass($data);
""".stripMargin
}

override def prettyName: String = "zip_with_index"
}

Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
package org.apache.spark.sql.catalyst.expressions

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.types._

class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
Expand Down Expand Up @@ -410,4 +411,57 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper
checkEvaluation(Flatten(asa3), null)
checkEvaluation(Flatten(asa4), null)
}

test("Zip With Index") {
def r(values: Any*): InternalRow = create_row(values: _*)
val t = Literal.TrueLiteral
val f = Literal.FalseLiteral

// Primitive-type elements
val ai0 = Literal.create(Seq(2, 8, 4, 7), ArrayType(IntegerType))
val ai1 = Literal.create(Seq(null, 4, null, 2), ArrayType(IntegerType))
val ai2 = Literal.create(Seq(null, null, null), ArrayType(IntegerType))
val ai3 = Literal.create(Seq(2), ArrayType(IntegerType))
val ai4 = Literal.create(Seq.empty, ArrayType(IntegerType))
val ai5 = Literal.create(null, ArrayType(IntegerType))

checkEvaluation(ZipWithIndex(ai0, f, f), Seq(r(2, 1), r(8, 2), r(4, 3), r(7, 4)))
checkEvaluation(ZipWithIndex(ai1, f, f), Seq(r(null, 1), r(4, 2), r(null, 3), r(2, 4)))
checkEvaluation(ZipWithIndex(ai2, f, f), Seq(r(null, 1), r(null, 2), r(null, 3)))
checkEvaluation(ZipWithIndex(ai3, f, f), Seq(r(2, 1)))
checkEvaluation(ZipWithIndex(ai4, f, f), Seq.empty)
checkEvaluation(ZipWithIndex(ai5, f, f), null)

checkEvaluation(ZipWithIndex(ai0, t, t), Seq(r(0, 2), r(1, 8), r(2, 4), r(3, 7)))
checkEvaluation(ZipWithIndex(ai1, t, t), Seq(r(0, null), r(1, 4), r(2, null), r(3, 2)))
checkEvaluation(ZipWithIndex(ai2, t, t), Seq(r(0, null), r(1, null), r(2, null)))
checkEvaluation(ZipWithIndex(ai3, t, t), Seq(r(0, 2)))
checkEvaluation(ZipWithIndex(ai4, t, t), Seq.empty)
checkEvaluation(ZipWithIndex(ai5, t, t), null)

// Non-primitive-type elements
val as0 = Literal.create(Seq("b", "a", "y", "z"), ArrayType(StringType))
val as1 = Literal.create(Seq(null, "x", null, "y"), ArrayType(StringType))
val as2 = Literal.create(Seq(null, null, null), ArrayType(StringType))
val as3 = Literal.create(Seq("a"), ArrayType(StringType))
val as4 = Literal.create(Seq.empty, ArrayType(StringType))
val as5 = Literal.create(null, ArrayType(StringType))
val aas = Literal.create(Seq(Seq("e"), Seq("c", "d")), ArrayType(ArrayType(StringType)))

checkEvaluation(ZipWithIndex(as0, f, f), Seq(r("b", 1), r("a", 2), r("y", 3), r("z", 4)))
checkEvaluation(ZipWithIndex(as1, f, f), Seq(r(null, 1), r("x", 2), r(null, 3), r("y", 4)))
checkEvaluation(ZipWithIndex(as2, f, f), Seq(r(null, 1), r(null, 2), r(null, 3)))
checkEvaluation(ZipWithIndex(as3, f, f), Seq(r("a", 1)))
checkEvaluation(ZipWithIndex(as4, f, f), Seq.empty)
checkEvaluation(ZipWithIndex(as5, f, f), null)
checkEvaluation(ZipWithIndex(aas, f, f), Seq(r(Seq("e"), 1), r(Seq("c", "d"), 2)))

checkEvaluation(ZipWithIndex(as0, t, t), Seq(r(0, "b"), r(1, "a"), r(2, "y"), r(3, "z")))
checkEvaluation(ZipWithIndex(as1, t, t), Seq(r(0, null), r(1, "x"), r(2, null), r(3, "y")))
checkEvaluation(ZipWithIndex(as2, t, t), Seq(r(0, null), r(1, null), r(2, null)))
checkEvaluation(ZipWithIndex(as3, t, t), Seq(r(0, "a")))
checkEvaluation(ZipWithIndex(as4, t, t), Seq.empty)
checkEvaluation(ZipWithIndex(as5, t, t), null)
checkEvaluation(ZipWithIndex(aas, t, t), Seq(r(0, Seq("e")), r(1, Seq("c", "d"))))
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,9 @@ trait ExpressionEvalHelper extends GeneratorDrivenPropertyChecks {
if (expected.isNaN) result.isNaN else expected == result
case (result: Float, expected: Float) =>
if (expected.isNaN) result.isNaN else expected == result
case (result: UnsafeRow, expected: GenericInternalRow) =>
val structType = exprDataType.asInstanceOf[StructType]
result.toSeq(structType) == expected.toSeq(structType)
case (result: Row, expected: InternalRow) => result.toSeq == expected.toSeq(result.schema)
case _ =>
result == expected
Expand Down
24 changes: 24 additions & 0 deletions sql/core/src/main/scala/org/apache/spark/sql/functions.scala
Original file line number Diff line number Diff line change
Expand Up @@ -3378,6 +3378,30 @@ object functions {
*/
def flatten(e: Column): Column = withExpr { Flatten(e.expr) }

/**
* Transforms the input array by encapsulating elements into pairs
* with indexes indicating the order.
*
* Note: The array index is placed second and starts from one.
*
* @group collection_funcs
* @since 2.4.0
*/
def zip_with_index(e: Column): Column = withExpr {
ZipWithIndex(e.expr, Literal.FalseLiteral, Literal.FalseLiteral)
}

/**
* Transforms the input array by encapsulating elements into pairs
* with indexes indicating the order.
*
* @group collection_funcs
* @since 2.4.0
*/
def zip_with_index(e: Column, indexFirst: Boolean, startFromZero: Boolean): Column = withExpr {
ZipWithIndex(e.expr, Literal(indexFirst), Literal(startFromZero))
}

/**
* Returns an unordered array containing the keys of the map.
* @group collection_funcs
Expand Down
Loading