-
Notifications
You must be signed in to change notification settings - Fork 29k
[SPARK-7681][MLlib] Add SparseVector support for gemv #6209
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
c069507
5d6d07a
4616696
410381a
054f05d
458d1ae
57a8c1e
b890e63
ce0bb8b
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -463,7 +463,7 @@ private[spark] object BLAS extends Serializable with Logging { | |
| def gemv( | ||
| alpha: Double, | ||
| A: Matrix, | ||
| x: DenseVector, | ||
| x: Vector, | ||
| beta: Double, | ||
| y: DenseVector): Unit = { | ||
| require(A.numCols == x.size, | ||
|
|
@@ -473,44 +473,169 @@ private[spark] object BLAS extends Serializable with Logging { | |
| if (alpha == 0.0) { | ||
| logDebug("gemv: alpha is equal to 0. Returning y.") | ||
| } else { | ||
| A match { | ||
| case sparse: SparseMatrix => | ||
| gemv(alpha, sparse, x, beta, y) | ||
| case dense: DenseMatrix => | ||
| gemv(alpha, dense, x, beta, y) | ||
| (A, x) match { | ||
| case (smA: SparseMatrix, dvx: DenseVector) => | ||
| gemv(alpha, smA, dvx, beta, y) | ||
| case (smA: SparseMatrix, svx: SparseVector) => | ||
| gemv(alpha, smA, svx, beta, y) | ||
| case (dmA: DenseMatrix, dvx: DenseVector) => | ||
| gemv(alpha, dmA, dvx, beta, y) | ||
| case (dmA: DenseMatrix, svx: SparseVector) => | ||
| gemv(alpha, dmA, svx, beta, y) | ||
| case _ => | ||
| throw new IllegalArgumentException(s"gemv doesn't support matrix type ${A.getClass}.") | ||
| throw new IllegalArgumentException(s"gemv doesn't support running on matrix type " + | ||
| s"${A.getClass} and vector type ${x.getClass}.") | ||
| } | ||
| } | ||
| } | ||
|
|
||
| /** | ||
| * y := alpha * A * x + beta * y | ||
| * For `DenseMatrix` A. | ||
| * For `DenseMatrix` A and `DenseVector` x. | ||
| */ | ||
| private def gemv( | ||
| alpha: Double, | ||
| A: DenseMatrix, | ||
| x: DenseVector, | ||
| beta: Double, | ||
| y: DenseVector): Unit = { | ||
| y: DenseVector): Unit = { | ||
| val tStrA = if (A.isTransposed) "T" else "N" | ||
| val mA = if (!A.isTransposed) A.numRows else A.numCols | ||
| val nA = if (!A.isTransposed) A.numCols else A.numRows | ||
| nativeBLAS.dgemv(tStrA, mA, nA, alpha, A.values, mA, x.values, 1, beta, | ||
| y.values, 1) | ||
| } | ||
|
|
||
| /** | ||
| * y := alpha * A * x + beta * y | ||
| * For `DenseMatrix` A and `SparseVector` x. | ||
| */ | ||
| private def gemv( | ||
| alpha: Double, | ||
| A: DenseMatrix, | ||
| x: SparseVector, | ||
| beta: Double, | ||
| y: DenseVector): Unit = { | ||
| val mA: Int = A.numRows | ||
| val nA: Int = A.numCols | ||
|
|
||
| val Avals = A.values | ||
|
|
||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. are you using this? |
||
| val xIndices = x.indices | ||
| val xNnz = xIndices.length | ||
| val xValues = x.values | ||
| val yValues = y.values | ||
|
|
||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. all should be |
||
| if (alpha == 0.0) { | ||
| scal(beta, y) | ||
| return | ||
| } | ||
|
|
||
| if (A.isTransposed) { | ||
| var rowCounterForA = 0 | ||
| while (rowCounterForA < mA) { | ||
| var sum = 0.0 | ||
| var k = 0 | ||
| while (k < xNnz) { | ||
| sum += xValues(k) * Avals(xIndices(k) + rowCounterForA * nA) | ||
| k += 1 | ||
| } | ||
| yValues(rowCounterForA) = sum * alpha + beta * yValues(rowCounterForA) | ||
| rowCounterForA += 1 | ||
| } | ||
| } else { | ||
| var rowCounterForA = 0 | ||
| while (rowCounterForA < mA) { | ||
| var sum = 0.0 | ||
| var k = 0 | ||
| while (k < xNnz) { | ||
| sum += xValues(k) * Avals(xIndices(k) * mA + rowCounterForA) | ||
| k += 1 | ||
| } | ||
| yValues(rowCounterForA) = sum * alpha + beta * yValues(rowCounterForA) | ||
| rowCounterForA += 1 | ||
| } | ||
| } | ||
| } | ||
|
|
||
| /** | ||
| * y := alpha * A * x + beta * y | ||
| * For `SparseMatrix` A. | ||
| * For `SparseMatrix` A and `SparseVector` x. | ||
| */ | ||
| private def gemv( | ||
| alpha: Double, | ||
| A: SparseMatrix, | ||
| x: SparseVector, | ||
| beta: Double, | ||
| y: DenseVector): Unit = { | ||
| val xValues = x.values | ||
| val xIndices = x.indices | ||
| val xNnz = xIndices.length | ||
|
|
||
| val yValues = y.values | ||
|
|
||
| val mA: Int = A.numRows | ||
| val nA: Int = A.numCols | ||
|
|
||
| val Avals = A.values | ||
| val Arows = if (!A.isTransposed) A.rowIndices else A.colPtrs | ||
| val Acols = if (!A.isTransposed) A.colPtrs else A.rowIndices | ||
|
|
||
| if (alpha == 0.0) { | ||
| scal(beta, y) | ||
| return | ||
| } | ||
|
|
||
| if (A.isTransposed) { | ||
| var rowCounter = 0 | ||
| while (rowCounter < mA) { | ||
| var i = Arows(rowCounter) | ||
| val indEnd = Arows(rowCounter + 1) | ||
| var sum = 0.0 | ||
| var k = 0 | ||
| while (k < xNnz && i < indEnd) { | ||
| if (xIndices(k) == Acols(i)) { | ||
| sum += Avals(i) * xValues(k) | ||
| i += 1 | ||
| } | ||
| k += 1 | ||
| } | ||
| yValues(rowCounter) = sum * alpha + beta * yValues(rowCounter) | ||
| rowCounter += 1 | ||
| } | ||
| } else { | ||
| scal(beta, y) | ||
|
|
||
| var colCounterForA = 0 | ||
| var k = 0 | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. scal(beta, y)
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do we need to check
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You are right. |
||
| while (colCounterForA < nA && k < xNnz) { | ||
| if (xIndices(k) == colCounterForA) { | ||
| var i = Acols(colCounterForA) | ||
| val indEnd = Acols(colCounterForA + 1) | ||
|
|
||
| val xTemp = xValues(k) * alpha | ||
| while (i < indEnd) { | ||
| val rowIndex = Arows(i) | ||
| yValues(Arows(i)) += Avals(i) * xTemp | ||
| i += 1 | ||
| } | ||
| k += 1 | ||
| } | ||
| colCounterForA += 1 | ||
| } | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I checked the logic, and it looks correct. Great. |
||
| } | ||
| } | ||
|
|
||
| /** | ||
| * y := alpha * A * x + beta * y | ||
| * For `SparseMatrix` A and `DenseVector` x. | ||
| */ | ||
| private def gemv( | ||
| alpha: Double, | ||
| A: SparseMatrix, | ||
| x: DenseVector, | ||
| beta: Double, | ||
| y: DenseVector): Unit = { | ||
| y: DenseVector): Unit = { | ||
| val xValues = x.values | ||
| val yValues = y.values | ||
| val mA: Int = A.numRows | ||
|
|
@@ -534,10 +659,7 @@ private[spark] object BLAS extends Serializable with Logging { | |
| rowCounter += 1 | ||
| } | ||
| } else { | ||
| // Scale vector first if `beta` is not equal to 0.0 | ||
| if (beta != 0.0) { | ||
| scal(beta, y) | ||
| } | ||
| scal(beta, y) | ||
| // Perform matrix-vector multiplication and add to y | ||
| var colCounterForA = 0 | ||
| while (colCounterForA < nA) { | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -77,8 +77,13 @@ sealed trait Matrix extends Serializable { | |
| C | ||
| } | ||
|
|
||
| /** Convenience method for `Matrix`-`DenseVector` multiplication. */ | ||
| /** Convenience method for `Matrix`-`DenseVector` multiplication. For binary compatibility. */ | ||
| def multiply(y: DenseVector): DenseVector = { | ||
| multiply(y.asInstanceOf[Vector]) | ||
| } | ||
|
|
||
| /** Convenience method for `Matrix`-`Vector` multiplication. */ | ||
| def multiply(y: Vector): DenseVector = { | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We cannot delete |
||
| val output = new DenseVector(new Array[Double](numRows)) | ||
| BLAS.gemv(1.0, this, y, 0.0, output) | ||
| output | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -256,42 +256,108 @@ class BLASSuite extends FunSuite { | |
| val dA = | ||
| new DenseMatrix(4, 3, Array(0.0, 1.0, 0.0, 0.0, 2.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 3.0)) | ||
| val sA = new SparseMatrix(4, 3, Array(0, 1, 3, 4), Array(1, 0, 2, 3), Array(1.0, 2.0, 1.0, 3.0)) | ||
|
|
||
| val x = new DenseVector(Array(1.0, 2.0, 3.0)) | ||
|
|
||
| val dA2 = | ||
| new DenseMatrix(4, 3, Array(0.0, 2.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 3.0), true) | ||
| val sA2 = | ||
| new SparseMatrix(4, 3, Array(0, 1, 2, 3, 4), Array(1, 0, 1, 2), Array(2.0, 1.0, 1.0, 3.0), | ||
| true) | ||
|
|
||
| val dx = new DenseVector(Array(1.0, 2.0, 3.0)) | ||
| val sx = dx.toSparse | ||
| val expected = new DenseVector(Array(4.0, 1.0, 2.0, 9.0)) | ||
|
|
||
| assert(dA.multiply(x) ~== expected absTol 1e-15) | ||
| assert(sA.multiply(x) ~== expected absTol 1e-15) | ||
|
|
||
| assert(dA.multiply(dx) ~== expected absTol 1e-15) | ||
| assert(sA.multiply(dx) ~== expected absTol 1e-15) | ||
| assert(dA.multiply(sx) ~== expected absTol 1e-15) | ||
| assert(sA.multiply(sx) ~== expected absTol 1e-15) | ||
|
|
||
| val y1 = new DenseVector(Array(1.0, 3.0, 1.0, 0.0)) | ||
| val y2 = y1.copy | ||
| val y3 = y1.copy | ||
| val y4 = y1.copy | ||
| val y5 = y1.copy | ||
| val y6 = y1.copy | ||
| val y7 = y1.copy | ||
| val y8 = y1.copy | ||
| val y9 = y1.copy | ||
| val y10 = y1.copy | ||
| val y11 = y1.copy | ||
| val y12 = y1.copy | ||
| val y13 = y1.copy | ||
| val y14 = y1.copy | ||
| val y15 = y1.copy | ||
| val y16 = y1.copy | ||
|
|
||
| val expected2 = new DenseVector(Array(6.0, 7.0, 4.0, 9.0)) | ||
| val expected3 = new DenseVector(Array(10.0, 8.0, 6.0, 18.0)) | ||
|
|
||
| gemv(1.0, dA, x, 2.0, y1) | ||
| gemv(1.0, sA, x, 2.0, y2) | ||
| gemv(2.0, dA, x, 2.0, y3) | ||
| gemv(2.0, sA, x, 2.0, y4) | ||
| gemv(1.0, dA, dx, 2.0, y1) | ||
| gemv(1.0, sA, dx, 2.0, y2) | ||
| gemv(1.0, dA, sx, 2.0, y3) | ||
| gemv(1.0, sA, sx, 2.0, y4) | ||
|
|
||
| gemv(1.0, dA2, dx, 2.0, y5) | ||
| gemv(1.0, sA2, dx, 2.0, y6) | ||
| gemv(1.0, dA2, sx, 2.0, y7) | ||
| gemv(1.0, sA2, sx, 2.0, y8) | ||
|
|
||
| gemv(2.0, dA, dx, 2.0, y9) | ||
| gemv(2.0, sA, dx, 2.0, y10) | ||
| gemv(2.0, dA, sx, 2.0, y11) | ||
| gemv(2.0, sA, sx, 2.0, y12) | ||
|
|
||
| gemv(2.0, dA2, dx, 2.0, y13) | ||
| gemv(2.0, sA2, dx, 2.0, y14) | ||
| gemv(2.0, dA2, sx, 2.0, y15) | ||
| gemv(2.0, sA2, sx, 2.0, y16) | ||
|
|
||
| assert(y1 ~== expected2 absTol 1e-15) | ||
| assert(y2 ~== expected2 absTol 1e-15) | ||
| assert(y3 ~== expected3 absTol 1e-15) | ||
| assert(y4 ~== expected3 absTol 1e-15) | ||
| assert(y3 ~== expected2 absTol 1e-15) | ||
| assert(y4 ~== expected2 absTol 1e-15) | ||
|
|
||
| assert(y5 ~== expected2 absTol 1e-15) | ||
| assert(y6 ~== expected2 absTol 1e-15) | ||
| assert(y7 ~== expected2 absTol 1e-15) | ||
| assert(y8 ~== expected2 absTol 1e-15) | ||
|
|
||
| assert(y9 ~== expected3 absTol 1e-15) | ||
| assert(y10 ~== expected3 absTol 1e-15) | ||
| assert(y11 ~== expected3 absTol 1e-15) | ||
| assert(y12 ~== expected3 absTol 1e-15) | ||
|
|
||
| assert(y13 ~== expected3 absTol 1e-15) | ||
| assert(y14 ~== expected3 absTol 1e-15) | ||
| assert(y15 ~== expected3 absTol 1e-15) | ||
| assert(y16 ~== expected3 absTol 1e-15) | ||
|
|
||
| withClue("columns of A don't match the rows of B") { | ||
| intercept[Exception] { | ||
| gemv(1.0, dA.transpose, x, 2.0, y1) | ||
| gemv(1.0, dA.transpose, dx, 2.0, y1) | ||
| } | ||
| intercept[Exception] { | ||
| gemv(1.0, sA.transpose, dx, 2.0, y1) | ||
| } | ||
| intercept[Exception] { | ||
| gemv(1.0, dA.transpose, sx, 2.0, y1) | ||
| } | ||
| intercept[Exception] { | ||
| gemv(1.0, sA.transpose, sx, 2.0, y1) | ||
| } | ||
| } | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. check exception for |
||
|
|
||
| val dAT = | ||
| new DenseMatrix(3, 4, Array(0.0, 2.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 3.0)) | ||
| val sAT = | ||
| new SparseMatrix(3, 4, Array(0, 1, 2, 3, 4), Array(1, 0, 1, 2), Array(2.0, 1.0, 1.0, 3.0)) | ||
|
|
||
| val dATT = dAT.transpose | ||
| val sATT = sAT.transpose | ||
|
|
||
| assert(dATT.multiply(x) ~== expected absTol 1e-15) | ||
| assert(sATT.multiply(x) ~== expected absTol 1e-15) | ||
| assert(dATT.multiply(dx) ~== expected absTol 1e-15) | ||
| assert(sATT.multiply(dx) ~== expected absTol 1e-15) | ||
| assert(dATT.multiply(sx) ~== expected absTol 1e-15) | ||
| assert(sATT.multiply(sx) ~== expected absTol 1e-15) | ||
| } | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. add |
||
| } | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How about
SparseMatrixandSparseVector? To make the consistent naming, we can usedmA,smA,dvx, andsvx.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If you don't really want to add
SparseMatrixandSparseVector, the type safety will be broken when you call with this configuration. Previously, this function is totally type safe in compile time, and no way to get into "case _".