Skip to content

Commit

Permalink
Flax will default to returning regular dicts instead of FrozenDicts i…
Browse files Browse the repository at this point in the history
…n the future. Some internal-only tests fail because they expect frozen params when creating TrainState.

PiperOrigin-RevId: 533567519
  • Loading branch information
chiamp authored and Flax Authors committed May 19, 2023
1 parent 380fef6 commit 82db69d
Showing 1 changed file with 3 additions and 2 deletions.
5 changes: 3 additions & 2 deletions flax/training/train_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,11 +83,12 @@ def apply_gradients(self, *, grads, **kwargs):
@classmethod
def create(cls, *, apply_fn, params, tx, **kwargs):
"""Creates a new instance with `step=0` and initialized `opt_state`."""
opt_state = tx.init(params)
frozen_params = core.freeze(params)
opt_state = tx.init(frozen_params)
return cls(
step=0,
apply_fn=apply_fn,
params=params,
params=frozen_params,
tx=tx,
opt_state=opt_state,
**kwargs,
Expand Down

0 comments on commit 82db69d

Please sign in to comment.