diff --git a/src/Microsoft.ML.AutoML/CodeGen/image_classification_search_space.json b/src/Microsoft.ML.AutoML/CodeGen/image_classification_search_space.json index 0359ff78d0..fec9b9016e 100644 --- a/src/Microsoft.ML.AutoML/CodeGen/image_classification_search_space.json +++ b/src/Microsoft.ML.AutoML/CodeGen/image_classification_search_space.json @@ -16,6 +16,21 @@ "name": "FeatureColumnName", "type": "string", "default": "Feature" + }, + { + "name": "Arch", + "type": "imageClassificationArchType", + "default": "ResnetV250" + }, + { + "name": "BatchSize", + "type": "integer", + "default": 10 + }, + { + "name": "Epoch", + "type": "integer", + "default": 200 } ] } diff --git a/src/Microsoft.ML.AutoML/CodeGen/search-space-schema.json b/src/Microsoft.ML.AutoML/CodeGen/search-space-schema.json index 935ff61643..0fa6fa9590 100644 --- a/src/Microsoft.ML.AutoML/CodeGen/search-space-schema.json +++ b/src/Microsoft.ML.AutoML/CodeGen/search-space-schema.json @@ -10,12 +10,27 @@ "type": "array", "items": { "type": "integer" } }, + "imageClassificationArchArray": { + "type": "array", + "items": { + "$ref": "#/definitions/imageClassificationArchType" + } + }, "dnnModelFactoryArray": { "type": "array", "items": { "$ref": "#/definitions/dnnModelFactoryType" } }, + "imageClassificationArchType": { + "type": "string", + "enum": [ + "InceptionV3", + "MobilenetV2", + "ResnetV2101", + "ResnetV250" + ] + }, "dnnModelFactoryType": { "type": "string", "enum": [ @@ -54,6 +69,9 @@ { "$ref": "#/definitions/dnnModelFactoryArray" }, + { + "$ref": "#/definitions/imageClassificationArchArray" + }, { "$ref": "#/definitions/boolArray" }, @@ -177,8 +195,10 @@ "Sentence2ColumnName", "BatchSize", "MaxEpochs", + "Epoch", "Architecture", - "AddKeyValueAnnotationsAsText" + "AddKeyValueAnnotationsAsText", + "Arch" ] }, "option_type": { @@ -195,7 +215,8 @@ "colorsOrder", "anchor", "dnnModelFactory", - "bertArchitecture" + "bertArchitecture", + "imageClassificationArchType" ] } }, diff --git a/src/Microsoft.ML.AutoML/SweepableEstimator/Estimators/Images.cs b/src/Microsoft.ML.AutoML/SweepableEstimator/Estimators/Images.cs index 8998151e5e..ec2281f58e 100644 --- a/src/Microsoft.ML.AutoML/SweepableEstimator/Estimators/Images.cs +++ b/src/Microsoft.ML.AutoML/SweepableEstimator/Estimators/Images.cs @@ -2,6 +2,8 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. using System; +using Microsoft.ML.Vision; + namespace Microsoft.ML.AutoML.CodeGen { internal partial class LoadImages @@ -40,8 +42,16 @@ internal partial class ImageClassificationMulti { public override IEstimator BuildFromOption(MLContext context, ImageClassificationOption param) { + var option = new ImageClassificationTrainer.Options + { + Arch = param.Arch, + BatchSize = param.BatchSize, + LabelColumnName = param.LabelColumnName, + FeatureColumnName = param.FeatureColumnName, + ScoreColumnName = param.ScoreColumnName, + }; - return context.MulticlassClassification.Trainers.ImageClassification(param.LabelColumnName, param.FeatureColumnName, param.ScoreColumnName); + return context.MulticlassClassification.Trainers.ImageClassification(option); } } diff --git a/tools-local/Microsoft.ML.AutoML.SourceGenerator/SearchSpaceGenerator.cs b/tools-local/Microsoft.ML.AutoML.SourceGenerator/SearchSpaceGenerator.cs index c199a6faf0..af19e405c2 100644 --- a/tools-local/Microsoft.ML.AutoML.SourceGenerator/SearchSpaceGenerator.cs +++ b/tools-local/Microsoft.ML.AutoML.SourceGenerator/SearchSpaceGenerator.cs @@ -55,6 +55,7 @@ public void Execute(GeneratorExecutionContext context) "colorsOrder" => "ColorsOrder", "dnnModelFactory" => "string", "bertArchitecture" => "BertArchitecture", + "imageClassificationArchType" => "Microsoft.ML.Vision.ImageClassificationTrainer.Architecture", _ => throw new ArgumentException("unknown type"), }; @@ -72,6 +73,7 @@ public void Execute(GeneratorExecutionContext context) (_, "ColorBits") => defaultToken.GetValue(), (_, "ColorsOrder") => defaultToken.GetValue(), (_, "BertArchitecture") => defaultToken.GetValue(), + (_, "Microsoft.ML.Vision.ImageClassificationTrainer.Architecture") => defaultToken.GetValue(), (_, _) => throw new ArgumentException("unknown"), }; diff --git a/tools-local/Microsoft.ML.AutoML.SourceGenerator/Template/SearchSpace.cs b/tools-local/Microsoft.ML.AutoML.SourceGenerator/Template/SearchSpace.cs index 47a6a89a94..be05d616f2 100644 --- a/tools-local/Microsoft.ML.AutoML.SourceGenerator/Template/SearchSpace.cs +++ b/tools-local/Microsoft.ML.AutoML.SourceGenerator/Template/SearchSpace.cs @@ -33,7 +33,7 @@ public virtual string TransformText() using ResizingKind = Microsoft.ML.Transforms.Image.ImageResizingEstimator.ResizingKind; using Anchor = Microsoft.ML.Transforms.Image.ImageResizingEstimator.Anchor; using BertArchitecture = Microsoft.ML.TorchSharp.NasBert.BertArchitecture; - +using static Microsoft.ML.Vision.ImageClassificationTrainer.Architecture; #nullable enable namespace "); diff --git a/tools-local/Microsoft.ML.AutoML.SourceGenerator/Template/SearchSpace.tt b/tools-local/Microsoft.ML.AutoML.SourceGenerator/Template/SearchSpace.tt index 021508c3f8..b64990f3f2 100644 --- a/tools-local/Microsoft.ML.AutoML.SourceGenerator/Template/SearchSpace.tt +++ b/tools-local/Microsoft.ML.AutoML.SourceGenerator/Template/SearchSpace.tt @@ -11,7 +11,7 @@ using ColorBits = Microsoft.ML.Transforms.Image.ImagePixelExtractingEstimator.Co using ResizingKind = Microsoft.ML.Transforms.Image.ImageResizingEstimator.ResizingKind; using Anchor = Microsoft.ML.Transforms.Image.ImageResizingEstimator.Anchor; using BertArchitecture = Microsoft.ML.TorchSharp.NasBert.BertArchitecture; - +using static Microsoft.ML.Vision.ImageClassificationTrainer.Architecture; #nullable enable namespace <#=NameSpace#>