77using  Microsoft . ML . Runtime . CommandLine ; 
88using  Microsoft . ML . Runtime . Data ; 
99using  Microsoft . ML . Runtime . Data . Conversion ; 
10+ using  Microsoft . ML . Runtime . EntryPoints ; 
1011using  Microsoft . ML . Runtime . Internal . Calibration ; 
1112using  Microsoft . ML . Runtime . Internal . Internallearn ; 
1213using  Microsoft . ML . Runtime . Internal . Utilities ; 
@@ -45,7 +46,7 @@ internal static class FastTreeShared
4546    } 
4647
4748    public  abstract  class  FastTreeTrainerBase < TArgs ,  TTransformer ,  TModel >  : 
48-         TrainerEstimatorBase < TTransformer ,  TModel > 
49+         TrainerEstimatorBaseWithGroupId < TTransformer ,  TModel > 
4950        where  TTransformer :  ISingleFeaturePredictionTransformer < TModel > 
5051        where  TArgs  :  TreeArgs ,  new ( ) 
5152        where  TModel  :  IPredictorProducing < Float > 
@@ -92,26 +93,36 @@ public abstract class FastTreeTrainerBase<TArgs, TTransformer, TModel> :
9293        /// <summary> 
9394        /// Constructor to use when instantiating the classes deriving from here through the API. 
9495        /// </summary> 
95-         private  protected  FastTreeTrainerBase ( IHostEnvironment  env ,  SchemaShape . Column  label ,  string  featureColumn , 
96-             string  weightColumn  =  null ,  string  groupIdColumn  =  null ,  Action < TArgs >  advancedSettings  =  null ) 
97-             :  base ( Contracts . CheckRef ( env ,  nameof ( env ) ) . Register ( RegisterName ) ,  TrainerUtils . MakeR4VecFeature ( featureColumn ) ,  label ,  TrainerUtils . MakeR4ScalarWeightColumn ( weightColumn ) ) 
96+         private  protected  FastTreeTrainerBase ( IHostEnvironment  env , 
97+             SchemaShape . Column  label , 
98+             string  featureColumn , 
99+             string  weightColumn , 
100+             string  groupIdColumn , 
101+             int  numLeaves , 
102+             int  numTrees , 
103+             int  minDocumentsInLeafs , 
104+             Action < TArgs >  advancedSettings ) 
105+             :  base ( Contracts . CheckRef ( env ,  nameof ( env ) ) . Register ( RegisterName ) ,  TrainerUtils . MakeR4VecFeature ( featureColumn ) ,  label ,  TrainerUtils . MakeR4ScalarWeightColumn ( weightColumn ) ,  TrainerUtils . MakeU4ScalarColumn ( groupIdColumn ) ) 
98106        { 
99107            Args  =  new  TArgs ( ) ; 
100108
109+             // set up the directly provided values 
110+             // override with the directly provided values. 
111+             Args . NumLeaves  =  numLeaves ; 
112+             Args . NumTrees  =  numTrees ; 
113+             Args . MinDocumentsInLeafs  =  minDocumentsInLeafs ; 
114+ 
101115            //apply the advanced args, if the user supplied any 
102116            advancedSettings ? . Invoke ( Args ) ; 
103117
104-             // check that the users didn't specify different label, group, feature, weights in the args, from what they supplied directly 
105-             TrainerUtils . CheckArgsHaveDefaultColNames ( Host ,  Args ) ; 
106- 
107118            Args . LabelColumn  =  label . Name ; 
108119            Args . FeatureColumn  =  featureColumn ; 
109120
110121            if  ( weightColumn  !=  null ) 
111-                 Args . WeightColumn  =  weightColumn ; 
122+                 Args . WeightColumn  =  Optional < string > . Explicit ( weightColumn ) ;   ; 
112123
113124            if  ( groupIdColumn  !=  null ) 
114-                 Args . GroupIdColumn  =  groupIdColumn ; 
125+                 Args . GroupIdColumn  =  Optional < string > . Explicit ( groupIdColumn ) ;   ; 
115126
116127            // The discretization step renders this trainer non-parametric, and therefore it does not need normalization. 
117128            // Also since it builds its own internal discretized columnar structures, it cannot benefit from caching. 
@@ -128,7 +139,7 @@ private protected FastTreeTrainerBase(IHostEnvironment env, SchemaShape.Column l
128139        /// Legacy constructor that is used when invoking the classes deriving from this, through maml. 
129140        /// </summary> 
130141        private  protected  FastTreeTrainerBase ( IHostEnvironment  env ,  TArgs  args ,  SchemaShape . Column  label ) 
131-             :  base ( Contracts . CheckRef ( env ,  nameof ( env ) ) . Register ( RegisterName ) ,  TrainerUtils . MakeR4VecFeature ( args . FeatureColumn ) ,  label ,  TrainerUtils . MakeR4ScalarWeightColumn ( args . WeightColumn ) ) 
142+             :  base ( Contracts . CheckRef ( env ,  nameof ( env ) ) . Register ( RegisterName ) ,  TrainerUtils . MakeR4VecFeature ( args . FeatureColumn ) ,  label ,  TrainerUtils . MakeR4ScalarWeightColumn ( args . WeightColumn ,   args . WeightColumn . IsExplicit ) ) 
132143        { 
133144            Host . CheckValue ( args ,  nameof ( args ) ) ; 
134145            Args  =  args ; 
@@ -159,32 +170,6 @@ protected virtual Float GetMaxLabel()
159170            return  Float . PositiveInfinity ; 
160171        } 
161172
162-         /// <summary> 
163-         /// If, after applying the advancedSettings delegate, the args are different that the default value 
164-         /// and are also different than the value supplied directly to the xtension method, warn the user 
165-         /// about which value is being used. 
166-         /// The parameters that appear here, numTrees, minDocumentsInLeafs, numLeaves, learningRate are the ones the users are most likely to tune. 
167-         /// This list should follow the one in the constructor, and the extension methods on the <see cref="TrainContextBase"/>. 
168-         /// REVIEW: we should somehow mark the arguments that are set apart in those two places. Currently they stand out by their sort order annotation. 
169-         /// </summary> 
170-         protected  void  CheckArgsAndAdvancedSettingMismatch ( int  numLeaves , 
171-             int  numTrees , 
172-             int  minDocumentsInLeafs , 
173-             double  learningRate , 
174-             BoostedTreeArgs  snapshot , 
175-             BoostedTreeArgs  currentArgs ) 
176-         { 
177-             using  ( var  ch  =  Host . Start ( "Comparing advanced settings with the directly provided values." ) ) 
178-             { 
179- 
180-                 // Check that the user didn't supply different parameters in the args, from what it specified directly. 
181-                 TrainerUtils . CheckArgsAndAdvancedSettingMismatch ( ch ,  numLeaves ,  snapshot . NumLeaves ,  currentArgs . NumLeaves ,  nameof ( numLeaves ) ) ; 
182-                 TrainerUtils . CheckArgsAndAdvancedSettingMismatch ( ch ,  numTrees ,  snapshot . NumTrees ,  currentArgs . NumTrees ,  nameof ( numTrees ) ) ; 
183-                 TrainerUtils . CheckArgsAndAdvancedSettingMismatch ( ch ,  minDocumentsInLeafs ,  snapshot . MinDocumentsInLeafs ,  currentArgs . MinDocumentsInLeafs ,  nameof ( minDocumentsInLeafs ) ) ; 
184-                 TrainerUtils . CheckArgsAndAdvancedSettingMismatch ( ch ,  learningRate ,  snapshot . LearningRates ,  currentArgs . LearningRates ,  nameof ( learningRate ) ) ; 
185-             } 
186-         } 
187- 
188173        private  void  Initialize ( IHostEnvironment  env ) 
189174        { 
190175            int  numThreads  =  Args . NumThreads  ??  Environment . ProcessorCount ; 
0 commit comments