Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,296 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.ml.optim

import com.github.fommil.netlib.LAPACK.{getInstance => lapack}
import org.netlib.util.intW

import org.apache.spark.Logging
import org.apache.spark.mllib.linalg._
import org.apache.spark.mllib.linalg.distributed.RowMatrix
import org.apache.spark.rdd.RDD

/**
* Model fitted by [[WeightedLeastSquares]].
* @param coefficients model coefficients
* @param intercept model intercept
*/
private[ml] class WeightedLeastSquaresModel(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will you merge this code into current LinearRegression.scala?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This might be used by other algorithms like log-linear model or Lp regression. We can discuss this later.

val coefficients: DenseVector,
val intercept: Double) extends Serializable

/**
* Weighted least squares solver via normal equation.
* Given weighted observations (w,,i,,, a,,i,,, b,,i,,), we use the following weighted least squares
* formulation:
*
* min,,x,z,, 1/2 sum,,i,, w,,i,, (a,,i,,^T^ x + z - b,,i,,)^2^ / sum,,i,, w_i
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This contains the cost function solved by ALS.LeastSquaresNESolver (and duplicates the Cholesky dppsv solver); should we make a JIRA to refactor existing code to reuse this class?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

SPARK-10490. Left a TODO.

* + 1/2 lambda / delta sum,,j,, (sigma,,j,, x,,j,,)^2^,
*
* where lambda is the regularization parameter, and delta and sigma,,j,, are controlled by
* [[standardizeLabel]] and [[standardizeFeatures]], respectively.
*
* Set [[regParam]] to 0.0 and turn off both [[standardizeFeatures]] and [[standardizeLabel]] to
* match R's `lm`.
* Turn on [[standardizeLabel]] to match R's `glmnet`.
*
* @param fitIntercept whether to fit intercept. If false, z is 0.0.
* @param regParam L2 regularization parameter (lambda)
* @param standardizeFeatures whether to standardize features. If true, sigma_,,j,, is the
* population standard deviation of the j-th column of A. Otherwise,
* sigma,,j,, is 1.0.
* @param standardizeLabel whether to standardize label. If true, delta is the population standard
* deviation of the label column b. Otherwise, delta is 1.0.
*/
private[ml] class WeightedLeastSquares(
val fitIntercept: Boolean,
val regParam: Double,
val standardizeFeatures: Boolean,
val standardizeLabel: Boolean) extends Logging with Serializable {
import WeightedLeastSquares._

require(regParam >= 0.0, s"regParam cannot be negative: $regParam")
if (regParam == 0.0) {
logWarning("regParam is zero, which might cause numerical instability and overfitting.")
}

/**
* Creates a [[WeightedLeastSquaresModel]] from an RDD of [[Instance]]s.
*/
def fit(instances: RDD[Instance]): WeightedLeastSquaresModel = {
val summary = instances.treeAggregate(new Aggregator)(_.add(_), _.merge(_))
summary.validate()
logInfo(s"Number of instances: ${summary.count}.")
val triK = summary.triK
val bBar = summary.bBar
val bStd = summary.bStd
val aBar = summary.aBar
val aVar = summary.aVar
val abBar = summary.abBar
val aaBar = summary.aaBar
val aaValues = aaBar.values

if (fitIntercept) {
// shift centers
// A^T A - aBar aBar^T
RowMatrix.dspr(-1.0, aBar, aaValues)
// A^T b - bBar aBar
BLAS.axpy(-bBar, aBar, abBar)
}

// add regularization to diagonals
var i = 0
var j = 2
while (i < triK) {
var lambda = regParam
if (standardizeFeatures) {
lambda *= aVar(j - 2)
}
if (standardizeLabel) {
// TODO: handle the case when bStd = 0
lambda /= bStd
}
aaValues(i) += lambda
i += j
j += 1
}

val x = choleskySolve(aaBar.values, abBar)

// compute intercept
val intercept = if (fitIntercept) {
bBar - BLAS.dot(aBar, x)
} else {
0.0
}

new WeightedLeastSquaresModel(x, intercept)
}

/**
* Solves a symmetric positive definite linear system via Cholesky factorization.
* The input arguments are modified in-place to store the factorization and the solution.
* @param A the upper triangular part of A
* @param bx right-hand side
* @return the solution vector
*/
// TODO: SPARK-10490 - consolidate this and the Cholesky solver in ALS
private def choleskySolve(A: Array[Double], bx: DenseVector): DenseVector = {
val k = bx.size
val info = new intW(0)
lapack.dppsv("U", k, 1, A, bx.values, k, info)
val code = info.`val`
assert(code == 0, s"lapack.dpotrs returned $code.")
bx
}
}

private[ml] object WeightedLeastSquares {

/**
* Case class for weighted observations.
* @param w weight, must be positive
* @param a features
* @param b label
*/
case class Instance(w: Double, a: Vector, b: Double) {
require(w >= 0.0, s"Weight cannot be negative: $w.")
}

/**
* Aggregator to provide necessary summary statistics for solving [[WeightedLeastSquares]].
*/
// TODO: consolidate aggregates for summary statistics
private class Aggregator extends Serializable {
var initialized: Boolean = false
var k: Int = _
var count: Long = _
var triK: Int = _
private var wSum: Double = _
private var wwSum: Double = _
private var bSum: Double = _
private var bbSum: Double = _
private var aSum: DenseVector = _
private var abSum: DenseVector = _
private var aaSum: DenseVector = _

private def init(k: Int): Unit = {
require(k <= 4096, "In order to take the normal equation approach efficiently, " +
s"we set the max number of features to 4096 but got $k.")
this.k = k
triK = k * (k + 1) / 2
count = 0L
wSum = 0.0
wwSum = 0.0
bSum = 0.0
bbSum = 0.0
aSum = new DenseVector(Array.ofDim(k))
abSum = new DenseVector(Array.ofDim(k))
aaSum = new DenseVector(Array.ofDim(triK))
initialized = true
}

/**
* Adds an instance.
*/
def add(instance: Instance): this.type = {
val Instance(w, a, b) = instance
val ak = a.size
if (!initialized) {
init(ak)
initialized = true
}
assert(ak == k, s"Dimension mismatch. Expect vectors of size $k but got $ak.")
count += 1L
wSum += w
wwSum += w * w
bSum += w * b
bbSum += w * b * b
BLAS.axpy(w, a, aSum)
BLAS.axpy(w * b, a, abSum)
RowMatrix.dspr(w, a, aaSum.values)
this
}

/**
* Merges another [[Aggregator]].
*/
def merge(other: Aggregator): this.type = {
if (!other.initialized) {
this
} else {
if (!initialized) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If this is not initialized but other is, can we just return other?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The contract of merge in Spark is that the first argument is mutable but not the second. See https://github.com/apache/spark/blob/master/core/src/main/scala/org/apache/spark/rdd/RDD.scala#L1068.

init(other.k)
}
assert(k == other.k, s"dimension mismatch: this.k = $k but other.k = ${other.k}")
count += other.count
wSum += other.wSum
wwSum += other.wwSum
bSum += other.bSum
bbSum += other.bbSum
BLAS.axpy(1.0, other.aSum, aSum)
BLAS.axpy(1.0, other.abSum, abSum)
BLAS.axpy(1.0, other.aaSum, aaSum)
this
}
}

/**
* Validates that we have seen observations.
*/
def validate(): Unit = {
assert(initialized, "Training dataset is empty.")
assert(wSum > 0.0, "Sum of weights cannot be zero.")
}

/**
* Weighted mean of features.
*/
def aBar: DenseVector = {
val output = aSum.copy
BLAS.scal(1.0 / wSum, output)
output
}

/**
* Weighted mean of labels.
*/
def bBar: Double = bSum / wSum

/**
* Weighted population standard deviation of labels.
*/
def bStd: Double = math.sqrt(bbSum / wSum - bBar * bBar)

/**
* Weighted mean of (label * features).
*/
def abBar: DenseVector = {
val output = abSum.copy
BLAS.scal(1.0 / wSum, output)
output
}

/**
* Weighted mean of (features * features^T^).
*/
def aaBar: DenseVector = {
val output = aaSum.copy
BLAS.scal(1.0 / wSum, output)
output
}

/**
* Weighted population variance of features.
*/
def aVar: DenseVector = {
val variance = Array.ofDim[Double](k)
var i = 0
var j = 2
val aaValues = aaSum.values
while (i < triK) {
val l = j - 2
val aw = aSum(l) / wSum
variance(l) = aaValues(i) / wSum - aw * aw
i += j
j += 1
}
new DenseVector(variance)
}
}
}
7 changes: 7 additions & 0 deletions mllib/src/main/scala/org/apache/spark/mllib/linalg/BLAS.scala
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,13 @@ private[spark] object BLAS extends Serializable with Logging {
}
}

/** Y += a * x */
private[spark] def axpy(a: Double, X: DenseMatrix, Y: DenseMatrix): Unit = {
require(X.numRows == Y.numRows && X.numCols == Y.numCols, "Dimension mismatch: " +
s"size(X) = ${(X.numRows, X.numCols)} but size(Y) = ${(Y.numRows, Y.numCols)}.")
f2jBLAS.daxpy(X.numRows * X.numCols, a, X.values, 1, Y.values, 1)
}

/**
* dot(x, y)
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -678,7 +678,8 @@ object RowMatrix {
*
* @param U the upper triangular part of the matrix packed in an array (column major)
*/
private def dspr(alpha: Double, v: Vector, U: Array[Double]): Unit = {
// TODO: SPARK-10491 - move this method to linalg.BLAS
private[spark] def dspr(alpha: Double, v: Vector, U: Array[Double]): Unit = {
// TODO: Find a better home (breeze?) for this method.
val n = v.size
v match {
Expand Down
Loading