diff --git a/mllib/src/main/scala/org/apache/spark/ml/fpm/PrefixSpan.scala b/mllib/src/main/scala/org/apache/spark/ml/fpm/PrefixSpan.scala
index 02168fee16db..41716c621ca9 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/fpm/PrefixSpan.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/fpm/PrefixSpan.scala
@@ -18,6 +18,8 @@
package org.apache.spark.ml.fpm
import org.apache.spark.annotation.{Experimental, Since}
+import org.apache.spark.ml.param._
+import org.apache.spark.ml.util.Identifiable
import org.apache.spark.mllib.fpm.{PrefixSpan => mllibPrefixSpan}
import org.apache.spark.sql.{DataFrame, Dataset, Row}
import org.apache.spark.sql.functions.col
@@ -29,13 +31,97 @@ import org.apache.spark.sql.types.{ArrayType, LongType, StructField, StructType}
* The PrefixSpan algorithm is described in J. Pei, et al., PrefixSpan: Mining Sequential Patterns
* Efficiently by Prefix-Projected Pattern Growth
* (see here).
+ * This class is not yet an Estimator/Transformer, use `findFrequentSequentialPatterns` method to
+ * run the PrefixSpan algorithm.
*
* @see Sequential Pattern Mining
* (Wikipedia)
*/
@Since("2.4.0")
@Experimental
-object PrefixSpan {
+final class PrefixSpan(@Since("2.4.0") override val uid: String) extends Params {
+
+ @Since("2.4.0")
+ def this() = this(Identifiable.randomUID("prefixSpan"))
+
+ /**
+ * Param for the minimal support level (default: `0.1`).
+ * Sequential patterns that appear more than (minSupport * size-of-the-dataset) times are
+ * identified as frequent sequential patterns.
+ * @group param
+ */
+ @Since("2.4.0")
+ val minSupport = new DoubleParam(this, "minSupport", "The minimal support level of the " +
+ "sequential pattern. Sequential pattern that appears more than " +
+ "(minSupport * size-of-the-dataset)." +
+ "times will be output.", ParamValidators.gtEq(0.0))
+
+ /** @group getParam */
+ @Since("2.4.0")
+ def getMinSupport: Double = $(minSupport)
+
+ /** @group setParam */
+ @Since("2.4.0")
+ def setMinSupport(value: Double): this.type = set(minSupport, value)
+
+ /**
+ * Param for the maximal pattern length (default: `10`).
+ * @group param
+ */
+ @Since("2.4.0")
+ val maxPatternLength = new IntParam(this, "maxPatternLength",
+ "The maximal length of the sequential pattern.",
+ ParamValidators.gt(0))
+
+ /** @group getParam */
+ @Since("2.4.0")
+ def getMaxPatternLength: Int = $(maxPatternLength)
+
+ /** @group setParam */
+ @Since("2.4.0")
+ def setMaxPatternLength(value: Int): this.type = set(maxPatternLength, value)
+
+ /**
+ * Param for the maximum number of items (including delimiters used in the internal storage
+ * format) allowed in a projected database before local processing (default: `32000000`).
+ * If a projected database exceeds this size, another iteration of distributed prefix growth
+ * is run.
+ * @group param
+ */
+ @Since("2.4.0")
+ val maxLocalProjDBSize = new LongParam(this, "maxLocalProjDBSize",
+ "The maximum number of items (including delimiters used in the internal storage format) " +
+ "allowed in a projected database before local processing. If a projected database exceeds " +
+ "this size, another iteration of distributed prefix growth is run.",
+ ParamValidators.gt(0))
+
+ /** @group getParam */
+ @Since("2.4.0")
+ def getMaxLocalProjDBSize: Long = $(maxLocalProjDBSize)
+
+ /** @group setParam */
+ @Since("2.4.0")
+ def setMaxLocalProjDBSize(value: Long): this.type = set(maxLocalProjDBSize, value)
+
+ /**
+ * Param for the name of the sequence column in dataset (default "sequence"), rows with
+ * nulls in this column are ignored.
+ * @group param
+ */
+ @Since("2.4.0")
+ val sequenceCol = new Param[String](this, "sequenceCol", "The name of the sequence column in " +
+ "dataset, rows with nulls in this column are ignored.")
+
+ /** @group getParam */
+ @Since("2.4.0")
+ def getSequenceCol: String = $(sequenceCol)
+
+ /** @group setParam */
+ @Since("2.4.0")
+ def setSequenceCol(value: String): this.type = set(sequenceCol, value)
+
+ setDefault(minSupport -> 0.1, maxPatternLength -> 10, maxLocalProjDBSize -> 32000000,
+ sequenceCol -> "sequence")
/**
* :: Experimental ::
@@ -43,54 +129,39 @@ object PrefixSpan {
*
* @param dataset A dataset or a dataframe containing a sequence column which is
* {{{Seq[Seq[_]]}}} type
- * @param sequenceCol the name of the sequence column in dataset, rows with nulls in this column
- * are ignored
- * @param minSupport the minimal support level of the sequential pattern, any pattern that
- * appears more than (minSupport * size-of-the-dataset) times will be output
- * (recommended value: `0.1`).
- * @param maxPatternLength the maximal length of the sequential pattern
- * (recommended value: `10`).
- * @param maxLocalProjDBSize The maximum number of items (including delimiters used in the
- * internal storage format) allowed in a projected database before
- * local processing. If a projected database exceeds this size, another
- * iteration of distributed prefix growth is run
- * (recommended value: `32000000`).
* @return A `DataFrame` that contains columns of sequence and corresponding frequency.
* The schema of it will be:
* - `sequence: Seq[Seq[T]]` (T is the item type)
* - `freq: Long`
*/
@Since("2.4.0")
- def findFrequentSequentialPatterns(
- dataset: Dataset[_],
- sequenceCol: String,
- minSupport: Double,
- maxPatternLength: Int,
- maxLocalProjDBSize: Long): DataFrame = {
-
- val inputType = dataset.schema(sequenceCol).dataType
+ def findFrequentSequentialPatterns(dataset: Dataset[_]): DataFrame = {
+ val sequenceColParam = $(sequenceCol)
+ val inputType = dataset.schema(sequenceColParam).dataType
require(inputType.isInstanceOf[ArrayType] &&
inputType.asInstanceOf[ArrayType].elementType.isInstanceOf[ArrayType],
s"The input column must be ArrayType and the array element type must also be ArrayType, " +
s"but got $inputType.")
-
- val data = dataset.select(sequenceCol)
- val sequences = data.where(col(sequenceCol).isNotNull).rdd
+ val data = dataset.select(sequenceColParam)
+ val sequences = data.where(col(sequenceColParam).isNotNull).rdd
.map(r => r.getAs[Seq[Seq[Any]]](0).map(_.toArray).toArray)
val mllibPrefixSpan = new mllibPrefixSpan()
- .setMinSupport(minSupport)
- .setMaxPatternLength(maxPatternLength)
- .setMaxLocalProjDBSize(maxLocalProjDBSize)
+ .setMinSupport($(minSupport))
+ .setMaxPatternLength($(maxPatternLength))
+ .setMaxLocalProjDBSize($(maxLocalProjDBSize))
val rows = mllibPrefixSpan.run(sequences).freqSequences.map(f => Row(f.sequence, f.freq))
val schema = StructType(Seq(
- StructField("sequence", dataset.schema(sequenceCol).dataType, nullable = false),
+ StructField("sequence", dataset.schema(sequenceColParam).dataType, nullable = false),
StructField("freq", LongType, nullable = false)))
val freqSequences = dataset.sparkSession.createDataFrame(rows, schema)
freqSequences
}
+ @Since("2.4.0")
+ override def copy(extra: ParamMap): PrefixSpan = defaultCopy(extra)
+
}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/fpm/PrefixSpanSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/fpm/PrefixSpanSuite.scala
index 9e538696cbcf..2252151af306 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/fpm/PrefixSpanSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/fpm/PrefixSpanSuite.scala
@@ -29,8 +29,11 @@ class PrefixSpanSuite extends MLTest {
test("PrefixSpan projections with multiple partial starts") {
val smallDataset = Seq(Seq(Seq(1, 2), Seq(1, 2, 3))).toDF("sequence")
- val result = PrefixSpan.findFrequentSequentialPatterns(smallDataset, "sequence",
- minSupport = 1.0, maxPatternLength = 2, maxLocalProjDBSize = 32000000)
+ val result = new PrefixSpan()
+ .setMinSupport(1.0)
+ .setMaxPatternLength(2)
+ .setMaxLocalProjDBSize(32000000)
+ .findFrequentSequentialPatterns(smallDataset)
.as[(Seq[Seq[Int]], Long)].collect()
val expected = Array(
(Seq(Seq(1)), 1L),
@@ -90,8 +93,11 @@ class PrefixSpanSuite extends MLTest {
test("PrefixSpan Integer type, variable-size itemsets") {
val df = smallTestData.toDF("sequence")
- val result = PrefixSpan.findFrequentSequentialPatterns(df, "sequence",
- minSupport = 0.5, maxPatternLength = 5, maxLocalProjDBSize = 32000000)
+ val result = new PrefixSpan()
+ .setMinSupport(0.5)
+ .setMaxPatternLength(5)
+ .setMaxLocalProjDBSize(32000000)
+ .findFrequentSequentialPatterns(df)
.as[(Seq[Seq[Int]], Long)].collect()
compareResults[Int](smallTestDataExpectedResult, result)
@@ -99,8 +105,11 @@ class PrefixSpanSuite extends MLTest {
test("PrefixSpan input row with nulls") {
val df = (smallTestData :+ null).toDF("sequence")
- val result = PrefixSpan.findFrequentSequentialPatterns(df, "sequence",
- minSupport = 0.5, maxPatternLength = 5, maxLocalProjDBSize = 32000000)
+ val result = new PrefixSpan()
+ .setMinSupport(0.5)
+ .setMaxPatternLength(5)
+ .setMaxLocalProjDBSize(32000000)
+ .findFrequentSequentialPatterns(df)
.as[(Seq[Seq[Int]], Long)].collect()
compareResults[Int](smallTestDataExpectedResult, result)
@@ -111,8 +120,11 @@ class PrefixSpanSuite extends MLTest {
val df = smallTestData
.map(seq => seq.map(itemSet => itemSet.map(intToString)))
.toDF("sequence")
- val result = PrefixSpan.findFrequentSequentialPatterns(df, "sequence",
- minSupport = 0.5, maxPatternLength = 5, maxLocalProjDBSize = 32000000)
+ val result = new PrefixSpan()
+ .setMinSupport(0.5)
+ .setMaxPatternLength(5)
+ .setMaxLocalProjDBSize(32000000)
+ .findFrequentSequentialPatterns(df)
.as[(Seq[Seq[String]], Long)].collect()
val expected = smallTestDataExpectedResult.map { case (seq, freq) =>