-
Notifications
You must be signed in to change notification settings - Fork 29k
[SPARK-14891][ML] Add schema validation for ALS #12762
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
c794317
878baf5
a6965ce
3e84f4c
48e7b6b
1b49801
d663a16
676daf9
34a0700
b0aa98f
1fd1f87
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -42,7 +42,7 @@ import org.apache.spark.mllib.optimization.NNLS | |
| import org.apache.spark.rdd.RDD | ||
| import org.apache.spark.sql.{DataFrame, Dataset} | ||
| import org.apache.spark.sql.functions._ | ||
| import org.apache.spark.sql.types.{DoubleType, FloatType, IntegerType, StructType} | ||
| import org.apache.spark.sql.types._ | ||
| import org.apache.spark.storage.StorageLevel | ||
| import org.apache.spark.util.Utils | ||
| import org.apache.spark.util.collection.{OpenHashMap, OpenHashSet, SortDataFormat, Sorter} | ||
|
|
@@ -53,24 +53,43 @@ import org.apache.spark.util.random.XORShiftRandom | |
| */ | ||
| private[recommendation] trait ALSModelParams extends Params with HasPredictionCol { | ||
| /** | ||
| * Param for the column name for user ids. | ||
| * Param for the column name for user ids. Ids must be integers. Other | ||
| * numeric types are supported for this column, but will be cast to integers as long as they | ||
| * fall within the integer value range. | ||
| * Default: "user" | ||
| * @group param | ||
| */ | ||
| val userCol = new Param[String](this, "userCol", "column name for user ids") | ||
| val userCol = new Param[String](this, "userCol", "column name for user ids. Ids must be within " + | ||
| "the integer value range.") | ||
|
|
||
| /** @group getParam */ | ||
| def getUserCol: String = $(userCol) | ||
|
|
||
| /** | ||
| * Param for the column name for item ids. | ||
| * Param for the column name for item ids. Ids must be integers. Other | ||
| * numeric types are supported for this column, but will be cast to integers as long as they | ||
| * fall within the integer value range. | ||
| * Default: "item" | ||
| * @group param | ||
| */ | ||
| val itemCol = new Param[String](this, "itemCol", "column name for item ids") | ||
| val itemCol = new Param[String](this, "itemCol", "column name for item ids. Ids must be within " + | ||
| "the integer value range.") | ||
|
|
||
| /** @group getParam */ | ||
| def getItemCol: String = $(itemCol) | ||
|
|
||
| /** | ||
| * Attempts to safely cast a user/item id to an Int. Throws an exception if the value is | ||
| * out of integer range. | ||
| */ | ||
| protected val checkedCast = udf { (n: Double) => | ||
| if (n > Int.MaxValue || n < Int.MinValue) { | ||
| throw new IllegalArgumentException(s"ALS only supports values in Integer range for columns " + | ||
| s"${$(userCol)} and ${$(itemCol)}. Value $n was out of Integer range.") | ||
| } else { | ||
| n.toInt | ||
| } | ||
| } | ||
| } | ||
|
|
||
| /** | ||
|
|
@@ -193,10 +212,11 @@ private[recommendation] trait ALSParams extends ALSModelParams with HasMaxIter w | |
| * @return output schema | ||
| */ | ||
| protected def validateAndTransformSchema(schema: StructType): StructType = { | ||
| SchemaUtils.checkColumnType(schema, $(userCol), IntegerType) | ||
| SchemaUtils.checkColumnType(schema, $(itemCol), IntegerType) | ||
| val ratingType = schema($(ratingCol)).dataType | ||
| require(ratingType == FloatType || ratingType == DoubleType) | ||
| // user and item will be cast to Int | ||
| SchemaUtils.checkNumericType(schema, $(userCol)) | ||
| SchemaUtils.checkNumericType(schema, $(itemCol)) | ||
| // rating will be cast to Float | ||
| SchemaUtils.checkNumericType(schema, $(ratingCol)) | ||
| SchemaUtils.appendColumn(schema, $(predictionCol), FloatType) | ||
| } | ||
| } | ||
|
|
@@ -232,6 +252,7 @@ class ALSModel private[ml] ( | |
|
|
||
| @Since("2.0.0") | ||
| override def transform(dataset: Dataset[_]): DataFrame = { | ||
| transformSchema(dataset.schema) | ||
| // Register a UDF for DataFrame, and then | ||
| // create a new column named map(predictionCol) by running the predict UDF. | ||
| val predict = udf { (userFeatures: Seq[Float], itemFeatures: Seq[Float]) => | ||
|
|
@@ -242,16 +263,19 @@ class ALSModel private[ml] ( | |
| } | ||
| } | ||
| dataset | ||
| .join(userFactors, dataset($(userCol)) === userFactors("id"), "left") | ||
| .join(itemFactors, dataset($(itemCol)) === itemFactors("id"), "left") | ||
| .join(userFactors, | ||
| checkedCast(dataset($(userCol)).cast(DoubleType)) === userFactors("id"), "left") | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Shouldn't this and the next line's cast both use IntegerType?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @jkbradley the existing code did the cast to This is so we can allow any numeric type for the user/item columns (providing some form of "backward compatability" with the old version that didn't check types at all), but we can still only support actual values that are
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Right, makes sense, thanks! |
||
| .join(itemFactors, | ||
| checkedCast(dataset($(itemCol)).cast(DoubleType)) === itemFactors("id"), "left") | ||
| .select(dataset("*"), | ||
| predict(userFactors("features"), itemFactors("features")).as($(predictionCol))) | ||
| } | ||
|
|
||
| @Since("1.3.0") | ||
| override def transformSchema(schema: StructType): StructType = { | ||
| SchemaUtils.checkColumnType(schema, $(userCol), IntegerType) | ||
| SchemaUtils.checkColumnType(schema, $(itemCol), IntegerType) | ||
| // user and item will be cast to Int | ||
| SchemaUtils.checkNumericType(schema, $(userCol)) | ||
| SchemaUtils.checkNumericType(schema, $(itemCol)) | ||
| SchemaUtils.appendColumn(schema, $(predictionCol), FloatType) | ||
| } | ||
|
|
||
|
|
@@ -430,10 +454,13 @@ class ALS(@Since("1.4.0") override val uid: String) extends Estimator[ALSModel] | |
|
|
||
| @Since("2.0.0") | ||
| override def fit(dataset: Dataset[_]): ALSModel = { | ||
| transformSchema(dataset.schema) | ||
| import dataset.sparkSession.implicits._ | ||
|
|
||
| val r = if ($(ratingCol) != "") col($(ratingCol)).cast(FloatType) else lit(1.0f) | ||
| val ratings = dataset | ||
| .select(col($(userCol)).cast(IntegerType), col($(itemCol)).cast(IntegerType), r) | ||
| .select(checkedCast(col($(userCol)).cast(DoubleType)), | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. DoubleType -> IntegerType |
||
| checkedCast(col($(itemCol)).cast(DoubleType)), r) | ||
| .rdd | ||
| .map { row => | ||
| Rating(row.getInt(0), row.getInt(1), row.getFloat(2)) | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So it seems like the only other numeric type we support is Long, maybe it would be better to say that? Someone might try and pass in BigInts or Doubles and expect this work.
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We "support" all numeric types in the sense that the input col can be any numeric type. But it is cast to Int. It is a "safe" cast though, if it is > Int.MaxValue or < Int.MinValue it throws an exception. "Safe" in the sense that it won't mangle the user's input ids (e.g. if Longs are passed in they will now get a failure on
fitrather than a silent cast of those Long ids into Ints).There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah yes, I didn't notice the first cast from input type to Long - it seems like that would be OK[ish] most of the time (except with floats/doubles), but also with certain BigDecimal you could end up throwing away the high bits when going to a Long and a very out of range value would pass the range check.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah ok. Could cast to double or float here... I was just concerned about any
storage / performance impact, but if everything is pipelines through the
cast -> udf then no problem
On Thu, 28 Apr 2016 at 21:27, Holden Karau notifications@github.com wrote:
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think keeping cast to Integer is good for performance - but maybe just avoiding supporting input types that we might silently fail on/produce junk results for.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
cc @jkbradley
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Personally as per the related JIRA, I would actually like to support Int, Long and String for ids in ALS (with appropriate warnings about performance impact for Long/String ids). For the vast majority of use cases I believe the user-friendliness of supporting String in particular outweighs the performance impact. For those users who need performance at scale, they can stick to Int.
But for now, since only Int ids are supported in the DF API, some validation is better than nothing. I am actually slightly more in favor of only supporting Int or Long for the id columns in this PR, since the real-world occurrence of a Double or other more esoteric numeric type for the id column is, IMO, highly unlikely, and in that case requiring users to do the cast explicitly themselves is ok I would say.
So we can support only Int and Longs (within Integer range) as a simpler alternative here - it would just require to update the type checks in
transformSchemaand the tests.@jkbradley @srowen @holdenk thoughts?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ping @mengxr also in case you have a chance to take a look, and consider this question of whether to only support Int/Long for ids or support all numeric types (with "safe" cast to Int in both cases)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry for the slow answer. I like supporting all Numeric types, with checking. I agree we should support String IDs at some point, with automatic indexing; that can be part of this discussion: [https://issues.apache.org/jira/browse/SPARK-11106]