@@ -25,7 +25,7 @@ import org.apache.hadoop.fs.Path
2525import org .apache .spark .annotation .{Experimental , Since }
2626import org .apache .spark .ml .{Estimator , Model }
2727import org .apache .spark .ml .param ._
28- import org .apache .spark .ml .param .shared .{ HasFeaturesCol , HasPredictionCol }
28+ import org .apache .spark .ml .param .shared .HasPredictionCol
2929import org .apache .spark .ml .util ._
3030import org .apache .spark .mllib .fpm .{AssociationRules => MLlibAssociationRules ,
3131 FPGrowth => MLlibFPGrowth }
@@ -37,7 +37,20 @@ import org.apache.spark.sql.types._
3737/**
3838 * Common params for FPGrowth and FPGrowthModel
3939 */
40- private [fpm] trait FPGrowthParams extends Params with HasFeaturesCol with HasPredictionCol {
40+ private [fpm] trait FPGrowthParams extends Params with HasPredictionCol {
41+
42+ /**
43+ * Items column name.
44+ * Default: "items"
45+ * @group param
46+ */
47+ @ Since (" 2.2.0" )
48+ val itemsCol : Param [String ] = new Param [String ](this , " itemsCol" , " items column name" )
49+ setDefault(itemsCol -> " items" )
50+
51+ /** @group getParam */
52+ @ Since (" 2.2.0" )
53+ def getItemsCol : String = $(itemsCol)
4154
4255 /**
4356 * Minimal support level of the frequent pattern. [0.0, 1.0]. Any pattern that appears
@@ -91,10 +104,10 @@ private[fpm] trait FPGrowthParams extends Params with HasFeaturesCol with HasPre
91104 */
92105 @ Since (" 2.2.0" )
93106 protected def validateAndTransformSchema (schema : StructType ): StructType = {
94- val inputType = schema($(featuresCol )).dataType
107+ val inputType = schema($(itemsCol )).dataType
95108 require(inputType.isInstanceOf [ArrayType ],
96109 s " The input column must be ArrayType, but got $inputType. " )
97- SchemaUtils .appendColumn(schema, $(predictionCol), schema($(featuresCol )).dataType)
110+ SchemaUtils .appendColumn(schema, $(predictionCol), schema($(itemsCol )).dataType)
98111 }
99112}
100113
@@ -133,7 +146,7 @@ class FPGrowth @Since("2.2.0") (
133146
134147 /** @group setParam */
135148 @ Since (" 2.2.0" )
136- def setFeaturesCol (value : String ): this .type = set(featuresCol , value)
149+ def setItemsCol (value : String ): this .type = set(itemsCol , value)
137150
138151 /** @group setParam */
139152 @ Since (" 2.2.0" )
@@ -146,8 +159,8 @@ class FPGrowth @Since("2.2.0") (
146159 }
147160
148161 private def genericFit [T : ClassTag ](dataset : Dataset [_]): FPGrowthModel = {
149- val data = dataset.select($(featuresCol ))
150- val items = data.where(col($(featuresCol )).isNotNull).rdd.map(r => r.getSeq[T ](0 ).toArray)
162+ val data = dataset.select($(itemsCol ))
163+ val items = data.where(col($(itemsCol )).isNotNull).rdd.map(r => r.getSeq[T ](0 ).toArray)
151164 val mllibFP = new MLlibFPGrowth ().setMinSupport($(minSupport))
152165 if (isSet(numPartitions)) {
153166 mllibFP.setNumPartitions($(numPartitions))
@@ -156,7 +169,7 @@ class FPGrowth @Since("2.2.0") (
156169 val rows = parentModel.freqItemsets.map(f => Row (f.items, f.freq))
157170
158171 val schema = StructType (Seq (
159- StructField (" items" , dataset.schema($(featuresCol )).dataType, nullable = false ),
172+ StructField (" items" , dataset.schema($(itemsCol )).dataType, nullable = false ),
160173 StructField (" freq" , LongType , nullable = false )))
161174 val frequentItems = dataset.sparkSession.createDataFrame(rows, schema)
162175 copyValues(new FPGrowthModel (uid, frequentItems)).setParent(this )
@@ -198,7 +211,7 @@ class FPGrowthModel private[ml] (
198211
199212 /** @group setParam */
200213 @ Since (" 2.2.0" )
201- def setFeaturesCol (value : String ): this .type = set(featuresCol , value)
214+ def setItemsCol (value : String ): this .type = set(itemsCol , value)
202215
203216 /** @group setParam */
204217 @ Since (" 2.2.0" )
@@ -235,7 +248,7 @@ class FPGrowthModel private[ml] (
235248 .collect().asInstanceOf [Array [(Seq [Any ], Seq [Any ])]]
236249 val brRules = dataset.sparkSession.sparkContext.broadcast(rules)
237250
238- val dt = dataset.schema($(featuresCol )).dataType
251+ val dt = dataset.schema($(itemsCol )).dataType
239252 // For each rule, examine the input items and summarize the consequents
240253 val predictUDF = udf((items : Seq [_]) => {
241254 if (items != null ) {
@@ -249,7 +262,7 @@ class FPGrowthModel private[ml] (
249262 } else {
250263 Seq .empty
251264 }}, dt)
252- dataset.withColumn($(predictionCol), predictUDF(col($(featuresCol ))))
265+ dataset.withColumn($(predictionCol), predictUDF(col($(itemsCol ))))
253266 }
254267
255268 @ Since (" 2.2.0" )
0 commit comments