Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 38 additions & 0 deletions mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,18 @@ class Pipeline @Since("1.4.0") (
@Since("1.2.0")
def getStages: Array[PipelineStage] = $(stages).clone()

/** Returns stage at index i in Pipeline */
@Since("2.0.0")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can't be since 2.0.0 at this point. Also use a @return tag in the docs.

def getStage[T <: PipelineStage](i: Int): T = getStages.apply(i).asInstanceOf[T]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here and elsewhere, I would just write array(i) instead of array.apply(i)

Also, can you please document the @tparam?


/** Returns all stages of this type */
@Since("2.0.0")
def getStagesOfType[T <: PipelineStage]: Array[T] = {
getStages.collect {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: This is probably more natural as a one-liner

case stage: T => stage
}
}

/**
* Fits the pipeline to the input dataset with additional parameters. If a stage is an
* [[Estimator]], its [[Estimator#fit]] method will be called on the input dataset to fit a model.
Expand Down Expand Up @@ -309,6 +321,32 @@ class PipelineModel private[ml] (

@Since("1.6.0")
override def write: MLWriter = new PipelineModel.PipelineModelWriter(this)

/** Returns stage at index i in PipelineModel */
@Since("2.0.0")
def getStage[T <: Transformer](i: Int): T = {
stages.apply(i).asInstanceOf[T]
}

/**
* Returns stage given its parent or generating instance in PipelineModel.
* E.g., if this PipelineModel was created from a Pipeline containing a stage
* {{myStage}}, then passing {{myStage}} to this method will return the
* corresponding stage in this PipelineModel.
*/
@Since("2.0.0")
def getStage[T <: Transformer, E <: PipelineStage](stage: E): T = {
val idxInPipeline = this.parent.asInstanceOf[Pipeline].getStages.indexOf(stage)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should test against the uid, not the stage reference.

stages.apply(idxInPipeline).asInstanceOf[T]
}

/** Returns all stages of this type */
@Since("2.0.0")
def getStagesOfType[T <: Transformer]: Array[T] = {
stages.collect {
case stage: T => stage
}
}
}

@Since("1.6.0")
Expand Down