Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,11 @@ import java.util.{Arrays, Random}
import scala.collection.mutable.{ArrayBuffer, ArrayBuilder => MArrayBuilder, HashSet => MHashSet}

import breeze.linalg.{CSCMatrix => BSM, DenseMatrix => BDM, Matrix => BM}
import com.github.fommil.netlib.BLAS.{getInstance => blas}

import org.apache.spark.annotation.{DeveloperApi, Since}
import org.apache.spark.sql.catalyst.expressions.GenericMutableRow
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.GenericMutableRow
import org.apache.spark.sql.catalyst.util.GenericArrayData
import org.apache.spark.sql.types._

Expand Down Expand Up @@ -58,6 +59,20 @@ sealed trait Matrix extends Serializable {
newArray
}

/**
* Returns an iterator of column vectors.
* This operation could be expensive, depending on the underlying storage.
*/
@Since("2.0.0")
def colIter: Iterator[Vector]

/**
* Returns an iterator of row vectors.
* This operation could be expensive, depending on the underlying storage.
*/
@Since("2.0.0")
def rowIter: Iterator[Vector] = this.transpose.colIter

/** Converts to a breeze matrix. */
private[mllib] def toBreeze: BM[Double]

Expand Down Expand Up @@ -386,6 +401,21 @@ class DenseMatrix @Since("1.3.0") (
}
new SparseMatrix(numRows, numCols, colPtrs, rowIndices.result(), spVals.result())
}

@Since("2.0.0")
override def colIter: Iterator[Vector] = {
if (isTransposed) {
Iterator.tabulate(numCols) { j =>
val col = new Array[Double](numRows)
blas.dcopy(numRows, values, j, numCols, col, 0, 1)
new DenseVector(col)
}
} else {
Iterator.tabulate(numCols) { j =>
new DenseVector(values.slice(j * numRows, (j + 1) * numRows))
}
}
}
}

/**
Expand Down Expand Up @@ -656,6 +686,38 @@ class SparseMatrix @Since("1.3.0") (
@Since("1.5.0")
override def numActives: Int = values.length

@Since("2.0.0")
override def colIter: Iterator[Vector] = {
if (isTransposed) {
val indicesArray = Array.fill(numCols)(MArrayBuilder.make[Int])
val valuesArray = Array.fill(numCols)(MArrayBuilder.make[Double])
var i = 0
while (i < numRows) {
var k = colPtrs(i)
val rowEnd = colPtrs(i + 1)
while (k < rowEnd) {
val j = rowIndices(k)
indicesArray(j) += i
valuesArray(j) += values(k)
k += 1
}
i += 1
}
Iterator.tabulate(numCols) { j =>
val ii = indicesArray(j).result()
val vv = valuesArray(j).result()
new SparseVector(numRows, ii, vv)
}
} else {
Iterator.tabulate(numCols) { j =>
val colStart = colPtrs(j)
val colEnd = colPtrs(j + 1)
val ii = rowIndices.slice(colStart, colEnd)
val vv = values.slice(colStart, colEnd)
new SparseVector(numRows, ii, vv)
}
}
}
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -494,4 +494,17 @@ class MatricesSuite extends SparkFunSuite {
assert(sm1.numNonzeros === 1)
assert(sm1.numActives === 3)
}

test("row/col iterator") {
val dm = new DenseMatrix(3, 2, Array(0, 1, 2, 3, 4, 0))
val sm = dm.toSparse
val rows = Seq(Vectors.dense(0, 3), Vectors.dense(1, 4), Vectors.dense(2, 0))
val cols = Seq(Vectors.dense(0, 1, 2), Vectors.dense(3, 4, 0))
for (m <- Seq(dm, sm)) {
assert(m.rowIter.toSeq === rows)
assert(m.colIter.toSeq === cols)
assert(m.transpose.rowIter.toSeq === cols)
assert(m.transpose.colIter.toSeq === rows)
}
}
}
4 changes: 4 additions & 0 deletions project/MimaExcludes.scala
Original file line number Diff line number Diff line change
Expand Up @@ -531,6 +531,10 @@ object MimaExcludes {
ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.scheduler.SparkListener.onOtherEvent"),
ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.sql.sources.CreatableRelationProvider.createRelation"),
ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.sql.sources.InsertableRelation.insert")
) ++ Seq(
// SPARK-13927: add row/column iterator to local matrices
ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.mllib.linalg.Matrix.rowIter"),
ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.mllib.linalg.Matrix.colIter")
)
case v if v.startsWith("1.6") =>
Seq(
Expand Down