Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow train.py-like config for eval.py #1351

Merged
merged 14 commits into from
Jul 23, 2024
49 changes: 49 additions & 0 deletions llmfoundry/command_utils/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,54 @@ def evaluate_model(
return (trainer, logger_keys, eval_gauntlet_callback, eval_gauntlet_df)


def allow_toplevel_keys(cfg: Dict[str, Any]) -> Dict[str, Any]:
"""Transform the config to allow top-level keys for model configuration.

This function allows users to use the 'train.py' syntax in 'eval.py'.
It converts a config with top-level 'model', 'tokenizer', and (optionally) 'load_path' keys
into the nested 'models' list format required by 'eval.py'.

Input config format (train.py style):
```yaml
model:
<model_kwargs>
load_path: /path/to/checkpoint
tokenizer:
<tokenizer_kwargs>
```

Output config format (eval.py style):
```yaml
models:
- model:
<model_kwargs>
tokenizer:
<tokenizer_kwargs>
load_path: /path/to/checkpoint
```
"""
if 'model' in cfg:
if 'models' in cfg:
raise ValueError(
'Please specify either model or models in the config, not both',
)
default_name = cfg.get('model').get('name') # type: ignore
model_cfg = {
'model': cfg.pop('model'),
'tokenizer': cfg.pop('tokenizer', None),
'model_name': cfg.pop('model_name', default_name),
}
if 'tokenizer' not in model_cfg or model_cfg['tokenizer'] is None:
raise ValueError(
'When specifying model, "tokenizer" must be provided in the config',
)
if 'load_path' in cfg:
model_cfg['load_path'] = cfg.pop('load_path')
cfg['models'] = [model_cfg]

return cfg


def evaluate(cfg: DictConfig) -> Tuple[list[Trainer], pd.DataFrame]:
# Run user provided code if specified
for code_path in cfg.get('code_paths', []):
Expand All @@ -184,6 +232,7 @@ def evaluate(cfg: DictConfig) -> Tuple[list[Trainer], pd.DataFrame]:
cfg,
EvalConfig,
EVAL_CONFIG_KEYS,
transforms=[allow_toplevel_keys],
icl_tasks_required=True,
)

Expand Down
Loading