Skip to content

Commit

Permalink
Merge pull request vllm-project#2 from Starmys/dev/chengzhang/fix-spa…
Browse files Browse the repository at this point in the history
…rsemixer

Fix sparse-mixer
  • Loading branch information
xiaoxiawu-microsoft authored May 10, 2024
2 parents fc9dd73 + 9a33bb0 commit e1ce33a
Show file tree
Hide file tree
Showing 2 changed files with 195 additions and 24 deletions.
15 changes: 10 additions & 5 deletions vllm/model_executor/layers/fused_moe/fused_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import functools
import json
import os
from typing import Any, Dict, Optional, Tuple
from typing import Any, Dict, Optional, Tuple, Callable

import torch
import triton
Expand Down Expand Up @@ -321,6 +321,7 @@ def fused_moe(
w2_scale: Optional[torch.Tensor] = None,
a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None,
routing_func: Callable = torch.topk,
) -> torch.Tensor:
"""
This function computes a Mixture of Experts (MoE) layer using two sets of
Expand Down Expand Up @@ -362,12 +363,14 @@ def fused_moe(
M, _ = hidden_states.shape
E, N, _ = w1.shape

if is_hip():
if routing_func != torch.topk:
topk_weights, topk_ids = routing_func(gating_output, topk)
elif is_hip():
# The MoE kernels are not yet supported on ROCm.
routing_weights = torch.softmax(gating_output,
dim=-1,
dtype=torch.float32)
topk_weights, topk_ids = torch.topk(routing_weights, topk, dim=-1)
topk_weights, topk_ids = routing_func(routing_weights, topk)
else:
import vllm._moe_C as moe_kernels

Expand Down Expand Up @@ -433,6 +436,8 @@ def fused_moe(

sorted_token_ids, expert_ids, num_tokens_post_padded = moe_align_block_size(
topk_ids, config['BLOCK_SIZE_M'], E)
compute_type = (tl.bfloat16
if hidden_states.dtype == torch.bfloat16 else tl.float16)

invoke_fused_moe_kernel(hidden_states,
w1,
Expand All @@ -447,7 +452,7 @@ def fused_moe(
False,
topk_ids.shape[1],
config,
compute_type=tl.float16,
compute_type=compute_type,
use_fp8=use_fp8)

ops.silu_and_mul(intermediate_cache2, intermediate_cache1.view(-1, N))
Expand All @@ -465,7 +470,7 @@ def fused_moe(
True,
1,
config,
compute_type=tl.float16,
compute_type=compute_type,
use_fp8=use_fp8)

if inplace:
Expand Down
204 changes: 185 additions & 19 deletions vllm/model_executor/models/phimoe.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,9 @@

import torch
from torch import nn
from transformers import MixtralConfig

from transformers.configuration_utils import PretrainedConfig
from transformers.utils import logging

from vllm import _custom_ops as ops
from vllm.attention import Attention, AttentionMetadata
Expand All @@ -52,6 +54,171 @@
from vllm.utils import print_warning_once


logger = logging.get_logger(__name__)


class PhiMoEConfig(PretrainedConfig):

model_type = "phi3_moe"
keys_to_ignore_at_inference = ["past_key_values"]

def __init__(
self,
vocab_size=32000,
hidden_size=4096,
intermediate_size=14336,
num_hidden_layers=32,
num_attention_heads=32,
num_key_value_heads=8,
hidden_act="silu",
max_position_embeddings=4096 * 32,
initializer_range=0.02,
rms_norm_eps=1e-5,
use_cache=True,
pad_token_id=None,
bos_token_id=1,
eos_token_id=2,
tie_word_embeddings=False,
rope_theta=1e6,
sliding_window=None,
attention_dropout=0.0,
num_experts_per_tok=2,
num_local_experts=8,
output_router_logits=False,
router_aux_loss_coef=0.001,
router_jitter_noise=0.0,
attention_bias=False,
lm_head_bias=False,
**kwargs,
):
self.vocab_size = vocab_size
self.max_position_embeddings = max_position_embeddings
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.sliding_window = sliding_window
self.attention_bias = attention_bias
self.lm_head_bias = lm_head_bias
# for backward compatibility
if num_key_value_heads is None:
num_key_value_heads = num_attention_heads

self.num_key_value_heads = num_key_value_heads
self.hidden_act = hidden_act
self.initializer_range = initializer_range
self.rms_norm_eps = rms_norm_eps
self.use_cache = use_cache
self.rope_theta = rope_theta
self.attention_dropout = attention_dropout

self.num_experts_per_tok = num_experts_per_tok
self.num_local_experts = num_local_experts
self.output_router_logits = output_router_logits
self.router_aux_loss_coef = router_aux_loss_coef
self.router_jitter_noise = router_jitter_noise
super().__init__(
pad_token_id=pad_token_id,
bos_token_id=bos_token_id,
eos_token_id=eos_token_id,
tie_word_embeddings=tie_word_embeddings,
**kwargs,
)


class mp(torch.autograd.Function):

@staticmethod
def forward(
ctx,
scores: torch.Tensor,
multiplier: torch.Tensor,
selected_experts: torch.Tensor,
masked_gates: torch.Tensor,
mask_for_one: torch.Tensor,
):
ctx.save_for_backward(multiplier, selected_experts, masked_gates)
return multiplier * mask_for_one

@staticmethod
def backward(
ctx,
grad_at_output: torch.Tensor,
):
multiplier, selected_experts, masked_gates = ctx.saved_tensors

grad_at_output = grad_at_output * multiplier

grad_at_scores_expaned = masked_gates * grad_at_output.mul(-1)
grad_at_scores_expaned.scatter_add_(
dim=-1,
index=selected_experts,
src=grad_at_output,
)

return (
grad_at_scores_expaned,
None,
None,
None,
None,
)


def sparsemixer(scores, top_k, jitter_eps=0.1):
assert top_k == 2

################ first expert ################

with torch.no_grad():
# compute mask for sparsity
mask_logits_threshold, max_ind = scores.max(dim=-1, keepdim=True)
factor = scores.abs().clamp(min=mask_logits_threshold)
mask_logits_threshold = (
(mask_logits_threshold - scores) / factor
) > (2 * jitter_eps)

# apply mask
masked_gates = scores.masked_fill(mask_logits_threshold, float('-inf'))
selected_experts = max_ind

# compute scores for gradients
masked_gates = torch.softmax(masked_gates, dim=-1)
multiplier_o = masked_gates.gather(dim=-1, index=selected_experts)

multiplier = multiplier_o

# masked out first expert
masked_scores = torch.scatter(
scores,
-1,
selected_experts,
float('-inf'),
)
with torch.no_grad():
# compute mask for sparsity
mask_logits_threshold, max_ind = masked_scores.max(dim=-1, keepdim=True)
factor = scores.abs().clamp(min=mask_logits_threshold)
mask_logits_threshold = (
(mask_logits_threshold - scores) / factor
) > (2 * jitter_eps)

# apply mask
masked_gates_top2 = masked_scores.masked_fill(mask_logits_threshold, float('-inf'))
selected_experts_top2 = max_ind
# compute scores for gradients
masked_gates_top2 = torch.softmax(masked_gates_top2, dim=-1)
multiplier_top2 = masked_gates_top2.gather(dim=-1, index=selected_experts_top2)

multiplier = torch.concat((multiplier, multiplier_top2), dim=-1)
selected_experts = torch.concat((selected_experts, selected_experts_top2), dim=-1)

return (
multiplier,
selected_experts,
)


class PhiMoE(nn.Module):
"""A tensor-parallel MoE implementation for PhiMoE that shards each expert
across all ranks.
Expand Down Expand Up @@ -174,10 +341,7 @@ def process_weights_after_loading(self):
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
num_tokens, hidden_size = hidden_states.shape
hidden_states = hidden_states.view(-1, self.hidden_size)
# router_logits: (num_tokens, n_experts)
router_logits, _ = self.gate(hidden_states)
#print (self.ws)
#import pdb;pdb.set_trace()
final_hidden_states = fused_moe(hidden_states,
self.ws,
self.w2s,
Expand All @@ -189,7 +353,8 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
w1_scale=self.ws_scale,
w2_scale=self.w2s_scale,
a1_scale=self.as_scale,
a2_scale=self.a2s_scale)
a2_scale=self.a2s_scale,
routing_func=sparsemixer)

if self.tp_size > 1:
final_hidden_states = tensor_model_parallel_all_reduce(
Expand All @@ -201,6 +366,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
class PhiMoEAttention(nn.Module):

def __init__(self,
config: PhiMoEConfig,
hidden_size: int,
num_heads: int,
num_kv_heads: int,
Expand Down Expand Up @@ -243,21 +409,21 @@ def __init__(self,
self.head_dim,
self.total_num_heads,
self.total_num_kv_heads,
bias=True,
bias=config.attention_bias,
quant_config=quant_config,
)
self.o_proj = RowParallelLinear(
self.total_num_heads * self.head_dim,
hidden_size,
bias=True,
bias=config.attention_bias,
quant_config=quant_config,
)
self.rotary_emb = get_rope(
self.head_dim,
rotary_dim=self.head_dim,
max_position=max_position,
base=int(self.rope_theta),
# is_neox_style=True,
is_neox_style=True,
)

self.attn = Attention(
Expand Down Expand Up @@ -287,29 +453,29 @@ class PhiMoEDecoderLayer(nn.Module):

def __init__(
self,
config: MixtralConfig,
config: PhiMoEConfig,
quant_config: Optional[QuantizationConfig] = None,
) -> None:
super().__init__()
self.hidden_size = config.hidden_size
# Requires transformers > 4.32.0
rope_theta = 10000.0 #getattr(config, "rope_theta", 10000)
rope_theta = getattr(config, "rope_theta", 10000)
self.self_attn = PhiMoEAttention(
config=config,
hidden_size=self.hidden_size,
num_heads=config.num_attention_heads,
max_position=config.max_position_embeddings,
num_kv_heads=config.num_key_value_heads,
rope_theta=rope_theta,
sliding_window=config.sliding_window,
quant_config=quant_config)

self.block_sparse_moe = PhiMoE(
num_experts=config.num_local_experts,
top_k=config.num_experts_per_tok,
hidden_size=config.hidden_size,
intermediate_size=config.intermediate_size,
quant_config=quant_config)


self.input_layernorm = nn.LayerNorm(config.hidden_size,
eps=config.rms_norm_eps,
Expand All @@ -326,20 +492,19 @@ def forward(
attn_metadata: AttentionMetadata,
residual: Optional[torch.Tensor],
) -> torch.Tensor:
# Self Attention
# if residual is None:
residual = hidden_states

# Self Attention
hidden_states = self.input_layernorm(hidden_states)
# else:
# hidden_states, residual = self.input_layernorm(
# hidden_states, residual)

hidden_states = self.self_attn(
positions=positions,
hidden_states=hidden_states,
kv_cache=kv_cache,
attn_metadata=attn_metadata,
)
hidden_states = hidden_states + residual

# Fully Connected
residual = hidden_states
hidden_states = self.post_attention_layernorm(hidden_states)
Expand All @@ -353,7 +518,7 @@ class PhiMoEModel(nn.Module):

def __init__(
self,
config: MixtralConfig,
config: PhiMoEConfig,
quant_config: Optional[QuantizationConfig] = None,
lora_config: Optional[LoRAConfig] = None,
) -> None:
Expand Down Expand Up @@ -419,7 +584,7 @@ class PhiMoEForCausalLM(nn.Module):

def __init__(
self,
config: MixtralConfig,
config: PhiMoEConfig,
quant_config: Optional[QuantizationConfig] = None,
lora_config: Optional[LoRAConfig] = None,
) -> None:
Expand All @@ -446,6 +611,7 @@ def __init__(
config.vocab_size)
self.sampler = Sampler()


def forward(
self,
input_ids: torch.Tensor,
Expand Down

0 comments on commit e1ce33a

Please sign in to comment.