Skip to content

Commit

Permalink
Add automated unit tests
Browse files Browse the repository at this point in the history
  • Loading branch information
rahul-tuli committed Jul 18, 2024
1 parent dbc0035 commit 487e19f
Showing 1 changed file with 44 additions and 0 deletions.
44 changes: 44 additions & 0 deletions tests/llmcompressor/recipe/test_recipe.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,10 @@
import pytest
import yaml

from llmcompressor.modifiers import Modifier
from llmcompressor.modifiers.obcq.base import SparseGPTModifier
from llmcompressor.recipe import Recipe
from llmcompressor.recipe.recipe import create_recipe_string_from_modifiers
from tests.llmcompressor.helpers import valid_recipe_strings


Expand Down Expand Up @@ -96,3 +98,45 @@ def test_recipe_can_be_created_from_modifier_instances():
):
assert isinstance(actual_modifier, type(expected_modifier))
assert actual_modifier.dict() == expected_modifier.dict()


class A_FirstDummyModifier(Modifier):
def model_dump(self):
return {}


class B_SecondDummyModifier(Modifier):
def model_dump(self):
return {}


def test_create_recipe_string_from_modifiers_with_default_group_name():
modifiers = [B_SecondDummyModifier(), A_FirstDummyModifier()]
expected_recipe_str = (
"DEFAULT_stage:\n"
" DEFAULT_modifiers:\n"
" B_SecondDummyModifier: {}\n"
" A_FirstDummyModifier: {}\n"
)
actual_recipe_str = create_recipe_string_from_modifiers(modifiers)
assert actual_recipe_str == expected_recipe_str


def test_create_recipe_string_from_modifiers_with_custom_group_name():
modifiers = [B_SecondDummyModifier(), A_FirstDummyModifier()]
group_name = "custom"
expected_recipe_str = (
"custom_stage:\n"
" DEFAULT_modifiers:\n"
" B_SecondDummyModifier: {}\n"
" A_FirstDummyModifier: {}\n"
)
actual_recipe_str = create_recipe_string_from_modifiers(modifiers, group_name)
assert actual_recipe_str == expected_recipe_str


def test_create_recipe_string_from_modifiers_with_empty_modifiers():
modifiers = []
expected_recipe_str = "DEFAULT_stage:\n" " DEFAULT_modifiers: {}\n"
actual_recipe_str = create_recipe_string_from_modifiers(modifiers)
assert actual_recipe_str == expected_recipe_str

0 comments on commit 487e19f

Please sign in to comment.