Skip to content

Commit b2b5044

Browse files
ydshiehArthurZucker
authored andcommitted
Fix PhimoeIntegrationTest (#41007)
* fix * fix * fix * fix * fix --------- Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
1 parent d6d2d03 commit b2b5044

File tree

1 file changed

+57
-24
lines changed

1 file changed

+57
-24
lines changed

tests/models/phimoe/test_modeling_phimoe.py

Lines changed: 57 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -14,12 +14,14 @@
1414

1515
"""Testing suite for the PyTorch PhiMoE model."""
1616

17+
import copy
1718
import unittest
1819

1920
from parameterized import parameterized
2021

2122
from transformers import PhimoeConfig, StaticCache, is_torch_available
2223
from transformers.testing_utils import (
24+
cleanup,
2325
require_torch,
2426
slow,
2527
torch_device,
@@ -130,31 +132,47 @@ def test_model_rope_scaling_from_config(self, scaling_type):
130132
@slow
131133
@require_torch
132134
class PhimoeIntegrationTest(unittest.TestCase):
133-
def test_model_phimoe_instruct_logits(self):
134-
input_ids = {
135-
"input_ids": torch.tensor(
136-
[[1212, 318, 281, 1672, 2643, 290, 428, 318, 257, 1332]], dtype=torch.long, device=torch_device
135+
model = None
136+
137+
@classmethod
138+
def get_model(cls):
139+
if cls.model is None:
140+
cls.model = PhimoeForCausalLM.from_pretrained(
141+
"microsoft/Phi-3.5-MoE-instruct", dtype="auto", device_map="auto"
137142
)
138-
}
143+
return cls.model
144+
145+
@classmethod
146+
def tearDownClass(cls):
147+
del cls.model
148+
cleanup(torch_device, gc_collect=True)
149+
150+
def setUp(self):
151+
cleanup(torch_device, gc_collect=True)
152+
153+
def tearDown(self):
154+
cleanup(torch_device, gc_collect=True)
139155

140-
model = PhimoeForCausalLM.from_pretrained("microsoft/Phi-3.5-MoE-instruct").to(torch_device)
156+
def test_model_phimoe_instruct_logits(self):
157+
input_ids = {"input_ids": torch.tensor([[1212, 318, 281, 1672]], dtype=torch.long, device=torch_device)}
158+
159+
model = self.get_model()
141160
model.eval()
142161

143-
output = model(**input_ids).logits
162+
with torch.no_grad():
163+
output = model(**input_ids).logits
144164

145-
EXPECTED_OUTPUT = torch.tensor([[-3.5312, -2.5000, -1.2734, 0.3555, -0.7578, -0.4727, 0.5977, -0.4316,
146-
0.2256, -1.2188, -1.6797, 0.9961, 3.7656, 11.3125, -1.3828, -4.8438,
147-
-5.7500, -1.9375, 0.7227, -0.3438, -0.2100, -0.4277, -0.0444, -0.5352,
148-
-0.6406, -0.1016, -0.4258, -1.0234, 0.4297, -0.6250],
149-
[-0.9883, 0.1455, -0.4902, 2.3594, 0.7031, 3.1406, 0.4375, 0.2559,
150-
0.6172, -2.1094, -1.3359, 2.5938, 4.9062, 10.8125, -0.1094, 1.5781,
151-
-4.9375, 0.7148, -0.0972, 1.7656, -0.0801, 0.2217, 0.1875, -0.4629,
152-
1.5781, 0.3535, 0.0874, 0.6836, -0.0518, -1.2969]]).to(torch_device) # fmt: skip
165+
EXPECTED_OUTPUT = torch.tensor(
166+
[
167+
[-3.4844, -2.4531, -1.1719, 0.6055, -0.4922, -0.1001, 0.8086, -0.2422, 0.3477, -1.0078],
168+
[-0.9766, 0.1631, -0.5508, 2.3594, 0.7031, 3.1719, 0.4141, 0.2305, 0.6055, -2.1250],
169+
]
170+
).to(device=torch_device, dtype=output.dtype) # fmt: skip
153171

154-
torch.testing.assert_close(EXPECTED_OUTPUT, output[0, :2, :30], rtol=1e-4, atol=1e-4)
172+
torch.testing.assert_close(output[0, :2, :10], EXPECTED_OUTPUT, rtol=1e-4, atol=1e-4)
155173

156174
def test_phimoe_instruct_generation(self):
157-
model = PhimoeForCausalLM.from_pretrained("microsoft/Phi-3.5-MoE-instruct")
175+
model = self.get_model()
158176
tokenizer = AutoTokenizer.from_pretrained("microsoft/Phi-3.5-MoE-instruct")
159177

160178
messages = [
@@ -166,17 +184,29 @@ def test_phimoe_instruct_generation(self):
166184
]
167185
inputs = tokenizer.apply_chat_template(messages, add_generation_prompt=True, return_tensors="pt")
168186

169-
outputs = model.generate(inputs, max_new_tokens=32)
187+
outputs = model.generate(inputs, max_new_tokens=30)
170188
output_text = tokenizer.batch_decode(outputs)
171189

172190
EXPECTED_OUTPUT = [
173-
"<|system|> You are a helpful digital assistant. Please provide safe, ethical and accurate information to the user.<|end|><|user|> Can you provide ways to eat combinations of bananas and dragonfruits?<|end|><|assistant|> Certainly! Bananas and dragonfruits are both delicious and nutritious fruits that can be combined in various ways to create tast"
191+
"<|system|> You are a helpful digital assistant. Please provide safe, ethical and accurate information to the user.<|end|><|user|> Can you provide ways to eat combinations of bananas and dragonfruits?<|end|><|assistant|> Certainly! Bananas and dragonfruits are both delicious and nutritious fruits that can be combined in various ways to create",
174192
]
175-
176193
self.assertListEqual(output_text, EXPECTED_OUTPUT)
177194

178195
def test_phimoe_instruct_with_static_cache(self):
179-
model = PhimoeForCausalLM.from_pretrained("microsoft/Phi-3.5-MoE-instruct")
196+
model = self.get_model()
197+
# Can't run with the real checkpoint, even if offloaded. Let's just use a tiny dummy one
198+
config = copy.deepcopy(model.config)
199+
config.num_hidden_layers = 2
200+
# make `head_dim = 128`
201+
config.hidden_size = 512
202+
config.num_attention_heads = 4
203+
config.num_key_value_heads = 1
204+
config.intermediate_size = 512
205+
config.max_position_embeddinqgs = 64
206+
config.num_local_experts = 4
207+
torch.manual_seed(42)
208+
model = PhimoeForCausalLM(config).to(torch_device)
209+
model.eval()
180210
tokenizer = AutoTokenizer.from_pretrained("microsoft/Phi-3.5-MoE-instruct")
181211

182212
messages = [
@@ -186,14 +216,17 @@ def test_phimoe_instruct_with_static_cache(self):
186216
},
187217
{"role": "user", "content": "Can you provide ways to eat combinations of bananas and dragonfruits?"},
188218
]
189-
inputs = tokenizer.apply_chat_template(messages, add_generation_prompt=True, return_tensors="pt")
219+
inputs = tokenizer.apply_chat_template(messages, add_generation_prompt=True, return_tensors="pt").to(
220+
torch_device
221+
)
190222

191-
response_tokens = PhimoeMiniWithStaticCache.generate(model, inputs, 64)
223+
response_tokens = PhimoeMiniWithStaticCache.generate(model, inputs, max_seq_len=30)
192224

193225
output_text = tokenizer.batch_decode(torch.tensor([response_tokens], dtype=torch.long, device=torch_device))
194226

227+
# This is dummy outputs. We actually check if it could run with static cache, not the output quality.
195228
EXPECTED_OUTPUT = [
196-
"<|system|> You are a helpful digital assistant. Please provide safe, ethical and accurate information to the user.<|end|><|user|> Can you provide ways to eat combinations of bananas and dragonfruits?<|end|><|assistant|> Certainly! Bananas and dragonfruits are both delicious and nutritious fruits that can"
229+
"<|system|> You are a helpful digital assistant. Please provide safe, ethical and accurate information to the user.<|end|><|user|> Can you provide ways to eat combinations of bananas and dragonfruits?<|end|><|assistant|> awards"
197230
]
198231

199232
self.assertListEqual(output_text, EXPECTED_OUTPUT)

0 commit comments

Comments
 (0)