Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 20 additions & 0 deletions python/pyspark/sql/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -2412,6 +2412,26 @@ def map_entries(col):
return Column(sc._jvm.functions.map_entries(_to_java_column(col)))


@since(2.4)
def map_from_entries(col):
"""
Collection function: Returns a map created from the given array of entries.

:param col: name of column or expression

>>> from pyspark.sql.functions import map_from_entries
>>> df = spark.sql("SELECT array(struct(1, 'a'), struct(2, 'b')) as data")
>>> df.select(map_from_entries("data").alias("map")).show()
+----------------+
| map|
+----------------+
|[1 -> a, 2 -> b]|
+----------------+
"""
sc = SparkContext._active_spark_context
return Column(sc._jvm.functions.map_from_entries(_to_java_column(col)))


@ignore_unicode_prefix
@since(2.4)
def array_repeat(col, count):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -421,6 +421,7 @@ object FunctionRegistry {
expression[MapKeys]("map_keys"),
expression[MapValues]("map_values"),
expression[MapEntries]("map_entries"),
expression[MapFromEntries]("map_from_entries"),
expression[Size]("size"),
expression[Slice]("slice"),
expression[Size]("cardinality"),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -819,6 +819,36 @@ class CodegenContext {
}
}

/**
* Generates code to do null safe execution when accessing properties of complex
* ArrayData elements.
*
* @param nullElements used to decide whether the ArrayData might contain null or not.
* @param isNull a variable indicating whether the result will be evaluated to null or not.
* @param arrayData a variable name representing the ArrayData.
* @param execute the code that should be executed only if the ArrayData doesn't contain
* any null.
*/
def nullArrayElementsSaveExec(
nullElements: Boolean,
isNull: String,
arrayData: String)(
execute: String): String = {
val i = freshName("idx")
if (nullElements) {
s"""
|for (int $i = 0; !$isNull && $i < $arrayData.numElements(); $i++) {
| $isNull |= $arrayData.isNullAt($i);
|}
|if (!$isNull) {
| $execute
|}
""".stripMargin
} else {
execute
}
}

/**
* Splits the generated code of expressions into multiple functions, because function has
* 64kb code size limit in JVM. If the class to which the function would be inlined would grow
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, TypeCoercion}
import org.apache.spark.sql.catalyst.expressions.ArraySortLike.NullOrder
import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.expressions.codegen.Block._
import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData, MapData, TypeUtils}
import org.apache.spark.sql.catalyst.util._
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.Platform
Expand Down Expand Up @@ -475,6 +475,223 @@ case class MapEntries(child: Expression) extends UnaryExpression with ExpectsInp
override def prettyName: String = "map_entries"
}

/**
* Returns a map created from the given array of entries.
*/
@ExpressionDescription(
usage = "_FUNC_(arrayOfEntries) - Returns a map created from the given array of entries.",
examples = """
Examples:
> SELECT _FUNC_(array(struct(1, 'a'), struct(2, 'b')));
{1:"a",2:"b"}
""",
since = "2.4.0")
case class MapFromEntries(child: Expression) extends UnaryExpression {

@transient
private lazy val dataTypeDetails: Option[(MapType, Boolean, Boolean)] = child.dataType match {
case ArrayType(
StructType(Array(
StructField(_, keyType, keyNullable, _),
StructField(_, valueType, valueNullable, _))),
containsNull) => Some((MapType(keyType, valueType, valueNullable), keyNullable, containsNull))
case _ => None
}

private def nullEntries: Boolean = dataTypeDetails.get._3

override def nullable: Boolean = child.nullable || nullEntries

override def dataType: MapType = dataTypeDetails.get._1

override def checkInputDataTypes(): TypeCheckResult = dataTypeDetails match {
case Some(_) => TypeCheckResult.TypeCheckSuccess
case None => TypeCheckResult.TypeCheckFailure(s"'${child.sql}' is of " +
s"${child.dataType.simpleString} type. $prettyName accepts only arrays of pair structs.")
}

override protected def nullSafeEval(input: Any): Any = {
val arrayData = input.asInstanceOf[ArrayData]
val numEntries = arrayData.numElements()
var i = 0
if(nullEntries) {
while (i < numEntries) {
if (arrayData.isNullAt(i)) return null
i += 1
}
}
val keyArray = new Array[AnyRef](numEntries)
val valueArray = new Array[AnyRef](numEntries)
i = 0
while (i < numEntries) {
val entry = arrayData.getStruct(i, 2)
val key = entry.get(0, dataType.keyType)
if (key == null) {
throw new RuntimeException("The first field from a struct (key) can't be null.")
}
keyArray.update(i, key)
val value = entry.get(1, dataType.valueType)
valueArray.update(i, value)
i += 1
}
ArrayBasedMapData(keyArray, valueArray)
}

override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
nullSafeCodeGen(ctx, ev, c => {
val numEntries = ctx.freshName("numEntries")
val isKeyPrimitive = CodeGenerator.isPrimitiveType(dataType.keyType)
val isValuePrimitive = CodeGenerator.isPrimitiveType(dataType.valueType)
val code = if (isKeyPrimitive && isValuePrimitive) {
genCodeForPrimitiveElements(ctx, c, ev.value, numEntries)
} else {
genCodeForAnyElements(ctx, c, ev.value, numEntries)
}
ctx.nullArrayElementsSaveExec(nullEntries, ev.isNull, c) {
s"""
|final int $numEntries = $c.numElements();
|$code
""".stripMargin
}
})
}

private def genCodeForAssignmentLoop(
ctx: CodegenContext,
childVariable: String,
mapData: String,
numEntries: String,
keyAssignment: (String, String) => String,
valueAssignment: (String, String) => String): String = {
val entry = ctx.freshName("entry")
val i = ctx.freshName("idx")

val nullKeyCheck = if (dataTypeDetails.get._2) {
s"""
|if ($entry.isNullAt(0)) {
| throw new RuntimeException("The first field from a struct (key) can't be null.");
|}
""".stripMargin
} else {
""
}

s"""
|for (int $i = 0; $i < $numEntries; $i++) {
| InternalRow $entry = $childVariable.getStruct($i, 2);
| $nullKeyCheck
| ${keyAssignment(CodeGenerator.getValue(entry, dataType.keyType, "0"), i)}
| ${valueAssignment(entry, i)}
|}
""".stripMargin
}

private def genCodeForPrimitiveElements(
ctx: CodegenContext,
childVariable: String,
mapData: String,
numEntries: String): String = {
val byteArraySize = ctx.freshName("byteArraySize")
val keySectionSize = ctx.freshName("keySectionSize")
val valueSectionSize = ctx.freshName("valueSectionSize")
val data = ctx.freshName("byteArray")
val unsafeMapData = ctx.freshName("unsafeMapData")
val keyArrayData = ctx.freshName("keyArrayData")
val valueArrayData = ctx.freshName("valueArrayData")

val baseOffset = Platform.BYTE_ARRAY_OFFSET
val keySize = dataType.keyType.defaultSize
val valueSize = dataType.valueType.defaultSize
val kByteSize = s"UnsafeArrayData.calculateSizeOfUnderlyingByteArray($numEntries, $keySize)"
val vByteSize = s"UnsafeArrayData.calculateSizeOfUnderlyingByteArray($numEntries, $valueSize)"
val keyTypeName = CodeGenerator.primitiveTypeName(dataType.keyType)
val valueTypeName = CodeGenerator.primitiveTypeName(dataType.valueType)

val keyAssignment = (key: String, idx: String) => s"$keyArrayData.set$keyTypeName($idx, $key);"
val valueAssignment = (entry: String, idx: String) => {
val value = CodeGenerator.getValue(entry, dataType.valueType, "1")
val valueNullUnsafeAssignment = s"$valueArrayData.set$valueTypeName($idx, $value);"
if (dataType.valueContainsNull) {
s"""
|if ($entry.isNullAt(1)) {
| $valueArrayData.setNullAt($idx);
|} else {
| $valueNullUnsafeAssignment
|}
""".stripMargin
} else {
valueNullUnsafeAssignment
}
}
val assignmentLoop = genCodeForAssignmentLoop(
ctx,
childVariable,
mapData,
numEntries,
keyAssignment,
valueAssignment
)

s"""
|final long $keySectionSize = $kByteSize;
|final long $valueSectionSize = $vByteSize;
|final long $byteArraySize = 8 + $keySectionSize + $valueSectionSize;
|if ($byteArraySize > ${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH}) {
| ${genCodeForAnyElements(ctx, childVariable, mapData, numEntries)}
|} else {
| final byte[] $data = new byte[(int)$byteArraySize];
| UnsafeMapData $unsafeMapData = new UnsafeMapData();
| Platform.putLong($data, $baseOffset, $keySectionSize);
| Platform.putLong($data, ${baseOffset + 8}, $numEntries);
| Platform.putLong($data, ${baseOffset + 8} + $keySectionSize, $numEntries);
| $unsafeMapData.pointTo($data, $baseOffset, (int)$byteArraySize);
| ArrayData $keyArrayData = $unsafeMapData.keyArray();
| ArrayData $valueArrayData = $unsafeMapData.valueArray();
| $assignmentLoop
| $mapData = $unsafeMapData;
|}
""".stripMargin
}

private def genCodeForAnyElements(
ctx: CodegenContext,
childVariable: String,
mapData: String,
numEntries: String): String = {
val keys = ctx.freshName("keys")
val values = ctx.freshName("values")
val mapDataClass = classOf[ArrayBasedMapData].getName()

val isValuePrimitive = CodeGenerator.isPrimitiveType(dataType.valueType)
val valueAssignment = (entry: String, idx: String) => {
val value = CodeGenerator.getValue(entry, dataType.valueType, "1")
if (dataType.valueContainsNull && isValuePrimitive) {
s"$values[$idx] = $entry.isNullAt(1) ? null : (Object)$value;"
} else {
s"$values[$idx] = $value;"
}
}
val keyAssignment = (key: String, idx: String) => s"$keys[$idx] = $key;"
val assignmentLoop = genCodeForAssignmentLoop(
ctx,
childVariable,
mapData,
numEntries,
keyAssignment,
valueAssignment)

s"""
|final Object[] $keys = new Object[$numEntries];
|final Object[] $values = new Object[$numEntries];
|$assignmentLoop
|$mapData = $mapDataClass.apply($keys, $values);
""".stripMargin
}

override def prettyName: String = "map_from_entries"
}


/**
* Common base class for [[SortArray]] and [[ArraySort]].
*/
Expand Down Expand Up @@ -1990,24 +2207,10 @@ case class Flatten(child: Expression) extends UnaryExpression {
} else {
genCodeForFlattenOfNonPrimitiveElements(ctx, c, ev.value)
}
if (childDataType.containsNull) nullElementsProtection(ev, c, code) else code
ctx.nullArrayElementsSaveExec(childDataType.containsNull, ev.isNull, c)(code)
})
}

private def nullElementsProtection(
ev: ExprCode,
childVariableName: String,
coreLogic: String): String = {
s"""
|for (int z = 0; !${ev.isNull} && z < $childVariableName.numElements(); z++) {
| ${ev.isNull} |= $childVariableName.isNullAt(z);
|}
|if (!${ev.isNull}) {
| $coreLogic
|}
""".stripMargin
}

private def genCodeForNumberOfElements(
ctx: CodegenContext,
childVariableName: String) : (String, String) = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,57 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper
checkEvaluation(MapEntries(ms2), null)
}

test("MapFromEntries") {
def arrayType(keyType: DataType, valueType: DataType) : DataType = {
ArrayType(
StructType(Seq(
StructField("a", keyType),
StructField("b", valueType))),
true)
}
def r(values: Any*): InternalRow = create_row(values: _*)

// Primitive-type keys and values
val aiType = arrayType(IntegerType, IntegerType)
val ai0 = Literal.create(Seq(r(1, 10), r(2, 20), r(3, 20)), aiType)
val ai1 = Literal.create(Seq(r(1, null), r(2, 20), r(3, null)), aiType)
val ai2 = Literal.create(Seq.empty, aiType)
val ai3 = Literal.create(null, aiType)
val ai4 = Literal.create(Seq(r(1, 10), r(1, 20)), aiType)
val ai5 = Literal.create(Seq(r(1, 10), r(null, 20)), aiType)
val ai6 = Literal.create(Seq(null, r(2, 20), null), aiType)

checkEvaluation(MapFromEntries(ai0), Map(1 -> 10, 2 -> 20, 3 -> 20))
checkEvaluation(MapFromEntries(ai1), Map(1 -> null, 2 -> 20, 3 -> null))
checkEvaluation(MapFromEntries(ai2), Map.empty)
checkEvaluation(MapFromEntries(ai3), null)
checkEvaluation(MapKeys(MapFromEntries(ai4)), Seq(1, 1))
checkExceptionInExpression[RuntimeException](
MapFromEntries(ai5),
"The first field from a struct (key) can't be null.")
checkEvaluation(MapFromEntries(ai6), null)

// Non-primitive-type keys and values
val asType = arrayType(StringType, StringType)
val as0 = Literal.create(Seq(r("a", "aa"), r("b", "bb"), r("c", "bb")), asType)
val as1 = Literal.create(Seq(r("a", null), r("b", "bb"), r("c", null)), asType)
val as2 = Literal.create(Seq.empty, asType)
val as3 = Literal.create(null, asType)
val as4 = Literal.create(Seq(r("a", "aa"), r("a", "bb")), asType)
val as5 = Literal.create(Seq(r("a", "aa"), r(null, "bb")), asType)
val as6 = Literal.create(Seq(null, r("b", "bb"), null), asType)

checkEvaluation(MapFromEntries(as0), Map("a" -> "aa", "b" -> "bb", "c" -> "bb"))
checkEvaluation(MapFromEntries(as1), Map("a" -> null, "b" -> "bb", "c" -> null))
checkEvaluation(MapFromEntries(as2), Map.empty)
checkEvaluation(MapFromEntries(as3), null)
checkEvaluation(MapKeys(MapFromEntries(as4)), Seq("a", "a"))
checkExceptionInExpression[RuntimeException](
MapFromEntries(as5),
"The first field from a struct (key) can't be null.")
checkEvaluation(MapFromEntries(as6), null)
}

test("Sort Array") {
val a0 = Literal.create(Seq(2, 1, 3), ArrayType(IntegerType))
val a1 = Literal.create(Seq[Integer](), ArrayType(IntegerType))
Expand Down
Loading