@@ -117,6 +117,7 @@ public abstract class GraphTrainableModel(vararg layers: Layer) : TrainableModel
117
117
* Returns a list of layer variables in this model.
118
118
*/
119
119
private fun layerVariables (): List <KVariable > = layers.variables()
120
+
120
121
/* *
121
122
* Returns a list of non-trainable, 'frozen' layer variables in this model.
122
123
*/
@@ -327,7 +328,7 @@ public abstract class GraphTrainableModel(vararg layers: Layer) : TrainableModel
327
328
val averageTrainingMetricAccum = FloatArray (metrics.size) { 0.0f }
328
329
329
330
while (batchIter.hasNext() && ! stopTraining) {
330
- fitCallbacks.forEach { it.onTrainBatchBegin(batchCounter, trainBatchSize, trainingHistory)}
331
+ fitCallbacks.forEach { it.onTrainBatchBegin(batchCounter, trainBatchSize, trainingHistory) }
331
332
val batch: DataBatch = batchIter.next()
332
333
333
334
val (xBatchShape, yBatchShape) = calculateXYShapes(batch)
@@ -370,12 +371,14 @@ public abstract class GraphTrainableModel(vararg layers: Layer) : TrainableModel
370
371
// TODO: create map (metric name and metric value)
371
372
logger.debug { " Batch stat: { lossValue: $lossValue metricValues: $metricValues }" }
372
373
373
- fitCallbacks.forEach { it.onTrainBatchEnd(
374
- batchCounter,
375
- trainBatchSize,
376
- batchTrainingEvent,
377
- trainingHistory
378
- ) }
374
+ fitCallbacks.forEach {
375
+ it.onTrainBatchEnd(
376
+ batchCounter,
377
+ trainBatchSize,
378
+ batchTrainingEvent,
379
+ trainingHistory
380
+ )
381
+ }
379
382
}
380
383
}
381
384
}
@@ -384,23 +387,29 @@ public abstract class GraphTrainableModel(vararg layers: Layer) : TrainableModel
384
387
}
385
388
386
389
val avgTrainingMetricValue = FloatArray (metrics.size) { 0.0f }
387
- averageTrainingMetricAccum.forEachIndexed { index, metricValue -> avgTrainingMetricValue[index] = metricValue / batchCounter}
390
+ averageTrainingMetricAccum.forEachIndexed { index, metricValue ->
391
+ avgTrainingMetricValue[index] = metricValue / batchCounter
392
+ }
388
393
389
394
val avgLossValue = (averageTrainingLossAccum / batchCounter)
390
395
391
396
val nanList = mutableListOf<Double >()
392
- for (j in 1 .. metrics.size) {
397
+ for (j in 1 .. metrics.size) {
393
398
nanList.add(Double .NaN )
394
399
}
395
400
396
401
val epochTrainingEvent = EpochTrainingEvent (
397
402
i,
398
- avgLossValue.toDouble(), avgTrainingMetricValue.map { it.toDouble() }.toMutableList(), Double .NaN , nanList
403
+ avgLossValue.toDouble(),
404
+ avgTrainingMetricValue.map { it.toDouble() }.toMutableList(),
405
+ Double .NaN ,
406
+ nanList
399
407
)
400
408
401
409
if (validationIsEnabled) {
402
410
val evaluationResult = evaluate(validationDataset!! , validationBatchSize!! , listOf ())
403
- val validationMetricValues = metrics.map { evaluationResult.metrics[Metrics .convertBack(it)] }.toList()// TODO: probably I should it by name, not by type
411
+ val validationMetricValues = metrics.map { evaluationResult.metrics[Metrics .convertBack(it)] }.toList()
412
+ // TODO: probably I should it by name, not by type
404
413
val validationLossValue = evaluationResult.lossValue
405
414
epochTrainingEvent.valLossValue = validationLossValue
406
415
epochTrainingEvent.valMetricValues = validationMetricValues!!
@@ -453,7 +462,7 @@ public abstract class GraphTrainableModel(vararg layers: Layer) : TrainableModel
453
462
val metricValues = mutableListOf<Float >()
454
463
455
464
check(tensorList.size == metricOps.size + 1 ) { " ${metricOps.size} metrics are monitored, but ${tensorList.size - 1 } metrics are returned!" }
456
- for (i in 1 .. metricOps.size) {
465
+ for (i in 1 .. metricOps.size) {
457
466
metricValues.add(tensorList[i].floatValue())
458
467
}
459
468
@@ -514,7 +523,7 @@ public abstract class GraphTrainableModel(vararg layers: Layer) : TrainableModel
514
523
val metricValues = mutableListOf<Float >()
515
524
516
525
check(lossAndMetricsTensors.size == metricOps.size + 1 ) { " ${metricOps.size} metrics are monitored, but ${lossAndMetricsTensors.size - 1 } metrics are returned!" }
517
- for (i in 1 .. metricOps.size) {
526
+ for (i in 1 .. metricOps.size) {
518
527
metricValues.add(lossAndMetricsTensors[i].floatValue())
519
528
}
520
529
@@ -523,10 +532,13 @@ public abstract class GraphTrainableModel(vararg layers: Layer) : TrainableModel
523
532
averageMetricAccum[i] + = metricValues[i]
524
533
}
525
534
526
- val batchEvent = BatchEvent (batchCounter, lossValue.toDouble(), averageMetricAccum.map { it.toDouble() })
535
+ val batchEvent = BatchEvent (batchCounter, lossValue.toDouble(),
536
+ averageMetricAccum.map { it.toDouble() })
527
537
evaluationHistory.appendBatch(batchEvent)
528
538
529
- callbacks.forEach { it.onTestBatchEnd(batchCounter, batchSize, batchEvent, evaluationHistory) }
539
+ callbacks.forEach {
540
+ it.onTestBatchEnd(batchCounter, batchSize, batchEvent, evaluationHistory)
541
+ }
530
542
}
531
543
}
532
544
@@ -537,7 +549,7 @@ public abstract class GraphTrainableModel(vararg layers: Layer) : TrainableModel
537
549
}
538
550
539
551
val avgMetricValue = FloatArray (metrics.size) { 0.0f }
540
- averageMetricAccum.forEachIndexed { index, metricValue -> avgMetricValue[index] = metricValue / batchCounter}
552
+ averageMetricAccum.forEachIndexed { index, metricValue -> avgMetricValue[index] = metricValue / batchCounter }
541
553
542
554
val avgLossValue = (averageLossAccum / batchCounter).toDouble()
543
555
0 commit comments