Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 53 additions & 1 deletion src/Microsoft.ML.SamplesUtils/SamplesDatasetUtils.cs
Original file line number Diff line number Diff line change
Expand Up @@ -257,7 +257,7 @@ public static IEnumerable<BinaryLabelFloatFeatureVectorSample> GenerateBinaryLa
// Initialize an example with a random label and an empty feature vector.
var sample = new BinaryLabelFloatFeatureVectorSample() { Label = rnd.Next() % 2 == 0, Features = new float[_simpleBinaryClassSampleFeatureLength] };
// Fill feature vector according the assigned label.
for (int j = 0; j < 10; ++j)
for (int j = 0; j < _simpleBinaryClassSampleFeatureLength; ++j)
{
var value = (float)rnd.NextDouble();
// Positive class gets larger feature value.
Expand All @@ -271,6 +271,58 @@ public static IEnumerable<BinaryLabelFloatFeatureVectorSample> GenerateBinaryLa
return data;
}

public class FfmExample
{
public bool Label;

[VectorType(_simpleBinaryClassSampleFeatureLength)]
Copy link
Member

@abgoswam abgoswam Jan 22, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

VectorType [](start = 13, length = 10)

am curious - is this attribute required ? #Closed

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is required.


In reply to: 249994977 [](ancestors = 249994977)

public float[] Field0;

[VectorType(_simpleBinaryClassSampleFeatureLength)]
public float[] Field1;

[VectorType(_simpleBinaryClassSampleFeatureLength)]
public float[] Field2;
}

public static IEnumerable<FfmExample> GenerateFfmSamples(int exampleCount)
{
var rnd = new Random(0);
var data = new List<FfmExample>();
for (int i = 0; i < exampleCount; ++i)
{
// Initialize an example with a random label and an empty feature vector.
var sample = new FfmExample() { Label = rnd.Next() % 2 == 0,
Field0 = new float[_simpleBinaryClassSampleFeatureLength],
Field1 = new float[_simpleBinaryClassSampleFeatureLength],
Field2 = new float[_simpleBinaryClassSampleFeatureLength] };
// Fill feature vector according the assigned label.
for (int j = 0; j < 10; ++j)
Copy link
Member

@abgoswam abgoswam Jan 22, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

10 [](start = 36, length = 2)

_simpleBinaryClassSampleFeatureLength ? #Closed

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice catch. Fixed.


In reply to: 249993941 [](ancestors = 249993941)

{
var value0 = (float)rnd.NextDouble();
// Positive class gets larger feature value.
if (sample.Label)
value0 += 0.2f;
sample.Field0[j] = value0;

var value1 = (float)rnd.NextDouble();
// Positive class gets smaller feature value.
if (sample.Label)
value1 -= 0.2f;
sample.Field1[j] = value1;

var value2 = (float)rnd.NextDouble();
// Positive class gets larger feature value.
if (sample.Label)
value2 += 0.8f;
sample.Field2[j] = value2;
}

data.Add(sample);
}
return data;
}

/// <summary>
/// feature vector's length in <see cref="MulticlassClassificationExample"/>.
/// </summary>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,16 @@ public sealed class FieldAwareFactorizationMachineTrainer : TrainerBase<FieldAwa
internal const string LoadName = "FieldAwareFactorizationMachine";
internal const string ShortName = "ffm";

public sealed class Arguments : LearnerInputBaseWithLabel
public sealed class Arguments : LearnerInputBaseWithWeight
{
/// <summary>
/// Columns to use for features. The i-th string in <see cref="FeatureColumn"/> stores the name of the features
/// form the i-th field.
/// </summary>
[Argument(ArgumentType.AtMostOnce, HelpText = "Columns to use for feature vectors. The i-th specified string denotes the column containing features form the i-th field.",
ShortName = "feat", SortOrder = 2, Visibility = ArgumentAttribute.VisibilityType.EntryPointsOnly)]
public new string[] FeatureColumn = { DefaultColumnNames.Features };

[Argument(ArgumentType.AtMostOnce, HelpText = "Initial learning rate", ShortName = "lr", SortOrder = 1)]
[TlcModule.SweepableFloatParam(0.001f, 1.0f, isLogScale: true)]
public float LearningRate = (float)0.1;
Expand Down Expand Up @@ -122,6 +130,14 @@ public FieldAwareFactorizationMachineTrainer(IHostEnvironment env, Arguments arg
{
Initialize(env, args);
Info = new TrainerInfo(supportValid: true, supportIncrementalTrain: true);

FeatureColumns = new SchemaShape.Column[args.FeatureColumn.Length];

for (int i = 0; i < args.FeatureColumn.Length; i++)
FeatureColumns[i] = new SchemaShape.Column(args.FeatureColumn[i], SchemaShape.Column.VectorKind.Vector, NumberType.R4, false);

LabelColumn = new SchemaShape.Column(args.LabelColumn, SchemaShape.Column.VectorKind.Scalar, BoolType.Instance, false);
WeightColumn = args.WeightColumn.IsExplicit ? new SchemaShape.Column(args.WeightColumn, SchemaShape.Column.VectorKind.Scalar, NumberType.R4, false) : default;
}

/// <summary>
Expand Down
26 changes: 26 additions & 0 deletions test/Microsoft.ML.StaticPipelineTesting/Training.cs
Original file line number Diff line number Diff line change
Expand Up @@ -298,6 +298,32 @@ public void FfmBinaryClassification()
Assert.InRange(metrics.Auprc, 0, 1);
}

[Fact]
public void FfmBinaryClassificationWithAdvancedArguments()
{
var mlContext = new MLContext(seed: 0);
var data = DatasetUtils.GenerateFfmSamples(500);
var dataView = ComponentCreation.CreateDataView(mlContext, data.ToList());

var ffmArgs = new FieldAwareFactorizationMachineTrainer.Arguments();
// Customized field names
ffmArgs.FeatureColumn = new[]{
nameof(DatasetUtils.FfmExample.Field0),
nameof(DatasetUtils.FfmExample.Field1),
nameof(DatasetUtils.FfmExample.Field2) };
var pipeline = new FieldAwareFactorizationMachineTrainer(mlContext, ffmArgs);

var model = pipeline.Fit(dataView);
var prediction = model.Transform(dataView);

var metrics = mlContext.BinaryClassification.Evaluate(prediction);

// Run a sanity check against a few of the metrics.
Assert.InRange(metrics.Accuracy, 0.9, 1);
Assert.InRange(metrics.Auc, 0.9, 1);
Assert.InRange(metrics.Auprc, 0.9, 1);
}

[Fact]
public void SdcaMulticlass()
{
Expand Down