Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 2 additions & 18 deletions tests/models/phimoe/test_modeling_phimoe.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@

"""Testing suite for the PyTorch PhiMoE model."""

import copy
import unittest

from parameterized import parameterized
Expand Down Expand Up @@ -59,6 +58,7 @@ def forward(
past_key_values=self.cache,
).logits

@torch.no_grad()
@staticmethod
def generate(model: PhimoeForCausalLM, prompt_tokens: torch.LongTensor, max_seq_len: int) -> list[int]:
model = PhimoeMiniWithStaticCache(model, 1, max_seq_len + prompt_tokens.shape[-1])
Expand Down Expand Up @@ -194,19 +194,6 @@ def test_phimoe_instruct_generation(self):

def test_phimoe_instruct_with_static_cache(self):
model = self.get_model()
# Can't run with the real checkpoint, even if offloaded. Let's just use a tiny dummy one
config = copy.deepcopy(model.config)
config.num_hidden_layers = 2
# make `head_dim = 128`
config.hidden_size = 512
config.num_attention_heads = 4
config.num_key_value_heads = 1
config.intermediate_size = 512
config.max_position_embeddinqgs = 64
config.num_local_experts = 4
torch.manual_seed(42)
model = PhimoeForCausalLM(config).to(torch_device)
model.eval()
tokenizer = AutoTokenizer.from_pretrained("microsoft/Phi-3.5-MoE-instruct")

messages = [
Expand All @@ -221,12 +208,9 @@ def test_phimoe_instruct_with_static_cache(self):
)

response_tokens = PhimoeMiniWithStaticCache.generate(model, inputs, max_seq_len=30)

output_text = tokenizer.batch_decode(torch.tensor([response_tokens], dtype=torch.long, device=torch_device))

# This is dummy outputs. We actually check if it could run with static cache, not the output quality.
EXPECTED_OUTPUT = [
"<|system|> You are a helpful digital assistant. Please provide safe, ethical and accurate information to the user.<|end|><|user|> Can you provide ways to eat combinations of bananas and dragonfruits?<|end|><|assistant|> awards"
"<|system|> You are a helpful digital assistant. Please provide safe, ethical and accurate information to the user.<|end|><|user|> Can you provide ways to eat combinations of bananas and dragonfruits?<|end|><|assistant|> C"
]

self.assertListEqual(output_text, EXPECTED_OUTPUT)