From 60a496eb5494289800d1711ee0b18c9175a5a810 Mon Sep 17 00:00:00 2001 From: Justin Zhao Date: Fri, 11 Aug 2023 13:55:54 -0400 Subject: [PATCH 1/3] Add test for generation config --- tests/integration_tests/test_llm.py | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/tests/integration_tests/test_llm.py b/tests/integration_tests/test_llm.py index ce852cd1e07..90af329b5c7 100644 --- a/tests/integration_tests/test_llm.py +++ b/tests/integration_tests/test_llm.py @@ -46,6 +46,11 @@ } +def get_num_non_empty_tokens(iterable): + """Returns the number of non-empty tokens.""" + return len(list(filter(bool, iterable))) + + @pytest.fixture(scope="module") def local_backend(): return LOCAL_BACKEND @@ -134,6 +139,16 @@ def test_llm_text_to_text(tmpdir, backend, ray_cluster_4cpu): assert preds["Answer_probabilities"] assert preds["Answer_probability"] + # Check that in-line generation parameters are used. Original prediction uses max_new_tokens = 5. + assert get_num_non_empty_tokens(preds["Answer_predictions"][0]) > 3 + + # This prediction uses max_new_tokens = 2. + preds, _ = model.predict( + dataset=dataset_filename, output_directory=str(tmpdir), split="test", generation_config={"max_new_tokens": 2} + ) + preds = convert_preds(preds) + assert get_num_non_empty_tokens(preds["Answer_predictions"][0]) < 3 + @pytest.mark.llm @pytest.mark.parametrize( From 071865c82f54c603b294ba73f9d3ea3499149b0d Mon Sep 17 00:00:00 2001 From: Justin Zhao Date: Fri, 11 Aug 2023 13:59:21 -0400 Subject: [PATCH 2/3] Add check that the state of the model object is unchanged. --- tests/integration_tests/test_llm.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/tests/integration_tests/test_llm.py b/tests/integration_tests/test_llm.py index 90af329b5c7..cb9981329b8 100644 --- a/tests/integration_tests/test_llm.py +++ b/tests/integration_tests/test_llm.py @@ -29,6 +29,7 @@ LOCAL_BACKEND = {"type": "local"} TEST_MODEL_NAME = "hf-internal-testing/tiny-random-GPTJForCausalLM" +MAX_NEW_TOKENS_TEST_DEFAULT = 5 RAY_BACKEND = { "type": "ray", @@ -84,7 +85,7 @@ def get_generation_config(): "top_p": 0.75, "top_k": 40, "num_beams": 4, - "max_new_tokens": 5, + "max_new_tokens": MAX_NEW_TOKENS_TEST_DEFAULT, } @@ -141,6 +142,7 @@ def test_llm_text_to_text(tmpdir, backend, ray_cluster_4cpu): # Check that in-line generation parameters are used. Original prediction uses max_new_tokens = 5. assert get_num_non_empty_tokens(preds["Answer_predictions"][0]) > 3 + original_max_new_tokens = model.model.generation.max_new_tokens # This prediction uses max_new_tokens = 2. preds, _ = model.predict( @@ -149,6 +151,9 @@ def test_llm_text_to_text(tmpdir, backend, ray_cluster_4cpu): preds = convert_preds(preds) assert get_num_non_empty_tokens(preds["Answer_predictions"][0]) < 3 + # Check that the state of the model is unchanged. + assert model.model.generation.max_new_tokens == original_max_new_tokens + @pytest.mark.llm @pytest.mark.parametrize( From 0c038e2800b5501f8cc805b061594c1a1797772f Mon Sep 17 00:00:00 2001 From: Justin Zhao Date: Fri, 11 Aug 2023 16:07:27 -0400 Subject: [PATCH 3/3] PR feedback --- tests/integration_tests/test_llm.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/tests/integration_tests/test_llm.py b/tests/integration_tests/test_llm.py index cb9981329b8..a95e9c1f195 100644 --- a/tests/integration_tests/test_llm.py +++ b/tests/integration_tests/test_llm.py @@ -141,15 +141,20 @@ def test_llm_text_to_text(tmpdir, backend, ray_cluster_4cpu): assert preds["Answer_probability"] # Check that in-line generation parameters are used. Original prediction uses max_new_tokens = 5. - assert get_num_non_empty_tokens(preds["Answer_predictions"][0]) > 3 + assert get_num_non_empty_tokens(preds["Answer_predictions"][0]) <= MAX_NEW_TOKENS_TEST_DEFAULT original_max_new_tokens = model.model.generation.max_new_tokens # This prediction uses max_new_tokens = 2. preds, _ = model.predict( - dataset=dataset_filename, output_directory=str(tmpdir), split="test", generation_config={"max_new_tokens": 2} + dataset=dataset_filename, + output_directory=str(tmpdir), + split="test", + generation_config={"min_new_tokens": 2, "max_new_tokens": 3}, ) preds = convert_preds(preds) - assert get_num_non_empty_tokens(preds["Answer_predictions"][0]) < 3 + print(preds["Answer_predictions"][0]) + num_non_empty_tokens = get_num_non_empty_tokens(preds["Answer_predictions"][0]) + assert 2 <= num_non_empty_tokens <= 3 # Check that the state of the model is unchanged. assert model.model.generation.max_new_tokens == original_max_new_tokens