From 9a9852995d3a2a7982394b5ecb7e52878c928b9e Mon Sep 17 00:00:00 2001 From: Marcus Chiam Date: Fri, 19 May 2023 15:34:46 -0700 Subject: [PATCH] Flax will default to returning regular dicts instead of FrozenDicts in the future. Some internal-only tests fail because they expect frozen params when creating TrainState. PiperOrigin-RevId: 533567519 --- flax/training/train_state.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/flax/training/train_state.py b/flax/training/train_state.py index d3b5641dfd..89bedee164 100644 --- a/flax/training/train_state.py +++ b/flax/training/train_state.py @@ -83,11 +83,12 @@ def apply_gradients(self, *, grads, **kwargs): @classmethod def create(cls, *, apply_fn, params, tx, **kwargs): """Creates a new instance with `step=0` and initialized `opt_state`.""" - opt_state = tx.init(params) + frozen_params = core.freeze(params) if isinstance(params, dict) else params + opt_state = tx.init(frozen_params) return cls( step=0, apply_fn=apply_fn, - params=params, + params=frozen_params, tx=tx, opt_state=opt_state, **kwargs,