Skip to content

Commit

Permalink
Remove duplicate value-to-key mapping transform for multiclass string…
Browse files Browse the repository at this point in the history
… labels (dotnet#283)
  • Loading branch information
daholste authored and Dmitry-A committed Aug 22, 2019
1 parent 86f86c1 commit 27e2e57
Show file tree
Hide file tree
Showing 5 changed files with 32 additions and 38 deletions.
19 changes: 1 addition & 18 deletions src/Microsoft.ML.Auto/PipelineSuggesters/PipelineSuggester.cs
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,7 @@ public static SuggestedPipeline GetNextInferredPipeline(MLContext context,
{
var availableTrainers = RecipeInference.AllowedTrainers(context, task,
ColumnInformationUtil.BuildColumnInfo(columns), trainerWhitelist);
var transforms = CalculateTransforms(context, columns, task);
//var transforms = TransformInferenceApi.InferTransforms(context, columns, task);
var transforms = TransformInferenceApi.InferTransforms(context, task, columns);

// if we haven't run all pipelines once
if (history.Count() < availableTrainers.Count())
Expand Down Expand Up @@ -213,21 +212,5 @@ private static bool SampleHyperparameters(MLContext context, SuggestedTrainer tr

return true;
}

private static IEnumerable<SuggestedTransform> CalculateTransforms(
MLContext context,
(string, DataViewType, ColumnPurpose, ColumnDimensions)[] columns,
TaskKind task)
{
var transforms = TransformInferenceApi.InferTransforms(context, columns).ToList();
// this is a work-around for ML.NET bug tracked by https://github.com/dotnet/machinelearning/issues/1969
if (task == TaskKind.MulticlassClassification)
{
var labelColumn = columns.First(c => c.Item3 == ColumnPurpose.Label).Item1;
var transform = ValueToKeyMappingExtension.CreateSuggestedTransform(context, labelColumn, labelColumn);
transforms.Add(transform);
}
return transforms;
}
}
}
33 changes: 19 additions & 14 deletions src/Microsoft.ML.Auto/TransformInference/TransformInference.cs
Original file line number Diff line number Diff line change
Expand Up @@ -117,12 +117,12 @@ public bool Equals(ColumnRoutingStructure obj)

internal interface ITransformInferenceExpert
{
IEnumerable<SuggestedTransform> Apply(IntermediateColumn[] columns);
IEnumerable<SuggestedTransform> Apply(IntermediateColumn[] columns, TaskKind task);
}

public abstract class TransformInferenceExpertBase : ITransformInferenceExpert
{
public abstract IEnumerable<SuggestedTransform> Apply(IntermediateColumn[] columns);
public abstract IEnumerable<SuggestedTransform> Apply(IntermediateColumn[] columns, TaskKind task);

protected readonly MLContext Context;

Expand All @@ -137,8 +137,8 @@ private static IEnumerable<ITransformInferenceExpert> GetExperts(MLContext conte
// The expert work independently of each other, the sequence is irrelevant
// (it only determines the sequence of resulting transforms).

// For text labels, convert to categories.
yield return new Experts.AutoLabel(context);
// For multiclass tasks, convert label column to key
yield return new Experts.Label(context);

// For boolean columns use convert transform
yield return new Experts.Boolean(context);
Expand All @@ -155,21 +155,26 @@ private static IEnumerable<ITransformInferenceExpert> GetExperts(MLContext conte

internal static class Experts
{
internal sealed class AutoLabel : TransformInferenceExpertBase
internal sealed class Label : TransformInferenceExpertBase
{
public AutoLabel(MLContext context) : base(context)
public Label(MLContext context) : base(context)
{
}

public override IEnumerable<SuggestedTransform> Apply(IntermediateColumn[] columns)
public override IEnumerable<SuggestedTransform> Apply(IntermediateColumn[] columns, TaskKind task)
{
if (task != TaskKind.MulticlassClassification)
{
yield break;
}

var lastLabelColId = Array.FindLastIndex(columns, x => x.Purpose == ColumnPurpose.Label);
if (lastLabelColId < 0)
yield break;

var col = columns[lastLabelColId];

if (col.Type.IsText())
if (!col.Type.IsKey())
{
yield return ValueToKeyMappingExtension.CreateSuggestedTransform(Context, col.ColumnName, col.ColumnName);
}
Expand All @@ -182,7 +187,7 @@ public Categorical(MLContext context) : base(context)
{
}

public override IEnumerable<SuggestedTransform> Apply(IntermediateColumn[] columns)
public override IEnumerable<SuggestedTransform> Apply(IntermediateColumn[] columns, TaskKind task)
{
bool foundCat = false;
bool foundCatHash = false;
Expand Down Expand Up @@ -232,7 +237,7 @@ public Boolean(MLContext context) : base(context)
{
}

public override IEnumerable<SuggestedTransform> Apply(IntermediateColumn[] columns)
public override IEnumerable<SuggestedTransform> Apply(IntermediateColumn[] columns, TaskKind task)
{
var newColumns = new List<string>();

Expand Down Expand Up @@ -260,7 +265,7 @@ public Text(MLContext context) : base(context)
{
}

public override IEnumerable<SuggestedTransform> Apply(IntermediateColumn[] columns)
public override IEnumerable<SuggestedTransform> Apply(IntermediateColumn[] columns, TaskKind task)
{
var featureCols = new List<string>();

Expand All @@ -286,7 +291,7 @@ public NumericMissing(MLContext context) : base(context)
{
}

public override IEnumerable<SuggestedTransform> Apply(IntermediateColumn[] columns)
public override IEnumerable<SuggestedTransform> Apply(IntermediateColumn[] columns, TaskKind task)
{
var columnsWithMissing = new List<string>();
foreach (var column in columns)
Expand All @@ -313,7 +318,7 @@ public override IEnumerable<SuggestedTransform> Apply(IntermediateColumn[] colum
/// <summary>
/// Automatically infer transforms for the data view
/// </summary>
public static SuggestedTransform[] InferTransforms(MLContext context, (string, DataViewType, ColumnPurpose, ColumnDimensions)[] columns)
public static SuggestedTransform[] InferTransforms(MLContext context, TaskKind task, (string, DataViewType, ColumnPurpose, ColumnDimensions)[] columns)
{
var intermediateCols = columns.Where(c => c.Item3 != ColumnPurpose.Ignore)
.Select(c => new IntermediateColumn(c.Item1, c.Item2, c.Item3, c.Item4))
Expand All @@ -322,7 +327,7 @@ public static SuggestedTransform[] InferTransforms(MLContext context, (string, D
var suggestedTransforms = new List<SuggestedTransform>();
foreach (var expert in GetExperts(context))
{
SuggestedTransform[] suggestions = expert.Apply(intermediateCols).ToArray();
SuggestedTransform[] suggestions = expert.Apply(intermediateCols, task).ToArray();
suggestedTransforms.AddRange(suggestions);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,9 @@ namespace Microsoft.ML.Auto
{
internal static class TransformInferenceApi
{
public static IEnumerable<SuggestedTransform> InferTransforms(MLContext context, (string, DataViewType, ColumnPurpose, ColumnDimensions)[] columns)
public static IEnumerable<SuggestedTransform> InferTransforms(MLContext context, TaskKind task, (string, DataViewType, ColumnPurpose, ColumnDimensions)[] columns)
{
return TransformInference.InferTransforms(context, columns);
return TransformInference.InferTransforms(context, task, columns);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,11 @@ public static bool IsVector(this DataViewType columnType)
return columnType is VectorType;
}

public static bool IsKey(this DataViewType columnType)
{
return columnType is KeyType;
}

public static bool IsKnownSizeVector(this DataViewType columnType)
{
var vector = columnType as VectorType;
Expand Down
9 changes: 5 additions & 4 deletions src/Test/TransformInferenceTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -645,7 +645,7 @@ public void TransformInferenceCustomLabelCol()
}

[TestMethod]
public void TransformInferenceCustomTextLabelCol()
public void TransformInferenceCustomTextLabelColMulticlass()
{
TransformInferenceTestCore(new (string, DataViewType, ColumnPurpose, ColumnDimensions)[]
{
Expand All @@ -663,7 +663,7 @@ public void TransformInferenceCustomTextLabelCol()
],
""Properties"": {}
}
]");
]", TaskKind.MulticlassClassification);
}

[TestMethod]
Expand Down Expand Up @@ -727,9 +727,10 @@ public void TransformInferenceMissingNameCollision()

private static void TransformInferenceTestCore(
(string name, DataViewType type, ColumnPurpose purpose, ColumnDimensions dimensions)[] columns,
string expectedJson)
string expectedJson,
TaskKind task = TaskKind.BinaryClassification)
{
var transforms = TransformInferenceApi.InferTransforms(new MLContext(), columns);
var transforms = TransformInferenceApi.InferTransforms(new MLContext(), task, columns);
TestApplyTransformsToRealDataView(transforms, columns);
var pipelineNodes = transforms.Select(t => t.PipelineNode);
Util.AssertObjectMatchesJson(expectedJson, pipelineNodes);
Expand Down

0 comments on commit 27e2e57

Please sign in to comment.