Skip to content

Commit

Permalink
Add transformation hooks to hf_causal_lm (#1383)
Browse files Browse the repository at this point in the history
  • Loading branch information
irenedea authored Jul 23, 2024
1 parent d2d29ad commit cefd616
Show file tree
Hide file tree
Showing 3 changed files with 139 additions and 51 deletions.
72 changes: 21 additions & 51 deletions llmfoundry/models/hf/hf_causal_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
Any,
Dict,
List,
Mapping,
Optional,
Tuple,
Union,
Expand All @@ -23,7 +22,6 @@
from transformers import (
AutoConfig,
AutoModelForCausalLM,
PretrainedConfig,
PreTrainedModel,
PreTrainedTokenizerBase,
)
Expand All @@ -36,7 +34,7 @@
from llmfoundry.models.hf.model_wrapper import HuggingFaceModelWithFSDP
from llmfoundry.models.layers.attention import is_flash_v2_installed
from llmfoundry.models.utils import init_empty_weights
from llmfoundry.utils.config_utils import get_hf_config_value
from llmfoundry.utils.config_utils import set_config_overrides

if TYPE_CHECKING:
from peft import PeftConfig, PeftModel
Expand Down Expand Up @@ -105,9 +103,13 @@ def __init__(
config_overrides=config_overrides,
load_in_8bit=load_in_8bit,
pretrained=pretrained,
prepare_for_fsdp=True,
prepare_for_fsdp=False,
)

model = self.transform_model(model)

ComposerHFCausalLM.prepare_inner_model(model, init_device)

train_metrics, eval_metrics = ComposerHFCausalLM.build_metrics(
use_train_metrics=use_train_metrics,
additional_train_metrics=additional_train_metrics,
Expand All @@ -121,7 +123,7 @@ def __init__(

peft_config_object = None
if peft_config is not None:
peft_config_object = self._get_peft_config(peft_config)
peft_config_object = self.get_peft_config(peft_config)

# Set up config args for the model construction and base classes
super().__init__(
Expand All @@ -135,6 +137,17 @@ def __init__(
should_save_peft_only=should_save_peft_only,
)

def transform_model(self, model: PreTrainedModel) -> PreTrainedModel:
"""Transforms the model after initialization.
Args:
model (PreTrainedModel): The model to transform.
Returns:
PreTrainedModel: The transformed model.
"""
return model

@staticmethod
def build_metrics(
use_train_metrics: bool,
Expand Down Expand Up @@ -259,50 +272,7 @@ def _autoset_attn_implementation_monkeypatch(
_autoset_attn_implementation_monkeypatch,
)

# set config overrides
for k, v in config_overrides.items():
if not hasattr(config, k):
raise ValueError(
f'config does not have attribute "{k}" to override ({k}: {v}).',
)

attr = getattr(config, k)
# attempt to disallow typos in nested configs
if isinstance(attr, Mapping):
extra_keys = [_k for _k in v.keys() if _k not in attr.keys()]
if extra_keys:
raise ValueError(
f'Config dict override got unknown keys. ' +
f'Extra keys: {extra_keys}. ' +
f'Expected (a subset of) keys: {list(attr.keys())}.',
)
getattr(config, k).update(v)
# necessary case to allow for rope_scaling to be overriden in llama config
elif attr is None and isinstance(v, Mapping):
setattr(config, k, {})
getattr(config, k).update(v)
elif isinstance(attr, PretrainedConfig):
if not isinstance(v, Mapping):
raise ValueError(
f'Expected a dictionary for config override {k}, but got {v}.',
)

for _k, _v in v.items():
if not hasattr(attr, _k):
raise ValueError(
f'config does not have attribute "{_k}" to override ({k}: {_k}: {_v}).',
)
setattr(attr, _k, _v)
else:
setattr(config, k, v)

if hasattr(config, 'attn_config') and get_hf_config_value(
config.attn_config,
'seq_parallel_world_size',
) is not None:
raise NotImplementedError(
'Sequence Parallelism is not supported for HuggingFace models.',
)
set_config_overrides(config, config_overrides)

# We need to have all non-zero local ranks be not-pretrained
# Rank 0 will still be pretrained, and distribute the weights appropriately
Expand Down Expand Up @@ -395,10 +365,10 @@ def _autoset_attn_implementation_monkeypatch(

if prepare_for_fsdp:
ComposerHFCausalLM.prepare_inner_model(model, init_device)

return model

@staticmethod
def _get_peft_config(peft_config_dict: Dict[str, Any]) -> 'PeftConfig':
def get_peft_config(self, peft_config_dict: Dict[str, Any]) -> 'PeftConfig':
if peft_installed:
from peft import LoraConfig
peft_type = peft_config_dict.get('peft_type', '')
Expand Down
42 changes: 42 additions & 0 deletions llmfoundry/utils/config_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -812,3 +812,45 @@ def _verify_uc_path(path: str) -> bool:
f'but your `UCVolumeDatasetSource` might be invalid.',
)
return False


def set_config_overrides(
config: PretrainedConfig,
config_overrides: Dict[str, Any],
):
# set config overrides
for k, v in config_overrides.items():
if not hasattr(config, k):
raise ValueError(
f'config does not have attribute "{k}" to override ({k}: {v}).',
)

attr = getattr(config, k)
# attempt to disallow typos in nested configs
if isinstance(attr, Mapping):
extra_keys = [_k for _k in v.keys() if _k not in attr.keys()]
if extra_keys:
raise ValueError(
f'Config dict override got unknown keys. ' +
f'Extra keys: {extra_keys}. ' +
f'Expected (a subset of) keys: {list(attr.keys())}.',
)
getattr(config, k).update(v)
# necessary case to allow for rope_scaling to be overriden in llama config
elif attr is None and isinstance(v, Mapping):
setattr(config, k, {})
getattr(config, k).update(v)
elif isinstance(attr, PretrainedConfig):
if not isinstance(v, Mapping):
raise ValueError(
f'Expected a dictionary for config override {k}, but got {v}.',
)

for _k, _v in v.items():
if not hasattr(attr, _k):
raise ValueError(
f'config does not have attribute "{_k}" to override ({k}: {_k}: {_v}).',
)
setattr(attr, _k, _v)
else:
setattr(config, k, v)
76 changes: 76 additions & 0 deletions tests/models/hf/test_hf_transform.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
# Copyright 2024 MosaicML LLM Foundry authors
# SPDX-License-Identifier: Apache-2.0

from typing import Any, Dict, Optional

import pytest
from composer.models.huggingface import maybe_get_underlying_model
from peft import PeftConfig, PeftModel
from transformers import LlamaForCausalLM, PreTrainedModel

from llmfoundry.models.hf.hf_causal_lm import ComposerHFCausalLM
from llmfoundry.models.utils import init_empty_weights


@pytest.mark.gpu
@pytest.mark.parametrize(
'peft_config',
[
None,
{
'peft_type': 'LORA',
'task_type': 'CAUSAL_LM',
'lora_alpha': 32,
'r': 2,
'target_modules': [
'q_proj',
'k_proj',
'v_proj',
],
},
],
)
def test_hf_transform(peft_config: Optional[dict]):
model_cfg = {
'pretrained_model_name_or_path': 'codellama/CodeLlama-7b-hf',
'config_overrides': {
'num_hidden_layers': 2,
'hidden_size': 32,
'intermediate_size': 64,
},
'pretrained': False,
'peft_config': peft_config,
'init_device': 'meta',
'tokenizer': 'codellama/CodeLlama-7b-hf',
}

class TransformedHFCausalLM(ComposerHFCausalLM):

def transform_model(self, model: PreTrainedModel) -> PreTrainedModel:
assert isinstance(model, LlamaForCausalLM)
with init_empty_weights():
model.config.num_hidden_layers = 1
new_model = type(model)(model.config)
return new_model

def get_peft_config(
self,
peft_config_dict: Dict[str, Any],
) -> PeftConfig:
peft_config_dict['target_modules'] = ['o_proj']
return super().get_peft_config(peft_config_dict)

composer_model = TransformedHFCausalLM(**model_cfg)
model = composer_model.model
inner_model = maybe_get_underlying_model(model)

if peft_config:
peft_model = composer_model.model
assert isinstance(peft_model, PeftModel)

target_modules = peft_model.peft_config[peft_model.active_adapter
].target_modules
assert list(target_modules) == ['o_proj']

assert isinstance(inner_model, LlamaForCausalLM)
assert inner_model.config.num_hidden_layers == 1

0 comments on commit cefd616

Please sign in to comment.