Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 41 additions & 14 deletions mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ import org.apache.spark.mllib.optimization.NNLS
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{DataFrame, Dataset}
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types.{DoubleType, FloatType, IntegerType, StructType}
import org.apache.spark.sql.types._
import org.apache.spark.storage.StorageLevel
import org.apache.spark.util.Utils
import org.apache.spark.util.collection.{OpenHashMap, OpenHashSet, SortDataFormat, Sorter}
Expand All @@ -53,24 +53,43 @@ import org.apache.spark.util.random.XORShiftRandom
*/
private[recommendation] trait ALSModelParams extends Params with HasPredictionCol {
/**
* Param for the column name for user ids.
* Param for the column name for user ids. Ids must be integers. Other
* numeric types are supported for this column, but will be cast to integers as long as they
* fall within the integer value range.
* Default: "user"
* @group param
*/
val userCol = new Param[String](this, "userCol", "column name for user ids")
val userCol = new Param[String](this, "userCol", "column name for user ids. Ids must be within " +
"the integer value range.")

/** @group getParam */
def getUserCol: String = $(userCol)

/**
* Param for the column name for item ids.
* Param for the column name for item ids. Ids must be integers. Other
* numeric types are supported for this column, but will be cast to integers as long as they
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So it seems like the only other numeric type we support is Long, maybe it would be better to say that? Someone might try and pass in BigInts or Doubles and expect this work.

Copy link
Contributor Author

@MLnick MLnick Apr 28, 2016

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We "support" all numeric types in the sense that the input col can be any numeric type. But it is cast to Int. It is a "safe" cast though, if it is > Int.MaxValue or < Int.MinValue it throws an exception. "Safe" in the sense that it won't mangle the user's input ids (e.g. if Longs are passed in they will now get a failure on fit rather than a silent cast of those Long ids into Ints).

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah yes, I didn't notice the first cast from input type to Long - it seems like that would be OK[ish] most of the time (except with floats/doubles), but also with certain BigDecimal you could end up throwing away the high bits when going to a Long and a very out of range value would pass the range check.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah ok. Could cast to double or float here... I was just concerned about any
storage / performance impact, but if everything is pipelines through the
cast -> udf then no problem
On Thu, 28 Apr 2016 at 21:27, Holden Karau notifications@github.com wrote:

In mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala
#12762 (comment):

/** @group getParam */
def getUserCol: String = $(userCol)

/**

  • * Param for the column name for item ids.
  • * Param for the column name for item ids. Ids must be integers. Other
  • * numeric types are supported for this column, but will be cast to integers as long as they

Ah yes, I didn't notice the first cast from input type to Long - it seems
like that would be OK[ish] most of the time (except with floats/doubles),
but also with certain BigDecimal you could end up throwing away the high
bits when going to a Long and a very out of range value would pass the
range check.


You are receiving this because you authored the thread.
Reply to this email directly or view it on GitHub
https://github.com/apache/spark/pull/12762/files/73ea0b62f1c0ae6a9897ec83f5c8dfedea86f3f9#r61487974

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think keeping cast to Integer is good for performance - but maybe just avoiding supporting input types that we might silently fail on/produce junk results for.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor Author

@MLnick MLnick Apr 29, 2016

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Personally as per the related JIRA, I would actually like to support Int, Long and String for ids in ALS (with appropriate warnings about performance impact for Long/String ids). For the vast majority of use cases I believe the user-friendliness of supporting String in particular outweighs the performance impact. For those users who need performance at scale, they can stick to Int.

But for now, since only Int ids are supported in the DF API, some validation is better than nothing. I am actually slightly more in favor of only supporting Int or Long for the id columns in this PR, since the real-world occurrence of a Double or other more esoteric numeric type for the id column is, IMO, highly unlikely, and in that case requiring users to do the cast explicitly themselves is ok I would say.

So we can support only Int and Longs (within Integer range) as a simpler alternative here - it would just require to update the type checks in transformSchema and the tests.

@jkbradley @srowen @holdenk thoughts?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ping @mengxr also in case you have a chance to take a look, and consider this question of whether to only support Int/Long for ids or support all numeric types (with "safe" cast to Int in both cases)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry for the slow answer. I like supporting all Numeric types, with checking. I agree we should support String IDs at some point, with automatic indexing; that can be part of this discussion: [https://issues.apache.org/jira/browse/SPARK-11106]

* fall within the integer value range.
* Default: "item"
* @group param
*/
val itemCol = new Param[String](this, "itemCol", "column name for item ids")
val itemCol = new Param[String](this, "itemCol", "column name for item ids. Ids must be within " +
"the integer value range.")

/** @group getParam */
def getItemCol: String = $(itemCol)

/**
* Attempts to safely cast a user/item id to an Int. Throws an exception if the value is
* out of integer range.
*/
protected val checkedCast = udf { (n: Double) =>
if (n > Int.MaxValue || n < Int.MinValue) {
throw new IllegalArgumentException(s"ALS only supports values in Integer range for columns " +
s"${$(userCol)} and ${$(itemCol)}. Value $n was out of Integer range.")
} else {
n.toInt
}
}
}

/**
Expand Down Expand Up @@ -193,10 +212,11 @@ private[recommendation] trait ALSParams extends ALSModelParams with HasMaxIter w
* @return output schema
*/
protected def validateAndTransformSchema(schema: StructType): StructType = {
SchemaUtils.checkColumnType(schema, $(userCol), IntegerType)
SchemaUtils.checkColumnType(schema, $(itemCol), IntegerType)
val ratingType = schema($(ratingCol)).dataType
require(ratingType == FloatType || ratingType == DoubleType)
// user and item will be cast to Int
SchemaUtils.checkNumericType(schema, $(userCol))
SchemaUtils.checkNumericType(schema, $(itemCol))
// rating will be cast to Float
SchemaUtils.checkNumericType(schema, $(ratingCol))
SchemaUtils.appendColumn(schema, $(predictionCol), FloatType)
}
}
Expand Down Expand Up @@ -232,6 +252,7 @@ class ALSModel private[ml] (

@Since("2.0.0")
override def transform(dataset: Dataset[_]): DataFrame = {
transformSchema(dataset.schema)
// Register a UDF for DataFrame, and then
// create a new column named map(predictionCol) by running the predict UDF.
val predict = udf { (userFeatures: Seq[Float], itemFeatures: Seq[Float]) =>
Expand All @@ -242,16 +263,19 @@ class ALSModel private[ml] (
}
}
dataset
.join(userFactors, dataset($(userCol)) === userFactors("id"), "left")
.join(itemFactors, dataset($(itemCol)) === itemFactors("id"), "left")
.join(userFactors,
checkedCast(dataset($(userCol)).cast(DoubleType)) === userFactors("id"), "left")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't this and the next line's cast both use IntegerType?

Copy link
Contributor Author

@MLnick MLnick May 27, 2016

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jkbradley the existing code did the cast to Int - that means passing in Long or Double (say) would silently cast and potentially lose precison and give weird results, with no exception or warning. That's why here we cast to DoubleType and use the checkedCast udf to do a safe cast to Int if the value is within Int value range. If not we throw an exception with a helpful message.

This is so we can allow any numeric type for the user/item columns (providing some form of "backward compatability" with the old version that didn't check types at all), but we can still only support actual values that are Ints.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right, makes sense, thanks!

.join(itemFactors,
checkedCast(dataset($(itemCol)).cast(DoubleType)) === itemFactors("id"), "left")
.select(dataset("*"),
predict(userFactors("features"), itemFactors("features")).as($(predictionCol)))
}

@Since("1.3.0")
override def transformSchema(schema: StructType): StructType = {
SchemaUtils.checkColumnType(schema, $(userCol), IntegerType)
SchemaUtils.checkColumnType(schema, $(itemCol), IntegerType)
// user and item will be cast to Int
SchemaUtils.checkNumericType(schema, $(userCol))
SchemaUtils.checkNumericType(schema, $(itemCol))
SchemaUtils.appendColumn(schema, $(predictionCol), FloatType)
}

Expand Down Expand Up @@ -430,10 +454,13 @@ class ALS(@Since("1.4.0") override val uid: String) extends Estimator[ALSModel]

@Since("2.0.0")
override def fit(dataset: Dataset[_]): ALSModel = {
transformSchema(dataset.schema)
import dataset.sparkSession.implicits._

val r = if ($(ratingCol) != "") col($(ratingCol)).cast(FloatType) else lit(1.0f)
val ratings = dataset
.select(col($(userCol)).cast(IntegerType), col($(itemCol)).cast(IntegerType), r)
.select(checkedCast(col($(userCol)).cast(DoubleType)),
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

DoubleType -> IntegerType

checkedCast(col($(itemCol)).cast(DoubleType)), r)
.rdd
.map { row =>
Rating(row.getInt(0), row.getInt(1), row.getFloat(2))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ import org.apache.spark.mllib.util.TestingUtils._
import org.apache.spark.rdd.RDD
import org.apache.spark.scheduler.{SparkListener, SparkListenerStageCompleted}
import org.apache.spark.sql.{DataFrame, Row, SparkSession}
import org.apache.spark.sql.types.{FloatType, IntegerType}
import org.apache.spark.storage.StorageLevel
import org.apache.spark.util.Utils

Expand Down Expand Up @@ -205,7 +206,6 @@ class ALSSuite

/**
* Generates an explicit feedback dataset for testing ALS.
*
* @param numUsers number of users
* @param numItems number of items
* @param rank rank
Expand Down Expand Up @@ -246,7 +246,6 @@ class ALSSuite

/**
* Generates an implicit feedback dataset for testing ALS.
*
* @param numUsers number of users
* @param numItems number of items
* @param rank rank
Expand All @@ -265,7 +264,6 @@ class ALSSuite

/**
* Generates random user/item factors, with i.i.d. values drawn from U(a, b).
*
* @param size number of users/items
* @param rank number of features
* @param random random number generator
Expand All @@ -284,7 +282,6 @@ class ALSSuite

/**
* Test ALS using the given training/test splits and parameters.
*
* @param training training dataset
* @param test test dataset
* @param rank rank of the matrix factorization
Expand Down Expand Up @@ -486,6 +483,62 @@ class ALSSuite
assert(getFactors(model.userFactors) === getFactors(model2.userFactors))
assert(getFactors(model.itemFactors) === getFactors(model2.itemFactors))
}

test("input type validation") {
val spark = this.spark
import spark.implicits._

// check that ALS can handle all numeric types for rating column
// and user/item columns (when the user/item ids are within Int range)
val als = new ALS().setMaxIter(1).setRank(1)
Seq(("user", IntegerType), ("item", IntegerType), ("rating", FloatType)).foreach {
case (colName, sqlType) =>
MLTestingUtils.checkNumericTypesALS(als, spark, colName, sqlType) {
(ex, act) =>
ex.userFactors.first().getSeq[Float](1) === act.userFactors.first.getSeq[Float](1)
} { (ex, act, _) =>
ex.transform(_: DataFrame).select("prediction").first.getFloat(0) ~==
act.transform(_: DataFrame).select("prediction").first.getFloat(0) absTol 1e-6
}
}
// check user/item ids falling outside of Int range
val big = Int.MaxValue.toLong + 1
val small = Int.MinValue.toDouble - 1
val df = Seq(
(0, 0L, 0d, 1, 1L, 1d, 3.0),
(0, big, small, 0, big, small, 2.0),
(1, 1L, 1d, 0, 0L, 0d, 5.0)
).toDF("user", "user_big", "user_small", "item", "item_big", "item_small", "rating")
withClue("fit should fail when ids exceed integer range. ") {
assert(intercept[IllegalArgumentException] {
als.fit(df.select(df("user_big").as("user"), df("item"), df("rating")))
}.getMessage.contains("was out of Integer range"))
assert(intercept[IllegalArgumentException] {
als.fit(df.select(df("user_small").as("user"), df("item"), df("rating")))
}.getMessage.contains("was out of Integer range"))
assert(intercept[IllegalArgumentException] {
als.fit(df.select(df("item_big").as("item"), df("user"), df("rating")))
}.getMessage.contains("was out of Integer range"))
assert(intercept[IllegalArgumentException] {
als.fit(df.select(df("item_small").as("item"), df("user"), df("rating")))
}.getMessage.contains("was out of Integer range"))
}
withClue("transform should fail when ids exceed integer range. ") {
val model = als.fit(df)
assert(intercept[SparkException] {
model.transform(df.select(df("user_big").as("user"), df("item"))).first
}.getMessage.contains("was out of Integer range"))
assert(intercept[SparkException] {
model.transform(df.select(df("user_small").as("user"), df("item"))).first
}.getMessage.contains("was out of Integer range"))
assert(intercept[SparkException] {
model.transform(df.select(df("item_big").as("item"), df("user"))).first
}.getMessage.contains("was out of Integer range"))
assert(intercept[SparkException] {
model.transform(df.select(df("item_small").as("item"), df("user"))).first
}.getMessage.contains("was out of Integer range"))
}
}
}

class ALSCleanerSuite extends SparkFunSuite {
Expand Down
45 changes: 45 additions & 0 deletions mllib/src/test/scala/org/apache/spark/ml/util/MLTestingUtils.scala
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.{Estimator, Model}
import org.apache.spark.ml.evaluation.Evaluator
import org.apache.spark.ml.param.ParamMap
import org.apache.spark.ml.recommendation.{ALS, ALSModel}
import org.apache.spark.ml.tree.impl.TreeTests
import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.sql.{DataFrame, SparkSession}
Expand Down Expand Up @@ -58,6 +59,30 @@ object MLTestingUtils extends SparkFunSuite {
"Column label must be of type NumericType but was actually of type StringType"))
}

def checkNumericTypesALS(
estimator: ALS,
spark: SparkSession,
column: String,
baseType: NumericType)
(check: (ALSModel, ALSModel) => Unit)
(check2: (ALSModel, ALSModel, DataFrame) => Unit): Unit = {
val dfs = genRatingsDFWithNumericCols(spark, column)
val expected = estimator.fit(dfs(baseType))
val actuals = dfs.keys.filter(_ != baseType).map(t => (t, estimator.fit(dfs(t))))
actuals.foreach { case (_, actual) => check(expected, actual) }
actuals.foreach { case (t, actual) => check2(expected, actual, dfs(t)) }

val baseDF = dfs(baseType)
val others = baseDF.columns.toSeq.diff(Seq(column)).map(col(_))
val cols = Seq(col(column).cast(StringType)) ++ others
val strDF = baseDF.select(cols: _*)
val thrown = intercept[IllegalArgumentException] {
estimator.fit(strDF)
}
assert(thrown.getMessage.contains(
s"$column must be of type NumericType but was actually of type StringType"))
}

def checkNumericTypes[T <: Evaluator](evaluator: T, spark: SparkSession): Unit = {
val dfs = genEvaluatorDFWithNumericLabelCol(spark, "label", "prediction")
val expected = evaluator.evaluate(dfs(DoubleType))
Expand Down Expand Up @@ -116,6 +141,26 @@ object MLTestingUtils extends SparkFunSuite {
}.toMap
}

def genRatingsDFWithNumericCols(
spark: SparkSession,
column: String): Map[NumericType, DataFrame] = {
val df = spark.createDataFrame(Seq(
(0, 10, 1.0),
(1, 20, 2.0),
(2, 30, 3.0),
(3, 40, 4.0),
(4, 50, 5.0)
)).toDF("user", "item", "rating")

val others = df.columns.toSeq.diff(Seq(column)).map(col(_))
val types: Seq[NumericType] =
Seq(ShortType, LongType, IntegerType, FloatType, ByteType, DoubleType, DecimalType(10, 0))
types.map { t =>
val cols = Seq(col(column).cast(t)) ++ others
t -> df.select(cols: _*)
}.toMap
}

def genEvaluatorDFWithNumericLabelCol(
spark: SparkSession,
labelColName: String = "label",
Expand Down
8 changes: 4 additions & 4 deletions python/pyspark/ml/recommendation.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,10 +110,10 @@ class ALS(JavaEstimator, HasCheckpointInterval, HasMaxIter, HasPredictionCol, Ha
typeConverter=TypeConverters.toBoolean)
alpha = Param(Params._dummy(), "alpha", "alpha for implicit preference",
typeConverter=TypeConverters.toFloat)
userCol = Param(Params._dummy(), "userCol", "column name for user ids",
typeConverter=TypeConverters.toString)
itemCol = Param(Params._dummy(), "itemCol", "column name for item ids",
typeConverter=TypeConverters.toString)
userCol = Param(Params._dummy(), "userCol", "column name for user ids. Ids must be within " +
"the integer value range.", typeConverter=TypeConverters.toString)
itemCol = Param(Params._dummy(), "itemCol", "column name for item ids. Ids must be within " +
"the integer value range.", typeConverter=TypeConverters.toString)
ratingCol = Param(Params._dummy(), "ratingCol", "column name for ratings",
typeConverter=TypeConverters.toString)
nonnegative = Param(Params._dummy(), "nonnegative",
Expand Down