Replies: 1 comment
-
To access the learning rate during training you could:
def fit(params: optax.Params, optimizer: optax.GradientTransformation) -> optax.Params:
opt_state = optimizer.init(params)
for i, (batch, labels) in enumerate(zip(TRAINING_DATA, LABELS)):
params, opt_state, loss_value = step(params, opt_state, batch, labels)
if i % 100 == 0:
count = opt_state.inner_state[0].count # get current step
lr = schedule(count) # get learning rate from schedule
print(f'Step {i:3}, Loss: {loss_value:.3f}, Learning rate: {lr:.9f}')
return params
params = fit(initial_params, optimizer)
# Wrap the optimizer to inject the hyperparameters
optimizer = optax.inject_hyperparams(optax.adamw)(learning_rate=schedule)
def fit(params: optax.Params, optimizer: optax.GradientTransformation) -> optax.Params:
opt_state = optimizer.init(params)
# Since we injected hyperparams, we can access them directly here
print(f'Available hyperparams: {" ".join(opt_state.hyperparams.keys())}\n')
for i, (batch, labels) in enumerate(zip(TRAINING_DATA, LABELS)):
params, opt_state, loss_value = step(params, opt_state, batch, labels)
if i % 100 == 0:
# Get the updated learning rate
lr = opt_state.hyperparams['learning_rate']
print(f'Step {i:3}, Loss: {loss_value:.3f}, Learning rate: {lr:.3f}')
return params
params = fit(initial_params, optimizer)
For further discussion, see #206 |
Beta Was this translation helpful? Give feedback.
0 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
-
Is there a method for accessing the learning rate being used by the optimizer at each step during training from the schedule.py schedules?
Beta Was this translation helpful? Give feedback.
All reactions