Skip to content

Commit ee0d2af

Browse files
author
Nick Pentreath
committed
[SPARK-20677][MLLIB][ML] Follow-up to ALS recommend-all performance PRs
Small clean ups from #17742 and #17845. ## How was this patch tested? Existing unit tests. Author: Nick Pentreath <nickp@za.ibm.com> Closes #17919 from MLnick/SPARK-20677-als-perf-followup. (cherry picked from commit 25b4f41) Signed-off-by: Nick Pentreath <nickp@za.ibm.com>
1 parent b8d37ac commit ee0d2af

File tree

4 files changed

+28
-53
lines changed

4 files changed

+28
-53
lines changed

mllib-local/src/main/scala/org/apache/spark/ml/linalg/BLAS.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ private[spark] object BLAS extends Serializable {
2929
@transient private var _nativeBLAS: NetlibBLAS = _
3030

3131
// For level-1 routines, we use Java implementation.
32-
private def f2jBLAS: NetlibBLAS = {
32+
private[ml] def f2jBLAS: NetlibBLAS = {
3333
if (_f2jBLAS == null) {
3434
_f2jBLAS = new F2jBLAS
3535
}

mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala

Lines changed: 7 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ import org.apache.spark.{Dependency, Partitioner, ShuffleDependency, SparkContex
3535
import org.apache.spark.annotation.{DeveloperApi, Since}
3636
import org.apache.spark.internal.Logging
3737
import org.apache.spark.ml.{Estimator, Model}
38+
import org.apache.spark.ml.linalg.BLAS
3839
import org.apache.spark.ml.param._
3940
import org.apache.spark.ml.param.shared._
4041
import org.apache.spark.ml.util._
@@ -363,7 +364,7 @@ class ALSModel private[ml] (
363364
* relatively efficient, the approach implemented here is significantly more efficient.
364365
*
365366
* This approach groups factors into blocks and computes the top-k elements per block,
366-
* using a simple dot product (instead of gemm) and an efficient [[BoundedPriorityQueue]].
367+
* using dot product and an efficient [[BoundedPriorityQueue]] (instead of gemm).
367368
* It then computes the global top-k by aggregating the per block top-k elements with
368369
* a [[TopByKeyAggregator]]. This significantly reduces the size of intermediate and shuffle data.
369370
* This is the DataFrame equivalent to the approach used in
@@ -393,31 +394,18 @@ class ALSModel private[ml] (
393394
val m = srcIter.size
394395
val n = math.min(dstIter.size, num)
395396
val output = new Array[(Int, Int, Float)](m * n)
396-
var j = 0
397+
var i = 0
397398
val pq = new BoundedPriorityQueue[(Int, Float)](num)(Ordering.by(_._2))
398399
srcIter.foreach { case (srcId, srcFactor) =>
399400
dstIter.foreach { case (dstId, dstFactor) =>
400-
/*
401-
* The below code is equivalent to
402-
* `val score = blas.sdot(rank, srcFactor, 1, dstFactor, 1)`
403-
* This handwritten version is as or more efficient as BLAS calls in this case.
404-
*/
405-
var score = 0.0f
406-
var k = 0
407-
while (k < rank) {
408-
score += srcFactor(k) * dstFactor(k)
409-
k += 1
410-
}
401+
// We use F2jBLAS which is faster than a call to native BLAS for vector dot product
402+
val score = BLAS.f2jBLAS.sdot(rank, srcFactor, 1, dstFactor, 1)
411403
pq += dstId -> score
412404
}
413-
val pqIter = pq.iterator
414-
var i = 0
415-
while (i < n) {
416-
val (dstId, score) = pqIter.next()
417-
output(j + i) = (srcId, dstId, score)
405+
pq.foreach { case (dstId, score) =>
406+
output(i) = (srcId, dstId, score)
418407
i += 1
419408
}
420-
j += n
421409
pq.clear()
422410
}
423411
output.toSeq

mllib/src/main/scala/org/apache/spark/mllib/linalg/BLAS.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ private[spark] object BLAS extends Serializable with Logging {
3131
@transient private var _nativeBLAS: NetlibBLAS = _
3232

3333
// For level-1 routines, we use Java implementation.
34-
private def f2jBLAS: NetlibBLAS = {
34+
private[mllib] def f2jBLAS: NetlibBLAS = {
3535
if (_f2jBLAS == null) {
3636
_f2jBLAS = new F2jBLAS
3737
}

mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala

Lines changed: 19 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,6 @@ package org.apache.spark.mllib.recommendation
2020
import java.io.IOException
2121
import java.lang.{Integer => JavaInteger}
2222

23-
import scala.collection.mutable
24-
2523
import com.clearspring.analytics.stream.cardinality.HyperLogLogPlus
2624
import com.github.fommil.netlib.BLAS.{getInstance => blas}
2725
import org.apache.hadoop.fs.Path
@@ -33,7 +31,7 @@ import org.apache.spark.SparkContext
3331
import org.apache.spark.annotation.Since
3432
import org.apache.spark.api.java.{JavaPairRDD, JavaRDD}
3533
import org.apache.spark.internal.Logging
36-
import org.apache.spark.mllib.linalg._
34+
import org.apache.spark.mllib.linalg.BLAS
3735
import org.apache.spark.mllib.rdd.MLPairRDDFunctions._
3836
import org.apache.spark.mllib.util.{Loader, Saveable}
3937
import org.apache.spark.rdd.RDD
@@ -263,6 +261,19 @@ object MatrixFactorizationModel extends Loader[MatrixFactorizationModel] {
263261

264262
/**
265263
* Makes recommendations for all users (or products).
264+
*
265+
* Note: the previous approach used for computing top-k recommendations aimed to group
266+
* individual factor vectors into blocks, so that Level 3 BLAS operations (gemm) could
267+
* be used for efficiency. However, this causes excessive GC pressure due to the large
268+
* arrays required for intermediate result storage, as well as a high sensitivity to the
269+
* block size used.
270+
*
271+
* The following approach still groups factors into blocks, but instead computes the
272+
* top-k elements per block, using dot product and an efficient [[BoundedPriorityQueue]]
273+
* (instead of gemm). This avoids any large intermediate data structures and results
274+
* in significantly reduced GC pressure as well as shuffle data, which far outweighs
275+
* any cost incurred from not using Level 3 BLAS operations.
276+
*
266277
* @param rank rank
267278
* @param srcFeatures src features to receive recommendations
268279
* @param dstFeatures dst features used to make recommendations
@@ -277,46 +288,22 @@ object MatrixFactorizationModel extends Loader[MatrixFactorizationModel] {
277288
num: Int): RDD[(Int, Array[(Int, Double)])] = {
278289
val srcBlocks = blockify(srcFeatures)
279290
val dstBlocks = blockify(dstFeatures)
280-
/**
281-
* The previous approach used for computing top-k recommendations aimed to group
282-
* individual factor vectors into blocks, so that Level 3 BLAS operations (gemm) could
283-
* be used for efficiency. However, this causes excessive GC pressure due to the large
284-
* arrays required for intermediate result storage, as well as a high sensitivity to the
285-
* block size used.
286-
* The following approach still groups factors into blocks, but instead computes the
287-
* top-k elements per block, using a simple dot product (instead of gemm) and an efficient
288-
* [[BoundedPriorityQueue]]. This avoids any large intermediate data structures and results
289-
* in significantly reduced GC pressure as well as shuffle data, which far outweighs
290-
* any cost incurred from not using Level 3 BLAS operations.
291-
*/
292291
val ratings = srcBlocks.cartesian(dstBlocks).flatMap { case (srcIter, dstIter) =>
293292
val m = srcIter.size
294293
val n = math.min(dstIter.size, num)
295294
val output = new Array[(Int, (Int, Double))](m * n)
296-
var j = 0
295+
var i = 0
297296
val pq = new BoundedPriorityQueue[(Int, Double)](n)(Ordering.by(_._2))
298297
srcIter.foreach { case (srcId, srcFactor) =>
299298
dstIter.foreach { case (dstId, dstFactor) =>
300-
/*
301-
* The below code is equivalent to
302-
* `val score = blas.ddot(rank, srcFactor, 1, dstFactor, 1)`
303-
* This handwritten version is as or more efficient as BLAS calls in this case.
304-
*/
305-
var score: Double = 0
306-
var k = 0
307-
while (k < rank) {
308-
score += srcFactor(k) * dstFactor(k)
309-
k += 1
310-
}
299+
// We use F2jBLAS which is faster than a call to native BLAS for vector dot product
300+
val score = BLAS.f2jBLAS.ddot(rank, srcFactor, 1, dstFactor, 1)
311301
pq += dstId -> score
312302
}
313-
val pqIter = pq.iterator
314-
var i = 0
315-
while (i < n) {
316-
output(j + i) = (srcId, pqIter.next())
303+
pq.foreach { case (dstId, score) =>
304+
output(i) = (srcId, (dstId, score))
317305
i += 1
318306
}
319-
j += n
320307
pq.clear()
321308
}
322309
output.toSeq

0 commit comments

Comments
 (0)