Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 32 additions & 0 deletions mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,15 @@ sealed trait Vector extends Serializable {
def copy: Vector = {
throw new NotImplementedError(s"copy is not implemented for ${this.getClass}.")
}

/**
* Applies a function `f` to all the active elements of dense and sparse vector.
*
* @param f the function takes two parameters where the first parameter is the index of
* the vector with type `Int`, and the second parameter is the corresponding value
* with type `Double`.
*/
private[spark] def foreachActive(f: (Int, Double) => Unit)
}

/**
Expand Down Expand Up @@ -273,6 +282,17 @@ class DenseVector(val values: Array[Double]) extends Vector {
override def copy: DenseVector = {
new DenseVector(values.clone())
}

private[spark] override def foreachActive(f: (Int, Double) => Unit) = {
var i = 0
val localValuesSize = values.size
val localValues = values

while (i < localValuesSize) {
f(i, localValues(i))
i += 1
}
}
}

/**
Expand Down Expand Up @@ -309,4 +329,16 @@ class SparseVector(
}

private[mllib] override def toBreeze: BV[Double] = new BSV[Double](indices, values, size)

private[spark] override def foreachActive(f: (Int, Double) => Unit) = {
var i = 0
val localValuesSize = values.size
val localIndices = indices
val localValues = values

while (i < localValuesSize) {
f(localIndices(i), localValues(i))
i += 1
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,8 @@

package org.apache.spark.mllib.stat

import breeze.linalg.{DenseVector => BDV}

import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vectors, Vector}
import org.apache.spark.mllib.linalg.{Vectors, Vector}

/**
* :: DeveloperApi ::
Expand All @@ -40,37 +38,14 @@ import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vectors, Vector
class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with Serializable {

private var n = 0
private var currMean: BDV[Double] = _
private var currM2n: BDV[Double] = _
private var currM2: BDV[Double] = _
private var currL1: BDV[Double] = _
private var currMean: Array[Double] = _
private var currM2n: Array[Double] = _
private var currM2: Array[Double] = _
private var currL1: Array[Double] = _
private var totalCnt: Long = 0
private var nnz: BDV[Double] = _
private var currMax: BDV[Double] = _
private var currMin: BDV[Double] = _

/**
* Adds input value to position i.
*/
private[this] def add(i: Int, value: Double) = {
if (value != 0.0) {
if (currMax(i) < value) {
currMax(i) = value
}
if (currMin(i) > value) {
currMin(i) = value
}

val prevMean = currMean(i)
val diff = value - prevMean
currMean(i) = prevMean + diff / (nnz(i) + 1.0)
currM2n(i) += (value - currMean(i)) * diff
currM2(i) += value * value
currL1(i) += math.abs(value)

nnz(i) += 1.0
}
}
private var nnz: Array[Double] = _
private var currMax: Array[Double] = _
private var currMin: Array[Double] = _

/**
* Add a new sample to this summarizer, and update the statistical summary.
Expand All @@ -83,33 +58,36 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S
require(sample.size > 0, s"Vector should have dimension larger than zero.")
n = sample.size

currMean = BDV.zeros[Double](n)
currM2n = BDV.zeros[Double](n)
currM2 = BDV.zeros[Double](n)
currL1 = BDV.zeros[Double](n)
nnz = BDV.zeros[Double](n)
currMax = BDV.fill(n)(Double.MinValue)
currMin = BDV.fill(n)(Double.MaxValue)
currMean = Array.ofDim[Double](n)
currM2n = Array.ofDim[Double](n)
currM2 = Array.ofDim[Double](n)
currL1 = Array.ofDim[Double](n)
nnz = Array.ofDim[Double](n)
currMax = Array.fill[Double](n)(Double.MinValue)
currMin = Array.fill[Double](n)(Double.MaxValue)
}

require(n == sample.size, s"Dimensions mismatch when adding new sample." +
s" Expecting $n but got ${sample.size}.")

sample match {
case dv: DenseVector => {
var j = 0
while (j < dv.size) {
add(j, dv.values(j))
j += 1
sample.foreachActive { (index, value) =>
if (value != 0.0) {
if (currMax(index) < value) {
currMax(index) = value
}
}
case sv: SparseVector =>
var j = 0
while (j < sv.indices.size) {
add(sv.indices(j), sv.values(j))
j += 1
if (currMin(index) > value) {
currMin(index) = value
}
case v => throw new IllegalArgumentException("Do not support vector type " + v.getClass)

val prevMean = currMean(index)
val diff = value - prevMean
currMean(index) = prevMean + diff / (nnz(index) + 1.0)
currM2n(index) += (value - currMean(index)) * diff
currM2(index) += value * value
currL1(index) += math.abs(value)

nnz(index) += 1.0
}
}

totalCnt += 1
Expand Down Expand Up @@ -152,34 +130,34 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S
}
} else if (totalCnt == 0 && other.totalCnt != 0) {
this.n = other.n
this.currMean = other.currMean.copy
this.currM2n = other.currM2n.copy
this.currM2 = other.currM2.copy
this.currL1 = other.currL1.copy
this.currMean = other.currMean.clone
this.currM2n = other.currM2n.clone
this.currM2 = other.currM2.clone
this.currL1 = other.currL1.clone
this.totalCnt = other.totalCnt
this.nnz = other.nnz.copy
this.currMax = other.currMax.copy
this.currMin = other.currMin.copy
this.nnz = other.nnz.clone
this.currMax = other.currMax.clone
this.currMin = other.currMin.clone
}
this
}

override def mean: Vector = {
require(totalCnt > 0, s"Nothing has been added to this summarizer.")

val realMean = BDV.zeros[Double](n)
val realMean = Array.ofDim[Double](n)
var i = 0
while (i < n) {
realMean(i) = currMean(i) * (nnz(i) / totalCnt)
i += 1
}
Vectors.fromBreeze(realMean)
Vectors.dense(realMean)
}

override def variance: Vector = {
require(totalCnt > 0, s"Nothing has been added to this summarizer.")

val realVariance = BDV.zeros[Double](n)
val realVariance = Array.ofDim[Double](n)

val denominator = totalCnt - 1.0

Expand All @@ -194,16 +172,15 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S
i += 1
}
}

Vectors.fromBreeze(realVariance)
Vectors.dense(realVariance)
}

override def count: Long = totalCnt

override def numNonzeros: Vector = {
require(totalCnt > 0, s"Nothing has been added to this summarizer.")

Vectors.fromBreeze(nnz)
Vectors.dense(nnz)
}

override def max: Vector = {
Expand All @@ -214,7 +191,7 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S
if ((nnz(i) < totalCnt) && (currMax(i) < 0.0)) currMax(i) = 0.0
i += 1
}
Vectors.fromBreeze(currMax)
Vectors.dense(currMax)
}

override def min: Vector = {
Expand All @@ -225,25 +202,25 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S
if ((nnz(i) < totalCnt) && (currMin(i) > 0.0)) currMin(i) = 0.0
i += 1
}
Vectors.fromBreeze(currMin)
Vectors.dense(currMin)
}

override def normL2: Vector = {
require(totalCnt > 0, s"Nothing has been added to this summarizer.")

val realMagnitude = BDV.zeros[Double](n)
val realMagnitude = Array.ofDim[Double](n)

var i = 0
while (i < currM2.size) {
realMagnitude(i) = math.sqrt(currM2(i))
i += 1
}

Vectors.fromBreeze(realMagnitude)
Vectors.dense(realMagnitude)
}

override def normL1: Vector = {
require(totalCnt > 0, s"Nothing has been added to this summarizer.")
Vectors.fromBreeze(currL1)

Vectors.dense(currL1)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -173,4 +173,28 @@ class VectorsSuite extends FunSuite {
val v = Vectors.fromBreeze(x(::, 0))
assert(v.size === x.rows)
}

test("foreachActive") {
val dv = Vectors.dense(0.0, 1.2, 3.1, 0.0)
val sv = Vectors.sparse(4, Seq((1, 1.2), (2, 3.1), (3, 0.0)))

val dvMap = scala.collection.mutable.Map[Int, Double]()
dv.foreachActive { (index, value) =>
dvMap.put(index, value)
}
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Calling foreach without parenthesis

dv.foreach {
  case (index: Int, value: Double) => dvMap0.put(index, value)
}

will cause

Error:(182, 16) missing parameter type for expanded function
The argument types of an anonymous function must be fully known. (SLS 8.5)
Expected type was: Boolean
    dv.foreach {
               ^

This is scala curry function overloading issue. It seems that unless we change the signature to

private[spark] def foreach(skippingZeros: Boolean = false, f: ((Int, Double)) => Unit)

we need to explicitly call it with parenthesis when we want to call it with default value of skippingZeros.

assert(dvMap.size === 4)
assert(dvMap.get(0) === Some(0.0))
assert(dvMap.get(1) === Some(1.2))
assert(dvMap.get(2) === Some(3.1))
assert(dvMap.get(3) === Some(0.0))

val svMap = scala.collection.mutable.Map[Int, Double]()
sv.foreachActive { (index, value) =>
svMap.put(index, value)
}
assert(svMap.size === 3)
assert(svMap.get(1) === Some(1.2))
assert(svMap.get(2) === Some(3.1))
assert(svMap.get(3) === Some(0.0))
}
}