File tree Expand file tree Collapse file tree 1 file changed +11
-0
lines changed Expand file tree Collapse file tree 1 file changed +11
-0
lines changed Original file line number Diff line number Diff line change @@ -161,6 +161,17 @@ def get_chunk_sizes_across_dp_rank(self) -> Optional[list[int]]:
161161 assert self .local_sizes is not None
162162 return self .local_sizes
163163
164+ # Get the cumulative tokens across sequence parallel ranks.
165+ # In this case the input to the MoEs will be distributed w.r.t both
166+ # DP and TP rank.
167+ # When sp_size==1, this is just the cummulative num tokens across DP.
168+ def cu_tokens_across_sp (self , sp_size : int ) -> torch .Tensor :
169+ num_tokens_across_sp_cpu = (
170+ self .num_tokens_across_dp_cpu - 1 + sp_size
171+ ) // sp_size
172+ num_tokens_across_sp_cpu = num_tokens_across_sp_cpu .repeat_interleave (sp_size )
173+ return torch .cumsum (num_tokens_across_sp_cpu , dim = 0 )
174+
164175
165176@dataclass
166177class ForwardContext :
You can’t perform that action at this time.
0 commit comments