@@ -163,7 +163,7 @@ private void RunCore(IChannel ch, string cmd)
163163 RoleMappedData validData = null ;
164164 if ( ! string . IsNullOrWhiteSpace ( Args . ValidationFile ) )
165165 {
166- if ( ! TrainUtils . CanUseValidationData ( trainer ) )
166+ if ( ! trainer . Info . SupportsValidation )
167167 {
168168 ch . Warning ( "Ignoring validationFile: Trainer does not accept validation dataset." ) ;
169169 }
@@ -242,39 +242,32 @@ public static IPredictor Train(IHostEnvironment env, IChannel ch, RoleMappedData
242242 }
243243
244244 private static IPredictor TrainCore ( IHostEnvironment env , IChannel ch , RoleMappedData data , ITrainer trainer , string name , RoleMappedData validData ,
245- ICalibratorTrainer calibrator , int maxCalibrationExamples , bool ? cacheData , IPredictor inpPredictor = null )
245+ ICalibratorTrainer calibrator , int maxCalibrationExamples , bool ? cacheData , IPredictor inputPredictor = null )
246246 {
247247 Contracts . CheckValue ( env , nameof ( env ) ) ;
248248 env . CheckValue ( ch , nameof ( ch ) ) ;
249249 ch . CheckValue ( data , nameof ( data ) ) ;
250250 ch . CheckValue ( trainer , nameof ( trainer ) ) ;
251251 ch . CheckNonEmpty ( name , nameof ( name ) ) ;
252252 ch . CheckValueOrNull ( validData ) ;
253- ch . CheckValueOrNull ( inpPredictor ) ;
253+ ch . CheckValueOrNull ( inputPredictor ) ;
254254
255255 AddCacheIfWanted ( env , ch , trainer , ref data , cacheData ) ;
256256 ch . Trace ( "Training" ) ;
257257 if ( validData != null )
258258 AddCacheIfWanted ( env , ch , trainer , ref validData , cacheData ) ;
259259
260- var trainerEx = trainer as ITrainerEx ;
261- if ( inpPredictor != null && trainerEx ? . SupportsIncrementalTraining != true )
260+ if ( inputPredictor != null && ! trainer . Info . SupportsIncrementalTraining )
262261 {
263262 ch . Warning ( "Ignoring " + nameof ( TrainCommand . Arguments . InputModelFile ) +
264263 ": Trainer does not support incremental training." ) ;
265- inpPredictor = null ;
264+ inputPredictor = null ;
266265 }
267- ch . Assert ( validData == null || CanUseValidationData ( trainer ) ) ;
268- var predictor = trainer . Train ( new TrainContext ( data , validData , inpPredictor ) ) ;
266+ ch . Assert ( validData == null || trainer . Info . SupportsValidation ) ;
267+ var predictor = trainer . Train ( new TrainContext ( data , validData , inputPredictor ) ) ;
269268 return CalibratorUtils . TrainCalibratorIfNeeded ( env , ch , calibrator , maxCalibrationExamples , trainer , predictor , data ) ;
270269 }
271270
272- public static bool CanUseValidationData ( ITrainer trainer )
273- {
274- Contracts . CheckValue ( trainer , nameof ( trainer ) ) ;
275- return ( trainer as ITrainerEx ) ? . SupportsValidation ?? false ;
276- }
277-
278271 public static bool TryLoadPredictor ( IChannel ch , IHostEnvironment env , string inputModelFile , out IPredictor inputPredictor )
279272 {
280273 Contracts . AssertValue ( env ) ;
@@ -388,9 +381,8 @@ public static void SaveDataPipe(IHostEnvironment env, RepositoryWriter repositor
388381 IDataView pipeStart ;
389382 var xfs = BacktrackPipe ( dataPipe , out pipeStart ) ;
390383
391- IDataLoader loader ;
392384 Action < ModelSaveContext > saveAction ;
393- if ( ! blankLoader && ( loader = pipeStart as IDataLoader ) != null )
385+ if ( ! blankLoader && pipeStart is IDataLoader loader )
394386 saveAction = loader . Save ;
395387 else
396388 {
@@ -460,7 +452,7 @@ public static bool AddNormalizerIfNeeded(IHostEnvironment env, IChannel ch, ITra
460452 if ( autoNorm != NormalizeOption . Yes )
461453 {
462454 DvBool isNormalized = DvBool . False ;
463- if ( trainer . NeedNormalization ( ) != true || schema . IsNormalized ( featCol ) )
455+ if ( ! trainer . Info . NeedNormalization || schema . IsNormalized ( featCol ) )
464456 {
465457 ch . Info ( "Not adding a normalizer." ) ;
466458 return false ;
@@ -491,8 +483,7 @@ private static bool AddCacheIfWanted(IHostEnvironment env, IChannel ch, ITrainer
491483 ch . AssertValue ( trainer , nameof ( trainer ) ) ;
492484 ch . AssertValue ( data , nameof ( data ) ) ;
493485
494- ITrainerEx trainerEx = trainer as ITrainerEx ;
495- bool shouldCache = cacheData ?? ( ! ( data . Data is BinaryLoader ) && ( trainerEx == null || trainerEx . WantCaching ) ) ;
486+ bool shouldCache = cacheData ?? ! ( data . Data is BinaryLoader ) && trainer . Info . WantCaching ;
496487
497488 if ( shouldCache )
498489 {
0 commit comments