Skip to content

Commit

Permalink
handle uneven heads across ranks when combining state_dicts; resolves P…
Browse files Browse the repository at this point in the history
  • Loading branch information
lxuechen authored Aug 20, 2023
1 parent d431f16 commit 25d6b1d
Showing 1 changed file with 38 additions and 20 deletions.
58 changes: 38 additions & 20 deletions flash_attn/models/gpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,16 +20,12 @@
from flash_attn.modules.block import Block, ParallelBlock
from flash_attn.modules.embedding import GPT2Embeddings, ParallelGPT2Embeddings
from flash_attn.modules.mha import MHA, ParallelMHA
from flash_attn.modules.mlp import (
FusedMLP,
GatedMlp,
Mlp,
ParallelFusedMLP,
ParallelGatedMlp,
ParallelMLP,
)
from flash_attn.modules.mlp import (FusedMLP, GatedMlp, Mlp, ParallelFusedMLP,
ParallelGatedMlp, ParallelMLP)
from flash_attn.ops.activations import sqrelu_fwd
from flash_attn.utils.distributed import all_gather_raw, get_dim_for_local_rank, sync_shared_params
from flash_attn.utils.distributed import (all_gather_raw,
get_dim_for_local_rank,
sync_shared_params)
from flash_attn.utils.generation import GenerationMixin
from flash_attn.utils.pretrained import state_dict_from_pretrained

Expand All @@ -44,7 +40,8 @@
dropout_add_layer_norm = None

try:
from flash_attn.ops.layer_norm import dropout_add_layer_norm_parallel_residual
from flash_attn.ops.layer_norm import \
dropout_add_layer_norm_parallel_residual
except ImportError:
dropout_add_layer_norm_parallel_residual = None

Expand Down Expand Up @@ -673,6 +670,8 @@ def load_state_dict(self, state_dict, strict=True):
def shard_state_dict_tp(state_dict, config, world_size, rank):
"""Convert the state_dict of a standard GPT model to the state_dict of a GPT model
with tensor parallel.
This function modifies state_dict in place.
"""
pad_vocab_size_multiple = getattr(config, "pad_vocab_size_multiple", 1)
vocab_size = math.ceil(config.vocab_size / pad_vocab_size_multiple) * pad_vocab_size_multiple
Expand Down Expand Up @@ -784,11 +783,14 @@ def shard_qkv_headdim(state_dict, key):
return state_dict


def combine_state_dicts_tp(state_dicts, config):
"""Convert the state_dict of a GPT model with tensor parallel to the state_dict of a
standard GPT model.
def combine_state_dicts_tp(state_dicts: list[dict[str, torch.Tensor]], config: GPT2Config):
"""Convert the list of sharded state_dict of a GPT model with tensor parallel to
the state_dict of a standard GPT model.
This function is meant to be the "reverse" of shard_state_dict_tp.
Precondition:
- state_dicts should be ordered in the same way as the shards were created.
"""
world_size = len(state_dicts)
keys = state_dicts[0].keys()
Expand All @@ -812,9 +814,6 @@ def combine_dim(state_dicts, state_dict, key, dim=-1):
def combine_qkv_headdim(state_dicts, state_dict, key):
n_head = config.n_head
n_head_kv = getattr(config, "n_head_kv", n_head)
assert n_head % world_size == 0 and n_head_kv % world_size == 0
n_head_per_rank = n_head // world_size
n_head_kv_per_rank = n_head_kv // world_size
if key in state_dict:
if n_head_kv == n_head:
xs = [
Expand All @@ -830,18 +829,37 @@ def combine_qkv_headdim(state_dicts, state_dict, key):
)
for s in state_dicts
]
n_head_each_rank = [
get_dim_for_local_rank(n_head, world_size, local_rank)
for local_rank in range(world_size)
]
n_head_kv_each_rank = [
get_dim_for_local_rank(n_head_kv, world_size, local_rank)
for local_rank in range(world_size)
]
state_dict[key] = rearrange(
torch.cat(
[
torch.cat([x[:n_head_per_rank] for x in xs], dim=0),
torch.cat(
[x[: n_head_each_rank[rank]] for rank, x in enumerate(xs)], dim=0
),
torch.cat(
[
x[
n_head_each_rank[rank] : n_head_each_rank[rank]
+ n_head_kv_each_rank[rank]
]
for rank, x in enumerate(xs)
],
dim=0,
),
torch.cat(
[
x[n_head_per_rank : n_head_per_rank + n_head_kv_per_rank]
for x in xs
x[n_head_each_rank[rank] + n_head_kv_each_rank[rank] :]
for rank, x in enumerate(xs)
],
dim=0,
),
torch.cat([x[-n_head_kv_per_rank:] for x in xs], dim=0),
],
dim=0,
),
Expand Down

0 comments on commit 25d6b1d

Please sign in to comment.