diff --git a/examples/gpt/hybrid_parallel/README.md b/examples/gpt/hybrid_parallel/README.md index e54e47e86..9d3e4b1a8 100644 --- a/examples/gpt/hybrid_parallel/README.md +++ b/examples/gpt/hybrid_parallel/README.md @@ -76,7 +76,7 @@ Data: | type_vocab_size | 词表类型 | | initializer_range | 参数初始化的范围 | | use_recompute | 是否使用recompute训练 | -| recompute_granularity | recompute训练的粒度,可选 `full` `only_attn`,full即recompute全部transformer,only_attn表明只recompute self attention部分 | +| recompute_granularity | recompute训练的粒度,可选 `full` `full_attn` `core_attn`,full即recompute全部transformer,full_attn表明只recompute所有self attention部分,core_attn表明只recompute `softmax(qkT)v` 部分。注:显存占用方面,`core_attn` > `full_attn` > `full`,若所选策略产生OOM错误,可以适当更改recompute_granularity | | fused_linear | 是否使用fused_linear代替传统Linear加速训练。注:该功能需要cuda 11.6及以上编译的paddle支持。 | diff --git a/examples/gpt/single/README.md b/examples/gpt/single/README.md index fdfb28099..41aeade19 100644 --- a/examples/gpt/single/README.md +++ b/examples/gpt/single/README.md @@ -68,7 +68,7 @@ Data: | type_vocab_size | 词表类型 | | initializer_range | 参数初始化的范围 | | use_recompute | 是否使用recompute训练 | -| recompute_granularity | recompute训练的粒度,可选 `full` `only_attn`,full即recompute全部transformer,only_attn表明只recompute self attention部分 | +| recompute_granularity | recompute训练的粒度,可选 `full` `full_attn` `core_attn`,full即recompute全部transformer,full_attn表明只recompute所有self attention部分,core_attn表明只recompute `softmax(qkT)v` 部分。注:显存占用方面,`core_attn` > `full_attn` > `full`,若所选策略产生OOM错误,可以适当更改recompute_granularity | | fused_linear | 是否使用fused_linear代替传统Linear加速训练。注:该功能需要cuda 11.6及以上编译的paddle支持。 | ### 优化器 diff --git a/examples/gpt/tools.py b/examples/gpt/tools.py index f1e936850..3ef259f22 100644 --- a/examples/gpt/tools.py +++ b/examples/gpt/tools.py @@ -113,9 +113,9 @@ def process_model_configs(yaml_dict): configs['ffn_hidden_size'] = 4 * configs['hidden_size'] if configs['use_recompute']: - assert configs['recompute_granularity'] in ["full", "only_attn"], \ + assert configs['recompute_granularity'] in ["full", "full_attn", "core_attn"], \ "recompute_granularity can be only chosen from " \ - "'full' or 'only_attn', but received '{}'".format(configs['recompute_granularity']) + "'full', 'full_attn' or 'core_attn', but received '{}'".format(configs['recompute_granularity']) if configs['fused_linear'] and not is_fused_matmul_bias_supported(): configs['fused_linear'] = False diff --git a/fleetx/models/gpt_model/modeling.py b/fleetx/models/gpt_model/modeling.py index 7a59740cd..d992af611 100644 --- a/fleetx/models/gpt_model/modeling.py +++ b/fleetx/models/gpt_model/modeling.py @@ -54,8 +54,10 @@ def __init__(self, need_weights=False, weight_attr=None, bias_attr=None, - fuse=False, - fused_linear=False): + fuse=True, + fused_linear=False, + use_recompute=False, + recompute_granularity="full"): super(MultiHeadAttention, self).__init__() self.embed_dim = embed_dim self.kdim = kdim if kdim is not None else embed_dim @@ -64,6 +66,8 @@ def __init__(self, self.dropout = dropout self.need_weights = need_weights self.fuse = fuse + self.use_recompute = use_recompute + self.recompute_granularity = recompute_granularity self.head_dim = embed_dim // num_heads assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads" @@ -164,6 +168,31 @@ def gen_cache(self, key, value=None, type=Cache): # incremental_state with initial value, mainly for usage like UniLM return self.Cache(key, value) + def core_attn(self, q, k, v): + # scale dot product attention + product = layers.matmul( + x=q, y=k, transpose_y=True, alpha=self.head_dim ** -0.5) + + # TODO(liuyuang): support softmax_mask_fuse_upper_triangle for generation task + weights = F.softmax(product) + + # weights = incubate.softmax_mask_fuse_upper_triangle(product) + + if self.dropout: + weights = F.dropout( + weights, + self.dropout, + training=self.training, + mode="upscale_in_train") + + out = tensor.matmul(weights, v) + + # combine heads + out = tensor.transpose(out, perm=[0, 2, 1, 3]) + out = tensor.reshape(x=out, shape=[0, 0, out.shape[2] * out.shape[3]]) + + return out, weights + def forward(self, query, key, @@ -187,30 +216,11 @@ def forward(self, else: q, k, v, cache = self._prepare_qkv(query, key, value, use_cache, cache) - # scale dot product attention - product = layers.matmul( - x=q, y=k, transpose_y=True, alpha=self.head_dim**-0.5) - - # if attn_mask is not None: - # product = product + attn_mask - - # TODO(liuyuang): support softmax_mask_fuse_upper_triangle for generation task - weights = F.softmax(product) - - # weights = incubate.softmax_mask_fuse_upper_triangle(product) - - if self.dropout: - weights = F.dropout( - weights, - self.dropout, - training=self.training, - mode="upscale_in_train") - out = tensor.matmul(weights, v) - - # combine heads - out = tensor.transpose(out, perm=[0, 2, 1, 3]) - out = tensor.reshape(x=out, shape=[0, 0, out.shape[2] * out.shape[3]]) + if self.use_recompute and self.recompute_granularity == "core_attn": + out, weights = recompute(self.core_attn, q, k, v) + else: + out, weights = self.core_attn(q, k, v) # project to output out = self.out_proj(out) @@ -323,7 +333,8 @@ def __init__(self, weight_attr=None, bias_attr=None, fused_linear=False, - recompute_attn=False): + use_recompute=False, + recompute_granularity="full"): self._config = locals() self._config.pop("self") self._config.pop("__class__", None) # py3 @@ -332,7 +343,8 @@ def __init__(self, attn_dropout = dropout if attn_dropout is None else attn_dropout act_dropout = dropout if act_dropout is None else act_dropout self.normalize_before = normalize_before - self.recompute_attn = recompute_attn + self.use_recompute = use_recompute + self.recompute_granularity = recompute_granularity weight_attrs = _convert_param_attr_to_list(weight_attr, 3) bias_attrs = _convert_param_attr_to_list(bias_attr, 3) @@ -345,7 +357,9 @@ def __init__(self, dropout=attn_dropout, weight_attr=weight_attrs[0], bias_attr=bias_attrs[0], - fused_linear=fused_linear) + fused_linear=fused_linear, + use_recompute=use_recompute, + recompute_granularity=recompute_granularity) self.linear1 = Linear( d_model, dim_feedforward, weight_attrs[2], bias_attr=bias_attrs[2]) self.linear2 = Linear( @@ -364,9 +378,8 @@ def forward(self, tgt, memory, tgt_mask=None, use_cache=False, cache=None): tgt = self.norm1(tgt) if use_cache is False: - if self.recompute_attn: - tgt = recompute(self.self_attn, tgt, None, None, tgt_mask, - use_cache, cache) + if self.use_recompute and self.recompute_granularity == "full_attn": + tgt = recompute(self.self_attn, tgt, None, None, tgt_mask, use_cache, cache) else: tgt = self.self_attn(tgt, tgt, tgt, tgt_mask, use_cache, cache) else: @@ -454,8 +467,6 @@ def __init__(self, super(GPTModel, self).__init__() - recompute_attn = use_recompute and recompute_granularity == "only_attn" - self.initializer_range = initializer_range self.hidden_size = hidden_size self.vocab_size = vocab_size @@ -480,7 +491,8 @@ def __init__(self, mean=0.0, std=self.initializer_range)), bias_attr=None, fused_linear=fused_linear, - recompute_attn=recompute_attn)) + use_recompute=use_recompute, + recompute_granularity=recompute_granularity)) self.decoder = TransformerDecoder( decoder_layers, diff --git a/fleetx/models/gpt_model/modeling_hybrid.py b/fleetx/models/gpt_model/modeling_hybrid.py index 690478966..0f78987a1 100644 --- a/fleetx/models/gpt_model/modeling_hybrid.py +++ b/fleetx/models/gpt_model/modeling_hybrid.py @@ -79,7 +79,9 @@ def __init__(self, bias_attr=None, fuse=True, num_partitions=1, - fused_linear=False): + fused_linear=False, + use_recompute=False, + recompute_granularity="full"): super(MultiHeadAttention, self).__init__() self.embed_dim = embed_dim self.kdim = kdim if kdim is not None else embed_dim @@ -88,6 +90,8 @@ def __init__(self, self.dropout = dropout self.need_weights = need_weights self.fuse = fuse + self.use_recompute = use_recompute + self.recompute_granularity = recompute_granularity self.head_dim = embed_dim // num_heads assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads" @@ -219,6 +223,34 @@ def gen_cache(self, key, value=None, type=Cache): # incremental_state with initial value, mainly for usage like UniLM return self.Cache(key, value) + def core_attn(self, q, k, v): + # scale dot product attention + product = layers.matmul( + x=q, y=k, transpose_y=True, alpha=self.head_dim ** -0.5) + + # if attn_mask is not None: + # product = product + attn_mask + + # weights = F.softmax(product) + + weights = incubate.softmax_mask_fuse_upper_triangle(product) + + if self.dropout: + with get_rng_state_tracker().rng_state('local_seed'): + weights = F.dropout( + weights, + self.dropout, + training=self.training, + mode="upscale_in_train") + + out = tensor.matmul(weights, v) + + # combine heads + out = tensor.transpose(out, perm=[0, 2, 1, 3]) + out = tensor.reshape(x=out, shape=[0, 0, out.shape[2] * out.shape[3]]) + + return out, weights + def forward(self, query, key, @@ -242,30 +274,11 @@ def forward(self, else: q, k, v, cache = self._prepare_qkv(query, key, value, use_cache, cache) - # scale dot product attention - product = layers.matmul( - x=q, y=k, transpose_y=True, alpha=self.head_dim**-0.5) - - # if attn_mask is not None: - # product = product + attn_mask - - # weights = F.softmax(product) - - weights = incubate.softmax_mask_fuse_upper_triangle(product) - - if self.dropout: - with get_rng_state_tracker().rng_state('local_seed'): - weights = F.dropout( - weights, - self.dropout, - training=self.training, - mode="upscale_in_train") - - out = tensor.matmul(weights, v) - # combine heads - out = tensor.transpose(out, perm=[0, 2, 1, 3]) - out = tensor.reshape(x=out, shape=[0, 0, out.shape[2] * out.shape[3]]) + if self.use_recompute and self.recompute_granularity == "core_attn": + out, weights = recompute(self.core_attn, q, k, v) + else: + out, weights = self.core_attn(q, k, v) # project to output out = self.out_proj(out) @@ -380,7 +393,9 @@ def __init__(self, bias_attr=None, num_partitions=1, fused_linear=False, - recompute_attn=False): + recompute_attn=False, + use_recompute=False, + recompute_granularity="full"): self._config = locals() self._config.pop("self") self._config.pop("__class__", None) # py3 @@ -389,7 +404,8 @@ def __init__(self, attn_dropout = dropout if attn_dropout is None else attn_dropout act_dropout = dropout if act_dropout is None else act_dropout self.normalize_before = normalize_before - self.recompute_attn = recompute_attn + self.use_recompute = use_recompute + self.recompute_granularity = recompute_granularity weight_attrs = _convert_param_attr_to_list(weight_attr, 3) bias_attrs = _convert_param_attr_to_list(bias_attr, 3) @@ -401,7 +417,9 @@ def __init__(self, weight_attr=weight_attrs[0], bias_attr=bias_attrs[0], num_partitions=num_partitions, - fused_linear=fused_linear) + fused_linear=fused_linear, + use_recompute=use_recompute, + recompute_granularity=recompute_granularity) self.linear1 = fleet.meta_parallel.ColumnParallelLinear( d_model, @@ -437,9 +455,8 @@ def forward(self, tgt = self.norm1(tgt) if use_cache is False: - if self.recompute_attn: - tgt = recompute(self.self_attn, tgt, None, None, tgt_mask, - use_cache, cache) + if self.use_recompute and self.recompute_granularity == "full_attn": + tgt = recompute(self.self_attn, tgt, None, None, tgt_mask, use_cache, cache) else: tgt = self.self_attn(tgt, tgt, tgt, tgt_mask, use_cache, cache) else: @@ -535,8 +552,6 @@ def __init__(self, super(GPTModel, self).__init__() - recompute_attn = use_recompute and recompute_granularity == "only_attn" - self.initializer_range = initializer_range self.hidden_size = hidden_size self.vocab_size = vocab_size @@ -562,7 +577,8 @@ def __init__(self, bias_attr=None, num_partitions=num_partitions, fused_linear=fused_linear, - recompute_attn=recompute_attn)) + use_recompute=use_recompute, + recompute_granularity=recompute_granularity)) self.decoder = TransformerDecoder( decoder_layers, @@ -765,8 +781,6 @@ def __init__(self, fused_linear=False, recompute_granularity="full"): - recompute_attn = use_recompute and recompute_granularity == "only_attn" - # forward desc self.descs = [] @@ -799,7 +813,8 @@ def __init__(self, bias_attr=None, num_partitions=num_partitions, fused_linear=fused_linear, - recompute_attn=recompute_attn)) + use_recompute=use_recompute, + recompute_granularity=recompute_granularity)) self.descs.append( LayerDesc( @@ -822,7 +837,7 @@ def _logits_helper(embedding, output): initializer_range=0.02)) recompute_interval = 0 - if recompute and not recompute_attn: + if recompute and recompute_granularity == "full": recompute_interval = 1 super().__init__(