44import jax
55import jax .numpy as jnp
66import torch
7+ from jax .experimental .layout import Format , Layout
78from jax .sharding import Mesh , NamedSharding , PartitionSpec
89from torch .nn .parameter import Parameter
910from torchax .interop import jax_view , torch_view
@@ -177,9 +178,13 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
177178
178179 if layer .use_ep :
179180 w13_weight = jax .device_put (
180- w13_weight , NamedSharding (self .mesh , P ('model' , None , None )))
181+ w13_weight ,
182+ Format (Layout ((0 , 1 , 2 )),
183+ NamedSharding (self .mesh , P ('model' , None , None ))))
181184 w2_weight = jax .device_put (
182- w2_weight , NamedSharding (self .mesh , P ('model' , None , None )))
185+ w2_weight ,
186+ Format (Layout ((0 , 1 , 2 )),
187+ NamedSharding (self .mesh , P ('model' , None , None ))))
183188 else :
184189 intermediate_size = w13_weight .shape [1 ] // 2
185190 assert intermediate_size == w2_weight .shape [- 1 ]
@@ -191,9 +196,13 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
191196 n_shards ,
192197 dim = 1 )
193198 w13_weight = jax .device_put (
194- w13_weight , NamedSharding (self .mesh , P (None , 'model' , None )))
199+ w13_weight ,
200+ Format (Layout ((0 , 1 , 2 )),
201+ NamedSharding (self .mesh , P (None , 'model' , None ))))
195202 w2_weight = jax .device_put (
196- w2_weight , NamedSharding (self .mesh , P (None , None , 'model' )))
203+ w2_weight ,
204+ Format (Layout ((0 , 1 , 2 )),
205+ NamedSharding (self .mesh , P (None , None , 'model' ))))
197206 w13_weight = Parameter (torch_view (w13_weight ), requires_grad = False )
198207 w2_weight = Parameter (torch_view (w2_weight ), requires_grad = False )
199208
0 commit comments