@@ -1240,15 +1240,20 @@ def fused_experts_impl(hidden_states: torch.Tensor,
12401240
12411241 config = get_config_func (M )
12421242
1243- intermediate_cache1 = torch .empty ((M , top_k_num , N ),
1244- device = hidden_states .device ,
1245- dtype = hidden_states .dtype )
1243+ # We can reuse the memory between these because by the time we need
1244+ # cache3, we're done with cache1
1245+ cache13 = torch .empty (M * top_k_num * max (N , w2 .shape [1 ]),
1246+ device = hidden_states .device ,
1247+ dtype = hidden_states .dtype )
1248+ intermediate_cache1 = cache13 [:M * top_k_num * N ].view (
1249+ (M , topk_ids .shape [1 ], N ))
1250+ intermediate_cache3 = cache13 [:M * top_k_num * w2 .shape [1 ]].view (
1251+ (M , topk_ids .shape [1 ], w2 .shape [1 ]))
1252+
1253+ # This needs separate memory since it's used concurrently with cache1
12461254 intermediate_cache2 = torch .empty ((M * top_k_num , N // 2 ),
12471255 device = hidden_states .device ,
12481256 dtype = hidden_states .dtype )
1249- intermediate_cache3 = torch .empty ((M , top_k_num , w2 .shape [1 ]),
1250- device = hidden_states .device ,
1251- dtype = hidden_states .dtype )
12521257
12531258 if hidden_states .dtype == torch .bfloat16 :
12541259 compute_type = tl .bfloat16
0 commit comments