Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

recompute core attn #635

Merged
merged 4 commits into from
Aug 22, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion examples/gpt/hybrid_parallel/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ Data:
| type_vocab_size | 词表类型 |
| initializer_range | 参数初始化的范围 |
| use_recompute | 是否使用recompute训练 |
| recompute_granularity | recompute训练的粒度,可选 `full` `only_attn`,full即recompute全部transformer,only_attn表明只recompute self attention部分 |
| recompute_granularity | recompute训练的粒度,可选 `full` `full_attn` `core_attn`,full即recompute全部transformer,full_attn表明只recompute所有self attention部分,core_attn表明只recompute `softmax(qkT)v` 部分。注:显存占用方面,`core_attn` > `full_attn` > `full`,若所选策略产生OOM错误,可以适当更改recompute_granularity |
| fused_linear | 是否使用fused_linear代替传统Linear加速训练。注:该功能需要cuda 11.6及以上编译的paddle支持。 |


Expand Down
2 changes: 1 addition & 1 deletion examples/gpt/single/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ Data:
| type_vocab_size | 词表类型 |
| initializer_range | 参数初始化的范围 |
| use_recompute | 是否使用recompute训练 |
| recompute_granularity | recompute训练的粒度,可选 `full` `only_attn`,full即recompute全部transformer,only_attn表明只recompute self attention部分 |
| recompute_granularity | recompute训练的粒度,可选 `full` `full_attn` `core_attn`,full即recompute全部transformer,full_attn表明只recompute所有self attention部分,core_attn表明只recompute `softmax(qkT)v` 部分。注:显存占用方面,`core_attn` > `full_attn` > `full`,若所选策略产生OOM错误,可以适当更改recompute_granularity |
| fused_linear | 是否使用fused_linear代替传统Linear加速训练。注:该功能需要cuda 11.6及以上编译的paddle支持。 |

### 优化器
Expand Down
4 changes: 2 additions & 2 deletions examples/gpt/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,9 +113,9 @@ def process_model_configs(yaml_dict):
configs['ffn_hidden_size'] = 4 * configs['hidden_size']

if configs['use_recompute']:
assert configs['recompute_granularity'] in ["full", "only_attn"], \
assert configs['recompute_granularity'] in ["full", "full_attn", "core_attn"], \
"recompute_granularity can be only chosen from " \
"'full' or 'only_attn', but received '{}'".format(configs['recompute_granularity'])
"'full', 'full_attn' or 'core_attn', but received '{}'".format(configs['recompute_granularity'])

if configs['fused_linear'] and not is_fused_matmul_bias_supported():
configs['fused_linear'] = False
Expand Down
80 changes: 46 additions & 34 deletions fleetx/models/gpt_model/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,10 @@ def __init__(self,
need_weights=False,
weight_attr=None,
bias_attr=None,
fuse=False,
fused_linear=False):
fuse=True,
fused_linear=False,
use_recompute=False,
recompute_granularity="full"):
super(MultiHeadAttention, self).__init__()
self.embed_dim = embed_dim
self.kdim = kdim if kdim is not None else embed_dim
Expand All @@ -64,6 +66,8 @@ def __init__(self,
self.dropout = dropout
self.need_weights = need_weights
self.fuse = fuse
self.use_recompute = use_recompute
self.recompute_granularity = recompute_granularity

self.head_dim = embed_dim // num_heads
assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads"
Expand Down Expand Up @@ -164,6 +168,31 @@ def gen_cache(self, key, value=None, type=Cache):
# incremental_state with initial value, mainly for usage like UniLM
return self.Cache(key, value)

def core_attn(self, q, k, v):
# scale dot product attention
product = layers.matmul(
x=q, y=k, transpose_y=True, alpha=self.head_dim ** -0.5)

# TODO(liuyuang): support softmax_mask_fuse_upper_triangle for generation task
weights = F.softmax(product)

# weights = incubate.softmax_mask_fuse_upper_triangle(product)

if self.dropout:
weights = F.dropout(
weights,
self.dropout,
training=self.training,
mode="upscale_in_train")

out = tensor.matmul(weights, v)

# combine heads
out = tensor.transpose(out, perm=[0, 2, 1, 3])
out = tensor.reshape(x=out, shape=[0, 0, out.shape[2] * out.shape[3]])

return out, weights

def forward(self,
query,
key,
Expand All @@ -187,30 +216,11 @@ def forward(self,
else:
q, k, v, cache = self._prepare_qkv(query, key, value, use_cache,
cache)
# scale dot product attention
product = layers.matmul(
x=q, y=k, transpose_y=True, alpha=self.head_dim**-0.5)

# if attn_mask is not None:
# product = product + attn_mask

# TODO(liuyuang): support softmax_mask_fuse_upper_triangle for generation task
weights = F.softmax(product)

# weights = incubate.softmax_mask_fuse_upper_triangle(product)

if self.dropout:
weights = F.dropout(
weights,
self.dropout,
training=self.training,
mode="upscale_in_train")

out = tensor.matmul(weights, v)

# combine heads
out = tensor.transpose(out, perm=[0, 2, 1, 3])
out = tensor.reshape(x=out, shape=[0, 0, out.shape[2] * out.shape[3]])
if self.use_recompute and self.recompute_granularity == "core_attn":
out, weights = recompute(self.core_attn, q, k, v)
else:
out, weights = self.core_attn(q, k, v)

# project to output
out = self.out_proj(out)
Expand Down Expand Up @@ -323,7 +333,8 @@ def __init__(self,
weight_attr=None,
bias_attr=None,
fused_linear=False,
recompute_attn=False):
use_recompute=False,
recompute_granularity="full"):
self._config = locals()
self._config.pop("self")
self._config.pop("__class__", None) # py3
Expand All @@ -332,7 +343,8 @@ def __init__(self,
attn_dropout = dropout if attn_dropout is None else attn_dropout
act_dropout = dropout if act_dropout is None else act_dropout
self.normalize_before = normalize_before
self.recompute_attn = recompute_attn
self.use_recompute = use_recompute
self.recompute_granularity = recompute_granularity

weight_attrs = _convert_param_attr_to_list(weight_attr, 3)
bias_attrs = _convert_param_attr_to_list(bias_attr, 3)
Expand All @@ -345,7 +357,9 @@ def __init__(self,
dropout=attn_dropout,
weight_attr=weight_attrs[0],
bias_attr=bias_attrs[0],
fused_linear=fused_linear)
fused_linear=fused_linear,
use_recompute=use_recompute,
recompute_granularity=recompute_granularity)
self.linear1 = Linear(
d_model, dim_feedforward, weight_attrs[2], bias_attr=bias_attrs[2])
self.linear2 = Linear(
Expand All @@ -364,9 +378,8 @@ def forward(self, tgt, memory, tgt_mask=None, use_cache=False, cache=None):
tgt = self.norm1(tgt)

if use_cache is False:
if self.recompute_attn:
tgt = recompute(self.self_attn, tgt, None, None, tgt_mask,
use_cache, cache)
if self.use_recompute and self.recompute_granularity == "full_attn":
tgt = recompute(self.self_attn, tgt, None, None, tgt_mask, use_cache, cache)
else:
tgt = self.self_attn(tgt, tgt, tgt, tgt_mask, use_cache, cache)
else:
Expand Down Expand Up @@ -454,8 +467,6 @@ def __init__(self,

super(GPTModel, self).__init__()

recompute_attn = use_recompute and recompute_granularity == "only_attn"

self.initializer_range = initializer_range
self.hidden_size = hidden_size
self.vocab_size = vocab_size
Expand All @@ -480,7 +491,8 @@ def __init__(self,
mean=0.0, std=self.initializer_range)),
bias_attr=None,
fused_linear=fused_linear,
recompute_attn=recompute_attn))
use_recompute=use_recompute,
recompute_granularity=recompute_granularity))

self.decoder = TransformerDecoder(
decoder_layers,
Expand Down
89 changes: 52 additions & 37 deletions fleetx/models/gpt_model/modeling_hybrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,9 @@ def __init__(self,
bias_attr=None,
fuse=True,
num_partitions=1,
fused_linear=False):
fused_linear=False,
use_recompute=False,
recompute_granularity="full"):
super(MultiHeadAttention, self).__init__()
self.embed_dim = embed_dim
self.kdim = kdim if kdim is not None else embed_dim
Expand All @@ -88,6 +90,8 @@ def __init__(self,
self.dropout = dropout
self.need_weights = need_weights
self.fuse = fuse
self.use_recompute = use_recompute
self.recompute_granularity = recompute_granularity

self.head_dim = embed_dim // num_heads
assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads"
Expand Down Expand Up @@ -219,6 +223,34 @@ def gen_cache(self, key, value=None, type=Cache):
# incremental_state with initial value, mainly for usage like UniLM
return self.Cache(key, value)

def core_attn(self, q, k, v):
# scale dot product attention
product = layers.matmul(
x=q, y=k, transpose_y=True, alpha=self.head_dim ** -0.5)

# if attn_mask is not None:
# product = product + attn_mask

# weights = F.softmax(product)

weights = incubate.softmax_mask_fuse_upper_triangle(product)

if self.dropout:
with get_rng_state_tracker().rng_state('local_seed'):
weights = F.dropout(
weights,
self.dropout,
training=self.training,
mode="upscale_in_train")

out = tensor.matmul(weights, v)

# combine heads
out = tensor.transpose(out, perm=[0, 2, 1, 3])
out = tensor.reshape(x=out, shape=[0, 0, out.shape[2] * out.shape[3]])

return out, weights

def forward(self,
query,
key,
Expand All @@ -242,30 +274,11 @@ def forward(self,
else:
q, k, v, cache = self._prepare_qkv(query, key, value, use_cache,
cache)
# scale dot product attention
product = layers.matmul(
x=q, y=k, transpose_y=True, alpha=self.head_dim**-0.5)

# if attn_mask is not None:
# product = product + attn_mask

# weights = F.softmax(product)

weights = incubate.softmax_mask_fuse_upper_triangle(product)

if self.dropout:
with get_rng_state_tracker().rng_state('local_seed'):
weights = F.dropout(
weights,
self.dropout,
training=self.training,
mode="upscale_in_train")

out = tensor.matmul(weights, v)

# combine heads
out = tensor.transpose(out, perm=[0, 2, 1, 3])
out = tensor.reshape(x=out, shape=[0, 0, out.shape[2] * out.shape[3]])
if self.use_recompute and self.recompute_granularity == "core_attn":
out, weights = recompute(self.core_attn, q, k, v)
else:
out, weights = self.core_attn(q, k, v)

# project to output
out = self.out_proj(out)
Expand Down Expand Up @@ -380,7 +393,9 @@ def __init__(self,
bias_attr=None,
num_partitions=1,
fused_linear=False,
recompute_attn=False):
recompute_attn=False,
use_recompute=False,
recompute_granularity="full"):
self._config = locals()
self._config.pop("self")
self._config.pop("__class__", None) # py3
Expand All @@ -389,7 +404,8 @@ def __init__(self,
attn_dropout = dropout if attn_dropout is None else attn_dropout
act_dropout = dropout if act_dropout is None else act_dropout
self.normalize_before = normalize_before
self.recompute_attn = recompute_attn
self.use_recompute = use_recompute
self.recompute_granularity = recompute_granularity

weight_attrs = _convert_param_attr_to_list(weight_attr, 3)
bias_attrs = _convert_param_attr_to_list(bias_attr, 3)
Expand All @@ -401,7 +417,9 @@ def __init__(self,
weight_attr=weight_attrs[0],
bias_attr=bias_attrs[0],
num_partitions=num_partitions,
fused_linear=fused_linear)
fused_linear=fused_linear,
use_recompute=use_recompute,
recompute_granularity=recompute_granularity)

self.linear1 = fleet.meta_parallel.ColumnParallelLinear(
d_model,
Expand Down Expand Up @@ -437,9 +455,8 @@ def forward(self,
tgt = self.norm1(tgt)

if use_cache is False:
if self.recompute_attn:
tgt = recompute(self.self_attn, tgt, None, None, tgt_mask,
use_cache, cache)
if self.use_recompute and self.recompute_granularity == "full_attn":
tgt = recompute(self.self_attn, tgt, None, None, tgt_mask, use_cache, cache)
else:
tgt = self.self_attn(tgt, tgt, tgt, tgt_mask, use_cache, cache)
else:
Expand Down Expand Up @@ -535,8 +552,6 @@ def __init__(self,

super(GPTModel, self).__init__()

recompute_attn = use_recompute and recompute_granularity == "only_attn"

self.initializer_range = initializer_range
self.hidden_size = hidden_size
self.vocab_size = vocab_size
Expand All @@ -562,7 +577,8 @@ def __init__(self,
bias_attr=None,
num_partitions=num_partitions,
fused_linear=fused_linear,
recompute_attn=recompute_attn))
use_recompute=use_recompute,
recompute_granularity=recompute_granularity))

self.decoder = TransformerDecoder(
decoder_layers,
Expand Down Expand Up @@ -765,8 +781,6 @@ def __init__(self,
fused_linear=False,
recompute_granularity="full"):

recompute_attn = use_recompute and recompute_granularity == "only_attn"

# forward desc
self.descs = []

Expand Down Expand Up @@ -799,7 +813,8 @@ def __init__(self,
bias_attr=None,
num_partitions=num_partitions,
fused_linear=fused_linear,
recompute_attn=recompute_attn))
use_recompute=use_recompute,
recompute_granularity=recompute_granularity))

self.descs.append(
LayerDesc(
Expand All @@ -822,7 +837,7 @@ def _logits_helper(embedding, output):
initializer_range=0.02))

recompute_interval = 0
if recompute and not recompute_attn:
if recompute and recompute_granularity == "full":
recompute_interval = 1

super().__init__(
Expand Down