@@ -107,6 +107,7 @@ public sealed class Arguments : ArgumentsBase
107107 internal const string ShortName = "gam" ;
108108
109109 public override PredictionKind PredictionKind => PredictionKind . BinaryClassification ;
110+ private protected override bool NeedCalibration => true ;
110111
111112 public BinaryClassificationGamTrainer ( IHostEnvironment env , Arguments args )
112113 : base ( env , args ) { }
@@ -225,6 +226,7 @@ public abstract class ArgumentsBase : LearnerInputBaseWithWeight
225226 protected int [ ] FeatureMap ;
226227
227228 public override TrainerInfo Info { get ; }
229+ private protected virtual bool NeedCalibration => false ;
228230
229231 private protected GamTrainerBase ( IHostEnvironment env , TArgs args )
230232 : base ( env , RegisterName )
@@ -240,7 +242,7 @@ private protected GamTrainerBase(IHostEnvironment env, TArgs args)
240242 Host . CheckParam ( 0 < args . NumIterations , nameof ( args . NumIterations ) , "Must be positive." ) ;
241243
242244 Args = args ;
243- Info = new TrainerInfo ( normalization : false , calibration : this is BinaryClassificationGamTrainer , caching : false ) ;
245+ Info = new TrainerInfo ( normalization : false , calibration : NeedCalibration , caching : false ) ;
244246 _gainConfidenceInSquaredStandardDeviations = Math . Pow ( ProbabilityFunctions . Probit ( 1 - ( 1 - Args . GainConfidenceLevel ) * 0.5 ) , 2 ) ;
245247 _entropyCoefficient = Args . EntropyCoefficient * 1e-6 ;
246248 int numThreads = args . NumThreads ?? Environment . ProcessorCount ;
0 commit comments