From e4171e1d62b3e34baf4cb1805e12b4aa31a4a811 Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Mon, 7 May 2018 12:06:56 +0100 Subject: [PATCH 01/13] initial commit --- python/pyspark/sql/functions.py | 16 +++++ .../catalyst/analysis/FunctionRegistry.scala | 1 + .../expressions/complexTypeCreator.scala | 65 ++++++++++++++++++- .../expressions/ComplexTypeSuite.scala | 31 +++++++++ .../org/apache/spark/sql/functions.scala | 11 ++++ .../spark/sql/DataFrameFunctionsSuite.scala | 28 ++++++++ 6 files changed, 151 insertions(+), 1 deletion(-) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 1759195c6fcc..b5ea0f8a4e7d 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -1819,6 +1819,22 @@ def create_map(*cols): return Column(jc) +@ignore_unicode_prefix +@since(2.4) +def create_map_fromarray(col1, col2): + """Creates a new map from two arrays. + + :param col1: name of column containing a set of keys. All elements should not be null + :param col2: name of column containing a set of values + + >>> df = spark.createDataFrame([([2, 5], ["Alice", "Bob"])], ['k', 'v']) + >>> df.select(create_map(df.k, df.v).alias("map")).collect() + [Row(map={2: u'Alice', 5: u'Bob'})] + """ + sc = SparkContext._active_spark_context + return Column(sc._jvm.functions.map(_to_java_column(col1), _to_java_column(col2))) + + @since(1.4) def array(*cols): """Creates a new array column. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index 49fb35b08358..9fec11644f43 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -415,6 +415,7 @@ object FunctionRegistry { expression[ArrayPosition]("array_position"), expression[ArraySort]("array_sort"), expression[CreateMap]("map"), + expression[CreateMapFromArray]("map_fromarray"), expression[CreateNamedStruct]("named_struct"), expression[ElementAt]("element_at"), expression[MapKeys]("map_keys"), diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala index a9867aaeb0cf..199e0468c9e7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala @@ -22,7 +22,7 @@ import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.expressions.codegen.Block._ -import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, GenericArrayData, TypeUtils} +import org.apache.spark.sql.catalyst.util._ import org.apache.spark.sql.types._ import org.apache.spark.unsafe.Platform import org.apache.spark.unsafe.array.ByteArrayMethods @@ -236,6 +236,69 @@ case class CreateMap(children: Seq[Expression]) extends Expression { override def prettyName: String = "map" } +/** + * Returns a catalyst Map containing the two arrays in children expressions as keys and values. + */ +@ExpressionDescription( + usage = """ + _FUNC_(keys, values) - Creates a map with a pair of the given key/value arrays. All elements + in keys should not be null""", + examples = """ + Examples: + > SELECT _FUNC_([1.0, 3.0], ['2', '4']); + {1.0:"2",3.0:"4"} + """, since = "2.4.0") +case class CreateMapFromArray(left: Expression, right: Expression) + extends BinaryExpression with ExpectsInputTypes { + + override def inputTypes: Seq[AbstractDataType] = Seq(ArrayType, ArrayType) + + override def checkInputDataTypes(): TypeCheckResult = { + (left.dataType, right.dataType) match { + case (ArrayType(_, cn), ArrayType(_, _)) => + if (!cn) { + TypeCheckResult.TypeCheckSuccess + } else { + TypeCheckResult.TypeCheckFailure("All of the given keys should be non-null") + } + case _ => + TypeCheckResult.TypeCheckFailure("The given two arguments should be an array") + } + } + + override def dataType: DataType = { + MapType( + keyType = left.dataType.asInstanceOf[ArrayType].elementType, + valueType = right.dataType.asInstanceOf[ArrayType].elementType, + valueContainsNull = left.dataType.asInstanceOf[ArrayType].containsNull) + } + + override def nullable: Boolean = false + + override def nullSafeEval(keyArray: Any, valueArray: Any): Any = { + val keyArrayData = keyArray.asInstanceOf[ArrayData] + val valueArrayData = valueArray.asInstanceOf[ArrayData] + if (keyArrayData.numElements != valueArrayData.numElements) { + throw new RuntimeException("The given two arrays should have the same length") + } + new ArrayBasedMapData(keyArrayData.copy(), valueArrayData.copy()) + } + + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + nullSafeCodeGen(ctx, ev, (keyArrayData, valueArrayData) => { + val arrayBasedMapData = classOf[ArrayBasedMapData].getName + s""" + |if ($keyArrayData.numElements() != $valueArrayData.numElements()) { + | throw new RuntimeException("The given two arrays should have the same length"); + |} + |${ev.value} = new $arrayBasedMapData($keyArrayData.copy(), $valueArrayData.copy()); + """ + }) + } + + override def prettyName: String = "map" +} + /** * An expression representing a not yet available attribute name. This expression is unevaluable * and as its name suggests it is a temporary place holder until we're able to determine the diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala index b4138ce366b3..4ef78022dc44 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala @@ -186,6 +186,37 @@ class ComplexTypeSuite extends SparkFunSuite with ExpressionEvalHelper { } } + test("CreateMapFromArray") { + def createMap(keys: Seq[Any], values: Seq[Any]): Map[Any, Any] = { + // catalyst map is order-sensitive, so we create ListMap here to preserve the elements order. + scala.collection.immutable.ListMap(keys.zip(values): _*) + } + + val intSeq = Seq(5, 10, 15, 20, 25) + val longSeq = intSeq.map(_.toLong) + val strSeq = intSeq.map(_.toString) + val intWithNullSeq = Seq[java.lang.Integer](5, 10, null, 20, 25) + val longWithNullSeq = intSeq.map(java.lang.Long.valueOf(_)) + + val intArray = Literal.create(intSeq, ArrayType(IntegerType, false)) + val longArray = Literal.create(longSeq, ArrayType(LongType, false)) + val strArray = Literal.create(strSeq, ArrayType(StringType, false)) + + val intwithNullArray = Literal.create(intWithNullSeq, ArrayType(IntegerType, true)) + val longwithNullArray = Literal.create(longWithNullSeq, ArrayType(LongType, true)) + + checkEvaluation(CreateMapFromArray(intArray, longArray), createMap(intSeq, longSeq)) + checkEvaluation(CreateMapFromArray(intArray, strArray), createMap(intSeq, strSeq)) + checkEvaluation( + CreateMapFromArray(strArray, intwithNullArray), createMap(strSeq, intWithNullSeq)) + checkEvaluation( + CreateMapFromArray(strArray, longwithNullArray), createMap(strSeq, longWithNullSeq)) + intercept[RuntimeException] { + checkEvaluation( + CreateMapFromArray(intArray, Literal.create(Seq(1), ArrayType(IntegerType))), null) + } + } + test("CreateStruct") { val row = create_row(1, 2, 3) val c1 = 'a.int.at(0) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index a2aae9a708ff..188b9025f907 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -1070,6 +1070,17 @@ object functions { @scala.annotation.varargs def map(cols: Column*): Column = withExpr { CreateMap(cols.map(_.expr)) } + /** + * Creates a new map column. The array in the first column is used for keys. The array in the + * second column is used for values. All elements in the array for key should not be null. + * + * @group normal_funcs + * @since 2.4 + */ + def map_fromarray(keys: Column, values: Column): Column = withExpr { + CreateMapFromArray(keys.expr, values.expr) + } + /** * Marks a DataFrame as small enough for use in broadcast joins. * diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index 59119bbbd8a2..d06d0b110e42 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -62,6 +62,34 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { assert(row.getMap[Int, String](0) === Map(2 -> "a")) } + test("map with array") { + val df1 = Seq((Seq(1, 2), Seq("a", "b"))).toDF("k", "v") + checkAnswer(df1.select(map_fromarray($"k", $"v")), Seq(Row(Map(1 -> "a", 2 -> "b")))) + + val df2 = Seq((Seq(1, 2), Seq(null, "b"))).toDF("k", "v") + checkAnswer(df2.select(map_fromarray($"k", $"v")), Seq(Row(Map(1 -> null, 2 -> "b")))) + + val df3 = Seq((Seq("a", null), Seq(1, 2))).toDF("k", "v") + intercept[AnalysisException] { + df3.select(map_fromarray($"k", $"v")) + } + + val df4 = Seq((1, "a")).toDF("k", "v") + intercept[AnalysisException] { + df4.select(map_fromarray($"k", $"v")) + } + + val df5 = Seq((null, null)).toDF("k", "v") + intercept[AnalysisException] { + df5.select(map_fromarray($"k", $"v")) + } + + val df6 = Seq((Seq(1, 2), Seq("a"))).toDF("k", "v") + intercept[RuntimeException] { + df6.select(map_fromarray($"k", $"v")).collect + } + } + test("struct with column name") { val df = Seq((1, "str")).toDF("a", "b") val row = df.select(struct("a", "b")).first() From 95d92d857013892eca8b8b943e56dfb319bc838b Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Mon, 7 May 2018 20:18:47 +0100 Subject: [PATCH 02/13] fix pyspark test failure --- python/pyspark/sql/functions.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index b5ea0f8a4e7d..8c79ce369140 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -1828,7 +1828,7 @@ def create_map_fromarray(col1, col2): :param col2: name of column containing a set of values >>> df = spark.createDataFrame([([2, 5], ["Alice", "Bob"])], ['k', 'v']) - >>> df.select(create_map(df.k, df.v).alias("map")).collect() + >>> df.select(create_map_fromarray(df.k, df.v).alias("map")).collect() [Row(map={2: u'Alice', 5: u'Bob'})] """ sc = SparkContext._active_spark_context From 1df6bb513868926df7713d3e277346d21059475c Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Mon, 7 May 2018 20:19:15 +0100 Subject: [PATCH 03/13] address review comments --- .../spark/sql/catalyst/expressions/complexTypeCreator.scala | 4 ++-- .../spark/sql/catalyst/expressions/ComplexTypeSuite.scala | 6 ++++++ .../org/apache/spark/sql/DataFrameFunctionsSuite.scala | 4 ++++ 3 files changed, 12 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala index 199e0468c9e7..c50bc5c26f06 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala @@ -270,10 +270,10 @@ case class CreateMapFromArray(left: Expression, right: Expression) MapType( keyType = left.dataType.asInstanceOf[ArrayType].elementType, valueType = right.dataType.asInstanceOf[ArrayType].elementType, - valueContainsNull = left.dataType.asInstanceOf[ArrayType].containsNull) + valueContainsNull = right.dataType.asInstanceOf[ArrayType].containsNull) } - override def nullable: Boolean = false + override def nullable: Boolean = left.nullable || right.nullable override def nullSafeEval(keyArray: Any, valueArray: Any): Any = { val keyArrayData = keyArray.asInstanceOf[ArrayData] diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala index 4ef78022dc44..69f3cf36e3a1 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala @@ -205,12 +205,18 @@ class ComplexTypeSuite extends SparkFunSuite with ExpressionEvalHelper { val intwithNullArray = Literal.create(intWithNullSeq, ArrayType(IntegerType, true)) val longwithNullArray = Literal.create(longWithNullSeq, ArrayType(LongType, true)) + val nullArray = Literal.create(null, ArrayType(StringType, true)) + checkEvaluation(CreateMapFromArray(intArray, longArray), createMap(intSeq, longSeq)) checkEvaluation(CreateMapFromArray(intArray, strArray), createMap(intSeq, strSeq)) checkEvaluation( CreateMapFromArray(strArray, intwithNullArray), createMap(strSeq, intWithNullSeq)) checkEvaluation( CreateMapFromArray(strArray, longwithNullArray), createMap(strSeq, longWithNullSeq)) + checkEvaluation( + CreateMapFromArray(strArray, longwithNullArray), createMap(strSeq, longWithNullSeq)) + checkEvaluation(CreateMapFromArray(nullArray, nullArray), null) + intercept[RuntimeException] { checkEvaluation( CreateMapFromArray(intArray, Literal.create(Seq(1), ArrayType(IntegerType))), null) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index d06d0b110e42..b6c6098c2ae9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -64,6 +64,10 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { test("map with array") { val df1 = Seq((Seq(1, 2), Seq("a", "b"))).toDF("k", "v") + val expectedType = MapType(IntegerType, StringType, valueContainsNull = true) + val row = df1.select(map_fromarray($"k", $"v")).first() + assert(row.schema(0).dataType === expectedType) + assert(row.getMap[Int, String](0) === Map(1 -> "a", 2 -> "b")) checkAnswer(df1.select(map_fromarray($"k", $"v")), Seq(Row(Map(1 -> "a", 2 -> "b")))) val df2 = Seq((Seq(1, 2), Seq(null, "b"))).toDF("k", "v") From 2075770efc17e2481e2a539adfac8bd72e421ee3 Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Tue, 15 May 2018 18:53:15 +0100 Subject: [PATCH 04/13] address review comments fix test failure --- python/pyspark/sql/functions.py | 7 ++++--- .../catalyst/analysis/FunctionRegistry.scala | 2 +- .../expressions/complexTypeCreator.scala | 4 ++-- .../expressions/ComplexTypeSuite.scala | 20 ++++++++++--------- .../org/apache/spark/sql/functions.scala | 4 ++-- .../spark/sql/DataFrameFunctionsSuite.scala | 16 +++++++-------- 6 files changed, 28 insertions(+), 25 deletions(-) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 8c79ce369140..f0f52af95ca0 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -1821,18 +1821,19 @@ def create_map(*cols): @ignore_unicode_prefix @since(2.4) -def create_map_fromarray(col1, col2): +def create_map_from_arrays(col1, col2): """Creates a new map from two arrays. :param col1: name of column containing a set of keys. All elements should not be null :param col2: name of column containing a set of values >>> df = spark.createDataFrame([([2, 5], ["Alice", "Bob"])], ['k', 'v']) - >>> df.select(create_map_fromarray(df.k, df.v).alias("map")).collect() + >>> df.select(create_map_from_arrays(df.k, df.v).alias("map")).collect() [Row(map={2: u'Alice', 5: u'Bob'})] """ sc = SparkContext._active_spark_context - return Column(sc._jvm.functions.map(_to_java_column(col1), _to_java_column(col2))) + return Column(sc._jvm.functions.create_map_from_arrays( + _to_java_column(col1), _to_java_column(col2))) @since(1.4) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index 9fec11644f43..626fa83a1c4a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -415,7 +415,7 @@ object FunctionRegistry { expression[ArrayPosition]("array_position"), expression[ArraySort]("array_sort"), expression[CreateMap]("map"), - expression[CreateMapFromArray]("map_fromarray"), + expression[CreateMapFromArrays]("map_from_arrays"), expression[CreateNamedStruct]("named_struct"), expression[ElementAt]("element_at"), expression[MapKeys]("map_keys"), diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala index c50bc5c26f06..17a5055ae5d1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala @@ -248,7 +248,7 @@ case class CreateMap(children: Seq[Expression]) extends Expression { > SELECT _FUNC_([1.0, 3.0], ['2', '4']); {1.0:"2",3.0:"4"} """, since = "2.4.0") -case class CreateMapFromArray(left: Expression, right: Expression) +case class CreateMapFromArrays(left: Expression, right: Expression) extends BinaryExpression with ExpectsInputTypes { override def inputTypes: Seq[AbstractDataType] = Seq(ArrayType, ArrayType) @@ -296,7 +296,7 @@ case class CreateMapFromArray(left: Expression, right: Expression) }) } - override def prettyName: String = "map" + override def prettyName: String = "create_map_from_arrays" } /** diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala index 69f3cf36e3a1..92ef0d011e73 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala @@ -186,7 +186,7 @@ class ComplexTypeSuite extends SparkFunSuite with ExpressionEvalHelper { } } - test("CreateMapFromArray") { + test("CreateMapFromArrays") { def createMap(keys: Seq[Any], values: Seq[Any]): Map[Any, Any] = { // catalyst map is order-sensitive, so we create ListMap here to preserve the elements order. scala.collection.immutable.ListMap(keys.zip(values): _*) @@ -195,6 +195,7 @@ class ComplexTypeSuite extends SparkFunSuite with ExpressionEvalHelper { val intSeq = Seq(5, 10, 15, 20, 25) val longSeq = intSeq.map(_.toLong) val strSeq = intSeq.map(_.toString) + val intDupSeq = Seq(5, 10, 15, 15, 5) val intWithNullSeq = Seq[java.lang.Integer](5, 10, null, 20, 25) val longWithNullSeq = intSeq.map(java.lang.Long.valueOf(_)) @@ -205,21 +206,22 @@ class ComplexTypeSuite extends SparkFunSuite with ExpressionEvalHelper { val intwithNullArray = Literal.create(intWithNullSeq, ArrayType(IntegerType, true)) val longwithNullArray = Literal.create(longWithNullSeq, ArrayType(LongType, true)) - val nullArray = Literal.create(null, ArrayType(StringType, true)) + val nullArray = Literal.create(null, ArrayType(StringType, false)) + + checkEvaluation(CreateMapFromArrays(intArray, longArray), createMap(intSeq, longSeq)) + checkEvaluation(CreateMapFromArrays(intArray, strArray), createMap(intSeq, strSeq)) - checkEvaluation(CreateMapFromArray(intArray, longArray), createMap(intSeq, longSeq)) - checkEvaluation(CreateMapFromArray(intArray, strArray), createMap(intSeq, strSeq)) checkEvaluation( - CreateMapFromArray(strArray, intwithNullArray), createMap(strSeq, intWithNullSeq)) + CreateMapFromArrays(strArray, intwithNullArray), createMap(strSeq, intWithNullSeq)) checkEvaluation( - CreateMapFromArray(strArray, longwithNullArray), createMap(strSeq, longWithNullSeq)) + CreateMapFromArrays(strArray, longwithNullArray), createMap(strSeq, longWithNullSeq)) checkEvaluation( - CreateMapFromArray(strArray, longwithNullArray), createMap(strSeq, longWithNullSeq)) - checkEvaluation(CreateMapFromArray(nullArray, nullArray), null) + CreateMapFromArrays(strArray, longwithNullArray), createMap(strSeq, longWithNullSeq)) + checkEvaluation(CreateMapFromArrays(nullArray, nullArray), null) intercept[RuntimeException] { checkEvaluation( - CreateMapFromArray(intArray, Literal.create(Seq(1), ArrayType(IntegerType))), null) + CreateMapFromArrays(intArray, Literal.create(Seq(1), ArrayType(IntegerType))), null) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 188b9025f907..c77f3e46d70f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -1077,8 +1077,8 @@ object functions { * @group normal_funcs * @since 2.4 */ - def map_fromarray(keys: Column, values: Column): Column = withExpr { - CreateMapFromArray(keys.expr, values.expr) + def map_from_arrays(keys: Column, values: Column): Column = withExpr { + CreateMapFromArrays(keys.expr, values.expr) } /** diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index b6c6098c2ae9..355a48f330da 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -62,35 +62,35 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { assert(row.getMap[Int, String](0) === Map(2 -> "a")) } - test("map with array") { + test("map with arrays") { val df1 = Seq((Seq(1, 2), Seq("a", "b"))).toDF("k", "v") val expectedType = MapType(IntegerType, StringType, valueContainsNull = true) - val row = df1.select(map_fromarray($"k", $"v")).first() + val row = df1.select(map_from_arrays($"k", $"v")).first() assert(row.schema(0).dataType === expectedType) assert(row.getMap[Int, String](0) === Map(1 -> "a", 2 -> "b")) - checkAnswer(df1.select(map_fromarray($"k", $"v")), Seq(Row(Map(1 -> "a", 2 -> "b")))) + checkAnswer(df1.select(map_from_arrays($"k", $"v")), Seq(Row(Map(1 -> "a", 2 -> "b")))) val df2 = Seq((Seq(1, 2), Seq(null, "b"))).toDF("k", "v") - checkAnswer(df2.select(map_fromarray($"k", $"v")), Seq(Row(Map(1 -> null, 2 -> "b")))) + checkAnswer(df2.select(map_from_arrays($"k", $"v")), Seq(Row(Map(1 -> null, 2 -> "b")))) val df3 = Seq((Seq("a", null), Seq(1, 2))).toDF("k", "v") intercept[AnalysisException] { - df3.select(map_fromarray($"k", $"v")) + df3.select(map_from_arrays($"k", $"v")) } val df4 = Seq((1, "a")).toDF("k", "v") intercept[AnalysisException] { - df4.select(map_fromarray($"k", $"v")) + df4.select(map_from_arrays($"k", $"v")) } val df5 = Seq((null, null)).toDF("k", "v") intercept[AnalysisException] { - df5.select(map_fromarray($"k", $"v")) + df5.select(map_from_arrays($"k", $"v")) } val df6 = Seq((Seq(1, 2), Seq("a"))).toDF("k", "v") intercept[RuntimeException] { - df6.select(map_fromarray($"k", $"v")).collect + df6.select(map_from_arrays($"k", $"v")).collect } } From d5ff7becfaebd034daa12bda6206e38bf8272456 Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Wed, 16 May 2018 03:28:49 +0100 Subject: [PATCH 05/13] fix pyspark test failure --- python/pyspark/sql/functions.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index f0f52af95ca0..09ad207a2fa8 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -1832,8 +1832,7 @@ def create_map_from_arrays(col1, col2): [Row(map={2: u'Alice', 5: u'Bob'})] """ sc = SparkContext._active_spark_context - return Column(sc._jvm.functions.create_map_from_arrays( - _to_java_column(col1), _to_java_column(col2))) + return Column(sc._jvm.functions.map_from_arrays(_to_java_column(col1), _to_java_column(col2))) @since(1.4) From 4eee89da3a80f679b8ff0a631c4374ae6fa0de86 Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Thu, 17 May 2018 15:32:24 +0100 Subject: [PATCH 06/13] fix pyspark test failure --- .../expressions/complexTypeCreator.scala | 31 ++++++++++++++----- .../expressions/ComplexTypeSuite.scala | 7 ++++- .../spark/sql/DataFrameFunctionsSuite.scala | 12 +++---- 3 files changed, 35 insertions(+), 15 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala index 17a5055ae5d1..0e0c0636c73a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala @@ -255,12 +255,8 @@ case class CreateMapFromArrays(left: Expression, right: Expression) override def checkInputDataTypes(): TypeCheckResult = { (left.dataType, right.dataType) match { - case (ArrayType(_, cn), ArrayType(_, _)) => - if (!cn) { - TypeCheckResult.TypeCheckSuccess - } else { - TypeCheckResult.TypeCheckFailure("All of the given keys should be non-null") - } + case (ArrayType(_, _), ArrayType(_, _)) => + TypeCheckResult.TypeCheckSuccess case _ => TypeCheckResult.TypeCheckFailure("The given two arguments should be an array") } @@ -281,18 +277,39 @@ case class CreateMapFromArrays(left: Expression, right: Expression) if (keyArrayData.numElements != valueArrayData.numElements) { throw new RuntimeException("The given two arrays should have the same length") } + val leftArrayType = left.dataType.asInstanceOf[ArrayType] + if (leftArrayType.containsNull) { + if (keyArrayData.toArray(leftArrayType.elementType).contains(null)) { + throw new RuntimeException("Cannot use null as map key!") + } + } new ArrayBasedMapData(keyArrayData.copy(), valueArrayData.copy()) } override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { nullSafeCodeGen(ctx, ev, (keyArrayData, valueArrayData) => { val arrayBasedMapData = classOf[ArrayBasedMapData].getName + val leftArrayType = left.dataType.asInstanceOf[ArrayType] + val keyArrayElemNullCheck = if (!leftArrayType.containsNull) "" else { + val leftArrayTypeTerm = ctx.addReferenceObj("leftArrayType", leftArrayType.elementType) + val array = ctx.freshName("array") + val i = ctx.freshName("i") + s""" + |Object[] $array = $keyArrayData.toObjectArray($leftArrayTypeTerm); + |for (int $i = 0; $i < $array.length; $i++) { + | if ($array[$i] == null) { + | throw new RuntimeException("Cannot use null as map key!"); + | } + |} + """.stripMargin + } s""" |if ($keyArrayData.numElements() != $valueArrayData.numElements()) { | throw new RuntimeException("The given two arrays should have the same length"); |} + |$keyArrayElemNullCheck |${ev.value} = new $arrayBasedMapData($keyArrayData.copy(), $valueArrayData.copy()); - """ + """.stripMargin }) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala index 92ef0d011e73..9bc883646725 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala @@ -195,7 +195,7 @@ class ComplexTypeSuite extends SparkFunSuite with ExpressionEvalHelper { val intSeq = Seq(5, 10, 15, 20, 25) val longSeq = intSeq.map(_.toLong) val strSeq = intSeq.map(_.toString) - val intDupSeq = Seq(5, 10, 15, 15, 5) + val integerSeq = Seq[java.lang.Integer](5, 10, 15, 20, 25) val intWithNullSeq = Seq[java.lang.Integer](5, 10, null, 20, 25) val longWithNullSeq = intSeq.map(java.lang.Long.valueOf(_)) @@ -203,6 +203,7 @@ class ComplexTypeSuite extends SparkFunSuite with ExpressionEvalHelper { val longArray = Literal.create(longSeq, ArrayType(LongType, false)) val strArray = Literal.create(strSeq, ArrayType(StringType, false)) + val integerArray = Literal.create(integerSeq, ArrayType(IntegerType, true)) val intwithNullArray = Literal.create(intWithNullSeq, ArrayType(IntegerType, true)) val longwithNullArray = Literal.create(longWithNullSeq, ArrayType(LongType, true)) @@ -210,6 +211,7 @@ class ComplexTypeSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(CreateMapFromArrays(intArray, longArray), createMap(intSeq, longSeq)) checkEvaluation(CreateMapFromArrays(intArray, strArray), createMap(intSeq, strSeq)) + checkEvaluation(CreateMapFromArrays(integerArray, strArray), createMap(integerSeq, strSeq)) checkEvaluation( CreateMapFromArrays(strArray, intwithNullArray), createMap(strSeq, intWithNullSeq)) @@ -219,6 +221,9 @@ class ComplexTypeSuite extends SparkFunSuite with ExpressionEvalHelper { CreateMapFromArrays(strArray, longwithNullArray), createMap(strSeq, longWithNullSeq)) checkEvaluation(CreateMapFromArrays(nullArray, nullArray), null) + intercept[RuntimeException] { + checkEvaluation(CreateMapFromArrays(intwithNullArray, strArray), null) + } intercept[RuntimeException] { checkEvaluation( CreateMapFromArrays(intArray, Literal.create(Seq(1), ArrayType(IntegerType))), null) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index 355a48f330da..ad26e1f72dab 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -73,19 +73,17 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { val df2 = Seq((Seq(1, 2), Seq(null, "b"))).toDF("k", "v") checkAnswer(df2.select(map_from_arrays($"k", $"v")), Seq(Row(Map(1 -> null, 2 -> "b")))) - val df3 = Seq((Seq("a", null), Seq(1, 2))).toDF("k", "v") - intercept[AnalysisException] { - df3.select(map_from_arrays($"k", $"v")) - } + val df3 = Seq((null, null)).toDF("k", "v") + checkAnswer(df3.select(map_from_arrays($"k", $"v")), Seq(Row(null))) val df4 = Seq((1, "a")).toDF("k", "v") intercept[AnalysisException] { df4.select(map_from_arrays($"k", $"v")) } - val df5 = Seq((null, null)).toDF("k", "v") - intercept[AnalysisException] { - df5.select(map_from_arrays($"k", $"v")) + val df5 = Seq((Seq("a", null), Seq(1, 2))).toDF("k", "v") + intercept[RuntimeException] { + df5.select(map_from_arrays($"k", $"v")).collect } val df6 = Seq((Seq(1, 2), Seq("a"))).toDF("k", "v") From 7b66ab4d14f9840f4ca708f532c4145461618ef3 Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Thu, 17 May 2018 19:55:17 +0100 Subject: [PATCH 07/13] fix pyspark test failure --- python/pyspark/sql/functions.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 09ad207a2fa8..d03de4207cd6 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -1829,7 +1829,7 @@ def create_map_from_arrays(col1, col2): >>> df = spark.createDataFrame([([2, 5], ["Alice", "Bob"])], ['k', 'v']) >>> df.select(create_map_from_arrays(df.k, df.v).alias("map")).collect() - [Row(map={2: u'Alice', 5: u'Bob'})] + [Row(map={5: u'Bob', 2: u'Alice'})] """ sc = SparkContext._active_spark_context return Column(sc._jvm.functions.map_from_arrays(_to_java_column(col1), _to_java_column(col2))) From 2fcbb805801ebe94616e5c3ae2e4f76cd73e5f02 Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Fri, 18 May 2018 20:41:08 +0100 Subject: [PATCH 08/13] fix pyspark test faiure --- python/pyspark/sql/functions.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index d03de4207cd6..09ad207a2fa8 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -1829,7 +1829,7 @@ def create_map_from_arrays(col1, col2): >>> df = spark.createDataFrame([([2, 5], ["Alice", "Bob"])], ['k', 'v']) >>> df.select(create_map_from_arrays(df.k, df.v).alias("map")).collect() - [Row(map={5: u'Bob', 2: u'Alice'})] + [Row(map={2: u'Alice', 5: u'Bob'})] """ sc = SparkContext._active_spark_context return Column(sc._jvm.functions.map_from_arrays(_to_java_column(col1), _to_java_column(col2))) From 228fcc66e2b85b957833da739a20229867d51cbc Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Fri, 8 Jun 2018 01:54:25 +0100 Subject: [PATCH 09/13] address review comments --- python/pyspark/sql/functions.py | 4 +-- .../catalyst/analysis/FunctionRegistry.scala | 2 +- .../expressions/complexTypeCreator.scala | 28 ++++++------------- .../expressions/ComplexTypeSuite.scala | 20 ++++++------- .../org/apache/spark/sql/functions.scala | 2 +- 5 files changed, 23 insertions(+), 33 deletions(-) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 09ad207a2fa8..d83db64a127e 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -1821,14 +1821,14 @@ def create_map(*cols): @ignore_unicode_prefix @since(2.4) -def create_map_from_arrays(col1, col2): +def map_from_arrays(col1, col2): """Creates a new map from two arrays. :param col1: name of column containing a set of keys. All elements should not be null :param col2: name of column containing a set of values >>> df = spark.createDataFrame([([2, 5], ["Alice", "Bob"])], ['k', 'v']) - >>> df.select(create_map_from_arrays(df.k, df.v).alias("map")).collect() + >>> df.select(map_from_arrays(df.k, df.v).alias("map")).collect() [Row(map={2: u'Alice', 5: u'Bob'})] """ sc = SparkContext._active_spark_context diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index 626fa83a1c4a..f5a5c66c5b46 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -415,9 +415,9 @@ object FunctionRegistry { expression[ArrayPosition]("array_position"), expression[ArraySort]("array_sort"), expression[CreateMap]("map"), - expression[CreateMapFromArrays]("map_from_arrays"), expression[CreateNamedStruct]("named_struct"), expression[ElementAt]("element_at"), + expression[MapFromArrays]("map_from_arrays"), expression[MapKeys]("map_keys"), expression[MapValues]("map_values"), expression[MapEntries]("map_entries"), diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala index 0e0c0636c73a..8c60ca323a3a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala @@ -248,20 +248,11 @@ case class CreateMap(children: Seq[Expression]) extends Expression { > SELECT _FUNC_([1.0, 3.0], ['2', '4']); {1.0:"2",3.0:"4"} """, since = "2.4.0") -case class CreateMapFromArrays(left: Expression, right: Expression) +case class MapFromArrays(left: Expression, right: Expression) extends BinaryExpression with ExpectsInputTypes { override def inputTypes: Seq[AbstractDataType] = Seq(ArrayType, ArrayType) - override def checkInputDataTypes(): TypeCheckResult = { - (left.dataType, right.dataType) match { - case (ArrayType(_, _), ArrayType(_, _)) => - TypeCheckResult.TypeCheckSuccess - case _ => - TypeCheckResult.TypeCheckFailure("The given two arguments should be an array") - } - } - override def dataType: DataType = { MapType( keyType = left.dataType.asInstanceOf[ArrayType].elementType, @@ -269,8 +260,6 @@ case class CreateMapFromArrays(left: Expression, right: Expression) valueContainsNull = right.dataType.asInstanceOf[ArrayType].containsNull) } - override def nullable: Boolean = left.nullable || right.nullable - override def nullSafeEval(keyArray: Any, valueArray: Any): Any = { val keyArrayData = keyArray.asInstanceOf[ArrayData] val valueArrayData = valueArray.asInstanceOf[ArrayData] @@ -279,8 +268,12 @@ case class CreateMapFromArrays(left: Expression, right: Expression) } val leftArrayType = left.dataType.asInstanceOf[ArrayType] if (leftArrayType.containsNull) { - if (keyArrayData.toArray(leftArrayType.elementType).contains(null)) { - throw new RuntimeException("Cannot use null as map key!") + var i = 0 + while (i < keyArrayData.numElements) { + if (keyArrayData.isNullAt(i)) { + throw new RuntimeException("Cannot use null as map key!") + } + i += 1 } } new ArrayBasedMapData(keyArrayData.copy(), valueArrayData.copy()) @@ -291,13 +284,10 @@ case class CreateMapFromArrays(left: Expression, right: Expression) val arrayBasedMapData = classOf[ArrayBasedMapData].getName val leftArrayType = left.dataType.asInstanceOf[ArrayType] val keyArrayElemNullCheck = if (!leftArrayType.containsNull) "" else { - val leftArrayTypeTerm = ctx.addReferenceObj("leftArrayType", leftArrayType.elementType) - val array = ctx.freshName("array") val i = ctx.freshName("i") s""" - |Object[] $array = $keyArrayData.toObjectArray($leftArrayTypeTerm); - |for (int $i = 0; $i < $array.length; $i++) { - | if ($array[$i] == null) { + |for (int $i = 0; $i < $keyArrayData.numElements(); $i++) { + | if ($keyArrayData.isNullAt($i)) { | throw new RuntimeException("Cannot use null as map key!"); | } |} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala index 9bc883646725..8e180c553049 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala @@ -186,7 +186,7 @@ class ComplexTypeSuite extends SparkFunSuite with ExpressionEvalHelper { } } - test("CreateMapFromArrays") { + test("MapFromArrays") { def createMap(keys: Seq[Any], values: Seq[Any]): Map[Any, Any] = { // catalyst map is order-sensitive, so we create ListMap here to preserve the elements order. scala.collection.immutable.ListMap(keys.zip(values): _*) @@ -209,24 +209,24 @@ class ComplexTypeSuite extends SparkFunSuite with ExpressionEvalHelper { val nullArray = Literal.create(null, ArrayType(StringType, false)) - checkEvaluation(CreateMapFromArrays(intArray, longArray), createMap(intSeq, longSeq)) - checkEvaluation(CreateMapFromArrays(intArray, strArray), createMap(intSeq, strSeq)) - checkEvaluation(CreateMapFromArrays(integerArray, strArray), createMap(integerSeq, strSeq)) + checkEvaluation(MapFromArrays(intArray, longArray), createMap(intSeq, longSeq)) + checkEvaluation(MapFromArrays(intArray, strArray), createMap(intSeq, strSeq)) + checkEvaluation(MapFromArrays(integerArray, strArray), createMap(integerSeq, strSeq)) checkEvaluation( - CreateMapFromArrays(strArray, intwithNullArray), createMap(strSeq, intWithNullSeq)) + MapFromArrays(strArray, intwithNullArray), createMap(strSeq, intWithNullSeq)) checkEvaluation( - CreateMapFromArrays(strArray, longwithNullArray), createMap(strSeq, longWithNullSeq)) + MapFromArrays(strArray, longwithNullArray), createMap(strSeq, longWithNullSeq)) checkEvaluation( - CreateMapFromArrays(strArray, longwithNullArray), createMap(strSeq, longWithNullSeq)) - checkEvaluation(CreateMapFromArrays(nullArray, nullArray), null) + MapFromArrays(strArray, longwithNullArray), createMap(strSeq, longWithNullSeq)) + checkEvaluation(MapFromArrays(nullArray, nullArray), null) intercept[RuntimeException] { - checkEvaluation(CreateMapFromArrays(intwithNullArray, strArray), null) + checkEvaluation(MapFromArrays(intwithNullArray, strArray), null) } intercept[RuntimeException] { checkEvaluation( - CreateMapFromArrays(intArray, Literal.create(Seq(1), ArrayType(IntegerType))), null) + MapFromArrays(intArray, Literal.create(Seq(1), ArrayType(IntegerType))), null) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index c77f3e46d70f..f6b631e07bbd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -1078,7 +1078,7 @@ object functions { * @since 2.4 */ def map_from_arrays(keys: Column, values: Column): Column = withExpr { - CreateMapFromArrays(keys.expr, values.expr) + MapFromArrays(keys.expr, values.expr) } /** From 6d53a96a0f7ec122fc44a05213256e92856eb3cf Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Fri, 8 Jun 2018 04:09:31 +0100 Subject: [PATCH 10/13] address review comments --- .../spark/sql/catalyst/expressions/complexTypeCreator.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala index 8c60ca323a3a..0a5f8a907b50 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala @@ -249,7 +249,7 @@ case class CreateMap(children: Seq[Expression]) extends Expression { {1.0:"2",3.0:"4"} """, since = "2.4.0") case class MapFromArrays(left: Expression, right: Expression) - extends BinaryExpression with ExpectsInputTypes { + extends BinaryExpression with ExpectsInputTypes { override def inputTypes: Seq[AbstractDataType] = Seq(ArrayType, ArrayType) @@ -303,7 +303,7 @@ case class MapFromArrays(left: Expression, right: Expression) }) } - override def prettyName: String = "create_map_from_arrays" + override def prettyName: String = "map_from_arrays" } /** From a4b3ec2eed67aa9152a0e5cf8bad4d84ff9bc9b0 Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Sun, 10 Jun 2018 09:18:50 +0100 Subject: [PATCH 11/13] use show() to check pyspark result --- python/pyspark/sql/functions.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index d83db64a127e..96f183616666 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -1819,7 +1819,6 @@ def create_map(*cols): return Column(jc) -@ignore_unicode_prefix @since(2.4) def map_from_arrays(col1, col2): """Creates a new map from two arrays. @@ -1828,8 +1827,12 @@ def map_from_arrays(col1, col2): :param col2: name of column containing a set of values >>> df = spark.createDataFrame([([2, 5], ["Alice", "Bob"])], ['k', 'v']) - >>> df.select(map_from_arrays(df.k, df.v).alias("map")).collect() - [Row(map={2: u'Alice', 5: u'Bob'})] + >>> df.select(map_from_arrays(df.k, df.v).alias("map")).show() + +----------------------+ + | map| + +----------------------+ + |[2 -> Alice, 5 -> Bob]| + +----------------------+ """ sc = SparkContext._active_spark_context return Column(sc._jvm.functions.map_from_arrays(_to_java_column(col1), _to_java_column(col2))) From a0b4ac52ff4b6d270431b5e5f1e5d5be2c56108b Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Sun, 10 Jun 2018 16:05:49 +0100 Subject: [PATCH 12/13] fix pyspark test error --- python/pyspark/sql/functions.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 96f183616666..577740016a26 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -1826,13 +1826,13 @@ def map_from_arrays(col1, col2): :param col1: name of column containing a set of keys. All elements should not be null :param col2: name of column containing a set of values - >>> df = spark.createDataFrame([([2, 5], ["Alice", "Bob"])], ['k', 'v']) + >>> df = spark.createDataFrame([([2, 5], ['a', 'b'])], ['k', 'v']) >>> df.select(map_from_arrays(df.k, df.v).alias("map")).show() - +----------------------+ - | map| - +----------------------+ - |[2 -> Alice, 5 -> Bob]| - +----------------------+ + +----------------+ + | map| + +----------------+ + |[2 -> a, 5 -> b]| + +----------------+ """ sc = SparkContext._active_spark_context return Column(sc._jvm.functions.map_from_arrays(_to_java_column(col1), _to_java_column(col2))) From 38d086877385324ae872652e9dbeb484a0915557 Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Mon, 11 Jun 2018 22:49:57 +0100 Subject: [PATCH 13/13] address review comments --- .../sql/catalyst/expressions/ComplexTypeSuite.scala | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala index 8e180c553049..726193b41173 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala @@ -204,8 +204,8 @@ class ComplexTypeSuite extends SparkFunSuite with ExpressionEvalHelper { val strArray = Literal.create(strSeq, ArrayType(StringType, false)) val integerArray = Literal.create(integerSeq, ArrayType(IntegerType, true)) - val intwithNullArray = Literal.create(intWithNullSeq, ArrayType(IntegerType, true)) - val longwithNullArray = Literal.create(longWithNullSeq, ArrayType(LongType, true)) + val intWithNullArray = Literal.create(intWithNullSeq, ArrayType(IntegerType, true)) + val longWithNullArray = Literal.create(longWithNullSeq, ArrayType(LongType, true)) val nullArray = Literal.create(null, ArrayType(StringType, false)) @@ -214,15 +214,15 @@ class ComplexTypeSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(MapFromArrays(integerArray, strArray), createMap(integerSeq, strSeq)) checkEvaluation( - MapFromArrays(strArray, intwithNullArray), createMap(strSeq, intWithNullSeq)) + MapFromArrays(strArray, intWithNullArray), createMap(strSeq, intWithNullSeq)) checkEvaluation( - MapFromArrays(strArray, longwithNullArray), createMap(strSeq, longWithNullSeq)) + MapFromArrays(strArray, longWithNullArray), createMap(strSeq, longWithNullSeq)) checkEvaluation( - MapFromArrays(strArray, longwithNullArray), createMap(strSeq, longWithNullSeq)) + MapFromArrays(strArray, longWithNullArray), createMap(strSeq, longWithNullSeq)) checkEvaluation(MapFromArrays(nullArray, nullArray), null) intercept[RuntimeException] { - checkEvaluation(MapFromArrays(intwithNullArray, strArray), null) + checkEvaluation(MapFromArrays(intWithNullArray, strArray), null) } intercept[RuntimeException] { checkEvaluation(