Skip to content

Commit

Permalink
Add check that the state of the model object is unchanged.
Browse files Browse the repository at this point in the history
  • Loading branch information
justinxzhao committed Aug 11, 2023
1 parent 60a496e commit 071865c
Showing 1 changed file with 6 additions and 1 deletion.
7 changes: 6 additions & 1 deletion tests/integration_tests/test_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@

LOCAL_BACKEND = {"type": "local"}
TEST_MODEL_NAME = "hf-internal-testing/tiny-random-GPTJForCausalLM"
MAX_NEW_TOKENS_TEST_DEFAULT = 5

RAY_BACKEND = {
"type": "ray",
Expand Down Expand Up @@ -84,7 +85,7 @@ def get_generation_config():
"top_p": 0.75,
"top_k": 40,
"num_beams": 4,
"max_new_tokens": 5,
"max_new_tokens": MAX_NEW_TOKENS_TEST_DEFAULT,
}


Expand Down Expand Up @@ -141,6 +142,7 @@ def test_llm_text_to_text(tmpdir, backend, ray_cluster_4cpu):

# Check that in-line generation parameters are used. Original prediction uses max_new_tokens = 5.
assert get_num_non_empty_tokens(preds["Answer_predictions"][0]) > 3
original_max_new_tokens = model.model.generation.max_new_tokens

# This prediction uses max_new_tokens = 2.
preds, _ = model.predict(
Expand All @@ -149,6 +151,9 @@ def test_llm_text_to_text(tmpdir, backend, ray_cluster_4cpu):
preds = convert_preds(preds)
assert get_num_non_empty_tokens(preds["Answer_predictions"][0]) < 3

# Check that the state of the model is unchanged.
assert model.model.generation.max_new_tokens == original_max_new_tokens


@pytest.mark.llm
@pytest.mark.parametrize(
Expand Down

0 comments on commit 071865c

Please sign in to comment.