From a788c1cdd3f1eb88520819b16f1d3c0be1d778a2 Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Sat, 22 Jun 2024 12:00:30 +0000 Subject: [PATCH] fix test --- tests/utils/test_modeling_utils.py | 15 +++++---------- 1 file changed, 5 insertions(+), 10 deletions(-) diff --git a/tests/utils/test_modeling_utils.py b/tests/utils/test_modeling_utils.py index 134b4758e637..c86c340017b0 100644 --- a/tests/utils/test_modeling_utils.py +++ b/tests/utils/test_modeling_utils.py @@ -1424,20 +1424,15 @@ def test_pretrained_low_mem_new_config(self): self.assertEqual(model.__class__.__name__, model_ref.__class__.__name__) def test_generation_config_is_loaded_with_model(self): - # Note: `joaogante/tiny-random-gpt2-with-generation-config` has a `generation_config.json` containing a dummy - # `transformers_version` field set to `foo`. If loading the file fails, this test also fails. + # Note: `TinyLlama/TinyLlama-1.1B-Chat-v1.0` has a `generation_config.json` containing `max_length: 2048` # 1. Load without further parameters - model = AutoModelForCausalLM.from_pretrained( - "joaogante/tiny-random-gpt2-with-generation-config", use_safetensors=False - ) - self.assertEqual(model.generation_config.transformers_version, "foo") + model = AutoModelForCausalLM.from_pretrained("TinyLlama/TinyLlama-1.1B-Chat-v1.0") + self.assertEqual(model.generation_config.max_length, 2048) # 2. Load with `device_map` - model = AutoModelForCausalLM.from_pretrained( - "joaogante/tiny-random-gpt2-with-generation-config", device_map="auto", use_safetensors=False - ) - self.assertEqual(model.generation_config.transformers_version, "foo") + model = AutoModelForCausalLM.from_pretrained("TinyLlama/TinyLlama-1.1B-Chat-v1.0", device_map="auto") + self.assertEqual(model.generation_config.max_length, 2048) @require_safetensors def test_safetensors_torch_from_torch(self):