diff --git a/dev/sparktestsupport/modules.py b/dev/sparktestsupport/modules.py index 1443584ccbcb..41793593e995 100644 --- a/dev/sparktestsupport/modules.py +++ b/dev/sparktestsupport/modules.py @@ -460,6 +460,7 @@ def __hash__(self): "pyspark.ml.evaluation", "pyspark.ml.feature", "pyspark.ml.fpm", + "pyspark.ml.functions", "pyspark.ml.image", "pyspark.ml.linalg.__init__", "pyspark.ml.recommendation", diff --git a/mllib/src/main/scala/org/apache/spark/ml/functions.scala b/mllib/src/main/scala/org/apache/spark/ml/functions.scala new file mode 100644 index 000000000000..1faf562c4d89 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/functions.scala @@ -0,0 +1,48 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml + +import org.apache.spark.annotation.Since +import org.apache.spark.ml.linalg.Vector +import org.apache.spark.mllib.linalg.{Vector => OldVector} +import org.apache.spark.sql.Column +import org.apache.spark.sql.functions.udf + +// scalastyle:off +@Since("3.0.0") +object functions { +// scalastyle:on + + private val vectorToArrayUdf = udf { vec: Any => + vec match { + case v: Vector => v.toArray + case v: OldVector => v.toArray + case v => throw new IllegalArgumentException( + "function vector_to_array requires a non-null input argument and input type must be " + + "`org.apache.spark.ml.linalg.Vector` or `org.apache.spark.mllib.linalg.Vector`, " + + s"but got ${ if (v == null) "null" else v.getClass.getName }.") + } + }.asNonNullable() + + /** + * Converts a column of MLlib sparse/dense vectors into a column of dense arrays. + * + * @since 3.0.0 + */ + def vector_to_array(v: Column): Column = vectorToArrayUdf(v) +} diff --git a/mllib/src/test/scala/org/apache/spark/ml/FunctionsSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/FunctionsSuite.scala new file mode 100644 index 000000000000..2f5062c689fc --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/FunctionsSuite.scala @@ -0,0 +1,65 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml + +import org.apache.spark.SparkException +import org.apache.spark.ml.functions.vector_to_array +import org.apache.spark.ml.linalg.Vectors +import org.apache.spark.ml.util.MLTest +import org.apache.spark.mllib.linalg.{Vectors => OldVectors} +import org.apache.spark.sql.functions.col + +class FunctionsSuite extends MLTest { + + import testImplicits._ + + test("test vector_to_array") { + val df = Seq( + (Vectors.dense(1.0, 2.0, 3.0), OldVectors.dense(10.0, 20.0, 30.0)), + (Vectors.sparse(3, Seq((0, 2.0), (2, 3.0))), OldVectors.sparse(3, Seq((0, 20.0), (2, 30.0)))) + ).toDF("vec", "oldVec") + + val result = df.select(vector_to_array('vec), vector_to_array('oldVec)) + .as[(Seq[Double], Seq[Double])] + .collect().toSeq + + val expected = Seq( + (Seq(1.0, 2.0, 3.0), Seq(10.0, 20.0, 30.0)), + (Seq(2.0, 0.0, 3.0), Seq(20.0, 0.0, 30.0)) + ) + assert(result === expected) + + val df2 = Seq( + (Vectors.dense(1.0, 2.0, 3.0), + OldVectors.dense(10.0, 20.0, 30.0), 1), + (null, null, 0) + ).toDF("vec", "oldVec", "label") + + + for ((colName, valType) <- Seq( + ("vec", "null"), ("oldVec", "null"), ("label", "java.lang.Integer"))) { + val thrown1 = intercept[SparkException] { + df2.select(vector_to_array(col(colName))).count + } + assert(thrown1.getCause.getMessage.contains( + "function vector_to_array requires a non-null input argument and input type must be " + + "`org.apache.spark.ml.linalg.Vector` or `org.apache.spark.mllib.linalg.Vector`, " + + s"but got ${valType}")) + } + } +} diff --git a/python/docs/pyspark.ml.rst b/python/docs/pyspark.ml.rst index 6a5d81706f07..e31dfddd5988 100644 --- a/python/docs/pyspark.ml.rst +++ b/python/docs/pyspark.ml.rst @@ -41,6 +41,14 @@ pyspark.ml.clustering module :undoc-members: :inherited-members: +pyspark.ml.functions module +---------------------------- + +.. automodule:: pyspark.ml.functions + :members: + :undoc-members: + :inherited-members: + pyspark.ml.linalg module ---------------------------- diff --git a/python/pyspark/ml/functions.py b/python/pyspark/ml/functions.py new file mode 100644 index 000000000000..2b4d8ddcd00a --- /dev/null +++ b/python/pyspark/ml/functions.py @@ -0,0 +1,68 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from pyspark import since, SparkContext +from pyspark.sql.column import Column, _to_java_column + + +@since(3.0) +def vector_to_array(col): + """ + Converts a column of MLlib sparse/dense vectors into a column of dense arrays. + + >>> from pyspark.ml.linalg import Vectors + >>> from pyspark.ml.functions import vector_to_array + >>> from pyspark.mllib.linalg import Vectors as OldVectors + >>> df = spark.createDataFrame([ + ... (Vectors.dense(1.0, 2.0, 3.0), OldVectors.dense(10.0, 20.0, 30.0)), + ... (Vectors.sparse(3, [(0, 2.0), (2, 3.0)]), + ... OldVectors.sparse(3, [(0, 20.0), (2, 30.0)]))], + ... ["vec", "oldVec"]) + >>> df.select(vector_to_array("vec").alias("vec"), + ... vector_to_array("oldVec").alias("oldVec")).collect() + [Row(vec=[1.0, 2.0, 3.0], oldVec=[10.0, 20.0, 30.0]), + Row(vec=[2.0, 0.0, 3.0], oldVec=[20.0, 0.0, 30.0])] + """ + sc = SparkContext._active_spark_context + return Column( + sc._jvm.org.apache.spark.ml.functions.vector_to_array(_to_java_column(col))) + + +def _test(): + import doctest + from pyspark.sql import SparkSession + import pyspark.ml.functions + import sys + globs = pyspark.ml.functions.__dict__.copy() + spark = SparkSession.builder \ + .master("local[2]") \ + .appName("ml.functions tests") \ + .getOrCreate() + sc = spark.sparkContext + globs['sc'] = sc + globs['spark'] = spark + + (failure_count, test_count) = doctest.testmod( + pyspark.ml.functions, globs=globs, + optionflags=doctest.ELLIPSIS | doctest.NORMALIZE_WHITESPACE) + spark.stop() + if failure_count: + sys.exit(-1) + + +if __name__ == "__main__": + _test()