@@ -57,17 +57,19 @@ def _partition_fn(self, name, module, device_mesh):
5757 # w1 shape = (experts, out_dim, in_dim)
5858 module .register_parameter (
5959 "w1" , nn .Parameter (distribute_tensor (module .w1 , device_mesh , [Shard (1 )]))
60- )
60+ ) # Rowwise sharding
61+
6162 # w2 shape = (experts, in_dim, out_dim)
6263 module .register_parameter (
6364 "w2" ,
6465 nn .Parameter (distribute_tensor (module .w2 , device_mesh , [Shard (2 )])),
65- )
66+ ) # Columnwise sharding
67+
6668 # w3 shape = (experts, out_dim, in_dim)
6769 module .register_parameter (
6870 "w3" ,
6971 nn .Parameter (distribute_tensor (module .w3 , device_mesh , [Shard (1 )])),
70- )
72+ ) # Columnwise sharding
7173
7274 def _apply (self , module : nn .Module , device_mesh : DeviceMesh ) -> nn .Module :
7375 return distribute_module (
@@ -230,17 +232,19 @@ def _partition_fn_2d(self, name, mod, ep_tp_mesh):
230232 mod .register_parameter (
231233 "w1" ,
232234 nn .Parameter (distribute_tensor (mod .w1 , ep_tp_mesh , [Shard (0 ), Shard (1 )])),
233- )
235+ ) # Rowwise sharding
236+
234237 # w2 shape = (experts, in_dim, out_dim)
235238 mod .register_parameter (
236239 "w2" ,
237240 nn .Parameter (distribute_tensor (mod .w2 , ep_tp_mesh , [Shard (0 ), Shard (2 )])),
238- )
241+ ) # Columnwise sharding
242+
239243 # w3 shape = (experts, out_dim, in_dim)
240244 mod .register_parameter (
241245 "w3" ,
242246 nn .Parameter (distribute_tensor (mod .w3 , ep_tp_mesh , [Shard (0 ), Shard (1 )])),
243- )
247+ ) # Rowwise sharding
244248
245249 def _token_combine (self , mod , routed_output , device_mesh ):
246250 # token combine happens on the EP mesh, whereas device_mesh is [ep, tp] mesh
0 commit comments