Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -425,22 +425,27 @@ class BlockMatrix @Since("1.3.0") (
*/
private[distributed] def simulateMultiply(
other: BlockMatrix,
partitioner: GridPartitioner): (BlockDestinations, BlockDestinations) = {
val leftMatrix = blockInfo.keys.collect() // blockInfo should already be cached
val rightMatrix = other.blocks.keys.collect()
partitioner: GridPartitioner,
midDimSplitNum: Int): (BlockDestinations, BlockDestinations) = {
val leftMatrix = blockInfo.keys.collect()
val rightMatrix = other.blockInfo.keys.collect()

val rightCounterpartsHelper = rightMatrix.groupBy(_._1).mapValues(_.map(_._2))
val leftDestinations = leftMatrix.map { case (rowIndex, colIndex) =>
val rightCounterparts = rightCounterpartsHelper.getOrElse(colIndex, Array.empty[Int])
val partitions = rightCounterparts.map(b => partitioner.getPartition((rowIndex, b)))
((rowIndex, colIndex), partitions.toSet)
val midDimSplitIndex = colIndex % midDimSplitNum
((rowIndex, colIndex),
partitions.toSet.map((pid: Int) => pid * midDimSplitNum + midDimSplitIndex))
}.toMap

val leftCounterpartsHelper = leftMatrix.groupBy(_._2).mapValues(_.map(_._1))
val rightDestinations = rightMatrix.map { case (rowIndex, colIndex) =>
val leftCounterparts = leftCounterpartsHelper.getOrElse(rowIndex, Array.empty[Int])
val partitions = leftCounterparts.map(b => partitioner.getPartition((b, colIndex)))
((rowIndex, colIndex), partitions.toSet)
val midDimSplitIndex = rowIndex % midDimSplitNum
((rowIndex, colIndex),
partitions.toSet.map((pid: Int) => pid * midDimSplitNum + midDimSplitIndex))
}.toMap

(leftDestinations, rightDestinations)
Expand All @@ -459,14 +464,39 @@ class BlockMatrix @Since("1.3.0") (
*/
@Since("1.3.0")
def multiply(other: BlockMatrix): BlockMatrix = {
multiply(other, 1)
}

/**
* Left multiplies this [[BlockMatrix]] to `other`, another [[BlockMatrix]]. The `colsPerBlock`
* of this matrix must equal the `rowsPerBlock` of `other`. If `other` contains
* `SparseMatrix`, they will have to be converted to a `DenseMatrix`. The output
* [[BlockMatrix]] will only consist of blocks of `DenseMatrix`. This may cause
* some performance issues until support for multiplying two sparse matrices is added.
* Blocks with duplicate indices will be added with each other.
*
* @param other Matrix `B` in `A * B = C`
* @param numMidDimSplits Number of splits to cut on the middle dimension when doing
* multiplication. For example, when multiplying a Matrix `A` of
* size `m x n` with Matrix `B` of size `n x k`, this parameter
* configures the parallelism to use when grouping the matrices. The
* parallelism will increase from `m x k` to `m x k x numMidDimSplits`,
* which in some cases also reduces total shuffled data.
*/
@Since("2.2.0")
def multiply(
other: BlockMatrix,
numMidDimSplits: Int): BlockMatrix = {
require(numCols() == other.numRows(), "The number of columns of A and the number of rows " +
s"of B must be equal. A.numCols: ${numCols()}, B.numRows: ${other.numRows()}. If you " +
"think they should be equal, try setting the dimensions of A and B explicitly while " +
"initializing them.")
require(numMidDimSplits > 0, "numMidDimSplits should be a positive integer.")
if (colsPerBlock == other.rowsPerBlock) {
val resultPartitioner = GridPartitioner(numRowBlocks, other.numColBlocks,
math.max(blocks.partitions.length, other.blocks.partitions.length))
val (leftDestinations, rightDestinations) = simulateMultiply(other, resultPartitioner)
val (leftDestinations, rightDestinations)
= simulateMultiply(other, resultPartitioner, numMidDimSplits)
// Each block of A must be multiplied with the corresponding blocks in the columns of B.
val flatA = blocks.flatMap { case ((blockRowIndex, blockColIndex), block) =>
val destinations = leftDestinations.getOrElse((blockRowIndex, blockColIndex), Set.empty)
Expand All @@ -477,7 +507,11 @@ class BlockMatrix @Since("1.3.0") (
val destinations = rightDestinations.getOrElse((blockRowIndex, blockColIndex), Set.empty)
destinations.map(j => (j, (blockRowIndex, blockColIndex, block)))
}
val newBlocks = flatA.cogroup(flatB, resultPartitioner).flatMap { case (pId, (a, b)) =>
val intermediatePartitioner = new Partitioner {
override def numPartitions: Int = resultPartitioner.numPartitions * numMidDimSplits
override def getPartition(key: Any): Int = key.asInstanceOf[Int]
}
val newBlocks = flatA.cogroup(flatB, intermediatePartitioner).flatMap { case (pId, (a, b)) =>
a.flatMap { case (leftRowIndex, leftColIndex, leftBlock) =>
b.filter(_._1 == leftColIndex).map { case (rightRowIndex, rightColIndex, rightBlock) =>
val C = rightBlock match {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -267,6 +267,15 @@ class BlockMatrixSuite extends SparkFunSuite with MLlibTestSparkContext {
assert(sparseBM.subtract(sparseBM).toBreeze() === sparseBM.subtract(denseBM).toBreeze())
}

def testMultiply(A: BlockMatrix, B: BlockMatrix, expectedResult: Matrix,
numMidDimSplits: Int): Unit = {
val C = A.multiply(B, numMidDimSplits)
val localC = C.toLocalMatrix()
assert(C.numRows() === A.numRows())
assert(C.numCols() === B.numCols())
assert(localC ~== expectedResult absTol 1e-8)
}

test("multiply") {
// identity matrix
val blocks: Seq[((Int, Int), Matrix)] = Seq(
Expand Down Expand Up @@ -302,12 +311,13 @@ class BlockMatrixSuite extends SparkFunSuite with MLlibTestSparkContext {
// Try it with increased number of partitions
val largeA = new BlockMatrix(sc.parallelize(largerAblocks, 10), 6, 4)
val largeB = new BlockMatrix(sc.parallelize(largerBblocks, 8), 4, 4)
val largeC = largeA.multiply(largeB)
val localC = largeC.toLocalMatrix()

val result = largeA.toLocalMatrix().multiply(largeB.toLocalMatrix().asInstanceOf[DenseMatrix])
assert(largeC.numRows() === largeA.numRows())
assert(largeC.numCols() === largeB.numCols())
assert(localC ~== result absTol 1e-8)

testMultiply(largeA, largeB, result, 1)
testMultiply(largeA, largeB, result, 2)
testMultiply(largeA, largeB, result, 3)
testMultiply(largeA, largeB, result, 4)
}

test("simulate multiply") {
Expand All @@ -318,7 +328,7 @@ class BlockMatrixSuite extends SparkFunSuite with MLlibTestSparkContext {
val B = new BlockMatrix(rdd, colPerPart, rowPerPart)
val resultPartitioner = GridPartitioner(gridBasedMat.numRowBlocks, B.numColBlocks,
math.max(numPartitions, 2))
val (destinationsA, destinationsB) = gridBasedMat.simulateMultiply(B, resultPartitioner)
val (destinationsA, destinationsB) = gridBasedMat.simulateMultiply(B, resultPartitioner, 1)
assert(destinationsA((0, 0)) === Set(0))
assert(destinationsA((0, 1)) === Set(2))
assert(destinationsA((1, 0)) === Set(0))
Expand Down