Skip to content

Commit

Permalink
CLI tool - make validation dataset optional and support for crossvali…
Browse files Browse the repository at this point in the history
…dation in generated code (dotnet#83)

* Added sequential grouping of columns

* reverted the file

* bug fixes, more logic to templates to support cross-validate

* formatting and fix type in consolehelper

* Added logic in templates

* revert settings
  • Loading branch information
srsaggam authored and Dmitry-A committed Aug 22, 2019
1 parent 6b179df commit fd71c47
Show file tree
Hide file tree
Showing 7 changed files with 221 additions and 168 deletions.
2 changes: 1 addition & 1 deletion src/mlnet/CodeGenerator/TrainerGenerators.cs
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,7 @@ internal class LinearSvm : TrainerGeneratorBase
internal override string MethodName => "LinearSupportVectorMachines";

//ClassName of the options to trainer
internal override string OptionsName => "LinearSvm.Options";
internal override string OptionsName => "LinearSvmTrainer.Options";

//The named parameters to the trainer.
internal override IDictionary<string, string> NamedParameters
Expand Down
40 changes: 2 additions & 38 deletions src/mlnet/Commands/CommandDefinitions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -18,23 +18,19 @@ public static System.CommandLine.Command New()
{
var newCommand = new System.CommandLine.Command("new", "ML.NET CLI tool for code generation",

handler: CommandHandler.Create</*FileInfo,*/ FileInfo,/* FileInfo,*/ FileInfo, TaskKind, string>((/*FileInfo dataset,*/ FileInfo trainDataset, /*FileInfo validationDataset,*/ FileInfo testDataset, TaskKind mlTask, string labelColumnName) =>
handler: CommandHandler.Create<FileInfo, FileInfo, TaskKind, string>((FileInfo trainDataset, FileInfo testDataset, TaskKind mlTask, string labelColumnName) =>
{
NewCommand.Run(new Options()
{
/*Dataset = dataset,*/
TrainDataset = trainDataset,
/*ValidationDataset = validationDataset,*/
TestDataset = testDataset,
MlTask = mlTask,
LabelName = labelColumnName
});
}))
{
//Dataset(),
TrainDataset(),
//ValidationDataset(),
TestDataset(),
MlTask(),
LabelName(),
Expand All @@ -51,10 +47,6 @@ public static System.CommandLine.Command New()
{
return "Option required : --train-dataset";
}
if (sym.Children["--test-dataset"] == null)
{
return "Option required : --test-dataset";
}
if (sym.Children["--ml-task"] == null)
{
return "Option required : --ml-task";
Expand All @@ -69,21 +61,14 @@ public static System.CommandLine.Command New()

return newCommand;

//Option Dataset() =>
// new Option("--dataset", "Dataset file path.",
// new Argument<FileInfo>().ExistingOnly());

Option TrainDataset() =>
new Option("--train-dataset", "Train dataset file path.",
new Argument<FileInfo>().ExistingOnly());

//Option ValidationDataset() =>
// new Option("--validation-dataset", "Test dataset file path.",
// new Argument<FileInfo>().ExistingOnly());

Option TestDataset() =>
new Option("--test-dataset", "Test dataset file path.",
new Argument<FileInfo>().ExistingOnly());
new Argument<FileInfo>(defaultValue: default(FileInfo)).ExistingOnly());

Option MlTask() =>
new Option("--ml-task", "Type of ML task.",
Expand All @@ -93,27 +78,6 @@ Option LabelName() =>
new Option("--label-column-name", "Name of the label column.",
new Argument<string>());

//Option ColumnSeperator() =>
// new Option("--column-separator", "Column separator in dataset file.",
// new Argument<string>(defaultValue: default(string)));

//Option ExplorationTimeout() =>
// new Option("--exploration-timeout", "Timeout for exploring the best models.",
// new Argument<int>(defaultValue: 10));

//Option Name() =>
// new Option("--name", "Name of the project file.",
// new Argument<string>(defaultValue: "SampleProject"));

//Option ShowOutput() =>
// new Option("--show-output", "Show output on the console",
// new Argument<bool>(defaultValue: true));

//Option LabelIndex() =>
// new Option("--label-column-index", "Index of the label column.",
// new Argument<int>(defaultValue: -1));


}

private static string[] GetMlTaskSuggestions()
Expand Down
18 changes: 11 additions & 7 deletions src/mlnet/Commands/NewCommand.cs
Original file line number Diff line number Diff line change
Expand Up @@ -24,17 +24,21 @@ internal static void Run(Options options)
// For Version 0.1 It is required that the data set has header.
var columnInference = context.Data.InferColumns(options.TrainDataset.FullName, label, true, groupColumns: false);
var textLoader = context.Data.CreateTextLoader(columnInference);
var trainData = textLoader.Read(options.TrainDataset.FullName);

var validationData = textLoader.Read(options.TestDataset.FullName);
Pipeline pipelineToDeconstruct = null;
IDataView trainData = textLoader.Read(options.TrainDataset.FullName);
IDataView validationData = options.TestDataset == null ? null : textLoader.Read(options.TestDataset.FullName);

var result = ExploreModels(options, context, label, trainData, validationData, pipelineToDeconstruct);
pipelineToDeconstruct = result.Item1;
//Explore the models
Pipeline pipeline = null;
var result = ExploreModels(options, context, label, trainData, validationData, pipeline);

//Get the best pipeline
pipeline = result.Item1;
var model = result.Item2;

//Path can be overriden from args
GenerateModel(model, @"./BestModel", "model.zip", context);
RunCodeGen(options, columnInference, pipelineToDeconstruct);
RunCodeGen(options, columnInference, pipeline);
}

private static void GenerateModel(ITransformer model, string ModelPath, string modelName, MLContext mlContext)
Expand Down Expand Up @@ -116,7 +120,7 @@ private static void RunCodeGen(Options options, ColumnInferenceResult columnInfe
MLCodeGen codeGen = new MLCodeGen()
{
Path = options.TrainDataset.FullName,
TestPath = options.TestDataset.FullName,
TestPath = options.TestDataset?.FullName,
Columns = columns,
Transforms = transforms,
HasHeader = columnInference.HasHeader,
Expand Down
Loading

0 comments on commit fd71c47

Please sign in to comment.