Skip to content

Commit c48a430

Browse files
tianyu-lgithubsgi
authored andcommitted
minor refactor over EP (pytorch#1854)
This PR: - let `ExpertParallel` handles indices permute / unpermute when EP is used - move `to_local` to model code to be more explicit - rename the `expert_parallel` wrapper which does permute / unpermute to `indices_permutation_wrapper` to be more accurate
1 parent 18185d7 commit c48a430

File tree

1 file changed

+0
-1
lines changed

1 file changed

+0
-1
lines changed

torchtitan/distributed/expert_parallel.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -227,7 +227,6 @@ def __init__(self):
227227
def _prepare_inputput_fn(self, mod, inputs, device_mesh):
228228
# shape (batch_size*seq_len, top_k)
229229
top_scores, selected_experts_indices = inputs
230-
num_tokens, _ = top_scores.shape
231230

232231
# NOTE: If needed, we can pad tokens in case bs*slen is not divisible by TP degree
233232
# if top_scores.shape[0] % device_mesh.size() != 0:

0 commit comments

Comments
 (0)