diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala index 85e63b1382b5e..6d8db7e656561 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala @@ -114,6 +114,52 @@ sealed trait Matrix extends Serializable { * corresponding value in the matrix with type `Double`. */ private[spark] def foreachActive(f: (Int, Int, Double) => Unit) + + override def hashCode(): Int = { + var result: Int = 31 + numRows + result = 31 * result + numCols + this.foreachActive { case (rowInd, colInd, value) => + // ignore explict 0 for comparison between sparse and dense + if (value != 0) { + result = 31 * result + rowInd + result = 31 * result + colInd + // refer to {@link java.util.Arrays.equals} for hash algorithm + val bits = java.lang.Double.doubleToLongBits(value) + result = 31 * result + (bits ^ (bits >>> 32)).toInt + } + } + result + } + + override def equals(other: Any): Boolean = { + other match { + case mat: Matrix => + if (mat.numRows != this.numRows || mat.numCols != this.numCols) return false + (this, mat) match { + case (dm1: DenseMatrix, dm2: DenseMatrix) => + Arrays.equals(dm1.toArray, dm2.toArray) + case (sm1: SparseMatrix, sm2: SparseMatrix) => + // For the case in which one matrix is CSC and the other is CSR + // the values, colPtrs and rowIndices need not be the same. + // When both matrices are of the same type, it is sufficient to check that + // the values, colPtrs and rowIndices are the same. + if (sm1.isTransposed != sm2.isTransposed) { + if (sm1.values.length != sm2.values.length) return false + sm1.foreachActive { + case (i, j, value) => if (value != sm2(i, j)) return false + } + } else { + if (sm1.values != sm2.values) return false + if (sm1.colPtrs != sm2.colPtrs) return false + if (sm1.rowIndices != sm2.rowIndices) return false + } + true + case (dm1: DenseMatrix, sm1: SparseMatrix) => Matrices.equals(dm1, sm1) + case (sm1: SparseMatrix, dm1: DenseMatrix) => Matrices.equals(dm1, sm1) + } + case _ => false + } + } } @DeveloperApi @@ -814,6 +860,16 @@ object Matrices { } } + /** + * Check equality between sparse/dense matrices + */ + private[mllib] def equals(denseMat: DenseMatrix, sparseMat: SparseMatrix): Boolean = { + sparseMat.foreachActive { (row, col, value) => + if (value != denseMat(row, col)) return false + } + return true + } + /** * Generate a `Matrix` consisting of zeros. * @param numRows number of rows of the matrix diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/MatricesSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/MatricesSuite.scala index 8dbb70f5d1c4c..b92f3f2962054 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/MatricesSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/MatricesSuite.scala @@ -455,4 +455,43 @@ class MatricesSuite extends SparkFunSuite { lines = mat.toString(5, 100).lines.toArray assert(lines.size == 5 && lines.forall(_.size <= 100)) } + + test("equals") { + // A == (A.T).T + val dm = new DenseMatrix(2, 2, Array(0.0, 1.4, 0.3, 3.5)) + val dm2 = new DenseMatrix(2, 3, Array(0.0, 1.4, 2.9, 0.0, 0.0, 0.0)) + val dmt = new DenseMatrix(2, 2, Array(0.0, 0.3, 1.4, 3.5)).transpose + val dmt2 = new DenseMatrix(3, 2, Array(0.0, 2.9, 0.0, 1.4, 0.0, 0.0)).transpose + assert(dm == dmt) + assert(dm2 == dmt2) + assert(dm != dmt2) + assert(dm2 != dmt) + + // Check that dense matrix == corresponding sparse matrix. + val sm = new SparseMatrix(2, 2, Array(0, 1, 3), Array(1, 0, 1), Array(1.4, 0.3, 3.5)) + val sm2 = new SparseMatrix(2, 3, Array(0, 1, 2, 2), Array(1, 0), Array(1.4, 2.9)) + val sm3 = new SparseMatrix(2, 2, Array(0, 1, 3), Array(1, 0, 1), Array(1.4, 0.9, 3.5)) + val sm4 = new SparseMatrix(2, 3, Array(0, 1, 2, 2), Array(1, 0), Array(1.4, 2.2)) + assert(dm == sm) + assert(dm2 == sm2) + assert(dm != sm3) + assert(dm2 != sm4) + + // Check that dense matrix == corresponding CSC matrix. + val csr = new SparseMatrix(2, 2, Array(0, 1, 3), Array(1, 0, 1), Array(0.3, 1.4, 3.5)) + val csr2 = new SparseMatrix(3, 2, Array(0, 1, 2), Array(1, 0), Array(2.9, 1.4)) + assert(dm == csr.transpose) + assert(dm2 == csr2.transpose) + assert(dm != csr2.transpose) + assert(dm2 != csr.transpose) + + // Check equality between csr and csc matrices + assert(sm == csr.transpose) + assert(sm2 == csr2.transpose) + assert(sm != csr2.transpose) + assert(sm2 != csr.transpose) + + Seq(dm, dm2, sm, sm2).foreach{mat => assert(mat == mat.asInstanceOf[Matrix])} + + } }