Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added more logging to OBJ-DET #6646

Merged
merged 2 commits into from
May 4, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 32 additions & 11 deletions src/Microsoft.ML.TorchSharp/AutoFormerV2/ObjectDetectionTrainer.cs
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,11 @@ public sealed class Options : TransformInputBase
/// Gets or sets the weight decay in optimizer.
/// </summary>
public double WeightDecay = 0.0;

/// <summary>
/// How often to log the loss.
/// </summary>
public int LogEveryNStep = 50;
}

private protected readonly IHost Host;
Expand All @@ -122,7 +127,7 @@ public sealed class Options : TransformInputBase

internal ObjectDetectionTrainer(IHostEnvironment env, Options options)
{
Host = Contracts.CheckRef(env, nameof(env)).Register(nameof(NasBertTrainer));
Host = Contracts.CheckRef(env, nameof(env)).Register(nameof(ObjectDetectionTrainer));
Contracts.Assert(options.MaxEpoch > 0);
Contracts.AssertValue(options.BoundingBoxColumnName);
Contracts.AssertValue(options.LabelColumnName);
Expand Down Expand Up @@ -163,14 +168,21 @@ public ObjectDetectionTransformer Fit(IDataView input)
using (var ch = Host.Start("TrainModel"))
using (var pch = Host.StartProgressChannel("Training model"))
{
var header = new ProgressHeader(new[] { "Accuracy" }, null);
var header = new ProgressHeader(new[] { "Loss" }, new[] { "total images" });

var trainer = new Trainer(this, ch, input);
pch.SetHeader(header, e => e.SetMetric(0, trainer.Accuracy));
pch.SetHeader(header,
e =>
{
e.SetProgress(0, trainer.Updates, trainer.RowCount);
e.SetMetric(0, trainer.LossValue);
});

for (int i = 0; i < Option.MaxEpoch; i++)
{
ch.Trace($"Starting epoch {i}");
Host.CheckAlive();
trainer.Train(Host, input);
trainer.Train(Host, input, pch);
ch.Trace($"Finished epoch {i}");
}
var labelCol = input.Schema.GetColumnOrNull(Option.LabelColumnName);
Expand All @@ -191,17 +203,19 @@ internal class Trainer
protected readonly ObjectDetectionTrainer Parent;
public FocalLoss Loss;
public int Updates;
public float Accuracy;
public float LossValue;
public readonly int RowCount;
private readonly IChannel _channel;

public Trainer(ObjectDetectionTrainer parent, IChannel ch, IDataView input)
{
Parent = parent;
Updates = 0;
Accuracy = 0;

LossValue = 0;
_channel = ch;

// Get row count and figure out num of unique labels
var rowCount = GetRowCountAndSetLabelCount(input);
RowCount = GetRowCountAndSetLabelCount(input);
Device = TorchUtils.InitializeDevice(Parent.Host);

// Initialize the model and load pre-trained weights
Expand Down Expand Up @@ -274,7 +288,7 @@ private string GetModelPath()
return relativeFilePath;
}

public void Train(IHost host, IDataView input)
public void Train(IHost host, IDataView input, IProgressChannel pch)
{
// Get the cursor and the correct columns based on the inputs
DataViewRowCursor cursor = input.GetRowCursor(input.Schema[Parent.Option.LabelColumnName], input.Schema[Parent.Option.BoundingBoxColumnName], input.Schema[Parent.Option.ImageColumnName]);
Expand Down Expand Up @@ -302,7 +316,7 @@ public void Train(IHost host, IDataView input)

while (cursorValid)
{
cursorValid = TrainStep(host, cursor, boundingBoxGetter, imageGetter, labelGetter);
cursorValid = TrainStep(host, cursor, boundingBoxGetter, imageGetter, labelGetter, pch);
}

LearningRateScheduler.step();
Expand All @@ -312,7 +326,8 @@ private bool TrainStep(IHost host,
DataViewRowCursor cursor,
ValueGetter<VBuffer<float>> boundingBoxGetter,
ValueGetter<MLImage> imageGetter,
ValueGetter<VBuffer<uint>> labelGetter)
ValueGetter<VBuffer<uint>> labelGetter,
IProgressChannel pch)
{
using var disposeScope = torch.NewDisposeScope();
var cursorValid = true;
Expand Down Expand Up @@ -343,6 +358,12 @@ private bool TrainStep(IHost host,
Optimizer.step();
host.CheckAlive();

if (Updates % Parent.Option.LogEveryNStep == 0)
{
pch.Checkpoint(lossValue.ToDouble(), Updates);
_channel.Info($"Row: {Updates}, Loss: {lossValue.ToDouble()}");
}

return cursorValid;
}

Expand Down
20 changes: 17 additions & 3 deletions test/Microsoft.ML.Tests/ObjectDetectionTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,13 @@
using System.Linq;
using Microsoft.ML.Data;
using Microsoft.ML.RunTests;
using Microsoft.ML.Transforms.Image;
using Microsoft.VisualBasic;
using Microsoft.ML.TorchSharp;
using Xunit;
using Xunit.Abstractions;
using Microsoft.ML.TorchSharp.AutoFormerV2;
using Microsoft.ML.Runtime;
using System.Collections.Generic;

namespace Microsoft.ML.Tests
{
Expand Down Expand Up @@ -50,13 +51,13 @@ public void SimpleObjDetectionTest()
.Append(ML.MulticlassClassification.Trainers.ObjectDetection("Labels", boundingBoxColumnName: "Box", maxEpoch: 1))
.Append(ML.Transforms.Conversion.MapKeyToValue("PredictedLabel"));


var options = new ObjectDetectionTrainer.Options()
{
LabelColumnName = "Labels",
BoundingBoxColumnName = "Box",
ScoreThreshold = .5,
MaxEpoch = 1
MaxEpoch = 1,
LogEveryNStep = 1,
};

var pipeline = ML.Transforms.Text.TokenizeIntoWords("Labels", separators: new char[] { ',' })
Expand All @@ -67,13 +68,26 @@ public void SimpleObjDetectionTest()
.Append(ML.MulticlassClassification.Trainers.ObjectDetection(options))
.Append(ML.Transforms.Conversion.MapKeyToValue("PredictedLabel"));

var logs = new List<LoggingEventArgs>();

ML.Log += (o, e) =>
{
if (e.Source.StartsWith("ObjectDetectionTrainer") && e.Kind == ChannelMessageKind.Info && e.Message.Contains("Loss:"))
{
logs.Add(e);
}
};

var model = pipeline.Fit(data);
var idv = model.Transform(data);
// Make sure the metrics work.
var metrics = ML.MulticlassClassification.EvaluateObjectDetection(idv, idv.Schema[2], idv.Schema["Box"], idv.Schema["PredictedLabel"], idv.Schema["PredictedBoundingBoxes"], idv.Schema["Score"]);
Assert.True(!float.IsNaN(metrics.MAP50));
Assert.True(!float.IsNaN(metrics.MAP50_95));

// We aren't doing enough training to get a consistent loss, so just make sure its being logged
Assert.True(logs.Count > 0);

// Make sure the filtered pipeline can run without any columns but image column AFTER training
var dataFiltered = TextLoader.Create(ML, new TextLoader.Options()
{
Expand Down
Binary file modified test/data/images/object-detection/fruit0.png
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You'll have to resize the boxes too :)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Did you try adding resize code to the test instead?

Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified test/data/images/object-detection/fruit1.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified test/data/images/object-detection/fruit10.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified test/data/images/object-detection/fruit100.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified test/data/images/object-detection/fruit101.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified test/data/images/object-detection/fruit102.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified test/data/images/object-detection/fruit103.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified test/data/images/object-detection/fruit104.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified test/data/images/object-detection/fruit105.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified test/data/images/object-detection/fruit106.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.