Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support freeze parameters #1544

Merged
merged 2 commits into from
Mar 24, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions api/src/main/java/ai/djl/nn/Block.java
Original file line number Diff line number Diff line change
Expand Up @@ -268,6 +268,18 @@ default NDList forward(
void loadParameters(NDManager manager, DataInputStream is)
throws IOException, MalformedModelException;

/**
* Freezes or unfreezes all parameters inside the block for training.
*
* @param freeze true if the parameter should be frozen
* @see Parameter#freeze(boolean)
*/
default void freezeParameters(boolean freeze) {
for (Parameter parameter : getParameters().values()) {
parameter.freeze(freeze);
}
}

/**
* Validates that actual layout matches the expected layout.
*
Expand Down
16 changes: 16 additions & 0 deletions api/src/main/java/ai/djl/nn/Parameter.java
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,22 @@ public boolean requiresGradient() {
return requiresGrad;
}

/**
* Freezes or unfreezes the parameter for training.
*
* <p>Sometimes during training, especially during transfer learning, it is typical to train
* only part of the model. For this, the freeze can be used to prevent certain parts from being
* trained.
*
* <p>This modifies the {@link #requiresGradient()} of the parameter.
*
* @param freeze true if the parameter should be frozen ({@code freeze == !requiresGradient()})
*/
public void freeze(boolean freeze) {
requiresGrad = !freeze;
array.setRequiresGradient(requiresGrad);
}

/**
* Checks if this {@code Parameter} is initialized.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,17 @@
package ai.djl.integration.tests.training;

import ai.djl.Model;
import ai.djl.basicmodelzoo.basic.Mlp;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDArrays;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.NDManager;
import ai.djl.ndarray.types.DataType;
import ai.djl.ndarray.types.Shape;
import ai.djl.nn.Block;
import ai.djl.nn.Blocks;
import ai.djl.nn.Parameter;
import ai.djl.nn.SequentialBlock;
import ai.djl.nn.core.Linear;
import ai.djl.testing.Assertions;
import ai.djl.testing.TestRequirements;
Expand Down Expand Up @@ -73,6 +77,61 @@ public void testAutograd() {
}
}

@Test
public void testFreezeParameters() {
try (Model model = Model.newInstance("model")) {
Block blockFrozen = new Mlp(10, 10, new int[] {10});
Block blockNormal = new Mlp(10, 10, new int[] {10});
Block combined = new SequentialBlock().add(blockFrozen).add(blockNormal);
model.setBlock(combined);

TrainingConfig config =
new DefaultTrainingConfig(Loss.l2Loss())
.optInitializer(Initializer.ONES, Parameter.Type.WEIGHT);

try (Trainer trainer = model.newTrainer(config)) {
trainer.initialize(new Shape(1, 10));

blockFrozen.freezeParameters(true);

// Find total params
Float frozenVal =
blockFrozen.getParameters().valueAt(0).getArray().sum().getFloat();
Float normalVal =
blockNormal.getParameters().valueAt(0).getArray().sum().getFloat();

// Run training step
NDManager manager = trainer.getManager();
NDArray data = manager.arange(100.0f).reshape(new Shape(10, 10));
NDArray labels = manager.arange(100.0f).reshape(new Shape(10, 10));
Batch batch =
new Batch(
manager, new NDList(data), new NDList(labels), 1, null, null, 0, 1);
EasyTrain.trainBatch(trainer, batch);
trainer.step();

// Check updated total params
// The frozen one should not have changed, but normal one should
Float newFrozenVal =
blockFrozen.getParameters().valueAt(0).getArray().sum().getFloat();
Float newNormalVal =
blockNormal.getParameters().valueAt(0).getArray().sum().getFloat();
Assert.assertEquals(newFrozenVal, frozenVal);
Assert.assertNotEquals(newNormalVal, normalVal);

blockFrozen.freezeParameters(false);

// Check that unfreezing the block now makes it update
EasyTrain.trainBatch(trainer, batch);
trainer.step();

Float nowUnfrozenVal =
blockFrozen.getParameters().valueAt(0).getArray().sum().getFloat();
Assert.assertNotEquals(nowUnfrozenVal, frozenVal);
}
}
}

@Test
public void testTrain() throws IOException, TranslateException {
TestRequirements.nightly();
Expand Down