Skip to content

Commit

Permalink
Fix pytorch checkpointing for CL callback (#1581)
Browse files Browse the repository at this point in the history
  • Loading branch information
b-chu authored Oct 10, 2024
1 parent 5b0a53e commit 1654827
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 43 deletions.
72 changes: 33 additions & 39 deletions llmfoundry/callbacks/curriculum_learning_callback.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,16 @@
# Copyright 2022 MosaicML LLM Foundry authors
# Copyright 2024 MosaicML LLM Foundry authors
# SPDX-License-Identifier: Apache-2.0

"""Enable curriculum learning by resuming with a different dataset.
"""Enable curriculum learning by specifying a schedule of datasets to train on.
This callback is currently experimental. The API may change without warning in
the future.
This module provides a CurriculumLearning callback that allows for dynamic
dataset switching during training based on a predefined schedule.
"""

import copy
import logging
import warnings
from typing import Any, Optional, Union
from dataclasses import dataclass
from typing import Any, Union

from composer import DataSpec
from composer.core import State, Time, TimeUnit, ensure_time
Expand All @@ -24,13 +24,18 @@
BaseContextualError,
TrainDataLoaderLocation,
)
from llmfoundry.utils.warnings import VersionedDeprecationWarning

log = logging.getLogger(__name__)

__all__ = ['CurriculumLearning']


@dataclass
class CurriculumLearningState:
schedule: list[dict[str, Any]]
schedule_index: int


class CurriculumLearning(CallbackWithConfig):
"""Starts an epoch with a different dataset when resuming from a checkpoint.
Expand All @@ -57,8 +62,8 @@ class CurriculumLearning(CallbackWithConfig):
being used. Note that this is the full train config and must
contain the 'train_loader', 'device_train_batch_size', and
'tokenizer' keys.
duration (Union[Time, str, int], optional): The duration of the first datamix
(which corresponds to the train_loader). Defaults to None.
duration (Union[Time, str, int]): The duration of the first datamix
(which corresponds to the train_loader).
schedule (list[dict[str, Any]]): The list of datamixes to use and their
durations. Duration units must match max_duration and be in terms of
a TimeUnit that is supported by Iteration. The duration values must
Expand All @@ -73,29 +78,17 @@ def __init__(
self,
train_config: dict[str, Any],
schedule: list[dict[str, Any]],
duration: Optional[Union[Time, str, int]] = None,
duration: Union[Time, str, int],
):
if duration is None:
warnings.warn(
VersionedDeprecationWarning(
'Specifying the full schedule in the CurriculumLearning ' +
'callback is deprecated. Please specify the duration of ' +
'the first datamix separately and change the schedule ' +
'use datasets instead of dataloaders.',
remove_version='0.15.0',
),
)

# Ensure all duration units are in epochs or tokens and values are positive
self._schedule = schedule
if len(self._schedule) == 0:
raise ValueError('The schedule must have at least one datamix.')
if duration is not None:
first_datamix = {
'duration': duration,
'dataset': train_config['train_loader']['dataset'],
}
self._schedule.insert(0, first_datamix)
first_datamix = {
'duration': duration,
'dataset': train_config['train_loader']['dataset'],
}
self._schedule.insert(0, first_datamix)
for datamix in self._schedule:
self._validate_datamix(datamix)

Expand Down Expand Up @@ -167,10 +160,7 @@ def iteration_start(self, state: State, logger: Logger):
clean_stale_shared_memory()
datamix = copy.deepcopy(self._schedule[self._schedule_index])
train_loader_config = copy.deepcopy(self._train_loader_config)
if 'dataset' in datamix:
train_loader_config['dataset'].update(datamix['dataset'])
else:
train_loader_config = datamix['train_loader']
train_loader_config['dataset'].update(datamix['dataset'])
data_spec = self._build_train_loader(
train_loader_config=train_loader_config,
logger=logger,
Expand All @@ -193,25 +183,29 @@ def iteration_end(self, state: State, logger: Logger):

def state_dict(self):
return {
'schedule': self._schedule,
'schedule_index': self._schedule_index,
'state':
CurriculumLearningState(
schedule=self._schedule,
schedule_index=self._schedule_index,
),
}

def load_state_dict(self, state: dict[str, Any]):
self._schedule_index = state['schedule_index']
schedule = state['state'].schedule
self._schedule_index = state['state'].schedule_index

# Ensure that the schedule has not changed on previously trained datamixes
for idx in range(state['schedule_index']):
if self._schedule[idx] != state['schedule'][idx]:
for idx in range(self._schedule_index):
if self._schedule[idx] != schedule[idx]:
raise ValueError((
f'Previous datamixes must stay the same across ',
f'resumptions. Expected {state["schedule"][idx]} but got ',
f'resumptions. Expected {schedule[idx]} but got ',
f'{self._schedule[idx]}',
))

# Ensure that the datamix has not changed on the current datamix
current_loader = self._schedule[self._schedule_index]['train_loader']
saved_loader = state['schedule'][self._schedule_index]['train_loader']
saved_loader = schedule[self._schedule_index]['train_loader']
if current_loader != saved_loader:
raise ValueError((
f'The current datamix must stay the same across resumptions. ',
Expand Down Expand Up @@ -282,5 +276,5 @@ def _validate_datamix(self, datamix: dict[str, Any]):
'Schedules can only be defined in terms of epochs or tokens.',
)

if 'train_loader' not in datamix and 'dataset' not in datamix:
if 'dataset' not in datamix:
raise ValueError('Each datamix must have a dataset.')
16 changes: 12 additions & 4 deletions tests/callbacks/test_curriculum_learning_callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
from omegaconf import OmegaConf as om
from torch.utils.data import DataLoader

from llmfoundry.callbacks.curriculum_learning_callback import \
CurriculumLearningState
from llmfoundry.data.text_data import StreamingTextDataset
from llmfoundry.utils.builders import build_callback

Expand Down Expand Up @@ -237,8 +239,11 @@ def test_curriculum_learning_callback_state_dict(build_tiny_mpt: Callable,):
callback.iteration_start(state, logger)
callback.iteration_end(state, logger)
assert callback.state_dict() == {
'schedule': kwargs['schedule'],
'schedule_index': 1,
'state':
CurriculumLearningState(
schedule=kwargs['schedule'],
schedule_index=1,
),
}


Expand Down Expand Up @@ -280,8 +285,11 @@ def test_curriculum_learning_callback_load_state_dict(
callback.iteration_start(state, logger)
callback.iteration_end(state, logger)
assert callback.state_dict() == {
'schedule': kwargs['schedule'],
'schedule_index': 1,
'state':
CurriculumLearningState(
schedule=kwargs['schedule'],
schedule_index=1,
),
}


Expand Down

0 comments on commit 1654827

Please sign in to comment.