Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
330 changes: 329 additions & 1 deletion mllib/src/main/scala/org/apache/spark/mllib/linalg/BLAS.scala
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,17 @@
package org.apache.spark.mllib.linalg

import com.github.fommil.netlib.{BLAS => NetlibBLAS, F2jBLAS}
import com.github.fommil.netlib.BLAS.{getInstance => NativeBLAS}

import org.apache.spark.Logging

/**
* BLAS routines for MLlib's vectors and matrices.
*/
private[mllib] object BLAS extends Serializable {
private[mllib] object BLAS extends Serializable with Logging {

@transient private var _f2jBLAS: NetlibBLAS = _
@transient private var _nativeBLAS: NetlibBLAS = _

// For level-1 routines, we use Java implementation.
private def f2jBLAS: NetlibBLAS = {
Expand Down Expand Up @@ -197,4 +201,328 @@ private[mllib] object BLAS extends Serializable {
throw new IllegalArgumentException(s"scal doesn't support vector type ${x.getClass}.")
}
}

// For level-3 routines, we use the native BLAS.
private def nativeBLAS: NetlibBLAS = {
if (_nativeBLAS == null) {
_nativeBLAS = NativeBLAS
}
_nativeBLAS
}

/**
* C := alpha * A * B + beta * C
* @param transA whether to use the transpose of matrix A (true), or A itself (false).
* @param transB whether to use the transpose of matrix B (true), or B itself (false).
* @param alpha a scalar to scale the multiplication A * B.
* @param A the matrix A that will be left multiplied to B. Size of m x k.
* @param B the matrix B that will be left multiplied by A. Size of k x n.
* @param beta a scalar that can be used to scale matrix C.
* @param C the resulting matrix C. Size of m x n.
*/
def gemm(
transA: Boolean,
transB: Boolean,
alpha: Double,
A: Matrix,
B: DenseMatrix,
beta: Double,
C: DenseMatrix): Unit = {
if (alpha == 0.0) {
logDebug("gemm: alpha is equal to 0. Returning C.")
} else {
A match {
case sparse: SparseMatrix =>
gemm(transA, transB, alpha, sparse, B, beta, C)
case dense: DenseMatrix =>
gemm(transA, transB, alpha, dense, B, beta, C)
case _ =>
throw new IllegalArgumentException(s"gemm doesn't support matrix type ${A.getClass}.")
}
}
}

/**
* C := alpha * A * B + beta * C
*
* @param alpha a scalar to scale the multiplication A * B.
* @param A the matrix A that will be left multiplied to B. Size of m x k.
* @param B the matrix B that will be left multiplied by A. Size of k x n.
* @param beta a scalar that can be used to scale matrix C.
* @param C the resulting matrix C. Size of m x n.
*/
def gemm(
alpha: Double,
A: Matrix,
B: DenseMatrix,
beta: Double,
C: DenseMatrix): Unit = {
gemm(false, false, alpha, A, B, beta, C)
}

/**
* C := alpha * A * B + beta * C
* For `DenseMatrix` A.
*/
private def gemm(
transA: Boolean,
transB: Boolean,
alpha: Double,
A: DenseMatrix,
B: DenseMatrix,
beta: Double,
C: DenseMatrix): Unit = {
val mA: Int = if (!transA) A.numRows else A.numCols
val nB: Int = if (!transB) B.numCols else B.numRows
val kA: Int = if (!transA) A.numCols else A.numRows
val kB: Int = if (!transB) B.numRows else B.numCols
val tAstr = if (!transA) "N" else "T"
val tBstr = if (!transB) "N" else "T"

require(kA == kB, s"The columns of A don't match the rows of B. A: $kA, B: $kB")
require(mA == C.numRows, s"The rows of C don't match the rows of A. C: ${C.numRows}, A: $mA")
require(nB == C.numCols,
s"The columns of C don't match the columns of B. C: ${C.numCols}, A: $nB")

nativeBLAS.dgemm(tAstr, tBstr, mA, nB, kA, alpha, A.values, A.numRows, B.values, B.numRows,
beta, C.values, C.numRows)
}

/**
* C := alpha * A * B + beta * C
* For `SparseMatrix` A.
*/
private def gemm(
transA: Boolean,
transB: Boolean,
alpha: Double,
A: SparseMatrix,
B: DenseMatrix,
beta: Double,
C: DenseMatrix): Unit = {
val mA: Int = if (!transA) A.numRows else A.numCols
val nB: Int = if (!transB) B.numCols else B.numRows
val kA: Int = if (!transA) A.numCols else A.numRows
val kB: Int = if (!transB) B.numRows else B.numCols

require(kA == kB, s"The columns of A don't match the rows of B. A: $kA, B: $kB")
require(mA == C.numRows, s"The rows of C don't match the rows of A. C: ${C.numRows}, A: $mA")
require(nB == C.numCols,
s"The columns of C don't match the columns of B. C: ${C.numCols}, A: $nB")

val Avals = A.values
val Arows = if (!transA) A.rowIndices else A.colPtrs
val Acols = if (!transA) A.colPtrs else A.rowIndices

// Slicing is easy in this case. This is the optimal multiplication setting for sparse matrices
if (transA){
var colCounterForB = 0
if (!transB) { // Expensive to put the check inside the loop
while (colCounterForB < nB) {
var rowCounterForA = 0
val Cstart = colCounterForB * mA
val Bstart = colCounterForB * kA
while (rowCounterForA < mA) {
var i = Arows(rowCounterForA)
val indEnd = Arows(rowCounterForA + 1)
var sum = 0.0
while (i < indEnd) {
sum += Avals(i) * B.values(Bstart + Acols(i))
i += 1
}
val Cindex = Cstart + rowCounterForA
C.values(Cindex) = beta * C.values(Cindex) + sum * alpha
rowCounterForA += 1
}
colCounterForB += 1
}
} else {
while (colCounterForB < nB) {
var rowCounter = 0
val Cstart = colCounterForB * mA
while (rowCounter < mA) {
var i = Arows(rowCounter)
val indEnd = Arows(rowCounter + 1)
var sum = 0.0
while (i < indEnd) {
sum += Avals(i) * B(colCounterForB, Acols(i))
i += 1
}
val Cindex = Cstart + rowCounter
C.values(Cindex) = beta * C.values(Cindex) + sum * alpha
rowCounter += 1
}
colCounterForB += 1
}
}
} else {
// Scale matrix first if `beta` is not equal to 0.0
if (beta != 0.0){
f2jBLAS.dscal(C.values.length, beta, C.values, 1)
}
// Perform matrix multiplication and add to C. The rows of A are multiplied by the columns of
// B, and added to C.
var colCounterForB = 0 // the column to be updated in C
if (!transB) { // Expensive to put the check inside the loop
while (colCounterForB < nB) {
var colCounterForA = 0 // The column of A to multiply with the row of B
val Bstart = colCounterForB * kB
val Cstart = colCounterForB * mA
while (colCounterForA < kA) {
var i = Acols(colCounterForA)
val indEnd = Acols(colCounterForA + 1)
val Bval = B.values(Bstart + colCounterForA) * alpha
while (i < indEnd){
C.values(Cstart + Arows(i)) += Avals(i) * Bval
i += 1
}
colCounterForA += 1
}
colCounterForB += 1
}
} else {
while (colCounterForB < nB) {
var colCounterForA = 0 // The column of A to multiply with the row of B
val Cstart = colCounterForB * mA
while (colCounterForA < kA){
var i = Acols(colCounterForA)
val indEnd = Acols(colCounterForA + 1)
val Bval = B(colCounterForB, colCounterForA) * alpha
while (i < indEnd){
C.values(Cstart + Arows(i)) += Avals(i) * Bval
i += 1
}
colCounterForA += 1
}
colCounterForB += 1
}
}
}
}

/**
* y := alpha * A * x + beta * y
* @param trans whether to use the transpose of matrix A (true), or A itself (false).
* @param alpha a scalar to scale the multiplication A * x.
* @param A the matrix A that will be left multiplied to x. Size of m x n.
* @param x the vector x that will be left multiplied by A. Size of n x 1.
* @param beta a scalar that can be used to scale vector y.
* @param y the resulting vector y. Size of m x 1.
*/
def gemv(
trans: Boolean,
alpha: Double,
A: Matrix,
x: DenseVector,
beta: Double,
y: DenseVector): Unit = {

val mA: Int = if (!trans) A.numRows else A.numCols
val nx: Int = x.size
val nA: Int = if (!trans) A.numCols else A.numRows

require(nA == nx, s"The columns of A don't match the number of elements of x. A: $nA, x: $nx")
require(mA == y.size,
s"The rows of A don't match the number of elements of y. A: $mA, y:${y.size}}")
if (alpha == 0.0) {
logDebug("gemv: alpha is equal to 0. Returning y.")
} else {
A match {
case sparse: SparseMatrix =>
gemv(trans, alpha, sparse, x, beta, y)
case dense: DenseMatrix =>
gemv(trans, alpha, dense, x, beta, y)
case _ =>
throw new IllegalArgumentException(s"gemv doesn't support matrix type ${A.getClass}.")
}
}
}

/**
* y := alpha * A * x + beta * y
*
* @param alpha a scalar to scale the multiplication A * x.
* @param A the matrix A that will be left multiplied to x. Size of m x n.
* @param x the vector x that will be left multiplied by A. Size of n x 1.
* @param beta a scalar that can be used to scale vector y.
* @param y the resulting vector y. Size of m x 1.
*/
def gemv(
alpha: Double,
A: Matrix,
x: DenseVector,
beta: Double,
y: DenseVector): Unit = {
gemv(false, alpha, A, x, beta, y)
}

/**
* y := alpha * A * x + beta * y
* For `DenseMatrix` A.
*/
private def gemv(
trans: Boolean,
alpha: Double,
A: DenseMatrix,
x: DenseVector,
beta: Double,
y: DenseVector): Unit = {
val tStrA = if (!trans) "N" else "T"
nativeBLAS.dgemv(tStrA, A.numRows, A.numCols, alpha, A.values, A.numRows, x.values, 1, beta,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is worth testing whether we should use f2jBLAS or nativeBLAS for level-2.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll take a look at it

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nativeBLAS was faster across the board. Here are the results (in ms):

m n f2j native
1000 500 0.7 0.1
1000 1000 2.3 0.7
1000 5000 7.7 3.6
5000 500 3.8 2.4
5000 1000 7.7 4.0
5000 5000 38.4 17.4
10000 500 6.5 4.5
10000 1000 13.0 8.9
10000 5000 75.4 36.7

y.values, 1)
}

/**
* y := alpha * A * x + beta * y
* For `SparseMatrix` A.
*/
private def gemv(
trans: Boolean,
alpha: Double,
A: SparseMatrix,
x: DenseVector,
beta: Double,
y: DenseVector): Unit = {

val mA: Int = if(!trans) A.numRows else A.numCols
val nA: Int = if(!trans) A.numCols else A.numRows

val Avals = A.values
val Arows = if (!trans) A.rowIndices else A.colPtrs
val Acols = if (!trans) A.colPtrs else A.rowIndices

// Slicing is easy in this case. This is the optimal multiplication setting for sparse matrices
if (trans){
var rowCounter = 0
while (rowCounter < mA){
var i = Arows(rowCounter)
val indEnd = Arows(rowCounter + 1)
var sum = 0.0
while(i < indEnd){
sum += Avals(i) * x.values(Acols(i))
i += 1
}
y.values(rowCounter) = beta * y.values(rowCounter) + sum * alpha
rowCounter += 1
}
} else {
// Scale vector first if `beta` is not equal to 0.0
if (beta != 0.0){
scal(beta, y)
}
// Perform matrix-vector multiplication and add to y
var colCounterForA = 0
while (colCounterForA < nA){
var i = Acols(colCounterForA)
val indEnd = Acols(colCounterForA + 1)
val xVal = x.values(colCounterForA) * alpha
while (i < indEnd){
val rowIndex = Arows(i)
y.values(rowIndex) += Avals(i) * xVal
i += 1
}
colCounterForA += 1
}
}
}
}
Loading