Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add LoadPlanner and SavePlanner registries #1358

Merged
merged 13 commits into from
Jul 18, 2024
27 changes: 27 additions & 0 deletions llmfoundry/command_utils/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,10 @@
build_callback,
build_composer_model,
build_evaluators,
build_load_planner,
build_logger,
build_optimizer,
build_save_planner,
build_scheduler,
build_tokenizer,
)
Expand Down Expand Up @@ -256,6 +258,31 @@ def train(cfg: DictConfig) -> Trainer:
# Optional fsdp data, fine-tuning, and eval configs
fsdp_config: Optional[Dict[str, Any]] = train_cfg.fsdp_config

if fsdp_config is not None:
if 'load_planner' in fsdp_config:
load_planners = fsdp_config['load_planner'].items()
if len(load_planners) > 1:
raise ValueError(
'Only one load planner can be specified in the config.',
)
load_planner_name, load_planner_config = load_planners[0]
fsdp_config['load_planner'] = build_load_planner(
load_planner_name,
**load_planner_config,
)

if 'save_planner' in fsdp_config:
save_planners = fsdp_config['save_planner'].items()
if len(save_planners) > 1:
raise ValueError(
'Only one save planner can be specified in the config.',
)
save_planner_name, save_planner_config = save_planners[0]
fsdp_config['save_planner'] = build_save_planner(
save_planner_name,
**save_planner_config,
)

eval_loader_config = train_cfg.eval_loader if train_cfg.eval_loader is not None else train_cfg.eval_loaders
icl_tasks_config = train_cfg.icl_tasks or train_cfg.icl_tasks_str
eval_gauntlet_config = train_cfg.eval_gauntlet or train_cfg.eval_gauntlet_str
Expand Down
37 changes: 37 additions & 0 deletions llmfoundry/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from composer.loggers import LoggerDestination
from composer.models import ComposerModel
from composer.optim import ComposerScheduler
from torch.distributed.checkpoint import LoadPlanner, SavePlanner
from torch.optim import Optimizer
from torch.utils.data import DataLoader as TorchDataloader
from torch.utils.data import Dataset
Expand Down Expand Up @@ -339,6 +340,42 @@
description=_config_transforms_description,
)

_load_planners_description = (
"""The load_planners registry is used to register classes that implement the LoadPlanner interface.

The LoadPlanner will be passed as part of the FSDP config arg of the Trainer. It will be used to load distributed checkpoints.

Returns:
LoadPlanner: The load planner.
"""
)

load_planners = create_registry(
'llmfoundry',
'load_planners',
generic_type=Type[LoadPlanner],
entry_points=True,
description=_load_planners_description,
)

_save_planners_description = (
"""The save_planners registry is used to register classes that implement the SavePlanner interface.
irenedea marked this conversation as resolved.
Show resolved Hide resolved

The savePlanner will be passed as part of the FSDP config arg of the Trainer. It will be used to save distributed checkpoints.
irenedea marked this conversation as resolved.
Show resolved Hide resolved

Returns:
SavePlanner: The save planner.
"""
)

save_planners = create_registry(
irenedea marked this conversation as resolved.
Show resolved Hide resolved
'llmfoundry',
'save_planners',
generic_type=Type[SavePlanner],
entry_points=True,
description=_save_planners_description,
)

__all__ = [
'loggers',
'callbacks',
Expand Down
39 changes: 39 additions & 0 deletions llmfoundry/utils/builders.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from composer.utils import dist
from omegaconf import DictConfig
from omegaconf import OmegaConf as om
from torch.distributed.checkpoint import LoadPlanner, SavePlanner
from torch.optim.optimizer import Optimizer
from torchmetrics import Metric
from transformers import AutoTokenizer, PreTrainedTokenizerBase
Expand Down Expand Up @@ -187,6 +188,44 @@ def build_icl_data_and_gauntlet(
return icl_evaluators, logger_keys, eval_gauntlet_cb


def build_load_planner(name: str, **kwargs: Any) -> LoadPlanner:
"""Builds a load planner from the registry.

Args:
name: Name of the load planner to build.

Returns:
LoadPlanner: The load planner.
"""
return construct_from_registry(
name=name,
registry=registry.load_planners,
partial_function=True,
pre_validation_function=LoadPlanner,
post_validation_function=None,
kwargs=kwargs,
)


def build_save_planner(name: str, **kwargs: Any) -> SavePlanner:
"""Builds a save planner from the registry.

Args:
name: Name of the save planner to build.

Returns:
savePlanner: The save planner.
"""
return construct_from_registry(
name=name,
registry=registry.save_planners,
partial_function=True,
pre_validation_function=SavePlanner,
post_validation_function=None,
kwargs=kwargs,
)


def build_composer_model(
name: str,
cfg: Dict[str, Any],
Expand Down
2 changes: 2 additions & 0 deletions tests/test_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,8 @@ def test_expected_registries_exist():
'fcs',
'icl_datasets',
'config_transforms',
'load_planners',
'save_planners',
}

assert existing_registries == expected_registry_names
Expand Down
35 changes: 35 additions & 0 deletions tests/utils/test_builders.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,17 +13,24 @@
from composer.callbacks import Generate
from composer.core import Evaluator
from composer.loggers import WandBLogger
from torch.distributed.checkpoint.default_planner import (
DefaultLoadPlanner,
DefaultSavePlanner,
)
from transformers import PreTrainedTokenizerBase

from llmfoundry.callbacks import HuggingFaceCheckpointer
from llmfoundry.registry import load_planners, save_planners
from llmfoundry.tokenizers.tiktoken import TiktokenTokenizerWrapper
from llmfoundry.utils.builders import (
add_metrics_to_eval_loaders,
build_callback,
build_eval_loaders,
build_evaluators,
build_load_planner,
build_logger,
build_optimizer,
build_save_planner,
build_tokenizer,
)

Expand Down Expand Up @@ -345,6 +352,34 @@ def test_build_eval_loaders(monkeypatch: pytest.MonkeyPatch):
assert eval_loaders2[1].metric_names == []


def test_build_load_planner():
# Dummy LoadPlanner for testing
class DummyLoadPlanner(DefaultLoadPlanner):

def __init__(self, is_test: bool):
self.is_test = is_test

load_planners.register('dummy', func=DummyLoadPlanner)
load_planner = build_load_planner('dummy', is_test=True)

assert isinstance(load_planner, DummyLoadPlanner)
assert load_planner.is_test is True


def test_build_save_planner():
# Dummy SavePlanner for testing
class DummySavePlanner(DefaultSavePlanner):

def __init__(self, is_test: bool):
self.is_test = is_test

save_planners.register('dummy', func=DummySavePlanner)
save_planner = build_save_planner('dummy', is_test=True)

assert isinstance(save_planner, DummySavePlanner)
assert save_planner.is_test is True


def test_add_metrics_to_eval_loaders():
evaluators = [
Evaluator(
Expand Down
Loading