Skip to content

Commit 053062b

Browse files
authored
Multiple feature columns in FFM (#2205)
* Multiple feature columns in FFM * Update src/Microsoft.ML.FastTree/TreeEnsemble/TreeEnsemble.cs
1 parent efe9b96 commit 053062b

File tree

5 files changed

+136
-3
lines changed

5 files changed

+136
-3
lines changed

src/Microsoft.ML.FastTree/TreeEnsemble/TreeEnsemble.cs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,9 @@ namespace Microsoft.ML.Trainers.FastTree.Internal
1818
{
1919
public class TreeEnsemble
2020
{
21+
/// <summary>
22+
/// String appended to the text representation of <see cref="TreeEnsemble"/>. This is mainly used in <see cref="ToTreeEnsembleIni"/>.
23+
/// </summary>
2124
private readonly string _firstInputInitializationContent;
2225
private readonly List<RegressionTree> _trees;
2326

src/Microsoft.ML.SamplesUtils/SamplesDatasetUtils.cs

Lines changed: 53 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -257,7 +257,7 @@ public static IEnumerable<BinaryLabelFloatFeatureVectorSample> GenerateBinaryLa
257257
// Initialize an example with a random label and an empty feature vector.
258258
var sample = new BinaryLabelFloatFeatureVectorSample() { Label = rnd.Next() % 2 == 0, Features = new float[_simpleBinaryClassSampleFeatureLength] };
259259
// Fill feature vector according the assigned label.
260-
for (int j = 0; j < 10; ++j)
260+
for (int j = 0; j < _simpleBinaryClassSampleFeatureLength; ++j)
261261
{
262262
var value = (float)rnd.NextDouble();
263263
// Positive class gets larger feature value.
@@ -271,6 +271,58 @@ public static IEnumerable<BinaryLabelFloatFeatureVectorSample> GenerateBinaryLa
271271
return data;
272272
}
273273

274+
public class FfmExample
275+
{
276+
public bool Label;
277+
278+
[VectorType(_simpleBinaryClassSampleFeatureLength)]
279+
public float[] Field0;
280+
281+
[VectorType(_simpleBinaryClassSampleFeatureLength)]
282+
public float[] Field1;
283+
284+
[VectorType(_simpleBinaryClassSampleFeatureLength)]
285+
public float[] Field2;
286+
}
287+
288+
public static IEnumerable<FfmExample> GenerateFfmSamples(int exampleCount)
289+
{
290+
var rnd = new Random(0);
291+
var data = new List<FfmExample>();
292+
for (int i = 0; i < exampleCount; ++i)
293+
{
294+
// Initialize an example with a random label and an empty feature vector.
295+
var sample = new FfmExample() { Label = rnd.Next() % 2 == 0,
296+
Field0 = new float[_simpleBinaryClassSampleFeatureLength],
297+
Field1 = new float[_simpleBinaryClassSampleFeatureLength],
298+
Field2 = new float[_simpleBinaryClassSampleFeatureLength] };
299+
// Fill feature vector according the assigned label.
300+
for (int j = 0; j < 10; ++j)
301+
{
302+
var value0 = (float)rnd.NextDouble();
303+
// Positive class gets larger feature value.
304+
if (sample.Label)
305+
value0 += 0.2f;
306+
sample.Field0[j] = value0;
307+
308+
var value1 = (float)rnd.NextDouble();
309+
// Positive class gets smaller feature value.
310+
if (sample.Label)
311+
value1 -= 0.2f;
312+
sample.Field1[j] = value1;
313+
314+
var value2 = (float)rnd.NextDouble();
315+
// Positive class gets larger feature value.
316+
if (sample.Label)
317+
value2 += 0.8f;
318+
sample.Field2[j] = value2;
319+
}
320+
321+
data.Add(sample);
322+
}
323+
return data;
324+
}
325+
274326
/// <summary>
275327
/// feature vector's length in <see cref="MulticlassClassificationExample"/>.
276328
/// </summary>

src/Microsoft.ML.StandardLearners/FactorizationMachine/FactorizationMachineTrainer.cs

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ public sealed class FieldAwareFactorizationMachineTrainer : TrainerBase<FieldAwa
4040
internal const string LoadName = "FieldAwareFactorizationMachine";
4141
internal const string ShortName = "ffm";
4242

43-
public sealed class Arguments : LearnerInputBaseWithLabel
43+
public sealed class Arguments : LearnerInputBaseWithWeight
4444
{
4545
[Argument(ArgumentType.AtMostOnce, HelpText = "Initial learning rate", ShortName = "lr", SortOrder = 1)]
4646
[TlcModule.SweepableFloatParam(0.001f, 1.0f, isLogScale: true)]
@@ -65,6 +65,15 @@ public sealed class Arguments : LearnerInputBaseWithLabel
6565
[Argument(ArgumentType.AtMostOnce, HelpText = "Whether to normalize the input vectors so that the concatenation of all fields' feature vectors is unit-length", ShortName = "norm", SortOrder = 6)]
6666
public bool Norm = true;
6767

68+
/// <summary>
69+
/// Extra feature column names. The column named <see cref="LearnerInputBase.FeatureColumn"/> stores features from the first field.
70+
/// The i-th string in <see cref="ExtraFeatureColumns"/> stores the name of the (i+1)-th field's feature column.
71+
/// </summary>
72+
[Argument(ArgumentType.Multiple, HelpText = "Extra columns to use for feature vectors. The i-th specified string denotes the column containing features form the (i+1)-th field." +
73+
" Note that the first field is specified by \"feat\" instead of \"exfeat\".",
74+
ShortName = "exfeat", SortOrder = 7)]
75+
public string[] ExtraFeatureColumns;
76+
6877
[Argument(ArgumentType.AtMostOnce, HelpText = "Whether to shuffle for each training iteration", ShortName = "shuf", SortOrder = 90)]
6978
public bool Shuffle = true;
7079

@@ -122,13 +131,26 @@ public FieldAwareFactorizationMachineTrainer(IHostEnvironment env, Arguments arg
122131
{
123132
Initialize(env, args);
124133
Info = new TrainerInfo(supportValid: true, supportIncrementalTrain: true);
134+
135+
// There can be multiple feature columns in FFM, jointly specified by args.FeatureColumn and args.ExtraFeatureColumns.
136+
FeatureColumns = new SchemaShape.Column[1 + args.ExtraFeatureColumns.Length];
137+
138+
// Treat the default feature column as the 1st field.
139+
FeatureColumns[0] = new SchemaShape.Column(args.FeatureColumn, SchemaShape.Column.VectorKind.Vector, NumberType.R4, false);
140+
141+
// Add 2nd, 3rd, and other fields from a FFM-specific argument, args.ExtraFeatureColumns.
142+
for (int i = 0; args.ExtraFeatureColumns != null && i < args.ExtraFeatureColumns.Length; i++)
143+
FeatureColumns[i + 1] = new SchemaShape.Column(args.ExtraFeatureColumns[i], SchemaShape.Column.VectorKind.Vector, NumberType.R4, false);
144+
145+
LabelColumn = new SchemaShape.Column(args.LabelColumn, SchemaShape.Column.VectorKind.Scalar, BoolType.Instance, false);
146+
WeightColumn = args.WeightColumn.IsExplicit ? new SchemaShape.Column(args.WeightColumn, SchemaShape.Column.VectorKind.Scalar, NumberType.R4, false) : default;
125147
}
126148

127149
/// <summary>
128150
/// Initializing a new instance of <see cref="FieldAwareFactorizationMachineTrainer"/>.
129151
/// </summary>
130152
/// <param name="env">The private instance of <see cref="IHostEnvironment"/>.</param>
131-
/// <param name="featureColumns">The name of column hosting the features.</param>
153+
/// <param name="featureColumns">The name of column hosting the features. The i-th element stores feature column of the i-th field.</param>
132154
/// <param name="labelColumn">The name of the label column.</param>
133155
/// <param name="advancedSettings">A delegate to apply all the advanced arguments to the algorithm.</param>
134156
/// <param name="weights">The name of the optional weights' column.</param>

test/BaselineOutput/Common/EntryPoints/core_manifest.json

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10222,6 +10222,18 @@
1022210222
"IsLogScale": true
1022310223
}
1022410224
},
10225+
{
10226+
"Name": "WeightColumn",
10227+
"Type": "String",
10228+
"Desc": "Column to use for example weight",
10229+
"Aliases": [
10230+
"weight"
10231+
],
10232+
"Required": false,
10233+
"SortOrder": 4.0,
10234+
"IsNullable": false,
10235+
"Default": "Weight"
10236+
},
1022510237
{
1022610238
"Name": "LambdaLatent",
1022710239
"Type": "Float",
@@ -10292,6 +10304,21 @@
1029210304
"IsNullable": false,
1029310305
"Default": "Auto"
1029410306
},
10307+
{
10308+
"Name": "ExtraFeatureColumns",
10309+
"Type": {
10310+
"Kind": "Array",
10311+
"ItemType": "String"
10312+
},
10313+
"Desc": "Extra columns to use for feature vectors. The i-th specified string denotes the column containing features form the (i+1)-th field. Note that the first field is specified by \"feat\" instead of \"exfeat\".",
10314+
"Aliases": [
10315+
"exfeat"
10316+
],
10317+
"Required": false,
10318+
"SortOrder": 7.0,
10319+
"IsNullable": false,
10320+
"Default": null
10321+
},
1029510322
{
1029610323
"Name": "Shuffle",
1029710324
"Type": "Bool",
@@ -10342,6 +10369,7 @@
1034210369
}
1034310370
],
1034410371
"InputKind": [
10372+
"ITrainerInputWithWeight",
1034510373
"ITrainerInputWithLabel",
1034610374
"ITrainerInput"
1034710375
],

test/Microsoft.ML.Tests/TrainerEstimators/FAFMEstimator.cs

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,15 +2,43 @@
22
// The .NET Foundation licenses this file to you under the MIT license.
33
// See the LICENSE file in the project root for more information.
44

5+
using System.Linq;
56
using Microsoft.ML.Data;
67
using Microsoft.ML.FactorizationMachine;
78
using Microsoft.ML.RunTests;
9+
using Microsoft.ML.SamplesUtils;
810
using Xunit;
911

1012
namespace Microsoft.ML.Tests.TrainerEstimators
1113
{
1214
public partial class TrainerEstimators : TestDataPipeBase
1315
{
16+
[Fact]
17+
public void FfmBinaryClassificationWithAdvancedArguments()
18+
{
19+
var mlContext = new MLContext(seed: 0);
20+
var data = DatasetUtils.GenerateFfmSamples(500);
21+
var dataView = ComponentCreation.CreateDataView(mlContext, data.ToList());
22+
23+
var ffmArgs = new FieldAwareFactorizationMachineTrainer.Arguments();
24+
25+
// Customized the field names.
26+
ffmArgs.FeatureColumn = nameof(DatasetUtils.FfmExample.Field0); // First field.
27+
ffmArgs.ExtraFeatureColumns = new[]{ nameof(DatasetUtils.FfmExample.Field1), nameof(DatasetUtils.FfmExample.Field2) };
28+
29+
var pipeline = new FieldAwareFactorizationMachineTrainer(mlContext, ffmArgs);
30+
31+
var model = pipeline.Fit(dataView);
32+
var prediction = model.Transform(dataView);
33+
34+
var metrics = mlContext.BinaryClassification.Evaluate(prediction);
35+
36+
// Run a sanity check against a few of the metrics.
37+
Assert.InRange(metrics.Accuracy, 0.9, 1);
38+
Assert.InRange(metrics.Auc, 0.9, 1);
39+
Assert.InRange(metrics.Auprc, 0.9, 1);
40+
}
41+
1442
[Fact]
1543
public void FieldAwareFactorizationMachine_Estimator()
1644
{

0 commit comments

Comments
 (0)