Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor Recipe creation flow: Directly convert Modifier Instances to Recipe #48

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 8 additions & 7 deletions src/llmcompressor/recipe/modifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,13 +83,14 @@ def create_modifier(self) -> "Modifier":
@model_validator(mode="before")
@classmethod
def extract_modifier_type(cls, values: Dict[str, Any]) -> Dict[str, Any]:
modifier = {"group": values.pop("group")}
assert len(values) == 1, "multiple key pairs found for modifier"
modifier_type, args = list(values.items())[0]

modifier["type"] = modifier_type
modifier["args"] = args
return modifier
if len(values) == 2:
# values contains only group and the Modifier type as keys
group = values.pop("group")
modifier_type, args = values.popitem()
return {"group": group, "type": modifier_type, "args": args}

# values already in the correct format
return values

def dict(self, *args, **kwargs) -> Dict[str, Any]:
"""
Expand Down
87 changes: 18 additions & 69 deletions src/llmcompressor/recipe/recipe.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from llmcompressor.recipe.args import RecipeArgs
from llmcompressor.recipe.base import RecipeBase
from llmcompressor.recipe.metadata import RecipeMetaData
from llmcompressor.recipe.modifier import RecipeModifier
from llmcompressor.recipe.stage import RecipeStage

__all__ = ["Recipe", "RecipeTuple"]
Expand Down Expand Up @@ -55,20 +56,29 @@ def from_modifiers(
"""
logger.info("Creating recipe from modifiers")

# validate Modifiers
if isinstance(modifiers, Modifier):
modifiers: List[Modifier] = [modifiers]
modifiers = [modifiers]

if any(not isinstance(modifier, Modifier) for modifier in modifiers):
raise ValueError("modifiers must be a list of Modifier instances")

recipe_string: str = create_recipe_string_from_modifiers(
modifiers=modifiers,
modifier_group_name=modifier_group_name,
)
group_name = modifier_group_name or "default"

# modifier group name already included in the recipe string
return cls.create_instance(path_or_modifiers=recipe_string)
recipe_modifiers: List[RecipeModifier] = [
RecipeModifier(
type=modifier.__class__.__name__,
group=group_name,
args=modifier.model_dump(exclude_unset=True),
)
for modifier in modifiers
]
# assume one stage for modifier instances
stages: List[RecipeStage] = [
RecipeStage(group=group_name, modifiers=recipe_modifiers)
]
recipe = Recipe()
recipe.stages = stages
return recipe

@classmethod
def create_instance(
Expand Down Expand Up @@ -638,67 +648,6 @@ def _parse_recipe_from_md(file_path, yaml_str):
return yaml_str


def create_recipe_string_from_modifiers(
modifiers: List[Modifier],
modifier_group_name: Optional[str] = None,
) -> str:
"""
Create a recipe string from a list of Modifier instances

(Note: this pathway assumes there's only one stage in the recipe
associated by the modifier_group_name, if None, a dummy default
group_name will be assigned.)

:param modifiers: The list of Modifier instances
:param modifier_group_name: The stage_name of the recipe,
if `oneshot` or `train` the run_type of the recipe will be
inferred from the modifier_group_name, if None, a dummy default
group_name will be assigned.
:return: A string in yaml format from which the recipe can be created
"""

# Recipe(s) are yaml/json strings of the following format:
# run_type_stage: # should contain oneshot/train
# modifiers:
# ModifierTypeOne:
# start: 0.0
# end: 2.0
# ...
# ModifierTypeTwo:
# ...

# Create a recipe string from the modifiers
default_group_name: str = "DEFAULT"
modifier_group_name: str = modifier_group_name or default_group_name

recipe_dict = {
f"{modifier_group_name}_stage": {
f"{default_group_name}_modifiers": {
modifier.__class__.__name__: modifier.model_dump(exclude_unset=True)
for modifier in modifiers
}
}
}
recipe_str: str = yaml.dump(recipe_dict, sort_keys=False)
return recipe_str


def get_modifiers_dict(modifiers: List[Dict[str, Any]]) -> Dict[str, Any]:
group_dict = {}

for modifier in modifiers:
modifier_type = modifier["type"]
modifier_group = modifier["group"]

if modifier_group not in group_dict:
group_dict[modifier_group] = []

modifier_dict = {modifier_type: modifier["args"]}
group_dict[modifier_group].append(modifier_dict)

return group_dict


def get_yaml_serializable_stage_dict(modifiers: List[Dict[str, Any]]) -> Dict[str, Any]:
"""
This function is used to convert a list of modifiers into a dictionary
Expand Down
40 changes: 20 additions & 20 deletions src/llmcompressor/recipe/stage.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,26 +161,26 @@ def extract_dict_modifiers(values: Dict[str, Any]) -> List[Dict[str, Any]]:
"""

modifiers = []
remove_keys = []

if "modifiers" in values and values["modifiers"]:
remove_keys.append("modifiers")
for mod_key, mod_value in values["stages"].items():
modifier = {mod_key: mod_value}
modifier["group"] = "default"
modifiers.append(modifier)

for key, value in list(values.items()):
if key.endswith("_modifiers"):
remove_keys.append(key)
group = key.rsplit("_modifiers", 1)[0]
for mod_key, mod_value in value.items():
modifier = {mod_key: mod_value}
modifier["group"] = group
modifiers.append(modifier)

for key in remove_keys:
del values[key]

if "modifiers" in values:
modifier_values = values.pop("modifiers")
if "stages" in values:
for mod_key, mod_value in values.pop("stages").items():
modifiers.append({mod_key: mod_value, "group": "default"})
else:
values["default_stage"] = {
"default_modifiers": {mod.type: mod.args for mod in modifier_values}
}
modifiers.extend(
{mod.type: mod.args, "group": "default"} for mod in modifier_values
)

for key in [k for k in values if k.endswith("_modifiers")]:
group = key.rsplit("_modifiers", 1)[0]
modifiers.extend(
{mod_key: mod_value, "group": group}
for mod_key, mod_value in values.pop(key).items()
)

return modifiers

Expand Down
17 changes: 17 additions & 0 deletions tests/e2e/recipe.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
quant_stage:
quant_modifiers:
SmoothQuantModifier:
smoothing_strength: 0.8
mappings:
- - ['re:.*q_proj', 're:.*k_proj', 're:.*v_proj']
- re:.*input_layernorm
- - ['re:.*gate_proj', 're:.*up_proj']
- re:.*post_attention_layernorm
GPTQModifier:
sequential_update: false
ignore: [lm_head]
config_groups:
group_0:
weights: {num_bits: 8, type: int, symmetric: true, strategy: channel}
input_activations: {num_bits: 8, symmetric: false}
targets: [Linear]
94 changes: 94 additions & 0 deletions tests/e2e/test_recipe_parsing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
from pathlib import Path

import pytest

from llmcompressor.core.session_functions import reset_session
from llmcompressor.modifiers.quantization.gptq import GPTQModifier
from llmcompressor.modifiers.smoothquant import SmoothQuantModifier
from llmcompressor.modifiers.smoothquant.base import DEFAULT_SMOOTHQUANT_MAPPINGS
from llmcompressor.transformers import SparseAutoModelForCausalLM, oneshot
from tests.testing_utils import requires_gpu


@pytest.fixture
def common_setup():
model_stub = "Xenova/llama2.c-stories15M"
model = SparseAutoModelForCausalLM.from_pretrained(
model_stub, device_map="auto", torch_dtype="auto"
)

dataset = "ultrachat-200k"
output_dir = "./test_output"
splits = {"calibration": "train_gen[:5%]"}
max_seq_length = 2048
pad_to_max_length = False
num_calibration_samples = 8

return (
model,
dataset,
output_dir,
splits,
max_seq_length,
pad_to_max_length,
num_calibration_samples,
)


def recipes():
modifier_objects = [
SmoothQuantModifier(
smoothing_strength=0.8, mappings=DEFAULT_SMOOTHQUANT_MAPPINGS
),
GPTQModifier(
targets="Linear", scheme="W8A8", ignore=["lm_head"], sequential_update=False
),
]

recipe_str = """
DEFAULT_stage:
DEFAULT_modifiers:
SmoothQuantModifier:
smoothing_strength: 0.8
mappings:
- - ['re:.*q_proj', 're:.*k_proj', 're:.*v_proj']
- re:.*input_layernorm
- - ['re:.*gate_proj', 're:.*up_proj']
- re:.*post_attention_layernorm
GPTQModifier:
sequential_update: false
targets: Linear
scheme: W8A8
"""

recipe_file = str(Path(__file__).parent / "recipe.yaml")

return [modifier_objects, recipe_str, recipe_file]


@requires_gpu
@pytest.mark.parametrize("recipe", recipes())
def test_oneshot(common_setup, recipe):
(
model,
dataset,
output_dir,
splits,
max_seq_length,
pad_to_max_length,
num_calibration_samples,
) = common_setup

oneshot(
model=model,
dataset=dataset,
recipe=recipe,
output_dir=output_dir,
splits=splits,
max_seq_length=max_seq_length,
pad_to_max_length=pad_to_max_length,
num_calibration_samples=num_calibration_samples,
save_compressed=True,
)

reset_session()
13 changes: 13 additions & 0 deletions tests/llmcompressor/helpers.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
# flake8: noqa


def valid_recipe_strings():
return [
"""
Expand Down Expand Up @@ -52,4 +55,14 @@ def valid_recipe_strings():
final_sparsity: 0.5
targets: __ALL_PRUNABLE__
""",
"""
test1_stage:
smoothquant_modifiers:
SmoothQuantModifier:
smoothing_strength: 0.5
mappings: [
[["re:.*q_proj", "re:.*k_proj", "re:.*v_proj"], "re:.*input_layernorm"],
[["re:.*gate_proj", "re:.*up_proj"], "re:.*post_attention_layernorm"]
]
""",
]
46 changes: 1 addition & 45 deletions tests/llmcompressor/recipe/test_recipe.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,8 @@
import pytest
import yaml

from llmcompressor.modifiers import Modifier
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we add a new test for a list of modifiers?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added Ben!

from llmcompressor.modifiers.obcq.base import SparseGPTModifier
from llmcompressor.recipe import Recipe
from llmcompressor.recipe.recipe import create_recipe_string_from_modifiers
from tests.llmcompressor.helpers import valid_recipe_strings


Expand Down Expand Up @@ -97,46 +95,4 @@ def test_recipe_can_be_created_from_modifier_instances():
actual_modifiers[0].modifiers, expected_modifiers[0].modifiers
):
assert isinstance(actual_modifier, type(expected_modifier))
assert actual_modifier.dict() == expected_modifier.dict()


class A_FirstDummyModifier(Modifier):
def on_initialize(self, *args, **kwargs) -> bool:
return True


class B_SecondDummyModifier(Modifier):
def on_initialize(self, *args, **kwargs) -> bool:
return True


def test_create_recipe_string_from_modifiers_with_default_group_name():
modifiers = [B_SecondDummyModifier(), A_FirstDummyModifier()]
expected_recipe_str = (
"DEFAULT_stage:\n"
" DEFAULT_modifiers:\n"
" B_SecondDummyModifier: {}\n"
" A_FirstDummyModifier: {}\n"
)
actual_recipe_str = create_recipe_string_from_modifiers(modifiers)
assert actual_recipe_str == expected_recipe_str


def test_create_recipe_string_from_modifiers_with_custom_group_name():
modifiers = [B_SecondDummyModifier(), A_FirstDummyModifier()]
group_name = "custom"
expected_recipe_str = (
"custom_stage:\n"
" DEFAULT_modifiers:\n"
" B_SecondDummyModifier: {}\n"
" A_FirstDummyModifier: {}\n"
)
actual_recipe_str = create_recipe_string_from_modifiers(modifiers, group_name)
assert actual_recipe_str == expected_recipe_str


def test_create_recipe_string_from_modifiers_with_empty_modifiers():
modifiers = []
expected_recipe_str = "DEFAULT_stage:\n" " DEFAULT_modifiers: {}\n"
actual_recipe_str = create_recipe_string_from_modifiers(modifiers)
assert actual_recipe_str == expected_recipe_str
assert actual_modifier.model_dump() == expected_modifier.model_dump()
Loading