Skip to content

Commit d8c8cd7

Browse files
committed
Use vanilla fp8 math
1 parent 2ecfe92 commit d8c8cd7

File tree

1 file changed

+124
-51
lines changed

1 file changed

+124
-51
lines changed

tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors_moe.py

Lines changed: 124 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from collections.abc import Callable
22

33
import torch
4+
import torch.nn.functional as F
45
from compressed_tensors.quantization import QuantizationStrategy
56

67
from jax.sharding import PartitionSpec as P
@@ -58,7 +59,7 @@
5859
from jax.experimental.layout import Format, Layout
5960
from jax.sharding import Mesh, NamedSharding, PartitionSpec
6061
from torch.nn.parameter import Parameter
61-
from torchax.interop import jax_view, torch_view
62+
from torchax.interop import jax_view, torch_view, call_jax
6263
from torchax.ops.mappings import t2j
6364
from vllm.attention.layer import Attention
6465
from vllm.logger import init_logger
@@ -139,33 +140,47 @@ def __init__(
139140
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
140141
assert isinstance(layer, FusedMoE)
141142

143+
intermediate_size = layer.w13_weight.shape[1] // 2
144+
w1_weight = layer.w13_weight[:, :intermediate_size]
145+
w3_weight = layer.w13_weight[:, intermediate_size:]
146+
w1_weight_scale = layer.w13_weight_scale[:, :intermediate_size]
147+
w3_weight_scale = layer.w13_weight_scale[:, intermediate_size:]
148+
142149
w2_weight = t2j(layer.w2_weight, use_dlpack=False)
143-
w13_weight = t2j(layer.w13_weight, use_dlpack=False)
144-
w13_weight_scale = t2j(layer.w13_weight_scale, use_dlpack=False)
145150
w2_weight_scale = t2j(layer.w2_weight_scale, use_dlpack=False)
151+
w1_weight = t2j(w1_weight, use_dlpack=False)
152+
w1_weight_scale = t2j(w1_weight_scale, use_dlpack=False)
153+
w3_weight = t2j(w3_weight, use_dlpack=False)
154+
w3_weight_scale = t2j(w3_weight_scale, use_dlpack=False)
146155

147156
if layer.use_ep:
148157
format = Format(
149158
Layout((0, 1, 2)), NamedSharding(self.mesh, P("model", None, None))
150159
)
151-
w13_weight = jax.device_put(w13_weight, format)
152-
w13_weight_scale = jax.device_put(w13_weight_scale, format)
160+
w1_weight = jax.device_put(w1_weight, format)
161+
w1_weight_scale = jax.device_put(w1_weight_scale, format)
162+
w3_weight = jax.device_put(w3_weight, format)
163+
w3_weight_scale = jax.device_put(w3_weight_scale, format)
153164
w2_weight = jax.device_put(w2_weight, format)
154165
w2_weight_scale = jax.device_put(w2_weight_scale, format)
155166
else:
156-
intermediate_size = w13_weight.shape[1] // 2
157167
assert intermediate_size == w2_weight.shape[-1]
158168
output_sizes = [intermediate_size, intermediate_size]
159169
n_shards = self.mesh.shape["model"]
160170
assert intermediate_size % n_shards == 0
161-
w13_weight = reorder_concatenated_tensor_for_sharding(
162-
w13_weight, output_sizes, n_shards, dim=1
163-
)
171+
172+
# TODO: enable this if using fused weights
173+
#w13_weight = reorder_concatenated_tensor_for_sharding(
174+
# w13_weight, output_sizes, n_shards, dim=1
175+
#)
176+
164177
w13_format = Format(
165178
Layout((0, 1, 2)), NamedSharding(self.mesh, P(None, "model", None))
166179
)
167-
w13_weight = jax.device_put(w13_weight, w13_format)
168-
w13_weight_scale = jax.device_put(w13_weight_scale, w13_format)
180+
w1_weight = jax.device_put(w1_weight, w13_format)
181+
w1_weight_scale = jax.device_put(w1_weight_scale, w13_format)
182+
w3_weight = jax.device_put(w3_weight, w13_format)
183+
w3_weight_scale = jax.device_put(w3_weight_scale, w13_format)
169184
w2_weight = jax.device_put(
170185
w2_weight,
171186
Format(
@@ -176,15 +191,21 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
176191
w2_weight_scale,
177192
Format(Layout((0, 1, 2)), NamedSharding(self.mesh, P())),
178193
) # replicate
179-
w13_weight = Parameter(torch_view(w13_weight), requires_grad=False)
194+
195+
w1_weight = Parameter(torch_view(w1_weight), requires_grad=False)
196+
w1_weight_scale = Parameter(torch_view(w1_weight_scale), requires_grad=False)
180197
w2_weight = Parameter(torch_view(w2_weight), requires_grad=False)
181-
w13_weight_scale = Parameter(torch_view(w13_weight_scale), requires_grad=False)
182198
w2_weight_scale = Parameter(torch_view(w2_weight_scale), requires_grad=False)
199+
w3_weight = Parameter(torch_view(w3_weight), requires_grad=False)
200+
w3_weight_scale = Parameter(torch_view(w3_weight_scale), requires_grad=False)
183201

184-
layer.w13_weight = w13_weight
202+
# TODO dont reuse variable
203+
layer.w13_weight = w1_weight
204+
layer.w13_weight_scale = w1_weight_scale
185205
layer.w2_weight = w2_weight
186-
layer.w13_weight_scale = w13_weight_scale
187206
layer.w2_weight_scale = w2_weight_scale
207+
layer.w3_weight = w3_weight
208+
layer.w3_weight_scale = w3_weight_scale
188209

189210
def apply(
190211
self,
@@ -215,45 +236,97 @@ def apply(
215236
if scoring_func != "softmax":
216237
raise NotImplementedError("Only softmax is supported for scoring_func")
217238

218-
import sys
219-
220-
sys.stdin = open(0)
221-
breakpoint()
222-
223-
_fused_moe_func = functools.partial(
224-
jax.jit(
225-
jax_fused_moe_func_padded,
226-
static_argnames=[
227-
"topk",
228-
"global_num_experts",
229-
"renormalize",
230-
"reduce_results",
231-
"mesh",
232-
"use_ep",
233-
],
234-
),
235-
topk=top_k,
236-
global_num_experts=global_num_experts,
237-
renormalize=renormalize,
238-
reduce_results=layer.reduce_results,
239-
mesh=self.mesh,
240-
use_ep=layer.use_ep,
241-
)
239+
seqlen = x.shape[0]
240+
241+
# import sys
242+
243+
# sys.stdin = open(0)
244+
# breakpoint()
242245

243-
output = _fused_moe_func(
244-
jax_view(x),
245-
(
246-
jax_view(layer.w13_weight).astype(jnp.float32.dtype)
247-
* jax_view(layer.w13_weight_scale)
248-
).astype(jnp.bfloat16.dtype),
249-
(
250-
jax_view(layer.w2_weight).astype(jnp.float32.dtype)
251-
* jax_view(layer.w2_weight_scale)
252-
).astype(jnp.bfloat16.dtype),
253-
jax_view(router_logits),
246+
expert_weights = F.softmax(router_logits, dim=-1)
247+
expert_weights, expert_indices = torch.topk(
248+
expert_weights, top_k, dim=-1
254249
)
250+
if renormalize:
251+
expert_weights /= expert_weights.sum(dim=-1, keepdim=True)
252+
253+
# cond ffn
254+
# e = total num of exp = 160
255+
# t = seqlen
256+
# o = config.imtermediate size
257+
# i = config.dim
258+
#torch.einsum("ti, eoi -> teo", x, layer.w13_weight) * self.w13_weight_scale)
259+
ux1 = call_jax(
260+
jax.lax.dot, x, layer.w13_weight,
261+
dimension_numbers=(((1, ), (2, )), ((), ())),
262+
preferred_element_type=jnp.bfloat16.dtype
263+
)
264+
x1 = F.silu(ux1 * layer.w13_weight_scale.squeeze(2))
265+
266+
#x3 = torch.einsum("ti, eoi -> teo", x, layer.w3_weight) * self.w3_weight_scale
267+
x3 = call_jax(
268+
jax.lax.dot, x, layer.w3_weight,
269+
dimension_numbers=(((1, ), (2, )), ((), ())),
270+
preferred_element_type=jnp.bfloat16.dtype
271+
) * layer.w3_weight_scale.squeeze(2)
272+
273+
274+
#expert_outs = torch.einsum("teo, eio -> tei", (x1 * x3), self.w2_weight) * self.w2_weight_scale
275+
expert_outs = call_jax(
276+
jax.lax.dot, x1 * x3, layer.w2_weight,
277+
dimension_numbers=(((2, ), (2, )), ((1, ), (0, ))),
278+
preferred_element_type=jnp.bfloat16.dtype
279+
).transpose(0, 1) * layer.w2_weight_scale.squeeze(2)
280+
281+
282+
seq_indexes = torch.arange(seqlen, device='jax').unsqueeze(1)
283+
expert_outs = expert_outs[seq_indexes, expert_indices]
284+
285+
# out = torch.einsum("tai,ta -> ti", expert_outs, expert_weights)
286+
out = call_jax(
287+
jax.lax.dot, expert_outs, expert_weights,
288+
dimension_numbers=(((1, ), (1, )), ((0, ), (0, ))),
289+
preferred_element_type=jnp.bfloat16.dtype
290+
)
255291

256-
return torch_view(output)
292+
return out
293+
294+
295+
296+
# _fused_moe_func = functools.partial(
297+
# jax.jit(
298+
# jax_fused_moe_func_padded,
299+
# static_argnames=[
300+
# "topk",
301+
# "global_num_experts",
302+
# "renormalize",
303+
# "reduce_results",
304+
# "mesh",
305+
# "use_ep",
306+
# ],
307+
# ),
308+
# topk=top_k,
309+
# global_num_experts=global_num_experts,
310+
# renormalize=renormalize,
311+
# reduce_results=layer.reduce_results,
312+
# mesh=self.mesh,
313+
# use_ep=layer.use_ep,
314+
# )
315+
316+
# output = _fused_moe_func(
317+
# jax_view(x),
318+
# (
319+
# jax_view(layer.w13_weight).astype(jnp.bfloat16.dtype)
320+
# * jax_view(layer.w13_weight_scale).astype(jnp.bfloat16.dtype)
321+
# ).astype(jnp.bfloat16.dtype),
322+
# (
323+
# jax_view(layer.w2_weight).astype(jnp.bfloat16.dtype)
324+
# * jax_view(layer.w2_weight_scale).astype(jnp.bfloat16.dtype)
325+
# ).astype(jnp.bfloat16.dtype),
326+
# jax_view(router_logits),
327+
# )
328+
329+
# return torch_view(output)
257330

258331
def create_weights(
259332
self,

0 commit comments

Comments
 (0)