Skip to content

Commit e8111f5

Browse files
committed
Reformat code and optimize imports
1 parent e77595b commit e8111f5

File tree

83 files changed

+400
-247
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

83 files changed

+400
-247
lines changed

api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/core/Functional.kt

+4-3
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,8 @@ package org.jetbrains.kotlinx.dl.api.core
77

88
import org.jetbrains.kotlinx.dl.api.core.layer.Layer
99
import org.jetbrains.kotlinx.dl.api.core.layer.core.Input
10-
import org.jetbrains.kotlinx.dl.api.core.layer.weights
1110
import org.jetbrains.kotlinx.dl.api.core.layer.freeze
11+
import org.jetbrains.kotlinx.dl.api.core.layer.weights
1212
import org.jetbrains.kotlinx.dl.api.core.util.sortTopologically
1313
import org.jetbrains.kotlinx.dl.api.inference.keras.*
1414
import org.tensorflow.Operand
@@ -90,13 +90,14 @@ public class Functional(vararg layers: Layer) : GraphTrainableModel(*layers) {
9090
layers += pretrainedLayers
9191

9292
val topLayers = topModel.layers
93-
layers+= topLayers
93+
layers += topLayers
9494
topLayers[0].inboundLayers.add(pretrainedLayers.last())
9595

9696
if (topModel is Sequential && layers.size > 1) {
9797
// establish edges in DAG
9898
topLayers.subList(1, topLayers.size).forEachIndexed { index, layer ->
99-
val topLayersIndex = index - 1 + 1 // shift -1 to take previous, but shift +1 because it's an index in subList, started from 1
99+
val topLayersIndex = index - 1 + 1
100+
// shift -1 to take previous, but shift +1 because it's an index in subList, started from 1
100101
layer.inboundLayers.add(topLayers[topLayersIndex])
101102
}
102103
}

api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/core/GraphTrainableModel.kt

+28-16
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,7 @@ public abstract class GraphTrainableModel(vararg layers: Layer) : TrainableModel
117117
* Returns a list of layer variables in this model.
118118
*/
119119
private fun layerVariables(): List<KVariable> = layers.variables()
120+
120121
/**
121122
* Returns a list of non-trainable, 'frozen' layer variables in this model.
122123
*/
@@ -327,7 +328,7 @@ public abstract class GraphTrainableModel(vararg layers: Layer) : TrainableModel
327328
val averageTrainingMetricAccum = FloatArray(metrics.size) { 0.0f }
328329

329330
while (batchIter.hasNext() && !stopTraining) {
330-
fitCallbacks.forEach { it.onTrainBatchBegin(batchCounter, trainBatchSize, trainingHistory)}
331+
fitCallbacks.forEach { it.onTrainBatchBegin(batchCounter, trainBatchSize, trainingHistory) }
331332
val batch: DataBatch = batchIter.next()
332333

333334
val (xBatchShape, yBatchShape) = calculateXYShapes(batch)
@@ -370,12 +371,14 @@ public abstract class GraphTrainableModel(vararg layers: Layer) : TrainableModel
370371
// TODO: create map (metric name and metric value)
371372
logger.debug { "Batch stat: { lossValue: $lossValue metricValues: $metricValues }" }
372373

373-
fitCallbacks.forEach { it.onTrainBatchEnd(
374-
batchCounter,
375-
trainBatchSize,
376-
batchTrainingEvent,
377-
trainingHistory
378-
) }
374+
fitCallbacks.forEach {
375+
it.onTrainBatchEnd(
376+
batchCounter,
377+
trainBatchSize,
378+
batchTrainingEvent,
379+
trainingHistory
380+
)
381+
}
379382
}
380383
}
381384
}
@@ -384,23 +387,29 @@ public abstract class GraphTrainableModel(vararg layers: Layer) : TrainableModel
384387
}
385388

386389
val avgTrainingMetricValue = FloatArray(metrics.size) { 0.0f }
387-
averageTrainingMetricAccum.forEachIndexed { index, metricValue -> avgTrainingMetricValue[index] = metricValue / batchCounter}
390+
averageTrainingMetricAccum.forEachIndexed { index, metricValue ->
391+
avgTrainingMetricValue[index] = metricValue / batchCounter
392+
}
388393

389394
val avgLossValue = (averageTrainingLossAccum / batchCounter)
390395

391396
val nanList = mutableListOf<Double>()
392-
for(j in 1 .. metrics.size) {
397+
for (j in 1..metrics.size) {
393398
nanList.add(Double.NaN)
394399
}
395400

396401
val epochTrainingEvent = EpochTrainingEvent(
397402
i,
398-
avgLossValue.toDouble(), avgTrainingMetricValue.map { it.toDouble() }.toMutableList(), Double.NaN, nanList
403+
avgLossValue.toDouble(),
404+
avgTrainingMetricValue.map { it.toDouble() }.toMutableList(),
405+
Double.NaN,
406+
nanList
399407
)
400408

401409
if (validationIsEnabled) {
402410
val evaluationResult = evaluate(validationDataset!!, validationBatchSize!!, listOf())
403-
val validationMetricValues = metrics.map { evaluationResult.metrics[Metrics.convertBack(it)] }.toList()// TODO: probably I should it by name, not by type
411+
val validationMetricValues = metrics.map { evaluationResult.metrics[Metrics.convertBack(it)] }.toList()
412+
// TODO: probably I should it by name, not by type
404413
val validationLossValue = evaluationResult.lossValue
405414
epochTrainingEvent.valLossValue = validationLossValue
406415
epochTrainingEvent.valMetricValues = validationMetricValues!!
@@ -453,7 +462,7 @@ public abstract class GraphTrainableModel(vararg layers: Layer) : TrainableModel
453462
val metricValues = mutableListOf<Float>()
454463

455464
check(tensorList.size == metricOps.size + 1) { "${metricOps.size} metrics are monitored, but ${tensorList.size - 1} metrics are returned!" }
456-
for (i in 1 .. metricOps.size) {
465+
for (i in 1..metricOps.size) {
457466
metricValues.add(tensorList[i].floatValue())
458467
}
459468

@@ -514,7 +523,7 @@ public abstract class GraphTrainableModel(vararg layers: Layer) : TrainableModel
514523
val metricValues = mutableListOf<Float>()
515524

516525
check(lossAndMetricsTensors.size == metricOps.size + 1) { "${metricOps.size} metrics are monitored, but ${lossAndMetricsTensors.size - 1} metrics are returned!" }
517-
for (i in 1 .. metricOps.size) {
526+
for (i in 1..metricOps.size) {
518527
metricValues.add(lossAndMetricsTensors[i].floatValue())
519528
}
520529

@@ -523,10 +532,13 @@ public abstract class GraphTrainableModel(vararg layers: Layer) : TrainableModel
523532
averageMetricAccum[i] += metricValues[i]
524533
}
525534

526-
val batchEvent = BatchEvent(batchCounter, lossValue.toDouble(), averageMetricAccum.map { it.toDouble() })
535+
val batchEvent = BatchEvent(batchCounter, lossValue.toDouble(),
536+
averageMetricAccum.map { it.toDouble() })
527537
evaluationHistory.appendBatch(batchEvent)
528538

529-
callbacks.forEach { it.onTestBatchEnd(batchCounter, batchSize, batchEvent, evaluationHistory) }
539+
callbacks.forEach {
540+
it.onTestBatchEnd(batchCounter, batchSize, batchEvent, evaluationHistory)
541+
}
530542
}
531543
}
532544

@@ -537,7 +549,7 @@ public abstract class GraphTrainableModel(vararg layers: Layer) : TrainableModel
537549
}
538550

539551
val avgMetricValue = FloatArray(metrics.size) { 0.0f }
540-
averageMetricAccum.forEachIndexed { index, metricValue -> avgMetricValue[index] = metricValue / batchCounter}
552+
averageMetricAccum.forEachIndexed { index, metricValue -> avgMetricValue[index] = metricValue / batchCounter }
541553

542554
val avgLossValue = (averageLossAccum / batchCounter).toDouble()
543555

api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/core/KGraph.kt

+2-2
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright 2020 JetBrains s.r.o. and Kotlin Deep Learning project contributors. All Rights Reserved.
2+
* Copyright 2020-2022 JetBrains s.r.o. and Kotlin Deep Learning project contributors. All Rights Reserved.
33
* Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE.txt file.
44
*/
55

@@ -78,7 +78,7 @@ public class KGraph(graphDef: ByteArray, prefix: String) : AutoCloseable {
7878

7979
while (operations.hasNext()) {
8080
val operation = operations.next() as GraphOperation
81-
if(operation.type().equals("VariableV2")) {
81+
if (operation.type().equals("VariableV2")) {
8282
variableNames.add(operation.name())
8383
}
8484
}

api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/core/Sequential.kt

+5-6
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,7 @@ import org.jetbrains.kotlinx.dl.api.core.layer.Layer
99
import org.jetbrains.kotlinx.dl.api.core.layer.core.Input
1010
import org.jetbrains.kotlinx.dl.api.core.layer.weights
1111
import org.jetbrains.kotlinx.dl.api.core.shape.TensorShape
12-
import org.jetbrains.kotlinx.dl.api.inference.keras.deserializeSequentialModel
13-
import org.jetbrains.kotlinx.dl.api.inference.keras.loadSequentialModelLayers
14-
import org.jetbrains.kotlinx.dl.api.inference.keras.loadSerializedModel
15-
import org.jetbrains.kotlinx.dl.api.inference.keras.serializeModel
12+
import org.jetbrains.kotlinx.dl.api.inference.keras.*
1613
import org.tensorflow.Operand
1714
import org.tensorflow.Shape
1815
import java.io.File
@@ -78,7 +75,7 @@ public class Sequential(vararg layers: Layer) : GraphTrainableModel(*layers) {
7875
public fun loadModelConfiguration(configuration: File, inputShape: IntArray? = null): Sequential {
7976
require(configuration.isFile) { "${configuration.absolutePath} is not a file. Should be a .json file with configuration." }
8077

81-
return org.jetbrains.kotlinx.dl.api.inference.keras.loadSequentialModelConfiguration(configuration, inputShape)
78+
return loadSequentialModelConfiguration(configuration, inputShape)
8279
}
8380

8481
/**
@@ -88,7 +85,9 @@ public class Sequential(vararg layers: Layer) : GraphTrainableModel(*layers) {
8885
* @return Pair of <input layer; list of layers>.
8986
*/
9087
@JvmStatic
91-
public fun loadModelLayersFromConfiguration(configuration: File, inputShape: IntArray? = null): Pair<Input, MutableList<Layer>> {
88+
public fun loadModelLayersFromConfiguration(configuration: File,
89+
inputShape: IntArray? = null
90+
): Pair<Input, MutableList<Layer>> {
9291
require(configuration.isFile) { "${configuration.absolutePath} is not a file. Should be a .json file with configuration." }
9392

9493
val config = loadSerializedModel(configuration)

api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/core/history/History.kt

+5-2
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright 2020 JetBrains s.r.o. and Kotlin Deep Learning project contributors. All Rights Reserved.
2+
* Copyright 2020-2022 JetBrains s.r.o. and Kotlin Deep Learning project contributors. All Rights Reserved.
33
* Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE.txt file.
44
*/
55

@@ -55,7 +55,10 @@ public class History {
5555
* @param lossValue Final value of loss function.
5656
* @param metricValues Final values of chosen metrics.
5757
*/
58-
public class BatchEvent(public val batchIndex: Int, public val lossValue: Double, public val metricValues: List<Double>) {
58+
public class BatchEvent(public val batchIndex: Int,
59+
public val lossValue: Double,
60+
public val metricValues: List<Double>
61+
) {
5962
override fun toString(): String {
6063
return "BatchEvent(batchIndex=$batchIndex, lossValue=$lossValue, metricValues=$metricValues)"
6164
}

api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/core/history/TrainingHistory.kt

+1-2
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,11 @@
11
/*
2-
* Copyright 2020-2021 JetBrains s.r.o. and Kotlin Deep Learning project contributors. All Rights Reserved.
2+
* Copyright 2020-2022 JetBrains s.r.o. and Kotlin Deep Learning project contributors. All Rights Reserved.
33
* Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE.txt file.
44
*/
55

66
package org.jetbrains.kotlinx.dl.api.core.history
77

88
import java.util.*
9-
import kotlin.reflect.KProperty1
109

1110
/**
1211
* Contains all recorded batch events as a list of [BatchTrainingEvent] objects and epoch events as a list of [EpochTrainingEvent] objects.

api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/core/initializer/Initializer.kt

+1-2
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright 2020 JetBrains s.r.o. and Kotlin Deep Learning project contributors. All Rights Reserved.
2+
* Copyright 2020-2022 JetBrains s.r.o. and Kotlin Deep Learning project contributors. All Rights Reserved.
33
* Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE.txt file.
44
*/
55

@@ -11,7 +11,6 @@ import org.jetbrains.kotlinx.dl.api.core.util.defaultAssignOpName
1111
import org.jetbrains.kotlinx.dl.api.core.util.defaultInitializerOpName
1212
import org.tensorflow.Operand
1313
import org.tensorflow.op.Ops
14-
import org.tensorflow.op.core.Assign
1514

1615
/**
1716
* Initializer base class: all initializers inherit this class.

api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/core/initializer/Orthogonal.kt

+5-2
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright 2020 JetBrains s.r.o. and Kotlin Deep Learning project contributors. All Rights Reserved.
2+
* Copyright 2020-2022 JetBrains s.r.o. and Kotlin Deep Learning project contributors. All Rights Reserved.
33
* Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE.txt file.
44
*/
55

@@ -36,7 +36,10 @@ public class Orthogonal(
3636
require(dimsShape >= 2) { "The tensor to initialize must be at least two-dimensional" }
3737

3838
// Generate a random matrix
39-
val distOpND: Operand<Float> = tf.random.statelessRandomNormal(shape, tf.constant(longArrayOf(seed, 0L)), getDType())
39+
val distOpND: Operand<Float> = tf.random.statelessRandomNormal(
40+
shape,
41+
tf.constant(longArrayOf(seed, 0L)), getDType()
42+
)
4043

4144
// Flatten the generated random matrix with the last dimension remaining
4245
// its original shape, so it works for conv2d

api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/core/layer/ParametrizedLayer.kt

+3-1
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,9 @@ internal fun List<Layer>.frozenVariables(): List<KVariable> {
5252
/**
5353
* Initializes this layers variables using provided initializer operands.
5454
*/
55-
public fun ParametrizedLayer.initialize(session: Session): Unit = variables.map { it.initializerOperation }.init(session)
55+
public fun ParametrizedLayer.initialize(session: Session) {
56+
variables.map { it.initializerOperation }.init(session)
57+
}
5658

5759
/**
5860
* Initializes variables for [ParametrizedLayer] instances using provided initializer operands.

api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/core/layer/reshaping/UpSampling2D.kt

+4-2
Original file line numberDiff line numberDiff line change
@@ -66,8 +66,10 @@ public class UpSampling2D(
6666
)
6767
return when (interpolation) {
6868
InterpolationMethod.NEAREST -> tf.image.resizeNearestNeighbor(input, newSize)
69-
InterpolationMethod.BILINEAR ->
70-
tf.image.resizeBilinear(input, newSize, ResizeBilinear.halfPixelCenters(true))
69+
InterpolationMethod.BILINEAR -> tf.image.resizeBilinear(
70+
input, newSize,
71+
ResizeBilinear.halfPixelCenters(true)
72+
)
7173
else -> throw IllegalArgumentException("The interpolation type interpolation is not supported.")
7274
}
7375
}

api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/core/optimizer/AdaDelta.kt

+1-3
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright 2020 JetBrains s.r.o. and Kotlin Deep Learning project contributors. All Rights Reserved.
2+
* Copyright 2020-2022 JetBrains s.r.o. and Kotlin Deep Learning project contributors. All Rights Reserved.
33
* Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE.txt file.
44
*/
55

@@ -15,8 +15,6 @@ import org.tensorflow.op.core.Constant
1515
import org.tensorflow.op.core.Gradients
1616
import org.tensorflow.op.core.Variable
1717
import org.tensorflow.op.train.ApplyAdadelta
18-
import org.tensorflow.op.train.ApplyAdam
19-
import java.util.*
2018

2119
private const val ACCUMULATOR = "accum"
2220
private const val ACCUMULATOR_UPDATE = "accum_update"

api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/core/optimizer/AdaGrad.kt

+1-2
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright 2020 JetBrains s.r.o. and Kotlin Deep Learning project contributors. All Rights Reserved.
2+
* Copyright 2020-2022 JetBrains s.r.o. and Kotlin Deep Learning project contributors. All Rights Reserved.
33
* Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE.txt file.
44
*/
55

@@ -15,7 +15,6 @@ import org.tensorflow.op.core.Constant
1515
import org.tensorflow.op.core.Gradients
1616
import org.tensorflow.op.core.Variable
1717
import org.tensorflow.op.train.ApplyAdagrad
18-
import java.util.*
1918

2019
private const val ACCUMULATOR = "accumulator"
2120

api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/core/optimizer/AdaGradDA.kt

+1-2
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright 2020 JetBrains s.r.o. and Kotlin Deep Learning project contributors. All Rights Reserved.
2+
* Copyright 2020-2022 JetBrains s.r.o. and Kotlin Deep Learning project contributors. All Rights Reserved.
33
* Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE.txt file.
44
*/
55

@@ -19,7 +19,6 @@ import org.tensorflow.op.core.Constant
1919
import org.tensorflow.op.core.Gradients
2020
import org.tensorflow.op.core.Variable
2121
import org.tensorflow.op.train.ApplyAdagradDa
22-
import java.util.*
2322

2423
private val GLOBAL_STEP = defaultOptimizerVariableName("adagrad-da-global-step")
2524
private const val ACCUMULATOR = "gradient_accumulator"

api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/core/optimizer/Adam.kt

+1-2
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright 2020 JetBrains s.r.o. and Kotlin Deep Learning project contributors. All Rights Reserved.
2+
* Copyright 2020-2022 JetBrains s.r.o. and Kotlin Deep Learning project contributors. All Rights Reserved.
33
* Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE.txt file.
44
*/
55

@@ -19,7 +19,6 @@ import org.tensorflow.op.core.Constant
1919
import org.tensorflow.op.core.Gradients
2020
import org.tensorflow.op.core.Variable
2121
import org.tensorflow.op.train.ApplyAdam
22-
import java.util.*
2322

2423
private const val FIRST_MOMENT = "m"
2524
private const val SECOND_MOMENT = "v"

api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/core/optimizer/Adamax.kt

+1-2
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright 2020 JetBrains s.r.o. and Kotlin Deep Learning project contributors. All Rights Reserved.
2+
* Copyright 2020-2022 JetBrains s.r.o. and Kotlin Deep Learning project contributors. All Rights Reserved.
33
* Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE.txt file.
44
*/
55

@@ -20,7 +20,6 @@ import org.tensorflow.op.core.Constant
2020
import org.tensorflow.op.core.Gradients
2121
import org.tensorflow.op.core.Variable
2222
import org.tensorflow.op.train.ApplyAdaMax
23-
import java.util.*
2423

2524
private const val FIRST_MOMENT = "m"
2625
private const val SECOND_MOMENT = "v"

api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/core/optimizer/Ftrl.kt

+1-2
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright 2020 JetBrains s.r.o. and Kotlin Deep Learning project contributors. All Rights Reserved.
2+
* Copyright 2020-2022 JetBrains s.r.o. and Kotlin Deep Learning project contributors. All Rights Reserved.
33
* Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE.txt file.
44
*/
55

@@ -15,7 +15,6 @@ import org.tensorflow.op.core.Constant
1515
import org.tensorflow.op.core.Gradients
1616
import org.tensorflow.op.core.Variable
1717
import org.tensorflow.op.train.ApplyFtrl
18-
import java.util.*
1918

2019
private const val ACCUMULATOR = "gradient_accumulator"
2120
private const val LINEAR_ACCUMULATOR = "linear_accumulator"

api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/core/optimizer/Momentum.kt

+1-2
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright 2020 JetBrains s.r.o. and Kotlin Deep Learning project contributors. All Rights Reserved.
2+
* Copyright 2020-2022 JetBrains s.r.o. and Kotlin Deep Learning project contributors. All Rights Reserved.
33
* Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE.txt file.
44
*/
55

@@ -14,7 +14,6 @@ import org.tensorflow.op.core.Constant
1414
import org.tensorflow.op.core.Gradients
1515
import org.tensorflow.op.core.Variable
1616
import org.tensorflow.op.train.ApplyMomentum
17-
import java.util.*
1817

1918
private const val MOMENTUM = "momentum"
2019

0 commit comments

Comments
 (0)