Skip to content

Commit

Permalink
Bugfix: Statisics inherits n_correct from previous instance
Browse files Browse the repository at this point in the history
The default value must be either zero or None, depending on whether
accuracy is reported or not.
  • Loading branch information
Waino committed Oct 7, 2024
1 parent 7229141 commit 203d4d5
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 6 deletions.
7 changes: 4 additions & 3 deletions mammoth/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,8 +239,9 @@ def train(
else:
logger.info('Start training loop and validate every %d steps...', valid_steps)

total_stats = mammoth.utils.Statistics()
report_stats = mammoth.utils.Statistics()
n_correct = 0 if self.report_training_accuracy else None
total_stats = mammoth.utils.Statistics(n_correct=n_correct)
report_stats = mammoth.utils.Statistics(n_correct=n_correct)
self._start_report_manager(start_time=total_stats.start_time)
self.optim.zero_grad()

Expand Down Expand Up @@ -385,7 +386,7 @@ def validate(self, valid_iter, moving_average=None, task=None):

for batch, metadata, _ in valid_iter:
if stats is None:
stats = mammoth.utils.Statistics()
stats = mammoth.utils.Statistics(n_correct=0)

stats.n_src_words += batch.src.mask.sum().item()
src = batch.src.tensor
Expand Down
8 changes: 5 additions & 3 deletions mammoth/utils/report_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,8 @@ def report_training(
if optimizer is not None:
for line in optimizer.report_steps():
logger.info(line)
return mammoth.utils.Statistics()
n_correct = None if report_stats.n_correct is None else 0
return mammoth.utils.Statistics(n_correct=n_correct)
else:
return report_stats

Expand Down Expand Up @@ -156,7 +157,8 @@ def _report_training(self, step, num_steps, learning_rate, patience, report_stat
report_stats.output(step, num_steps, learning_rate, self.start_time)

self.maybe_log_tensorboard(report_stats, "progress", learning_rate, patience, step)
report_stats = mammoth.utils.Statistics()
n_correct = None if report_stats.n_correct is None else 0
report_stats = mammoth.utils.Statistics(n_correct=n_correct)

total = sum(sampled_task_counts.values())
logger.info(f'Task sampling distribution: (total {total})')
Expand All @@ -183,7 +185,7 @@ def _report_step(self, lr, patience, step, train_stats=None, valid_stats=None):
structured_logging({
'type': 'validation',
'step': step,
'learning_rate': lr,
# 'learning_rate': lr,
'perplexity': ppl,
'accuracy': acc,
'crossentropy': valid_stats.xent(),
Expand Down

0 comments on commit 203d4d5

Please sign in to comment.