Skip to content
This repository was archived by the owner on Aug 1, 2023. It is now read-only.
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
102 changes: 6 additions & 96 deletions pytorch_translate/evals.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,61 +160,6 @@ def eval_tune_loss(args, trainer, task, subset, extra_state):
return extra_state, stop_due_to_tune_loss


def evaluate_bleu(
args, task, extra_state: Dict[str, Any], trainer, averaged_params: OrderedDict
) -> Tuple[Dict[str, Any], bool, bool, List]:
if args.disable_eval_bleu:
extra_state["tune_bleu"]["current"] = 0.0
return (extra_state, False, True, None)
epoch, offset = extra_state["epoch"], extra_state["batch_offset"]
if args.log_verbose:
print(
f"| Preparing to calculate BLEU score for epoch {epoch}, offset {offset}."
)
extra_state["tune_bleu"]["current"], translation_samples = calculate_bleu_on_subset(
args=args,
task=task,
epoch_str=f"{epoch:03d}",
offset=offset,
dataset_split=args.valid_subset,
trainer=trainer,
model_params=averaged_params,
)
if args.log_verbose:
print(f"| Finished calculating BLEU score for epoch {epoch}, offset {offset}.")

new_best_averaged_checkpoint = False
if (
extra_state["tune_bleu"]["best"] is None
or extra_state["tune_bleu"]["current"] > extra_state["tune_bleu"]["best"]
):
extra_state["tune_bleu"]["best"] = extra_state["tune_bleu"]["current"]
extra_state["tune_bleu"]["best_epoch"] = epoch
extra_state["tune_bleu"]["num_since_best"] = 0
new_best_averaged_checkpoint = True
else:
extra_state["tune_bleu"]["num_since_best"] += 1

stop_due_to_tune_bleu = False
if (
args.stop_no_best_bleu_eval >= 0
and extra_state["tune_bleu"]["num_since_best"] > args.stop_no_best_bleu_eval
):
stop_due_to_tune_bleu = True
print(
f"Stopping training due to BLEU score stagnation on tune set - "
f"last best BLEU score of {extra_state['tune_bleu']['best']} "
f"(current score: {extra_state['tune_bleu']['current']}) was "
f"{extra_state['tune_bleu']['num_since_best']} evals ago."
)
return (
extra_state,
stop_due_to_tune_bleu,
new_best_averaged_checkpoint,
translation_samples,
)


def calculate_bleu_on_subset(
args,
task,
Expand Down Expand Up @@ -336,7 +281,7 @@ def save_and_eval(
args.save_interval_updates <= 0
or (extra_state["num_iterations"] % args.save_interval_updates != 0)
):
return extra_state, stop_due_to_time_limit, None
return extra_state, stop_due_to_time_limit

# Update training time before saving the checkpoint.
time_now: float = time.time()
Expand Down Expand Up @@ -369,34 +314,11 @@ def save_and_eval(
f"have a checkpoint_manager defined."
)

# trick to prepare the task for evaluation, e.g. in latent variable model we need to set eval_key in RoundRobinZipDataset
if hasattr(task, "prepare_for_eval") and callable(task.prepare_for_eval):
task.prepare_for_eval()
# Only save checkpoints and eval tune BLEU on the master - all other
# processes will just get the results from the master.
translation_samples: Optional[List] = None
if is_master:
averaged_params: OrderedDict = checkpoint_manager.get_averaged_params(
new_params=trainer.get_model().state_dict()
)

# TODO: fix after masked lm work completes
if "save_only" not in args or not args.save_only:
(
extra_state,
stop_due_to_tune_bleu,
new_best_averaged_checkpoint,
translation_samples,
) = evaluate_bleu(
args=args,
task=task,
extra_state=extra_state,
trainer=trainer,
averaged_params=averaged_params,
)
else:
new_best_averaged_checkpoint = True
stop_due_to_tune_bleu = False
new_best_averaged_checkpoint = extra_state["tune_eval"]["num_since_best"] == 0
# checkpoint_manager takes ownership of averaged_params.
extra_state = checkpoint_manager.save(
args=args,
Expand All @@ -408,33 +330,21 @@ def save_and_eval(
checkpoint_manager.save_best_averaged_checkpoint(
args=args, trainer=trainer, extra_state=extra_state
)
if hasattr(task, "prepare_for_train") and callable(task.prepare_for_train):
task.prepare_for_train()

# extra_state["tune_bleu"] needs to be sync'ed between master and workers
# since we only do BLEU eval on master, but then need that info for
# determining when to do lr_shrink on all workers.
master_tune_bleu = None
master_stop_training = None
if is_master:
master_tune_bleu = extra_state["tune_bleu"]
master_stop_training = (
stop_due_to_time_limit
or stop_due_to_tune_loss
or stop_due_to_tune_bleu
or stop_due_to_max_update
stop_due_to_time_limit or stop_due_to_tune_loss or stop_due_to_max_update
)
tune_bleu, stop_training = pytorch_translate_utils.all_gather_from_master(
args=args, data=[master_tune_bleu, master_stop_training]
stop_training = pytorch_translate_utils.all_gather_from_master(
args=args, data=[master_stop_training]
)
extra_state["tune_bleu"] = tune_bleu

# TODO: fix after masked lm work completes
if "save_only" not in args or not args.save_only:
# Basic sanity checks that extra_state is populated correctly.
assert (
extra_state["tune_eval"]["loss"] is not None
and extra_state["tune_eval"]["perplexity"] is not None
and extra_state["tune_bleu"]["current"] is not None
)
return extra_state, stop_training, translation_samples
return extra_state, stop_training
4 changes: 2 additions & 2 deletions pytorch_translate/options.py
Original file line number Diff line number Diff line change
Expand Up @@ -555,12 +555,12 @@ def expand_optimization_args(group):
"in the first place. A value of < 0 disables this.",
)
group.add_argument(
"--shrink-lr-no-best-bleu-eval",
"--shrink-lr-no-best-tune-loss",
default=5,
type=int,
metavar="N",
help="Decay learning rate after N evals have been run without "
"achieving a better BLEU score than before. This is to achieve "
"achieving a lower tune loss than before. This is to achieve "
"decay lr within an epoch, independent of lr_scheduler. "
"Note that this is affected by --save-interval-updates in "
"how frequently we run BLEU eval in the first place. "
Expand Down
32 changes: 6 additions & 26 deletions pytorch_translate/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,12 +109,6 @@ def default_extra_state(args) -> Dict[str, Any]:
"lowest_loss": None,
"num_since_best": 0,
},
"tune_bleu": {
"current": None,
"best": None,
"best_epoch": None,
"num_since_best": 0,
},
# The list of checkpoint files is actually managed by the
# CheckpointManager, which overwrites this placeholder when it saves
# checkpoints.
Expand All @@ -136,8 +130,8 @@ def update_output(
num_updates,
{
"train_ppl": train_ppl,
"tune_loss": extra_state["tune_eval"]["loss"],
"tune_ppl": extra_state["tune_eval"]["perplexity"],
"tune_bleu": extra_state["tune_bleu"]["current"],
"wps": wps,
# translation_samples isn't currently used by the queue reader,
# so just pass None for now until we start needing it.
Expand All @@ -159,7 +153,6 @@ def clear_per_step_extra_state(extra_state: Dict[str, Any]) -> Dict[str, Any]:
"""
extra_state["tune_eval"]["loss"] = None
extra_state["tune_eval"]["perplexity"] = None
extra_state["tune_bleu"]["current"] = None
return extra_state


Expand Down Expand Up @@ -566,11 +559,7 @@ def train(
# any case where extra_case does not get populated correctly.
extra_state = clear_per_step_extra_state(extra_state)
extra_state["batch_offset"] = i + 1
(
extra_state,
stop_training_mid_epoch,
translation_samples,
) = evals.save_and_eval(
extra_state, stop_training_mid_epoch = evals.save_and_eval(
args=args,
trainer=trainer,
task=task,
Expand Down Expand Up @@ -607,9 +596,9 @@ def train(
hasattr(args, "lr_shrink")
and args.save_interval_updates > 0
and extra_state["num_iterations"] % args.save_interval_updates == 0
and args.shrink_lr_no_best_bleu_eval > 0
and extra_state["tune_bleu"]["num_since_best"]
> args.shrink_lr_no_best_bleu_eval
and args.shrink_lr_no_best_tune_loss > 0
and extra_state["tune_eval"]["num_since_best"]
> args.shrink_lr_no_best_tune_loss
):
current_lr = trainer.optimizer.get_lr()
trainer.optimizer.set_lr(current_lr * args.lr_shrink)
Expand All @@ -626,11 +615,7 @@ def train(

# batch_offset being None denotes the end of an epoch.
extra_state["batch_offset"] = None
(
extra_state,
stop_training_end_of_epoch,
translation_samples,
) = evals.save_and_eval(
extra_state, stop_training_end_of_epoch = evals.save_and_eval(
args=args,
trainer=trainer,
task=task,
Expand Down Expand Up @@ -662,11 +647,6 @@ def train(
if checkpoint_manager:
checkpoint_manager.remove_all_checkpoints()

print(
f"| Best BLEU score of {extra_state['tune_bleu']['best']} was from "
f"epoch {extra_state['tune_bleu']['best_epoch']}"
)


def setup_epoch(args, epoch_itr, trainer):
"""Sets up data and progress meters for one epoch."""
Expand Down