22// The .NET Foundation licenses this file to you under the MIT license.
33// See the LICENSE file in the project root for more information.
44
5- using System ;
65using System . IO ;
76using Microsoft . ML . Runtime ;
87using Microsoft . ML . Runtime . Data ;
2322
2423namespace Microsoft . ML . Runtime . Data
2524{
26- public abstract class PredictionTransformerBase < TModel > : IPredictionTransformer < TModel > , ICanSaveModel
25+
26+ /// <summary>
27+ /// Base class for transformers with no feature column, or more than one feature columns.
28+ /// </summary>
29+ /// <typeparam name="TModel"></typeparam>
30+ public abstract class PredictionTransformerBase < TModel > : IPredictionTransformer < TModel >
2731 where TModel : class , IPredictor
2832 {
29- private const string DirModel = "Model" ;
30- private const string DirTransSchema = "TrainSchema" ;
33+ /// <summary>
34+ /// The model.
35+ /// </summary>
36+ public TModel Model { get ; }
3137
38+ protected const string DirModel = "Model" ;
39+ protected const string DirTransSchema = "TrainSchema" ;
3240 protected readonly IHost Host ;
33- protected readonly ISchemaBindableMapper BindableMapper ;
34- protected readonly ISchema TrainSchema ;
35-
36- public string FeatureColumn { get ; }
37-
38- public ColumnType FeatureColumnType { get ; }
41+ protected ISchemaBindableMapper BindableMapper ;
42+ protected ISchema TrainSchema ;
3943
40- public TModel Model { get ; }
41-
42- public PredictionTransformerBase ( IHost host , TModel model , ISchema trainSchema , string featureColumn )
44+ protected PredictionTransformerBase ( IHost host , TModel model , ISchema trainSchema )
4345 {
4446 Contracts . CheckValue ( host , nameof ( host ) ) ;
45- Contracts . CheckValueOrNull ( featureColumn ) ;
47+
4648 Host = host ;
4749 Host . CheckValue ( trainSchema , nameof ( trainSchema ) ) ;
4850
4951 Model = model ;
50- FeatureColumn = featureColumn ;
51- if ( featureColumn == null )
52- FeatureColumnType = null ;
53- else if ( ! trainSchema . TryGetColumnIndex ( featureColumn , out int col ) )
54- throw Host . ExceptSchemaMismatch ( nameof ( featureColumn ) , RoleMappedSchema . ColumnRole . Feature . Value , featureColumn ) ;
55- else
56- FeatureColumnType = trainSchema . GetColumnType ( col ) ;
57-
5852 TrainSchema = trainSchema ;
59- BindableMapper = ScoreUtils . GetSchemaBindableMapper ( Host , model ) ;
6053 }
6154
62- internal PredictionTransformerBase ( IHost host , ModelLoadContext ctx )
55+ protected PredictionTransformerBase ( IHost host , ModelLoadContext ctx )
56+
6357 {
6458 Host = host ;
6559
66- ctx . LoadModel < TModel , SignatureLoadModel > ( host , out TModel model , DirModel ) ;
67- Model = model ;
68-
6960 // *** Binary format ***
7061 // model: prediction model.
7162 // stream: empty data view that contains train schema.
7263 // id of string: feature column.
7364
65+ ctx . LoadModel < TModel , SignatureLoadModel > ( host , out TModel model , DirModel ) ;
66+ Model = model ;
67+
7468 // Clone the stream with the schema into memory.
7569 var ms = new MemoryStream ( ) ;
7670 ctx . TryLoadBinaryStream ( DirTransSchema , reader =>
@@ -81,19 +75,90 @@ internal PredictionTransformerBase(IHost host, ModelLoadContext ctx)
8175 ms . Position = 0 ;
8276 var loader = new BinaryLoader ( host , new BinaryLoader . Arguments ( ) , ms ) ;
8377 TrainSchema = loader . Schema ;
78+ }
79+
80+ /// <summary>
81+ /// Gets the output schema resulting from the <see cref="Transform(IDataView)"/>
82+ /// </summary>
83+ /// <param name="inputSchema">The <see cref="ISchema"/> of the input data.</param>
84+ /// <returns>The resulting <see cref="ISchema"/>.</returns>
85+ public abstract ISchema GetOutputSchema ( ISchema inputSchema ) ;
86+
87+ /// <summary>
88+ /// Transforms the input data.
89+ /// </summary>
90+ /// <param name="input">The input data.</param>
91+ /// <returns>The transformed <see cref="IDataView"/></returns>
92+ public abstract IDataView Transform ( IDataView input ) ;
93+
94+ protected void SaveModel ( ModelSaveContext ctx )
95+ {
96+ // *** Binary format ***
97+ // <base info>
98+ // stream: empty data view that contains train schema.
8499
100+ ctx . SaveModel ( Model , DirModel ) ;
101+ ctx . SaveBinaryStream ( DirTransSchema , writer =>
102+ {
103+ using ( var ch = Host . Start ( "Saving train schema" ) )
104+ {
105+ var saver = new BinarySaver ( Host , new BinarySaver . Arguments { Silent = true } ) ;
106+ DataSaverUtils . SaveDataView ( ch , saver , new EmptyDataView ( Host , TrainSchema ) , writer . BaseStream ) ;
107+ }
108+ } ) ;
109+ }
110+ }
111+
112+ /// <summary>
113+ /// The base class for all the transformers implementing the <see cref="ISingleFeaturePredictionTransformer{TModel}"/>.
114+ /// Those are all the transformers that work with one feature column.
115+ /// </summary>
116+ /// <typeparam name="TModel">The model used to transform the data.</typeparam>
117+ public abstract class SingleFeaturePredictionTransformerBase < TModel > : PredictionTransformerBase < TModel > , ISingleFeaturePredictionTransformer < TModel > , ICanSaveModel
118+ where TModel : class , IPredictor
119+ {
120+ /// <summary>
121+ /// The name of the feature column used by the prediction transformer.
122+ /// </summary>
123+ public string FeatureColumn { get ; }
124+
125+ /// <summary>
126+ /// The type of the prediction transformer
127+ /// </summary>
128+ public ColumnType FeatureColumnType { get ; }
129+
130+ public SingleFeaturePredictionTransformerBase ( IHost host , TModel model , ISchema trainSchema , string featureColumn )
131+ : base ( host , model , trainSchema )
132+ {
133+ FeatureColumn = featureColumn ;
134+
135+ FeatureColumn = featureColumn ;
136+ if ( featureColumn == null )
137+ FeatureColumnType = null ;
138+ else if ( ! trainSchema . TryGetColumnIndex ( featureColumn , out int col ) )
139+ throw Host . ExceptSchemaMismatch ( nameof ( featureColumn ) , RoleMappedSchema . ColumnRole . Feature . Value , featureColumn ) ;
140+ else
141+ FeatureColumnType = trainSchema . GetColumnType ( col ) ;
142+
143+ BindableMapper = ScoreUtils . GetSchemaBindableMapper ( Host , model ) ;
144+ }
145+
146+ internal SingleFeaturePredictionTransformerBase ( IHost host , ModelLoadContext ctx )
147+ : base ( host , ctx )
148+ {
85149 FeatureColumn = ctx . LoadStringOrNull ( ) ;
150+
86151 if ( FeatureColumn == null )
87152 FeatureColumnType = null ;
88153 else if ( ! TrainSchema . TryGetColumnIndex ( FeatureColumn , out int col ) )
89154 throw Host . ExceptSchemaMismatch ( nameof ( FeatureColumn ) , RoleMappedSchema . ColumnRole . Feature . Value , FeatureColumn ) ;
90155 else
91156 FeatureColumnType = TrainSchema . GetColumnType ( col ) ;
92157
93- BindableMapper = ScoreUtils . GetSchemaBindableMapper ( Host , model ) ;
158+ BindableMapper = ScoreUtils . GetSchemaBindableMapper ( Host , Model ) ;
94159 }
95160
96- public ISchema GetOutputSchema ( ISchema inputSchema )
161+ public override ISchema GetOutputSchema ( ISchema inputSchema )
97162 {
98163 Host . CheckValue ( inputSchema , nameof ( inputSchema ) ) ;
99164
@@ -108,8 +173,6 @@ public ISchema GetOutputSchema(ISchema inputSchema)
108173 return Transform ( new EmptyDataView ( Host , inputSchema ) ) . Schema ;
109174 }
110175
111- public abstract IDataView Transform ( IDataView input ) ;
112-
113176 public void Save ( ModelSaveContext ctx )
114177 {
115178 Host . CheckValue ( ctx , nameof ( ctx ) ) ;
@@ -119,26 +182,16 @@ public void Save(ModelSaveContext ctx)
119182
120183 protected virtual void SaveCore ( ModelSaveContext ctx )
121184 {
122- // *** Binary format ***
123- // model: prediction model.
124- // stream: empty data view that contains train schema.
125- // id of string: feature column.
126-
127- ctx . SaveModel ( Model , DirModel ) ;
128- ctx . SaveBinaryStream ( DirTransSchema , writer =>
129- {
130- using ( var ch = Host . Start ( "Saving train schema" ) )
131- {
132- var saver = new BinarySaver ( Host , new BinarySaver . Arguments { Silent = true } ) ;
133- DataSaverUtils . SaveDataView ( ch , saver , new EmptyDataView ( Host , TrainSchema ) , writer . BaseStream ) ;
134- }
135- } ) ;
136-
185+ SaveModel ( ctx ) ;
137186 ctx . SaveStringOrNull ( FeatureColumn ) ;
138187 }
139188 }
140189
141- public sealed class BinaryPredictionTransformer < TModel > : PredictionTransformerBase < TModel >
190+ /// <summary>
191+ /// Base class for the <see cref="ISingleFeaturePredictionTransformer{TModel}"/> working on binary classification tasks.
192+ /// </summary>
193+ /// <typeparam name="TModel">An implementation of the <see cref="IPredictorProducing{TResult}"/></typeparam>
194+ public sealed class BinaryPredictionTransformer < TModel > : SingleFeaturePredictionTransformerBase < TModel >
142195 where TModel : class , IPredictorProducing < float >
143196 {
144197 private readonly BinaryClassifierScorer _scorer ;
@@ -207,7 +260,11 @@ private static VersionInfo GetVersionInfo()
207260 }
208261 }
209262
210- public sealed class MulticlassPredictionTransformer < TModel > : PredictionTransformerBase < TModel >
263+ /// <summary>
264+ /// Base class for the <see cref="ISingleFeaturePredictionTransformer{TModel}"/> working on multi-class classification tasks.
265+ /// </summary>
266+ /// <typeparam name="TModel">An implementation of the <see cref="IPredictorProducing{TResult}"/></typeparam>
267+ public sealed class MulticlassPredictionTransformer < TModel > : SingleFeaturePredictionTransformerBase < TModel >
211268 where TModel : class , IPredictorProducing < VBuffer < float > >
212269 {
213270 private readonly MultiClassClassifierScorer _scorer ;
@@ -268,7 +325,11 @@ private static VersionInfo GetVersionInfo()
268325 }
269326 }
270327
271- public sealed class RegressionPredictionTransformer < TModel > : PredictionTransformerBase < TModel >
328+ /// <summary>
329+ /// Base class for the <see cref="ISingleFeaturePredictionTransformer{TModel}"/> working on regression tasks.
330+ /// </summary>
331+ /// <typeparam name="TModel">An implementation of the <see cref="IPredictorProducing{TResult}"/></typeparam>
332+ public sealed class RegressionPredictionTransformer < TModel > : SingleFeaturePredictionTransformerBase < TModel >
272333 where TModel : class , IPredictorProducing < float >
273334 {
274335 private readonly GenericScorer _scorer ;
@@ -314,7 +375,7 @@ private static VersionInfo GetVersionInfo()
314375 }
315376 }
316377
317- public sealed class RankingPredictionTransformer < TModel > : PredictionTransformerBase < TModel >
378+ public sealed class RankingPredictionTransformer < TModel > : SingleFeaturePredictionTransformerBase < TModel >
318379 where TModel : class , IPredictorProducing < float >
319380 {
320381 private readonly GenericScorer _scorer ;
0 commit comments