Skip to content

Commit

Permalink
support when num_heads is not divisible by world_size; resolves Paddl…
Browse files Browse the repository at this point in the history
…ePaddle#459 (PaddlePaddle#461)

* uneql rank.

* trim.

* enable passing in number of heads for each rank.

* simplify.

* simplify.

* cleanup.

* fix col parallel.

* fix bug with row parallel.

* fit out proj.

* refac.

* fix sharding logic.

* refac sharding.

* refac.

* support multiple of.

* make fn reuseable.

* fix bug in dimensions.

* scaffold.

* test uneven heads.

* fix test by adding barrier.

* refac.

* reuse code.

* clean up.
  • Loading branch information
lxuechen authored Aug 18, 2023
1 parent ada4710 commit bb4cded
Show file tree
Hide file tree
Showing 5 changed files with 161 additions and 45 deletions.
67 changes: 38 additions & 29 deletions flash_attn/models/gpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
ParallelMLP,
)
from flash_attn.ops.activations import sqrelu_fwd
from flash_attn.utils.distributed import all_gather_raw, sync_shared_params
from flash_attn.utils.distributed import all_gather_raw, sync_shared_params, get_dim_for_local_rank
from flash_attn.utils.generation import GenerationMixin
from flash_attn.utils.pretrained import state_dict_from_pretrained
from transformers import GPT2Config
Expand Down Expand Up @@ -62,7 +62,6 @@
except ImportError:
FusedDenseSqreluDense = None


logger = logging.getLogger(__name__)


Expand Down Expand Up @@ -681,41 +680,58 @@ def shard_state_dict_tp(state_dict, config, world_size, rank):
inner_dim = config.n_inner if config.n_inner is not None else 4 * config.hidden_size
assert inner_dim % world_size == 0

n_head = config.n_head
n_head_kv = getattr(config, "n_head_kv", n_head)

embed_dim = config.hidden_size
head_dim = embed_dim // n_head

def shard_first_dim(state_dict, key):
if key in state_dict:
x = state_dict[key]
dim = x.shape[0] // world_size
state_dict[key] = x[rank * dim : (rank + 1) * dim]
state_dict[key] = x[rank * dim: (rank + 1) * dim]

def shard_last_dim(state_dict, key):
def shard_last_dim(state_dict, key, multiple_of=1):
if key in state_dict:
x = state_dict[key]
dim = x.shape[-1] // world_size
state_dict[key] = x[..., rank * dim : (rank + 1) * dim]
dim_each_rank = [
get_dim_for_local_rank(x.size(-1), world_size, local_rank, multiple_of)
for local_rank in range(world_size)
]
beg, end = tuple(sum(dim_each_rank[:pos]) for pos in (rank, rank + 1))
state_dict[key] = x[..., beg:end]

def shard_gatedmlp_fc1_dim(state_dict, key):
if key in state_dict:
x = state_dict[key]
dim = x.shape[0] // world_size // 2
state_dict[key] = rearrange(
rearrange(x, "(two o) ... -> two o ...", two=2)[:, rank * dim : (rank + 1) * dim],
rearrange(x, "(two o) ... -> two o ...", two=2)[:, rank * dim: (rank + 1) * dim],
"two o ... -> (two o) ...",
)

def shard_qkv_headdim(state_dict, key):
if key in state_dict:
n_head = config.n_head
n_head_kv = getattr(config, "n_head_kv", n_head)
assert n_head % world_size == 0 and n_head_kv % world_size == 0
n_head_each_rank = [
get_dim_for_local_rank(n_head, world_size, local_rank) for local_rank in range(world_size)
]
n_head_kv_each_rank = [
get_dim_for_local_rank(n_head_kv, world_size, local_rank) for local_rank in range(world_size)
]

beg_n_head = sum(n_head_each_rank[:rank])
end_n_head = sum(n_head_each_rank[: rank + 1])

beg_n_head_kv = sum(n_head_kv_each_rank[:rank])
end_n_head_kv = sum(n_head_kv_each_rank[: rank + 1])

if n_head_kv == n_head:
x = rearrange(state_dict[key], "(three d) ... -> three d ...", three=3)
dim = x.shape[1] // world_size
state_dict[key] = rearrange(
x[:, rank * dim : (rank + 1) * dim], "three d ... -> (three d) ..."
x[:, beg_n_head * head_dim : end_n_head * head_dim], "three d ... -> (three d) ..."
)
else:
n_head_per_rank = n_head // world_size
n_head_kv_per_rank = n_head_kv // world_size
x = rearrange(
state_dict[key],
"(nheadqkv headdim) ... -> nheadqkv headdim ...",
Expand All @@ -724,19 +740,9 @@ def shard_qkv_headdim(state_dict, key):
state_dict[key] = rearrange(
torch.cat(
[
x[rank * n_head_per_rank : (rank + 1) * n_head_per_rank],
x[
n_head
+ rank * n_head_kv_per_rank : n_head
+ (rank + 1) * n_head_kv_per_rank
],
x[
n_head
+ n_head_kv
+ rank * n_head_kv_per_rank : n_head
+ n_head_kv
+ (rank + 1) * n_head_kv_per_rank
],
x[beg_n_head:end_n_head],
x[n_head + beg_n_head_kv: n_head + end_n_head_kv],
x[n_head + n_head_kv + beg_n_head_kv: n_head + n_head_kv + end_n_head_kv],
],
dim=0,
),
Expand All @@ -751,7 +757,9 @@ def shard_qkv_headdim(state_dict, key):
for i in range(config.num_hidden_layers):
shard_qkv_headdim(state_dict, f"transformer.layers.{i}.mixer.Wqkv.weight")
shard_qkv_headdim(state_dict, f"transformer.layers.{i}.mixer.Wqkv.bias")
shard_last_dim(state_dict, f"transformer.layers.{i}.mixer.out_proj.weight")
shard_last_dim(
state_dict, f"transformer.layers.{i}.mixer.out_proj.weight", multiple_of=head_dim
)
if rank != 0:
state_dict.pop(f"transformer.layers.{i}.mixer.out_proj.bias", None)
if config.activation_function in ["glu", "swiglu", "geglu"]:
Expand Down Expand Up @@ -816,7 +824,7 @@ def combine_qkv_headdim(state_dicts, state_dict, key):
torch.cat([x[:n_head_per_rank] for x in xs], dim=0),
torch.cat(
[
x[n_head_per_rank : n_head_per_rank + n_head_kv_per_rank]
x[n_head_per_rank: n_head_per_rank + n_head_kv_per_rank]
for x in xs
],
dim=0,
Expand Down Expand Up @@ -922,6 +930,7 @@ def key_mapping_transformer(key):
return key

state_dict = OrderedDict((key_mapping_transformer(k), v) for k, v in state_dict.items())

# Word embedding and position embedding
def key_mapping_pos_emb(key):
return re.sub(r"^wpe.", "transformer.embeddings.position_embeddings.", key)
Expand Down
20 changes: 11 additions & 9 deletions flash_attn/modules/mha.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,10 @@

import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange, repeat

from flash_attn.utils.distributed import get_dim_for_local_rank

try:
from flash_attn import (
flash_attn_kvpacked_func,
Expand Down Expand Up @@ -720,22 +721,21 @@ def __init__(
self.use_flash_attn = use_flash_attn
self.checkpointing = checkpointing
self.process_group = process_group
self.world_size = process_group.size() if process_group is not None else 1
self.world_size = process_group.size()
self.local_rank = torch.distributed.get_rank(process_group)

self.num_heads = num_heads
assert self.embed_dim % self.num_heads == 0, "embed_dim must be divisible by num_heads"

self.num_heads_kv = num_heads_kv if num_heads_kv is not None else num_heads
self.num_heads_per_rank = num_heads // self.world_size
self.num_heads_kv_per_rank = self.num_heads_kv // self.world_size
assert (
self.num_heads % self.num_heads_kv == 0
), "num_heads must be divisible by num_heads_kv"
assert self.embed_dim % num_heads == 0, "embed_dim must be divisible by num_heads"
assert (
self.num_heads_kv % self.world_size == 0
), "num_heads_kv must be divisible by world_size"

self.num_heads_per_rank = get_dim_for_local_rank(self.num_heads, self.world_size, self.local_rank)
self.num_heads_kv_per_rank = get_dim_for_local_rank(self.num_heads, self.world_size, self.local_rank)
self.head_dim = self.embed_dim // num_heads
qkv_dim = self.head_dim * (self.num_heads + 2 * self.num_heads_kv)
kv_dim = 2 * self.head_dim * self.num_heads_kv

if self.rotary_emb_dim > 0:
assert RotaryEmbedding is not None, "rotary_emb is not installed"
Expand All @@ -755,6 +755,7 @@ def __init__(
process_group,
bias=qkv_proj_bias,
sequence_parallel=sequence_parallel,
multiple_of=self.head_dim * 3,
**factory_kwargs,
)
inner_attn_cls = FlashSelfAttention if use_flash_attn else SelfAttention
Expand All @@ -771,6 +772,7 @@ def __init__(
process_group,
bias=out_proj_bias,
sequence_parallel=sequence_parallel,
multiple_of=self.head_dim,
**factory_kwargs,
)

Expand Down
2 changes: 1 addition & 1 deletion flash_attn/ops/fused_dense.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,7 +226,7 @@ def __init__(
local_multiple = div + int(torch.distributed.get_rank(process_group) < mod)
# Only rank 0 will have bias
super().__init__(
in_features // world_size,
local_multiple * multiple_of,
out_features,
bias=bias and rank == 0,
device=device,
Expand Down
12 changes: 12 additions & 0 deletions flash_attn/utils/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,3 +125,15 @@ def allreduce_sequence_parallel_grad(model: torch.nn.Module, process_group: Proc
torch.distributed.all_reduce(coalesced, group=process_group)
for buf, synced in zip(grads, torch._utils._unflatten_dense_tensors(coalesced, grads)):
buf.copy_(synced)


def get_dim_for_local_rank(dim: int, world_size: int, local_rank: int, multiple_of: int = 1) -> int:
"""Get the dim for the local rank derived from splitting dim on world_size processes.
The split may not be even across the world_size processes.
"""
multiple = dim // multiple_of
div = multiple // world_size
mod = multiple % world_size
local_multiple = div + int(local_rank < mod)
return local_multiple * multiple_of
105 changes: 99 additions & 6 deletions tests/models/test_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

from einops import rearrange

from transformers import LlamaTokenizer
from transformers import LlamaTokenizer, LlamaConfig
from transformers.models.llama.modeling_llama import LlamaForCausalLM

from flash_attn.models.gpt import GPTLMHeadModel, combine_state_dicts_tp, shard_state_dict_tp
Expand Down Expand Up @@ -255,7 +255,6 @@ def test_llama_generation(model_name, checkpoint_format):
logits_ref = model_ref(out_hf.sequences).logits[:, (seqlen - 1):-1].to(device=device)
del model_ref


pretrained_state_dict = _pretrained_state_dict_from_checkpoint(
checkpoint_path, model_name, config, checkpoint_format
)
Expand Down Expand Up @@ -297,8 +296,8 @@ def test_llama_generation(model_name, checkpoint_format):
hf_error = (logits_hf - logits_ref).abs().max().item()

print(f'HF fp16 logits max diff: {hf_error}')
print(f'Logits max diff: {(logits - logits_ref).abs().max().item() }')
print(f'Logits CG max diff: {(logits_cg - logits_ref).abs().max().item() }')
print(f'Logits max diff: {(logits - logits_ref).abs().max().item()}')
print(f'Logits CG max diff: {(logits_cg - logits_ref).abs().max().item()}')

assert (logits_parallel - logits_ref).abs().max().item() < 2 * hf_error
assert (logits - logits_ref).abs().max().item() < 2 * hf_error
Expand Down Expand Up @@ -410,7 +409,101 @@ def test_llama_parallel_generation(model_name, world_size, checkpoint_format):

hf_error = (logits_hf - logits_ref).abs().max().item()
print(f'HF fp16 logits max diff: {hf_error}')
print(f'Logits max diff: {(logits - logits_ref).abs().max().item() }')
print(f'Logits max diff: {(logits - logits_ref).abs().max().item()}')
assert (logits - logits_ref).abs().max().item() < 2 * hf_error
print(f'Logits CG max diff: {(logits_cg - logits_ref).abs().max().item() }')
print(f'Logits CG max diff: {(logits_cg - logits_ref).abs().max().item()}')
assert torch.equal(logits_cg, logits)


@torch.no_grad()
@pytest.mark.parametrize('world_size', [2])
def test_llama_parallel_uneven_num_heads(world_size):
from apex.transformer import parallel_state

checkpoint_path = Path(os.environ.get('CHECKPOINT_DIR', current_dir.parent.parent / 'checkpoints')) / 'llama'
num_attention_heads = world_size + 1
model_name = f'teeny-{num_attention_heads}-heads'

if not torch.distributed.is_initialized():
torch.distributed.init_process_group(backend='nccl', init_method='env://')
device = f'cuda:{torch.distributed.get_rank()}'
assert world_size <= torch.distributed.get_world_size()
parallel_state.initialize_model_parallel(tensor_model_parallel_size_=world_size)
rank = parallel_state.get_tensor_model_parallel_rank()
process_group = parallel_state.get_tensor_model_parallel_group()

dtype = torch.float16
llama_config = LlamaConfig(
hidden_size=256 * num_attention_heads, # ParallelGatedMlp hidden_features must be divisible by 256
intermediate_size=256 * num_attention_heads * 4,
num_hidden_layers=4,
num_attention_heads=num_attention_heads,
initializer_range=0.5, # Set crazy init range so we don't have near zero weights implying a vacuous test.
)
config = llama_config_to_gpt2_config(llama_config)
config.use_flash_attn = True
config.fused_bias_fc = True
config.fused_mlp = False # We don't have fused GatedMLP yet
config.fused_dropout_add_ln = True
config.residual_in_fp32 = True

torch.manual_seed(0)
batch_size = 2
max_seqlen = 256
seqlens = torch.randint(max_seqlen // 2, max_seqlen + 1, (batch_size,), device=device)
input_ids = torch.randint(0, config.vocab_size, (batch_size, max_seqlen), dtype=torch.long,
device=device)

# Create a shared test model.
if rank == 0:
LlamaForCausalLM(config=llama_config).save_pretrained(checkpoint_path / f"{model_name}-hf")
torch.distributed.barrier()

# Run the standard forward pass test.
pretrained_state_dict = _pretrained_state_dict_from_checkpoint(
checkpoint_path, model_name, config, checkpoint_format="hf"
)
model = GPTLMHeadModel(config, process_group=process_group, device=device, dtype=dtype)
model.load_state_dict(shard_state_dict_tp(pretrained_state_dict, config, world_size, rank))
model.eval()

# TODO: Avoid duplicate code. Modularize the comparison of two forward pass diffs.
out = model.transformer(input_ids)
out, _ = all_gather_raw(out, process_group=process_group)
out = rearrange(out, "(b s) d -> b s d", b=batch_size)
logits = model(input_ids).logits
logits = rearrange(logits, "(b s) d -> b s d", b=batch_size)
logits, _ = all_gather_raw(logits, process_group)
logits = rearrange(logits, '(n b) ... d -> b ... (n d)', b=batch_size)

if rank == 0:
model_ref = LlamaForCausalLM.from_pretrained(
Path(checkpoint_path) / f'{model_name}-hf', device_map="auto"
)
model_ref.eval()
out_ref = model_ref.model(input_ids).last_hidden_state.to(device=device)
logits_ref = model_ref(input_ids).logits.to(device=device)
del model_ref

model_hf = LlamaForCausalLM.from_pretrained(
Path(checkpoint_path) / f'{model_name}-hf', torch_dtype=dtype, device_map="auto"
)
model_hf.eval()
out_hf = model_hf.model(input_ids).last_hidden_state.to(device=device)
logits_hf = model_hf(input_ids).logits.to(device=device)
del model_hf

print(f'Output max diff: {(out - out_ref).abs().max().item()}')
print(f'Output mean diff: {(out - out_ref).abs().mean().item()}')
print(f'HF fp16 max diff: {(out_hf - out_ref).abs().max().item()}')
print(f'HF fp16 mean diff: {(out_hf - out_ref).abs().mean().item()}')
assert (out - out_ref).abs().max().item() < 2 * (out_hf - out_ref).abs().max().item()

print(f'Logits max diff: {(logits - logits_ref).abs().max().item()}')
print(f'Logits mean diff: {(logits - logits_ref).abs().mean().item()}')
print(f'HF fp16 max diff: {(logits_hf - logits_ref).abs().max().item()}')
print(f'HF fp16 mean diff: {(logits_hf - logits_ref).abs().mean().item()}')
assert (logits - logits_ref).abs().max().item() < 2 * (logits_hf - logits_ref).abs().max().item()

import shutil
shutil.rmtree(checkpoint_path / f'{model_name}-hf')

0 comments on commit bb4cded

Please sign in to comment.