From c794317354998a45184ae8c2869a723d11be76ba Mon Sep 17 00:00:00 2001 From: Nick Pentreath Date: Thu, 28 Apr 2016 13:49:16 +0200 Subject: [PATCH 01/11] Intial work --- .../apache/spark/ml/recommendation/ALS.scala | 42 ++++++++++++++----- .../spark/ml/recommendation/ALSSuite.scala | 36 ++++++++++++++++ .../apache/spark/ml/util/MLTestingUtils.scala | 42 +++++++++++++++++++ 3 files changed, 110 insertions(+), 10 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala index 509c944fed74..aaa768ffa487 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala @@ -42,7 +42,7 @@ import org.apache.spark.mllib.optimization.NNLS import org.apache.spark.rdd.RDD import org.apache.spark.sql.{DataFrame, Dataset} import org.apache.spark.sql.functions._ -import org.apache.spark.sql.types.{DoubleType, FloatType, IntegerType, StructType} +import org.apache.spark.sql.types._ import org.apache.spark.storage.StorageLevel import org.apache.spark.util.Utils import org.apache.spark.util.collection.{OpenHashMap, OpenHashSet, SortDataFormat, Sorter} @@ -71,6 +71,20 @@ private[recommendation] trait ALSModelParams extends Params with HasPredictionCo /** @group getParam */ def getItemCol: String = $(itemCol) + + /** + * Attempts to safely cast a Long user/item id to an Int. Throws an exception if the value is + * out of integer range. + * @return + */ + protected val checkedCast = udf { (n: Long) => + if (n > Int.MaxValue.toLong || n < Int.MinValue.toLong) { + throw new IllegalArgumentException(s"ALS only supports values in Integer range for columns " + + s"${$(userCol)} and ${$(itemCol)}. Value $n was out of Integer range.") + } else { + n.toInt + } + } } /** @@ -193,10 +207,11 @@ private[recommendation] trait ALSParams extends ALSModelParams with HasMaxIter w * @return output schema */ protected def validateAndTransformSchema(schema: StructType): StructType = { - SchemaUtils.checkColumnType(schema, $(userCol), IntegerType) - SchemaUtils.checkColumnType(schema, $(itemCol), IntegerType) - val ratingType = schema($(ratingCol)).dataType - require(ratingType == FloatType || ratingType == DoubleType) + // user and item will be cast to Int + SchemaUtils.checkNumericType(schema, $(userCol)) + SchemaUtils.checkNumericType(schema, $(itemCol)) + // rating will be cast to Float + SchemaUtils.checkNumericType(schema, $(ratingCol)) SchemaUtils.appendColumn(schema, $(predictionCol), FloatType) } } @@ -232,6 +247,7 @@ class ALSModel private[ml] ( @Since("2.0.0") override def transform(dataset: Dataset[_]): DataFrame = { + transformSchema(dataset.schema) // Register a UDF for DataFrame, and then // create a new column named map(predictionCol) by running the predict UDF. val predict = udf { (userFeatures: Seq[Float], itemFeatures: Seq[Float]) => @@ -242,16 +258,19 @@ class ALSModel private[ml] ( } } dataset - .join(userFactors, dataset($(userCol)) === userFactors("id"), "left") - .join(itemFactors, dataset($(itemCol)) === itemFactors("id"), "left") + .join(userFactors, + checkedCast(dataset($(userCol)).cast(LongType)) === userFactors("id"), "left") + .join(itemFactors, + checkedCast(dataset($(itemCol)).cast(LongType)) === itemFactors("id"), "left") .select(dataset("*"), predict(userFactors("features"), itemFactors("features")).as($(predictionCol))) } @Since("1.3.0") override def transformSchema(schema: StructType): StructType = { - SchemaUtils.checkColumnType(schema, $(userCol), IntegerType) - SchemaUtils.checkColumnType(schema, $(itemCol), IntegerType) + // user and item will be cast to Int + SchemaUtils.checkNumericType(schema, $(userCol)) + SchemaUtils.checkNumericType(schema, $(itemCol)) SchemaUtils.appendColumn(schema, $(predictionCol), FloatType) } @@ -430,10 +449,13 @@ class ALS(@Since("1.4.0") override val uid: String) extends Estimator[ALSModel] @Since("2.0.0") override def fit(dataset: Dataset[_]): ALSModel = { + transformSchema(dataset.schema) import dataset.sparkSession.implicits._ + val r = if ($(ratingCol) != "") col($(ratingCol)).cast(FloatType) else lit(1.0f) val ratings = dataset - .select(col($(userCol)).cast(IntegerType), col($(itemCol)).cast(IntegerType), r) + .select(checkedCast(col($(userCol)).cast(LongType)), + checkedCast(col($(itemCol)).cast(LongType)), r) .rdd .map { row => Rating(row.getInt(0), row.getInt(1), row.getFloat(2)) diff --git a/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala index 9da0c32deede..2694ce2a9ee6 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala @@ -42,6 +42,8 @@ import org.apache.spark.sql.{DataFrame, Row, SparkSession} import org.apache.spark.storage.StorageLevel import org.apache.spark.util.Utils +import org.apache.spark.sql.types.{FloatType, IntegerType, StringType} + class ALSSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest with Logging { @@ -486,6 +488,40 @@ class ALSSuite assert(getFactors(model.userFactors) === getFactors(model2.userFactors)) assert(getFactors(model.itemFactors) === getFactors(model2.itemFactors)) } + + test("input type validation") { + val sqlContext = this.sqlContext + import sqlContext.implicits._ + + val als = new ALS().setMaxIter(1).setRank(1) + Seq(("user", IntegerType), ("item", IntegerType), ("rating", FloatType)).foreach { + case (colName, sqlType) => + MLTestingUtils.checkNumericTypesALS[ALSModel, ALS](als, sqlContext, colName, sqlType) { + (ex, act) => + ex.userFactors.first().getSeq[Float](1) === act.userFactors.first.getSeq[Float](1) + } + } + + withClue("Should fail when ids exceed integer range. ") { + val df = Seq( + (0, 0d, 0d, 1, 1d, 1d, 3.0), + (0, 2e10, -2e10, 0, 2e10, -2e10, 2.0), + (1, 1d, 1d, 0, 0d, 0d, 5.0) + ).toDF("user", "user_big", "user_small", "item", "item_big", "item_small", "rating") + intercept[IllegalArgumentException] { + als.fit(df.select("user_big", "item", "rating")) + } + intercept[IllegalArgumentException] { + als.fit(df.select("user_small", "item", "rating")) + } + intercept[IllegalArgumentException] { + als.fit(df.select("user", "item_big", "rating")) + } + intercept[IllegalArgumentException] { + als.fit(df.select("user", "item_small", "rating")) + } + } + } } class ALSCleanerSuite extends SparkFunSuite { diff --git a/mllib/src/test/scala/org/apache/spark/ml/util/MLTestingUtils.scala b/mllib/src/test/scala/org/apache/spark/ml/util/MLTestingUtils.scala index ad7d2c9b8d40..069ee423723c 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/util/MLTestingUtils.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/util/MLTestingUtils.scala @@ -58,6 +58,27 @@ object MLTestingUtils extends SparkFunSuite { "Column label must be of type NumericType but was actually of type StringType")) } + def checkNumericTypesALS[M <: Model[M], T <: Estimator[M]]( + estimator: T, + sqlContext: SQLContext, + column: String, + baseType: NumericType)(check: (M, M) => Unit): Unit = { + val dfs = genRatingsDFWithNumericCols(sqlContext, column) + val expected = estimator.fit(dfs(baseType)) + val actuals = dfs.keys.filter(_ != baseType).map(t => estimator.fit(dfs(t))) + actuals.foreach(actual => check(expected, actual)) + + val baseDF = dfs(baseType) + val others = baseDF.columns.toSeq.diff(Seq(column)).map(col(_)) + val cols = Seq(col(column).cast(StringType)) ++ others + val strDF = baseDF.select(cols: _*) + val thrown = intercept[IllegalArgumentException] { + estimator.fit(strDF) + } + assert(thrown.getMessage.contains( + s"$column must be of type NumericType but was actually of type StringType")) + } + def checkNumericTypes[T <: Evaluator](evaluator: T, spark: SparkSession): Unit = { val dfs = genEvaluatorDFWithNumericLabelCol(spark, "label", "prediction") val expected = evaluator.evaluate(dfs(DoubleType)) @@ -116,6 +137,27 @@ object MLTestingUtils extends SparkFunSuite { }.toMap } + def genRatingsDFWithNumericCols( + sqlContext: SQLContext, + column: String): Map[NumericType, DataFrame] = { + val df = sqlContext.createDataFrame(Seq( + (0, 10, 1.0), + (1, 20, 2.0), + (2, 30, 3.0), + (3, 40, 4.0), + (4, 50, 5.0) + )).toDF("user", "item", "rating") + + val others = df.columns.toSeq.diff(Seq(column)).map(col(_)) + val types: Seq[NumericType] = + Seq(ShortType, LongType, IntegerType, FloatType, ByteType, DoubleType, DecimalType(10, 0)) + types.map { t => + val toCast = col(column).cast(t) + val cols = Seq(toCast) ++ others + t -> df.select(cols: _*) + }.toMap + } + def genEvaluatorDFWithNumericLabelCol( spark: SparkSession, labelColName: String = "label", From 878baf5d4bc5eeb968ce0a81fafc4ebdaa1c0166 Mon Sep 17 00:00:00 2001 From: Nick Pentreath Date: Thu, 28 Apr 2016 15:47:45 +0200 Subject: [PATCH 02/11] Improve tests and add transform test cases + small cleanup --- .../apache/spark/ml/recommendation/ALS.scala | 1 - .../spark/ml/recommendation/ALSSuite.scala | 62 ++++++++++++------- .../apache/spark/ml/util/MLTestingUtils.scala | 14 +++-- 3 files changed, 47 insertions(+), 30 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala index aaa768ffa487..c3d86a20d20b 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala @@ -75,7 +75,6 @@ private[recommendation] trait ALSModelParams extends Params with HasPredictionCo /** * Attempts to safely cast a Long user/item id to an Int. Throws an exception if the value is * out of integer range. - * @return */ protected val checkedCast = udf { (n: Long) => if (n > Int.MaxValue.toLong || n < Int.MinValue.toLong) { diff --git a/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala index 2694ce2a9ee6..d1f469023330 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala @@ -207,7 +207,6 @@ class ALSSuite /** * Generates an explicit feedback dataset for testing ALS. - * * @param numUsers number of users * @param numItems number of items * @param rank rank @@ -248,7 +247,6 @@ class ALSSuite /** * Generates an implicit feedback dataset for testing ALS. - * * @param numUsers number of users * @param numItems number of items * @param rank rank @@ -267,7 +265,6 @@ class ALSSuite /** * Generates random user/item factors, with i.i.d. values drawn from U(a, b). - * * @param size number of users/items * @param rank number of features * @param random random number generator @@ -286,7 +283,6 @@ class ALSSuite /** * Test ALS using the given training/test splits and parameters. - * * @param training training dataset * @param test test dataset * @param rank rank of the matrix factorization @@ -493,33 +489,53 @@ class ALSSuite val sqlContext = this.sqlContext import sqlContext.implicits._ + val test = Seq() val als = new ALS().setMaxIter(1).setRank(1) Seq(("user", IntegerType), ("item", IntegerType), ("rating", FloatType)).foreach { case (colName, sqlType) => MLTestingUtils.checkNumericTypesALS[ALSModel, ALS](als, sqlContext, colName, sqlType) { (ex, act) => ex.userFactors.first().getSeq[Float](1) === act.userFactors.first.getSeq[Float](1) + } { (ex, act, _) => + ex.transform(_: DataFrame).select("prediction").first.getDouble(0) ~== + act.transform(_: DataFrame).select("prediction").first.getDouble(0) absTol 1e-6 } } - - withClue("Should fail when ids exceed integer range. ") { - val df = Seq( - (0, 0d, 0d, 1, 1d, 1d, 3.0), - (0, 2e10, -2e10, 0, 2e10, -2e10, 2.0), - (1, 1d, 1d, 0, 0d, 0d, 5.0) - ).toDF("user", "user_big", "user_small", "item", "item_big", "item_small", "rating") - intercept[IllegalArgumentException] { - als.fit(df.select("user_big", "item", "rating")) - } - intercept[IllegalArgumentException] { - als.fit(df.select("user_small", "item", "rating")) - } - intercept[IllegalArgumentException] { - als.fit(df.select("user", "item_big", "rating")) - } - intercept[IllegalArgumentException] { - als.fit(df.select("user", "item_small", "rating")) - } + val big = Int.MaxValue.toLong + 1 + val small = Int.MinValue.toDouble - 1 + val df = Seq( + (0, 0L, 0d, 1, 1L, 1d, 3.0), + (0, big, small, 0, big, small, 2.0), + (1, 1L, 1d, 0, 0L, 0d, 5.0) + ).toDF("user", "user_big", "user_small", "item", "item_big", "item_small", "rating") + withClue("fit should fail when ids exceed integer range. ") { + assert(intercept[IllegalArgumentException] { + als.fit(df.select(df("user_big").as("user"), df("item"), df("rating"))) + }.getMessage.contains("was out of Integer range")) + assert(intercept[IllegalArgumentException] { + als.fit(df.select(df("user_small").as("user"), df("item"), df("rating"))) + }.getMessage.contains("was out of Integer range")) + assert(intercept[IllegalArgumentException] { + als.fit(df.select(df("item_big").as("item"), df("user"), df("rating"))) + }.getMessage.contains("was out of Integer range")) + assert(intercept[IllegalArgumentException] { + als.fit(df.select(df("item_small").as("item"), df("user"), df("rating"))) + }.getMessage.contains("was out of Integer range")) + } + withClue("transform should fail when ids exceed integer range. ") { + val model = als.fit(df) + assert(intercept[SparkException] { + model.transform(df.select(df("user_big").as("user"), df("item"))).first + }.getMessage.contains("was out of Integer range")) + assert(intercept[SparkException] { + model.transform(df.select(df("user_small").as("user"), df("item"))).first + }.getMessage.contains("was out of Integer range")) + assert(intercept[SparkException] { + model.transform(df.select(df("item_big").as("item"), df("user"))).first + }.getMessage.contains("was out of Integer range")) + assert(intercept[SparkException] { + model.transform(df.select(df("item_small").as("item"), df("user"))).first + }.getMessage.contains("was out of Integer range")) } } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/util/MLTestingUtils.scala b/mllib/src/test/scala/org/apache/spark/ml/util/MLTestingUtils.scala index 069ee423723c..4d0f8c06567c 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/util/MLTestingUtils.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/util/MLTestingUtils.scala @@ -62,11 +62,14 @@ object MLTestingUtils extends SparkFunSuite { estimator: T, sqlContext: SQLContext, column: String, - baseType: NumericType)(check: (M, M) => Unit): Unit = { + baseType: NumericType) + (check: (M, M) => Unit) + (check2: (M, M, DataFrame) => Unit): Unit = { val dfs = genRatingsDFWithNumericCols(sqlContext, column) val expected = estimator.fit(dfs(baseType)) - val actuals = dfs.keys.filter(_ != baseType).map(t => estimator.fit(dfs(t))) - actuals.foreach(actual => check(expected, actual)) + val actuals = dfs.keys.filter(_ != baseType).map(t => (t, estimator.fit(dfs(t)))) + actuals.foreach { case (_, actual) => check(expected, actual) } + actuals.foreach { case (t, actual) => check2(expected, actual, dfs(t)) } val baseDF = dfs(baseType) val others = baseDF.columns.toSeq.diff(Seq(column)).map(col(_)) @@ -152,8 +155,7 @@ object MLTestingUtils extends SparkFunSuite { val types: Seq[NumericType] = Seq(ShortType, LongType, IntegerType, FloatType, ByteType, DoubleType, DecimalType(10, 0)) types.map { t => - val toCast = col(column).cast(t) - val cols = Seq(toCast) ++ others + val cols = Seq(col(column).cast(t)) ++ others t -> df.select(cols: _*) }.toMap } @@ -173,7 +175,7 @@ object MLTestingUtils extends SparkFunSuite { val types = Seq(ShortType, LongType, IntegerType, FloatType, ByteType, DoubleType, DecimalType(10, 0)) types - .map(t => t -> df.select(col(labelColName).cast(t), col(predictionColName))) + .map(t => t -> df.select(col(labelColName).cast(t) * 1000, col(predictionColName))) .toMap } } From a6965ceb2bd9f74f5c234fbc4abbfa2bb0f957b5 Mon Sep 17 00:00:00 2001 From: Nick Pentreath Date: Thu, 28 Apr 2016 16:12:54 +0200 Subject: [PATCH 03/11] revert small erroneous addition --- .../test/scala/org/apache/spark/ml/util/MLTestingUtils.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mllib/src/test/scala/org/apache/spark/ml/util/MLTestingUtils.scala b/mllib/src/test/scala/org/apache/spark/ml/util/MLTestingUtils.scala index 4d0f8c06567c..4da0280875c5 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/util/MLTestingUtils.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/util/MLTestingUtils.scala @@ -175,7 +175,7 @@ object MLTestingUtils extends SparkFunSuite { val types = Seq(ShortType, LongType, IntegerType, FloatType, ByteType, DoubleType, DecimalType(10, 0)) types - .map(t => t -> df.select(col(labelColName).cast(t) * 1000, col(predictionColName))) + .map(t => t -> df.select(col(labelColName).cast(t), col(predictionColName))) .toMap } } From 3e84f4c5288c8e58dd5141575db931c8c2e2d8c4 Mon Sep 17 00:00:00 2001 From: Nick Pentreath Date: Thu, 28 Apr 2016 20:47:41 +0200 Subject: [PATCH 04/11] user/item col param docs --- .../org/apache/spark/ml/recommendation/ALS.scala | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala index c3d86a20d20b..0694dbe17f57 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala @@ -53,21 +53,27 @@ import org.apache.spark.util.random.XORShiftRandom */ private[recommendation] trait ALSModelParams extends Params with HasPredictionCol { /** - * Param for the column name for user ids. + * Param for the column name for user ids. Ids must be integers. Other + * numeric types are supported for this column, but will be cast to integers as long as they + * fall within the integer value range. * Default: "user" * @group param */ - val userCol = new Param[String](this, "userCol", "column name for user ids") + val userCol = new Param[String](this, "userCol", "column name for user ids. Must be within " + + "the integer value range.") /** @group getParam */ def getUserCol: String = $(userCol) /** - * Param for the column name for item ids. + * Param for the column name for item ids. Ids must be integers. Other + * numeric types are supported for this column, but will be cast to integers as long as they + * fall within the integer value range. * Default: "item" * @group param */ - val itemCol = new Param[String](this, "itemCol", "column name for item ids") + val itemCol = new Param[String](this, "itemCol", "column name for item ids. Must be within " + + "the integer value range.") /** @group getParam */ def getItemCol: String = $(itemCol) From 48e7b6b70edb57209aa02bedd49b2ce73fc19f08 Mon Sep 17 00:00:00 2001 From: Nick Pentreath Date: Fri, 29 Apr 2016 08:45:07 +0200 Subject: [PATCH 05/11] indentation and unused val --- .../apache/spark/ml/recommendation/ALSSuite.scala | 1 - .../org/apache/spark/ml/util/MLTestingUtils.scala | 12 ++++++------ 2 files changed, 6 insertions(+), 7 deletions(-) diff --git a/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala index d1f469023330..a52246c50a87 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala @@ -489,7 +489,6 @@ class ALSSuite val sqlContext = this.sqlContext import sqlContext.implicits._ - val test = Seq() val als = new ALS().setMaxIter(1).setRank(1) Seq(("user", IntegerType), ("item", IntegerType), ("rating", FloatType)).foreach { case (colName, sqlType) => diff --git a/mllib/src/test/scala/org/apache/spark/ml/util/MLTestingUtils.scala b/mllib/src/test/scala/org/apache/spark/ml/util/MLTestingUtils.scala index 4da0280875c5..7ceef34b7629 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/util/MLTestingUtils.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/util/MLTestingUtils.scala @@ -59,12 +59,12 @@ object MLTestingUtils extends SparkFunSuite { } def checkNumericTypesALS[M <: Model[M], T <: Estimator[M]]( - estimator: T, - sqlContext: SQLContext, - column: String, - baseType: NumericType) - (check: (M, M) => Unit) - (check2: (M, M, DataFrame) => Unit): Unit = { + estimator: T, + sqlContext: SQLContext, + column: String, + baseType: NumericType) + (check: (M, M) => Unit) + (check2: (M, M, DataFrame) => Unit): Unit = { val dfs = genRatingsDFWithNumericCols(sqlContext, column) val expected = estimator.fit(dfs(baseType)) val actuals = dfs.keys.filter(_ != baseType).map(t => (t, estimator.fit(dfs(t)))) From 1b49801a96e379e05063d681b67b6a67d9bd6535 Mon Sep 17 00:00:00 2001 From: Nick Pentreath Date: Wed, 4 May 2016 15:01:12 +0200 Subject: [PATCH 06/11] user/item col cast to DoubleType before checkedCast --- .../apache/spark/ml/recommendation/ALS.scala | 17 ++++++++--------- 1 file changed, 8 insertions(+), 9 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala index 0694dbe17f57..3fc7ac7edb25 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala @@ -79,11 +79,11 @@ private[recommendation] trait ALSModelParams extends Params with HasPredictionCo def getItemCol: String = $(itemCol) /** - * Attempts to safely cast a Long user/item id to an Int. Throws an exception if the value is + * Attempts to safely cast a user/item id to an Int. Throws an exception if the value is * out of integer range. */ - protected val checkedCast = udf { (n: Long) => - if (n > Int.MaxValue.toLong || n < Int.MinValue.toLong) { + protected val checkedCast = udf { (n: Double) => + if (n > Int.MaxValue || n < Int.MinValue) { throw new IllegalArgumentException(s"ALS only supports values in Integer range for columns " + s"${$(userCol)} and ${$(itemCol)}. Value $n was out of Integer range.") } else { @@ -264,9 +264,9 @@ class ALSModel private[ml] ( } dataset .join(userFactors, - checkedCast(dataset($(userCol)).cast(LongType)) === userFactors("id"), "left") + checkedCast(dataset($(userCol)).cast(DoubleType)) === userFactors("id"), "left") .join(itemFactors, - checkedCast(dataset($(itemCol)).cast(LongType)) === itemFactors("id"), "left") + checkedCast(dataset($(itemCol)).cast(DoubleType)) === itemFactors("id"), "left") .select(dataset("*"), predict(userFactors("features"), itemFactors("features")).as($(predictionCol))) } @@ -459,11 +459,10 @@ class ALS(@Since("1.4.0") override val uid: String) extends Estimator[ALSModel] val r = if ($(ratingCol) != "") col($(ratingCol)).cast(FloatType) else lit(1.0f) val ratings = dataset - .select(checkedCast(col($(userCol)).cast(LongType)), - checkedCast(col($(itemCol)).cast(LongType)), r) + .select(checkedCast(col($(userCol)).cast(DoubleType)), + checkedCast(col($(itemCol)).cast(DoubleType)), r) .rdd - .map { row => - Rating(row.getInt(0), row.getInt(1), row.getFloat(2)) + .map { row => Rating(row.getInt(0), row.getInt(1), row.getFloat(2)) } val instrLog = Instrumentation.create(this, ratings) instrLog.logParams(rank, numUserBlocks, numItemBlocks, implicitPrefs, alpha, From d663a1697755845e092db40cb97ff3167ff91baa Mon Sep 17 00:00:00 2001 From: Nick Pentreath Date: Wed, 4 May 2016 15:05:26 +0200 Subject: [PATCH 07/11] Fix erroneous indentation change --- .../main/scala/org/apache/spark/ml/recommendation/ALS.scala | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala index 3fc7ac7edb25..5757fd7e605c 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala @@ -462,7 +462,8 @@ class ALS(@Since("1.4.0") override val uid: String) extends Estimator[ALSModel] .select(checkedCast(col($(userCol)).cast(DoubleType)), checkedCast(col($(itemCol)).cast(DoubleType)), r) .rdd - .map { row => Rating(row.getInt(0), row.getInt(1), row.getFloat(2)) + .map { row => + Rating(row.getInt(0), row.getInt(1), row.getFloat(2)) } val instrLog = Instrumentation.create(this, ratings) instrLog.logParams(rank, numUserBlocks, numItemBlocks, implicitPrefs, alpha, From 676daf94889dd0cc1e732176ef0ff66a43be21e1 Mon Sep 17 00:00:00 2001 From: Nick Pentreath Date: Wed, 11 May 2016 21:09:49 +0200 Subject: [PATCH 08/11] update to SparkSession --- .../org/apache/spark/ml/recommendation/ALSSuite.scala | 9 ++++----- .../scala/org/apache/spark/ml/util/MLTestingUtils.scala | 8 ++++---- 2 files changed, 8 insertions(+), 9 deletions(-) diff --git a/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala index a52246c50a87..87f671eccb0a 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala @@ -39,11 +39,10 @@ import org.apache.spark.mllib.util.TestingUtils._ import org.apache.spark.rdd.RDD import org.apache.spark.scheduler.{SparkListener, SparkListenerStageCompleted} import org.apache.spark.sql.{DataFrame, Row, SparkSession} +import org.apache.spark.sql.types.{FloatType, IntegerType} import org.apache.spark.storage.StorageLevel import org.apache.spark.util.Utils -import org.apache.spark.sql.types.{FloatType, IntegerType, StringType} - class ALSSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest with Logging { @@ -486,13 +485,13 @@ class ALSSuite } test("input type validation") { - val sqlContext = this.sqlContext - import sqlContext.implicits._ + val spark = this.spark + import spark.implicits._ val als = new ALS().setMaxIter(1).setRank(1) Seq(("user", IntegerType), ("item", IntegerType), ("rating", FloatType)).foreach { case (colName, sqlType) => - MLTestingUtils.checkNumericTypesALS[ALSModel, ALS](als, sqlContext, colName, sqlType) { + MLTestingUtils.checkNumericTypesALS[ALSModel, ALS](als, spark, colName, sqlType) { (ex, act) => ex.userFactors.first().getSeq[Float](1) === act.userFactors.first.getSeq[Float](1) } { (ex, act, _) => diff --git a/mllib/src/test/scala/org/apache/spark/ml/util/MLTestingUtils.scala b/mllib/src/test/scala/org/apache/spark/ml/util/MLTestingUtils.scala index 7ceef34b7629..c31a81fc5db7 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/util/MLTestingUtils.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/util/MLTestingUtils.scala @@ -60,12 +60,12 @@ object MLTestingUtils extends SparkFunSuite { def checkNumericTypesALS[M <: Model[M], T <: Estimator[M]]( estimator: T, - sqlContext: SQLContext, + spark: SparkSession, column: String, baseType: NumericType) (check: (M, M) => Unit) (check2: (M, M, DataFrame) => Unit): Unit = { - val dfs = genRatingsDFWithNumericCols(sqlContext, column) + val dfs = genRatingsDFWithNumericCols(spark, column) val expected = estimator.fit(dfs(baseType)) val actuals = dfs.keys.filter(_ != baseType).map(t => (t, estimator.fit(dfs(t)))) actuals.foreach { case (_, actual) => check(expected, actual) } @@ -141,9 +141,9 @@ object MLTestingUtils extends SparkFunSuite { } def genRatingsDFWithNumericCols( - sqlContext: SQLContext, + spark: SparkSession, column: String): Map[NumericType, DataFrame] = { - val df = sqlContext.createDataFrame(Seq( + val df = spark.createDataFrame(Seq( (0, 10, 1.0), (1, 20, 2.0), (2, 30, 3.0), From 34a0700a051ae22e945ccce8c05196281348a8e8 Mon Sep 17 00:00:00 2001 From: Nick Pentreath Date: Thu, 12 May 2016 08:51:56 +0200 Subject: [PATCH 09/11] Get rid of unnecessary generic test method for ALS --- .../org/apache/spark/ml/recommendation/ALSSuite.scala | 5 ++++- .../scala/org/apache/spark/ml/util/MLTestingUtils.scala | 9 +++++---- 2 files changed, 9 insertions(+), 5 deletions(-) diff --git a/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala index 87f671eccb0a..97061f9835c7 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala @@ -488,10 +488,12 @@ class ALSSuite val spark = this.spark import spark.implicits._ + // check that ALS can handle all numeric types for rating column + // and user/item columns (when the user/item ids are within Int range) val als = new ALS().setMaxIter(1).setRank(1) Seq(("user", IntegerType), ("item", IntegerType), ("rating", FloatType)).foreach { case (colName, sqlType) => - MLTestingUtils.checkNumericTypesALS[ALSModel, ALS](als, spark, colName, sqlType) { + MLTestingUtils.checkNumericTypesALS(als, spark, colName, sqlType) { (ex, act) => ex.userFactors.first().getSeq[Float](1) === act.userFactors.first.getSeq[Float](1) } { (ex, act, _) => @@ -499,6 +501,7 @@ class ALSSuite act.transform(_: DataFrame).select("prediction").first.getDouble(0) absTol 1e-6 } } + // check user/item ids falling outside of Int range val big = Int.MaxValue.toLong + 1 val small = Int.MinValue.toDouble - 1 val df = Seq( diff --git a/mllib/src/test/scala/org/apache/spark/ml/util/MLTestingUtils.scala b/mllib/src/test/scala/org/apache/spark/ml/util/MLTestingUtils.scala index c31a81fc5db7..7560d26db119 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/util/MLTestingUtils.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/util/MLTestingUtils.scala @@ -21,6 +21,7 @@ import org.apache.spark.SparkFunSuite import org.apache.spark.ml.{Estimator, Model} import org.apache.spark.ml.evaluation.Evaluator import org.apache.spark.ml.param.ParamMap +import org.apache.spark.ml.recommendation.{ALS, ALSModel} import org.apache.spark.ml.tree.impl.TreeTests import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.sql.{DataFrame, SparkSession} @@ -58,13 +59,13 @@ object MLTestingUtils extends SparkFunSuite { "Column label must be of type NumericType but was actually of type StringType")) } - def checkNumericTypesALS[M <: Model[M], T <: Estimator[M]]( - estimator: T, + def checkNumericTypesALS( + estimator: ALS, spark: SparkSession, column: String, baseType: NumericType) - (check: (M, M) => Unit) - (check2: (M, M, DataFrame) => Unit): Unit = { + (check: (ALSModel, ALSModel) => Unit) + (check2: (ALSModel, ALSModel, DataFrame) => Unit): Unit = { val dfs = genRatingsDFWithNumericCols(spark, column) val expected = estimator.fit(dfs(baseType)) val actuals = dfs.keys.filter(_ != baseType).map(t => (t, estimator.fit(dfs(t)))) From b0aa98fa4daf5cbed2df742ee2ab5f322f77a4f6 Mon Sep 17 00:00:00 2001 From: Nick Pentreath Date: Mon, 16 May 2016 11:09:03 +0200 Subject: [PATCH 10/11] test using float --- .../scala/org/apache/spark/ml/recommendation/ALSSuite.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala index 97061f9835c7..d45acaf489d5 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala @@ -497,8 +497,8 @@ class ALSSuite (ex, act) => ex.userFactors.first().getSeq[Float](1) === act.userFactors.first.getSeq[Float](1) } { (ex, act, _) => - ex.transform(_: DataFrame).select("prediction").first.getDouble(0) ~== - act.transform(_: DataFrame).select("prediction").first.getDouble(0) absTol 1e-6 + ex.transform(_: DataFrame).select("prediction").first.getFloat(0) ~== + act.transform(_: DataFrame).select("prediction").first.getFloat(0) absTol 1e-6 } } // check user/item ids falling outside of Int range From 1fd1f873f0926cace39c6ee16a830c16710bc6d0 Mon Sep 17 00:00:00 2001 From: Nick Pentreath Date: Mon, 16 May 2016 11:09:40 +0200 Subject: [PATCH 11/11] Update user/item col doc (also in PySpark) --- .../scala/org/apache/spark/ml/recommendation/ALS.scala | 4 ++-- python/pyspark/ml/recommendation.py | 8 ++++---- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala index 5757fd7e605c..f257382d2205 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala @@ -59,7 +59,7 @@ private[recommendation] trait ALSModelParams extends Params with HasPredictionCo * Default: "user" * @group param */ - val userCol = new Param[String](this, "userCol", "column name for user ids. Must be within " + + val userCol = new Param[String](this, "userCol", "column name for user ids. Ids must be within " + "the integer value range.") /** @group getParam */ @@ -72,7 +72,7 @@ private[recommendation] trait ALSModelParams extends Params with HasPredictionCo * Default: "item" * @group param */ - val itemCol = new Param[String](this, "itemCol", "column name for item ids. Must be within " + + val itemCol = new Param[String](this, "itemCol", "column name for item ids. Ids must be within " + "the integer value range.") /** @group getParam */ diff --git a/python/pyspark/ml/recommendation.py b/python/pyspark/ml/recommendation.py index d7cb65846574..86c00d91652d 100644 --- a/python/pyspark/ml/recommendation.py +++ b/python/pyspark/ml/recommendation.py @@ -110,10 +110,10 @@ class ALS(JavaEstimator, HasCheckpointInterval, HasMaxIter, HasPredictionCol, Ha typeConverter=TypeConverters.toBoolean) alpha = Param(Params._dummy(), "alpha", "alpha for implicit preference", typeConverter=TypeConverters.toFloat) - userCol = Param(Params._dummy(), "userCol", "column name for user ids", - typeConverter=TypeConverters.toString) - itemCol = Param(Params._dummy(), "itemCol", "column name for item ids", - typeConverter=TypeConverters.toString) + userCol = Param(Params._dummy(), "userCol", "column name for user ids. Ids must be within " + + "the integer value range.", typeConverter=TypeConverters.toString) + itemCol = Param(Params._dummy(), "itemCol", "column name for item ids. Ids must be within " + + "the integer value range.", typeConverter=TypeConverters.toString) ratingCol = Param(Params._dummy(), "ratingCol", "column name for ratings", typeConverter=TypeConverters.toString) nonnegative = Param(Params._dummy(), "nonnegative",