Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Multi eval dataset logging #603

Merged
merged 10 commits into from
Sep 27, 2023
45 changes: 29 additions & 16 deletions scripts/train/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,16 @@ def validate_config(cfg: DictConfig):
"""Validates compatible model and dataloader selection."""
loaders = [cfg.train_loader]
if 'eval_loader' in cfg:
loaders.append(cfg.eval_loader)
eval_loader = cfg.eval_loader
if isinstance(eval_loader, ListConfig):
for loader in eval_loader:
if loader.label is None:
raise ValueError(
'When specifying multiple evaluation datasets, each one must include the \
`label` attribute.')
loaders.append(loader)
else:
loaders.append(eval_loader)
for loader in loaders:
if loader.name == 'text':
if cfg.model.name in ['hf_prefix_lm', 'hf_t5']:
Expand Down Expand Up @@ -245,10 +254,8 @@ def main(cfg: DictConfig) -> Trainer:
must_exist=False,
default_value=None,
convert=True)
eval_loader_config: Optional[DictConfig] = pop_config(cfg,
'eval_loader',
must_exist=False,
default_value=None)
eval_loader_config: Optional[Union[DictConfig, ListConfig]] = pop_config(
cfg, 'eval_loader', must_exist=False, default_value=None)
icl_tasks_config: Optional[Union[ListConfig,
str]] = pop_config(cfg,
'icl_tasks',
Expand Down Expand Up @@ -466,15 +473,21 @@ def main(cfg: DictConfig) -> Trainer:
## Evaluation
print('Building eval loader...')
evaluators = []
eval_loader = None
eval_loaders = []
if eval_loader_config is not None:
eval_dataloader = build_dataloader(eval_loader_config, tokenizer,
device_eval_batch_size)
eval_loader = Evaluator(
label='eval',
dataloader=eval_dataloader,
metric_names=[], # we will add these after model is created
)
is_multi_eval = isinstance(eval_loader_config, ListConfig)
eval_configs = eval_loader_config if is_multi_eval else [
eval_loader_config
]
for eval_config in eval_configs:
eval_dataloader = build_dataloader(eval_config, tokenizer,
device_eval_batch_size)
eval_loader = Evaluator(
label=f'eval/{eval_config.label}' if is_multi_eval else 'eval',
dataloader=eval_dataloader,
metric_names=[], # we will add these after model is created
)
eval_loaders.append(eval_loader)
snarayan21 marked this conversation as resolved.
Show resolved Hide resolved

eval_gauntlet_callback = None

Expand Down Expand Up @@ -514,11 +527,11 @@ def main(cfg: DictConfig) -> Trainer:

# Now add the eval metrics
if eval_loader_config is not None:
assert eval_loader is not None
assert model.train_metrics is not None
eval_metric_names = list(model.train_metrics.keys())
eval_loader.metric_names = eval_metric_names
evaluators.insert(0, eval_loader) # Put the base eval_loader first
for eval_loader in eval_loaders:
eval_loader.metric_names = eval_metric_names
evaluators.insert(0, eval_loader) # Put the base eval_loaders first

# Build the Trainer
print('Building trainer...')
Expand Down
4 changes: 2 additions & 2 deletions tests/test_data_prep_scripts.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,13 +37,13 @@ def test_download_script_from_api():

def test_json_script_from_api():
# test calling it directly
path = os.path.join(os.getcwd(), 'my-copy-c4-3')
path = os.path.join(os.getcwd(), 'my-copy-arxiv-1')
irenedea marked this conversation as resolved.
Show resolved Hide resolved
shutil.rmtree(path, ignore_errors=True)
main_json(
Namespace(
**{
'path': 'scripts/data_prep/example_data/arxiv.jsonl',
'out_root': './my-copy-c4-3',
'out_root': './my-copy-arxiv-1',
'compression': None,
'split': 'train',
'concat_tokens': None,
Expand Down
20 changes: 20 additions & 0 deletions tests/test_train_inputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,3 +158,23 @@ def test_invalid_name_in_scheduler_cfg_errors(self,
main(cfg)
assert str(exception_info.value
) == 'Not sure how to build scheduler: invalid-scheduler'

def test_no_label_multiple_eval_datasets(self, cfg: DictConfig) -> None:
data_local = './my-copy-c4-multi-eval'
make_fake_index_file(f'{data_local}/train/index.json')
make_fake_index_file(f'{data_local}/val/index.json')
cfg.train_loader.dataset.local = data_local
# Set up multiple eval datasets
first_eval_loader = cfg.eval_loader
first_eval_loader.dataset.local = data_local
second_eval_loader = copy.deepcopy(first_eval_loader)
# Set the first eval dataloader to have no label
first_eval_loader.label = None
second_eval_loader.label = 'eval_1'
cfg.eval_loader = om.create([first_eval_loader, second_eval_loader])
with pytest.raises(ValueError) as exception_info:
main(cfg)
assert str(
exception_info.value
) == 'When specifying multiple evaluation datasets, each one must include the \
`label` attribute.'
84 changes: 79 additions & 5 deletions tests/test_training.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
# Copyright 2022 MosaicML LLM Foundry authors
# SPDX-License-Identifier: Apache-2.0
import copy
import os
import pathlib
import shutil
import sys
from argparse import Namespace
Expand All @@ -16,13 +18,14 @@
sys.path.append(repo_dir)

from scripts.data_prep.convert_dataset_hf import main as main_hf # noqa: E402
from scripts.data_prep.convert_dataset_json import \
main as main_json # noqa: E402
from scripts.train.train import main # noqa: E402


def create_c4_dataset_xsmall(prefix: str) -> str:
def create_c4_dataset_xsmall(path: pathlib.Path) -> str:
"""Creates a small mocked version of the C4 dataset."""
c4_dir = os.path.join(os.getcwd(), f'my-copy-c4-{prefix}')
shutil.rmtree(c4_dir, ignore_errors=True)
c4_dir = os.path.join(path, f'my-copy-c4')
downloaded_split = 'val_xsmall' # very fast to convert

# Hyperparameters from https://github.com/mosaicml/llm-foundry/blob/340a56658560ebceb2a3aa69d6e37813e415acd0/README.md#L188
Expand Down Expand Up @@ -52,6 +55,28 @@ def create_c4_dataset_xsmall(prefix: str) -> str:
return c4_dir


def create_arxiv_dataset(path: pathlib.Path) -> str:
"""Creates an arxiv dataset."""
arxiv_dir = os.path.join(path, f'my-copy-arxiv')
downloaded_split = 'train'

main_json(
Namespace(
**{
'path': 'data_prep/example_data/arxiv.jsonl',
'out_root': arxiv_dir,
'compression': None,
'split': downloaded_split,
'concat_tokens': None,
'bos_text': None,
'eos_text': None,
'no_wrap': False,
'num_workers': None
}))

return arxiv_dir


def gpt_tiny_cfg(dataset_name: str, device: str):
"""Create gpt tiny cfg."""
conf_path: str = os.path.join(repo_dir,
Expand Down Expand Up @@ -89,9 +114,9 @@ def set_correct_cwd():
os.chdir('..')


def test_train_gauntlet(set_correct_cwd: Any):
def test_train_gauntlet(set_correct_cwd: Any, tmp_path: pathlib.Path):
"""Test training run with a small dataset."""
dataset_name = create_c4_dataset_xsmall('cpu-gauntlet')
dataset_name = create_c4_dataset_xsmall(tmp_path)
test_cfg = gpt_tiny_cfg(dataset_name, 'cpu')
test_cfg.icl_tasks = ListConfig([
DictConfig({
Expand Down Expand Up @@ -150,3 +175,52 @@ def test_train_gauntlet(set_correct_cwd: Any):
inmemorylogger.data['icl/metrics/eval_gauntlet/average'][-1], tuple)

assert inmemorylogger.data['icl/metrics/eval_gauntlet/average'][-1][-1] == 0


def test_train_multi_eval(set_correct_cwd: Any, tmp_path: pathlib.Path):
"""Test training run with multiple eval datasets."""
c4_dataset_name = create_c4_dataset_xsmall(tmp_path)
test_cfg = gpt_tiny_cfg(c4_dataset_name, 'cpu')
# Set up multiple eval dataloaders
first_eval_loader = test_cfg.eval_loader
first_eval_loader.label = 'c4'
# Create second eval dataloader using the arxiv dataset.
second_eval_loader = copy.deepcopy(first_eval_loader)
arxiv_dataset_name = create_arxiv_dataset(tmp_path)
second_eval_loader.data_local = arxiv_dataset_name
second_eval_loader.label = 'arxiv'
test_cfg.eval_loader = om.create([first_eval_loader, second_eval_loader])
test_cfg.eval_subset_num_batches = 1 # -1 to evaluate on all batches

test_cfg.max_duration = '1ba'
test_cfg.eval_interval = '1ba'
test_cfg.loggers = DictConfig({'inmemory': DictConfig({})})
trainer = main(test_cfg)

assert isinstance(trainer.logger.destinations, tuple)

assert len(trainer.logger.destinations) > 0
inmemorylogger = trainer.logger.destinations[
0] # pyright: ignore [reportGeneralTypeIssues]
assert isinstance(inmemorylogger, InMemoryLogger)
print(inmemorylogger.data.keys())

# Checks for first eval dataloader
assert 'metrics/eval/c4/LanguageCrossEntropy' in inmemorylogger.data.keys()
assert isinstance(
inmemorylogger.data['metrics/eval/c4/LanguageCrossEntropy'], list)
assert len(
inmemorylogger.data['metrics/eval/c4/LanguageCrossEntropy'][-1]) > 0
assert isinstance(
inmemorylogger.data['metrics/eval/c4/LanguageCrossEntropy'][-1], tuple)

# Checks for second eval dataloader
assert 'metrics/eval/arxiv/LanguageCrossEntropy' in inmemorylogger.data.keys(
)
assert isinstance(
inmemorylogger.data['metrics/eval/arxiv/LanguageCrossEntropy'], list)
assert len(
inmemorylogger.data['metrics/eval/arxiv/LanguageCrossEntropy'][-1]) > 0
assert isinstance(
inmemorylogger.data['metrics/eval/arxiv/LanguageCrossEntropy'][-1],
tuple)
Loading