Skip to content

Commit 9dfad1b

Browse files
author
Nick Pentreath
committed
Expose {ml, mllib}-private f2jBLAS and use that
1 parent 0b1eaa3 commit 9dfad1b

File tree

4 files changed

+6
-10
lines changed

4 files changed

+6
-10
lines changed

mllib-local/src/main/scala/org/apache/spark/ml/linalg/BLAS.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ private[spark] object BLAS extends Serializable {
2929
@transient private var _nativeBLAS: NetlibBLAS = _
3030

3131
// For level-1 routines, we use Java implementation.
32-
private def f2jBLAS: NetlibBLAS = {
32+
private[ml] def f2jBLAS: NetlibBLAS = {
3333
if (_f2jBLAS == null) {
3434
_f2jBLAS = new F2jBLAS
3535
}

mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,6 @@ import scala.util.{Sorting, Try}
2727
import scala.util.hashing.byteswap64
2828

2929
import com.github.fommil.netlib.BLAS.{getInstance => blas}
30-
import com.github.fommil.netlib.F2jBLAS
3130
import org.apache.hadoop.fs.Path
3231
import org.json4s.DefaultFormats
3332
import org.json4s.JsonDSL._
@@ -36,6 +35,7 @@ import org.apache.spark.{Dependency, Partitioner, ShuffleDependency, SparkContex
3635
import org.apache.spark.annotation.{DeveloperApi, Since}
3736
import org.apache.spark.internal.Logging
3837
import org.apache.spark.ml.{Estimator, Model}
38+
import org.apache.spark.ml.linalg.BLAS
3939
import org.apache.spark.ml.param._
4040
import org.apache.spark.ml.param.shared._
4141
import org.apache.spark.ml.util._
@@ -399,7 +399,7 @@ class ALSModel private[ml] (
399399
srcIter.foreach { case (srcId, srcFactor) =>
400400
dstIter.foreach { case (dstId, dstFactor) =>
401401
// We use F2jBLAS which is faster than a call to native BLAS for vector dot product
402-
val score = ALSModel._f2jBLAS.sdot(rank, srcFactor, 1, dstFactor, 1)
402+
val score = BLAS.f2jBLAS.sdot(rank, srcFactor, 1, dstFactor, 1)
403403
pq += dstId -> score
404404
}
405405
pq.foreach { case (dstId, score) =>
@@ -439,8 +439,6 @@ class ALSModel private[ml] (
439439
@Since("1.6.0")
440440
object ALSModel extends MLReadable[ALSModel] {
441441

442-
@transient private[recommendation] val _f2jBLAS = new F2jBLAS
443-
444442
private val NaN = "nan"
445443
private val Drop = "drop"
446444
private[recommendation] final val supportedColdStartStrategies = Array(NaN, Drop)

mllib/src/main/scala/org/apache/spark/mllib/linalg/BLAS.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ private[spark] object BLAS extends Serializable with Logging {
3131
@transient private var _nativeBLAS: NetlibBLAS = _
3232

3333
// For level-1 routines, we use Java implementation.
34-
private def f2jBLAS: NetlibBLAS = {
34+
private[mllib] def f2jBLAS: NetlibBLAS = {
3535
if (_f2jBLAS == null) {
3636
_f2jBLAS = new F2jBLAS
3737
}

mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@ import java.lang.{Integer => JavaInteger}
2222

2323
import com.clearspring.analytics.stream.cardinality.HyperLogLogPlus
2424
import com.github.fommil.netlib.BLAS.{getInstance => blas}
25-
import com.github.fommil.netlib.F2jBLAS
2625
import org.apache.hadoop.fs.Path
2726
import org.json4s._
2827
import org.json4s.JsonDSL._
@@ -32,6 +31,7 @@ import org.apache.spark.SparkContext
3231
import org.apache.spark.annotation.Since
3332
import org.apache.spark.api.java.{JavaPairRDD, JavaRDD}
3433
import org.apache.spark.internal.Logging
34+
import org.apache.spark.mllib.linalg.BLAS
3535
import org.apache.spark.mllib.rdd.MLPairRDDFunctions._
3636
import org.apache.spark.mllib.util.{Loader, Saveable}
3737
import org.apache.spark.rdd.RDD
@@ -246,8 +246,6 @@ object MatrixFactorizationModel extends Loader[MatrixFactorizationModel] {
246246

247247
import org.apache.spark.mllib.util.Loader._
248248

249-
@transient private val _f2jBLAS = new F2jBLAS
250-
251249
/**
252250
* Makes recommendations for a single user (or product).
253251
*/
@@ -299,7 +297,7 @@ object MatrixFactorizationModel extends Loader[MatrixFactorizationModel] {
299297
srcIter.foreach { case (srcId, srcFactor) =>
300298
dstIter.foreach { case (dstId, dstFactor) =>
301299
// We use F2jBLAS which is faster than a call to native BLAS for vector dot product
302-
val score = _f2jBLAS.ddot(rank, srcFactor, 1, dstFactor, 1)
300+
val score = BLAS.f2jBLAS.ddot(rank, srcFactor, 1, dstFactor, 1)
303301
pq += dstId -> score
304302
}
305303
pq.foreach { case (dstId, score) =>

0 commit comments

Comments
 (0)