From 60efd0520a3af52995c2d6b1a2abaeebe658bb32 Mon Sep 17 00:00:00 2001 From: Yuhao Yang Date: Tue, 14 Jun 2016 14:27:21 +0800 Subject: [PATCH 1/3] add support for association rule --- .../org/apache/spark/mllib/fpm/AssociationRules.scala | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/fpm/AssociationRules.scala b/mllib/src/main/scala/org/apache/spark/mllib/fpm/AssociationRules.scala index 9a63cc29dacb..391fc76b2f9f 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/fpm/AssociationRules.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/fpm/AssociationRules.scala @@ -120,6 +120,13 @@ object AssociationRules { @Since("1.5.0") def confidence: Double = freqUnion.toDouble / freqAntecedent + /** + * Returns the support of the rule. Current implementation would return the number of + * co-occurrence of antecedent and consequent. + */ + @Since("2.1.0") + def support: Double = freqUnion.toDouble + require(antecedent.toSet.intersect(consequent.toSet).isEmpty, { val sharedItems = antecedent.toSet.intersect(consequent.toSet) s"A valid association rule must have disjoint antecedent and " + From 8b166761024c1b5bed9f90aa8f550eb2103b9b64 Mon Sep 17 00:00:00 2001 From: Yuhao Yang Date: Thu, 23 Jun 2016 17:30:47 -0400 Subject: [PATCH 2/3] add data size --- .../mllib/JavaAssociationRulesExample.java | 2 +- .../mllib/AssociationRulesExample.scala | 2 +- .../api/python/FPGrowthModelWrapper.scala | 2 +- .../spark/mllib/fpm/AssociationRules.scala | 13 +++++++------ .../org/apache/spark/mllib/fpm/FPGrowth.scala | 19 +++++++++++-------- .../mllib/fpm/JavaAssociationRulesSuite.java | 2 +- .../mllib/fpm/AssociationRulesSuite.scala | 4 ++-- 7 files changed, 24 insertions(+), 20 deletions(-) diff --git a/examples/src/main/java/org/apache/spark/examples/mllib/JavaAssociationRulesExample.java b/examples/src/main/java/org/apache/spark/examples/mllib/JavaAssociationRulesExample.java index 189560e3fe1f..1e622f153c98 100644 --- a/examples/src/main/java/org/apache/spark/examples/mllib/JavaAssociationRulesExample.java +++ b/examples/src/main/java/org/apache/spark/examples/mllib/JavaAssociationRulesExample.java @@ -45,7 +45,7 @@ public static void main(String[] args) { AssociationRules arules = new AssociationRules() .setMinConfidence(0.8); - JavaRDD> results = arules.run(freqItemsets); + JavaRDD> results = arules.run(freqItemsets, 50L); for (AssociationRules.Rule rule : results.collect()) { System.out.println( diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/AssociationRulesExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/AssociationRulesExample.scala index 11e18c9f040b..23633bfff2d2 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/AssociationRulesExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/AssociationRulesExample.scala @@ -39,7 +39,7 @@ object AssociationRulesExample { val ar = new AssociationRules() .setMinConfidence(0.8) - val results = ar.run(freqItemsets) + val results = ar.run(freqItemsets, 50L) results.collect().foreach { rule => println("[" + rule.antecedent.mkString(",") diff --git a/mllib/src/main/scala/org/apache/spark/mllib/api/python/FPGrowthModelWrapper.scala b/mllib/src/main/scala/org/apache/spark/mllib/api/python/FPGrowthModelWrapper.scala index e6d1dceebed4..fb0034cfa03e 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/api/python/FPGrowthModelWrapper.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/api/python/FPGrowthModelWrapper.scala @@ -24,7 +24,7 @@ import org.apache.spark.rdd.RDD * A Wrapper of FPGrowthModel to provide helper method for Python */ private[python] class FPGrowthModelWrapper(model: FPGrowthModel[Any]) - extends FPGrowthModel(model.freqItemsets) { + extends FPGrowthModel(model.freqItemsets, model.dataSize) { def getFreqItemsets: RDD[Array[Any]] = { SerDe.fromTuple2RDD(model.freqItemsets.map(x => (x.javaItems, x.freq))) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/fpm/AssociationRules.scala b/mllib/src/main/scala/org/apache/spark/mllib/fpm/AssociationRules.scala index 391fc76b2f9f..c6650e09975c 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/fpm/AssociationRules.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/fpm/AssociationRules.scala @@ -63,7 +63,7 @@ class AssociationRules private[fpm] ( * */ @Since("1.5.0") - def run[Item: ClassTag](freqItemsets: RDD[FreqItemset[Item]]): RDD[Rule[Item]] = { + def run[Item: ClassTag](freqItemsets: RDD[FreqItemset[Item]], dataSize: Long): RDD[Rule[Item]] = { // For candidate rule X => Y, generate (X, (Y, freq(X union Y))) val candidates = freqItemsets.flatMap { itemset => val items = itemset.items @@ -79,15 +79,15 @@ class AssociationRules private[fpm] ( // Join to get (X, ((Y, freq(X union Y)), freq(X))), generate rules, and filter by confidence candidates.join(freqItemsets.map(x => (x.items.toSeq, x.freq))) .map { case (antecendent, ((consequent, freqUnion), freqAntecedent)) => - new Rule(antecendent.toArray, consequent.toArray, freqUnion, freqAntecedent) + new Rule(antecendent.toArray, consequent.toArray, freqUnion, freqAntecedent, dataSize) }.filter(_.confidence >= minConfidence) } /** Java-friendly version of [[run]]. */ @Since("1.5.0") - def run[Item](freqItemsets: JavaRDD[FreqItemset[Item]]): JavaRDD[Rule[Item]] = { + def run[Item](freqItemsets: JavaRDD[FreqItemset[Item]], dataSize: Long): JavaRDD[Rule[Item]] = { val tag = fakeClassTag[Item] - run(freqItemsets.rdd)(tag) + run(freqItemsets.rdd, dataSize)(tag) } } @@ -111,7 +111,8 @@ object AssociationRules { @Since("1.5.0") val antecedent: Array[Item], @Since("1.5.0") val consequent: Array[Item], freqUnion: Double, - freqAntecedent: Double) extends Serializable { + freqAntecedent: Double, + dataSize: Long) extends Serializable { /** * Returns the confidence of the rule. @@ -125,7 +126,7 @@ object AssociationRules { * co-occurrence of antecedent and consequent. */ @Since("2.1.0") - def support: Double = freqUnion.toDouble + def support: Double = freqUnion.toDouble / dataSize require(antecedent.toSet.intersect(consequent.toSet).isEmpty, { val sharedItems = antecedent.toSet.intersect(consequent.toSet) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/fpm/FPGrowth.scala b/mllib/src/main/scala/org/apache/spark/mllib/fpm/FPGrowth.scala index 0f7fbe9556c5..c15c536a8a9e 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/fpm/FPGrowth.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/fpm/FPGrowth.scala @@ -49,7 +49,8 @@ import org.apache.spark.storage.StorageLevel */ @Since("1.3.0") class FPGrowthModel[Item: ClassTag] @Since("1.3.0") ( - @Since("1.3.0") val freqItemsets: RDD[FreqItemset[Item]]) + @Since("1.3.0") val freqItemsets: RDD[FreqItemset[Item]], + @Since("2.0.0") val dataSize: Long) extends Saveable with Serializable { /** * Generates association rules for the [[Item]]s in [[freqItemsets]]. @@ -58,7 +59,7 @@ class FPGrowthModel[Item: ClassTag] @Since("1.3.0") ( @Since("1.5.0") def generateAssociationRules(confidence: Double): RDD[AssociationRules.Rule[Item]] = { val associationRules = new AssociationRules(confidence) - associationRules.run(freqItemsets) + associationRules.run(freqItemsets, dataSize) } /** @@ -102,7 +103,8 @@ object FPGrowthModel extends Loader[FPGrowthModel[_]] { val spark = SparkSession.builder().sparkContext(sc).getOrCreate() val metadata = compact(render( - ("class" -> thisClassName) ~ ("version" -> thisFormatVersion))) + ("class" -> thisClassName) ~ ("version" -> thisFormatVersion) ~ + ("dataSize" -> model.dataSize))) sc.parallelize(Seq(metadata), 1).saveAsTextFile(Loader.metadataPath(path)) // Get the type of item class @@ -128,19 +130,20 @@ object FPGrowthModel extends Loader[FPGrowthModel[_]] { val (className, formatVersion, metadata) = Loader.loadMetadata(sc, path) assert(className == thisClassName) assert(formatVersion == thisFormatVersion) - + val dataSize = (metadata \ "dataSize").extract[Long] val freqItemsets = spark.read.parquet(Loader.dataPath(path)) val sample = freqItemsets.select("items").head().get(0) - loadImpl(freqItemsets, sample) + loadImpl(freqItemsets, sample, dataSize) } - def loadImpl[Item: ClassTag](freqItemsets: DataFrame, sample: Item): FPGrowthModel[Item] = { + def loadImpl[Item: ClassTag](freqItemsets: DataFrame, sample: Item, + dataSize: Long): FPGrowthModel[Item] = { val freqItemsetsRDD = freqItemsets.select("items", "freq").rdd.map { x => val items = x.getAs[Seq[Item]](0).toArray val freq = x.getLong(1) new FreqItemset(items, freq) } - new FPGrowthModel(freqItemsetsRDD) + new FPGrowthModel(freqItemsetsRDD, dataSize) } } } @@ -215,7 +218,7 @@ class FPGrowth private ( val partitioner = new HashPartitioner(numParts) val freqItems = genFreqItems(data, minCount, partitioner) val freqItemsets = genFreqItemsets(data, minCount, freqItems, partitioner) - new FPGrowthModel(freqItemsets) + new FPGrowthModel(freqItemsets, count) } /** Java-friendly version of [[run]]. */ diff --git a/mllib/src/test/java/org/apache/spark/mllib/fpm/JavaAssociationRulesSuite.java b/mllib/src/test/java/org/apache/spark/mllib/fpm/JavaAssociationRulesSuite.java index 3451e0773759..1e1562e87383 100644 --- a/mllib/src/test/java/org/apache/spark/mllib/fpm/JavaAssociationRulesSuite.java +++ b/mllib/src/test/java/org/apache/spark/mllib/fpm/JavaAssociationRulesSuite.java @@ -36,6 +36,6 @@ public void runAssociationRules() { new FreqItemset(new String[]{"a", "b"}, 12L) )); - JavaRDD> results = (new AssociationRules()).run(freqItemsets); + JavaRDD> results = (new AssociationRules()).run(freqItemsets, 50L); } } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/fpm/AssociationRulesSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/fpm/AssociationRulesSuite.scala index dcb1f398b04b..9786049b1d01 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/fpm/AssociationRulesSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/fpm/AssociationRulesSuite.scala @@ -38,7 +38,7 @@ class AssociationRulesSuite extends SparkFunSuite with MLlibTestSparkContext { val results1 = ar .setMinConfidence(0.9) - .run(freqItemsets) + .run(freqItemsets, 10L) .collect() /* Verify results using the `R` code: @@ -67,7 +67,7 @@ class AssociationRulesSuite extends SparkFunSuite with MLlibTestSparkContext { val results2 = ar .setMinConfidence(0) - .run(freqItemsets) + .run(freqItemsets, 10L) .collect() /* Verify results using the `R` code: From ed384c7f81c65725a64180b0e7da5267d5173913 Mon Sep 17 00:00:00 2001 From: Yuhao Yang Date: Thu, 23 Jun 2016 17:50:12 -0400 Subject: [PATCH 3/3] java style --- .../org/apache/spark/mllib/fpm/JavaAssociationRulesSuite.java | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/mllib/src/test/java/org/apache/spark/mllib/fpm/JavaAssociationRulesSuite.java b/mllib/src/test/java/org/apache/spark/mllib/fpm/JavaAssociationRulesSuite.java index 1e1562e87383..e4624bb67410 100644 --- a/mllib/src/test/java/org/apache/spark/mllib/fpm/JavaAssociationRulesSuite.java +++ b/mllib/src/test/java/org/apache/spark/mllib/fpm/JavaAssociationRulesSuite.java @@ -36,6 +36,7 @@ public void runAssociationRules() { new FreqItemset(new String[]{"a", "b"}, 12L) )); - JavaRDD> results = (new AssociationRules()).run(freqItemsets, 50L); + JavaRDD> results = (new AssociationRules()).run( + freqItemsets, 50L); } }