Skip to content

Commit

Permalink
[Model] Adding Granite MoE. (vllm-project#8206)
Browse files Browse the repository at this point in the history
Co-authored-by: Nick Hill <nickhill@us.ibm.com>
Signed-off-by: Sumit Dubey <sumit.dubey2@ibm.com>
  • Loading branch information
2 people authored and sumitd2 committed Nov 14, 2024
1 parent 5510852 commit c3852ee
Show file tree
Hide file tree
Showing 4 changed files with 492 additions and 3 deletions.
39 changes: 39 additions & 0 deletions tests/models/decoder_only/language/test_granitemoe.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
"""Compare the outputs of HF and vLLM for Granite models using greedy sampling.
Run `pytest tests/models/test_granite.py`.
"""
import pytest

from ...utils import check_logprobs_close

MODELS = [
"ibm/PowerMoE-3b",
]


@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["bfloat16"])
@pytest.mark.parametrize("max_tokens", [64])
@pytest.mark.parametrize("num_logprobs", [5])
def test_models(
hf_runner,
vllm_runner,
example_prompts,
model: str,
dtype: str,
max_tokens: int,
num_logprobs: int,
) -> None:
with hf_runner(model, dtype=dtype) as hf_model:
hf_outputs = hf_model.generate_greedy_logprobs_limit(
example_prompts, max_tokens, num_logprobs)

with vllm_runner(model, dtype=dtype) as vllm_model:
vllm_outputs = vllm_model.generate_greedy_logprobs(
example_prompts, max_tokens, num_logprobs)
check_logprobs_close(
outputs_0_lst=hf_outputs,
outputs_1_lst=vllm_outputs,
name_0="hf",
name_1="vllm",
)
1 change: 1 addition & 0 deletions vllm/model_executor/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
"GPTJForCausalLM": ("gpt_j", "GPTJForCausalLM"),
"GPTNeoXForCausalLM": ("gpt_neox", "GPTNeoXForCausalLM"),
"GraniteForCausalLM": ("granite", "GraniteForCausalLM"),
"GraniteMoeForCausalLM": ("granitemoe", "GraniteMoeForCausalLM"),
"InternLMForCausalLM": ("llama", "LlamaForCausalLM"),
"InternLM2ForCausalLM": ("internlm2", "InternLM2ForCausalLM"),
"JAISLMHeadModel": ("jais", "JAISLMHeadModel"),
Expand Down
7 changes: 4 additions & 3 deletions vllm/model_executor/models/granite.py
Original file line number Diff line number Diff line change
Expand Up @@ -404,9 +404,12 @@ def __init__(
self.lm_head.weight = self.model.embed_tokens.weight

logit_scale = getattr(config, "logit_scale", 1.0)

if hasattr(config, "logits_scaling"):
logit_scale /= config.logits_scaling
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
config.vocab_size,
logit_scale)
scale=logit_scale)
self.sampler = Sampler()
else:
self.lm_head = PPMissingLayer()
Expand All @@ -428,8 +431,6 @@ def compute_logits(
sampling_metadata: SamplingMetadata) -> Optional[torch.Tensor]:
logits = self.logits_processor(self.lm_head, hidden_states,
sampling_metadata)
if logits is not None:
logits /= self.config.logits_scaling
return logits

def sample(
Expand Down
Loading

0 comments on commit c3852ee

Please sign in to comment.