Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support optional residual add in fused_attention and fused_feedforward. #43474

Merged
merged 4 commits into from
Jun 17, 2022

Conversation

Xreki
Copy link
Contributor

@Xreki Xreki commented Jun 13, 2022

PR types

Function optimization

PR changes

OPs

Describe

develop中fused_attention op对等的小算子组网代码如下:

  // input: [batch_size, seq_len, embed_dim]
  // final_out: [batch_size, seq_len, num_heads, head_dim]
  if (pre_layernorm)
    query = layer_norm(input);
  out = compute_qkv(query) + qkv_bias;
  // fmha module
  {
    out = transpose(out, perm=[2, 0, 3, 1, 4]);
    out = q * k^t;
    out = attn_mask + out;
    out = softmax(out);
    out = dropout(out);
    out = out * v;
    out = transpose(out, perm=[0, 2, 1, 3]);
  }
  out = out_linear(out);
  if (pre_layernorm)
    final_out = residual + dropout(bias + out);
  else
    final_out = layer_norm(residual + dropout(bias + out));

CAE模型里面的Attention用法略有不同,Attention结构中主要包括以下几点:

  1. head_dim = dim // num_heads if attn_head_dim is None attn_head_dim,允许指定attn_head_dim为其他值。模型中使用默认值。
  2. self.scale = qk_scale or head_dim ** -0.5,qk计算结果的scale系数,可以指定为其他值。模型中使用默认值。
  3. q_biasv_bias、无k_bias。模型中通过如下方式,每次计算合并的QKV矩阵乘之前,将k_bias的值置为0。
        qkv_bias = None
        if self.q_bias is not None:
            k_bias = paddle.zeros_like(self.v_bias)
            k_bias.stop_gradient = True
            qkv_bias = paddle.concat((self.q_bias, k_bias, self.v_bias))

Attention使用的地方也有所不同,代码如下:

        if self.gamma_1 is None:
            x = x + self.drop_path(self.attn(self.norm1(x), bool_masked_pos))
            x = x + self.drop_path(self.mlp(self.norm2(x)))
        else:
            x = x + self.drop_path(self.gamma_1 * self.attn(self.norm1(x), bool_masked_pos))
            x = x + self.drop_path(self.gamma_2 * self.mlp(self.norm2(x)))

self.drop_pathdropout_prob参数为0、self.gamma_1self.gamma_2为None,x = x + self.attn(self.norm1(x), bool_masked_pos)x = x + self.mlp(self.norm2(x))可直接调用fused_attentionfused_feedforward融合算子。

实际模型中,self.drop_path中的dropout_prob为0,但self.gamma_1self.gamma_2都不为None。若要使用fused_attentionfused_feedforward融合算子,需要对算子功能进行改动,有2种方案:

  1. fused_attentionfused_feedforward中加上乘gamma计算。好处是融合粒度更大;坏处是使得融合算子更加复杂,另外一旦修改self.drop_path中的dropout_prob值,则依然不能使用融合算子,相对来说通用性更低。
  2. fused_attentionfused_feedforward中增加一个add_residual属性,用以控制最后一步是否进行加residual操作。

当前PR采用的是方案2,且已经在CAE模型中验证,模型性能提升7%。

@paddle-bot-old
Copy link

你的PR提交成功,感谢你对开源项目的贡献!
请关注后续CI自动化测试结果,详情请参考Paddle-CI手册
Your PR has been submitted. Thanks for your contribution!
Please wait for the result of CI firstly. See Paddle CI Manual for details.

@Xreki Xreki force-pushed the op/fused_attention_residual branch from 606c594 to 85d5041 Compare June 14, 2022 06:29
@Xreki Xreki force-pushed the op/fused_attention_residual branch from e09e57f to de26128 Compare June 14, 2022 06:59
@Xreki Xreki force-pushed the op/fused_attention_residual branch from da294c9 to eceed0b Compare June 16, 2022 10:12
Copy link
Contributor

@zhangting2020 zhangting2020 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@Xreki Xreki requested review from zkh2016 and limin2021 June 17, 2022 06:31
Copy link
Contributor

@limin2021 limin2021 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM for fused_attention.

@Xreki Xreki requested a review from lanxianghit June 17, 2022 07:09
@Xreki Xreki requested a review from qingqing01 June 17, 2022 07:15
@@ -454,6 +459,7 @@ def fused_multi_head_attention(x,
- train: out = input * mask
- inference: out = input * (1.0 - p)
ring_id (int, optional): For distributed forward in mp, only support NCCL and forward. Default is -1, means not using mp
add_residual (bool, optional): Whether add residual at the end. Default is True.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. 当前API在大模型推理中有用到,增加attr对推理无影响。
  2. 对上面文档 code-block:: python 里公式 也更新下吧。 原始的功能,默认是True吗?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

原始的功能,默认是True吗?

是的

对上面文档 code-block:: python 里公式 也更新下吧

我下个PR修改吧。

@Shixiaowei02
Copy link
Contributor

@Xreki 请下个 PR 对算子参数的修改添加 .AsExtra() 标记,满足推理模型的兼容性需求。谢谢

@Xreki Xreki requested a review from TCChenlong June 17, 2022 07:36
@Xreki Xreki merged commit 19e866f into PaddlePaddle:develop Jun 17, 2022
@Xreki Xreki deleted the op/fused_attention_residual branch June 17, 2022 10:10
@Xreki
Copy link
Contributor Author

Xreki commented Jun 20, 2022

请下个 PR 对算子参数的修改添加 .AsExtra() 标记,满足推理模型的兼容性需求。

@Shixiaowei02@cyj1986 确认过了,新增的add_residual属性会影响融合算子的功能,即设置与否,算子的功能不一样,故不应添加AsExtra标记。

zhangting2020 pushed a commit to zhangting2020/Paddle that referenced this pull request Jun 21, 2022
…d. (PaddlePaddle#43474)

* Support optional residual add in fused_attention and fused_feedforward.

* Add checkpoint and add the check of add_residual when pre_layer_norm is false.

* Add TODO and change the python api to add add_residual argument.
lanxianghit pushed a commit that referenced this pull request Jun 22, 2022
…rge tensor for cudnn_softmax (#43719)

 [cherry pick] Support optional residual add in fused ops and slice large tensor for cudnn_softmax

cherry-pick #43635 #43681 #43474
sneaxiy pushed a commit to sneaxiy/Paddle that referenced this pull request Jun 27, 2022
…d. (PaddlePaddle#43474)

* Support optional residual add in fused_attention and fused_feedforward.

* Add checkpoint and add the check of add_residual when pre_layer_norm is false.

* Add TODO and change the python api to add add_residual argument.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

8 participants