@@ -134,7 +134,7 @@ def extra_groups_for_head_shards(ngroups: int, tp_size: int):
134134 return 0
135135
136136 # for n_groups == 1, this is exactly tp_size - n_groups
137- return tp_size - ngroups
137+ return tp_size - ngroups
138138
139139
140140def mamba_v2_sharded_weight_loader (
@@ -168,12 +168,9 @@ def loader(param: torch.Tensor, loaded_weight: torch.Tensor) -> None:
168168 # - compute the rank into the loaded shard.
169169 # - if there is replication, different TP shards will
170170 # take from the same rank.
171- if duplicate_groups :
172- # NOTE: currently we only support duplication
173- # in the case where num_groups == 1
174- rank = 0
175- else :
176- rank = tp_rank
171+ # NOTE: currently we only support duplication
172+ # in the case where num_groups == 1
173+ rank = 0 if duplicate_groups else tp_rank
177174
178175 # - leftmost boundary index into loaded weight.
179176 loaded_skip = rank * shard_size
@@ -247,7 +244,7 @@ def __init__(self,
247244 assert num_heads % self .tp_size == 0 , \
248245 "Tensor parallel world size must divide num heads."
249246
250-
247+
251248 assert (n_groups % self .tp_size ) == 0 or n_groups == 1 , \
252249 (
253250 "If tensor parallel world size does not divide num_heads, "
0 commit comments