Skip to content

Commit

Permalink
add batchsize and arch to imageClassificationSweepableTrainer
Browse files Browse the repository at this point in the history
  • Loading branch information
LittleLittleCloud committed Mar 14, 2023
1 parent f5776b0 commit 0fc2d86
Show file tree
Hide file tree
Showing 6 changed files with 53 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,21 @@
"name": "FeatureColumnName",
"type": "string",
"default": "Feature"
},
{
"name": "Arch",
"type": "imageClassificationArchType",
"default": "ResnetV250"
},
{
"name": "BatchSize",
"type": "integer",
"default": 10
},
{
"name": "Epoch",
"type": "integer",
"default": 200
}
]
}
25 changes: 23 additions & 2 deletions src/Microsoft.ML.AutoML/CodeGen/search-space-schema.json
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,27 @@
"type": "array",
"items": { "type": "integer" }
},
"imageClassificationArchArray": {
"type": "array",
"items": {
"$ref": "#/definitions/imageClassificationArchType"
}
},
"dnnModelFactoryArray": {
"type": "array",
"items": {
"$ref": "#/definitions/dnnModelFactoryType"
}
},
"imageClassificationArchType": {
"type": "string",
"enum": [
"InceptionV3",
"MobilenetV2",
"ResnetV2101",
"ResnetV250"
]
},
"dnnModelFactoryType": {
"type": "string",
"enum": [
Expand Down Expand Up @@ -54,6 +69,9 @@
{
"$ref": "#/definitions/dnnModelFactoryArray"
},
{
"$ref": "#/definitions/imageClassificationArchArray"
},
{
"$ref": "#/definitions/boolArray"
},
Expand Down Expand Up @@ -177,8 +195,10 @@
"Sentence2ColumnName",
"BatchSize",
"MaxEpochs",
"Epoch",
"Architecture",
"AddKeyValueAnnotationsAsText"
"AddKeyValueAnnotationsAsText",
"Arch"
]
},
"option_type": {
Expand All @@ -195,7 +215,8 @@
"colorsOrder",
"anchor",
"dnnModelFactory",
"bertArchitecture"
"bertArchitecture",
"imageClassificationArchType"
]
}
},
Expand Down
12 changes: 11 additions & 1 deletion src/Microsoft.ML.AutoML/SweepableEstimator/Estimators/Images.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
using System;
using Microsoft.ML.Vision;

namespace Microsoft.ML.AutoML.CodeGen
{
internal partial class LoadImages
Expand Down Expand Up @@ -40,8 +42,16 @@ internal partial class ImageClassificationMulti
{
public override IEstimator<ITransformer> BuildFromOption(MLContext context, ImageClassificationOption param)
{
var option = new ImageClassificationTrainer.Options
{
Arch = param.Arch,
BatchSize = param.BatchSize,
LabelColumnName = param.LabelColumnName,
FeatureColumnName = param.FeatureColumnName,
ScoreColumnName = param.ScoreColumnName,
};

return context.MulticlassClassification.Trainers.ImageClassification(param.LabelColumnName, param.FeatureColumnName, param.ScoreColumnName);
return context.MulticlassClassification.Trainers.ImageClassification(option);
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ public void Execute(GeneratorExecutionContext context)
"colorsOrder" => "ColorsOrder",
"dnnModelFactory" => "string",
"bertArchitecture" => "BertArchitecture",
"imageClassificationArchType" => "Microsoft.ML.Vision.ImageClassificationTrainer.Architecture",
_ => throw new ArgumentException("unknown type"),
};
Expand All @@ -72,6 +73,7 @@ public void Execute(GeneratorExecutionContext context)
(_, "ColorBits") => defaultToken.GetValue<string>(),
(_, "ColorsOrder") => defaultToken.GetValue<string>(),
(_, "BertArchitecture") => defaultToken.GetValue<string>(),
(_, "Microsoft.ML.Vision.ImageClassificationTrainer.Architecture") => defaultToken.GetValue<string>(),
(_, _) => throw new ArgumentException("unknown"),
};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ public virtual string TransformText()
using ResizingKind = Microsoft.ML.Transforms.Image.ImageResizingEstimator.ResizingKind;
using Anchor = Microsoft.ML.Transforms.Image.ImageResizingEstimator.Anchor;
using BertArchitecture = Microsoft.ML.TorchSharp.NasBert.BertArchitecture;
using static Microsoft.ML.Vision.ImageClassificationTrainer.Architecture;
#nullable enable
namespace ");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ using ColorBits = Microsoft.ML.Transforms.Image.ImagePixelExtractingEstimator.Co
using ResizingKind = Microsoft.ML.Transforms.Image.ImageResizingEstimator.ResizingKind;
using Anchor = Microsoft.ML.Transforms.Image.ImageResizingEstimator.Anchor;
using BertArchitecture = Microsoft.ML.TorchSharp.NasBert.BertArchitecture;

using static Microsoft.ML.Vision.ImageClassificationTrainer.Architecture;
#nullable enable

namespace <#=NameSpace#>
Expand Down

0 comments on commit 0fc2d86

Please sign in to comment.