Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 56 additions & 0 deletions mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,52 @@ sealed trait Matrix extends Serializable {
* corresponding value in the matrix with type `Double`.
*/
private[spark] def foreachActive(f: (Int, Int, Double) => Unit)

override def hashCode(): Int = {
var result: Int = 31 + numRows
result = 31 * result + numCols
this.foreachActive { case (rowInd, colInd, value) =>
// ignore explict 0 for comparison between sparse and dense
if (value != 0) {
result = 31 * result + rowInd
result = 31 * result + colInd
// refer to {@link java.util.Arrays.equals} for hash algorithm
val bits = java.lang.Double.doubleToLongBits(value)
result = 31 * result + (bits ^ (bits >>> 32)).toInt
}
}
result
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hash isTransposed.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this is be done? Even if isTransposed is not same, two matrices can be equal.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you please verify this?
There might be cases where a CSR matrix is the same as a CSC matrix. See (https://github.com/apache/spark/pull/5081/files#diff-8fbb9a5e1adf997a37f4d05521b8a1acR488)
If I hash isTransposed I would give different hashes for both of these, which is not desirable.

}

override def equals(other: Any): Boolean = {
other match {
case mat: Matrix =>
if (mat.numRows != this.numRows || mat.numCols != this.numCols) return false
(this, mat) match {
case (dm1: DenseMatrix, dm2: DenseMatrix) =>
Arrays.equals(dm1.toArray, dm2.toArray)
case (sm1: SparseMatrix, sm2: SparseMatrix) =>
// For the case in which one matrix is CSC and the other is CSR
// the values, colPtrs and rowIndices need not be the same.
// When both matrices are of the same type, it is sufficient to check that
// the values, colPtrs and rowIndices are the same.
if (sm1.isTransposed != sm2.isTransposed) {
if (sm1.values.length != sm2.values.length) return false
sm1.foreachActive {
case (i, j, value) => if (value != sm2(i, j)) return false
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sm2(i, j) is quite expensive. Could you leave a TODO here to reduce the cost?

}
} else {
if (sm1.values != sm2.values) return false
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is not correct because one matrix may contain explicit zeros. Please include this case in the unit test.

if (sm1.colPtrs != sm2.colPtrs) return false
if (sm1.rowIndices != sm2.rowIndices) return false
}
true
case (dm1: DenseMatrix, sm1: SparseMatrix) => Matrices.equals(dm1, sm1)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sm1 -> sm2?

case (sm1: SparseMatrix, dm1: DenseMatrix) => Matrices.equals(dm1, sm1)
}
case _ => false
}
}
}

@DeveloperApi
Expand Down Expand Up @@ -814,6 +860,16 @@ object Matrices {
}
}

/**
* Check equality between sparse/dense matrices
*/
private[mllib] def equals(denseMat: DenseMatrix, sparseMat: SparseMatrix): Boolean = {
sparseMat.foreachActive { (row, col, value) =>
if (value != denseMat(row, col)) return false
}
return true
}

/**
* Generate a `Matrix` consisting of zeros.
* @param numRows number of rows of the matrix
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -455,4 +455,43 @@ class MatricesSuite extends SparkFunSuite {
lines = mat.toString(5, 100).lines.toArray
assert(lines.size == 5 && lines.forall(_.size <= 100))
}

test("equals") {
// A == (A.T).T
val dm = new DenseMatrix(2, 2, Array(0.0, 1.4, 0.3, 3.5))
val dm2 = new DenseMatrix(2, 3, Array(0.0, 1.4, 2.9, 0.0, 0.0, 0.0))
val dmt = new DenseMatrix(2, 2, Array(0.0, 0.3, 1.4, 3.5)).transpose
val dmt2 = new DenseMatrix(3, 2, Array(0.0, 2.9, 0.0, 1.4, 0.0, 0.0)).transpose
assert(dm == dmt)
assert(dm2 == dmt2)
assert(dm != dmt2)
assert(dm2 != dmt)

// Check that dense matrix == corresponding sparse matrix.
val sm = new SparseMatrix(2, 2, Array(0, 1, 3), Array(1, 0, 1), Array(1.4, 0.3, 3.5))
val sm2 = new SparseMatrix(2, 3, Array(0, 1, 2, 2), Array(1, 0), Array(1.4, 2.9))
val sm3 = new SparseMatrix(2, 2, Array(0, 1, 3), Array(1, 0, 1), Array(1.4, 0.9, 3.5))
val sm4 = new SparseMatrix(2, 3, Array(0, 1, 2, 2), Array(1, 0), Array(1.4, 2.2))
assert(dm == sm)
assert(dm2 == sm2)
assert(dm != sm3)
assert(dm2 != sm4)

// Check that dense matrix == corresponding CSC matrix.
val csr = new SparseMatrix(2, 2, Array(0, 1, 3), Array(1, 0, 1), Array(0.3, 1.4, 3.5))
val csr2 = new SparseMatrix(3, 2, Array(0, 1, 2), Array(1, 0), Array(2.9, 1.4))
assert(dm == csr.transpose)
assert(dm2 == csr2.transpose)
assert(dm != csr2.transpose)
assert(dm2 != csr.transpose)

// Check equality between csr and csc matrices
assert(sm == csr.transpose)
assert(sm2 == csr2.transpose)
assert(sm != csr2.transpose)
assert(sm2 != csr.transpose)

Seq(dm, dm2, sm, sm2).foreach{mat => assert(mat == mat.asInstanceOf[Matrix])}

}
}