From 071865c82f54c603b294ba73f9d3ea3499149b0d Mon Sep 17 00:00:00 2001 From: Justin Zhao Date: Fri, 11 Aug 2023 13:59:21 -0400 Subject: [PATCH] Add check that the state of the model object is unchanged. --- tests/integration_tests/test_llm.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/tests/integration_tests/test_llm.py b/tests/integration_tests/test_llm.py index 90af329b5c7..cb9981329b8 100644 --- a/tests/integration_tests/test_llm.py +++ b/tests/integration_tests/test_llm.py @@ -29,6 +29,7 @@ LOCAL_BACKEND = {"type": "local"} TEST_MODEL_NAME = "hf-internal-testing/tiny-random-GPTJForCausalLM" +MAX_NEW_TOKENS_TEST_DEFAULT = 5 RAY_BACKEND = { "type": "ray", @@ -84,7 +85,7 @@ def get_generation_config(): "top_p": 0.75, "top_k": 40, "num_beams": 4, - "max_new_tokens": 5, + "max_new_tokens": MAX_NEW_TOKENS_TEST_DEFAULT, } @@ -141,6 +142,7 @@ def test_llm_text_to_text(tmpdir, backend, ray_cluster_4cpu): # Check that in-line generation parameters are used. Original prediction uses max_new_tokens = 5. assert get_num_non_empty_tokens(preds["Answer_predictions"][0]) > 3 + original_max_new_tokens = model.model.generation.max_new_tokens # This prediction uses max_new_tokens = 2. preds, _ = model.predict( @@ -149,6 +151,9 @@ def test_llm_text_to_text(tmpdir, backend, ray_cluster_4cpu): preds = convert_preds(preds) assert get_num_non_empty_tokens(preds["Answer_predictions"][0]) < 3 + # Check that the state of the model is unchanged. + assert model.model.generation.max_new_tokens == original_max_new_tokens + @pytest.mark.llm @pytest.mark.parametrize(