Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
import org.apache.spark.ml.classification.ClassificationModel;
import org.apache.spark.ml.param.IntParam;
import org.apache.spark.ml.param.ParamMap;
import org.apache.spark.ml.util.Identifiable$;
import org.apache.spark.mllib.linalg.BLAS;
import org.apache.spark.mllib.linalg.Vector;
import org.apache.spark.mllib.linalg.Vectors;
Expand Down Expand Up @@ -103,7 +104,23 @@ public static void main(String[] args) throws Exception {
* However, this should still compile and run successfully.
*/
class MyJavaLogisticRegression
extends Classifier<Vector, MyJavaLogisticRegression, MyJavaLogisticRegressionModel> {
extends Classifier<Vector, MyJavaLogisticRegression, MyJavaLogisticRegressionModel> {

public MyJavaLogisticRegression() {
init();
}

public MyJavaLogisticRegression(String uid) {
this.uid_ = uid;
init();
}

private String uid_ = Identifiable$.MODULE$.randomUID("myJavaLogReg");

@Override
public String uid() {
return uid_;
}

/**
* Param for max number of iterations
Expand All @@ -117,7 +134,7 @@ class MyJavaLogisticRegression

int getMaxIter() { return (Integer) getOrDefault(maxIter); }

public MyJavaLogisticRegression() {
private void init() {
setMaxIter(100);
}

Expand All @@ -137,7 +154,7 @@ public MyJavaLogisticRegressionModel train(DataFrame dataset) {
Vector weights = Vectors.zeros(numFeatures); // Learning would happen here.

// Create a model, and return it.
return new MyJavaLogisticRegressionModel(this, weights);
return new MyJavaLogisticRegressionModel(uid(), weights).setParent(this);
}
}

Expand All @@ -149,17 +166,21 @@ public MyJavaLogisticRegressionModel train(DataFrame dataset) {
* However, this should still compile and run successfully.
*/
class MyJavaLogisticRegressionModel
extends ClassificationModel<Vector, MyJavaLogisticRegressionModel> {

private MyJavaLogisticRegression parent_;
public MyJavaLogisticRegression parent() { return parent_; }
extends ClassificationModel<Vector, MyJavaLogisticRegressionModel> {

private Vector weights_;
public Vector weights() { return weights_; }

public MyJavaLogisticRegressionModel(MyJavaLogisticRegression parent_, Vector weights_) {
this.parent_ = parent_;
this.weights_ = weights_;
public MyJavaLogisticRegressionModel(String uid, Vector weights) {
this.uid_ = uid;
this.weights_ = weights;
}

private String uid_ = Identifiable$.MODULE$.randomUID("myJavaLogReg");

@Override
public String uid() {
return uid_;
}

// This uses the default implementation of transform(), which reads column "features" and outputs
Expand Down Expand Up @@ -204,6 +225,6 @@ public Vector predictRaw(Vector features) {
*/
@Override
public MyJavaLogisticRegressionModel copy(ParamMap extra) {
return copyValues(new MyJavaLogisticRegressionModel(parent_, weights_), extra);
return copyValues(new MyJavaLogisticRegressionModel(uid(), weights_), extra);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ package org.apache.spark.examples.ml
import org.apache.spark.{SparkConf, SparkContext}
import org.apache.spark.ml.classification.{ClassificationModel, Classifier, ClassifierParams}
import org.apache.spark.ml.param.{IntParam, ParamMap}
import org.apache.spark.ml.util.Identifiable
import org.apache.spark.mllib.linalg.{BLAS, Vector, Vectors}
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.sql.{DataFrame, Row, SQLContext}
Expand Down Expand Up @@ -106,10 +107,12 @@ private trait MyLogisticRegressionParams extends ClassifierParams {
*
* NOTE: This is private since it is an example. In practice, you may not want it to be private.
*/
private class MyLogisticRegression
private class MyLogisticRegression(override val uid: String)
extends Classifier[Vector, MyLogisticRegression, MyLogisticRegressionModel]
with MyLogisticRegressionParams {

def this() = this(Identifiable.randomUID("myLogReg"))

setMaxIter(100) // Initialize

// The parameter setter is in this class since it should return type MyLogisticRegression.
Expand All @@ -125,7 +128,7 @@ private class MyLogisticRegression
val weights = Vectors.zeros(numFeatures) // Learning would happen here.

// Create a model, and return it.
new MyLogisticRegressionModel(this, weights)
new MyLogisticRegressionModel(uid, weights).setParent(this)
}
}

Expand All @@ -135,7 +138,7 @@ private class MyLogisticRegression
* NOTE: This is private since it is an example. In practice, you may not want it to be private.
*/
private class MyLogisticRegressionModel(
override val parent: MyLogisticRegression,
override val uid: String,
val weights: Vector)
extends ClassificationModel[Vector, MyLogisticRegressionModel]
with MyLogisticRegressionParams {
Expand Down Expand Up @@ -173,6 +176,6 @@ private class MyLogisticRegressionModel(
* This is used for the default implementation of [[transform()]].
*/
override def copy(extra: ParamMap): MyLogisticRegressionModel = {
copyValues(new MyLogisticRegressionModel(parent, weights), extra)
copyValues(new MyLogisticRegressionModel(uid, weights), extra)
}
}
10 changes: 9 additions & 1 deletion mllib/src/main/scala/org/apache/spark/ml/Model.scala
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,15 @@ abstract class Model[M <: Model[M]] extends Transformer {
* The parent estimator that produced this model.
* Note: For ensembles' component Models, this value can be null.
*/
val parent: Estimator[M]
var parent: Estimator[M] = _

/**
* Sets the parent of this model (Java API).
*/
def setParent(parent: Estimator[M]): M = {
this.parent = parent
this.asInstanceOf[M]
}

override def copy(extra: ParamMap): M = {
// The default implementation of Params.copy doesn't work for models.
Expand Down
11 changes: 7 additions & 4 deletions mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import scala.collection.mutable.ListBuffer
import org.apache.spark.Logging
import org.apache.spark.annotation.{AlphaComponent, DeveloperApi}
import org.apache.spark.ml.param.{Param, ParamMap, Params}
import org.apache.spark.ml.util.Identifiable
import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.types.StructType

Expand Down Expand Up @@ -80,7 +81,9 @@ abstract class PipelineStage extends Params with Logging {
* an identity transformer.
*/
@AlphaComponent
class Pipeline extends Estimator[PipelineModel] {
class Pipeline(override val uid: String) extends Estimator[PipelineModel] {

def this() = this(Identifiable.randomUID("pipeline"))

/**
* param for pipeline stages
Expand Down Expand Up @@ -148,7 +151,7 @@ class Pipeline extends Estimator[PipelineModel] {
}
}

new PipelineModel(this, transformers.toArray)
new PipelineModel(uid, transformers.toArray).setParent(this)
}

override def copy(extra: ParamMap): Pipeline = {
Expand All @@ -171,7 +174,7 @@ class Pipeline extends Estimator[PipelineModel] {
*/
@AlphaComponent
class PipelineModel private[ml] (
override val parent: Pipeline,
override val uid: String,
val stages: Array[Transformer])
extends Model[PipelineModel] with Logging {

Expand All @@ -190,6 +193,6 @@ class PipelineModel private[ml] (
}

override def copy(extra: ParamMap): PipelineModel = {
new PipelineModel(parent, stages)
new PipelineModel(uid, stages)
}
}
2 changes: 1 addition & 1 deletion mllib/src/main/scala/org/apache/spark/ml/Predictor.scala
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ abstract class Predictor[
// This handles a few items such as schema validation.
// Developers only need to implement train().
transformSchema(dataset.schema, logging = true)
copyValues(train(dataset))
copyValues(train(dataset).setParent(this))
}

override def copy(extra: ParamMap): Learner = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ import org.apache.spark.annotation.AlphaComponent
import org.apache.spark.ml.{PredictionModel, Predictor}
import org.apache.spark.ml.param.ParamMap
import org.apache.spark.ml.tree.{TreeClassifierParams, DecisionTreeParams, DecisionTreeModel, Node}
import org.apache.spark.ml.util.MetadataUtils
import org.apache.spark.ml.util.{Identifiable, MetadataUtils}
import org.apache.spark.mllib.linalg.Vector
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.tree.{DecisionTree => OldDecisionTree}
Expand All @@ -39,10 +39,12 @@ import org.apache.spark.sql.DataFrame
* features.
*/
@AlphaComponent
final class DecisionTreeClassifier
final class DecisionTreeClassifier(override val uid: String)
extends Predictor[Vector, DecisionTreeClassifier, DecisionTreeClassificationModel]
with DecisionTreeParams with TreeClassifierParams {

def this() = this(Identifiable.randomUID("dtc"))

// Override parameter setters from parent trait for Java API compatibility.

override def setMaxDepth(value: Int): this.type = super.setMaxDepth(value)
Expand Down Expand Up @@ -101,7 +103,7 @@ object DecisionTreeClassifier {
*/
@AlphaComponent
final class DecisionTreeClassificationModel private[ml] (
override val parent: DecisionTreeClassifier,
override val uid: String,
override val rootNode: Node)
extends PredictionModel[Vector, DecisionTreeClassificationModel]
with DecisionTreeModel with Serializable {
Expand All @@ -114,7 +116,7 @@ final class DecisionTreeClassificationModel private[ml] (
}

override def copy(extra: ParamMap): DecisionTreeClassificationModel = {
copyValues(new DecisionTreeClassificationModel(parent, rootNode), extra)
copyValues(new DecisionTreeClassificationModel(uid, rootNode), extra)
}

override def toString: String = {
Expand All @@ -138,6 +140,7 @@ private[ml] object DecisionTreeClassificationModel {
s"Cannot convert non-classification DecisionTreeModel (old API) to" +
s" DecisionTreeClassificationModel (new API). Algo is: ${oldModel.algo}")
val rootNode = Node.fromOld(oldModel.topNode, categoricalFeatures)
new DecisionTreeClassificationModel(parent, rootNode)
val uid = if (parent != null) parent.uid else Identifiable.randomUID("dtc")
new DecisionTreeClassificationModel(uid, rootNode)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ import org.apache.spark.ml.{PredictionModel, Predictor}
import org.apache.spark.ml.param.{Param, ParamMap}
import org.apache.spark.ml.regression.DecisionTreeRegressionModel
import org.apache.spark.ml.tree.{GBTParams, TreeClassifierParams, DecisionTreeModel, TreeEnsembleModel}
import org.apache.spark.ml.util.MetadataUtils
import org.apache.spark.ml.util.{Identifiable, MetadataUtils}
import org.apache.spark.mllib.linalg.Vector
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.tree.{GradientBoostedTrees => OldGBT}
Expand All @@ -44,10 +44,12 @@ import org.apache.spark.sql.DataFrame
* Note: Multiclass labels are not currently supported.
*/
@AlphaComponent
final class GBTClassifier
final class GBTClassifier(override val uid: String)
extends Predictor[Vector, GBTClassifier, GBTClassificationModel]
with GBTParams with TreeClassifierParams with Logging {

def this() = this(Identifiable.randomUID("gbtc"))

// Override parameter setters from parent trait for Java API compatibility.

// Parameters from TreeClassifierParams:
Expand Down Expand Up @@ -160,7 +162,7 @@ object GBTClassifier {
*/
@AlphaComponent
final class GBTClassificationModel(
override val parent: GBTClassifier,
override val uid: String,
private val _trees: Array[DecisionTreeRegressionModel],
private val _treeWeights: Array[Double])
extends PredictionModel[Vector, GBTClassificationModel]
Expand All @@ -184,7 +186,7 @@ final class GBTClassificationModel(
}

override def copy(extra: ParamMap): GBTClassificationModel = {
copyValues(new GBTClassificationModel(parent, _trees, _treeWeights), extra)
copyValues(new GBTClassificationModel(uid, _trees, _treeWeights), extra)
}

override def toString: String = {
Expand All @@ -210,6 +212,7 @@ private[ml] object GBTClassificationModel {
// parent, fittingParamMap for each tree is null since there are no good ways to set these.
DecisionTreeRegressionModel.fromOld(tree, null, categoricalFeatures)
}
new GBTClassificationModel(parent, newTrees, oldModel.treeWeights)
val uid = if (parent != null) parent.uid else Identifiable.randomUID("gbtc")
new GBTClassificationModel(parent.uid, newTrees, oldModel.treeWeights)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ import breeze.optimize.{CachedDiffFunction, DiffFunction}
import org.apache.spark.annotation.AlphaComponent
import org.apache.spark.ml.param._
import org.apache.spark.ml.param.shared._
import org.apache.spark.ml.util.Identifiable
import org.apache.spark.mllib.linalg._
import org.apache.spark.mllib.linalg.BLAS._
import org.apache.spark.mllib.regression.LabeledPoint
Expand All @@ -50,10 +51,12 @@ private[classification] trait LogisticRegressionParams extends ProbabilisticClas
* Currently, this class only supports binary classification.
*/
@AlphaComponent
class LogisticRegression
class LogisticRegression(override val uid: String)
extends ProbabilisticClassifier[Vector, LogisticRegression, LogisticRegressionModel]
with LogisticRegressionParams with Logging {

def this() = this(Identifiable.randomUID("logreg"))

/**
* Set the regularization parameter.
* Default is 0.0.
Expand Down Expand Up @@ -213,7 +216,7 @@ class LogisticRegression
(weightsWithIntercept, 0.0)
}

new LogisticRegressionModel(this, weights.compressed, intercept)
new LogisticRegressionModel(uid, weights.compressed, intercept)
}
}

Expand All @@ -224,7 +227,7 @@ class LogisticRegression
*/
@AlphaComponent
class LogisticRegressionModel private[ml] (
override val parent: LogisticRegression,
override val uid: String,
val weights: Vector,
val intercept: Double)
extends ProbabilisticClassificationModel[Vector, LogisticRegressionModel]
Expand Down Expand Up @@ -276,7 +279,7 @@ class LogisticRegressionModel private[ml] (
}

override def copy(extra: ParamMap): LogisticRegressionModel = {
copyValues(new LogisticRegressionModel(parent, weights, intercept), extra)
copyValues(new LogisticRegressionModel(uid, weights, intercept), extra)
}

override protected def raw2prediction(rawPrediction: Vector): Double = {
Expand Down
Loading