Skip to content

Commit

Permalink
Fix pytorch checkpointing for CL callback
Browse files Browse the repository at this point in the history
  • Loading branch information
b-chu committed Oct 10, 2024
1 parent 813d50e commit 7f8e3e7
Showing 1 changed file with 30 additions and 42 deletions.
72 changes: 30 additions & 42 deletions llmfoundry/callbacks/curriculum_learning_callback.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,13 @@
# Copyright 2022 MosaicML LLM Foundry authors
# Copyright 2024 MosaicML LLM Foundry authors
# SPDX-License-Identifier: Apache-2.0

"""Enable curriculum learning by resuming with a different dataset.
This callback is currently experimental. The API may change without warning in
the future.
"""Enable curriculum learning by specifying a schedule of datasets to train on.
"""

import copy
from dataclasses import dataclass
import logging
import warnings
from typing import Any, Optional, Union
from typing import Any, Union

from composer import DataSpec
from composer.core import State, Time, TimeUnit, ensure_time
Expand All @@ -24,13 +21,18 @@
BaseContextualError,
TrainDataLoaderLocation,
)
from llmfoundry.utils.warnings import VersionedDeprecationWarning

log = logging.getLogger(__name__)

__all__ = ['CurriculumLearning']


@dataclass
class CurriculumLearningState:
schedule: list[dict[str, Any]]
schedule_index: int


class CurriculumLearning(CallbackWithConfig):
"""Starts an epoch with a different dataset when resuming from a checkpoint.
Expand All @@ -57,8 +59,8 @@ class CurriculumLearning(CallbackWithConfig):
being used. Note that this is the full train config and must
contain the 'train_loader', 'device_train_batch_size', and
'tokenizer' keys.
duration (Union[Time, str, int], optional): The duration of the first datamix
(which corresponds to the train_loader). Defaults to None.
duration (Union[Time, str, int]): The duration of the first datamix
(which corresponds to the train_loader).
schedule (list[dict[str, Any]]): The list of datamixes to use and their
durations. Duration units must match max_duration and be in terms of
a TimeUnit that is supported by Iteration. The duration values must
Expand All @@ -73,29 +75,17 @@ def __init__(
self,
train_config: dict[str, Any],
schedule: list[dict[str, Any]],
duration: Optional[Union[Time, str, int]] = None,
duration: Union[Time, str, int],
):
if duration is None:
warnings.warn(
VersionedDeprecationWarning(
'Specifying the full schedule in the CurriculumLearning ' +
'callback is deprecated. Please specify the duration of ' +
'the first datamix separately and change the schedule ' +
'use datasets instead of dataloaders.',
remove_version='0.15.0',
),
)

# Ensure all duration units are in epochs or tokens and values are positive
self._schedule = schedule
if len(self._schedule) == 0:
raise ValueError('The schedule must have at least one datamix.')
if duration is not None:
first_datamix = {
'duration': duration,
'dataset': train_config['train_loader']['dataset'],
}
self._schedule.insert(0, first_datamix)
first_datamix = {
'duration': duration,
'dataset': train_config['train_loader']['dataset'],
}
self._schedule.insert(0, first_datamix)
for datamix in self._schedule:
self._validate_datamix(datamix)

Expand Down Expand Up @@ -167,10 +157,7 @@ def iteration_start(self, state: State, logger: Logger):
clean_stale_shared_memory()
datamix = copy.deepcopy(self._schedule[self._schedule_index])
train_loader_config = copy.deepcopy(self._train_loader_config)
if 'dataset' in datamix:
train_loader_config['dataset'].update(datamix['dataset'])
else:
train_loader_config = datamix['train_loader']
train_loader_config['dataset'].update(datamix['dataset'])
data_spec = self._build_train_loader(
train_loader_config=train_loader_config,
logger=logger,
Expand All @@ -192,26 +179,27 @@ def iteration_end(self, state: State, logger: Logger):
self._schedule_index += 1

def state_dict(self):
return {
'schedule': self._schedule,
'schedule_index': self._schedule_index,
}
return {'state': CurriculumLearningState(
schedule=self._schedule,
schedule_index=self._schedule_index,
)}

def load_state_dict(self, state: dict[str, Any]):
self._schedule_index = state['schedule_index']
schedule = state['state'].schedule
self._schedule_index = state['state'].schedule_index

# Ensure that the schedule has not changed on previously trained datamixes
for idx in range(state['schedule_index']):
if self._schedule[idx] != state['schedule'][idx]:
for idx in range(self._schedule_index):
if self._schedule[idx] != schedule[idx]:
raise ValueError((
f'Previous datamixes must stay the same across ',
f'resumptions. Expected {state["schedule"][idx]} but got ',
f'resumptions. Expected {schedule[idx]} but got ',
f'{self._schedule[idx]}',
))

# Ensure that the datamix has not changed on the current datamix
current_loader = self._schedule[self._schedule_index]['train_loader']
saved_loader = state['schedule'][self._schedule_index]['train_loader']
saved_loader = schedule[self._schedule_index]['train_loader']
if current_loader != saved_loader:
raise ValueError((
f'The current datamix must stay the same across resumptions. ',
Expand Down Expand Up @@ -282,5 +270,5 @@ def _validate_datamix(self, datamix: dict[str, Any]):
'Schedules can only be defined in terms of epochs or tokens.',
)

if 'train_loader' not in datamix and 'dataset' not in datamix:
if 'dataset' not in datamix:
raise ValueError('Each datamix must have a dataset.')

0 comments on commit 7f8e3e7

Please sign in to comment.