Skip to content

Commit

Permalink
add trainer extension tests, & misc fixes (dotnet#23)
Browse files Browse the repository at this point in the history
  • Loading branch information
daholste authored Jan 24, 2019
1 parent f7e6376 commit 3d3567c
Show file tree
Hide file tree
Showing 6 changed files with 84 additions and 25 deletions.
6 changes: 0 additions & 6 deletions src/AutoML/TrainerExtensions/MultiTrainerExtensions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,6 @@ public ITrainerEstimator CreateInstance(MLContext mlContext, IEnumerable<Sweepab

internal class LightGbmMultiExtension : ITrainerExtension
{
private static readonly ITrainerExtension _binaryLearnerCatalogItem = new LightGbmBinaryExtension();

public IEnumerable<SweepableParam> GetHyperparamSweepRanges()
{
return SweepableParams.BuildLightGbmParams();
Expand Down Expand Up @@ -80,8 +78,6 @@ public ITrainerEstimator CreateInstance(MLContext mlContext, IEnumerable<Sweepab

internal class SdcaMultiExtension : ITrainerExtension
{
private static readonly ITrainerExtension _binaryLearnerCatalogItem = new SdcaBinaryExtension();

public IEnumerable<SweepableParam> GetHyperparamSweepRanges()
{
return SweepableParams.BuildSdcaParams();
Expand Down Expand Up @@ -161,8 +157,6 @@ public ITrainerEstimator CreateInstance(MLContext mlContext, IEnumerable<Sweepab

internal class LogisticRegressionMultiExtension : ITrainerExtension
{
private static readonly ITrainerExtension _binaryLearnerCatalogItem = new LogisticRegressionBinaryExtension();

public IEnumerable<SweepableParam> GetHyperparamSweepRanges()
{
return SweepableParams.BuildLogisticRegressionParams();
Expand Down
2 changes: 1 addition & 1 deletion src/AutoML/TrainerExtensions/TrainerExtensionCatalog.cs
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ public static IEnumerable<ITrainerExtension> GetTrainers(TaskKind task, int maxI
{
return GetBinaryLearners(maxIterations);
}
else if (task == TaskKind.BinaryClassification)
else if (task == TaskKind.MulticlassClassification)
{
return GetMultiLearners(maxIterations);
}
Expand Down
8 changes: 0 additions & 8 deletions src/AutoML/Utils/SweepableParamAttributes.cs
Original file line number Diff line number Diff line change
Expand Up @@ -76,14 +76,6 @@ public override void SetUsingValueText(string valueText)
RawValue = i;
}

public int IndexOf(object option)
{
for (int i = 0; i < Options.Length; i++)
if (option == Options[i])
return i;
return -1;
}

private static string TranslateOption(object o)
{
switch (o)
Expand Down
36 changes: 26 additions & 10 deletions src/Test/SweeperTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -8,20 +8,35 @@ namespace Microsoft.ML.Auto.Test
public class SweeperTests
{
[TestMethod]
public void Smac3ParamsTest()
public void SmacQuickRunTest()
{
var numInitialPopulation = 10;

var floatValueGenerator = new FloatValueGenerator(new FloatParamArguments() { Name = "float", Min = 1, Max = 1000 });
var floatLogValueGenerator = new FloatValueGenerator(new FloatParamArguments() { Name = "floatLog", Min = 1, Max = 1000, LogBase = true });
var longValueGenerator = new LongValueGenerator(new LongParamArguments() { Name = "long", Min = 1, Max = 1000 });
var longLogValueGenerator = new LongValueGenerator(new LongParamArguments() { Name = "longLog", Min = 1, Max = 1000, LogBase = true });
var discreteValueGeneator = new DiscreteValueGenerator(new DiscreteParamArguments() { Name = "discrete", Values = new[] { "200", "400", "600", "800" } });

var sweeper = new SmacSweeper(new SmacSweeper.Arguments()
{
SweptParameters = new IValueGenerator[] {
new FloatValueGenerator(new FloatParamArguments() { Name = "x1", Min = 1, Max = 1000}),
new LongValueGenerator(new LongParamArguments() { Name = "x2", Min = 1, Max = 1000}),
new DiscreteValueGenerator(new DiscreteParamArguments() { Name = "x3", Values = new[] { "200", "400", "600", "800" } }),
floatValueGenerator,
floatLogValueGenerator,
longValueGenerator,
longLogValueGenerator,
discreteValueGeneator
},
NumberInitialPopulation = numInitialPopulation
});

// sanity check grid
Assert.IsNotNull(floatValueGenerator[0].ValueText);
Assert.IsNotNull(floatLogValueGenerator[0].ValueText);
Assert.IsNotNull(longValueGenerator[0].ValueText);
Assert.IsNotNull(longLogValueGenerator[0].ValueText);
Assert.IsNotNull(discreteValueGeneator[0].ValueText);

List<RunResult> results = new List<RunResult>();

RunResult bestResult = null;
Expand All @@ -31,12 +46,13 @@ public void Smac3ParamsTest()

foreach (ParameterSet p in pars)
{
float x1 = (p["x1"] as FloatParameterValue).Value;
float x2 = (p["x2"] as LongParameterValue).Value;
float x3 = float.Parse(p["x3"].ValueText);
float x1 = float.Parse(p["float"].ValueText);
float x2 = float.Parse(p["floatLog"].ValueText);
long x3 = long.Parse(p["long"].ValueText);
long x4 = long.Parse(p["longLog"].ValueText);
int x5 = int.Parse(p["discrete"].ValueText);

double metric = -200 * (Math.Abs(100 - x1) +
Math.Abs(300 - x2) + Math.Abs(500 - x3));
double metric = x1 + x2 + x3 + x4 + x5;

RunResult result = new RunResult(p, metric, true);
if (bestResult == null || bestResult.MetricValue < metric)
Expand All @@ -53,7 +69,7 @@ public void Smac3ParamsTest()
Console.WriteLine($"Best: {bestResult.MetricValue}");

Assert.IsNotNull(bestResult);
Assert.IsTrue(bestResult.MetricValue != 0);
Assert.IsTrue(bestResult.MetricValue > 0);
}


Expand Down
46 changes: 46 additions & 0 deletions src/Test/TrainerExtensionsTests.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
using System;
using System.Linq;
using Microsoft.VisualStudio.TestTools.UnitTesting;

namespace Microsoft.ML.Auto.Test
{
[TestClass]
public class TrainerExtensionsTests
{
[TestMethod]
public void TrainerExtensionInstanceTests()
{
var context = new MLContext();
var trainerNames = Enum.GetValues(typeof(TrainerName)).Cast<TrainerName>();
foreach(var trainerName in trainerNames)
{
var extension = TrainerExtensionCatalog.GetTrainerExtension(trainerName);
var instance = extension.CreateInstance(context, null);
Assert.IsNotNull(instance);
var sweepParams = extension.GetHyperparamSweepRanges();
Assert.IsNotNull(sweepParams);
}
}

[TestMethod]
public void GetTrainersByMaxIterations()
{
var tasks = new TaskKind[] { TaskKind.BinaryClassification,
TaskKind.MulticlassClassification, TaskKind.Regression };

foreach(var task in tasks)
{
var trainerSet10 = TrainerExtensionCatalog.GetTrainers(task, 10);
var trainerSet50 = TrainerExtensionCatalog.GetTrainers(task, 50);
var trainerSet100 = TrainerExtensionCatalog.GetTrainers(task, 100);

Assert.IsNotNull(trainerSet10);
Assert.IsNotNull(trainerSet50);
Assert.IsNotNull(trainerSet100);

Assert.IsTrue(trainerSet10.Count() < trainerSet50.Count());
Assert.IsTrue(trainerSet50.Count() < trainerSet100.Count());
}
}
}
}
11 changes: 11 additions & 0 deletions src/Test/UserInputValidationTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,17 @@ public void ValidateAutoFitArgsPurposeOverrideDuplicateCol()
});
}

[TestMethod]
public void ValidateAutoFitArgsPurposeOverrideSuccess()
{
UserInputValidationUtil.ValidateAutoFitArgs(DatasetUtil.GetUciAdultDataView(),
DatasetUtil.UciAdultLabel, DatasetUtil.GetUciAdultDataView(),
null, new List<(string, ColumnPurpose)>()
{
("Workclass", ColumnPurpose.CategoricalFeature)
});
}

[TestMethod]
[ExpectedException(typeof(ArgumentException))]
public void ValidateAutoFitArgsTrainValidColCountMismatch()
Expand Down

0 comments on commit 3d3567c

Please sign in to comment.