Skip to content

Commit

Permalink
Remove performance issues from freezing MXNet (#2394)
Browse files Browse the repository at this point in the history
* Remove performance issues from freezing MXNet

In #2360, the behavior of using pre-trained models was to freeze parameters.
However freezing the parameters on MXNet seems to cause a significant
performance regression for training. This removes those changes for a temporary
workaround until a deeper investigation can take place.

Co-authored-by: Frank Liu <frankfliu2000@gmail.com>
  • Loading branch information
zachgk and frankfliu authored Feb 14, 2023
1 parent 7ef827a commit 04cf346
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@ public class MxModel extends BaseModel {
* @throws IOException Exception for file loading
*/
@Override
@SuppressWarnings("PMD.EmptyControlStatement")
public void load(Path modelPath, String prefix, Map<String, ?> options)
throws IOException, MalformedModelException {
setModelDir(modelPath);
Expand Down Expand Up @@ -143,12 +144,15 @@ public void load(Path modelPath, String prefix, Map<String, ?> options)
boolean trainParam =
options != null && Boolean.parseBoolean((String) options.get("trainParam"));
if (!trainParam) {
block.freezeParameters(true);
// TODO: See https://github.com/deepjavalibrary/djl/pull/2360
// NOPMD
// block.freezeParameters(true);
}
}

/** {@inheritDoc} */
@Override
@SuppressWarnings("PMD.EmptyControlStatement")
public Trainer newTrainer(TrainingConfig trainingConfig) {
PairList<Initializer, Predicate<Parameter>> initializer = trainingConfig.getInitializers();
if (block == null) {
Expand All @@ -157,7 +161,8 @@ public Trainer newTrainer(TrainingConfig trainingConfig) {
}
if (wasLoaded) {
// Unfreeze parameters if training directly
block.freezeParameters(false);
// TODO: See https://github.com/deepjavalibrary/djl/pull/2360
// block.freezeParameters(false);
}
for (Pair<Initializer, Predicate<Parameter>> pair : initializer) {
if (pair.getKey() != null && pair.getValue() != null) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,6 @@ private static Model getModel(Arguments arguments)
SequentialBlock newBlock = new SequentialBlock();
SymbolBlock block = (SymbolBlock) model.getBlock();
block.removeLastBlock();
block.freezeParameters(false);
newBlock.add(block);
// the original model don't include the flatten
// so apply the flatten here
Expand Down

0 comments on commit 04cf346

Please sign in to comment.