@@ -12,32 +12,28 @@ namespace Microsoft.ML.Auto
1212{ 
1313    public  static   class  RegressionExtensions 
1414    { 
15-         public  static   RegressionResult  AutoFit ( this  RegressionContext  context , 
15+         public  static   IEnumerable < IterationResult < RegressionMetrics > >  AutoFit ( this  RegressionContext  context , 
1616            IDataView  trainData , 
1717            string  label  =  DefaultColumnNames . Label , 
1818            IDataView  validationData  =  null , 
1919            uint  timeoutInMinutes  =  AutoFitDefaults . TimeOutInMinutes , 
2020            IEstimator < ITransformer >  preFeaturizers  =  null , 
21-             IEnumerable < ( string ,  ColumnPurpose ) >  columnPurposes  =  null , 
22-             CancellationToken  cancellationToken  =  default , 
23-             IProgress < RegressionIterationResult >  iterationCallback  =  null ) 
21+             IEnumerable < ( string ,  ColumnPurpose ) >  columnPurposes  =  null ) 
2422        { 
2523            var  settings  =  new  AutoFitSettings ( ) ; 
2624            settings . StoppingCriteria . TimeOutInMinutes  =  timeoutInMinutes ; 
2725
2826            return  AutoFit ( context ,  trainData ,  label ,  validationData ,  settings , 
29-                 preFeaturizers ,  columnPurposes ,  cancellationToken ,   iterationCallback ,   null ) ; 
27+                 preFeaturizers ,  columnPurposes ,  null ) ; 
3028        } 
3129
32-         internal  static   RegressionResult  AutoFit ( this  RegressionContext  context , 
30+         internal  static   IEnumerable < IterationResult < RegressionMetrics > >  AutoFit ( this  RegressionContext  context , 
3331            IDataView  trainData , 
3432            string  label  =  DefaultColumnNames . Label , 
3533            IDataView  validationData  =  null , 
3634            AutoFitSettings  settings  =  null , 
3735            IEstimator < ITransformer >  preFeaturizers  =  null , 
3836            IEnumerable < ( string ,  ColumnPurpose ) >  columnPurposes  =  null , 
39-             CancellationToken  cancellationToken  =  default , 
40-             IProgress < RegressionIterationResult >  iterationCallback  =  null , 
4137            IDebugLogger  debugLogger  =  null ) 
4238        { 
4339            UserInputValidationUtil . ValidateAutoFitArgs ( trainData ,  label ,  validationData ,  settings ,  columnPurposes ) ; 
@@ -48,49 +44,38 @@ internal static RegressionResult AutoFit(this RegressionContext context,
4844            } 
4945
5046            // run autofit & get all pipelines run in that process 
51-             var  ( allPipelines ,  bestPipeline )  =  AutoFitApi . Fit ( trainData ,  validationData ,  label , 
52-                 settings ,  preFeaturizers ,  TaskKind . Regression ,  OptimizingMetric . RSquared ,  columnPurposes ,  debugLogger ) ; 
47+             var  autoFitter  =  new  AutoFitter < RegressionMetrics > ( TaskKind . Regression ,  trainData ,  label ,  validationData , 
48+                 settings ,  preFeaturizers ,  columnPurposes , 
49+                 OptimizingMetric . RSquared ,  debugLogger ) ; 
5350
54-             var  results  =  new  RegressionIterationResult [ allPipelines . Length ] ; 
55-             for  ( var  i  =  0 ;  i  <  results . Length ;  i ++ ) 
56-             { 
57-                 var  iterationResult  =  allPipelines [ i ] ; 
58-                 var  result  =  new  RegressionIterationResult ( iterationResult . Model ,  ( RegressionMetrics ) iterationResult . EvaluatedMetrics ,  iterationResult . ScoredValidationData ,  iterationResult . Pipeline . ToPipeline ( ) ) ; 
59-                 results [ i ]  =  result ; 
60-             } 
61-             var  bestResult  =  new  RegressionIterationResult ( bestPipeline . Model ,  ( RegressionMetrics ) bestPipeline . EvaluatedMetrics ,  bestPipeline . ScoredValidationData ,  bestPipeline . Pipeline . ToPipeline ( ) ) ; 
62-             return  new  RegressionResult ( bestResult ,  results ) ; 
51+             return  autoFitter . Fit ( ) ; 
6352        } 
6453    } 
6554
6655    public  static   class  BinaryClassificationExtensions 
6756    { 
68-         public  static   BinaryClassificationResult  AutoFit ( this  BinaryClassificationContext  context , 
57+         public  static   IEnumerable < IterationResult < BinaryClassificationMetrics > >  AutoFit ( this  BinaryClassificationContext  context , 
6958            IDataView  trainData , 
7059            string  label  =  DefaultColumnNames . Label , 
7160            IDataView  validationData  =  null , 
7261            uint  timeoutInMinutes  =  AutoFitDefaults . TimeOutInMinutes , 
7362            IEstimator < ITransformer >  preFeaturizers  =  null , 
74-             IEnumerable < ( string ,  ColumnPurpose ) >  columnPurposes  =  null , 
75-             CancellationToken  cancellationToken  =  default , 
76-             IProgress < BinaryClassificationItertionResult >  iterationCallback  =  null ) 
63+             IEnumerable < ( string ,  ColumnPurpose ) >  columnPurposes  =  null ) 
7764        { 
7865            var  settings  =  new  AutoFitSettings ( ) ; 
7966            settings . StoppingCriteria . TimeOutInMinutes  =  timeoutInMinutes ; 
8067
8168            return  AutoFit ( context ,  trainData ,  label ,  validationData ,  settings , 
82-                 preFeaturizers ,  columnPurposes ,  cancellationToken ,   iterationCallback ,   null ) ; 
69+                 preFeaturizers ,  columnPurposes ,  null ) ; 
8370        } 
8471
85-         internal  static   BinaryClassificationResult  AutoFit ( this  BinaryClassificationContext  context , 
72+         internal  static   IEnumerable < IterationResult < BinaryClassificationMetrics > >  AutoFit ( this  BinaryClassificationContext  context , 
8673            IDataView  trainData , 
8774            string  label  =  DefaultColumnNames . Label , 
8875            IDataView  validationData  =  null , 
8976            AutoFitSettings  settings  =  null , 
9077            IEstimator < ITransformer >  preFeaturizers  =  null , 
9178            IEnumerable < ( string ,  ColumnPurpose ) >  columnPurposes  =  null , 
92-             CancellationToken  cancellationToken  =  default , 
93-             IProgress < BinaryClassificationItertionResult >  iterationCallback  =  null , 
9479            IDebugLogger  debugLogger  =  null ) 
9580        { 
9681            UserInputValidationUtil . ValidateAutoFitArgs ( trainData ,  label ,  validationData ,  settings ,  columnPurposes ) ; 
@@ -101,159 +86,67 @@ internal static BinaryClassificationResult AutoFit(this BinaryClassificationCont
10186            } 
10287
10388            // run autofit & get all pipelines run in that process 
104-             var  ( allPipelines ,  bestPipeline )  =  AutoFitApi . Fit ( trainData ,  validationData ,  label , 
105-                 settings ,  preFeaturizers ,  TaskKind . BinaryClassification ,  OptimizingMetric . Accuracy , 
106-                 columnPurposes ,  debugLogger ) ; 
107- 
108-             var  results  =  new  BinaryClassificationItertionResult [ allPipelines . Length ] ; 
109-             for  ( var  i  =  0 ;  i  <  results . Length ;  i ++ ) 
110-             { 
111-                 var  iterationResult  =  allPipelines [ i ] ; 
112-                 var  result  =  new  BinaryClassificationItertionResult ( iterationResult . Model ,  ( BinaryClassificationMetrics ) iterationResult . EvaluatedMetrics ,  iterationResult . ScoredValidationData ,  iterationResult . Pipeline . ToPipeline ( ) ) ; 
113-                 results [ i ]  =  result ; 
114-             } 
115-             var  bestResult  =  new  BinaryClassificationItertionResult ( bestPipeline . Model ,  ( BinaryClassificationMetrics ) bestPipeline . EvaluatedMetrics ,  bestPipeline . ScoredValidationData ,  bestPipeline . Pipeline . ToPipeline ( ) ) ; 
116-             return  new  BinaryClassificationResult ( bestResult ,  results ) ; 
89+             var  autoFitter  =  new  AutoFitter < BinaryClassificationMetrics > ( TaskKind . BinaryClassification ,  trainData ,  label ,  validationData , 
90+                 settings ,  preFeaturizers ,  columnPurposes ,  
91+                 OptimizingMetric . RSquared ,  debugLogger ) ; 
92+             
93+             return  autoFitter . Fit ( ) ; 
11794        } 
11895    } 
11996
12097    public  static   class  MulticlassExtensions 
12198    { 
122-         public  static   MulticlassClassificationResult  AutoFit ( this  MulticlassClassificationContext  context , 
99+         public  static   IEnumerable < IterationResult < MultiClassClassifierMetrics > >  AutoFit ( this  MulticlassClassificationContext  context , 
123100            IDataView  trainData , 
124101            string  label  =  DefaultColumnNames . Label , 
125102            IDataView  validationData  =  null , 
126103            uint  timeoutInMinutes  =  AutoFitDefaults . TimeOutInMinutes , 
127104            IEstimator < ITransformer >  preFeaturizers  =  null , 
128-             IEnumerable < ( string ,  ColumnPurpose ) >  columnPurposes  =  null , 
129-             CancellationToken  cancellationToken  =  default , 
130-             IProgress < MulticlassClassificationIterationResult >  iterationCallback  =  null ) 
105+             IEnumerable < ( string ,  ColumnPurpose ) >  columnPurposes  =  null ) 
131106        { 
132107            var  settings  =  new  AutoFitSettings ( ) ; 
133108            settings . StoppingCriteria . TimeOutInMinutes  =  timeoutInMinutes ; 
134109
135110            return  AutoFit ( context ,  trainData ,  label ,  validationData ,  settings , 
136-                 preFeaturizers ,  columnPurposes ,  cancellationToken ,   iterationCallback ,   null ) ; 
111+                 preFeaturizers ,  columnPurposes ,  null ) ; 
137112        } 
138113
139-         internal  static   MulticlassClassificationResult  AutoFit ( this  MulticlassClassificationContext  context , 
114+         internal  static   IEnumerable < IterationResult < MultiClassClassifierMetrics > >  AutoFit ( this  MulticlassClassificationContext  context , 
140115            IDataView  trainData , 
141116            string  label  =  DefaultColumnNames . Label , 
142117            IDataView  validationData  =  null , 
143118            AutoFitSettings  settings  =  null , 
144119            IEstimator < ITransformer >  preFeaturizers  =  null , 
145120            IEnumerable < ( string ,  ColumnPurpose ) >  columnPurposes  =  null , 
146-             CancellationToken  cancellationToken  =  default , 
147-             IProgress < MulticlassClassificationIterationResult >  iterationCallback  =  null ,  IDebugLogger  debugLogger  =  null ) 
121+             IDebugLogger  debugLogger  =  null ) 
148122        { 
149123            UserInputValidationUtil . ValidateAutoFitArgs ( trainData ,  label ,  validationData ,  settings ,  columnPurposes ) ; 
150124
151125            if  ( validationData  ==  null ) 
152126            { 
153127                ( trainData ,  validationData )  =  context . TestValidateSplit ( trainData ) ; 
154128            } 
155- 
129+              
156130            // run autofit & get all pipelines run in that process 
157-             var  ( allPipelines ,  bestPipeline )  =  AutoFitApi . Fit ( trainData ,  validationData ,  label , 
158-                 settings ,  preFeaturizers ,  TaskKind . MulticlassClassification ,  OptimizingMetric . Accuracy , 
159-                 columnPurposes ,  debugLogger ) ; 
160- 
161-             var  results  =  new  MulticlassClassificationIterationResult [ allPipelines . Length ] ; 
162-             for  ( var  i  =  0 ;  i  <  results . Length ;  i ++ ) 
163-             { 
164-                 var  iterationResult  =  allPipelines [ i ] ; 
165-                 var  result  =  new  MulticlassClassificationIterationResult ( iterationResult . Model ,  ( MultiClassClassifierMetrics ) iterationResult . EvaluatedMetrics ,  iterationResult . ScoredValidationData ,  iterationResult . Pipeline . ToPipeline ( ) ) ; 
166-                 results [ i ]  =  result ; 
167-             } 
168-             var  bestResult  =  new  MulticlassClassificationIterationResult ( bestPipeline . Model ,  ( MultiClassClassifierMetrics ) bestPipeline . EvaluatedMetrics ,  bestPipeline . ScoredValidationData ,  bestPipeline . Pipeline . ToPipeline ( ) ) ; 
169-             return  new  MulticlassClassificationResult ( bestResult ,  results ) ; 
170-         } 
171-     } 
172- 
173-     public  class  BinaryClassificationResult 
174-     { 
175-         public  readonly  BinaryClassificationItertionResult  BestIteration ; 
176-         public  readonly  BinaryClassificationItertionResult [ ]  IterationResults ; 
177- 
178-         public  BinaryClassificationResult ( BinaryClassificationItertionResult  bestPipeline , 
179-             BinaryClassificationItertionResult [ ]  iterationResults ) 
180-         { 
181-             BestIteration  =  bestPipeline ; 
182-             IterationResults  =  iterationResults ; 
183-         } 
184-     } 
185- 
186-     public  class  MulticlassClassificationResult 
187-     { 
188-         public  readonly  MulticlassClassificationIterationResult  BestIteration ; 
189-         public  readonly  MulticlassClassificationIterationResult [ ]  IterationResults ; 
190- 
191-         public  MulticlassClassificationResult ( MulticlassClassificationIterationResult  bestPipeline , 
192-             MulticlassClassificationIterationResult [ ]  iterationResults ) 
193-         { 
194-             BestIteration  =  bestPipeline ; 
195-             IterationResults  =  iterationResults ; 
196-         } 
197-     } 
198- 
199-     public  class  RegressionResult 
200-     { 
201-         public  readonly  RegressionIterationResult  BestIteration ; 
202-         public  readonly  RegressionIterationResult [ ]  IterationResults ; 
203- 
204-         public  RegressionResult ( RegressionIterationResult  bestPipeline , 
205-             RegressionIterationResult [ ]  iterationResults ) 
206-         { 
207-             BestIteration  =  bestPipeline ; 
208-             IterationResults  =  iterationResults ; 
209-         } 
210-     } 
211- 
212-     public  class  BinaryClassificationItertionResult 
213-     { 
214-         public  readonly  BinaryClassificationMetrics  Metrics ; 
215-         public  readonly  ITransformer  Model ; 
216-         public  readonly  IDataView  ScoredValidationData ; 
217-         internal  readonly  Pipeline  Pipeline ; 
218- 
219-         internal  BinaryClassificationItertionResult ( ITransformer  model ,  BinaryClassificationMetrics  metrics ,  IDataView  scoredValidationData ,  Pipeline  pipeline ) 
220-         { 
221-             Model  =  model ; 
222-             ScoredValidationData  =  scoredValidationData ; 
223-             Metrics  =  metrics ; 
224-             Pipeline  =  pipeline ; 
225-         } 
226-     } 
227- 
228-     public  class  MulticlassClassificationIterationResult 
229-     { 
230-         public  readonly  MultiClassClassifierMetrics  Metrics ; 
231-         public  readonly  ITransformer  Model ; 
232-         public  readonly  IDataView  ScoredValidationData ; 
233-         internal  readonly  Pipeline  Pipeline ; 
234- 
235-         internal  MulticlassClassificationIterationResult ( ITransformer  model ,  MultiClassClassifierMetrics  metrics ,  IDataView  scoredValidationData ,  Pipeline  pipeline ) 
236-         { 
237-             Model  =  model ; 
238-             Metrics  =  metrics ; 
239-             ScoredValidationData  =  scoredValidationData ; 
240-             Pipeline  =  pipeline ; 
131+             var  autoFitter  =  new  AutoFitter < MultiClassClassifierMetrics > ( TaskKind . MulticlassClassification ,  trainData ,  label ,  validationData , 
132+                 settings ,  preFeaturizers ,  columnPurposes ,  OptimizingMetric . RSquared ,  debugLogger ) ; 
133+             return  autoFitter . Fit ( ) ; 
241134        } 
242135    } 
243136
244-     public  class  RegressionIterationResult 
137+     public  class  IterationResult < T > 
245138    { 
246-         public  readonly  RegressionMetrics  Metrics ; 
139+         public  readonly  T  Metrics ; 
247140        public  readonly  ITransformer  Model ; 
248-         public  readonly  IDataView   ScoredValidationData ; 
141+         public  readonly  Exception   Exception ; 
249142        internal  readonly  Pipeline  Pipeline ; 
250143
251-         internal  RegressionIterationResult ( ITransformer  model ,  RegressionMetrics  metrics ,  IDataView   scoredValidationData ,   Pipeline   pipeline ) 
144+         internal  IterationResult ( ITransformer  model ,  T  metrics ,  Pipeline   pipeline ,   Exception   exception ) 
252145        { 
253146            Model  =  model ; 
254147            Metrics  =  metrics ; 
255-             ScoredValidationData  =  scoredValidationData ; 
256148            Pipeline  =  pipeline ; 
149+             Exception  =  exception ; 
257150        } 
258151    } 
259152} 
0 commit comments