diff --git a/mllib/src/main/scala/org/apache/spark/ml/fpm/PrefixSpan.scala b/mllib/src/main/scala/org/apache/spark/ml/fpm/PrefixSpan.scala
new file mode 100644
index 000000000000..820320982627
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/ml/fpm/PrefixSpan.scala
@@ -0,0 +1,254 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.fpm
+
+import scala.reflect.ClassTag
+
+import org.apache.hadoop.fs.Path
+
+import org.apache.spark.annotation.{Experimental, Since}
+import org.apache.spark.ml.{Estimator, Model}
+import org.apache.spark.ml.param._
+import org.apache.spark.ml.util._
+import org.apache.spark.mllib.fpm.{PrefixSpan => mllibPrefixSpan}
+import org.apache.spark.mllib.fpm.PrefixSpan.FreqSequence
+import org.apache.spark.sql._
+import org.apache.spark.sql.functions._
+import org.apache.spark.sql.types._
+import org.apache.spark.storage.StorageLevel
+
+/**
+ * Common params for PrefixSpan and PrefixSpanModel
+ */
+private[fpm] trait PrefixSpanParams extends Params {
+
+ /**
+ * Sequence column name.
+ * Default: "sequence"
+ * @group param
+ */
+ @Since("2.4.0")
+ val sequenceCol: Param[String] = new Param[String](this, "sequenceCol", "sequence column name")
+ setDefault(sequenceCol -> "sequence")
+
+ /** @group getParam */
+ @Since("2.2.0")
+ def getSequenceCol: String = $(sequenceCol)
+
+ /**
+ * Minimal support level of the sequential pattern. Any pattern that appears
+ * more than (minSupport * size-of-the-dataset) times will be output.
+ * Default: 0.1
+ * @group param
+ */
+ @Since("2.4.0")
+ val minSupport: DoubleParam = new DoubleParam(this, "minSupport",
+ "the minimal support level of a sequential pattern",
+ ParamValidators.inRange(0.0, 1.0))
+ setDefault(minSupport -> 0.1)
+
+ /** @group getParam */
+ @Since("2.4.0")
+ def getMinSupport: Double = $(minSupport)
+
+ /**
+ * The maximal length of the sequential pattern, any pattern that appears
+ * less than maxPatternLength will be output
+ * Default: 10
+ * @group param
+ */
+ @Since("2.4.0")
+ val maxPatternLength: IntParam = new IntParam(this, "maxPatternLength",
+ "the maximal length of the sequential pattern",
+ ParamValidators.inRange(1, Int.MaxValue))
+ setDefault(maxPatternLength -> 10)
+
+ /**
+ * The maximum number of items (including delimiters used in the internal
+ * storage format) allowed in a projected database before local
+ * processing. If a projected database exceeds this size, another
+ * iteration of distributed prefix growth is run.
+ * Default: 32000000L
+ * @group param
+ */
+ @Since("2.4.0")
+ val maxLocalProjDBSize: LongParam = new LongParam(this, "maxLocalProjDBSize",
+ "The maximum number of items (including delimiters used in the internal " +
+ "storage format) allowed in a projected database before local processin")
+ setDefault(maxLocalProjDBSize -> 32000000L)
+
+ /**
+ * Validates and transforms the input schema.
+ * @param schema input schema
+ * @return output schema
+ */
+ @Since("2.2.0")
+ protected def validateAndTransformSchema(schema: StructType): StructType = {
+ val inputType = schema($(sequenceCol)).dataType
+ require(inputType.isInstanceOf[ArrayType] &&
+ inputType.asInstanceOf[ArrayType].elementType.isInstanceOf[ArrayType],
+ s"The input column must be ArrayType[ArrayType], but got $inputType.")
+ schema
+ }
+}
+
+/**
+ * :: Experimental ::
+ * A parallel PrefixSpan algorithm to mine frequent sequential patterns.
+ * The PrefixSpan algorithm is described in J. Pei, et al., PrefixSpan: Mining Sequential Patterns
+ * Efficiently by Prefix-Projected Pattern Growth
+ * (see here).
+ *
+ * @see Sequential Pattern Mining
+ * (Wikipedia)
+ */
+@Since("2.4.0")
+@Experimental
+class PrefixSpan @Since("2.4.0") (@Since("2.4.0") override val uid: String)
+ extends Estimator[PrefixSpanModel] with PrefixSpanParams with DefaultParamsWritable {
+
+ @Since("2.4.0")
+ def this() = this(Identifiable.randomUID("prefixspan"))
+
+ /** @group setParam */
+ @Since("2.4.0")
+ def setMinSupport(value: Double): this.type = set(minSupport, value)
+
+ /** @group setParam */
+ @Since("2.4.0")
+ def setMaxPatternLength(value: Int): this.type = set(maxPatternLength, value)
+
+ /** @group setParam */
+ @Since("2.4.0")
+ def setMaxLocalProjDBSize(value: Long): this.type = set(maxLocalProjDBSize, value)
+
+ @Since("2.4.0")
+ override def fit(dataset: Dataset[_]): PrefixSpanModel = {
+ transformSchema(dataset.schema, logging = true)
+ genericFit(dataset)
+ }
+
+ private def genericFit[T: ClassTag](dataset: Dataset[_]): PrefixSpanModel = {
+ val handlePersistence = dataset.storageLevel == StorageLevel.NONE
+
+ val data = dataset.select($(sequenceCol))
+ val sequences = data.where(col($(sequenceCol)).isNotNull).rdd
+ .map(r => r.getAs[Seq[Seq[Any]]](0).map(_.toArray).toArray)
+ val mllibPrefixSpan = new mllibPrefixSpan()
+ .setMinSupport($(minSupport))
+ .setMaxPatternLength($(maxPatternLength))
+ .setMaxLocalProjDBSize($(maxLocalProjDBSize))
+ if (handlePersistence) {
+ sequences.persist(StorageLevel.MEMORY_AND_DISK)
+ }
+ val rows = mllibPrefixSpan.run(sequences).freqSequences.map(f => Row(f.sequence, f.freq))
+ val schema = StructType(Seq(
+ StructField("sequence", dataset.schema($(sequenceCol)).dataType, nullable = false),
+ StructField("freq", LongType, nullable = false)))
+ val freqSequences = dataset.sparkSession.createDataFrame(rows, schema)
+
+ if (handlePersistence) {
+ sequences.unpersist()
+ }
+
+ copyValues(new PrefixSpanModel(uid, freqSequences)).setParent(this)
+ }
+
+ @Since("2.4.0")
+ override def transformSchema(schema: StructType): StructType = {
+ validateAndTransformSchema(schema)
+ }
+
+ @Since("2.4.0")
+ override def copy(extra: ParamMap): PrefixSpan = defaultCopy(extra)
+}
+
+@Since("2.4.0")
+object PrefixSpan extends DefaultParamsReadable[PrefixSpan] {
+
+ @Since("2.4.0")
+ override def load(path: String): PrefixSpan = super.load(path)
+}
+
+/**
+ * :: Experimental ::
+ * Model fitted by PrefixSpan.
+ */
+@Since("2.4.0")
+@Experimental
+class PrefixSpanModel private[ml] (
+ @Since("2.4.0") override val uid: String,
+ @Since("2.4.0") @transient val freqSequences: DataFrame)
+ extends Model[PrefixSpanModel] with PrefixSpanParams with MLWritable {
+
+ @Since("2.4.0")
+ override def transform(dataset: Dataset[_]): DataFrame = {
+ transformSchema(dataset.schema, logging = true)
+ dataset.toDF()
+ }
+
+ @Since("2.4.0")
+ override def transformSchema(schema: StructType): StructType = {
+ validateAndTransformSchema(schema)
+ }
+
+ @Since("2.4.0")
+ override def copy(extra: ParamMap): PrefixSpanModel = {
+ val copied = new PrefixSpanModel(uid, freqSequences)
+ copyValues(copied, extra).setParent(this.parent)
+ }
+
+ @Since("2.4.0")
+ override def write: MLWriter = new PrefixSpanModel.PrefixSpanModelWriter(this)
+}
+
+@Since("2.4.0")
+object PrefixSpanModel extends MLReadable[PrefixSpanModel] {
+
+ @Since("2.4.0")
+ override def read: MLReader[PrefixSpanModel] = new PrefixSpanModelReader
+
+ @Since("2.4.0")
+ override def load(path: String): PrefixSpanModel = super.load(path)
+
+ /** [[MLWriter]] instance for [[PrefixSpanModel]] */
+ private[PrefixSpanModel]
+ class PrefixSpanModelWriter(instance: PrefixSpanModel) extends MLWriter {
+
+ override protected def saveImpl(path: String): Unit = {
+ DefaultParamsWriter.saveMetadata(instance, path, sc)
+ val dataPath = new Path(path, "data").toString
+ instance.freqSequences.write.parquet(dataPath)
+ }
+ }
+
+ private class PrefixSpanModelReader extends MLReader[PrefixSpanModel] {
+
+ /** Checked against metadata when loading model */
+ private val className = classOf[PrefixSpanModel].getName
+
+ override def load(path: String): PrefixSpanModel = {
+ val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
+ val dataPath = new Path(path, "data").toString
+ val frequentItems = sparkSession.read.parquet(dataPath)
+ val model = new PrefixSpanModel(metadata.uid, frequentItems)
+ DefaultParamsReader.getAndSetParams(model, metadata)
+ model
+ }
+ }
+}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/fpm/PrefixSpanSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/fpm/PrefixSpanSuite.scala
new file mode 100644
index 000000000000..23fae3ae9758
--- /dev/null
+++ b/mllib/src/test/scala/org/apache/spark/ml/fpm/PrefixSpanSuite.scala
@@ -0,0 +1,145 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.ml.fpm
+
+import org.apache.spark.SparkFunSuite
+import org.apache.spark.ml.param.ParamsSuite
+import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils}
+import org.apache.spark.mllib.util.MLlibTestSparkContext
+import org.apache.spark.sql.DataFrame
+
+class PrefixSpanSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {
+
+ import testImplicits._
+
+ override def beforeAll(): Unit = {
+ super.beforeAll()
+ smallDataset = Seq(Seq(Seq(1, 2), Seq(1, 2, 3))).toDF("sequence")
+ }
+
+ @transient var smallDataset: DataFrame = _
+
+ test("PrefixSpan parameter check") {
+ val prefixSpan = new PrefixSpan()
+ val model = prefixSpan.fit(smallDataset)
+ ParamsSuite.checkParams(prefixSpan)
+ ParamsSuite.checkParams(model)
+ }
+
+ test("PrefixSpan projections with multiple partial starts") {
+ val prefixSpan = new PrefixSpan()
+ .setMinSupport(1.0)
+ .setMaxPatternLength(2)
+ val model = prefixSpan.fit(smallDataset)
+ val result = model.freqSequences.as[(Seq[Seq[Int]], Long)].collect()
+ val expected = Array(
+ (Seq(Seq(1)), 1L),
+ (Seq(Seq(1, 2)), 1L),
+ (Seq(Seq(1), Seq(1)), 1L),
+ (Seq(Seq(1), Seq(2)), 1L),
+ (Seq(Seq(1), Seq(3)), 1L),
+ (Seq(Seq(1, 3)), 1L),
+ (Seq(Seq(2)), 1L),
+ (Seq(Seq(2, 3)), 1L),
+ (Seq(Seq(2), Seq(1)), 1L),
+ (Seq(Seq(2), Seq(2)), 1L),
+ (Seq(Seq(2), Seq(3)), 1L),
+ (Seq(Seq(3)), 1L))
+ compareResults[Int](expected, result)
+ }
+
+ val testData = Seq(
+ Seq(Seq(1, 2), Seq(3)),
+ Seq(Seq(1), Seq(3, 2), Seq(1, 2)),
+ Seq(Seq(1, 2), Seq(5)),
+ Seq(Seq(6)))
+
+ val expectedResult = Array(
+ (Seq(Seq(1)), 3L),
+ (Seq(Seq(2)), 3L),
+ (Seq(Seq(3)), 2L),
+ (Seq(Seq(1), Seq(3)), 2L),
+ (Seq(Seq(1, 2)), 3L)
+ )
+
+ test("PrefixSpan Integer type, variable-size itemsets") {
+ val df = testData.toDF("sequence")
+ val prefixSpan = new PrefixSpan()
+ .setMinSupport(0.5)
+ .setMaxPatternLength(5)
+ val model = prefixSpan.fit(df)
+ val result = model.freqSequences.as[(Seq[Seq[Int]], Long)].collect()
+
+ compareResults[Int](expectedResult, result)
+ }
+
+ test("PrefixSpan String type, variable-size itemsets") {
+ val intToString = (1 to 6).zip(Seq("a", "b", "c", "d", "e", "f")).toMap
+ val df = testData
+ .map(seq => seq.map(itemSet => itemSet.map(intToString)))
+ .toDF("sequence")
+ val prefixSpan = new PrefixSpan()
+ .setMinSupport(0.5)
+ .setMaxPatternLength(5)
+ val model = prefixSpan.fit(df)
+ val result = model.freqSequences.as[(Seq[Seq[String]], Long)].collect()
+ val expected = expectedResult.map { case (seq, freq) =>
+ (seq.map(itemSet => itemSet.map(intToString)), freq)
+ }
+ compareResults[String](expected, result)
+ }
+
+ test("read/write") {
+ def checkModelData(model: PrefixSpanModel, model2: PrefixSpanModel): Unit = {
+ compareResults(
+ model.freqSequences.as[(Seq[Seq[Int]], Long)].collect(),
+ model2.freqSequences.as[(Seq[Seq[Int]], Long)].collect()
+ )
+ }
+ val prefixSpan = new PrefixSpan()
+ testEstimatorAndModelReadWrite(prefixSpan, smallDataset, PrefixSpanSuite.allParamSettings,
+ PrefixSpanSuite.allParamSettings, checkModelData)
+ }
+
+ private def compareResults[Item](
+ expectedValue: Array[(Seq[Seq[Item]], Long)],
+ actualValue: Array[(Seq[Seq[Item]], Long)]): Unit = {
+ val expectedSet = expectedValue.map { x =>
+ (x._1.map(_.toSet), x._2)
+ }.toSet
+ val actualSet = actualValue.map { x =>
+ (x._1.map(_.toSet), x._2)
+ }.toSet
+ assert(expectedSet === actualSet)
+ }
+
+}
+
+object PrefixSpanSuite {
+
+ /**
+ * Mapping from all Params to valid settings which differ from the defaults.
+ * This is useful for tests which need to exercise all Params, such as save/load.
+ * This excludes input columns to simplify some tests.
+ */
+ val allParamSettings: Map[String, Any] = Map(
+ "minSupport" -> 0.2,
+ "maxPatternLength" -> 20,
+ "maxLocalProjDBSize" -> 50000000L
+ )
+
+}