1818package org .apache .spark .ml .fpm
1919
2020import org .apache .spark .annotation .{Experimental , Since }
21+ import org .apache .spark .ml .param ._
22+ import org .apache .spark .ml .util .Identifiable
2123import org .apache .spark .mllib .fpm .{PrefixSpan => mllibPrefixSpan }
2224import org .apache .spark .sql .{DataFrame , Dataset , Row }
2325import org .apache .spark .sql .functions .col
@@ -29,68 +31,137 @@ import org.apache.spark.sql.types.{ArrayType, LongType, StructField, StructType}
2931 * The PrefixSpan algorithm is described in J. Pei, et al., PrefixSpan: Mining Sequential Patterns
3032 * Efficiently by Prefix-Projected Pattern Growth
3133 * (see <a href="http://doi.org/10.1109/ICDE.2001.914830">here</a>).
34+ * This class is not yet an Estimator/Transformer, use `findFrequentSequentialPatterns` method to
35+ * run the PrefixSpan algorithm.
3236 *
3337 * @see <a href="https://en.wikipedia.org/wiki/Sequential_Pattern_Mining">Sequential Pattern Mining
3438 * (Wikipedia)</a>
3539 */
3640@ Since (" 2.4.0" )
3741@ Experimental
38- object PrefixSpan {
42+ final class PrefixSpan (@ Since (" 2.4.0" ) override val uid : String ) extends Params {
43+
44+ @ Since (" 2.4.0" )
45+ def this () = this (Identifiable .randomUID(" prefixSpan" ))
46+
47+ /**
48+ * Param for the minimal support level (default: `0.1`).
49+ * Sequential patterns that appear more than (minSupport * size-of-the-dataset) times are
50+ * identified as frequent sequential patterns.
51+ * @group param
52+ */
53+ @ Since (" 2.4.0" )
54+ val minSupport = new DoubleParam (this , " minSupport" , " The minimal support level of the " +
55+ " sequential pattern. Sequential pattern that appears more than " +
56+ " (minSupport * size-of-the-dataset)." +
57+ " times will be output." , ParamValidators .gtEq(0.0 ))
58+
59+ /** @group getParam */
60+ @ Since (" 2.4.0" )
61+ def getMinSupport : Double = $(minSupport)
62+
63+ /** @group setParam */
64+ @ Since (" 2.4.0" )
65+ def setMinSupport (value : Double ): this .type = set(minSupport, value)
66+
67+ /**
68+ * Param for the maximal pattern length (default: `10`).
69+ * @group param
70+ */
71+ @ Since (" 2.4.0" )
72+ val maxPatternLength = new IntParam (this , " maxPatternLength" ,
73+ " The maximal length of the sequential pattern." ,
74+ ParamValidators .gt(0 ))
75+
76+ /** @group getParam */
77+ @ Since (" 2.4.0" )
78+ def getMaxPatternLength : Int = $(maxPatternLength)
79+
80+ /** @group setParam */
81+ @ Since (" 2.4.0" )
82+ def setMaxPatternLength (value : Int ): this .type = set(maxPatternLength, value)
83+
84+ /**
85+ * Param for the maximum number of items (including delimiters used in the internal storage
86+ * format) allowed in a projected database before local processing (default: `32000000`).
87+ * If a projected database exceeds this size, another iteration of distributed prefix growth
88+ * is run.
89+ * @group param
90+ */
91+ @ Since (" 2.4.0" )
92+ val maxLocalProjDBSize = new LongParam (this , " maxLocalProjDBSize" ,
93+ " The maximum number of items (including delimiters used in the internal storage format) " +
94+ " allowed in a projected database before local processing. If a projected database exceeds " +
95+ " this size, another iteration of distributed prefix growth is run." ,
96+ ParamValidators .gt(0 ))
97+
98+ /** @group getParam */
99+ @ Since (" 2.4.0" )
100+ def getMaxLocalProjDBSize : Long = $(maxLocalProjDBSize)
101+
102+ /** @group setParam */
103+ @ Since (" 2.4.0" )
104+ def setMaxLocalProjDBSize (value : Long ): this .type = set(maxLocalProjDBSize, value)
105+
106+ /**
107+ * Param for the name of the sequence column in dataset (default "sequence"), rows with
108+ * nulls in this column are ignored.
109+ * @group param
110+ */
111+ @ Since (" 2.4.0" )
112+ val sequenceCol = new Param [String ](this , " sequenceCol" , " The name of the sequence column in " +
113+ " dataset, rows with nulls in this column are ignored." )
114+
115+ /** @group getParam */
116+ @ Since (" 2.4.0" )
117+ def getSequenceCol : String = $(sequenceCol)
118+
119+ /** @group setParam */
120+ @ Since (" 2.4.0" )
121+ def setSequenceCol (value : String ): this .type = set(sequenceCol, value)
122+
123+ setDefault(minSupport -> 0.1 , maxPatternLength -> 10 , maxLocalProjDBSize -> 32000000 ,
124+ sequenceCol -> " sequence" )
39125
40126 /**
41127 * :: Experimental ::
42128 * Finds the complete set of frequent sequential patterns in the input sequences of itemsets.
43129 *
44130 * @param dataset A dataset or a dataframe containing a sequence column which is
45131 * {{{Seq[Seq[_]]}}} type
46- * @param sequenceCol the name of the sequence column in dataset, rows with nulls in this column
47- * are ignored
48- * @param minSupport the minimal support level of the sequential pattern, any pattern that
49- * appears more than (minSupport * size-of-the-dataset) times will be output
50- * (recommended value: `0.1`).
51- * @param maxPatternLength the maximal length of the sequential pattern
52- * (recommended value: `10`).
53- * @param maxLocalProjDBSize The maximum number of items (including delimiters used in the
54- * internal storage format) allowed in a projected database before
55- * local processing. If a projected database exceeds this size, another
56- * iteration of distributed prefix growth is run
57- * (recommended value: `32000000`).
58132 * @return A `DataFrame` that contains columns of sequence and corresponding frequency.
59133 * The schema of it will be:
60134 * - `sequence: Seq[Seq[T]]` (T is the item type)
61135 * - `freq: Long`
62136 */
63137 @ Since (" 2.4.0" )
64- def findFrequentSequentialPatterns (
65- dataset : Dataset [_],
66- sequenceCol : String ,
67- minSupport : Double ,
68- maxPatternLength : Int ,
69- maxLocalProjDBSize : Long ): DataFrame = {
70-
71- val inputType = dataset.schema(sequenceCol).dataType
138+ def findFrequentSequentialPatterns (dataset : Dataset [_]): DataFrame = {
139+ val sequenceColParam = $(sequenceCol)
140+ val inputType = dataset.schema(sequenceColParam).dataType
72141 require(inputType.isInstanceOf [ArrayType ] &&
73142 inputType.asInstanceOf [ArrayType ].elementType.isInstanceOf [ArrayType ],
74143 s " The input column must be ArrayType and the array element type must also be ArrayType, " +
75144 s " but got $inputType. " )
76145
77-
78- val data = dataset.select(sequenceCol)
79- val sequences = data.where(col(sequenceCol).isNotNull).rdd
146+ val data = dataset.select(sequenceColParam)
147+ val sequences = data.where(col(sequenceColParam).isNotNull).rdd
80148 .map(r => r.getAs[Seq [Seq [Any ]]](0 ).map(_.toArray).toArray)
81149
82150 val mllibPrefixSpan = new mllibPrefixSpan()
83- .setMinSupport(minSupport)
84- .setMaxPatternLength(maxPatternLength)
85- .setMaxLocalProjDBSize(maxLocalProjDBSize)
151+ .setMinSupport($( minSupport) )
152+ .setMaxPatternLength($( maxPatternLength) )
153+ .setMaxLocalProjDBSize($( maxLocalProjDBSize) )
86154
87155 val rows = mllibPrefixSpan.run(sequences).freqSequences.map(f => Row (f.sequence, f.freq))
88156 val schema = StructType (Seq (
89- StructField (" sequence" , dataset.schema(sequenceCol ).dataType, nullable = false ),
157+ StructField (" sequence" , dataset.schema(sequenceColParam ).dataType, nullable = false ),
90158 StructField (" freq" , LongType , nullable = false )))
91159 val freqSequences = dataset.sparkSession.createDataFrame(rows, schema)
92160
93161 freqSequences
94162 }
95163
164+ @ Since (" 2.4.0" )
165+ override def copy (extra : ParamMap ): PrefixSpan = defaultCopy(extra)
166+
96167}
0 commit comments