【快乐开源】CINN编译器符号推导扩量 #66444
Labels
PFCC
Paddle Framework Contributor Club,https://github.com/PaddlePaddle/community/tree/master/pfcc
status/new-issue
新建
type/others
其他问题
任务划分🧾
快乐开源2024
符号推导按照接口实现难度划分为“简单”、“中等”和“复杂”三个等级,前期只开放简单任务,后续会逐步开放更具有挑战性的任务。任务列表见:InferSymbolicShape接口实现任务列表🧾
认领方式
请大家直接在👆的excel表中认领任务,如:
一、需求背景
深度学习模型常需要处理各种形状和尺寸的数据,支持动态 Shape 特性的深度学习框架允许模型在训练和推理过程中适应不同尺寸的输入,从而提高模型的灵活性和通用性。动态 Shape 功能允许模型在训练和推理时延迟计算张量的部分或全部维度,直到运行时再确定,因此可以根据实际的输入尺寸选择合适的优化策略,以达到最佳性能。
以 paddle.reshape()为例,现在的动态维度用 "-1" 表示,信息量少表示能力较弱,很多约束信息没法表示出来。而在CINN编译器引入了Shape Dialect之后,CINN能够直接基于动态Shape语义进行编译与优化。在CINN内部能够直接使用“S0”、“S1”这样的符号表示张量的维度信息,并能够通过添加一些维度约束限制为简化符号推导过程和后续编译优化提供指导信息。
二、参考文档
在实现具体算子的符号推导接口时,需要了解符号的表示和推导相关的基础概念
2.1 动态Shape符号表示
在CINN中,符号通过DimExpr、ShapeOrData、ShapeOrDataDimExprs三个不同的抽象层次进行表示
2.1.1 DimExpr
DimExpr是Shape Dialect最底层的数据结构,用于表示单个维度对应的符号信息。目前符号表示的语法支持int64_t的整数,string(一般为符号推导产生的从"S0","S1","S2"...一系列的新符号)以及加减乘除等复合语法,符号的操作也支持加减乘除运算以及相等判断。
2.1.2 ShapeOrData
基于单个维度的抽象表示DimExpr,Shape Dialect使用ShapeOrData表示Tensor对应的符号维度信息,如[1, S0, S1]。
// paddle/pir/include/dialect/shape/utils/shape_or_data_expr.h std::vector<T> shape_; std::optional<std::vector<T>> data_;
解释说明:一般来说,Tensor对应的符号维度信息用vector数据结构表示就够了,但对于一些特殊情况,vector的表示能力会有所不足,例如 :
在这个例子中,Tensor y 是存储了Tensor x 的 shape 信息,y 本身的shape为[2]。单算子的符号推导是根据操作数的shape信息和attribute信息实现的,此时如果要实现reshape op的符号推导接口就会发现我们真正需要的是Tensor y本身存储的信息[S0, S1],而不是 shape of y: [2] ,因此CINN中设计了data区以提升张量形状信息的表示能力。
2.1.3 ShapeOrDataDimExprs
ShapeOrDataDimExprs用于表示value的符号信息。由于value不仅可以是DenseTensorType还可能是VectorType,相应地ShapeOrDataDimExprs也需要是TensorShapeOrDataDimExprs或TensorListShapeOrDataDimExprs,实现中使用std::variant支持多种类型控制。
2.2 符号推导重要组件
在实现单算子符号推导接口时,除了基础的符号表示,还需要掌握 CINN 编译器的 Shape Dialect 中提供的两个重要组件:符号推导的上下文管理器 InferSymbolicShapeContext 和 符号间的约束信息管理 ConstraintsManager
2.2.1 InferSymbolicShapeContext
InferSymbolicShapeContext是符号推导的一个上下文环境类,单算子符号推导接口开发主要会用到以下接口:
2.2.2 符号约束
建立约束的核心设计理念是,减少新符号,提升性能。目前约束包括Equal、GTOne(大于 1)、Broadcastable。实现具体Op的符号推导接口时无需关注ConstraintsManager的具体实现,只需了解上诉三个约束的添加方法:
其中的Broadcastable约束,将从后向前比较两个 Tensor 的形状,需要满足如下至少一个条件才能进行广播:
因此如果能知道某个维度的值大于1,那么它在参与Broadcast时候一定与最终的Broadcast结果相同;如果两个参与Broadcast的维度都大于1,那么这两个维度一定相等且于最终的Broadcast结果相同。
2.2.3 单算子符号推导
算子继承InferSymbolicShapeInterface接口来实现符号推导,该接口传入符号推导上下文并对齐进行修改。具体来说就是从符号推导上下文中获取所需输入value的符号信息然后通过符号计算得到并在符号推导上下文中设置输入value的符号信息。除此之外,需要根据算子的计算特点在符号推导上下文中加入符号的约束关系。以Matmul算子的符号推导接口为例:
三、开发流程和示例
3.1 开发流程
编译和测试命令
编译命令(以python3.9为例):
测试命令(以test_reshape_op.py为例):
3.2 添加接口示例PR
#65880
#65889
3.3 常见Debug问题
在本地Op级别测试和CI验证的Debug阶段可能会遇到如下类型的问题:
The text was updated successfully, but these errors were encountered: