Skip to content

Commit a9e9262

Browse files
committed
Extract common method for preparing output fields.
1 parent 66d46ac commit a9e9262

File tree

1 file changed

+16
-13
lines changed

1 file changed

+16
-13
lines changed

mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoderEstimator.scala

Lines changed: 16 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -105,15 +105,12 @@ class OneHotEncoderEstimator @Since("2.3.0") (@Since("2.3.0") override val uid:
105105
override def transformSchema(schema: StructType): StructType = {
106106
val inputColNames = $(inputCols)
107107
val outputColNames = $(outputCols)
108-
val inputFields = schema.fields
109108

110109
OneHotEncoderEstimator.checkParamsValidity(inputColNames, outputColNames, schema)
111110

112-
val outputFields = inputColNames.zip(outputColNames).map { case (inputColName, outputColName) =>
113-
OneHotEncoderCommon.transformOutputColumnSchema(
114-
schema(inputColName), $(dropLast), outputColName)
115-
}
116-
StructType(inputFields ++ outputFields)
111+
val outputFields = OneHotEncoderEstimator.prepareOutputFields(
112+
inputColNames.map(schema(_)), outputColNames, $(dropLast))
113+
StructType(schema.fields ++ outputFields)
117114
}
118115

119116
@Since("2.3.0")
@@ -180,6 +177,16 @@ object OneHotEncoderEstimator extends DefaultParamsReadable[OneHotEncoderEstimat
180177
s"Output column $outputColName already exists.")
181178
}
182179
}
180+
181+
private[feature] def prepareOutputFields(
182+
inputCols: Seq[StructField],
183+
outputColNames: Seq[String],
184+
dropLast: Boolean): Seq[StructField] = {
185+
inputCols.zip(outputColNames).map { case (inputCol, outputColName) =>
186+
OneHotEncoderCommon.transformOutputColumnSchema(
187+
inputCol, dropLast, outputColName)
188+
}
189+
}
183190
}
184191

185192
@Since("2.3.0")
@@ -233,20 +240,16 @@ class OneHotEncoderModel private[ml] (
233240
override def transformSchema(schema: StructType): StructType = {
234241
val inputColNames = $(inputCols)
235242
val outputColNames = $(outputCols)
236-
val inputFields = schema.fields
237243

238244
OneHotEncoderEstimator.checkParamsValidity(inputColNames, outputColNames, schema)
239245

240246
require(inputColNames.length == categorySizes.length,
241247
s"The number of input columns ${inputColNames.length} must be the same as the number of " +
242248
s"features ${categorySizes.length} during fitting.")
243249

244-
val inputOutputPairs = inputColNames.zip(outputColNames)
245-
val outputFields = inputOutputPairs.map { case (inputColName, outputColName) =>
246-
OneHotEncoderCommon.transformOutputColumnSchema(
247-
schema(inputColName), $(dropLast), outputColName)
248-
}
249-
verifyNumOfValues(StructType(inputFields ++ outputFields))
250+
val outputFields = OneHotEncoderEstimator.prepareOutputFields(
251+
inputColNames.map(schema(_)), outputColNames, $(dropLast))
252+
verifyNumOfValues(StructType(schema.fields ++ outputFields))
250253
}
251254

252255
private def verifyNumOfValues(schema: StructType): StructType = {

0 commit comments

Comments
 (0)