Skip to content

Commit

Permalink
Flax will default to returning regular dicts instead of FrozenDicts i…
Browse files Browse the repository at this point in the history
…n the future. Some internal-only tests fail because they expect frozen params when creating TrainState.

PiperOrigin-RevId: 533567519
  • Loading branch information
chiamp authored and Flax Authors committed May 31, 2023
1 parent a1a669a commit 9a98529
Showing 1 changed file with 3 additions and 2 deletions.
5 changes: 3 additions & 2 deletions flax/training/train_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,11 +83,12 @@ def apply_gradients(self, *, grads, **kwargs):
@classmethod
def create(cls, *, apply_fn, params, tx, **kwargs):
"""Creates a new instance with `step=0` and initialized `opt_state`."""
opt_state = tx.init(params)
frozen_params = core.freeze(params) if isinstance(params, dict) else params
opt_state = tx.init(frozen_params)
return cls(
step=0,
apply_fn=apply_fn,
params=params,
params=frozen_params,
tx=tx,
opt_state=opt_state,
**kwargs,
Expand Down

0 comments on commit 9a98529

Please sign in to comment.