Skip to content

Commit 922f4aa

Browse files
authored
[Torchax] Optimize MoE weight layout (#778)
Signed-off-by: Kyuyeun Kim <kyuyeunk@google.com>
1 parent 32720d7 commit 922f4aa

File tree

1 file changed

+13
-4
lines changed

1 file changed

+13
-4
lines changed

tpu_commons/models/vllm/quantization/unquantized.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import jax
55
import jax.numpy as jnp
66
import torch
7+
from jax.experimental.layout import Format, Layout
78
from jax.sharding import Mesh, NamedSharding, PartitionSpec
89
from torch.nn.parameter import Parameter
910
from torchax.interop import jax_view, torch_view
@@ -177,9 +178,13 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
177178

178179
if layer.use_ep:
179180
w13_weight = jax.device_put(
180-
w13_weight, NamedSharding(self.mesh, P('model', None, None)))
181+
w13_weight,
182+
Format(Layout((0, 1, 2)),
183+
NamedSharding(self.mesh, P('model', None, None))))
181184
w2_weight = jax.device_put(
182-
w2_weight, NamedSharding(self.mesh, P('model', None, None)))
185+
w2_weight,
186+
Format(Layout((0, 1, 2)),
187+
NamedSharding(self.mesh, P('model', None, None))))
183188
else:
184189
intermediate_size = w13_weight.shape[1] // 2
185190
assert intermediate_size == w2_weight.shape[-1]
@@ -191,9 +196,13 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
191196
n_shards,
192197
dim=1)
193198
w13_weight = jax.device_put(
194-
w13_weight, NamedSharding(self.mesh, P(None, 'model', None)))
199+
w13_weight,
200+
Format(Layout((0, 1, 2)),
201+
NamedSharding(self.mesh, P(None, 'model', None))))
195202
w2_weight = jax.device_put(
196-
w2_weight, NamedSharding(self.mesh, P(None, None, 'model')))
203+
w2_weight,
204+
Format(Layout((0, 1, 2)),
205+
NamedSharding(self.mesh, P(None, None, 'model'))))
197206
w13_weight = Parameter(torch_view(w13_weight), requires_grad=False)
198207
w2_weight = Parameter(torch_view(w2_weight), requires_grad=False)
199208

0 commit comments

Comments
 (0)