Skip to content
This repository has been archived by the owner on Nov 3, 2023. It is now read-only.

Full Tensorboard metric titles #3534

Merged
merged 7 commits into from
Mar 18, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -32,5 +32,6 @@ clean:
cd "${SOURCEDIR}"; python generate_task_list.py
cd "${SOURCEDIR}"; python generate_zoo_list.py
cd "${SOURCEDIR}"; python generate_mutator_list.py
cd "${SOURCEDIR}"; python generate_metric_list.py
cd "${SOURCEDIR}"; python generate_cli.py
@$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
17 changes: 17 additions & 0 deletions docs/source/generate_metric_list.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
#!/usr/bin/env python3

# Copyright (c) Facebook, Inc. and its affiliates.
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

from parlai.core.metrics import METRICS_DISPLAY_DATA


fout = open('metric_list.inc', 'w')

fout.write('| Metric | Explanation |\n')
fout.write('| ------ | ----------- |\n')
for metric, display in sorted(METRICS_DISPLAY_DATA.items()):
fout.write(f'| `{metric}` | {display.description} |\n')

fout.close()
35 changes: 2 additions & 33 deletions docs/source/tutorial_metrics.md
Original file line number Diff line number Diff line change
Expand Up @@ -417,36 +417,5 @@ If you find a metric not listed here,
please [file an issue on GitHub](https://github.com/facebookresearch/ParlAI/issues/new?assignees=&labels=Docs,Metrics&template=other.md).
:::

| Metric | Explanation |
| ----------------------- | ------------ |
| `accuracy` | Exact match text accuracy |
| `bleu-4` | BLEU-4 of the generation, under a standardized (model-independent) tokenizer |
| `clen` | Average length of context in number of tokens |
| `clip` | Fraction of batches with clipped gradients |
| `ctpb` | Context tokens per batch |
| `ctps` | Context tokens per second |
| `ctrunc` | Fraction of samples with some context truncation |
| `context_average_tokens_truncated` | Average length of context tokens truncated |
| `exps` | Examples per second |
| `exs` | Number of examples processed since last print |
| `f1` | Unigram F1 overlap, under a standardized (model-independent) tokenizer |
| `gnorm` | Gradient norm |
| `gpu_mem` | Fraction of GPU memory used. May slightly underestimate true value. |
| `hits@1`, `hits@5`, ... | Fraction of correct choices in K guesses. (Similar to recall@K) |
| `interdistinct-1`, `interdictinct-2` | Fraction of n-grams unique across _all_ generations |
| `intradistinct-1`, `intradictinct-2` | Fraction of n-grams unique _within_ each utterance |
| `jga` | Joint Goal Accuracy |
| `llen` | Average length of label in number of tokens |
| `loss` | Loss |
| `lr` | The most recent learning rate applied |
| `ltpb` | Label tokens per batch |
| `ltps` | Label tokens per second |
| `ltrunc` | Fraction of samples with some label truncation |
| `label_average_tokens_truncated` | Average length of label tokens truncated |
| `rouge-1`, `rouge-1`, `rouge-L` | ROUGE metrics |
| `token_acc` | Token-wise accuracy (generative only) |
| `token_em` | Utterance-level token accuracy. Roughly corresponds to perfection under greedy search (generative only) |
| `total_train_updates` | Number of SGD steps taken across all batches |
| `tpb` | Total tokens (context + label) per batch |
| `tps` | Total tokens (context + label) per second |
| `ups` | Updates per second (approximate) |
```{include} metric_list.inc
```
18 changes: 12 additions & 6 deletions parlai/core/logs.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
import numbers
import datetime
from parlai.core.opt import Opt
from parlai.core.metrics import Metric, dict_report
from parlai.core.metrics import Metric, dict_report, get_metric_display_data
from parlai.utils.io import PathManager
import parlai.utils.logging as logging

Expand Down Expand Up @@ -87,12 +87,18 @@ def log_metrics(self, setting, step, report):
The report to log
"""
for k, v in report.items():
if isinstance(v, numbers.Number):
self.writer.add_scalar(f'{k}/{setting}', v, global_step=step)
elif isinstance(v, Metric):
self.writer.add_scalar(f'{k}/{setting}', v.value(), global_step=step)
else:
v = v.value() if isinstance(v, Metric) else v
if not isinstance(v, numbers.Number):
logging.error(f'k {k} v {v} is not a number')
continue
display = get_metric_display_data(metric=k)
self.writer.add_scalar(
f'{k}/{setting}',
v,
global_step=step,
display_name=f"{display.title}",
summary_description=display.description,
)

def flush(self):
self.writer.flush()
Expand Down
114 changes: 113 additions & 1 deletion parlai/core/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,17 @@
import functools
import datetime
import math
from typing import Union, List, Optional, Tuple, Set, Any, Dict, Counter as TCounter
from typing import (
Any,
Counter as TCounter,
Dict,
List,
NamedTuple,
Optional,
Set,
Tuple,
Union,
)

import torch

Expand All @@ -35,6 +45,108 @@
ALL_METRICS = DEFAULT_METRICS | ROUGE_METRICS | BLEU_METRICS | DISTINCT_METRICS


class MetricDisplayData(NamedTuple):
title: str
description: str


METRICS_DISPLAY_DATA = {
"accuracy": MetricDisplayData("Accuracy", "Exact match text accuracy"),
"bleu-4": MetricDisplayData(
"BLEU-4",
"BLEU-4 of the generation, under a standardized (model-independent) tokenizer",
),
"clen": MetricDisplayData(
"Context Length", "Average length of context in number of tokens"
),
"clip": MetricDisplayData(
"Clipped Gradients", "Fraction of batches with clipped gradients"
),
"ctpb": MetricDisplayData("Context Tokens Per Batch", "Context tokens per batch"),
"ctps": MetricDisplayData("Context Tokens Per Second", "Context tokens per second"),
"ctrunc": MetricDisplayData(
"Context Truncation", "Fraction of samples with some context truncation"
),
"ctrunclen": MetricDisplayData(
"Context Truncation Length", "Average length of context tokens truncated"
),
"exps": MetricDisplayData("Examples Per Second", "Examples per second"),
"exs": MetricDisplayData(
"Examples", "Number of examples processed since last print"
),
"f1": MetricDisplayData(
"F1", "Unigram F1 overlap, under a standardized (model-independent) tokenizer"
),
"gnorm": MetricDisplayData("Gradient Norm", "Gradient norm"),
"gpu_mem": MetricDisplayData(
"GPU Memory",
"Fraction of GPU memory used. May slightly underestimate true value.",
),
"hits@1": MetricDisplayData(
"Hits@1", "Fraction of correct choices in 1 guess. (Similar to recall@K)"
),
"hits@5": MetricDisplayData(
"Hits@5", "Fraction of correct choices in 5 guesses. (Similar to recall@K)"
),
"interdistinct-1": MetricDisplayData(
"Interdistinct-1", "Fraction of n-grams unique across _all_ generations"
),
"interdistinct-2": MetricDisplayData(
"Interdistinct-1", "Fraction of n-grams unique across _all_ generations"
),
"intradistinct-1": MetricDisplayData(
"Intradictinct-1", "Fraction of n-grams unique _within_ each utterance"
),
"intradictinct-2": MetricDisplayData(
"Intradictinct-2", "Fraction of n-grams unique _within_ each utterance"
),
"jga": MetricDisplayData("Joint Goal Accuracy", "Joint Goal Accuracy"),
"llen": MetricDisplayData(
"Label Length", "Average length of label in number of tokens"
),
"loss": MetricDisplayData("Loss", "Loss"),
"lr": MetricDisplayData("Learning Rate", "The most recent learning rate applied"),
"ltpb": MetricDisplayData("Label Tokens Per Batch", "Label tokens per batch"),
"ltps": MetricDisplayData("Label Tokens Per Second", "Label tokens per second"),
"ltrunc": MetricDisplayData(
"Label Truncation", "Fraction of samples with some label truncation"
),
"ltrunclen": MetricDisplayData(
"Label Truncation Length", "Average length of label tokens truncated"
),
"rouge-1": MetricDisplayData("ROUGE-1", "ROUGE metrics"),
"rouge-2": MetricDisplayData("ROUGE-2", "ROUGE metrics"),
"rouge-L": MetricDisplayData("ROUGE-L", "ROUGE metrics"),
"token_acc": MetricDisplayData(
"Token Accuracy", "Token-wise accuracy (generative only)"
),
"token_em": MetricDisplayData(
"Token Exact Match",
"Utterance-level token accuracy. Roughly corresponds to perfection under greedy search (generative only)",
),
"total_train_updates": MetricDisplayData(
"Total Train Updates", "Number of SGD steps taken across all batches"
),
"tpb": MetricDisplayData(
"Tokens Per Batch", "Total tokens (context + label) per batch"
),
"tps": MetricDisplayData(
"Tokens Per Second", "Total tokens (context + label) per second"
),
"ups": MetricDisplayData("Updates Per Second", "Updates per second (approximate)"),
}


def get_metric_display_data(metric: str) -> MetricDisplayData:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe as a utility of MetricsDisplayData

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's kind of nice to keep this functional, though, since there isn't any state we should be keeping around. Also, it needs access to METRICS_DISPLAY_DATA which I think makes more sense scoped to the namespace than a class. What's the advantage you see from putting it in MetricDisplayData?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was thinking of a classmethod (and maybe the global too), just to keep everything in a tight namespace.

Copy link
Contributor

@stephenroller stephenroller Mar 18, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lol prolly the global can't be in there so long as it's self-typed...

anyway saul goodman

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah I think if we went that path, the metrics/titles/descriptions would live in a separate json/yaml file. And we'd have a separate function that loads them up as MetricDisplayDatas into a global. But then there'd be a disconnect between the source of truth and the global which is a little weird. I guess we could make MetricDisplayData a singleton and load them up on instantiation, but then we have to instantiate an object just to get this static list of strings which feels heavy.

Another way to keep them in a tight namespace would be to just create a metrics_list.py module.

Idk let me know if any of those options sound better, I see plenty of advantages and disadvantages to each so not super opinionated lol

return METRICS_DISPLAY_DATA.get(
metric,
MetricDisplayData(
title=metric,
description="No description provided. Please add it to metrics.py if this is an official metric in ParlAI.",
),
)


re_art = re.compile(r'\b(a|an|the)\b')
re_punc = re.compile(r'[!"#$%&()*+,-./:;<=>?@\[\]\\^`{|}~_\']')

Expand Down
6 changes: 2 additions & 4 deletions parlai/core/torch_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -2131,8 +2131,7 @@ def batch_act(self, observations):
)
if batch._context_truncated_length is not None:
self.record_local_metric(
'context_average_tokens_truncated',
AverageMetric.many(batch._context_truncated_length),
'ctrunclen', AverageMetric.many(batch._context_truncated_length)
)
if batch._label_original_length is not None:
self.record_local_metric(
Expand All @@ -2143,8 +2142,7 @@ def batch_act(self, observations):
)
if batch._label_truncated_length is not None:
self.record_local_metric(
'label_average_tokens_truncated',
AverageMetric.many(batch._label_truncated_length),
'ltrunclen', AverageMetric.many(batch._label_truncated_length)
)

self.global_metrics.add('exps', GlobalTimerMetric(batch.batchsize))
Expand Down
8 changes: 2 additions & 6 deletions tests/test_torch_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -1048,9 +1048,5 @@ def test_truncate_metrics(self):
self.assertEqual(agent._local_metrics['ltrunc'][0].value(), 1.0)
self.assertEqual(agent._local_metrics['clen'][0].value(), 9)
self.assertEqual(agent._local_metrics['llen'][0].value(), 11)
self.assertEqual(
agent._local_metrics['context_average_tokens_truncated'][0].value(), 4
)
self.assertEqual(
agent._local_metrics['label_average_tokens_truncated'][0].value(), 6
)
self.assertEqual(agent._local_metrics['ctrunclen'][0].value(), 4)
self.assertEqual(agent._local_metrics['ltrunclen'][0].value(), 6)