Skip to content

Commit

Permalink
Merge branch 'main' into callback-tests
Browse files Browse the repository at this point in the history
  • Loading branch information
irenedea authored Oct 11, 2024
2 parents b3c2578 + 85b251f commit 1d52ee8
Show file tree
Hide file tree
Showing 7 changed files with 95 additions and 85 deletions.
78 changes: 36 additions & 42 deletions llmfoundry/callbacks/curriculum_learning_callback.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,16 @@
# Copyright 2022 MosaicML LLM Foundry authors
# Copyright 2024 MosaicML LLM Foundry authors
# SPDX-License-Identifier: Apache-2.0

"""Enable curriculum learning by resuming with a different dataset.
"""Enable curriculum learning by specifying a schedule of datasets to train on.
This callback is currently experimental. The API may change without warning in
the future.
This module provides a CurriculumLearning callback that allows for dynamic
dataset switching during training based on a predefined schedule.
"""

import copy
import logging
import warnings
from typing import Any, Optional, Union
from dataclasses import dataclass
from typing import Any, Union

from composer import DataSpec
from composer.core import State, Time, TimeUnit, ensure_time
Expand All @@ -24,13 +24,18 @@
BaseContextualError,
TrainDataLoaderLocation,
)
from llmfoundry.utils.warnings import VersionedDeprecationWarning

log = logging.getLogger(__name__)

__all__ = ['CurriculumLearning']


@dataclass
class CurriculumLearningState:
schedule: list[dict[str, Any]]
schedule_index: int


class CurriculumLearning(CallbackWithConfig):
"""Starts an epoch with a different dataset when resuming from a checkpoint.
Expand All @@ -57,8 +62,8 @@ class CurriculumLearning(CallbackWithConfig):
being used. Note that this is the full train config and must
contain the 'train_loader', 'device_train_batch_size', and
'tokenizer' keys.
duration (Union[Time, str, int], optional): The duration of the first datamix
(which corresponds to the train_loader). Defaults to None.
duration (Union[Time, str, int]): The duration of the first datamix
(which corresponds to the train_loader).
schedule (list[dict[str, Any]]): The list of datamixes to use and their
durations. Duration units must match max_duration and be in terms of
a TimeUnit that is supported by Iteration. The duration values must
Expand All @@ -73,29 +78,17 @@ def __init__(
self,
train_config: dict[str, Any],
schedule: list[dict[str, Any]],
duration: Optional[Union[Time, str, int]] = None,
duration: Union[Time, str, int],
):
if duration is None:
warnings.warn(
VersionedDeprecationWarning(
'Specifying the full schedule in the CurriculumLearning ' +
'callback is deprecated. Please specify the duration of ' +
'the first datamix separately and change the schedule ' +
'use datasets instead of dataloaders.',
remove_version='0.15.0',
),
)

# Ensure all duration units are in epochs or tokens and values are positive
self._schedule = schedule
if len(self._schedule) == 0:
raise ValueError('The schedule must have at least one datamix.')
if duration is not None:
first_datamix = {
'duration': duration,
'dataset': train_config['train_loader']['dataset'],
}
self._schedule.insert(0, first_datamix)
first_datamix = {
'duration': duration,
'dataset': train_config['train_loader']['dataset'],
}
self._schedule.insert(0, first_datamix)
for datamix in self._schedule:
self._validate_datamix(datamix)

Expand Down Expand Up @@ -167,10 +160,7 @@ def iteration_start(self, state: State, logger: Logger):
clean_stale_shared_memory()
datamix = copy.deepcopy(self._schedule[self._schedule_index])
train_loader_config = copy.deepcopy(self._train_loader_config)
if 'dataset' in datamix:
train_loader_config['dataset'].update(datamix['dataset'])
else:
train_loader_config = datamix['train_loader']
train_loader_config['dataset'].update(datamix['dataset'])
data_spec = self._build_train_loader(
train_loader_config=train_loader_config,
logger=logger,
Expand All @@ -193,29 +183,33 @@ def iteration_end(self, state: State, logger: Logger):

def state_dict(self):
return {
'schedule': self._schedule,
'schedule_index': self._schedule_index,
'state':
CurriculumLearningState(
schedule=self._schedule,
schedule_index=self._schedule_index,
),
}

def load_state_dict(self, state: dict[str, Any]):
self._schedule_index = state['schedule_index']
schedule = state['state'].schedule
self._schedule_index = state['state'].schedule_index

# Ensure that the schedule has not changed on previously trained datamixes
for idx in range(state['schedule_index']):
if self._schedule[idx] != state['schedule'][idx]:
for idx in range(self._schedule_index):
if self._schedule[idx] != schedule[idx]:
raise ValueError((
f'Previous datamixes must stay the same across ',
f'resumptions. Expected {state["schedule"][idx]} but got ',
f'resumptions. Expected {schedule[idx]} but got ',
f'{self._schedule[idx]}',
))

# Ensure that the datamix has not changed on the current datamix
current_loader = self._schedule[self._schedule_index]['train_loader']
saved_loader = state['schedule'][self._schedule_index]['train_loader']
if current_loader != saved_loader:
current_dataset = self._schedule[self._schedule_index]['dataset']
saved_dataset = schedule[self._schedule_index]['dataset']
if current_dataset != saved_dataset:
raise ValueError((
f'The current datamix must stay the same across resumptions. ',
f'Expected {saved_loader} but got {current_loader}',
f'Expected {saved_dataset} but got {current_dataset}',
))

# Ensure that the current datamix duration is in the correct units
Expand Down Expand Up @@ -282,5 +276,5 @@ def _validate_datamix(self, datamix: dict[str, Any]):
'Schedules can only be defined in terms of epochs or tokens.',
)

if 'train_loader' not in datamix and 'dataset' not in datamix:
if 'dataset' not in datamix:
raise ValueError('Each datamix must have a dataset.')
64 changes: 31 additions & 33 deletions llmfoundry/command_utils/data_prep/convert_text_to_mds.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,41 +240,39 @@ def download_and_convert(
object_store = maybe_create_object_store_from_uri(input_folder)

# Download file_names
with tempfile.TemporaryDirectory() as tmp_dir:
log.info(f'Created temporary directory: {tmp_dir}')
downloading_iter = DownloadingIterable(
object_names=file_names,
output_folder=tmp_dir,
object_store=object_store,
)
log.info(f'Initializing tokenizer: {tokenizer_name}')
tokenizer = AutoTokenizer.from_pretrained(
tokenizer_name,
trust_remote_code=trust_remote_code,
)
tokenizer.model_max_length = 5000000000 # Hack to prevent warnings from HuggingFace

# Use the ConcatTokensDataset from LLM-foundry to concatenate sequences of tokens up
# to the maximum sequence length
dataset = ConcatTokensFromFilesDataset(
files=downloading_iter,
max_length=concat_tokens,
tokenizer=tokenizer,
eos_text=eos_text,
bos_text=bos_text,
no_wrap=no_wrap,
)
downloading_iter = DownloadingIterable(
object_names=file_names,
output_folder=None, # Downloads to temporary files.
object_store=object_store,
)
log.info(f'Initializing tokenizer: {tokenizer_name}')
tokenizer = AutoTokenizer.from_pretrained(
tokenizer_name,
trust_remote_code=trust_remote_code,
)
tokenizer.model_max_length = 5000000000 # Hack to prevent warnings from HuggingFace

# Use the ConcatTokensDataset from LLM-foundry to concatenate sequences of tokens up
# to the maximum sequence length
dataset = ConcatTokensFromFilesDataset(
files=downloading_iter,
max_length=concat_tokens,
tokenizer=tokenizer,
eos_text=eos_text,
bos_text=bos_text,
no_wrap=no_wrap,
)

columns = {'tokens': 'ndarray:int32'}
columns = {'tokens': 'ndarray:int32'}

log.info('Converting to MDS format...')
with MDSWriter(
out=output_folder,
columns=columns,
compression=compression,
) as out:
for sample in tqdm(dataset):
out.write(sample)
log.info('Converting to MDS format...')
with MDSWriter(
out=output_folder,
columns=columns,
compression=compression,
) as out:
for sample in tqdm(dataset):
out.write(sample)

log.info(f'Completed download and conversion for {len(file_names)} files')

Expand Down
3 changes: 3 additions & 0 deletions llmfoundry/command_utils/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,6 +234,9 @@ def train(cfg: DictConfig) -> Trainer:
logging.getLogger(__name__).setLevel(
train_cfg.python_log_level.upper(),
) # Train script
logging.getLogger('streaming').setLevel(
train_cfg.python_log_level.upper(),
) # Streaming module

_initialize_dist_with_barrier(dist_timeout=train_cfg.dist_timeout)

Expand Down
1 change: 1 addition & 0 deletions llmfoundry/data/finetuning/dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -590,6 +590,7 @@ def _download_remote_hf_dataset(remote_path: str, split: str) -> str:
]
raise FinetuningFileNotFoundError(
files_searched=files_searched,
supported_extensions=SUPPORTED_EXTENSIONS,
) from e
else:
log.debug(
Expand Down
8 changes: 5 additions & 3 deletions llmfoundry/utils/data_prep_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import json
import os
import tempfile
from glob import glob
from typing import Optional

Expand Down Expand Up @@ -105,7 +106,7 @@ class DownloadingIterable:
def __init__(
self,
object_names: list[str],
output_folder: str,
output_folder: Optional[str],
object_store: Optional[ObjectStore],
):
"""Iterable that downloads files before yielding the local filename.
Expand All @@ -114,7 +115,7 @@ def __init__(
Args:
object_names (List[str]): Names of objects to download
output_folder (str): Local folder to write downloaded files to
output_folder (Optional[str]): Local folder to write downloaded files to. If none, uses a temporary folder.
object_store (Optional[ObjectStore]): Object store to download from
"""
self.object_names = object_names
Expand All @@ -131,7 +132,8 @@ def __iter__(self):
output_filename = os.path.join(
self.output_folder,
object_name.strip('/'),
)
) if self.output_folder is not None else tempfile.NamedTemporaryFile(
).name

download_file(
object_store=self.object_store,
Expand Down
10 changes: 7 additions & 3 deletions llmfoundry/utils/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -486,16 +486,20 @@ def __str__(self):
class FinetuningFileNotFoundError(UserError):
"""Error thrown when a file can't be found with any supported extension."""

def __init__(self, files_searched: list[str]) -> None:
from llmfoundry.data.finetuning.tasks import SUPPORTED_EXTENSIONS
def __init__(
self,
files_searched: list[str],
supported_extensions: list[str],
) -> None:
message = (
f'Could not find a file with any of ' + \
f'the supported extensions: {SUPPORTED_EXTENSIONS}\n' + \
f'the supported extensions: {supported_extensions}\n' + \
f'at {files_searched}'
)
super().__init__(
message,
files_searched=files_searched,
supported_extensions=supported_extensions,
)


Expand Down
16 changes: 12 additions & 4 deletions tests/callbacks/test_curriculum_learning_callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
from omegaconf import OmegaConf as om
from torch.utils.data import DataLoader

from llmfoundry.callbacks.curriculum_learning_callback import \
CurriculumLearningState
from llmfoundry.data.text_data import StreamingTextDataset
from llmfoundry.utils.builders import build_callback

Expand Down Expand Up @@ -237,8 +239,11 @@ def test_curriculum_learning_callback_state_dict(build_tiny_mpt: Callable,):
callback.iteration_start(state, logger)
callback.iteration_end(state, logger)
assert callback.state_dict() == {
'schedule': kwargs['schedule'],
'schedule_index': 1,
'state':
CurriculumLearningState(
schedule=kwargs['schedule'],
schedule_index=1,
),
}


Expand Down Expand Up @@ -280,8 +285,11 @@ def test_curriculum_learning_callback_load_state_dict(
callback.iteration_start(state, logger)
callback.iteration_end(state, logger)
assert callback.state_dict() == {
'schedule': kwargs['schedule'],
'schedule_index': 1,
'state':
CurriculumLearningState(
schedule=kwargs['schedule'],
schedule_index=1,
),
}


Expand Down

0 comments on commit 1d52ee8

Please sign in to comment.