diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index ad4bd6f5089e9..f74ee06f66df3 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -2233,6 +2233,26 @@ def flatten(col): return Column(sc._jvm.functions.flatten(_to_java_column(col))) +@since(2.4) +def zip_with_index(col, indexFirst=False, startFromZero=False): + """ + Collection function: transforms the input array by encapsulating elements into pairs + with indexes indicating the order. + + :param col: name of column or expression + + >>> df = spark.createDataFrame([([2, 5, 3],), ([],)], ['d']) + >>> df.select(zip_with_index(df.d).alias('r')).collect() + [Row(r=[Row(value=2, index=1), Row(value=5, index=2), Row(value=3, index=3)]), Row(r=[])] + >>> df.select(zip_with_index(df.d, indexFirst=True, startFromZero=False).alias('r')).collect() + [Row(r=[Row(index=1, value=2), Row(index=2, value=5), Row(index=3, value=3)]), Row(r=[])] + >>> df.select(zip_with_index(df.d, indexFirst=True, startFromZero=True).alias('r')).collect() + [Row(r=[Row(index=0, value=2), Row(index=1, value=5), Row(index=2, value=3)]), Row(r=[])] + """ + sc = SparkContext._active_spark_context + return Column(sc._jvm.functions.zip_with_index(_to_java_column(col), indexFirst, startFromZero)) + + @since(2.3) def map_keys(col): """ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index 6bc7b4e4f7cb3..2aa20307bca92 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -415,6 +415,7 @@ object FunctionRegistry { expression[Reverse]("reverse"), expression[Concat]("concat"), expression[Flatten]("flatten"), + expression[ZipWithIndex]("zip_with_index"), CreateStruct.registryEntry, // misc functions diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index 90223b9126555..eb41629245cce 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.catalyst.expressions import java.util.Comparator +import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.expressions.codegen._ @@ -1228,3 +1229,158 @@ case class Flatten(child: Expression) extends UnaryExpression { override def prettyName: String = "flatten" } + +/** + * Transforms an array by assigning an order number to each element. + */ +// scalastyle:off line.size.limit +@ExpressionDescription( + usage = "_FUNC_(array[, indexFirst, startFromZero]) - Transforms the input array by encapsulating elements into pairs with indexes indicating the order.", + examples = """ + Examples: + > SELECT _FUNC_(array("d", "a", null, "b")); + [("d",1),("a",2),(null,3),("b",4)] + > SELECT _FUNC_(array("d", "a", null, "b"), true, false); + [(1,"d"),(2,"a"),(3,null),(4,"b")] + > SELECT _FUNC_(array("d", "a", null, "b"), true, true); + [(0,"d"),(1,"a"),(2,null),(3,"b")] + """, + since = "2.4.0") +// scalastyle:on line.size.limit +case class ZipWithIndex(child: Expression, indexFirst: Expression, startFromZero: Expression) + extends UnaryExpression with ExpectsInputTypes { + + def this(e: Expression) = this(e, Literal.FalseLiteral, Literal.FalseLiteral) + + def exprToFlag(e: Expression, order: String): Boolean = e match { + case Literal(v: Boolean, BooleanType) => v + case _ => throw new AnalysisException(s"The $order argument has to be a boolean constant.") + } + + private val idxFirst: Boolean = exprToFlag(indexFirst, "second") + + private val (idxShift, idxGen): (Int, String) = if (exprToFlag(startFromZero, "third")) { + (0, "z") + } else { + (1, "z + 1") + } + + private val MAX_ARRAY_LENGTH: Int = ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH + + override def inputTypes: Seq[AbstractDataType] = Seq(ArrayType) + + lazy val childArrayType: ArrayType = child.dataType.asInstanceOf[ArrayType] + + override def dataType: DataType = { + val elementField = StructField("value", childArrayType.elementType, childArrayType.containsNull) + val indexField = StructField("index", IntegerType, false) + + val fields = if (idxFirst) Seq(indexField, elementField) else Seq(elementField, indexField) + + ArrayType(StructType(fields), false) + } + + override protected def nullSafeEval(input: Any): Any = { + val array = input.asInstanceOf[ArrayData].toObjectArray(childArrayType.elementType) + + val makeStruct = (v: Any, i: Int) => if (idxFirst) InternalRow(i, v) else InternalRow(v, i) + val resultData = array.zipWithIndex.map{case (v, i) => makeStruct(v, i + idxShift)} + + new GenericArrayData(resultData) + } + + override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + nullSafeCodeGen(ctx, ev, c => { + val numElements = ctx.freshName("numElements") + val code = if (CodeGenerator.isPrimitiveType(childArrayType.elementType)) { + genCodeForPrimitiveElements(ctx, c, ev.value, numElements) + } else { + genCodeForAnyElements(ctx, c, ev.value, numElements) + } + s""" + |final int $numElements = $c.numElements(); + |$code + """.stripMargin + }) + } + + private def genCodeForPrimitiveElements( + ctx: CodegenContext, + childVariableName: String, + arrayData: String, + numElements: String): String = { + val byteArraySize = ctx.freshName("byteArraySize") + val data = ctx.freshName("byteArray") + val unsafeRow = ctx.freshName("unsafeRow") + val structSize = ctx.freshName("structSize") + val unsafeArrayData = ctx.freshName("unsafeArrayData") + val structsOffset = ctx.freshName("structsOffset") + val calculateArraySize = "UnsafeArrayData.calculateSizeOfUnderlyingByteArray" + val calculateHeader = "UnsafeArrayData.calculateHeaderPortionInBytes" + + val baseOffset = Platform.BYTE_ARRAY_OFFSET + val longSize = LongType.defaultSize + val primitiveValueTypeName = CodeGenerator.primitiveTypeName(childArrayType.elementType) + val (valuePosition, indexPosition) = if (idxFirst) ("1", "0") else ("0", "1") + + s""" + |final int $structSize = ${UnsafeRow.calculateBitSetWidthInBytes(2) + longSize * 2}; + |final long $byteArraySize = $calculateArraySize($numElements, $longSize + $structSize); + |final int $structsOffset = $calculateHeader($numElements) + $numElements * $longSize; + |if ($byteArraySize > $MAX_ARRAY_LENGTH) { + | ${genCodeForAnyElements(ctx, childVariableName, arrayData, numElements)} + |} else { + | final byte[] $data = new byte[(int)$byteArraySize]; + | UnsafeArrayData $unsafeArrayData = new UnsafeArrayData(); + | Platform.putLong($data, $baseOffset, $numElements); + | $unsafeArrayData.pointTo($data, $baseOffset, (int)$byteArraySize); + | UnsafeRow $unsafeRow = new UnsafeRow(2); + | for (int z = 0; z < $numElements; z++) { + | long offset = $structsOffset + z * $structSize; + | $unsafeArrayData.setLong(z, (offset << 32) + $structSize); + | $unsafeRow.pointTo($data, $baseOffset + offset, $structSize); + | if (${childArrayType.containsNull} && $childVariableName.isNullAt(z)) { + | $unsafeRow.setNullAt($valuePosition); + | } else { + | $unsafeRow.set$primitiveValueTypeName( + | $valuePosition, + | ${CodeGenerator.getValue(childVariableName, childArrayType.elementType, "z")} + | ); + | } + | $unsafeRow.setInt($indexPosition, $idxGen); + | } + | $arrayData = $unsafeArrayData; + |} + """.stripMargin + } + + private def genCodeForAnyElements( + ctx: CodegenContext, + childVariableName: String, + arrayData: String, + numElements: String): String = { + val genericArrayClass = classOf[GenericArrayData].getName + val rowClass = classOf[GenericInternalRow].getName + val data = ctx.freshName("internalRowArray") + + val getElement = CodeGenerator.getValue(childVariableName, childArrayType.elementType, "z") + val isPrimitiveType = CodeGenerator.isPrimitiveType(childArrayType.elementType) + val elementValue = if (childArrayType.containsNull && isPrimitiveType) { + s"$childVariableName.isNullAt(z) ? null : (Object)$getElement" + } else { + getElement + } + val arguments = if (idxFirst) s"$idxGen, $elementValue" else s"$elementValue, $idxGen" + + s""" + |final Object[] $data = new Object[$numElements]; + |for (int z = 0; z < $numElements; z++) { + | $data[z] = new $rowClass(new Object[]{$arguments}); + |} + |$arrayData = new $genericArrayClass($data); + """.stripMargin + } + + override def prettyName: String = "zip_with_index" +} + diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala index 7048d93fd5649..27c2aa6e8be64 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.types._ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { @@ -410,4 +411,57 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper checkEvaluation(Flatten(asa3), null) checkEvaluation(Flatten(asa4), null) } + + test("Zip With Index") { + def r(values: Any*): InternalRow = create_row(values: _*) + val t = Literal.TrueLiteral + val f = Literal.FalseLiteral + + // Primitive-type elements + val ai0 = Literal.create(Seq(2, 8, 4, 7), ArrayType(IntegerType)) + val ai1 = Literal.create(Seq(null, 4, null, 2), ArrayType(IntegerType)) + val ai2 = Literal.create(Seq(null, null, null), ArrayType(IntegerType)) + val ai3 = Literal.create(Seq(2), ArrayType(IntegerType)) + val ai4 = Literal.create(Seq.empty, ArrayType(IntegerType)) + val ai5 = Literal.create(null, ArrayType(IntegerType)) + + checkEvaluation(ZipWithIndex(ai0, f, f), Seq(r(2, 1), r(8, 2), r(4, 3), r(7, 4))) + checkEvaluation(ZipWithIndex(ai1, f, f), Seq(r(null, 1), r(4, 2), r(null, 3), r(2, 4))) + checkEvaluation(ZipWithIndex(ai2, f, f), Seq(r(null, 1), r(null, 2), r(null, 3))) + checkEvaluation(ZipWithIndex(ai3, f, f), Seq(r(2, 1))) + checkEvaluation(ZipWithIndex(ai4, f, f), Seq.empty) + checkEvaluation(ZipWithIndex(ai5, f, f), null) + + checkEvaluation(ZipWithIndex(ai0, t, t), Seq(r(0, 2), r(1, 8), r(2, 4), r(3, 7))) + checkEvaluation(ZipWithIndex(ai1, t, t), Seq(r(0, null), r(1, 4), r(2, null), r(3, 2))) + checkEvaluation(ZipWithIndex(ai2, t, t), Seq(r(0, null), r(1, null), r(2, null))) + checkEvaluation(ZipWithIndex(ai3, t, t), Seq(r(0, 2))) + checkEvaluation(ZipWithIndex(ai4, t, t), Seq.empty) + checkEvaluation(ZipWithIndex(ai5, t, t), null) + + // Non-primitive-type elements + val as0 = Literal.create(Seq("b", "a", "y", "z"), ArrayType(StringType)) + val as1 = Literal.create(Seq(null, "x", null, "y"), ArrayType(StringType)) + val as2 = Literal.create(Seq(null, null, null), ArrayType(StringType)) + val as3 = Literal.create(Seq("a"), ArrayType(StringType)) + val as4 = Literal.create(Seq.empty, ArrayType(StringType)) + val as5 = Literal.create(null, ArrayType(StringType)) + val aas = Literal.create(Seq(Seq("e"), Seq("c", "d")), ArrayType(ArrayType(StringType))) + + checkEvaluation(ZipWithIndex(as0, f, f), Seq(r("b", 1), r("a", 2), r("y", 3), r("z", 4))) + checkEvaluation(ZipWithIndex(as1, f, f), Seq(r(null, 1), r("x", 2), r(null, 3), r("y", 4))) + checkEvaluation(ZipWithIndex(as2, f, f), Seq(r(null, 1), r(null, 2), r(null, 3))) + checkEvaluation(ZipWithIndex(as3, f, f), Seq(r("a", 1))) + checkEvaluation(ZipWithIndex(as4, f, f), Seq.empty) + checkEvaluation(ZipWithIndex(as5, f, f), null) + checkEvaluation(ZipWithIndex(aas, f, f), Seq(r(Seq("e"), 1), r(Seq("c", "d"), 2))) + + checkEvaluation(ZipWithIndex(as0, t, t), Seq(r(0, "b"), r(1, "a"), r(2, "y"), r(3, "z"))) + checkEvaluation(ZipWithIndex(as1, t, t), Seq(r(0, null), r(1, "x"), r(2, null), r(3, "y"))) + checkEvaluation(ZipWithIndex(as2, t, t), Seq(r(0, null), r(1, null), r(2, null))) + checkEvaluation(ZipWithIndex(as3, t, t), Seq(r(0, "a"))) + checkEvaluation(ZipWithIndex(as4, t, t), Seq.empty) + checkEvaluation(ZipWithIndex(as5, t, t), null) + checkEvaluation(ZipWithIndex(aas, t, t), Seq(r(0, Seq("e")), r(1, Seq("c", "d")))) + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala index b4bf6d7107d7e..e739f1a6b4cfd 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala @@ -98,6 +98,9 @@ trait ExpressionEvalHelper extends GeneratorDrivenPropertyChecks { if (expected.isNaN) result.isNaN else expected == result case (result: Float, expected: Float) => if (expected.isNaN) result.isNaN else expected == result + case (result: UnsafeRow, expected: GenericInternalRow) => + val structType = exprDataType.asInstanceOf[StructType] + result.toSeq(structType) == expected.toSeq(structType) case (result: Row, expected: InternalRow) => result.toSeq == expected.toSeq(result.schema) case _ => result == expected diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 25afaacc38d6f..35115e89263cc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -3378,6 +3378,30 @@ object functions { */ def flatten(e: Column): Column = withExpr { Flatten(e.expr) } + /** + * Transforms the input array by encapsulating elements into pairs + * with indexes indicating the order. + * + * Note: The array index is placed second and starts from one. + * + * @group collection_funcs + * @since 2.4.0 + */ + def zip_with_index(e: Column): Column = withExpr { + ZipWithIndex(e.expr, Literal.FalseLiteral, Literal.FalseLiteral) + } + + /** + * Transforms the input array by encapsulating elements into pairs + * with indexes indicating the order. + * + * @group collection_funcs + * @since 2.4.0 + */ + def zip_with_index(e: Column, indexFirst: Boolean, startFromZero: Boolean): Column = withExpr { + ZipWithIndex(e.expr, Literal(indexFirst), Literal(startFromZero)) + } + /** * Returns an unordered array containing the keys of the map. * @group collection_funcs diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index c216d1322a06c..7712c355bdb08 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -793,6 +793,131 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { } } + test("zip_with_index function") { + val dummyFilter = (c: Column) => c.isNull || c.isNotNull // switch codegen on + val oneRowDF = Seq(("Spark", 3215, true)).toDF("s", "i", "b") + + // Test cases with primitive-type elements + val idf = Seq( + Seq(1, 9, 8, 7), + Seq.empty, + null + ).toDF("i") + + checkAnswer( + idf.select(zip_with_index('i)), + Seq(Row(Seq(Row(1, 1), Row(9, 2), Row(8, 3), Row(7, 4))), Row(Seq.empty), Row(null)) + ) + checkAnswer( + idf.filter(dummyFilter('i)).select(zip_with_index('i)), + Seq(Row(Seq(Row(1, 1), Row(9, 2), Row(8, 3), Row(7, 4))), Row(Seq.empty), Row(null)) + ) + checkAnswer( + idf.select(zip_with_index('i, true, false)), + Seq(Row(Seq(Row(1, 1), Row(2, 9), Row(3, 8), Row(4, 7))), Row(Seq.empty), Row(null)) + ) + checkAnswer( + idf.select(zip_with_index('i, true, true)), + Seq(Row(Seq(Row(0, 1), Row(1, 9), Row(2, 8), Row(3, 7))), Row(Seq.empty), Row(null)) + ) + checkAnswer( + idf.selectExpr("zip_with_index(i)"), + Seq(Row(Seq(Row(1, 1), Row(9, 2), Row(8, 3), Row(7, 4))), Row(Seq.empty), Row(null)) + ) + checkAnswer( + idf.selectExpr("zip_with_index(i, true, false)"), + Seq(Row(Seq(Row(1, 1), Row(2, 9), Row(3, 8), Row(4, 7))), Row(Seq.empty), Row(null)) + ) + checkAnswer( + idf.selectExpr("zip_with_index(i, false, true)"), + Seq(Row(Seq(Row(1, 0), Row(9, 1), Row(8, 2), Row(7, 3))), Row(Seq.empty), Row(null)) + ) + checkAnswer( + oneRowDF.selectExpr("zip_with_index(array(null, 2, null), false, true)"), + Seq(Row(Seq(Row(null, 0), Row(2, 1), Row(null, 2)))) + ) + checkAnswer( + oneRowDF.selectExpr("zip_with_index(array(null, 2, null), true, true)"), + Seq(Row(Seq(Row(0, null), Row(1, 2), Row(2, null)))) + ) + + // Test cases with non-primitive-type elements + val sdf = Seq( + Seq("c", "a", "d", "b"), + Seq(null, "x", null), + Seq.empty, + null + ).toDF("s") + + checkAnswer( + sdf.select(zip_with_index('s)), + Seq( + Row(Seq(Row("c", 1), Row("a", 2), Row("d", 3), Row("b", 4))), + Row(Seq(Row(null, 1), Row("x", 2), Row(null, 3))), + Row(Seq.empty), + Row(null)) + ) + checkAnswer( + sdf.filter(dummyFilter('s)).select(zip_with_index('s)), + Seq( + Row(Seq(Row("c", 1), Row("a", 2), Row("d", 3), Row("b", 4))), + Row(Seq(Row(null, 1), Row("x", 2), Row(null, 3))), + Row(Seq.empty), + Row(null)) + ) + checkAnswer( + sdf.select(zip_with_index('s, true, false)), + Seq( + Row(Seq(Row(1, "c"), Row(2, "a"), Row(3, "d"), Row(4, "b"))), + Row(Seq(Row(1, null), Row(2, "x"), Row(3, null))), + Row(Seq.empty), + Row(null)) + ) + checkAnswer( + sdf.select(zip_with_index('s, true, true)), + Seq( + Row(Seq(Row(0, "c"), Row(1, "a"), Row(2, "d"), Row(3, "b"))), + Row(Seq(Row(0, null), Row(1, "x"), Row(2, null))), + Row(Seq.empty), + Row(null)) + ) + checkAnswer( + sdf.selectExpr("zip_with_index(s)"), + Seq( + Row(Seq(Row("c", 1), Row("a", 2), Row("d", 3), Row("b", 4))), + Row(Seq(Row(null, 1), Row("x", 2), Row(null, 3))), + Row(Seq.empty), + Row(null)) + ) + checkAnswer( + sdf.selectExpr("zip_with_index(s, false, true)"), + Seq( + Row(Seq(Row("c", 0), Row("a", 1), Row("d", 2), Row("b", 3))), + Row(Seq(Row(null, 0), Row("x", 1), Row(null, 2))), + Row(Seq.empty), + Row(null)) + ) + checkAnswer( + sdf.selectExpr("zip_with_index(s, true, false)"), + Seq( + Row(Seq(Row(1, "c"), Row(2, "a"), Row(3, "d"), Row(4, "b"))), + Row(Seq(Row(1, null), Row(2, "x"), Row(3, null))), + Row(Seq.empty), + Row(null)) + ) + + // Error test cases + intercept[AnalysisException] { + oneRowDF.select(zip_with_index('s)) + } + intercept[AnalysisException] { + oneRowDF.selectExpr("zip_with_index(array(1, 2, 3), b, false)") + } + intercept[AnalysisException] { + oneRowDF.selectExpr("zip_with_index(array(1, 2, 3), true, 1)") + } + } + private def assertValuesDoNotChangeAfterCoalesceOrUnion(v: Column): Unit = { import DataFrameFunctionsSuite.CodegenFallbackExpr for ((codegenFallback, wholeStage) <- Seq((true, false), (false, false), (false, true))) {