diff --git a/llmfoundry/callbacks/curriculum_learning_callback.py b/llmfoundry/callbacks/curriculum_learning_callback.py index 70e996e494..7478496666 100644 --- a/llmfoundry/callbacks/curriculum_learning_callback.py +++ b/llmfoundry/callbacks/curriculum_learning_callback.py @@ -1,16 +1,16 @@ -# Copyright 2022 MosaicML LLM Foundry authors +# Copyright 2024 MosaicML LLM Foundry authors # SPDX-License-Identifier: Apache-2.0 -"""Enable curriculum learning by resuming with a different dataset. +"""Enable curriculum learning by specifying a schedule of datasets to train on. -This callback is currently experimental. The API may change without warning in -the future. +This module provides a CurriculumLearning callback that allows for dynamic +dataset switching during training based on a predefined schedule. """ import copy import logging -import warnings -from typing import Any, Optional, Union +from dataclasses import dataclass +from typing import Any, Union from composer import DataSpec from composer.core import State, Time, TimeUnit, ensure_time @@ -24,13 +24,18 @@ BaseContextualError, TrainDataLoaderLocation, ) -from llmfoundry.utils.warnings import VersionedDeprecationWarning log = logging.getLogger(__name__) __all__ = ['CurriculumLearning'] +@dataclass +class CurriculumLearningState: + schedule: list[dict[str, Any]] + schedule_index: int + + class CurriculumLearning(CallbackWithConfig): """Starts an epoch with a different dataset when resuming from a checkpoint. @@ -57,8 +62,8 @@ class CurriculumLearning(CallbackWithConfig): being used. Note that this is the full train config and must contain the 'train_loader', 'device_train_batch_size', and 'tokenizer' keys. - duration (Union[Time, str, int], optional): The duration of the first datamix - (which corresponds to the train_loader). Defaults to None. + duration (Union[Time, str, int]): The duration of the first datamix + (which corresponds to the train_loader). schedule (list[dict[str, Any]]): The list of datamixes to use and their durations. Duration units must match max_duration and be in terms of a TimeUnit that is supported by Iteration. The duration values must @@ -73,29 +78,17 @@ def __init__( self, train_config: dict[str, Any], schedule: list[dict[str, Any]], - duration: Optional[Union[Time, str, int]] = None, + duration: Union[Time, str, int], ): - if duration is None: - warnings.warn( - VersionedDeprecationWarning( - 'Specifying the full schedule in the CurriculumLearning ' + - 'callback is deprecated. Please specify the duration of ' + - 'the first datamix separately and change the schedule ' + - 'use datasets instead of dataloaders.', - remove_version='0.15.0', - ), - ) - # Ensure all duration units are in epochs or tokens and values are positive self._schedule = schedule if len(self._schedule) == 0: raise ValueError('The schedule must have at least one datamix.') - if duration is not None: - first_datamix = { - 'duration': duration, - 'dataset': train_config['train_loader']['dataset'], - } - self._schedule.insert(0, first_datamix) + first_datamix = { + 'duration': duration, + 'dataset': train_config['train_loader']['dataset'], + } + self._schedule.insert(0, first_datamix) for datamix in self._schedule: self._validate_datamix(datamix) @@ -167,10 +160,7 @@ def iteration_start(self, state: State, logger: Logger): clean_stale_shared_memory() datamix = copy.deepcopy(self._schedule[self._schedule_index]) train_loader_config = copy.deepcopy(self._train_loader_config) - if 'dataset' in datamix: - train_loader_config['dataset'].update(datamix['dataset']) - else: - train_loader_config = datamix['train_loader'] + train_loader_config['dataset'].update(datamix['dataset']) data_spec = self._build_train_loader( train_loader_config=train_loader_config, logger=logger, @@ -193,25 +183,29 @@ def iteration_end(self, state: State, logger: Logger): def state_dict(self): return { - 'schedule': self._schedule, - 'schedule_index': self._schedule_index, + 'state': + CurriculumLearningState( + schedule=self._schedule, + schedule_index=self._schedule_index, + ), } def load_state_dict(self, state: dict[str, Any]): - self._schedule_index = state['schedule_index'] + schedule = state['state'].schedule + self._schedule_index = state['state'].schedule_index # Ensure that the schedule has not changed on previously trained datamixes - for idx in range(state['schedule_index']): - if self._schedule[idx] != state['schedule'][idx]: + for idx in range(self._schedule_index): + if self._schedule[idx] != schedule[idx]: raise ValueError(( f'Previous datamixes must stay the same across ', - f'resumptions. Expected {state["schedule"][idx]} but got ', + f'resumptions. Expected {schedule[idx]} but got ', f'{self._schedule[idx]}', )) # Ensure that the datamix has not changed on the current datamix current_loader = self._schedule[self._schedule_index]['train_loader'] - saved_loader = state['schedule'][self._schedule_index]['train_loader'] + saved_loader = schedule[self._schedule_index]['train_loader'] if current_loader != saved_loader: raise ValueError(( f'The current datamix must stay the same across resumptions. ', @@ -282,5 +276,5 @@ def _validate_datamix(self, datamix: dict[str, Any]): 'Schedules can only be defined in terms of epochs or tokens.', ) - if 'train_loader' not in datamix and 'dataset' not in datamix: + if 'dataset' not in datamix: raise ValueError('Each datamix must have a dataset.') diff --git a/tests/callbacks/test_curriculum_learning_callback.py b/tests/callbacks/test_curriculum_learning_callback.py index 0e6a6c1efe..618aab456b 100644 --- a/tests/callbacks/test_curriculum_learning_callback.py +++ b/tests/callbacks/test_curriculum_learning_callback.py @@ -13,6 +13,8 @@ from omegaconf import OmegaConf as om from torch.utils.data import DataLoader +from llmfoundry.callbacks.curriculum_learning_callback import \ + CurriculumLearningState from llmfoundry.data.text_data import StreamingTextDataset from llmfoundry.utils.builders import build_callback @@ -237,8 +239,11 @@ def test_curriculum_learning_callback_state_dict(build_tiny_mpt: Callable,): callback.iteration_start(state, logger) callback.iteration_end(state, logger) assert callback.state_dict() == { - 'schedule': kwargs['schedule'], - 'schedule_index': 1, + 'state': + CurriculumLearningState( + schedule=kwargs['schedule'], + schedule_index=1, + ), } @@ -280,8 +285,11 @@ def test_curriculum_learning_callback_load_state_dict( callback.iteration_start(state, logger) callback.iteration_end(state, logger) assert callback.state_dict() == { - 'schedule': kwargs['schedule'], - 'schedule_index': 1, + 'state': + CurriculumLearningState( + schedule=kwargs['schedule'], + schedule_index=1, + ), }