Skip to content
This repository has been archived by the owner on Nov 3, 2023. It is now read-only.

Commit

Permalink
Fix more tests.
Browse files Browse the repository at this point in the history
  • Loading branch information
stephenroller committed Aug 30, 2020
1 parent abc3f2d commit 125a5b4
Show file tree
Hide file tree
Showing 3 changed files with 4 additions and 3 deletions.
4 changes: 2 additions & 2 deletions parlai/nn/lr_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def _init_warmup_scheduler(self, optimizer, states):
self.warmup_scheduler = optim.lr_scheduler.LambdaLR(
optimizer, self._warmup_lr
)
if 'warmup_scheduler' in states:
if states.get('warmup_scheduler'):
self.warmup_scheduler.load_state_dict(states['warmup_scheduler'])
else:
self.warmup_scheduler = None
Expand Down Expand Up @@ -307,7 +307,7 @@ def step(self, num_steps):
Override this method to override the behavior for training schedulers.
"""
self._number_training_updates += 1
self._number_training_updates = num_steps
if self._is_lr_warming_up():
self.warmup_scheduler.step()
else:
Expand Down
2 changes: 1 addition & 1 deletion tests/test_lr_schedulers.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ def test_invsqrt(self):
steps = self._run_pass(
lr_scheduler='invsqrt', warmup_updates=50, invsqrt_lr_decay_gamma=1
)
self.assertAlmostEquals(steps[-1], 0.0324272)
self.assertAlmostEquals(steps[-1], 0.0324443)

# decay very slowly
steps = self._run_pass(
Expand Down
1 change: 1 addition & 0 deletions tests/test_transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -797,6 +797,7 @@ def _test_learning_rate_resuming(self, args):
init_model=os.path.join(tmpdir, 'model'),
model_file=os.path.join(tmpdir, 'newmodel2'),
lr_scheduler='reduceonplateau',
log_every_n_secs=0.001,
**args,
)
)
Expand Down

0 comments on commit 125a5b4

Please sign in to comment.