11from collections .abc import Callable
22
33import torch
4+ import torch .nn .functional as F
45from compressed_tensors .quantization import QuantizationStrategy
56
67from jax .sharding import PartitionSpec as P
5859from jax .experimental .layout import Format , Layout
5960from jax .sharding import Mesh , NamedSharding , PartitionSpec
6061from torch .nn .parameter import Parameter
61- from torchax .interop import jax_view , torch_view
62+ from torchax .interop import jax_view , torch_view , call_jax
6263from torchax .ops .mappings import t2j
6364from vllm .attention .layer import Attention
6465from vllm .logger import init_logger
@@ -139,33 +140,47 @@ def __init__(
139140 def process_weights_after_loading (self , layer : torch .nn .Module ) -> None :
140141 assert isinstance (layer , FusedMoE )
141142
143+ intermediate_size = layer .w13_weight .shape [1 ] // 2
144+ w1_weight = layer .w13_weight [:, :intermediate_size ]
145+ w3_weight = layer .w13_weight [:, intermediate_size :]
146+ w1_weight_scale = layer .w13_weight_scale [:, :intermediate_size ]
147+ w3_weight_scale = layer .w13_weight_scale [:, intermediate_size :]
148+
142149 w2_weight = t2j (layer .w2_weight , use_dlpack = False )
143- w13_weight = t2j (layer .w13_weight , use_dlpack = False )
144- w13_weight_scale = t2j (layer .w13_weight_scale , use_dlpack = False )
145150 w2_weight_scale = t2j (layer .w2_weight_scale , use_dlpack = False )
151+ w1_weight = t2j (w1_weight , use_dlpack = False )
152+ w1_weight_scale = t2j (w1_weight_scale , use_dlpack = False )
153+ w3_weight = t2j (w3_weight , use_dlpack = False )
154+ w3_weight_scale = t2j (w3_weight_scale , use_dlpack = False )
146155
147156 if layer .use_ep :
148157 format = Format (
149158 Layout ((0 , 1 , 2 )), NamedSharding (self .mesh , P ("model" , None , None ))
150159 )
151- w13_weight = jax .device_put (w13_weight , format )
152- w13_weight_scale = jax .device_put (w13_weight_scale , format )
160+ w1_weight = jax .device_put (w1_weight , format )
161+ w1_weight_scale = jax .device_put (w1_weight_scale , format )
162+ w3_weight = jax .device_put (w3_weight , format )
163+ w3_weight_scale = jax .device_put (w3_weight_scale , format )
153164 w2_weight = jax .device_put (w2_weight , format )
154165 w2_weight_scale = jax .device_put (w2_weight_scale , format )
155166 else :
156- intermediate_size = w13_weight .shape [1 ] // 2
157167 assert intermediate_size == w2_weight .shape [- 1 ]
158168 output_sizes = [intermediate_size , intermediate_size ]
159169 n_shards = self .mesh .shape ["model" ]
160170 assert intermediate_size % n_shards == 0
161- w13_weight = reorder_concatenated_tensor_for_sharding (
162- w13_weight , output_sizes , n_shards , dim = 1
163- )
171+
172+ # TODO: enable this if using fused weights
173+ #w13_weight = reorder_concatenated_tensor_for_sharding(
174+ # w13_weight, output_sizes, n_shards, dim=1
175+ #)
176+
164177 w13_format = Format (
165178 Layout ((0 , 1 , 2 )), NamedSharding (self .mesh , P (None , "model" , None ))
166179 )
167- w13_weight = jax .device_put (w13_weight , w13_format )
168- w13_weight_scale = jax .device_put (w13_weight_scale , w13_format )
180+ w1_weight = jax .device_put (w1_weight , w13_format )
181+ w1_weight_scale = jax .device_put (w1_weight_scale , w13_format )
182+ w3_weight = jax .device_put (w3_weight , w13_format )
183+ w3_weight_scale = jax .device_put (w3_weight_scale , w13_format )
169184 w2_weight = jax .device_put (
170185 w2_weight ,
171186 Format (
@@ -176,15 +191,21 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
176191 w2_weight_scale ,
177192 Format (Layout ((0 , 1 , 2 )), NamedSharding (self .mesh , P ())),
178193 ) # replicate
179- w13_weight = Parameter (torch_view (w13_weight ), requires_grad = False )
194+
195+ w1_weight = Parameter (torch_view (w1_weight ), requires_grad = False )
196+ w1_weight_scale = Parameter (torch_view (w1_weight_scale ), requires_grad = False )
180197 w2_weight = Parameter (torch_view (w2_weight ), requires_grad = False )
181- w13_weight_scale = Parameter (torch_view (w13_weight_scale ), requires_grad = False )
182198 w2_weight_scale = Parameter (torch_view (w2_weight_scale ), requires_grad = False )
199+ w3_weight = Parameter (torch_view (w3_weight ), requires_grad = False )
200+ w3_weight_scale = Parameter (torch_view (w3_weight_scale ), requires_grad = False )
183201
184- layer .w13_weight = w13_weight
202+ # TODO dont reuse variable
203+ layer .w13_weight = w1_weight
204+ layer .w13_weight_scale = w1_weight_scale
185205 layer .w2_weight = w2_weight
186- layer .w13_weight_scale = w13_weight_scale
187206 layer .w2_weight_scale = w2_weight_scale
207+ layer .w3_weight = w3_weight
208+ layer .w3_weight_scale = w3_weight_scale
188209
189210 def apply (
190211 self ,
@@ -215,45 +236,97 @@ def apply(
215236 if scoring_func != "softmax" :
216237 raise NotImplementedError ("Only softmax is supported for scoring_func" )
217238
218- import sys
219-
220- sys .stdin = open (0 )
221- breakpoint ()
222-
223- _fused_moe_func = functools .partial (
224- jax .jit (
225- jax_fused_moe_func_padded ,
226- static_argnames = [
227- "topk" ,
228- "global_num_experts" ,
229- "renormalize" ,
230- "reduce_results" ,
231- "mesh" ,
232- "use_ep" ,
233- ],
234- ),
235- topk = top_k ,
236- global_num_experts = global_num_experts ,
237- renormalize = renormalize ,
238- reduce_results = layer .reduce_results ,
239- mesh = self .mesh ,
240- use_ep = layer .use_ep ,
241- )
239+ seqlen = x .shape [0 ]
240+
241+ # import sys
242+
243+ # sys.stdin = open(0)
244+ # breakpoint()
242245
243- output = _fused_moe_func (
244- jax_view (x ),
245- (
246- jax_view (layer .w13_weight ).astype (jnp .float32 .dtype )
247- * jax_view (layer .w13_weight_scale )
248- ).astype (jnp .bfloat16 .dtype ),
249- (
250- jax_view (layer .w2_weight ).astype (jnp .float32 .dtype )
251- * jax_view (layer .w2_weight_scale )
252- ).astype (jnp .bfloat16 .dtype ),
253- jax_view (router_logits ),
246+ expert_weights = F .softmax (router_logits , dim = - 1 )
247+ expert_weights , expert_indices = torch .topk (
248+ expert_weights , top_k , dim = - 1
254249 )
250+ if renormalize :
251+ expert_weights /= expert_weights .sum (dim = - 1 , keepdim = True )
252+
253+ # cond ffn
254+ # e = total num of exp = 160
255+ # t = seqlen
256+ # o = config.imtermediate size
257+ # i = config.dim
258+ #torch.einsum("ti, eoi -> teo", x, layer.w13_weight) * self.w13_weight_scale)
259+ ux1 = call_jax (
260+ jax .lax .dot , x , layer .w13_weight ,
261+ dimension_numbers = (((1 , ), (2 , )), ((), ())),
262+ preferred_element_type = jnp .bfloat16 .dtype
263+ )
264+ x1 = F .silu (ux1 * layer .w13_weight_scale .squeeze (2 ))
265+
266+ #x3 = torch.einsum("ti, eoi -> teo", x, layer.w3_weight) * self.w3_weight_scale
267+ x3 = call_jax (
268+ jax .lax .dot , x , layer .w3_weight ,
269+ dimension_numbers = (((1 , ), (2 , )), ((), ())),
270+ preferred_element_type = jnp .bfloat16 .dtype
271+ ) * layer .w3_weight_scale .squeeze (2 )
272+
273+
274+ #expert_outs = torch.einsum("teo, eio -> tei", (x1 * x3), self.w2_weight) * self.w2_weight_scale
275+ expert_outs = call_jax (
276+ jax .lax .dot , x1 * x3 , layer .w2_weight ,
277+ dimension_numbers = (((2 , ), (2 , )), ((1 , ), (0 , ))),
278+ preferred_element_type = jnp .bfloat16 .dtype
279+ ).transpose (0 , 1 ) * layer .w2_weight_scale .squeeze (2 )
280+
281+
282+ seq_indexes = torch .arange (seqlen , device = 'jax' ).unsqueeze (1 )
283+ expert_outs = expert_outs [seq_indexes , expert_indices ]
284+
285+ # out = torch.einsum("tai,ta -> ti", expert_outs, expert_weights)
286+ out = call_jax (
287+ jax .lax .dot , expert_outs , expert_weights ,
288+ dimension_numbers = (((1 , ), (1 , )), ((0 , ), (0 , ))),
289+ preferred_element_type = jnp .bfloat16 .dtype
290+ )
255291
256- return torch_view (output )
292+ return out
293+
294+
295+
296+ # _fused_moe_func = functools.partial(
297+ # jax.jit(
298+ # jax_fused_moe_func_padded,
299+ # static_argnames=[
300+ # "topk",
301+ # "global_num_experts",
302+ # "renormalize",
303+ # "reduce_results",
304+ # "mesh",
305+ # "use_ep",
306+ # ],
307+ # ),
308+ # topk=top_k,
309+ # global_num_experts=global_num_experts,
310+ # renormalize=renormalize,
311+ # reduce_results=layer.reduce_results,
312+ # mesh=self.mesh,
313+ # use_ep=layer.use_ep,
314+ # )
315+
316+ # output = _fused_moe_func(
317+ # jax_view(x),
318+ # (
319+ # jax_view(layer.w13_weight).astype(jnp.bfloat16.dtype)
320+ # * jax_view(layer.w13_weight_scale).astype(jnp.bfloat16.dtype)
321+ # ).astype(jnp.bfloat16.dtype),
322+ # (
323+ # jax_view(layer.w2_weight).astype(jnp.bfloat16.dtype)
324+ # * jax_view(layer.w2_weight_scale).astype(jnp.bfloat16.dtype)
325+ # ).astype(jnp.bfloat16.dtype),
326+ # jax_view(router_logits),
327+ # )
328+
329+ # return torch_view(output)
257330
258331 def create_weights (
259332 self ,
0 commit comments