@@ -105,15 +105,12 @@ class OneHotEncoderEstimator @Since("2.3.0") (@Since("2.3.0") override val uid:
105105 override def transformSchema (schema : StructType ): StructType = {
106106 val inputColNames = $(inputCols)
107107 val outputColNames = $(outputCols)
108- val inputFields = schema.fields
109108
110109 OneHotEncoderEstimator .checkParamsValidity(inputColNames, outputColNames, schema)
111110
112- val outputFields = inputColNames.zip(outputColNames).map { case (inputColName, outputColName) =>
113- OneHotEncoderCommon .transformOutputColumnSchema(
114- schema(inputColName), $(dropLast), outputColName)
115- }
116- StructType (inputFields ++ outputFields)
111+ val outputFields = OneHotEncoderEstimator .prepareOutputFields(
112+ inputColNames.map(schema(_)), outputColNames, $(dropLast))
113+ StructType (schema.fields ++ outputFields)
117114 }
118115
119116 @ Since (" 2.3.0" )
@@ -180,6 +177,16 @@ object OneHotEncoderEstimator extends DefaultParamsReadable[OneHotEncoderEstimat
180177 s " Output column $outputColName already exists. " )
181178 }
182179 }
180+
181+ private [feature] def prepareOutputFields (
182+ inputCols : Seq [StructField ],
183+ outputColNames : Seq [String ],
184+ dropLast : Boolean ): Seq [StructField ] = {
185+ inputCols.zip(outputColNames).map { case (inputCol, outputColName) =>
186+ OneHotEncoderCommon .transformOutputColumnSchema(
187+ inputCol, dropLast, outputColName)
188+ }
189+ }
183190}
184191
185192@ Since (" 2.3.0" )
@@ -233,20 +240,16 @@ class OneHotEncoderModel private[ml] (
233240 override def transformSchema (schema : StructType ): StructType = {
234241 val inputColNames = $(inputCols)
235242 val outputColNames = $(outputCols)
236- val inputFields = schema.fields
237243
238244 OneHotEncoderEstimator .checkParamsValidity(inputColNames, outputColNames, schema)
239245
240246 require(inputColNames.length == categorySizes.length,
241247 s " The number of input columns ${inputColNames.length} must be the same as the number of " +
242248 s " features ${categorySizes.length} during fitting. " )
243249
244- val inputOutputPairs = inputColNames.zip(outputColNames)
245- val outputFields = inputOutputPairs.map { case (inputColName, outputColName) =>
246- OneHotEncoderCommon .transformOutputColumnSchema(
247- schema(inputColName), $(dropLast), outputColName)
248- }
249- verifyNumOfValues(StructType (inputFields ++ outputFields))
250+ val outputFields = OneHotEncoderEstimator .prepareOutputFields(
251+ inputColNames.map(schema(_)), outputColNames, $(dropLast))
252+ verifyNumOfValues(StructType (schema.fields ++ outputFields))
250253 }
251254
252255 private def verifyNumOfValues (schema : StructType ): StructType = {
0 commit comments