@@ -21,16 +21,24 @@ public sealed class PipelineResultRow
2121 {
2222 public string GraphJson { get ; }
2323 public double MetricValue { get ; }
24+ public double TrainingMetricValue { get ; }
2425 public string PipelineId { get ; }
26+ public string FirstInput { get ; }
27+ public string PredictorModel { get ; }
2528
2629 public PipelineResultRow ( )
2730 { }
2831
29- public PipelineResultRow ( string graphJson , double metricValue , string pipelineId )
32+ public PipelineResultRow ( string graphJson , double metricValue ,
33+ string pipelineId , double trainingMetricValue , string firstInput ,
34+ string predictorModel )
3035 {
3136 GraphJson = graphJson ;
3237 MetricValue = metricValue ;
3338 PipelineId = pipelineId ;
39+ TrainingMetricValue = trainingMetricValue ;
40+ FirstInput = firstInput ;
41+ PredictorModel = predictorModel ;
3442 }
3543 }
3644
@@ -111,7 +119,8 @@ public AutoInference.EntryPointGraphDef ToEntryPointGraph(Experiment experiment
111119 public bool Equals ( PipelinePattern obj ) => obj != null && UniqueId == obj . UniqueId ;
112120
113121 // REVIEW: We may want to allow for sweeping with CV in the future, so we will need to add new methods like this, or refactor these in that case.
114- public Experiment CreateTrainTestExperiment ( IDataView trainData , IDataView testData , MacroUtils . TrainerKinds trainerKind , out Models . TrainTestEvaluator . Output resultsOutput )
122+ public Experiment CreateTrainTestExperiment ( IDataView trainData , IDataView testData , MacroUtils . TrainerKinds trainerKind ,
123+ bool includeTrainingMetrics , out Models . TrainTestEvaluator . Output resultsOutput )
115124 {
116125 var graphDef = ToEntryPointGraph ( ) ;
117126 var subGraph = graphDef . Graph ;
@@ -136,7 +145,8 @@ public Experiment CreateTrainTestExperiment(IDataView trainData, IDataView testD
136145 Model = finalOutput
137146 } ,
138147 PipelineId = UniqueId . ToString ( "N" ) ,
139- Kind = MacroUtils . TrainerKindApiValue < Models . MacroUtilsTrainerKinds > ( trainerKind )
148+ Kind = MacroUtils . TrainerKindApiValue < Models . MacroUtilsTrainerKinds > ( trainerKind ) ,
149+ IncludeTrainingMetrics = includeTrainingMetrics
140150 } ;
141151
142152 var experiment = _env . CreateExperiment ( ) ;
@@ -150,7 +160,7 @@ public Experiment CreateTrainTestExperiment(IDataView trainData, IDataView testD
150160 }
151161
152162 public Models . TrainTestEvaluator . Output AddAsTrainTest ( Var < IDataView > trainData , Var < IDataView > testData ,
153- MacroUtils . TrainerKinds trainerKind , Experiment experiment = null )
163+ MacroUtils . TrainerKinds trainerKind , Experiment experiment = null , bool includeTrainingMetrics = false )
154164 {
155165 experiment = experiment ?? _env . CreateExperiment ( ) ;
156166 var graphDef = ToEntryPointGraph ( experiment ) ;
@@ -174,7 +184,8 @@ public Models.TrainTestEvaluator.Output AddAsTrainTest(Var<IDataView> trainData,
174184 TrainingData = trainData ,
175185 TestingData = testData ,
176186 Kind = MacroUtils . TrainerKindApiValue < Models . MacroUtilsTrainerKinds > ( trainerKind ) ,
177- PipelineId = UniqueId . ToString ( "N" )
187+ PipelineId = UniqueId . ToString ( "N" ) ,
188+ IncludeTrainingMetrics = includeTrainingMetrics
178189 } ;
179190 var trainTestOutput = experiment . Add ( trainTestInput ) ;
180191 return trainTestOutput ;
@@ -183,34 +194,58 @@ public Models.TrainTestEvaluator.Output AddAsTrainTest(Var<IDataView> trainData,
183194 /// <summary>
184195 /// Runs a train-test experiment on the current pipeline, through entrypoints.
185196 /// </summary>
186- public double RunTrainTestExperiment ( IDataView trainData , IDataView testData , AutoInference . SupportedMetric metric , MacroUtils . TrainerKinds trainerKind )
197+ public void RunTrainTestExperiment ( IDataView trainData , IDataView testData ,
198+ AutoInference . SupportedMetric metric , MacroUtils . TrainerKinds trainerKind , out double testMetricValue ,
199+ out double trainMetricValue )
187200 {
188- var experiment = CreateTrainTestExperiment ( trainData , testData , trainerKind , out var trainTestOutput ) ;
201+ var experiment = CreateTrainTestExperiment ( trainData , testData , trainerKind , true , out var trainTestOutput ) ;
189202 experiment . Run ( ) ;
203+
190204 var dataOut = experiment . GetOutput ( trainTestOutput . OverallMetrics ) ;
191205 var schema = dataOut . Schema ;
192206 schema . TryGetColumnIndex ( metric . Name , out var metricCol ) ;
207+ double metricValue = 0 ;
208+ double trainingMetricValue = 0 ;
193209
194210 using ( var cursor = dataOut . GetRowCursor ( col => col == metricCol ) )
195211 {
196212 var getter = cursor . GetGetter < double > ( metricCol ) ;
197- double metricValue = 0 ;
198213 cursor . MoveNext ( ) ;
199214 getter ( ref metricValue ) ;
200- return metricValue ;
215+ }
216+
217+ dataOut = experiment . GetOutput ( trainTestOutput . TrainingOverallMetrics ) ;
218+ schema = dataOut . Schema ;
219+ schema . TryGetColumnIndex ( metric . Name , out metricCol ) ;
220+
221+ using ( var cursor = dataOut . GetRowCursor ( col => col == metricCol ) )
222+ {
223+ var getter = cursor . GetGetter < double > ( metricCol ) ;
224+ cursor . MoveNext ( ) ;
225+ getter ( ref trainingMetricValue ) ;
226+ testMetricValue = metricValue ;
227+ trainMetricValue = trainingMetricValue ;
201228 }
202229 }
203230
204- public static PipelineResultRow [ ] ExtractResults ( IHostEnvironment env , IDataView data , string graphColName , string metricColName , string idColName )
231+ public static PipelineResultRow [ ] ExtractResults ( IHostEnvironment env , IDataView data ,
232+ string graphColName , string metricColName , string idColName , string trainingMetricColName ,
233+ string firstInputColName , string predictorModelColName )
205234 {
206235 var results = new List < PipelineResultRow > ( ) ;
207236 var schema = data . Schema ;
208237 if ( ! schema . TryGetColumnIndex ( graphColName , out var graphCol ) )
209238 throw env . ExceptNotSupp ( $ "Column name { graphColName } not found") ;
210239 if ( ! schema . TryGetColumnIndex ( metricColName , out var metricCol ) )
211240 throw env . ExceptNotSupp ( $ "Column name { metricColName } not found") ;
241+ if ( ! schema . TryGetColumnIndex ( trainingMetricColName , out var trainingMetricCol ) )
242+ throw env . ExceptNotSupp ( $ "Column name { trainingMetricColName } not found") ;
212243 if ( ! schema . TryGetColumnIndex ( idColName , out var pipelineIdCol ) )
213244 throw env . ExceptNotSupp ( $ "Column name { idColName } not found") ;
245+ if ( ! schema . TryGetColumnIndex ( firstInputColName , out var firstInputCol ) )
246+ throw env . ExceptNotSupp ( $ "Column name { firstInputColName } not found") ;
247+ if ( ! schema . TryGetColumnIndex ( predictorModelColName , out var predictorModelCol ) )
248+ throw env . ExceptNotSupp ( $ "Column name { predictorModelColName } not found") ;
214249
215250 using ( var cursor = data . GetRowCursor ( col => true ) )
216251 {
@@ -225,15 +260,33 @@ public static PipelineResultRow[] ExtractResults(IHostEnvironment env, IDataView
225260 var getter3 = cursor . GetGetter < DvText > ( pipelineIdCol ) ;
226261 DvText pipelineId = new DvText ( ) ;
227262 getter3 ( ref pipelineId ) ;
228- results . Add ( new PipelineResultRow ( graphJson . ToString ( ) , metricValue , pipelineId . ToString ( ) ) ) ;
263+ var getter4 = cursor . GetGetter < double > ( trainingMetricCol ) ;
264+ double trainingMetricValue = 0 ;
265+ getter4 ( ref trainingMetricValue ) ;
266+ var getter5 = cursor . GetGetter < DvText > ( firstInputCol ) ;
267+ DvText firstInput = new DvText ( ) ;
268+ getter5 ( ref firstInput ) ;
269+ var getter6 = cursor . GetGetter < DvText > ( predictorModelCol ) ;
270+ DvText predictorModel = new DvText ( ) ;
271+ getter6 ( ref predictorModel ) ;
272+
273+ results . Add ( new PipelineResultRow ( graphJson . ToString ( ) ,
274+ metricValue , pipelineId . ToString ( ) , trainingMetricValue ,
275+ firstInput . ToString ( ) , predictorModel . ToString ( ) ) ) ;
229276 }
230277 }
231278
232279 return results . ToArray ( ) ;
233280 }
234281
235- public PipelineResultRow ToResultRow ( ) =>
236- new PipelineResultRow ( ToEntryPointGraph ( ) . Graph . ToJsonString ( ) ,
237- PerformanceSummary ? . MetricValue ?? - 1d , UniqueId . ToString ( "N" ) ) ;
282+ public PipelineResultRow ToResultRow ( ) {
283+ var graphDef = ToEntryPointGraph ( ) ;
284+
285+ return new PipelineResultRow ( $ "{{'Nodes' : [{ graphDef . Graph . ToJsonString ( ) } ]}}",
286+ PerformanceSummary ? . MetricValue ?? - 1d , UniqueId . ToString ( "N" ) ,
287+ PerformanceSummary ? . TrainingMetricValue ?? - 1d ,
288+ graphDef . GetSubgraphFirstNodeDataVarName ( _env ) ,
289+ graphDef . ModelOutput . VarName ) ;
290+ }
238291 }
239292}
0 commit comments