Skip to content
This repository has been archived by the owner on Nov 3, 2023. It is now read-only.

[train] New training options for logging/validation based on number of steps #3379

Merged
merged 12 commits into from
Mar 8, 2021
2 changes: 1 addition & 1 deletion parlai/agents/test_agents/test_agents.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,7 @@ def eval_step(self, batch):

class MockTrainUpdatesAgent(MockTorchAgent):
"""
Simulate training updates
Simulate training updates.
"""

def train_step(self, batch):
Expand Down
14 changes: 3 additions & 11 deletions parlai/nn/lr_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
import numpy as np

from parlai.core.exceptions import StopTrainException
import parlai.utils.logging as logging
from parlai.utils.misc import warn_once


Expand Down Expand Up @@ -226,17 +225,10 @@ def lr_scheduler_factory(cls, opt, optimizer, states, hard_reset=False):
warmup_updates = opt.get('warmup_updates', -1)
warmup_rate = opt.get('warmup_rate', 1e-4)
max_lr_steps = opt.get('max_train_steps', -1)
deprecated_max_lr_steps = opt.get('max_lr_steps', -1)
if deprecated_max_lr_steps > 0:
logging.warn(
'**DEPRECATED: --max-lr-steps is deprecated, please only specify '
'--max-train-steps instead'
if opt.get('max_lr_steps', -1) > 0:
raise ValueError(
'--max-lr-steps is **DEPRECATED**; please set --max-train-steps directly'
)
if deprecated_max_lr_steps != max_lr_steps:
logging.warn(
f'Setting max_lr_steps from {deprecated_max_lr_steps} to {max_lr_steps}'
)
max_lr_steps = deprecated_max_lr_steps
invsqrt_lr_decay_gamma = opt.get('invsqrt_lr_decay_gamma', -1)

if opt.get('lr_scheduler') == 'none':
Expand Down
20 changes: 17 additions & 3 deletions parlai/scripts/train_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -704,13 +704,27 @@ def train(self):
exs_per_epoch = world.num_examples()
stephenroller marked this conversation as resolved.
Show resolved Hide resolved
self._total_exs = int(np.round(self._total_epochs * exs_per_epoch))
# and use the primary worker's timings for everything
train_time, log_time, validate_time = sync_object(
(
if any(
stephenroller marked this conversation as resolved.
Show resolved Hide resolved
getattr(self, k) < float('inf')
for k in [
'max_train_time',
'log_every_n_secs',
'validation_every_n_secs',
]
):
train_time, log_time, validate_time = sync_object(
(
self.train_time.time(),
self.log_time.time(),
self.validate_time.time(),
)
)
else:
train_time, log_time, validate_time = (
self.train_time.time(),
self.log_time.time(),
self.validate_time.time(),
)
)

# check counters and timers
if self._total_epochs >= self.max_num_epochs:
Expand Down
2 changes: 1 addition & 1 deletion tests/test_train_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ def test_multitasking_id_overlap(self):

def _test_opt_step_opts(self, update_freq: int):
"""
Test -tstep, -vstep, -lstep
Test -tstep, -vstep, -lstep.

:param update_freq:
update frequency
Expand Down