@@ -113,12 +113,9 @@ def _maybe_pad_weight(self, weight: torch.Tensor) -> torch.Tensor:
113113 def process_weights_after_loading (self , layer : torch .nn .Module ) -> None :
114114 super ().process_weights_after_loading (layer )
115115
116- layer .w13_weight = torch .nn .Parameter (self ._maybe_pad_weight (
117- layer .w13_weight .data ),
118- requires_grad = False )
119- layer .w2_weight = torch .nn .Parameter (self ._maybe_pad_weight (
120- layer .w2_weight .data ),
121- requires_grad = False )
116+ # Padding the weight for better performance on ROCm
117+ layer .w13_weight .data = self ._maybe_pad_weight (layer .w13_weight .data )
118+ layer .w2_weight .data = self ._maybe_pad_weight (layer .w2_weight .data )
122119 # Lazy import to avoid importing triton.
123120 from vllm .model_executor .layers .fused_moe .rocm_aiter_fused_moe import (
124121 is_rocm_aiter_moe_enabled , shuffle_weights )
@@ -127,10 +124,8 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
127124 shuffled_w13 , shuffled_w2 = shuffle_weights (
128125 layer .w13_weight .data , layer .w2_weight .data )
129126
130- layer .w13_weight = torch .nn .Parameter (shuffled_w13 ,
131- requires_grad = False )
132- layer .w2_weight = torch .nn .Parameter (shuffled_w2 ,
133- requires_grad = False )
127+ layer .w13_weight .data = shuffled_w13
128+ layer .w2_weight .data = shuffled_w2
134129
135130 if current_platform .is_cpu ():
136131 if current_platform .get_cpu_architecture () == CpuArchEnum .X86 :
0 commit comments