Skip to content

Commit d4adf4e

Browse files
author
Marcelo Vanzin
committed
Merge branch 'master' into SPARK-8297
2 parents 3b262e8 + 3744b7f commit d4adf4e

File tree

50 files changed

+1050
-530
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

50 files changed

+1050
-530
lines changed

R/pkg/R/mllib.R

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ setClass("PipelineModel", representation(model = "jobj"))
2727
#' Fits a generalized linear model, similarly to R's glm(). Also see the glmnet package.
2828
#'
2929
#' @param formula A symbolic description of the model to be fitted. Currently only a few formula
30-
#' operators are supported, including '~' and '+'.
30+
#' operators are supported, including '~', '+', '-', and '.'.
3131
#' @param data DataFrame for training
3232
#' @param family Error distribution. "gaussian" -> linear regression, "binomial" -> logistic reg.
3333
#' @param lambda Regularization parameter

R/pkg/inst/tests/test_mllib.R

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,3 +40,11 @@ test_that("predictions match with native glm", {
4040
rVals <- predict(glm(Sepal.Width ~ Sepal.Length + Species, data = iris), iris)
4141
expect_true(all(abs(rVals - vals) < 1e-6), rVals - vals)
4242
})
43+
44+
test_that("dot minus and intercept vs native glm", {
45+
training <- createDataFrame(sqlContext, iris)
46+
model <- glm(Sepal_Width ~ . - Species + 0, data = training)
47+
vals <- collect(select(predict(model, training), "prediction"))
48+
rVals <- predict(glm(Sepal.Width ~ . - Species + 0, data = iris), iris)
49+
expect_true(all(abs(rVals - vals) < 1e-6), rVals - vals)
50+
})

docs/configuration.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -203,7 +203,7 @@ Apart from these, the following properties are also available, and may be useful
203203
<td><code>spark.driver.extraClassPath</code></td>
204204
<td>(none)</td>
205205
<td>
206-
Extra classpath entries to append to the classpath of the driver.
206+
Extra classpath entries to prepend to the classpath of the driver.
207207

208208
<br /><em>Note:</em> In client mode, this config must not be set through the <code>SparkConf</code>
209209
directly in your application, because the driver JVM has already started at that point.
@@ -250,7 +250,7 @@ Apart from these, the following properties are also available, and may be useful
250250
<td><code>spark.executor.extraClassPath</code></td>
251251
<td>(none)</td>
252252
<td>
253-
Extra classpath entries to append to the classpath of executors. This exists primarily for
253+
Extra classpath entries to prepend to the classpath of executors. This exists primarily for
254254
backwards-compatibility with older versions of Spark. Users typically should not need to set
255255
this option.
256256
</td>

mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala

Lines changed: 18 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -78,13 +78,20 @@ class RFormula(override val uid: String) extends Estimator[RFormulaModel] with R
7878
/** @group getParam */
7979
def getFormula: String = $(formula)
8080

81+
/** Whether the formula specifies fitting an intercept. */
82+
private[ml] def hasIntercept: Boolean = {
83+
require(parsedFormula.isDefined, "Must call setFormula() first.")
84+
parsedFormula.get.hasIntercept
85+
}
86+
8187
override def fit(dataset: DataFrame): RFormulaModel = {
8288
require(parsedFormula.isDefined, "Must call setFormula() first.")
89+
val resolvedFormula = parsedFormula.get.resolve(dataset.schema)
8390
// StringType terms and terms representing interactions need to be encoded before assembly.
8491
// TODO(ekl) add support for feature interactions
85-
var encoderStages = ArrayBuffer[PipelineStage]()
86-
var tempColumns = ArrayBuffer[String]()
87-
val encodedTerms = parsedFormula.get.terms.map { term =>
92+
val encoderStages = ArrayBuffer[PipelineStage]()
93+
val tempColumns = ArrayBuffer[String]()
94+
val encodedTerms = resolvedFormula.terms.map { term =>
8895
dataset.schema(term) match {
8996
case column if column.dataType == StringType =>
9097
val indexCol = term + "_idx_" + uid
@@ -103,7 +110,7 @@ class RFormula(override val uid: String) extends Estimator[RFormulaModel] with R
103110
.setOutputCol($(featuresCol))
104111
encoderStages += new ColumnPruner(tempColumns.toSet)
105112
val pipelineModel = new Pipeline(uid).setStages(encoderStages.toArray).fit(dataset)
106-
copyValues(new RFormulaModel(uid, parsedFormula.get, pipelineModel).setParent(this))
113+
copyValues(new RFormulaModel(uid, resolvedFormula, pipelineModel).setParent(this))
107114
}
108115

109116
// optimistic schema; does not contain any ML attributes
@@ -124,13 +131,13 @@ class RFormula(override val uid: String) extends Estimator[RFormulaModel] with R
124131
/**
125132
* :: Experimental ::
126133
* A fitted RFormula. Fitting is required to determine the factor levels of formula terms.
127-
* @param parsedFormula a pre-parsed R formula.
134+
* @param resolvedFormula the fitted R formula.
128135
* @param pipelineModel the fitted feature model, including factor to index mappings.
129136
*/
130137
@Experimental
131138
class RFormulaModel private[feature](
132139
override val uid: String,
133-
parsedFormula: ParsedRFormula,
140+
resolvedFormula: ResolvedRFormula,
134141
pipelineModel: PipelineModel)
135142
extends Model[RFormulaModel] with RFormulaBase {
136143

@@ -144,8 +151,8 @@ class RFormulaModel private[feature](
144151
val withFeatures = pipelineModel.transformSchema(schema)
145152
if (hasLabelCol(schema)) {
146153
withFeatures
147-
} else if (schema.exists(_.name == parsedFormula.label)) {
148-
val nullable = schema(parsedFormula.label).dataType match {
154+
} else if (schema.exists(_.name == resolvedFormula.label)) {
155+
val nullable = schema(resolvedFormula.label).dataType match {
149156
case _: NumericType | BooleanType => false
150157
case _ => true
151158
}
@@ -158,12 +165,12 @@ class RFormulaModel private[feature](
158165
}
159166

160167
override def copy(extra: ParamMap): RFormulaModel = copyValues(
161-
new RFormulaModel(uid, parsedFormula, pipelineModel))
168+
new RFormulaModel(uid, resolvedFormula, pipelineModel))
162169

163-
override def toString: String = s"RFormulaModel(${parsedFormula})"
170+
override def toString: String = s"RFormulaModel(${resolvedFormula})"
164171

165172
private def transformLabel(dataset: DataFrame): DataFrame = {
166-
val labelName = parsedFormula.label
173+
val labelName = resolvedFormula.label
167174
if (hasLabelCol(dataset.schema)) {
168175
dataset
169176
} else if (dataset.schema.exists(_.name == labelName)) {
@@ -207,26 +214,3 @@ private class ColumnPruner(columnsToPrune: Set[String]) extends Transformer {
207214

208215
override def copy(extra: ParamMap): ColumnPruner = defaultCopy(extra)
209216
}
210-
211-
/**
212-
* Represents a parsed R formula.
213-
*/
214-
private[ml] case class ParsedRFormula(label: String, terms: Seq[String])
215-
216-
/**
217-
* Limited implementation of R formula parsing. Currently supports: '~', '+'.
218-
*/
219-
private[ml] object RFormulaParser extends RegexParsers {
220-
def term: Parser[String] = "([a-zA-Z]|\\.[a-zA-Z_])[a-zA-Z0-9._]*".r
221-
222-
def expr: Parser[List[String]] = term ~ rep("+" ~> term) ^^ { case a ~ list => a :: list }
223-
224-
def formula: Parser[ParsedRFormula] =
225-
(term ~ "~" ~ expr) ^^ { case r ~ "~" ~ t => ParsedRFormula(r, t.distinct) }
226-
227-
def parse(value: String): ParsedRFormula = parseAll(formula, value) match {
228-
case Success(result, _) => result
229-
case failure: NoSuccess => throw new IllegalArgumentException(
230-
"Could not parse formula: " + value)
231-
}
232-
}
Lines changed: 129 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,129 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.ml.feature
19+
20+
import scala.util.parsing.combinator.RegexParsers
21+
22+
import org.apache.spark.mllib.linalg.VectorUDT
23+
import org.apache.spark.sql.types._
24+
25+
/**
26+
* Represents a parsed R formula.
27+
*/
28+
private[ml] case class ParsedRFormula(label: ColumnRef, terms: Seq[Term]) {
29+
/**
30+
* Resolves formula terms into column names. A schema is necessary for inferring the meaning
31+
* of the special '.' term. Duplicate terms will be removed during resolution.
32+
*/
33+
def resolve(schema: StructType): ResolvedRFormula = {
34+
var includedTerms = Seq[String]()
35+
terms.foreach {
36+
case Dot =>
37+
includedTerms ++= simpleTypes(schema).filter(_ != label.value)
38+
case ColumnRef(value) =>
39+
includedTerms :+= value
40+
case Deletion(term: Term) =>
41+
term match {
42+
case ColumnRef(value) =>
43+
includedTerms = includedTerms.filter(_ != value)
44+
case Dot =>
45+
// e.g. "- .", which removes all first-order terms
46+
val fromSchema = simpleTypes(schema)
47+
includedTerms = includedTerms.filter(fromSchema.contains(_))
48+
case _: Deletion =>
49+
assert(false, "Deletion terms cannot be nested")
50+
case _: Intercept =>
51+
}
52+
case _: Intercept =>
53+
}
54+
ResolvedRFormula(label.value, includedTerms.distinct)
55+
}
56+
57+
/** Whether this formula specifies fitting with an intercept term. */
58+
def hasIntercept: Boolean = {
59+
var intercept = true
60+
terms.foreach {
61+
case Intercept(enabled) =>
62+
intercept = enabled
63+
case Deletion(Intercept(enabled)) =>
64+
intercept = !enabled
65+
case _ =>
66+
}
67+
intercept
68+
}
69+
70+
// the dot operator excludes complex column types
71+
private def simpleTypes(schema: StructType): Seq[String] = {
72+
schema.fields.filter(_.dataType match {
73+
case _: NumericType | StringType | BooleanType | _: VectorUDT => true
74+
case _ => false
75+
}).map(_.name)
76+
}
77+
}
78+
79+
/**
80+
* Represents a fully evaluated and simplified R formula.
81+
*/
82+
private[ml] case class ResolvedRFormula(label: String, terms: Seq[String])
83+
84+
/**
85+
* R formula terms. See the R formula docs here for more information:
86+
* http://stat.ethz.ch/R-manual/R-patched/library/stats/html/formula.html
87+
*/
88+
private[ml] sealed trait Term
89+
90+
/* R formula reference to all available columns, e.g. "." in a formula */
91+
private[ml] case object Dot extends Term
92+
93+
/* R formula reference to a column, e.g. "+ Species" in a formula */
94+
private[ml] case class ColumnRef(value: String) extends Term
95+
96+
/* R formula intercept toggle, e.g. "+ 0" in a formula */
97+
private[ml] case class Intercept(enabled: Boolean) extends Term
98+
99+
/* R formula deletion of a variable, e.g. "- Species" in a formula */
100+
private[ml] case class Deletion(term: Term) extends Term
101+
102+
/**
103+
* Limited implementation of R formula parsing. Currently supports: '~', '+', '-', '.'.
104+
*/
105+
private[ml] object RFormulaParser extends RegexParsers {
106+
def intercept: Parser[Intercept] =
107+
"([01])".r ^^ { case a => Intercept(a == "1") }
108+
109+
def columnRef: Parser[ColumnRef] =
110+
"([a-zA-Z]|\\.[a-zA-Z_])[a-zA-Z0-9._]*".r ^^ { case a => ColumnRef(a) }
111+
112+
def term: Parser[Term] = intercept | columnRef | "\\.".r ^^ { case _ => Dot }
113+
114+
def terms: Parser[List[Term]] = (term ~ rep("+" ~ term | "-" ~ term)) ^^ {
115+
case op ~ list => list.foldLeft(List(op)) {
116+
case (left, "+" ~ right) => left ++ Seq(right)
117+
case (left, "-" ~ right) => left ++ Seq(Deletion(right))
118+
}
119+
}
120+
121+
def formula: Parser[ParsedRFormula] =
122+
(columnRef ~ "~" ~ terms) ^^ { case r ~ "~" ~ t => ParsedRFormula(r, t) }
123+
124+
def parse(value: String): ParsedRFormula = parseAll(formula, value) match {
125+
case Success(result, _) => result
126+
case failure: NoSuccess => throw new IllegalArgumentException(
127+
"Could not parse formula: " + value)
128+
}
129+
}

mllib/src/main/scala/org/apache/spark/ml/r/SparkRWrappers.scala

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,8 +32,14 @@ private[r] object SparkRWrappers {
3232
alpha: Double): PipelineModel = {
3333
val formula = new RFormula().setFormula(value)
3434
val estimator = family match {
35-
case "gaussian" => new LinearRegression().setRegParam(lambda).setElasticNetParam(alpha)
36-
case "binomial" => new LogisticRegression().setRegParam(lambda).setElasticNetParam(alpha)
35+
case "gaussian" => new LinearRegression()
36+
.setRegParam(lambda)
37+
.setElasticNetParam(alpha)
38+
.setFitIntercept(formula.hasIntercept)
39+
case "binomial" => new LogisticRegression()
40+
.setRegParam(lambda)
41+
.setElasticNetParam(alpha)
42+
.setFitIntercept(formula.hasIntercept)
3743
}
3844
val pipeline = new Pipeline().setStages(Array(formula, estimator))
3945
pipeline.fit(df)
Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.mllib.api.python
19+
20+
import java.util.{List => JList}
21+
22+
import scala.collection.JavaConverters._
23+
import scala.collection.mutable.ArrayBuffer
24+
25+
import org.apache.spark.SparkContext
26+
import org.apache.spark.mllib.linalg.{Vector, Vectors, Matrix}
27+
import org.apache.spark.mllib.clustering.GaussianMixtureModel
28+
29+
/**
30+
* Wrapper around GaussianMixtureModel to provide helper methods in Python
31+
*/
32+
private[python] class GaussianMixtureModelWrapper(model: GaussianMixtureModel) {
33+
val weights: Vector = Vectors.dense(model.weights)
34+
val k: Int = weights.size
35+
36+
/**
37+
* Returns gaussians as a List of Vectors and Matrices corresponding each MultivariateGaussian
38+
*/
39+
val gaussians: JList[Object] = {
40+
val modelGaussians = model.gaussians
41+
var i = 0
42+
var mu = ArrayBuffer.empty[Vector]
43+
var sigma = ArrayBuffer.empty[Matrix]
44+
while (i < k) {
45+
mu += modelGaussians(i).mu
46+
sigma += modelGaussians(i).sigma
47+
i += 1
48+
}
49+
List(mu.toArray, sigma.toArray).map(_.asInstanceOf[Object]).asJava
50+
}
51+
52+
def save(sc: SparkContext, path: String): Unit = model.save(sc, path)
53+
}

mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala

Lines changed: 2 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -364,7 +364,7 @@ private[python] class PythonMLLibAPI extends Serializable {
364364
seed: java.lang.Long,
365365
initialModelWeights: java.util.ArrayList[Double],
366366
initialModelMu: java.util.ArrayList[Vector],
367-
initialModelSigma: java.util.ArrayList[Matrix]): JList[Object] = {
367+
initialModelSigma: java.util.ArrayList[Matrix]): GaussianMixtureModelWrapper = {
368368
val gmmAlg = new GaussianMixture()
369369
.setK(k)
370370
.setConvergenceTol(convergenceTol)
@@ -382,16 +382,7 @@ private[python] class PythonMLLibAPI extends Serializable {
382382
if (seed != null) gmmAlg.setSeed(seed)
383383

384384
try {
385-
val model = gmmAlg.run(data.rdd.persist(StorageLevel.MEMORY_AND_DISK))
386-
var wt = ArrayBuffer.empty[Double]
387-
var mu = ArrayBuffer.empty[Vector]
388-
var sigma = ArrayBuffer.empty[Matrix]
389-
for (i <- 0 until model.k) {
390-
wt += model.weights(i)
391-
mu += model.gaussians(i).mu
392-
sigma += model.gaussians(i).sigma
393-
}
394-
List(Vectors.dense(wt.toArray), mu.toArray, sigma.toArray).map(_.asInstanceOf[Object]).asJava
385+
new GaussianMixtureModelWrapper(gmmAlg.run(data.rdd.persist(StorageLevel.MEMORY_AND_DISK)))
395386
} finally {
396387
data.rdd.unpersist(blocking = false)
397388
}

0 commit comments

Comments
 (0)