Skip to content

Commit

Permalink
Fix typos in callbacks with configs (#1146)
Browse files Browse the repository at this point in the history
  • Loading branch information
dakinggg authored Apr 29, 2024
1 parent 8be3254 commit 704a90a
Show file tree
Hide file tree
Showing 9 changed files with 65 additions and 29 deletions.
12 changes: 6 additions & 6 deletions llmfoundry/callbacks/async_eval_callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,28 +210,28 @@ class AsyncEval(CallbackWithConfig):

def __init__(
self,
training_params: Dict[str, Any],
train_config: Dict[str, Any],
interval: Union[str, int, Time],
eval_run_config: Optional[Dict[str, Any]] = None,
):

# Run these during init to fail fast in any of the error cases
for required in ('save_interval', 'save_folder'):
if required not in training_params:
if required not in train_config:
raise ValueError(f'{required} required for async eval')

if '/' in training_params.get('save_filename', ''):
if '/' in train_config.get('save_filename', ''):
raise ValueError(
'AsyncEval not supported for save_filename that includes a path'
)

self.checkpoint_save_folder = training_params['save_folder']
self.training_params = training_params
self.checkpoint_save_folder = train_config['save_folder']
self.training_params = train_config
self.eval_run_config = validate_eval_run_config(eval_run_config)

self.current_run = self._get_current_run()
get_eval_parameters(
parameters=training_params,
parameters=train_config,
checkpoint='test',
training_run_name=self.current_run.name,
)
Expand Down
4 changes: 2 additions & 2 deletions llmfoundry/callbacks/curriculum_learning_callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,13 +33,13 @@ class CurriculumLearning(CallbackWithConfig):
being used.
"""

def __init__(self, dataset_index: int, train_config: Dict):
def __init__(self, train_config: Dict, dataset_index: int):
self.dataset_index = dataset_index
self.saved_dataset_index = 0
self.all_dataset_configs = []
self.current_dataset_state = {}
# The current dataset config is resolved and passed in train.py
self.current_dataset_config = train_config['dataloader']
self.current_dataset_config = train_config['train_loader']

def before_load(self, state: State, logger: Logger):
del logger
Expand Down
8 changes: 4 additions & 4 deletions llmfoundry/utils/builders.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,18 +217,18 @@ def build_composer_model(
def build_callback(
name: str,
kwargs: Optional[Dict[str, Any]] = None,
config: Any = None,
train_config: Any = None,
) -> Callback:
"""Builds a callback from the registry."""
registry_to_use = registry.callbacks
if name in registry.callbacks_with_config:
if kwargs is None:
kwargs = {}
if 'config' in kwargs:
if 'train_config' in kwargs:
raise ValueError(
f'`config` is a reserved keyword for callbacks with config. Please remove it from the kwargs.'
f'`train_config` is a reserved keyword for callbacks with config. Please remove it from the kwargs.'
)
kwargs['config'] = config
kwargs['train_config'] = train_config
registry_to_use = registry.callbacks_with_config

return construct_from_registry(name=name,
Expand Down
2 changes: 1 addition & 1 deletion scripts/eval/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ def evaluate_model(

# Callbacks
callbacks: List[Callback] = [
build_callback(str(name), callback_cfg)
build_callback(name=str(name), kwargs=callback_cfg)
for name, callback_cfg in callback_configs.items()
] if callback_configs else []

Expand Down
4 changes: 3 additions & 1 deletion scripts/train/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -481,7 +481,9 @@ def main(cfg: DictConfig) -> Trainer:

# Callbacks
callbacks: List[Callback] = [
build_callback(str(name), callback_cfg, om.to_container(logged_cfg))
build_callback(name=str(name),
kwargs=callback_cfg,
train_config=om.to_container(logged_cfg))
for name, callback_cfg in callback_configs.items()
] if callback_configs else []

Expand Down
23 changes: 23 additions & 0 deletions tests/callbacks/test_async_eval_callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
get_run_name,
validate_eval_run_config,
validate_interval)
from llmfoundry.utils.builders import build_callback
from mcli import Run, RunConfig, RunStatus

RUN_NAME = 'foo_bar-1234'
Expand Down Expand Up @@ -242,6 +243,28 @@ def test_validate_eval_run_config():
)


@patch('llmfoundry.callbacks.async_eval_callback.get_run',
return_value=FAKE_RUN)
def test_async_eval_callback_builds(mock_get_run: MagicMock):
kwargs = {'interval': 1}
config = {
'save_folder': 'foo',
'save_interval': 1,
'device_eval_batch_size': 2,
'max_seq_len': 3,
'model': {
'name': 'foo',
},
'tokenizer': {},
'icl_tasks': [],
}
callback = build_callback('async_eval', kwargs=kwargs, train_config=config)
assert isinstance(callback, AsyncEval)
assert callback.current_run.name == RUN_NAME
assert mock_get_run.call_count == 1
assert mock_get_run.call_args[0][0] == RUN_NAME


@patch('llmfoundry.callbacks.async_eval_callback.get_run',
return_value=FAKE_RUN)
@patch('llmfoundry.callbacks.async_eval_callback.create_run',
Expand Down
12 changes: 12 additions & 0 deletions tests/callbacks/test_curriculum_learning_callback.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
# Copyright 2024 MosaicML LLM Foundry authors
# SPDX-License-Identifier: Apache-2.0

from llmfoundry.utils.builders import build_callback


def test_curriculum_learning_callback_builds():
kwargs = {'dataset_index': 0}
callback = build_callback('curriculum_learning',
kwargs=kwargs,
train_config={'train_loader': {}})
assert callback is not None
2 changes: 1 addition & 1 deletion tests/callbacks/test_mbmoe_tok_per_expert_callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,6 @@

def test_mbmoe_tok_per_expert_builds():
"""Test that the callback can be built."""
callback = build_callback('mbmoe_tok_per_expert')
callback = build_callback(name='mbmoe_tok_per_expert')
assert callback is not None
assert callback.__class__.__name__ == 'MegaBlocksMoE_TokPerExpert'
27 changes: 13 additions & 14 deletions tests/utils/test_builders.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def test_tokenizer_no_EOS():

def test_build_callback_fails():
with pytest.raises(ValueError):
build_callback('nonexistent_callback', {}, {})
build_callback(name='nonexistent_callback', kwargs={}, train_config={})


@pytest.mark.parametrize(
Expand All @@ -72,14 +72,13 @@ def test_build_generate_callback(
autospec=True) as mock_generate:
mock_generate.return_value = None
build_callback(
'generate_callback',
{
name='generate_callback',
kwargs={
'prompts': ['hello'],
interval_key: interval_value,
'foo': 'bar',
'something': 'else',
},
{},
)

assert mock_generate.call_count == 1
Expand All @@ -96,13 +95,12 @@ def test_build_generate_callback_unspecified_interval():
autospec=True) as mock_generate:
mock_generate.return_value = None
build_callback(
'generate_callback',
{
name='generate_callback',
kwargs={
'prompts': ['hello'],
'foo': 'bar',
'something': 'else',
},
{},
)


Expand All @@ -120,13 +118,14 @@ def test_build_hf_checkpointer_callback():
'task': 'llm/v1/completions'
}
}
build_callback(name='hf_checkpointer',
kwargs={
'save_folder': save_folder,
'save_interval': save_interval,
'mlflow_logging_config': mlflow_logging_config_dict
},
config={})
build_callback(
name='hf_checkpointer',
kwargs={
'save_folder': save_folder,
'save_interval': save_interval,
'mlflow_logging_config': mlflow_logging_config_dict
},
)

assert mock_hf_checkpointer.call_count == 1
_, _, kwargs = mock_hf_checkpointer.mock_calls[0]
Expand Down

0 comments on commit 704a90a

Please sign in to comment.