Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Dy2Stat]Refactor convert_shape transformer logic #43846

Merged
merged 2 commits into from
Jun 27, 2022

Conversation

Aurelius84
Copy link
Contributor

@Aurelius84 Aurelius84 commented Jun 27, 2022

PR types

Others

PR changes

Others

Describe

What's New?

重构了动转静中对于 x.shapepaddle.shape 的转写逻辑,升级为JIT运行时动态图判断执行,解决了shape API的重要头部问题,用户不需要再手动去调整此类代码,极大地提升了用户的转写体验。

1. 问题背景

在之前的动转静场景中,用户模型代码中经常用到类似 B, C, H, W = x.shape的代码,但当x中包含动态shape时(值含有-1或None),上述代码在转静时会返回-1,失去了动态语义。因此动转静会做静态分析,将其转换为B,C,H,W = paddle.shape(x),这样处理有如下缺点:

  • 依赖静态类型分析。由于是借助AST的解析,判断B C H W是否被用到了控制流中,所以总会存在边界case
  • 静态shape值也变成了Tensor。假设H W两个值是静态值,之前由于会转为 paddle.shape + slice,导致H W也动态化了。当下游的 API 参数不支持Tensor时,则会导出报错
  • 逻辑复杂。之前的实现逻辑基于AST,导致转写解析逻辑比较复杂,可维护性比较差,且鲁棒性,易用性体验也一般。

2. 迭代方案

此 PR 将 Tensor Shape逻辑转换为 JIT 动态判断执行,统一将B, C, H, W = x.shape 转换为B, C, H, W = _jst.convert_shape(x):

def convert_shape(x):
    """
    A function representation of the shape of variable.
    """

    if isinstance(x, Variable):
        values = list(x.shape)
        if has_negative(values):
            shape_tensor = nn.shape(x)
            for i, v in enumerate(values):
                if v is None or v < 0:
                    values[i] = shape_tensor[i]
        return values
    else:
        return x.shape

简而言之,分为如下两个场景:

  • x.shape = [2, 3, 4 ],则转换后不做任何变换,仅进行tuple → list 类型变换
  • x.shape = [2, -1, 4],若包含-1,则返回 [2, var, 4],其中 var = paddle.shape(x)[1]

3. 方案优势

升级后的方案,具有如下优势:

  • 代码简洁,鲁棒性好。由于升级为运行时判断,故不再依赖静态类型分析, 更加鲁棒;用户不需要再手动调整代码了,统一使用 x.shape即可
  • 必要时转动态shape。对于x中部分值包含-1,此方案依旧会返回list,且只会对动态shape转为var,更加灵活,尽可能保持动静统一
  • 性能更优。之前的方案可能会重复调用 paddle.shape(x),此方案最多仅会调用一次 paddle.shape,执行性能上会有收益

4. 主要工作

  • 重构了代码逻辑,删除了旧的冗余逻辑;
  • 删除了无用的单测(因旧的接口代码已删除)

TODO:

  • 移除 convert_operation中 convert前缀,精简转写代码,提升可读性
  • 删除无用 Tensor Shape 相关的公共函数

@paddle-bot-old
Copy link

你的PR提交成功,感谢你对开源项目的贡献!
请关注后续CI自动化测试结果,详情请参考Paddle-CI手册
Your PR has been submitted. Thanks for your contribution!
Please wait for the result of CI firstly. See Paddle CI Manual for details.

Copy link
Contributor

@2742195759 2742195759 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@2742195759 2742195759 self-requested a review June 27, 2022 08:30
Copy link
Contributor

@2742195759 2742195759 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@Aurelius84 Aurelius84 merged commit d82d5b8 into PaddlePaddle:develop Jun 27, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants