@@ -20,28 +20,28 @@ namespace Microsoft.ML.Runtime.Data
2020 /// <include file='doc.xml' path='doc/members/member[@name="NAHandle"]'/>
2121 public static class NAHandleTransform
2222 {
23- public enum ReplacementKind
23+ public enum ReplacementKind : byte
2424 {
2525 /// <summary>
2626 /// Replace with the default value of the column based on it's type. For example, 'zero' for numeric and 'empty' for string/text columns.
2727 /// </summary>
2828 [ EnumValueDisplay ( "Zero/empty" ) ]
29- DefaultValue ,
29+ DefaultValue = 0 ,
3030
3131 /// <summary>
3232 /// Replace with the mean value of the column. Supports only numeric/time span/ DateTime columns.
3333 /// </summary>
34- Mean ,
34+ Mean = 1 ,
3535
3636 /// <summary>
3737 /// Replace with the minimum value of the column. Supports only numeric/time span/ DateTime columns.
3838 /// </summary>
39- Minimum ,
39+ Minimum = 2 ,
4040
4141 /// <summary>
4242 /// Replace with the maximum value of the column. Supports only numeric/time span/ DateTime columns.
4343 /// </summary>
44- Maximum ,
44+ Maximum = 3 ,
4545
4646 [ HideEnumValue ]
4747 Def = DefaultValue ,
@@ -135,7 +135,7 @@ public static IDataTransform Create(IHostEnvironment env, Arguments args, IDataV
135135 h . CheckValue ( input , nameof ( input ) ) ;
136136 h . CheckUserArg ( Utils . Size ( args . Column ) > 0 , nameof ( args . Column ) ) ;
137137
138- var replaceCols = new List < NAReplaceTransform . Column > ( ) ;
138+ var replaceCols = new List < NAReplaceTransform . ColumnInfo > ( ) ;
139139 var naIndicatorCols = new List < NAIndicatorTransform . Column > ( ) ;
140140 var naConvCols = new List < ConvertTransform . Column > ( ) ;
141141 var concatCols = new List < ConcatTransform . TaggedColumn > ( ) ;
@@ -149,26 +149,16 @@ public static IDataTransform Create(IHostEnvironment env, Arguments args, IDataV
149149 var addInd = column . ConcatIndicator ?? args . Concat ;
150150 if ( ! addInd )
151151 {
152- replaceCols . Add (
153- new NAReplaceTransform . Column ( )
154- {
155- Kind = ( NAReplaceTransform . ReplacementKind ? ) column . Kind ,
156- Name = column . Name ,
157- Source = column . Source ,
158- Slot = column . ImputeBySlot
159- } ) ;
152+ replaceCols . Add ( new NAReplaceTransform . ColumnInfo ( column . Source , column . Name , ( NAReplaceTransform . ColumnInfo . ReplacementMode ) ( column . Kind ?? args . ReplaceWith ) , column . ImputeBySlot ?? args . ImputeBySlot ) ) ;
160153 continue ;
161154 }
162155
163156 // Check that the indicator column has a type that can be converted to the NAReplaceTransform output type,
164157 // so that they can be concatenated.
165- int inputCol ;
166- if ( ! input . Schema . TryGetColumnIndex ( column . Source , out inputCol ) )
158+ if ( ! input . Schema . TryGetColumnIndex ( column . Source , out int inputCol ) )
167159 throw h . Except ( "Column '{0}' does not exist" , column . Source ) ;
168160 var replaceType = input . Schema . GetColumnType ( inputCol ) ;
169- Delegate conv ;
170- bool identity ;
171- if ( ! Conversions . Instance . TryGetStandardConversion ( BoolType . Instance , replaceType . ItemType , out conv , out identity ) )
161+ if ( ! Conversions . Instance . TryGetStandardConversion ( BoolType . Instance , replaceType . ItemType , out Delegate conv , out bool identity ) )
172162 {
173163 throw h . Except ( "Cannot concatenate indicator column of type '{0}' to input column of type '{1}'" ,
174164 BoolType . Instance , replaceType . ItemType ) ;
@@ -186,14 +176,7 @@ public static IDataTransform Create(IHostEnvironment env, Arguments args, IDataV
186176 naConvCols . Add ( new ConvertTransform . Column ( ) { Name = tmpIsMissingColName , Source = tmpIsMissingColName , ResultType = replaceType . ItemType . RawKind } ) ;
187177
188178 // Add the NAReplaceTransform column.
189- replaceCols . Add (
190- new NAReplaceTransform . Column ( )
191- {
192- Kind = ( NAReplaceTransform . ReplacementKind ? ) column . Kind ,
193- Name = tmpReplacementColName ,
194- Source = column . Source ,
195- Slot = column . ImputeBySlot
196- } ) ;
179+ replaceCols . Add ( new NAReplaceTransform . ColumnInfo ( column . Source , tmpReplacementColName , ( NAReplaceTransform . ColumnInfo . ReplacementMode ) ( column . Kind ?? args . ReplaceWith ) , column . ImputeBySlot ?? args . ImputeBySlot ) ) ;
197180
198181 // Add the ConcatTransform column.
199182 if ( replaceType . IsVector )
@@ -237,15 +220,8 @@ public static IDataTransform Create(IHostEnvironment env, Arguments args, IDataV
237220 h . AssertValue ( output ) ;
238221 output = new ConvertTransform ( h , new ConvertTransform . Arguments ( ) { Column = naConvCols . ToArray ( ) } , output ) ;
239222 }
240-
241223 // Create the NAReplace transform.
242- output = new NAReplaceTransform ( h ,
243- new NAReplaceTransform . Arguments ( )
244- {
245- Column = replaceCols . ToArray ( ) ,
246- ReplacementKind = ( NAReplaceTransform . ReplacementKind ) args . ReplaceWith ,
247- ImputeBySlot = args . ImputeBySlot
248- } , output ?? input ) ;
224+ output = NAReplaceTransform . Create ( env , output ?? input , replaceCols . ToArray ( ) ) ;
249225
250226 // Concat the NAReplaceTransform output and the NAIndicatorTransform output.
251227 if ( naIndicatorCols . Count > 0 )
0 commit comments