diff --git a/scripts/train/train.py b/scripts/train/train.py index 87217702e5..b31c15467e 100644 --- a/scripts/train/train.py +++ b/scripts/train/train.py @@ -33,7 +33,16 @@ def validate_config(cfg: DictConfig): """Validates compatible model and dataloader selection.""" loaders = [cfg.train_loader] if 'eval_loader' in cfg: - loaders.append(cfg.eval_loader) + eval_loader = cfg.eval_loader + if isinstance(eval_loader, ListConfig): + for loader in eval_loader: + if loader.label is None: + raise ValueError( + 'When specifying multiple evaluation datasets, each one must include the \ + `label` attribute.') + loaders.append(loader) + else: + loaders.append(eval_loader) for loader in loaders: if loader.name == 'text': if cfg.model.name in ['hf_prefix_lm', 'hf_t5']: @@ -245,10 +254,8 @@ def main(cfg: DictConfig) -> Trainer: must_exist=False, default_value=None, convert=True) - eval_loader_config: Optional[DictConfig] = pop_config(cfg, - 'eval_loader', - must_exist=False, - default_value=None) + eval_loader_config: Optional[Union[DictConfig, ListConfig]] = pop_config( + cfg, 'eval_loader', must_exist=False, default_value=None) icl_tasks_config: Optional[Union[ListConfig, str]] = pop_config(cfg, 'icl_tasks', @@ -466,15 +473,21 @@ def main(cfg: DictConfig) -> Trainer: ## Evaluation print('Building eval loader...') evaluators = [] - eval_loader = None + eval_loaders = [] if eval_loader_config is not None: - eval_dataloader = build_dataloader(eval_loader_config, tokenizer, - device_eval_batch_size) - eval_loader = Evaluator( - label='eval', - dataloader=eval_dataloader, - metric_names=[], # we will add these after model is created - ) + is_multi_eval = isinstance(eval_loader_config, ListConfig) + eval_configs = eval_loader_config if is_multi_eval else [ + eval_loader_config + ] + for eval_config in eval_configs: + eval_dataloader = build_dataloader(eval_config, tokenizer, + device_eval_batch_size) + eval_loader = Evaluator( + label=f'eval/{eval_config.label}' if is_multi_eval else 'eval', + dataloader=eval_dataloader, + metric_names=[], # we will add these after model is created + ) + eval_loaders.append(eval_loader) eval_gauntlet_callback = None @@ -514,11 +527,11 @@ def main(cfg: DictConfig) -> Trainer: # Now add the eval metrics if eval_loader_config is not None: - assert eval_loader is not None assert model.train_metrics is not None eval_metric_names = list(model.train_metrics.keys()) - eval_loader.metric_names = eval_metric_names - evaluators.insert(0, eval_loader) # Put the base eval_loader first + for eval_loader in eval_loaders: + eval_loader.metric_names = eval_metric_names + evaluators.insert(0, eval_loader) # Put the base eval_loaders first # Build the Trainer print('Building trainer...') diff --git a/tests/test_data_prep_scripts.py b/tests/test_data_prep_scripts.py index 52ab42806f..4c555ea9a2 100644 --- a/tests/test_data_prep_scripts.py +++ b/tests/test_data_prep_scripts.py @@ -37,13 +37,13 @@ def test_download_script_from_api(): def test_json_script_from_api(): # test calling it directly - path = os.path.join(os.getcwd(), 'my-copy-c4-3') + path = os.path.join(os.getcwd(), 'my-copy-arxiv-1') shutil.rmtree(path, ignore_errors=True) main_json( Namespace( **{ 'path': 'scripts/data_prep/example_data/arxiv.jsonl', - 'out_root': './my-copy-c4-3', + 'out_root': './my-copy-arxiv-1', 'compression': None, 'split': 'train', 'concat_tokens': None, diff --git a/tests/test_train_inputs.py b/tests/test_train_inputs.py index 2f29f6e7b5..bf90f48ef0 100644 --- a/tests/test_train_inputs.py +++ b/tests/test_train_inputs.py @@ -158,3 +158,23 @@ def test_invalid_name_in_scheduler_cfg_errors(self, main(cfg) assert str(exception_info.value ) == 'Not sure how to build scheduler: invalid-scheduler' + + def test_no_label_multiple_eval_datasets(self, cfg: DictConfig) -> None: + data_local = './my-copy-c4-multi-eval' + make_fake_index_file(f'{data_local}/train/index.json') + make_fake_index_file(f'{data_local}/val/index.json') + cfg.train_loader.dataset.local = data_local + # Set up multiple eval datasets + first_eval_loader = cfg.eval_loader + first_eval_loader.dataset.local = data_local + second_eval_loader = copy.deepcopy(first_eval_loader) + # Set the first eval dataloader to have no label + first_eval_loader.label = None + second_eval_loader.label = 'eval_1' + cfg.eval_loader = om.create([first_eval_loader, second_eval_loader]) + with pytest.raises(ValueError) as exception_info: + main(cfg) + assert str( + exception_info.value + ) == 'When specifying multiple evaluation datasets, each one must include the \ + `label` attribute.' diff --git a/tests/test_training.py b/tests/test_training.py index e03703c859..9d40fc2a78 100644 --- a/tests/test_training.py +++ b/tests/test_training.py @@ -1,6 +1,8 @@ # Copyright 2022 MosaicML LLM Foundry authors # SPDX-License-Identifier: Apache-2.0 +import copy import os +import pathlib import shutil import sys from argparse import Namespace @@ -16,13 +18,14 @@ sys.path.append(repo_dir) from scripts.data_prep.convert_dataset_hf import main as main_hf # noqa: E402 +from scripts.data_prep.convert_dataset_json import \ + main as main_json # noqa: E402 from scripts.train.train import main # noqa: E402 -def create_c4_dataset_xsmall(prefix: str) -> str: +def create_c4_dataset_xsmall(path: pathlib.Path) -> str: """Creates a small mocked version of the C4 dataset.""" - c4_dir = os.path.join(os.getcwd(), f'my-copy-c4-{prefix}') - shutil.rmtree(c4_dir, ignore_errors=True) + c4_dir = os.path.join(path, f'my-copy-c4') downloaded_split = 'val_xsmall' # very fast to convert # Hyperparameters from https://github.com/mosaicml/llm-foundry/blob/340a56658560ebceb2a3aa69d6e37813e415acd0/README.md#L188 @@ -52,6 +55,28 @@ def create_c4_dataset_xsmall(prefix: str) -> str: return c4_dir +def create_arxiv_dataset(path: pathlib.Path) -> str: + """Creates an arxiv dataset.""" + arxiv_dir = os.path.join(path, f'my-copy-arxiv') + downloaded_split = 'train' + + main_json( + Namespace( + **{ + 'path': 'data_prep/example_data/arxiv.jsonl', + 'out_root': arxiv_dir, + 'compression': None, + 'split': downloaded_split, + 'concat_tokens': None, + 'bos_text': None, + 'eos_text': None, + 'no_wrap': False, + 'num_workers': None + })) + + return arxiv_dir + + def gpt_tiny_cfg(dataset_name: str, device: str): """Create gpt tiny cfg.""" conf_path: str = os.path.join(repo_dir, @@ -89,9 +114,9 @@ def set_correct_cwd(): os.chdir('..') -def test_train_gauntlet(set_correct_cwd: Any): +def test_train_gauntlet(set_correct_cwd: Any, tmp_path: pathlib.Path): """Test training run with a small dataset.""" - dataset_name = create_c4_dataset_xsmall('cpu-gauntlet') + dataset_name = create_c4_dataset_xsmall(tmp_path) test_cfg = gpt_tiny_cfg(dataset_name, 'cpu') test_cfg.icl_tasks = ListConfig([ DictConfig({ @@ -150,3 +175,52 @@ def test_train_gauntlet(set_correct_cwd: Any): inmemorylogger.data['icl/metrics/eval_gauntlet/average'][-1], tuple) assert inmemorylogger.data['icl/metrics/eval_gauntlet/average'][-1][-1] == 0 + + +def test_train_multi_eval(set_correct_cwd: Any, tmp_path: pathlib.Path): + """Test training run with multiple eval datasets.""" + c4_dataset_name = create_c4_dataset_xsmall(tmp_path) + test_cfg = gpt_tiny_cfg(c4_dataset_name, 'cpu') + # Set up multiple eval dataloaders + first_eval_loader = test_cfg.eval_loader + first_eval_loader.label = 'c4' + # Create second eval dataloader using the arxiv dataset. + second_eval_loader = copy.deepcopy(first_eval_loader) + arxiv_dataset_name = create_arxiv_dataset(tmp_path) + second_eval_loader.data_local = arxiv_dataset_name + second_eval_loader.label = 'arxiv' + test_cfg.eval_loader = om.create([first_eval_loader, second_eval_loader]) + test_cfg.eval_subset_num_batches = 1 # -1 to evaluate on all batches + + test_cfg.max_duration = '1ba' + test_cfg.eval_interval = '1ba' + test_cfg.loggers = DictConfig({'inmemory': DictConfig({})}) + trainer = main(test_cfg) + + assert isinstance(trainer.logger.destinations, tuple) + + assert len(trainer.logger.destinations) > 0 + inmemorylogger = trainer.logger.destinations[ + 0] # pyright: ignore [reportGeneralTypeIssues] + assert isinstance(inmemorylogger, InMemoryLogger) + print(inmemorylogger.data.keys()) + + # Checks for first eval dataloader + assert 'metrics/eval/c4/LanguageCrossEntropy' in inmemorylogger.data.keys() + assert isinstance( + inmemorylogger.data['metrics/eval/c4/LanguageCrossEntropy'], list) + assert len( + inmemorylogger.data['metrics/eval/c4/LanguageCrossEntropy'][-1]) > 0 + assert isinstance( + inmemorylogger.data['metrics/eval/c4/LanguageCrossEntropy'][-1], tuple) + + # Checks for second eval dataloader + assert 'metrics/eval/arxiv/LanguageCrossEntropy' in inmemorylogger.data.keys( + ) + assert isinstance( + inmemorylogger.data['metrics/eval/arxiv/LanguageCrossEntropy'], list) + assert len( + inmemorylogger.data['metrics/eval/arxiv/LanguageCrossEntropy'][-1]) > 0 + assert isinstance( + inmemorylogger.data['metrics/eval/arxiv/LanguageCrossEntropy'][-1], + tuple)