Skip to content

Commit

Permalink
Pass name of layer to init scale.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 555526825
  • Loading branch information
The paxml Authors committed Aug 10, 2023
1 parent 14ad181 commit 71b71da
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion paxml/learners.py
Original file line number Diff line number Diff line change
Expand Up @@ -415,7 +415,7 @@ def _adjust_var(old_var, transformed_grad, is_learnable, var_wh, var_key):
target_magnitude = 1.0
logging.info('var: %s is a scalar', var_key)
else:
target_magnitude = base_layer.var_init_scale(var_wh)
target_magnitude = base_layer.var_init_scale(var_wh, var_key)
logging.info(
'var: %s, target_magnitude: %f', var_key, target_magnitude
)
Expand Down

0 comments on commit 71b71da

Please sign in to comment.