Skip to content

Commit 77bf12d

Browse files
committed
Replace featuresCol with itemsCol in ml.fpm.FPGrowth
1 parent ee91a0d commit 77bf12d

File tree

2 files changed

+31
-18
lines changed

2 files changed

+31
-18
lines changed

mllib/src/main/scala/org/apache/spark/ml/fpm/FPGrowth.scala

Lines changed: 24 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ import org.apache.hadoop.fs.Path
2525
import org.apache.spark.annotation.{Experimental, Since}
2626
import org.apache.spark.ml.{Estimator, Model}
2727
import org.apache.spark.ml.param._
28-
import org.apache.spark.ml.param.shared.{HasFeaturesCol, HasPredictionCol}
28+
import org.apache.spark.ml.param.shared.HasPredictionCol
2929
import org.apache.spark.ml.util._
3030
import org.apache.spark.mllib.fpm.{AssociationRules => MLlibAssociationRules,
3131
FPGrowth => MLlibFPGrowth}
@@ -37,7 +37,20 @@ import org.apache.spark.sql.types._
3737
/**
3838
* Common params for FPGrowth and FPGrowthModel
3939
*/
40-
private[fpm] trait FPGrowthParams extends Params with HasFeaturesCol with HasPredictionCol {
40+
private[fpm] trait FPGrowthParams extends Params with HasPredictionCol {
41+
42+
/**
43+
* Items column name.
44+
* Default: "items"
45+
* @group param
46+
*/
47+
@Since("2.2.0")
48+
val itemsCol: Param[String] = new Param[String](this, "itemsCol", "items column name")
49+
setDefault(itemsCol -> "items")
50+
51+
/** @group getParam */
52+
@Since("2.2.0")
53+
def getItemsCol: String = $(itemsCol)
4154

4255
/**
4356
* Minimal support level of the frequent pattern. [0.0, 1.0]. Any pattern that appears
@@ -91,10 +104,10 @@ private[fpm] trait FPGrowthParams extends Params with HasFeaturesCol with HasPre
91104
*/
92105
@Since("2.2.0")
93106
protected def validateAndTransformSchema(schema: StructType): StructType = {
94-
val inputType = schema($(featuresCol)).dataType
107+
val inputType = schema($(itemsCol)).dataType
95108
require(inputType.isInstanceOf[ArrayType],
96109
s"The input column must be ArrayType, but got $inputType.")
97-
SchemaUtils.appendColumn(schema, $(predictionCol), schema($(featuresCol)).dataType)
110+
SchemaUtils.appendColumn(schema, $(predictionCol), schema($(itemsCol)).dataType)
98111
}
99112
}
100113

@@ -133,7 +146,7 @@ class FPGrowth @Since("2.2.0") (
133146

134147
/** @group setParam */
135148
@Since("2.2.0")
136-
def setFeaturesCol(value: String): this.type = set(featuresCol, value)
149+
def setItemsCol(value: String): this.type = set(itemsCol, value)
137150

138151
/** @group setParam */
139152
@Since("2.2.0")
@@ -146,8 +159,8 @@ class FPGrowth @Since("2.2.0") (
146159
}
147160

148161
private def genericFit[T: ClassTag](dataset: Dataset[_]): FPGrowthModel = {
149-
val data = dataset.select($(featuresCol))
150-
val items = data.where(col($(featuresCol)).isNotNull).rdd.map(r => r.getSeq[T](0).toArray)
162+
val data = dataset.select($(itemsCol))
163+
val items = data.where(col($(itemsCol)).isNotNull).rdd.map(r => r.getSeq[T](0).toArray)
151164
val mllibFP = new MLlibFPGrowth().setMinSupport($(minSupport))
152165
if (isSet(numPartitions)) {
153166
mllibFP.setNumPartitions($(numPartitions))
@@ -156,7 +169,7 @@ class FPGrowth @Since("2.2.0") (
156169
val rows = parentModel.freqItemsets.map(f => Row(f.items, f.freq))
157170

158171
val schema = StructType(Seq(
159-
StructField("items", dataset.schema($(featuresCol)).dataType, nullable = false),
172+
StructField("items", dataset.schema($(itemsCol)).dataType, nullable = false),
160173
StructField("freq", LongType, nullable = false)))
161174
val frequentItems = dataset.sparkSession.createDataFrame(rows, schema)
162175
copyValues(new FPGrowthModel(uid, frequentItems)).setParent(this)
@@ -198,7 +211,7 @@ class FPGrowthModel private[ml] (
198211

199212
/** @group setParam */
200213
@Since("2.2.0")
201-
def setFeaturesCol(value: String): this.type = set(featuresCol, value)
214+
def setItemsCol(value: String): this.type = set(itemsCol, value)
202215

203216
/** @group setParam */
204217
@Since("2.2.0")
@@ -235,7 +248,7 @@ class FPGrowthModel private[ml] (
235248
.collect().asInstanceOf[Array[(Seq[Any], Seq[Any])]]
236249
val brRules = dataset.sparkSession.sparkContext.broadcast(rules)
237250

238-
val dt = dataset.schema($(featuresCol)).dataType
251+
val dt = dataset.schema($(itemsCol)).dataType
239252
// For each rule, examine the input items and summarize the consequents
240253
val predictUDF = udf((items: Seq[_]) => {
241254
if (items != null) {
@@ -249,7 +262,7 @@ class FPGrowthModel private[ml] (
249262
} else {
250263
Seq.empty
251264
}}, dt)
252-
dataset.withColumn($(predictionCol), predictUDF(col($(featuresCol))))
265+
dataset.withColumn($(predictionCol), predictUDF(col($(itemsCol))))
253266
}
254267

255268
@Since("2.2.0")

mllib/src/test/scala/org/apache/spark/ml/fpm/FPGrowthSuite.scala

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ class FPGrowthSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul
3434

3535
test("FPGrowth fit and transform with different data types") {
3636
Array(IntegerType, StringType, ShortType, LongType, ByteType).foreach { dt =>
37-
val data = dataset.withColumn("features", col("features").cast(ArrayType(dt)))
37+
val data = dataset.withColumn("items", col("items").cast(ArrayType(dt)))
3838
val model = new FPGrowth().setMinSupport(0.5).fit(data)
3939
val generatedRules = model.setMinConfidence(0.5).associationRules
4040
val expectedRules = spark.createDataFrame(Seq(
@@ -52,8 +52,8 @@ class FPGrowthSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul
5252
(0, Array("1", "2"), Array.emptyIntArray),
5353
(0, Array("1", "2"), Array.emptyIntArray),
5454
(0, Array("1", "3"), Array(2))
55-
)).toDF("id", "features", "prediction")
56-
.withColumn("features", col("features").cast(ArrayType(dt)))
55+
)).toDF("id", "items", "prediction")
56+
.withColumn("items", col("items").cast(ArrayType(dt)))
5757
.withColumn("prediction", col("prediction").cast(ArrayType(dt)))
5858
assert(expectedTransformed.collect().toSet.equals(
5959
transformed.collect().toSet))
@@ -79,7 +79,7 @@ class FPGrowthSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul
7979
(1, Array("1", "2", "3", "5")),
8080
(2, Array("1", "2", "3", "4")),
8181
(3, null.asInstanceOf[Array[String]])
82-
)).toDF("id", "features")
82+
)).toDF("id", "items")
8383
val model = new FPGrowth().setMinSupport(0.7).fit(dataset)
8484
val prediction = model.transform(df)
8585
assert(prediction.select("prediction").where("id=3").first().getSeq[String](0).isEmpty)
@@ -108,11 +108,11 @@ class FPGrowthSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul
108108
val dataset = spark.createDataFrame(Seq(
109109
Array("1", "3"),
110110
Array("2", "3")
111-
).map(Tuple1(_))).toDF("features")
111+
).map(Tuple1(_))).toDF("items")
112112
val model = new FPGrowth().fit(dataset)
113113

114114
val prediction = model.transform(
115-
spark.createDataFrame(Seq(Tuple1(Array("1", "2")))).toDF("features")
115+
spark.createDataFrame(Seq(Tuple1(Array("1", "2")))).toDF("items")
116116
).first().getAs[Seq[String]]("prediction")
117117

118118
assert(prediction === Seq("3"))
@@ -127,7 +127,7 @@ object FPGrowthSuite {
127127
(0, Array("1", "2")),
128128
(0, Array("1", "2")),
129129
(0, Array("1", "3"))
130-
)).toDF("id", "features")
130+
)).toDF("id", "items")
131131
}
132132

133133
/**

0 commit comments

Comments
 (0)