|
| 1 | +// <SnippetAddUsings> |
| 2 | +using System; |
| 3 | +using System.Collections.Generic; |
| 4 | +using System.IO; |
| 5 | +using System.Linq; |
| 6 | +using Microsoft.Data.DataView; |
| 7 | +using Microsoft.ML; |
| 8 | +using Microsoft.ML.Core.Data; |
| 9 | +using Microsoft.ML.Data; |
| 10 | +using Microsoft.ML.ImageAnalytics; |
| 11 | +using Microsoft.ML.Trainers; |
| 12 | +// </SnippetAddUsings> |
| 13 | + |
| 14 | +namespace TransferLearningTF |
| 15 | +{ |
| 16 | + class Program |
| 17 | + { |
| 18 | + // <SnippetDeclareGlobalVariables> |
| 19 | + static readonly string _assetsPath = Path.Combine(Environment.CurrentDirectory, "assets"); |
| 20 | + static readonly string _trainTagsTsv = Path.Combine(_assetsPath, "inputs-train", "data", "tags.tsv"); |
| 21 | + static readonly string _predictTagsTsv = Path.Combine(_assetsPath, "inputs-predict", "data", "tags.tsv"); |
| 22 | + static readonly string _trainImagesFolder = Path.Combine(_assetsPath, "inputs-train", "data"); |
| 23 | + static readonly string _predictImagesFolder = Path.Combine(_assetsPath, "inputs-predict", "data"); |
| 24 | + static readonly string _inceptionPb = Path.Combine(_assetsPath, "inputs-train", "inception", "tensorflow_inception_graph.pb"); |
| 25 | + static readonly string _inputImageClassifierZip = Path.Combine(_assetsPath, "inputs-predict", "imageClassifier.zip"); |
| 26 | + static readonly string _outputImageClassifierZip = Path.Combine(_assetsPath, "outputs", "imageClassifier.zip"); |
| 27 | + private static string LabelTokey = nameof(LabelTokey); |
| 28 | + private static string ImageReal = nameof(ImageReal); |
| 29 | + private static string PredictedLabelValue = nameof(PredictedLabelValue); |
| 30 | + // </SnippetDeclareGlobalVariables> |
| 31 | + |
| 32 | + static void Main(string[] args) |
| 33 | + { |
| 34 | + // Create MLContext to be shared across the model creation workflow objects |
| 35 | + // <SnippetCreateMLContext> |
| 36 | + MLContext mlContext = new MLContext(seed:1); |
| 37 | + // </SnippetCreateMLContext> |
| 38 | + |
| 39 | + // <SnippetCallReuseAndTuneInceptionModel> |
| 40 | + ReuseAndTuneInceptionModel(mlContext, _trainTagsTsv, _trainImagesFolder, _inceptionPb, _outputImageClassifierZip); |
| 41 | + // </CallSnippetReuseAndTuneInceptionModel> |
| 42 | + |
| 43 | + // <SnippetCallClassifyImages> |
| 44 | + ClassifyImages(mlContext, _predictTagsTsv, _predictImagesFolder, _outputImageClassifierZip); |
| 45 | + // </SnippetCallClassifyImages> |
| 46 | + } |
| 47 | + |
| 48 | + // <SnippetInceptionSettings> |
| 49 | + private struct InceptionSettings |
| 50 | + { |
| 51 | + public const int ImageHeight = 224; |
| 52 | + public const int ImageWidth = 224; |
| 53 | + public const float Mean = 117; |
| 54 | + public const float Scale = 1; |
| 55 | + public const bool ChannelsLast = true; |
| 56 | + } |
| 57 | + // </SnippetInceptionSettings> |
| 58 | + |
| 59 | + // Build and train model |
| 60 | + public static void ReuseAndTuneInceptionModel(MLContext mlContext, string dataLocation, string imagesFolder, string inputModelLocation, string outputModelLocation) |
| 61 | + { |
| 62 | + |
| 63 | + Console.WriteLine("Read model"); |
| 64 | + Console.WriteLine($"Model location: {inputModelLocation}"); |
| 65 | + Console.WriteLine($"Images folder: {_trainImagesFolder}"); |
| 66 | + Console.WriteLine($"Training file: {dataLocation}"); |
| 67 | + Console.WriteLine($"Default parameters: image size=({InceptionSettings.ImageWidth},{InceptionSettings.ImageHeight}), image mean: {InceptionSettings.Mean}"); |
| 68 | + |
| 69 | + // <SnippetLoadData> |
| 70 | + var data = mlContext.Data.ReadFromTextFile<ImageData>(path: dataLocation, hasHeader: true); |
| 71 | + // </SnippetLoadData> |
| 72 | + |
| 73 | + // <SnippetMapValueToKey1> |
| 74 | + var estimator = mlContext.Transforms.Conversion.MapValueToKey(outputColumnName: LabelTokey, inputColumnName: DefaultColumnNames.Label) |
| 75 | + // </SnippetMapValueToKey1> |
| 76 | + // The image transforms transform the images into the model's expected format. |
| 77 | + // <SnippetImageTransforms> |
| 78 | + .Append(mlContext.Transforms.LoadImages(_trainImagesFolder, (ImageReal, nameof(ImageData.ImagePath)))) |
| 79 | + .Append(mlContext.Transforms.Resize(outputColumnName: ImageReal, imageWidth: InceptionSettings.ImageWidth, imageHeight: InceptionSettings.ImageHeight, inputColumnName: ImageReal)) |
| 80 | + .Append(mlContext.Transforms.ExtractPixels(new ImagePixelExtractorTransformer.ColumnInfo(name: "input", inputColumnName: ImageReal, interleave: InceptionSettings.ChannelsLast, offset: InceptionSettings.Mean))) |
| 81 | + // </SnippetImageTransforms> |
| 82 | + // The ScoreTensorFlowModel transform scores the TensorFlow model and allows communication |
| 83 | + // <SnippetScoreTensorFlowModel> |
| 84 | + .Append(mlContext.Transforms.ScoreTensorFlowModel(modelLocation: inputModelLocation, outputColumnNames: new[] { "softmax2_pre_activation" }, inputColumnNames: new[] { "input" })) |
| 85 | + // </SnippetScoreTensorFlowModel> |
| 86 | + // <SnippetAddTrainer> |
| 87 | + .Append(mlContext.MulticlassClassification.Trainers.LogisticRegression(labelColumn: LabelTokey, featureColumn: "softmax2_pre_activation")) |
| 88 | + // </SnippetAddTrainer> |
| 89 | + // <SnippetMapValueToKey2> |
| 90 | + .Append(mlContext.Transforms.Conversion.MapKeyToValue((PredictedLabelValue, DefaultColumnNames.PredictedLabel))); |
| 91 | + // </SnippetMapValueToKey2> |
| 92 | + |
| 93 | + // Train the model |
| 94 | + Console.WriteLine("=============== Training classification model ==============="); |
| 95 | + // Create and train the model based on the dataset that has been loaded, transformed. |
| 96 | + // <SnippetTrainModel> |
| 97 | + ITransformer model = estimator.Fit(data); |
| 98 | + // </SnippetTrainModel> |
| 99 | + |
| 100 | + // Process the training data through the model |
| 101 | + // This is an optional step, but it's useful for debugging issues |
| 102 | + // <SnippetTransformData> |
| 103 | + var predictions = model.Transform(data); |
| 104 | + // </SnippetTransformData> |
| 105 | + |
| 106 | + // Create enumerables for both the ImageData and ImagePrediction DataViews |
| 107 | + // for displaying results |
| 108 | + // <SnippetEnumerateDataViews> |
| 109 | + var imageData = mlContext.CreateEnumerable<ImageData>(data, false, true); |
| 110 | + var imagePredictionData = mlContext.CreateEnumerable<ImagePrediction>(predictions, false, true); |
| 111 | + // </SnippetEnumerateDataViews> |
| 112 | + |
| 113 | + // Read the tags.tsv file and add the filepath to the image file name |
| 114 | + // before loading into ImageData |
| 115 | + // <SnippetCallPairAndDisplayResults1> |
| 116 | + PairAndDisplayResults(imageData, imagePredictionData); |
| 117 | + // </SnippetCallPairAndDisplayResults1> |
| 118 | + |
| 119 | + // Get some performance metrics on the model using training data |
| 120 | + Console.WriteLine("=============== Classification metrics ==============="); |
| 121 | + |
| 122 | + // <SnippetEvaluate> |
| 123 | + var regressionContext = new MulticlassClassificationCatalog(mlContext); |
| 124 | + var metrics = regressionContext.Evaluate(predictions, label: LabelTokey, predictedLabel: DefaultColumnNames.PredictedLabel); |
| 125 | + // </SnippetEvaluate> |
| 126 | + |
| 127 | + //<SnippetDisplayMetrics> |
| 128 | + Console.WriteLine($"LogLoss is: {metrics.LogLoss}"); |
| 129 | + Console.WriteLine($"PerClassLogLoss is: {String.Join(" , ", metrics.PerClassLogLoss.Select(c => c.ToString()))}"); |
| 130 | + //</SnippetDisplayMetrics> |
| 131 | + |
| 132 | + // Save the model to assets/outputs |
| 133 | + Console.WriteLine("=============== Save model to local file ==============="); |
| 134 | + |
| 135 | + // <SnippetSaveModel> |
| 136 | + using (var fileStream = new FileStream(outputModelLocation, FileMode.Create)) |
| 137 | + mlContext.Model.Save(model, fileStream); |
| 138 | + // </SnippetSaveModel> |
| 139 | + |
| 140 | + Console.WriteLine($"Model saved: {outputModelLocation}"); |
| 141 | + } |
| 142 | + |
| 143 | + public static void ClassifyImages(MLContext mlContext, string dataLocation, string imagesFolder, string outputModelLocation) |
| 144 | + { |
| 145 | + Console.WriteLine($"=============== Loading model ==============="); |
| 146 | + Console.WriteLine($"Model loaded: {outputModelLocation}"); |
| 147 | + |
| 148 | + // Load the model |
| 149 | + // <SnippetLoadModel> |
| 150 | + ITransformer loadedModel; |
| 151 | + using (var fileStream = new FileStream(outputModelLocation, FileMode.Open)) |
| 152 | + loadedModel = mlContext.Model.Load(fileStream); |
| 153 | + // </SnippetLoadModel> |
| 154 | + |
| 155 | + // Read the tags.tsv file and add the filepath to the image file name |
| 156 | + // before loading into ImageData |
| 157 | + // <SnippetReadFromTSV> |
| 158 | + var imageData = ReadFromTsv(dataLocation, imagesFolder); |
| 159 | + var imageDataView = mlContext.Data.ReadFromEnumerable<ImageData>(imageData); |
| 160 | + // </SnippetReadFromTSV> |
| 161 | + |
| 162 | + // <SnippetPredict> |
| 163 | + var predictions = loadedModel.Transform(imageDataView); |
| 164 | + var imagePredictionData = mlContext.CreateEnumerable<ImagePrediction>(predictions, false,true); |
| 165 | + // </SnippetPredict> |
| 166 | + |
| 167 | + Console.WriteLine("=============== Making classifications ==============="); |
| 168 | + // <SnippetCallPairAndDisplayResults2> |
| 169 | + PairAndDisplayResults(imageData, imagePredictionData); |
| 170 | + // </SnippetCallPairAndDisplayResults2> |
| 171 | + |
| 172 | + } |
| 173 | + |
| 174 | + private static void PairAndDisplayResults(IEnumerable<ImageData> imageNetData, IEnumerable<ImagePrediction> imageNetPredictionData) |
| 175 | + { |
| 176 | + // Builds pairs of (image, prediction) to sync up for display |
| 177 | + // <SnippetBuildImagePredictionPairs> |
| 178 | + IEnumerable<(ImageData image, ImagePrediction prediction)> imagesAndPredictions = imageNetData.Zip(imageNetPredictionData, (image, prediction) => (image, prediction)); |
| 179 | + // </SnippetBuildImagePredictionPairs> |
| 180 | + |
| 181 | + // <SnippetDisplayPredictions> |
| 182 | + foreach ((ImageData image, ImagePrediction prediction) item in imagesAndPredictions) |
| 183 | + { |
| 184 | + Console.WriteLine($"Image: {Path.GetFileName(item.image.ImagePath)} predicted as: {item.prediction.PredictedLabelValue} with score: {item.prediction.Score.Max()} "); |
| 185 | + } |
| 186 | + // </SnippetDisplayPredictions> |
| 187 | + } |
| 188 | + |
| 189 | + public static IEnumerable<ImageData> ReadFromTsv(string file, string folder) |
| 190 | + { |
| 191 | + //Need to parse through the tags.tsv file to combine the file path to the |
| 192 | + // image name for the ImagePath property so that the image file can be found. |
| 193 | + |
| 194 | + // <SnippetReadFromTsv> |
| 195 | + return File.ReadAllLines(file) |
| 196 | + .Select(line => line.Split('\t')) |
| 197 | + .Select(line => new ImageData() |
| 198 | + { |
| 199 | + ImagePath = Path.Combine(folder, line[0]), |
| 200 | + Label = line[1], |
| 201 | + }); |
| 202 | + // </SnippetReadFromTsv> |
| 203 | + } |
| 204 | + } |
| 205 | + |
| 206 | +} |
0 commit comments