diff --git a/mllib/src/main/scala/org/apache/spark/ml/fpm/PrefixSpan.scala b/mllib/src/main/scala/org/apache/spark/ml/fpm/PrefixSpan.scala index 41716c621ca9..bd1c1a888520 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/fpm/PrefixSpan.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/fpm/PrefixSpan.scala @@ -53,7 +53,7 @@ final class PrefixSpan(@Since("2.4.0") override val uid: String) extends Params @Since("2.4.0") val minSupport = new DoubleParam(this, "minSupport", "The minimal support level of the " + "sequential pattern. Sequential pattern that appears more than " + - "(minSupport * size-of-the-dataset)." + + "(minSupport * size-of-the-dataset) " + "times will be output.", ParamValidators.gtEq(0.0)) /** @group getParam */ @@ -128,10 +128,10 @@ final class PrefixSpan(@Since("2.4.0") override val uid: String) extends Params * Finds the complete set of frequent sequential patterns in the input sequences of itemsets. * * @param dataset A dataset or a dataframe containing a sequence column which is - * {{{Seq[Seq[_]]}}} type + * {{{ArrayType(ArrayType(T))}}} type, T is the item type for the input dataset. * @return A `DataFrame` that contains columns of sequence and corresponding frequency. * The schema of it will be: - * - `sequence: Seq[Seq[T]]` (T is the item type) + * - `sequence: ArrayType(ArrayType(T))` (T is the item type) * - `freq: Long` */ @Since("2.4.0") diff --git a/python/pyspark/ml/fpm.py b/python/pyspark/ml/fpm.py index b8dafd49d354..fd19fd96c4df 100644 --- a/python/pyspark/ml/fpm.py +++ b/python/pyspark/ml/fpm.py @@ -16,8 +16,9 @@ # from pyspark import keyword_only, since +from pyspark.sql import DataFrame from pyspark.ml.util import * -from pyspark.ml.wrapper import JavaEstimator, JavaModel +from pyspark.ml.wrapper import JavaEstimator, JavaModel, JavaParams, _jvm from pyspark.ml.param.shared import * __all__ = ["FPGrowth", "FPGrowthModel"] @@ -243,3 +244,104 @@ def setParams(self, minSupport=0.3, minConfidence=0.8, itemsCol="items", def _create_model(self, java_model): return FPGrowthModel(java_model) + + +class PrefixSpan(JavaParams): + """ + .. note:: Experimental + + A parallel PrefixSpan algorithm to mine frequent sequential patterns. + The PrefixSpan algorithm is described in J. Pei, et al., PrefixSpan: Mining Sequential Patterns + Efficiently by Prefix-Projected Pattern Growth + (see here). + This class is not yet an Estimator/Transformer, use :py:func:`findFrequentSequentialPatterns` + method to run the PrefixSpan algorithm. + + @see Sequential Pattern Mining + (Wikipedia) + .. versionadded:: 2.4.0 + + """ + + minSupport = Param(Params._dummy(), "minSupport", "The minimal support level of the " + + "sequential pattern. Sequential pattern that appears more than " + + "(minSupport * size-of-the-dataset) times will be output. Must be >= 0.", + typeConverter=TypeConverters.toFloat) + + maxPatternLength = Param(Params._dummy(), "maxPatternLength", + "The maximal length of the sequential pattern. Must be > 0.", + typeConverter=TypeConverters.toInt) + + maxLocalProjDBSize = Param(Params._dummy(), "maxLocalProjDBSize", + "The maximum number of items (including delimiters used in the " + + "internal storage format) allowed in a projected database before " + + "local processing. If a projected database exceeds this size, " + + "another iteration of distributed prefix growth is run. " + + "Must be > 0.", + typeConverter=TypeConverters.toInt) + + sequenceCol = Param(Params._dummy(), "sequenceCol", "The name of the sequence column in " + + "dataset, rows with nulls in this column are ignored.", + typeConverter=TypeConverters.toString) + + @keyword_only + def __init__(self, minSupport=0.1, maxPatternLength=10, maxLocalProjDBSize=32000000, + sequenceCol="sequence"): + """ + __init__(self, minSupport=0.1, maxPatternLength=10, maxLocalProjDBSize=32000000, \ + sequenceCol="sequence") + """ + super(PrefixSpan, self).__init__() + self._java_obj = self._new_java_obj("org.apache.spark.ml.fpm.PrefixSpan", self.uid) + self._setDefault(minSupport=0.1, maxPatternLength=10, maxLocalProjDBSize=32000000, + sequenceCol="sequence") + kwargs = self._input_kwargs + self.setParams(**kwargs) + + @keyword_only + @since("2.4.0") + def setParams(self, minSupport=0.1, maxPatternLength=10, maxLocalProjDBSize=32000000, + sequenceCol="sequence"): + """ + setParams(self, minSupport=0.1, maxPatternLength=10, maxLocalProjDBSize=32000000, \ + sequenceCol="sequence") + """ + kwargs = self._input_kwargs + return self._set(**kwargs) + + @since("2.4.0") + def findFrequentSequentialPatterns(self, dataset): + """ + .. note:: Experimental + Finds the complete set of frequent sequential patterns in the input sequences of itemsets. + + :param dataset: A dataframe containing a sequence column which is + `ArrayType(ArrayType(T))` type, T is the item type for the input dataset. + :return: A `DataFrame` that contains columns of sequence and corresponding frequency. + The schema of it will be: + - `sequence: ArrayType(ArrayType(T))` (T is the item type) + - `freq: Long` + + >>> from pyspark.ml.fpm import PrefixSpan + >>> from pyspark.sql import Row + >>> df = sc.parallelize([Row(sequence=[[1, 2], [3]]), + ... Row(sequence=[[1], [3, 2], [1, 2]]), + ... Row(sequence=[[1, 2], [5]]), + ... Row(sequence=[[6]])]).toDF() + >>> prefixSpan = PrefixSpan(minSupport=0.5, maxPatternLength=5) + >>> prefixSpan.findFrequentSequentialPatterns(df).sort("sequence").show(truncate=False) + +----------+----+ + |sequence |freq| + +----------+----+ + |[[1]] |3 | + |[[1], [3]]|2 | + |[[1, 2]] |3 | + |[[2]] |3 | + |[[3]] |2 | + +----------+----+ + + .. versionadded:: 2.4.0 + """ + self._transfer_params_to_java() + jdf = self._java_obj.findFrequentSequentialPatterns(dataset._jdf) + return DataFrame(jdf, dataset.sql_ctx)