Skip to content

Commit 1f0ae96

Browse files
committed
add test
Signed-off-by: Kyle Sayers <kylesayrs@gmail.com>
1 parent f5ebcb1 commit 1f0ae96

File tree

1 file changed

+71
-0
lines changed

1 file changed

+71
-0
lines changed
Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
import tempfile
2+
import unittest
3+
4+
from transformers import LlavaConfig
5+
6+
7+
class LlavaConfigTest(unittest.TestCase):
8+
def test_llava_reload(self):
9+
"""
10+
Simple test for reloading default llava configs
11+
"""
12+
with tempfile.TemporaryDirectory() as tmp_dir:
13+
config = LlavaConfig()
14+
config.save_pretrained(tmp_dir)
15+
16+
reloaded = LlavaConfig.from_pretrained(tmp_dir)
17+
assert config.to_dict() == reloaded.to_dict()
18+
19+
def test_pixtral_reload(self):
20+
"""
21+
Simple test for reloading pixtral configs
22+
"""
23+
vision_config = {
24+
"model_type": "pixtral",
25+
"head_dim": 64,
26+
"hidden_act": "silu",
27+
"image_size": 1024,
28+
"is_composition": True,
29+
"patch_size": 16,
30+
"rope_theta": 10000.0,
31+
"tie_word_embeddings": False,
32+
}
33+
34+
text_config = {
35+
# "model_type": "mistral",
36+
"model_type": "pixtral_text",
37+
"hidden_size": 5120,
38+
"head_dim": 128,
39+
"num_attention_heads": 32,
40+
"intermediate_size": 14336,
41+
"is_composition": True,
42+
"max_position_embeddings": 1024000,
43+
"num_hidden_layers": 40,
44+
"num_key_value_heads": 8,
45+
"rms_norm_eps": 1e-05,
46+
"rope_theta": 1000000000.0,
47+
"sliding_window": None,
48+
"vocab_size": 131072,
49+
}
50+
51+
with tempfile.TemporaryDirectory() as tmp_dir:
52+
config = LlavaConfig(vision_config=vision_config, text_config=text_config)
53+
config.save_pretrained(tmp_dir)
54+
55+
reloaded = LlavaConfig.from_pretrained(tmp_dir)
56+
assert config.to_dict() == reloaded.to_dict()
57+
58+
def test_arbitrary_reload(self):
59+
"""
60+
Simple test for reloading arbirarily composed subconfigs
61+
"""
62+
default_values = LlavaConfig().to_dict()
63+
default_values["vision_config"]["model_type"] = "qwen2_vl"
64+
default_values["text_config"]["model_type"] = "opt"
65+
66+
with tempfile.TemporaryDirectory() as tmp_dir:
67+
config = LlavaConfig(**default_values)
68+
config.save_pretrained(tmp_dir)
69+
70+
reloaded = LlavaConfig.from_pretrained(tmp_dir)
71+
assert config.to_dict() == reloaded.to_dict()

0 commit comments

Comments
 (0)