@@ -28,8 +28,8 @@ public sealed class Arguments : DataCommand.ArgumentsBase
2828 [ Argument ( ArgumentType . Multiple , HelpText = "Trainer to use" , ShortName = "tr" ) ]
2929 public SubComponent < ITrainer , SignatureTrainer > Trainer = new SubComponent < ITrainer , SignatureTrainer > ( "AveragedPerceptron" ) ;
3030
31- [ Argument ( ArgumentType . Multiple , HelpText = "Scorer to use" , NullName = "<Auto>" , SortOrder = 101 ) ]
32- public SubComponent < IDataScorerTransform , SignatureDataScorer > Scorer ;
31+ [ Argument ( ArgumentType . Multiple , HelpText = "Scorer to use" , NullName = "<Auto>" , SortOrder = 101 , SignatureType = typeof ( SignatureDataScorer ) ) ]
32+ public IComponentFactory < IDataView , ISchemaBoundMapper , RoleMappedSchema , IDataScorerTransform > Scorer ;
3333
3434 [ Argument ( ArgumentType . Multiple , HelpText = "Evaluator to use" , ShortName = "eval" , NullName = "<Auto>" , SortOrder = 102 ) ]
3535 public SubComponent < IMamlEvaluator , SignatureMamlEvaluator > Evaluator ;
@@ -76,8 +76,8 @@ public sealed class Arguments : DataCommand.ArgumentsBase
7676 [ Argument ( ArgumentType . AtMostOnce , IsInputFileName = true , HelpText = "The validation data file" , ShortName = "valid" ) ]
7777 public string ValidationFile ;
7878
79- [ Argument ( ArgumentType . Multiple , HelpText = "Output calibrator" , ShortName = "cali" , NullName = "<None>" ) ]
80- public SubComponent < ICalibratorTrainer , SignatureCalibrator > Calibrator = new SubComponent < ICalibratorTrainer , SignatureCalibrator > ( "PlattCalibration" ) ;
79+ [ Argument ( ArgumentType . Multiple , HelpText = "Output calibrator" , ShortName = "cali" , NullName = "<None>" , SignatureType = typeof ( SignatureCalibrator ) ) ]
80+ public IComponentFactory < ICalibratorTrainer > Calibrator = new PlattCalibratorTrainerFactory ( ) ;
8181
8282 [ Argument ( ArgumentType . LastOccurenceWins , HelpText = "Number of instances to train the calibrator" , ShortName = "numcali" ) ]
8383 public int MaxCalibrationExamples = 1000000000 ;
@@ -383,9 +383,9 @@ public FoldResult(Dictionary<string, IDataView> metrics, ISchema scoreSchema, Ro
383383 private readonly string _splitColumn ;
384384 private readonly int _numFolds ;
385385 private readonly SubComponent < ITrainer , SignatureTrainer > _trainer ;
386- private readonly SubComponent < IDataScorerTransform , SignatureDataScorer > _scorer ;
386+ private readonly IComponentFactory < IDataView , ISchemaBoundMapper , RoleMappedSchema , IDataScorerTransform > _scorer ;
387387 private readonly SubComponent < IMamlEvaluator , SignatureMamlEvaluator > _evaluator ;
388- private readonly SubComponent < ICalibratorTrainer , SignatureCalibrator > _calibrator ;
388+ private readonly IComponentFactory < ICalibratorTrainer > _calibrator ;
389389 private readonly int _maxCalibrationExamples ;
390390 private readonly bool _useThreads ;
391391 private readonly bool ? _cacheData ;
@@ -423,7 +423,7 @@ public FoldHelper(
423423 Arguments args ,
424424 Func < IHostEnvironment , IChannel , IDataView , ITrainer , RoleMappedData > createExamples ,
425425 Func < IHostEnvironment , IChannel , IDataView , RoleMappedData , IDataView , RoleMappedData > applyTransformsToTestData ,
426- SubComponent < IDataScorerTransform , SignatureDataScorer > scorer ,
426+ IComponentFactory < IDataView , ISchemaBoundMapper , RoleMappedSchema , IDataScorerTransform > scorer ,
427427 SubComponent < IMamlEvaluator , SignatureMamlEvaluator > evaluator ,
428428 Func < IDataView > getValidationDataView = null ,
429429 Func < IHostEnvironment , IChannel , IDataView , RoleMappedData , IDataView , RoleMappedData > applyTransformsToValidationData = null ,
@@ -559,11 +559,12 @@ private FoldResult RunFold(int fold)
559559
560560 // Score.
561561 ch . Trace ( "Scoring and evaluating" ) ;
562- var bindable = ScoreUtils . GetSchemaBindableMapper ( host , predictor , _scorer ) ;
562+ ch . Assert ( _scorer == null || _scorer is ICommandLineComponentFactory , "CrossValidationCommand should only be used from the command line." ) ;
563+ var bindable = ScoreUtils . GetSchemaBindableMapper ( host , predictor , scorerFactorySettings : _scorer as ICommandLineComponentFactory ) ;
563564 ch . AssertValue ( bindable ) ;
564565 var mapper = bindable . Bind ( host , testData . Schema ) ;
565- var scorerComp = _scorer . IsGood ( ) ? _scorer : ScoreUtils . GetScorerComponent ( mapper ) ;
566- IDataScorerTransform scorePipe = scorerComp . CreateInstance ( host , testData . Data , mapper , trainData . Schema ) ;
566+ var scorerComp = _scorer ?? ScoreUtils . GetScorerComponent ( mapper ) ;
567+ IDataScorerTransform scorePipe = scorerComp . CreateComponent ( host , testData . Data , mapper , trainData . Schema ) ;
567568
568569 // Save per-fold model.
569570 string modelFileName = ConstructPerFoldName ( _outputModelFile , fold ) ;
0 commit comments