Skip to content

Commit

Permalink
Fix evaluators actually pulling eval metrics (#1006)
Browse files Browse the repository at this point in the history
* fix bug on metrics

* lint

* lint

* add unit test

* lint
  • Loading branch information
mvpatel2000 authored Mar 5, 2024
1 parent cbdddf0 commit 09ff550
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 4 deletions.
10 changes: 7 additions & 3 deletions scripts/train/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from composer.loggers import MosaicMLLogger
from composer.loggers.mosaicml_logger import (MOSAICML_ACCESS_TOKEN_ENV_VAR,
MOSAICML_PLATFORM_ENV_VAR)
from composer.metrics.nlp import InContextLearningMetric
from composer.profiler import (JSONTraceHandler, Profiler, TraceHandler,
cyclic_schedule)
from composer.utils import dist, get_device, reproducibility
Expand Down Expand Up @@ -538,9 +539,12 @@ def main(cfg: DictConfig) -> Trainer:

# Now add the eval metrics
if eval_loader_config is not None and not use_async_eval:
train_metrics = model.get_metrics(is_train=True)
evaluators = add_metrics_to_eval_loaders(evaluators,
list(train_metrics.keys()))
eval_metrics = model.get_metrics(is_train=False)
non_icl_metrics = [
metric_name for metric_name, metric in eval_metrics.items()
if not isinstance(metric, InContextLearningMetric)
]
evaluators = add_metrics_to_eval_loaders(evaluators, non_icl_metrics)

# Build the Trainer
log.info('Building trainer...')
Expand Down
29 changes: 28 additions & 1 deletion tests/a_scripts/train/test_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,6 @@ def test_train_multi_eval(tmp_path: pathlib.Path):
inmemorylogger = trainer.logger.destinations[
0] # pyright: ignore [reportGeneralTypeIssues]
assert isinstance(inmemorylogger, InMemoryLogger)
print(inmemorylogger.data.keys())

# Checks for first eval dataloader
assert 'metrics/eval/c4/LanguageCrossEntropy' in inmemorylogger.data.keys()
Expand All @@ -143,3 +142,31 @@ def test_train_multi_eval(tmp_path: pathlib.Path):
assert isinstance(
inmemorylogger.data['metrics/eval/arxiv/LanguageCrossEntropy'][-1],
tuple)


def test_eval_metrics_with_no_train_metrics(tmp_path: pathlib.Path):
"""Test using use_train_metrics=False does not disable eval metrics."""
c4_dataset_name = create_c4_dataset_xxsmall(tmp_path)
test_cfg = gpt_tiny_cfg(c4_dataset_name, 'cpu')
first_eval_loader = test_cfg.eval_loader
first_eval_loader.label = 'c4'
test_cfg.eval_loader = om.create([first_eval_loader])
test_cfg.eval_subset_num_batches = 1 # -1 to evaluate on all batches
test_cfg.max_duration = '1ba'
test_cfg.eval_interval = '1ba'
test_cfg.loggers = DictConfig({'inmemory': DictConfig({})})
test_cfg.model['use_train_metrics'] = False
trainer = main(test_cfg)

# Check eval metrics exist
inmemorylogger = trainer.logger.destinations[
0] # pyright: ignore [reportGeneralTypeIssues]
assert isinstance(inmemorylogger, InMemoryLogger)

assert 'metrics/eval/c4/LanguageCrossEntropy' in inmemorylogger.data.keys()
assert isinstance(
inmemorylogger.data['metrics/eval/c4/LanguageCrossEntropy'], list)
assert len(
inmemorylogger.data['metrics/eval/c4/LanguageCrossEntropy'][-1]) > 0
assert isinstance(
inmemorylogger.data['metrics/eval/c4/LanguageCrossEntropy'][-1], tuple)

0 comments on commit 09ff550

Please sign in to comment.