Skip to content

Commit a0a32ef

Browse files
committed
Fix GaussianMixture training failed due to feature column type mistake
1 parent ef0ccbc commit a0a32ef

File tree

10 files changed

+19
-7
lines changed

10 files changed

+19
-7
lines changed

mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,7 @@ class BisectingKMeansModel private[ml] (
9999

100100
@Since("2.0.0")
101101
override def transform(dataset: Dataset[_]): DataFrame = {
102+
transformSchema(dataset.schema)
102103
val predictUDF = udf((vector: Vector) => predict(vector))
103104
dataset.withColumn($(predictionCol), predictUDF(col($(featuresCol))))
104105
}
@@ -222,6 +223,7 @@ class BisectingKMeans @Since("2.0.0") (
222223

223224
@Since("2.0.0")
224225
override def fit(dataset: Dataset[_]): BisectingKMeansModel = {
226+
transformSchema(dataset.schema)
225227
val rdd: RDD[OldVector] = dataset.select(col($(featuresCol))).rdd.map {
226228
case Row(point: Vector) => OldVectors.fromML(point)
227229
}

mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ import org.apache.spark.ml.stat.distribution.MultivariateGaussian
3030
import org.apache.spark.ml.util._
3131
import org.apache.spark.mllib.clustering.{GaussianMixture => MLlibGM}
3232
import org.apache.spark.mllib.linalg.{Matrices => OldMatrices, Matrix => OldMatrix,
33-
Vector => OldVector, Vectors => OldVectors, VectorUDT => OldVectorUDT}
33+
Vector => OldVector, Vectors => OldVectors}
3434
import org.apache.spark.rdd.RDD
3535
import org.apache.spark.sql.{DataFrame, Dataset, Row, SparkSession}
3636
import org.apache.spark.sql.functions.{col, udf}
@@ -61,9 +61,9 @@ private[clustering] trait GaussianMixtureParams extends Params with HasMaxIter w
6161
* @return output schema
6262
*/
6363
protected def validateAndTransformSchema(schema: StructType): StructType = {
64-
SchemaUtils.checkColumnType(schema, $(featuresCol), new OldVectorUDT)
64+
SchemaUtils.checkColumnType(schema, $(featuresCol), new VectorUDT)
6565
SchemaUtils.appendColumn(schema, $(predictionCol), IntegerType)
66-
SchemaUtils.appendColumn(schema, $(probabilityCol), new OldVectorUDT)
66+
SchemaUtils.appendColumn(schema, $(probabilityCol), new VectorUDT)
6767
}
6868
}
6969

@@ -95,6 +95,7 @@ class GaussianMixtureModel private[ml] (
9595

9696
@Since("2.0.0")
9797
override def transform(dataset: Dataset[_]): DataFrame = {
98+
transformSchema(dataset.schema)
9899
val predUDF = udf((vector: Vector) => predict(vector))
99100
val probUDF = udf((vector: Vector) => predictProbability(vector))
100101
dataset.withColumn($(predictionCol), predUDF(col($(featuresCol))))
@@ -317,6 +318,7 @@ class GaussianMixture @Since("2.0.0") (
317318

318319
@Since("2.0.0")
319320
override def fit(dataset: Dataset[_]): GaussianMixtureModel = {
321+
transformSchema(dataset.schema)
320322
val rdd: RDD[OldVector] = dataset.select(col($(featuresCol))).rdd.map {
321323
case Row(point: Vector) => OldVectors.fromML(point)
322324
}

mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,7 @@ class KMeansModel private[ml] (
120120

121121
@Since("2.0.0")
122122
override def transform(dataset: Dataset[_]): DataFrame = {
123+
transformSchema(dataset.schema)
123124
val predictUDF = udf((vector: Vector) => predict(vector))
124125
dataset.withColumn($(predictionCol), predictUDF(col($(featuresCol))))
125126
}
@@ -304,6 +305,7 @@ class KMeans @Since("1.5.0") (
304305

305306
@Since("2.0.0")
306307
override def fit(dataset: Dataset[_]): KMeansModel = {
308+
transformSchema(dataset.schema)
307309
val rdd: RDD[OldVector] = dataset.select(col($(featuresCol))).rdd.map {
308310
case Row(point: Vector) => OldVectors.fromML(point)
309311
}

mllib/src/main/scala/org/apache/spark/ml/feature/Interaction.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,7 @@ class Interaction @Since("1.6.0") (@Since("1.6.0") override val uid: String) ext
6868

6969
@Since("2.0.0")
7070
override def transform(dataset: Dataset[_]): DataFrame = {
71+
transformSchema(dataset.schema)
7172
val inputFeatures = $(inputCols).map(c => dataset.schema(c))
7273
val featureEncoders = getFeatureEncoders(inputFeatures)
7374
val featureAttrs = getFeatureAttrs(inputFeatures)

mllib/src/main/scala/org/apache/spark/ml/feature/MinMaxScaler.scala

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,7 @@ class MinMaxScaler @Since("1.5.0") (@Since("1.5.0") override val uid: String)
111111

112112
@Since("2.0.0")
113113
override def fit(dataset: Dataset[_]): MinMaxScalerModel = {
114-
transformSchema(dataset.schema, logging = true)
114+
transformSchema(dataset.schema)
115115
val input: RDD[OldVector] = dataset.select($(inputCol)).rdd.map {
116116
case Row(v: Vector) => OldVectors.fromML(v)
117117
}
@@ -170,6 +170,7 @@ class MinMaxScalerModel private[ml] (
170170

171171
@Since("2.0.0")
172172
override def transform(dataset: Dataset[_]): DataFrame = {
173+
transformSchema(dataset.schema)
173174
val originalRange = (originalMax.asBreeze - originalMin.asBreeze).toArray
174175
val minArray = originalMin.toArray
175176

mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,7 @@ final class QuantileDiscretizer @Since("1.6.0") (@Since("1.6.0") override val ui
9797

9898
@Since("1.6.0")
9999
override def transformSchema(schema: StructType): StructType = {
100-
SchemaUtils.checkColumnType(schema, $(inputCol), DoubleType)
100+
SchemaUtils.checkNumericType(schema, $(inputCol))
101101
val inputFields = schema.fields
102102
require(inputFields.forall(_.name != $(outputCol)),
103103
s"Output column ${$(outputCol)} already exists.")
@@ -108,6 +108,7 @@ final class QuantileDiscretizer @Since("1.6.0") (@Since("1.6.0") override val ui
108108

109109
@Since("2.0.0")
110110
override def fit(dataset: Dataset[_]): Bucketizer = {
111+
transformSchema(dataset.schema)
111112
val splits = dataset.stat.approxQuantile($(inputCol),
112113
(0.0 to 1.0 by 1.0/$(numBuckets)).toArray, $(relativeError))
113114
splits(0) = Double.NegativeInfinity

mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,7 @@ class RFormula @Since("1.5.0") (@Since("1.5.0") override val uid: String)
112112

113113
@Since("2.0.0")
114114
override def fit(dataset: Dataset[_]): RFormulaModel = {
115+
transformSchema(dataset.schema)
115116
require(isDefined(formula), "Formula must be defined first.")
116117
val parsedFormula = RFormulaParser.parse($(formula))
117118
val resolvedFormula = parsedFormula.resolve(dataset.schema)

mllib/src/main/scala/org/apache/spark/ml/feature/SQLTransformer.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@ class SQLTransformer @Since("1.6.0") (@Since("1.6.0") override val uid: String)
6363

6464
@Since("2.0.0")
6565
override def transform(dataset: Dataset[_]): DataFrame = {
66+
transformSchema(dataset.schema)
6667
val tableName = Identifiable.randomUID(uid)
6768
dataset.createOrReplaceTempView(tableName)
6869
val realStatement = $(statement).replace(tableIdentifier, tableName)

mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -196,7 +196,7 @@ class AFTSurvivalRegression @Since("1.6.0") (@Since("1.6.0") override val uid: S
196196

197197
@Since("2.0.0")
198198
override def fit(dataset: Dataset[_]): AFTSurvivalRegressionModel = {
199-
validateAndTransformSchema(dataset.schema, fitting = true)
199+
transformSchema(dataset.schema)
200200
val instances = extractAFTPoints(dataset)
201201
val handlePersistence = dataset.rdd.getStorageLevel == StorageLevel.NONE
202202
if (handlePersistence) instances.persist(StorageLevel.MEMORY_AND_DISK)

mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -164,7 +164,7 @@ class IsotonicRegression @Since("1.5.0") (@Since("1.5.0") override val uid: Stri
164164

165165
@Since("2.0.0")
166166
override def fit(dataset: Dataset[_]): IsotonicRegressionModel = {
167-
validateAndTransformSchema(dataset.schema, fitting = true)
167+
transformSchema(dataset.schema)
168168
// Extract columns from data. If dataset is persisted, do not persist oldDataset.
169169
val instances = extractWeightedLabeledPoints(dataset)
170170
val handlePersistence = dataset.rdd.getStorageLevel == StorageLevel.NONE
@@ -234,6 +234,7 @@ class IsotonicRegressionModel private[ml] (
234234

235235
@Since("2.0.0")
236236
override def transform(dataset: Dataset[_]): DataFrame = {
237+
transformSchema(dataset.schema)
237238
val predict = dataset.schema($(featuresCol)).dataType match {
238239
case DoubleType =>
239240
udf { feature: Double => oldModel.predict(feature) }

0 commit comments

Comments
 (0)