Skip to content

Commit

Permalink
Fix tests
Browse files Browse the repository at this point in the history
  • Loading branch information
monoxgas committed Feb 1, 2024
1 parent b809e8a commit 10f6af5
Showing 1 changed file with 5 additions and 5 deletions.
10 changes: 5 additions & 5 deletions tests/test_generator_creation.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,14 @@
from rigging.generator import GenerateParams, LiteLLMGenerator, get_generator


@pytest.mark.parametrize("identifier", ["test_model", "litellm:test_model"])
@pytest.mark.parametrize("identifier", ["test_model", "litellm!test_model"])
def test_get_generator_default_is_litellm(identifier: str) -> None:
generator = get_generator(identifier)
assert isinstance(generator, LiteLLMGenerator)
assert generator.model == "test_model"


@pytest.mark.parametrize("identifier", ["invalid:testing", "no_exist:stuff,args=123"])
@pytest.mark.parametrize("identifier", ["invalid!testing", "no_exist!stuff,args=123"])
def test_get_generator_invalid_provider(identifier: str) -> None:
with pytest.raises(InvalidModelSpecifiedError):
get_generator(identifier)
Expand All @@ -20,9 +20,9 @@ def test_get_generator_invalid_provider(identifier: str) -> None:
@pytest.mark.parametrize(
"identifier, valid_params",
[
("litellm:test_model,max_tokens=123,top_p=10", GenerateParams(max_tokens=123, top_p=10)),
("litellm:test_model,temperature=0.5", GenerateParams(temperature=0.5)),
("litellm:test_model,max_tokens=100,temperature=1.0", GenerateParams(max_tokens=100, temperature=1.0)),
("litellm!test_model,max_tokens=123,top_p=10", GenerateParams(max_tokens=123, top_p=10)),
("litellm!test_model,temperature=0.5", GenerateParams(temperature=0.5)),
("test_model,max_tokens=100,temperature=1.0", GenerateParams(max_tokens=100, temperature=1.0)),
],
)
def test_get_generator_with_params(identifier: str, valid_params: GenerateParams) -> None:
Expand Down

0 comments on commit 10f6af5

Please sign in to comment.