diff --git a/tests/test_generator_creation.py b/tests/test_generator_creation.py index e5cb560..e1767f4 100644 --- a/tests/test_generator_creation.py +++ b/tests/test_generator_creation.py @@ -4,14 +4,14 @@ from rigging.generator import GenerateParams, LiteLLMGenerator, get_generator -@pytest.mark.parametrize("identifier", ["test_model", "litellm:test_model"]) +@pytest.mark.parametrize("identifier", ["test_model", "litellm!test_model"]) def test_get_generator_default_is_litellm(identifier: str) -> None: generator = get_generator(identifier) assert isinstance(generator, LiteLLMGenerator) assert generator.model == "test_model" -@pytest.mark.parametrize("identifier", ["invalid:testing", "no_exist:stuff,args=123"]) +@pytest.mark.parametrize("identifier", ["invalid!testing", "no_exist!stuff,args=123"]) def test_get_generator_invalid_provider(identifier: str) -> None: with pytest.raises(InvalidModelSpecifiedError): get_generator(identifier) @@ -20,9 +20,9 @@ def test_get_generator_invalid_provider(identifier: str) -> None: @pytest.mark.parametrize( "identifier, valid_params", [ - ("litellm:test_model,max_tokens=123,top_p=10", GenerateParams(max_tokens=123, top_p=10)), - ("litellm:test_model,temperature=0.5", GenerateParams(temperature=0.5)), - ("litellm:test_model,max_tokens=100,temperature=1.0", GenerateParams(max_tokens=100, temperature=1.0)), + ("litellm!test_model,max_tokens=123,top_p=10", GenerateParams(max_tokens=123, top_p=10)), + ("litellm!test_model,temperature=0.5", GenerateParams(temperature=0.5)), + ("test_model,max_tokens=100,temperature=1.0", GenerateParams(max_tokens=100, temperature=1.0)), ], ) def test_get_generator_with_params(identifier: str, valid_params: GenerateParams) -> None: