Skip to content
152 changes: 137 additions & 15 deletions mllib/src/main/scala/org/apache/spark/mllib/linalg/BLAS.scala
Original file line number Diff line number Diff line change
Expand Up @@ -463,7 +463,7 @@ private[spark] object BLAS extends Serializable with Logging {
def gemv(
alpha: Double,
A: Matrix,
x: DenseVector,
x: Vector,
beta: Double,
y: DenseVector): Unit = {
require(A.numCols == x.size,
Expand All @@ -473,44 +473,169 @@ private[spark] object BLAS extends Serializable with Logging {
if (alpha == 0.0) {
logDebug("gemv: alpha is equal to 0. Returning y.")
} else {
A match {
case sparse: SparseMatrix =>
gemv(alpha, sparse, x, beta, y)
case dense: DenseMatrix =>
gemv(alpha, dense, x, beta, y)
(A, x) match {
case (smA: SparseMatrix, dvx: DenseVector) =>
gemv(alpha, smA, dvx, beta, y)
case (smA: SparseMatrix, svx: SparseVector) =>
gemv(alpha, smA, svx, beta, y)
case (dmA: DenseMatrix, dvx: DenseVector) =>
gemv(alpha, dmA, dvx, beta, y)
case (dmA: DenseMatrix, svx: SparseVector) =>
gemv(alpha, dmA, svx, beta, y)
case _ =>
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about SparseMatrix and SparseVector? To make the consistent naming, we can use dmA, smA, dvx, and svx.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you don't really want to add SparseMatrix and SparseVector, the type safety will be broken when you call with this configuration. Previously, this function is totally type safe in compile time, and no way to get into "case _".

throw new IllegalArgumentException(s"gemv doesn't support matrix type ${A.getClass}.")
throw new IllegalArgumentException(s"gemv doesn't support running on matrix type " +
s"${A.getClass} and vector type ${x.getClass}.")
}
}
}

/**
* y := alpha * A * x + beta * y
* For `DenseMatrix` A.
* For `DenseMatrix` A and `DenseVector` x.
*/
private def gemv(
alpha: Double,
A: DenseMatrix,
x: DenseVector,
beta: Double,
y: DenseVector): Unit = {
y: DenseVector): Unit = {
val tStrA = if (A.isTransposed) "T" else "N"
val mA = if (!A.isTransposed) A.numRows else A.numCols
val nA = if (!A.isTransposed) A.numCols else A.numRows
nativeBLAS.dgemv(tStrA, mA, nA, alpha, A.values, mA, x.values, 1, beta,
y.values, 1)
}

/**
* y := alpha * A * x + beta * y
* For `DenseMatrix` A and `SparseVector` x.
*/
private def gemv(
alpha: Double,
A: DenseMatrix,
x: SparseVector,
beta: Double,
y: DenseVector): Unit = {
val mA: Int = A.numRows
val nA: Int = A.numCols

val Avals = A.values

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

are you using this?

val xIndices = x.indices
val xNnz = xIndices.length
val xValues = x.values
val yValues = y.values

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

all should be val

if (alpha == 0.0) {
scal(beta, y)
return
}

if (A.isTransposed) {
var rowCounterForA = 0
while (rowCounterForA < mA) {
var sum = 0.0
var k = 0
while (k < xNnz) {
sum += xValues(k) * Avals(xIndices(k) + rowCounterForA * nA)
k += 1
}
yValues(rowCounterForA) = sum * alpha + beta * yValues(rowCounterForA)
rowCounterForA += 1
}
} else {
var rowCounterForA = 0
while (rowCounterForA < mA) {
var sum = 0.0
var k = 0
while (k < xNnz) {
sum += xValues(k) * Avals(xIndices(k) * mA + rowCounterForA)
k += 1
}
yValues(rowCounterForA) = sum * alpha + beta * yValues(rowCounterForA)
rowCounterForA += 1
}
}
}

/**
* y := alpha * A * x + beta * y
* For `SparseMatrix` A.
* For `SparseMatrix` A and `SparseVector` x.
*/
private def gemv(
alpha: Double,
A: SparseMatrix,
x: SparseVector,
beta: Double,
y: DenseVector): Unit = {
val xValues = x.values
val xIndices = x.indices
val xNnz = xIndices.length

val yValues = y.values

val mA: Int = A.numRows
val nA: Int = A.numCols

val Avals = A.values
val Arows = if (!A.isTransposed) A.rowIndices else A.colPtrs
val Acols = if (!A.isTransposed) A.colPtrs else A.rowIndices

if (alpha == 0.0) {
scal(beta, y)
return
}

if (A.isTransposed) {
var rowCounter = 0
while (rowCounter < mA) {
var i = Arows(rowCounter)
val indEnd = Arows(rowCounter + 1)
var sum = 0.0
var k = 0
while (k < xNnz && i < indEnd) {
if (xIndices(k) == Acols(i)) {
sum += Avals(i) * xValues(k)
i += 1
}
k += 1
}
yValues(rowCounter) = sum * alpha + beta * yValues(rowCounter)
rowCounter += 1
}
} else {
scal(beta, y)

var colCounterForA = 0
var k = 0
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

scal(beta, y)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need to check if (beta != 0.0) here? I think we should do the scaling even beta is 0.0 as it will clear out y.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You are right.

while (colCounterForA < nA && k < xNnz) {
if (xIndices(k) == colCounterForA) {
var i = Acols(colCounterForA)
val indEnd = Acols(colCounterForA + 1)

val xTemp = xValues(k) * alpha
while (i < indEnd) {
val rowIndex = Arows(i)
yValues(Arows(i)) += Avals(i) * xTemp
i += 1
}
k += 1
}
colCounterForA += 1
}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I checked the logic, and it looks correct. Great.

}
}

/**
* y := alpha * A * x + beta * y
* For `SparseMatrix` A and `DenseVector` x.
*/
private def gemv(
alpha: Double,
A: SparseMatrix,
x: DenseVector,
beta: Double,
y: DenseVector): Unit = {
y: DenseVector): Unit = {
val xValues = x.values
val yValues = y.values
val mA: Int = A.numRows
Expand All @@ -534,10 +659,7 @@ private[spark] object BLAS extends Serializable with Logging {
rowCounter += 1
}
} else {
// Scale vector first if `beta` is not equal to 0.0
if (beta != 0.0) {
scal(beta, y)
}
scal(beta, y)
// Perform matrix-vector multiplication and add to y
var colCounterForA = 0
while (colCounterForA < nA) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -77,8 +77,13 @@ sealed trait Matrix extends Serializable {
C
}

/** Convenience method for `Matrix`-`DenseVector` multiplication. */
/** Convenience method for `Matrix`-`DenseVector` multiplication. For binary compatibility. */
def multiply(y: DenseVector): DenseVector = {
multiply(y.asInstanceOf[Vector])
}

/** Convenience method for `Matrix`-`Vector` multiplication. */
def multiply(y: Vector): DenseVector = {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We cannot delete def multiply(y: DenseVector), which breaks binary compatibility. Please delegate the implementation to multiply(y: Vector) and update Mima excludes.

val output = new DenseVector(new Array[Double](numRows))
BLAS.gemv(1.0, this, y, 0.0, output)
output
Expand Down
96 changes: 81 additions & 15 deletions mllib/src/test/scala/org/apache/spark/mllib/linalg/BLASSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -256,42 +256,108 @@ class BLASSuite extends FunSuite {
val dA =
new DenseMatrix(4, 3, Array(0.0, 1.0, 0.0, 0.0, 2.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 3.0))
val sA = new SparseMatrix(4, 3, Array(0, 1, 3, 4), Array(1, 0, 2, 3), Array(1.0, 2.0, 1.0, 3.0))

val x = new DenseVector(Array(1.0, 2.0, 3.0))

val dA2 =
new DenseMatrix(4, 3, Array(0.0, 2.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 3.0), true)
val sA2 =
new SparseMatrix(4, 3, Array(0, 1, 2, 3, 4), Array(1, 0, 1, 2), Array(2.0, 1.0, 1.0, 3.0),
true)

val dx = new DenseVector(Array(1.0, 2.0, 3.0))
val sx = dx.toSparse
val expected = new DenseVector(Array(4.0, 1.0, 2.0, 9.0))

assert(dA.multiply(x) ~== expected absTol 1e-15)
assert(sA.multiply(x) ~== expected absTol 1e-15)

assert(dA.multiply(dx) ~== expected absTol 1e-15)
assert(sA.multiply(dx) ~== expected absTol 1e-15)
assert(dA.multiply(sx) ~== expected absTol 1e-15)
assert(sA.multiply(sx) ~== expected absTol 1e-15)

val y1 = new DenseVector(Array(1.0, 3.0, 1.0, 0.0))
val y2 = y1.copy
val y3 = y1.copy
val y4 = y1.copy
val y5 = y1.copy
val y6 = y1.copy
val y7 = y1.copy
val y8 = y1.copy
val y9 = y1.copy
val y10 = y1.copy
val y11 = y1.copy
val y12 = y1.copy
val y13 = y1.copy
val y14 = y1.copy
val y15 = y1.copy
val y16 = y1.copy

val expected2 = new DenseVector(Array(6.0, 7.0, 4.0, 9.0))
val expected3 = new DenseVector(Array(10.0, 8.0, 6.0, 18.0))

gemv(1.0, dA, x, 2.0, y1)
gemv(1.0, sA, x, 2.0, y2)
gemv(2.0, dA, x, 2.0, y3)
gemv(2.0, sA, x, 2.0, y4)
gemv(1.0, dA, dx, 2.0, y1)
gemv(1.0, sA, dx, 2.0, y2)
gemv(1.0, dA, sx, 2.0, y3)
gemv(1.0, sA, sx, 2.0, y4)

gemv(1.0, dA2, dx, 2.0, y5)
gemv(1.0, sA2, dx, 2.0, y6)
gemv(1.0, dA2, sx, 2.0, y7)
gemv(1.0, sA2, sx, 2.0, y8)

gemv(2.0, dA, dx, 2.0, y9)
gemv(2.0, sA, dx, 2.0, y10)
gemv(2.0, dA, sx, 2.0, y11)
gemv(2.0, sA, sx, 2.0, y12)

gemv(2.0, dA2, dx, 2.0, y13)
gemv(2.0, sA2, dx, 2.0, y14)
gemv(2.0, dA2, sx, 2.0, y15)
gemv(2.0, sA2, sx, 2.0, y16)

assert(y1 ~== expected2 absTol 1e-15)
assert(y2 ~== expected2 absTol 1e-15)
assert(y3 ~== expected3 absTol 1e-15)
assert(y4 ~== expected3 absTol 1e-15)
assert(y3 ~== expected2 absTol 1e-15)
assert(y4 ~== expected2 absTol 1e-15)

assert(y5 ~== expected2 absTol 1e-15)
assert(y6 ~== expected2 absTol 1e-15)
assert(y7 ~== expected2 absTol 1e-15)
assert(y8 ~== expected2 absTol 1e-15)

assert(y9 ~== expected3 absTol 1e-15)
assert(y10 ~== expected3 absTol 1e-15)
assert(y11 ~== expected3 absTol 1e-15)
assert(y12 ~== expected3 absTol 1e-15)

assert(y13 ~== expected3 absTol 1e-15)
assert(y14 ~== expected3 absTol 1e-15)
assert(y15 ~== expected3 absTol 1e-15)
assert(y16 ~== expected3 absTol 1e-15)

withClue("columns of A don't match the rows of B") {
intercept[Exception] {
gemv(1.0, dA.transpose, x, 2.0, y1)
gemv(1.0, dA.transpose, dx, 2.0, y1)
}
intercept[Exception] {
gemv(1.0, sA.transpose, dx, 2.0, y1)
}
intercept[Exception] {
gemv(1.0, dA.transpose, sx, 2.0, y1)
}
intercept[Exception] {
gemv(1.0, sA.transpose, sx, 2.0, y1)
}
}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

check exception for

gemv(1.0, dA.transpose, dx, 2.0, y1)
gemv(1.0, sA.transpose, dx, 2.0, y1)
gemv(1.0, dA.transpose, sx, 2.0, y1)
gemv(1.0, sA.transpose, sx, 2.0, y1)


val dAT =
new DenseMatrix(3, 4, Array(0.0, 2.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 3.0))
val sAT =
new SparseMatrix(3, 4, Array(0, 1, 2, 3, 4), Array(1, 0, 1, 2), Array(2.0, 1.0, 1.0, 3.0))

val dATT = dAT.transpose
val sATT = sAT.transpose

assert(dATT.multiply(x) ~== expected absTol 1e-15)
assert(sATT.multiply(x) ~== expected absTol 1e-15)
assert(dATT.multiply(dx) ~== expected absTol 1e-15)
assert(sATT.multiply(dx) ~== expected absTol 1e-15)
assert(dATT.multiply(sx) ~== expected absTol 1e-15)
assert(sATT.multiply(sx) ~== expected absTol 1e-15)
}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add

assert(dATT.multiply(sx) ~== expected absTol 1e-15)
assert(sATT.multiply(sx) ~== expected absTol 1e-15)

}
18 changes: 16 additions & 2 deletions project/MimaExcludes.scala
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,14 @@ object MimaExcludes {
ProblemFilters.exclude[MissingMethodProblem](
"org.apache.spark.mllib.linalg.Vector.toSparse"),
ProblemFilters.exclude[MissingMethodProblem](
"org.apache.spark.mllib.linalg.Vector.numActives")
"org.apache.spark.mllib.linalg.Vector.numActives"),
// SPARK-7681 add SparseVector support for gemv
ProblemFilters.exclude[MissingMethodProblem](
"org.apache.spark.mllib.linalg.Matrix.multiply"),
ProblemFilters.exclude[MissingMethodProblem](
"org.apache.spark.mllib.linalg.DenseMatrix.multiply"),
ProblemFilters.exclude[MissingMethodProblem](
"org.apache.spark.mllib.linalg.SparseMatrix.multiply")
) ++ Seq(
// Execution should never be included as its always internal.
MimaBuild.excludeSparkPackage("sql.execution"),
Expand Down Expand Up @@ -172,7 +179,14 @@ object MimaExcludes {
ProblemFilters.exclude[MissingMethodProblem](
"org.apache.spark.mllib.linalg.Matrix.isTransposed"),
ProblemFilters.exclude[MissingMethodProblem](
"org.apache.spark.mllib.linalg.Matrix.foreachActive")
"org.apache.spark.mllib.linalg.Matrix.foreachActive"),
// SPARK-7681 add SparseVector support for gemv
ProblemFilters.exclude[MissingMethodProblem](
"org.apache.spark.mllib.linalg.Matrix.multiply"),
ProblemFilters.exclude[MissingMethodProblem](
"org.apache.spark.mllib.linalg.DenseMatrix.multiply"),
ProblemFilters.exclude[MissingMethodProblem](
"org.apache.spark.mllib.linalg.SparseMatrix.multiply")
) ++ Seq(
// SPARK-5540
ProblemFilters.exclude[MissingMethodProblem](
Expand Down