Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bring deepspeed_main up-to-date #746

Merged
merged 5 commits into from
Dec 17, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ repos:
hooks:
- id: codespell
args: [
'--ignore-words-list=reord', # Word used in error messages that need rewording
'--ignore-words-list=reord,dout', # Word used in error messages that need rewording
--check-filenames,
--check-hidden,
]
5 changes: 5 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,11 @@ from the repository root.
</aside>


### Flash Attention

To use [Flash-Attention](https://github.com/HazyResearch/flash-attention), install the additional dependencies in `./requirements/requirements-flashattention.txt` and set the attention type in your configuration accordingly (see [configs](./configs/)). This can provide significant speed-ups over regular attention on certain GPU architectures, including Ampere GPUs (such as A100s); see the repository for more details.


### Containerized Setup

We also provide a Dockerfile if you prefer to run NeoX in a container. To use this option, first build an image named `gpt-neox` from the repository root directory with `docker build -t gpt-neox -f Dockerfile .`. We also host pre-built images on Docker Hub at `leogao2/gpt-neox`.
Expand Down
40 changes: 37 additions & 3 deletions configs/neox_arguments.md
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ Logging Arguments

- **git_hash**: str

Default = 5bba068
Default = 12f6f76

current git hash of repository

Expand Down Expand Up @@ -797,6 +797,14 @@ Misc. Arguments



- **save_iters**: list

Default = None

Set during training



- **global_num_gpus**: int

Default = None
Expand Down Expand Up @@ -1132,11 +1140,37 @@ Training Arguments



- **save_interval**: int
- **checkpoint_scale**: typing.Literal['linear', 'log']

Default = linear

How step at which checkpoints are saved should scale. "linear" implies 1 checkpoint will be saved at every multiple of `checkpoint-factor`,
while "log" implies that the number of steps between each checkpoint will be multiplied by `checkpoint-factor` at each step, starting from step 1.



- **checkpoint_factor**: int

Default = None

Acts as a multiplier on either the "log" or "linear" checkpoint spacing.

With `checkpoint-scale="linear"`, `checkpoint-factor=20`, and `train-iters=100`, checkpoints will be saved at
steps [20, 40, 60, 80, 100].

With `checkpoint-scale="log"`, `checkpoint-factor=2`, and `train-iters=100`, checkpoints will be saved at
steps [1, 2, 4, 8, 16, 32, 64, 100].

Note that the last checkpoint step is always saved.



- **extra_save_iters**: list

Default = None

Number of iterations between checkpoint saves.
Additional iterations when a checkpoint should be saved.
Must be a list of ints or `None`.



Expand Down
185 changes: 185 additions & 0 deletions megatron/model/flash_attention.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,185 @@
# Based on: https://github.com/HazyResearch/flash-attention/blob/4a6eaa9f27df6fff7ffb2c24e894938a687dd870/flash_attn/flash_attn_interface.py

import torch
import torch.nn as nn
import torch.nn.functional as F

import flash_attn_cuda


def _flash_attn_forward(
q,
k,
v,
out,
cu_seqlens_q,
cu_seqlens_k,
max_seqlen_q,
max_seqlen_k,
dropout_p,
softmax_scale,
causal,
return_softmax,
num_splits=0,
generator=None,
):
"""
num_splits: how much to parallelize over the seqlen_q dimension. num_splits=0 means
it will be set by an internal heuristic. We're exposing num_splits mostly for benchmarking.
Don't change it unless you know what you're doing.
"""
softmax_lse, *rest = flash_attn_cuda.fwd(
q,
k,
v,
out,
cu_seqlens_q,
cu_seqlens_k,
max_seqlen_q,
max_seqlen_k,
dropout_p,
softmax_scale,
False,
causal,
return_softmax,
num_splits,
generator,
)
# if out.isnan().any() or softmax_lse.isnan().any():
# breakpoint()
S_dmask = rest[0] if return_softmax else None
return out, softmax_lse, S_dmask


def _flash_attn_backward(
dout,
q,
k,
v,
out,
softmax_lse,
dq,
dk,
dv,
cu_seqlens_q,
cu_seqlens_k,
max_seqlen_q,
max_seqlen_k,
dropout_p,
softmax_scale,
causal,
num_splits=0,
generator=None,
):
"""
num_splits: whether to parallelize over the seqlen_k dimension (num_splits > 1) or
not (num_splits = 1). num_splits=0 means it will be set by an internal heuristic.
Any value above 1 will call the same kernel (i.e. num_splits=2 would call the same kernel
as num_splits=3), so effectively the choices are 0, 1, and 2.
This hyperparameter can be tuned for performance, but default value (heuristic) should work fine.
"""
_, _, _, softmax_d = flash_attn_cuda.bwd(
dout,
q,
k,
v,
out,
softmax_lse,
dq,
dk,
dv,
cu_seqlens_q,
cu_seqlens_k,
max_seqlen_q,
max_seqlen_k,
dropout_p,
softmax_scale,
False,
causal,
num_splits,
generator,
)
# if dk.isnan().any() or dk.isnan().any() or dv.isnan().any() or softmax_d.isnan().any():
# breakpoint()
return dq, dk, dv, softmax_d


class FlashAttnQKVPackedFunc(torch.autograd.Function):
@staticmethod
def forward(
ctx,
qkv,
cu_seqlens,
max_seqlen,
dropout_p,
softmax_scale,
causal,
return_softmax,
):
# Save rng_state because the backward pass will regenerate the dropout mask
rng_state = torch.cuda.get_rng_state() if dropout_p > 0 else None
if softmax_scale is None:
softmax_scale = qkv.shape[-1] ** (-0.5)
out, softmax_lse, S_dmask = _flash_attn_forward(
qkv[:, 0],
qkv[:, 1],
qkv[:, 2],
torch.empty_like(qkv[:, 0]),
cu_seqlens,
cu_seqlens,
max_seqlen,
max_seqlen,
dropout_p,
softmax_scale,
causal=causal,
return_softmax=return_softmax,
)
ctx.save_for_backward(qkv, out, softmax_lse, cu_seqlens, rng_state)
ctx.dropout_p = dropout_p
ctx.max_seqlen = max_seqlen
ctx.softmax_scale = softmax_scale
ctx.causal = causal
return out if not return_softmax else (out, softmax_lse, S_dmask)

@staticmethod
def backward(ctx, dout, *args):
qkv, out, softmax_lse, cu_seqlens, rng_state = ctx.saved_tensors
if rng_state is not None:
cur_rng_state = torch.cuda.get_rng_state()
torch.cuda.set_rng_state(rng_state)
dqkv = torch.empty_like(qkv)
_flash_attn_backward(
dout,
qkv[:, 0],
qkv[:, 1],
qkv[:, 2],
out,
softmax_lse,
dqkv[:, 0],
dqkv[:, 1],
dqkv[:, 2],
cu_seqlens,
cu_seqlens,
ctx.max_seqlen,
ctx.max_seqlen,
ctx.dropout_p,
ctx.softmax_scale,
ctx.causal,
)
if rng_state is not None:
torch.cuda.set_rng_state(cur_rng_state)
return dqkv, None, None, None, None, None, None


def flash_attn_unpadded_qkvpacked_func(
qkv,
cu_seqlens,
max_seqlen,
dropout_p,
softmax_scale=None,
causal=False,
return_attn_probs=False,
):
return FlashAttnQKVPackedFunc.apply(
qkv, cu_seqlens, max_seqlen, dropout_p, softmax_scale, causal, return_attn_probs
)
86 changes: 75 additions & 11 deletions megatron/model/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,7 +259,8 @@ def __init__(
self.rotary_emb = None

self.attention_type = neox_args.attention_config[layer_number]
self.sparse = self.attention_type != "global"
self.use_flash_attention = self.attention_type == "flash"
self.sparse = self.attention_type != "global" and not self.use_flash_attention
if self.sparse:
self.sparse_attn = configure_sparse_attention(
neox_args,
Expand All @@ -268,19 +269,31 @@ def __init__(
mpu=mpu,
)
else:
self.scale_mask_softmax = FusedScaleMaskSoftmax(
input_in_fp16=self.fp16,
input_in_bf16=self.bf16,
fusion_type=get_fusion_type(neox_args),
mask_func=self.attention_mask_func,
softmax_in_fp32=self.attention_softmax_in_fp32,
scale=coeff,
)
if self.use_flash_attention:
from megatron.model.flash_attention import (
flash_attn_unpadded_qkvpacked_func,
)

self.flash_attention_function = flash_attn_unpadded_qkvpacked_func
if self.pos_emb == "alibi":
raise ValueError(
"Flash attention is currently not compatible with AliBi positional embeddings. Use sinuisoidal, learned, or rotary embeddings instead."
)
else:
self.scale_mask_softmax = FusedScaleMaskSoftmax(
input_in_fp16=self.fp16,
input_in_bf16=self.bf16,
fusion_type=get_fusion_type(neox_args),
mask_func=self.attention_mask_func,
softmax_in_fp32=self.attention_softmax_in_fp32,
scale=coeff,
)

# Dropout. Note that for a single iteration, this layer will generate
# different outputs on different number of parallel partitions but
# on average it should not be partition dependent.
self.attention_dropout = nn.Dropout(neox_args.attention_dropout)
self.dropout_p = neox_args.attention_dropout
self.attention_dropout = nn.Dropout(self.dropout_p)

# Output.
self.dense = mpu.RowParallelLinear(
Expand Down Expand Up @@ -396,6 +409,55 @@ def attention(
context_layer = context_layer.view(*output_size)
return context_layer

def flash_attention(self, query_layer, key_layer, value_layer):
# [b, np, sq, sk]
output_size = (
query_layer.size(1),
query_layer.size(2),
query_layer.size(0),
key_layer.size(0),
)
# [s, b, np, hn] -> [b, s, np, hn] -> [b * s, 1, np, hn]
query_layer = query_layer.transpose(0, 1).reshape(
output_size[0] * output_size[2], 1, output_size[1], -1
)
key_layer = key_layer.transpose(0, 1).reshape(
output_size[0] * output_size[3], 1, output_size[1], -1
)
value_layer = value_layer.transpose(0, 1).reshape(
output_size[0] * output_size[3], 1, output_size[1], -1
)

# Combined q/k/v into [b * s, 3, np, hn].
qkv = torch.concat([query_layer, key_layer, value_layer], dim=1)

batch_size = output_size[0]
seqlen = output_size[2]
max_s = seqlen
cu_seqlens = torch.arange(
0,
(batch_size + 1) * seqlen,
step=seqlen,
dtype=torch.int32,
device=qkv.device,
)
output = self.flash_attention_function(
qkv,
cu_seqlens,
max_s,
self.dropout_p if self.training else 0.0,
softmax_scale=None,
causal=True,
)
# [b * sq, np, hn] -> [b, sq, np, hn]
matmul_result = output.view(
output_size[0], output_size[2], output.shape[1], output.shape[2]
)
# [b, sq, np, hn] -> [b, np, sq, hn]
matmul_result = matmul_result.transpose(1, 2)

return matmul_result

def sparse_attention(self, query_layer, key_layer, value_layer, attention_mask):
# TODO: sparse attn dropout?
# TODO: pad to block size
Expand Down Expand Up @@ -483,7 +545,9 @@ def forward(self, hidden_states, attention_mask, layer_past=None):
if self.use_cache:
present = torch.stack((key_layer, value_layer))

if not self.sparse:
if self.use_flash_attention:
context_layer = self.flash_attention(query_layer, key_layer, value_layer)
elif not self.sparse:
context_layer = self.attention(
query_layer, key_layer, value_layer, layer_past, attention_mask
)
Expand Down
Loading