diff --git a/mllib/src/main/scala/org/apache/spark/ml/fpm/PrefixSpan.scala b/mllib/src/main/scala/org/apache/spark/ml/fpm/PrefixSpan.scala
index 41716c621ca9..bd1c1a888520 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/fpm/PrefixSpan.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/fpm/PrefixSpan.scala
@@ -53,7 +53,7 @@ final class PrefixSpan(@Since("2.4.0") override val uid: String) extends Params
@Since("2.4.0")
val minSupport = new DoubleParam(this, "minSupport", "The minimal support level of the " +
"sequential pattern. Sequential pattern that appears more than " +
- "(minSupport * size-of-the-dataset)." +
+ "(minSupport * size-of-the-dataset) " +
"times will be output.", ParamValidators.gtEq(0.0))
/** @group getParam */
@@ -128,10 +128,10 @@ final class PrefixSpan(@Since("2.4.0") override val uid: String) extends Params
* Finds the complete set of frequent sequential patterns in the input sequences of itemsets.
*
* @param dataset A dataset or a dataframe containing a sequence column which is
- * {{{Seq[Seq[_]]}}} type
+ * {{{ArrayType(ArrayType(T))}}} type, T is the item type for the input dataset.
* @return A `DataFrame` that contains columns of sequence and corresponding frequency.
* The schema of it will be:
- * - `sequence: Seq[Seq[T]]` (T is the item type)
+ * - `sequence: ArrayType(ArrayType(T))` (T is the item type)
* - `freq: Long`
*/
@Since("2.4.0")
diff --git a/python/pyspark/ml/fpm.py b/python/pyspark/ml/fpm.py
index b8dafd49d354..fd19fd96c4df 100644
--- a/python/pyspark/ml/fpm.py
+++ b/python/pyspark/ml/fpm.py
@@ -16,8 +16,9 @@
#
from pyspark import keyword_only, since
+from pyspark.sql import DataFrame
from pyspark.ml.util import *
-from pyspark.ml.wrapper import JavaEstimator, JavaModel
+from pyspark.ml.wrapper import JavaEstimator, JavaModel, JavaParams, _jvm
from pyspark.ml.param.shared import *
__all__ = ["FPGrowth", "FPGrowthModel"]
@@ -243,3 +244,104 @@ def setParams(self, minSupport=0.3, minConfidence=0.8, itemsCol="items",
def _create_model(self, java_model):
return FPGrowthModel(java_model)
+
+
+class PrefixSpan(JavaParams):
+ """
+ .. note:: Experimental
+
+ A parallel PrefixSpan algorithm to mine frequent sequential patterns.
+ The PrefixSpan algorithm is described in J. Pei, et al., PrefixSpan: Mining Sequential Patterns
+ Efficiently by Prefix-Projected Pattern Growth
+ (see here).
+ This class is not yet an Estimator/Transformer, use :py:func:`findFrequentSequentialPatterns`
+ method to run the PrefixSpan algorithm.
+
+ @see Sequential Pattern Mining
+ (Wikipedia)
+ .. versionadded:: 2.4.0
+
+ """
+
+ minSupport = Param(Params._dummy(), "minSupport", "The minimal support level of the " +
+ "sequential pattern. Sequential pattern that appears more than " +
+ "(minSupport * size-of-the-dataset) times will be output. Must be >= 0.",
+ typeConverter=TypeConverters.toFloat)
+
+ maxPatternLength = Param(Params._dummy(), "maxPatternLength",
+ "The maximal length of the sequential pattern. Must be > 0.",
+ typeConverter=TypeConverters.toInt)
+
+ maxLocalProjDBSize = Param(Params._dummy(), "maxLocalProjDBSize",
+ "The maximum number of items (including delimiters used in the " +
+ "internal storage format) allowed in a projected database before " +
+ "local processing. If a projected database exceeds this size, " +
+ "another iteration of distributed prefix growth is run. " +
+ "Must be > 0.",
+ typeConverter=TypeConverters.toInt)
+
+ sequenceCol = Param(Params._dummy(), "sequenceCol", "The name of the sequence column in " +
+ "dataset, rows with nulls in this column are ignored.",
+ typeConverter=TypeConverters.toString)
+
+ @keyword_only
+ def __init__(self, minSupport=0.1, maxPatternLength=10, maxLocalProjDBSize=32000000,
+ sequenceCol="sequence"):
+ """
+ __init__(self, minSupport=0.1, maxPatternLength=10, maxLocalProjDBSize=32000000, \
+ sequenceCol="sequence")
+ """
+ super(PrefixSpan, self).__init__()
+ self._java_obj = self._new_java_obj("org.apache.spark.ml.fpm.PrefixSpan", self.uid)
+ self._setDefault(minSupport=0.1, maxPatternLength=10, maxLocalProjDBSize=32000000,
+ sequenceCol="sequence")
+ kwargs = self._input_kwargs
+ self.setParams(**kwargs)
+
+ @keyword_only
+ @since("2.4.0")
+ def setParams(self, minSupport=0.1, maxPatternLength=10, maxLocalProjDBSize=32000000,
+ sequenceCol="sequence"):
+ """
+ setParams(self, minSupport=0.1, maxPatternLength=10, maxLocalProjDBSize=32000000, \
+ sequenceCol="sequence")
+ """
+ kwargs = self._input_kwargs
+ return self._set(**kwargs)
+
+ @since("2.4.0")
+ def findFrequentSequentialPatterns(self, dataset):
+ """
+ .. note:: Experimental
+ Finds the complete set of frequent sequential patterns in the input sequences of itemsets.
+
+ :param dataset: A dataframe containing a sequence column which is
+ `ArrayType(ArrayType(T))` type, T is the item type for the input dataset.
+ :return: A `DataFrame` that contains columns of sequence and corresponding frequency.
+ The schema of it will be:
+ - `sequence: ArrayType(ArrayType(T))` (T is the item type)
+ - `freq: Long`
+
+ >>> from pyspark.ml.fpm import PrefixSpan
+ >>> from pyspark.sql import Row
+ >>> df = sc.parallelize([Row(sequence=[[1, 2], [3]]),
+ ... Row(sequence=[[1], [3, 2], [1, 2]]),
+ ... Row(sequence=[[1, 2], [5]]),
+ ... Row(sequence=[[6]])]).toDF()
+ >>> prefixSpan = PrefixSpan(minSupport=0.5, maxPatternLength=5)
+ >>> prefixSpan.findFrequentSequentialPatterns(df).sort("sequence").show(truncate=False)
+ +----------+----+
+ |sequence |freq|
+ +----------+----+
+ |[[1]] |3 |
+ |[[1], [3]]|2 |
+ |[[1, 2]] |3 |
+ |[[2]] |3 |
+ |[[3]] |2 |
+ +----------+----+
+
+ .. versionadded:: 2.4.0
+ """
+ self._transfer_params_to_java()
+ jdf = self._java_obj.findFrequentSequentialPatterns(dataset._jdf)
+ return DataFrame(jdf, dataset.sql_ctx)