Skip to content

Commit 0ebc9cf

Browse files
tianyu-lgithubsgi
authored andcommitted
minor refactor over EP (pytorch#1854)
This PR: - let `ExpertParallel` handles indices permute / unpermute when EP is used - move `to_local` to model code to be more explicit - rename the `expert_parallel` wrapper which does permute / unpermute to `indices_permutation_wrapper` to be more accurate
1 parent 93f6513 commit 0ebc9cf

File tree

1 file changed

+1
-0
lines changed

1 file changed

+1
-0
lines changed

torchtitan/distributed/expert_parallel.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -223,6 +223,7 @@ def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module:
223223
class ReordererSequenceParallel(ParallelStyle):
224224
def __init__(self):
225225
super().__init__()
226+
self.top_k = None
226227

227228
def _prepare_inputput_fn(self, mod, inputs, device_mesh):
228229
# shape (batch_size*seq_len, top_k)

0 commit comments

Comments
 (0)