diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala index 1c94479ef02a..ec32e37afb79 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala @@ -537,7 +537,7 @@ class RowMatrix @Since("1.0.0") ( def tallSkinnyQR(computeQ: Boolean = false): QRDecomposition[RowMatrix, Matrix] = { val col = numCols().toInt // split rows horizontally into smaller matrices, and compute QR for each of them - val blockQRs = rows.retag(classOf[Vector]).glom().map { partRows => + val blockQRs = rows.retag(classOf[Vector]).glom().filter(_.length != 0).map { partRows => val bdm = BDM.zeros[Double](partRows.length, col) var i = 0 partRows.foreach { row => @@ -548,10 +548,11 @@ class RowMatrix @Since("1.0.0") ( } // combine the R part from previous results vertically into a tall matrix - val combinedR = blockQRs.treeReduce{ (r1, r2) => + val combinedR = blockQRs.treeReduce { (r1, r2) => val stackedR = BDM.vertcat(r1, r2) breeze.linalg.qr.reduced(stackedR).r } + val finalR = Matrices.fromBreeze(combinedR.toDenseMatrix) val finalQ = if (computeQ) { try { diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/RowMatrixSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/RowMatrixSuite.scala index 7c4c6d8409c6..7c9e14f8cee7 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/RowMatrixSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/RowMatrixSuite.scala @@ -28,6 +28,7 @@ import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.linalg.{Matrices, Vector, Vectors} import org.apache.spark.mllib.random.RandomRDDs import org.apache.spark.mllib.util.{LocalClusterSparkContext, MLlibTestSparkContext} +import org.apache.spark.mllib.util.TestingUtils._ class RowMatrixSuite extends SparkFunSuite with MLlibTestSparkContext { @@ -281,6 +282,22 @@ class RowMatrixSuite extends SparkFunSuite with MLlibTestSparkContext { assert(cov(i, j) === cov(j, i)) } } + + test("QR decomposition should aware of empty partition (SPARK-16369)") { + val mat: RowMatrix = new RowMatrix(sc.parallelize(denseData, 1)) + val qrResult = mat.tallSkinnyQR(true) + + val matWithEmptyPartition = new RowMatrix(sc.parallelize(denseData, 8)) + val qrResult2 = matWithEmptyPartition.tallSkinnyQR(true) + + assert(qrResult.Q.numCols() === qrResult2.Q.numCols(), "Q matrix ncol not match") + assert(qrResult.Q.numRows() === qrResult2.Q.numRows(), "Q matrix nrow not match") + qrResult.Q.rows.collect().zip(qrResult2.Q.rows.collect()) + .foreach(x => assert(x._1 ~== x._2 relTol 1E-8, "Q matrix not match")) + + qrResult.R.toArray.zip(qrResult2.R.toArray) + .foreach(x => assert(x._1 ~== x._2 relTol 1E-8, "R matrix not match")) + } } class RowMatrixClusterSuite extends SparkFunSuite with LocalClusterSparkContext {