Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add deprecation warning to fsdp_config #1530

Merged
merged 16 commits into from
Sep 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 32 additions & 3 deletions llmfoundry/command_utils/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,15 @@
import logging
import os
import time
import warnings
from typing import Any, Optional, Union

import pandas as pd
import torch
from composer.core import Callback
from composer.loggers.logger_destination import LoggerDestination
from composer.trainer import Trainer
from composer.utils import dist, get_device, reproducibility
from composer.utils import dist, get_device, parallelism, reproducibility
from omegaconf import DictConfig
from omegaconf import OmegaConf as om

Expand All @@ -36,6 +37,7 @@
process_init_device,
)
from llmfoundry.utils.registry_utils import import_file
from llmfoundry.utils.warnings import VersionedDeprecationWarning

log = logging.getLogger(__name__)

Expand All @@ -52,7 +54,6 @@ def evaluate_model(
device_eval_batch_size: Union[int, float],
eval_gauntlet_config: Optional[Union[str, dict[str, Any]]],
eval_loader_config: Optional[Union[dict[str, Any], list[dict[str, Any]]]],
fsdp_config: Optional[dict[str, Any]],
loggers: list[LoggerDestination],
python_log_level: Optional[str],
precision: str,
Expand All @@ -62,9 +63,33 @@ def evaluate_model(
callback_configs: Optional[dict[str, Any]],
metadata: Optional[dict[str, str]],
logged_config: dict[str, Any],
fsdp_config: Optional[dict[str, Any]] = None,
parallelism_config: Optional[dict[str, Any]] = None,
should_log_config: bool = True,
load_path: Optional[str] = None,
):
if parallelism_config:
deprecated_fsdp_args = list(
parallelism.FSDPConfig.__annotations__.keys(),
)
for deprecated_arg in deprecated_fsdp_args:
if deprecated_arg in parallelism_config:
raise ValueError(
'parallelism_config cannot contain deprecated fsdp_config arguments.',
)

if fsdp_config:
warnings.warn(
VersionedDeprecationWarning(
'The argument fsdp_config is deprecated. Please use parallelism_config instead.',
remove_version='0.13.0',
),
)
if fsdp_config and parallelism_config:
dakinggg marked this conversation as resolved.
Show resolved Hide resolved
raise ValueError(
'Both fsdp_config and parallelism_config cannot be provided at the same time. Please use parallelism_config.',
)

log.info(f'Evaluating model: {model_name}')
# Build tokenizer and model
tokenizer_cfg = tokenizer
Expand Down Expand Up @@ -99,6 +124,10 @@ def evaluate_model(
mosaicml_logger.log_metrics(metadata)
mosaicml_logger._flush_metadata(force_flush=True)

fsdp_config = parallelism_config.get(
'fsdp_config',
None,
) if parallelism_config else fsdp_config
if fsdp_config and model.get('load_in_8bit', False):
raise ValueError(
'The FSDP config block is not supported when loading ' +
Expand Down Expand Up @@ -146,7 +175,7 @@ def evaluate_model(
callbacks=callbacks,
loggers=loggers,
precision=precision,
fsdp_config=fsdp_config,
parallelism_config={'fsdp': fsdp_config},
load_path=load_path,
load_weights_only=True,
progress_bar=False,
Expand Down
5 changes: 3 additions & 2 deletions tests/a_scripts/inference/test_convert_composer_to_hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -1042,7 +1042,8 @@ def test_huggingface_conversion_callback(
model=original_model,
device='gpu',
precision=trainer_precision,
fsdp_config=fsdp_config if fsdp_state_dict_type is not None else None,
parallelism_config={'fsdp': fsdp_config}
if fsdp_state_dict_type is not None else None,
train_dataloader=train_dataloader,
save_folder=os.path.join(tmp_path, 'checkpoints'),
save_interval=save_interval,
Expand Down Expand Up @@ -1469,7 +1470,7 @@ def test_mptmoe_huggingface_conversion_callback(
trainer = Trainer(
model=original_model,
device='gpu',
fsdp_config=fsdp_config,
parallelism_config={'fsdp': fsdp_config},
train_dataloader=train_dataloader,
save_folder=os.path.join(tmp_path, 'checkpoints'),
save_interval=save_interval,
Expand Down
125 changes: 125 additions & 0 deletions tests/eval/test_eval_deprecation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
# Copyright 2024 MosaicML LLM Foundry authors
# SPDX-License-Identifier: Apache-2.0

import unittest
import warnings

from llmfoundry.command_utils.eval import evaluate_model
from llmfoundry.utils.warnings import VersionedDeprecationWarning


class TestEvaluateModelDeprecation(unittest.TestCase):

def setUp(self):
self.common_args = { # type: ignore
'tokenizer': {
'name': 'test_tokenizer',
},
'model': {
'name': 'test_model',
},
'model_name': 'test',
'dist_timeout': 60,
'run_name': 'test_run',
'seed': 42,
'icl_tasks': [],
'max_seq_len': 512,
'device_eval_batch_size': 1,
'eval_gauntlet_config': None,
'eval_loader_config': None,
'loggers': [],
'python_log_level': None,
'precision': 'fp32',
'eval_gauntlet_df': None,
'eval_subset_num_batches': 1,
'icl_subset_num_batches': None,
'callback_configs': None,
'metadata': None,
'logged_config': {},
}

def test_no_deprecation_warning(self):
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter('always')
import composer.utils.parallelism
deprecated_fsdp_args = list(
composer.utils.parallelism.FSDPConfig.__annotations__.keys(),
)
print(deprecated_fsdp_args)

try:
parallelism_config = {'fsdp': {'verbose': True}}
evaluate_model(
**self.common_args,
parallelism_config=parallelism_config,
)
except ValueError as ve:
if 'parallelism_config cannot contain deprecated fsdp_config arguments.' in str(
ve,
):
self.fail(
'Raised ValueError about deprecated fsdp_config arguments',
)
elif 'Both fsdp_config and parallelism_config cannot be provided at the same time.' in str(
ve,
):
self.fail(
'Raised ValueError about both configs being provided',
)
except Exception:
pass

deprecation_warnings = [
warning for warning in w
if isinstance(warning.message, VersionedDeprecationWarning)
]
if deprecation_warnings:
self.fail('VersionedDeprecationWarning was raised')

def test_deprecation_warning_with_deprecated_arg(self):
# Use assertRaises to catch the expected ValueError
with self.assertRaises(ValueError) as context:
# Directly call evaluate_model; do not use try-except here
evaluate_model(
**self.common_args,
parallelism_config={'activation_checkpointing': True},
)

# Assert that the correct error message is in the exception
self.assertIn(
'parallelism_config cannot contain deprecated fsdp_config arguments.',
str(context.exception),
)

def test_deprecation_warning_with_fsdp_config(self):
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter('always')

try:
evaluate_model(
**self.common_args,
parallelism_config=None,
fsdp_config={'verbose': True},
)
except Exception:
pass

self.assertTrue(
any(
issubclass(warning.category, VersionedDeprecationWarning)
for warning in w
),
)

def test_error_with_both_fsdp_and_parallelism_config(self):
with self.assertRaises(ValueError) as context:
evaluate_model(
**self.common_args,
parallelism_config={'some_arg': True},
fsdp_config={'some_arg': True},
)

self.assertIn(
'Both fsdp_config and parallelism_config cannot be provided at the same time.',
str(context.exception),
)
2 changes: 1 addition & 1 deletion tests/models/hf/test_fsdp_weight_tying.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ def test_fsdp_weight_tying(
trainer = Trainer(
model=original_model,
device='gpu',
fsdp_config=fsdp_config,
parallelism_config={'fsdp': fsdp_config},
train_dataloader=[],
device_train_microbatch_size=1,
)
Expand Down
2 changes: 1 addition & 1 deletion tests/models/hf/test_hf_peft_wrapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ def test_lora_mixed_init(
trainer = Trainer(
model=original_model,
device='gpu',
fsdp_config=fsdp_config,
parallelism_config={'fsdp': fsdp_config},
train_dataloader=[],
device_train_microbatch_size=1,
)
Expand Down
2 changes: 1 addition & 1 deletion tests/models/test_fsdp_act_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def test_fsdp_act_checkpoint(
trainer = Trainer(
model=model,
device='gpu',
fsdp_config=fsdp_config,
parallelism_config={'fsdp': fsdp_config},
)

assert trainer.state.fsdp_enabled
Expand Down
Loading