Skip to content

Commit

Permalink
Support freeze parameters
Browse files Browse the repository at this point in the history
This adds a method to support freezing and unfreezing a parameter for transfer
learning. There is also a helper to (un)freeze all paremters in a block, but
without filtering. For more advanced use cases of (un)freezing part of the
parameters in a block, it should be implemented using block.getChildren(),
block.getDirectParameters() and/or block.getParameters().
  • Loading branch information
zachgk committed Mar 24, 2022
1 parent 33c29ed commit b8594fe
Show file tree
Hide file tree
Showing 3 changed files with 86 additions and 0 deletions.
11 changes: 11 additions & 0 deletions api/src/main/java/ai/djl/nn/Block.java
Original file line number Diff line number Diff line change
Expand Up @@ -268,6 +268,17 @@ default NDList forward(
void loadParameters(NDManager manager, DataInputStream is)
throws IOException, MalformedModelException;

/**
* Freezes or unfreezes all parameters inside the block for training.
*
* @see Parameter#freeze(Boolean)
*/
default void freezeParameters(Boolean freeze) {
for (Parameter parameter : getParameters().values()) {
parameter.freeze(freeze);
}
}

/**
* Validates that actual layout matches the expected layout.
*
Expand Down
16 changes: 16 additions & 0 deletions api/src/main/java/ai/djl/nn/Parameter.java
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,22 @@ public boolean requiresGradient() {
return requiresGrad;
}

/**
* Freezes or unfreezes the parameter for training.
*
* <p>Sometimes during training, especially during transfer learning, it is typical to train
* only part of the model. For this, the freeze can be used to prevent certain parts from being
* trained.
*
* <p>This modifies the {@link #requiresGradient()} of the parameter.
*
* @param freeze whether the parameter should be frozen ({@code freeze == !requiresGradient()})
*/
public void freeze(Boolean freeze) {
requiresGrad = !freeze;
array.setRequiresGradient(requiresGrad);
}

/**
* Checks if this {@code Parameter} is initialized.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,17 @@
package ai.djl.integration.tests.training;

import ai.djl.Model;
import ai.djl.basicmodelzoo.basic.Mlp;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDArrays;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.NDManager;
import ai.djl.ndarray.types.DataType;
import ai.djl.ndarray.types.Shape;
import ai.djl.nn.Block;
import ai.djl.nn.Blocks;
import ai.djl.nn.Parameter;
import ai.djl.nn.SequentialBlock;
import ai.djl.nn.core.Linear;
import ai.djl.testing.Assertions;
import ai.djl.testing.TestRequirements;
Expand Down Expand Up @@ -73,6 +77,61 @@ public void testAutograd() {
}
}

@Test
public void testFreezeParameters() {
try (Model model = Model.newInstance("model")) {
Block blockFrozen = new Mlp(10, 10, new int[] {10});
Block blockNormal = new Mlp(10, 10, new int[] {10});
Block combined = new SequentialBlock().add(blockFrozen).add(blockNormal);
model.setBlock(combined);

TrainingConfig config =
new DefaultTrainingConfig(Loss.l2Loss())
.optInitializer(Initializer.ONES, Parameter.Type.WEIGHT);

try (Trainer trainer = model.newTrainer(config)) {
trainer.initialize(new Shape(1, 10));

blockFrozen.freezeParameters(true);

// Find total params
Float frozenVal =
blockFrozen.getParameters().valueAt(0).getArray().sum().getFloat();
Float normalVal =
blockNormal.getParameters().valueAt(0).getArray().sum().getFloat();

// Run training step
NDManager manager = trainer.getManager();
NDArray data = manager.arange(100.0f).reshape(new Shape(10, 10));
NDArray labels = manager.arange(100.0f).reshape(new Shape(10, 10));
Batch batch =
new Batch(
manager, new NDList(data), new NDList(labels), 1, null, null, 0, 1);
EasyTrain.trainBatch(trainer, batch);
trainer.step();

// Check updated total params
// The frozen one should not have changed, but normal one should
Float newFrozenVal =
blockFrozen.getParameters().valueAt(0).getArray().sum().getFloat();
Float newNormalVal =
blockNormal.getParameters().valueAt(0).getArray().sum().getFloat();
Assert.assertEquals(newFrozenVal, frozenVal);
Assert.assertNotEquals(newNormalVal, normalVal);

blockFrozen.freezeParameters(false);

// Check that unfreezing the block now makes it update
EasyTrain.trainBatch(trainer, batch);
trainer.step();

Float nowUnfrozenVal =
blockFrozen.getParameters().valueAt(0).getArray().sum().getFloat();
Assert.assertNotEquals(nowUnfrozenVal, frozenVal);
}
}
}

@Test
public void testTrain() throws IOException, TranslateException {
TestRequirements.nightly();
Expand Down

0 comments on commit b8594fe

Please sign in to comment.