@@ -114,6 +114,11 @@ public sealed class Options : TransformInputBase
114114 /// Gets or sets the weight decay in optimizer.
115115 /// </summary>
116116 public double WeightDecay = 0.0 ;
117+
118+ /// <summary>
119+ /// How often to log the loss.
120+ /// </summary>
121+ public int LogEveryNStep = 50 ;
117122 }
118123
119124 private protected readonly IHost Host ;
@@ -122,7 +127,7 @@ public sealed class Options : TransformInputBase
122127
123128 internal ObjectDetectionTrainer ( IHostEnvironment env , Options options )
124129 {
125- Host = Contracts . CheckRef ( env , nameof ( env ) ) . Register ( nameof ( NasBertTrainer ) ) ;
130+ Host = Contracts . CheckRef ( env , nameof ( env ) ) . Register ( nameof ( ObjectDetectionTrainer ) ) ;
126131 Contracts . Assert ( options . MaxEpoch > 0 ) ;
127132 Contracts . AssertValue ( options . BoundingBoxColumnName ) ;
128133 Contracts . AssertValue ( options . LabelColumnName ) ;
@@ -163,14 +168,21 @@ public ObjectDetectionTransformer Fit(IDataView input)
163168 using ( var ch = Host . Start ( "TrainModel" ) )
164169 using ( var pch = Host . StartProgressChannel ( "Training model" ) )
165170 {
166- var header = new ProgressHeader ( new [ ] { "Accuracy" } , null ) ;
171+ var header = new ProgressHeader ( new [ ] { "Loss" } , new [ ] { "total images" } ) ;
172+
167173 var trainer = new Trainer ( this , ch , input ) ;
168- pch . SetHeader ( header , e => e . SetMetric ( 0 , trainer . Accuracy ) ) ;
174+ pch . SetHeader ( header ,
175+ e =>
176+ {
177+ e . SetProgress ( 0 , trainer . Updates , trainer . RowCount ) ;
178+ e . SetMetric ( 0 , trainer . LossValue ) ;
179+ } ) ;
180+
169181 for ( int i = 0 ; i < Option . MaxEpoch ; i ++ )
170182 {
171183 ch . Trace ( $ "Starting epoch { i } ") ;
172184 Host . CheckAlive ( ) ;
173- trainer . Train ( Host , input ) ;
185+ trainer . Train ( Host , input , pch ) ;
174186 ch . Trace ( $ "Finished epoch { i } ") ;
175187 }
176188 var labelCol = input . Schema . GetColumnOrNull ( Option . LabelColumnName ) ;
@@ -191,17 +203,19 @@ internal class Trainer
191203 protected readonly ObjectDetectionTrainer Parent ;
192204 public FocalLoss Loss ;
193205 public int Updates ;
194- public float Accuracy ;
206+ public float LossValue ;
207+ public readonly int RowCount ;
208+ private readonly IChannel _channel ;
195209
196210 public Trainer ( ObjectDetectionTrainer parent , IChannel ch , IDataView input )
197211 {
198212 Parent = parent ;
199213 Updates = 0 ;
200- Accuracy = 0 ;
201-
214+ LossValue = 0 ;
215+ _channel = ch ;
202216
203217 // Get row count and figure out num of unique labels
204- var rowCount = GetRowCountAndSetLabelCount ( input ) ;
218+ RowCount = GetRowCountAndSetLabelCount ( input ) ;
205219 Device = TorchUtils . InitializeDevice ( Parent . Host ) ;
206220
207221 // Initialize the model and load pre-trained weights
@@ -274,7 +288,7 @@ private string GetModelPath()
274288 return relativeFilePath ;
275289 }
276290
277- public void Train ( IHost host , IDataView input )
291+ public void Train ( IHost host , IDataView input , IProgressChannel pch )
278292 {
279293 // Get the cursor and the correct columns based on the inputs
280294 DataViewRowCursor cursor = input . GetRowCursor ( input . Schema [ Parent . Option . LabelColumnName ] , input . Schema [ Parent . Option . BoundingBoxColumnName ] , input . Schema [ Parent . Option . ImageColumnName ] ) ;
@@ -302,7 +316,7 @@ public void Train(IHost host, IDataView input)
302316
303317 while ( cursorValid )
304318 {
305- cursorValid = TrainStep ( host , cursor , boundingBoxGetter , imageGetter , labelGetter ) ;
319+ cursorValid = TrainStep ( host , cursor , boundingBoxGetter , imageGetter , labelGetter , pch ) ;
306320 }
307321
308322 LearningRateScheduler . step ( ) ;
@@ -312,7 +326,8 @@ private bool TrainStep(IHost host,
312326 DataViewRowCursor cursor ,
313327 ValueGetter < VBuffer < float > > boundingBoxGetter ,
314328 ValueGetter < MLImage > imageGetter ,
315- ValueGetter < VBuffer < uint > > labelGetter )
329+ ValueGetter < VBuffer < uint > > labelGetter ,
330+ IProgressChannel pch )
316331 {
317332 using var disposeScope = torch . NewDisposeScope ( ) ;
318333 var cursorValid = true ;
@@ -343,6 +358,12 @@ private bool TrainStep(IHost host,
343358 Optimizer . step ( ) ;
344359 host . CheckAlive ( ) ;
345360
361+ if ( Updates % Parent . Option . LogEveryNStep == 0 )
362+ {
363+ pch . Checkpoint ( lossValue . ToDouble ( ) , Updates ) ;
364+ _channel . Info ( $ "Row: { Updates } , Loss: { lossValue . ToDouble ( ) } ") ;
365+ }
366+
346367 return cursorValid ;
347368 }
348369
0 commit comments