diff --git a/src/Microsoft.ML.Core/Prediction/ITrainer.cs b/src/Microsoft.ML.Core/Prediction/ITrainer.cs
index 6e04a30e6f..b38a742d9a 100644
--- a/src/Microsoft.ML.Core/Prediction/ITrainer.cs
+++ b/src/Microsoft.ML.Core/Prediction/ITrainer.cs
@@ -2,9 +2,8 @@
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
-using System;
using System.Collections.Generic;
-using System.IO;
+using Microsoft.ML.Runtime.Data;
namespace Microsoft.ML.Runtime
{
@@ -27,151 +26,79 @@ namespace Microsoft.ML.Runtime
public delegate void SignatureSequenceTrainer();
public delegate void SignatureMatrixRecommendingTrainer();
- ///
- /// Interface to provide extra information about a trainer.
- ///
- public interface ITrainerEx : ITrainer
- {
- // REVIEW: Ideally trainers should be able to communicate
- // something about the type of data they are capable of being trained
- // on, e.g., what ColumnKinds they want, how many of each, of what type,
- // etc. This interface seems like the most natural conduit for that sort
- // of extra information.
-
- // REVIEW: Can we please have consistent naming here?
- // 'Need' vs. 'Want' looks arbitrary to me, and it's grammatically more correct to
- // be 'Needs' / 'Wants' anyway.
-
- ///
- /// Whether the trainer needs to see data in normalized form.
- ///
- bool NeedNormalization { get; }
-
- ///
- /// Whether the trainer needs calibration to produce probabilities.
- ///
- bool NeedCalibration { get; }
-
- ///
- /// Whether this trainer could benefit from a cached view of the data.
- ///
- bool WantCaching { get; }
- }
-
- public interface ITrainerHost
- {
- Random Rand { get; }
- int Verbosity { get; }
-
- TextWriter StdOut { get; }
- TextWriter StdErr { get; }
- }
-
- // The Trainer (of Factory) can optionally implement this.
- public interface IModelCombiner
- where TPredictor : IPredictor
- {
- TPredictor CombineModels(IEnumerable models);
- }
-
public delegate void SignatureModelCombiner(PredictionKind kind);
///
- /// Weakly typed interface for a trainer "session" that produces a predictor.
+ /// The base interface for a trainers. Implementors should not implement this interface directly,
+ /// but rather implement the more specific .
///
public interface ITrainer
{
///
- /// Return the type of prediction task for the produced predictor.
+ /// Auxiliary information about the trainer in terms of its capabilities
+ /// and requirements.
///
- PredictionKind PredictionKind { get; }
+ TrainerInfo Info { get; }
///
- /// Returns the trained predictor.
- /// REVIEW: Consider removing this.
+ /// Return the type of prediction task for the produced predictor.
///
- IPredictor CreatePredictor();
- }
-
- ///
- /// Interface implemented by the MetalinearLearners base class.
- /// Used to distinguish the MetaLinear Learners from the other learners
- ///
- public interface IMetaLinearTrainer
- {
-
- }
+ PredictionKind PredictionKind { get; }
- public interface ITrainer : ITrainer
- {
///
- /// Trains a predictor using the specified dataset.
+ /// Trains a predictor.
///
- /// Training dataset
- void Train(TDataSet data);
+ /// A context containing at least the training data
+ /// The trained predictor
+ ///
+ IPredictor Train(TrainContext context);
}
///
- /// Strongly typed generic interface for a trainer. A trainer object takes
- /// supervision data and produces a predictor.
+ /// Strongly typed generic interface for a trainer. A trainer object takes training data
+ /// and produces a predictor.
///
- /// Type of the training dataset
/// Type of predictor produced
- public interface ITrainer : ITrainer
+ public interface ITrainer : ITrainer
where TPredictor : IPredictor
{
///
- /// Returns the trained predictor.
- ///
- /// Trained predictor ready to make predictions
- new TPredictor CreatePredictor();
- }
-
- ///
- /// Trainers that want data to do their own validation implement this interface.
- ///
- public interface IValidatingTrainer : ITrainer
- {
- ///
- /// Trains a predictor using the specified dataset.
+ /// Trains a predictor.
///
- /// Training dataset
- /// Validation dataset
- void Train(TDataSet data, TDataSet validData);
+ /// A context containing at least the training data
+ /// The trained predictor
+ new TPredictor Train(TrainContext context);
}
- public interface IIncrementalTrainer : ITrainer
+ public static class TrainerExtensions
{
///
- /// Trains a predictor using the specified dataset and a trained predictor.
+ /// Convenience train extension for the case where one has only a training set with no auxiliary information.
+ /// Equivalent to calling
+ /// on a constructed with .
///
- /// Training dataset
- /// A trained predictor
- void Train(TDataSet data, TPredictor predictor);
- }
+ /// The trainer
+ /// The training data.
+ /// The trained predictor
+ public static IPredictor Train(this ITrainer trainer, RoleMappedData trainData)
+ => trainer.Train(new TrainContext(trainData));
- public interface IIncrementalValidatingTrainer : ITrainer
- {
///
- /// Trains a predictor using the specified dataset and a trained predictor.
+ /// Convenience train extension for the case where one has only a training set with no auxiliary information.
+ /// Equivalent to calling
+ /// on a constructed with .
///
- /// Training dataset
- /// Validation dataset
- /// A trained predictor
- void Train(TDataSet data, TDataSet validData, TPredictor predictor);
+ /// The trainer
+ /// The training data.
+ /// The trained predictor
+ public static TPredictor Train(this ITrainer trainer, RoleMappedData trainData) where TPredictor : IPredictor
+ => trainer.Train(new TrainContext(trainData));
}
-#if FUTURE
- public interface IMultiTrainer :
- IMultiTrainer
- {
- }
-
- public interface IMultiTrainer :
- ITrainer
+ // A trainer can optionally implement this to indicate it can combine multiple models into a single predictor.
+ public interface IModelCombiner
+ where TPredictor : IPredictor
{
- void UpdatePredictor(TDataBatch trainInstance);
- IPredictor GetCurrentPredictor();
+ TPredictor CombineModels(IEnumerable models);
}
-#endif
}
diff --git a/src/Microsoft.ML.Core/Prediction/TrainContext.cs b/src/Microsoft.ML.Core/Prediction/TrainContext.cs
new file mode 100644
index 0000000000..3464aa4bc9
--- /dev/null
+++ b/src/Microsoft.ML.Core/Prediction/TrainContext.cs
@@ -0,0 +1,57 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+// See the LICENSE file in the project root for more information.
+
+using Microsoft.ML.Runtime.Data;
+
+namespace Microsoft.ML.Runtime
+{
+ ///
+ /// Holds information relevant to trainers. Instances of this class are meant to be constructed and passed
+ /// into or .
+ /// This holds at least a training set, as well as optioonally a predictor.
+ ///
+ public sealed class TrainContext
+ {
+ ///
+ /// The training set. Cannot be null.
+ ///
+ public RoleMappedData TrainingSet { get; }
+
+ ///
+ /// The validation set. Can be null. Note that passing a non-null validation set into
+ /// a trainer that does not support validation sets should not be considered an error condition. It
+ /// should simply be ignored in that case.
+ ///
+ public RoleMappedData ValidationSet { get; }
+
+ ///
+ /// The initial predictor, for incremental training. Note that if a implementor
+ /// does not support incremental training, then it can ignore it similarly to how one would ignore
+ /// . However, if the trainer does support incremental training and there
+ /// is something wrong with a non-null value of this, then the trainer ought to throw an exception.
+ ///
+ public IPredictor InitialPredictor { get; }
+
+
+ ///
+ /// Constructor, given a training set and optional other arguments.
+ ///
+ /// Will set to this value. This must be specified
+ /// Will set to this value if specified
+ /// Will set to this value if specified
+ public TrainContext(RoleMappedData trainingSet, RoleMappedData validationSet = null, IPredictor initialPredictor = null)
+ {
+ Contracts.CheckValue(trainingSet, nameof(trainingSet));
+ Contracts.CheckValueOrNull(validationSet);
+ Contracts.CheckValueOrNull(initialPredictor);
+
+ // REVIEW: Should there be code here to ensure that the role mappings between the two are compatible?
+ // That is, all the role mappings are the same and the columns between them have identical types?
+
+ TrainingSet = trainingSet;
+ ValidationSet = validationSet;
+ InitialPredictor = initialPredictor;
+ }
+ }
+}
diff --git a/src/Microsoft.ML.Core/Prediction/TrainerInfo.cs b/src/Microsoft.ML.Core/Prediction/TrainerInfo.cs
new file mode 100644
index 0000000000..cce728e09a
--- /dev/null
+++ b/src/Microsoft.ML.Core/Prediction/TrainerInfo.cs
@@ -0,0 +1,71 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+// See the LICENSE file in the project root for more information.
+
+namespace Microsoft.ML.Runtime
+{
+ ///
+ /// Instances of this class posses information about trainers, in terms of their requirements and capabilities.
+ /// The intended usage is as the value for .
+ ///
+ public sealed class TrainerInfo
+ {
+ // REVIEW: Ideally trainers should be able to communicate
+ // something about the type of data they are capable of being trained
+ // on, e.g., what ColumnKinds they want, how many of each, of what type,
+ // etc. This interface seems like the most natural conduit for that sort
+ // of extra information.
+
+ ///
+ /// Whether the trainer needs to see data in normalized form. Only non-parametric learners will tend to produce
+ /// normalization here.
+ ///
+ public bool NeedNormalization { get; }
+
+ ///
+ /// Whether the trainer needs calibration to produce probabilities. As a general rule only trainers that produce
+ /// binary classifier predictors that also do not have a natural probabilistic interpretation should have a
+ /// true value here.
+ ///
+ public bool NeedCalibration { get; }
+
+ ///
+ /// Whether this trainer could benefit from a cached view of the data. Trainers that have few passes over the
+ /// data, or that need to build their own custom data structure over the data, will have a false here.
+ ///
+ public bool WantCaching { get; }
+
+ ///
+ /// Whether the trainer supports validation sets via . Not implementing
+ /// this interface and returning true from this property is an indication the trainer does not support
+ /// that.
+ ///
+ public bool SupportsValidation { get; }
+
+ ///
+ /// Whether the trainer can support incremental trainers via . Not
+ /// implementing this interface and returning true from this property is an indication the trainer does
+ /// not support that.
+ ///
+ public bool SupportsIncrementalTraining { get; }
+
+ ///
+ /// Initializes with the given parameters. The parameters have default values for the most typical values
+ /// for most classical trainers.
+ ///
+ /// The value for the property
+ /// The value for the property
+ /// The value for the property
+ /// The value for the property
+ /// The value for the property
+ public TrainerInfo(bool normalization = true, bool calibration = false, bool caching = true,
+ bool supportValid = false, bool supportIncrementalTrain = false)
+ {
+ NeedNormalization = normalization;
+ NeedCalibration = calibration;
+ WantCaching = caching;
+ SupportsValidation = supportValid;
+ SupportsIncrementalTraining = supportIncrementalTrain;
+ }
+ }
+}
diff --git a/src/Microsoft.ML.Core/Utilities/ObjectPool.cs b/src/Microsoft.ML.Core/Utilities/ObjectPool.cs
index 46486dc937..4a65286551 100644
--- a/src/Microsoft.ML.Core/Utilities/ObjectPool.cs
+++ b/src/Microsoft.ML.Core/Utilities/ObjectPool.cs
@@ -39,7 +39,7 @@ public abstract class ObjectPoolBase
public int Count => _pool.Count;
public int NumCreated { get { return _numCreated; } }
- protected internal ObjectPoolBase()
+ private protected ObjectPoolBase()
{
_pool = new ConcurrentBag();
}
diff --git a/src/Microsoft.ML.Data/Commands/CrossValidationCommand.cs b/src/Microsoft.ML.Data/Commands/CrossValidationCommand.cs
index 23b1c601c9..26ec32d3fe 100644
--- a/src/Microsoft.ML.Data/Commands/CrossValidationCommand.cs
+++ b/src/Microsoft.ML.Data/Commands/CrossValidationCommand.cs
@@ -538,7 +538,7 @@ private FoldResult RunFold(int fold)
if (_getValidationDataView != null)
{
ch.Assert(_applyTransformsToValidationData != null);
- if (!TrainUtils.CanUseValidationData(trainer))
+ if (!trainer.Info.SupportsValidation)
ch.Warning("Trainer does not accept validation dataset.");
else
{
diff --git a/src/Microsoft.ML.Data/Commands/TrainCommand.cs b/src/Microsoft.ML.Data/Commands/TrainCommand.cs
index b5e3964157..1c25275c3e 100644
--- a/src/Microsoft.ML.Data/Commands/TrainCommand.cs
+++ b/src/Microsoft.ML.Data/Commands/TrainCommand.cs
@@ -163,7 +163,7 @@ private void RunCore(IChannel ch, string cmd)
RoleMappedData validData = null;
if (!string.IsNullOrWhiteSpace(Args.ValidationFile))
{
- if (!TrainUtils.CanUseValidationData(trainer))
+ if (!trainer.Info.SupportsValidation)
{
ch.Warning("Ignoring validationFile: Trainer does not accept validation dataset.");
}
@@ -235,14 +235,14 @@ public static IPredictor Train(IHostEnvironment env, IChannel ch, RoleMappedData
}
public static IPredictor Train(IHostEnvironment env, IChannel ch, RoleMappedData data, ITrainer trainer, string name, RoleMappedData validData,
- SubComponent calibrator, int maxCalibrationExamples, bool? cacheData, IPredictor inpPredictor = null)
+ SubComponent calibrator, int maxCalibrationExamples, bool? cacheData, IPredictor inputPredictor = null)
{
ICalibratorTrainer caliTrainer = !calibrator.IsGood() ? null : calibrator.CreateInstance(env);
- return TrainCore(env, ch, data, trainer, name, validData, caliTrainer, maxCalibrationExamples, cacheData, inpPredictor);
+ return TrainCore(env, ch, data, trainer, name, validData, caliTrainer, maxCalibrationExamples, cacheData, inputPredictor);
}
private static IPredictor TrainCore(IHostEnvironment env, IChannel ch, RoleMappedData data, ITrainer trainer, string name, RoleMappedData validData,
- ICalibratorTrainer calibrator, int maxCalibrationExamples, bool? cacheData, IPredictor inpPredictor = null)
+ ICalibratorTrainer calibrator, int maxCalibrationExamples, bool? cacheData, IPredictor inputPredictor = null)
{
Contracts.CheckValue(env, nameof(env));
env.CheckValue(ch, nameof(ch));
@@ -250,79 +250,22 @@ private static IPredictor TrainCore(IHostEnvironment env, IChannel ch, RoleMappe
ch.CheckValue(trainer, nameof(trainer));
ch.CheckNonEmpty(name, nameof(name));
ch.CheckValueOrNull(validData);
- ch.CheckValueOrNull(inpPredictor);
+ ch.CheckValueOrNull(inputPredictor);
- var trainerRmd = trainer as ITrainer;
- if (trainerRmd == null)
- throw ch.ExceptUserArg(nameof(TrainCommand.Arguments.Trainer), "Trainer '{0}' does not accept known training data type", name);
-
- Action, object, object, object> trainCoreAction = TrainCore;
- IPredictor predictor;
AddCacheIfWanted(env, ch, trainer, ref data, cacheData);
ch.Trace("Training");
if (validData != null)
AddCacheIfWanted(env, ch, trainer, ref validData, cacheData);
- var genericExam = trainCoreAction.GetMethodInfo().GetGenericMethodDefinition().MakeGenericMethod(
- typeof(RoleMappedData),
- inpPredictor != null ? inpPredictor.GetType() : typeof(IPredictor));
- Action trainExam = trainerRmd.Train;
- genericExam.Invoke(null, new object[] { ch, trainerRmd, trainExam, data, validData, inpPredictor });
-
- ch.Trace("Constructing predictor");
- predictor = trainerRmd.CreatePredictor();
- return CalibratorUtils.TrainCalibratorIfNeeded(env, ch, calibrator, maxCalibrationExamples, trainer, predictor, data);
- }
-
- public static bool CanUseValidationData(ITrainer trainer)
- {
- Contracts.CheckValue(trainer, nameof(trainer));
-
- if (trainer is ITrainer)
- return trainer is IValidatingTrainer;
-
- return false;
- }
-
- private static void TrainCore(IChannel ch, ITrainer trainer, Action train, TDataSet data, TDataSet validData = null, TPredictor predictor = null)
- where TDataSet : class
- where TPredictor : class
- {
- const string inputModelArg = nameof(TrainCommand.Arguments.InputModelFile);
- if (validData != null)
+ if (inputPredictor != null && !trainer.Info.SupportsIncrementalTraining)
{
- if (predictor != null)
- {
- var incValidTrainer = trainer as IIncrementalValidatingTrainer;
- if (incValidTrainer != null)
- {
- incValidTrainer.Train(data, validData, predictor);
- return;
- }
-
- ch.Warning("Ignoring " + inputModelArg + ": Trainer is not an incremental trainer.");
- }
-
- var validTrainer = trainer as IValidatingTrainer;
- ch.AssertValue(validTrainer);
- validTrainer.Train(data, validData);
- }
- else
- {
- if (predictor != null)
- {
- var incTrainer = trainer as IIncrementalTrainer;
- if (incTrainer != null)
- {
- incTrainer.Train(data, predictor);
- return;
- }
-
- ch.Warning("Ignoring " + inputModelArg + ": Trainer is not an incremental trainer.");
- }
-
- train(data);
+ ch.Warning("Ignoring " + nameof(TrainCommand.Arguments.InputModelFile) +
+ ": Trainer does not support incremental training.");
+ inputPredictor = null;
}
+ ch.Assert(validData == null || trainer.Info.SupportsValidation);
+ var predictor = trainer.Train(new TrainContext(data, validData, inputPredictor));
+ return CalibratorUtils.TrainCalibratorIfNeeded(env, ch, calibrator, maxCalibrationExamples, trainer, predictor, data);
}
public static bool TryLoadPredictor(IChannel ch, IHostEnvironment env, string inputModelFile, out IPredictor inputPredictor)
@@ -438,9 +381,8 @@ public static void SaveDataPipe(IHostEnvironment env, RepositoryWriter repositor
IDataView pipeStart;
var xfs = BacktrackPipe(dataPipe, out pipeStart);
- IDataLoader loader;
Action saveAction;
- if (!blankLoader && (loader = pipeStart as IDataLoader) != null)
+ if (!blankLoader && pipeStart is IDataLoader loader)
saveAction = loader.Save;
else
{
@@ -510,7 +452,7 @@ public static bool AddNormalizerIfNeeded(IHostEnvironment env, IChannel ch, ITra
if (autoNorm != NormalizeOption.Yes)
{
DvBool isNormalized = DvBool.False;
- if (trainer.NeedNormalization() != true || schema.IsNormalized(featCol))
+ if (!trainer.Info.NeedNormalization || schema.IsNormalized(featCol))
{
ch.Info("Not adding a normalizer.");
return false;
@@ -541,8 +483,7 @@ private static bool AddCacheIfWanted(IHostEnvironment env, IChannel ch, ITrainer
ch.AssertValue(trainer, nameof(trainer));
ch.AssertValue(data, nameof(data));
- ITrainerEx trainerEx = trainer as ITrainerEx;
- bool shouldCache = cacheData ?? (!(data.Data is BinaryLoader) && (trainerEx == null || trainerEx.WantCaching));
+ bool shouldCache = cacheData ?? !(data.Data is BinaryLoader) && trainer.Info.WantCaching;
if (shouldCache)
{
diff --git a/src/Microsoft.ML.Data/Commands/TrainTestCommand.cs b/src/Microsoft.ML.Data/Commands/TrainTestCommand.cs
index 7c4249c6ee..03ee7cdf12 100644
--- a/src/Microsoft.ML.Data/Commands/TrainTestCommand.cs
+++ b/src/Microsoft.ML.Data/Commands/TrainTestCommand.cs
@@ -152,7 +152,7 @@ private void RunCore(IChannel ch, string cmd)
RoleMappedData validData = null;
if (!string.IsNullOrWhiteSpace(Args.ValidationFile))
{
- if (!TrainUtils.CanUseValidationData(trainer))
+ if (!trainer.Info.SupportsValidation)
{
ch.Warning("Ignoring validationFile: Trainer does not accept validation dataset.");
}
diff --git a/src/Microsoft.ML.Data/EntryPoints/InputBase.cs b/src/Microsoft.ML.Data/EntryPoints/InputBase.cs
index bc45a929d4..94b67af670 100644
--- a/src/Microsoft.ML.Data/EntryPoints/InputBase.cs
+++ b/src/Microsoft.ML.Data/EntryPoints/InputBase.cs
@@ -164,9 +164,8 @@ public static TOut Train(IHost host, TArg input,
}
case CachingOptions.Auto:
{
- ITrainerEx trainerEx = trainer as ITrainerEx;
// REVIEW: we should switch to hybrid caching in future.
- if (!(input.TrainingData is BinaryLoader) && (trainerEx == null || trainerEx.WantCaching))
+ if (!(input.TrainingData is BinaryLoader) && trainer.Info.WantCaching)
// default to Memory so mml is on par with maml
cachingType = Cache.CachingType.Memory;
break;
diff --git a/src/Microsoft.ML.Data/Prediction/Calibrator.cs b/src/Microsoft.ML.Data/Prediction/Calibrator.cs
index 835ba5d99a..1328bb006a 100644
--- a/src/Microsoft.ML.Data/Prediction/Calibrator.cs
+++ b/src/Microsoft.ML.Data/Prediction/Calibrator.cs
@@ -687,8 +687,7 @@ public static class CalibratorUtils
private static bool NeedCalibration(IHostEnvironment env, IChannel ch, ICalibratorTrainer calibrator,
ITrainer trainer, IPredictor predictor, RoleMappedSchema schema)
{
- var trainerEx = trainer as ITrainerEx;
- if (trainerEx == null || !trainerEx.NeedCalibration)
+ if (!trainer.Info.NeedCalibration)
{
ch.Info("Not training a calibrator because it is not needed.");
return false;
diff --git a/src/Microsoft.ML.Data/Training/EarlyStoppingCriteria.cs b/src/Microsoft.ML.Data/Training/EarlyStoppingCriteria.cs
index 13cdb126ee..285db8bfe1 100644
--- a/src/Microsoft.ML.Data/Training/EarlyStoppingCriteria.cs
+++ b/src/Microsoft.ML.Data/Training/EarlyStoppingCriteria.cs
@@ -139,9 +139,9 @@ public class Arguments : ArgumentsBase
public int WindowSize = 5;
}
- protected internal Queue PastScores;
+ protected Queue PastScores;
- internal MovingWindowEarlyStoppingCriterion(Arguments args, bool lowerIsBetter)
+ private protected MovingWindowEarlyStoppingCriterion(Arguments args, bool lowerIsBetter)
: base(args, lowerIsBetter)
{
Contracts.CheckUserArg(0 <= Args.Threshold && args.Threshold <= 1, nameof(args.Threshold), "Must be in range [0,1].");
diff --git a/src/Microsoft.ML.Data/Training/TrainerBase.cs b/src/Microsoft.ML.Data/Training/TrainerBase.cs
index 90f8b64a7c..ca2f2c7b64 100644
--- a/src/Microsoft.ML.Data/Training/TrainerBase.cs
+++ b/src/Microsoft.ML.Data/Training/TrainerBase.cs
@@ -4,59 +4,32 @@
namespace Microsoft.ML.Runtime.Training
{
- public abstract class TrainerBase : ITrainer, ITrainerEx
+ public abstract class TrainerBase : ITrainer
+ where TPredictor : IPredictor
{
- public const string NoTrainingInstancesMessage = "No valid training instances found, all instances have missing features.";
+ ///
+ /// A standard string to use in errors or warnings by subclasses, to communicate the idea that no valid
+ /// instances were able to be found.
+ ///
+ protected const string NoTrainingInstancesMessage = "No valid training instances found, all instances have missing features.";
- protected readonly IHost Host;
+ protected IHost Host { get; }
public string Name { get; }
public abstract PredictionKind PredictionKind { get; }
- public abstract bool NeedNormalization { get; }
- public abstract bool NeedCalibration { get; }
- public abstract bool WantCaching { get; }
+ public abstract TrainerInfo Info { get; }
protected TrainerBase(IHostEnvironment env, string name)
{
Contracts.CheckValue(env, nameof(env));
- Contracts.CheckNonEmpty(name, nameof(name));
+ env.CheckNonEmpty(name, nameof(name));
Name = name;
Host = env.Register(name);
}
- IPredictor ITrainer.CreatePredictor()
- {
- return CreatePredictorCore();
- }
-
- protected abstract IPredictor CreatePredictorCore();
- }
-
- public abstract class TrainerBase : TrainerBase
- where TPredictor : IPredictor
- {
- protected TrainerBase(IHostEnvironment env, string name)
- : base(env, name)
- {
- }
-
- public abstract TPredictor CreatePredictor();
-
- protected sealed override IPredictor CreatePredictorCore()
- {
- return CreatePredictor();
- }
- }
-
- public abstract class TrainerBase : TrainerBase, ITrainer
- where TPredictor : IPredictor
- {
- protected TrainerBase(IHostEnvironment env, string name)
- : base(env, name)
- {
- }
+ IPredictor ITrainer.Train(TrainContext context) => Train(context);
- public abstract void Train(TDataSet data);
+ public abstract TPredictor Train(TrainContext context);
}
}
diff --git a/src/Microsoft.ML.Data/Transforms/NormalizeTransform.cs b/src/Microsoft.ML.Data/Transforms/NormalizeTransform.cs
index f3c560c8dc..7a4738d5e4 100644
--- a/src/Microsoft.ML.Data/Transforms/NormalizeTransform.cs
+++ b/src/Microsoft.ML.Data/Transforms/NormalizeTransform.cs
@@ -205,9 +205,8 @@ private NormalizeTransform(IHost host, ArgumentsBase args, IDataView input,
///
/// The host environment to use to potentially instantiate the transform
/// The role-mapped data that is potentially going to be modified by this method.
- /// The trainer to query with .
- /// This method will not modify if the return from that is null or
- /// false.
+ /// The trainer to query as to whether it wants normalization. If the
+ /// 's is true
/// True if the normalizer was applied and was modified
public static bool CreateIfNeeded(IHostEnvironment env, ref RoleMappedData data, ITrainer trainer)
{
@@ -215,14 +214,12 @@ public static bool CreateIfNeeded(IHostEnvironment env, ref RoleMappedData data,
env.CheckValue(data, nameof(data));
env.CheckValue(trainer, nameof(trainer));
- // If this is false or null, we do not want to normalize.
- if (trainer.NeedNormalization() != true)
- return false;
- // If this is true or null, we do not want to normalize.
- if (data.Schema.FeaturesAreNormalized() != false)
+ // If the trainer does not need normalization, or if the features either don't exist
+ // or are not normalized, return false.
+ if (!trainer.Info.NeedNormalization || data.Schema.FeaturesAreNormalized() != false)
return false;
var featInfo = data.Schema.Feature;
- env.AssertValue(featInfo); // Should be defined, if FEaturesAreNormalized returned a definite value.
+ env.AssertValue(featInfo); // Should be defined, if FeaturesAreNormalized returned a definite value.
var view = CreateMinMaxNormalizer(env, data.Data, name: featInfo.Name);
data = new RoleMappedData(view, data.Schema.GetColumnRoleNames());
@@ -363,20 +360,6 @@ protected override Delegate GetGetterCore(IChannel ch, IRow input, int iinfo, ou
public static class NormalizeUtils
{
- ///
- /// Tells whether the trainer wants normalization.
- ///
- /// This method works via testing whether the trainer implements the optional interface
- /// , via the Boolean property.
- /// If does not implement that interface, then we return null
- /// The trainer to query
- /// Whether the trainer wants normalization
- public static bool? NeedNormalization(this ITrainer trainer)
- {
- Contracts.CheckValue(trainer, nameof(trainer));
- return (trainer as ITrainerEx)?.NeedNormalization;
- }
-
///
/// Returns whether the feature column in the schema is indicated to be normalized. If the features column is not
/// specified on the schema, then this will return null.
diff --git a/src/Microsoft.ML.Ensemble/OutputCombiners/BaseStacking.cs b/src/Microsoft.ML.Ensemble/OutputCombiners/BaseStacking.cs
index f49e3af81c..f30174a31d 100644
--- a/src/Microsoft.ML.Ensemble/OutputCombiners/BaseStacking.cs
+++ b/src/Microsoft.ML.Ensemble/OutputCombiners/BaseStacking.cs
@@ -27,10 +27,10 @@ public abstract class ArgumentsBase
[Argument(ArgumentType.Multiple, HelpText = "Base predictor for meta learning", ShortName = "bp", SortOrder = 50,
Visibility = ArgumentAttribute.VisibilityType.CmdLineOnly)]
[TGUI(Label = "Base predictor")]
- public SubComponent>, TSigBase> BasePredictorType;
+ public SubComponent>, TSigBase> BasePredictorType;
}
- protected readonly SubComponent>, TSigBase> BasePredictorType;
+ protected readonly SubComponent>, TSigBase> BasePredictorType;
protected readonly IHost Host;
protected IPredictorProducing Meta;
@@ -188,10 +188,9 @@ public void Train(List>> models,
var rmd = new RoleMappedData(view, DefaultColumnNames.Label, DefaultColumnNames.Features);
var trainer = BasePredictorType.CreateInstance(host);
- if (trainer is ITrainerEx ex && ex.NeedNormalization)
+ if (trainer.Info.NeedNormalization)
ch.Warning("The trainer specified for stacking wants normalization, but we do not currently allow this.");
- trainer.Train(rmd);
- Meta = trainer.CreatePredictor();
+ Meta = trainer.Train(rmd);
CheckMeta();
ch.Done();
diff --git a/src/Microsoft.ML.Ensemble/OutputCombiners/MultiStacking.cs b/src/Microsoft.ML.Ensemble/OutputCombiners/MultiStacking.cs
index 588dd89508..2ef74c8169 100644
--- a/src/Microsoft.ML.Ensemble/OutputCombiners/MultiStacking.cs
+++ b/src/Microsoft.ML.Ensemble/OutputCombiners/MultiStacking.cs
@@ -43,7 +43,7 @@ public sealed class Arguments : ArgumentsBase, ISupportMulticlassOutputCombinerF
public Arguments()
{
// REVIEW: Perhaps we can have a better non-parametetric learner.
- BasePredictorType = new SubComponent, SignatureMultiClassClassifierTrainer>(
+ BasePredictorType = new SubComponent, SignatureMultiClassClassifierTrainer>(
"OVA", "p=FastTreeBinaryClassification");
}
}
diff --git a/src/Microsoft.ML.Ensemble/OutputCombiners/RegressionStacking.cs b/src/Microsoft.ML.Ensemble/OutputCombiners/RegressionStacking.cs
index aeb011a51b..0b5f8e6057 100644
--- a/src/Microsoft.ML.Ensemble/OutputCombiners/RegressionStacking.cs
+++ b/src/Microsoft.ML.Ensemble/OutputCombiners/RegressionStacking.cs
@@ -39,7 +39,7 @@ public sealed class Arguments : ArgumentsBase, ISupportRegressionOutputCombinerF
{
public Arguments()
{
- BasePredictorType = new SubComponent, SignatureRegressorTrainer>("FastTreeRegression");
+ BasePredictorType = new SubComponent, SignatureRegressorTrainer>("FastTreeRegression");
}
public IRegressionOutputCombiner CreateComponent(IHostEnvironment env) => new RegressionStacking(env, this);
diff --git a/src/Microsoft.ML.Ensemble/OutputCombiners/Stacking.cs b/src/Microsoft.ML.Ensemble/OutputCombiners/Stacking.cs
index afd6e3f958..f3481e9936 100644
--- a/src/Microsoft.ML.Ensemble/OutputCombiners/Stacking.cs
+++ b/src/Microsoft.ML.Ensemble/OutputCombiners/Stacking.cs
@@ -37,7 +37,7 @@ public sealed class Arguments : ArgumentsBase, ISupportBinaryOutputCombinerFacto
{
public Arguments()
{
- BasePredictorType = new SubComponent, SignatureBinaryClassifierTrainer>("FastTreeBinaryClassification");
+ BasePredictorType = new SubComponent, SignatureBinaryClassifierTrainer>("FastTreeBinaryClassification");
}
public IBinaryOutputCombiner CreateComponent(IHostEnvironment env) => new Stacking(env, this);
diff --git a/src/Microsoft.ML.Ensemble/Selector/SubModelSelector/BaseDiverseSelector.cs b/src/Microsoft.ML.Ensemble/Selector/SubModelSelector/BaseDiverseSelector.cs
index 56f7e1edfc..5e31b2c8f5 100644
--- a/src/Microsoft.ML.Ensemble/Selector/SubModelSelector/BaseDiverseSelector.cs
+++ b/src/Microsoft.ML.Ensemble/Selector/SubModelSelector/BaseDiverseSelector.cs
@@ -23,7 +23,7 @@ public abstract class DiverseSelectorArguments : ArgumentsBase
private readonly IComponentFactory> _diversityMetricType;
private ConcurrentDictionary>, TOutput[]> _predictions;
- protected internal BaseDiverseSelector(IHostEnvironment env, DiverseSelectorArguments args, string name,
+ private protected BaseDiverseSelector(IHostEnvironment env, DiverseSelectorArguments args, string name,
IComponentFactory> diversityMetricType)
: base(args, env, name)
{
diff --git a/src/Microsoft.ML.Ensemble/Trainer/Binary/EnsembleTrainer.cs b/src/Microsoft.ML.Ensemble/Trainer/Binary/EnsembleTrainer.cs
index 80fe4cbdee..ab2dff5045 100644
--- a/src/Microsoft.ML.Ensemble/Trainer/Binary/EnsembleTrainer.cs
+++ b/src/Microsoft.ML.Ensemble/Trainer/Binary/EnsembleTrainer.cs
@@ -46,7 +46,7 @@ public sealed class Arguments : ArgumentsBase
public Arguments()
{
- BasePredictors = new[] { new SubComponent, SignatureBinaryClassifierTrainer>("LinearSVM") };
+ BasePredictors = new[] { new SubComponent, SignatureBinaryClassifierTrainer>("LinearSVM") };
}
}
@@ -60,16 +60,13 @@ public EnsembleTrainer(IHostEnvironment env, Arguments args)
Combiner = args.OutputCombiner.CreateComponent(Host);
}
- public override PredictionKind PredictionKind
- {
- get { return PredictionKind.BinaryClassification; }
- }
+ public override PredictionKind PredictionKind => PredictionKind.BinaryClassification;
- public override TScalarPredictor CreatePredictor()
+ private protected override TScalarPredictor CreatePredictor(List> models)
{
- if (Models.All(m => m.Predictor is TDistPredictor))
- return new EnsembleDistributionPredictor(Host, PredictionKind, CreateModels(), Combiner);
- return new EnsemblePredictor(Host, PredictionKind, CreateModels(), Combiner);
+ if (models.All(m => m.Predictor is TDistPredictor))
+ return new EnsembleDistributionPredictor(Host, PredictionKind, CreateModels(models), Combiner);
+ return new EnsemblePredictor(Host, PredictionKind, CreateModels(models), Combiner);
}
public TScalarPredictor CombineModels(IEnumerable models)
@@ -77,19 +74,13 @@ public TScalarPredictor CombineModels(IEnumerable models)
var combiner = _outputCombiner.CreateComponent(Host);
var p = models.First();
- TScalarPredictor predictor = null;
if (p is TDistPredictor)
{
- predictor = new EnsembleDistributionPredictor(Host, p.PredictionKind,
+ return new EnsembleDistributionPredictor(Host, p.PredictionKind,
models.Select(k => new FeatureSubsetModel((TDistPredictor)k)).ToArray(), combiner);
}
- else
- {
- predictor = new EnsemblePredictor(Host, p.PredictionKind,
+ return new EnsemblePredictor(Host, p.PredictionKind,
models.Select(k => new FeatureSubsetModel(k)).ToArray(), combiner);
- }
-
- return predictor;
}
}
}
\ No newline at end of file
diff --git a/src/Microsoft.ML.Ensemble/Trainer/EnsembleTrainerBase.cs b/src/Microsoft.ML.Ensemble/Trainer/EnsembleTrainerBase.cs
index 776b1f5f53..0a350ef8ee 100644
--- a/src/Microsoft.ML.Ensemble/Trainer/EnsembleTrainerBase.cs
+++ b/src/Microsoft.ML.Ensemble/Trainer/EnsembleTrainerBase.cs
@@ -20,7 +20,7 @@ namespace Microsoft.ML.Runtime.Ensemble
{
using Stopwatch = System.Diagnostics.Stopwatch;
- public abstract class EnsembleTrainerBase : TrainerBase
+ public abstract class EnsembleTrainerBase : TrainerBase
where TPredictor : class, IPredictorProducing
where TSelector : class, ISubModelSelector
where TCombiner : class, IOutputCombiner
@@ -54,27 +54,24 @@ public abstract class ArgumentsBase : LearnerInputBaseWithLabel
public bool ShowMetrics;
[Argument(ArgumentType.Multiple, HelpText = "Base predictor type", ShortName = "bp,basePredictorTypes", SortOrder = 1, Visibility = ArgumentAttribute.VisibilityType.CmdLineOnly)]
- public SubComponent>, TSig>[] BasePredictors;
+ public SubComponent>, TSig>[] BasePredictors;
}
private const int DefaultNumModels = 50;
/// Command-line arguments
- protected readonly ArgumentsBase Args;
- protected readonly int NumModels;
+ private protected readonly ArgumentsBase Args;
+ private protected readonly int NumModels;
/// Ensemble members
- protected readonly ITrainer>[] Trainers;
+ private protected readonly ITrainer>[] Trainers;
private readonly ISubsetSelector _subsetSelector;
- protected ISubModelSelector SubModelSelector;
- protected IOutputCombiner Combiner;
+ private protected ISubModelSelector SubModelSelector;
+ private protected IOutputCombiner Combiner;
- protected List>> Models;
+ public override TrainerInfo Info { get; }
- private readonly bool _needNorm;
- private readonly bool _needCalibration;
-
- internal EnsembleTrainerBase(ArgumentsBase args, IHostEnvironment env, string name)
+ private protected EnsembleTrainerBase(ArgumentsBase args, IHostEnvironment env, string name)
: base(env, name)
{
Args = args;
@@ -93,41 +90,31 @@ internal EnsembleTrainerBase(ArgumentsBase args, IHostEnvironment env, string na
_subsetSelector = Args.SamplingType.CreateComponent(Host);
- Trainers = new ITrainer>[NumModels];
+ Trainers = new ITrainer>[NumModels];
for (int i = 0; i < Trainers.Length; i++)
Trainers[i] = Args.BasePredictors[i % Args.BasePredictors.Length].CreateInstance(Host);
- _needNorm = Trainers.Any(
- t =>
- {
- return t is ITrainerEx nn && nn.NeedNormalization;
- });
- _needCalibration = Trainers.Any(
- t =>
- {
- return t is ITrainerEx nn && nn.NeedCalibration;
- });
+ // We infer normalization and calibration preferences from the trainers. However, even if the internal trainers
+ // don't need caching we are performing multiple passes over the data, so it is probably appropriate to always cache.
+ Info = new TrainerInfo(
+ normalization: Trainers.Any(t => t.Info.NeedNormalization),
+ calibration: Trainers.Any(t => t.Info.NeedCalibration));
ch.Done();
}
}
- public override bool NeedNormalization => _needNorm;
-
- public override bool NeedCalibration => _needCalibration;
-
- // No matter the internal predictors, we are performing multiple passes over the data
- // so it is probably appropriate to always cache.
- public override bool WantCaching => true;
-
- public override void Train(RoleMappedData data)
+ public sealed override TPredictor Train(TrainContext context)
{
+ Host.CheckValue(context, nameof(context));
+
using (var ch = Host.Start("Training"))
{
- TrainCore(ch, data);
+ var pred = TrainCore(ch, context.TrainingSet);
ch.Done();
+ return pred;
}
}
- private void TrainCore(IChannel ch, RoleMappedData data)
+ private TPredictor TrainCore(IChannel ch, RoleMappedData data)
{
Host.AssertValue(ch);
ch.AssertValue(data);
@@ -143,6 +130,7 @@ private void TrainCore(IChannel ch, RoleMappedData data)
validationDataSetProportion = Math.Max(validationDataSetProportion, stackingTrainer.ValidationDatasetProportion);
var needMetrics = Args.ShowMetrics || Combiner is IWeightedAverager;
+ var models = new List>>();
_subsetSelector.Initialize(data, NumModels, Args.BatchSize, validationDataSetProportion);
int batchNumber = 1;
@@ -150,7 +138,7 @@ private void TrainCore(IChannel ch, RoleMappedData data)
{
// 2. Core train
ch.Info("Training {0} learners for the batch {1}", Trainers.Length, batchNumber++);
- var models = new FeatureSubsetModel>[Trainers.Length];
+ var batchModels = new FeatureSubsetModel>[Trainers.Length];
Parallel.ForEach(_subsetSelector.GetSubsets(batch, Host.Rand),
new ParallelOptions() { MaxDegreeOfParallelism = Args.TrainParallel ? -1 : 1 },
@@ -162,26 +150,24 @@ private void TrainCore(IChannel ch, RoleMappedData data)
{
if (EnsureMinimumFeaturesSelected(subset))
{
- Trainers[(int)index].Train(subset.Data);
-
var model = new FeatureSubsetModel>(
- Trainers[(int)index].CreatePredictor(),
+ Trainers[(int)index].Train(subset.Data),
subset.SelectedFeatures,
null);
SubModelSelector.CalculateMetrics(model, _subsetSelector, subset, batch, needMetrics);
- models[(int)index] = model;
+ batchModels[(int)index] = model;
}
}
catch (Exception ex)
{
- ch.Assert(models[(int)index] == null);
+ ch.Assert(batchModels[(int)index] == null);
ch.Warning(ex.Sensitivity(), "Trainer {0} of {1} was not learned properly due to the exception '{2}' and will not be added to models.",
index + 1, Trainers.Length, ex.Message);
}
ch.Info("Trainer {0} of {1} finished in {2}", index + 1, Trainers.Length, sw.Elapsed);
});
- var modelsList = models.Where(m => m != null).ToList();
+ var modelsList = batchModels.Where(m => m != null).ToList();
if (Args.ShowMetrics)
PrintMetrics(ch, modelsList);
@@ -190,15 +176,17 @@ private void TrainCore(IChannel ch, RoleMappedData data)
if (stackingTrainer != null)
stackingTrainer.Train(modelsList, _subsetSelector.GetTestData(null, batch), Host);
- foreach (var model in modelsList)
- Utils.Add(ref Models, model);
- int modelSize = Utils.Size(Models);
+ models.AddRange(modelsList);
+ int modelSize = Utils.Size(models);
if (modelSize < Utils.Size(Trainers))
ch.Warning("{0} of {1} trainings failed.", Utils.Size(Trainers) - modelSize, Utils.Size(Trainers));
ch.Check(modelSize > 0, "Ensemble training resulted in no valid models.");
}
+ return CreatePredictor(models);
}
+ private protected abstract TPredictor CreatePredictor(List>> models);
+
private bool EnsureMinimumFeaturesSelected(Subset subset)
{
if (subset.SelectedFeatures == null)
@@ -212,7 +200,7 @@ private bool EnsureMinimumFeaturesSelected(Subset subset)
return false;
}
- protected virtual void PrintMetrics(IChannel ch, List>> models)
+ private protected virtual void PrintMetrics(IChannel ch, List>> models)
{
// REVIEW: The formatting of this method is bizarre and seemingly not even self-consistent
// w.r.t. its usage of |. Is this intentional?
@@ -225,17 +213,17 @@ protected virtual void PrintMetrics(IChannel ch, List string.Format("| {0} |", m.Value))), model.Predictor.GetType().Name);
}
- protected FeatureSubsetModel[] CreateModels() where T : IPredictor
+ private protected static FeatureSubsetModel[] CreateModels(List>> models) where T : IPredictor
{
- var models = new FeatureSubsetModel[Models.Count];
- for (int i = 0; i < Models.Count; i++)
+ var subsetModels = new FeatureSubsetModel[models.Count];
+ for (int i = 0; i < models.Count; i++)
{
- models[i] = new FeatureSubsetModel(
- (T)Models[i].Predictor,
- Models[i].SelectedFeatures,
- Models[i].Metrics);
+ subsetModels[i] = new FeatureSubsetModel(
+ (T)models[i].Predictor,
+ models[i].SelectedFeatures,
+ models[i].Metrics);
}
- return models;
+ return subsetModels;
}
}
}
diff --git a/src/Microsoft.ML.Ensemble/Trainer/Multiclass/MulticlassDataPartitionEnsembleTrainer.cs b/src/Microsoft.ML.Ensemble/Trainer/Multiclass/MulticlassDataPartitionEnsembleTrainer.cs
index 0e6b4f6a53..4421cd5838 100644
--- a/src/Microsoft.ML.Ensemble/Trainer/Multiclass/MulticlassDataPartitionEnsembleTrainer.cs
+++ b/src/Microsoft.ML.Ensemble/Trainer/Multiclass/MulticlassDataPartitionEnsembleTrainer.cs
@@ -47,7 +47,7 @@ public sealed class Arguments : ArgumentsBase
public Arguments()
{
- BasePredictors = new[] { new SubComponent, SignatureMultiClassClassifierTrainer>("MultiClassLogisticRegression") };
+ BasePredictors = new[] { new SubComponent, SignatureMultiClassClassifierTrainer>("MultiClassLogisticRegression") };
}
}
@@ -61,12 +61,11 @@ public MulticlassDataPartitionEnsembleTrainer(IHostEnvironment env, Arguments ar
Combiner = args.OutputCombiner.CreateComponent(Host);
}
- public override PredictionKind PredictionKind { get { return PredictionKind.MultiClassClassification; } }
+ public override PredictionKind PredictionKind => PredictionKind.MultiClassClassification;
- public override EnsembleMultiClassPredictor CreatePredictor()
+ private protected override EnsembleMultiClassPredictor CreatePredictor(List> models)
{
- var combiner = Combiner;
- return new EnsembleMultiClassPredictor(Host, CreateModels(), combiner as IMultiClassOutputCombiner);
+ return new EnsembleMultiClassPredictor(Host, CreateModels(models), Combiner as IMultiClassOutputCombiner);
}
public TVectorPredictor CombineModels(IEnumerable models)
diff --git a/src/Microsoft.ML.Ensemble/Trainer/Regression/RegressionEnsembleTrainer.cs b/src/Microsoft.ML.Ensemble/Trainer/Regression/RegressionEnsembleTrainer.cs
index 322c1e02a1..1cc36f20cd 100644
--- a/src/Microsoft.ML.Ensemble/Trainer/Regression/RegressionEnsembleTrainer.cs
+++ b/src/Microsoft.ML.Ensemble/Trainer/Regression/RegressionEnsembleTrainer.cs
@@ -41,7 +41,7 @@ public sealed class Arguments : ArgumentsBase
public Arguments()
{
- BasePredictors = new[] { new SubComponent, SignatureRegressorTrainer>("OnlineGradientDescent") };
+ BasePredictors = new[] { new SubComponent, SignatureRegressorTrainer>("OnlineGradientDescent") };
}
}
@@ -55,14 +55,11 @@ public RegressionEnsembleTrainer(IHostEnvironment env, Arguments args)
Combiner = args.OutputCombiner.CreateComponent(Host);
}
- public override PredictionKind PredictionKind
- {
- get { return PredictionKind.Regression; }
- }
+ public override PredictionKind PredictionKind => PredictionKind.Regression;
- public override TScalarPredictor CreatePredictor()
+ private protected override TScalarPredictor CreatePredictor(List> models)
{
- return new EnsemblePredictor(Host, PredictionKind, CreateModels(), Combiner);
+ return new EnsemblePredictor(Host, PredictionKind, CreateModels(models), Combiner);
}
public TScalarPredictor CombineModels(IEnumerable models)
diff --git a/src/Microsoft.ML.FastTree/FastTree.cs b/src/Microsoft.ML.FastTree/FastTree.cs
index e7432139c5..29fe439e0a 100644
--- a/src/Microsoft.ML.FastTree/FastTree.cs
+++ b/src/Microsoft.ML.FastTree/FastTree.cs
@@ -44,8 +44,7 @@ internal static class FastTreeShared
}
public abstract class FastTreeTrainerBase :
- TrainerBase,
- IValidatingTrainer
+ TrainerBase
where TArgs : TreeArgs, new()
where TPredictor : IPredictorProducing
{
@@ -82,17 +81,21 @@ public abstract class FastTreeTrainerBase :
protected string InnerArgs => CmdParser.GetSettings(Host, Args, new TArgs());
- public override bool NeedNormalization => false;
-
- public override bool WantCaching => false;
+ public override TrainerInfo Info { get; }
public bool HasCategoricalFeatures => Utils.Size(CategoricalFeatures) > 0;
- protected internal FastTreeTrainerBase(IHostEnvironment env, TArgs args)
+ private protected virtual bool NeedCalibration => false;
+
+ private protected FastTreeTrainerBase(IHostEnvironment env, TArgs args)
: base(env, RegisterName)
{
Host.CheckValue(args, nameof(args));
Args = args;
+ // The discretization step renders this trainer non-parametric, and therefore it does not need normalization.
+ // Also since it builds its own internal discretized columnar structures, it cannot benefit from caching.
+ // Finally, even the binary classifiers, being logitboost, tend to not benefit from external calibration.
+ Info = new TrainerInfo(normalization: false, caching: false, calibration: NeedCalibration);
int numThreads = Args.NumThreads ?? Environment.ProcessorCount;
if (Host.ConcurrencyFactor > 0 && numThreads > Host.ConcurrencyFactor)
{
@@ -125,14 +128,6 @@ protected internal FastTreeTrainerBase(IHostEnvironment env, TArgs args)
protected abstract ObjectiveFunctionBase ConstructObjFunc(IChannel ch);
- public void Train(RoleMappedData trainData, RoleMappedData validationData)
- {
- // REVIEW: Idiotic. This should be reversed... the other train method should
- // be put in here, rather than having this "hidden argument" through an instance field.
- ValidData = validationData;
- Train(trainData);
- }
-
protected virtual Float GetMaxLabel()
{
return Float.PositiveInfinity;
@@ -1887,7 +1882,7 @@ private void MakeBoundariesAndCheckLabels(out long missingInstances, out long to
missingInstances = cursor.BadFeaturesRowCount;
}
- ch.Check(totalInstances > 0, TrainerBase.NoTrainingInstancesMessage);
+ ch.Check(totalInstances > 0, "All instances skipped due to missing features.");
if (missingInstances > 0)
ch.Warning("Skipped {0} instances with missing features during training", missingInstances);
@@ -2813,7 +2808,7 @@ public abstract class FastTreePredictionWrapper :
public bool CanSavePfa => true;
public bool CanSaveOnnx => true;
- protected internal FastTreePredictionWrapper(IHostEnvironment env, string name, Ensemble trainedEnsemble, int numFeatures, string innerArgs)
+ protected FastTreePredictionWrapper(IHostEnvironment env, string name, Ensemble trainedEnsemble, int numFeatures, string innerArgs)
: base(env, name)
{
Host.CheckValue(trainedEnsemble, nameof(trainedEnsemble));
diff --git a/src/Microsoft.ML.FastTree/FastTreeClassification.cs b/src/Microsoft.ML.FastTree/FastTreeClassification.cs
index 18f61e6dbe..4eca3e15fb 100644
--- a/src/Microsoft.ML.FastTree/FastTreeClassification.cs
+++ b/src/Microsoft.ML.FastTree/FastTreeClassification.cs
@@ -62,11 +62,11 @@ private static VersionInfo GetVersionInfo()
loaderSignature: LoaderSignature);
}
- protected override uint VerNumFeaturesSerialized { get { return 0x00010002; } }
+ protected override uint VerNumFeaturesSerialized => 0x00010002;
- protected override uint VerDefaultValueSerialized { get { return 0x00010004; } }
+ protected override uint VerDefaultValueSerialized => 0x00010004;
- protected override uint VerCategoricalSplitSerialized { get { return 0x00010005; } }
+ protected override uint VerCategoricalSplitSerialized => 0x00010005;
internal FastTreeBinaryPredictor(IHostEnvironment env, Ensemble trainedEnsemble, int featureCount, string innerArgs)
: base(env, RegistrationName, trainedEnsemble, featureCount, innerArgs)
@@ -97,7 +97,7 @@ public static IPredictorProducing Create(IHostEnvironment env, ModelLoadC
return new SchemaBindableCalibratedPredictor(env, predictor, calibrator);
}
- public override PredictionKind PredictionKind { get { return PredictionKind.BinaryClassification; } }
+ public override PredictionKind PredictionKind => PredictionKind.BinaryClassification;
}
///
@@ -116,12 +116,14 @@ public FastTreeBinaryClassificationTrainer(IHostEnvironment env, Arguments args)
{
}
- public override bool NeedCalibration => false;
+ public override PredictionKind PredictionKind => PredictionKind.BinaryClassification;
- public override PredictionKind PredictionKind { get { return PredictionKind.BinaryClassification; } }
-
- public override void Train(RoleMappedData trainData)
+ public override IPredictorWithFeatureWeights Train(TrainContext context)
{
+ Host.CheckValue(context, nameof(context));
+ var trainData = context.TrainingSet;
+ ValidData = context.ValidationSet;
+
using (var ch = Host.Start("Training"))
{
ch.CheckValue(trainData, nameof(trainData));
@@ -133,12 +135,6 @@ public override void Train(RoleMappedData trainData)
TrainCore(ch);
ch.Done();
}
- }
-
- public override IPredictorWithFeatureWeights CreatePredictor()
- {
- Host.Check(TrainedEnsemble != null,
- "The predictor cannot be created before training is complete");
// The FastTree binary classification boosting is naturally calibrated to
// output probabilities when transformed using a scaled logistic function,
diff --git a/src/Microsoft.ML.FastTree/FastTreeRanking.cs b/src/Microsoft.ML.FastTree/FastTreeRanking.cs
index e0b6933726..0e44553dee 100644
--- a/src/Microsoft.ML.FastTree/FastTreeRanking.cs
+++ b/src/Microsoft.ML.FastTree/FastTreeRanking.cs
@@ -51,8 +51,6 @@ public sealed partial class FastTreeRankingTrainer : BoostingFastTreeTrainerBase
private Test _specialTrainSetTest;
private TestHistory _firstTestSetHistory;
- public override bool NeedCalibration => false;
-
public override PredictionKind PredictionKind => PredictionKind.Ranking;
public FastTreeRankingTrainer(IHostEnvironment env, Arguments args)
@@ -65,8 +63,12 @@ protected override float GetMaxLabel()
return GetLabelGains().Length - 1;
}
- public override void Train(RoleMappedData trainData)
+ public override FastTreeRankingPredictor Train(TrainContext context)
{
+ Host.CheckValue(context, nameof(context));
+ var trainData = context.TrainingSet;
+ ValidData = context.ValidationSet;
+
using (var ch = Host.Start("Training"))
{
var maxLabel = GetLabelGains().Length - 1;
@@ -75,11 +77,6 @@ public override void Train(RoleMappedData trainData)
FeatureCount = trainData.Schema.Feature.Type.ValueCount;
ch.Done();
}
- }
-
- public override FastTreeRankingPredictor CreatePredictor()
- {
- Host.Check(TrainedEnsemble != null, "The predictor cannot be created before training is complete");
return new FastTreeRankingPredictor(Host, TrainedEnsemble, FeatureCount, InnerArgs);
}
@@ -1063,11 +1060,11 @@ private static VersionInfo GetVersionInfo()
loaderSignature: LoaderSignature);
}
- protected override uint VerNumFeaturesSerialized { get { return 0x00010002; } }
+ protected override uint VerNumFeaturesSerialized => 0x00010002;
- protected override uint VerDefaultValueSerialized { get { return 0x00010004; } }
+ protected override uint VerDefaultValueSerialized => 0x00010004;
- protected override uint VerCategoricalSplitSerialized { get { return 0x00010005; } }
+ protected override uint VerCategoricalSplitSerialized => 0x00010005;
internal FastTreeRankingPredictor(IHostEnvironment env, Ensemble trainedEnsemble, int featureCount, string innerArgs)
: base(env, RegistrationName, trainedEnsemble, featureCount, innerArgs)
@@ -1090,7 +1087,7 @@ public static FastTreeRankingPredictor Create(IHostEnvironment env, ModelLoadCon
return new FastTreeRankingPredictor(env, ctx);
}
- public override PredictionKind PredictionKind { get { return PredictionKind.Ranking; } }
+ public override PredictionKind PredictionKind => PredictionKind.Ranking;
}
public static partial class FastTree
diff --git a/src/Microsoft.ML.FastTree/FastTreeRegression.cs b/src/Microsoft.ML.FastTree/FastTreeRegression.cs
index 308437440a..1e78f4c473 100644
--- a/src/Microsoft.ML.FastTree/FastTreeRegression.cs
+++ b/src/Microsoft.ML.FastTree/FastTreeRegression.cs
@@ -43,23 +43,21 @@ public sealed partial class FastTreeRegressionTrainer : BoostingFastTreeTrainerB
private Test _trainRegressionTest;
private Test _testRegressionTest;
+ public override PredictionKind PredictionKind => PredictionKind.Regression;
+
public FastTreeRegressionTrainer(IHostEnvironment env, Arguments args)
: base(env, args)
{
}
- public override bool NeedCalibration
+ public override FastTreeRegressionPredictor Train(TrainContext context)
{
- get { return false; }
- }
+ Host.CheckValue(context, nameof(context));
+ var trainData = context.TrainingSet;
+ ValidData = context.ValidationSet;
- public override PredictionKind PredictionKind { get { return PredictionKind.Regression; } }
-
- public override void Train(RoleMappedData trainData)
- {
using (var ch = Host.Start("Training"))
{
- ch.CheckValue(trainData, nameof(trainData));
trainData.CheckRegressionLabel();
trainData.CheckFeatureFloatVector();
trainData.CheckOptFloatWeight();
@@ -68,12 +66,6 @@ public override void Train(RoleMappedData trainData)
TrainCore(ch);
ch.Done();
}
- }
-
- public override FastTreeRegressionPredictor CreatePredictor()
- {
- Host.Check(TrainedEnsemble != null,
- "The predictor cannot be created before training is complete");
return new FastTreeRegressionPredictor(Host, TrainedEnsemble, FeatureCount, InnerArgs);
}
@@ -414,11 +406,11 @@ private static VersionInfo GetVersionInfo()
loaderSignature: LoaderSignature);
}
- protected override uint VerNumFeaturesSerialized { get { return 0x00010002; } }
+ protected override uint VerNumFeaturesSerialized => 0x00010002;
- protected override uint VerDefaultValueSerialized { get { return 0x00010004; } }
+ protected override uint VerDefaultValueSerialized => 0x00010004;
- protected override uint VerCategoricalSplitSerialized { get { return 0x00010005; } }
+ protected override uint VerCategoricalSplitSerialized => 0x00010005;
internal FastTreeRegressionPredictor(IHostEnvironment env, Ensemble trainedEnsemble, int featureCount, string innerArgs)
: base(env, RegistrationName, trainedEnsemble, featureCount, innerArgs)
@@ -444,7 +436,7 @@ public static FastTreeRegressionPredictor Create(IHostEnvironment env, ModelLoad
return new FastTreeRegressionPredictor(env, ctx);
}
- public override PredictionKind PredictionKind { get { return PredictionKind.Regression; } }
+ public override PredictionKind PredictionKind => PredictionKind.Regression;
}
public static partial class FastTree
diff --git a/src/Microsoft.ML.FastTree/FastTreeTweedie.cs b/src/Microsoft.ML.FastTree/FastTreeTweedie.cs
index b43c499a44..6c1b56b1eb 100644
--- a/src/Microsoft.ML.FastTree/FastTreeTweedie.cs
+++ b/src/Microsoft.ML.FastTree/FastTreeTweedie.cs
@@ -42,12 +42,7 @@ public sealed partial class FastTreeTweedieTrainer : BoostingFastTreeTrainerBase
private Test _trainRegressionTest;
private Test _testRegressionTest;
- public override bool NeedCalibration
- {
- get { return false; }
- }
-
- public override PredictionKind PredictionKind { get { return PredictionKind.Regression; } }
+ public override PredictionKind PredictionKind => PredictionKind.Regression;
public FastTreeTweedieTrainer(IHostEnvironment env, Arguments args)
: base(env, args)
@@ -55,8 +50,12 @@ public FastTreeTweedieTrainer(IHostEnvironment env, Arguments args)
Host.CheckUserArg(1 <= Args.Index && Args.Index <= 2, nameof(Args.Index), "Must be in the range [1, 2]");
}
- public override void Train(RoleMappedData trainData)
+ public override FastTreeTweediePredictor Train(TrainContext context)
{
+ Host.CheckValue(context, nameof(context));
+ var trainData = context.TrainingSet;
+ ValidData = context.ValidationSet;
+
using (var ch = Host.Start("Training"))
{
ch.CheckValue(trainData, nameof(trainData));
@@ -68,12 +67,6 @@ public override void Train(RoleMappedData trainData)
TrainCore(ch);
ch.Done();
}
- }
-
- public override FastTreeTweediePredictor CreatePredictor()
- {
- Host.Check(TrainedEnsemble != null,
- "The predictor cannot be created before training is complete");
return new FastTreeTweediePredictor(Host, TrainedEnsemble, FeatureCount, InnerArgs);
}
@@ -409,11 +402,11 @@ private static VersionInfo GetVersionInfo()
loaderSignature: LoaderSignature);
}
- protected override uint VerNumFeaturesSerialized { get { return 0x00010001; } }
+ protected override uint VerNumFeaturesSerialized => 0x00010001;
- protected override uint VerDefaultValueSerialized { get { return 0x00010002; } }
+ protected override uint VerDefaultValueSerialized => 0x00010002;
- protected override uint VerCategoricalSplitSerialized { get { return 0x00010003; } }
+ protected override uint VerCategoricalSplitSerialized => 0x00010003;
internal FastTreeTweediePredictor(IHostEnvironment env, Ensemble trainedEnsemble, int featureCount, string innerArgs)
: base(env, RegistrationName, trainedEnsemble, featureCount, innerArgs)
@@ -452,7 +445,7 @@ protected override void Map(ref VBuffer src, ref float dst)
dst = MathUtils.ExpSlow(dst);
}
- public override PredictionKind PredictionKind { get { return PredictionKind.Regression; } }
+ public override PredictionKind PredictionKind => PredictionKind.Regression;
}
public static partial class FastTree
diff --git a/src/Microsoft.ML.FastTree/GamTrainer.cs b/src/Microsoft.ML.FastTree/GamTrainer.cs
index 931dd2335b..3b3ca9e92f 100644
--- a/src/Microsoft.ML.FastTree/GamTrainer.cs
+++ b/src/Microsoft.ML.FastTree/GamTrainer.cs
@@ -75,7 +75,7 @@ internal override void CheckLabel(RoleMappedData data)
data.CheckRegressionLabel();
}
- public override RegressionGamPredictor CreatePredictor()
+ private protected override RegressionGamPredictor CreatePredictor()
{
return new RegressionGamPredictor(Host, InputLength, TrainSet, BinEffects, FeatureMap);
}
@@ -107,7 +107,7 @@ public sealed class Arguments : ArgumentsBase
internal const string ShortName = "gam";
public override PredictionKind PredictionKind => PredictionKind.BinaryClassification;
- public override bool NeedCalibration => true;
+ private protected override bool NeedCalibration => true;
public BinaryClassificationGamTrainer(IHostEnvironment env, Arguments args)
: base(env, args) { }
@@ -137,7 +137,7 @@ private bool[] ConvertTargetsToBool(double[] targets)
return boolArray;
}
- public override BinaryClassGamPredictor CreatePredictor()
+ private protected override BinaryClassGamPredictor CreatePredictor()
{
return new BinaryClassGamPredictor(Host, InputLength, TrainSet, BinEffects, FeatureMap);
}
@@ -152,9 +152,7 @@ protected override ObjectiveFunctionBase CreateObjectiveFunction()
///
/// Generalized Additive Model Learner.
///
- public abstract partial class GamTrainerBase :
- TrainerBase,
- ITrainer
+ public abstract partial class GamTrainerBase : TrainerBase
where TArgs : GamTrainerBase.ArgumentsBase, new()
where TPredictor : GamPredictorBase
{
@@ -227,13 +225,10 @@ public abstract class ArgumentsBase : LearnerInputBaseWithWeight
protected double[][] BinEffects;
protected int[] FeatureMap;
- public override bool NeedCalibration => false;
+ public override TrainerInfo Info { get; }
+ private protected virtual bool NeedCalibration => false;
- public override bool NeedNormalization => false;
-
- public override bool WantCaching => false;
-
- public GamTrainerBase(IHostEnvironment env, TArgs args)
+ private protected GamTrainerBase(IHostEnvironment env, TArgs args)
: base(env, RegisterName)
{
Contracts.CheckValue(env, nameof(env));
@@ -247,6 +242,7 @@ public GamTrainerBase(IHostEnvironment env, TArgs args)
Host.CheckParam(0 < args.NumIterations, nameof(args.NumIterations), "Must be positive.");
Args = args;
+ Info = new TrainerInfo(normalization: false, calibration: NeedCalibration, caching: false);
_gainConfidenceInSquaredStandardDeviations = Math.Pow(ProbabilityFunctions.Probit(1 - (1 - Args.GainConfidenceLevel) * 0.5), 2);
_entropyCoefficient = Args.EntropyCoefficient * 1e-6;
int numThreads = args.NumThreads ?? Environment.ProcessorCount;
@@ -264,18 +260,22 @@ public GamTrainerBase(IHostEnvironment env, TArgs args)
InitializeThreads(numThreads);
}
- public override void Train(RoleMappedData trainData)
+ public sealed override TPredictor Train(TrainContext context)
{
using (var ch = Host.Start("Training"))
{
- ch.CheckValue(trainData, nameof(trainData));
- ConvertData(trainData);
- InputLength = trainData.Schema.Feature.Type.ValueCount;
+ ch.CheckValue(context, nameof(context));
+ ConvertData(context.TrainingSet);
+ InputLength = context.TrainingSet.Schema.Feature.Type.ValueCount;
TrainCore(ch);
+ var pred = CreatePredictor();
ch.Done();
+ return pred;
}
}
+ private protected abstract TPredictor CreatePredictor();
+
internal abstract void CheckLabel(RoleMappedData data);
private void ConvertData(RoleMappedData trainData)
@@ -569,7 +569,7 @@ public abstract class GamPredictorBase : PredictorBase,
public ColumnType OutputType => NumberType.Float;
- protected internal GamPredictorBase(IHostEnvironment env, string name, int inputLength, Dataset trainSet, double[][] binEffects, int[] featureMap)
+ private protected GamPredictorBase(IHostEnvironment env, string name, int inputLength, Dataset trainSet, double[][] binEffects, int[] featureMap)
: base(env, name)
{
Host.CheckValue(trainSet, nameof(trainSet));
diff --git a/src/Microsoft.ML.FastTree/RandomForestClassification.cs b/src/Microsoft.ML.FastTree/RandomForestClassification.cs
index 8cd62ceb77..612313fb93 100644
--- a/src/Microsoft.ML.FastTree/RandomForestClassification.cs
+++ b/src/Microsoft.ML.FastTree/RandomForestClassification.cs
@@ -67,13 +67,13 @@ private static VersionInfo GetVersionInfo()
loaderSignature: LoaderSignature);
}
- protected override uint VerNumFeaturesSerialized { get { return 0x00010003; } }
+ protected override uint VerNumFeaturesSerialized => 0x00010003;
- protected override uint VerDefaultValueSerialized { get { return 0x00010005; } }
+ protected override uint VerDefaultValueSerialized => 0x00010005;
- protected override uint VerCategoricalSplitSerialized { get { return 0x00010006; } }
+ protected override uint VerCategoricalSplitSerialized => 0x00010006;
- public override PredictionKind PredictionKind { get { return PredictionKind.BinaryClassification; } }
+ public override PredictionKind PredictionKind => PredictionKind.BinaryClassification;
internal FastForestClassificationPredictor(IHostEnvironment env, Ensemble trainedEnsemble, int featureCount,
string innerArgs)
@@ -129,20 +129,20 @@ public sealed class Arguments : FastForestArgumentsBase
private bool[] _trainSetLabels;
+ public override PredictionKind PredictionKind => PredictionKind.BinaryClassification;
+ private protected override bool NeedCalibration => true;
+
public FastForestClassification(IHostEnvironment env, Arguments args)
: base(env, args)
{
}
- public override bool NeedCalibration
+ public override IPredictorWithFeatureWeights Train(TrainContext context)
{
- get { return true; }
- }
+ Host.CheckValue(context, nameof(context));
+ var trainData = context.TrainingSet;
+ ValidData = context.ValidationSet;
- public override PredictionKind PredictionKind { get { return PredictionKind.BinaryClassification; } }
-
- public override void Train(RoleMappedData trainData)
- {
using (var ch = Host.Start("Training"))
{
ch.CheckValue(trainData, nameof(trainData));
@@ -154,13 +154,6 @@ public override void Train(RoleMappedData trainData)
TrainCore(ch);
ch.Done();
}
- }
-
- public override IPredictorWithFeatureWeights CreatePredictor()
- {
- Host.Check(TrainedEnsemble != null,
- "The predictor cannot be created before training is complete");
-
// LogitBoost is naturally calibrated to
// output probabilities when transformed using
// the logistic function, so if we have trained no
diff --git a/src/Microsoft.ML.FastTree/RandomForestRegression.cs b/src/Microsoft.ML.FastTree/RandomForestRegression.cs
index f501037df3..7c45af3e54 100644
--- a/src/Microsoft.ML.FastTree/RandomForestRegression.cs
+++ b/src/Microsoft.ML.FastTree/RandomForestRegression.cs
@@ -53,11 +53,11 @@ private static VersionInfo GetVersionInfo()
loaderSignature: LoaderSignature);
}
- protected override uint VerNumFeaturesSerialized { get { return 0x00010003; } }
+ protected override uint VerNumFeaturesSerialized => 0x00010003;
- protected override uint VerDefaultValueSerialized { get { return 0x00010005; } }
+ protected override uint VerDefaultValueSerialized => 0x00010005;
- protected override uint VerCategoricalSplitSerialized { get { return 0x00010006; } }
+ protected override uint VerCategoricalSplitSerialized => 0x00010006;
internal FastForestRegressionPredictor(IHostEnvironment env, Ensemble trainedEnsemble, int featureCount,
string innerArgs, int samplesCount)
@@ -99,7 +99,7 @@ public static FastForestRegressionPredictor Create(IHostEnvironment env, ModelLo
return new FastForestRegressionPredictor(env, ctx);
}
- public override PredictionKind PredictionKind { get { return PredictionKind.Regression; } }
+ public override PredictionKind PredictionKind => PredictionKind.Regression;
protected override void Map(ref VBuffer src, ref Float dst)
{
@@ -158,15 +158,14 @@ public FastForestRegression(IHostEnvironment env, Arguments args)
{
}
- public override bool NeedCalibration
- {
- get { return false; }
- }
+ public override PredictionKind PredictionKind => PredictionKind.Regression;
- public override PredictionKind PredictionKind { get { return PredictionKind.Regression; } }
-
- public override void Train(RoleMappedData trainData)
+ public override FastForestRegressionPredictor Train(TrainContext context)
{
+ Host.CheckValue(context, nameof(context));
+ var trainData = context.TrainingSet;
+ ValidData = context.ValidationSet;
+
using (var ch = Host.Start("Training"))
{
ch.CheckValue(trainData, nameof(trainData));
@@ -178,13 +177,6 @@ public override void Train(RoleMappedData trainData)
TrainCore(ch);
ch.Done();
}
- }
-
- public override FastForestRegressionPredictor CreatePredictor()
- {
- Host.Check(TrainedEnsemble != null,
- "The predictor cannot be created before training is complete");
-
return new FastForestRegressionPredictor(Host, TrainedEnsemble, FeatureCount, InnerArgs, Args.QuantileSampleCount);
}
diff --git a/src/Microsoft.ML.FastTree/Training/Test.cs b/src/Microsoft.ML.FastTree/Training/Test.cs
index f000a1fa32..4e72f46372 100644
--- a/src/Microsoft.ML.FastTree/Training/Test.cs
+++ b/src/Microsoft.ML.FastTree/Training/Test.cs
@@ -10,10 +10,8 @@
namespace Microsoft.ML.Runtime.FastTree.Internal
{
- public class TestResult : IComparable
+ public sealed class TestResult : IComparable
{
- private double _finalValue;
-
public enum ValueOperator : int
{
None = 0, // the final value will be the raw value,
@@ -36,33 +34,31 @@ public enum ValueOperator : int
// the raw value should be the same constant for all test results.
}
- public string LossFunctionName { get; private set; }
+ public string LossFunctionName { get; }
///
/// Raw value used for calculating final test result value.
///
- public double RawValue { get; private set; }
+ public double RawValue { get; }
///
/// The factor used for calculating final test result value.
///
- public double Factor { get; private set; }
+ public double Factor { get; }
///
/// The operator used for calculating final test result value.
/// Final value = Operator(RawValue, Factor)
///
- public ValueOperator Operator { get; private set; }
+ public ValueOperator Operator { get; }
///
/// Indicates that the lower value of this metric is better
/// This is used for early stopping (with TestHistory and TestWindowWithTolerance)
///
- public bool LowerIsBetter { get; private set; }
+ public bool LowerIsBetter { get; }
- public double FinalValue {
- get { return _finalValue; }
- }
+ public double FinalValue { get; }
public TestResult(string lossFunctionName, double rawValue, double factor, bool lowerIsBetter, ValueOperator valueOperator)
{
@@ -72,7 +68,7 @@ public TestResult(string lossFunctionName, double rawValue, double factor, bool
Operator = valueOperator;
LowerIsBetter = lowerIsBetter;
- CalculateFinalValue();
+ FinalValue = CalculateFinalValue();
}
public int CompareTo(TestResult o)
@@ -124,7 +120,7 @@ public static TestResult FromByteArray(byte[] buffer, ref int offset)
(ValueOperator)valueOperator);
}
- private void CalculateFinalValue()
+ private double CalculateFinalValue()
{
switch (Operator)
{
@@ -133,14 +129,11 @@ private void CalculateFinalValue()
case ValueOperator.Min:
case ValueOperator.None:
case ValueOperator.Sum:
- _finalValue = RawValue;
- break;
+ return RawValue;
case ValueOperator.Average:
- _finalValue = RawValue / Factor;
- break;
+ return RawValue / Factor;
case ValueOperator.SqrtAverage:
- _finalValue = Math.Sqrt(RawValue / Factor);
- break;
+ return Math.Sqrt(RawValue / Factor);
default:
throw Contracts.Except("Unsupported value operator: {0}", Operator);
}
@@ -157,7 +150,7 @@ public abstract class Test
//The method returns one or more losses on a given Dataset
public abstract IEnumerable ComputeTests(double[] scores);
- public Test(ScoreTracker scoreTracker)
+ private protected Test(ScoreTracker scoreTracker)
{
ScoreTracker = scoreTracker;
if (ScoreTracker != null)
@@ -207,13 +200,13 @@ public class TestHistory : Test
protected IList History;
protected int Iteration { get; private set; }
- public TestResult BestResult { get; protected internal set; }
- public int BestIteration { get; protected internal set; }
+ public TestResult BestResult { get; private protected set; }
+ public int BestIteration { get; private protected set; }
// scenarioWithoutHistory - simple test scenario we want to track the history and look for best iteration
// lossIndex - index of lossFunction in case Test returns more than one loss (default should be 0)
// lower is better: are we looking for minimum or maximum of loss function?
- public TestHistory(Test scenarioWithoutHistory, int lossIndex)
+ internal TestHistory(Test scenarioWithoutHistory, int lossIndex)
: base(null)
{
History = new List();
diff --git a/src/Microsoft.ML.KMeansClustering/KMeansPlusPlusTrainer.cs b/src/Microsoft.ML.KMeansClustering/KMeansPlusPlusTrainer.cs
index dce7be48d2..3049e6a54c 100644
--- a/src/Microsoft.ML.KMeansClustering/KMeansPlusPlusTrainer.cs
+++ b/src/Microsoft.ML.KMeansClustering/KMeansPlusPlusTrainer.cs
@@ -29,7 +29,7 @@
namespace Microsoft.ML.Runtime.KMeans
{
///
- public class KMeansPlusPlusTrainer : TrainerBase
+ public class KMeansPlusPlusTrainer : TrainerBase
{
public const string LoadNameValue = "KMeansPlusPlus";
internal const string UserNameValue = "KMeans++ Clustering";
@@ -74,11 +74,6 @@ public class Arguments : UnsupervisedLearnerInputBaseWithWeight
}
private readonly int _k;
- private int _dimensionality;
-
- // The coordinates of the final centroids at the end of the training. During training
- // it holds the centroids of the previous iteration.
- private readonly VBuffer[] _centroids;
private readonly int _maxIterations; // max number of iterations to train
private readonly Float _convergenceThreshold; // convergence thresholds
@@ -87,6 +82,9 @@ public class Arguments : UnsupervisedLearnerInputBaseWithWeight
private readonly InitAlgorithm _initAlgorithm;
private readonly int _numThreads;
+ public override TrainerInfo Info { get; }
+ public override PredictionKind PredictionKind => PredictionKind.Clustering;
+
public KMeansPlusPlusTrainer(IHostEnvironment env, Arguments args)
: base(env, LoadNameValue)
{
@@ -101,8 +99,6 @@ public KMeansPlusPlusTrainer(IHostEnvironment env, Arguments args)
Host.CheckUserArg(args.OptTol > 0, nameof(args.OptTol), "Tolerance must be positive");
_convergenceThreshold = args.OptTol;
- _centroids = new VBuffer[_k];
-
Host.CheckUserArg(args.AccelMemBudgetMb > 0, nameof(args.AccelMemBudgetMb), "Must be positive");
_accelMemBudgetMb = args.AccelMemBudgetMb;
@@ -111,36 +107,38 @@ public KMeansPlusPlusTrainer(IHostEnvironment env, Arguments args)
Host.CheckUserArg(!args.NumThreads.HasValue || args.NumThreads > 0, nameof(args.NumThreads),
"Must be either null or a positive integer.");
_numThreads = ComputeNumThreads(Host, args.NumThreads);
+ Info = new TrainerInfo();
}
- public override bool NeedNormalization => true;
- public override bool NeedCalibration => false;
- public override bool WantCaching => true;
- public override PredictionKind PredictionKind => PredictionKind.Clustering;
-
- public override void Train(RoleMappedData data)
+ public override KMeansPredictor Train(TrainContext context)
{
- Host.CheckValue(data, nameof(data));
+ Host.CheckValue(context, nameof(context));
+ var data = context.TrainingSet;
- data.CheckFeatureFloatVector(out _dimensionality);
- Contracts.Assert(_dimensionality > 0);
+ data.CheckFeatureFloatVector(out int dimensionality);
+ Contracts.Assert(dimensionality > 0);
using (var ch = Host.Start("Training"))
{
- TrainCore(ch, data);
+ var pred = TrainCore(ch, data, dimensionality);
ch.Done();
+ return pred;
}
}
- private void TrainCore(IChannel ch, RoleMappedData data)
+ private KMeansPredictor TrainCore(IChannel ch, RoleMappedData data, int dimensionality)
{
Host.AssertValue(ch);
ch.AssertValue(data);
- // REVIEW: In high-dimensionality cases this is less than ideal
- // and we should consider using sparse buffers.
+ // REVIEW: In high-dimensionality cases this is less than ideal and we should consider
+ // using sparse buffers for the centroids.
+
+ // The coordinates of the final centroids at the end of the training. During training
+ // it holds the centroids of the previous iteration.
+ var centroids = new VBuffer[_k];
for (int i = 0; i < _k; i++)
- _centroids[i] = VBufferUtils.CreateDense(_dimensionality);
+ centroids[i] = VBufferUtils.CreateDense(dimensionality);
ch.Info("Initializing centroids");
long missingFeatureCount;
@@ -154,29 +152,29 @@ private void TrainCore(IChannel ch, RoleMappedData data)
// pay attention to their incoming set of centroids and incrementally train.
if (_initAlgorithm == InitAlgorithm.KMeansPlusPlus)
{
- KMeansPlusPlusInit.Initialize(Host, _numThreads, ch, cursorFactory, _k, _dimensionality,
- _centroids, out missingFeatureCount, out totalTrainingInstances);
+ KMeansPlusPlusInit.Initialize(Host, _numThreads, ch, cursorFactory, _k, dimensionality,
+ centroids, out missingFeatureCount, out totalTrainingInstances);
}
else if (_initAlgorithm == InitAlgorithm.Random)
{
KMeansRandomInit.Initialize(Host, _numThreads, ch, cursorFactory, _k,
- _centroids, out missingFeatureCount, out totalTrainingInstances);
+ centroids, out missingFeatureCount, out totalTrainingInstances);
}
else
{
// Defaulting to KMeans|| initialization.
- KMeansBarBarInitialization.Initialize(Host, _numThreads, ch, cursorFactory, _k, _dimensionality,
- _centroids, _accelMemBudgetMb, out missingFeatureCount, out totalTrainingInstances);
+ KMeansBarBarInitialization.Initialize(Host, _numThreads, ch, cursorFactory, _k, dimensionality,
+ centroids, _accelMemBudgetMb, out missingFeatureCount, out totalTrainingInstances);
}
- KMeansUtils.VerifyModelConsistency(_centroids);
+ KMeansUtils.VerifyModelConsistency(centroids);
ch.Info("Centroids initialized, starting main trainer");
KMeansLloydsYinYangTrain.Train(
- Host, _numThreads, ch, cursorFactory, totalTrainingInstances, _k, _dimensionality, _maxIterations,
- _accelMemBudgetMb, _convergenceThreshold, _centroids);
+ Host, _numThreads, ch, cursorFactory, totalTrainingInstances, _k, dimensionality, _maxIterations,
+ _accelMemBudgetMb, _convergenceThreshold, centroids);
- KMeansUtils.VerifyModelConsistency(_centroids);
+ KMeansUtils.VerifyModelConsistency(centroids);
ch.Info("Model trained successfully on {0} instances", totalTrainingInstances);
if (missingFeatureCount > 0)
{
@@ -184,11 +182,7 @@ private void TrainCore(IChannel ch, RoleMappedData data)
"{0} instances with missing features detected and ignored. Consider using MissingHandler.",
missingFeatureCount);
}
- }
-
- public override KMeansPredictor CreatePredictor()
- {
- return new KMeansPredictor(Host, _k, _centroids, copyIn: true);
+ return new KMeansPredictor(Host, _k, centroids, copyIn: true);
}
private static int ComputeNumThreads(IHost host, int? argNumThreads)
diff --git a/src/Microsoft.ML.LightGBM/LightGbmBinaryTrainer.cs b/src/Microsoft.ML.LightGBM/LightGbmBinaryTrainer.cs
index 54cd523e72..c25b196b0e 100644
--- a/src/Microsoft.ML.LightGBM/LightGbmBinaryTrainer.cs
+++ b/src/Microsoft.ML.LightGBM/LightGbmBinaryTrainer.cs
@@ -43,11 +43,10 @@ private static VersionInfo GetVersionInfo()
loaderSignature: LoaderSignature);
}
- protected override uint VerNumFeaturesSerialized { get { return 0x00010002; } }
-
- protected override uint VerDefaultValueSerialized { get { return 0x00010004; } }
-
- protected override uint VerCategoricalSplitSerialized { get { return 0x00010005; } }
+ protected override uint VerNumFeaturesSerialized => 0x00010002;
+ protected override uint VerDefaultValueSerialized => 0x00010004;
+ protected override uint VerCategoricalSplitSerialized => 0x00010005;
+ public override PredictionKind PredictionKind => PredictionKind.BinaryClassification;
internal LightGbmBinaryPredictor(IHostEnvironment env, FastTree.Internal.Ensemble trainedEnsemble, int featureCount, string innerArgs)
: base(env, RegistrationName, trainedEnsemble, featureCount, innerArgs)
@@ -77,8 +76,6 @@ public static IPredictorProducing Create(IHostEnvironment env, ModelLoadC
return predictor;
return new CalibratedPredictor(env, predictor, calibrator);
}
-
- public override PredictionKind PredictionKind { get { return PredictionKind.BinaryClassification; } }
}
///
@@ -89,12 +86,14 @@ public sealed class LightGbmBinaryTrainer : LightGbmTrainerBase PredictionKind.BinaryClassification;
+
public LightGbmBinaryTrainer(IHostEnvironment env, LightGbmArguments args)
- : base(env, args, PredictionKind.BinaryClassification, "LGBBINCL")
+ : base(env, args, LoadNameValue)
{
}
- public override IPredictorWithFeatureWeights CreatePredictor()
+ private protected override IPredictorWithFeatureWeights CreatePredictor()
{
Host.Check(TrainedEnsemble != null, "The predictor cannot be created before training is complete");
var innerArgs = LightGbmInterfaceUtils.JoinParameters(Options);
diff --git a/src/Microsoft.ML.LightGBM/LightGbmMulticlassTrainer.cs b/src/Microsoft.ML.LightGBM/LightGbmMulticlassTrainer.cs
index 2a84bad0e8..97649feb2a 100644
--- a/src/Microsoft.ML.LightGBM/LightGbmMulticlassTrainer.cs
+++ b/src/Microsoft.ML.LightGBM/LightGbmMulticlassTrainer.cs
@@ -29,9 +29,10 @@ public sealed class LightGbmMulticlassTrainer : LightGbmTrainerBase PredictionKind.MultiClassClassification;
public LightGbmMulticlassTrainer(IHostEnvironment env, LightGbmArguments args)
- : base(env, args, PredictionKind.MultiClassClassification, "LightGBMMulticlass")
+ : base(env, args, LoadNameValue)
{
_numClass = -1;
}
@@ -53,7 +54,7 @@ private LightGbmBinaryPredictor CreateBinaryPredictor(int classID, string innerA
return new LightGbmBinaryPredictor(Host, GetBinaryEnsemble(classID), FeatureCount, innerArgs);
}
- public override OvaPredictor CreatePredictor()
+ private protected override OvaPredictor CreatePredictor()
{
Host.Check(TrainedEnsemble != null, "The predictor cannot be created before training is complete.");
diff --git a/src/Microsoft.ML.LightGBM/LightGbmRankingTrainer.cs b/src/Microsoft.ML.LightGBM/LightGbmRankingTrainer.cs
index 4a1d1634a8..2f44a0ba9d 100644
--- a/src/Microsoft.ML.LightGBM/LightGbmRankingTrainer.cs
+++ b/src/Microsoft.ML.LightGBM/LightGbmRankingTrainer.cs
@@ -41,11 +41,10 @@ private static VersionInfo GetVersionInfo()
loaderSignature: LoaderSignature);
}
- protected override uint VerNumFeaturesSerialized { get { return 0x00010002; } }
-
- protected override uint VerDefaultValueSerialized { get { return 0x00010004; } }
-
- protected override uint VerCategoricalSplitSerialized { get { return 0x00010005; } }
+ protected override uint VerNumFeaturesSerialized => 0x00010002;
+ protected override uint VerDefaultValueSerialized => 0x00010004;
+ protected override uint VerCategoricalSplitSerialized => 0x00010005;
+ public override PredictionKind PredictionKind => PredictionKind.Ranking;
internal LightGbmRankingPredictor(IHostEnvironment env, FastTree.Internal.Ensemble trainedEnsemble, int featureCount, string innerArgs)
: base(env, RegistrationName, trainedEnsemble, featureCount, innerArgs)
@@ -67,8 +66,6 @@ public static LightGbmRankingPredictor Create(IHostEnvironment env, ModelLoadCon
{
return new LightGbmRankingPredictor(env, ctx);
}
-
- public override PredictionKind PredictionKind { get { return PredictionKind.Ranking; } }
}
///
@@ -78,8 +75,10 @@ public sealed class LightGbmRankingTrainer : LightGbmTrainerBase PredictionKind.Ranking;
+
public LightGbmRankingTrainer(IHostEnvironment env, LightGbmArguments args)
- : base(env, args, PredictionKind.Ranking, "LightGBMRanking")
+ : base(env, args, LoadNameValue)
{
}
@@ -103,7 +102,7 @@ protected override void CheckDataValid(IChannel ch, RoleMappedData data)
}
}
- public override LightGbmRankingPredictor CreatePredictor()
+ private protected override LightGbmRankingPredictor CreatePredictor()
{
Host.Check(TrainedEnsemble != null, "The predictor cannot be created before training is complete");
var innerArgs = LightGbmInterfaceUtils.JoinParameters(Options);
diff --git a/src/Microsoft.ML.LightGBM/LightGbmRegressionTrainer.cs b/src/Microsoft.ML.LightGBM/LightGbmRegressionTrainer.cs
index 6ae3da792a..db2f1f268a 100644
--- a/src/Microsoft.ML.LightGBM/LightGbmRegressionTrainer.cs
+++ b/src/Microsoft.ML.LightGBM/LightGbmRegressionTrainer.cs
@@ -41,11 +41,10 @@ private static VersionInfo GetVersionInfo()
loaderSignature: LoaderSignature);
}
- protected override uint VerNumFeaturesSerialized { get { return 0x00010002; } }
-
- protected override uint VerDefaultValueSerialized { get { return 0x00010004; } }
-
- protected override uint VerCategoricalSplitSerialized { get { return 0x00010005; } }
+ protected override uint VerNumFeaturesSerialized => 0x00010002;
+ protected override uint VerDefaultValueSerialized => 0x00010004;
+ protected override uint VerCategoricalSplitSerialized => 0x00010005;
+ public override PredictionKind PredictionKind => PredictionKind.Regression;
internal LightGbmRegressionPredictor(IHostEnvironment env, FastTree.Internal.Ensemble trainedEnsemble, int featureCount, string innerArgs)
: base(env, RegistrationName, trainedEnsemble, featureCount, innerArgs)
@@ -70,8 +69,6 @@ public static LightGbmRegressionPredictor Create(IHostEnvironment env, ModelLoad
ctx.CheckAtModel(GetVersionInfo());
return new LightGbmRegressionPredictor(env, ctx);
}
-
- public override PredictionKind PredictionKind { get { return PredictionKind.Regression; } }
}
public sealed class LightGbmRegressorTrainer : LightGbmTrainerBase
@@ -81,12 +78,14 @@ public sealed class LightGbmRegressorTrainer : LightGbmTrainerBase PredictionKind.Regression;
+
public LightGbmRegressorTrainer(IHostEnvironment env, LightGbmArguments args)
- : base(env, args, PredictionKind.Regression, "LightGBMRegressor")
+ : base(env, args, LoadNameValue)
{
}
- public override LightGbmRegressionPredictor CreatePredictor()
+ private protected override LightGbmRegressionPredictor CreatePredictor()
{
Host.Check(TrainedEnsemble != null,
"The predictor cannot be created before training is complete");
diff --git a/src/Microsoft.ML.LightGBM/LightGbmTrainerBase.cs b/src/Microsoft.ML.LightGBM/LightGbmTrainerBase.cs
index c778c4ee23..83e0f7803b 100644
--- a/src/Microsoft.ML.LightGBM/LightGbmTrainerBase.cs
+++ b/src/Microsoft.ML.LightGBM/LightGbmTrainerBase.cs
@@ -24,9 +24,7 @@ internal static class LightGbmShared
///
/// Base class for all training with LightGBM.
///
- public abstract class LightGbmTrainerBase :
- ITrainer,
- IValidatingTrainer
+ public abstract class LightGbmTrainerBase : TrainerBase
where TPredictor : IPredictorProducing
{
private sealed class CategoricalMetaData
@@ -39,86 +37,67 @@ private sealed class CategoricalMetaData
public bool[] IsCategoricalFeature;
}
- #region members
- private readonly IHostEnvironment _env;
- private readonly PredictionKind _predictionKind;
-
- protected readonly IHost Host;
- protected readonly LightGbmArguments Args;
+ private protected readonly LightGbmArguments Args;
///
/// Stores argumments as objects to convert them to invariant string type in the end so that
/// the code is culture agnostic. When retrieving key value from this dictionary as string
/// please convert to string invariant by string.Format(CultureInfo.InvariantCulture, "{0}", Option[key]).
///
- protected readonly Dictionary Options;
- protected readonly IParallel ParallelTraining;
+ private protected readonly Dictionary Options;
+ private protected readonly IParallel ParallelTraining;
// Store _featureCount and _trainedEnsemble to construct predictor.
- protected int FeatureCount;
- protected FastTree.Internal.Ensemble TrainedEnsemble;
+ private protected int FeatureCount;
+ private protected FastTree.Internal.Ensemble TrainedEnsemble;
- #endregion
+ private static readonly TrainerInfo _info = new TrainerInfo(normalization: false, caching: false, supportValid: true);
+ public override TrainerInfo Info => _info;
- protected LightGbmTrainerBase(IHostEnvironment env, LightGbmArguments args, PredictionKind predictionKind, string name)
+ private protected LightGbmTrainerBase(IHostEnvironment env, LightGbmArguments args, string name)
+ : base(env, name)
{
- Contracts.CheckValue(env, nameof(env));
- env.CheckNonWhiteSpace(name, nameof(name));
-
- Host = env.Register(name);
Host.CheckValue(args, nameof(args));
Args = args;
Options = Args.ToDictionary(Host);
- _predictionKind = predictionKind;
- _env = env;
ParallelTraining = Args.ParallelTrainer != null ? Args.ParallelTrainer.CreateComponent(env) : new SingleTrainer();
InitParallelTraining();
}
- public void Train(RoleMappedData data)
+ public override TPredictor Train(TrainContext context)
{
- Dataset dtrain;
- CategoricalMetaData catMetaData;
- using (var ch = Host.Start("Loading data for LightGBM"))
- {
- using (var pch = Host.StartProgressChannel("Loading data for LightGBM"))
- dtrain = LoadTrainingData(ch, data, out catMetaData);
- ch.Done();
- }
- using (var ch = Host.Start("Training with LightGBM"))
- {
- using (var pch = Host.StartProgressChannel("Training with LightGBM"))
- TrainCore(ch, pch, dtrain, catMetaData);
- ch.Done();
- }
- dtrain.Dispose();
- DisposeParallelTraining();
- }
+ Host.CheckValue(context, nameof(context));
- public void Train(RoleMappedData data, RoleMappedData validData)
- {
- Dataset dtrain;
- Dataset dvalid;
+ Dataset dtrain = null;
+ Dataset dvalid = null;
CategoricalMetaData catMetaData;
- using (var ch = Host.Start("Loading data for LightGBM"))
+ try
{
- using (var pch = Host.StartProgressChannel("Loading data for LightGBM"))
+ using (var ch = Host.Start("Loading data for LightGBM"))
{
- dtrain = LoadTrainingData(ch, data, out catMetaData);
- dvalid = LoadValidationData(ch, dtrain, validData, catMetaData);
+ using (var pch = Host.StartProgressChannel("Loading data for LightGBM"))
+ {
+ dtrain = LoadTrainingData(ch, context.TrainingSet, out catMetaData);
+ if (context.ValidationSet != null)
+ dvalid = LoadValidationData(ch, dtrain, context.ValidationSet, catMetaData);
+ }
+ ch.Done();
+ }
+ using (var ch = Host.Start("Training with LightGBM"))
+ {
+ using (var pch = Host.StartProgressChannel("Training with LightGBM"))
+ TrainCore(ch, pch, dtrain, catMetaData, dvalid);
+ ch.Done();
}
- ch.Done();
}
- using (var ch = Host.Start("Training with LightGBM"))
+ finally
{
- using (var pch = Host.StartProgressChannel("Training with LightGBM"))
- TrainCore(ch, pch, dtrain, catMetaData, dvalid);
- ch.Done();
+ dtrain?.Dispose();
+ dvalid?.Dispose();
+ DisposeParallelTraining();
}
- dtrain.Dispose();
- dvalid.Dispose();
- DisposeParallelTraining();
+ return CreatePredictor();
}
private void InitParallelTraining()
@@ -178,7 +157,7 @@ protected virtual void GetDefaultParameters(IChannel ch, int numRow, bool hasCat
private FloatLabelCursor.Factory CreateCursorFactory(RoleMappedData data)
{
var loadFlags = CursOpt.AllLabels | CursOpt.AllWeights | CursOpt.Features;
- if (_predictionKind == PredictionKind.Ranking)
+ if (PredictionKind == PredictionKind.Ranking)
loadFlags |= CursOpt.Group;
var factory = new FloatLabelCursor.Factory(data, loadFlags);
@@ -392,7 +371,7 @@ private void GetMetainfo(IChannel ch, FloatLabelCursor.Factory factory,
List labelList = new List();
bool hasWeights = factory.Data.Schema.Weight != null;
bool hasGroup = false;
- if (_predictionKind == PredictionKind.Ranking)
+ if (PredictionKind == PredictionKind.Ranking)
{
ch.Check(factory.Data.Schema != null, "The data for ranking task should have group field.");
hasGroup = true;
@@ -870,14 +849,7 @@ private static int GetNumSampleRow(int numRow, int numCol)
return ret;
}
- public PredictionKind PredictionKind => _predictionKind;
-
- IPredictor ITrainer.CreatePredictor()
- {
- return CreatePredictor();
- }
-
- public abstract TPredictor CreatePredictor();
+ private protected abstract TPredictor CreatePredictor();
///
/// This function will be called before training. It will check the label/group and add parameters for specific applications.
diff --git a/src/Microsoft.ML.PCA/PcaTrainer.cs b/src/Microsoft.ML.PCA/PcaTrainer.cs
index 23e7351a86..bebaa49691 100644
--- a/src/Microsoft.ML.PCA/PcaTrainer.cs
+++ b/src/Microsoft.ML.PCA/PcaTrainer.cs
@@ -41,7 +41,7 @@ namespace Microsoft.ML.Runtime.PCA
///
/// This PCA can be made into Kernel PCA by using Random Fourier Features transform
///
- public sealed class RandomizedPcaTrainer : TrainerBase
+ public sealed class RandomizedPcaTrainer : TrainerBase
{
public const string LoadNameValue = "pcaAnomaly";
internal const string UserNameValue = "PCA Anomaly Detector";
@@ -69,13 +69,16 @@ public class Arguments : UnsupervisedLearnerInputBaseWithWeight
public int? Seed;
}
- private int _dimension;
private readonly int _rank;
private readonly int _oversampling;
private readonly bool _center;
private readonly int _seed;
- private VBuffer[] _eigenvectors; // top eigenvectors of the covariance matrix
- private VBuffer _mean;
+
+ public override PredictionKind PredictionKind => PredictionKind.AnomalyDetection;
+
+ // The training performs two passes, only. Probably not worth caching.
+ private static readonly TrainerInfo _info = new TrainerInfo(caching: false);
+ public override TrainerInfo Info => _info;
public RandomizedPcaTrainer(IHostEnvironment env, Arguments args)
: base(env, LoadNameValue)
@@ -90,65 +93,43 @@ public RandomizedPcaTrainer(IHostEnvironment env, Arguments args)
_seed = args.Seed ?? Host.Rand.Next();
}
- public override bool NeedNormalization
- {
- get { return true; }
- }
-
- public override bool NeedCalibration
- {
- get { return false; }
- }
-
- public override bool WantCaching
- {
- // Two passes, only. Probably not worth caching.
- get { return false; }
- }
-
- public override PcaPredictor CreatePredictor()
- {
- return new PcaPredictor(Host, _rank, _eigenvectors, ref _mean);
- }
-
- public override PredictionKind PredictionKind { get { return PredictionKind.AnomalyDetection; } }
-
//Note: the notations used here are the same as in http://web.stanford.edu/group/mmds/slides2010/Martinsson.pdf (pg. 9)
- public override void Train(RoleMappedData data)
+ public override PcaPredictor Train(TrainContext context)
{
- Host.CheckValue(data, nameof(data));
+ Host.CheckValue(context, nameof(context));
- data.CheckFeatureFloatVector(out _dimension);
+ context.TrainingSet.CheckFeatureFloatVector(out int dimension);
using (var ch = Host.Start("Training"))
{
- TrainCore(ch, data);
+ var pred = TrainCore(ch, context.TrainingSet, dimension);
ch.Done();
+ return pred;
}
}
- private void TrainCore(IChannel ch, RoleMappedData data)
+ private PcaPredictor TrainCore(IChannel ch, RoleMappedData data, int dimension)
{
Host.AssertValue(ch);
ch.AssertValue(data);
- if (_rank > _dimension)
- throw ch.Except("Rank ({0}) cannot be larger than the original dimension ({1})", _rank, _dimension);
- int oversampledRank = Math.Min(_rank + _oversampling, _dimension);
+ if (_rank > dimension)
+ throw ch.Except("Rank ({0}) cannot be larger than the original dimension ({1})", _rank, dimension);
+ int oversampledRank = Math.Min(_rank + _oversampling, dimension);
//exact: (size of the 2 big matrices + other minor allocations) / (2^30)
- Double memoryUsageEstimate = 2.0 * _dimension * oversampledRank * sizeof(Float) / 1e9;
+ Double memoryUsageEstimate = 2.0 * dimension * oversampledRank * sizeof(Float) / 1e9;
if (memoryUsageEstimate > 2)
ch.Info("Estimate memory usage: {0:G2} GB. If running out of memory, reduce rank and oversampling factor.", memoryUsageEstimate);
- var y = Zeros(oversampledRank, _dimension);
- _mean = _center ? VBufferUtils.CreateDense(_dimension) : VBufferUtils.CreateEmpty(_dimension);
+ var y = Zeros(oversampledRank, dimension);
+ var mean = _center ? VBufferUtils.CreateDense(dimension) : VBufferUtils.CreateEmpty(dimension);
- var omega = GaussianMatrix(oversampledRank, _dimension, _seed);
+ var omega = GaussianMatrix(oversampledRank, dimension, _seed);
var cursorFactory = new FeatureFloatVectorCursor.Factory(data, CursOpt.Features | CursOpt.Weight);
long numBad;
- Project(Host, cursorFactory, ref _mean, omega, y, out numBad);
+ Project(Host, cursorFactory, ref mean, omega, y, out numBad);
if (numBad > 0)
ch.Warning("Skipped {0} instances with missing features/weights during training", numBad);
@@ -166,7 +147,7 @@ private void TrainCore(IChannel ch, RoleMappedData data)
var q = y; // q in QR decomposition.
var b = omega; // reuse the memory allocated by Omega.
- Project(Host, cursorFactory, ref _mean, q, b, out numBad);
+ Project(Host, cursorFactory, ref mean, q, b, out numBad);
//Compute B2 = B' * B
var b2 = new Float[oversampledRank * oversampledRank];
@@ -179,8 +160,9 @@ private void TrainCore(IChannel ch, RoleMappedData data)
Float[] smallEigenvalues;// eigenvectors and eigenvalues of the small matrix B2.
Float[] smallEigenvectors;
EigenUtils.EigenDecomposition(b2, out smallEigenvalues, out smallEigenvectors);
- PostProcess(b, smallEigenvalues, smallEigenvectors, _dimension, oversampledRank);
- _eigenvectors = b;
+ PostProcess(b, smallEigenvalues, smallEigenvectors, dimension, oversampledRank);
+
+ return new PcaPredictor(Host, _rank, b, ref mean);
}
private static VBuffer[] Zeros(int k, int d)
diff --git a/src/Microsoft.ML.PipelineInference/ExperimentsGenerator.cs b/src/Microsoft.ML.PipelineInference/ExperimentsGenerator.cs
index 3c7441c87f..6b184e7b61 100644
--- a/src/Microsoft.ML.PipelineInference/ExperimentsGenerator.cs
+++ b/src/Microsoft.ML.PipelineInference/ExperimentsGenerator.cs
@@ -110,8 +110,7 @@ public static List GenerateCandidates(IHostEnvironment env, string dataFi
//get all the trainers for this task, and generate the initial set of candidates.
// Exclude the hidden learners, and the metalinear learners.
- var trainers = ComponentCatalog.GetAllDerivedClasses(typeof(ITrainer), predictorType)
- .Where(cls => !cls.IsHidden && !typeof(IMetaLinearTrainer).IsAssignableFrom(cls.Type));
+ var trainers = ComponentCatalog.GetAllDerivedClasses(typeof(ITrainer), predictorType).Where(cls => !cls.IsHidden);
var loaderSubComponent = new SubComponent("TextLoader", loaderSettings);
string loader = $" loader={loaderSubComponent}";
diff --git a/src/Microsoft.ML.StandardLearners/FactorizationMachine/FactorizationMachineTrainer.cs b/src/Microsoft.ML.StandardLearners/FactorizationMachine/FactorizationMachineTrainer.cs
index b967bd9f95..58c28b8712 100644
--- a/src/Microsoft.ML.StandardLearners/FactorizationMachine/FactorizationMachineTrainer.cs
+++ b/src/Microsoft.ML.StandardLearners/FactorizationMachine/FactorizationMachineTrainer.cs
@@ -30,9 +30,7 @@ namespace Microsoft.ML.Runtime.FactorizationMachine
[3] https://github.com/wschin/fast-ffm/blob/master/fast-ffm.pdf
*/
///
- public sealed class FieldAwareFactorizationMachineTrainer : TrainerBase,
- IIncrementalTrainer, IValidatingTrainer,
- IIncrementalValidatingTrainer
+ public sealed class FieldAwareFactorizationMachineTrainer : TrainerBase
{
public const string Summary = "Train a field-aware factorization machine for binary classification";
public const string UserName = "Field-aware Factorization Machine";
@@ -76,9 +74,7 @@ public sealed class Arguments : LearnerInputBaseWithLabel
}
public override PredictionKind PredictionKind => PredictionKind.BinaryClassification;
- public override bool NeedNormalization => true;
- public override bool NeedCalibration => false;
- public override bool WantCaching => true;
+ public override TrainerInfo Info { get; }
private readonly int _latentDim;
private readonly int _latentDimAligned;
private readonly float _lambdaLinear;
@@ -89,7 +85,6 @@ public sealed class Arguments : LearnerInputBaseWithLabel
private readonly bool _shuffle;
private readonly bool _verbose;
private readonly float _radius;
- private FieldAwareFactorizationMachinePredictor _pred;
public FieldAwareFactorizationMachineTrainer(IHostEnvironment env, Arguments args) : base(env, LoadName)
{
@@ -108,6 +103,7 @@ public FieldAwareFactorizationMachineTrainer(IHostEnvironment env, Arguments arg
_shuffle = args.Shuffle;
_verbose = args.Verbose;
_radius = args.Radius;
+ Info = new TrainerInfo();
}
private void InitializeTrainingState(int fieldCount, int featureCount, FieldAwareFactorizationMachinePredictor predictor, out float[] linearWeights,
@@ -216,7 +212,7 @@ private static double CalculateAvgLoss(IChannel ch, RoleMappedData data, bool no
return loss / exampleCount;
}
- private void TrainCore(IChannel ch, IProgressChannel pch, RoleMappedData data, RoleMappedData validData, FieldAwareFactorizationMachinePredictor predictor)
+ private FieldAwareFactorizationMachinePredictor TrainCore(IChannel ch, IProgressChannel pch, RoleMappedData data, RoleMappedData validData, FieldAwareFactorizationMachinePredictor predictor)
{
Host.AssertValue(ch);
Host.AssertValue(pch);
@@ -346,63 +342,25 @@ private void TrainCore(IChannel ch, IProgressChannel pch, RoleMappedData data, R
ch.Warning($"Skipped {badExampleCount} examples with bad label/weight/features in training set");
if (validBadExampleCount != 0)
ch.Warning($"Skipped {validBadExampleCount} examples with bad label/weight/features in validation set");
- _pred = new FieldAwareFactorizationMachinePredictor(Host, _norm, fieldCount, totalFeatureCount, _latentDim, linearWeights, latentWeightsAligned);
+ return new FieldAwareFactorizationMachinePredictor(Host, _norm, fieldCount, totalFeatureCount, _latentDim, linearWeights, latentWeightsAligned);
}
- public override void Train(RoleMappedData data)
+ public override FieldAwareFactorizationMachinePredictor Train(TrainContext context)
{
- Host.CheckValue(data, nameof(data));
- using (var ch = Host.Start("Training"))
- using (var pch = Host.StartProgressChannel("Training"))
- {
- TrainCore(ch, pch, data, null, null);
- ch.Done();
- }
- }
-
- public void Train(RoleMappedData data, RoleMappedData validData)
- {
- Host.CheckValue(data, nameof(data));
- Host.CheckValue(validData, nameof(validData));
- using (var ch = Host.Start("Training"))
- using (var pch = Host.StartProgressChannel("Training"))
- {
- TrainCore(ch, pch, data, validData, null);
- ch.Done();
- }
- }
+ Host.CheckValue(context, nameof(context));
+ var initPredictor = context.InitialPredictor as FieldAwareFactorizationMachinePredictor;
+ Host.CheckParam(context.InitialPredictor == null || initPredictor != null, nameof(context),
+ "Initial predictor should have been " + nameof(FieldAwareFactorizationMachinePredictor));
- public void Train(RoleMappedData data, FieldAwareFactorizationMachinePredictor predictor)
- {
- Host.CheckValue(data, nameof(data));
- Host.CheckValue(predictor, nameof(predictor));
using (var ch = Host.Start("Training"))
using (var pch = Host.StartProgressChannel("Training"))
{
- TrainCore(ch, pch, data, null, predictor);
+ var pred = TrainCore(ch, pch, context.TrainingSet, context.ValidationSet, initPredictor);
ch.Done();
+ return pred;
}
}
- public void Train(RoleMappedData data, RoleMappedData validData, FieldAwareFactorizationMachinePredictor predictor)
- {
- Host.CheckValue(data, nameof(data));
- Host.CheckValue(data, nameof(validData));
- Host.CheckValue(predictor, nameof(predictor));
- using (var ch = Host.Start("Training"))
- using (var pch = Host.StartProgressChannel("Training"))
- {
- TrainCore(ch, pch, data, validData, predictor);
- ch.Done();
- }
- }
-
- public override FieldAwareFactorizationMachinePredictor CreatePredictor()
- {
- Host.Check(_pred != null, nameof(Train) + " has not yet been called");
- return _pred;
- }
-
[TlcModule.EntryPoint(Name = "Trainers.FieldAwareFactorizationMachineBinaryClassifier",
Desc = Summary,
UserName = UserName,
diff --git a/src/Microsoft.ML.StandardLearners/Standard/LinearClassificationTrainer.cs b/src/Microsoft.ML.StandardLearners/Standard/LinearClassificationTrainer.cs
index 554babf1ce..5ae866c5d7 100644
--- a/src/Microsoft.ML.StandardLearners/Standard/LinearClassificationTrainer.cs
+++ b/src/Microsoft.ML.StandardLearners/Standard/LinearClassificationTrainer.cs
@@ -44,62 +44,57 @@ namespace Microsoft.ML.Runtime.Learners
using Stopwatch = System.Diagnostics.Stopwatch;
using TScalarPredictor = IPredictorWithFeatureWeights;
- public abstract class LinearTrainerBase : TrainerBase
+ public abstract class LinearTrainerBase : TrainerBase
where TPredictor : IPredictor
{
- protected int NumFeatures;
- protected VBuffer[] Weights;
- protected Float[] Bias;
protected bool NeedShuffle;
- public override bool NeedNormalization => true;
-
- public override bool WantCaching => true;
+ private static readonly TrainerInfo _info = new TrainerInfo();
+ public override TrainerInfo Info => _info;
///
/// Whether data is to be shuffled every epoch.
///
protected abstract bool ShuffleData { get; }
- protected LinearTrainerBase(IHostEnvironment env, string name)
+ private protected LinearTrainerBase(IHostEnvironment env, string name)
: base(env, name)
{
}
- protected void TrainEx(RoleMappedData data, LinearPredictor predictor)
+ public override TPredictor Train(TrainContext context)
{
+ Host.CheckValue(context, nameof(context));
+ TPredictor pred;
using (var ch = Host.Start("Training"))
{
- ch.AssertValue(data, nameof(data));
- ch.AssertValueOrNull(predictor);
- var preparedData = PrepareDataFromTrainingExamples(ch, data);
- TrainCore(ch, preparedData, predictor);
+ var preparedData = PrepareDataFromTrainingExamples(ch, context.TrainingSet, out int weightSetCount);
+ var initPred = context.InitialPredictor;
+ var linInitPred = (initPred as CalibratedPredictorBase)?.SubPredictor as LinearPredictor;
+ linInitPred = linInitPred ?? initPred as LinearPredictor;
+ Host.CheckParam(context.InitialPredictor == null || linInitPred != null, nameof(context),
+ "Initial predictor was not a linear predictor.");
+ pred = TrainCore(ch, preparedData, linInitPred, weightSetCount);
ch.Done();
}
+ return pred;
}
- public override void Train(RoleMappedData examples)
- {
- Host.CheckValue(examples, nameof(examples));
- TrainEx(examples, null);
- }
-
- protected abstract void TrainCore(IChannel ch, RoleMappedData data, LinearPredictor predictor);
-
- ///
- /// Gets the size of weights and bias array. For binary classification and regression, this is 1.
- /// For multi-class classification, this equals the number of classes.
- ///
- protected abstract int WeightArraySize { get; }
+ protected abstract TPredictor TrainCore(IChannel ch, RoleMappedData data, LinearPredictor predictor, int weightSetCount);
///
/// This method ensures that the data meets the requirements of this trainer and its
/// subclasses, injects necessary transforms, and throws if it couldn't meet them.
///
- protected RoleMappedData PrepareDataFromTrainingExamples(IChannel ch, RoleMappedData examples)
+ /// The channel
+ /// The training examples
+ /// Gets the length of weights and bias array. For binary classification and regression,
+ /// this is 1. For multi-class classification, this equals the number of classes on the label.
+ /// A potentially modified version of
+ protected RoleMappedData PrepareDataFromTrainingExamples(IChannel ch, RoleMappedData examples, out int weightSetCount)
{
ch.AssertValue(examples);
- CheckLabel(examples);
+ CheckLabel(examples, out weightSetCount);
examples.CheckFeatureFloatVector();
var idvToShuffle = examples.Data;
IDataView idvToFeedTrain;
@@ -120,17 +115,17 @@ protected RoleMappedData PrepareDataFromTrainingExamples(IChannel ch, RoleMapped
var roles = examples.Schema.GetColumnRoleNames();
var examplesToFeedTrain = new RoleMappedData(idvToFeedTrain, roles);
- ch.Assert(examplesToFeedTrain.Schema.Label != null);
- ch.Assert(examplesToFeedTrain.Schema.Feature != null);
+ ch.AssertValue(examplesToFeedTrain.Schema.Label);
+ ch.AssertValue(examplesToFeedTrain.Schema.Feature);
if (examples.Schema.Weight != null)
- ch.Assert(examplesToFeedTrain.Schema.Weight != null);
+ ch.AssertValue(examplesToFeedTrain.Schema.Weight);
- NumFeatures = examplesToFeedTrain.Schema.Feature.Type.VectorSize;
- ch.Check(NumFeatures > 0, "Training set has 0 instances, aborting training.");
+ int numFeatures = examplesToFeedTrain.Schema.Feature.Type.VectorSize;
+ ch.Check(numFeatures > 0, "Training set has no features, aborting training.");
return examplesToFeedTrain;
}
- protected abstract void CheckLabel(RoleMappedData examples);
+ protected abstract void CheckLabel(RoleMappedData examples, out int weightSetCount);
protected Float WDot(ref VBuffer features, ref VBuffer weights, Float bias)
{
@@ -165,13 +160,13 @@ public abstract class ArgumentsBase : LearnerInputBaseWithLabel
{
[Argument(ArgumentType.AtMostOnce, HelpText = "L2 regularizer constant. By default the l2 constant is automatically inferred based on data set.", NullName = "", ShortName = "l2", SortOrder = 1)]
[TGUI(Label = "L2 Regularizer Constant", SuggestedSweeps = ",1e-7,1e-6,1e-5,1e-4,1e-3,1e-2")]
- [TlcModule.SweepableDiscreteParamAttribute("L2Const", new object[] { "", 1e-7f, 1e-6f, 1e-5f, 1e-4f, 1e-3f, 1e-2f })]
+ [TlcModule.SweepableDiscreteParam("L2Const", new object[] { "", 1e-7f, 1e-6f, 1e-5f, 1e-4f, 1e-3f, 1e-2f })]
public Float? L2Const;
// REVIEW: make the default positive when we know how to consume a sparse model
[Argument(ArgumentType.AtMostOnce, HelpText = "L1 soft threshold (L1/L2). Note that it is easier to control and sweep using the threshold parameter than the raw L1-regularizer constant. By default the l1 threshold is automatically inferred based on data set.", NullName = "", ShortName = "l1", SortOrder = 2)]
[TGUI(Label = "L1 Soft Threshold", SuggestedSweeps = ",0,0.25,0.5,0.75,1")]
- [TlcModule.SweepableDiscreteParamAttribute("L1Threshold", new object[] { "", 0f, 0.25f, 0.5f, 0.75f, 1f })]
+ [TlcModule.SweepableDiscreteParam("L1Threshold", new object[] { "", 0f, 0.25f, 0.5f, 0.75f, 1f })]
public Float? L1Threshold;
[Argument(ArgumentType.AtMostOnce, HelpText = "Degree of lock-free parallelism. Defaults to automatic. Determinism not guaranteed.", NullName = "", ShortName = "nt,t,threads", SortOrder = 50)]
@@ -180,16 +175,16 @@ public abstract class ArgumentsBase : LearnerInputBaseWithLabel
[Argument(ArgumentType.AtMostOnce, HelpText = "The tolerance for the ratio between duality gap and primal loss for convergence checking.", ShortName = "tol")]
[TGUI(SuggestedSweeps = "0.001, 0.01, 0.1, 0.2")]
- [TlcModule.SweepableDiscreteParamAttribute("ConvergenceTolerance", new object[] { 0.001f, 0.01f, 0.1f, 0.2f })]
+ [TlcModule.SweepableDiscreteParam("ConvergenceTolerance", new object[] { 0.001f, 0.01f, 0.1f, 0.2f })]
public Float ConvergenceTolerance = 0.1f;
[Argument(ArgumentType.AtMostOnce, HelpText = "Maximum number of iterations; set to 1 to simulate online learning. Defaults to automatic.", NullName = "", ShortName = "iter")]
[TGUI(Label = "Max number of iterations", SuggestedSweeps = ",10,20,100")]
- [TlcModule.SweepableDiscreteParamAttribute("MaxIterations", new object[] { "", 10, 20, 100 })]
+ [TlcModule.SweepableDiscreteParam("MaxIterations", new object[] { "", 10, 20, 100 })]
public int? MaxIterations;
[Argument(ArgumentType.AtMostOnce, HelpText = "Shuffle data every epoch?", ShortName = "shuf")]
- [TlcModule.SweepableDiscreteParamAttribute("Shuffle", null, isBool: true)]
+ [TlcModule.SweepableDiscreteParam("Shuffle", null, isBool: true)]
public bool Shuffle = true;
[Argument(ArgumentType.AtMostOnce, HelpText = "Convergence check frequency (in terms of number of iterations). Set as negative or zero for not checking at all. If left blank, it defaults to check after every 'numThreads' iterations.", NullName = "", ShortName = "checkFreq")]
@@ -197,7 +192,7 @@ public abstract class ArgumentsBase : LearnerInputBaseWithLabel
[Argument(ArgumentType.AtMostOnce, HelpText = "The learning rate for adjusting bias from being regularized.", ShortName = "blr")]
[TGUI(SuggestedSweeps = "0, 0.01, 0.1, 1")]
- [TlcModule.SweepableDiscreteParamAttribute("BiasLearningRate", new object[] { 0.0f, 0.01f, 0.1f, 1f })]
+ [TlcModule.SweepableDiscreteParam("BiasLearningRate", new object[] { 0.0f, 0.01f, 0.1f, 1f })]
public Float BiasLearningRate = 0;
internal virtual void Check(IHostEnvironment env)
@@ -217,6 +212,7 @@ internal virtual void Check(IHostEnvironment env)
"could drastically slow down the convergence. So using l2Const = {1} instead.", L2Const);
L2Const = L2LowerBound;
+ ch.Done();
}
}
}
@@ -243,12 +239,7 @@ protected enum MetricKind
private readonly ArgumentsBase _args;
protected ISupportSdcaLoss Loss;
- public override bool NeedNormalization
- {
- get { return true; }
- }
-
- protected override bool ShuffleData { get { return _args.Shuffle; } }
+ protected override bool ShuffleData => _args.Shuffle;
protected SdcaTrainerBase(ArgumentsBase args, IHostEnvironment env, string name)
: base(env, name)
@@ -257,13 +248,13 @@ protected SdcaTrainerBase(ArgumentsBase args, IHostEnvironment env, string name)
_args.Check(env);
}
- protected override void TrainCore(IChannel ch, RoleMappedData data, LinearPredictor predictor)
+ protected sealed override TPredictor TrainCore(IChannel ch, RoleMappedData data, LinearPredictor predictor, int weightSetCount)
{
Contracts.Assert(predictor == null, "SDCA based trainers don't support continuous training.");
- Contracts.Assert(NumFeatures > 0, "Number of features must be assigned prior to passing into TrainCore.");
- int weightArraySize = WeightArraySize;
- Contracts.Assert(weightArraySize >= 1);
- long maxTrainingExamples = MaxDualTableSize / weightArraySize;
+ Contracts.Assert(weightSetCount >= 1);
+
+ int numFeatures = data.Schema.Feature.Type.VectorSize;
+ long maxTrainingExamples = MaxDualTableSize / weightSetCount;
var cursorFactory = new FloatLabelCursor.Factory(data, CursOpt.Label | CursOpt.Features | CursOpt.Weight | CursOpt.Id);
int numThreads;
if (_args.NumThreads.HasValue)
@@ -301,8 +292,7 @@ protected override void TrainCore(IChannel ch, RoleMappedData data, LinearPredic
ch.Assert(checkFrequency > 0);
var pOptions = new ParallelOptions { MaxDegreeOfParallelism = numThreads };
- var converged = false;
- var watch = new Stopwatch();
+ bool converged = false;
// Getting the total count of rows in data. Ignore rows with bad label and feature values.
long count = 0;
@@ -398,24 +388,24 @@ protected override void TrainCore(IChannel ch, RoleMappedData data, LinearPredic
Contracts.Assert(_args.L2Const.HasValue);
if (_args.L1Threshold == null)
- _args.L1Threshold = TuneDefaultL1(ch, NumFeatures);
+ _args.L1Threshold = TuneDefaultL1(ch, numFeatures);
ch.Assert(_args.L1Threshold.HasValue);
var l1Threshold = _args.L1Threshold.Value;
var l1ThresholdZero = l1Threshold == 0;
- VBuffer[] weights = new VBuffer[weightArraySize];
- VBuffer[] bestWeights = new VBuffer[weightArraySize];
- VBuffer[] l1IntermediateWeights = l1ThresholdZero ? null : new VBuffer[weightArraySize];
- Float[] biasReg = new Float[weightArraySize];
- Float[] bestBiasReg = new Float[weightArraySize];
- Float[] biasUnreg = new Float[weightArraySize];
- Float[] bestBiasUnreg = new Float[weightArraySize];
- Float[] l1IntermediateBias = l1ThresholdZero ? null : new Float[weightArraySize];
-
- for (int i = 0; i < weightArraySize; i++)
+ var weights = new VBuffer[weightSetCount];
+ var bestWeights = new VBuffer[weightSetCount];
+ var l1IntermediateWeights = l1ThresholdZero ? null : new VBuffer[weightSetCount];
+ var biasReg = new Float[weightSetCount];
+ var bestBiasReg = new Float[weightSetCount];
+ var biasUnreg = new Float[weightSetCount];
+ var bestBiasUnreg = new Float[weightSetCount];
+ var l1IntermediateBias = l1ThresholdZero ? null : new Float[weightSetCount];
+
+ for (int i = 0; i < weightSetCount; i++)
{
- weights[i] = VBufferUtils.CreateDense(NumFeatures);
- bestWeights[i] = VBufferUtils.CreateDense(NumFeatures);
+ weights[i] = VBufferUtils.CreateDense(numFeatures);
+ bestWeights[i] = VBufferUtils.CreateDense(numFeatures);
biasReg[i] = 0;
bestBiasReg[i] = 0;
biasUnreg[i] = 0;
@@ -423,7 +413,7 @@ protected override void TrainCore(IChannel ch, RoleMappedData data, LinearPredic
if (!l1ThresholdZero)
{
- l1IntermediateWeights[i] = VBufferUtils.CreateDense(NumFeatures);
+ l1IntermediateWeights[i] = VBufferUtils.CreateDense(numFeatures);
l1IntermediateBias[i] = 0;
}
}
@@ -441,7 +431,7 @@ protected override void TrainCore(IChannel ch, RoleMappedData data, LinearPredic
if (idToIdx == null)
{
Contracts.Assert(!needLookup);
- long dualsLength = ((long)idLoMax + 1) * WeightArraySize;
+ long dualsLength = ((long)idLoMax + 1) * weightSetCount;
if (dualsLength <= Utils.ArrayMaxSize)
{
// The dual variables fit into a standard float[].
@@ -465,7 +455,7 @@ protected override void TrainCore(IChannel ch, RoleMappedData data, LinearPredic
{
// Similar logic as above when using the id-to-index lookup.
Contracts.Assert(needLookup);
- long dualsLength = count * WeightArraySize;
+ long dualsLength = count * weightSetCount;
if (dualsLength <= Utils.ArrayMaxSize)
{
duals = new StandardArrayDualsTable((int)dualsLength);
@@ -497,8 +487,6 @@ protected override void TrainCore(IChannel ch, RoleMappedData data, LinearPredic
ch.Assert(_args.MaxIterations.HasValue);
var maxIterations = _args.MaxIterations.Value;
- watch.Start();
-
var rands = new IRandom[maxIterations];
for (int i = 0; i < maxIterations; i++)
rands[i] = RandomUtils.Create(Host.Rand.Next());
@@ -506,9 +494,9 @@ protected override void TrainCore(IChannel ch, RoleMappedData data, LinearPredic
// If we favor storing the invariants, precompute the invariants now.
if (invariants != null)
{
- Contracts.Assert((idToIdx == null & ((long)idLoMax + 1) * WeightArraySize <= Utils.ArrayMaxSize) | (idToIdx != null & count * WeightArraySize <= Utils.ArrayMaxSize));
- Func getIndexFromIdAndRow = GetIndexFromIdAndRowGetter(idToIdx);
- int invariantCoeff = WeightArraySize == 1 ? 1 : 2;
+ Contracts.Assert((idToIdx == null & ((long)idLoMax + 1) * weightSetCount <= Utils.ArrayMaxSize) | (idToIdx != null & count * weightSetCount <= Utils.ArrayMaxSize));
+ Func getIndexFromIdAndRow = GetIndexFromIdAndRowGetter(idToIdx, biasReg.Length);
+ int invariantCoeff = weightSetCount == 1 ? 1 : 2;
using (var cursor = cursorFactory.Create())
using (var pch = Host.StartProgressChannel("SDCA invariants initialization"))
{
@@ -599,23 +587,25 @@ protected override void TrainCore(IChannel ch, RoleMappedData data, LinearPredic
}
}
- Bias = new Float[weightArraySize];
+ var bias = new Float[weightSetCount];
if (bestIter > 0)
{
ch.Info("Using best model from iteration {0}.", bestIter);
- Weights = bestWeights;
- for (int i = 0; i < weightArraySize; i++)
- Bias[i] = bestBiasReg[i] + bestBiasUnreg[i];
+ weights = bestWeights;
+ for (int i = 0; i < weightSetCount; i++)
+ bias[i] = bestBiasReg[i] + bestBiasUnreg[i];
}
else
{
ch.Info("Using model from last iteration.");
- Weights = weights;
- for (int i = 0; i < weightArraySize; i++)
- Bias[i] = biasReg[i] + biasUnreg[i];
+ for (int i = 0; i < weightSetCount; i++)
+ bias[i] = biasReg[i] + biasUnreg[i];
}
+ return CreatePredictor(weights, bias);
}
+ protected abstract TPredictor CreatePredictor(VBuffer[] weights, Float[] bias);
+
// Assign an upper bound for number of iterations based on data set size first.
// This ensures SDCA will not run forever...
// Based on empirical estimation of max iterations needed.
@@ -746,7 +736,7 @@ protected virtual void TrainWithoutLock(IProgressChannelProvider progress, Float
if (pch != null)
pch.SetHeader(new ProgressHeader("examples"), e => e.SetProgress(0, rowCount));
- Func getIndexFromId = GetIndexFromIdGetter(idToIdx);
+ Func getIndexFromId = GetIndexFromIdGetter(idToIdx, biasReg.Length);
while (cursor.MoveNext())
{
long idx = getIndexFromId(cursor.Id);
@@ -901,7 +891,7 @@ protected virtual bool CheckConvergence(
using (var cursor = cursorFactory.Create())
{
long row = 0;
- Func getIndexFromIdAndRow = GetIndexFromIdAndRowGetter(idToIdx);
+ Func getIndexFromIdAndRow = GetIndexFromIdAndRowGetter(idToIdx, biasReg.Length);
// Iterates through data to compute loss function.
while (cursor.MoveNext())
{
@@ -994,8 +984,8 @@ public StandardArrayDualsTable(int length)
public override Float this[long index]
{
- get { return _duals[(int)index]; }
- set { _duals[(int)index] = value; }
+ get => _duals[(int)index];
+ set => _duals[(int)index] = value;
}
public override void ApplyAt(long index, Visitor manip)
@@ -1011,7 +1001,7 @@ private sealed class BigArrayDualsTable : DualsTableBase
{
private BigArray _duals;
- public override long Length { get { return _duals.Length; } }
+ public override long Length => _duals.Length;
public BigArrayDualsTable(long length)
{
@@ -1021,14 +1011,8 @@ public BigArrayDualsTable(long length)
public override Float this[long index]
{
- get
- {
- return _duals[index];
- }
- set
- {
- _duals[index] = value;
- }
+ get => _duals[index];
+ set => _duals[index] = value;
}
public override void ApplyAt(long index, Visitor manip)
@@ -1042,10 +1026,10 @@ public override void ApplyAt(long index, Visitor manip)
/// Returns a function delegate to retrieve index from id.
/// This is to avoid redundant conditional branches in the tight loop of training.
///
- protected Func GetIndexFromIdGetter(IdToIdxLookup idToIdx)
+ protected Func GetIndexFromIdGetter(IdToIdxLookup idToIdx, int biasLength)
{
Contracts.AssertValueOrNull(idToIdx);
- long maxTrainingExamples = MaxDualTableSize / WeightArraySize;
+ long maxTrainingExamples = MaxDualTableSize / biasLength;
if (idToIdx == null)
{
return (UInt128 id) =>
@@ -1073,10 +1057,10 @@ protected Func GetIndexFromIdGetter(IdToIdxLookup idToIdx)
/// Only works if the cursor is not shuffled.
/// This is to avoid redundant conditional branches in the tight loop of training.
///
- protected Func GetIndexFromIdAndRowGetter(IdToIdxLookup idToIdx)
+ protected Func GetIndexFromIdAndRowGetter(IdToIdxLookup idToIdx, int biasLength)
{
Contracts.AssertValueOrNull(idToIdx);
- long maxTrainingExamples = MaxDualTableSize / WeightArraySize;
+ long maxTrainingExamples = MaxDualTableSize / biasLength;
if (idToIdx == null)
{
return (UInt128 id, long row) =>
@@ -1115,7 +1099,7 @@ protected Func GetIndexFromIdAndRowGetter(IdToIdxLookup idT
/// the table growing operation initializes a new larger bucket and rehash the existing entries to
/// the new bucket. Such operation has an expected complexity proportional to the size.
///
- protected internal sealed class IdToIdxLookup
+ protected sealed class IdToIdxLookup
{
// Utilizing this struct gives better cache behavior than using parallel arrays.
private struct Entry
@@ -1142,7 +1126,7 @@ public Entry(long itNext, UInt128 value)
///
/// Gets the count of id entries.
///
- public long Count { get { return _count; } }
+ public long Count => _count;
///
/// Initializes an instance of the class with the specified size.
@@ -1368,7 +1352,7 @@ public void Add(Double summand)
}
}
- public sealed class LinearClassificationTrainer : SdcaTrainerBase, ITrainer, ITrainerEx
+ public sealed class LinearClassificationTrainer : SdcaTrainerBase
{
public const string LoadNameValue = "SDCA";
public const string UserNameValue = "Fast Linear (SA-SDCA)";
@@ -1402,57 +1386,48 @@ internal override void Check(IHostEnvironment env)
public override PredictionKind PredictionKind => PredictionKind.BinaryClassification;
- public override bool NeedCalibration => !(_loss is LogLoss);
-
- protected override int WeightArraySize => 1;
+ public override TrainerInfo Info { get; }
public LinearClassificationTrainer(IHostEnvironment env, Arguments args)
: base(args, env, LoadNameValue)
{
_loss = args.LossFunction.CreateComponent(env);
base.Loss = _loss;
+ Info = new TrainerInfo(calibration: !(_loss is LogLoss));
NeedShuffle = args.Shuffle;
_args = args;
_positiveInstanceWeight = _args.PositiveInstanceWeight;
}
- public override IPredictor CreatePredictor()
+ protected override TScalarPredictor CreatePredictor(VBuffer[] weights, Float[] bias)
{
- Contracts.Assert(WeightArraySize == 1);
- Contracts.Assert(Utils.Size(Weights) == 1);
- Contracts.Assert(Utils.Size(Bias) == 1);
- Host.Check(Weights[0].Length > 0);
- VBuffer maybeSparseWeights = VBufferUtils.CreateEmpty(Weights[0].Length);
- VBufferUtils.CreateMaybeSparseCopy(ref Weights[0], ref maybeSparseWeights, Conversions.Instance.GetIsDefaultPredicate(NumberType.Float));
- var predictor = new LinearBinaryPredictor(Host, ref maybeSparseWeights, Bias[0]);
+ Host.CheckParam(Utils.Size(weights) == 1, nameof(weights));
+ Host.CheckParam(Utils.Size(bias) == 1, nameof(bias));
+ Host.CheckParam(weights[0].Length > 0, nameof(weights));
+
+ VBuffer maybeSparseWeights = default;
+ VBufferUtils.CreateMaybeSparseCopy(ref weights[0], ref maybeSparseWeights,
+ Conversions.Instance.GetIsDefaultPredicate(NumberType.Float));
+ var predictor = new LinearBinaryPredictor(Host, ref maybeSparseWeights, bias[0]);
if (!(_loss is LogLoss))
return predictor;
return new ParameterMixingCalibratedPredictor(Host, predictor, new PlattCalibrator(Host, -1, 0));
}
- TScalarPredictor ITrainer.CreatePredictor()
- {
- var predictor = CreatePredictor() as TScalarPredictor;
- Contracts.AssertValue(predictor);
- return predictor;
- }
-
protected override Float GetInstanceWeight(FloatLabelCursor cursor)
{
return cursor.Label > 0 ? cursor.Weight * _positiveInstanceWeight : cursor.Weight;
}
- protected override void CheckLabel(RoleMappedData examples)
+ protected override void CheckLabel(RoleMappedData examples, out int weightSetCount)
{
examples.CheckBinaryLabel();
+ weightSetCount = 1;
}
}
public sealed class StochasticGradientDescentClassificationTrainer :
- LinearTrainerBase,
- IIncrementalTrainer,
- ITrainer,
- ITrainerEx
+ LinearTrainerBase
{
public const string LoadNameValue = "BinarySGD";
public const string UserNameValue = "Hogwild SGD (binary)";
@@ -1465,7 +1440,7 @@ public sealed class Arguments : LearnerInputBaseWithWeight
[Argument(ArgumentType.AtMostOnce, HelpText = "L2 regularizer constant", ShortName = "l2", SortOrder = 50)]
[TGUI(Label = "L2 Regularizer Constant", SuggestedSweeps = "1e-7,5e-7,1e-6,5e-6,1e-5")]
- [TlcModule.SweepableDiscreteParamAttribute("L2Const", new object[] { 1e-7f, 5e-7f, 1e-6f, 5e-6f, 1e-5f })]
+ [TlcModule.SweepableDiscreteParam("L2Const", new object[] { 1e-7f, 5e-7f, 1e-6f, 5e-6f, 1e-5f })]
public Float L2Const = (Float)1e-6;
[Argument(ArgumentType.AtMostOnce, HelpText = "Degree of lock-free parallelism. Defaults to automatic depending on data sparseness. Determinism not guaranteed.", ShortName = "nt,t,threads", SortOrder = 50)]
@@ -1474,12 +1449,12 @@ public sealed class Arguments : LearnerInputBaseWithWeight
[Argument(ArgumentType.AtMostOnce, HelpText = "Exponential moving averaged improvement tolerance for convergence", ShortName = "tol")]
[TGUI(SuggestedSweeps = "1e-2,1e-3,1e-4,1e-5")]
- [TlcModule.SweepableDiscreteParamAttribute("ConvergenceTolerance", new object[] { 1e-2f, 1e-3f, 1e-4f, 1e-5f })]
+ [TlcModule.SweepableDiscreteParam("ConvergenceTolerance", new object[] { 1e-2f, 1e-3f, 1e-4f, 1e-5f })]
public Double ConvergenceTolerance = 1e-4;
[Argument(ArgumentType.AtMostOnce, HelpText = "Maximum number of iterations; set to 1 to simulate online learning.", ShortName = "iter")]
[TGUI(Label = "Max number of iterations", SuggestedSweeps = "1,5,10,20")]
- [TlcModule.SweepableDiscreteParamAttribute("MaxIterations", new object[] { 1, 5, 10, 20 })]
+ [TlcModule.SweepableDiscreteParam("MaxIterations", new object[] { 1, 5, 10, 20 })]
public int MaxIterations = 20;
[Argument(ArgumentType.AtMostOnce, HelpText = "Initial learning rate (only used by SGD)", ShortName = "ilr,lr")]
@@ -1487,7 +1462,7 @@ public sealed class Arguments : LearnerInputBaseWithWeight
public Double InitLearningRate = 0.01;
[Argument(ArgumentType.AtMostOnce, HelpText = "Shuffle data every epoch?", ShortName = "shuf")]
- [TlcModule.SweepableDiscreteParamAttribute("Shuffle", null, isBool: true)]
+ [TlcModule.SweepableDiscreteParam("Shuffle", null, isBool: true)]
public bool Shuffle = true;
[Argument(ArgumentType.AtMostOnce, HelpText = "Apply weight to the positive class, for imbalanced data", ShortName = "piw")]
@@ -1502,15 +1477,23 @@ public sealed class Arguments : LearnerInputBaseWithWeight
[Argument(ArgumentType.AtMostOnce, HelpText = "The maximum number of examples to use when training the calibrator", Visibility = ArgumentAttribute.VisibilityType.EntryPointsOnly)]
public int MaxCalibrationExamples = 1000000;
- public void Check(ITrainerHost host)
+ internal void Check(IHostEnvironment env)
{
- Contracts.CheckUserArg(L2Const >= 0, nameof(L2Const), "L2 constant must be non-negative.");
- Contracts.CheckUserArg(InitLearningRate > 0, nameof(InitLearningRate), "Initial learning rate must be positive.");
- Contracts.CheckUserArg(MaxIterations > 0, nameof(MaxIterations), "Max number of iterations must be positive.");
- Contracts.CheckUserArg(PositiveInstanceWeight > 0, nameof(PositiveInstanceWeight), "Weight for positive instances must be positive");
+ Contracts.CheckValue(env, nameof(env));
+ env.CheckUserArg(L2Const >= 0, nameof(L2Const), "Must be non-negative.");
+ env.CheckUserArg(InitLearningRate > 0, nameof(InitLearningRate), "Must be positive.");
+ env.CheckUserArg(MaxIterations > 0, nameof(MaxIterations), "Must be positive.");
+ env.CheckUserArg(PositiveInstanceWeight > 0, nameof(PositiveInstanceWeight), "Must be positive");
if (InitLearningRate * L2Const >= 1)
- host.StdOut.WriteLine("Learning rate {0} set too high; reducing to {1}", InitLearningRate, InitLearningRate = (Float)0.5 / L2Const);
+ {
+ using (var ch = env.Start("Argument Adjustment"))
+ {
+ ch.Warning("{0} {1} set too high; reducing to {1}", nameof(InitLearningRate),
+ InitLearningRate, InitLearningRate = (Float)0.5 / L2Const);
+ ch.Done();
+ }
+ }
if (ConvergenceTolerance <= 0)
ConvergenceTolerance = Float.Epsilon;
@@ -1520,63 +1503,34 @@ public void Check(ITrainerHost host)
private readonly IClassificationLoss _loss;
private readonly Arguments _args;
- protected override bool ShuffleData { get { return _args.Shuffle; } }
-
- protected override int WeightArraySize { get { return 1; } }
+ protected override bool ShuffleData => _args.Shuffle;
- public override PredictionKind PredictionKind { get { return PredictionKind.BinaryClassification; } }
+ public override PredictionKind PredictionKind => PredictionKind.BinaryClassification;
- public override bool NeedCalibration
- {
- get { return !(_loss is LogLoss); }
- }
+ public override TrainerInfo Info { get; }
public StochasticGradientDescentClassificationTrainer(IHostEnvironment env, Arguments args)
: base(env, LoadNameValue)
{
+ args.Check(env);
_loss = args.LossFunction.CreateComponent(env);
+ Info = new TrainerInfo(calibration: !(_loss is LogLoss), supportIncrementalTrain: true);
NeedShuffle = args.Shuffle;
_args = args;
}
- public override IPredictor CreatePredictor()
- {
- Contracts.Assert(WeightArraySize == 1);
- Contracts.Assert(Utils.Size(Weights) == 1);
- Contracts.Assert(Utils.Size(Bias) == 1);
- Host.Check(Weights[0].Length > 0);
- VBuffer maybeSparseWeights = VBufferUtils.CreateEmpty(Weights[0].Length);
- VBufferUtils.CreateMaybeSparseCopy(ref Weights[0], ref maybeSparseWeights, Conversions.Instance.GetIsDefaultPredicate(NumberType.Float));
- var predictor = new LinearBinaryPredictor(Host, ref maybeSparseWeights, Bias[0]);
- if (!(_loss is LogLoss))
- return predictor;
- return new ParameterMixingCalibratedPredictor(Host, predictor, new PlattCalibrator(Host, -1, 0));
- }
-
- TScalarPredictor ITrainer.CreatePredictor()
- {
- var predictor = CreatePredictor() as TScalarPredictor;
- Contracts.AssertValue(predictor);
- return predictor;
- }
-
- public void Train(RoleMappedData data, IPredictor predictor)
- {
- Host.CheckValue(data, nameof(data));
- Host.CheckValue(predictor, nameof(predictor));
- LinearPredictor pred = (predictor as CalibratedPredictorBase)?.SubPredictor as LinearPredictor;
- pred = pred ?? predictor as LinearPredictor;
- Host.CheckParam(pred != null, nameof(predictor), "Not a linear predictor.");
- TrainEx(data, pred);
- }
-
//For complexity analysis, we assume that
// - The number of features is N
// - Average number of non-zero per instance is k
- protected override void TrainCore(IChannel ch, RoleMappedData data, LinearPredictor predictor)
+ protected override TScalarPredictor TrainCore(IChannel ch, RoleMappedData data, LinearPredictor predictor, int weightSetCount)
{
- ch.Assert(NumFeatures > 0, "Number of features must be assigned prior to passing into TrainCore.");
+ Contracts.AssertValue(data);
+ Contracts.Assert(weightSetCount == 1);
+ Contracts.AssertValueOrNull(predictor);
+
+ int numFeatures = data.Schema.Feature.Type.VectorSize;
var cursorFactory = new FloatLabelCursor.Factory(data, CursOpt.Label | CursOpt.Features | CursOpt.Weight);
+
int numThreads;
if (_args.NumThreads.HasValue)
{
@@ -1603,7 +1557,7 @@ protected override void TrainCore(IChannel ch, RoleMappedData data, LinearPredic
bias = predictor.Bias;
}
else
- weights = VBufferUtils.CreateDense(NumFeatures);
+ weights = VBufferUtils.CreateDense(numFeatures);
var weightsSync = new object();
double weightScaling = 1;
@@ -1742,15 +1696,18 @@ protected override void TrainCore(IChannel ch, RoleMappedData data, LinearPredic
VectorUtils.ScaleBy(ref weights, (Float)weightScaling); // restore the true weights
- Weights = new VBuffer[1];
- Bias = new Float[1];
- Weights[0] = weights;
- Bias[0] = bias;
+ VBuffer maybeSparseWeights = default;
+ VBufferUtils.CreateMaybeSparseCopy(ref weights, ref maybeSparseWeights, Conversions.Instance.GetIsDefaultPredicate(NumberType.Float));
+ var pred = new LinearBinaryPredictor(Host, ref maybeSparseWeights, bias);
+ if (!(_loss is LogLoss))
+ return pred;
+ return new ParameterMixingCalibratedPredictor(Host, pred, new PlattCalibrator(Host, -1, 0));
}
- protected override void CheckLabel(RoleMappedData examples)
+ protected override void CheckLabel(RoleMappedData examples, out int weightSetCount)
{
examples.CheckBinaryLabel();
+ weightSetCount = 1;
}
[TlcModule.EntryPoint(Name = "Trainers.StochasticGradientDescentBinaryClassifier", Desc = "Train an Hogwild SGD binary model.", UserName = UserNameValue, ShortName = ShortName)]
diff --git a/src/Microsoft.ML.StandardLearners/Standard/LogisticRegression/LbfgsPredictorBase.cs b/src/Microsoft.ML.StandardLearners/Standard/LogisticRegression/LbfgsPredictorBase.cs
index 89f4866228..537364a56a 100644
--- a/src/Microsoft.ML.StandardLearners/Standard/LogisticRegression/LbfgsPredictorBase.cs
+++ b/src/Microsoft.ML.StandardLearners/Standard/LogisticRegression/LbfgsPredictorBase.cs
@@ -17,9 +17,7 @@
namespace Microsoft.ML.Runtime.Learners
{
- public abstract class LbfgsTrainerBase :
- TrainerBase,
- IIncrementalTrainer
+ public abstract class LbfgsTrainerBase : TrainerBase
where TPredictor : class, IPredictorProducing
{
public abstract class ArgumentsBase : LearnerInputBaseWithWeight
@@ -134,6 +132,11 @@ public abstract class ArgumentsBase : LearnerInputBaseWithWeight
private VBuffer[] _localGradients;
private Float[] _localLosses;
+ // REVIEW: It's pointless to request caching when we're going to load everything into
+ // memory, that is, when using multiple threads. So should caching not be requested?
+ private static readonly TrainerInfo _info = new TrainerInfo(caching: true, supportIncrementalTrain: true);
+ public override TrainerInfo Info => _info;
+
internal LbfgsTrainerBase(ArgumentsBase args, IHostEnvironment env, string name, bool showTrainingStats = false)
: base(env, name)
{
@@ -172,16 +175,9 @@ internal LbfgsTrainerBase(ArgumentsBase args, IHostEnvironment env, string name,
}
}
- public override bool NeedNormalization => true;
-
- // REVIEW: It's pointless to request caching when we're going to load everything into
- // memory, that is, when using multiple threads.
- public override bool WantCaching => true;
-
protected virtual int ClassCount => 1;
protected int BiasCount => ClassCount;
protected int WeightCount => ClassCount * NumFeatures;
-
protected virtual Optimizer InitializeOptimizer(IChannel ch, FloatLabelCursor.Factory cursorFactory,
out VBuffer init, out ITerminationCriterion terminationCriterion)
{
@@ -289,28 +285,23 @@ protected virtual VBuffer InitializeWeightsSgd(IChannel ch, FloatLabelCur
protected abstract VBuffer InitializeWeightsFromPredictor(TPredictor srcPredictor);
- public void Train(RoleMappedData data, TPredictor predictor)
- {
- Contracts.CheckValue(data, nameof(data));
- Contracts.CheckValue(predictor, nameof(predictor));
-
- _srcPredictor = predictor;
- Train(data);
- }
-
protected abstract void CheckLabel(RoleMappedData data);
protected virtual void PreTrainingProcessInstance(Float label, ref VBuffer feat, Float weight)
{
}
+ protected abstract TPredictor CreatePredictor();
+
///
/// The basic training calls the optimizer
///
- public override void Train(RoleMappedData data)
+ public override TPredictor Train(TrainContext context)
{
- Contracts.CheckValue(data, nameof(data));
+ Contracts.CheckValue(context, nameof(context));
+ var data = context.TrainingSet;
+ _srcPredictor = context.TrainingSet as TPredictor;
data.CheckFeatureFloatVector(out NumFeatures);
CheckLabel(data);
data.CheckOptFloatWeight();
@@ -318,16 +309,18 @@ public override void Train(RoleMappedData data)
if (NumFeatures >= Utils.ArrayMaxSize / ClassCount)
{
throw Contracts.ExceptParam(nameof(data),
- String.Format("The number of model parameters which is equal to ('# of features' + 1) * '# of classes' should be less than or equal to {0}.", Utils.ArrayMaxSize));
+ "The number of model parameters which is equal to ('# of features' + 1) * '# of classes' should be less than or equal to {0}.", Utils.ArrayMaxSize);
}
using (var ch = Host.Start("Training"))
{
TrainCore(ch, data);
+ var pred = CreatePredictor();
ch.Done();
+ return pred;
}
}
-
+
private void TrainCore(IChannel ch, RoleMappedData data)
{
Host.AssertValue(ch);
diff --git a/src/Microsoft.ML.StandardLearners/Standard/LogisticRegression/LogisticRegression.cs b/src/Microsoft.ML.StandardLearners/Standard/LogisticRegression/LogisticRegression.cs
index 3cf97ea801..f720b960a4 100644
--- a/src/Microsoft.ML.StandardLearners/Standard/LogisticRegression/LogisticRegression.cs
+++ b/src/Microsoft.ML.StandardLearners/Standard/LogisticRegression/LogisticRegression.cs
@@ -55,9 +55,7 @@ public LogisticRegression(IHostEnvironment env, Arguments args)
_posWeight = 0;
}
- public override bool NeedCalibration { get { return false; } }
-
- public override PredictionKind PredictionKind { get { return PredictionKind.BinaryClassification; } }
+ public override PredictionKind PredictionKind => PredictionKind.BinaryClassification;
protected override void CheckLabel(RoleMappedData data)
{
@@ -373,7 +371,7 @@ protected override VBuffer InitializeWeightsFromPredictor(ParameterMixing
return InitializeWeights(pred.Weights2, new[] { pred.Bias });
}
- public override ParameterMixingCalibratedPredictor CreatePredictor()
+ protected override ParameterMixingCalibratedPredictor CreatePredictor()
{
// Logistic regression is naturally calibrated to
// output probabilities when transformed using
diff --git a/src/Microsoft.ML.StandardLearners/Standard/LogisticRegression/MulticlassLogisticRegression.cs b/src/Microsoft.ML.StandardLearners/Standard/LogisticRegression/MulticlassLogisticRegression.cs
index 66c1d41084..5f7712843f 100644
--- a/src/Microsoft.ML.StandardLearners/Standard/LogisticRegression/MulticlassLogisticRegression.cs
+++ b/src/Microsoft.ML.StandardLearners/Standard/LogisticRegression/MulticlassLogisticRegression.cs
@@ -67,16 +67,14 @@ public sealed class Arguments : ArgumentsBase
private LinearModelStatistics _stats;
- protected override int ClassCount { get { return _numClasses; } }
+ protected override int ClassCount => _numClasses;
public MulticlassLogisticRegression(IHostEnvironment env, Arguments args)
: base(args, env, LoadNameValue, Contracts.CheckRef(args, nameof(args)).ShowTrainingStats)
{
}
- public override bool NeedCalibration { get { return false; } }
-
- public override PredictionKind PredictionKind { get { return PredictionKind.MultiClassClassification; } }
+ public override PredictionKind PredictionKind => PredictionKind.MultiClassClassification;
protected override void CheckLabel(RoleMappedData data)
{
@@ -203,7 +201,7 @@ protected override VBuffer InitializeWeightsFromPredictor(MulticlassLogis
return InitializeWeights(srcPredictor.DenseWeightsEnumerable(), srcPredictor.BiasesEnumerable());
}
- public override MulticlassLogisticRegressionPredictor CreatePredictor()
+ protected override MulticlassLogisticRegressionPredictor CreatePredictor()
{
if (_numClasses < 1)
throw Contracts.Except("Cannot create a multiclass predictor with {0} classes", _numClasses);
diff --git a/src/Microsoft.ML.StandardLearners/Standard/MultiClass/MetaMulticlassTrainer.cs b/src/Microsoft.ML.StandardLearners/Standard/MultiClass/MetaMulticlassTrainer.cs
index 9a3552f74b..52cd025370 100644
--- a/src/Microsoft.ML.StandardLearners/Standard/MultiClass/MetaMulticlassTrainer.cs
+++ b/src/Microsoft.ML.StandardLearners/Standard/MultiClass/MetaMulticlassTrainer.cs
@@ -13,9 +13,9 @@
namespace Microsoft.ML.Runtime.Learners
{
- using TScalarTrainer = ITrainer>;
+ using TScalarTrainer = ITrainer>;
- public abstract class MetaMulticlassTrainer : TrainerBase
+ public abstract class MetaMulticlassTrainer : TrainerBase
where TPred : IPredictor
where TArgs : MetaMulticlassTrainer.ArgumentsBase
{
@@ -38,15 +38,9 @@ public abstract class ArgumentsBase
protected readonly TArgs Args;
private TScalarTrainer _trainer;
- private TPred _pred;
public sealed override PredictionKind PredictionKind => PredictionKind.MultiClassClassification;
- public sealed override bool NeedNormalization { get; }
- public sealed override bool NeedCalibration => false;
-
- // No matter what the internal predictor, we're performing many passes
- // simply by virtue of this being a meta-trainer.
- public sealed override bool WantCaching => true;
+ public override TrainerInfo Info { get; }
internal MetaMulticlassTrainer(IHostEnvironment env, TArgs args, string name)
: base(env, name)
@@ -56,8 +50,9 @@ internal MetaMulticlassTrainer(IHostEnvironment env, TArgs args, string name)
Host.CheckUserArg(Args.PredictorType.IsGood(), nameof(Args.PredictorType));
// Create the first trainer so errors in the args surface early.
_trainer = Args.PredictorType.CreateInstance(Host);
- var ex = _trainer as ITrainerEx;
- NeedNormalization = ex != null && ex.NeedNormalization;
+ // Regarding caching, no matter what the internal predictor, we're performing many passes
+ // simply by virtue of this being a meta-trainer, so we will still cache.
+ Info = new TrainerInfo(normalization: _trainer.Info.NeedNormalization);
}
protected IDataView MapLabelsCore(ColumnType type, RefPredicate equalsTarget, RoleMappedData data, string dstName)
@@ -96,9 +91,11 @@ protected TScalarTrainer GetTrainer()
protected abstract TPred TrainCore(IChannel ch, RoleMappedData data, int count);
- public override void Train(RoleMappedData data)
+ public override TPred Train(TrainContext context)
{
- Host.CheckValue(data, nameof(data));
+ Host.CheckValue(context, nameof(context));
+ var data = context.TrainingSet;
+
data.CheckFeatureFloatVector();
int count;
@@ -107,16 +104,11 @@ public override void Train(RoleMappedData data)
using (var ch = Host.Start("Training"))
{
- _pred = TrainCore(ch, data, count);
- ch.Check(_pred != null, "Training did not result in a predictor");
+ var pred = TrainCore(ch, data, count);
+ ch.Check(pred != null, "Training did not result in a predictor");
ch.Done();
+ return pred;
}
}
-
- public override TPred CreatePredictor()
- {
- Host.Check(_pred != null, nameof(CreatePredictor) + " called before " + nameof(Train));
- return _pred;
- }
}
}
diff --git a/src/Microsoft.ML.StandardLearners/Standard/MultiClass/MultiClassNaiveBayesTrainer.cs b/src/Microsoft.ML.StandardLearners/Standard/MultiClass/MultiClassNaiveBayesTrainer.cs
index 94efc8cf05..94bac1d9ec 100644
--- a/src/Microsoft.ML.StandardLearners/Standard/MultiClass/MultiClassNaiveBayesTrainer.cs
+++ b/src/Microsoft.ML.StandardLearners/Standard/MultiClass/MultiClassNaiveBayesTrainer.cs
@@ -26,7 +26,7 @@
namespace Microsoft.ML.Runtime.Learners
{
- public sealed class MultiClassNaiveBayesTrainer : TrainerBase
+ public sealed class MultiClassNaiveBayesTrainer : TrainerBase
{
public const string LoadName = "MultiClassNaiveBayes";
internal const string UserName = "Multiclass Naive Bayes";
@@ -43,24 +43,21 @@ public sealed class Arguments : LearnerInputBaseWithLabel
{
}
- private MultiClassNaiveBayesPredictor _predictor;
-
public override PredictionKind PredictionKind => PredictionKind.MultiClassClassification;
- public override bool NeedNormalization => false;
-
- public override bool NeedCalibration => false;
-
- public override bool WantCaching => false;
+ private static readonly TrainerInfo _info = new TrainerInfo(normalization: false, caching: false);
+ public override TrainerInfo Info => _info;
public MultiClassNaiveBayesTrainer(IHostEnvironment env, Arguments args)
: base(env, LoadName)
{
+ Host.CheckValue(args, nameof(args));
}
- public override void Train(RoleMappedData data)
+ public override MultiClassNaiveBayesPredictor Train(TrainContext context)
{
- Host.CheckValue(data, nameof(data));
+ Host.CheckValue(context, nameof(context));
+ var data = context.TrainingSet;
Host.Check(data.Schema.Label != null, "Missing Label column");
Host.Check(data.Schema.Label.Type == NumberType.Float || data.Schema.Label.Type is KeyType,
"Invalid type for Label column, only floats and known-size keys are supported");
@@ -89,6 +86,7 @@ public override void Train(RoleMappedData data)
if (cursor.Row.Position > int.MaxValue)
{
ch.Warning("Stopping training because maximum number of rows have been traversed");
+ ch.Done();
break;
}
@@ -118,16 +116,12 @@ public override void Train(RoleMappedData data)
examplesProcessed += 1;
}
+ ch.Done();
}
Array.Resize(ref labelHistogram, labelCount);
Array.Resize(ref featureHistogram, labelCount);
- _predictor = new MultiClassNaiveBayesPredictor(Host, labelHistogram, featureHistogram, featureCount);
- }
-
- public override MultiClassNaiveBayesPredictor CreatePredictor()
- {
- return _predictor;
+ return new MultiClassNaiveBayesPredictor(Host, labelHistogram, featureHistogram, featureCount);
}
[TlcModule.EntryPoint(Name = "Trainers.NaiveBayesClassifier",
diff --git a/src/Microsoft.ML.StandardLearners/Standard/MultiClass/Ova.cs b/src/Microsoft.ML.StandardLearners/Standard/MultiClass/Ova.cs
index 7b5fcc8a93..62e3a79631 100644
--- a/src/Microsoft.ML.StandardLearners/Standard/MultiClass/Ova.cs
+++ b/src/Microsoft.ML.StandardLearners/Standard/MultiClass/Ova.cs
@@ -34,7 +34,7 @@ namespace Microsoft.ML.Runtime.Learners
{
using CR = RoleMappedSchema.ColumnRole;
using TScalarPredictor = IPredictorProducing;
- using TScalarTrainer = ITrainer>;
+ using TScalarTrainer = ITrainer>;
public sealed class Ova : MetaMulticlassTrainer
{
@@ -81,9 +81,10 @@ private TScalarPredictor TrainOne(IChannel ch, TScalarTrainer trainer, RoleMappe
.Prepend(CR.Label.Bind(dstName));
var td = new RoleMappedData(view, roles);
- trainer.Train(td);
+ // REVIEW: In principle we could support validation sets and the like via the train context, but
+ // this is currently unsupported.
+ var predictor = trainer.Train(td);
- var predictor = trainer.CreatePredictor();
if (Args.UseProbabilities)
{
ICalibratorTrainer calibrator;
diff --git a/src/Microsoft.ML.StandardLearners/Standard/MultiClass/Pkpd.cs b/src/Microsoft.ML.StandardLearners/Standard/MultiClass/Pkpd.cs
index cf1e7c062b..1c4700ccdd 100644
--- a/src/Microsoft.ML.StandardLearners/Standard/MultiClass/Pkpd.cs
+++ b/src/Microsoft.ML.StandardLearners/Standard/MultiClass/Pkpd.cs
@@ -26,7 +26,7 @@
namespace Microsoft.ML.Runtime.Learners
{
- using TScalarTrainer = ITrainer>;
+ using TScalarTrainer = ITrainer>;
using TScalarPredictor = IPredictorProducing;
using TDistPredictor = IDistPredictorProducing;
using CR = RoleMappedSchema.ColumnRole;
@@ -78,14 +78,13 @@ private TDistPredictor TrainOne(IChannel ch, TScalarTrainer trainer, RoleMappedD
.Prepend(CR.Label.Bind(dstName));
var td = new RoleMappedData(view, roles);
- trainer.Train(td);
+ var predictor = trainer.Train(td);
ICalibratorTrainer calibrator;
if (!Args.Calibrator.IsGood())
calibrator = null;
else
calibrator = Args.Calibrator.CreateInstance(Host);
- TScalarPredictor predictor = trainer.CreatePredictor();
var res = CalibratorUtils.TrainCalibratorIfNeeded(Host, ch, calibrator, Args.MaxCalibrationExamples,
trainer, predictor, td);
var dist = res as TDistPredictor;
diff --git a/src/Microsoft.ML.StandardLearners/Standard/OlsLinearRegression.cs b/src/Microsoft.ML.StandardLearners/Standard/OlsLinearRegression.cs
index db271ff858..7f47271f68 100644
--- a/src/Microsoft.ML.StandardLearners/Standard/OlsLinearRegression.cs
+++ b/src/Microsoft.ML.StandardLearners/Standard/OlsLinearRegression.cs
@@ -30,7 +30,7 @@
namespace Microsoft.ML.Runtime.Learners
{
- public sealed class OlsLinearRegressionTrainer : TrainerBase
+ public sealed class OlsLinearRegressionTrainer : TrainerBase
{
public sealed class Arguments : LearnerInputBaseWithWeight
{
@@ -57,20 +57,15 @@ It assumes that the conditional mean of the dependent variable follows a linear
By minimizing the squares of the difference between observed values and the predictions, the parameters of the regressor can be estimated.
";
- private VBuffer _weights;
- private Float _bias;
-
- // These have length equal to the number of model parameters, i.e., one for bias plus length of weights.
- private Double[] _standardErrors;
- private Double[] _tValues;
- private Double[] _pValues;
-
- private Double _rSquared;
- private Double _rSquaredAdjusted;
-
private readonly Float _l2Weight;
private readonly bool _perParameterSignificance;
+ public override PredictionKind PredictionKind => PredictionKind.Regression;
+
+ // The training performs two passes, only. Probably not worth caching.
+ private static readonly TrainerInfo _info = new TrainerInfo(caching: false);
+ public override TrainerInfo Info => _info;
+
public OlsLinearRegressionTrainer(IHostEnvironment env, Arguments args)
: base(env, LoadNameValue)
{
@@ -80,25 +75,6 @@ public OlsLinearRegressionTrainer(IHostEnvironment env, Arguments args)
_perParameterSignificance = args.PerParameterSignificance;
}
- public override bool NeedNormalization
- {
- get { return true; }
- }
-
- public override bool NeedCalibration
- {
- get { return false; }
- }
-
- public override bool WantCaching
- {
- // Two passes, only. Probably not worth caching.
- get { return false; }
- }
-
- public override PredictionKind PredictionKind
- { get { return PredictionKind.Regression; } }
-
///
/// In several calculations, we calculate probabilities or other quantities that should range
/// from 0 to 1, but because of numerical imprecision may, in entirely innocent circumstances,
@@ -107,15 +83,14 @@ public override PredictionKind PredictionKind
/// The quantity that should be clamped from 0 to 1
/// Either p, or 0 or 1 if it was outside the range 0 to 1
private static Double ProbClamp(Double p)
- {
- return Math.Max(0, Math.Min(p, 1));
- }
+ => Math.Max(0, Math.Min(p, 1));
- public override void Train(RoleMappedData examples)
+ public override OlsLinearRegressionPredictor Train(TrainContext context)
{
using (var ch = Host.Start("Training"))
{
- ch.CheckValue(examples, nameof(examples));
+ ch.CheckValue(context, nameof(context));
+ var examples = context.TrainingSet;
ch.CheckParam(examples.Schema.Feature != null, nameof(examples), "Need a feature column");
ch.CheckParam(examples.Schema.Label != null, nameof(examples), "Need a label column");
@@ -133,12 +108,13 @@ public override void Train(RoleMappedData examples)
var cursorFactory = new FloatLabelCursor.Factory(examples, CursOpt.Label | CursOpt.Features);
- TrainCore(ch, cursorFactory, typeFeat.VectorSize);
+ var pred = TrainCore(ch, cursorFactory, typeFeat.VectorSize);
ch.Done();
+ return pred;
}
}
- private void TrainCore(IChannel ch, FloatLabelCursor.Factory cursorFactory, int featureCount)
+ private OlsLinearRegressionPredictor TrainCore(IChannel ch, FloatLabelCursor.Factory cursorFactory, int featureCount)
{
Host.AssertValue(ch);
ch.AssertValue(cursorFactory);
@@ -262,26 +238,21 @@ private void TrainCore(IChannel ch, FloatLabelCursor.Factory cursorFactory, int
var weights = VBufferUtils.CreateDense(beta.Length - 1);
for (int i = 1; i < beta.Length; ++i)
weights.Values[i - 1] = (Float)beta[i];
- _weights = weights;
- _bias = (Float)beta[0];
- _standardErrors = _tValues = _pValues = null;
+ var bias = (Float)beta[0];
if (!(_l2Weight > 0) && m == n)
{
// We would expect the solution to the problem to be exact in this case.
- _rSquared = 1;
- _rSquaredAdjusted = Float.NaN;
ch.Info("Number of examples equals number of parameters, solution is exact but no statistics can be derived");
- ch.Done();
- return;
+ return new OlsLinearRegressionPredictor(Host, ref weights, bias, null, null, null, 1, Float.NaN);
}
Double rss = 0; // residual sum of squares
Double tss = 0; // total sum of squares
using (var cursor = cursorFactory.Create())
{
- var lrPredictor = new LinearRegressionPredictor(Host, ref _weights, _bias);
+ var lrPredictor = new LinearRegressionPredictor(Host, ref weights, bias);
var lrMap = lrPredictor.GetMapper, Float>();
- Float yh = default(Float);
+ Float yh = default;
while (cursor.MoveNext())
{
var features = cursor.Features;
@@ -292,27 +263,28 @@ private void TrainCore(IChannel ch, FloatLabelCursor.Factory cursorFactory, int
tss += ydm * ydm;
}
}
- _rSquared = ProbClamp(1 - (rss / tss));
+ var rSquared = ProbClamp(1 - (rss / tss));
// R^2 adjusted differs from the normal formula on account of the bias term, by Said's reckoning.
+ double rSquaredAdjusted;
if (n > m)
{
- _rSquaredAdjusted = ProbClamp(1 - (1 - _rSquared) * (n - 1) / (n - m));
+ rSquaredAdjusted = ProbClamp(1 - (1 - rSquared) * (n - 1) / (n - m));
ch.Info("Coefficient of determination R2 = {0:g}, or {1:g} (adjusted)",
- _rSquared, _rSquaredAdjusted);
+ rSquared, rSquaredAdjusted);
}
else
- _rSquaredAdjusted = Double.NaN;
+ rSquaredAdjusted = Double.NaN;
// The per parameter significance is compute intensive and may not be required for all practitioners.
// Also we can't estimate it, unless we can estimate the variance, which requires more examples than
// parameters.
if (!_perParameterSignificance || m >= n)
- return;
+ return new OlsLinearRegressionPredictor(Host, ref weights, bias, null, null, null, rSquared, rSquaredAdjusted);
- ch.Assert(!Double.IsNaN(_rSquaredAdjusted));
- _standardErrors = new Double[m];
- _tValues = new Double[m];
- _pValues = new Double[m];
+ ch.Assert(!Double.IsNaN(rSquaredAdjusted));
+ var standardErrors = new Double[m];
+ var tValues = new Double[m];
+ var pValues = new Double[m];
// Invert X'X:
Mkl.Pptri(Mkl.Layout.RowMajor, Mkl.UpLo.Lo, m, xtx);
var s2 = rss / (n - m); // estimate of variance of y
@@ -320,7 +292,7 @@ private void TrainCore(IChannel ch, FloatLabelCursor.Factory cursorFactory, int
for (int i = 0; i < m; i++)
{
// Initialize with inverse Hessian.
- _standardErrors[i] = (Single)xtx[i * (i + 1) / 2 + i];
+ standardErrors[i] = (Single)xtx[i * (i + 1) / 2 + i];
}
if (_l2Weight > 0)
@@ -334,9 +306,9 @@ private void TrainCore(IChannel ch, FloatLabelCursor.Factory cursorFactory, int
{
var entry = (Single)xtx[ioffset];
var adjustment = -reg * entry * entry;
- _standardErrors[iRow] -= adjustment;
+ standardErrors[iRow] -= adjustment;
if (0 < iCol && iCol < iRow)
- _standardErrors[iCol] -= adjustment;
+ standardErrors[iCol] -= adjustment;
ioffset++;
}
}
@@ -347,17 +319,14 @@ private void TrainCore(IChannel ch, FloatLabelCursor.Factory cursorFactory, int
for (int i = 0; i < m; i++)
{
// sqrt of diagonal entries of s2 * inverse(X'X + reg * I) * X'X * inverse(X'X + reg * I).
- _standardErrors[i] = Math.Sqrt(s2 * _standardErrors[i]);
- ch.Check(FloatUtils.IsFinite(_standardErrors[i]), "Non-finite standard error detected from OLS solution");
- _tValues[i] = beta[i] / _standardErrors[i];
- _pValues[i] = (Float)MathUtils.TStatisticToPValue(_tValues[i], n - m);
- ch.Check(0 <= _pValues[i] && _pValues[i] <= 1, "p-Value calculated outside expected [0,1] range");
+ standardErrors[i] = Math.Sqrt(s2 * standardErrors[i]);
+ ch.Check(FloatUtils.IsFinite(standardErrors[i]), "Non-finite standard error detected from OLS solution");
+ tValues[i] = beta[i] / standardErrors[i];
+ pValues[i] = (Float)MathUtils.TStatisticToPValue(tValues[i], n - m);
+ ch.Check(0 <= pValues[i] && pValues[i] <= 1, "p-Value calculated outside expected [0,1] range");
}
- }
- public override OlsLinearRegressionPredictor CreatePredictor()
- {
- return new OlsLinearRegressionPredictor(Host, ref _weights, _bias, _standardErrors, _tValues, _pValues, _rSquared, _rSquaredAdjusted);
+ return new OlsLinearRegressionPredictor(Host, ref weights, bias, standardErrors, tValues, pValues, rSquared, rSquaredAdjusted);
}
internal static class Mkl
diff --git a/src/Microsoft.ML.StandardLearners/Standard/Online/AveragedPerceptron.cs b/src/Microsoft.ML.StandardLearners/Standard/Online/AveragedPerceptron.cs
index aa5ecb67a5..a2e09b6905 100644
--- a/src/Microsoft.ML.StandardLearners/Standard/Online/AveragedPerceptron.cs
+++ b/src/Microsoft.ML.StandardLearners/Standard/Online/AveragedPerceptron.cs
@@ -49,17 +49,14 @@ public class Arguments : AveragedLinearArguments
public int MaxCalibrationExamples = 1000000;
}
+ protected override bool NeedCalibration => true;
+
public AveragedPerceptronTrainer(IHostEnvironment env, Arguments args)
: base(args, env, UserNameValue)
{
LossFunction = Args.LossFunction.CreateComponent(env);
}
- public override bool NeedCalibration
- {
- get { return true; }
- }
-
public override PredictionKind PredictionKind { get { return PredictionKind.BinaryClassification; } }
protected override void CheckLabel(RoleMappedData data)
@@ -68,7 +65,7 @@ protected override void CheckLabel(RoleMappedData data)
data.CheckBinaryLabel();
}
- public override LinearBinaryPredictor CreatePredictor()
+ protected override LinearBinaryPredictor CreatePredictor()
{
Contracts.Assert(WeightsScale == 1);
diff --git a/src/Microsoft.ML.StandardLearners/Standard/Online/LinearSvm.cs b/src/Microsoft.ML.StandardLearners/Standard/Online/LinearSvm.cs
index d2b8f0b30f..d435539e95 100644
--- a/src/Microsoft.ML.StandardLearners/Standard/Online/LinearSvm.cs
+++ b/src/Microsoft.ML.StandardLearners/Standard/Online/LinearSvm.cs
@@ -80,6 +80,8 @@ public sealed class Arguments : OnlineLinearArguments
private Float _weightsUpdateScale;
private Float _biasUpdate;
+ protected override bool NeedCalibration => true;
+
public LinearSvm(IHostEnvironment env, Arguments args)
: base(args, env, UserNameValue)
{
@@ -87,11 +89,6 @@ public LinearSvm(IHostEnvironment env, Arguments args)
Contracts.CheckUserArg(args.BatchSize > 0, nameof(args.BatchSize), UserErrorPositive);
}
- public override bool NeedCalibration
- {
- get { return true; }
- }
-
public override PredictionKind PredictionKind { get { return PredictionKind.BinaryClassification; } }
protected override void CheckLabel(RoleMappedData data)
@@ -221,7 +218,7 @@ private void UpdateWeights(ref VBuffer weightsUpdate, Float weightsUpdate
}
}
- public override TPredictor CreatePredictor()
+ protected override TPredictor CreatePredictor()
{
Contracts.Assert(WeightsScale == 1);
return new LinearBinaryPredictor(Host, ref Weights, Bias);
diff --git a/src/Microsoft.ML.StandardLearners/Standard/Online/OnlineGradientDescent.cs b/src/Microsoft.ML.StandardLearners/Standard/Online/OnlineGradientDescent.cs
index f345466e19..5cbdc01478 100644
--- a/src/Microsoft.ML.StandardLearners/Standard/Online/OnlineGradientDescent.cs
+++ b/src/Microsoft.ML.StandardLearners/Standard/Online/OnlineGradientDescent.cs
@@ -58,11 +58,6 @@ public OnlineGradientDescentTrainer(IHostEnvironment env, Arguments args)
LossFunction = args.LossFunction.CreateComponent(env);
}
- public override bool NeedCalibration
- {
- get { return false; }
- }
-
public override PredictionKind PredictionKind { get { return PredictionKind.Regression; } }
protected override void CheckLabel(RoleMappedData data)
@@ -70,7 +65,7 @@ protected override void CheckLabel(RoleMappedData data)
data.CheckRegressionLabel();
}
- public override TPredictor CreatePredictor()
+ protected override TPredictor CreatePredictor()
{
Contracts.Assert(WeightsScale == 1);
VBuffer weights = default(VBuffer);
diff --git a/src/Microsoft.ML.StandardLearners/Standard/Online/OnlineLinear.cs b/src/Microsoft.ML.StandardLearners/Standard/Online/OnlineLinear.cs
index bcd4b33d58..60fe7f9705 100644
--- a/src/Microsoft.ML.StandardLearners/Standard/Online/OnlineLinear.cs
+++ b/src/Microsoft.ML.StandardLearners/Standard/Online/OnlineLinear.cs
@@ -21,7 +21,7 @@ public abstract class OnlineLinearArguments : LearnerInputBaseWithLabel
{
[Argument(ArgumentType.AtMostOnce, HelpText = "Number of iterations", ShortName = "iter", SortOrder = 50)]
[TGUI(Label = "Number of Iterations", Description = "Number of training iterations through data", SuggestedSweeps = "1,10,100")]
- [TlcModule.SweepableLongParamAttribute("NumIterations", 1, 100, stepSize:10, isLogScale:true)]
+ [TlcModule.SweepableLongParamAttribute("NumIterations", 1, 100, stepSize: 10, isLogScale: true)]
public int NumIterations = 1;
[Argument(ArgumentType.AtMostOnce, HelpText = "Initial Weights and bias, comma-separated", ShortName = "initweights")]
@@ -34,16 +34,14 @@ public abstract class OnlineLinearArguments : LearnerInputBaseWithLabel
public Float InitWtsDiameter = 0;
[Argument(ArgumentType.AtMostOnce, HelpText = "Whether to shuffle for each training iteration", ShortName = "shuf")]
- [TlcModule.SweepableDiscreteParamAttribute("Shuffle", new object[] {false, true})]
+ [TlcModule.SweepableDiscreteParamAttribute("Shuffle", new object[] { false, true })]
public bool Shuffle = true;
[Argument(ArgumentType.AtMostOnce, HelpText = "Size of cache when trained in Scope", ShortName = "cache")]
public int StreamingCacheSize = 1000000;
}
- public abstract class OnlineLinearTrainer :
- TrainerBase,
- IIncrementalTrainer
+ public abstract class OnlineLinearTrainer : TrainerBase
where TArguments : OnlineLinearArguments
where TPredictor : IPredictorProducing
{
@@ -72,6 +70,10 @@ public abstract class OnlineLinearTrainer :
protected const string UserErrorPositive = "must be positive";
protected const string UserErrorNonNegative = "must be non-negative";
+ public override TrainerInfo Info { get; }
+
+ protected virtual bool NeedCalibration => false;
+
protected OnlineLinearTrainer(TArguments args, IHostEnvironment env, string name)
: base(env, name)
{
@@ -81,22 +83,12 @@ protected OnlineLinearTrainer(TArguments args, IHostEnvironment env, string name
Contracts.CheckUserArg(args.StreamingCacheSize > 0, nameof(args.StreamingCacheSize), UserErrorPositive);
Args = args;
- }
-
- public override bool NeedNormalization
- {
- get { return true; }
- }
-
- public override bool WantCaching
- {
- // REVIEW: This could return true if there are more than 0 iterations,
- // if we got around the whole shuffling issue.
- get { return true; }
+ // REVIEW: Caching could be false for one iteration, if we got around the whole shuffling issue.
+ Info = new TrainerInfo(calibration: NeedCalibration);
}
///
- /// Propagates the _weightsScale to the weights vector.
+ /// Propagates the to the vector.
///
protected void ScaleWeights()
{
@@ -108,9 +100,9 @@ protected void ScaleWeights()
}
///
- /// Conditionally propagates the _weightsScale to the weights vector when
- /// it reaches a scale where additions to weights would start dropping too much
- /// precision. ("Too much" is mostly empirically defined.)
+ /// Conditionally propagates the to the vector
+ /// when it reaches a scale where additions to weights would start dropping too much precision.
+ /// ("Too much" is mostly empirically defined.)
///
protected void ScaleWeightsIfNeeded()
{
@@ -119,18 +111,20 @@ protected void ScaleWeightsIfNeeded()
ScaleWeights();
}
- private void TrainEx(RoleMappedData data, LinearPredictor predictor)
+ public override TPredictor Train(TrainContext context)
{
- Contracts.AssertValue(data, nameof(data));
- Contracts.AssertValueOrNull(predictor);
+ Host.CheckValue(context, nameof(context));
+ var initPredictor = context.InitialPredictor;
+ var initLinearPred = initPredictor as LinearPredictor ?? (initPredictor as CalibratedPredictorBase)?.SubPredictor as LinearPredictor;
+ Host.CheckParam(initPredictor == null || initLinearPred != null, nameof(context), "Not a linear predictor.");
+ var data = context.TrainingSet;
- int numFeatures;
- data.CheckFeatureFloatVector(out numFeatures);
+ data.CheckFeatureFloatVector(out int numFeatures);
CheckLabel(data);
using (var ch = Host.Start("Training"))
{
- InitCore(ch, numFeatures, predictor);
+ InitCore(ch, numFeatures, initLinearPred);
// InitCore should set the number of features field.
Contracts.Assert(NumFeatures > 0);
@@ -150,23 +144,11 @@ private void TrainEx(RoleMappedData data, LinearPredictor predictor)
ch.Done();
}
- }
- public override void Train(RoleMappedData data)
- {
- Host.CheckValue(data, nameof(data));
- TrainEx(data, null);
+ return CreatePredictor();
}
- public void Train(RoleMappedData data, IPredictor predictor)
- {
- Host.CheckValue(data, nameof(data));
- Host.CheckValue(predictor, nameof(predictor));
- LinearPredictor pred = (predictor as CalibratedPredictorBase)?.SubPredictor as LinearPredictor;
- pred = pred ?? predictor as LinearPredictor;
- Host.CheckParam(pred != null, nameof(predictor), "Not a linear predictor.");
- TrainEx(data, pred);
- }
+ protected abstract TPredictor CreatePredictor();
protected abstract void CheckLabel(RoleMappedData data);
diff --git a/src/Microsoft.ML.StandardLearners/Standard/PoissonRegression/PoissonRegression.cs b/src/Microsoft.ML.StandardLearners/Standard/PoissonRegression/PoissonRegression.cs
index 9322c2cc75..94dbb42325 100644
--- a/src/Microsoft.ML.StandardLearners/Standard/PoissonRegression/PoissonRegression.cs
+++ b/src/Microsoft.ML.StandardLearners/Standard/PoissonRegression/PoissonRegression.cs
@@ -45,9 +45,7 @@ public PoissonRegression(IHostEnvironment env, Arguments args)
{
}
- public override bool NeedCalibration { get { return false; } }
-
- public override PredictionKind PredictionKind { get { return PredictionKind.Regression; } }
+ public override PredictionKind PredictionKind => PredictionKind.Regression;
protected override void CheckLabel(RoleMappedData data)
{
@@ -106,7 +104,7 @@ protected override Float AccumulateOneGradient(ref VBuffer feat, Float la
return -(y * dot - lambda) * weight;
}
- public override PoissonRegressionPredictor CreatePredictor()
+ protected override PoissonRegressionPredictor CreatePredictor()
{
VBuffer weights = default(VBuffer);
CurrentWeights.CopyTo(ref weights, 1, CurrentWeights.Length - 1);
diff --git a/src/Microsoft.ML.StandardLearners/Standard/SdcaMultiClass.cs b/src/Microsoft.ML.StandardLearners/Standard/SdcaMultiClass.cs
index 20bc349a7c..facceffd70 100644
--- a/src/Microsoft.ML.StandardLearners/Standard/SdcaMultiClass.cs
+++ b/src/Microsoft.ML.StandardLearners/Standard/SdcaMultiClass.cs
@@ -30,7 +30,7 @@ namespace Microsoft.ML.Runtime.Learners
// SDCA linear multiclass trainer.
///
- public class SdcaMultiClassTrainer : SdcaTrainerBase, ITrainerEx
+ public class SdcaMultiClassTrainer : SdcaTrainerBase
{
public const string LoadNameValue = "SDCAMC";
public const string UserNameValue = "Fast Linear Multi-class Classification (SA-SDCA)";
@@ -45,21 +45,8 @@ public sealed class Arguments : ArgumentsBase
private readonly ISupportSdcaClassificationLoss _loss;
private readonly Arguments _args;
- private int _numClasses;
- public override PredictionKind PredictionKind
- {
- get { return PredictionKind.MultiClassClassification; }
- }
-
- protected override int WeightArraySize
- {
- get
- {
- Contracts.Assert(_numClasses > 0, "_numClasses should already have been initialized when this property is called.");
- return _numClasses;
- }
- }
+ public override PredictionKind PredictionKind => PredictionKind.MultiClassClassification;
public SdcaMultiClassTrainer(IHostEnvironment env, Arguments args)
: base(args, env, LoadNameValue)
@@ -70,8 +57,6 @@ public SdcaMultiClassTrainer(IHostEnvironment env, Arguments args)
_args = args;
}
- public override bool NeedCalibration { get { return false; } }
-
///
protected override void TrainWithoutLock(IProgressChannelProvider progress, FloatLabelCursor.Factory cursorFactory, IRandom rand,
IdToIdxLookup idToIdx, int numThreads, DualsTableBase duals, Float[] biasReg, Float[] invariants, Float lambdaNInv,
@@ -82,11 +67,9 @@ protected override void TrainWithoutLock(IProgressChannelProvider progress, Floa
Contracts.AssertValueOrNull(idToIdx);
Contracts.AssertValueOrNull(invariants);
Contracts.AssertValueOrNull(featureNormSquared);
- int weightArraySize = WeightArraySize;
- Contracts.Assert(weightArraySize == _numClasses);
- Contracts.Assert(Utils.Size(weights) == weightArraySize);
- Contracts.Assert(Utils.Size(biasReg) == weightArraySize);
- Contracts.Assert(Utils.Size(biasUnreg) == weightArraySize);
+ int numClasses = Utils.Size(weights);
+ Contracts.Assert(Utils.Size(biasReg) == numClasses);
+ Contracts.Assert(Utils.Size(biasUnreg) == numClasses);
int maxUpdateTrials = 2 * numThreads;
var l1Threshold = _args.L1Threshold.Value;
@@ -101,11 +84,11 @@ protected override void TrainWithoutLock(IProgressChannelProvider progress, Floa
if (pch != null)
pch.SetHeader(new ProgressHeader("examples"), e => e.SetProgress(0, rowCount));
- Func getIndexFromId = GetIndexFromIdGetter(idToIdx);
+ Func getIndexFromId = GetIndexFromIdGetter(idToIdx, biasReg.Length);
while (cursor.MoveNext())
{
long idx = getIndexFromId(cursor.Id);
- long dualIndexInitPos = idx * weightArraySize;
+ long dualIndexInitPos = idx * numClasses;
var features = cursor.Features;
var label = (int)cursor.Label;
Float invariant;
@@ -139,7 +122,7 @@ protected override void TrainWithoutLock(IProgressChannelProvider progress, Floa
Float labelAdjustment = 0;
// Iterates through all classes.
- for (int iClass = 0; iClass < _numClasses; iClass++)
+ for (int iClass = 0; iClass < numClasses; iClass++)
{
// Skip the dual/weights/bias update for label class. Will be taken care of at the end.
if (iClass == label)
@@ -161,9 +144,7 @@ protected override void TrainWithoutLock(IProgressChannelProvider progress, Floa
dualUpdate -= adjustment;
bool success = false;
duals.ApplyAt(dualIndex, (long index, ref Float value) =>
- {
- success = Interlocked.CompareExchange(ref value, dual + dualUpdate, dual) == dual;
- });
+ success = Interlocked.CompareExchange(ref value, dual + dualUpdate, dual) == dual);
if (success)
{
@@ -251,24 +232,23 @@ protected override bool CheckConvergence(
{
Contracts.AssertValue(weights);
Contracts.AssertValue(duals);
- Contracts.Assert(weights.Length == _numClasses);
- Contracts.Assert(duals.Length >= _numClasses * count);
+ int numClasses = weights.Length;
+ Contracts.Assert(duals.Length >= numClasses * count);
Contracts.AssertValueOrNull(idToIdx);
- int weightArraySize = WeightArraySize;
- Contracts.Assert(weightArraySize == _numClasses);
- Contracts.Assert(Utils.Size(weights) == weightArraySize);
- Contracts.Assert(Utils.Size(biasReg) == weightArraySize);
- Contracts.Assert(Utils.Size(biasUnreg) == weightArraySize);
+ Contracts.Assert(Utils.Size(weights) == numClasses);
+ Contracts.Assert(Utils.Size(biasReg) == numClasses);
+ Contracts.Assert(Utils.Size(biasUnreg) == numClasses);
Contracts.Assert(Utils.Size(metrics) == 6);
var reportedValues = new Double?[metrics.Length + 1];
reportedValues[metrics.Length] = iter;
var lossSum = new CompensatedSum();
var dualLossSum = new CompensatedSum();
+ int numFeatures = weights[0].Length;
using (var cursor = cursorFactory.Create())
{
long row = 0;
- Func getIndexFromIdAndRow = GetIndexFromIdAndRowGetter(idToIdx);
+ Func getIndexFromIdAndRow = GetIndexFromIdAndRowGetter(idToIdx, biasReg.Length);
// Iterates through data to compute loss function.
while (cursor.MoveNext())
{
@@ -279,8 +259,8 @@ protected override bool CheckConvergence(
Double subLoss = 0;
Double subDualLoss = 0;
long idx = getIndexFromIdAndRow(cursor.Id, row);
- long dualIndex = idx * _numClasses;
- for (int iClass = 0; iClass < _numClasses; iClass++)
+ long dualIndex = idx * numClasses;
+ for (int iClass = 0; iClass < numClasses; iClass++)
{
if (iClass == label)
{
@@ -290,7 +270,7 @@ protected override bool CheckConvergence(
var currentClassOutput = WDot(ref features, ref weights[iClass], biasReg[iClass] + biasUnreg[iClass]);
subLoss += _loss.Loss(labelOutput - currentClassOutput, 1);
- Contracts.Assert(dualIndex == iClass + idx * _numClasses);
+ Contracts.Assert(dualIndex == iClass + idx * numClasses);
var dual = duals[dualIndex++];
subDualLoss += _loss.DualLoss(1, dual);
}
@@ -300,7 +280,7 @@ protected override bool CheckConvergence(
row++;
}
- Host.Assert(idToIdx == null || row * WeightArraySize == duals.Length);
+ Host.Assert(idToIdx == null || row * numClasses == duals.Length);
}
Contracts.Assert(_args.L2Const.HasValue);
@@ -311,7 +291,7 @@ protected override bool CheckConvergence(
Double weightsL1Norm = 0;
Double weightsL2NormSquared = 0;
Double biasRegularizationAdjustment = 0;
- for (int iClass = 0; iClass < _numClasses; iClass++)
+ for (int iClass = 0; iClass < numClasses; iClass++)
{
weightsL1Norm += VectorUtils.L1Norm(ref weights[iClass]) + Math.Abs(biasReg[iClass]);
weightsL2NormSquared += VectorUtils.NormSquared(weights[iClass]) + biasReg[iClass] * biasReg[iClass];
@@ -330,13 +310,14 @@ protected override bool CheckConvergence(
metrics[(int)MetricKind.DualityGap] = dualityGap;
metrics[(int)MetricKind.BiasUnreg] = biasUnreg[0];
metrics[(int)MetricKind.BiasReg] = biasReg[0];
- metrics[(int)MetricKind.L1Sparsity] = _args.L1Threshold == 0 ? 1 : (Double)weights.Sum(weight => weight.Values.Count(w => w != 0)) / (_numClasses * NumFeatures);
+ metrics[(int)MetricKind.L1Sparsity] = _args.L1Threshold == 0 ? 1 : weights.Sum(
+ weight => weight.Values.Count(w => w != 0)) / (numClasses * numFeatures);
bool converged = dualityGap / newLoss < _args.ConvergenceTolerance;
if (metrics[(int)MetricKind.Loss] < bestPrimalLoss)
{
- for (int iClass = 0; iClass < _numClasses; iClass++)
+ for (int iClass = 0; iClass < numClasses; iClass++)
{
// Maintain a copy of weights and bias with best primal loss thus far.
// This is some extra work and uses extra memory, but it seems worth doing it.
@@ -358,14 +339,19 @@ protected override bool CheckConvergence(
return converged;
}
- public override TVectorPredictor CreatePredictor()
+ protected override TVectorPredictor CreatePredictor(VBuffer[] weights, Float[] bias)
{
- return new MulticlassLogisticRegressionPredictor(Host, Weights, Bias, _numClasses, NumFeatures, null, stats: null);
+ Host.CheckValue(weights, nameof(weights));
+ Host.CheckValue(bias, nameof(bias));
+ Host.CheckParam(weights.Length > 0, nameof(weights));
+ Host.CheckParam(weights.Length == bias.Length, nameof(weights));
+
+ return new MulticlassLogisticRegressionPredictor(Host, weights, bias, bias.Length, weights[0].Length, null, stats: null);
}
- protected override void CheckLabel(RoleMappedData examples)
+ protected override void CheckLabel(RoleMappedData examples, out int weightSetCount)
{
- examples.CheckMultiClassLabel(out _numClasses);
+ examples.CheckMultiClassLabel(out weightSetCount);
}
protected override Float[] InitializeFeatureNormSquared(int length)
diff --git a/src/Microsoft.ML.StandardLearners/Standard/SdcaRegression.cs b/src/Microsoft.ML.StandardLearners/Standard/SdcaRegression.cs
index 512818bba7..466625a679 100644
--- a/src/Microsoft.ML.StandardLearners/Standard/SdcaRegression.cs
+++ b/src/Microsoft.ML.StandardLearners/Standard/SdcaRegression.cs
@@ -26,7 +26,7 @@ namespace Microsoft.ML.Runtime.Learners
using TScalarPredictor = IPredictorWithFeatureWeights;
///
- public sealed class SdcaRegressionTrainer : SdcaTrainerBase, ITrainer, ITrainerEx
+ public sealed class SdcaRegressionTrainer : SdcaTrainerBase
{
public const string LoadNameValue = "SDCAR";
public const string UserNameValue = "Fast Linear Regression (SA-SDCA)";
@@ -51,11 +51,7 @@ public Arguments()
private readonly ISupportSdcaRegressionLoss _loss;
private readonly Arguments _args;
- public override PredictionKind PredictionKind { get { return PredictionKind.Regression; } }
-
- public override bool NeedCalibration { get { return false; } }
-
- protected override int WeightArraySize { get { return 1; } }
+ public override PredictionKind PredictionKind => PredictionKind.Regression;
public SdcaRegressionTrainer(IHostEnvironment env, Arguments args)
: base(args, env, LoadNameValue)
@@ -66,22 +62,16 @@ public SdcaRegressionTrainer(IHostEnvironment env, Arguments args)
_args = args;
}
- public override IPredictor CreatePredictor()
- {
- Contracts.Assert(WeightArraySize == 1);
- Contracts.Assert(Utils.Size(Weights) == 1);
- Contracts.Assert(Utils.Size(Bias) == 1);
- Host.Check(Weights[0].Length > 0);
- VBuffer maybeSparseWeights = VBufferUtils.CreateEmpty(Weights[0].Length);
- VBufferUtils.CreateMaybeSparseCopy(ref Weights[0], ref maybeSparseWeights, Conversions.Instance.GetIsDefaultPredicate(NumberType.Float));
- return new LinearRegressionPredictor(Host, ref maybeSparseWeights, Bias[0]);
- }
-
- TScalarPredictor ITrainer.CreatePredictor()
+ protected override TScalarPredictor CreatePredictor(VBuffer[] weights, Float[] bias)
{
- var predictor = CreatePredictor() as TScalarPredictor;
- Contracts.AssertValue(predictor);
- return predictor;
+ Host.CheckParam(Utils.Size(weights) == 1, nameof(weights));
+ Host.CheckParam(Utils.Size(bias) == 1, nameof(bias));
+ Host.CheckParam(weights[0].Length > 0, nameof(weights));
+
+ VBuffer maybeSparseWeights = default;
+ VBufferUtils.CreateMaybeSparseCopy(ref weights[0], ref maybeSparseWeights,
+ Conversions.Instance.GetIsDefaultPredicate(NumberType.Float));
+ return new LinearRegressionPredictor(Host, ref maybeSparseWeights, bias[0]);
}
protected override Float GetInstanceWeight(FloatLabelCursor cursor)
@@ -89,9 +79,10 @@ protected override Float GetInstanceWeight(FloatLabelCursor cursor)
return cursor.Weight;
}
- protected override void CheckLabel(RoleMappedData examples)
+ protected override void CheckLabel(RoleMappedData examples, out int weightSetCount)
{
examples.CheckRegressionLabel();
+ weightSetCount = 1;
}
// REVIEW: No extra benefits from using more threads in training.
diff --git a/src/Microsoft.ML.StandardLearners/Standard/Simple/SimpleTrainers.cs b/src/Microsoft.ML.StandardLearners/Standard/Simple/SimpleTrainers.cs
index 364e29b877..abfff554c9 100644
--- a/src/Microsoft.ML.StandardLearners/Standard/Simple/SimpleTrainers.cs
+++ b/src/Microsoft.ML.StandardLearners/Standard/Simple/SimpleTrainers.cs
@@ -38,7 +38,7 @@ namespace Microsoft.ML.Runtime.Learners
///
/// A trainer that trains a predictor that returns random values
///
- public sealed class RandomTrainer : TrainerBase
+ public sealed class RandomTrainer : TrainerBase
{
internal const string LoadNameValue = "RandomPredictor";
internal const string UserNameValue = "Random Predictor";
@@ -54,29 +54,20 @@ public class Arguments
public bool BooleanArg = false;
}
- private Arguments _args;
+ public override PredictionKind PredictionKind => PredictionKind.BinaryClassification;
+
+ private static readonly TrainerInfo _info = new TrainerInfo(normalization: false, caching: false);
+ public override TrainerInfo Info => _info;
public RandomTrainer(IHostEnvironment env, Arguments args)
: base(env, LoadNameValue)
{
- _args = args;
- }
-
- public override PredictionKind PredictionKind
- { get { return PredictionKind.BinaryClassification; } }
- public override bool NeedNormalization
- { get { return false; } }
- public override bool NeedCalibration
- { get { return false; } }
- public override bool WantCaching
- { get { return false; } }
-
- public override void Train(RoleMappedData data)
- {
+ Host.CheckValue(args, nameof(args));
}
- public override RandomPredictor CreatePredictor()
+ public override RandomPredictor Train(TrainContext context)
{
+ Host.CheckValue(context, nameof(context));
return new RandomPredictor(Host, Host.Rand.Next());
}
}
@@ -107,16 +98,10 @@ private static VersionInfo GetVersionInfo()
private readonly object _instanceLock;
private readonly Random _random;
- private readonly ColumnType _inputType;
-
- public override PredictionKind PredictionKind
- { get { return PredictionKind.BinaryClassification; } }
- public ColumnType InputType
- { get { return _inputType; } }
- public ColumnType OutputType
- { get { return NumberType.Float; } }
- public ColumnType DistType
- { get { return NumberType.Float; } }
+ public override PredictionKind PredictionKind => PredictionKind.BinaryClassification;
+ public ColumnType InputType { get; }
+ public ColumnType OutputType => NumberType.Float;
+ public ColumnType DistType => NumberType.Float;
public RandomPredictor(IHostEnvironment env, int seed)
: base(env, LoaderSignature)
@@ -126,7 +111,7 @@ public RandomPredictor(IHostEnvironment env, int seed)
_instanceLock = new object();
_random = new Random(_seed);
- _inputType = new VectorType(NumberType.Float);
+ InputType = new VectorType(NumberType.Float);
}
///
@@ -211,7 +196,7 @@ private void MapDist(ref VBuffer src, ref Float score, ref Float prob)
}
// Learns the prior distribution for 0/1 class labels and just outputs that.
- public sealed class PriorTrainer : TrainerBase
+ public sealed class PriorTrainer : TrainerBase
{
internal const string LoadNameValue = "PriorPredictor";
internal const string UserNameValue = "Prior Predictor";
@@ -220,26 +205,22 @@ public sealed class Arguments
{
}
- private Float _prob;
+ public override PredictionKind PredictionKind => PredictionKind.BinaryClassification;
- public override PredictionKind PredictionKind
- { get { return PredictionKind.BinaryClassification; } }
- public override bool NeedNormalization
- { get { return false; } }
- public override bool NeedCalibration
- { get { return false; } }
- public override bool WantCaching
- { get { return false; } }
+ private static readonly TrainerInfo _info = new TrainerInfo(normalization: false, caching: false);
+ public override TrainerInfo Info => _info;
public PriorTrainer(IHostEnvironment env, Arguments args)
: base(env, LoadNameValue)
{
- _prob = Float.NaN;
+ Host.CheckValue(args, nameof(args));
}
- public override void Train(RoleMappedData data)
+ public override PriorPredictor Train(TrainContext context)
{
- Contracts.CheckValue(data, nameof(data));
+ Contracts.CheckValue(context, nameof(context));
+ var data = context.TrainingSet;
+ data.CheckBinaryLabel();
Contracts.CheckParam(data.Schema.Label != null, nameof(data), "Missing Label column");
Contracts.CheckParam(data.Schema.Label.Type == NumberType.Float, nameof(data), "Invalid type for Label column");
@@ -248,11 +229,11 @@ public override void Train(RoleMappedData data)
int col = data.Schema.Label.Index;
int colWeight = -1;
- if (data.Schema.Weight != null && data.Schema.Weight.Type == NumberType.Float)
+ if (data.Schema.Weight?.Type == NumberType.Float)
colWeight = data.Schema.Weight.Index;
using (var cursor = data.Data.GetRowCursor(c => c == col || c == colWeight))
{
- var getLab = cursor.GetGetter(col);
+ var getLab = cursor.GetLabelFloatGetter(data);
var getWeight = colWeight >= 0 ? cursor.GetGetter(colWeight) : null;
Float lab = default(Float);
Float weight = 1;
@@ -274,13 +255,8 @@ public override void Train(RoleMappedData data)
}
}
- if (pos + neg > 0)
- _prob = (Float)(pos / (pos + neg));
- }
-
- public override PriorPredictor CreatePredictor()
- {
- return new PriorPredictor(Host, _prob);
+ Float prob = prob = pos + neg > 0 ? (Float)(pos / (pos + neg)) : Float.NaN;
+ return new PriorPredictor(Host, prob);
}
}
@@ -304,8 +280,6 @@ private static VersionInfo GetVersionInfo()
private readonly Float _prob;
private readonly Float _raw;
- private readonly ColumnType _inputType;
-
public PriorPredictor(IHostEnvironment env, Float prob)
: base(env, LoaderSignature)
{
@@ -314,7 +288,7 @@ public PriorPredictor(IHostEnvironment env, Float prob)
_prob = prob;
_raw = 2 * _prob - 1; // This could be other functions -- logodds for instance
- _inputType = new VectorType(NumberType.Float);
+ InputType = new VectorType(NumberType.Float);
}
private PriorPredictor(IHostEnvironment env, ModelLoadContext ctx)
@@ -328,7 +302,7 @@ private PriorPredictor(IHostEnvironment env, ModelLoadContext ctx)
_raw = 2 * _prob - 1;
- _inputType = new VectorType(NumberType.Float);
+ InputType = new VectorType(NumberType.Float);
}
public static PriorPredictor Create(IHostEnvironment env, ModelLoadContext ctx)
@@ -353,12 +327,9 @@ protected override void SaveCore(ModelSaveContext ctx)
public override PredictionKind PredictionKind
{ get { return PredictionKind.BinaryClassification; } }
- public ColumnType InputType
- { get { return _inputType; } }
- public ColumnType OutputType
- { get { return NumberType.Float; } }
- public ColumnType DistType
- { get { return NumberType.Float; } }
+ public ColumnType InputType { get; }
+ public ColumnType OutputType => NumberType.Float;
+ public ColumnType DistType => NumberType.Float;
public ValueMapper GetMapper()
{
diff --git a/src/Microsoft.ML.Sweeper/Algorithms/SmacSweeper.cs b/src/Microsoft.ML.Sweeper/Algorithms/SmacSweeper.cs
index 391b102cb0..2351454709 100644
--- a/src/Microsoft.ML.Sweeper/Algorithms/SmacSweeper.cs
+++ b/src/Microsoft.ML.Sweeper/Algorithms/SmacSweeper.cs
@@ -142,9 +142,8 @@ private FastForestRegressionPredictor FitModel(IEnumerable previousR
args.MinDocumentsInLeafs = _args.NMinForSplit;
// Train random forest.
- FastForestRegression trainer = new FastForestRegression(_host, args);
- trainer.Train(data);
- FastForestRegressionPredictor predictor = trainer.CreatePredictor();
+ var trainer = new FastForestRegression(_host, args);
+ var predictor = trainer.Train(data);
// Return random forest predictor.
ch.Done();
diff --git a/src/Microsoft.ML.Transforms/LearnerFeatureSelection.cs b/src/Microsoft.ML.Transforms/LearnerFeatureSelection.cs
index 637d75250b..c2b2bead79 100644
--- a/src/Microsoft.ML.Transforms/LearnerFeatureSelection.cs
+++ b/src/Microsoft.ML.Transforms/LearnerFeatureSelection.cs
@@ -33,8 +33,8 @@ public sealed class Arguments
public int? NumSlotsToKeep;
[Argument(ArgumentType.Multiple, HelpText = "Filter", ShortName = "f", SortOrder = 1)]
- public SubComponent>, SignatureFeatureScorerTrainer> Filter =
- new SubComponent>, SignatureFeatureScorerTrainer>("SDCA");
+ public SubComponent>, SignatureFeatureScorerTrainer> Filter =
+ new SubComponent>, SignatureFeatureScorerTrainer>("SDCA");
[Argument(ArgumentType.LastOccurenceWins, HelpText = "Column to use for features", ShortName = "feat,col", SortOrder = 3, Purpose = SpecialPurpose.ColumnName)]
public string FeatureColumn = DefaultColumnNames.Features;
diff --git a/test/Microsoft.ML.Core.Tests/UnitTests/TestEntryPoints.cs b/test/Microsoft.ML.Core.Tests/UnitTests/TestEntryPoints.cs
index afd488915e..b0bc269164 100644
--- a/test/Microsoft.ML.Core.Tests/UnitTests/TestEntryPoints.cs
+++ b/test/Microsoft.ML.Core.Tests/UnitTests/TestEntryPoints.cs
@@ -685,8 +685,7 @@ public void EntryPointCalibrate()
// This tests that the SchemaBindableCalibratedPredictor doesn't get confused if its sub-predictor is already calibrated.
var fastForest = new FastForestClassification(Env, new FastForestClassification.Arguments());
var rmd = new RoleMappedData(splitOutput.TrainData[0], "Label", "Features");
- fastForest.Train(rmd);
- var ffModel = new PredictorModel(Env, rmd, splitOutput.TrainData[0], fastForest.CreatePredictor());
+ var ffModel = new PredictorModel(Env, rmd, splitOutput.TrainData[0], fastForest.Train(rmd));
var calibratedFfModel = Calibrate.Platt(Env,
new Calibrate.NoArgumentsInput() { Data = splitOutput.TestData[0], UncalibratedPredictorModel = ffModel }).PredictorModel;
var twiceCalibratedFfModel = Calibrate.Platt(Env,
@@ -1219,9 +1218,8 @@ public void EntryPointMulticlassPipelineEnsemble()
var mlr = new MulticlassLogisticRegression(Env, new MulticlassLogisticRegression.Arguments());
var rmd = new RoleMappedData(data, "Label", "Features");
- mlr.Train(rmd);
- predictorModels[i] = new PredictorModel(Env, rmd, data, mlr.CreatePredictor());
+ predictorModels[i] = new PredictorModel(Env, rmd, data, mlr.Train(rmd));
var transformModel = new TransformModel(Env, data, splitOutput.TrainData[i]);
predictorModels[i] = ModelOperations.CombineTwoModels(Env,
diff --git a/test/Microsoft.ML.Tests/ScenariosWithDirectInstantiation/IrisPlantClassificationTests.cs b/test/Microsoft.ML.Tests/ScenariosWithDirectInstantiation/IrisPlantClassificationTests.cs
index 20649b5b25..0535f2b15d 100644
--- a/test/Microsoft.ML.Tests/ScenariosWithDirectInstantiation/IrisPlantClassificationTests.cs
+++ b/test/Microsoft.ML.Tests/ScenariosWithDirectInstantiation/IrisPlantClassificationTests.cs
@@ -70,15 +70,14 @@ public void TrainAndPredictIrisModelUsingDirectInstantiationTest()
trans = NormalizeTransform.CreateMinMaxNormalizer(env, trans, "Features");
// Train
- var trainer = new SdcaMultiClassTrainer(env, new SdcaMultiClassTrainer.Arguments());
+ var trainer = new SdcaMultiClassTrainer(env, new SdcaMultiClassTrainer.Arguments() { NumThreads = 1 } );
// Explicity adding CacheDataView since caching is not working though trainer has 'Caching' On/Auto
var cached = new CacheDataView(env, trans, prefetch: null);
var trainRoles = new RoleMappedData(cached, label: "Label", feature: "Features");
- trainer.Train(trainRoles);
+ var pred = trainer.Train(trainRoles);
// Get scorer and evaluate the predictions from test data
- var pred = trainer.CreatePredictor();
IDataScorerTransform testDataScorer = GetScorer(env, trans, pred, testDataPath);
var metrics = Evaluate(env, testDataScorer);
CompareMatrics(metrics);
diff --git a/test/Microsoft.ML.Tests/ScenariosWithDirectInstantiation/SentimentPredictionTests.cs b/test/Microsoft.ML.Tests/ScenariosWithDirectInstantiation/SentimentPredictionTests.cs
index da208cb3f0..87c3ab4beb 100644
--- a/test/Microsoft.ML.Tests/ScenariosWithDirectInstantiation/SentimentPredictionTests.cs
+++ b/test/Microsoft.ML.Tests/ScenariosWithDirectInstantiation/SentimentPredictionTests.cs
@@ -79,10 +79,9 @@ public void TrainAndPredictSentimentModelWithDirectionInstantiationTest()
});
var trainRoles = new RoleMappedData(trans, label: "Label", feature: "Features");
- trainer.Train(trainRoles);
+ var pred = trainer.Train(trainRoles);
// Get scorer and evaluate the predictions from test data
- var pred = trainer.CreatePredictor();
IDataScorerTransform testDataScorer = GetScorer(env, trans, pred, testDataPath);
var metrics = EvaluateBinary(env, testDataScorer);
ValidateBinaryMetrics(metrics);