Skip to content

Commit

Permalink
Refactor Recipe creation flow: Directly convert Modifier Instances to…
Browse files Browse the repository at this point in the history
… Recipe
  • Loading branch information
rahul-tuli committed Aug 1, 2024
1 parent daad632 commit a1a2bcf
Show file tree
Hide file tree
Showing 4 changed files with 47 additions and 139 deletions.
15 changes: 8 additions & 7 deletions src/llmcompressor/recipe/modifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,13 +83,14 @@ def create_modifier(self) -> "Modifier":
@model_validator(mode="before")
@classmethod
def extract_modifier_type(cls, values: Dict[str, Any]) -> Dict[str, Any]:
modifier = {"group": values.pop("group")}
assert len(values) == 1, "multiple key pairs found for modifier"
modifier_type, args = list(values.items())[0]

modifier["type"] = modifier_type
modifier["args"] = args
return modifier
if len(values) == 2:
# values contains only group and the Modifier type as keys
group = values.pop("group")
modifier_type, args = values.popitem()
return {"group": group, "type": modifier_type, "args": args}

# values already in the correct format
return values

def dict(self, *args, **kwargs) -> Dict[str, Any]:
"""
Expand Down
87 changes: 18 additions & 69 deletions src/llmcompressor/recipe/recipe.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from llmcompressor.recipe.args import RecipeArgs
from llmcompressor.recipe.base import RecipeBase
from llmcompressor.recipe.metadata import RecipeMetaData
from llmcompressor.recipe.modifier import RecipeModifier
from llmcompressor.recipe.stage import RecipeStage

__all__ = ["Recipe", "RecipeTuple"]
Expand Down Expand Up @@ -55,20 +56,29 @@ def from_modifiers(
"""
logger.info("Creating recipe from modifiers")

# validate Modifiers
if isinstance(modifiers, Modifier):
modifiers: List[Modifier] = [modifiers]
modifiers = [modifiers]

if any(not isinstance(modifier, Modifier) for modifier in modifiers):
raise ValueError("modifiers must be a list of Modifier instances")

recipe_string: str = create_recipe_string_from_modifiers(
modifiers=modifiers,
modifier_group_name=modifier_group_name,
)
group_name = modifier_group_name or "default"

# modifier group name already included in the recipe string
return cls.create_instance(path_or_modifiers=recipe_string)
recipe_modifiers: List[RecipeModifier] = [
RecipeModifier(
type=modifier.__class__.__name__,
group=group_name,
args=modifier.model_dump(exclude_unset=True),
)
for modifier in modifiers
]
# assume one stage for modifier instances
stages: List[RecipeStage] = [
RecipeStage(group=group_name, modifiers=recipe_modifiers)
]
recipe = Recipe()
recipe.stages = stages
return recipe

@classmethod
def create_instance(
Expand Down Expand Up @@ -631,67 +641,6 @@ def _parse_recipe_from_md(file_path, yaml_str):
return yaml_str


def create_recipe_string_from_modifiers(
modifiers: List[Modifier],
modifier_group_name: Optional[str] = None,
) -> str:
"""
Create a recipe string from a list of Modifier instances
(Note: this pathway assumes there's only one stage in the recipe
associated by the modifier_group_name, if None, a dummy default
group_name will be assigned.)
:param modifiers: The list of Modifier instances
:param modifier_group_name: The stage_name of the recipe,
if `oneshot` or `train` the run_type of the recipe will be
inferred from the modifier_group_name, if None, a dummy default
group_name will be assigned.
:return: A string in yaml format from which the recipe can be created
"""

# Recipe(s) are yaml/json strings of the following format:
# run_type_stage: # should contain oneshot/train
# modifiers:
# ModifierTypeOne:
# start: 0.0
# end: 2.0
# ...
# ModifierTypeTwo:
# ...

# Create a recipe string from the modifiers
default_group_name: str = "DEFAULT"
modifier_group_name: str = modifier_group_name or default_group_name

recipe_dict = {
f"{modifier_group_name}_stage": {
f"{default_group_name}_modifiers": {
modifier.__class__.__name__: modifier.model_dump(exclude_unset=True)
for modifier in modifiers
}
}
}
recipe_str: str = yaml.dump(recipe_dict, sort_keys=False)
return recipe_str


def get_modifiers_dict(modifiers: List[Dict[str, Any]]) -> Dict[str, Any]:
group_dict = {}

for modifier in modifiers:
modifier_type = modifier["type"]
modifier_group = modifier["group"]

if modifier_group not in group_dict:
group_dict[modifier_group] = []

modifier_dict = {modifier_type: modifier["args"]}
group_dict[modifier_group].append(modifier_dict)

return group_dict


def get_yaml_serializable_stage_dict(modifiers: List[Dict[str, Any]]) -> Dict[str, Any]:
"""
This function is used to convert a list of modifiers into a dictionary
Expand Down
40 changes: 20 additions & 20 deletions src/llmcompressor/recipe/stage.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,26 +161,26 @@ def extract_dict_modifiers(values: Dict[str, Any]) -> List[Dict[str, Any]]:
"""

modifiers = []
remove_keys = []

if "modifiers" in values and values["modifiers"]:
remove_keys.append("modifiers")
for mod_key, mod_value in values["stages"].items():
modifier = {mod_key: mod_value}
modifier["group"] = "default"
modifiers.append(modifier)

for key, value in list(values.items()):
if key.endswith("_modifiers"):
remove_keys.append(key)
group = key.rsplit("_modifiers", 1)[0]
for mod_key, mod_value in value.items():
modifier = {mod_key: mod_value}
modifier["group"] = group
modifiers.append(modifier)

for key in remove_keys:
del values[key]

if "modifiers" in values:
modifier_values = values.pop("modifiers")
if "stages" in values:
for mod_key, mod_value in values.pop("stages").items():
modifiers.append({mod_key: mod_value, "group": "default"})
else:
values["default_stage"] = {
"default_modifiers": {mod.type: mod.args for mod in modifier_values}
}
modifiers.extend(
{mod.type: mod.args, "group": "default"} for mod in modifier_values
)

for key in [k for k in values if k.endswith("_modifiers")]:
group = key.rsplit("_modifiers", 1)[0]
modifiers.extend(
{mod_key: mod_value, "group": group}
for mod_key, mod_value in values.pop(key).items()
)

return modifiers

Expand Down
44 changes: 1 addition & 43 deletions tests/llmcompressor/recipe/test_recipe.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,8 @@
import pytest
import yaml

from llmcompressor.modifiers import Modifier
from llmcompressor.modifiers.obcq.base import SparseGPTModifier
from llmcompressor.recipe import Recipe
from llmcompressor.recipe.recipe import create_recipe_string_from_modifiers
from tests.llmcompressor.helpers import valid_recipe_strings


Expand Down Expand Up @@ -97,44 +95,4 @@ def test_recipe_can_be_created_from_modifier_instances():
actual_modifiers[0].modifiers, expected_modifiers[0].modifiers
):
assert isinstance(actual_modifier, type(expected_modifier))
assert actual_modifier.dict() == expected_modifier.dict()


class A_FirstDummyModifier(Modifier):
pass


class B_SecondDummyModifier(Modifier):
pass


def test_create_recipe_string_from_modifiers_with_default_group_name():
modifiers = [B_SecondDummyModifier(), A_FirstDummyModifier()]
expected_recipe_str = (
"DEFAULT_stage:\n"
" DEFAULT_modifiers:\n"
" B_SecondDummyModifier: {}\n"
" A_FirstDummyModifier: {}\n"
)
actual_recipe_str = create_recipe_string_from_modifiers(modifiers)
assert actual_recipe_str == expected_recipe_str


def test_create_recipe_string_from_modifiers_with_custom_group_name():
modifiers = [B_SecondDummyModifier(), A_FirstDummyModifier()]
group_name = "custom"
expected_recipe_str = (
"custom_stage:\n"
" DEFAULT_modifiers:\n"
" B_SecondDummyModifier: {}\n"
" A_FirstDummyModifier: {}\n"
)
actual_recipe_str = create_recipe_string_from_modifiers(modifiers, group_name)
assert actual_recipe_str == expected_recipe_str


def test_create_recipe_string_from_modifiers_with_empty_modifiers():
modifiers = []
expected_recipe_str = "DEFAULT_stage:\n" " DEFAULT_modifiers: {}\n"
actual_recipe_str = create_recipe_string_from_modifiers(modifiers)
assert actual_recipe_str == expected_recipe_str
assert actual_modifier.model_dump() == expected_modifier.model_dump()

0 comments on commit a1a2bcf

Please sign in to comment.