Skip to content

Commit 044a6d3

Browse files
authored
FAFM to estimator (#912)
* FAFM to extend TrainerEstimatorBase * Fixing the creation of getters on the Fafm predictor.
1 parent a627d5b commit 044a6d3

26 files changed

+643
-123
lines changed

src/Microsoft.ML.Core/Data/IEstimator.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -229,14 +229,14 @@ public interface IDataReaderEstimator<in TSource, out TReader>
229229

230230
/// <summary>
231231
/// The transformer is a component that transforms data.
232-
/// It also supports 'schema propagation' to answer the question of 'how the data with this schema look after you transform it?'.
232+
/// It also supports 'schema propagation' to answer the question of 'how will the data with this schema look, after you transform it?'.
233233
/// </summary>
234234
public interface ITransformer
235235
{
236236
/// <summary>
237237
/// Schema propagation for transformers.
238238
/// Returns the output schema of the data, if the input schema is like the one provided.
239-
/// Throws <see cref="SchemaException"/> iff the input schema is not valid for the transformer.
239+
/// Throws <see cref="SchemaException"/> if the input schema is not valid for the transformer.
240240
/// </summary>
241241
ISchema GetOutputSchema(ISchema inputSchema);
242242

src/Microsoft.ML.Data/Prediction/IPredictionTransformer.cs

Lines changed: 21 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,22 +2,37 @@
22
// The .NET Foundation licenses this file to you under the MIT license.
33
// See the LICENSE file in the project root for more information.
44

5+
using System.Collections.Generic;
56
using Microsoft.ML.Core.Data;
67
using Microsoft.ML.Runtime.Data;
7-
using Microsoft.ML.Runtime.Internal.Calibration;
8-
using System;
9-
using System.Collections.Generic;
10-
using System.Text;
118

129
namespace Microsoft.ML.Runtime
1310
{
11+
/// <summary>
12+
/// An interface for all the transformer that can transform data based on the <see cref="IPredictor"/> field.
13+
/// The implemendations of this interface either have no feature column, or have more than one feature column, and cannot implement the
14+
/// <see cref="ISingleFeaturePredictionTransformer{TModel}"/>, which most of the ML.Net tranformer implement.
15+
/// </summary>
16+
/// <typeparam name="TModel">The <see cref="IPredictor"/> used for the data transformation.</typeparam>
1417
public interface IPredictionTransformer<out TModel> : ITransformer
1518
where TModel : IPredictor
1619
{
20+
TModel Model { get; }
21+
}
22+
23+
/// <summary>
24+
/// An ISingleFeaturePredictionTransformer contains the name of the <see cref="FeatureColumn"/>
25+
/// and its type, <see cref="FeatureColumnType"/>. Implementations of this interface, have the ability
26+
/// to score the data of an input <see cref="IDataView"/> through the <see cref="ITransformer.Transform(IDataView)"/>
27+
/// </summary>
28+
/// <typeparam name="TModel">The <see cref="IPredictor"/> used for the data transformation.</typeparam>
29+
public interface ISingleFeaturePredictionTransformer<out TModel> : IPredictionTransformer<TModel>
30+
where TModel : IPredictor
31+
{
32+
/// <summary>The name of the feature column.</summary>
1733
string FeatureColumn { get; }
1834

35+
/// <summary>Holds information about the type of the feature column.</summary>
1936
ColumnType FeatureColumnType { get; }
20-
21-
TModel Model { get; }
2237
}
2338
}

src/Microsoft.ML.Data/Scorers/PredictionTransformer.cs

Lines changed: 111 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
// The .NET Foundation licenses this file to you under the MIT license.
33
// See the LICENSE file in the project root for more information.
44

5-
using System;
65
using System.IO;
76
using Microsoft.ML.Runtime;
87
using Microsoft.ML.Runtime.Data;
@@ -23,54 +22,49 @@
2322

2423
namespace Microsoft.ML.Runtime.Data
2524
{
26-
public abstract class PredictionTransformerBase<TModel> : IPredictionTransformer<TModel>, ICanSaveModel
25+
26+
/// <summary>
27+
/// Base class for transformers with no feature column, or more than one feature columns.
28+
/// </summary>
29+
/// <typeparam name="TModel"></typeparam>
30+
public abstract class PredictionTransformerBase<TModel> : IPredictionTransformer<TModel>
2731
where TModel : class, IPredictor
2832
{
29-
private const string DirModel = "Model";
30-
private const string DirTransSchema = "TrainSchema";
33+
/// <summary>
34+
/// The model.
35+
/// </summary>
36+
public TModel Model { get; }
3137

38+
protected const string DirModel = "Model";
39+
protected const string DirTransSchema = "TrainSchema";
3240
protected readonly IHost Host;
33-
protected readonly ISchemaBindableMapper BindableMapper;
34-
protected readonly ISchema TrainSchema;
35-
36-
public string FeatureColumn { get; }
37-
38-
public ColumnType FeatureColumnType { get; }
41+
protected ISchemaBindableMapper BindableMapper;
42+
protected ISchema TrainSchema;
3943

40-
public TModel Model { get; }
41-
42-
public PredictionTransformerBase(IHost host, TModel model, ISchema trainSchema, string featureColumn)
44+
protected PredictionTransformerBase(IHost host, TModel model, ISchema trainSchema)
4345
{
4446
Contracts.CheckValue(host, nameof(host));
45-
Contracts.CheckValueOrNull(featureColumn);
47+
4648
Host = host;
4749
Host.CheckValue(trainSchema, nameof(trainSchema));
4850

4951
Model = model;
50-
FeatureColumn = featureColumn;
51-
if (featureColumn == null)
52-
FeatureColumnType = null;
53-
else if (!trainSchema.TryGetColumnIndex(featureColumn, out int col))
54-
throw Host.ExceptSchemaMismatch(nameof(featureColumn), RoleMappedSchema.ColumnRole.Feature.Value, featureColumn);
55-
else
56-
FeatureColumnType = trainSchema.GetColumnType(col);
57-
5852
TrainSchema = trainSchema;
59-
BindableMapper = ScoreUtils.GetSchemaBindableMapper(Host, model);
6053
}
6154

62-
internal PredictionTransformerBase(IHost host, ModelLoadContext ctx)
55+
protected PredictionTransformerBase(IHost host, ModelLoadContext ctx)
56+
6357
{
6458
Host = host;
6559

66-
ctx.LoadModel<TModel, SignatureLoadModel>(host, out TModel model, DirModel);
67-
Model = model;
68-
6960
// *** Binary format ***
7061
// model: prediction model.
7162
// stream: empty data view that contains train schema.
7263
// id of string: feature column.
7364

65+
ctx.LoadModel<TModel, SignatureLoadModel>(host, out TModel model, DirModel);
66+
Model = model;
67+
7468
// Clone the stream with the schema into memory.
7569
var ms = new MemoryStream();
7670
ctx.TryLoadBinaryStream(DirTransSchema, reader =>
@@ -81,19 +75,90 @@ internal PredictionTransformerBase(IHost host, ModelLoadContext ctx)
8175
ms.Position = 0;
8276
var loader = new BinaryLoader(host, new BinaryLoader.Arguments(), ms);
8377
TrainSchema = loader.Schema;
78+
}
79+
80+
/// <summary>
81+
/// Gets the output schema resulting from the <see cref="Transform(IDataView)"/>
82+
/// </summary>
83+
/// <param name="inputSchema">The <see cref="ISchema"/> of the input data.</param>
84+
/// <returns>The resulting <see cref="ISchema"/>.</returns>
85+
public abstract ISchema GetOutputSchema(ISchema inputSchema);
86+
87+
/// <summary>
88+
/// Transforms the input data.
89+
/// </summary>
90+
/// <param name="input">The input data.</param>
91+
/// <returns>The transformed <see cref="IDataView"/></returns>
92+
public abstract IDataView Transform(IDataView input);
93+
94+
protected void SaveModel(ModelSaveContext ctx)
95+
{
96+
// *** Binary format ***
97+
// <base info>
98+
// stream: empty data view that contains train schema.
8499

100+
ctx.SaveModel(Model, DirModel);
101+
ctx.SaveBinaryStream(DirTransSchema, writer =>
102+
{
103+
using (var ch = Host.Start("Saving train schema"))
104+
{
105+
var saver = new BinarySaver(Host, new BinarySaver.Arguments { Silent = true });
106+
DataSaverUtils.SaveDataView(ch, saver, new EmptyDataView(Host, TrainSchema), writer.BaseStream);
107+
}
108+
});
109+
}
110+
}
111+
112+
/// <summary>
113+
/// The base class for all the transformers implementing the <see cref="ISingleFeaturePredictionTransformer{TModel}"/>.
114+
/// Those are all the transformers that work with one feature column.
115+
/// </summary>
116+
/// <typeparam name="TModel">The model used to transform the data.</typeparam>
117+
public abstract class SingleFeaturePredictionTransformerBase<TModel> : PredictionTransformerBase<TModel>, ISingleFeaturePredictionTransformer<TModel>, ICanSaveModel
118+
where TModel : class, IPredictor
119+
{
120+
/// <summary>
121+
/// The name of the feature column used by the prediction transformer.
122+
/// </summary>
123+
public string FeatureColumn { get; }
124+
125+
/// <summary>
126+
/// The type of the prediction transformer
127+
/// </summary>
128+
public ColumnType FeatureColumnType { get; }
129+
130+
public SingleFeaturePredictionTransformerBase(IHost host, TModel model, ISchema trainSchema, string featureColumn)
131+
:base(host, model, trainSchema)
132+
{
133+
FeatureColumn = featureColumn;
134+
135+
FeatureColumn = featureColumn;
136+
if (featureColumn == null)
137+
FeatureColumnType = null;
138+
else if (!trainSchema.TryGetColumnIndex(featureColumn, out int col))
139+
throw Host.ExceptSchemaMismatch(nameof(featureColumn), RoleMappedSchema.ColumnRole.Feature.Value, featureColumn);
140+
else
141+
FeatureColumnType = trainSchema.GetColumnType(col);
142+
143+
BindableMapper = ScoreUtils.GetSchemaBindableMapper(Host, model);
144+
}
145+
146+
internal SingleFeaturePredictionTransformerBase(IHost host, ModelLoadContext ctx)
147+
:base(host, ctx)
148+
{
85149
FeatureColumn = ctx.LoadStringOrNull();
150+
86151
if (FeatureColumn == null)
87152
FeatureColumnType = null;
88153
else if (!TrainSchema.TryGetColumnIndex(FeatureColumn, out int col))
89154
throw Host.ExceptSchemaMismatch(nameof(FeatureColumn), RoleMappedSchema.ColumnRole.Feature.Value, FeatureColumn);
90155
else
91156
FeatureColumnType = TrainSchema.GetColumnType(col);
92157

93-
BindableMapper = ScoreUtils.GetSchemaBindableMapper(Host, model);
158+
BindableMapper = ScoreUtils.GetSchemaBindableMapper(Host, Model);
94159
}
95160

96-
public ISchema GetOutputSchema(ISchema inputSchema)
161+
public override ISchema GetOutputSchema(ISchema inputSchema)
97162
{
98163
Host.CheckValue(inputSchema, nameof(inputSchema));
99164

@@ -108,8 +173,6 @@ public ISchema GetOutputSchema(ISchema inputSchema)
108173
return Transform(new EmptyDataView(Host, inputSchema)).Schema;
109174
}
110175

111-
public abstract IDataView Transform(IDataView input);
112-
113176
public void Save(ModelSaveContext ctx)
114177
{
115178
Host.CheckValue(ctx, nameof(ctx));
@@ -119,26 +182,16 @@ public void Save(ModelSaveContext ctx)
119182

120183
protected virtual void SaveCore(ModelSaveContext ctx)
121184
{
122-
// *** Binary format ***
123-
// model: prediction model.
124-
// stream: empty data view that contains train schema.
125-
// id of string: feature column.
126-
127-
ctx.SaveModel(Model, DirModel);
128-
ctx.SaveBinaryStream(DirTransSchema, writer =>
129-
{
130-
using (var ch = Host.Start("Saving train schema"))
131-
{
132-
var saver = new BinarySaver(Host, new BinarySaver.Arguments { Silent = true });
133-
DataSaverUtils.SaveDataView(ch, saver, new EmptyDataView(Host, TrainSchema), writer.BaseStream);
134-
}
135-
});
136-
185+
SaveModel(ctx);
137186
ctx.SaveStringOrNull(FeatureColumn);
138187
}
139188
}
140189

141-
public sealed class BinaryPredictionTransformer<TModel> : PredictionTransformerBase<TModel>
190+
/// <summary>
191+
/// Base class for the <see cref="ISingleFeaturePredictionTransformer{TModel}"/> working on binary classification tasks.
192+
/// </summary>
193+
/// <typeparam name="TModel">An implementation of the <see cref="IPredictorProducing{TResult}"/></typeparam>
194+
public sealed class BinaryPredictionTransformer<TModel> : SingleFeaturePredictionTransformerBase<TModel>
142195
where TModel : class, IPredictorProducing<float>
143196
{
144197
private readonly BinaryClassifierScorer _scorer;
@@ -207,7 +260,11 @@ private static VersionInfo GetVersionInfo()
207260
}
208261
}
209262

210-
public sealed class MulticlassPredictionTransformer<TModel> : PredictionTransformerBase<TModel>
263+
/// <summary>
264+
/// Base class for the <see cref="ISingleFeaturePredictionTransformer{TModel}"/> working on multi-class classification tasks.
265+
/// </summary>
266+
/// <typeparam name="TModel">An implementation of the <see cref="IPredictorProducing{TResult}"/></typeparam>
267+
public sealed class MulticlassPredictionTransformer<TModel> : SingleFeaturePredictionTransformerBase<TModel>
211268
where TModel : class, IPredictorProducing<VBuffer<float>>
212269
{
213270
private readonly MultiClassClassifierScorer _scorer;
@@ -268,7 +325,11 @@ private static VersionInfo GetVersionInfo()
268325
}
269326
}
270327

271-
public sealed class RegressionPredictionTransformer<TModel> : PredictionTransformerBase<TModel>
328+
/// <summary>
329+
/// Base class for the <see cref="ISingleFeaturePredictionTransformer{TModel}"/> working on regression tasks.
330+
/// </summary>
331+
/// <typeparam name="TModel">An implementation of the <see cref="IPredictorProducing{TResult}"/></typeparam>
332+
public sealed class RegressionPredictionTransformer<TModel> : SingleFeaturePredictionTransformerBase<TModel>
272333
where TModel : class, IPredictorProducing<float>
273334
{
274335
private readonly GenericScorer _scorer;
@@ -314,7 +375,7 @@ private static VersionInfo GetVersionInfo()
314375
}
315376
}
316377

317-
public sealed class RankingPredictionTransformer<TModel> : PredictionTransformerBase<TModel>
378+
public sealed class RankingPredictionTransformer<TModel> : SingleFeaturePredictionTransformerBase<TModel>
318379
where TModel : class, IPredictorProducing<float>
319380
{
320381
private readonly GenericScorer _scorer;

src/Microsoft.ML.Data/Training/ITrainerEstimator.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
namespace Microsoft.ML.Runtime.Training
88
{
99
public interface ITrainerEstimator<out TTransformer, out TPredictor>: IEstimator<TTransformer>
10-
where TTransformer: IPredictionTransformer<TPredictor>
10+
where TTransformer: ISingleFeaturePredictionTransformer<TPredictor>
1111
where TPredictor: IPredictor
1212
{
1313
TrainerInfo Info { get; }

src/Microsoft.ML.Data/Training/TrainerEstimatorBase.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ namespace Microsoft.ML.Runtime.Training
1515
/// It produces a 'prediction transformer'.
1616
/// </summary>
1717
public abstract class TrainerEstimatorBase<TTransformer, TModel> : ITrainerEstimator<TTransformer, TModel>, ITrainer<TModel>
18-
where TTransformer : IPredictionTransformer<TModel>
18+
where TTransformer : ISingleFeaturePredictionTransformer<TModel>
1919
where TModel : IPredictor
2020
{
2121
/// <summary>
Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
// Licensed to the .NET Foundation under one or more agreements.
2+
// The .NET Foundation licenses this file to you under the MIT license.
3+
// See the LICENSE file in the project root for more information.
4+
5+
using System;
6+
using System.Collections.Generic;
7+
using System.Text;
8+
using Microsoft.ML.Runtime;
9+
using Microsoft.ML.Runtime.Data;
10+
using Microsoft.ML.Runtime.Training;
11+
12+
namespace Microsoft.ML.Core.Prediction
13+
{
14+
/// <summary>
15+
/// Holds information relevant to trainers. It is passed to the constructor of the<see cref="ITrainerEstimator{IPredictionTransformer, IPredictor}"/>
16+
/// holding additional data needed to fit the estimator. The additional data can be a validation set or an initial model.
17+
/// This holds at least a training set, as well as optioonally a predictor.
18+
/// </summary>
19+
public class TrainerEstimatorContext
20+
{
21+
/// <summary>
22+
/// The validation set. Can be <c>null</c>. Note that passing a non-<c>null</c> validation set into
23+
/// a trainer that does not support validation sets should not be considered an error condition. It
24+
/// should simply be ignored in that case.
25+
/// </summary>
26+
public IDataView ValidationSet { get; }
27+
28+
/// <summary>
29+
/// The initial predictor, for incremental training. Note that if a <see cref="ITrainerEstimator{IPredictionTransformer, IPredictor}"/> implementor
30+
/// does not support incremental training, then it can ignore it similarly to how one would ignore
31+
/// <see cref="ValidationSet"/>. However, if the trainer does support incremental training and there
32+
/// is something wrong with a non-<c>null</c> value of this, then the trainer ought to throw an exception.
33+
/// </summary>
34+
public IPredictor InitialPredictor { get; }
35+
36+
/// <summary>
37+
/// Initializes a new instance of <see cref="TrainerEstimatorContext"/>, given a training set and optional other arguments.
38+
/// </summary>
39+
/// <param name="validationSet">Will set <see cref="ValidationSet"/> to this value if specified</param>
40+
/// <param name="initialPredictor">Will set <see cref="InitialPredictor"/> to this value if specified</param>
41+
public TrainerEstimatorContext(IDataView validationSet = null, IPredictor initialPredictor = null)
42+
{
43+
Contracts.CheckValueOrNull(validationSet);
44+
Contracts.CheckValueOrNull(initialPredictor);
45+
46+
ValidationSet = validationSet;
47+
InitialPredictor = initialPredictor;
48+
}
49+
}
50+
}

src/Microsoft.ML.FastTree/BoostingFastTree.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
namespace Microsoft.ML.Runtime.FastTree
1515
{
1616
public abstract class BoostingFastTreeTrainerBase<TArgs, TTransformer, TModel> : FastTreeTrainerBase<TArgs, TTransformer, TModel>
17-
where TTransformer : IPredictionTransformer<TModel>
17+
where TTransformer : ISingleFeaturePredictionTransformer<TModel>
1818
where TArgs : BoostedTreeArgs, new()
1919
where TModel : IPredictorProducing<Float>
2020
{

0 commit comments

Comments
 (0)