Skip to content

Commit df12506

Browse files
WeichenXu123mengxr
authored andcommitted
[SPARK-20114][ML][FOLLOW-UP] spark.ml parity for sequential pattern mining - PrefixSpan
## What changes were proposed in this pull request? Change `PrefixSpan` into a class with param setter/getters. This address issues mentioned here: #20973 (comment) ## How was this patch tested? UT. Please review http://spark.apache.org/contributing.html before opening a pull request. Author: WeichenXu <weichen.xu@databricks.com> Closes #21393 from WeichenXu123/fix_prefix_span.
1 parent a40ffc6 commit df12506

File tree

2 files changed

+119
-36
lines changed

2 files changed

+119
-36
lines changed

mllib/src/main/scala/org/apache/spark/ml/fpm/PrefixSpan.scala

Lines changed: 99 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@
1818
package org.apache.spark.ml.fpm
1919

2020
import org.apache.spark.annotation.{Experimental, Since}
21+
import org.apache.spark.ml.param._
22+
import org.apache.spark.ml.util.Identifiable
2123
import org.apache.spark.mllib.fpm.{PrefixSpan => mllibPrefixSpan}
2224
import org.apache.spark.sql.{DataFrame, Dataset, Row}
2325
import org.apache.spark.sql.functions.col
@@ -29,68 +31,137 @@ import org.apache.spark.sql.types.{ArrayType, LongType, StructField, StructType}
2931
* The PrefixSpan algorithm is described in J. Pei, et al., PrefixSpan: Mining Sequential Patterns
3032
* Efficiently by Prefix-Projected Pattern Growth
3133
* (see <a href="http://doi.org/10.1109/ICDE.2001.914830">here</a>).
34+
* This class is not yet an Estimator/Transformer, use `findFrequentSequentialPatterns` method to
35+
* run the PrefixSpan algorithm.
3236
*
3337
* @see <a href="https://en.wikipedia.org/wiki/Sequential_Pattern_Mining">Sequential Pattern Mining
3438
* (Wikipedia)</a>
3539
*/
3640
@Since("2.4.0")
3741
@Experimental
38-
object PrefixSpan {
42+
final class PrefixSpan(@Since("2.4.0") override val uid: String) extends Params {
43+
44+
@Since("2.4.0")
45+
def this() = this(Identifiable.randomUID("prefixSpan"))
46+
47+
/**
48+
* Param for the minimal support level (default: `0.1`).
49+
* Sequential patterns that appear more than (minSupport * size-of-the-dataset) times are
50+
* identified as frequent sequential patterns.
51+
* @group param
52+
*/
53+
@Since("2.4.0")
54+
val minSupport = new DoubleParam(this, "minSupport", "The minimal support level of the " +
55+
"sequential pattern. Sequential pattern that appears more than " +
56+
"(minSupport * size-of-the-dataset)." +
57+
"times will be output.", ParamValidators.gtEq(0.0))
58+
59+
/** @group getParam */
60+
@Since("2.4.0")
61+
def getMinSupport: Double = $(minSupport)
62+
63+
/** @group setParam */
64+
@Since("2.4.0")
65+
def setMinSupport(value: Double): this.type = set(minSupport, value)
66+
67+
/**
68+
* Param for the maximal pattern length (default: `10`).
69+
* @group param
70+
*/
71+
@Since("2.4.0")
72+
val maxPatternLength = new IntParam(this, "maxPatternLength",
73+
"The maximal length of the sequential pattern.",
74+
ParamValidators.gt(0))
75+
76+
/** @group getParam */
77+
@Since("2.4.0")
78+
def getMaxPatternLength: Int = $(maxPatternLength)
79+
80+
/** @group setParam */
81+
@Since("2.4.0")
82+
def setMaxPatternLength(value: Int): this.type = set(maxPatternLength, value)
83+
84+
/**
85+
* Param for the maximum number of items (including delimiters used in the internal storage
86+
* format) allowed in a projected database before local processing (default: `32000000`).
87+
* If a projected database exceeds this size, another iteration of distributed prefix growth
88+
* is run.
89+
* @group param
90+
*/
91+
@Since("2.4.0")
92+
val maxLocalProjDBSize = new LongParam(this, "maxLocalProjDBSize",
93+
"The maximum number of items (including delimiters used in the internal storage format) " +
94+
"allowed in a projected database before local processing. If a projected database exceeds " +
95+
"this size, another iteration of distributed prefix growth is run.",
96+
ParamValidators.gt(0))
97+
98+
/** @group getParam */
99+
@Since("2.4.0")
100+
def getMaxLocalProjDBSize: Long = $(maxLocalProjDBSize)
101+
102+
/** @group setParam */
103+
@Since("2.4.0")
104+
def setMaxLocalProjDBSize(value: Long): this.type = set(maxLocalProjDBSize, value)
105+
106+
/**
107+
* Param for the name of the sequence column in dataset (default "sequence"), rows with
108+
* nulls in this column are ignored.
109+
* @group param
110+
*/
111+
@Since("2.4.0")
112+
val sequenceCol = new Param[String](this, "sequenceCol", "The name of the sequence column in " +
113+
"dataset, rows with nulls in this column are ignored.")
114+
115+
/** @group getParam */
116+
@Since("2.4.0")
117+
def getSequenceCol: String = $(sequenceCol)
118+
119+
/** @group setParam */
120+
@Since("2.4.0")
121+
def setSequenceCol(value: String): this.type = set(sequenceCol, value)
122+
123+
setDefault(minSupport -> 0.1, maxPatternLength -> 10, maxLocalProjDBSize -> 32000000,
124+
sequenceCol -> "sequence")
39125

40126
/**
41127
* :: Experimental ::
42128
* Finds the complete set of frequent sequential patterns in the input sequences of itemsets.
43129
*
44130
* @param dataset A dataset or a dataframe containing a sequence column which is
45131
* {{{Seq[Seq[_]]}}} type
46-
* @param sequenceCol the name of the sequence column in dataset, rows with nulls in this column
47-
* are ignored
48-
* @param minSupport the minimal support level of the sequential pattern, any pattern that
49-
* appears more than (minSupport * size-of-the-dataset) times will be output
50-
* (recommended value: `0.1`).
51-
* @param maxPatternLength the maximal length of the sequential pattern
52-
* (recommended value: `10`).
53-
* @param maxLocalProjDBSize The maximum number of items (including delimiters used in the
54-
* internal storage format) allowed in a projected database before
55-
* local processing. If a projected database exceeds this size, another
56-
* iteration of distributed prefix growth is run
57-
* (recommended value: `32000000`).
58132
* @return A `DataFrame` that contains columns of sequence and corresponding frequency.
59133
* The schema of it will be:
60134
* - `sequence: Seq[Seq[T]]` (T is the item type)
61135
* - `freq: Long`
62136
*/
63137
@Since("2.4.0")
64-
def findFrequentSequentialPatterns(
65-
dataset: Dataset[_],
66-
sequenceCol: String,
67-
minSupport: Double,
68-
maxPatternLength: Int,
69-
maxLocalProjDBSize: Long): DataFrame = {
70-
71-
val inputType = dataset.schema(sequenceCol).dataType
138+
def findFrequentSequentialPatterns(dataset: Dataset[_]): DataFrame = {
139+
val sequenceColParam = $(sequenceCol)
140+
val inputType = dataset.schema(sequenceColParam).dataType
72141
require(inputType.isInstanceOf[ArrayType] &&
73142
inputType.asInstanceOf[ArrayType].elementType.isInstanceOf[ArrayType],
74143
s"The input column must be ArrayType and the array element type must also be ArrayType, " +
75144
s"but got $inputType.")
76145

77-
78-
val data = dataset.select(sequenceCol)
79-
val sequences = data.where(col(sequenceCol).isNotNull).rdd
146+
val data = dataset.select(sequenceColParam)
147+
val sequences = data.where(col(sequenceColParam).isNotNull).rdd
80148
.map(r => r.getAs[Seq[Seq[Any]]](0).map(_.toArray).toArray)
81149

82150
val mllibPrefixSpan = new mllibPrefixSpan()
83-
.setMinSupport(minSupport)
84-
.setMaxPatternLength(maxPatternLength)
85-
.setMaxLocalProjDBSize(maxLocalProjDBSize)
151+
.setMinSupport($(minSupport))
152+
.setMaxPatternLength($(maxPatternLength))
153+
.setMaxLocalProjDBSize($(maxLocalProjDBSize))
86154

87155
val rows = mllibPrefixSpan.run(sequences).freqSequences.map(f => Row(f.sequence, f.freq))
88156
val schema = StructType(Seq(
89-
StructField("sequence", dataset.schema(sequenceCol).dataType, nullable = false),
157+
StructField("sequence", dataset.schema(sequenceColParam).dataType, nullable = false),
90158
StructField("freq", LongType, nullable = false)))
91159
val freqSequences = dataset.sparkSession.createDataFrame(rows, schema)
92160

93161
freqSequences
94162
}
95163

164+
@Since("2.4.0")
165+
override def copy(extra: ParamMap): PrefixSpan = defaultCopy(extra)
166+
96167
}

mllib/src/test/scala/org/apache/spark/ml/fpm/PrefixSpanSuite.scala

Lines changed: 20 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -29,8 +29,11 @@ class PrefixSpanSuite extends MLTest {
2929

3030
test("PrefixSpan projections with multiple partial starts") {
3131
val smallDataset = Seq(Seq(Seq(1, 2), Seq(1, 2, 3))).toDF("sequence")
32-
val result = PrefixSpan.findFrequentSequentialPatterns(smallDataset, "sequence",
33-
minSupport = 1.0, maxPatternLength = 2, maxLocalProjDBSize = 32000000)
32+
val result = new PrefixSpan()
33+
.setMinSupport(1.0)
34+
.setMaxPatternLength(2)
35+
.setMaxLocalProjDBSize(32000000)
36+
.findFrequentSequentialPatterns(smallDataset)
3437
.as[(Seq[Seq[Int]], Long)].collect()
3538
val expected = Array(
3639
(Seq(Seq(1)), 1L),
@@ -90,17 +93,23 @@ class PrefixSpanSuite extends MLTest {
9093

9194
test("PrefixSpan Integer type, variable-size itemsets") {
9295
val df = smallTestData.toDF("sequence")
93-
val result = PrefixSpan.findFrequentSequentialPatterns(df, "sequence",
94-
minSupport = 0.5, maxPatternLength = 5, maxLocalProjDBSize = 32000000)
96+
val result = new PrefixSpan()
97+
.setMinSupport(0.5)
98+
.setMaxPatternLength(5)
99+
.setMaxLocalProjDBSize(32000000)
100+
.findFrequentSequentialPatterns(df)
95101
.as[(Seq[Seq[Int]], Long)].collect()
96102

97103
compareResults[Int](smallTestDataExpectedResult, result)
98104
}
99105

100106
test("PrefixSpan input row with nulls") {
101107
val df = (smallTestData :+ null).toDF("sequence")
102-
val result = PrefixSpan.findFrequentSequentialPatterns(df, "sequence",
103-
minSupport = 0.5, maxPatternLength = 5, maxLocalProjDBSize = 32000000)
108+
val result = new PrefixSpan()
109+
.setMinSupport(0.5)
110+
.setMaxPatternLength(5)
111+
.setMaxLocalProjDBSize(32000000)
112+
.findFrequentSequentialPatterns(df)
104113
.as[(Seq[Seq[Int]], Long)].collect()
105114

106115
compareResults[Int](smallTestDataExpectedResult, result)
@@ -111,8 +120,11 @@ class PrefixSpanSuite extends MLTest {
111120
val df = smallTestData
112121
.map(seq => seq.map(itemSet => itemSet.map(intToString)))
113122
.toDF("sequence")
114-
val result = PrefixSpan.findFrequentSequentialPatterns(df, "sequence",
115-
minSupport = 0.5, maxPatternLength = 5, maxLocalProjDBSize = 32000000)
123+
val result = new PrefixSpan()
124+
.setMinSupport(0.5)
125+
.setMaxPatternLength(5)
126+
.setMaxLocalProjDBSize(32000000)
127+
.findFrequentSequentialPatterns(df)
116128
.as[(Seq[Seq[String]], Long)].collect()
117129

118130
val expected = smallTestDataExpectedResult.map { case (seq, freq) =>

0 commit comments

Comments
 (0)