Skip to content

Commit cd04300

Browse files
committed
FastTreeRankingTrainer expose non-advanced args(#1246)
1 parent 67360d9 commit cd04300

File tree

9 files changed

+48
-14
lines changed

9 files changed

+48
-14
lines changed

src/Microsoft.ML.FastTree/BoostingFastTree.cs

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,10 +21,20 @@ protected BoostingFastTreeTrainerBase(IHostEnvironment env, TArgs args, SchemaSh
2121
{
2222
}
2323

24-
protected BoostingFastTreeTrainerBase(IHostEnvironment env, SchemaShape.Column label, string featureColumn,
25-
string weightColumn = null, string groupIdColumn = null, Action<TArgs> advancedSettings = null)
26-
: base(env, label, featureColumn, weightColumn, groupIdColumn, advancedSettings)
24+
protected BoostingFastTreeTrainerBase(IHostEnvironment env,
25+
SchemaShape.Column label,
26+
string featureColumn,
27+
string weightColumn = null,
28+
string groupIdColumn = null,
29+
int numLeaves = Defaults.NumLeaves,
30+
int numTrees = Defaults.NumTrees,
31+
int minDocumentsInLeafs = Defaults.MinDocumentsInLeafs,
32+
double learningRate = Defaults.LearningRates,
33+
Action<TArgs> advancedSettings = null)
34+
: base(env, label, featureColumn, weightColumn, groupIdColumn, numLeaves, numTrees, minDocumentsInLeafs, advancedSettings)
2735
{
36+
//override with the directly provided values.
37+
Args.LearningRates = learningRate;
2838
}
2939

3040
protected override void CheckArgs(IChannel ch)

src/Microsoft.ML.FastTree/FastTree.cs

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -92,8 +92,15 @@ public abstract class FastTreeTrainerBase<TArgs, TTransformer, TModel> :
9292
/// <summary>
9393
/// Constructor to use when instantiating the classes deriving from here through the API.
9494
/// </summary>
95-
private protected FastTreeTrainerBase(IHostEnvironment env, SchemaShape.Column label, string featureColumn,
96-
string weightColumn = null, string groupIdColumn = null, Action<TArgs> advancedSettings = null)
95+
private protected FastTreeTrainerBase(IHostEnvironment env,
96+
SchemaShape.Column label,
97+
string featureColumn,
98+
string weightColumn = null,
99+
string groupIdColumn = null,
100+
int numLeaves = Defaults.NumLeaves,
101+
int numTrees = Defaults.NumTrees,
102+
int minDocumentsInLeafs = Defaults.MinDocumentsInLeafs,
103+
Action<TArgs> advancedSettings = null)
97104
: base(Contracts.CheckRef(env, nameof(env)).Register(RegisterName), TrainerUtils.MakeR4VecFeature(featureColumn), label, TrainerUtils.MakeR4ScalarWeightColumn(weightColumn))
98105
{
99106
Args = new TArgs();
@@ -113,6 +120,11 @@ private protected FastTreeTrainerBase(IHostEnvironment env, SchemaShape.Column l
113120
if (groupIdColumn != null)
114121
Args.GroupIdColumn = groupIdColumn;
115122

123+
//override with the directly provided values.
124+
Args.NumLeaves = numLeaves;
125+
Args.NumTrees = numTrees;
126+
Args.MinDocumentsInLeafs = minDocumentsInLeafs;
127+
116128
// The discretization step renders this trainer non-parametric, and therefore it does not need normalization.
117129
// Also since it builds its own internal discretized columnar structures, it cannot benefit from caching.
118130
// Finally, even the binary classifiers, being logitboost, tend to not benefit from external calibration.

src/Microsoft.ML.FastTree/FastTreeCatalog.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,7 @@ public static FastTreeRankingTrainer FastTree(this RankingContext.RankingTrainer
9494
{
9595
Contracts.CheckValue(ctx, nameof(ctx));
9696
var env = CatalogUtils.GetEnvironment(ctx);
97-
return new FastTreeRankingTrainer(env, label, features, groupId, weights, advancedSettings);
97+
return new FastTreeRankingTrainer(env, label, features, groupId, weights, advancedSettings: advancedSettings);
9898
}
9999
}
100100
}

src/Microsoft.ML.FastTree/FastTreeClassification.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -136,7 +136,7 @@ public FastTreeBinaryClassificationTrainer(IHostEnvironment env,
136136
int minDocumentsInLeafs = Defaults.MinDocumentsInLeafs,
137137
double learningRate = Defaults.LearningRates,
138138
Action<Arguments> advancedSettings = null)
139-
: base(env, TrainerUtils.MakeBoolScalarLabel(labelColumn), featureColumn, weightColumn, null, advancedSettings)
139+
: base(env, TrainerUtils.MakeBoolScalarLabel(labelColumn), featureColumn, weightColumn, null, numLeaves, numTrees, minDocumentsInLeafs, learningRate, advancedSettings)
140140
{
141141
Host.CheckNonEmpty(labelColumn, nameof(labelColumn));
142142
Host.CheckNonEmpty(featureColumn, nameof(featureColumn));

src/Microsoft.ML.FastTree/FastTreeRanking.cs

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -67,10 +67,22 @@ public sealed partial class FastTreeRankingTrainer
6767
/// <param name="featureColumn">The name of the feature column.</param>
6868
/// <param name="groupIdColumn">The name for the column containing the group ID. </param>
6969
/// <param name="weightColumn">The name for the column containing the initial weight.</param>
70+
/// <param name="numLeaves">The max number of leaves in each regression tree.</param>
71+
/// <param name="numTrees">Total number of decision trees to create in the ensemble.</param>
72+
/// <param name="minDocumentsInLeafs">The minimal number of documents allowed in a leaf of a regression tree, out of the subsampled data.</param>
73+
/// <param name="learningRate">The learning rate.</param>
7074
/// <param name="advancedSettings">A delegate to apply all the advanced arguments to the algorithm.</param>
71-
public FastTreeRankingTrainer(IHostEnvironment env, string labelColumn, string featureColumn, string groupIdColumn,
72-
string weightColumn = null, Action<Arguments> advancedSettings = null)
73-
: base(env, TrainerUtils.MakeR4ScalarLabel(labelColumn), featureColumn, weightColumn, groupIdColumn, advancedSettings: advancedSettings)
75+
public FastTreeRankingTrainer(IHostEnvironment env,
76+
string labelColumn,
77+
string featureColumn,
78+
string groupIdColumn,
79+
string weightColumn = null,
80+
int numLeaves = Defaults.NumLeaves,
81+
int numTrees = Defaults.NumTrees,
82+
int minDocumentsInLeafs = Defaults.MinDocumentsInLeafs,
83+
double learningRate = Defaults.LearningRates,
84+
Action<Arguments> advancedSettings = null)
85+
: base(env, TrainerUtils.MakeR4ScalarLabel(labelColumn), featureColumn, weightColumn, groupIdColumn, numLeaves, numTrees, minDocumentsInLeafs, learningRate, advancedSettings: advancedSettings)
7486
{
7587
Host.CheckNonEmpty(labelColumn, nameof(labelColumn));
7688
Host.CheckNonEmpty(featureColumn, nameof(featureColumn));

src/Microsoft.ML.FastTree/FastTreeRegression.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ public FastTreeRegressionTrainer(IHostEnvironment env,
7272
int minDocumentsInLeafs = Defaults.MinDocumentsInLeafs,
7373
double learningRate = Defaults.LearningRates,
7474
Action<Arguments> advancedSettings = null)
75-
: base(env, TrainerUtils.MakeR4ScalarLabel(labelColumn), featureColumn, weightColumn, null, advancedSettings)
75+
: base(env, TrainerUtils.MakeR4ScalarLabel(labelColumn), featureColumn, weightColumn, null, numLeaves, numTrees, minDocumentsInLeafs, learningRate, advancedSettings)
7676
{
7777
Host.CheckNonEmpty(labelColumn, nameof(labelColumn));
7878
Host.CheckNonEmpty(featureColumn, nameof(featureColumn));

src/Microsoft.ML.FastTree/FastTreeStatic.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -153,7 +153,7 @@ public static Scalar<float> FastTree<TVal>(this RankingContext.RankingTrainers c
153153
var rec = new TrainerEstimatorReconciler.Ranker<TVal>(
154154
(env, labelName, featuresName, groupIdName, weightsName) =>
155155
{
156-
var trainer = new FastTreeRankingTrainer(env, labelName, featuresName, groupIdName, weightsName, advancedSettings);
156+
var trainer = new FastTreeRankingTrainer(env, labelName, featuresName, groupIdName, weightsName, numLeaves, numTrees, minDatapointsInLeafs, learningRate, advancedSettings);
157157
if (onFit != null)
158158
return trainer.WithOnFitDelegate(trans => onFit(trans.Model));
159159
return trainer;

src/Microsoft.ML.FastTree/FastTreeTweedie.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ public sealed partial class FastTreeTweedieTrainer
5959
/// <param name="advancedSettings">A delegate to apply all the advanced arguments to the algorithm.</param>
6060
public FastTreeTweedieTrainer(IHostEnvironment env, string labelColumn, string featureColumn,
6161
string groupIdColumn = null, string weightColumn = null, Action<Arguments> advancedSettings = null)
62-
: base(env, TrainerUtils.MakeR4ScalarLabel(labelColumn), featureColumn, weightColumn, groupIdColumn, advancedSettings)
62+
: base(env, TrainerUtils.MakeR4ScalarLabel(labelColumn), featureColumn, weightColumn, groupIdColumn, advancedSettings: advancedSettings)
6363
{
6464
Host.CheckNonEmpty(labelColumn, nameof(labelColumn));
6565
Host.CheckNonEmpty(featureColumn, nameof(featureColumn));

src/Microsoft.ML.FastTree/RandomForest.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ protected RandomForestTrainerBase(IHostEnvironment env, TArgs args, SchemaShape.
3030
/// </summary>
3131
protected RandomForestTrainerBase(IHostEnvironment env, SchemaShape.Column label, string featureColumn,
3232
string weightColumn = null, string groupIdColumn = null, bool quantileEnabled = false, Action<TArgs> advancedSettings = null)
33-
: base(env, label, featureColumn, weightColumn, groupIdColumn, advancedSettings)
33+
: base(env, label, featureColumn, weightColumn, groupIdColumn, advancedSettings: advancedSettings)
3434
{
3535
_quantileEnabled = quantileEnabled;
3636
}

0 commit comments

Comments
 (0)