-
Notifications
You must be signed in to change notification settings - Fork 1.9k
Tree estimators #855
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Tree estimators #855
Changes from 11 commits
cbd84ab
0f68992
c76285c
13a187e
8e76ef1
e5f8925
f2410a6
574b9d2
4b3da66
63e63d5
0cccda5
5073a90
edc39d5
66eaf76
2d8b525
df5449a
e770ddc
e8fb048
7ce1bd3
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -57,7 +57,7 @@ public Arguments() | |
| env => new Ova(env, new Ova.Arguments() | ||
| { | ||
| PredictorType = ComponentFactoryUtils.CreateFromFunction( | ||
| e => new AveragedPerceptronTrainer(e, new AveragedPerceptronTrainer.Arguments())) | ||
| e => new FastTreeBinaryClassificationTrainer(e, DefaultColumnNames.Label, DefaultColumnNames.Features)) | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
I'd really rather we didn't. This seems to fit into the same bucket as the discussion on #682. That ensembling should have a dependency on FastTree merely because we have a default does not make sense to me. If someone wants to use stacking, that's great, but they need to specify the learners. #Pending There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, let's do that separately, when we shape the ensembles to take in the arguments in the constructor. In reply to: 218215323 [](ancestors = 218215323,218215145) |
||
| })); | ||
| } | ||
| } | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -25,6 +25,8 @@ | |
| using Microsoft.ML.Runtime.Training; | ||
| using Microsoft.ML.Runtime.TreePredictor; | ||
| using Newtonsoft.Json.Linq; | ||
| using Microsoft.ML.Core.Data; | ||
| using Microsoft.ML.Runtime.EntryPoints; | ||
|
||
|
|
||
| // All of these reviews apply in general to fast tree and random forest implementations. | ||
| //REVIEW: Decouple train method in Application.cs to have boosting and random forest logic seperate. | ||
|
|
@@ -43,10 +45,11 @@ internal static class FastTreeShared | |
| public static readonly object TrainLock = new object(); | ||
| } | ||
|
|
||
| public abstract class FastTreeTrainerBase<TArgs, TPredictor> : | ||
| TrainerBase<TPredictor> | ||
| public abstract class FastTreeTrainerBase<TArgs, TTransformer, TModel> : | ||
| TrainerEstimatorBase<TTransformer, TModel> | ||
| where TTransformer: IPredictionTransformer<TModel> | ||
| where TArgs : TreeArgs, new() | ||
| where TPredictor : IPredictorProducing<Float> | ||
| where TModel : IPredictorProducing<Float> | ||
| { | ||
| protected readonly TArgs Args; | ||
| protected readonly bool AllowGC; | ||
|
|
@@ -87,34 +90,53 @@ public abstract class FastTreeTrainerBase<TArgs, TPredictor> : | |
|
|
||
| private protected virtual bool NeedCalibration => false; | ||
|
|
||
| private protected FastTreeTrainerBase(IHostEnvironment env, TArgs args) | ||
| : base(env, RegisterName) | ||
| /// <summary> | ||
| /// Constructor to use when instantiating the classing deriving from here through the API. | ||
| /// </summary> | ||
| private protected FastTreeTrainerBase(IHostEnvironment env, SchemaShape.Column label, string featureColumn, | ||
| string weightColumn = null, string groupIdColumn = null, Action<TArgs> advancedSettings = null) | ||
| : base(Contracts.CheckRef(env, nameof(env)).Register(RegisterName), MakeFeatureColumn(featureColumn), label, MakeWeightColumn(weightColumn)) | ||
| { | ||
| Args = new TArgs(); | ||
|
|
||
| //apply the advanced args, if the user supplied any | ||
| advancedSettings?.Invoke(Args); | ||
| Args.LabelColumn = label.Name; | ||
|
|
||
| if (weightColumn != null) | ||
| Args.WeightColumn = Optional<string>.Explicit(weightColumn); | ||
|
|
||
| if (groupIdColumn != null) | ||
| Args.GroupIdColumn = Optional<string>.Explicit(groupIdColumn); | ||
|
|
||
| // The discretization step renders this trainer non-parametric, and therefore it does not need normalization. | ||
| // Also since it builds its own internal discretized columnar structures, it cannot benefit from caching. | ||
| // Finally, even the binary classifiers, being logitboost, tend to not benefit from external calibration. | ||
| Info = new TrainerInfo(normalization: false, caching: false, calibration: NeedCalibration, supportValid: true); | ||
| // REVIEW: CLR 4.6 has a bug that is only exposed in Scope, and if we trigger GC.Collect in scope environment | ||
| // with memory consumption more than 5GB, GC get stuck in infinite loop. So for now let's call GC only if we call things from TlcEnvironment. | ||
| AllowGC = (env is HostEnvironmentBase<TlcEnvironment>); | ||
|
|
||
| Initialize(env); | ||
| } | ||
|
|
||
| /// <summary> | ||
| /// Legacy constructor that is used when invoking the classsing deriving from this, through maml. | ||
| /// </summary> | ||
| private protected FastTreeTrainerBase(IHostEnvironment env, TArgs args, SchemaShape.Column label) | ||
| : base(Contracts.CheckRef(env, nameof(env)).Register(RegisterName), MakeFeatureColumn(args.FeatureColumn), label, MakeWeightColumn(args.WeightColumn)) | ||
| { | ||
| Host.CheckValue(args, nameof(args)); | ||
| Args = args; | ||
| // The discretization step renders this trainer non-parametric, and therefore it does not need normalization. | ||
| // Also since it builds its own internal discretized columnar structures, it cannot benefit from caching. | ||
| // Finally, even the binary classifiers, being logitboost, tend to not benefit from external calibration. | ||
| Info = new TrainerInfo(normalization: false, caching: false, calibration: NeedCalibration, supportValid: true); | ||
| int numThreads = Args.NumThreads ?? Environment.ProcessorCount; | ||
| if (Host.ConcurrencyFactor > 0 && numThreads > Host.ConcurrencyFactor) | ||
| { | ||
| using (var ch = Host.Start("FastTreeTrainerBase")) | ||
| { | ||
| numThreads = Host.ConcurrencyFactor; | ||
| ch.Warning("The number of threads specified in trainer arguments is larger than the concurrency factor " | ||
| + "setting of the environment. Using {0} training threads instead.", numThreads); | ||
| ch.Done(); | ||
| } | ||
| } | ||
| ParallelTraining = Args.ParallelTrainer != null ? Args.ParallelTrainer.CreateComponent(env) : new SingleTrainer(); | ||
| ParallelTraining.InitEnvironment(); | ||
| // REVIEW: CLR 4.6 has a bug that is only exposed in Scope, and if we trigger GC.Collect in scope environment | ||
| // with memory consumption more than 5GB, GC get stuck in infinite loop. So for now let's call GC only if we call things from TlcEnvironment. | ||
| AllowGC = (env is HostEnvironmentBase<TlcEnvironment>); | ||
| Tests = new List<Test>(); | ||
|
|
||
| InitializeThreads(numThreads); | ||
| Initialize(env); | ||
| } | ||
|
|
||
| protected abstract void PrepareLabels(IChannel ch); | ||
|
|
@@ -133,6 +155,39 @@ protected virtual Float GetMaxLabel() | |
| return Float.PositiveInfinity; | ||
| } | ||
|
|
||
| private static SchemaShape.Column MakeWeightColumn(string weightColumn) | ||
| { | ||
| if (weightColumn == null) | ||
| return null; | ||
| return new SchemaShape.Column(weightColumn, SchemaShape.Column.VectorKind.Scalar, NumberType.R4, false); | ||
| } | ||
|
|
||
| private static SchemaShape.Column MakeFeatureColumn(string featureColumn) | ||
| { | ||
| return new SchemaShape.Column(featureColumn, SchemaShape.Column.VectorKind.Vector, NumberType.R4, false); | ||
| } | ||
|
|
||
| private void Initialize(IHostEnvironment env) | ||
| { | ||
| int numThreads = Args.NumThreads ?? Environment.ProcessorCount; | ||
| if (Host.ConcurrencyFactor > 0 && numThreads > Host.ConcurrencyFactor) | ||
| { | ||
| using (var ch = Host.Start("FastTreeTrainerBase")) | ||
| { | ||
| numThreads = Host.ConcurrencyFactor; | ||
| ch.Warning("The number of threads specified in trainer arguments is larger than the concurrency factor " | ||
| + "setting of the environment. Using {0} training threads instead.", numThreads); | ||
| ch.Done(); | ||
| } | ||
| } | ||
| ParallelTraining = Args.ParallelTrainer != null ? Args.ParallelTrainer.CreateComponent(env) : new SingleTrainer(); | ||
| ParallelTraining.InitEnvironment(); | ||
|
|
||
| Tests = new List<Test>(); | ||
|
|
||
| InitializeThreads(numThreads); | ||
| } | ||
|
|
||
| protected void ConvertData(RoleMappedData trainData) | ||
| { | ||
| trainData.Schema.Schema.TryGetColumnIndex(DefaultColumnNames.Features, out int featureIndex); | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is the reason why we have two types that are identical in practically everything but name, so we can identify ranking estimators vs. regression estimators in a statically typed way?
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this transformer should also expose the group ID column name, at least that would be my belief
In reply to: 218214277 [](ancestors = 218214277)
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actually thought about this, like labels group ids are only needed for training, right? So for prediction I don't think they should be.
In reply to: 218216192 [](ancestors = 218216192,218214277)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So keep it, or make the Regression one Generic and use it for both?
In reply to: 218216839 [](ancestors = 218216839,218216192,218214277)